tracks lexicons and how many times they appeared on the jetstream
1use std::{
2 collections::HashMap,
3 fmt::Display,
4 net::SocketAddr,
5 ops::Deref,
6 sync::Arc,
7 time::{Duration, UNIX_EPOCH},
8};
9
10use anyhow::anyhow;
11use axum::{
12 Json, Router,
13 extract::{Query, State},
14 http::Request,
15 response::Response,
16 routing::get,
17};
18use axum_tws::{Message, WebSocketUpgrade};
19use serde::{Deserialize, Serialize};
20use smol_str::SmolStr;
21use tokio_util::sync::CancellationToken;
22use tower_http::{
23 classify::ServerErrorsFailureClass,
24 compression::CompressionLayer,
25 request_id::{MakeRequestUuid, PropagateRequestIdLayer, SetRequestIdLayer},
26 trace::TraceLayer,
27};
28use tracing::{Instrument, Span, field};
29
30use crate::{
31 db::Db,
32 error::{AppError, AppResult},
33 utils::time_now,
34};
35
36struct LatencyMillis(u128);
37
38impl From<Duration> for LatencyMillis {
39 fn from(duration: Duration) -> Self {
40 LatencyMillis(duration.as_millis())
41 }
42}
43
44impl Display for LatencyMillis {
45 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
46 write!(f, "{}ms", self.0)
47 }
48}
49
50pub async fn serve(db: Arc<Db>, cancel_token: CancellationToken) -> AppResult<()> {
51 let app = Router::new()
52 .route("/events", get(events))
53 .route("/stream_events", get(stream_events))
54 .route("/hits", get(hits))
55 .route("/since", get(since))
56 .route_layer(CompressionLayer::new().br(true).deflate(true).gzip(true).zstd(true))
57 .route_layer(PropagateRequestIdLayer::x_request_id())
58 .route_layer(
59 TraceLayer::new_for_http()
60 .make_span_with(|request: &Request<_>| {
61 let span = tracing::info_span!(
62 "request",
63 method = %request.method(),
64 uri = %request.uri(),
65 id = field::Empty,
66 ip = field::Empty,
67 );
68 if let Some(id) = request.headers().get("x-request-id") {
69 span.record("id", String::from_utf8_lossy(id.as_bytes()).deref());
70 }
71 if let Some(real_ip) = request.headers().get("x-real-ip") {
72 span.record("ip", String::from_utf8_lossy(real_ip.as_bytes()).deref());
73 }
74 span
75 })
76 .on_request(|_request: &Request<_>, span: &Span| {
77 let _ = span.enter();
78 tracing::info!("processing")
79 })
80 .on_response(|response: &Response<_>, latency: Duration, span: &Span| {
81 let _ = span.enter();
82 tracing::info!({code = %response.status().as_u16(), latency = %LatencyMillis::from(latency)}, "processed")
83 })
84 .on_eos(())
85 .on_failure(|error: ServerErrorsFailureClass, _: Duration, span: &Span| {
86 let _ = span.enter();
87 if matches!(error, ServerErrorsFailureClass::StatusCode(status_code) if status_code.is_server_error()) || matches!(error, ServerErrorsFailureClass::Error(_)) {
88 tracing::error!("server error: {}", error.to_string().to_lowercase());
89 };
90 }),
91 )
92 .route_layer(SetRequestIdLayer::x_request_id(MakeRequestUuid))
93 .with_state(db);
94
95 let addr = SocketAddr::from((
96 [0, 0, 0, 0],
97 std::env::var("PORT")
98 .ok()
99 .and_then(|s| s.parse::<u16>().ok())
100 .unwrap_or(3713),
101 ));
102 let listener = tokio::net::TcpListener::bind(addr).await?;
103
104 tracing::info!("starting serve on {addr}");
105 tokio::select! {
106 res = axum::serve(listener, app) => res.map_err(AppError::from),
107 _ = cancel_token.cancelled() => Err(anyhow!("cancelled").into()),
108 }
109}
110
111#[derive(Serialize)]
112struct NsidCount {
113 count: u128,
114 deleted_count: u128,
115 last_seen: u64,
116}
117
118#[derive(Serialize)]
119struct Events {
120 per_second: usize,
121 events: HashMap<SmolStr, NsidCount>,
122}
123
124async fn events(db: State<Arc<Db>>) -> AppResult<Json<Events>> {
125 let mut events = HashMap::new();
126 for result in db.get_counts() {
127 let (nsid, counts) = result?;
128 events.insert(
129 nsid,
130 NsidCount {
131 count: counts.count,
132 deleted_count: counts.deleted_count,
133 last_seen: counts.last_seen,
134 },
135 );
136 }
137 Ok(Json(Events {
138 events,
139 per_second: db.eps(),
140 }))
141}
142
143#[derive(Debug, Deserialize)]
144struct HitsQuery {
145 nsid: SmolStr,
146 from: Option<u64>,
147 to: Option<u64>,
148}
149
150#[derive(Serialize)]
151struct Hit {
152 timestamp: u64,
153 deleted: bool,
154}
155
156const MAX_HITS: usize = 100_000;
157
158async fn hits(
159 State(db): State<Arc<Db>>,
160 Query(params): Query<HitsQuery>,
161) -> AppResult<Json<Vec<Hit>>> {
162 let maybe_hits = db
163 .get_hits(
164 ¶ms.nsid,
165 params.to.unwrap_or(0)..params.from.unwrap_or(time_now()),
166 )
167 .take(MAX_HITS);
168 let mut hits = Vec::with_capacity(maybe_hits.size_hint().0);
169
170 for maybe_hit in maybe_hits {
171 let hit = maybe_hit?;
172 let hit_data = hit.access();
173
174 hits.push(Hit {
175 timestamp: hit.timestamp,
176 deleted: hit_data.deleted,
177 });
178 }
179
180 Ok(Json(hits))
181}
182
183async fn stream_events(db: State<Arc<Db>>, ws: WebSocketUpgrade) -> Response {
184 let span = tracing::info_span!(parent: Span::current(), "ws");
185 ws.on_upgrade(move |mut socket| {
186 (async move {
187 let mut listener = db.new_listener();
188 let mut data = Events {
189 events: HashMap::<SmolStr, NsidCount>::with_capacity(10),
190 per_second: 0,
191 };
192 let mut updates = 0;
193 while let Ok((nsid, counts)) = listener.recv().await {
194 data.events.insert(
195 nsid,
196 NsidCount {
197 count: counts.count,
198 deleted_count: counts.deleted_count,
199 last_seen: counts.last_seen,
200 },
201 );
202 updates += 1;
203 // send 20 times every second max
204 data.per_second = db.eps();
205 if updates >= data.per_second / 16 {
206 let msg = serde_json::to_string(&data).unwrap();
207 let res = socket.send(Message::text(msg)).await;
208 data.events.clear();
209 updates = 0;
210 if let Err(err) = res {
211 tracing::error!("error sending event: {err}");
212 break;
213 }
214 }
215 }
216 })
217 .instrument(span)
218 })
219}
220
221#[derive(Debug, Serialize)]
222struct Since {
223 since: u64,
224}
225
226async fn since(db: State<Arc<Db>>) -> AppResult<Json<Since>> {
227 Ok(Json(Since {
228 since: db.tracking_since()?,
229 }))
230}