A better Rust ATProto crate
at main 16 kB view raw
1//! Service authentication extractor and middleware 2//! 3//! # Example 4//! 5//! ```no_run 6//! use axum::{Router, routing::get}; 7//! use jacquard_axum::service_auth::{ServiceAuthConfig, ExtractServiceAuth}; 8//! use jacquard_identity::JacquardResolver; 9//! use jacquard_identity::resolver::ResolverOptions; 10//! use jacquard_common::types::string::Did; 11//! 12//! async fn handler( 13//! ExtractServiceAuth(auth): ExtractServiceAuth, 14//! ) -> String { 15//! format!("Authenticated as {}", auth.did()) 16//! } 17//! 18//! #[tokio::main] 19//! async fn main() { 20//! let resolver = JacquardResolver::new( 21//! reqwest::Client::new(), 22//! ResolverOptions::default(), 23//! ); 24//! let config = ServiceAuthConfig::new( 25//! Did::new_static("did:web:feedgen.example.com").unwrap(), 26//! resolver, 27//! ); 28//! 29//! let app = Router::new() 30//! .route("/xrpc/app.bsky.feed.getFeedSkeleton", get(handler)) 31//! .with_state(config); 32//! 33//! let listener = tokio::net::TcpListener::bind("0.0.0.0:3000") 34//! .await 35//! .unwrap(); 36//! axum::serve(listener, app).await.unwrap(); 37//! } 38//! ``` 39 40use axum::{ 41 Json, 42 extract::FromRequestParts, 43 http::{HeaderValue, StatusCode, header, request::Parts}, 44 middleware::Next, 45 response::{IntoResponse, Response}, 46}; 47use jacquard_common::{ 48 CowStr, IntoStatic, 49 service_auth::{self, PublicKey}, 50 types::{ 51 did_doc::VerificationMethod, 52 string::{Did, Nsid}, 53 }, 54}; 55use jacquard_identity::resolver::IdentityResolver; 56use serde_json::json; 57use std::sync::Arc; 58use thiserror::Error; 59 60/// Trait for providing service authentication configuration. 61/// 62/// This trait allows custom state types to provide service auth configuration 63/// without requiring `ServiceAuthConfig<R>` directly. 64pub trait ServiceAuth { 65 /// The identity resolver type 66 type Resolver: IdentityResolver; 67 68 /// Get the service DID (expected audience) 69 fn service_did(&self) -> &Did<'_>; 70 71 /// Get a reference to the identity resolver 72 fn resolver(&self) -> &Self::Resolver; 73 74 /// Whether to require the `lxm` (method binding) field 75 fn require_lxm(&self) -> bool; 76} 77 78/// Configuration for service auth verification. 79/// 80/// This should be stored in your Axum app state and will be extracted 81/// by the `ExtractServiceAuth` extractor. 82pub struct ServiceAuthConfig<R> { 83 /// The DID of your service (the expected audience) 84 service_did: Did<'static>, 85 /// Identity resolver for fetching DID documents 86 resolver: Arc<R>, 87 /// Whether to require the `lxm` (method binding) field 88 require_lxm: bool, 89} 90 91impl<R> Clone for ServiceAuthConfig<R> { 92 fn clone(&self) -> Self { 93 Self { 94 service_did: self.service_did.clone(), 95 resolver: Arc::clone(&self.resolver), 96 require_lxm: self.require_lxm, 97 } 98 } 99} 100 101impl<R: IdentityResolver> ServiceAuthConfig<R> { 102 /// Create a new service auth config. 103 /// 104 /// This enables `lxm` (method binding). If you need backward compatibility, 105 /// use `ServiceAuthConfig::new_legacy()` 106 pub fn new(service_did: Did<'static>, resolver: R) -> Self { 107 Self { 108 service_did, 109 resolver: Arc::new(resolver), 110 require_lxm: true, 111 } 112 } 113 114 /// Create a new service auth config. 115 /// 116 /// `lxm` (method binding) is disabled for backwards compatibility 117 pub fn new_legacy(service_did: Did<'static>, resolver: R) -> Self { 118 Self { 119 service_did, 120 resolver: Arc::new(resolver), 121 require_lxm: false, 122 } 123 } 124 125 /// Set whether to require the `lxm` field (method binding). 126 /// 127 /// When enabled, the JWT must contain an `lxm` field matching the requested endpoint. 128 /// This prevents token reuse across different methods. 129 pub fn require_lxm(mut self, require: bool) -> Self { 130 self.require_lxm = require; 131 self 132 } 133 134 /// Get the service DID. 135 pub fn service_did(&self) -> &Did<'static> { 136 &self.service_did 137 } 138 139 /// Get a reference to the identity resolver. 140 pub fn resolver(&self) -> &R { 141 &self.resolver 142 } 143} 144 145impl<R: IdentityResolver> ServiceAuth for ServiceAuthConfig<R> { 146 type Resolver = R; 147 148 fn service_did(&self) -> &Did<'_> { 149 &self.service_did 150 } 151 152 fn resolver(&self) -> &Self::Resolver { 153 &self.resolver 154 } 155 156 fn require_lxm(&self) -> bool { 157 self.require_lxm 158 } 159} 160 161/// Verified service authentication information. 162/// 163/// This is the result of successfully verifying a service auth JWT. 164/// This type is extracted by the `ExtractServiceAuth` extractor. 165#[derive(Debug, Clone, jacquard_derive::IntoStatic)] 166pub struct VerifiedServiceAuth<'a> { 167 /// The authenticated user's DID (from `iss` claim) 168 did: Did<'a>, 169 /// The audience (should match your service DID) 170 aud: Did<'a>, 171 /// The lexicon method NSID, if present 172 lxm: Option<Nsid<'a>>, 173 /// JWT ID (nonce), if present 174 jti: Option<CowStr<'a>>, 175} 176 177impl<'a> VerifiedServiceAuth<'a> { 178 /// Get the authenticated user's DID. 179 pub fn did(&self) -> &Did<'a> { 180 &self.did 181 } 182 183 /// Get the audience (your service DID). 184 pub fn aud(&self) -> &Did<'a> { 185 &self.aud 186 } 187 188 /// Get the lexicon method NSID, if present. 189 pub fn lxm(&self) -> Option<&Nsid<'a>> { 190 self.lxm.as_ref() 191 } 192 193 /// Get the JWT ID (nonce), if present. 194 /// 195 /// You can use this for replay protection by tracking seen JTIs 196 /// until their expiration time. 197 pub fn jti(&self) -> Option<&str> { 198 self.jti.as_ref().map(|j| j.as_ref()) 199 } 200} 201 202/// Axum extractor for service authentication. 203/// 204/// This extracts and verifies a service auth JWT from the Authorization header, 205/// resolving the issuer's DID to verify the signature. 206/// 207/// # Example 208/// 209/// ```no_run 210/// use axum::{Router, routing::get}; 211/// use jacquard_axum::service_auth::{ServiceAuthConfig, ExtractServiceAuth}; 212/// use jacquard_identity::JacquardResolver; 213/// use jacquard_identity::resolver::ResolverOptions; 214/// use jacquard_common::types::string::Did; 215/// 216/// async fn handler( 217/// ExtractServiceAuth(auth): ExtractServiceAuth, 218/// ) -> String { 219/// format!("Authenticated as {}", auth.did()) 220/// } 221/// 222/// #[tokio::main] 223/// async fn main() { 224/// let resolver = JacquardResolver::new( 225/// reqwest::Client::new(), 226/// ResolverOptions::default(), 227/// ); 228/// let config = ServiceAuthConfig::new( 229/// Did::new_static("did:web:feedgen.example.com").unwrap(), 230/// resolver, 231/// ); 232/// 233/// let app = Router::new() 234/// .route("/xrpc/app.bsky.feed.getFeedSkeleton", get(handler)) 235/// .with_state(config); 236/// 237/// let listener = tokio::net::TcpListener::bind("0.0.0.0:3000") 238/// .await 239/// .unwrap(); 240/// axum::serve(listener, app).await.unwrap(); 241/// } 242/// ``` 243pub struct ExtractServiceAuth(pub VerifiedServiceAuth<'static>); 244 245/// Errors that can occur during service auth verification. 246#[derive(Debug, Error, miette::Diagnostic)] 247pub enum ServiceAuthError { 248 /// Authorization header is missing 249 #[error("missing Authorization header")] 250 MissingAuthHeader, 251 252 /// Authorization header is malformed (not "Bearer `token`") 253 #[error("invalid Authorization header format")] 254 InvalidAuthHeader, 255 256 /// JWT parsing or verification failed 257 #[error("JWT verification failed: {0}")] 258 JwtError(#[from] service_auth::ServiceAuthError), 259 260 /// DID resolution failed 261 #[error("failed to resolve DID {did}: {source}")] 262 DidResolutionFailed { 263 did: Did<'static>, 264 #[source] 265 source: Box<dyn std::error::Error + Send + Sync>, 266 }, 267 268 /// No valid signing key found in DID document 269 #[error("no valid signing key found in DID document for {0}")] 270 NoSigningKey(Did<'static>), 271 272 /// Method binding required but missing 273 #[error("lxm (method binding) is required but missing from token")] 274 MethodBindingRequired, 275 276 /// Invalid key format 277 #[error("invalid key format: {0}")] 278 InvalidKey(String), 279} 280 281impl IntoResponse for ServiceAuthError { 282 fn into_response(self) -> Response { 283 let (status, error_code, message) = match &self { 284 ServiceAuthError::MissingAuthHeader => { 285 (StatusCode::UNAUTHORIZED, "AuthMissing", self.to_string()) 286 } 287 ServiceAuthError::InvalidAuthHeader => { 288 (StatusCode::UNAUTHORIZED, "AuthMissing", self.to_string()) 289 } 290 ServiceAuthError::JwtError(_) => ( 291 StatusCode::UNAUTHORIZED, 292 "AuthenticationRequired", 293 self.to_string(), 294 ), 295 ServiceAuthError::DidResolutionFailed { .. } => ( 296 StatusCode::UNAUTHORIZED, 297 "AuthenticationRequired", 298 self.to_string(), 299 ), 300 ServiceAuthError::NoSigningKey(_) => ( 301 StatusCode::UNAUTHORIZED, 302 "AuthenticationRequired", 303 self.to_string(), 304 ), 305 ServiceAuthError::MethodBindingRequired => ( 306 StatusCode::UNAUTHORIZED, 307 "AuthenticationRequired", 308 self.to_string(), 309 ), 310 ServiceAuthError::InvalidKey(_) => ( 311 StatusCode::UNAUTHORIZED, 312 "AuthenticationRequired", 313 self.to_string(), 314 ), 315 }; 316 317 tracing::warn!("Service auth failed: {}", message); 318 319 ( 320 status, 321 [( 322 header::CONTENT_TYPE, 323 HeaderValue::from_static("application/json"), 324 )], 325 Json(json!({ 326 "error": error_code, 327 "message": message, 328 })), 329 ) 330 .into_response() 331 } 332} 333 334impl<S> FromRequestParts<S> for ExtractServiceAuth 335where 336 S: ServiceAuth + Send + Sync, 337 S::Resolver: Send + Sync, 338{ 339 type Rejection = ServiceAuthError; 340 341 fn from_request_parts( 342 parts: &mut Parts, 343 state: &S, 344 ) -> impl std::future::Future<Output = Result<Self, Self::Rejection>> + Send { 345 async move { 346 // Extract Authorization header 347 let auth_header = parts 348 .headers 349 .get(header::AUTHORIZATION) 350 .ok_or(ServiceAuthError::MissingAuthHeader)?; 351 352 // Parse Bearer token 353 let auth_str = auth_header 354 .to_str() 355 .map_err(|_| ServiceAuthError::InvalidAuthHeader)?; 356 357 let token = auth_str 358 .strip_prefix("Bearer ") 359 .ok_or(ServiceAuthError::InvalidAuthHeader)?; 360 361 // Parse JWT 362 let parsed = service_auth::parse_jwt(token)?; 363 364 // Get claims for DID resolution 365 let claims = parsed.claims(); 366 367 // Resolve DID to get signing key (do this before checking claims) 368 let did_doc = state 369 .resolver() 370 .resolve_did_doc(&claims.iss) 371 .await 372 .map_err(|e| ServiceAuthError::DidResolutionFailed { 373 did: claims.iss.clone().into_static(), 374 source: Box::new(e), 375 })?; 376 377 // Parse the DID document response to get verification methods 378 let doc = did_doc 379 .parse() 380 .map_err(|e| ServiceAuthError::DidResolutionFailed { 381 did: claims.iss.clone().into_static(), 382 source: Box::new(e), 383 })?; 384 385 // Extract signing key from DID document 386 let verification_methods = doc 387 .verification_method 388 .as_deref() 389 .ok_or_else(|| ServiceAuthError::NoSigningKey(claims.iss.clone().into_static()))?; 390 391 let signing_key = extract_signing_key(verification_methods) 392 .ok_or_else(|| ServiceAuthError::NoSigningKey(claims.iss.clone().into_static()))?; 393 394 // Verify signature FIRST - if this fails, nothing else matters 395 service_auth::verify_signature(&parsed, &signing_key)?; 396 397 // Now validate claims (audience, expiration, etc.) 398 claims.validate(state.service_did())?; 399 400 // Check method binding if required 401 if state.require_lxm() && claims.lxm.is_none() { 402 return Err(ServiceAuthError::MethodBindingRequired); 403 } 404 405 // All checks passed - return verified auth 406 Ok(ExtractServiceAuth(VerifiedServiceAuth { 407 did: claims.iss.clone().into_static(), 408 aud: claims.aud.clone().into_static(), 409 lxm: claims.lxm.as_ref().map(|l| l.clone().into_static()), 410 jti: claims.jti.as_ref().map(|j| j.clone().into_static()), 411 })) 412 } 413 } 414} 415 416/// Extract the signing key from a DID document's verification methods. 417/// 418/// This looks for a key with type "atproto" or the first available key 419/// if no atproto-specific key is found. 420fn extract_signing_key(methods: &[VerificationMethod]) -> Option<PublicKey> { 421 // First try to find an atproto-specific key 422 let atproto_method = methods 423 .iter() 424 .find(|m| m.r#type.as_ref() == "Multikey" || m.r#type.as_ref() == "atproto"); 425 426 let method = atproto_method.or_else(|| methods.first())?; 427 428 // Parse the multikey 429 let public_key_multibase = method.public_key_multibase.as_ref()?; 430 431 // Decode multibase 432 let (_, key_bytes) = multibase::decode(public_key_multibase.as_ref()).ok()?; 433 434 // First two bytes are the multicodec prefix 435 if key_bytes.len() < 2 { 436 return None; 437 } 438 439 let codec = &key_bytes[..2]; 440 let key_material = &key_bytes[2..]; 441 442 match codec { 443 // p256-pub (0x1200) 444 [0x80, 0x24] => PublicKey::from_p256_bytes(key_material).ok(), 445 // secp256k1-pub (0xe7) 446 [0xe7, 0x01] => PublicKey::from_k256_bytes(key_material).ok(), 447 _ => None, 448 } 449} 450 451/// Middleware for verifying service authentication on all requests. 452/// 453/// This middleware extracts and verifies the service auth JWT, then adds the 454/// `VerifiedServiceAuth` to request extensions for downstream handlers to access. 455/// 456/// # Example 457/// 458/// ```no_run 459/// use axum::{Router, routing::get, middleware, Extension}; 460/// use jacquard_axum::service_auth::{ServiceAuthConfig, service_auth_middleware}; 461/// use jacquard_identity::JacquardResolver; 462/// use jacquard_identity::resolver::ResolverOptions; 463/// use jacquard_common::types::string::Did; 464/// 465/// async fn handler( 466/// Extension(auth): Extension<jacquard_axum::service_auth::VerifiedServiceAuth<'static>>, 467/// ) -> String { 468/// format!("Authenticated as {}", auth.did()) 469/// } 470/// 471/// #[tokio::main] 472/// async fn main() { 473/// let resolver = JacquardResolver::new( 474/// reqwest::Client::new(), 475/// ResolverOptions::default(), 476/// ); 477/// let config = ServiceAuthConfig::new( 478/// Did::new_static("did:web:feedgen.example.com").unwrap(), 479/// resolver, 480/// ); 481/// 482/// let app = Router::new() 483/// .route("/xrpc/app.bsky.feed.getFeedSkeleton", get(handler)) 484/// .layer(middleware::from_fn_with_state( 485/// config.clone(), 486/// service_auth_middleware::<ServiceAuthConfig<JacquardResolver>>, 487/// )) 488/// .with_state(config); 489/// 490/// let listener = tokio::net::TcpListener::bind("0.0.0.0:3000") 491/// .await 492/// .unwrap(); 493/// axum::serve(listener, app).await.unwrap(); 494/// } 495/// ``` 496pub async fn service_auth_middleware<S>( 497 state: axum::extract::State<S>, 498 mut req: axum::extract::Request, 499 next: Next, 500) -> Result<Response, ServiceAuthError> 501where 502 S: ServiceAuth + Send + Sync + Clone, 503 S::Resolver: Send + Sync, 504{ 505 // Extract auth from request parts 506 let (mut parts, body) = req.into_parts(); 507 let ExtractServiceAuth(auth) = 508 ExtractServiceAuth::from_request_parts(&mut parts, &state.0).await?; 509 510 // Add auth to extensions 511 parts.extensions.insert(auth); 512 513 // Reconstruct request and continue 514 req = axum::extract::Request::from_parts(parts, body); 515 Ok(next.run(req).await) 516}