forked from
microcosm.blue/microcosm-rs
Constellation, Spacedust, Slingshot, UFOs: atproto crates and services for microcosm
1use askama::Template;
2use axum::{
3 extract::{Query, Request},
4 http::{self, header},
5 middleware::{self, Next},
6 response::{IntoResponse, Response},
7 routing::get,
8 Router,
9};
10use axum_metrics::{ExtraMetricLabels, MetricLayer};
11use bincode::Options;
12use serde::{Deserialize, Serialize};
13use serde_with::serde_as;
14use std::collections::{HashMap, HashSet};
15use std::time::{Duration, UNIX_EPOCH};
16use tokio::net::{TcpListener, ToSocketAddrs};
17use tokio::task::spawn_blocking;
18use tokio_util::sync::CancellationToken;
19
20use crate::storage::{LinkReader, StorageStats};
21use crate::{CountsByCount, Did, RecordId};
22
23mod acceptable;
24mod filters;
25
26use acceptable::{acceptable, ExtractAccept};
27
28const DEFAULT_CURSOR_LIMIT: u64 = 16;
29const DEFAULT_CURSOR_LIMIT_MAX: u64 = 100;
30
31fn get_default_cursor_limit() -> u64 {
32 DEFAULT_CURSOR_LIMIT
33}
34
35fn to500(e: tokio::task::JoinError) -> http::StatusCode {
36 eprintln!("handler error: {e}");
37 http::StatusCode::INTERNAL_SERVER_ERROR
38}
39
40pub async fn serve<S, A>(store: S, addr: A, stay_alive: CancellationToken) -> anyhow::Result<()>
41where
42 S: LinkReader,
43 A: ToSocketAddrs,
44{
45 let app = Router::new()
46 .route("/robots.txt", get(robots))
47 .route(
48 "/",
49 get({
50 let store = store.clone();
51 move |accept| async {
52 spawn_blocking(|| hello(accept, store))
53 .await
54 .map_err(to500)?
55 }
56 }),
57 )
58 .route(
59 "/xrpc/blue.microcosm.links.getManyToManyCounts",
60 get({
61 let store = store.clone();
62 move |accept, query| async {
63 spawn_blocking(|| get_many_to_many_counts(accept, query, store))
64 .await
65 .map_err(to500)?
66 }
67 }),
68 )
69 .route(
70 "/links/count",
71 get({
72 let store = store.clone();
73 move |accept, query| async {
74 spawn_blocking(|| count_links(accept, query, store))
75 .await
76 .map_err(to500)?
77 }
78 }),
79 )
80 .route(
81 "/links/count/distinct-dids",
82 get({
83 let store = store.clone();
84 move |accept, query| async {
85 spawn_blocking(|| count_distinct_dids(accept, query, store))
86 .await
87 .map_err(to500)?
88 }
89 }),
90 )
91 .route(
92 "/xrpc/blue.microcosm.links.getBacklinks",
93 get({
94 let store = store.clone();
95 move |accept, query| async {
96 spawn_blocking(|| get_backlinks(accept, query, store))
97 .await
98 .map_err(to500)?
99 }
100 }),
101 )
102 .route(
103 "/links",
104 get({
105 let store = store.clone();
106 move |accept, query| async {
107 spawn_blocking(|| get_links(accept, query, store))
108 .await
109 .map_err(to500)?
110 }
111 }),
112 )
113 .route(
114 "/links/distinct-dids",
115 get({
116 let store = store.clone();
117 move |accept, query| async {
118 spawn_blocking(|| get_distinct_dids(accept, query, store))
119 .await
120 .map_err(to500)?
121 }
122 }),
123 )
124 .route(
125 // deprecated
126 "/links/all/count",
127 get({
128 let store = store.clone();
129 move |accept, query| async {
130 spawn_blocking(|| count_all_links(accept, query, store))
131 .await
132 .map_err(to500)?
133 }
134 }),
135 )
136 .route(
137 "/links/all",
138 get({
139 let store = store.clone();
140 move |accept, query| async {
141 spawn_blocking(|| explore_links(accept, query, store))
142 .await
143 .map_err(to500)?
144 }
145 }),
146 )
147 .layer(tower_http::cors::CorsLayer::permissive())
148 .layer(middleware::from_fn(add_lables))
149 .layer(MetricLayer::default());
150
151 let listener = TcpListener::bind(addr).await?;
152 println!("api: listening at http://{:?}", listener.local_addr()?);
153 axum::serve(listener, app)
154 .with_graceful_shutdown(async move { stay_alive.cancelled().await })
155 .await?;
156
157 Ok(())
158}
159
160async fn add_lables(request: Request, next: Next) -> Response {
161 let origin = request
162 .headers()
163 .get(header::ORIGIN)
164 .and_then(|o| o.to_str().map(|v| v.to_owned()).ok());
165 let user_agent = request.headers().get(header::USER_AGENT).and_then(|ua| {
166 ua.to_str()
167 .map(|v| {
168 if v.starts_with("Mozilla/") {
169 "Mozilla/...".into()
170 } else {
171 v.to_owned()
172 }
173 })
174 .ok()
175 });
176
177 let mut res = next.run(request).await;
178
179 let mut labels = Vec::new();
180 if let Some(o) = origin {
181 labels.push(metrics::Label::new("origin", o));
182 }
183 if let Some(ua) = user_agent {
184 labels.push(metrics::Label::new("user_agent", ua));
185 }
186 res.extensions_mut().insert(ExtraMetricLabels(labels));
187 res
188}
189
190async fn robots() -> &'static str {
191 "\
192User-agent: *
193Disallow: /links
194Disallow: /links/
195 "
196}
197
198#[derive(Template, Serialize, Deserialize)]
199#[template(path = "hello.html.j2")]
200struct HelloReponse {
201 help: &'static str,
202 days_indexed: Option<u64>,
203 stats: StorageStats,
204}
205fn hello(
206 accept: ExtractAccept,
207 store: impl LinkReader,
208) -> Result<impl IntoResponse, http::StatusCode> {
209 let stats = store
210 .get_stats()
211 .map_err(|_| http::StatusCode::INTERNAL_SERVER_ERROR)?;
212 let days_indexed = stats
213 .started_at
214 .map(|c| (UNIX_EPOCH + Duration::from_micros(c)).elapsed())
215 .transpose()
216 .map_err(|_| http::StatusCode::INTERNAL_SERVER_ERROR)?
217 .map(|d| d.as_secs() / 86_400);
218 Ok(acceptable(accept, HelloReponse {
219 help: "open this URL in a web browser (or request with Accept: text/html) for information about this API.",
220 days_indexed,
221 stats,
222 }))
223}
224
225#[derive(Clone, Deserialize)]
226#[serde(rename_all = "camelCase")]
227struct GetManyToManyCountsQuery {
228 subject: String,
229 source: String,
230 /// path to the secondary link in the linking record
231 path_to_other: String,
232 /// filter to linking records (join of the m2m) by these DIDs
233 #[serde(default)]
234 did: Vec<String>,
235 /// filter to specific secondary records
236 #[serde(default)]
237 other_subject: Vec<String>,
238 cursor: Option<OpaqueApiCursor>,
239 /// Set the max number of links to return per page of results
240 #[serde(default = "get_default_cursor_limit")]
241 limit: u64,
242}
243#[derive(Serialize)]
244struct OtherSubjectCount {
245 subject: String,
246 total: u64,
247 distinct: u64,
248}
249#[derive(Template, Serialize)]
250#[template(path = "get-many-to-many-counts.html.j2")]
251struct GetManyToManyCountsResponse {
252 counts_by_other_subject: Vec<OtherSubjectCount>,
253 cursor: Option<OpaqueApiCursor>,
254 #[serde(skip_serializing)]
255 query: GetManyToManyCountsQuery,
256}
257fn get_many_to_many_counts(
258 accept: ExtractAccept,
259 query: axum_extra::extract::Query<GetManyToManyCountsQuery>,
260 store: impl LinkReader,
261) -> Result<impl IntoResponse, http::StatusCode> {
262 let cursor_key = query
263 .cursor
264 .clone()
265 .map(|oc| ApiKeyedCursor::try_from(oc).map_err(|_| http::StatusCode::BAD_REQUEST))
266 .transpose()?
267 .map(|c| c.next);
268
269 let limit = query.limit;
270 if limit > DEFAULT_CURSOR_LIMIT_MAX {
271 return Err(http::StatusCode::BAD_REQUEST);
272 }
273
274 let filter_dids: HashSet<Did> = HashSet::from_iter(
275 query
276 .did
277 .iter()
278 .map(|d| d.trim())
279 .filter(|d| !d.is_empty())
280 .map(|d| Did(d.to_string())),
281 );
282
283 let filter_other_subjects: HashSet<String> = HashSet::from_iter(
284 query
285 .other_subject
286 .iter()
287 .map(|s| s.trim().to_string())
288 .filter(|s| !s.is_empty()),
289 );
290
291 let Some((collection, path)) = query.source.split_once(':') else {
292 return Err(http::StatusCode::BAD_REQUEST);
293 };
294 let path = format!(".{path}");
295
296 let path_to_other = format!(".{}", query.path_to_other);
297
298 let paged = store
299 .get_many_to_many_counts(
300 &query.subject,
301 collection,
302 &path,
303 &path_to_other,
304 limit,
305 cursor_key,
306 &filter_dids,
307 &filter_other_subjects,
308 )
309 .map_err(|_| http::StatusCode::INTERNAL_SERVER_ERROR)?;
310
311 let cursor = paged.next.map(|next| ApiKeyedCursor { next }.into());
312
313 let items = paged
314 .items
315 .into_iter()
316 .map(|(subject, total, distinct)| OtherSubjectCount {
317 subject,
318 total,
319 distinct,
320 })
321 .collect();
322
323 Ok(acceptable(
324 accept,
325 GetManyToManyCountsResponse {
326 counts_by_other_subject: items,
327 cursor,
328 query: (*query).clone(),
329 },
330 ))
331}
332
333#[derive(Clone, Deserialize)]
334struct GetLinksCountQuery {
335 target: String,
336 collection: String,
337 path: String,
338}
339#[derive(Template, Serialize)]
340#[template(path = "links-count.html.j2")]
341struct GetLinksCountResponse {
342 total: u64,
343 #[serde(skip_serializing)]
344 query: GetLinksCountQuery,
345}
346fn count_links(
347 accept: ExtractAccept,
348 query: Query<GetLinksCountQuery>,
349 store: impl LinkReader,
350) -> Result<impl IntoResponse, http::StatusCode> {
351 let total = store
352 .get_count(&query.target, &query.collection, &query.path)
353 .map_err(|_| http::StatusCode::INTERNAL_SERVER_ERROR)?;
354 Ok(acceptable(
355 accept,
356 GetLinksCountResponse {
357 total,
358 query: (*query).clone(),
359 },
360 ))
361}
362
363#[derive(Clone, Deserialize)]
364struct GetDidsCountQuery {
365 target: String,
366 collection: String,
367 path: String,
368}
369#[derive(Template, Serialize)]
370#[template(path = "dids-count.html.j2")]
371struct GetDidsCountResponse {
372 total: u64,
373 #[serde(skip_serializing)]
374 query: GetDidsCountQuery,
375}
376fn count_distinct_dids(
377 accept: ExtractAccept,
378 query: Query<GetDidsCountQuery>,
379 store: impl LinkReader,
380) -> Result<impl IntoResponse, http::StatusCode> {
381 let total = store
382 .get_distinct_did_count(&query.target, &query.collection, &query.path)
383 .map_err(|_| http::StatusCode::INTERNAL_SERVER_ERROR)?;
384 Ok(acceptable(
385 accept,
386 GetDidsCountResponse {
387 total,
388 query: (*query).clone(),
389 },
390 ))
391}
392
393#[derive(Clone, Deserialize)]
394struct GetBacklinksQuery {
395 /// The link target
396 ///
397 /// can be an AT-URI, plain DID, or regular URI
398 subject: String,
399 /// Filter links only from this link source
400 ///
401 /// eg.: `app.bsky.feed.like:subject.uri`
402 source: String,
403 cursor: Option<OpaqueApiCursor>,
404 /// Filter links only from these DIDs
405 ///
406 /// include multiple times to filter by multiple source DIDs
407 #[serde(default)]
408 did: Vec<String>,
409 /// Set the max number of links to return per page of results
410 #[serde(default = "get_default_cursor_limit")]
411 limit: u64,
412 // TODO: allow reverse (er, forward) order as well
413}
414#[derive(Template, Serialize)]
415#[template(path = "get-backlinks.html.j2")]
416struct GetBacklinksResponse {
417 total: u64,
418 records: Vec<RecordId>,
419 cursor: Option<OpaqueApiCursor>,
420 #[serde(skip_serializing)]
421 query: GetBacklinksQuery,
422 #[serde(skip_serializing)]
423 collection: String,
424 #[serde(skip_serializing)]
425 path: String,
426}
427fn get_backlinks(
428 accept: ExtractAccept,
429 query: axum_extra::extract::Query<GetBacklinksQuery>, // supports multiple param occurrences
430 store: impl LinkReader,
431) -> Result<impl IntoResponse, http::StatusCode> {
432 let until = query
433 .cursor
434 .clone()
435 .map(|oc| ApiCursor::try_from(oc).map_err(|_| http::StatusCode::BAD_REQUEST))
436 .transpose()?
437 .map(|c| c.next);
438
439 let limit = query.limit;
440 if limit > DEFAULT_CURSOR_LIMIT_MAX {
441 return Err(http::StatusCode::BAD_REQUEST);
442 }
443
444 let filter_dids: HashSet<Did> = HashSet::from_iter(
445 query
446 .did
447 .iter()
448 .map(|d| d.trim())
449 .filter(|d| !d.is_empty())
450 .map(|d| Did(d.to_string())),
451 );
452
453 let Some((collection, path)) = query.source.split_once(':') else {
454 return Err(http::StatusCode::BAD_REQUEST);
455 };
456 let path = format!(".{path}");
457
458 let paged = store
459 .get_links(
460 &query.subject,
461 collection,
462 &path,
463 limit,
464 until,
465 &filter_dids,
466 )
467 .map_err(|_| http::StatusCode::INTERNAL_SERVER_ERROR)?;
468
469 let cursor = paged.next.map(|next| {
470 ApiCursor {
471 version: paged.version,
472 next,
473 }
474 .into()
475 });
476
477 Ok(acceptable(
478 accept,
479 GetBacklinksResponse {
480 total: paged.total,
481 records: paged.items,
482 cursor,
483 query: (*query).clone(),
484 collection: collection.to_string(),
485 path,
486 },
487 ))
488}
489
490#[derive(Clone, Deserialize)]
491struct GetLinkItemsQuery {
492 target: String,
493 collection: String,
494 path: String,
495 cursor: Option<OpaqueApiCursor>,
496 /// Filter links only from these DIDs
497 ///
498 /// include multiple times to filter by multiple source DIDs
499 #[serde(default)]
500 did: Vec<String>,
501 /// [deprecated] Filter links only from these DIDs
502 ///
503 /// format: comma-separated sequence of DIDs
504 ///
505 /// errors: if `did` parameter is also present
506 ///
507 /// deprecated: use `did`, which can be repeated multiple times
508 from_dids: Option<String>, // comma separated: gross
509 #[serde(default = "get_default_cursor_limit")]
510 limit: u64,
511 // TODO: allow reverse (er, forward) order as well
512}
513#[derive(Template, Serialize)]
514#[template(path = "links.html.j2")]
515struct GetLinkItemsResponse {
516 // what does staleness mean?
517 // - new links have appeared. would be nice to offer a `since` cursor to fetch these. and/or,
518 // - links have been deleted. hmm.
519 total: u64,
520 linking_records: Vec<RecordId>,
521 cursor: Option<OpaqueApiCursor>,
522 #[serde(skip_serializing)]
523 query: GetLinkItemsQuery,
524}
525fn get_links(
526 accept: ExtractAccept,
527 query: axum_extra::extract::Query<GetLinkItemsQuery>, // supports multiple param occurrences
528 store: impl LinkReader,
529) -> Result<impl IntoResponse, http::StatusCode> {
530 let until = query
531 .cursor
532 .clone()
533 .map(|oc| ApiCursor::try_from(oc).map_err(|_| http::StatusCode::BAD_REQUEST))
534 .transpose()?
535 .map(|c| c.next);
536
537 let limit = query.limit;
538 if limit > DEFAULT_CURSOR_LIMIT_MAX {
539 return Err(http::StatusCode::BAD_REQUEST);
540 }
541
542 let mut filter_dids: HashSet<Did> = HashSet::from_iter(
543 query
544 .did
545 .iter()
546 .map(|d| d.trim())
547 .filter(|d| !d.is_empty())
548 .map(|d| Did(d.to_string())),
549 );
550
551 if let Some(comma_joined) = &query.from_dids {
552 if !filter_dids.is_empty() {
553 return Err(http::StatusCode::BAD_REQUEST);
554 }
555 for did in comma_joined.split(',') {
556 filter_dids.insert(Did(did.to_string()));
557 }
558 }
559
560 let paged = store
561 .get_links(
562 &query.target,
563 &query.collection,
564 &query.path,
565 limit,
566 until,
567 &filter_dids,
568 )
569 .map_err(|_| http::StatusCode::INTERNAL_SERVER_ERROR)?;
570
571 let cursor = paged.next.map(|next| {
572 ApiCursor {
573 version: paged.version,
574 next,
575 }
576 .into()
577 });
578
579 Ok(acceptable(
580 accept,
581 GetLinkItemsResponse {
582 total: paged.total,
583 linking_records: paged.items,
584 cursor,
585 query: (*query).clone(),
586 },
587 ))
588}
589
590#[derive(Clone, Deserialize)]
591struct GetDidItemsQuery {
592 target: String,
593 collection: String,
594 path: String,
595 cursor: Option<OpaqueApiCursor>,
596 limit: Option<u64>,
597 // TODO: allow reverse (er, forward) order as well
598}
599#[derive(Template, Serialize)]
600#[template(path = "dids.html.j2")]
601struct GetDidItemsResponse {
602 // what does staleness mean?
603 // - new links have appeared. would be nice to offer a `since` cursor to fetch these. and/or,
604 // - links have been deleted. hmm.
605 total: u64,
606 linking_dids: Vec<Did>,
607 cursor: Option<OpaqueApiCursor>,
608 #[serde(skip_serializing)]
609 query: GetDidItemsQuery,
610}
611fn get_distinct_dids(
612 accept: ExtractAccept,
613 query: Query<GetDidItemsQuery>,
614 store: impl LinkReader,
615) -> Result<impl IntoResponse, http::StatusCode> {
616 let until = query
617 .cursor
618 .clone()
619 .map(|oc| ApiCursor::try_from(oc).map_err(|_| http::StatusCode::BAD_REQUEST))
620 .transpose()?
621 .map(|c| c.next);
622
623 let limit = query.limit.unwrap_or(DEFAULT_CURSOR_LIMIT);
624 if limit > DEFAULT_CURSOR_LIMIT_MAX {
625 return Err(http::StatusCode::BAD_REQUEST);
626 }
627
628 let paged = store
629 .get_distinct_dids(&query.target, &query.collection, &query.path, limit, until)
630 .map_err(|_| http::StatusCode::INTERNAL_SERVER_ERROR)?;
631
632 let cursor = paged.next.map(|next| {
633 ApiCursor {
634 version: paged.version,
635 next,
636 }
637 .into()
638 });
639
640 Ok(acceptable(
641 accept,
642 GetDidItemsResponse {
643 total: paged.total,
644 linking_dids: paged.items,
645 cursor,
646 query: (*query).clone(),
647 },
648 ))
649}
650
651#[derive(Clone, Deserialize)]
652struct GetAllLinksQuery {
653 target: String,
654}
655#[derive(Template, Serialize)]
656#[template(path = "links-all-count.html.j2")]
657struct GetAllLinksResponse {
658 links: HashMap<String, HashMap<String, u64>>,
659 #[serde(skip_serializing)]
660 query: GetAllLinksQuery,
661}
662fn count_all_links(
663 accept: ExtractAccept,
664 query: Query<GetAllLinksQuery>,
665 store: impl LinkReader,
666) -> Result<impl IntoResponse, http::StatusCode> {
667 let links = store
668 .get_all_record_counts(&query.target)
669 .map_err(|_| http::StatusCode::INTERNAL_SERVER_ERROR)?;
670 Ok(acceptable(
671 accept,
672 GetAllLinksResponse {
673 links,
674 query: (*query).clone(),
675 },
676 ))
677}
678
679#[derive(Clone, Deserialize)]
680struct ExploreLinksQuery {
681 target: String,
682}
683#[derive(Template, Serialize)]
684#[template(path = "explore-links.html.j2")]
685struct ExploreLinksResponse {
686 links: HashMap<String, HashMap<String, CountsByCount>>,
687 #[serde(skip_serializing)]
688 query: ExploreLinksQuery,
689}
690fn explore_links(
691 accept: ExtractAccept,
692 query: Query<ExploreLinksQuery>,
693 store: impl LinkReader,
694) -> Result<impl IntoResponse, http::StatusCode> {
695 let links = store
696 .get_all_counts(&query.target)
697 .map_err(|_| http::StatusCode::INTERNAL_SERVER_ERROR)?;
698 Ok(acceptable(
699 accept,
700 ExploreLinksResponse {
701 links,
702 query: (*query).clone(),
703 },
704 ))
705}
706
707#[serde_as]
708#[derive(Clone, Serialize, Deserialize)] // for json
709struct OpaqueApiCursor(#[serde_as(as = "serde_with::hex::Hex")] Vec<u8>);
710
711#[derive(Serialize, Deserialize)] // for bincode
712struct ApiCursor {
713 version: (u64, u64), // (collection length, deleted item count)
714 next: u64,
715}
716
717impl TryFrom<OpaqueApiCursor> for ApiCursor {
718 type Error = bincode::Error;
719
720 fn try_from(item: OpaqueApiCursor) -> Result<Self, Self::Error> {
721 bincode::DefaultOptions::new().deserialize(&item.0)
722 }
723}
724
725impl From<ApiCursor> for OpaqueApiCursor {
726 fn from(item: ApiCursor) -> Self {
727 OpaqueApiCursor(bincode::DefaultOptions::new().serialize(&item).unwrap())
728 }
729}
730
731#[derive(Serialize, Deserialize)] // for bincode
732struct ApiKeyedCursor {
733 next: String, // the key
734}
735
736impl TryFrom<OpaqueApiCursor> for ApiKeyedCursor {
737 type Error = bincode::Error;
738
739 fn try_from(item: OpaqueApiCursor) -> Result<Self, Self::Error> {
740 bincode::DefaultOptions::new().deserialize(&item.0)
741 }
742}
743
744impl From<ApiKeyedCursor> for OpaqueApiCursor {
745 fn from(item: ApiKeyedCursor) -> Self {
746 OpaqueApiCursor(bincode::DefaultOptions::new().serialize(&item).unwrap())
747 }
748}