2FA logins gatekept #1

merged
opened by baileytownsend.dev targeting main from feature/2faCodeGeneration
+10
Cargo.lock
···
dependencies = [
"android-tzdata",
"iana-time-zone",
"num-traits",
"windows-link",
]
···
name = "pds_gatekeeper"
version = "0.1.0"
dependencies = [
"axum",
"axum-template",
"dotenvy",
"handlebars",
"hex",
"hyper-util",
"jwt-compact",
"lettre",
"rust-embed",
"scrypt",
"serde",
"serde_json",
"sqlx",
"tokio",
"tower-http",
···
dependencies = [
"base64",
"bytes",
"crc",
"crossbeam-queue",
"either",
···
"bitflags",
"byteorder",
"bytes",
"crc",
"digest",
"dotenvy",
···
"base64",
"bitflags",
"byteorder",
"crc",
"dotenvy",
"etcetera",
···
checksum = "c2d12fe70b2c1b4401038055f90f151b78208de1f9f89a7dbfd41587a10c3eea"
dependencies = [
"atoi",
"flume",
"futures-channel",
"futures-core",
···
dependencies = [
"android-tzdata",
"iana-time-zone",
+
"js-sys",
"num-traits",
+
"wasm-bindgen",
"windows-link",
]
···
name = "pds_gatekeeper"
version = "0.1.0"
dependencies = [
+
"anyhow",
"axum",
"axum-template",
+
"chrono",
"dotenvy",
"handlebars",
"hex",
"hyper-util",
"jwt-compact",
"lettre",
+
"rand 0.9.2",
"rust-embed",
"scrypt",
"serde",
"serde_json",
+
"sha2",
"sqlx",
"tokio",
"tower-http",
···
dependencies = [
"base64",
"bytes",
+
"chrono",
"crc",
"crossbeam-queue",
"either",
···
"bitflags",
"byteorder",
"bytes",
+
"chrono",
"crc",
"digest",
"dotenvy",
···
"base64",
"bitflags",
"byteorder",
+
"chrono",
"crc",
"dotenvy",
"etcetera",
···
checksum = "c2d12fe70b2c1b4401038055f90f151b78208de1f9f89a7dbfd41587a10c3eea"
dependencies = [
"atoi",
+
"chrono",
"flume",
"futures-channel",
"futures-core",
+5 -1
Cargo.toml
···
[dependencies]
axum = { version = "0.8.4", features = ["macros", "json"] }
tokio = { version = "1.47.1", features = ["rt-multi-thread", "macros", "signal"] }
-
sqlx = { version = "0.8.6", features = ["runtime-tokio-rustls", "sqlite", "migrate"] }
dotenvy = "0.15.7"
serde = { version = "1.0", features = ["derive"] }
serde_json = "1.0"
···
handlebars = { version = "6.3.2", features = ["rust-embed"] }
rust-embed = "8.7.2"
axum-template = { version = "3.0.0", features = ["handlebars"] }
···
[dependencies]
axum = { version = "0.8.4", features = ["macros", "json"] }
tokio = { version = "1.47.1", features = ["rt-multi-thread", "macros", "signal"] }
+
sqlx = { version = "0.8.6", features = ["runtime-tokio-rustls", "sqlite", "migrate", "chrono"] }
dotenvy = "0.15.7"
serde = { version = "1.0", features = ["derive"] }
serde_json = "1.0"
···
handlebars = { version = "6.3.2", features = ["rust-embed"] }
rust-embed = "8.7.2"
axum-template = { version = "3.0.0", features = ["handlebars"] }
+
rand = "0.9.2"
+
anyhow = "1.0.99"
+
chrono = "0.4.41"
+
sha2 = "0.10"
+5 -6
README.md
···
## 2FA
-
- [x] Ability to turn on/off 2FA
-
- [x] getSession overwrite to set the `emailAuthFactor` flag if the user has 2FA turned on
-
- [x] send an email using the `PDS_EMAIL_SMTP_URL` with a handlebar email template like Bluesky's 2FA sign in email.
-
- [ ] generate a 2FA code
-
- [ ] createSession gatekeeping (It does stop logins, just eh, doesn't actually send a real code or check it yet)
-
- [ ] oauth endpoint gatekeeping
## Captcha on Create Account
···
# Setup
Nothing here yet! If you are brave enough to try before full release, let me know and I'll help you set it up.
But I want to run it locally on my own PDS first to test run it a bit.
···
path /xrpc/com.atproto.server.getSession
path /xrpc/com.atproto.server.updateEmail
path /xrpc/com.atproto.server.createSession
}
handle @gatekeeper {
···
## 2FA
+
- Overrides The login endpoint to add 2FA for both Bluesky client logged in and OAuth logins
+
- Overrides the settings endpoints as well. As long as you have a confirmed email you can turn on 2FA
## Captcha on Create Account
···
# Setup
+
We are getting close! Testing now
+
Nothing here yet! If you are brave enough to try before full release, let me know and I'll help you set it up.
But I want to run it locally on my own PDS first to test run it a bit.
···
path /xrpc/com.atproto.server.getSession
path /xrpc/com.atproto.server.updateEmail
path /xrpc/com.atproto.server.createSession
+
path /@atproto/oauth-provider/~api/sign-in
}
handle @gatekeeper {
-3
migrations_bells_and_whistles/.keep
···
-
# This directory holds SQLx migrations for the bells_and_whistles.sqlite database.
-
# It is intentionally empty for now; running `sqlx::migrate!` will still ensure the
-
# migrations table exists and succeed with zero migrations.
···
+524
src/helpers.rs
···
···
+
use crate::AppState;
+
use crate::helpers::TokenCheckError::InvalidToken;
+
use anyhow::anyhow;
+
use axum::body::{Body, to_bytes};
+
use axum::extract::Request;
+
use axum::http::header::CONTENT_TYPE;
+
use axum::http::{HeaderMap, StatusCode, Uri};
+
use axum::response::{IntoResponse, Response};
+
use axum_template::TemplateEngine;
+
use chrono::Utc;
+
use lettre::message::{MultiPart, SinglePart, header};
+
use lettre::{AsyncTransport, Message};
+
use rand::Rng;
+
use serde::de::DeserializeOwned;
+
use serde_json::{Map, Value};
+
use sha2::{Digest, Sha256};
+
use sqlx::SqlitePool;
+
use tracing::{error, log};
+
+
///Used to generate the email 2fa code
+
const UPPERCASE_BASE32_CHARS: &[u8] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZ234567";
+
+
/// The result of a proxied call that attempts to parse JSON.
+
pub enum ProxiedResult<T> {
+
/// Successfully parsed JSON body along with original response headers.
+
Parsed { value: T, _headers: HeaderMap },
+
/// Could not or should not parse: return the original (or rebuilt) response as-is.
+
Passthrough(Response<Body>),
+
}
+
+
/// Proxy the incoming request to the PDS base URL plus the provided path and attempt to parse
+
/// the successful response body as JSON into `T`.
+
///
+
pub async fn proxy_get_json<T>(
+
state: &AppState,
+
mut req: Request,
+
path: &str,
+
) -> Result<ProxiedResult<T>, StatusCode>
+
where
+
T: DeserializeOwned,
+
{
+
let uri = format!("{}{}", state.pds_base_url, path);
+
*req.uri_mut() = Uri::try_from(uri).map_err(|_| StatusCode::BAD_REQUEST)?;
+
+
let result = state
+
.reverse_proxy_client
+
.request(req)
+
.await
+
.map_err(|_| StatusCode::BAD_REQUEST)?
+
.into_response();
+
+
if result.status() != StatusCode::OK {
+
return Ok(ProxiedResult::Passthrough(result));
+
}
+
+
let response_headers = result.headers().clone();
+
let body = result.into_body();
+
let body_bytes = to_bytes(body, usize::MAX)
+
.await
+
.map_err(|_| StatusCode::BAD_REQUEST)?;
+
+
match serde_json::from_slice::<T>(&body_bytes) {
+
Ok(value) => Ok(ProxiedResult::Parsed {
+
value,
+
_headers: response_headers,
+
}),
+
Err(err) => {
+
error!(%err, "failed to parse proxied JSON response; returning original body");
+
let mut builder = Response::builder().status(StatusCode::OK);
+
if let Some(headers) = builder.headers_mut() {
+
*headers = response_headers;
+
}
+
let resp = builder
+
.body(Body::from(body_bytes))
+
.map_err(|_| StatusCode::BAD_REQUEST)?;
+
Ok(ProxiedResult::Passthrough(resp))
+
}
+
}
+
}
+
+
/// Build a JSON error response with the required Content-Type header
+
/// Content-Type: application/json;charset=utf-8
+
/// Body shape: { "error": string, "message": string }
+
pub fn json_error_response(
+
status: StatusCode,
+
error: impl Into<String>,
+
message: impl Into<String>,
+
) -> Result<Response<Body>, StatusCode> {
+
let body_str = match serde_json::to_string(&serde_json::json!({
+
"error": error.into(),
+
"message": message.into(),
+
})) {
+
Ok(s) => s,
+
Err(_) => return Err(StatusCode::BAD_REQUEST),
+
};
+
+
Response::builder()
+
.status(status)
+
.header(CONTENT_TYPE, "application/json;charset=utf-8")
+
.body(Body::from(body_str))
+
.map_err(|_| StatusCode::BAD_REQUEST)
+
}
+
+
/// Build a JSON error response with the required Content-Type header
+
/// Content-Type: application/json (oauth endpoint does not like utf ending)
+
/// Body shape: { "error": string, "error_description": string }
+
pub fn oauth_json_error_response(
+
status: StatusCode,
+
error: impl Into<String>,
+
message: impl Into<String>,
+
) -> Result<Response<Body>, StatusCode> {
+
let body_str = match serde_json::to_string(&serde_json::json!({
+
"error": error.into(),
+
"error_description": message.into(),
+
})) {
+
Ok(s) => s,
+
Err(_) => return Err(StatusCode::BAD_REQUEST),
+
};
+
+
Response::builder()
+
.status(status)
+
.header(CONTENT_TYPE, "application/json")
+
.body(Body::from(body_str))
+
.map_err(|_| StatusCode::BAD_REQUEST)
+
}
+
+
/// Creates a random token of 10 characters for email 2FA
+
pub fn get_random_token() -> String {
+
let mut rng = rand::rng();
+
+
let mut full_code = String::with_capacity(10);
+
for _ in 0..10 {
+
let idx = rng.random_range(0..UPPERCASE_BASE32_CHARS.len());
+
full_code.push(UPPERCASE_BASE32_CHARS[idx] as char);
+
}
+
+
//The PDS implementation creates in lowercase, then converts to uppercase.
+
//Just going a head and doing uppercase here.
+
let slice_one = &full_code[0..5].to_ascii_uppercase();
+
let slice_two = &full_code[5..10].to_ascii_uppercase();
+
format!("{slice_one}-{slice_two}")
+
}
+
+
pub enum TokenCheckError {
+
InvalidToken,
+
ExpiredToken,
+
}
+
+
pub enum AuthResult {
+
WrongIdentityOrPassword,
+
/// The string here is the email address to create a hint for oauth
+
TwoFactorRequired(String),
+
/// User does not have 2FA enabled, or using an app password, or passes it
+
ProxyThrough,
+
TokenCheckFailed(TokenCheckError),
+
}
+
+
pub enum IdentifierType {
+
Email,
+
Did,
+
Handle,
+
}
+
+
impl IdentifierType {
+
fn what_is_it(identifier: String) -> Self {
+
if identifier.contains("@") {
+
IdentifierType::Email
+
} else if identifier.contains("did:") {
+
IdentifierType::Did
+
} else {
+
IdentifierType::Handle
+
}
+
}
+
}
+
+
/// Creates a hex string from the password and salt to find app passwords
+
fn scrypt_hex(password: &str, salt: &str) -> anyhow::Result<String> {
+
let params = scrypt::Params::new(14, 8, 1, 64)?;
+
let mut derived = [0u8; 64];
+
scrypt::scrypt(password.as_bytes(), salt.as_bytes(), &params, &mut derived)?;
+
Ok(hex::encode(derived))
+
}
+
+
/// Hashes the app password. did is used as the salt.
+
pub fn hash_app_password(did: &str, password: &str) -> anyhow::Result<String> {
+
let mut hasher = Sha256::new();
+
hasher.update(did.as_bytes());
+
let sha = hasher.finalize();
+
let salt = hex::encode(&sha[..16]);
+
let hash_hex = scrypt_hex(password, &salt)?;
+
Ok(format!("{salt}:{hash_hex}"))
+
}
+
+
async fn verify_password(password: &str, password_scrypt: &str) -> anyhow::Result<bool> {
+
// Expected format: "salt:hash" where hash is hex of scrypt(password, salt, 64 bytes)
+
let mut parts = password_scrypt.splitn(2, ':');
+
let salt = match parts.next() {
+
Some(s) if !s.is_empty() => s,
+
_ => return Ok(false),
+
};
+
let stored_hash_hex = match parts.next() {
+
Some(h) if !h.is_empty() => h,
+
_ => return Ok(false),
+
};
+
+
// Derive using the shared helper and compare
+
let derived_hex = match scrypt_hex(password, salt) {
+
Ok(h) => h,
+
Err(_) => return Ok(false),
+
};
+
+
Ok(derived_hex.as_str() == stored_hash_hex)
+
}
+
+
/// Handles the auth checks along with sending a 2fa email
+
pub async fn preauth_check(
+
state: &AppState,
+
identifier: &str,
+
password: &str,
+
two_factor_code: Option<String>,
+
oauth: bool,
+
) -> anyhow::Result<AuthResult> {
+
// Determine identifier type
+
let id_type = IdentifierType::what_is_it(identifier.to_string());
+
+
// Query account DB for did and passwordScrypt based on identifier type
+
let account_row: Option<(String, String, String, String)> = match id_type {
+
IdentifierType::Email => {
+
sqlx::query_as::<_, (String, String, String, String)>(
+
"SELECT account.did, account.passwordScrypt, account.email, actor.handle
+
FROM actor
+
LEFT JOIN account ON actor.did = account.did
+
where account.email = ? LIMIT 1",
+
)
+
.bind(identifier)
+
.fetch_optional(&state.account_pool)
+
.await?
+
}
+
IdentifierType::Handle => {
+
sqlx::query_as::<_, (String, String, String, String)>(
+
"SELECT account.did, account.passwordScrypt, account.email, actor.handle
+
FROM actor
+
LEFT JOIN account ON actor.did = account.did
+
where actor.handle = ? LIMIT 1",
+
)
+
.bind(identifier)
+
.fetch_optional(&state.account_pool)
+
.await?
+
}
+
IdentifierType::Did => {
+
sqlx::query_as::<_, (String, String, String, String)>(
+
"SELECT account.did, account.passwordScrypt, account.email, actor.handle
+
FROM actor
+
LEFT JOIN account ON actor.did = account.did
+
where account.did = ? LIMIT 1",
+
)
+
.bind(identifier)
+
.fetch_optional(&state.account_pool)
+
.await?
+
}
+
};
+
+
if let Some((did, password_scrypt, email, handle)) = account_row {
+
// Verify password before proceeding to 2FA email step
+
let verified = verify_password(password, &password_scrypt).await?;
+
if !verified {
+
if oauth {
+
//OAuth does not allow app password logins so just go ahead and send it along it's way
+
return Ok(AuthResult::WrongIdentityOrPassword);
+
}
+
//Theres a chance it could be an app password so check that as well
+
return match verify_app_password(&state.account_pool, &did, password).await {
+
Ok(valid) => {
+
if valid {
+
//Was a valid app password up to the PDS now
+
Ok(AuthResult::ProxyThrough)
+
} else {
+
Ok(AuthResult::WrongIdentityOrPassword)
+
}
+
}
+
Err(err) => {
+
log::error!("Error checking the app password: {err}");
+
Err(err)
+
}
+
};
+
}
+
+
// Check two-factor requirement for this DID in the gatekeeper DB
+
let required_opt = sqlx::query_as::<_, (u8,)>(
+
"SELECT required FROM two_factor_accounts WHERE did = ? LIMIT 1",
+
)
+
.bind(did.clone())
+
.fetch_optional(&state.pds_gatekeeper_pool)
+
.await?;
+
+
let two_factor_required = match required_opt {
+
Some(row) => row.0 != 0,
+
None => false,
+
};
+
+
if two_factor_required {
+
//Two factor is required and a taken was provided
+
if let Some(two_factor_code) = two_factor_code {
+
//if the two_factor_code is set need to see if we have a valid token
+
if !two_factor_code.is_empty() {
+
return match assert_valid_token(
+
&state.account_pool,
+
did.clone(),
+
two_factor_code,
+
)
+
.await
+
{
+
Ok(_) => {
+
let result_of_cleanup =
+
delete_all_email_tokens(&state.account_pool, did.clone()).await;
+
if result_of_cleanup.is_err() {
+
log::error!(
+
"There was an error deleting the email tokens after login: {:?}",
+
result_of_cleanup.err()
+
)
+
}
+
Ok(AuthResult::ProxyThrough)
+
}
+
Err(err) => Ok(AuthResult::TokenCheckFailed(err)),
+
};
+
}
+
}
+
+
return match create_two_factor_token(&state.account_pool, did).await {
+
Ok(code) => {
+
let mut email_data = Map::new();
+
email_data.insert("token".to_string(), Value::from(code.clone()));
+
email_data.insert("handle".to_string(), Value::from(handle.clone()));
+
let email_body = state
+
.template_engine
+
.render("two_factor_code.hbs", email_data)?;
+
+
let email_message = Message::builder()
+
//TODO prob get the proper type in the state
+
.from(state.mailer_from.parse()?)
+
.to(email.parse()?)
+
.subject("Sign in to Bluesky")
+
.multipart(
+
MultiPart::alternative() // This is composed of two parts.
+
.singlepart(
+
SinglePart::builder()
+
.header(header::ContentType::TEXT_PLAIN)
+
.body(format!("We received a sign-in request for the account @{handle}. Use the code: {code} to sign in. If this wasn't you, we recommend taking steps to protect your account by changing your password at https://bsky.app/settings.")), // Every message should have a plain text fallback.
+
)
+
.singlepart(
+
SinglePart::builder()
+
.header(header::ContentType::TEXT_HTML)
+
.body(email_body),
+
),
+
)?;
+
match state.mailer.send(email_message).await {
+
Ok(_) => Ok(AuthResult::TwoFactorRequired(mask_email(email))),
+
Err(err) => {
+
log::error!("Error sending the 2FA email: {err}");
+
Err(anyhow!(err))
+
}
+
}
+
}
+
Err(err) => {
+
log::error!("error on creating a 2fa token: {err}");
+
Err(anyhow!(err))
+
}
+
};
+
}
+
}
+
+
// No local 2FA requirement (or account not found)
+
Ok(AuthResult::ProxyThrough)
+
}
+
+
pub async fn create_two_factor_token(
+
account_db: &SqlitePool,
+
did: String,
+
) -> anyhow::Result<String> {
+
let purpose = "2fa_code";
+
+
let token = get_random_token();
+
let right_now = Utc::now();
+
+
let res = sqlx::query(
+
"INSERT INTO email_token (purpose, did, token, requestedAt)
+
VALUES (?, ?, ?, ?)
+
ON CONFLICT(purpose, did) DO UPDATE SET
+
token=excluded.token,
+
requestedAt=excluded.requestedAt
+
WHERE did=excluded.did",
+
)
+
.bind(purpose)
+
.bind(&did)
+
.bind(&token)
+
.bind(right_now)
+
.execute(account_db)
+
.await;
+
+
match res {
+
Ok(_) => Ok(token),
+
Err(err) => {
+
log::error!("Error creating a two factor token: {err}");
+
Err(anyhow::anyhow!(err))
+
}
+
}
+
}
+
+
pub async fn delete_all_email_tokens(account_db: &SqlitePool, did: String) -> anyhow::Result<()> {
+
sqlx::query("DELETE FROM email_token WHERE did = ?")
+
.bind(did)
+
.execute(account_db)
+
.await?;
+
Ok(())
+
}
+
+
pub async fn assert_valid_token(
+
account_db: &SqlitePool,
+
did: String,
+
token: String,
+
) -> Result<(), TokenCheckError> {
+
let token_upper = token.to_ascii_uppercase();
+
let purpose = "2fa_code";
+
+
let row: Option<(String,)> = sqlx::query_as(
+
"SELECT requestedAt FROM email_token WHERE purpose = ? AND did = ? AND token = ? LIMIT 1",
+
)
+
.bind(purpose)
+
.bind(did)
+
.bind(token_upper)
+
.fetch_optional(account_db)
+
.await
+
.map_err(|err| {
+
log::error!("Error getting the 2fa token: {err}");
+
InvalidToken
+
})?;
+
+
match row {
+
None => Err(InvalidToken),
+
Some(row) => {
+
// Token lives for 15 minutes
+
let expiration_ms = 15 * 60_000;
+
+
let requested_at_utc = match chrono::DateTime::parse_from_rfc3339(&row.0) {
+
Ok(dt) => dt.with_timezone(&Utc),
+
Err(_) => {
+
return Err(TokenCheckError::InvalidToken);
+
}
+
};
+
+
let now = Utc::now();
+
let age_ms = (now - requested_at_utc).num_milliseconds();
+
let expired = age_ms > expiration_ms;
+
if expired {
+
return Err(TokenCheckError::ExpiredToken);
+
}
+
+
Ok(())
+
}
+
}
+
}
+
+
/// We just need to confirm if it's there or not. Will let the PDS do the actual figuring of permissions
+
pub async fn verify_app_password(
+
account_db: &SqlitePool,
+
did: &str,
+
password: &str,
+
) -> anyhow::Result<bool> {
+
let password_scrypt = hash_app_password(did, password)?;
+
+
let row: Option<(i64,)> = sqlx::query_as(
+
"SELECT Count(*) FROM app_password WHERE did = ? AND passwordScrypt = ? LIMIT 1",
+
)
+
.bind(did)
+
.bind(password_scrypt)
+
.fetch_optional(account_db)
+
.await?;
+
+
Ok(match row {
+
None => false,
+
Some((count,)) => count > 0,
+
})
+
}
+
+
/// Mask an email address into a hint like "2***0@p***m".
+
pub fn mask_email(email: String) -> String {
+
// Basic split on first '@'
+
let mut parts = email.splitn(2, '@');
+
let local = match parts.next() {
+
Some(l) => l,
+
None => return email.to_string(),
+
};
+
let domain_rest = match parts.next() {
+
Some(d) if !d.is_empty() => d,
+
_ => return email.to_string(),
+
};
+
+
// Helper to mask a single label (keep first and last, middle becomes ***).
+
fn mask_label(s: &str) -> String {
+
let chars: Vec<char> = s.chars().collect();
+
match chars.len() {
+
0 => String::new(),
+
1 => format!("{}***", chars[0]),
+
2 => format!("{}***{}", chars[0], chars[1]),
+
_ => format!("{}***{}", chars[0], chars[chars.len() - 1]),
+
}
+
}
+
+
// Mask local
+
let masked_local = mask_label(local);
+
+
// Mask first domain label only, keep the rest of the domain intact
+
let mut dom_parts = domain_rest.splitn(2, '.');
+
let first_label = dom_parts.next().unwrap_or("");
+
let rest = dom_parts.next();
+
let masked_first = mask_label(first_label);
+
let masked_domain = if let Some(rest) = rest {
+
format!("{}.{rest}", masked_first)
+
} else {
+
masked_first
+
};
+
+
format!("{masked_local}@{masked_domain}")
+
}
+53 -26
src/main.rs
···
use crate::xrpc::com_atproto_server::{create_session, get_session, update_email};
-
use axum::middleware as ax_middleware;
-
mod middleware;
use axum::body::Body;
use axum::handler::Handler;
use axum::http::{Method, header};
use axum::routing::post;
use axum::{Router, routing::get};
use axum_template::engine::Engine;
···
use tower_governor::governor::GovernorConfigBuilder;
use tower_http::compression::CompressionLayer;
use tower_http::cors::{Any, CorsLayer};
-
use tracing::{error, log};
use tracing_subscriber::{EnvFilter, fmt, prelude::*};
mod xrpc;
type HyperUtilClient = hyper_util::client::legacy::Client<HttpConnector, Body>;
···
struct EmailTemplates;
#[derive(Clone)]
-
struct AppState {
account_pool: SqlitePool,
pds_gatekeeper_pool: SqlitePool,
reverse_proxy_client: HyperUtilClient,
···
let intro = "\n\nThis is a PDS gatekeeper\n\nCode: https://tangled.sh/@baileytownsend.dev/pds-gatekeeper\n";
-
let banner = format!(" {}\n{}", body, intro);
(
[(header::CONTENT_TYPE, "text/plain; charset=utf-8")],
···
#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
setup_tracing();
-
//TODO prod
dotenvy::from_path(Path::new("./pds.env"))?;
let pds_root = env::var("PDS_DATA_DIRECTORY")?;
-
// let pds_root = "/home/baileytownsend/Documents/code/docker_compose/pds/pds_data";
-
let account_db_url = format!("{}/account.sqlite", pds_root);
-
log::info!("accounts_db_url: {}", account_db_url);
let account_options = SqliteConnectOptions::new()
-
.journal_mode(SqliteJournalMode::Wal)
-
.filename(account_db_url);
let account_pool = SqlitePoolOptions::new()
.max_connections(5)
.connect_with(account_options)
.await?;
-
let bells_db_url = format!("{}/pds_gatekeeper.sqlite", pds_root);
let options = SqliteConnectOptions::new()
.journal_mode(SqliteJournalMode::Wal)
.filename(bells_db_url)
-
.create_if_missing(true);
let pds_gatekeeper_pool = SqlitePoolOptions::new()
.max_connections(5)
.connect_with(options)
.await?;
-
// Run migrations for the bells_and_whistles database
// Note: the migrations are embedded at compile time from the given directory
// sqlx
sqlx::migrate!("./migrations")
···
AsyncSmtpTransport::<Tokio1Executor>::from_url(smtp_url.as_str())?.build();
//Email templates setup
let mut hbs = Handlebars::new();
-
let _ = hbs.register_embed_templates::<EmailTemplates>();
let state = AppState {
account_pool,
pds_gatekeeper_pool,
reverse_proxy_client: client,
-
//TODO should be env prob
-
pds_base_url: "http://localhost:3000".to_string(),
mailer,
mailer_from: sent_from,
template_engine: Engine::from(hbs),
···
// Rate limiting
//Allows 5 within 60 seconds, and after 60 should drop one off? So hit 5, then goes to 4 after 60 seconds.
-
let governor_conf = GovernorConfigBuilder::default()
.per_second(60)
.burst_size(5)
.finish()
-
.unwrap();
-
let governor_limiter = governor_conf.limiter().clone();
let interval = Duration::from_secs(60);
// a separate background task to clean up
std::thread::spawn(move || {
loop {
std::thread::sleep(interval);
-
tracing::info!("rate limiting storage size: {}", governor_limiter.len());
-
governor_limiter.retain_recent();
}
});
···
"/xrpc/com.atproto.server.updateEmail",
post(update_email).layer(ax_middleware::from_fn(middleware::extract_did)),
)
.route(
"/xrpc/com.atproto.server.createSession",
-
post(create_session.layer(GovernorLayer::new(governor_conf))),
)
.layer(CompressionLayer::new())
.layer(cors)
.with_state(state);
-
let host = env::var("HOST").unwrap_or_else(|_| "127.0.0.1".to_string());
-
let port: u16 = env::var("PORT")
.ok()
.and_then(|s| s.parse().ok())
.unwrap_or(8080);
···
.with_graceful_shutdown(shutdown_signal());
if let Err(err) = server.await {
-
error!(error = %err, "server error");
}
Ok(())
···
+
#![warn(clippy::unwrap_used)]
+
use crate::oauth_provider::sign_in;
use crate::xrpc::com_atproto_server::{create_session, get_session, update_email};
use axum::body::Body;
use axum::handler::Handler;
use axum::http::{Method, header};
+
use axum::middleware as ax_middleware;
use axum::routing::post;
use axum::{Router, routing::get};
use axum_template::engine::Engine;
···
use tower_governor::governor::GovernorConfigBuilder;
use tower_http::compression::CompressionLayer;
use tower_http::cors::{Any, CorsLayer};
+
use tracing::log;
use tracing_subscriber::{EnvFilter, fmt, prelude::*};
+
pub mod helpers;
+
mod middleware;
+
mod oauth_provider;
mod xrpc;
type HyperUtilClient = hyper_util::client::legacy::Client<HttpConnector, Body>;
···
struct EmailTemplates;
#[derive(Clone)]
+
pub struct AppState {
account_pool: SqlitePool,
pds_gatekeeper_pool: SqlitePool,
reverse_proxy_client: HyperUtilClient,
···
let intro = "\n\nThis is a PDS gatekeeper\n\nCode: https://tangled.sh/@baileytownsend.dev/pds-gatekeeper\n";
+
let banner = format!(" {body}\n{intro}");
(
[(header::CONTENT_TYPE, "text/plain; charset=utf-8")],
···
#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
setup_tracing();
+
//TODO may need to change where this reads from? Like an env variable for it's location? Or arg?
dotenvy::from_path(Path::new("./pds.env"))?;
let pds_root = env::var("PDS_DATA_DIRECTORY")?;
+
let account_db_url = format!("{pds_root}/account.sqlite");
let account_options = SqliteConnectOptions::new()
+
.filename(account_db_url)
+
.busy_timeout(Duration::from_secs(5));
let account_pool = SqlitePoolOptions::new()
.max_connections(5)
.connect_with(account_options)
.await?;
+
let bells_db_url = format!("{pds_root}/pds_gatekeeper.sqlite");
let options = SqliteConnectOptions::new()
.journal_mode(SqliteJournalMode::Wal)
.filename(bells_db_url)
+
.create_if_missing(true)
+
.busy_timeout(Duration::from_secs(5));
let pds_gatekeeper_pool = SqlitePoolOptions::new()
.max_connections(5)
.connect_with(options)
.await?;
+
// Run migrations for the extra database
// Note: the migrations are embedded at compile time from the given directory
// sqlx
sqlx::migrate!("./migrations")
···
AsyncSmtpTransport::<Tokio1Executor>::from_url(smtp_url.as_str())?.build();
//Email templates setup
let mut hbs = Handlebars::new();
+
+
let users_email_directory = env::var("GATEKEEPER_EMAIL_TEMPLATES_DIRECTORY");
+
if let Ok(users_email_directory) = users_email_directory {
+
hbs.register_template_file(
+
"two_factor_code.hbs",
+
format!("{users_email_directory}/two_factor_code.hbs"),
+
)?;
+
} else {
+
let _ = hbs.register_embed_templates::<EmailTemplates>();
+
}
+
+
let pds_base_url =
+
env::var("PDS_BASE_URL").unwrap_or_else(|_| "http://localhost:3000".to_string());
let state = AppState {
account_pool,
pds_gatekeeper_pool,
reverse_proxy_client: client,
+
pds_base_url,
mailer,
mailer_from: sent_from,
template_engine: Engine::from(hbs),
···
// Rate limiting
//Allows 5 within 60 seconds, and after 60 should drop one off? So hit 5, then goes to 4 after 60 seconds.
+
let create_session_governor_conf = GovernorConfigBuilder::default()
.per_second(60)
.burst_size(5)
.finish()
+
.expect("failed to create governor config. this should not happen and is a bug");
+
+
// Create a second config with the same settings for the other endpoint
+
let sign_in_governor_conf = GovernorConfigBuilder::default()
+
.per_second(60)
+
.burst_size(5)
+
.finish()
+
.expect("failed to create governor config. this should not happen and is a bug");
+
+
let create_session_governor_limiter = create_session_governor_conf.limiter().clone();
+
let sign_in_governor_limiter = sign_in_governor_conf.limiter().clone();
let interval = Duration::from_secs(60);
// a separate background task to clean up
std::thread::spawn(move || {
loop {
std::thread::sleep(interval);
+
create_session_governor_limiter.retain_recent();
+
sign_in_governor_limiter.retain_recent();
}
});
···
"/xrpc/com.atproto.server.updateEmail",
post(update_email).layer(ax_middleware::from_fn(middleware::extract_did)),
)
+
.route(
+
"/@atproto/oauth-provider/~api/sign-in",
+
post(sign_in).layer(GovernorLayer::new(sign_in_governor_conf)),
+
)
.route(
"/xrpc/com.atproto.server.createSession",
+
post(create_session.layer(GovernorLayer::new(create_session_governor_conf))),
)
.layer(CompressionLayer::new())
.layer(cors)
.with_state(state);
+
let host = env::var("GATEKEEPER_HOST").unwrap_or_else(|_| "127.0.0.1".to_string());
+
let port: u16 = env::var("GATEKEEPER_PORT")
.ok()
.and_then(|s| s.parse().ok())
.unwrap_or(8080);
···
.with_graceful_shutdown(shutdown_signal());
if let Err(err) = server.await {
+
log::error!("server error:{err}");
}
Ok(())
+19 -34
src/middleware.rs
···
-
use crate::xrpc::helpers::json_error_response;
use axum::extract::Request;
use axum::http::{HeaderMap, StatusCode};
use axum::middleware::Next;
···
use jwt_compact::{AlgorithmExt, Claims, Token, UntrustedToken, ValidationError};
use serde::{Deserialize, Serialize};
use std::env;
#[derive(Clone, Debug)]
pub struct Did(pub Option<String>);
···
match token {
Ok(token) => {
match token {
-
None => {
-
return json_error_response(
-
StatusCode::BAD_REQUEST,
-
"TokenRequired",
-
"",
-
).unwrap();
-
}
Some(token) => {
let token = UntrustedToken::new(&token);
-
//Doing weird unwraps cause I can't do Result for middleware?
if token.is_err() {
-
return json_error_response(
-
StatusCode::BAD_REQUEST,
-
"TokenRequired",
-
"",
-
).unwrap();
}
-
let parsed_token = token.unwrap();
let claims: Result<Claims<TokenClaims>, ValidationError> =
parsed_token.deserialize_claims_unchecked();
if claims.is_err() {
-
return json_error_response(
-
StatusCode::BAD_REQUEST,
-
"TokenRequired",
-
"",
-
).unwrap();
}
-
let key = Hs256Key::new(env::var("PDS_JWT_SECRET").unwrap());
let token: Result<Token<TokenClaims>, ValidationError> =
Hs256.validator(&key).validate(&parsed_token);
if token.is_err() {
-
return json_error_response(
-
StatusCode::BAD_REQUEST,
-
"InvalidToken",
-
"",
-
).unwrap();
}
-
let token = token.unwrap();
//Not going to worry about expiration since it still goes to the PDS
-
req.extensions_mut()
.insert(Did(Some(token.claims().custom.sub.clone())));
next.run(req).await
}
}
}
-
Err(_) => {
-
return json_error_response(
-
StatusCode::BAD_REQUEST,
-
"InvalidToken",
-
"",
-
).unwrap();
}
}
}
···
+
use crate::helpers::json_error_response;
use axum::extract::Request;
use axum::http::{HeaderMap, StatusCode};
use axum::middleware::Next;
···
use jwt_compact::{AlgorithmExt, Claims, Token, UntrustedToken, ValidationError};
use serde::{Deserialize, Serialize};
use std::env;
+
use tracing::log;
#[derive(Clone, Debug)]
pub struct Did(pub Option<String>);
···
match token {
Ok(token) => {
match token {
+
None => json_error_response(StatusCode::BAD_REQUEST, "TokenRequired", "")
+
.expect("Error creating an error response"),
Some(token) => {
let token = UntrustedToken::new(&token);
if token.is_err() {
+
return json_error_response(StatusCode::BAD_REQUEST, "TokenRequired", "")
+
.expect("Error creating an error response");
}
+
let parsed_token = token.expect("Already checked for error");
let claims: Result<Claims<TokenClaims>, ValidationError> =
parsed_token.deserialize_claims_unchecked();
if claims.is_err() {
+
return json_error_response(StatusCode::BAD_REQUEST, "TokenRequired", "")
+
.expect("Error creating an error response");
}
+
let key = Hs256Key::new(
+
env::var("PDS_JWT_SECRET").expect("PDS_JWT_SECRET not set in the pds.env"),
+
);
let token: Result<Token<TokenClaims>, ValidationError> =
Hs256.validator(&key).validate(&parsed_token);
if token.is_err() {
+
return json_error_response(StatusCode::BAD_REQUEST, "InvalidToken", "")
+
.expect("Error creating an error response");
}
+
let token = token.expect("Already checked for error,");
//Not going to worry about expiration since it still goes to the PDS
req.extensions_mut()
.insert(Did(Some(token.claims().custom.sub.clone())));
next.run(req).await
}
}
}
+
Err(err) => {
+
log::error!("Error extracting token: {err}");
+
json_error_response(StatusCode::BAD_REQUEST, "InvalidToken", "")
+
.expect("Error creating an error response")
}
}
}
+141
src/oauth_provider.rs
···
···
+
use crate::AppState;
+
use crate::helpers::{AuthResult, oauth_json_error_response, preauth_check};
+
use axum::body::Body;
+
use axum::extract::State;
+
use axum::http::header::CONTENT_TYPE;
+
use axum::http::{HeaderMap, HeaderName, HeaderValue, StatusCode};
+
use axum::response::{IntoResponse, Response};
+
use axum::{Json, extract};
+
use serde::{Deserialize, Serialize};
+
use tracing::log;
+
+
#[derive(Serialize, Deserialize, Clone)]
+
pub struct SignInRequest {
+
pub username: String,
+
pub password: String,
+
pub remember: bool,
+
pub locale: String,
+
#[serde(skip_serializing_if = "Option::is_none", rename = "emailOtp")]
+
pub email_otp: Option<String>,
+
}
+
+
pub async fn sign_in(
+
State(state): State<AppState>,
+
headers: HeaderMap,
+
Json(mut payload): extract::Json<SignInRequest>,
+
) -> Result<Response<Body>, StatusCode> {
+
let identifier = payload.username.clone();
+
let password = payload.password.clone();
+
let auth_factor_token = payload.email_otp.clone();
+
+
match preauth_check(&state, &identifier, &password, auth_factor_token, true).await {
+
Ok(result) => match result {
+
AuthResult::WrongIdentityOrPassword => oauth_json_error_response(
+
StatusCode::BAD_REQUEST,
+
"invalid_request",
+
"Invalid identifier or password",
+
),
+
AuthResult::TwoFactorRequired(masked_email) => {
+
// Email sending step can be handled here if needed in the future.
+
+
// {"error":"second_authentication_factor_required","error_description":"emailOtp authentication factor required (hint: 2***0@p***m)","type":"emailOtp","hint":"2***0@p***m"}
+
let body_str = match serde_json::to_string(&serde_json::json!({
+
"error": "second_authentication_factor_required",
+
"error_description": format!("emailOtp authentication factor required (hint: {})", masked_email),
+
"type": "emailOtp",
+
"hint": masked_email,
+
})) {
+
Ok(s) => s,
+
Err(_) => return Err(StatusCode::BAD_REQUEST),
+
};
+
+
Response::builder()
+
.status(StatusCode::BAD_REQUEST)
+
.header(CONTENT_TYPE, "application/json")
+
.body(Body::from(body_str))
+
.map_err(|_| StatusCode::BAD_REQUEST)
+
}
+
AuthResult::ProxyThrough => {
+
//No 2FA or already passed
+
let uri = format!(
+
"{}{}",
+
state.pds_base_url, "/@atproto/oauth-provider/~api/sign-in"
+
);
+
+
let mut req = axum::http::Request::post(uri);
+
if let Some(req_headers) = req.headers_mut() {
+
// Copy headers but remove problematic ones. There was an issue with the PDS not parsing the body fully if i forwarded all headers
+
copy_filtered_headers(&headers, req_headers);
+
//Setting the content type to application/json manually
+
req_headers.insert(CONTENT_TYPE, HeaderValue::from_static("application/json"));
+
}
+
+
//Clears the email_otp because the pds will reject a request with it.
+
payload.email_otp = None;
+
let payload_bytes =
+
serde_json::to_vec(&payload).map_err(|_| StatusCode::BAD_REQUEST)?;
+
+
let req = req
+
.body(Body::from(payload_bytes))
+
.map_err(|_| StatusCode::BAD_REQUEST)?;
+
+
let proxied = state
+
.reverse_proxy_client
+
.request(req)
+
.await
+
.map_err(|_| StatusCode::BAD_REQUEST)?
+
.into_response();
+
+
Ok(proxied)
+
}
+
//Ignoring the type of token check failure. Looks like oauth on the entry treads them the same.
+
AuthResult::TokenCheckFailed(_) => oauth_json_error_response(
+
StatusCode::BAD_REQUEST,
+
"invalid_request",
+
"Unable to sign-in due to an unexpected server error",
+
),
+
},
+
Err(err) => {
+
log::error!(
+
"Error during pre-auth check. This happens on the create_session endpoint when trying to decide if the user has access:\n {err}"
+
);
+
oauth_json_error_response(
+
StatusCode::BAD_REQUEST,
+
"pds_gatekeeper_error",
+
"This error was not generated by the PDS, but PDS Gatekeeper. Please contact your PDS administrator for help and for them to review the server logs.",
+
)
+
}
+
}
+
}
+
+
fn is_disallowed_header(name: &HeaderName) -> bool {
+
// possible problematic headers with proxying
+
matches!(
+
name.as_str(),
+
"connection"
+
| "keep-alive"
+
| "proxy-authenticate"
+
| "proxy-authorization"
+
| "te"
+
| "trailer"
+
| "transfer-encoding"
+
| "upgrade"
+
| "host"
+
| "content-length"
+
| "content-encoding"
+
| "expect"
+
| "accept-encoding"
+
)
+
}
+
+
fn copy_filtered_headers(src: &HeaderMap, dst: &mut HeaderMap) {
+
for (name, value) in src.iter() {
+
if is_disallowed_header(name) {
+
continue;
+
}
+
// Only copy valid headers
+
if let Ok(hv) = HeaderValue::from_bytes(value.as_bytes()) {
+
dst.insert(name.clone(), hv);
+
}
+
}
+
}
+66 -211
src/xrpc/com_atproto_server.rs
···
use crate::AppState;
use crate::middleware::Did;
-
use crate::xrpc::helpers::{ProxiedResult, json_error_response, proxy_get_json};
use axum::body::Body;
use axum::extract::State;
use axum::http::{HeaderMap, StatusCode};
use axum::response::{IntoResponse, Response};
use axum::{Extension, Json, debug_handler, extract, extract::Request};
-
use axum_template::TemplateEngine;
-
use lettre::message::{MultiPart, SinglePart, header};
-
use lettre::{AsyncTransport, Message};
use serde::{Deserialize, Serialize};
use serde_json;
-
use serde_json::Value;
-
use serde_json::value::Map;
use tracing::log;
#[derive(Serialize, Deserialize, Debug, Clone)]
···
pub struct CreateSessionRequest {
identifier: String,
password: String,
-
auth_factor_token: String,
-
allow_takendown: bool,
-
}
-
-
pub enum AuthResult {
-
WrongIdentityOrPassword,
-
TwoFactorRequired,
-
TwoFactorFailed,
-
/// User does not have 2FA enabled, or passes it
-
ProxyThrough,
-
}
-
-
pub enum IdentifierType {
-
Email,
-
DID,
-
Handle,
-
}
-
-
impl IdentifierType {
-
fn what_is_it(identifier: String) -> Self {
-
if identifier.contains("@") {
-
IdentifierType::Email
-
} else if identifier.contains("did:") {
-
IdentifierType::DID
-
} else {
-
IdentifierType::Handle
-
}
-
}
-
}
-
-
async fn verify_password(password: &str, password_scrypt: &str) -> Result<bool, StatusCode> {
-
// Expected format: "salt:hash" where hash is hex of scrypt(password, salt, 64 bytes)
-
let mut parts = password_scrypt.splitn(2, ':');
-
let salt = match parts.next() {
-
Some(s) if !s.is_empty() => s,
-
_ => return Ok(false),
-
};
-
let stored_hash_hex = match parts.next() {
-
Some(h) if !h.is_empty() => h,
-
_ => return Ok(false),
-
};
-
-
//Sets up scrypt to mimic node's scrypt
-
let params = match scrypt::Params::new(14, 8, 1, 64) {
-
Ok(p) => p,
-
Err(_) => return Ok(false),
-
};
-
let mut derived = [0u8; 64];
-
if scrypt::scrypt(password.as_bytes(), salt.as_bytes(), &params, &mut derived).is_err() {
-
return Ok(false);
-
}
-
-
let stored_bytes = match hex::decode(stored_hash_hex) {
-
Ok(b) => b,
-
Err(e) => {
-
log::error!("Error decoding stored hash: {}", e);
-
return Ok(false);
-
}
-
};
-
-
Ok(derived.as_slice() == stored_bytes.as_slice())
-
}
-
-
async fn preauth_check(
-
state: &AppState,
-
identifier: &str,
-
password: &str,
-
) -> Result<AuthResult, StatusCode> {
-
// Determine identifier type
-
let id_type = IdentifierType::what_is_it(identifier.to_string());
-
-
// Query account DB for did and passwordScrypt based on identifier type
-
let account_row: Option<(String, String, String)> = match id_type {
-
IdentifierType::Email => sqlx::query_as::<_, (String, String, String)>(
-
"SELECT did, passwordScrypt, account.email FROM account WHERE email = ? LIMIT 1",
-
)
-
.bind(identifier)
-
.fetch_optional(&state.account_pool)
-
.await
-
.map_err(|_| StatusCode::BAD_REQUEST)?,
-
IdentifierType::Handle => sqlx::query_as::<_, (String, String, String)>(
-
"SELECT account.did, account.passwordScrypt, account.email
-
FROM actor
-
LEFT JOIN account ON actor.did = account.did
-
where actor.handle =? LIMIT 1",
-
)
-
.bind(identifier)
-
.fetch_optional(&state.account_pool)
-
.await
-
.map_err(|_| StatusCode::BAD_REQUEST)?,
-
IdentifierType::DID => sqlx::query_as::<_, (String, String, String)>(
-
"SELECT did, passwordScrypt, account.email FROM account WHERE did = ? LIMIT 1",
-
)
-
.bind(identifier)
-
.fetch_optional(&state.account_pool)
-
.await
-
.map_err(|_| StatusCode::BAD_REQUEST)?,
-
};
-
-
if let Some((did, password_scrypt, email)) = account_row {
-
// Check two-factor requirement for this DID in the gatekeeper DB
-
let required_opt = sqlx::query_as::<_, (u8,)>(
-
"SELECT required FROM two_factor_accounts WHERE did = ? LIMIT 1",
-
)
-
.bind(&did)
-
.fetch_optional(&state.pds_gatekeeper_pool)
-
.await
-
.map_err(|_| StatusCode::BAD_REQUEST)?;
-
-
let two_factor_required = match required_opt {
-
Some(row) => row.0 != 0,
-
None => false,
-
};
-
-
if two_factor_required {
-
// Verify password before proceeding to 2FA email step
-
let verified = verify_password(password, &password_scrypt).await?;
-
if !verified {
-
return Ok(AuthResult::WrongIdentityOrPassword);
-
}
-
let mut email_data = Map::new();
-
//TODO these need real values
-
let token = "test".to_string();
-
let handle = "baileytownsend.dev".to_string();
-
email_data.insert("token".to_string(), Value::from(token.clone()));
-
email_data.insert("handle".to_string(), Value::from(handle.clone()));
-
//TODO bad unwrap
-
let email_body = state
-
.template_engine
-
.render("two_factor_code.hbs", email_data)
-
.unwrap();
-
-
let email = Message::builder()
-
//TODO prob get the proper type in the state
-
.from(state.mailer_from.parse().unwrap())
-
.to(email.parse().unwrap())
-
.subject("Sign in to Bluesky")
-
.multipart(
-
MultiPart::alternative() // This is composed of two parts.
-
.singlepart(
-
SinglePart::builder()
-
.header(header::ContentType::TEXT_PLAIN)
-
.body(format!("We received a sign-in request for the account @{}. Use the code: {} to sign in. If this wasn't you, we recommend taking steps to protect your account by changing your password at https://bsky.app/settings.", handle, token)), // Every message should have a plain text fallback.
-
)
-
.singlepart(
-
SinglePart::builder()
-
.header(header::ContentType::TEXT_HTML)
-
.body(email_body),
-
),
-
)
-
//TODO bad
-
.unwrap();
-
return match state.mailer.send(email).await {
-
Ok(_) => Ok(AuthResult::TwoFactorRequired),
-
Err(err) => {
-
log::error!("Error sending the 2FA email: {}", err);
-
Err(StatusCode::BAD_REQUEST)
-
}
-
};
-
}
-
}
-
-
// No local 2FA requirement (or account not found)
-
Ok(AuthResult::ProxyThrough)
}
pub async fn create_session(
···
) -> Result<Response<Body>, StatusCode> {
let identifier = payload.identifier.clone();
let password = payload.password.clone();
// Run the shared pre-auth logic to validate and check 2FA requirement
-
match preauth_check(&state, &identifier, &password).await? {
-
AuthResult::WrongIdentityOrPassword => json_error_response(
-
StatusCode::UNAUTHORIZED,
-
"AuthenticationRequired",
-
"Invalid identifier or password",
-
),
-
AuthResult::TwoFactorRequired => {
-
// Email sending step can be handled here if needed in the future.
-
json_error_response(
StatusCode::UNAUTHORIZED,
-
"AuthFactorTokenRequired",
-
"A sign in code has been sent to your email address",
-
)
-
}
-
AuthResult::TwoFactorFailed => {
-
//Not sure what the errors are for this response is yet
-
json_error_response(StatusCode::UNAUTHORIZED, "PLACEHOLDER", "PLACEHOLDER")
-
}
-
AuthResult::ProxyThrough => {
-
//No 2FA or already passed
-
let uri = format!(
-
"{}{}",
-
state.pds_base_url, "/xrpc/com.atproto.server.createSession"
-
);
-
-
let mut req = axum::http::Request::post(uri);
-
if let Some(req_headers) = req.headers_mut() {
-
req_headers.extend(headers.clone());
}
-
let payload_bytes =
-
serde_json::to_vec(&payload).map_err(|_| StatusCode::BAD_REQUEST)?;
-
let req = req
-
.body(Body::from(payload_bytes))
-
.map_err(|_| StatusCode::BAD_REQUEST)?;
-
let proxied = state
-
.reverse_proxy_client
-
.request(req)
-
.await
-
.map_err(|_| StatusCode::BAD_REQUEST)?
-
.into_response();
-
Ok(proxied)
}
}
}
···
) -> Result<Response<Body>, StatusCode> {
//If email auth is not set at all it is a update email address
let email_auth_not_set = payload.email_auth_factor.is_none();
-
//If email aurth is set it is to either turn on or off 2fa
let email_auth_update = payload.email_auth_factor.unwrap_or(false);
// Email update asked for
···
}
}
-
// Updating the acutal email address
let uri = format!(
"{}{}",
state.pds_base_url, "/xrpc/com.atproto.server.updateEmail"
···
use crate::AppState;
+
use crate::helpers::{
+
AuthResult, ProxiedResult, TokenCheckError, json_error_response, preauth_check, proxy_get_json,
+
};
use crate::middleware::Did;
use axum::body::Body;
use axum::extract::State;
use axum::http::{HeaderMap, StatusCode};
use axum::response::{IntoResponse, Response};
use axum::{Extension, Json, debug_handler, extract, extract::Request};
use serde::{Deserialize, Serialize};
use serde_json;
use tracing::log;
#[derive(Serialize, Deserialize, Debug, Clone)]
···
pub struct CreateSessionRequest {
identifier: String,
password: String,
+
#[serde(skip_serializing_if = "Option::is_none")]
+
auth_factor_token: Option<String>,
+
#[serde(skip_serializing_if = "Option::is_none")]
+
allow_takendown: Option<bool>,
}
pub async fn create_session(
···
) -> Result<Response<Body>, StatusCode> {
let identifier = payload.identifier.clone();
let password = payload.password.clone();
+
let auth_factor_token = payload.auth_factor_token.clone();
// Run the shared pre-auth logic to validate and check 2FA requirement
+
match preauth_check(&state, &identifier, &password, auth_factor_token, false).await {
+
Ok(result) => match result {
+
AuthResult::WrongIdentityOrPassword => json_error_response(
StatusCode::UNAUTHORIZED,
+
"AuthenticationRequired",
+
"Invalid identifier or password",
+
),
+
AuthResult::TwoFactorRequired(_) => {
+
// Email sending step can be handled here if needed in the future.
+
json_error_response(
+
StatusCode::UNAUTHORIZED,
+
"AuthFactorTokenRequired",
+
"A sign in code has been sent to your email address",
+
)
}
+
AuthResult::ProxyThrough => {
+
log::info!("Proxying through");
+
//No 2FA or already passed
+
let uri = format!(
+
"{}{}",
+
state.pds_base_url, "/xrpc/com.atproto.server.createSession"
+
);
+
+
let mut req = axum::http::Request::post(uri);
+
if let Some(req_headers) = req.headers_mut() {
+
req_headers.extend(headers.clone());
+
}
+
let payload_bytes =
+
serde_json::to_vec(&payload).map_err(|_| StatusCode::BAD_REQUEST)?;
+
let req = req
+
.body(Body::from(payload_bytes))
+
.map_err(|_| StatusCode::BAD_REQUEST)?;
+
let proxied = state
+
.reverse_proxy_client
+
.request(req)
+
.await
+
.map_err(|_| StatusCode::BAD_REQUEST)?
+
.into_response();
+
Ok(proxied)
+
}
+
AuthResult::TokenCheckFailed(err) => match err {
+
TokenCheckError::InvalidToken => {
+
json_error_response(StatusCode::BAD_REQUEST, "InvalidToken", "Token is invalid")
+
}
+
TokenCheckError::ExpiredToken => {
+
json_error_response(StatusCode::BAD_REQUEST, "ExpiredToken", "Token is expired")
+
}
+
},
+
},
+
Err(err) => {
+
log::error!(
+
"Error during pre-auth check. This happens on the create_session endpoint when trying to decide if the user has access:\n {err}"
+
);
+
json_error_response(
+
StatusCode::INTERNAL_SERVER_ERROR,
+
"InternalServerError",
+
"This error was not generated by the PDS, but PDS Gatekeeper. Please contact your PDS administrator for help and for them to review the server logs.",
+
)
}
}
}
···
) -> Result<Response<Body>, StatusCode> {
//If email auth is not set at all it is a update email address
let email_auth_not_set = payload.email_auth_factor.is_none();
+
//If email auth is set it is to either turn on or off 2fa
let email_auth_update = payload.email_auth_factor.unwrap_or(false);
// Email update asked for
···
}
}
+
// Updating the actual email address by sending it on to the PDS
let uri = format!(
"{}{}",
state.pds_base_url, "/xrpc/com.atproto.server.updateEmail"
-150
src/xrpc/helpers.rs
···
-
use axum::body::{Body, to_bytes};
-
use axum::extract::Request;
-
use axum::http::{HeaderMap, Method, StatusCode, Uri};
-
use axum::http::header::CONTENT_TYPE;
-
use axum::response::{IntoResponse, Response};
-
use serde::de::DeserializeOwned;
-
use tracing::error;
-
-
use crate::AppState;
-
-
/// The result of a proxied call that attempts to parse JSON.
-
pub enum ProxiedResult<T> {
-
/// Successfully parsed JSON body along with original response headers.
-
Parsed { value: T, _headers: HeaderMap },
-
/// Could not or should not parse: return the original (or rebuilt) response as-is.
-
Passthrough(Response<Body>),
-
}
-
-
/// Proxy the incoming request to the PDS base URL plus the provided path and attempt to parse
-
/// the successful response body as JSON into `T`.
-
///
-
/// Behavior:
-
/// - If the proxied response is non-200, returns Passthrough with the original response.
-
/// - If the response is 200 but JSON parsing fails, returns Passthrough with the original body and headers.
-
/// - If parsing succeeds, returns Parsed { value, headers }.
-
pub async fn proxy_get_json<T>(
-
state: &AppState,
-
mut req: Request,
-
path: &str,
-
) -> Result<ProxiedResult<T>, StatusCode>
-
where
-
T: DeserializeOwned,
-
{
-
let uri = format!("{}{}", state.pds_base_url, path);
-
*req.uri_mut() = Uri::try_from(uri).map_err(|_| StatusCode::BAD_REQUEST)?;
-
-
let result = state
-
.reverse_proxy_client
-
.request(req)
-
.await
-
.map_err(|_| StatusCode::BAD_REQUEST)?
-
.into_response();
-
-
if result.status() != StatusCode::OK {
-
return Ok(ProxiedResult::Passthrough(result));
-
}
-
-
let response_headers = result.headers().clone();
-
let body = result.into_body();
-
let body_bytes = to_bytes(body, usize::MAX)
-
.await
-
.map_err(|_| StatusCode::BAD_REQUEST)?;
-
-
match serde_json::from_slice::<T>(&body_bytes) {
-
Ok(value) => Ok(ProxiedResult::Parsed {
-
value,
-
_headers: response_headers,
-
}),
-
Err(err) => {
-
error!(%err, "failed to parse proxied JSON response; returning original body");
-
let mut builder = Response::builder().status(StatusCode::OK);
-
if let Some(headers) = builder.headers_mut() {
-
*headers = response_headers;
-
}
-
let resp = builder
-
.body(Body::from(body_bytes))
-
.map_err(|_| StatusCode::BAD_REQUEST)?;
-
Ok(ProxiedResult::Passthrough(resp))
-
}
-
}
-
}
-
-
/// Proxy the incoming request as a POST to the PDS base URL plus the provided path and attempt to parse
-
/// the successful response body as JSON into `T`.
-
///
-
/// Behavior mirrors `proxy_get_json`:
-
/// - If the proxied response is non-200, returns Passthrough with the original response.
-
/// - If the response is 200 but JSON parsing fails, returns Passthrough with the original body and headers.
-
/// - If parsing succeeds, returns Parsed { value, headers }.
-
pub async fn _proxy_post_json<T>(
-
state: &AppState,
-
mut req: Request,
-
path: &str,
-
) -> Result<ProxiedResult<T>, StatusCode>
-
where
-
T: DeserializeOwned,
-
{
-
let uri = format!("{}{}", state.pds_base_url, path);
-
*req.uri_mut() = Uri::try_from(uri).map_err(|_| StatusCode::BAD_REQUEST)?;
-
*req.method_mut() = Method::POST;
-
-
let result = state
-
.reverse_proxy_client
-
.request(req)
-
.await
-
.map_err(|_| StatusCode::BAD_REQUEST)?
-
.into_response();
-
-
if result.status() != StatusCode::OK {
-
return Ok(ProxiedResult::Passthrough(result));
-
}
-
-
let response_headers = result.headers().clone();
-
let body = result.into_body();
-
let body_bytes = to_bytes(body, usize::MAX)
-
.await
-
.map_err(|_| StatusCode::BAD_REQUEST)?;
-
-
match serde_json::from_slice::<T>(&body_bytes) {
-
Ok(value) => Ok(ProxiedResult::Parsed {
-
value,
-
_headers: response_headers,
-
}),
-
Err(err) => {
-
error!(%err, "failed to parse proxied JSON response (POST); returning original body");
-
let mut builder = Response::builder().status(StatusCode::OK);
-
if let Some(headers) = builder.headers_mut() {
-
*headers = response_headers;
-
}
-
let resp = builder
-
.body(Body::from(body_bytes))
-
.map_err(|_| StatusCode::BAD_REQUEST)?;
-
Ok(ProxiedResult::Passthrough(resp))
-
}
-
}
-
}
-
-
-
/// Build a JSON error response with the required Content-Type header
-
/// Content-Type: application/json;charset=utf-8
-
/// Body shape: { "error": string, "message": string }
-
pub fn json_error_response(
-
status: StatusCode,
-
error: impl Into<String>,
-
message: impl Into<String>,
-
) -> Result<Response<Body>, StatusCode> {
-
let body_str = match serde_json::to_string(&serde_json::json!({
-
"error": error.into(),
-
"message": message.into(),
-
})) {
-
Ok(s) => s,
-
Err(_) => return Err(StatusCode::BAD_REQUEST),
-
};
-
-
Response::builder()
-
.status(status)
-
.header(CONTENT_TYPE, "application/json;charset=utf-8")
-
.body(Body::from(body_str))
-
.map_err(|_| StatusCode::BAD_REQUEST)
-
}
···
-1
src/xrpc/mod.rs
···
pub mod com_atproto_server;
-
pub mod helpers;
···
pub mod com_atproto_server;