From 4025e761a92d932ec63f298be0272d6267f87ee3 Mon Sep 17 00:00:00 2001 From: Bailey Townsend Date: Fri, 5 Sep 2025 20:25:44 -0500 Subject: [PATCH] Works but I feel like it should be more secure --- Cargo.toml | 5 ++--- src/main.rs | 3 +++ src/middleware.rs | 57 +++++++++++++++++++++++++++++++++-------------- 3 files changed, 45 insertions(+), 20 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 30a8a14..5193a46 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -19,11 +19,10 @@ tower_governor = "0.8.0" hex = "0.4" jwt-compact = { version = "0.8.0", features = ["es256k"] } scrypt = "0.11" -#lettre = { version = "0.11.18", default-features = false, features = ["pool", "tokio1-rustls", "smtp-transport", "hostname", "builder"] } -#lettre = { version = "0.11", default-features = false, features = ["builder", "webpki-roots", "rustls", "aws-lc-rs", "smtp-transport", "tokio1", "tokio1-rustls"] } +#Leaveing these two cause I think it is needed by the aws-lc-rs = "1.13.0" -lettre = { version = "0.11", default-features = false, features = ["builder", "webpki-roots", "rustls", "aws-lc-rs", "smtp-transport", "tokio1", "tokio1-rustls"] } rustls = { version = "0.23", default-features = false, features = ["tls12", "std", "logging", "aws_lc_rs"] } +lettre = { version = "0.11", default-features = false, features = ["builder", "webpki-roots", "rustls", "aws-lc-rs", "smtp-transport", "tokio1", "tokio1-rustls"] } handlebars = { version = "6.3.2", features = ["rust-embed"] } rust-embed = "8.7.2" axum-template = { version = "3.0.0", features = ["handlebars"] } diff --git a/src/main.rs b/src/main.rs index 03bb851..abf3306 100644 --- a/src/main.rs +++ b/src/main.rs @@ -175,6 +175,9 @@ async fn main() -> Result<(), Box> { .finish() .expect("failed to create governor config. this should not happen and is a bug"); + // let create_account_limiter_time: Option = + // env::var("GATEKEEPER_CREATE_ACCOUNT_LIMITER_WINDOW").unwrap_or_else(|_| None); + let create_session_governor_limiter = create_session_governor_conf.limiter().clone(); let sign_in_governor_limiter = sign_in_governor_conf.limiter().clone(); let interval = Duration::from_secs(60); diff --git a/src/middleware.rs b/src/middleware.rs index ef9c104..3025c38 100644 --- a/src/middleware.rs +++ b/src/middleware.rs @@ -1,5 +1,6 @@ use crate::helpers::json_error_response; use axum::extract::Request; +use axum::http::header::AUTHORIZATION; use axum::http::{HeaderMap, StatusCode}; use axum::middleware::Next; use axum::response::IntoResponse; @@ -12,21 +13,31 @@ use tracing::log; #[derive(Clone, Debug)] pub struct Did(pub Option); +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub enum AuthScheme { + Bearer, + DPoP, +} + #[derive(Serialize, Deserialize)] pub struct TokenClaims { pub sub: String, } pub async fn extract_did(mut req: Request, next: Next) -> impl IntoResponse { - let token = extract_bearer(req.headers()); + let auth = extract_auth(req.headers()); - match token { - Ok(token) => { - match token { + match auth { + Ok(auth_opt) => { + match auth_opt { None => json_error_response(StatusCode::BAD_REQUEST, "TokenRequired", "") .expect("Error creating an error response"), - Some(token) => { - let token = UntrustedToken::new(&token); + Some((scheme, token_str)) => { + // For Bearer, validate JWT and extract DID from `sub`. + // For DPoP, we currently only pass through and do not validate here; insert None DID. + // match scheme { + // AuthScheme::Bearer => { + let token = UntrustedToken::new(&token_str); if token.is_err() { return json_error_response(StatusCode::BAD_REQUEST, "TokenRequired", "") .expect("Error creating an error response"); @@ -49,9 +60,16 @@ pub async fn extract_did(mut req: Request, next: Next) -> impl IntoResponse { .expect("Error creating an error response"); } let token = token.expect("Already checked for error,"); - //Not going to worry about expiration since it still goes to the PDS + // Not going to worry about expiration since it still goes to the PDS req.extensions_mut() .insert(Did(Some(token.claims().custom.sub.clone()))); + // } + // AuthScheme::DPoP => { + // // No DID extraction from DPoP here; leave None + // req.extensions_mut().insert(Did(None)); + // } + // } + next.run(req).await } } @@ -64,19 +82,24 @@ pub async fn extract_did(mut req: Request, next: Next) -> impl IntoResponse { } } -fn extract_bearer(headers: &HeaderMap) -> Result, String> { +fn extract_auth(headers: &HeaderMap) -> Result, String> { match headers.get(axum::http::header::AUTHORIZATION) { None => Ok(None), - Some(hv) => match hv.to_str() { - Err(_) => Err("Authorization header is not valid".into()), - Ok(s) => { - // Accept forms like: "Bearer " (case-sensitive for the scheme here) - let mut parts = s.splitn(2, ' '); - match (parts.next(), parts.next()) { - (Some("Bearer"), Some(tok)) if !tok.is_empty() => Ok(Some(tok.to_string())), - _ => Err("Authorization header must be in format 'Bearer '".into()), + Some(hv) => { + match hv.to_str() { + Err(_) => Err("Authorization header is not valid".into()), + Ok(s) => { + // Accept forms like: "Bearer " or "DPoP " (case-sensitive for the scheme here) + let mut parts = s.splitn(2, ' '); + match (parts.next(), parts.next()) { + (Some("Bearer"), Some(tok)) if !tok.is_empty() => + Ok(Some((AuthScheme::Bearer, tok.to_string()))), + (Some("DPoP"), Some(tok)) if !tok.is_empty() => + Ok(Some((AuthScheme::DPoP, tok.to_string()))), + _ => Err("Authorization header must be in format 'Bearer ' or 'DPoP '".into()), + } } } - }, + } } } -- 2.43.0 From 95e5a3c055275400298111f8a6dd675049af1517 Mon Sep 17 00:00:00 2001 From: Bailey Townsend Date: Fri, 5 Sep 2025 21:06:56 -0500 Subject: [PATCH] WIP --- Cargo.lock | 6 +-- src/main.rs | 51 ++++++++++++++---- src/middleware.rs | 77 ++++++++++++++++----------- src/xrpc/com_atproto_server.rs | 97 ++++++++++++++++++---------------- 4 files changed, 140 insertions(+), 91 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 0f2cee2..0ba727b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -656,7 +656,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "778e2ac28f6c47af28e4907f13ffd1e1ddbd400980a9abd7c8df189bf578a5ad" dependencies = [ "libc", - "windows-sys 0.52.0", + "windows-sys 0.59.0", ] [[package]] @@ -1392,7 +1392,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "07033963ba89ebaf1584d767badaa2e8fcec21aedea6b8c0346d487d49c28667" dependencies = [ "cfg-if", - "windows-targets 0.48.5", + "windows-targets 0.52.6", ] [[package]] @@ -2136,7 +2136,7 @@ dependencies = [ "errno", "libc", "linux-raw-sys", - "windows-sys 0.52.0", + "windows-sys 0.59.0", ] [[package]] diff --git a/src/main.rs b/src/main.rs index abf3306..42fcb1a 100644 --- a/src/main.rs +++ b/src/main.rs @@ -19,7 +19,8 @@ use std::path::Path; use std::time::Duration; use std::{env, net::SocketAddr}; use tower_governor::GovernorLayer; -use tower_governor::governor::GovernorConfigBuilder; +use tower_governor::governor::{GovernorConfig, GovernorConfigBuilder}; +use tower_governor::key_extractor::PeerIpKeyExtractor; use tower_http::compression::CompressionLayer; use tower_http::cors::{Any, CorsLayer}; use tracing::log; @@ -166,20 +167,49 @@ async fn main() -> Result<(), Box> { .per_second(60) .burst_size(5) .finish() - .expect("failed to create governor config. this should not happen and is a bug"); + .expect("failed to create governor config for create session. this should not happen and is a bug"); // Create a second config with the same settings for the other endpoint let sign_in_governor_conf = GovernorConfigBuilder::default() .per_second(60) .burst_size(5) .finish() - .expect("failed to create governor config. this should not happen and is a bug"); - - // let create_account_limiter_time: Option = - // env::var("GATEKEEPER_CREATE_ACCOUNT_LIMITER_WINDOW").unwrap_or_else(|_| None); + .expect( + "failed to create governor config for sign in. this should not happen and is a bug", + ); + + let create_account_limiter_time: Option = + env::var("GATEKEEPER_CREATE_ACCOUNT_PER_SECOND").ok(); + let create_account_limiter_burst: Option = + env::var("GATEKEEPER_CREATE_ACCOUNT_BURST").ok(); + let mut create_account_governor_conf = None; + + if create_account_governor_conf.is_some() && create_account_limiter_time.is_some() { + let time = create_account_limiter_time + .expect("GATEKEEPER_CREATE_ACCOUNT_PER_SECOND not set") + .parse::() + .expect("GATEKEEPER_CREATE_ACCOUNT_PER_SECOND must be a valid integer"); + let burst = create_account_limiter_burst + .expect("GATEKEEPER_CREATE_ACCOUNT_BURST not set") + .parse::() + .expect("GATEKEEPER_CREATE_ACCOUNT_BURST must be a valid integer"); + + create_account_governor_conf = Some( + GovernorConfigBuilder::default() + .per_second(time) + .burst_size(burst) + .finish() + .expect("failed to create governor config for create account. this should not happen and is a bug"), + ) + } let create_session_governor_limiter = create_session_governor_conf.limiter().clone(); let sign_in_governor_limiter = sign_in_governor_conf.limiter().clone(); + let create_account_governor_limiter = match create_account_governor_conf { + None => None, + Some(conf) => Some(conf.limiter().clone()), + }; + let interval = Duration::from_secs(60); // a separate background task to clean up std::thread::spawn(move || { @@ -187,6 +217,9 @@ async fn main() -> Result<(), Box> { std::thread::sleep(interval); create_session_governor_limiter.retain_recent(); sign_in_governor_limiter.retain_recent(); + if let Some(ref limiter) = create_account_governor_limiter { + limiter.retain_recent(); + } } }); @@ -197,10 +230,7 @@ async fn main() -> Result<(), Box> { let app = Router::new() .route("/", get(root_handler)) - .route( - "/xrpc/com.atproto.server.getSession", - get(get_session).layer(ax_middleware::from_fn(middleware::extract_did)), - ) + .route("/xrpc/com.atproto.server.getSession", get(get_session)) .route( "/xrpc/com.atproto.server.updateEmail", post(update_email).layer(ax_middleware::from_fn(middleware::extract_did)), @@ -213,6 +243,7 @@ async fn main() -> Result<(), Box> { "/xrpc/com.atproto.server.createSession", post(create_session.layer(GovernorLayer::new(create_session_governor_conf))), ) + .route("/xrpc/com.atproto.server.createAccount") .layer(CompressionLayer::new()) .layer(cors) .with_state(state); diff --git a/src/middleware.rs b/src/middleware.rs index 3025c38..4b8044a 100644 --- a/src/middleware.rs +++ b/src/middleware.rs @@ -35,40 +35,53 @@ pub async fn extract_did(mut req: Request, next: Next) -> impl IntoResponse { Some((scheme, token_str)) => { // For Bearer, validate JWT and extract DID from `sub`. // For DPoP, we currently only pass through and do not validate here; insert None DID. - // match scheme { - // AuthScheme::Bearer => { - let token = UntrustedToken::new(&token_str); - if token.is_err() { - return json_error_response(StatusCode::BAD_REQUEST, "TokenRequired", "") - .expect("Error creating an error response"); - } - let parsed_token = token.expect("Already checked for error"); - let claims: Result, ValidationError> = - parsed_token.deserialize_claims_unchecked(); - if claims.is_err() { - return json_error_response(StatusCode::BAD_REQUEST, "TokenRequired", "") - .expect("Error creating an error response"); - } + match scheme { + AuthScheme::Bearer => { + let token = UntrustedToken::new(&token_str); + if token.is_err() { + return json_error_response( + StatusCode::BAD_REQUEST, + "TokenRequired", + "", + ) + .expect("Error creating an error response"); + } + let parsed_token = token.expect("Already checked for error"); + let claims: Result, ValidationError> = + parsed_token.deserialize_claims_unchecked(); + if claims.is_err() { + return json_error_response( + StatusCode::BAD_REQUEST, + "TokenRequired", + "", + ) + .expect("Error creating an error response"); + } - let key = Hs256Key::new( - env::var("PDS_JWT_SECRET").expect("PDS_JWT_SECRET not set in the pds.env"), - ); - let token: Result, ValidationError> = - Hs256.validator(&key).validate(&parsed_token); - if token.is_err() { - return json_error_response(StatusCode::BAD_REQUEST, "InvalidToken", "") - .expect("Error creating an error response"); + let key = Hs256Key::new( + env::var("PDS_JWT_SECRET") + .expect("PDS_JWT_SECRET not set in the pds.env"), + ); + let token: Result, ValidationError> = + Hs256.validator(&key).validate(&parsed_token); + if token.is_err() { + return json_error_response( + StatusCode::BAD_REQUEST, + "InvalidToken", + "", + ) + .expect("Error creating an error response"); + } + let token = token.expect("Already checked for error,"); + // Not going to worry about expiration since it still goes to the PDS + req.extensions_mut() + .insert(Did(Some(token.claims().custom.sub.clone()))); + } + AuthScheme::DPoP => { + //Not going to worry about oauth email update for now, just always forward to the PDS + req.extensions_mut().insert(Did(None)); + } } - let token = token.expect("Already checked for error,"); - // Not going to worry about expiration since it still goes to the PDS - req.extensions_mut() - .insert(Did(Some(token.claims().custom.sub.clone()))); - // } - // AuthScheme::DPoP => { - // // No DID extraction from DPoP here; leave None - // req.extensions_mut().insert(Did(None)); - // } - // } next.run(req).await } diff --git a/src/xrpc/com_atproto_server.rs b/src/xrpc/com_atproto_server.rs index 0f40de4..b05676f 100644 --- a/src/xrpc/com_atproto_server.rs +++ b/src/xrpc/com_atproto_server.rs @@ -147,63 +147,68 @@ pub async fn update_email( //If email auth is set it is to either turn on or off 2fa let email_auth_update = payload.email_auth_factor.unwrap_or(false); - // Email update asked for - if email_auth_update { - let email = payload.email.clone(); - let email_confirmed = sqlx::query_as::<_, (String,)>( - "SELECT did FROM account WHERE emailConfirmedAt IS NOT NULL AND email = ?", - ) - .bind(&email) - .fetch_optional(&state.account_pool) - .await - .map_err(|_| StatusCode::BAD_REQUEST)?; - - //Since the email is already confirmed we can enable 2fa - return match email_confirmed { - None => Err(StatusCode::BAD_REQUEST), - Some(did_row) => { - let _ = sqlx::query( - "INSERT INTO two_factor_accounts (did, required) VALUES (?, 1) ON CONFLICT(did) DO UPDATE SET required = 1", - ) - .bind(&did_row.0) - .execute(&state.pds_gatekeeper_pool) - .await - .map_err(|_| StatusCode::BAD_REQUEST)?; + //This means the middleware successfully extracted a did from the request, if not it just needs to be forward to the PDS + //This is also empty if it is an oauth request, which is not supported by gatekeeper turning on 2fa since the dpop stuff needs to be implemented + let did_is_not_empty = did.0.is_some(); - Ok(StatusCode::OK.into_response()) - } - }; - } - - // User wants auth turned off - if !email_auth_update && !email_auth_not_set { - //User wants auth turned off and has a token - if let Some(token) = &payload.token { - let token_found = sqlx::query_as::<_, (String,)>( - "SELECT token FROM email_token WHERE token = ? AND did = ? AND purpose = 'update_email'", + if did_is_not_empty { + // Email update asked for + if email_auth_update { + let email = payload.email.clone(); + let email_confirmed = sqlx::query_as::<_, (String,)>( + "SELECT did FROM account WHERE emailConfirmedAt IS NOT NULL AND email = ?", ) - .bind(token) - .bind(&did.0) + .bind(&email) .fetch_optional(&state.account_pool) .await .map_err(|_| StatusCode::BAD_REQUEST)?; - if token_found.is_some() { - let _ = sqlx::query( - "INSERT INTO two_factor_accounts (did, required) VALUES (?, 0) ON CONFLICT(did) DO UPDATE SET required = 0", + //Since the email is already confirmed we can enable 2fa + return match email_confirmed { + None => Err(StatusCode::BAD_REQUEST), + Some(did_row) => { + let _ = sqlx::query( + "INSERT INTO two_factor_accounts (did, required) VALUES (?, 1) ON CONFLICT(did) DO UPDATE SET required = 1", + ) + .bind(&did_row.0) + .execute(&state.pds_gatekeeper_pool) + .await + .map_err(|_| StatusCode::BAD_REQUEST)?; + + Ok(StatusCode::OK.into_response()) + } + }; + } + + // User wants auth turned off + if !email_auth_update && !email_auth_not_set { + //User wants auth turned off and has a token + if let Some(token) = &payload.token { + let token_found = sqlx::query_as::<_, (String,)>( + "SELECT token FROM email_token WHERE token = ? AND did = ? AND purpose = 'update_email'", ) - .bind(&did.0) - .execute(&state.pds_gatekeeper_pool) - .await - .map_err(|_| StatusCode::BAD_REQUEST)?; + .bind(token) + .bind(&did.0) + .fetch_optional(&state.account_pool) + .await + .map_err(|_| StatusCode::BAD_REQUEST)?; - return Ok(StatusCode::OK.into_response()); - } else { - return Err(StatusCode::BAD_REQUEST); + return if token_found.is_some() { + let _ = sqlx::query( + "INSERT INTO two_factor_accounts (did, required) VALUES (?, 0) ON CONFLICT(did) DO UPDATE SET required = 0", + ) + .bind(&did.0) + .execute(&state.pds_gatekeeper_pool) + .await + .map_err(|_| StatusCode::BAD_REQUEST)?; + + Ok(StatusCode::OK.into_response()) + } else { + Err(StatusCode::BAD_REQUEST) + }; } } } - // Updating the actual email address by sending it on to the PDS let uri = format!( "{}{}", -- 2.43.0 From 3c1c65c23b0b3a2db518bda42a1a9c23a3de4a26 Mon Sep 17 00:00:00 2001 From: Bailey Townsend Date: Fri, 5 Sep 2025 22:27:17 -0500 Subject: [PATCH] Create account rate limiting --- examples/Caddyfile | 1 + src/main.rs | 46 ++++++++++++++++++---------------- src/xrpc/com_atproto_server.rs | 24 ++++++++++++++++++ 3 files changed, 50 insertions(+), 21 deletions(-) diff --git a/examples/Caddyfile b/examples/Caddyfile index 9832246..26b8fa3 100644 --- a/examples/Caddyfile +++ b/examples/Caddyfile @@ -14,6 +14,7 @@ path /xrpc/com.atproto.server.getSession path /xrpc/com.atproto.server.updateEmail path /xrpc/com.atproto.server.createSession + path /xrpc/com.atproto.server.createAccount path /@atproto/oauth-provider/~api/sign-in } diff --git a/src/main.rs b/src/main.rs index 42fcb1a..2381dc5 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,6 +1,6 @@ #![warn(clippy::unwrap_used)] use crate::oauth_provider::sign_in; -use crate::xrpc::com_atproto_server::{create_session, get_session, update_email}; +use crate::xrpc::com_atproto_server::{create_account, create_session, get_session, update_email}; use axum::body::Body; use axum::handler::Handler; use axum::http::{Method, header}; @@ -20,7 +20,6 @@ use std::time::Duration; use std::{env, net::SocketAddr}; use tower_governor::GovernorLayer; use tower_governor::governor::{GovernorConfig, GovernorConfigBuilder}; -use tower_governor::key_extractor::PeerIpKeyExtractor; use tower_http::compression::CompressionLayer; use tower_http::cors::{Any, CorsLayer}; use tracing::log; @@ -92,7 +91,12 @@ async fn main() -> Result<(), Box> { let pds_env_location = env::var("PDS_ENV_LOCATION").unwrap_or_else(|_| "/pds/pds.env".to_string()); - dotenvy::from_path(Path::new(&pds_env_location))?; + let result_of_finding_pds_env = dotenvy::from_path(Path::new(&pds_env_location)); + if let Err(e) = result_of_finding_pds_env { + log::error!( + "Error loading pds.env file (ignore if you loaded your variables in the environment somehow else): {e}" + ); + } let pds_root = env::var("PDS_DATA_DIRECTORY")?; let account_db_url = format!("{pds_root}/account.sqlite"); @@ -182,33 +186,32 @@ async fn main() -> Result<(), Box> { env::var("GATEKEEPER_CREATE_ACCOUNT_PER_SECOND").ok(); let create_account_limiter_burst: Option = env::var("GATEKEEPER_CREATE_ACCOUNT_BURST").ok(); - let mut create_account_governor_conf = None; - if create_account_governor_conf.is_some() && create_account_limiter_time.is_some() { + //Default should be 608 requests per 5 minutes, PDS is 300 per 500 so will never hit it ideally + let mut create_account_governor_conf = GovernorConfigBuilder::default(); + if create_account_limiter_time.is_some() { let time = create_account_limiter_time .expect("GATEKEEPER_CREATE_ACCOUNT_PER_SECOND not set") .parse::() .expect("GATEKEEPER_CREATE_ACCOUNT_PER_SECOND must be a valid integer"); + create_account_governor_conf.per_second(time); + } + + if create_account_limiter_burst.is_some() { let burst = create_account_limiter_burst .expect("GATEKEEPER_CREATE_ACCOUNT_BURST not set") .parse::() .expect("GATEKEEPER_CREATE_ACCOUNT_BURST must be a valid integer"); - - create_account_governor_conf = Some( - GovernorConfigBuilder::default() - .per_second(time) - .burst_size(burst) - .finish() - .expect("failed to create governor config for create account. this should not happen and is a bug"), - ) + create_account_governor_conf.burst_size(burst); } + let create_account_governor_conf = create_account_governor_conf.finish().expect( + "failed to create governor config for create account. this should not happen and is a bug", + ); + let create_session_governor_limiter = create_session_governor_conf.limiter().clone(); let sign_in_governor_limiter = sign_in_governor_conf.limiter().clone(); - let create_account_governor_limiter = match create_account_governor_conf { - None => None, - Some(conf) => Some(conf.limiter().clone()), - }; + let create_account_governor_limiter = create_account_governor_conf.limiter().clone(); let interval = Duration::from_secs(60); // a separate background task to clean up @@ -217,9 +220,7 @@ async fn main() -> Result<(), Box> { std::thread::sleep(interval); create_session_governor_limiter.retain_recent(); sign_in_governor_limiter.retain_recent(); - if let Some(ref limiter) = create_account_governor_limiter { - limiter.retain_recent(); - } + create_account_governor_limiter.retain_recent(); } }); @@ -243,7 +244,10 @@ async fn main() -> Result<(), Box> { "/xrpc/com.atproto.server.createSession", post(create_session.layer(GovernorLayer::new(create_session_governor_conf))), ) - .route("/xrpc/com.atproto.server.createAccount") + .route( + "/xrpc/com.atproto.server.createAccount", + post(create_account).layer(GovernorLayer::new(create_account_governor_conf)), + ) .layer(CompressionLayer::new()) .layer(cors) .with_state(state); diff --git a/src/xrpc/com_atproto_server.rs b/src/xrpc/com_atproto_server.rs index b05676f..552eddf 100644 --- a/src/xrpc/com_atproto_server.rs +++ b/src/xrpc/com_atproto_server.rs @@ -264,3 +264,27 @@ pub async fn get_session( ProxiedResult::Passthrough(resp) => Ok(resp), } } + +pub async fn create_account( + State(state): State, + mut req: Request, +) -> Result, StatusCode> { + let uri = format!( + "{}{}", + state.pds_base_url, "/xrpc/com.atproto.server.createAccount" + ); + + // Rewrite the URI to point at the upstream PDS; keep headers, method, and body intact + *req.uri_mut() = uri + .parse() + .map_err(|_| StatusCode::BAD_REQUEST)?; + + let proxied = state + .reverse_proxy_client + .request(req) + .await + .map_err(|_| StatusCode::BAD_REQUEST)? + .into_response(); + + Ok(proxied) +} -- 2.43.0 From cb30160629c02adf5a1f2811dd32323f484612ce Mon Sep 17 00:00:00 2001 From: Bailey Townsend Date: Fri, 5 Sep 2025 22:47:19 -0500 Subject: [PATCH] logging --- src/xrpc/com_atproto_server.rs | 40 ++++++++++++++++++++++++---------- 1 file changed, 29 insertions(+), 11 deletions(-) diff --git a/src/xrpc/com_atproto_server.rs b/src/xrpc/com_atproto_server.rs index 552eddf..de6a898 100644 --- a/src/xrpc/com_atproto_server.rs +++ b/src/xrpc/com_atproto_server.rs @@ -10,6 +10,8 @@ use axum::response::{IntoResponse, Response}; use axum::{Extension, Json, debug_handler, extract, extract::Request}; use serde::{Deserialize, Serialize}; use serde_json; +use sqlx::Error; +use sqlx::sqlite::SqliteQueryResult; use tracing::log; #[derive(Serialize, Deserialize, Debug, Clone)] @@ -155,13 +157,19 @@ pub async fn update_email( // Email update asked for if email_auth_update { let email = payload.email.clone(); - let email_confirmed = sqlx::query_as::<_, (String,)>( + let email_confirmed = match sqlx::query_as::<_, (String,)>( "SELECT did FROM account WHERE emailConfirmedAt IS NOT NULL AND email = ?", ) .bind(&email) .fetch_optional(&state.account_pool) .await - .map_err(|_| StatusCode::BAD_REQUEST)?; + { + Ok(row) => row, + Err(err) => { + log::error!("Error checking if email is confirmed: {err}"); + return Err(StatusCode::BAD_REQUEST); + } + }; //Since the email is already confirmed we can enable 2fa return match email_confirmed { @@ -184,23 +192,35 @@ pub async fn update_email( if !email_auth_update && !email_auth_not_set { //User wants auth turned off and has a token if let Some(token) = &payload.token { - let token_found = sqlx::query_as::<_, (String,)>( + let token_found = match sqlx::query_as::<_, (String,)>( "SELECT token FROM email_token WHERE token = ? AND did = ? AND purpose = 'update_email'", ) .bind(token) .bind(&did.0) .fetch_optional(&state.account_pool) - .await - .map_err(|_| StatusCode::BAD_REQUEST)?; + .await{ + Ok(token) => token, + Err(err) => { + log::error!("Error checking if token is valid: {err}"); + return Err(StatusCode::BAD_REQUEST); + } + }; return if token_found.is_some() { - let _ = sqlx::query( + //TODO I think there may be a bug here and need to do some retry logic + // First try was erroring, seconds was allowing + match sqlx::query( "INSERT INTO two_factor_accounts (did, required) VALUES (?, 0) ON CONFLICT(did) DO UPDATE SET required = 0", ) .bind(&did.0) .execute(&state.pds_gatekeeper_pool) - .await - .map_err(|_| StatusCode::BAD_REQUEST)?; + .await { + Ok(_) => {} + Err(err) => { + log::error!("Error updating email auth: {err}"); + return Err(StatusCode::BAD_REQUEST); + } + } Ok(StatusCode::OK.into_response()) } else { @@ -275,9 +295,7 @@ pub async fn create_account( ); // Rewrite the URI to point at the upstream PDS; keep headers, method, and body intact - *req.uri_mut() = uri - .parse() - .map_err(|_| StatusCode::BAD_REQUEST)?; + *req.uri_mut() = uri.parse().map_err(|_| StatusCode::BAD_REQUEST)?; let proxied = state .reverse_proxy_client -- 2.43.0 From 6066b48f0220e72a1af723bb69da6c19b6800bd9 Mon Sep 17 00:00:00 2001 From: Bailey Townsend Date: Fri, 5 Sep 2025 23:03:20 -0500 Subject: [PATCH] little more clean up --- src/main.rs | 6 ++++-- src/middleware.rs | 2 -- src/xrpc/com_atproto_server.rs | 4 ++-- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/src/main.rs b/src/main.rs index 2381dc5..7429e46 100644 --- a/src/main.rs +++ b/src/main.rs @@ -19,7 +19,7 @@ use std::path::Path; use std::time::Duration; use std::{env, net::SocketAddr}; use tower_governor::GovernorLayer; -use tower_governor::governor::{GovernorConfig, GovernorConfigBuilder}; +use tower_governor::governor::GovernorConfigBuilder; use tower_http::compression::CompressionLayer; use tower_http::cors::{Any, CorsLayer}; use tracing::log; @@ -97,7 +97,9 @@ async fn main() -> Result<(), Box> { "Error loading pds.env file (ignore if you loaded your variables in the environment somehow else): {e}" ); } - let pds_root = env::var("PDS_DATA_DIRECTORY")?; + + let pds_root = + env::var("PDS_DATA_DIRECTORY").expect("PDS_DATA_DIRECTORY is not set in your pds.env file"); let account_db_url = format!("{pds_root}/account.sqlite"); let account_options = SqliteConnectOptions::new() diff --git a/src/middleware.rs b/src/middleware.rs index 4b8044a..b69c4ce 100644 --- a/src/middleware.rs +++ b/src/middleware.rs @@ -1,6 +1,5 @@ use crate::helpers::json_error_response; use axum::extract::Request; -use axum::http::header::AUTHORIZATION; use axum::http::{HeaderMap, StatusCode}; use axum::middleware::Next; use axum::response::IntoResponse; @@ -73,7 +72,6 @@ pub async fn extract_did(mut req: Request, next: Next) -> impl IntoResponse { .expect("Error creating an error response"); } let token = token.expect("Already checked for error,"); - // Not going to worry about expiration since it still goes to the PDS req.extensions_mut() .insert(Did(Some(token.claims().custom.sub.clone()))); } diff --git a/src/xrpc/com_atproto_server.rs b/src/xrpc/com_atproto_server.rs index de6a898..38a072b 100644 --- a/src/xrpc/com_atproto_server.rs +++ b/src/xrpc/com_atproto_server.rs @@ -10,8 +10,6 @@ use axum::response::{IntoResponse, Response}; use axum::{Extension, Json, debug_handler, extract, extract::Request}; use serde::{Deserialize, Serialize}; use serde_json; -use sqlx::Error; -use sqlx::sqlite::SqliteQueryResult; use tracing::log; #[derive(Serialize, Deserialize, Debug, Clone)] @@ -289,6 +287,8 @@ pub async fn create_account( State(state): State, mut req: Request, ) -> Result, StatusCode> { + //TODO if I add the block of only accounts authenticated just take the body as json here and grab the lxm token. No middle ware is needed + let uri = format!( "{}{}", state.pds_base_url, "/xrpc/com.atproto.server.createAccount" -- 2.43.0 From 9a5a6d9ffd463663d9e8c0b3eeb7a25331e8ca6f Mon Sep 17 00:00:00 2001 From: Bailey Townsend Date: Fri, 5 Sep 2025 23:07:18 -0500 Subject: [PATCH] forgot some docs --- README.md | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/README.md b/README.md index 2a0babe..c04cf60 100644 --- a/README.md +++ b/README.md @@ -113,3 +113,10 @@ same. `GATEKEEPER_HOST` - Host for pds gatekeeper. Defaults to `127.0.0.1` `GATEKEEPER_PORT` - Port for pds gatekeeper. Defaults to `8080` + +`GATEKEEPER_CREATE_ACCOUNT_PER_SECOND` - Sets how often it takes a count off the limiter. example if you hit the rate +limit of 5 and set to 60, then in 60 seconds you will be able to make one more. Or in 5 minutes be able to make 5 more. + +`GATEKEEPER_CREATE_ACCOUNT_BURST` - Sets how many requests can be made in a burst. In the prior example this is where +the 5 comes from. Example can set this to 10 to allow for 10 requests in a burst, and after 60 seconds it will drop one +off. \ No newline at end of file -- 2.43.0