tracks lexicons and how many times they appeared on the jetstream
1use std::{ 2 fmt::Display, 3 net::SocketAddr, 4 ops::{Bound, Deref, RangeBounds}, 5 time::Duration, 6}; 7 8use ahash::AHashMap; 9use anyhow::anyhow; 10use axum::{ 11 Json, Router, 12 extract::{Query, State}, 13 http::Request, 14 response::Response, 15 routing::get, 16}; 17use axum_tws::{Message, WebSocketUpgrade}; 18use rclite::Arc; 19use serde::{Deserialize, Serialize}; 20use smol_str::SmolStr; 21use tokio_util::sync::CancellationToken; 22use tower_http::{ 23 classify::ServerErrorsFailureClass, 24 compression::CompressionLayer, 25 request_id::{MakeRequestUuid, PropagateRequestIdLayer, SetRequestIdLayer}, 26 trace::TraceLayer, 27}; 28use tracing::{Instrument, Span, field}; 29 30use crate::{ 31 db::Db, 32 error::{AppError, AppResult}, 33}; 34 35struct LatencyMillis(u128); 36 37impl From<Duration> for LatencyMillis { 38 fn from(duration: Duration) -> Self { 39 LatencyMillis(duration.as_millis()) 40 } 41} 42 43impl Display for LatencyMillis { 44 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { 45 write!(f, "{}ms", self.0) 46 } 47} 48 49pub async fn serve(db: Arc<Db>, cancel_token: CancellationToken) -> AppResult<()> { 50 let app = Router::new() 51 .route("/events", get(events)) 52 .route("/stream_events", get(stream_events)) 53 .route("/hits", get(hits)) 54 .route("/since", get(since)) 55 .route_layer(CompressionLayer::new().br(true).deflate(true).gzip(true).zstd(true)) 56 .route_layer(PropagateRequestIdLayer::x_request_id()) 57 .route_layer( 58 TraceLayer::new_for_http() 59 .make_span_with(|request: &Request<_>| { 60 let span = tracing::info_span!( 61 "request", 62 method = %request.method(), 63 uri = %request.uri(), 64 id = field::Empty, 65 ip = field::Empty, 66 ); 67 if let Some(id) = request.headers().get("x-request-id") { 68 span.record("id", String::from_utf8_lossy(id.as_bytes()).deref()); 69 } 70 if let Some(real_ip) = request.headers().get("x-real-ip") { 71 span.record("ip", String::from_utf8_lossy(real_ip.as_bytes()).deref()); 72 } 73 span 74 }) 75 .on_request(|_request: &Request<_>, span: &Span| { 76 let _ = span.enter(); 77 tracing::info!("processing") 78 }) 79 .on_response(|response: &Response<_>, latency: Duration, span: &Span| { 80 let _ = span.enter(); 81 tracing::info!({code = %response.status().as_u16(), latency = %LatencyMillis::from(latency)}, "processed") 82 }) 83 .on_eos(()) 84 .on_failure(|error: ServerErrorsFailureClass, _: Duration, span: &Span| { 85 let _ = span.enter(); 86 if matches!(error, ServerErrorsFailureClass::StatusCode(status_code) if status_code.is_server_error()) || matches!(error, ServerErrorsFailureClass::Error(_)) { 87 tracing::error!("server error: {}", error.to_string().to_lowercase()); 88 }; 89 }), 90 ) 91 .route_layer(SetRequestIdLayer::x_request_id(MakeRequestUuid)) 92 .with_state(db); 93 94 let addr = SocketAddr::from(( 95 [0, 0, 0, 0], 96 std::env::var("PORT") 97 .ok() 98 .and_then(|s| s.parse::<u16>().ok()) 99 .unwrap_or(3713), 100 )); 101 let listener = tokio::net::TcpListener::bind(addr).await?; 102 103 tracing::info!("starting serve on {addr}"); 104 tokio::select! { 105 res = axum::serve(listener, app) => res.map_err(AppError::from), 106 _ = cancel_token.cancelled() => Err(anyhow!("cancelled").into()), 107 } 108} 109 110#[derive(Serialize)] 111struct NsidCount { 112 count: u128, 113 deleted_count: u128, 114 last_seen: u64, 115} 116 117#[derive(Serialize)] 118struct Events { 119 per_second: usize, 120 events: AHashMap<SmolStr, NsidCount>, 121} 122 123async fn events(db: State<Arc<Db>>) -> AppResult<Json<Events>> { 124 let mut events = AHashMap::new(); 125 for result in db.get_counts() { 126 let (nsid, counts) = result?; 127 events.insert( 128 nsid, 129 NsidCount { 130 count: counts.count, 131 deleted_count: counts.deleted_count, 132 last_seen: counts.last_seen, 133 }, 134 ); 135 } 136 Ok(Json(Events { 137 events, 138 per_second: db.eps(), 139 })) 140} 141 142#[derive(Debug, Deserialize)] 143struct HitsQuery { 144 nsid: SmolStr, 145 from: Option<u64>, 146 to: Option<u64>, 147} 148 149#[derive(Debug, Serialize)] 150struct Hit { 151 timestamp: u64, 152 deleted: bool, 153} 154 155const MAX_HITS: usize = 100_000; 156 157#[derive(Debug)] 158struct HitsRange { 159 from: Bound<u64>, 160 to: Bound<u64>, 161} 162 163impl RangeBounds<u64> for HitsRange { 164 fn start_bound(&self) -> Bound<&u64> { 165 self.from.as_ref() 166 } 167 168 fn end_bound(&self) -> Bound<&u64> { 169 self.to.as_ref() 170 } 171} 172 173async fn hits( 174 State(db): State<Arc<Db>>, 175 Query(params): Query<HitsQuery>, 176) -> AppResult<Json<Vec<Hit>>> { 177 let from = params.to.map(Bound::Included).unwrap_or(Bound::Unbounded); 178 let to = params.from.map(Bound::Included).unwrap_or(Bound::Unbounded); 179 180 db.get_hits(&params.nsid, HitsRange { from, to }, MAX_HITS) 181 .take(MAX_HITS) 182 .try_fold(Vec::with_capacity(MAX_HITS), |mut acc, hit| { 183 let hit = hit?; 184 let hit_data = hit.deser()?; 185 186 acc.push(Hit { 187 timestamp: hit.timestamp, 188 deleted: hit_data.deleted, 189 }); 190 Ok(acc) 191 }) 192 .map(Json) 193} 194 195async fn stream_events(db: State<Arc<Db>>, ws: WebSocketUpgrade) -> Response { 196 let span = tracing::info_span!(parent: Span::current(), "ws"); 197 ws.on_upgrade(move |mut socket| { 198 (async move { 199 let mut listener = db.new_listener(); 200 let mut data = Events { 201 events: AHashMap::<SmolStr, NsidCount>::with_capacity(10), 202 per_second: 0, 203 }; 204 let mut updates = 0; 205 while let Ok((nsid, counts)) = listener.recv().await { 206 data.events.insert( 207 nsid, 208 NsidCount { 209 count: counts.count, 210 deleted_count: counts.deleted_count, 211 last_seen: counts.last_seen, 212 }, 213 ); 214 updates += 1; 215 // send 20 times every second max 216 data.per_second = db.eps(); 217 if updates >= data.per_second / 16 { 218 let msg = serde_json::to_string(&data).unwrap(); 219 let res = socket.send(Message::text(msg)).await; 220 data.events.clear(); 221 updates = 0; 222 if let Err(err) = res { 223 tracing::error!("error sending event: {err}"); 224 break; 225 } 226 } 227 } 228 }) 229 .instrument(span) 230 }) 231} 232 233#[derive(Debug, Serialize)] 234struct Since { 235 since: u64, 236} 237 238async fn since(db: State<Arc<Db>>) -> AppResult<Json<Since>> { 239 Ok(Json(Since { 240 since: db.tracking_since()?, 241 })) 242}