tracks lexicons and how many times they appeared on the jetstream
1use std::{
2 collections::HashMap,
3 fmt::Display,
4 net::SocketAddr,
5 ops::Deref,
6 sync::Arc,
7 time::{Duration, UNIX_EPOCH},
8};
9
10use anyhow::anyhow;
11use axum::{
12 Json, Router,
13 extract::{Query, State},
14 http::Request,
15 response::Response,
16 routing::get,
17};
18use axum_tws::{Message, WebSocketUpgrade};
19use serde::{Deserialize, Serialize};
20use smol_str::SmolStr;
21use tokio_util::sync::CancellationToken;
22use tower_http::{
23 classify::ServerErrorsFailureClass,
24 compression::CompressionLayer,
25 request_id::{MakeRequestUuid, PropagateRequestIdLayer, SetRequestIdLayer},
26 trace::TraceLayer,
27};
28use tracing::{Instrument, Span, field};
29
30use crate::{
31 db::Db,
32 error::{AppError, AppResult},
33};
34
35struct LatencyMillis(u128);
36
37impl From<Duration> for LatencyMillis {
38 fn from(duration: Duration) -> Self {
39 LatencyMillis(duration.as_millis())
40 }
41}
42
43impl Display for LatencyMillis {
44 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
45 write!(f, "{}ms", self.0)
46 }
47}
48
49pub async fn serve(db: Arc<Db>, cancel_token: CancellationToken) -> AppResult<()> {
50 let app = Router::new()
51 .route("/events", get(events))
52 .route("/stream_events", get(stream_events))
53 .route("/hits", get(hits))
54 .route_layer(CompressionLayer::new().br(true).deflate(true).gzip(true).zstd(true))
55 .route_layer(PropagateRequestIdLayer::x_request_id())
56 .route_layer(
57 TraceLayer::new_for_http()
58 .make_span_with(|request: &Request<_>| {
59 let span = tracing::info_span!(
60 "request",
61 method = %request.method(),
62 uri = %request.uri(),
63 id = field::Empty,
64 ip = field::Empty,
65 );
66 if let Some(id) = request.headers().get("x-request-id") {
67 span.record("id", String::from_utf8_lossy(id.as_bytes()).deref());
68 }
69 if let Some(real_ip) = request.headers().get("x-real-ip") {
70 span.record("ip", String::from_utf8_lossy(real_ip.as_bytes()).deref());
71 }
72 span
73 })
74 .on_request(|_request: &Request<_>, span: &Span| {
75 let _ = span.enter();
76 tracing::info!("processing")
77 })
78 .on_response(|response: &Response<_>, latency: Duration, span: &Span| {
79 let _ = span.enter();
80 tracing::info!({code = %response.status().as_u16(), latency = %LatencyMillis::from(latency)}, "processed")
81 })
82 .on_eos(())
83 .on_failure(|error: ServerErrorsFailureClass, _: Duration, span: &Span| {
84 let _ = span.enter();
85 if matches!(error, ServerErrorsFailureClass::StatusCode(status_code) if status_code.is_server_error()) || matches!(error, ServerErrorsFailureClass::Error(_)) {
86 tracing::error!("server error: {}", error.to_string().to_lowercase());
87 };
88 }),
89 )
90 .route_layer(SetRequestIdLayer::x_request_id(MakeRequestUuid))
91 .with_state(db);
92
93 let addr = SocketAddr::from((
94 [0, 0, 0, 0],
95 std::env::var("PORT")
96 .ok()
97 .and_then(|s| s.parse::<u16>().ok())
98 .unwrap_or(3713),
99 ));
100 let listener = tokio::net::TcpListener::bind(addr).await?;
101
102 tracing::info!("starting serve on {addr}");
103 tokio::select! {
104 res = axum::serve(listener, app) => res.map_err(AppError::from),
105 _ = cancel_token.cancelled() => Err(anyhow!("cancelled").into()),
106 }
107}
108
109#[derive(Serialize)]
110struct NsidCount {
111 count: u128,
112 deleted_count: u128,
113 last_seen: u64,
114}
115
116#[derive(Serialize)]
117struct Events {
118 per_second: usize,
119 events: HashMap<SmolStr, NsidCount>,
120}
121
122async fn events(db: State<Arc<Db>>) -> AppResult<Json<Events>> {
123 let mut events = HashMap::new();
124 for result in db.get_counts() {
125 let (nsid, counts) = result?;
126 events.insert(
127 nsid,
128 NsidCount {
129 count: counts.count,
130 deleted_count: counts.deleted_count,
131 last_seen: counts.last_seen,
132 },
133 );
134 }
135 Ok(Json(Events {
136 events,
137 per_second: db.eps(),
138 }))
139}
140
141#[derive(Debug, Deserialize)]
142struct HitsQuery {
143 nsid: SmolStr,
144 from: Option<u64>,
145 to: Option<u64>,
146}
147
148#[derive(Serialize)]
149struct Hit {
150 timestamp: u64,
151 deleted: bool,
152}
153
154const MAX_HITS: usize = 100_000;
155
156async fn hits(
157 State(db): State<Arc<Db>>,
158 Query(params): Query<HitsQuery>,
159) -> AppResult<Json<Vec<Hit>>> {
160 let maybe_hits = db
161 .get_hits(
162 ¶ms.nsid,
163 params.to.unwrap_or(0)
164 ..params.from.unwrap_or(
165 std::time::SystemTime::now()
166 .duration_since(UNIX_EPOCH)
167 .expect("oops")
168 .as_micros() as u64,
169 ),
170 )?
171 .take(MAX_HITS);
172 let mut hits = Vec::with_capacity(maybe_hits.size_hint().0);
173
174 for maybe_hit in maybe_hits {
175 let (timestamp, hit) = maybe_hit?;
176 hits.push(Hit {
177 timestamp,
178 deleted: hit.deleted,
179 });
180 }
181
182 Ok(Json(hits))
183}
184
185async fn stream_events(db: State<Arc<Db>>, ws: WebSocketUpgrade) -> Response {
186 let span = tracing::info_span!(parent: Span::current(), "ws");
187 ws.on_upgrade(move |mut socket| {
188 (async move {
189 let mut listener = db.new_listener();
190 let mut data = Events {
191 events: HashMap::<SmolStr, NsidCount>::with_capacity(10),
192 per_second: 0,
193 };
194 let mut updates = 0;
195 while let Ok((nsid, counts)) = listener.recv().await {
196 data.events.insert(
197 nsid,
198 NsidCount {
199 count: counts.count,
200 deleted_count: counts.deleted_count,
201 last_seen: counts.last_seen,
202 },
203 );
204 updates += 1;
205 // send 20 times every second max
206 data.per_second = db.eps();
207 if updates >= data.per_second / 16 {
208 let msg = serde_json::to_string(&data).unwrap();
209 let res = socket.send(Message::text(msg)).await;
210 data.events.clear();
211 updates = 0;
212 if let Err(err) = res {
213 tracing::error!("error sending event: {err}");
214 break;
215 }
216 }
217 }
218 })
219 .instrument(span)
220 })
221}