forked from
microcosm.blue/microcosm-rs
Constellation, Spacedust, Slingshot, UFOs: atproto crates and services for microcosm
1use crate::error::SubscriberUpdateError;
2use crate::server::MultiSubscribeQuery;
3use crate::{ClientMessage, FilterableProperties, SubscriberSourcedMessage};
4use dropshot::WebsocketConnectionRaw;
5use futures::SinkExt;
6use futures::StreamExt;
7use std::error::Error;
8use std::sync::Arc;
9use std::time::Duration;
10use tokio::sync::broadcast::{self, error::RecvError};
11use tokio::time::interval;
12use tokio_tungstenite::{WebSocketStream, tungstenite::Message};
13use tokio_util::sync::CancellationToken;
14
15const PING_PERIOD: Duration = Duration::from_secs(30);
16
17pub struct Subscriber {
18 query: MultiSubscribeQuery,
19 shutdown: CancellationToken,
20}
21
22impl Subscriber {
23 pub fn new(query: MultiSubscribeQuery, shutdown: CancellationToken) -> Self {
24 Self { query, shutdown }
25 }
26
27 pub async fn start(
28 mut self,
29 ws: WebSocketStream<WebsocketConnectionRaw>,
30 mut receiver: broadcast::Receiver<Arc<ClientMessage>>,
31 ) -> Result<(), Box<dyn Error>> {
32 let mut ping_state = None;
33 let (mut ws_sender, mut ws_receiver) = ws.split();
34 let mut ping_interval = interval(PING_PERIOD);
35 let _guard = self.shutdown.clone().drop_guard();
36
37 // TODO: do we need to timeout ws sends??
38
39 metrics::counter!("subscribers_connected_total").increment(1);
40 metrics::gauge!("subscribers_connected").increment(1);
41
42 loop {
43 tokio::select! {
44 l = receiver.recv() => match l {
45 Ok(link) => if self.filter(&link.properties) {
46 if let Err(e) = ws_sender.send(link.message.clone()).await {
47 log::warn!("failed to send link, dropping subscriber: {e:?}");
48 break;
49 }
50 },
51 Err(RecvError::Closed) => self.shutdown.cancel(),
52 Err(RecvError::Lagged(n)) => {
53 log::warn!("dropping lagging subscriber (missed {n} messages already)");
54 self.shutdown.cancel();
55 }
56 },
57 cm = ws_receiver.next() => match cm {
58 Some(Ok(Message::Ping(state))) => {
59 if let Err(e) = ws_sender.send(Message::Pong(state)).await {
60 log::error!("failed to reply pong to subscriber: {e:?}");
61 break;
62 }
63 }
64 Some(Ok(Message::Pong(state))) => {
65 if let Some(expected_state) = ping_state {
66 if *state == expected_state {
67 ping_state = None; // good
68 } else {
69 log::error!("subscriber returned a pong with the wrong state, dropping");
70 self.shutdown.cancel();
71 }
72 } else {
73 log::error!("subscriber sent a pong when none was expected");
74 self.shutdown.cancel();
75 }
76 }
77 Some(Ok(Message::Text(raw))) => {
78 if let Err(e) = self.query.update_from_raw(&raw) {
79 log::error!("subscriber options could not be updated, dropping: {e:?}");
80 // TODO: send client an explanation
81 self.shutdown.cancel();
82 }
83 log::trace!("subscriber updated with opts: {:?}", self.query);
84 },
85 Some(Ok(m)) => log::trace!("subscriber sent an unexpected message: {m:?}"),
86 Some(Err(e)) => {
87 log::error!("failed to receive subscriber message: {e:?}");
88 break;
89 }
90 None => {
91 log::trace!("end of subscriber messages. bye!");
92 break;
93 }
94 },
95 _ = ping_interval.tick() => {
96 if ping_state.is_some() {
97 log::warn!("did not recieve pong within {PING_PERIOD:?}, dropping subscriber");
98 self.shutdown.cancel();
99 } else {
100 let new_state: [u8; 8] = rand::random();
101 let ping = new_state.to_vec().into();
102 ping_state = Some(new_state);
103 if let Err(e) = ws_sender.send(Message::Ping(ping)).await {
104 log::error!("failed to send ping to subscriber, dropping: {e:?}");
105 self.shutdown.cancel();
106 }
107 }
108 }
109 _ = self.shutdown.cancelled() => {
110 log::info!("subscriber shutdown requested, bye!");
111 if let Err(e) = ws_sender.close().await {
112 log::warn!("failed to close subscriber: {e:?}");
113 }
114 break;
115 },
116 }
117 }
118 log::trace!("end of subscriber. bye!");
119 metrics::gauge!("subscribers_connected").decrement(1);
120 Ok(())
121 }
122
123 fn filter(&self, properties: &FilterableProperties) -> bool {
124 let query = &self.query;
125
126 // subject + subject DIDs are logical OR
127 if !(query.wanted_subjects.is_empty() && query.wanted_subject_dids.is_empty()
128 || query.wanted_subjects.contains(&properties.subject)
129 || properties
130 .subject_did
131 .as_ref()
132 .map(|did| query.wanted_subject_dids.contains(did))
133 .unwrap_or(false))
134 {
135 // wowwww ^^ fix that
136 return false;
137 }
138
139 // subjects together with sources are logical AND
140 if !(query.wanted_sources.is_empty() || query.wanted_sources.contains(&properties.source)) {
141 return false;
142 }
143
144 true
145 }
146}
147
148impl MultiSubscribeQuery {
149 pub fn update_from_raw(&mut self, s: &str) -> Result<(), SubscriberUpdateError> {
150 let SubscriberSourcedMessage::OptionsUpdate(opts) =
151 serde_json::from_str(s).map_err(SubscriberUpdateError::FailedToParseMessage)?;
152 if opts.wanted_sources.len() > 1_000 {
153 return Err(SubscriberUpdateError::TooManySourcesWanted);
154 }
155 if opts.wanted_subject_dids.len() > 10_000 {
156 return Err(SubscriberUpdateError::TooManyDidsWanted);
157 }
158 if opts.wanted_subjects.len() > 50_000 {
159 return Err(SubscriberUpdateError::TooManySubjectsWanted);
160 }
161 *self = opts;
162 Ok(())
163 }
164}