tracks lexicons and how many times they appeared on the jetstream
1use std::{
2 collections::HashMap,
3 fmt::Display,
4 net::SocketAddr,
5 ops::{Bound, Deref, RangeBounds},
6 time::Duration,
7};
8
9use anyhow::anyhow;
10use axum::{
11 Json, Router,
12 extract::{Query, State},
13 http::Request,
14 response::Response,
15 routing::get,
16};
17use axum_tws::{Message, WebSocketUpgrade};
18use rclite::Arc;
19use serde::{Deserialize, Serialize};
20use smol_str::SmolStr;
21use tokio_util::sync::CancellationToken;
22use tower_http::{
23 classify::ServerErrorsFailureClass,
24 compression::CompressionLayer,
25 request_id::{MakeRequestUuid, PropagateRequestIdLayer, SetRequestIdLayer},
26 trace::TraceLayer,
27};
28use tracing::{Instrument, Span, field};
29
30use crate::{
31 db::Db,
32 error::{AppError, AppResult},
33};
34
35struct LatencyMillis(u128);
36
37impl From<Duration> for LatencyMillis {
38 fn from(duration: Duration) -> Self {
39 LatencyMillis(duration.as_millis())
40 }
41}
42
43impl Display for LatencyMillis {
44 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
45 write!(f, "{}ms", self.0)
46 }
47}
48
49pub async fn serve(db: Arc<Db>, cancel_token: CancellationToken) -> AppResult<()> {
50 let app = Router::new()
51 .route("/events", get(events))
52 .route("/stream_events", get(stream_events))
53 .route("/hits", get(hits))
54 .route("/since", get(since))
55 .route_layer(CompressionLayer::new().br(true).deflate(true).gzip(true).zstd(true))
56 .route_layer(PropagateRequestIdLayer::x_request_id())
57 .route_layer(
58 TraceLayer::new_for_http()
59 .make_span_with(|request: &Request<_>| {
60 let span = tracing::info_span!(
61 "request",
62 method = %request.method(),
63 uri = %request.uri(),
64 id = field::Empty,
65 ip = field::Empty,
66 );
67 if let Some(id) = request.headers().get("x-request-id") {
68 span.record("id", String::from_utf8_lossy(id.as_bytes()).deref());
69 }
70 if let Some(real_ip) = request.headers().get("x-real-ip") {
71 span.record("ip", String::from_utf8_lossy(real_ip.as_bytes()).deref());
72 }
73 span
74 })
75 .on_request(|_request: &Request<_>, span: &Span| {
76 let _ = span.enter();
77 tracing::info!("processing")
78 })
79 .on_response(|response: &Response<_>, latency: Duration, span: &Span| {
80 let _ = span.enter();
81 tracing::info!({code = %response.status().as_u16(), latency = %LatencyMillis::from(latency)}, "processed")
82 })
83 .on_eos(())
84 .on_failure(|error: ServerErrorsFailureClass, _: Duration, span: &Span| {
85 let _ = span.enter();
86 if matches!(error, ServerErrorsFailureClass::StatusCode(status_code) if status_code.is_server_error()) || matches!(error, ServerErrorsFailureClass::Error(_)) {
87 tracing::error!("server error: {}", error.to_string().to_lowercase());
88 };
89 }),
90 )
91 .route_layer(SetRequestIdLayer::x_request_id(MakeRequestUuid))
92 .with_state(db);
93
94 let addr = SocketAddr::from((
95 [0, 0, 0, 0],
96 std::env::var("PORT")
97 .ok()
98 .and_then(|s| s.parse::<u16>().ok())
99 .unwrap_or(3713),
100 ));
101 let listener = tokio::net::TcpListener::bind(addr).await?;
102
103 tracing::info!("starting serve on {addr}");
104 tokio::select! {
105 res = axum::serve(listener, app) => res.map_err(AppError::from),
106 _ = cancel_token.cancelled() => Err(anyhow!("cancelled").into()),
107 }
108}
109
110#[derive(Serialize)]
111struct NsidCount {
112 count: u128,
113 deleted_count: u128,
114 last_seen: u64,
115}
116
117#[derive(Serialize)]
118struct Events {
119 per_second: usize,
120 events: HashMap<SmolStr, NsidCount>,
121}
122
123async fn events(db: State<Arc<Db>>) -> AppResult<Json<Events>> {
124 let mut events = HashMap::new();
125 for result in db.get_counts() {
126 let (nsid, counts) = result?;
127 events.insert(
128 nsid,
129 NsidCount {
130 count: counts.count,
131 deleted_count: counts.deleted_count,
132 last_seen: counts.last_seen,
133 },
134 );
135 }
136 Ok(Json(Events {
137 events,
138 per_second: db.eps(),
139 }))
140}
141
142#[derive(Debug, Deserialize)]
143struct HitsQuery {
144 nsid: SmolStr,
145 from: Option<u64>,
146 to: Option<u64>,
147}
148
149#[derive(Debug, Serialize)]
150struct Hit {
151 timestamp: u64,
152 deleted: bool,
153}
154
155const MAX_HITS: usize = 100_000;
156
157#[derive(Debug)]
158struct HitsRange {
159 from: Bound<u64>,
160 to: Bound<u64>,
161}
162
163impl RangeBounds<u64> for HitsRange {
164 fn start_bound(&self) -> Bound<&u64> {
165 self.from.as_ref()
166 }
167
168 fn end_bound(&self) -> Bound<&u64> {
169 self.to.as_ref()
170 }
171}
172
173async fn hits(
174 State(db): State<Arc<Db>>,
175 Query(params): Query<HitsQuery>,
176) -> AppResult<Json<Vec<Hit>>> {
177 let from = params.to.map(Bound::Included).unwrap_or(Bound::Unbounded);
178 let to = params.from.map(Bound::Included).unwrap_or(Bound::Unbounded);
179 let maybe_hits = db
180 .get_hits(¶ms.nsid, HitsRange { from, to })
181 .take(MAX_HITS);
182 let mut hits = Vec::with_capacity(maybe_hits.size_hint().0);
183
184 for maybe_hit in maybe_hits {
185 let hit = maybe_hit?;
186 let hit_data = hit.deser()?;
187
188 hits.push(Hit {
189 timestamp: hit.timestamp,
190 deleted: hit_data.deleted,
191 });
192 }
193
194 Ok(Json(hits))
195}
196
197async fn stream_events(db: State<Arc<Db>>, ws: WebSocketUpgrade) -> Response {
198 let span = tracing::info_span!(parent: Span::current(), "ws");
199 ws.on_upgrade(move |mut socket| {
200 (async move {
201 let mut listener = db.new_listener();
202 let mut data = Events {
203 events: HashMap::<SmolStr, NsidCount>::with_capacity(10),
204 per_second: 0,
205 };
206 let mut updates = 0;
207 while let Ok((nsid, counts)) = listener.recv().await {
208 data.events.insert(
209 nsid,
210 NsidCount {
211 count: counts.count,
212 deleted_count: counts.deleted_count,
213 last_seen: counts.last_seen,
214 },
215 );
216 updates += 1;
217 // send 20 times every second max
218 data.per_second = db.eps();
219 if updates >= data.per_second / 16 {
220 let msg = serde_json::to_string(&data).unwrap();
221 let res = socket.send(Message::text(msg)).await;
222 data.events.clear();
223 updates = 0;
224 if let Err(err) = res {
225 tracing::error!("error sending event: {err}");
226 break;
227 }
228 }
229 }
230 })
231 .instrument(span)
232 })
233}
234
235#[derive(Debug, Serialize)]
236struct Since {
237 since: u64,
238}
239
240async fn since(db: State<Arc<Db>>) -> AppResult<Json<Since>> {
241 Ok(Json(Since {
242 since: db.tracking_since()?,
243 }))
244}