A better Rust ATProto crate
at main 15 kB view raw
1//! Service authentication JWT parsing and verification for AT Protocol. 2//! 3//! Service auth is atproto's inter-service authentication mechanism. When a backend 4//! service (feed generator, labeler, etc.) receives requests, the PDS signs a 5//! short-lived JWT with the user's signing key and includes it as a Bearer token. 6//! 7//! # JWT Structure 8//! 9//! - Header: `alg` (ES256K for k256, ES256 for p256), `typ` ("JWT") 10//! - Payload: 11//! - `iss`: user's DID (issuer) 12//! - `aud`: target service DID (audience) 13//! - `exp`: expiration unix timestamp 14//! - `iat`: issued at unix timestamp 15//! - `jti`: random nonce (128-bit hex) for replay protection 16//! - `lxm`: lexicon method NSID (method binding) 17//! - Signature: signed with user's signing key from DID doc (ES256 or ES256K) 18 19use crate::CowStr; 20use crate::IntoStatic; 21use crate::types::string::{Did, Nsid}; 22use base64::Engine; 23use base64::engine::general_purpose::URL_SAFE_NO_PAD; 24use ouroboros::self_referencing; 25use serde::{Deserialize, Serialize}; 26use signature::Verifier; 27use smol_str::SmolStr; 28use smol_str::format_smolstr; 29use thiserror::Error; 30 31#[cfg(feature = "crypto-p256")] 32use p256::ecdsa::{Signature as P256Signature, VerifyingKey as P256VerifyingKey}; 33 34#[cfg(feature = "crypto-k256")] 35use k256::ecdsa::{Signature as K256Signature, VerifyingKey as K256VerifyingKey}; 36 37/// Errors that can occur during JWT parsing and verification. 38#[derive(Debug, Error, miette::Diagnostic)] 39pub enum ServiceAuthError { 40 /// JWT format is invalid (not three base64-encoded parts separated by dots) 41 #[error("malformed JWT: {0}")] 42 MalformedToken(CowStr<'static>), 43 44 /// Base64 decoding failed 45 #[error("base64 decode error: {0}")] 46 Base64Decode(#[from] base64::DecodeError), 47 48 /// JSON parsing failed 49 #[error("JSON parsing error: {0}")] 50 JsonParse(#[from] serde_json::Error), 51 52 /// Signature verification failed 53 #[error("invalid signature")] 54 InvalidSignature, 55 56 /// Unsupported algorithm 57 #[error("unsupported algorithm: {alg}")] 58 UnsupportedAlgorithm { 59 /// Algorithm name from JWT header 60 alg: SmolStr, 61 }, 62 63 /// Token has expired 64 #[error("token expired at {exp} (current time: {now})")] 65 Expired { 66 /// Expiration timestamp from token 67 exp: i64, 68 /// Current timestamp 69 now: i64, 70 }, 71 72 /// Audience mismatch 73 #[error("audience mismatch: expected {expected}, got {actual}")] 74 AudienceMismatch { 75 /// Expected audience DID 76 expected: Did<'static>, 77 /// Actual audience DID in token 78 actual: Did<'static>, 79 }, 80 81 /// Method mismatch (lxm field) 82 #[error("method mismatch: expected {expected}, got {actual:?}")] 83 MethodMismatch { 84 /// Expected method NSID 85 expected: Nsid<'static>, 86 /// Actual method NSID in token (if any) 87 actual: Option<Nsid<'static>>, 88 }, 89 90 /// Missing required field 91 #[error("missing required field: {0}")] 92 MissingField(&'static str), 93 94 /// Crypto error 95 #[error("crypto error: {0}")] 96 Crypto(CowStr<'static>), 97} 98 99/// JWT header for service auth tokens. 100#[derive(Debug, Clone, Serialize, Deserialize)] 101pub struct JwtHeader<'a> { 102 /// Algorithm used for signing 103 #[serde(borrow)] 104 pub alg: CowStr<'a>, 105 /// Type (always "JWT") 106 #[serde(borrow)] 107 pub typ: CowStr<'a>, 108} 109 110impl IntoStatic for JwtHeader<'_> { 111 type Output = JwtHeader<'static>; 112 113 fn into_static(self) -> Self::Output { 114 JwtHeader { 115 alg: self.alg.into_static(), 116 typ: self.typ.into_static(), 117 } 118 } 119} 120 121/// Service authentication claims. 122/// 123/// These are the payload fields in a service auth JWT. 124#[derive(Debug, Clone, Serialize, Deserialize)] 125pub struct ServiceAuthClaims<'a> { 126 /// Issuer (user's DID) 127 #[serde(borrow)] 128 pub iss: Did<'a>, 129 130 /// Audience (target service DID) 131 #[serde(borrow)] 132 pub aud: Did<'a>, 133 134 /// Expiration time (unix timestamp) 135 pub exp: i64, 136 137 /// Issued at (unix timestamp) 138 pub iat: i64, 139 140 /// JWT ID (nonce for replay protection) 141 #[serde(borrow, skip_serializing_if = "Option::is_none")] 142 pub jti: Option<CowStr<'a>>, 143 144 /// Lexicon method NSID (method binding) 145 #[serde(borrow, skip_serializing_if = "Option::is_none")] 146 pub lxm: Option<Nsid<'a>>, 147} 148 149impl<'a> IntoStatic for ServiceAuthClaims<'a> { 150 type Output = ServiceAuthClaims<'static>; 151 152 fn into_static(self) -> Self::Output { 153 ServiceAuthClaims { 154 iss: self.iss.into_static(), 155 aud: self.aud.into_static(), 156 exp: self.exp, 157 iat: self.iat, 158 jti: self.jti.map(|j| j.into_static()), 159 lxm: self.lxm.map(|l| l.into_static()), 160 } 161 } 162} 163 164impl<'a> ServiceAuthClaims<'a> { 165 /// Validate the claims against expected values. 166 /// 167 /// Checks: 168 /// - Audience matches expected DID 169 /// - Token is not expired 170 pub fn validate(&self, expected_aud: &Did) -> Result<(), ServiceAuthError> { 171 // Check audience 172 if self.aud.as_str() != expected_aud.as_str() { 173 return Err(ServiceAuthError::AudienceMismatch { 174 expected: expected_aud.clone().into_static(), 175 actual: self.aud.clone().into_static(), 176 }); 177 } 178 179 // Check expiration 180 if self.is_expired() { 181 let now = chrono::Utc::now().timestamp(); 182 return Err(ServiceAuthError::Expired { exp: self.exp, now }); 183 } 184 185 Ok(()) 186 } 187 188 /// Check if the token has expired. 189 pub fn is_expired(&self) -> bool { 190 let now = chrono::Utc::now().timestamp(); 191 self.exp <= now 192 } 193 194 /// Check if the method (lxm) matches the expected NSID. 195 pub fn check_method(&self, nsid: &Nsid) -> bool { 196 self.lxm 197 .as_ref() 198 .map(|lxm| lxm.as_str() == nsid.as_str()) 199 .unwrap_or(false) 200 } 201 202 /// Require that the method (lxm) matches the expected NSID. 203 pub fn require_method(&self, nsid: &Nsid) -> Result<(), ServiceAuthError> { 204 if !self.check_method(nsid) { 205 return Err(ServiceAuthError::MethodMismatch { 206 expected: nsid.clone().into_static(), 207 actual: self.lxm.as_ref().map(|l| l.clone().into_static()), 208 }); 209 } 210 Ok(()) 211 } 212} 213 214/// Parsed JWT components. 215/// 216/// This struct owns the decoded buffers and parsed components using ouroboros 217/// self-referencing. The header and claims borrow from their respective buffers. 218#[self_referencing] 219pub struct ParsedJwt { 220 /// Decoded header buffer (owned) 221 header_buf: Vec<u8>, 222 /// Decoded payload buffer (owned) 223 payload_buf: Vec<u8>, 224 /// Original token string for signing_input 225 token: String, 226 /// Signature bytes 227 signature: Vec<u8>, 228 /// Parsed header borrowing from header_buf 229 #[borrows(header_buf)] 230 #[covariant] 231 header: JwtHeader<'this>, 232 /// Parsed claims borrowing from payload_buf 233 #[borrows(payload_buf)] 234 #[covariant] 235 claims: ServiceAuthClaims<'this>, 236} 237 238impl ParsedJwt { 239 /// Get the signing input (header.payload) for signature verification. 240 pub fn signing_input(&self) -> &[u8] { 241 self.with_token(|token| { 242 let dot_pos = token.find('.').unwrap(); 243 let second_dot_pos = token[dot_pos + 1..].find('.').unwrap() + dot_pos + 1; 244 token[..second_dot_pos].as_bytes() 245 }) 246 } 247 248 /// Get a reference to the header. 249 pub fn header(&self) -> &JwtHeader<'_> { 250 self.borrow_header() 251 } 252 253 /// Get a reference to the claims. 254 pub fn claims(&self) -> &ServiceAuthClaims<'_> { 255 self.borrow_claims() 256 } 257 258 /// Get a reference to the signature. 259 pub fn signature(&self) -> &[u8] { 260 self.borrow_signature() 261 } 262 263 /// Get owned header with 'static lifetime. 264 pub fn into_header(self) -> JwtHeader<'static> { 265 self.with_header(|header| header.clone().into_static()) 266 } 267 268 /// Get owned claims with 'static lifetime. 269 pub fn into_claims(self) -> ServiceAuthClaims<'static> { 270 self.with_claims(|claims| claims.clone().into_static()) 271 } 272} 273 274/// Parse a JWT token into its components without verifying the signature. 275/// 276/// This extracts and decodes all JWT components. The header and claims are parsed 277/// and borrow from their respective owned buffers using ouroboros self-referencing. 278pub fn parse_jwt(token: &str) -> Result<ParsedJwt, ServiceAuthError> { 279 let parts: Vec<&str> = token.split('.').collect(); 280 if parts.len() != 3 { 281 return Err(ServiceAuthError::MalformedToken(CowStr::new_static( 282 "JWT must have exactly 3 parts separated by dots", 283 ))); 284 } 285 286 let header_b64 = parts[0]; 287 let payload_b64 = parts[1]; 288 let signature_b64 = parts[2]; 289 290 // Decode all components 291 let header_buf = URL_SAFE_NO_PAD.decode(header_b64)?; 292 let payload_buf = URL_SAFE_NO_PAD.decode(payload_b64)?; 293 let signature = URL_SAFE_NO_PAD.decode(signature_b64)?; 294 295 // Validate that buffers contain valid JSON for their types 296 // We parse once here to validate, then again in the builder (unavoidable with ouroboros) 297 let _header: JwtHeader = serde_json::from_slice(&header_buf)?; 298 let _claims: ServiceAuthClaims = serde_json::from_slice(&payload_buf)?; 299 300 Ok(ParsedJwtBuilder { 301 header_buf, 302 payload_buf, 303 token: token.to_string(), 304 signature, 305 header_builder: |buf| { 306 // Safe: we validated this succeeds above 307 serde_json::from_slice(buf).expect("header was validated") 308 }, 309 claims_builder: |buf| { 310 // Safe: we validated this succeeds above 311 serde_json::from_slice(buf).expect("claims were validated") 312 }, 313 } 314 .build()) 315} 316 317/// Public key types for signature verification. 318#[derive(Debug, Clone)] 319pub enum PublicKey { 320 /// P-256 (ES256) public key 321 #[cfg(feature = "crypto-p256")] 322 P256(P256VerifyingKey), 323 324 /// secp256k1 (ES256K) public key 325 #[cfg(feature = "crypto-k256")] 326 K256(K256VerifyingKey), 327} 328 329impl PublicKey { 330 /// Create a P-256 public key from compressed or uncompressed bytes. 331 #[cfg(feature = "crypto-p256")] 332 pub fn from_p256_bytes(bytes: &[u8]) -> Result<Self, ServiceAuthError> { 333 let key = P256VerifyingKey::from_sec1_bytes(bytes).map_err(|e| { 334 ServiceAuthError::Crypto(CowStr::Owned(format_smolstr!("invalid P-256 key: {}", e))) 335 })?; 336 Ok(PublicKey::P256(key)) 337 } 338 339 /// Create a secp256k1 public key from compressed or uncompressed bytes. 340 #[cfg(feature = "crypto-k256")] 341 pub fn from_k256_bytes(bytes: &[u8]) -> Result<Self, ServiceAuthError> { 342 let key = K256VerifyingKey::from_sec1_bytes(bytes).map_err(|e| { 343 ServiceAuthError::Crypto(CowStr::Owned(format_smolstr!("invalid K-256 key: {}", e))) 344 })?; 345 Ok(PublicKey::K256(key)) 346 } 347} 348 349/// Verify a JWT signature using the provided public key. 350/// 351/// The algorithm is determined by the JWT header and must match the public key type. 352pub fn verify_signature( 353 parsed: &ParsedJwt, 354 public_key: &PublicKey, 355) -> Result<(), ServiceAuthError> { 356 let alg = parsed.header().alg.as_str(); 357 let signing_input = parsed.signing_input(); 358 let signature = parsed.signature(); 359 360 match (alg, public_key) { 361 #[cfg(feature = "crypto-p256")] 362 ("ES256", PublicKey::P256(key)) => { 363 let sig = P256Signature::from_slice(signature).map_err(|e| { 364 ServiceAuthError::Crypto(CowStr::Owned(format_smolstr!( 365 "invalid ES256 signature: {}", 366 e 367 ))) 368 })?; 369 key.verify(signing_input, &sig) 370 .map_err(|_| ServiceAuthError::InvalidSignature)?; 371 Ok(()) 372 } 373 374 #[cfg(feature = "crypto-k256")] 375 ("ES256K", PublicKey::K256(key)) => { 376 let sig = K256Signature::from_slice(signature).map_err(|e| { 377 ServiceAuthError::Crypto(CowStr::Owned(format_smolstr!( 378 "invalid ES256K signature: {}", 379 e 380 ))) 381 })?; 382 key.verify(signing_input, &sig) 383 .map_err(|_| ServiceAuthError::InvalidSignature)?; 384 Ok(()) 385 } 386 387 _ => Err(ServiceAuthError::UnsupportedAlgorithm { 388 alg: SmolStr::new(alg), 389 }), 390 } 391} 392 393/// Parse and verify a service auth JWT in one step, returning owned claims. 394/// 395/// This is a convenience function that combines parsing and signature verification. 396pub fn verify_service_jwt( 397 token: &str, 398 public_key: &PublicKey, 399) -> Result<ServiceAuthClaims<'static>, ServiceAuthError> { 400 let parsed = parse_jwt(token)?; 401 verify_signature(&parsed, public_key)?; 402 Ok(parsed.into_claims()) 403} 404 405#[cfg(test)] 406mod tests { 407 use super::*; 408 409 #[test] 410 fn test_parse_jwt_invalid_format() { 411 let result = parse_jwt("not.a.valid.jwt.with.too.many.parts"); 412 assert!(matches!(result, Err(ServiceAuthError::MalformedToken(_)))); 413 } 414 415 #[test] 416 fn test_claims_expiration() { 417 let now = chrono::Utc::now().timestamp(); 418 let expired_claims = ServiceAuthClaims { 419 iss: Did::new("did:plc:test").unwrap(), 420 aud: Did::new("did:web:example.com").unwrap(), 421 exp: now - 100, 422 iat: now - 200, 423 jti: None, 424 lxm: None, 425 }; 426 427 assert!(expired_claims.is_expired()); 428 429 let valid_claims = ServiceAuthClaims { 430 iss: Did::new("did:plc:test").unwrap(), 431 aud: Did::new("did:web:example.com").unwrap(), 432 exp: now + 100, 433 iat: now, 434 jti: None, 435 lxm: None, 436 }; 437 438 assert!(!valid_claims.is_expired()); 439 } 440 441 #[test] 442 fn test_audience_validation() { 443 let now = chrono::Utc::now().timestamp(); 444 let claims = ServiceAuthClaims { 445 iss: Did::new("did:plc:test").unwrap(), 446 aud: Did::new("did:web:example.com").unwrap(), 447 exp: now + 100, 448 iat: now, 449 jti: None, 450 lxm: None, 451 }; 452 453 let expected_aud = Did::new("did:web:example.com").unwrap(); 454 assert!(claims.validate(&expected_aud).is_ok()); 455 456 let wrong_aud = Did::new("did:web:wrong.com").unwrap(); 457 assert!(matches!( 458 claims.validate(&wrong_aud), 459 Err(ServiceAuthError::AudienceMismatch { .. }) 460 )); 461 } 462 463 #[test] 464 fn test_method_check() { 465 let claims = ServiceAuthClaims { 466 iss: Did::new("did:plc:test").unwrap(), 467 aud: Did::new("did:web:example.com").unwrap(), 468 exp: chrono::Utc::now().timestamp() + 100, 469 iat: chrono::Utc::now().timestamp(), 470 jti: None, 471 lxm: Some(Nsid::new("app.bsky.feed.getFeedSkeleton").unwrap()), 472 }; 473 474 let expected = Nsid::new("app.bsky.feed.getFeedSkeleton").unwrap(); 475 assert!(claims.check_method(&expected)); 476 477 let wrong = Nsid::new("app.bsky.feed.getTimeline").unwrap(); 478 assert!(!claims.check_method(&wrong)); 479 } 480}