From 36b9c54a79bc21984e3ef0ca8ccf80e6ea0ae90f Mon Sep 17 00:00:00 2001 From: Bailey Townsend Date: Wed, 20 Aug 2025 14:36:41 -0500 Subject: [PATCH] 2FA gatekeeping --- Cargo.lock | 10 + Cargo.toml | 6 +- README.md | 11 +- migrations_bells_and_whistles/.keep | 3 - src/helpers.rs | 524 ++++++++++++++++++++++++++++ src/main.rs | 79 +++-- src/middleware.rs | 53 +-- src/oauth_provider.rs | 141 ++++++++ src/xrpc/com_atproto_server.rs | 277 ++++----------- src/xrpc/helpers.rs | 150 -------- src/xrpc/mod.rs | 1 - 11 files changed, 823 insertions(+), 432 deletions(-) delete mode 100644 migrations_bells_and_whistles/.keep create mode 100644 src/helpers.rs create mode 100644 src/oauth_provider.rs delete mode 100644 src/xrpc/helpers.rs diff --git a/Cargo.lock b/Cargo.lock index 7955d87..fe3f1ae 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -287,7 +287,9 @@ checksum = "c469d952047f47f91b68d1cba3f10d63c11d73e4636f24f08daf0278abf01c4d" dependencies = [ "android-tzdata", "iana-time-zone", + "js-sys", "num-traits", + "wasm-bindgen", "windows-link", ] @@ -1652,18 +1654,22 @@ dependencies = [ name = "pds_gatekeeper" version = "0.1.0" dependencies = [ + "anyhow", "axum", "axum-template", + "chrono", "dotenvy", "handlebars", "hex", "hyper-util", "jwt-compact", "lettre", + "rand 0.9.2", "rust-embed", "scrypt", "serde", "serde_json", + "sha2", "sqlx", "tokio", "tower-http", @@ -2393,6 +2399,7 @@ checksum = "ee6798b1838b6a0f69c007c133b8df5866302197e404e8b6ee8ed3e3a5e68dc6" dependencies = [ "base64", "bytes", + "chrono", "crc", "crossbeam-queue", "either", @@ -2470,6 +2477,7 @@ dependencies = [ "bitflags", "byteorder", "bytes", + "chrono", "crc", "digest", "dotenvy", @@ -2511,6 +2519,7 @@ dependencies = [ "base64", "bitflags", "byteorder", + "chrono", "crc", "dotenvy", "etcetera", @@ -2545,6 +2554,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c2d12fe70b2c1b4401038055f90f151b78208de1f9f89a7dbfd41587a10c3eea" dependencies = [ "atoi", + "chrono", "flume", "futures-channel", "futures-core", diff --git a/Cargo.toml b/Cargo.toml index 796b7b5..9bed70d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -6,7 +6,7 @@ edition = "2024" [dependencies] axum = { version = "0.8.4", features = ["macros", "json"] } tokio = { version = "1.47.1", features = ["rt-multi-thread", "macros", "signal"] } -sqlx = { version = "0.8.6", features = ["runtime-tokio-rustls", "sqlite", "migrate"] } +sqlx = { version = "0.8.6", features = ["runtime-tokio-rustls", "sqlite", "migrate", "chrono"] } dotenvy = "0.15.7" serde = { version = "1.0", features = ["derive"] } serde_json = "1.0" @@ -22,3 +22,7 @@ lettre = { version = "0.11.18", features = ["tokio1", "pool", "tokio1-native-tls handlebars = { version = "6.3.2", features = ["rust-embed"] } rust-embed = "8.7.2" axum-template = { version = "3.0.0", features = ["handlebars"] } +rand = "0.9.2" +anyhow = "1.0.99" +chrono = "0.4.41" +sha2 = "0.10" diff --git a/README.md b/README.md index 1902fc2..bf8deef 100644 --- a/README.md +++ b/README.md @@ -12,12 +12,8 @@ logic of these endpoints still happens on the PDS via a proxied request, just so ## 2FA -- [x] Ability to turn on/off 2FA -- [x] getSession overwrite to set the `emailAuthFactor` flag if the user has 2FA turned on -- [x] send an email using the `PDS_EMAIL_SMTP_URL` with a handlebar email template like Bluesky's 2FA sign in email. -- [ ] generate a 2FA code -- [ ] createSession gatekeeping (It does stop logins, just eh, doesn't actually send a real code or check it yet) -- [ ] oauth endpoint gatekeeping +- Overrides The login endpoint to add 2FA for both Bluesky client logged in and OAuth logins +- Overrides the settings endpoints as well. As long as you have a confirmed email you can turn on 2FA ## Captcha on Create Account @@ -25,6 +21,8 @@ Future feature? # Setup +We are getting close! Testing now + Nothing here yet! If you are brave enough to try before full release, let me know and I'll help you set it up. But I want to run it locally on my own PDS first to test run it a bit. @@ -37,6 +35,7 @@ http://localhost { path /xrpc/com.atproto.server.getSession path /xrpc/com.atproto.server.updateEmail path /xrpc/com.atproto.server.createSession + path /@atproto/oauth-provider/~api/sign-in } handle @gatekeeper { diff --git a/migrations_bells_and_whistles/.keep b/migrations_bells_and_whistles/.keep deleted file mode 100644 index 501ab63..0000000 --- a/migrations_bells_and_whistles/.keep +++ /dev/null @@ -1,3 +0,0 @@ -# This directory holds SQLx migrations for the bells_and_whistles.sqlite database. -# It is intentionally empty for now; running `sqlx::migrate!` will still ensure the -# migrations table exists and succeed with zero migrations. diff --git a/src/helpers.rs b/src/helpers.rs new file mode 100644 index 0000000..7790d93 --- /dev/null +++ b/src/helpers.rs @@ -0,0 +1,524 @@ +use crate::AppState; +use crate::helpers::TokenCheckError::InvalidToken; +use anyhow::anyhow; +use axum::body::{Body, to_bytes}; +use axum::extract::Request; +use axum::http::header::CONTENT_TYPE; +use axum::http::{HeaderMap, StatusCode, Uri}; +use axum::response::{IntoResponse, Response}; +use axum_template::TemplateEngine; +use chrono::Utc; +use lettre::message::{MultiPart, SinglePart, header}; +use lettre::{AsyncTransport, Message}; +use rand::Rng; +use serde::de::DeserializeOwned; +use serde_json::{Map, Value}; +use sha2::{Digest, Sha256}; +use sqlx::SqlitePool; +use tracing::{error, log}; + +///Used to generate the email 2fa code +const UPPERCASE_BASE32_CHARS: &[u8] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZ234567"; + +/// The result of a proxied call that attempts to parse JSON. +pub enum ProxiedResult { + /// Successfully parsed JSON body along with original response headers. + Parsed { value: T, _headers: HeaderMap }, + /// Could not or should not parse: return the original (or rebuilt) response as-is. + Passthrough(Response), +} + +/// Proxy the incoming request to the PDS base URL plus the provided path and attempt to parse +/// the successful response body as JSON into `T`. +/// +pub async fn proxy_get_json( + state: &AppState, + mut req: Request, + path: &str, +) -> Result, StatusCode> +where + T: DeserializeOwned, +{ + let uri = format!("{}{}", state.pds_base_url, path); + *req.uri_mut() = Uri::try_from(uri).map_err(|_| StatusCode::BAD_REQUEST)?; + + let result = state + .reverse_proxy_client + .request(req) + .await + .map_err(|_| StatusCode::BAD_REQUEST)? + .into_response(); + + if result.status() != StatusCode::OK { + return Ok(ProxiedResult::Passthrough(result)); + } + + let response_headers = result.headers().clone(); + let body = result.into_body(); + let body_bytes = to_bytes(body, usize::MAX) + .await + .map_err(|_| StatusCode::BAD_REQUEST)?; + + match serde_json::from_slice::(&body_bytes) { + Ok(value) => Ok(ProxiedResult::Parsed { + value, + _headers: response_headers, + }), + Err(err) => { + error!(%err, "failed to parse proxied JSON response; returning original body"); + let mut builder = Response::builder().status(StatusCode::OK); + if let Some(headers) = builder.headers_mut() { + *headers = response_headers; + } + let resp = builder + .body(Body::from(body_bytes)) + .map_err(|_| StatusCode::BAD_REQUEST)?; + Ok(ProxiedResult::Passthrough(resp)) + } + } +} + +/// Build a JSON error response with the required Content-Type header +/// Content-Type: application/json;charset=utf-8 +/// Body shape: { "error": string, "message": string } +pub fn json_error_response( + status: StatusCode, + error: impl Into, + message: impl Into, +) -> Result, StatusCode> { + let body_str = match serde_json::to_string(&serde_json::json!({ + "error": error.into(), + "message": message.into(), + })) { + Ok(s) => s, + Err(_) => return Err(StatusCode::BAD_REQUEST), + }; + + Response::builder() + .status(status) + .header(CONTENT_TYPE, "application/json;charset=utf-8") + .body(Body::from(body_str)) + .map_err(|_| StatusCode::BAD_REQUEST) +} + +/// Build a JSON error response with the required Content-Type header +/// Content-Type: application/json (oauth endpoint does not like utf ending) +/// Body shape: { "error": string, "error_description": string } +pub fn oauth_json_error_response( + status: StatusCode, + error: impl Into, + message: impl Into, +) -> Result, StatusCode> { + let body_str = match serde_json::to_string(&serde_json::json!({ + "error": error.into(), + "error_description": message.into(), + })) { + Ok(s) => s, + Err(_) => return Err(StatusCode::BAD_REQUEST), + }; + + Response::builder() + .status(status) + .header(CONTENT_TYPE, "application/json") + .body(Body::from(body_str)) + .map_err(|_| StatusCode::BAD_REQUEST) +} + +/// Creates a random token of 10 characters for email 2FA +pub fn get_random_token() -> String { + let mut rng = rand::rng(); + + let mut full_code = String::with_capacity(10); + for _ in 0..10 { + let idx = rng.random_range(0..UPPERCASE_BASE32_CHARS.len()); + full_code.push(UPPERCASE_BASE32_CHARS[idx] as char); + } + + //The PDS implementation creates in lowercase, then converts to uppercase. + //Just going a head and doing uppercase here. + let slice_one = &full_code[0..5].to_ascii_uppercase(); + let slice_two = &full_code[5..10].to_ascii_uppercase(); + format!("{slice_one}-{slice_two}") +} + +pub enum TokenCheckError { + InvalidToken, + ExpiredToken, +} + +pub enum AuthResult { + WrongIdentityOrPassword, + /// The string here is the email address to create a hint for oauth + TwoFactorRequired(String), + /// User does not have 2FA enabled, or using an app password, or passes it + ProxyThrough, + TokenCheckFailed(TokenCheckError), +} + +pub enum IdentifierType { + Email, + Did, + Handle, +} + +impl IdentifierType { + fn what_is_it(identifier: String) -> Self { + if identifier.contains("@") { + IdentifierType::Email + } else if identifier.contains("did:") { + IdentifierType::Did + } else { + IdentifierType::Handle + } + } +} + +/// Creates a hex string from the password and salt to find app passwords +fn scrypt_hex(password: &str, salt: &str) -> anyhow::Result { + let params = scrypt::Params::new(14, 8, 1, 64)?; + let mut derived = [0u8; 64]; + scrypt::scrypt(password.as_bytes(), salt.as_bytes(), ¶ms, &mut derived)?; + Ok(hex::encode(derived)) +} + +/// Hashes the app password. did is used as the salt. +pub fn hash_app_password(did: &str, password: &str) -> anyhow::Result { + let mut hasher = Sha256::new(); + hasher.update(did.as_bytes()); + let sha = hasher.finalize(); + let salt = hex::encode(&sha[..16]); + let hash_hex = scrypt_hex(password, &salt)?; + Ok(format!("{salt}:{hash_hex}")) +} + +async fn verify_password(password: &str, password_scrypt: &str) -> anyhow::Result { + // Expected format: "salt:hash" where hash is hex of scrypt(password, salt, 64 bytes) + let mut parts = password_scrypt.splitn(2, ':'); + let salt = match parts.next() { + Some(s) if !s.is_empty() => s, + _ => return Ok(false), + }; + let stored_hash_hex = match parts.next() { + Some(h) if !h.is_empty() => h, + _ => return Ok(false), + }; + + // Derive using the shared helper and compare + let derived_hex = match scrypt_hex(password, salt) { + Ok(h) => h, + Err(_) => return Ok(false), + }; + + Ok(derived_hex.as_str() == stored_hash_hex) +} + +/// Handles the auth checks along with sending a 2fa email +pub async fn preauth_check( + state: &AppState, + identifier: &str, + password: &str, + two_factor_code: Option, + oauth: bool, +) -> anyhow::Result { + // Determine identifier type + let id_type = IdentifierType::what_is_it(identifier.to_string()); + + // Query account DB for did and passwordScrypt based on identifier type + let account_row: Option<(String, String, String, String)> = match id_type { + IdentifierType::Email => { + sqlx::query_as::<_, (String, String, String, String)>( + "SELECT account.did, account.passwordScrypt, account.email, actor.handle + FROM actor + LEFT JOIN account ON actor.did = account.did + where account.email = ? LIMIT 1", + ) + .bind(identifier) + .fetch_optional(&state.account_pool) + .await? + } + IdentifierType::Handle => { + sqlx::query_as::<_, (String, String, String, String)>( + "SELECT account.did, account.passwordScrypt, account.email, actor.handle + FROM actor + LEFT JOIN account ON actor.did = account.did + where actor.handle = ? LIMIT 1", + ) + .bind(identifier) + .fetch_optional(&state.account_pool) + .await? + } + IdentifierType::Did => { + sqlx::query_as::<_, (String, String, String, String)>( + "SELECT account.did, account.passwordScrypt, account.email, actor.handle + FROM actor + LEFT JOIN account ON actor.did = account.did + where account.did = ? LIMIT 1", + ) + .bind(identifier) + .fetch_optional(&state.account_pool) + .await? + } + }; + + if let Some((did, password_scrypt, email, handle)) = account_row { + // Verify password before proceeding to 2FA email step + let verified = verify_password(password, &password_scrypt).await?; + if !verified { + if oauth { + //OAuth does not allow app password logins so just go ahead and send it along it's way + return Ok(AuthResult::WrongIdentityOrPassword); + } + //Theres a chance it could be an app password so check that as well + return match verify_app_password(&state.account_pool, &did, password).await { + Ok(valid) => { + if valid { + //Was a valid app password up to the PDS now + Ok(AuthResult::ProxyThrough) + } else { + Ok(AuthResult::WrongIdentityOrPassword) + } + } + Err(err) => { + log::error!("Error checking the app password: {err}"); + Err(err) + } + }; + } + + // Check two-factor requirement for this DID in the gatekeeper DB + let required_opt = sqlx::query_as::<_, (u8,)>( + "SELECT required FROM two_factor_accounts WHERE did = ? LIMIT 1", + ) + .bind(did.clone()) + .fetch_optional(&state.pds_gatekeeper_pool) + .await?; + + let two_factor_required = match required_opt { + Some(row) => row.0 != 0, + None => false, + }; + + if two_factor_required { + //Two factor is required and a taken was provided + if let Some(two_factor_code) = two_factor_code { + //if the two_factor_code is set need to see if we have a valid token + if !two_factor_code.is_empty() { + return match assert_valid_token( + &state.account_pool, + did.clone(), + two_factor_code, + ) + .await + { + Ok(_) => { + let result_of_cleanup = + delete_all_email_tokens(&state.account_pool, did.clone()).await; + if result_of_cleanup.is_err() { + log::error!( + "There was an error deleting the email tokens after login: {:?}", + result_of_cleanup.err() + ) + } + Ok(AuthResult::ProxyThrough) + } + Err(err) => Ok(AuthResult::TokenCheckFailed(err)), + }; + } + } + + return match create_two_factor_token(&state.account_pool, did).await { + Ok(code) => { + let mut email_data = Map::new(); + email_data.insert("token".to_string(), Value::from(code.clone())); + email_data.insert("handle".to_string(), Value::from(handle.clone())); + let email_body = state + .template_engine + .render("two_factor_code.hbs", email_data)?; + + let email_message = Message::builder() + //TODO prob get the proper type in the state + .from(state.mailer_from.parse()?) + .to(email.parse()?) + .subject("Sign in to Bluesky") + .multipart( + MultiPart::alternative() // This is composed of two parts. + .singlepart( + SinglePart::builder() + .header(header::ContentType::TEXT_PLAIN) + .body(format!("We received a sign-in request for the account @{handle}. Use the code: {code} to sign in. If this wasn't you, we recommend taking steps to protect your account by changing your password at https://bsky.app/settings.")), // Every message should have a plain text fallback. + ) + .singlepart( + SinglePart::builder() + .header(header::ContentType::TEXT_HTML) + .body(email_body), + ), + )?; + match state.mailer.send(email_message).await { + Ok(_) => Ok(AuthResult::TwoFactorRequired(mask_email(email))), + Err(err) => { + log::error!("Error sending the 2FA email: {err}"); + Err(anyhow!(err)) + } + } + } + Err(err) => { + log::error!("error on creating a 2fa token: {err}"); + Err(anyhow!(err)) + } + }; + } + } + + // No local 2FA requirement (or account not found) + Ok(AuthResult::ProxyThrough) +} + +pub async fn create_two_factor_token( + account_db: &SqlitePool, + did: String, +) -> anyhow::Result { + let purpose = "2fa_code"; + + let token = get_random_token(); + let right_now = Utc::now(); + + let res = sqlx::query( + "INSERT INTO email_token (purpose, did, token, requestedAt) + VALUES (?, ?, ?, ?) + ON CONFLICT(purpose, did) DO UPDATE SET + token=excluded.token, + requestedAt=excluded.requestedAt + WHERE did=excluded.did", + ) + .bind(purpose) + .bind(&did) + .bind(&token) + .bind(right_now) + .execute(account_db) + .await; + + match res { + Ok(_) => Ok(token), + Err(err) => { + log::error!("Error creating a two factor token: {err}"); + Err(anyhow::anyhow!(err)) + } + } +} + +pub async fn delete_all_email_tokens(account_db: &SqlitePool, did: String) -> anyhow::Result<()> { + sqlx::query("DELETE FROM email_token WHERE did = ?") + .bind(did) + .execute(account_db) + .await?; + Ok(()) +} + +pub async fn assert_valid_token( + account_db: &SqlitePool, + did: String, + token: String, +) -> Result<(), TokenCheckError> { + let token_upper = token.to_ascii_uppercase(); + let purpose = "2fa_code"; + + let row: Option<(String,)> = sqlx::query_as( + "SELECT requestedAt FROM email_token WHERE purpose = ? AND did = ? AND token = ? LIMIT 1", + ) + .bind(purpose) + .bind(did) + .bind(token_upper) + .fetch_optional(account_db) + .await + .map_err(|err| { + log::error!("Error getting the 2fa token: {err}"); + InvalidToken + })?; + + match row { + None => Err(InvalidToken), + Some(row) => { + // Token lives for 15 minutes + let expiration_ms = 15 * 60_000; + + let requested_at_utc = match chrono::DateTime::parse_from_rfc3339(&row.0) { + Ok(dt) => dt.with_timezone(&Utc), + Err(_) => { + return Err(TokenCheckError::InvalidToken); + } + }; + + let now = Utc::now(); + let age_ms = (now - requested_at_utc).num_milliseconds(); + let expired = age_ms > expiration_ms; + if expired { + return Err(TokenCheckError::ExpiredToken); + } + + Ok(()) + } + } +} + +/// We just need to confirm if it's there or not. Will let the PDS do the actual figuring of permissions +pub async fn verify_app_password( + account_db: &SqlitePool, + did: &str, + password: &str, +) -> anyhow::Result { + let password_scrypt = hash_app_password(did, password)?; + + let row: Option<(i64,)> = sqlx::query_as( + "SELECT Count(*) FROM app_password WHERE did = ? AND passwordScrypt = ? LIMIT 1", + ) + .bind(did) + .bind(password_scrypt) + .fetch_optional(account_db) + .await?; + + Ok(match row { + None => false, + Some((count,)) => count > 0, + }) +} + +/// Mask an email address into a hint like "2***0@p***m". +pub fn mask_email(email: String) -> String { + // Basic split on first '@' + let mut parts = email.splitn(2, '@'); + let local = match parts.next() { + Some(l) => l, + None => return email.to_string(), + }; + let domain_rest = match parts.next() { + Some(d) if !d.is_empty() => d, + _ => return email.to_string(), + }; + + // Helper to mask a single label (keep first and last, middle becomes ***). + fn mask_label(s: &str) -> String { + let chars: Vec = s.chars().collect(); + match chars.len() { + 0 => String::new(), + 1 => format!("{}***", chars[0]), + 2 => format!("{}***{}", chars[0], chars[1]), + _ => format!("{}***{}", chars[0], chars[chars.len() - 1]), + } + } + + // Mask local + let masked_local = mask_label(local); + + // Mask first domain label only, keep the rest of the domain intact + let mut dom_parts = domain_rest.splitn(2, '.'); + let first_label = dom_parts.next().unwrap_or(""); + let rest = dom_parts.next(); + let masked_first = mask_label(first_label); + let masked_domain = if let Some(rest) = rest { + format!("{}.{rest}", masked_first) + } else { + masked_first + }; + + format!("{masked_local}@{masked_domain}") +} diff --git a/src/main.rs b/src/main.rs index 530e7c5..a95eef1 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,9 +1,10 @@ +#![warn(clippy::unwrap_used)] +use crate::oauth_provider::sign_in; use crate::xrpc::com_atproto_server::{create_session, get_session, update_email}; -use axum::middleware as ax_middleware; -mod middleware; use axum::body::Body; use axum::handler::Handler; use axum::http::{Method, header}; +use axum::middleware as ax_middleware; use axum::routing::post; use axum::{Router, routing::get}; use axum_template::engine::Engine; @@ -21,9 +22,12 @@ use tower_governor::GovernorLayer; use tower_governor::governor::GovernorConfigBuilder; use tower_http::compression::CompressionLayer; use tower_http::cors::{Any, CorsLayer}; -use tracing::{error, log}; +use tracing::log; use tracing_subscriber::{EnvFilter, fmt, prelude::*}; +pub mod helpers; +mod middleware; +mod oauth_provider; mod xrpc; type HyperUtilClient = hyper_util::client::legacy::Client; @@ -34,7 +38,7 @@ type HyperUtilClient = hyper_util::client::legacy::Client; struct EmailTemplates; #[derive(Clone)] -struct AppState { +pub struct AppState { account_pool: SqlitePool, pds_gatekeeper_pool: SqlitePool, reverse_proxy_client: HyperUtilClient, @@ -73,7 +77,7 @@ ______________| | || / \ / \||/ \ / \ || | |______________ let intro = "\n\nThis is a PDS gatekeeper\n\nCode: https://tangled.sh/@baileytownsend.dev/pds-gatekeeper\n"; - let banner = format!(" {}\n{}", body, intro); + let banner = format!(" {body}\n{intro}"); ( [(header::CONTENT_TYPE, "text/plain; charset=utf-8")], @@ -84,33 +88,32 @@ ______________| | || / \ / \||/ \ / \ || | |______________ #[tokio::main] async fn main() -> Result<(), Box> { setup_tracing(); - //TODO prod + //TODO may need to change where this reads from? Like an env variable for it's location? Or arg? dotenvy::from_path(Path::new("./pds.env"))?; let pds_root = env::var("PDS_DATA_DIRECTORY")?; - // let pds_root = "/home/baileytownsend/Documents/code/docker_compose/pds/pds_data"; - let account_db_url = format!("{}/account.sqlite", pds_root); - log::info!("accounts_db_url: {}", account_db_url); + let account_db_url = format!("{pds_root}/account.sqlite"); let account_options = SqliteConnectOptions::new() - .journal_mode(SqliteJournalMode::Wal) - .filename(account_db_url); + .filename(account_db_url) + .busy_timeout(Duration::from_secs(5)); let account_pool = SqlitePoolOptions::new() .max_connections(5) .connect_with(account_options) .await?; - let bells_db_url = format!("{}/pds_gatekeeper.sqlite", pds_root); + let bells_db_url = format!("{pds_root}/pds_gatekeeper.sqlite"); let options = SqliteConnectOptions::new() .journal_mode(SqliteJournalMode::Wal) .filename(bells_db_url) - .create_if_missing(true); + .create_if_missing(true) + .busy_timeout(Duration::from_secs(5)); let pds_gatekeeper_pool = SqlitePoolOptions::new() .max_connections(5) .connect_with(options) .await?; - // Run migrations for the bells_and_whistles database + // Run migrations for the extra database // Note: the migrations are embedded at compile time from the given directory // sqlx sqlx::migrate!("./migrations") @@ -130,14 +133,25 @@ async fn main() -> Result<(), Box> { AsyncSmtpTransport::::from_url(smtp_url.as_str())?.build(); //Email templates setup let mut hbs = Handlebars::new(); - let _ = hbs.register_embed_templates::(); + + let users_email_directory = env::var("GATEKEEPER_EMAIL_TEMPLATES_DIRECTORY"); + if let Ok(users_email_directory) = users_email_directory { + hbs.register_template_file( + "two_factor_code.hbs", + format!("{users_email_directory}/two_factor_code.hbs"), + )?; + } else { + let _ = hbs.register_embed_templates::(); + } + + let pds_base_url = + env::var("PDS_BASE_URL").unwrap_or_else(|_| "http://localhost:3000".to_string()); let state = AppState { account_pool, pds_gatekeeper_pool, reverse_proxy_client: client, - //TODO should be env prob - pds_base_url: "http://localhost:3000".to_string(), + pds_base_url, mailer, mailer_from: sent_from, template_engine: Engine::from(hbs), @@ -145,19 +159,28 @@ async fn main() -> Result<(), Box> { // Rate limiting //Allows 5 within 60 seconds, and after 60 should drop one off? So hit 5, then goes to 4 after 60 seconds. - let governor_conf = GovernorConfigBuilder::default() + let create_session_governor_conf = GovernorConfigBuilder::default() .per_second(60) .burst_size(5) .finish() - .unwrap(); - let governor_limiter = governor_conf.limiter().clone(); + .expect("failed to create governor config. this should not happen and is a bug"); + + // Create a second config with the same settings for the other endpoint + let sign_in_governor_conf = GovernorConfigBuilder::default() + .per_second(60) + .burst_size(5) + .finish() + .expect("failed to create governor config. this should not happen and is a bug"); + + let create_session_governor_limiter = create_session_governor_conf.limiter().clone(); + let sign_in_governor_limiter = sign_in_governor_conf.limiter().clone(); let interval = Duration::from_secs(60); // a separate background task to clean up std::thread::spawn(move || { loop { std::thread::sleep(interval); - tracing::info!("rate limiting storage size: {}", governor_limiter.len()); - governor_limiter.retain_recent(); + create_session_governor_limiter.retain_recent(); + sign_in_governor_limiter.retain_recent(); } }); @@ -176,16 +199,20 @@ async fn main() -> Result<(), Box> { "/xrpc/com.atproto.server.updateEmail", post(update_email).layer(ax_middleware::from_fn(middleware::extract_did)), ) + .route( + "/@atproto/oauth-provider/~api/sign-in", + post(sign_in).layer(GovernorLayer::new(sign_in_governor_conf)), + ) .route( "/xrpc/com.atproto.server.createSession", - post(create_session.layer(GovernorLayer::new(governor_conf))), + post(create_session.layer(GovernorLayer::new(create_session_governor_conf))), ) .layer(CompressionLayer::new()) .layer(cors) .with_state(state); - let host = env::var("HOST").unwrap_or_else(|_| "127.0.0.1".to_string()); - let port: u16 = env::var("PORT") + let host = env::var("GATEKEEPER_HOST").unwrap_or_else(|_| "127.0.0.1".to_string()); + let port: u16 = env::var("GATEKEEPER_PORT") .ok() .and_then(|s| s.parse().ok()) .unwrap_or(8080); @@ -202,7 +229,7 @@ async fn main() -> Result<(), Box> { .with_graceful_shutdown(shutdown_signal()); if let Err(err) = server.await { - error!(error = %err, "server error"); + log::error!("server error:{err}"); } Ok(()) diff --git a/src/middleware.rs b/src/middleware.rs index 29c5cc2..ef9c104 100644 --- a/src/middleware.rs +++ b/src/middleware.rs @@ -1,4 +1,4 @@ -use crate::xrpc::helpers::json_error_response; +use crate::helpers::json_error_response; use axum::extract::Request; use axum::http::{HeaderMap, StatusCode}; use axum::middleware::Next; @@ -7,6 +7,7 @@ use jwt_compact::alg::{Hs256, Hs256Key}; use jwt_compact::{AlgorithmExt, Claims, Token, UntrustedToken, ValidationError}; use serde::{Deserialize, Serialize}; use std::env; +use tracing::log; #[derive(Clone, Debug)] pub struct Did(pub Option); @@ -22,59 +23,43 @@ pub async fn extract_did(mut req: Request, next: Next) -> impl IntoResponse { match token { Ok(token) => { match token { - None => { - return json_error_response( - StatusCode::BAD_REQUEST, - "TokenRequired", - "", - ).unwrap(); - } + None => json_error_response(StatusCode::BAD_REQUEST, "TokenRequired", "") + .expect("Error creating an error response"), Some(token) => { let token = UntrustedToken::new(&token); - //Doing weird unwraps cause I can't do Result for middleware? if token.is_err() { - return json_error_response( - StatusCode::BAD_REQUEST, - "TokenRequired", - "", - ).unwrap(); + return json_error_response(StatusCode::BAD_REQUEST, "TokenRequired", "") + .expect("Error creating an error response"); } - let parsed_token = token.unwrap(); + let parsed_token = token.expect("Already checked for error"); let claims: Result, ValidationError> = parsed_token.deserialize_claims_unchecked(); if claims.is_err() { - return json_error_response( - StatusCode::BAD_REQUEST, - "TokenRequired", - "", - ).unwrap(); + return json_error_response(StatusCode::BAD_REQUEST, "TokenRequired", "") + .expect("Error creating an error response"); } - let key = Hs256Key::new(env::var("PDS_JWT_SECRET").unwrap()); + let key = Hs256Key::new( + env::var("PDS_JWT_SECRET").expect("PDS_JWT_SECRET not set in the pds.env"), + ); let token: Result, ValidationError> = Hs256.validator(&key).validate(&parsed_token); if token.is_err() { - return json_error_response( - StatusCode::BAD_REQUEST, - "InvalidToken", - "", - ).unwrap(); + return json_error_response(StatusCode::BAD_REQUEST, "InvalidToken", "") + .expect("Error creating an error response"); } - let token = token.unwrap(); + let token = token.expect("Already checked for error,"); //Not going to worry about expiration since it still goes to the PDS - req.extensions_mut() .insert(Did(Some(token.claims().custom.sub.clone()))); next.run(req).await } } } - Err(_) => { - return json_error_response( - StatusCode::BAD_REQUEST, - "InvalidToken", - "", - ).unwrap(); + Err(err) => { + log::error!("Error extracting token: {err}"); + json_error_response(StatusCode::BAD_REQUEST, "InvalidToken", "") + .expect("Error creating an error response") } } } diff --git a/src/oauth_provider.rs b/src/oauth_provider.rs new file mode 100644 index 0000000..ae3caad --- /dev/null +++ b/src/oauth_provider.rs @@ -0,0 +1,141 @@ +use crate::AppState; +use crate::helpers::{AuthResult, oauth_json_error_response, preauth_check}; +use axum::body::Body; +use axum::extract::State; +use axum::http::header::CONTENT_TYPE; +use axum::http::{HeaderMap, HeaderName, HeaderValue, StatusCode}; +use axum::response::{IntoResponse, Response}; +use axum::{Json, extract}; +use serde::{Deserialize, Serialize}; +use tracing::log; + +#[derive(Serialize, Deserialize, Clone)] +pub struct SignInRequest { + pub username: String, + pub password: String, + pub remember: bool, + pub locale: String, + #[serde(skip_serializing_if = "Option::is_none", rename = "emailOtp")] + pub email_otp: Option, +} + +pub async fn sign_in( + State(state): State, + headers: HeaderMap, + Json(mut payload): extract::Json, +) -> Result, StatusCode> { + let identifier = payload.username.clone(); + let password = payload.password.clone(); + let auth_factor_token = payload.email_otp.clone(); + + match preauth_check(&state, &identifier, &password, auth_factor_token, true).await { + Ok(result) => match result { + AuthResult::WrongIdentityOrPassword => oauth_json_error_response( + StatusCode::BAD_REQUEST, + "invalid_request", + "Invalid identifier or password", + ), + AuthResult::TwoFactorRequired(masked_email) => { + // Email sending step can be handled here if needed in the future. + + // {"error":"second_authentication_factor_required","error_description":"emailOtp authentication factor required (hint: 2***0@p***m)","type":"emailOtp","hint":"2***0@p***m"} + let body_str = match serde_json::to_string(&serde_json::json!({ + "error": "second_authentication_factor_required", + "error_description": format!("emailOtp authentication factor required (hint: {})", masked_email), + "type": "emailOtp", + "hint": masked_email, + })) { + Ok(s) => s, + Err(_) => return Err(StatusCode::BAD_REQUEST), + }; + + Response::builder() + .status(StatusCode::BAD_REQUEST) + .header(CONTENT_TYPE, "application/json") + .body(Body::from(body_str)) + .map_err(|_| StatusCode::BAD_REQUEST) + } + AuthResult::ProxyThrough => { + //No 2FA or already passed + let uri = format!( + "{}{}", + state.pds_base_url, "/@atproto/oauth-provider/~api/sign-in" + ); + + let mut req = axum::http::Request::post(uri); + if let Some(req_headers) = req.headers_mut() { + // Copy headers but remove problematic ones. There was an issue with the PDS not parsing the body fully if i forwarded all headers + copy_filtered_headers(&headers, req_headers); + //Setting the content type to application/json manually + req_headers.insert(CONTENT_TYPE, HeaderValue::from_static("application/json")); + } + + //Clears the email_otp because the pds will reject a request with it. + payload.email_otp = None; + let payload_bytes = + serde_json::to_vec(&payload).map_err(|_| StatusCode::BAD_REQUEST)?; + + let req = req + .body(Body::from(payload_bytes)) + .map_err(|_| StatusCode::BAD_REQUEST)?; + + let proxied = state + .reverse_proxy_client + .request(req) + .await + .map_err(|_| StatusCode::BAD_REQUEST)? + .into_response(); + + Ok(proxied) + } + //Ignoring the type of token check failure. Looks like oauth on the entry treads them the same. + AuthResult::TokenCheckFailed(_) => oauth_json_error_response( + StatusCode::BAD_REQUEST, + "invalid_request", + "Unable to sign-in due to an unexpected server error", + ), + }, + Err(err) => { + log::error!( + "Error during pre-auth check. This happens on the create_session endpoint when trying to decide if the user has access:\n {err}" + ); + oauth_json_error_response( + StatusCode::BAD_REQUEST, + "pds_gatekeeper_error", + "This error was not generated by the PDS, but PDS Gatekeeper. Please contact your PDS administrator for help and for them to review the server logs.", + ) + } + } +} + +fn is_disallowed_header(name: &HeaderName) -> bool { + // possible problematic headers with proxying + matches!( + name.as_str(), + "connection" + | "keep-alive" + | "proxy-authenticate" + | "proxy-authorization" + | "te" + | "trailer" + | "transfer-encoding" + | "upgrade" + | "host" + | "content-length" + | "content-encoding" + | "expect" + | "accept-encoding" + ) +} + +fn copy_filtered_headers(src: &HeaderMap, dst: &mut HeaderMap) { + for (name, value) in src.iter() { + if is_disallowed_header(name) { + continue; + } + // Only copy valid headers + if let Ok(hv) = HeaderValue::from_bytes(value.as_bytes()) { + dst.insert(name.clone(), hv); + } + } +} diff --git a/src/xrpc/com_atproto_server.rs b/src/xrpc/com_atproto_server.rs index 9ea6e15..45b180b 100644 --- a/src/xrpc/com_atproto_server.rs +++ b/src/xrpc/com_atproto_server.rs @@ -1,18 +1,15 @@ use crate::AppState; +use crate::helpers::{ + AuthResult, ProxiedResult, TokenCheckError, json_error_response, preauth_check, proxy_get_json, +}; use crate::middleware::Did; -use crate::xrpc::helpers::{ProxiedResult, json_error_response, proxy_get_json}; use axum::body::Body; use axum::extract::State; use axum::http::{HeaderMap, StatusCode}; use axum::response::{IntoResponse, Response}; use axum::{Extension, Json, debug_handler, extract, extract::Request}; -use axum_template::TemplateEngine; -use lettre::message::{MultiPart, SinglePart, header}; -use lettre::{AsyncTransport, Message}; use serde::{Deserialize, Serialize}; use serde_json; -use serde_json::Value; -use serde_json::value::Map; use tracing::log; #[derive(Serialize, Deserialize, Debug, Clone)] @@ -58,170 +55,10 @@ pub struct UpdateEmailResponse { pub struct CreateSessionRequest { identifier: String, password: String, - auth_factor_token: String, - allow_takendown: bool, -} - -pub enum AuthResult { - WrongIdentityOrPassword, - TwoFactorRequired, - TwoFactorFailed, - /// User does not have 2FA enabled, or passes it - ProxyThrough, -} - -pub enum IdentifierType { - Email, - DID, - Handle, -} - -impl IdentifierType { - fn what_is_it(identifier: String) -> Self { - if identifier.contains("@") { - IdentifierType::Email - } else if identifier.contains("did:") { - IdentifierType::DID - } else { - IdentifierType::Handle - } - } -} - -async fn verify_password(password: &str, password_scrypt: &str) -> Result { - // Expected format: "salt:hash" where hash is hex of scrypt(password, salt, 64 bytes) - let mut parts = password_scrypt.splitn(2, ':'); - let salt = match parts.next() { - Some(s) if !s.is_empty() => s, - _ => return Ok(false), - }; - let stored_hash_hex = match parts.next() { - Some(h) if !h.is_empty() => h, - _ => return Ok(false), - }; - - //Sets up scrypt to mimic node's scrypt - let params = match scrypt::Params::new(14, 8, 1, 64) { - Ok(p) => p, - Err(_) => return Ok(false), - }; - let mut derived = [0u8; 64]; - if scrypt::scrypt(password.as_bytes(), salt.as_bytes(), ¶ms, &mut derived).is_err() { - return Ok(false); - } - - let stored_bytes = match hex::decode(stored_hash_hex) { - Ok(b) => b, - Err(e) => { - log::error!("Error decoding stored hash: {}", e); - return Ok(false); - } - }; - - Ok(derived.as_slice() == stored_bytes.as_slice()) -} - -async fn preauth_check( - state: &AppState, - identifier: &str, - password: &str, -) -> Result { - // Determine identifier type - let id_type = IdentifierType::what_is_it(identifier.to_string()); - - // Query account DB for did and passwordScrypt based on identifier type - let account_row: Option<(String, String, String)> = match id_type { - IdentifierType::Email => sqlx::query_as::<_, (String, String, String)>( - "SELECT did, passwordScrypt, account.email FROM account WHERE email = ? LIMIT 1", - ) - .bind(identifier) - .fetch_optional(&state.account_pool) - .await - .map_err(|_| StatusCode::BAD_REQUEST)?, - IdentifierType::Handle => sqlx::query_as::<_, (String, String, String)>( - "SELECT account.did, account.passwordScrypt, account.email - FROM actor - LEFT JOIN account ON actor.did = account.did - where actor.handle =? LIMIT 1", - ) - .bind(identifier) - .fetch_optional(&state.account_pool) - .await - .map_err(|_| StatusCode::BAD_REQUEST)?, - IdentifierType::DID => sqlx::query_as::<_, (String, String, String)>( - "SELECT did, passwordScrypt, account.email FROM account WHERE did = ? LIMIT 1", - ) - .bind(identifier) - .fetch_optional(&state.account_pool) - .await - .map_err(|_| StatusCode::BAD_REQUEST)?, - }; - - if let Some((did, password_scrypt, email)) = account_row { - // Check two-factor requirement for this DID in the gatekeeper DB - let required_opt = sqlx::query_as::<_, (u8,)>( - "SELECT required FROM two_factor_accounts WHERE did = ? LIMIT 1", - ) - .bind(&did) - .fetch_optional(&state.pds_gatekeeper_pool) - .await - .map_err(|_| StatusCode::BAD_REQUEST)?; - - let two_factor_required = match required_opt { - Some(row) => row.0 != 0, - None => false, - }; - - if two_factor_required { - // Verify password before proceeding to 2FA email step - let verified = verify_password(password, &password_scrypt).await?; - if !verified { - return Ok(AuthResult::WrongIdentityOrPassword); - } - let mut email_data = Map::new(); - //TODO these need real values - let token = "test".to_string(); - let handle = "baileytownsend.dev".to_string(); - email_data.insert("token".to_string(), Value::from(token.clone())); - email_data.insert("handle".to_string(), Value::from(handle.clone())); - //TODO bad unwrap - let email_body = state - .template_engine - .render("two_factor_code.hbs", email_data) - .unwrap(); - - let email = Message::builder() - //TODO prob get the proper type in the state - .from(state.mailer_from.parse().unwrap()) - .to(email.parse().unwrap()) - .subject("Sign in to Bluesky") - .multipart( - MultiPart::alternative() // This is composed of two parts. - .singlepart( - SinglePart::builder() - .header(header::ContentType::TEXT_PLAIN) - .body(format!("We received a sign-in request for the account @{}. Use the code: {} to sign in. If this wasn't you, we recommend taking steps to protect your account by changing your password at https://bsky.app/settings.", handle, token)), // Every message should have a plain text fallback. - ) - .singlepart( - SinglePart::builder() - .header(header::ContentType::TEXT_HTML) - .body(email_body), - ), - ) - //TODO bad - .unwrap(); - return match state.mailer.send(email).await { - Ok(_) => Ok(AuthResult::TwoFactorRequired), - Err(err) => { - log::error!("Error sending the 2FA email: {}", err); - Err(StatusCode::BAD_REQUEST) - } - }; - } - } - - // No local 2FA requirement (or account not found) - Ok(AuthResult::ProxyThrough) + #[serde(skip_serializing_if = "Option::is_none")] + auth_factor_token: Option, + #[serde(skip_serializing_if = "Option::is_none")] + allow_takendown: Option, } pub async fn create_session( @@ -231,52 +68,70 @@ pub async fn create_session( ) -> Result, StatusCode> { let identifier = payload.identifier.clone(); let password = payload.password.clone(); + let auth_factor_token = payload.auth_factor_token.clone(); // Run the shared pre-auth logic to validate and check 2FA requirement - match preauth_check(&state, &identifier, &password).await? { - AuthResult::WrongIdentityOrPassword => json_error_response( - StatusCode::UNAUTHORIZED, - "AuthenticationRequired", - "Invalid identifier or password", - ), - AuthResult::TwoFactorRequired => { - // Email sending step can be handled here if needed in the future. - json_error_response( + match preauth_check(&state, &identifier, &password, auth_factor_token, false).await { + Ok(result) => match result { + AuthResult::WrongIdentityOrPassword => json_error_response( StatusCode::UNAUTHORIZED, - "AuthFactorTokenRequired", - "A sign in code has been sent to your email address", - ) - } - AuthResult::TwoFactorFailed => { - //Not sure what the errors are for this response is yet - json_error_response(StatusCode::UNAUTHORIZED, "PLACEHOLDER", "PLACEHOLDER") - } - AuthResult::ProxyThrough => { - //No 2FA or already passed - let uri = format!( - "{}{}", - state.pds_base_url, "/xrpc/com.atproto.server.createSession" - ); - - let mut req = axum::http::Request::post(uri); - if let Some(req_headers) = req.headers_mut() { - req_headers.extend(headers.clone()); + "AuthenticationRequired", + "Invalid identifier or password", + ), + AuthResult::TwoFactorRequired(_) => { + // Email sending step can be handled here if needed in the future. + json_error_response( + StatusCode::UNAUTHORIZED, + "AuthFactorTokenRequired", + "A sign in code has been sent to your email address", + ) } + AuthResult::ProxyThrough => { + log::info!("Proxying through"); + //No 2FA or already passed + let uri = format!( + "{}{}", + state.pds_base_url, "/xrpc/com.atproto.server.createSession" + ); + + let mut req = axum::http::Request::post(uri); + if let Some(req_headers) = req.headers_mut() { + req_headers.extend(headers.clone()); + } - let payload_bytes = - serde_json::to_vec(&payload).map_err(|_| StatusCode::BAD_REQUEST)?; - let req = req - .body(Body::from(payload_bytes)) - .map_err(|_| StatusCode::BAD_REQUEST)?; + let payload_bytes = + serde_json::to_vec(&payload).map_err(|_| StatusCode::BAD_REQUEST)?; + let req = req + .body(Body::from(payload_bytes)) + .map_err(|_| StatusCode::BAD_REQUEST)?; - let proxied = state - .reverse_proxy_client - .request(req) - .await - .map_err(|_| StatusCode::BAD_REQUEST)? - .into_response(); + let proxied = state + .reverse_proxy_client + .request(req) + .await + .map_err(|_| StatusCode::BAD_REQUEST)? + .into_response(); - Ok(proxied) + Ok(proxied) + } + AuthResult::TokenCheckFailed(err) => match err { + TokenCheckError::InvalidToken => { + json_error_response(StatusCode::BAD_REQUEST, "InvalidToken", "Token is invalid") + } + TokenCheckError::ExpiredToken => { + json_error_response(StatusCode::BAD_REQUEST, "ExpiredToken", "Token is expired") + } + }, + }, + Err(err) => { + log::error!( + "Error during pre-auth check. This happens on the create_session endpoint when trying to decide if the user has access:\n {err}" + ); + json_error_response( + StatusCode::INTERNAL_SERVER_ERROR, + "InternalServerError", + "This error was not generated by the PDS, but PDS Gatekeeper. Please contact your PDS administrator for help and for them to review the server logs.", + ) } } } @@ -290,7 +145,7 @@ pub async fn update_email( ) -> Result, StatusCode> { //If email auth is not set at all it is a update email address let email_auth_not_set = payload.email_auth_factor.is_none(); - //If email aurth is set it is to either turn on or off 2fa + //If email auth is set it is to either turn on or off 2fa let email_auth_update = payload.email_auth_factor.unwrap_or(false); // Email update asked for @@ -350,7 +205,7 @@ pub async fn update_email( } } - // Updating the acutal email address + // Updating the actual email address by sending it on to the PDS let uri = format!( "{}{}", state.pds_base_url, "/xrpc/com.atproto.server.updateEmail" diff --git a/src/xrpc/helpers.rs b/src/xrpc/helpers.rs deleted file mode 100644 index b3cbddf..0000000 --- a/src/xrpc/helpers.rs +++ /dev/null @@ -1,150 +0,0 @@ -use axum::body::{Body, to_bytes}; -use axum::extract::Request; -use axum::http::{HeaderMap, Method, StatusCode, Uri}; -use axum::http::header::CONTENT_TYPE; -use axum::response::{IntoResponse, Response}; -use serde::de::DeserializeOwned; -use tracing::error; - -use crate::AppState; - -/// The result of a proxied call that attempts to parse JSON. -pub enum ProxiedResult { - /// Successfully parsed JSON body along with original response headers. - Parsed { value: T, _headers: HeaderMap }, - /// Could not or should not parse: return the original (or rebuilt) response as-is. - Passthrough(Response), -} - -/// Proxy the incoming request to the PDS base URL plus the provided path and attempt to parse -/// the successful response body as JSON into `T`. -/// -/// Behavior: -/// - If the proxied response is non-200, returns Passthrough with the original response. -/// - If the response is 200 but JSON parsing fails, returns Passthrough with the original body and headers. -/// - If parsing succeeds, returns Parsed { value, headers }. -pub async fn proxy_get_json( - state: &AppState, - mut req: Request, - path: &str, -) -> Result, StatusCode> -where - T: DeserializeOwned, -{ - let uri = format!("{}{}", state.pds_base_url, path); - *req.uri_mut() = Uri::try_from(uri).map_err(|_| StatusCode::BAD_REQUEST)?; - - let result = state - .reverse_proxy_client - .request(req) - .await - .map_err(|_| StatusCode::BAD_REQUEST)? - .into_response(); - - if result.status() != StatusCode::OK { - return Ok(ProxiedResult::Passthrough(result)); - } - - let response_headers = result.headers().clone(); - let body = result.into_body(); - let body_bytes = to_bytes(body, usize::MAX) - .await - .map_err(|_| StatusCode::BAD_REQUEST)?; - - match serde_json::from_slice::(&body_bytes) { - Ok(value) => Ok(ProxiedResult::Parsed { - value, - _headers: response_headers, - }), - Err(err) => { - error!(%err, "failed to parse proxied JSON response; returning original body"); - let mut builder = Response::builder().status(StatusCode::OK); - if let Some(headers) = builder.headers_mut() { - *headers = response_headers; - } - let resp = builder - .body(Body::from(body_bytes)) - .map_err(|_| StatusCode::BAD_REQUEST)?; - Ok(ProxiedResult::Passthrough(resp)) - } - } -} - -/// Proxy the incoming request as a POST to the PDS base URL plus the provided path and attempt to parse -/// the successful response body as JSON into `T`. -/// -/// Behavior mirrors `proxy_get_json`: -/// - If the proxied response is non-200, returns Passthrough with the original response. -/// - If the response is 200 but JSON parsing fails, returns Passthrough with the original body and headers. -/// - If parsing succeeds, returns Parsed { value, headers }. -pub async fn _proxy_post_json( - state: &AppState, - mut req: Request, - path: &str, -) -> Result, StatusCode> -where - T: DeserializeOwned, -{ - let uri = format!("{}{}", state.pds_base_url, path); - *req.uri_mut() = Uri::try_from(uri).map_err(|_| StatusCode::BAD_REQUEST)?; - *req.method_mut() = Method::POST; - - let result = state - .reverse_proxy_client - .request(req) - .await - .map_err(|_| StatusCode::BAD_REQUEST)? - .into_response(); - - if result.status() != StatusCode::OK { - return Ok(ProxiedResult::Passthrough(result)); - } - - let response_headers = result.headers().clone(); - let body = result.into_body(); - let body_bytes = to_bytes(body, usize::MAX) - .await - .map_err(|_| StatusCode::BAD_REQUEST)?; - - match serde_json::from_slice::(&body_bytes) { - Ok(value) => Ok(ProxiedResult::Parsed { - value, - _headers: response_headers, - }), - Err(err) => { - error!(%err, "failed to parse proxied JSON response (POST); returning original body"); - let mut builder = Response::builder().status(StatusCode::OK); - if let Some(headers) = builder.headers_mut() { - *headers = response_headers; - } - let resp = builder - .body(Body::from(body_bytes)) - .map_err(|_| StatusCode::BAD_REQUEST)?; - Ok(ProxiedResult::Passthrough(resp)) - } - } -} - - -/// Build a JSON error response with the required Content-Type header -/// Content-Type: application/json;charset=utf-8 -/// Body shape: { "error": string, "message": string } -pub fn json_error_response( - status: StatusCode, - error: impl Into, - message: impl Into, -) -> Result, StatusCode> { - let body_str = match serde_json::to_string(&serde_json::json!({ - "error": error.into(), - "message": message.into(), - })) { - Ok(s) => s, - Err(_) => return Err(StatusCode::BAD_REQUEST), - }; - - Response::builder() - .status(status) - .header(CONTENT_TYPE, "application/json;charset=utf-8") - .body(Body::from(body_str)) - .map_err(|_| StatusCode::BAD_REQUEST) -} diff --git a/src/xrpc/mod.rs b/src/xrpc/mod.rs index ad6127e..0988f80 100644 --- a/src/xrpc/mod.rs +++ b/src/xrpc/mod.rs @@ -1,2 +1 @@ pub mod com_atproto_server; -pub mod helpers; -- 2.43.0