forked from
microcosm.blue/microcosm-rs
Constellation, Spacedust, Slingshot, UFOs: atproto crates and services for microcosm
1use super::{LinkReader, LinkStorage, PagedAppendingCollection, StorageStats};
2use crate::{ActionableEvent, CountsByCount, Did, RecordId};
3use anyhow::Result;
4use links::CollectedLink;
5use std::collections::{HashMap, HashSet};
6use std::sync::{Arc, Mutex};
7
8// hopefully-correct simple hashmap version, intended only for tests to verify disk impl
9#[derive(Debug, Clone)]
10pub struct MemStorage(Arc<Mutex<MemStorageData>>);
11
12type Linkers = Vec<Option<(Did, RKey)>>; // optional because we replace with None for deleted links to keep cursors stable
13
14#[derive(Debug, Default)]
15struct MemStorageData {
16 dids: HashMap<Did, bool>, // bool: active or nah
17 targets: HashMap<Target, HashMap<Source, Linkers>>, // target -> (collection, path) -> (did, rkey)?[]
18 links: HashMap<Did, HashMap<RepoId, Vec<(RecordPath, Target)>>>, // did -> collection:rkey -> (path, target)[]
19}
20
21impl MemStorage {
22 pub fn new() -> Self {
23 Self(Arc::new(Mutex::new(MemStorageData::default())))
24 }
25
26 fn add_links(&mut self, record_id: &RecordId, links: &[CollectedLink]) {
27 let mut data = self.0.lock().unwrap();
28 for link in links {
29 data.dids.entry(record_id.did()).or_insert(true); // if they are inserting a link, presumably they are active
30 data.targets
31 .entry(Target::new(link.target.as_str()))
32 .or_default()
33 .entry(Source::new(&record_id.collection, &link.path))
34 .or_default()
35 .push(Some((record_id.did(), RKey(record_id.rkey()))));
36 data.links
37 .entry(record_id.did())
38 .or_default()
39 .entry(RepoId::from_record_id(record_id))
40 .or_insert(Vec::with_capacity(1))
41 .push((
42 RecordPath::new(&link.path),
43 Target::new(link.target.as_str()),
44 ))
45 }
46 }
47
48 fn remove_links(&mut self, record_id: &RecordId) {
49 let mut data = self.0.lock().unwrap();
50 let repo_id = RepoId::from_record_id(record_id);
51 if let Some(Some(link_targets)) = data.links.get(&record_id.did).map(|cr| cr.get(&repo_id))
52 {
53 let link_targets = link_targets.clone(); // satisfy borrowck
54 for (record_path, target) in link_targets {
55 data.targets
56 .get_mut(&target)
57 .expect("must have the target if we have a link saved")
58 .get_mut(&Source::new(&record_id.collection, &record_path.0))
59 .expect("must have the target at this path if we have a link to it saved")
60 .iter_mut()
61 .rfind(|d| **d == Some((record_id.did(), RKey(record_id.rkey()))))
62 .expect("must be in dids list if we have a link to it")
63 .take();
64 }
65 }
66 data.links
67 .get_mut(&record_id.did)
68 .map(|cr| cr.remove(&repo_id));
69 }
70
71 fn update_links(&mut self, record_id: &RecordId, new_links: &[CollectedLink]) {
72 self.remove_links(record_id);
73 self.add_links(record_id, new_links);
74 }
75
76 fn set_account(&mut self, did: &Did, active: bool) {
77 let mut data = self.0.lock().unwrap();
78 if let Some(account) = data.dids.get_mut(did) {
79 *account = active;
80 }
81 }
82
83 fn delete_account(&mut self, did: &Did) {
84 let mut data = self.0.lock().unwrap();
85 if let Some(links) = data.links.get(did) {
86 let links = links.clone();
87 for (repo_id, targets) in links {
88 let targets = targets.clone();
89 for (record_path, target) in targets {
90 data.targets
91 .get_mut(&target)
92 .expect("must have the target if we have a link saved")
93 .get_mut(&Source::new(&repo_id.collection, &record_path.0))
94 .expect("must have the target at this path if we have a link to it saved")
95 .iter_mut()
96 .find(|d| **d == Some((did.clone(), repo_id.rkey.clone())))
97 .expect("lkasjdlfkj")
98 .take();
99 }
100 }
101 }
102 data.links.remove(did); // nb: this is removing by a whole prefix in kv context
103 data.dids.remove(did);
104 }
105}
106
107impl Default for MemStorage {
108 fn default() -> Self {
109 Self::new()
110 }
111}
112
113impl LinkStorage for MemStorage {
114 fn push(&mut self, event: &ActionableEvent, _cursor: u64) -> Result<()> {
115 match event {
116 ActionableEvent::CreateLinks { record_id, links } => self.add_links(record_id, links),
117 ActionableEvent::UpdateLinks {
118 record_id,
119 new_links,
120 } => self.update_links(record_id, new_links),
121 ActionableEvent::DeleteRecord(record_id) => self.remove_links(record_id),
122 ActionableEvent::ActivateAccount(did) => self.set_account(did, true),
123 ActionableEvent::DeactivateAccount(did) => self.set_account(did, false),
124 ActionableEvent::DeleteAccount(did) => self.delete_account(did),
125 }
126 Ok(())
127 }
128
129 fn to_readable(&mut self) -> impl LinkReader {
130 self.clone()
131 }
132}
133
134impl LinkReader for MemStorage {
135 fn get_count(&self, target: &str, collection: &str, path: &str) -> Result<u64> {
136 let data = self.0.lock().unwrap();
137 let Some(paths) = data.targets.get(&Target::new(target)) else {
138 return Ok(0);
139 };
140 let Some(linkers) = paths.get(&Source::new(collection, path)) else {
141 return Ok(0);
142 };
143 Ok(linkers.iter().flatten().count() as u64)
144 }
145
146 fn get_distinct_did_count(&self, target: &str, collection: &str, path: &str) -> Result<u64> {
147 let data = self.0.lock().unwrap();
148 let Some(paths) = data.targets.get(&Target::new(target)) else {
149 return Ok(0);
150 };
151 let Some(linkers) = paths.get(&Source::new(collection, path)) else {
152 return Ok(0);
153 };
154 Ok(linkers
155 .iter()
156 .flatten()
157 .map(|(did, _)| did)
158 .collect::<HashSet<_>>()
159 .len() as u64)
160 }
161
162 fn get_links(
163 &self,
164 target: &str,
165 collection: &str,
166 path: &str,
167 limit: u64,
168 until: Option<u64>,
169 ) -> Result<PagedAppendingCollection<RecordId>> {
170 let data = self.0.lock().unwrap();
171 let Some(paths) = data.targets.get(&Target::new(target)) else {
172 return Ok(PagedAppendingCollection {
173 version: (0, 0),
174 items: Vec::new(),
175 next: None,
176 total: 0,
177 });
178 };
179 let Some(did_rkeys) = paths.get(&Source::new(collection, path)) else {
180 return Ok(PagedAppendingCollection {
181 version: (0, 0),
182 items: Vec::new(),
183 next: None,
184 total: 0,
185 });
186 };
187
188 let total = did_rkeys.len();
189 let end = until
190 .map(|u| std::cmp::min(u as usize, total))
191 .unwrap_or(total);
192 let begin = end.saturating_sub(limit as usize);
193 let next = if begin == 0 { None } else { Some(begin as u64) };
194
195 let alive = did_rkeys.iter().flatten().count();
196 let gone = total - alive;
197
198 let items: Vec<_> = did_rkeys[begin..end]
199 .iter()
200 .rev()
201 .flatten()
202 .filter(|(did, _)| *data.dids.get(did).expect("did must be in dids"))
203 .map(|(did, rkey)| RecordId {
204 did: did.clone(),
205 rkey: rkey.0.clone(),
206 collection: collection.to_string(),
207 })
208 .collect();
209
210 Ok(PagedAppendingCollection {
211 version: (total as u64, gone as u64),
212 items,
213 next,
214 total: alive as u64,
215 })
216 }
217
218 fn get_distinct_dids(
219 &self,
220 target: &str,
221 collection: &str,
222 path: &str,
223 limit: u64,
224 until: Option<u64>,
225 ) -> Result<PagedAppendingCollection<Did>> {
226 let data = self.0.lock().unwrap();
227 let Some(paths) = data.targets.get(&Target::new(target)) else {
228 return Ok(PagedAppendingCollection {
229 version: (0, 0),
230 items: Vec::new(),
231 next: None,
232 total: 0,
233 });
234 };
235 let Some(did_rkeys) = paths.get(&Source::new(collection, path)) else {
236 return Ok(PagedAppendingCollection {
237 version: (0, 0),
238 items: Vec::new(),
239 next: None,
240 total: 0,
241 });
242 };
243
244 let dids: Vec<Option<Did>> = {
245 let mut seen = HashSet::new();
246 did_rkeys
247 .iter()
248 .map(|o| {
249 o.clone().and_then(|(did, _)| {
250 if seen.contains(&did) {
251 None
252 } else {
253 seen.insert(did.clone());
254 Some(did)
255 }
256 })
257 })
258 .collect()
259 };
260
261 let total = dids.len();
262 let end = until
263 .map(|u| std::cmp::min(u as usize, total))
264 .unwrap_or(total);
265 let begin = end.saturating_sub(limit as usize);
266 let next = if begin == 0 { None } else { Some(begin as u64) };
267
268 let alive = dids.iter().flatten().count();
269 let gone = total - alive;
270
271 let items: Vec<Did> = dids[begin..end]
272 .iter()
273 .rev()
274 .flatten()
275 .filter(|did| *data.dids.get(did).expect("did must be in dids"))
276 .cloned()
277 .collect();
278
279 Ok(PagedAppendingCollection {
280 version: (total as u64, gone as u64),
281 items,
282 next,
283 total: alive as u64,
284 })
285 }
286
287 fn get_all_record_counts(&self, target: &str) -> Result<HashMap<String, HashMap<String, u64>>> {
288 let data = self.0.lock().unwrap();
289 let mut out: HashMap<String, HashMap<String, u64>> = HashMap::new();
290 if let Some(asdf) = data.targets.get(&Target::new(target)) {
291 for (Source { collection, path }, linkers) in asdf {
292 let count = linkers.iter().flatten().count() as u64;
293 out.entry(collection.to_string())
294 .or_default()
295 .insert(path.to_string(), count);
296 }
297 }
298 Ok(out)
299 }
300
301 fn get_all_counts(
302 &self,
303 target: &str,
304 ) -> Result<HashMap<String, HashMap<String, CountsByCount>>> {
305 let data = self.0.lock().unwrap();
306 let mut out: HashMap<String, HashMap<String, CountsByCount>> = HashMap::new();
307 if let Some(asdf) = data.targets.get(&Target::new(target)) {
308 for (Source { collection, path }, linkers) in asdf {
309 let records = linkers.iter().flatten().count() as u64;
310 let distinct_dids = linkers
311 .iter()
312 .flatten()
313 .map(|(did, _)| did)
314 .collect::<HashSet<_>>()
315 .len() as u64;
316 out.entry(collection.to_string()).or_default().insert(
317 path.to_string(),
318 CountsByCount {
319 records,
320 distinct_dids,
321 },
322 );
323 }
324 }
325 Ok(out)
326 }
327
328 fn get_stats(&self) -> Result<StorageStats> {
329 let data = self.0.lock().unwrap();
330 let dids = data.dids.len() as u64;
331 let targetables = data
332 .targets
333 .values()
334 .map(|sources| sources.len())
335 .sum::<usize>() as u64;
336 let linking_records = data.links.values().map(|recs| recs.len()).sum::<usize>() as u64;
337 Ok(StorageStats {
338 dids,
339 targetables,
340 linking_records,
341 })
342 }
343}
344
345#[derive(Debug, PartialEq, Hash, Eq, Clone)]
346struct Target(String);
347
348impl Target {
349 fn new(t: &str) -> Self {
350 Self(t.into())
351 }
352}
353
354#[derive(Debug, PartialEq, Hash, Eq, Clone)]
355struct Source {
356 collection: String,
357 path: String,
358}
359
360impl Source {
361 fn new(collection: &str, path: &str) -> Self {
362 Self {
363 collection: collection.into(),
364 path: path.into(),
365 }
366 }
367}
368
369#[derive(Debug, PartialEq, Hash, Eq, Clone)]
370struct RKey(String);
371
372#[derive(Debug, PartialEq, Hash, Eq, Clone)]
373struct RepoId {
374 collection: String,
375 rkey: RKey,
376}
377
378impl RepoId {
379 fn from_record_id(record_id: &RecordId) -> Self {
380 Self {
381 collection: record_id.collection.clone(),
382 rkey: RKey(record_id.rkey.clone()),
383 }
384 }
385}
386
387#[derive(Debug, PartialEq, Hash, Eq, Clone)]
388struct RecordPath(String);
389
390impl RecordPath {
391 fn new(rp: &str) -> Self {
392 Self(rp.into())
393 }
394}