tracks lexicons and how many times they appeared on the jetstream
1use std::{ 2 collections::HashMap, 3 fmt::Display, 4 net::SocketAddr, 5 ops::Deref, 6 sync::Arc, 7 time::{Duration, UNIX_EPOCH}, 8}; 9 10use anyhow::anyhow; 11use axum::{ 12 Json, Router, 13 extract::{Query, State}, 14 http::Request, 15 response::Response, 16 routing::get, 17}; 18use axum_tws::{Message, WebSocketUpgrade}; 19use serde::{Deserialize, Serialize}; 20use smol_str::SmolStr; 21use tokio_util::sync::CancellationToken; 22use tower_http::{ 23 classify::ServerErrorsFailureClass, 24 compression::CompressionLayer, 25 request_id::{MakeRequestUuid, PropagateRequestIdLayer, SetRequestIdLayer}, 26 trace::TraceLayer, 27}; 28use tracing::{Instrument, Span, field}; 29 30use crate::{ 31 db::Db, 32 error::{AppError, AppResult}, 33 utils::time_now, 34}; 35 36struct LatencyMillis(u128); 37 38impl From<Duration> for LatencyMillis { 39 fn from(duration: Duration) -> Self { 40 LatencyMillis(duration.as_millis()) 41 } 42} 43 44impl Display for LatencyMillis { 45 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { 46 write!(f, "{}ms", self.0) 47 } 48} 49 50pub async fn serve(db: Arc<Db>, cancel_token: CancellationToken) -> AppResult<()> { 51 let app = Router::new() 52 .route("/events", get(events)) 53 .route("/stream_events", get(stream_events)) 54 .route("/hits", get(hits)) 55 .route("/since", get(since)) 56 .route_layer(CompressionLayer::new().br(true).deflate(true).gzip(true).zstd(true)) 57 .route_layer(PropagateRequestIdLayer::x_request_id()) 58 .route_layer( 59 TraceLayer::new_for_http() 60 .make_span_with(|request: &Request<_>| { 61 let span = tracing::info_span!( 62 "request", 63 method = %request.method(), 64 uri = %request.uri(), 65 id = field::Empty, 66 ip = field::Empty, 67 ); 68 if let Some(id) = request.headers().get("x-request-id") { 69 span.record("id", String::from_utf8_lossy(id.as_bytes()).deref()); 70 } 71 if let Some(real_ip) = request.headers().get("x-real-ip") { 72 span.record("ip", String::from_utf8_lossy(real_ip.as_bytes()).deref()); 73 } 74 span 75 }) 76 .on_request(|_request: &Request<_>, span: &Span| { 77 let _ = span.enter(); 78 tracing::info!("processing") 79 }) 80 .on_response(|response: &Response<_>, latency: Duration, span: &Span| { 81 let _ = span.enter(); 82 tracing::info!({code = %response.status().as_u16(), latency = %LatencyMillis::from(latency)}, "processed") 83 }) 84 .on_eos(()) 85 .on_failure(|error: ServerErrorsFailureClass, _: Duration, span: &Span| { 86 let _ = span.enter(); 87 if matches!(error, ServerErrorsFailureClass::StatusCode(status_code) if status_code.is_server_error()) || matches!(error, ServerErrorsFailureClass::Error(_)) { 88 tracing::error!("server error: {}", error.to_string().to_lowercase()); 89 }; 90 }), 91 ) 92 .route_layer(SetRequestIdLayer::x_request_id(MakeRequestUuid)) 93 .with_state(db); 94 95 let addr = SocketAddr::from(( 96 [0, 0, 0, 0], 97 std::env::var("PORT") 98 .ok() 99 .and_then(|s| s.parse::<u16>().ok()) 100 .unwrap_or(3713), 101 )); 102 let listener = tokio::net::TcpListener::bind(addr).await?; 103 104 tracing::info!("starting serve on {addr}"); 105 tokio::select! { 106 res = axum::serve(listener, app) => res.map_err(AppError::from), 107 _ = cancel_token.cancelled() => Err(anyhow!("cancelled").into()), 108 } 109} 110 111#[derive(Serialize)] 112struct NsidCount { 113 count: u128, 114 deleted_count: u128, 115 last_seen: u64, 116} 117 118#[derive(Serialize)] 119struct Events { 120 per_second: usize, 121 events: HashMap<SmolStr, NsidCount>, 122} 123 124async fn events(db: State<Arc<Db>>) -> AppResult<Json<Events>> { 125 let mut events = HashMap::new(); 126 for result in db.get_counts() { 127 let (nsid, counts) = result?; 128 events.insert( 129 nsid, 130 NsidCount { 131 count: counts.count, 132 deleted_count: counts.deleted_count, 133 last_seen: counts.last_seen, 134 }, 135 ); 136 } 137 Ok(Json(Events { 138 events, 139 per_second: db.eps(), 140 })) 141} 142 143#[derive(Debug, Deserialize)] 144struct HitsQuery { 145 nsid: SmolStr, 146 from: Option<u64>, 147 to: Option<u64>, 148} 149 150#[derive(Serialize)] 151struct Hit { 152 timestamp: u64, 153 deleted: bool, 154} 155 156const MAX_HITS: usize = 100_000; 157 158async fn hits( 159 State(db): State<Arc<Db>>, 160 Query(params): Query<HitsQuery>, 161) -> AppResult<Json<Vec<Hit>>> { 162 let maybe_hits = db 163 .get_hits( 164 &params.nsid, 165 params.to.unwrap_or(0)..params.from.unwrap_or(time_now()), 166 ) 167 .take(MAX_HITS); 168 let mut hits = Vec::with_capacity(maybe_hits.size_hint().0); 169 170 for maybe_hit in maybe_hits { 171 let hit = maybe_hit?; 172 let hit_data = hit.access(); 173 174 hits.push(Hit { 175 timestamp: hit.timestamp, 176 deleted: hit_data.deleted, 177 }); 178 } 179 180 Ok(Json(hits)) 181} 182 183async fn stream_events(db: State<Arc<Db>>, ws: WebSocketUpgrade) -> Response { 184 let span = tracing::info_span!(parent: Span::current(), "ws"); 185 ws.on_upgrade(move |mut socket| { 186 (async move { 187 let mut listener = db.new_listener(); 188 let mut data = Events { 189 events: HashMap::<SmolStr, NsidCount>::with_capacity(10), 190 per_second: 0, 191 }; 192 let mut updates = 0; 193 while let Ok((nsid, counts)) = listener.recv().await { 194 data.events.insert( 195 nsid, 196 NsidCount { 197 count: counts.count, 198 deleted_count: counts.deleted_count, 199 last_seen: counts.last_seen, 200 }, 201 ); 202 updates += 1; 203 // send 20 times every second max 204 data.per_second = db.eps(); 205 if updates >= data.per_second / 16 { 206 let msg = serde_json::to_string(&data).unwrap(); 207 let res = socket.send(Message::text(msg)).await; 208 data.events.clear(); 209 updates = 0; 210 if let Err(err) = res { 211 tracing::error!("error sending event: {err}"); 212 break; 213 } 214 } 215 } 216 }) 217 .instrument(span) 218 }) 219} 220 221#[derive(Debug, Serialize)] 222struct Since { 223 since: u64, 224} 225 226async fn since(db: State<Arc<Db>>) -> AppResult<Json<Since>> { 227 Ok(Json(Since { 228 since: db.tracking_since()?, 229 })) 230}