forked from
microcosm.blue/microcosm-rs
Constellation, Spacedust, Slingshot, UFOs: atproto crates and services for microcosm
1use super::{
2 LinkReader, LinkStorage, PagedAppendingCollection, PagedOrderedCollection, StorageStats,
3};
4use crate::{ActionableEvent, CountsByCount, Did, RecordId};
5use anyhow::Result;
6use links::CollectedLink;
7use std::collections::{HashMap, HashSet};
8use std::sync::{Arc, Mutex};
9
10// hopefully-correct simple hashmap version, intended only for tests to verify disk impl
11#[derive(Debug, Clone)]
12pub struct MemStorage(Arc<Mutex<MemStorageData>>);
13
14type Linkers = Vec<Option<(Did, RKey)>>; // optional because we replace with None for deleted links to keep cursors stable
15
16#[derive(Debug, Default)]
17struct MemStorageData {
18 dids: HashMap<Did, bool>, // bool: active or nah
19 targets: HashMap<Target, HashMap<Source, Linkers>>, // target -> (collection, path) -> (did, rkey)?[]
20 links: HashMap<Did, HashMap<RepoId, Vec<(RecordPath, Target)>>>, // did -> collection:rkey -> (path, target)[]
21}
22
23impl MemStorage {
24 pub fn new() -> Self {
25 Self(Arc::new(Mutex::new(MemStorageData::default())))
26 }
27
28 fn add_links(&mut self, record_id: &RecordId, links: &[CollectedLink]) {
29 let mut data = self.0.lock().unwrap();
30 for link in links {
31 data.dids.entry(record_id.did()).or_insert(true); // if they are inserting a link, presumably they are active
32 data.targets
33 .entry(Target::new(link.target.as_str()))
34 .or_default()
35 .entry(Source::new(&record_id.collection, &link.path))
36 .or_default()
37 .push(Some((record_id.did(), RKey(record_id.rkey()))));
38 data.links
39 .entry(record_id.did())
40 .or_default()
41 .entry(RepoId::from_record_id(record_id))
42 .or_insert(Vec::with_capacity(1))
43 .push((
44 RecordPath::new(&link.path),
45 Target::new(link.target.as_str()),
46 ))
47 }
48 }
49
50 fn remove_links(&mut self, record_id: &RecordId) {
51 let mut data = self.0.lock().unwrap();
52 let repo_id = RepoId::from_record_id(record_id);
53 if let Some(Some(link_targets)) = data.links.get(&record_id.did).map(|cr| cr.get(&repo_id))
54 {
55 let link_targets = link_targets.clone(); // satisfy borrowck
56 for (record_path, target) in link_targets {
57 data.targets
58 .get_mut(&target)
59 .expect("must have the target if we have a link saved")
60 .get_mut(&Source::new(&record_id.collection, &record_path.0))
61 .expect("must have the target at this path if we have a link to it saved")
62 .iter_mut()
63 .rfind(|d| **d == Some((record_id.did(), RKey(record_id.rkey()))))
64 .expect("must be in dids list if we have a link to it")
65 .take();
66 }
67 }
68 data.links
69 .get_mut(&record_id.did)
70 .map(|cr| cr.remove(&repo_id));
71 }
72
73 fn update_links(&mut self, record_id: &RecordId, new_links: &[CollectedLink]) {
74 self.remove_links(record_id);
75 self.add_links(record_id, new_links);
76 }
77
78 fn set_account(&mut self, did: &Did, active: bool) {
79 let mut data = self.0.lock().unwrap();
80 if let Some(account) = data.dids.get_mut(did) {
81 *account = active;
82 }
83 }
84
85 fn delete_account(&mut self, did: &Did) {
86 let mut data = self.0.lock().unwrap();
87 if let Some(links) = data.links.get(did) {
88 let links = links.clone();
89 for (repo_id, targets) in links {
90 let targets = targets.clone();
91 for (record_path, target) in targets {
92 data.targets
93 .get_mut(&target)
94 .expect("must have the target if we have a link saved")
95 .get_mut(&Source::new(&repo_id.collection, &record_path.0))
96 .expect("must have the target at this path if we have a link to it saved")
97 .iter_mut()
98 .find(|d| **d == Some((did.clone(), repo_id.rkey.clone())))
99 .expect("lkasjdlfkj")
100 .take();
101 }
102 }
103 }
104 data.links.remove(did); // nb: this is removing by a whole prefix in kv context
105 data.dids.remove(did);
106 }
107}
108
109impl Default for MemStorage {
110 fn default() -> Self {
111 Self::new()
112 }
113}
114
115impl LinkStorage for MemStorage {
116 fn push(&mut self, event: &ActionableEvent, _cursor: u64) -> Result<()> {
117 match event {
118 ActionableEvent::CreateLinks { record_id, links } => self.add_links(record_id, links),
119 ActionableEvent::UpdateLinks {
120 record_id,
121 new_links,
122 } => self.update_links(record_id, new_links),
123 ActionableEvent::DeleteRecord(record_id) => self.remove_links(record_id),
124 ActionableEvent::ActivateAccount(did) => self.set_account(did, true),
125 ActionableEvent::DeactivateAccount(did) => self.set_account(did, false),
126 ActionableEvent::DeleteAccount(did) => self.delete_account(did),
127 }
128 Ok(())
129 }
130
131 fn to_readable(&mut self) -> impl LinkReader {
132 self.clone()
133 }
134}
135
136impl LinkReader for MemStorage {
137 fn get_many_to_many_counts(
138 &self,
139 target: &str,
140 collection: &str,
141 path: &str,
142 path_to_other: &str,
143 limit: u64,
144 after: Option<String>,
145 filter_dids: &HashSet<Did>,
146 filter_to_targets: &HashSet<String>,
147 ) -> Result<PagedOrderedCollection<(String, u64, u64), String>> {
148 let data = self.0.lock().unwrap();
149 let Some(paths) = data.targets.get(&Target::new(target)) else {
150 return Ok(PagedOrderedCollection::default());
151 };
152 let Some(linkers) = paths.get(&Source::new(collection, path)) else {
153 return Ok(PagedOrderedCollection::default());
154 };
155
156 let path_to_other = RecordPath::new(path_to_other);
157 let filter_to_targets: HashSet<Target> =
158 HashSet::from_iter(filter_to_targets.iter().map(|s| Target::new(s)));
159
160 let mut grouped_counts: HashMap<Target, (u64, HashSet<Did>)> = HashMap::new();
161 for (did, rkey) in linkers.iter().flatten().cloned() {
162 if !filter_dids.is_empty() && !filter_dids.contains(&did) {
163 continue;
164 }
165 if let Some(fwd_target) = data
166 .links
167 .get(&did)
168 .unwrap_or(&HashMap::new())
169 .get(&RepoId {
170 collection: collection.to_string(),
171 rkey,
172 })
173 .unwrap_or(&Vec::new())
174 .iter()
175 .filter_map(|(path, target)| {
176 if *path == path_to_other
177 && (filter_to_targets.is_empty() || filter_to_targets.contains(target))
178 {
179 Some(target)
180 } else {
181 None
182 }
183 })
184 .take(1)
185 .next()
186 {
187 let e = grouped_counts.entry(fwd_target.clone()).or_default();
188 e.0 += 1;
189 e.1.insert(did.clone());
190 }
191 }
192 let mut items: Vec<(String, u64, u64)> = grouped_counts
193 .iter()
194 .map(|(k, (n, u))| (k.0.clone(), *n, u.len() as u64))
195 .collect();
196 items.sort();
197 items = items
198 .into_iter()
199 .skip_while(|(t, _, _)| after.as_ref().map(|a| t <= a).unwrap_or(false))
200 .take(limit as usize)
201 .collect();
202 let next = if items.len() as u64 >= limit {
203 items.last().map(|(t, _, _)| t.clone())
204 } else {
205 None
206 };
207 Ok(PagedOrderedCollection { items, next })
208 }
209
210 fn get_count(&self, target: &str, collection: &str, path: &str) -> Result<u64> {
211 let data = self.0.lock().unwrap();
212 let Some(paths) = data.targets.get(&Target::new(target)) else {
213 return Ok(0);
214 };
215 let Some(linkers) = paths.get(&Source::new(collection, path)) else {
216 return Ok(0);
217 };
218 Ok(linkers.iter().flatten().count() as u64)
219 }
220
221 fn get_distinct_did_count(&self, target: &str, collection: &str, path: &str) -> Result<u64> {
222 let data = self.0.lock().unwrap();
223 let Some(paths) = data.targets.get(&Target::new(target)) else {
224 return Ok(0);
225 };
226 let Some(linkers) = paths.get(&Source::new(collection, path)) else {
227 return Ok(0);
228 };
229 Ok(linkers
230 .iter()
231 .flatten()
232 .map(|(did, _)| did)
233 .collect::<HashSet<_>>()
234 .len() as u64)
235 }
236
237 fn get_links(
238 &self,
239 target: &str,
240 collection: &str,
241 path: &str,
242 limit: u64,
243 until: Option<u64>,
244 filter_dids: &HashSet<Did>,
245 ) -> Result<PagedAppendingCollection<RecordId>> {
246 let data = self.0.lock().unwrap();
247 let Some(paths) = data.targets.get(&Target::new(target)) else {
248 return Ok(PagedAppendingCollection {
249 version: (0, 0),
250 items: Vec::new(),
251 next: None,
252 total: 0,
253 });
254 };
255 let Some(did_rkeys) = paths.get(&Source::new(collection, path)) else {
256 return Ok(PagedAppendingCollection {
257 version: (0, 0),
258 items: Vec::new(),
259 next: None,
260 total: 0,
261 });
262 };
263
264 let did_rkeys: Vec<_> = if !filter_dids.is_empty() {
265 did_rkeys
266 .iter()
267 .filter(|m| {
268 Option::<(Did, RKey)>::clone(m)
269 .map(|(did, _)| filter_dids.contains(&did))
270 .unwrap_or(false)
271 })
272 .cloned()
273 .collect()
274 } else {
275 did_rkeys.to_vec()
276 };
277
278 let total = did_rkeys.len();
279 let end = until
280 .map(|u| std::cmp::min(u as usize, total))
281 .unwrap_or(total);
282 let begin = end.saturating_sub(limit as usize);
283 let next = if begin == 0 { None } else { Some(begin as u64) };
284
285 let alive = did_rkeys.iter().flatten().count();
286 let gone = total - alive;
287
288 let items: Vec<_> = did_rkeys[begin..end]
289 .iter()
290 .rev()
291 .flatten()
292 .filter(|(did, _)| *data.dids.get(did).expect("did must be in dids"))
293 .map(|(did, rkey)| RecordId {
294 did: did.clone(),
295 rkey: rkey.0.clone(),
296 collection: collection.to_string(),
297 })
298 .collect();
299
300 Ok(PagedAppendingCollection {
301 version: (total as u64, gone as u64),
302 items,
303 next,
304 total: alive as u64,
305 })
306 }
307
308 fn get_distinct_dids(
309 &self,
310 target: &str,
311 collection: &str,
312 path: &str,
313 limit: u64,
314 until: Option<u64>,
315 ) -> Result<PagedAppendingCollection<Did>> {
316 let data = self.0.lock().unwrap();
317 let Some(paths) = data.targets.get(&Target::new(target)) else {
318 return Ok(PagedAppendingCollection {
319 version: (0, 0),
320 items: Vec::new(),
321 next: None,
322 total: 0,
323 });
324 };
325 let Some(did_rkeys) = paths.get(&Source::new(collection, path)) else {
326 return Ok(PagedAppendingCollection {
327 version: (0, 0),
328 items: Vec::new(),
329 next: None,
330 total: 0,
331 });
332 };
333
334 let dids: Vec<Option<Did>> = {
335 let mut seen = HashSet::new();
336 did_rkeys
337 .iter()
338 .map(|o| {
339 o.clone().and_then(|(did, _)| {
340 if seen.contains(&did) {
341 None
342 } else {
343 seen.insert(did.clone());
344 Some(did)
345 }
346 })
347 })
348 .collect()
349 };
350
351 let total = dids.len();
352 let end = until
353 .map(|u| std::cmp::min(u as usize, total))
354 .unwrap_or(total);
355 let begin = end.saturating_sub(limit as usize);
356 let next = if begin == 0 { None } else { Some(begin as u64) };
357
358 let alive = dids.iter().flatten().count();
359 let gone = total - alive;
360
361 let items: Vec<Did> = dids[begin..end]
362 .iter()
363 .rev()
364 .flatten()
365 .filter(|did| *data.dids.get(did).expect("did must be in dids"))
366 .cloned()
367 .collect();
368
369 Ok(PagedAppendingCollection {
370 version: (total as u64, gone as u64),
371 items,
372 next,
373 total: alive as u64,
374 })
375 }
376
377 fn get_all_record_counts(&self, target: &str) -> Result<HashMap<String, HashMap<String, u64>>> {
378 let data = self.0.lock().unwrap();
379 let mut out: HashMap<String, HashMap<String, u64>> = HashMap::new();
380 if let Some(asdf) = data.targets.get(&Target::new(target)) {
381 for (Source { collection, path }, linkers) in asdf {
382 let count = linkers.iter().flatten().count() as u64;
383 out.entry(collection.to_string())
384 .or_default()
385 .insert(path.to_string(), count);
386 }
387 }
388 Ok(out)
389 }
390
391 fn get_all_counts(
392 &self,
393 target: &str,
394 ) -> Result<HashMap<String, HashMap<String, CountsByCount>>> {
395 let data = self.0.lock().unwrap();
396 let mut out: HashMap<String, HashMap<String, CountsByCount>> = HashMap::new();
397 if let Some(asdf) = data.targets.get(&Target::new(target)) {
398 for (Source { collection, path }, linkers) in asdf {
399 let records = linkers.iter().flatten().count() as u64;
400 let distinct_dids = linkers
401 .iter()
402 .flatten()
403 .map(|(did, _)| did)
404 .collect::<HashSet<_>>()
405 .len() as u64;
406 out.entry(collection.to_string()).or_default().insert(
407 path.to_string(),
408 CountsByCount {
409 records,
410 distinct_dids,
411 },
412 );
413 }
414 }
415 Ok(out)
416 }
417
418 fn get_stats(&self) -> Result<StorageStats> {
419 let data = self.0.lock().unwrap();
420 let dids = data.dids.len() as u64;
421 let targetables = data
422 .targets
423 .values()
424 .map(|sources| sources.len())
425 .sum::<usize>() as u64;
426 let linking_records = data.links.values().map(|recs| recs.len()).sum::<usize>() as u64;
427 Ok(StorageStats {
428 dids,
429 targetables,
430 linking_records,
431 started_at: None,
432 other_data: Default::default(),
433 })
434 }
435}
436
437#[derive(Debug, PartialEq, Hash, Eq, Clone)]
438struct Target(String);
439
440impl Target {
441 fn new(t: &str) -> Self {
442 Self(t.into())
443 }
444}
445
446#[derive(Debug, PartialEq, Hash, Eq, Clone)]
447struct Source {
448 collection: String,
449 path: String,
450}
451
452impl Source {
453 fn new(collection: &str, path: &str) -> Self {
454 Self {
455 collection: collection.into(),
456 path: path.into(),
457 }
458 }
459}
460
461#[derive(Debug, PartialEq, Hash, Eq, Clone)]
462struct RKey(String);
463
464#[derive(Debug, PartialEq, Hash, Eq, Clone)]
465struct RepoId {
466 collection: String,
467 rkey: RKey,
468}
469
470impl RepoId {
471 fn from_record_id(record_id: &RecordId) -> Self {
472 Self {
473 collection: record_id.collection.clone(),
474 rkey: RKey(record_id.rkey.clone()),
475 }
476 }
477}
478
479#[derive(Debug, PartialEq, Hash, Eq, Clone)]
480struct RecordPath(String);
481
482impl RecordPath {
483 fn new(rp: &str) -> Self {
484 Self(rp.into())
485 }
486}