tracks lexicons and how many times they appeared on the jetstream
1use std::{ 2 collections::HashMap, 3 fmt::Display, 4 net::SocketAddr, 5 ops::Deref, 6 sync::Arc, 7 time::{Duration, UNIX_EPOCH}, 8}; 9 10use anyhow::anyhow; 11use axum::{ 12 Json, Router, 13 extract::{Query, State}, 14 http::Request, 15 response::Response, 16 routing::get, 17}; 18use axum_tws::{Message, WebSocketUpgrade}; 19use serde::{Deserialize, Serialize}; 20use smol_str::SmolStr; 21use tokio_util::sync::CancellationToken; 22use tower_http::{ 23 classify::ServerErrorsFailureClass, 24 compression::CompressionLayer, 25 request_id::{MakeRequestUuid, PropagateRequestIdLayer, SetRequestIdLayer}, 26 trace::TraceLayer, 27}; 28use tracing::{Instrument, Span, field}; 29 30use crate::{ 31 db::Db, 32 error::{AppError, AppResult}, 33}; 34 35struct LatencyMillis(u128); 36 37impl From<Duration> for LatencyMillis { 38 fn from(duration: Duration) -> Self { 39 LatencyMillis(duration.as_millis()) 40 } 41} 42 43impl Display for LatencyMillis { 44 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { 45 write!(f, "{}ms", self.0) 46 } 47} 48 49pub async fn serve(db: Arc<Db>, cancel_token: CancellationToken) -> AppResult<()> { 50 let app = Router::new() 51 .route("/events", get(events)) 52 .route("/stream_events", get(stream_events)) 53 .route("/hits", get(hits)) 54 .route_layer(CompressionLayer::new().br(true).deflate(true).gzip(true).zstd(true)) 55 .route_layer(PropagateRequestIdLayer::x_request_id()) 56 .route_layer( 57 TraceLayer::new_for_http() 58 .make_span_with(|request: &Request<_>| { 59 let span = tracing::info_span!( 60 "request", 61 method = %request.method(), 62 uri = %request.uri(), 63 id = field::Empty, 64 ip = field::Empty, 65 ); 66 if let Some(id) = request.headers().get("x-request-id") { 67 span.record("id", String::from_utf8_lossy(id.as_bytes()).deref()); 68 } 69 if let Some(real_ip) = request.headers().get("x-real-ip") { 70 span.record("ip", String::from_utf8_lossy(real_ip.as_bytes()).deref()); 71 } 72 span 73 }) 74 .on_request(|_request: &Request<_>, span: &Span| { 75 let _ = span.enter(); 76 tracing::info!("processing") 77 }) 78 .on_response(|response: &Response<_>, latency: Duration, span: &Span| { 79 let _ = span.enter(); 80 tracing::info!({code = %response.status().as_u16(), latency = %LatencyMillis::from(latency)}, "processed") 81 }) 82 .on_eos(()) 83 .on_failure(|error: ServerErrorsFailureClass, _: Duration, span: &Span| { 84 let _ = span.enter(); 85 if matches!(error, ServerErrorsFailureClass::StatusCode(status_code) if status_code.is_server_error()) || matches!(error, ServerErrorsFailureClass::Error(_)) { 86 tracing::error!("server error: {}", error.to_string().to_lowercase()); 87 }; 88 }), 89 ) 90 .route_layer(SetRequestIdLayer::x_request_id(MakeRequestUuid)) 91 .with_state(db); 92 93 let addr = SocketAddr::from(( 94 [0, 0, 0, 0], 95 std::env::var("PORT") 96 .ok() 97 .and_then(|s| s.parse::<u16>().ok()) 98 .unwrap_or(3713), 99 )); 100 let listener = tokio::net::TcpListener::bind(addr).await?; 101 102 tracing::info!("starting serve on {addr}"); 103 tokio::select! { 104 res = axum::serve(listener, app) => res.map_err(AppError::from), 105 _ = cancel_token.cancelled() => Err(anyhow!("cancelled").into()), 106 } 107} 108 109#[derive(Serialize)] 110struct NsidCount { 111 count: u128, 112 deleted_count: u128, 113 last_seen: u64, 114} 115 116#[derive(Serialize)] 117struct Events { 118 per_second: usize, 119 events: HashMap<SmolStr, NsidCount>, 120} 121 122async fn events(db: State<Arc<Db>>) -> AppResult<Json<Events>> { 123 let mut events = HashMap::new(); 124 for result in db.get_counts() { 125 let (nsid, counts) = result?; 126 events.insert( 127 nsid, 128 NsidCount { 129 count: counts.count, 130 deleted_count: counts.deleted_count, 131 last_seen: counts.last_seen, 132 }, 133 ); 134 } 135 Ok(Json(Events { 136 events, 137 per_second: db.eps(), 138 })) 139} 140 141#[derive(Debug, Deserialize)] 142struct HitsQuery { 143 nsid: SmolStr, 144 from: Option<u64>, 145 to: Option<u64>, 146} 147 148#[derive(Serialize)] 149struct Hit { 150 timestamp: u64, 151 deleted: bool, 152} 153 154const MAX_HITS: usize = 100_000; 155 156async fn hits( 157 State(db): State<Arc<Db>>, 158 Query(params): Query<HitsQuery>, 159) -> AppResult<Json<Vec<Hit>>> { 160 let maybe_hits = db 161 .get_hits( 162 &params.nsid, 163 params.to.unwrap_or(0) 164 ..params.from.unwrap_or( 165 std::time::SystemTime::now() 166 .duration_since(UNIX_EPOCH) 167 .expect("oops") 168 .as_micros() as u64, 169 ), 170 )? 171 .take(MAX_HITS); 172 let mut hits = Vec::with_capacity(maybe_hits.size_hint().0); 173 174 for maybe_hit in maybe_hits { 175 let (timestamp, hit) = maybe_hit?; 176 hits.push(Hit { 177 timestamp, 178 deleted: hit.deleted, 179 }); 180 } 181 182 Ok(Json(hits)) 183} 184 185async fn stream_events(db: State<Arc<Db>>, ws: WebSocketUpgrade) -> Response { 186 let span = tracing::info_span!(parent: Span::current(), "ws"); 187 ws.on_upgrade(move |mut socket| { 188 (async move { 189 let mut listener = db.new_listener(); 190 let mut data = Events { 191 events: HashMap::<SmolStr, NsidCount>::with_capacity(10), 192 per_second: 0, 193 }; 194 let mut updates = 0; 195 while let Ok((nsid, counts)) = listener.recv().await { 196 data.events.insert( 197 nsid, 198 NsidCount { 199 count: counts.count, 200 deleted_count: counts.deleted_count, 201 last_seen: counts.last_seen, 202 }, 203 ); 204 updates += 1; 205 // send 20 times every second max 206 data.per_second = db.eps(); 207 if updates >= data.per_second / 16 { 208 let msg = serde_json::to_string(&data).unwrap(); 209 let res = socket.send(Message::text(msg)).await; 210 data.events.clear(); 211 updates = 0; 212 if let Err(err) = res { 213 tracing::error!("error sending event: {err}"); 214 break; 215 } 216 } 217 } 218 }) 219 .instrument(span) 220 }) 221}