use std::{ fmt::Display, net::SocketAddr, ops::{Bound, Deref, RangeBounds}, time::Duration, }; use ahash::AHashMap; use anyhow::anyhow; use axum::{ Json, Router, extract::{Query, State}, http::Request, response::Response, routing::get, }; use axum_tws::{Message, WebSocketUpgrade}; use rclite::Arc; use serde::{Deserialize, Serialize}; use smol_str::SmolStr; use tokio_util::sync::CancellationToken; use tower_http::{ classify::ServerErrorsFailureClass, compression::CompressionLayer, request_id::{MakeRequestUuid, PropagateRequestIdLayer, SetRequestIdLayer}, trace::TraceLayer, }; use tracing::{Instrument, Span, field}; use crate::{ db::Db, error::{AppError, AppResult}, }; struct LatencyMillis(u128); impl From for LatencyMillis { fn from(duration: Duration) -> Self { LatencyMillis(duration.as_millis()) } } impl Display for LatencyMillis { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!(f, "{}ms", self.0) } } pub async fn serve(db: Arc, cancel_token: CancellationToken) -> AppResult<()> { let app = Router::new() .route("/events", get(events)) .route("/stream_events", get(stream_events)) .route("/hits", get(hits)) .route("/since", get(since)) .route_layer(CompressionLayer::new().br(true).deflate(true).gzip(true).zstd(true)) .route_layer(PropagateRequestIdLayer::x_request_id()) .route_layer( TraceLayer::new_for_http() .make_span_with(|request: &Request<_>| { let span = tracing::info_span!( "request", method = %request.method(), uri = %request.uri(), id = field::Empty, ip = field::Empty, ); if let Some(id) = request.headers().get("x-request-id") { span.record("id", String::from_utf8_lossy(id.as_bytes()).deref()); } if let Some(real_ip) = request.headers().get("x-real-ip") { span.record("ip", String::from_utf8_lossy(real_ip.as_bytes()).deref()); } span }) .on_request(|_request: &Request<_>, span: &Span| { let _ = span.enter(); tracing::info!("processing") }) .on_response(|response: &Response<_>, latency: Duration, span: &Span| { let _ = span.enter(); tracing::info!({code = %response.status().as_u16(), latency = %LatencyMillis::from(latency)}, "processed") }) .on_eos(()) .on_failure(|error: ServerErrorsFailureClass, _: Duration, span: &Span| { let _ = span.enter(); if matches!(error, ServerErrorsFailureClass::StatusCode(status_code) if status_code.is_server_error()) || matches!(error, ServerErrorsFailureClass::Error(_)) { tracing::error!("server error: {}", error.to_string().to_lowercase()); }; }), ) .route_layer(SetRequestIdLayer::x_request_id(MakeRequestUuid)) .with_state(db); let addr = SocketAddr::from(( [0, 0, 0, 0], std::env::var("PORT") .ok() .and_then(|s| s.parse::().ok()) .unwrap_or(3713), )); let listener = tokio::net::TcpListener::bind(addr).await?; tracing::info!("starting serve on {addr}"); tokio::select! { res = axum::serve(listener, app) => res.map_err(AppError::from), _ = cancel_token.cancelled() => Err(anyhow!("cancelled").into()), } } #[derive(Serialize)] struct NsidCount { count: u128, deleted_count: u128, last_seen: u64, } #[derive(Serialize)] struct Events { per_second: usize, events: AHashMap, } async fn events(db: State>) -> AppResult> { let mut events = AHashMap::new(); for result in db.get_counts() { let (nsid, counts) = result?; events.insert( nsid, NsidCount { count: counts.count, deleted_count: counts.deleted_count, last_seen: counts.last_seen, }, ); } Ok(Json(Events { events, per_second: db.eps(), })) } #[derive(Debug, Deserialize)] struct HitsQuery { nsid: SmolStr, from: Option, to: Option, } #[derive(Debug, Serialize)] struct Hit { timestamp: u64, deleted: bool, } const MAX_HITS: usize = 100_000; #[derive(Debug)] struct HitsRange { from: Bound, to: Bound, } impl RangeBounds for HitsRange { fn start_bound(&self) -> Bound<&u64> { self.from.as_ref() } fn end_bound(&self) -> Bound<&u64> { self.to.as_ref() } } async fn hits( State(db): State>, Query(params): Query, ) -> AppResult>> { let from = params.to.map(Bound::Included).unwrap_or(Bound::Unbounded); let to = params.from.map(Bound::Included).unwrap_or(Bound::Unbounded); db.get_hits(¶ms.nsid, HitsRange { from, to }, MAX_HITS) .take(MAX_HITS) .try_fold(Vec::with_capacity(MAX_HITS), |mut acc, hit| { let hit = hit?; let hit_data = hit.deser()?; acc.push(Hit { timestamp: hit.timestamp, deleted: hit_data.deleted, }); Ok(acc) }) .map(Json) } async fn stream_events(db: State>, ws: WebSocketUpgrade) -> Response { let span = tracing::info_span!(parent: Span::current(), "ws"); ws.on_upgrade(move |mut socket| { (async move { let mut listener = db.new_listener(); let mut data = Events { events: AHashMap::::with_capacity(10), per_second: 0, }; let mut updates = 0; while let Ok((nsid, counts)) = listener.recv().await { data.events.insert( nsid, NsidCount { count: counts.count, deleted_count: counts.deleted_count, last_seen: counts.last_seen, }, ); updates += 1; // send 20 times every second max data.per_second = db.eps(); if updates >= data.per_second / 16 { let msg = serde_json::to_string(&data).unwrap(); let res = socket.send(Message::text(msg)).await; data.events.clear(); updates = 0; if let Err(err) = res { tracing::error!("error sending event: {err}"); break; } } } }) .instrument(span) }) } #[derive(Debug, Serialize)] struct Since { since: u64, } async fn since(db: State>) -> AppResult> { Ok(Json(Since { since: db.tracking_since()?, })) }