forked from
microcosm.blue/microcosm-rs
Constellation, Spacedust, Slingshot, UFOs: atproto crates and services for microcosm
1use atrium_api::types::string::Did;
2use atrium_oauth::OAuthClientMetadata;
3use axum::{
4 Router,
5 extract::{FromRef, Json as ExtractJson, Query, State},
6 http::{
7 StatusCode,
8 header::{CONTENT_SECURITY_POLICY, CONTENT_TYPE, HeaderMap, ORIGIN, REFERER},
9 },
10 response::{IntoResponse, Json, Redirect, Response},
11 routing::{get, post},
12};
13use axum_extra::extract::cookie::{Cookie, Expiration, Key, SameSite, SignedCookieJar};
14use axum_template::{RenderHtml, engine::Engine};
15use handlebars::{Handlebars, handlebars_helper};
16use jose_jwk::JwkSet;
17use std::path::PathBuf;
18
19use serde::Deserialize;
20use serde_json::{Value, json};
21use std::collections::HashSet;
22use std::sync::Arc;
23use std::time::{Duration, SystemTime};
24use tokio::net::TcpListener;
25use tokio_util::sync::CancellationToken;
26use url::Url;
27
28use crate::{
29 ExpiringTaskMap, OAuth, OAuthCallbackParams, OAuthCompleteError, ResolveHandleError, Tokens,
30};
31
32const FAVICON: &[u8] = include_bytes!("../static/favicon.ico");
33const STYLE_CSS: &str = include_str!("../static/style.css");
34
35const HELLO_COOKIE_KEY: &str = "hello-who-am-i";
36const DID_COOKIE_KEY: &str = "did";
37
38const COOKIE_EXPIRATION: Duration = Duration::from_secs(30 * 86_400);
39
40type AppEngine = Engine<Handlebars<'static>>;
41
42#[derive(Clone)]
43struct AppState {
44 pub key: Key,
45 pub allowed_hosts: Arc<HashSet<String>>,
46 pub engine: AppEngine,
47 pub oauth: Arc<OAuth>,
48 pub resolve_handles: ExpiringTaskMap<Result<String, ResolveHandleError>>,
49 pub shutdown: CancellationToken,
50 pub tokens: Arc<Tokens>,
51}
52
53impl FromRef<AppState> for Key {
54 fn from_ref(state: &AppState) -> Self {
55 state.key.clone()
56 }
57}
58
59#[allow(clippy::too_many_arguments)]
60pub async fn serve(
61 shutdown: CancellationToken,
62 app_secret: String,
63 oauth_private_key: Option<PathBuf>,
64 tokens: Tokens,
65 base: String,
66 bind: String,
67 allowed_hosts: Vec<String>,
68 dev: bool,
69) {
70 let mut hbs = Handlebars::new();
71 hbs.set_dev_mode(dev);
72 hbs.register_templates_directory("templates", Default::default())
73 .unwrap();
74
75 handlebars_helper!(json: |v: Value| serde_json::to_string(&v).unwrap());
76 hbs.register_helper("json", Box::new(json));
77
78 // clients have to pick up their identity-resolving tasks within this period
79 let task_pickup_expiration = Duration::from_secs(15);
80
81 let oauth = OAuth::new(oauth_private_key, base).unwrap();
82
83 let state = AppState {
84 engine: Engine::new(hbs),
85 key: Key::from(app_secret.as_bytes()), // TODO: via config
86 allowed_hosts: Arc::new(HashSet::from_iter(allowed_hosts)),
87 oauth: Arc::new(oauth),
88 resolve_handles: ExpiringTaskMap::new(task_pickup_expiration),
89 shutdown: shutdown.clone(),
90 tokens: Arc::new(tokens),
91 };
92
93 let app = Router::new()
94 .route("/", get(hello))
95 .route("/favicon.ico", get(favicon)) // todo MIME
96 .route("/style.css", get(css))
97 .route("/prompt", get(prompt))
98 .route("/user-info", post(user_info))
99 .route("/client-metadata.json", get(client_metadata))
100 .route("/auth", get(start_oauth))
101 .route("/authorized", get(complete_oauth))
102 .route("/disconnect", post(disconnect))
103 .route("/.well-known/jwks.json", get(jwks))
104 .with_state(state);
105
106 eprintln!("starting server at http://{bind}");
107 let listener = TcpListener::bind(bind)
108 .await
109 .expect("listener binding to work");
110
111 axum::serve(listener, app)
112 .with_graceful_shutdown(async move { shutdown.cancelled().await })
113 .await
114 .unwrap();
115}
116
117#[derive(Debug, Deserialize)]
118struct HelloQuery {
119 auth_reload: Option<String>,
120 auth_failed: Option<String>,
121}
122async fn hello(
123 State(AppState {
124 engine,
125 resolve_handles,
126 shutdown,
127 oauth,
128 ..
129 }): State<AppState>,
130 Query(params): Query<HelloQuery>,
131 mut jar: SignedCookieJar,
132) -> Response {
133 let is_auth_reload = params.auth_reload.is_some();
134 let auth_failed = params.auth_failed.is_some();
135 let no_cookie = jar.get(HELLO_COOKIE_KEY).is_none();
136 jar = jar.add(hello_cookie());
137
138 let info = if let Some(did) = jar.get(DID_COOKIE_KEY) {
139 if let Ok(did) = Did::new(did.value_trimmed().to_string()) {
140 // push cookie expiry
141 jar = jar.add(cookie(&did));
142 let fetch_key = resolve_handles.dispatch(
143 {
144 let oauth = oauth.clone();
145 let did = did.clone();
146 async move { oauth.resolve_handle(did.clone()).await }
147 },
148 shutdown.child_token(),
149 );
150 json!({
151 "did": did,
152 "fetch_key": fetch_key,
153 "is_auth_reload": is_auth_reload,
154 "auth_failed": auth_failed,
155 "no_cookie": no_cookie,
156 })
157 } else {
158 jar = jar.remove(DID_COOKIE_KEY);
159 json!({
160 "is_auth_reload": is_auth_reload,
161 "auth_failed": auth_failed,
162 "no_cookie": no_cookie,
163 })
164 }
165 } else {
166 json!({
167 "is_auth_reload": is_auth_reload,
168 "auth_failed": auth_failed,
169 "no_cookie": no_cookie,
170 })
171 };
172 let frame_headers = [(CONTENT_SECURITY_POLICY, "frame-ancestors 'none'")];
173 (frame_headers, jar, RenderHtml("hello", engine, info)).into_response()
174}
175
176async fn css() -> impl IntoResponse {
177 let headers = [
178 (CONTENT_TYPE, "text/css"),
179 // (CACHE_CONTROL, "") // TODO
180 ];
181 (headers, STYLE_CSS)
182}
183
184async fn favicon() -> impl IntoResponse {
185 ([(CONTENT_TYPE, "image/x-icon")], FAVICON)
186}
187
188fn hello_cookie() -> Cookie<'static> {
189 Cookie::build((HELLO_COOKIE_KEY, "hiiii"))
190 .http_only(true)
191 .secure(true)
192 .same_site(SameSite::None)
193 .expires(Expiration::DateTime(
194 (SystemTime::now() + COOKIE_EXPIRATION).into(),
195 )) // wtf safari needs this to not be a session cookie??
196 .max_age(COOKIE_EXPIRATION.try_into().unwrap())
197 .path("/")
198 .into()
199}
200
201fn cookie(did: &Did) -> Cookie<'static> {
202 Cookie::build((DID_COOKIE_KEY, did.to_string()))
203 .http_only(true)
204 .secure(true)
205 .same_site(SameSite::None)
206 .expires(Expiration::DateTime(
207 (SystemTime::now() + COOKIE_EXPIRATION).into(),
208 )) // wtf safari needs this to not be a session cookie??
209 .max_age(COOKIE_EXPIRATION.try_into().unwrap())
210 .path("/")
211 .into()
212}
213
214#[derive(Debug, Deserialize)]
215struct PromptQuery {
216 // this must *ONLY* be used for the postmessage target origin
217 app: Option<String>,
218}
219async fn prompt(
220 State(AppState {
221 allowed_hosts,
222 engine,
223 oauth,
224 resolve_handles,
225 shutdown,
226 tokens,
227 ..
228 }): State<AppState>,
229 Query(params): Query<PromptQuery>,
230 jar: SignedCookieJar,
231 headers: HeaderMap,
232) -> impl IntoResponse {
233 let err = |reason, check_frame, detail| {
234 metrics::counter!("whoami_auth_prompt", "ok" => "false", "reason" => reason).increment(1);
235 let info = json!({
236 "reason": reason,
237 "check_frame": check_frame,
238 "detail": detail,
239 });
240 let html = RenderHtml("prompt-error", engine.clone(), info);
241 (StatusCode::BAD_REQUEST, html).into_response()
242 };
243
244 let Some(parent) = headers.get(ORIGIN).or_else(|| {
245 eprintln!("referrer fallback");
246 // TODO: referer should only be used for localhost??
247 headers.get(REFERER)
248 }) else {
249 return err("Missing origin and no referrer for fallback", true, None);
250 };
251 let Ok(parent) = parent.to_str() else {
252 return err("Unreadable origin or referrer", true, None);
253 };
254 eprintln!(
255 "rolling with parent: {parent:?} (from origin? {})",
256 headers.get(ORIGIN).is_some()
257 );
258 let Ok(url) = Url::parse(parent) else {
259 return err("Bad origin or referrer", true, None);
260 };
261 let Some(parent_host) = url.host_str() else {
262 return err("Origin or referrer missing host", true, None);
263 };
264 if !allowed_hosts.contains(parent_host) {
265 return err(
266 "Login is not allowed on this page",
267 false,
268 Some(parent_host),
269 );
270 }
271 if let Some(ref app) = params.app
272 && !allowed_hosts.contains(app)
273 {
274 return err("Login is not allowed for this app", false, Some(app));
275 }
276 let parent_origin = url.origin().ascii_serialization();
277 if parent_origin == "null" {
278 return err("Origin or referrer header value is opaque", true, None);
279 }
280
281 let all_allowed = allowed_hosts
282 .iter()
283 .map(|h| format!("https://{h}"))
284 .collect::<Vec<_>>()
285 .join(" ");
286 let csp = format!("frame-ancestors 'self' {parent_origin} {all_allowed}");
287 let frame_headers = [(CONTENT_SECURITY_POLICY, &csp)];
288
289 if let Some(did) = jar.get(DID_COOKIE_KEY) {
290 let Ok(did) = Did::new(did.value_trimmed().to_string()) else {
291 return err("Bad cookie", false, None);
292 };
293
294 // push cookie expiry
295 let jar = jar.add(cookie(&did));
296
297 let token = match tokens.mint(&*did) {
298 Ok(t) => t,
299 Err(e) => {
300 eprintln!("failed to create JWT: {e:?}");
301 return err("failed to create JWT", false, None);
302 }
303 };
304
305 let fetch_key = resolve_handles.dispatch(
306 {
307 let oauth = oauth.clone();
308 let did = did.clone();
309 async move { oauth.resolve_handle(did.clone()).await }
310 },
311 shutdown.child_token(),
312 );
313
314 metrics::counter!("whoami_auth_prompt", "ok" => "true", "known" => "true").increment(1);
315 let info = json!({
316 "did": did,
317 "token": token,
318 "fetch_key": fetch_key,
319 "parent_host": parent_host,
320 "parent_origin": parent_origin,
321 "parent_target": params.app.map(|h| format!("https://{h}")),
322 });
323 (frame_headers, jar, RenderHtml("prompt", engine, info)).into_response()
324 } else {
325 metrics::counter!("whoami_auth_prompt", "ok" => "true", "known" => "false").increment(1);
326 let info = json!({
327 "parent_host": parent_host,
328 "parent_origin": parent_origin,
329 });
330 (frame_headers, RenderHtml("prompt", engine, info)).into_response()
331 }
332}
333
334#[derive(Debug, Deserialize)]
335struct UserInfoParams {
336 fetch_key: String,
337}
338async fn user_info(
339 State(AppState {
340 resolve_handles, ..
341 }): State<AppState>,
342 ExtractJson(params): ExtractJson<UserInfoParams>,
343) -> impl IntoResponse {
344 let err = |status, reason: &str| {
345 metrics::counter!("whoami_user_info", "found" => "false", "reason" => reason.to_string())
346 .increment(1);
347 (status, Json(json!({ "reason": reason }))).into_response()
348 };
349
350 let Some(task_handle) = resolve_handles.take(¶ms.fetch_key) else {
351 return err(StatusCode::NOT_FOUND, "fetch key does not exist or expired");
352 };
353
354 match task_handle.await {
355 Err(task_err) => {
356 eprintln!("task join error? {task_err:?}");
357 err(StatusCode::INTERNAL_SERVER_ERROR, "server errored")
358 }
359 Ok(Err(ResolveHandleError::ResolutionFailed(atrium_identity::Error::NotFound))) => {
360 err(StatusCode::NOT_FOUND, "handle not found")
361 }
362 Ok(Err(ResolveHandleError::ResolutionFailed(e))) => {
363 eprintln!("handle resolution failed: {e:?}");
364 err(
365 StatusCode::INTERNAL_SERVER_ERROR,
366 "handle resolution failed",
367 )
368 }
369 Ok(Err(ResolveHandleError::NoHandle)) => err(
370 StatusCode::INTERNAL_SERVER_ERROR,
371 "resolved identity but did not find a handle",
372 ),
373 Ok(Err(ResolveHandleError::InvalidHandle(_h, reason))) => err(
374 StatusCode::INTERNAL_SERVER_ERROR,
375 &format!("handle appears invalid: {reason}"),
376 ),
377 Ok(Ok(handle)) => {
378 metrics::counter!("whoami_user_info", "found" => "true").increment(1);
379 Json(json!({ "handle": handle })).into_response()
380 }
381 }
382}
383
384async fn client_metadata(
385 State(AppState { oauth, .. }): State<AppState>,
386) -> Json<OAuthClientMetadata> {
387 Json(oauth.client_metadata())
388}
389
390#[derive(Debug, Deserialize)]
391struct BeginOauthParams {
392 handle: String,
393}
394async fn start_oauth(
395 State(AppState { oauth, engine, .. }): State<AppState>,
396 Query(params): Query<BeginOauthParams>,
397 jar: SignedCookieJar,
398) -> Response {
399 // if any existing session was active, clear it first
400 // ...this might help a confusion attack w multiple sign-in flows or smth
401 let jar = jar.remove(DID_COOKIE_KEY);
402
403 use atrium_identity::Error as IdError;
404 use atrium_oauth::Error as OAuthError;
405
406 let err = |code, reason: &str| {
407 metrics::counter!("whoami_auth_start", "ok" => "false", "reason" => reason.to_string())
408 .increment(1);
409 let info = json!({
410 "result": "fail",
411 "reason": reason,
412 });
413 (code, RenderHtml("auth-fail", engine.clone(), info)).into_response()
414 };
415
416 match oauth.begin(¶ms.handle).await {
417 Err(OAuthError::Identity(
418 IdError::NotFound | IdError::HttpStatus(StatusCode::NOT_FOUND),
419 )) => err(StatusCode::NOT_FOUND, "handle not found"),
420 Err(OAuthError::Identity(IdError::AtIdentifier(r))) => err(StatusCode::BAD_REQUEST, &r),
421 Err(e) => {
422 eprintln!("begin auth failed: {e:?}");
423 err(StatusCode::INTERNAL_SERVER_ERROR, "unknown")
424 }
425 Ok(auth_url) => {
426 metrics::counter!("whoami_auth_start", "ok" => "true").increment(1);
427 (jar, Redirect::to(&auth_url)).into_response()
428 }
429 }
430}
431
432async fn complete_oauth(
433 State(AppState {
434 engine,
435 resolve_handles,
436 oauth,
437 shutdown,
438 tokens,
439 ..
440 }): State<AppState>,
441 Query(params): Query<OAuthCallbackParams>,
442 jar: SignedCookieJar,
443) -> Response {
444 let err = |code, result, reason: &str| {
445 metrics::counter!("whoami_auth_complete", "ok" => "false", "reason" => reason.to_string())
446 .increment(1);
447 let info = json!({
448 "result": result,
449 "reason": reason,
450 });
451 (code, RenderHtml("auth-fail", engine.clone(), info)).into_response()
452 };
453
454 let did = match oauth.complete(params).await {
455 Ok(did) => did,
456 Err(e) => {
457 return match e {
458 OAuthCompleteError::Denied { description, .. } => {
459 let desc = description.unwrap_or("permission to share was denied".to_string());
460 err(StatusCode::FORBIDDEN, "deny", desc.as_str())
461 }
462 OAuthCompleteError::Failed { .. } => {
463 eprintln!("auth completion failed: {e:?}");
464 err(
465 StatusCode::INTERNAL_SERVER_ERROR,
466 "fail",
467 "failed to complete",
468 )
469 }
470 OAuthCompleteError::CallbackFailed(e) => {
471 eprintln!("auth callback failed: {e:?}");
472 err(
473 StatusCode::INTERNAL_SERVER_ERROR,
474 "fail",
475 "failed to complete callback",
476 )
477 }
478 OAuthCompleteError::NoDid => err(StatusCode::BAD_REQUEST, "fail", "no DID found"),
479 };
480 }
481 };
482
483 let jar = jar.add(cookie(&did));
484
485 let token = match tokens.mint(&*did) {
486 Ok(t) => t,
487 Err(e) => {
488 eprintln!("failed to create JWT: {e:?}");
489 return err(
490 StatusCode::INTERNAL_SERVER_ERROR,
491 "fail",
492 "failed to create JWT",
493 );
494 }
495 };
496
497 let fetch_key = resolve_handles.dispatch(
498 {
499 let oauth = oauth.clone();
500 let did = did.clone();
501 async move { oauth.resolve_handle(did.clone()).await }
502 },
503 shutdown.child_token(),
504 );
505
506 metrics::counter!("whoami_auth_complete", "ok" => "true").increment(1);
507 let info = json!({
508 "did": did,
509 "token": token,
510 "fetch_key": fetch_key,
511 });
512 (jar, RenderHtml("authorized", engine, info)).into_response()
513}
514
515async fn disconnect(jar: SignedCookieJar) -> impl IntoResponse {
516 metrics::counter!("whoami_disconnect").increment(1);
517 let jar = jar.remove(DID_COOKIE_KEY);
518 (jar, Json(json!({ "ok": true })))
519}
520
521async fn jwks(State(AppState { oauth, tokens, .. }): State<AppState>) -> Json<JwkSet> {
522 let mut jwks = oauth.jwks();
523 jwks.keys.push(tokens.jwk());
524 Json(jwks)
525}