diff --git a/.env.example b/.env.example index 4beeeb7..873d716 100644 --- a/.env.example +++ b/.env.example @@ -1,3 +1,6 @@ +# the main relay to listen to +RELAY_URL=wss://pds.upcloud.world + # shitsky saves state here DATABASE_URL=postgres://postgres:postgres@localhost:5432/shitsky diff --git a/Cargo.toml b/Cargo.toml index 9c89f1e..ad6b57b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -13,7 +13,7 @@ futures = "0.3.31" futures-util = "0.3.31" listenfd = "1.0.2" maud = { version = "0.27", features = ["axum"] } -reqwest = "0.12.23" +reqwest = { version = "0.12", features = ["json"] } rs-car = "0.5.0" serde = { version = "1.0", features = ["derive"] } serde_bytes = "0.11.19" diff --git a/src/main.rs b/src/main.rs index fe04900..15d19a0 100644 --- a/src/main.rs +++ b/src/main.rs @@ -12,6 +12,9 @@ use tower_http::trace::TraceLayer; pub mod firehose; use firehose::{FirehoseEvent, FirehoseOptions, subscribe_repos}; +mod pds; +use pds::get_all_active_dids_from_pdses; + type Db = PgPool; #[tokio::main] @@ -19,6 +22,23 @@ async fn main() -> Result<(), Box> { tracing_subscriber::fmt::init(); dotenvy::dotenv().ok(); + let pds_hosts_str = std::env::var("PDS_LIST")?; + + let pds_hosts: Vec = pds_hosts_str + .split(',') + .map(|s| s.trim().to_string()) + .filter(|s| !s.is_empty()) + .collect(); + + if pds_hosts.is_empty() { + tracing::error!("Error: PDS_LIST environment variable is empty or contains only commas."); + return Ok(()); + } + + tracing::info!("Querying {} PDS(es): {:?}", pds_hosts.len(), pds_hosts); + + let _all_dids = get_all_active_dids_from_pdses(&pds_hosts).await?; + let db_url = std::env::var("DATABASE_URL").expect("DATABASE_URL must be set"); let pool = PgPoolOptions::new() .max_connections(5) @@ -32,17 +52,22 @@ async fn main() -> Result<(), Box> { let web_server_pool = pool.clone(); tokio::spawn(async move { web_server(web_server_pool).await }); - firehose_subscriber(pool).await; + let relay_url = std::env::var("RELAY_URL").unwrap_or_default(); + firehose_subscriber(pool, relay_url).await; Ok(()) } -async fn firehose_subscriber(db: Db) { +async fn firehose_subscriber(db: Db, relay_url: String) { tracing::info!("Starting firehose subscriber..."); - let options = FirehoseOptions { - relay_url: "wss://bsky.network".to_string(), - ..Default::default() + let options = if relay_url.is_empty() { + FirehoseOptions::default() + } else { + FirehoseOptions { + relay_url, + ..Default::default() + } }; let mut stream = Box::pin(subscribe_repos(options)); diff --git a/src/pds.rs b/src/pds.rs new file mode 100644 index 0000000..6321e5f --- /dev/null +++ b/src/pds.rs @@ -0,0 +1,142 @@ +use futures_util::future::join_all; +use reqwest::Client; +use serde::Deserialize; +use thiserror::Error; + +#[derive(Debug, Error)] +pub enum PdsError { + #[error("Network request failed: {0}")] + RequestError(#[from] reqwest::Error), + #[error("Failed to join task: {0}")] + JoinError(#[from] tokio::task::JoinError), + #[error("Environment variable not found: {0}")] + EnvVarError(#[from] std::env::VarError), +} + +#[derive(Deserialize, Debug)] +struct Repo { + did: String, +} + +#[derive(Deserialize, Debug)] +#[serde(rename_all = "camelCase")] +struct ListReposResponse { + cursor: Option, + repos: Vec, +} + +#[derive(Deserialize, Debug)] +#[serde(rename_all = "camelCase")] +struct DescribeRepoResponse { + handle_is_correct: bool, +} + +pub async fn get_all_active_dids_from_pdses(pds_hosts: &[String]) -> Result, PdsError> { + let client = Client::new(); + let mut tasks = Vec::new(); + + for host in pds_hosts { + let host_clone = host.clone(); + let client_clone = client.clone(); + tasks.push(tokio::spawn(async move { + fetch_active_dids_from_single_pds(host_clone, client_clone).await + })); + } + + let results = join_all(tasks).await; + let mut all_dids = Vec::new(); + + for result in results { + let pds_dids = result??; + all_dids.extend(pds_dids); + } + + tracing::info!("--- Sample of fetched ACTIVE DIDs (first 10) ---"); + for did in all_dids.iter().take(10) { + tracing::info!("{}", did); + } + tracing::info!("... and {} more.", all_dids.len().saturating_sub(10)); + tracing::info!("--- Total Active DIDs fetched: {} ---", all_dids.len()); + + Ok(all_dids) +} + +async fn fetch_active_dids_from_single_pds( + host: String, + client: Client, +) -> Result, PdsError> { + let mut active_dids = Vec::new(); + let mut cursor: Option = None; + let limit = 1000; + + tracing::info!("Fetching active DIDs from PDS: {}", host); + + loop { + let url = match &cursor { + Some(c) => format!( + "https://{}/xrpc/com.atproto.sync.listRepos?limit={}&cursor={}", + host, limit, c + ), + None => format!( + "https://{}/xrpc/com.atproto.sync.listRepos?limit={}", + host, limit + ), + }; + + let response: ListReposResponse = client.get(&url).send().await?.json().await?; + let dids_on_page: Vec = response.repos.into_iter().map(|r| r.did).collect(); + + if !dids_on_page.is_empty() { + let mut check_tasks = Vec::new(); + for did in dids_on_page { + let host_clone = host.clone(); + let client_clone = client.clone(); + check_tasks.push(tokio::spawn(async move { + let url = format!( + "https://{}/xrpc/com.atproto.repo.describeRepo?repo={}", + host_clone, did + ); + + let repo_info_result = async { + client_clone + .get(&url) + .send() + .await? + .json::() + .await + } + .await; + + if let Ok(repo_info) = repo_info_result + && repo_info.handle_is_correct + { + Some(did) + } else { + None + } + })); + } + + let checked_results = join_all(check_tasks).await; + for result in checked_results { + if let Ok(Some(active_did)) = result { + active_dids.push(active_did); + } + } + } + + if let Some(next_cursor) = response.cursor { + cursor = Some(next_cursor); + } else { + break; + } + } + + tracing::info!( + "Finished fetching {} active DIDs from {}.", + active_dids.len(), + host + ); + + Ok(active_dids) +}