tracks lexicons and how many times they appeared on the jetstream
1use std::{
2 fmt::Display,
3 net::SocketAddr,
4 ops::{Bound, Deref, RangeBounds},
5 time::Duration,
6};
7
8use ahash::AHashMap;
9use anyhow::anyhow;
10use axum::{
11 Json, Router,
12 extract::{Query, State},
13 http::Request,
14 response::Response,
15 routing::get,
16};
17use axum_tws::{Message, WebSocketUpgrade};
18use rclite::Arc;
19use serde::{Deserialize, Serialize};
20use smol_str::SmolStr;
21use tokio_util::sync::CancellationToken;
22use tower_http::{
23 classify::ServerErrorsFailureClass,
24 compression::CompressionLayer,
25 request_id::{MakeRequestUuid, PropagateRequestIdLayer, SetRequestIdLayer},
26 trace::TraceLayer,
27};
28use tracing::{Instrument, Span, field};
29
30use crate::{
31 db::Db,
32 error::{AppError, AppResult},
33};
34
35struct LatencyMillis(u128);
36
37impl From<Duration> for LatencyMillis {
38 fn from(duration: Duration) -> Self {
39 LatencyMillis(duration.as_millis())
40 }
41}
42
43impl Display for LatencyMillis {
44 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
45 write!(f, "{}ms", self.0)
46 }
47}
48
49pub async fn serve(db: Arc<Db>, cancel_token: CancellationToken) -> AppResult<()> {
50 let app = Router::new()
51 .route("/events", get(events))
52 .route("/stream_events", get(stream_events))
53 .route("/hits", get(hits))
54 .route("/since", get(since))
55 .route_layer(CompressionLayer::new().br(true).deflate(true).gzip(true).zstd(true))
56 .route_layer(PropagateRequestIdLayer::x_request_id())
57 .route_layer(
58 TraceLayer::new_for_http()
59 .make_span_with(|request: &Request<_>| {
60 let span = tracing::info_span!(
61 "request",
62 method = %request.method(),
63 uri = %request.uri(),
64 id = field::Empty,
65 ip = field::Empty,
66 );
67 if let Some(id) = request.headers().get("x-request-id") {
68 span.record("id", String::from_utf8_lossy(id.as_bytes()).deref());
69 }
70 if let Some(real_ip) = request.headers().get("x-real-ip") {
71 span.record("ip", String::from_utf8_lossy(real_ip.as_bytes()).deref());
72 }
73 span
74 })
75 .on_request(|_request: &Request<_>, span: &Span| {
76 let _ = span.enter();
77 tracing::info!("processing")
78 })
79 .on_response(|response: &Response<_>, latency: Duration, span: &Span| {
80 let _ = span.enter();
81 tracing::info!({code = %response.status().as_u16(), latency = %LatencyMillis::from(latency)}, "processed")
82 })
83 .on_eos(())
84 .on_failure(|error: ServerErrorsFailureClass, _: Duration, span: &Span| {
85 let _ = span.enter();
86 if matches!(error, ServerErrorsFailureClass::StatusCode(status_code) if status_code.is_server_error()) || matches!(error, ServerErrorsFailureClass::Error(_)) {
87 tracing::error!("server error: {}", error.to_string().to_lowercase());
88 };
89 }),
90 )
91 .route_layer(SetRequestIdLayer::x_request_id(MakeRequestUuid))
92 .with_state(db);
93
94 let addr = SocketAddr::from((
95 [0, 0, 0, 0],
96 std::env::var("PORT")
97 .ok()
98 .and_then(|s| s.parse::<u16>().ok())
99 .unwrap_or(3713),
100 ));
101 let listener = tokio::net::TcpListener::bind(addr).await?;
102
103 tracing::info!("starting serve on {addr}");
104 tokio::select! {
105 res = axum::serve(listener, app) => res.map_err(AppError::from),
106 _ = cancel_token.cancelled() => Err(anyhow!("cancelled").into()),
107 }
108}
109
110#[derive(Serialize)]
111struct NsidCount {
112 count: u128,
113 deleted_count: u128,
114 last_seen: u64,
115}
116
117#[derive(Serialize)]
118struct Events {
119 per_second: usize,
120 events: AHashMap<SmolStr, NsidCount>,
121}
122
123async fn events(db: State<Arc<Db>>) -> AppResult<Json<Events>> {
124 let mut events = AHashMap::new();
125 for result in db.get_counts() {
126 let (nsid, counts) = result?;
127 events.insert(
128 nsid,
129 NsidCount {
130 count: counts.count,
131 deleted_count: counts.deleted_count,
132 last_seen: counts.last_seen,
133 },
134 );
135 }
136 Ok(Json(Events {
137 events,
138 per_second: db.eps(),
139 }))
140}
141
142#[derive(Debug, Deserialize)]
143struct HitsQuery {
144 nsid: SmolStr,
145 from: Option<u64>,
146 to: Option<u64>,
147}
148
149#[derive(Debug, Serialize)]
150struct Hit {
151 timestamp: u64,
152 deleted: bool,
153}
154
155const MAX_HITS: usize = 100_000;
156
157#[derive(Debug)]
158struct HitsRange {
159 from: Bound<u64>,
160 to: Bound<u64>,
161}
162
163impl RangeBounds<u64> for HitsRange {
164 fn start_bound(&self) -> Bound<&u64> {
165 self.from.as_ref()
166 }
167
168 fn end_bound(&self) -> Bound<&u64> {
169 self.to.as_ref()
170 }
171}
172
173async fn hits(
174 State(db): State<Arc<Db>>,
175 Query(params): Query<HitsQuery>,
176) -> AppResult<Json<Vec<Hit>>> {
177 let from = params.to.map(Bound::Included).unwrap_or(Bound::Unbounded);
178 let to = params.from.map(Bound::Included).unwrap_or(Bound::Unbounded);
179
180 db.get_hits(¶ms.nsid, HitsRange { from, to }, MAX_HITS)
181 .take(MAX_HITS)
182 .try_fold(Vec::with_capacity(MAX_HITS), |mut acc, hit| {
183 let hit = hit?;
184 let hit_data = hit.deser()?;
185
186 acc.push(Hit {
187 timestamp: hit.timestamp,
188 deleted: hit_data.deleted,
189 });
190 Ok(acc)
191 })
192 .map(Json)
193}
194
195async fn stream_events(db: State<Arc<Db>>, ws: WebSocketUpgrade) -> Response {
196 let span = tracing::info_span!(parent: Span::current(), "ws");
197 ws.on_upgrade(move |mut socket| {
198 (async move {
199 let mut listener = db.new_listener();
200 let mut data = Events {
201 events: AHashMap::<SmolStr, NsidCount>::with_capacity(10),
202 per_second: 0,
203 };
204 let mut updates = 0;
205 while let Ok((nsid, counts)) = listener.recv().await {
206 data.events.insert(
207 nsid,
208 NsidCount {
209 count: counts.count,
210 deleted_count: counts.deleted_count,
211 last_seen: counts.last_seen,
212 },
213 );
214 updates += 1;
215 // send 20 times every second max
216 data.per_second = db.eps();
217 if updates >= data.per_second / 16 {
218 let msg = serde_json::to_string(&data).unwrap();
219 let res = socket.send(Message::text(msg)).await;
220 data.events.clear();
221 updates = 0;
222 if let Err(err) = res {
223 tracing::error!("error sending event: {err}");
224 break;
225 }
226 }
227 }
228 })
229 .instrument(span)
230 })
231}
232
233#[derive(Debug, Serialize)]
234struct Since {
235 since: u64,
236}
237
238async fn since(db: State<Arc<Db>>) -> AppResult<Json<Since>> {
239 Ok(Json(Since {
240 since: db.tracking_since()?,
241 }))
242}