Get DIDs of our wanted PDSes at startup time #4

merged
opened by lewis.moe targeting main

So that we can use the DIDs to prune the saved firehose stream into a nice subset! This PR doesn't have any re-checking which could be done periodically to make sure we're getting the most up to date data from the PDSes (which could also be done by listening to the account move events from the firehose).

Changed files
+176 -6
src
+3
.env.example
···
+
# the main relay to listen to
+
RELAY_URL=wss://pds.upcloud.world
+
# shitsky saves state here
DATABASE_URL=postgres://postgres:postgres@localhost:5432/shitsky
+1 -1
Cargo.toml
···
futures-util = "0.3.31"
listenfd = "1.0.2"
maud = { version = "0.27", features = ["axum"] }
-
reqwest = "0.12.23"
+
reqwest = { version = "0.12", features = ["json"] }
rs-car = "0.5.0"
serde = { version = "1.0", features = ["derive"] }
serde_bytes = "0.11.19"
+30 -5
src/main.rs
···
pub mod firehose;
use firehose::{FirehoseEvent, FirehoseOptions, subscribe_repos};
+
mod pds;
+
use pds::get_all_active_dids_from_pdses;
+
type Db = PgPool;
#[tokio::main]
···
tracing_subscriber::fmt::init();
dotenvy::dotenv().ok();
+
let pds_hosts_str = std::env::var("PDS_LIST")?;
+
+
let pds_hosts: Vec<String> = pds_hosts_str
+
.split(',')
+
.map(|s| s.trim().to_string())
+
.filter(|s| !s.is_empty())
+
.collect();
+
+
if pds_hosts.is_empty() {
+
tracing::error!("Error: PDS_LIST environment variable is empty or contains only commas.");
+
return Ok(());
+
}
+
+
tracing::info!("Querying {} PDS(es): {:?}", pds_hosts.len(), pds_hosts);
+
+
let _all_dids = get_all_active_dids_from_pdses(&pds_hosts).await?;
+
let db_url = std::env::var("DATABASE_URL").expect("DATABASE_URL must be set");
let pool = PgPoolOptions::new()
.max_connections(5)
···
let web_server_pool = pool.clone();
tokio::spawn(async move { web_server(web_server_pool).await });
-
firehose_subscriber(pool).await;
+
let relay_url = std::env::var("RELAY_URL").unwrap_or_default();
+
firehose_subscriber(pool, relay_url).await;
Ok(())
}
-
async fn firehose_subscriber(db: Db) {
+
async fn firehose_subscriber(db: Db, relay_url: String) {
tracing::info!("Starting firehose subscriber...");
-
let options = FirehoseOptions {
-
relay_url: "wss://bsky.network".to_string(),
-
..Default::default()
+
let options = if relay_url.is_empty() {
+
FirehoseOptions::default()
+
} else {
+
FirehoseOptions {
+
relay_url,
+
..Default::default()
+
}
};
let mut stream = Box::pin(subscribe_repos(options));
+142
src/pds.rs
···
+
use futures_util::future::join_all;
+
use reqwest::Client;
+
use serde::Deserialize;
+
use thiserror::Error;
+
+
#[derive(Debug, Error)]
+
pub enum PdsError {
+
#[error("Network request failed: {0}")]
+
RequestError(#[from] reqwest::Error),
+
#[error("Failed to join task: {0}")]
+
JoinError(#[from] tokio::task::JoinError),
+
#[error("Environment variable not found: {0}")]
+
EnvVarError(#[from] std::env::VarError),
+
}
+
+
#[derive(Deserialize, Debug)]
+
struct Repo {
+
did: String,
+
}
+
+
#[derive(Deserialize, Debug)]
+
#[serde(rename_all = "camelCase")]
+
struct ListReposResponse {
+
cursor: Option<String>,
+
repos: Vec<Repo>,
+
}
+
+
#[derive(Deserialize, Debug)]
+
#[serde(rename_all = "camelCase")]
+
struct DescribeRepoResponse {
+
handle_is_correct: bool,
+
}
+
+
pub async fn get_all_active_dids_from_pdses(pds_hosts: &[String]) -> Result<Vec<String>, PdsError> {
+
let client = Client::new();
+
let mut tasks = Vec::new();
+
+
for host in pds_hosts {
+
let host_clone = host.clone();
+
let client_clone = client.clone();
+
tasks.push(tokio::spawn(async move {
+
fetch_active_dids_from_single_pds(host_clone, client_clone).await
+
}));
+
}
+
+
let results = join_all(tasks).await;
+
let mut all_dids = Vec::new();
+
+
for result in results {
+
let pds_dids = result??;
+
all_dids.extend(pds_dids);
+
}
+
+
tracing::info!("--- Sample of fetched ACTIVE DIDs (first 10) ---");
+
for did in all_dids.iter().take(10) {
+
tracing::info!("{}", did);
+
}
+
tracing::info!("... and {} more.", all_dids.len().saturating_sub(10));
+
tracing::info!("--- Total Active DIDs fetched: {} ---", all_dids.len());
+
+
Ok(all_dids)
+
}
+
+
async fn fetch_active_dids_from_single_pds(
+
host: String,
+
client: Client,
+
) -> Result<Vec<String>, PdsError> {
+
let mut active_dids = Vec::new();
+
let mut cursor: Option<String> = None;
+
let limit = 1000;
+
+
tracing::info!("Fetching active DIDs from PDS: {}", host);
+
+
loop {
+
let url = match &cursor {
+
Some(c) => format!(
+
"https://{}/xrpc/com.atproto.sync.listRepos?limit={}&cursor={}",
+
host, limit, c
+
),
+
None => format!(
+
"https://{}/xrpc/com.atproto.sync.listRepos?limit={}",
+
host, limit
+
),
+
};
+
+
let response: ListReposResponse = client.get(&url).send().await?.json().await?;
+
let dids_on_page: Vec<String> = response.repos.into_iter().map(|r| r.did).collect();
+
+
if !dids_on_page.is_empty() {
+
let mut check_tasks = Vec::new();
+
for did in dids_on_page {
+
let host_clone = host.clone();
+
let client_clone = client.clone();
+
check_tasks.push(tokio::spawn(async move {
+
let url = format!(
+
"https://{}/xrpc/com.atproto.repo.describeRepo?repo={}",
+
host_clone, did
+
);
+
+
let repo_info_result = async {
+
client_clone
+
.get(&url)
+
.send()
+
.await?
+
.json::<DescribeRepoResponse>()
+
.await
+
}
+
.await;
+
+
if let Ok(repo_info) = repo_info_result
+
&& repo_info.handle_is_correct
+
{
+
Some(did)
+
} else {
+
None
+
}
+
}));
+
}
+
+
let checked_results = join_all(check_tasks).await;
+
for result in checked_results {
+
if let Ok(Some(active_did)) = result {
+
active_dids.push(active_did);
+
}
+
}
+
}
+
+
if let Some(next_cursor) = response.cursor {
+
cursor = Some(next_cursor);
+
} else {
+
break;
+
}
+
}
+
+
tracing::info!(
+
"Finished fetching {} active DIDs from {}.",
+
active_dids.len(),
+
host
+
);
+
+
Ok(active_dids)
+
}