back interdiff of round #2 and #1

2FA logins gatekept #1

merged
opened by baileytownsend.dev targeting main from feature/2faCodeGeneration
ERROR
migrations_bells_and_whistles/.keep

Failed to calculate interdiff for this file.

REBASED
Cargo.lock

This patch was likely rebased, as context lines do not match.

ERROR
Cargo.toml

Failed to calculate interdiff for this file.

REBASED
src/xrpc/helpers.rs

This patch was likely rebased, as context lines do not match.

REBASED
src/middleware.rs

This patch was likely rebased, as context lines do not match.

ERROR
src/xrpc/mod.rs

Failed to calculate interdiff for this file.

ERROR
README.md

Failed to calculate interdiff for this file.

ERROR
src/helpers.rs

Failed to calculate interdiff for this file.

ERROR
src/oauth_provider.rs

Failed to calculate interdiff for this file.

NEW
src/main.rs
···
+
#![warn(clippy::unwrap_used)]
+
use crate::oauth_provider::sign_in;
use crate::xrpc::com_atproto_server::{create_session, get_session, update_email};
-
use axum::middleware as ax_middleware;
-
mod middleware;
use axum::body::Body;
use axum::handler::Handler;
use axum::http::{Method, header};
+
use axum::middleware as ax_middleware;
use axum::routing::post;
use axum::{Router, routing::get};
use axum_template::engine::Engine;
···
use tower_governor::governor::GovernorConfigBuilder;
use tower_http::compression::CompressionLayer;
use tower_http::cors::{Any, CorsLayer};
-
use tracing::{error, log};
+
use tracing::log;
use tracing_subscriber::{EnvFilter, fmt, prelude::*};
+
pub mod helpers;
+
mod middleware;
+
mod oauth_provider;
mod xrpc;
type HyperUtilClient = hyper_util::client::legacy::Client<HttpConnector, Body>;
···
struct EmailTemplates;
#[derive(Clone)]
-
struct AppState {
+
pub struct AppState {
account_pool: SqlitePool,
pds_gatekeeper_pool: SqlitePool,
reverse_proxy_client: HyperUtilClient,
···
let intro = "\n\nThis is a PDS gatekeeper\n\nCode: https://tangled.sh/@baileytownsend.dev/pds-gatekeeper\n";
-
let banner = format!(" {}\n{}", body, intro);
+
let banner = format!(" {body}\n{intro}");
(
[(header::CONTENT_TYPE, "text/plain; charset=utf-8")],
···
#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
setup_tracing();
-
//TODO prod
+
//TODO may need to change where this reads from? Like an env variable for it's location? Or arg?
dotenvy::from_path(Path::new("./pds.env"))?;
let pds_root = env::var("PDS_DATA_DIRECTORY")?;
-
// let pds_root = "/home/baileytownsend/Documents/code/docker_compose/pds/pds_data";
-
let account_db_url = format!("{}/account.sqlite", pds_root);
-
log::info!("accounts_db_url: {}", account_db_url);
+
let account_db_url = format!("{pds_root}/account.sqlite");
let account_options = SqliteConnectOptions::new()
-
.journal_mode(SqliteJournalMode::Wal)
-
.filename(account_db_url);
+
.filename(account_db_url)
+
.busy_timeout(Duration::from_secs(5));
let account_pool = SqlitePoolOptions::new()
.max_connections(5)
.connect_with(account_options)
.await?;
-
let bells_db_url = format!("{}/pds_gatekeeper.sqlite", pds_root);
+
let bells_db_url = format!("{pds_root}/pds_gatekeeper.sqlite");
let options = SqliteConnectOptions::new()
.journal_mode(SqliteJournalMode::Wal)
.filename(bells_db_url)
-
.create_if_missing(true);
+
.create_if_missing(true)
+
.busy_timeout(Duration::from_secs(5));
let pds_gatekeeper_pool = SqlitePoolOptions::new()
.max_connections(5)
.connect_with(options)
.await?;
-
// Run migrations for the bells_and_whistles database
+
// Run migrations for the extra database
// Note: the migrations are embedded at compile time from the given directory
// sqlx
sqlx::migrate!("./migrations")
···
AsyncSmtpTransport::<Tokio1Executor>::from_url(smtp_url.as_str())?.build();
//Email templates setup
let mut hbs = Handlebars::new();
-
let _ = hbs.register_embed_templates::<EmailTemplates>();
+
+
let users_email_directory = env::var("GATEKEEPER_EMAIL_TEMPLATES_DIRECTORY");
+
if let Ok(users_email_directory) = users_email_directory {
+
hbs.register_template_file(
+
"two_factor_code.hbs",
+
format!("{users_email_directory}/two_factor_code.hbs"),
+
)?;
+
} else {
+
let _ = hbs.register_embed_templates::<EmailTemplates>();
+
}
+
+
let pds_base_url =
+
env::var("PDS_BASE_URL").unwrap_or_else(|_| "http://localhost:3000".to_string());
let state = AppState {
account_pool,
pds_gatekeeper_pool,
reverse_proxy_client: client,
-
//TODO should be env prob
-
pds_base_url: "http://localhost:3000".to_string(),
+
pds_base_url,
mailer,
mailer_from: sent_from,
template_engine: Engine::from(hbs),
···
// Rate limiting
//Allows 5 within 60 seconds, and after 60 should drop one off? So hit 5, then goes to 4 after 60 seconds.
-
let governor_conf = GovernorConfigBuilder::default()
+
let create_session_governor_conf = GovernorConfigBuilder::default()
.per_second(60)
.burst_size(5)
.finish()
-
.unwrap();
-
let governor_limiter = governor_conf.limiter().clone();
+
.expect("failed to create governor config. this should not happen and is a bug");
+
+
// Create a second config with the same settings for the other endpoint
+
let sign_in_governor_conf = GovernorConfigBuilder::default()
+
.per_second(60)
+
.burst_size(5)
+
.finish()
+
.expect("failed to create governor config. this should not happen and is a bug");
+
+
let create_session_governor_limiter = create_session_governor_conf.limiter().clone();
+
let sign_in_governor_limiter = sign_in_governor_conf.limiter().clone();
let interval = Duration::from_secs(60);
// a separate background task to clean up
std::thread::spawn(move || {
loop {
std::thread::sleep(interval);
-
tracing::info!("rate limiting storage size: {}", governor_limiter.len());
-
governor_limiter.retain_recent();
+
create_session_governor_limiter.retain_recent();
+
sign_in_governor_limiter.retain_recent();
}
});
···
"/xrpc/com.atproto.server.updateEmail",
post(update_email).layer(ax_middleware::from_fn(middleware::extract_did)),
)
+
.route(
+
"/@atproto/oauth-provider/~api/sign-in",
+
post(sign_in).layer(GovernorLayer::new(sign_in_governor_conf)),
+
)
.route(
"/xrpc/com.atproto.server.createSession",
-
post(create_session.layer(GovernorLayer::new(governor_conf))),
+
post(create_session.layer(GovernorLayer::new(create_session_governor_conf))),
)
.layer(CompressionLayer::new())
.layer(cors)
.with_state(state);
-
let host = env::var("HOST").unwrap_or_else(|_| "127.0.0.1".to_string());
-
let port: u16 = env::var("PORT")
+
let host = env::var("GATEKEEPER_HOST").unwrap_or_else(|_| "127.0.0.1".to_string());
+
let port: u16 = env::var("GATEKEEPER_PORT")
.ok()
.and_then(|s| s.parse().ok())
.unwrap_or(8080);
···
.with_graceful_shutdown(shutdown_signal());
if let Err(err) = server.await {
-
error!(error = %err, "server error");
+
log::error!("server error:{err}");
}
Ok(())
NEW
src/xrpc/com_atproto_server.rs
···
use crate::AppState;
+
use crate::helpers::{
+
AuthResult, ProxiedResult, TokenCheckError, json_error_response, preauth_check, proxy_get_json,
+
};
use crate::middleware::Did;
-
use crate::xrpc::helpers::{ProxiedResult, json_error_response, proxy_get_json};
use axum::body::Body;
use axum::extract::State;
use axum::http::{HeaderMap, StatusCode};
use axum::response::{IntoResponse, Response};
use axum::{Extension, Json, debug_handler, extract, extract::Request};
-
use axum_template::TemplateEngine;
-
use lettre::message::{MultiPart, SinglePart, header};
-
use lettre::{AsyncTransport, Message};
use serde::{Deserialize, Serialize};
use serde_json;
-
use serde_json::Value;
-
use serde_json::value::Map;
use tracing::log;
#[derive(Serialize, Deserialize, Debug, Clone)]
···
pub struct CreateSessionRequest {
identifier: String,
password: String,
-
auth_factor_token: String,
-
allow_takendown: bool,
-
}
-
-
pub enum AuthResult {
-
WrongIdentityOrPassword,
-
TwoFactorRequired,
-
TwoFactorFailed,
-
/// User does not have 2FA enabled, or passes it
-
ProxyThrough,
-
}
-
-
pub enum IdentifierType {
-
Email,
-
DID,
-
Handle,
-
}
-
-
impl IdentifierType {
-
fn what_is_it(identifier: String) -> Self {
-
if identifier.contains("@") {
-
IdentifierType::Email
-
} else if identifier.contains("did:") {
-
IdentifierType::DID
-
} else {
-
IdentifierType::Handle
-
}
-
}
-
}
-
-
async fn verify_password(password: &str, password_scrypt: &str) -> Result<bool, StatusCode> {
-
// Expected format: "salt:hash" where hash is hex of scrypt(password, salt, 64 bytes)
-
let mut parts = password_scrypt.splitn(2, ':');
-
let salt = match parts.next() {
-
Some(s) if !s.is_empty() => s,
-
_ => return Ok(false),
-
};
-
let stored_hash_hex = match parts.next() {
-
Some(h) if !h.is_empty() => h,
-
_ => return Ok(false),
-
};
-
-
//Sets up scrypt to mimic node's scrypt
-
let params = match scrypt::Params::new(14, 8, 1, 64) {
-
Ok(p) => p,
-
Err(_) => return Ok(false),
-
};
-
let mut derived = [0u8; 64];
-
if scrypt::scrypt(password.as_bytes(), salt.as_bytes(), &params, &mut derived).is_err() {
-
return Ok(false);
-
}
-
-
let stored_bytes = match hex::decode(stored_hash_hex) {
-
Ok(b) => b,
-
Err(e) => {
-
log::error!("Error decoding stored hash: {}", e);
-
return Ok(false);
-
}
-
};
-
-
Ok(derived.as_slice() == stored_bytes.as_slice())
-
}
-
-
async fn preauth_check(
-
state: &AppState,
-
identifier: &str,
-
password: &str,
-
) -> Result<AuthResult, StatusCode> {
-
// Determine identifier type
-
let id_type = IdentifierType::what_is_it(identifier.to_string());
-
-
// Query account DB for did and passwordScrypt based on identifier type
-
let account_row: Option<(String, String, String)> = match id_type {
-
IdentifierType::Email => sqlx::query_as::<_, (String, String, String)>(
-
"SELECT did, passwordScrypt, account.email FROM account WHERE email = ? LIMIT 1",
-
)
-
.bind(identifier)
-
.fetch_optional(&state.account_pool)
-
.await
-
.map_err(|_| StatusCode::BAD_REQUEST)?,
-
IdentifierType::Handle => sqlx::query_as::<_, (String, String, String)>(
-
"SELECT account.did, account.passwordScrypt, account.email
-
FROM actor
-
LEFT JOIN account ON actor.did = account.did
-
where actor.handle =? LIMIT 1",
-
)
-
.bind(identifier)
-
.fetch_optional(&state.account_pool)
-
.await
-
.map_err(|_| StatusCode::BAD_REQUEST)?,
-
IdentifierType::DID => sqlx::query_as::<_, (String, String, String)>(
-
"SELECT did, passwordScrypt, account.email FROM account WHERE did = ? LIMIT 1",
-
)
-
.bind(identifier)
-
.fetch_optional(&state.account_pool)
-
.await
-
.map_err(|_| StatusCode::BAD_REQUEST)?,
-
};
-
-
if let Some((did, password_scrypt, email)) = account_row {
-
// Check two-factor requirement for this DID in the gatekeeper DB
-
let required_opt = sqlx::query_as::<_, (u8,)>(
-
"SELECT required FROM two_factor_accounts WHERE did = ? LIMIT 1",
-
)
-
.bind(&did)
-
.fetch_optional(&state.pds_gatekeeper_pool)
-
.await
-
.map_err(|_| StatusCode::BAD_REQUEST)?;
-
-
let two_factor_required = match required_opt {
-
Some(row) => row.0 != 0,
-
None => false,
-
};
-
-
if two_factor_required {
-
// Verify password before proceeding to 2FA email step
-
let verified = verify_password(password, &password_scrypt).await?;
-
if !verified {
-
return Ok(AuthResult::WrongIdentityOrPassword);
-
}
-
let mut email_data = Map::new();
-
//TODO these need real values
-
let token = "test".to_string();
-
let handle = "baileytownsend.dev".to_string();
-
email_data.insert("token".to_string(), Value::from(token.clone()));
-
email_data.insert("handle".to_string(), Value::from(handle.clone()));
-
//TODO bad unwrap
-
let email_body = state
-
.template_engine
-
.render("two_factor_code.hbs", email_data)
-
.unwrap();
-
-
let email = Message::builder()
-
//TODO prob get the proper type in the state
-
.from(state.mailer_from.parse().unwrap())
-
.to(email.parse().unwrap())
-
.subject("Sign in to Bluesky")
-
.multipart(
-
MultiPart::alternative() // This is composed of two parts.
-
.singlepart(
-
SinglePart::builder()
-
.header(header::ContentType::TEXT_PLAIN)
-
.body(format!("We received a sign-in request for the account @{}. Use the code: {} to sign in. If this wasn't you, we recommend taking steps to protect your account by changing your password at https://bsky.app/settings.", handle, token)), // Every message should have a plain text fallback.
-
)
-
.singlepart(
-
SinglePart::builder()
-
.header(header::ContentType::TEXT_HTML)
-
.body(email_body),
-
),
-
)
-
//TODO bad
-
.unwrap();
-
return match state.mailer.send(email).await {
-
Ok(_) => Ok(AuthResult::TwoFactorRequired),
-
Err(err) => {
-
log::error!("Error sending the 2FA email: {}", err);
-
Err(StatusCode::BAD_REQUEST)
-
}
-
};
-
}
-
}
-
-
// No local 2FA requirement (or account not found)
-
Ok(AuthResult::ProxyThrough)
+
#[serde(skip_serializing_if = "Option::is_none")]
+
auth_factor_token: Option<String>,
+
#[serde(skip_serializing_if = "Option::is_none")]
+
allow_takendown: Option<bool>,
}
pub async fn create_session(
···
) -> Result<Response<Body>, StatusCode> {
let identifier = payload.identifier.clone();
let password = payload.password.clone();
+
let auth_factor_token = payload.auth_factor_token.clone();
// Run the shared pre-auth logic to validate and check 2FA requirement
-
match preauth_check(&state, &identifier, &password).await? {
-
AuthResult::WrongIdentityOrPassword => json_error_response(
-
StatusCode::UNAUTHORIZED,
-
"AuthenticationRequired",
-
"Invalid identifier or password",
-
),
-
AuthResult::TwoFactorRequired => {
-
// Email sending step can be handled here if needed in the future.
-
json_error_response(
+
match preauth_check(&state, &identifier, &password, auth_factor_token, false).await {
+
Ok(result) => match result {
+
AuthResult::WrongIdentityOrPassword => json_error_response(
StatusCode::UNAUTHORIZED,
-
"AuthFactorTokenRequired",
-
"A sign in code has been sent to your email address",
-
)
-
}
-
AuthResult::TwoFactorFailed => {
-
//Not sure what the errors are for this response is yet
-
json_error_response(StatusCode::UNAUTHORIZED, "PLACEHOLDER", "PLACEHOLDER")
-
}
-
AuthResult::ProxyThrough => {
-
//No 2FA or already passed
-
let uri = format!(
-
"{}{}",
-
state.pds_base_url, "/xrpc/com.atproto.server.createSession"
-
);
-
-
let mut req = axum::http::Request::post(uri);
-
if let Some(req_headers) = req.headers_mut() {
-
req_headers.extend(headers.clone());
+
"AuthenticationRequired",
+
"Invalid identifier or password",
+
),
+
AuthResult::TwoFactorRequired(_) => {
+
// Email sending step can be handled here if needed in the future.
+
json_error_response(
+
StatusCode::UNAUTHORIZED,
+
"AuthFactorTokenRequired",
+
"A sign in code has been sent to your email address",
+
)
}
+
AuthResult::ProxyThrough => {
+
log::info!("Proxying through");
+
//No 2FA or already passed
+
let uri = format!(
+
"{}{}",
+
state.pds_base_url, "/xrpc/com.atproto.server.createSession"
+
);
+
+
let mut req = axum::http::Request::post(uri);
+
if let Some(req_headers) = req.headers_mut() {
+
req_headers.extend(headers.clone());
+
}
-
let payload_bytes =
-
serde_json::to_vec(&payload).map_err(|_| StatusCode::BAD_REQUEST)?;
-
let req = req
-
.body(Body::from(payload_bytes))
-
.map_err(|_| StatusCode::BAD_REQUEST)?;
+
let payload_bytes =
+
serde_json::to_vec(&payload).map_err(|_| StatusCode::BAD_REQUEST)?;
+
let req = req
+
.body(Body::from(payload_bytes))
+
.map_err(|_| StatusCode::BAD_REQUEST)?;
-
let proxied = state
-
.reverse_proxy_client
-
.request(req)
-
.await
-
.map_err(|_| StatusCode::BAD_REQUEST)?
-
.into_response();
+
let proxied = state
+
.reverse_proxy_client
+
.request(req)
+
.await
+
.map_err(|_| StatusCode::BAD_REQUEST)?
+
.into_response();
-
Ok(proxied)
+
Ok(proxied)
+
}
+
AuthResult::TokenCheckFailed(err) => match err {
+
TokenCheckError::InvalidToken => {
+
json_error_response(StatusCode::BAD_REQUEST, "InvalidToken", "Token is invalid")
+
}
+
TokenCheckError::ExpiredToken => {
+
json_error_response(StatusCode::BAD_REQUEST, "ExpiredToken", "Token is expired")
+
}
+
},
+
},
+
Err(err) => {
+
log::error!(
+
"Error during pre-auth check. This happens on the create_session endpoint when trying to decide if the user has access:\n {err}"
+
);
+
json_error_response(
+
StatusCode::INTERNAL_SERVER_ERROR,
+
"InternalServerError",
+
"This error was not generated by the PDS, but PDS Gatekeeper. Please contact your PDS administrator for help and for them to review the server logs.",
+
)
}
}
}
···
) -> Result<Response<Body>, StatusCode> {
//If email auth is not set at all it is a update email address
let email_auth_not_set = payload.email_auth_factor.is_none();
-
//If email aurth is set it is to either turn on or off 2fa
+
//If email auth is set it is to either turn on or off 2fa
let email_auth_update = payload.email_auth_factor.unwrap_or(false);
// Email update asked for
···
}
}
-
// Updating the acutal email address
+
// Updating the actual email address by sending it on to the PDS
let uri = format!(
"{}{}",
state.pds_base_url, "/xrpc/com.atproto.server.updateEmail"