1//! Service authentication extractor and middleware
2//!
3//! # Example
4//!
5//! ```no_run
6//! use axum::{Router, routing::get};
7//! use jacquard_axum::service_auth::{ServiceAuthConfig, ExtractServiceAuth};
8//! use jacquard_identity::JacquardResolver;
9//! use jacquard_identity::resolver::ResolverOptions;
10//! use jacquard_common::types::string::Did;
11//!
12//! async fn handler(
13//! ExtractServiceAuth(auth): ExtractServiceAuth,
14//! ) -> String {
15//! format!("Authenticated as {}", auth.did())
16//! }
17//!
18//! #[tokio::main]
19//! async fn main() {
20//! let resolver = JacquardResolver::new(
21//! reqwest::Client::new(),
22//! ResolverOptions::default(),
23//! );
24//! let config = ServiceAuthConfig::new(
25//! Did::new_static("did:web:feedgen.example.com").unwrap(),
26//! resolver,
27//! );
28//!
29//! let app = Router::new()
30//! .route("/xrpc/app.bsky.feed.getFeedSkeleton", get(handler))
31//! .with_state(config);
32//!
33//! let listener = tokio::net::TcpListener::bind("0.0.0.0:3000")
34//! .await
35//! .unwrap();
36//! axum::serve(listener, app).await.unwrap();
37//! }
38//! ```
39
40use axum::{
41 Json,
42 extract::FromRequestParts,
43 http::{HeaderValue, StatusCode, header, request::Parts},
44 middleware::Next,
45 response::{IntoResponse, Response},
46};
47use jacquard_common::{
48 CowStr, IntoStatic,
49 service_auth::{self, PublicKey},
50 types::{
51 did_doc::VerificationMethod,
52 string::{Did, Nsid},
53 },
54};
55use jacquard_identity::resolver::IdentityResolver;
56use serde_json::json;
57use std::sync::Arc;
58use thiserror::Error;
59
60/// Trait for providing service authentication configuration.
61///
62/// This trait allows custom state types to provide service auth configuration
63/// without requiring `ServiceAuthConfig<R>` directly.
64pub trait ServiceAuth {
65 /// The identity resolver type
66 type Resolver: IdentityResolver;
67
68 /// Get the service DID (expected audience)
69 fn service_did(&self) -> &Did<'_>;
70
71 /// Get a reference to the identity resolver
72 fn resolver(&self) -> &Self::Resolver;
73
74 /// Whether to require the `lxm` (method binding) field
75 fn require_lxm(&self) -> bool;
76}
77
78/// Configuration for service auth verification.
79///
80/// This should be stored in your Axum app state and will be extracted
81/// by the `ExtractServiceAuth` extractor.
82pub struct ServiceAuthConfig<R> {
83 /// The DID of your service (the expected audience)
84 service_did: Did<'static>,
85 /// Identity resolver for fetching DID documents
86 resolver: Arc<R>,
87 /// Whether to require the `lxm` (method binding) field
88 require_lxm: bool,
89}
90
91impl<R> Clone for ServiceAuthConfig<R> {
92 fn clone(&self) -> Self {
93 Self {
94 service_did: self.service_did.clone(),
95 resolver: Arc::clone(&self.resolver),
96 require_lxm: self.require_lxm,
97 }
98 }
99}
100
101impl<R: IdentityResolver> ServiceAuthConfig<R> {
102 /// Create a new service auth config.
103 ///
104 /// This enables `lxm` (method binding). If you need backward compatibility,
105 /// use `ServiceAuthConfig::new_legacy()`
106 pub fn new(service_did: Did<'static>, resolver: R) -> Self {
107 Self {
108 service_did,
109 resolver: Arc::new(resolver),
110 require_lxm: true,
111 }
112 }
113
114 /// Create a new service auth config.
115 ///
116 /// `lxm` (method binding) is disabled for backwards compatibility
117 pub fn new_legacy(service_did: Did<'static>, resolver: R) -> Self {
118 Self {
119 service_did,
120 resolver: Arc::new(resolver),
121 require_lxm: false,
122 }
123 }
124
125 /// Set whether to require the `lxm` field (method binding).
126 ///
127 /// When enabled, the JWT must contain an `lxm` field matching the requested endpoint.
128 /// This prevents token reuse across different methods.
129 pub fn require_lxm(mut self, require: bool) -> Self {
130 self.require_lxm = require;
131 self
132 }
133
134 /// Get the service DID.
135 pub fn service_did(&self) -> &Did<'static> {
136 &self.service_did
137 }
138
139 /// Get a reference to the identity resolver.
140 pub fn resolver(&self) -> &R {
141 &self.resolver
142 }
143}
144
145impl<R: IdentityResolver> ServiceAuth for ServiceAuthConfig<R> {
146 type Resolver = R;
147
148 fn service_did(&self) -> &Did<'_> {
149 &self.service_did
150 }
151
152 fn resolver(&self) -> &Self::Resolver {
153 &self.resolver
154 }
155
156 fn require_lxm(&self) -> bool {
157 self.require_lxm
158 }
159}
160
161/// Verified service authentication information.
162///
163/// This is the result of successfully verifying a service auth JWT.
164/// This type is extracted by the `ExtractServiceAuth` extractor.
165#[derive(Debug, Clone, jacquard_derive::IntoStatic)]
166pub struct VerifiedServiceAuth<'a> {
167 /// The authenticated user's DID (from `iss` claim)
168 did: Did<'a>,
169 /// The audience (should match your service DID)
170 aud: Did<'a>,
171 /// The lexicon method NSID, if present
172 lxm: Option<Nsid<'a>>,
173 /// JWT ID (nonce), if present
174 jti: Option<CowStr<'a>>,
175}
176
177impl<'a> VerifiedServiceAuth<'a> {
178 /// Get the authenticated user's DID.
179 pub fn did(&self) -> &Did<'a> {
180 &self.did
181 }
182
183 /// Get the audience (your service DID).
184 pub fn aud(&self) -> &Did<'a> {
185 &self.aud
186 }
187
188 /// Get the lexicon method NSID, if present.
189 pub fn lxm(&self) -> Option<&Nsid<'a>> {
190 self.lxm.as_ref()
191 }
192
193 /// Get the JWT ID (nonce), if present.
194 ///
195 /// You can use this for replay protection by tracking seen JTIs
196 /// until their expiration time.
197 pub fn jti(&self) -> Option<&str> {
198 self.jti.as_ref().map(|j| j.as_ref())
199 }
200}
201
202/// Axum extractor for service authentication.
203///
204/// This extracts and verifies a service auth JWT from the Authorization header,
205/// resolving the issuer's DID to verify the signature.
206///
207/// # Example
208///
209/// ```no_run
210/// use axum::{Router, routing::get};
211/// use jacquard_axum::service_auth::{ServiceAuthConfig, ExtractServiceAuth};
212/// use jacquard_identity::JacquardResolver;
213/// use jacquard_identity::resolver::ResolverOptions;
214/// use jacquard_common::types::string::Did;
215///
216/// async fn handler(
217/// ExtractServiceAuth(auth): ExtractServiceAuth,
218/// ) -> String {
219/// format!("Authenticated as {}", auth.did())
220/// }
221///
222/// #[tokio::main]
223/// async fn main() {
224/// let resolver = JacquardResolver::new(
225/// reqwest::Client::new(),
226/// ResolverOptions::default(),
227/// );
228/// let config = ServiceAuthConfig::new(
229/// Did::new_static("did:web:feedgen.example.com").unwrap(),
230/// resolver,
231/// );
232///
233/// let app = Router::new()
234/// .route("/xrpc/app.bsky.feed.getFeedSkeleton", get(handler))
235/// .with_state(config);
236///
237/// let listener = tokio::net::TcpListener::bind("0.0.0.0:3000")
238/// .await
239/// .unwrap();
240/// axum::serve(listener, app).await.unwrap();
241/// }
242/// ```
243pub struct ExtractServiceAuth(pub VerifiedServiceAuth<'static>);
244
245/// Errors that can occur during service auth verification.
246#[derive(Debug, Error, miette::Diagnostic)]
247pub enum ServiceAuthError {
248 /// Authorization header is missing
249 #[error("missing Authorization header")]
250 MissingAuthHeader,
251
252 /// Authorization header is malformed (not "Bearer `token`")
253 #[error("invalid Authorization header format")]
254 InvalidAuthHeader,
255
256 /// JWT parsing or verification failed
257 #[error("JWT verification failed: {0}")]
258 JwtError(#[from] service_auth::ServiceAuthError),
259
260 /// DID resolution failed
261 #[error("failed to resolve DID {did}: {source}")]
262 DidResolutionFailed {
263 did: Did<'static>,
264 #[source]
265 source: Box<dyn std::error::Error + Send + Sync>,
266 },
267
268 /// No valid signing key found in DID document
269 #[error("no valid signing key found in DID document for {0}")]
270 NoSigningKey(Did<'static>),
271
272 /// Method binding required but missing
273 #[error("lxm (method binding) is required but missing from token")]
274 MethodBindingRequired,
275
276 /// Invalid key format
277 #[error("invalid key format: {0}")]
278 InvalidKey(String),
279}
280
281impl IntoResponse for ServiceAuthError {
282 fn into_response(self) -> Response {
283 let (status, error_code, message) = match &self {
284 ServiceAuthError::MissingAuthHeader => {
285 (StatusCode::UNAUTHORIZED, "AuthMissing", self.to_string())
286 }
287 ServiceAuthError::InvalidAuthHeader => {
288 (StatusCode::UNAUTHORIZED, "AuthMissing", self.to_string())
289 }
290 ServiceAuthError::JwtError(_) => (
291 StatusCode::UNAUTHORIZED,
292 "AuthenticationRequired",
293 self.to_string(),
294 ),
295 ServiceAuthError::DidResolutionFailed { .. } => (
296 StatusCode::UNAUTHORIZED,
297 "AuthenticationRequired",
298 self.to_string(),
299 ),
300 ServiceAuthError::NoSigningKey(_) => (
301 StatusCode::UNAUTHORIZED,
302 "AuthenticationRequired",
303 self.to_string(),
304 ),
305 ServiceAuthError::MethodBindingRequired => (
306 StatusCode::UNAUTHORIZED,
307 "AuthenticationRequired",
308 self.to_string(),
309 ),
310 ServiceAuthError::InvalidKey(_) => (
311 StatusCode::UNAUTHORIZED,
312 "AuthenticationRequired",
313 self.to_string(),
314 ),
315 };
316
317 tracing::warn!("Service auth failed: {}", message);
318
319 (
320 status,
321 [(
322 header::CONTENT_TYPE,
323 HeaderValue::from_static("application/json"),
324 )],
325 Json(json!({
326 "error": error_code,
327 "message": message,
328 })),
329 )
330 .into_response()
331 }
332}
333
334impl<S> FromRequestParts<S> for ExtractServiceAuth
335where
336 S: ServiceAuth + Send + Sync,
337 S::Resolver: Send + Sync,
338{
339 type Rejection = ServiceAuthError;
340
341 fn from_request_parts(
342 parts: &mut Parts,
343 state: &S,
344 ) -> impl std::future::Future<Output = Result<Self, Self::Rejection>> + Send {
345 async move {
346 // Extract Authorization header
347 let auth_header = parts
348 .headers
349 .get(header::AUTHORIZATION)
350 .ok_or(ServiceAuthError::MissingAuthHeader)?;
351
352 // Parse Bearer token
353 let auth_str = auth_header
354 .to_str()
355 .map_err(|_| ServiceAuthError::InvalidAuthHeader)?;
356
357 let token = auth_str
358 .strip_prefix("Bearer ")
359 .ok_or(ServiceAuthError::InvalidAuthHeader)?;
360
361 // Parse JWT
362 let parsed = service_auth::parse_jwt(token)?;
363
364 // Get claims for DID resolution
365 let claims = parsed.claims();
366
367 // Resolve DID to get signing key (do this before checking claims)
368 let did_doc = state
369 .resolver()
370 .resolve_did_doc(&claims.iss)
371 .await
372 .map_err(|e| ServiceAuthError::DidResolutionFailed {
373 did: claims.iss.clone().into_static(),
374 source: Box::new(e),
375 })?;
376
377 // Parse the DID document response to get verification methods
378 let doc = did_doc
379 .parse()
380 .map_err(|e| ServiceAuthError::DidResolutionFailed {
381 did: claims.iss.clone().into_static(),
382 source: Box::new(e),
383 })?;
384
385 // Extract signing key from DID document
386 let verification_methods = doc
387 .verification_method
388 .as_deref()
389 .ok_or_else(|| ServiceAuthError::NoSigningKey(claims.iss.clone().into_static()))?;
390
391 let signing_key = extract_signing_key(verification_methods)
392 .ok_or_else(|| ServiceAuthError::NoSigningKey(claims.iss.clone().into_static()))?;
393
394 // Verify signature FIRST - if this fails, nothing else matters
395 service_auth::verify_signature(&parsed, &signing_key)?;
396
397 // Now validate claims (audience, expiration, etc.)
398 claims.validate(state.service_did())?;
399
400 // Check method binding if required
401 if state.require_lxm() && claims.lxm.is_none() {
402 return Err(ServiceAuthError::MethodBindingRequired);
403 }
404
405 // All checks passed - return verified auth
406 Ok(ExtractServiceAuth(VerifiedServiceAuth {
407 did: claims.iss.clone().into_static(),
408 aud: claims.aud.clone().into_static(),
409 lxm: claims.lxm.as_ref().map(|l| l.clone().into_static()),
410 jti: claims.jti.as_ref().map(|j| j.clone().into_static()),
411 }))
412 }
413 }
414}
415
416/// Extract the signing key from a DID document's verification methods.
417///
418/// This looks for a key with type "atproto" or the first available key
419/// if no atproto-specific key is found.
420fn extract_signing_key(methods: &[VerificationMethod]) -> Option<PublicKey> {
421 // First try to find an atproto-specific key
422 let atproto_method = methods
423 .iter()
424 .find(|m| m.r#type.as_ref() == "Multikey" || m.r#type.as_ref() == "atproto");
425
426 let method = atproto_method.or_else(|| methods.first())?;
427
428 // Parse the multikey
429 let public_key_multibase = method.public_key_multibase.as_ref()?;
430
431 // Decode multibase
432 let (_, key_bytes) = multibase::decode(public_key_multibase.as_ref()).ok()?;
433
434 // First two bytes are the multicodec prefix
435 if key_bytes.len() < 2 {
436 return None;
437 }
438
439 let codec = &key_bytes[..2];
440 let key_material = &key_bytes[2..];
441
442 match codec {
443 // p256-pub (0x1200)
444 [0x80, 0x24] => PublicKey::from_p256_bytes(key_material).ok(),
445 // secp256k1-pub (0xe7)
446 [0xe7, 0x01] => PublicKey::from_k256_bytes(key_material).ok(),
447 _ => None,
448 }
449}
450
451/// Middleware for verifying service authentication on all requests.
452///
453/// This middleware extracts and verifies the service auth JWT, then adds the
454/// `VerifiedServiceAuth` to request extensions for downstream handlers to access.
455///
456/// # Example
457///
458/// ```no_run
459/// use axum::{Router, routing::get, middleware, Extension};
460/// use jacquard_axum::service_auth::{ServiceAuthConfig, service_auth_middleware};
461/// use jacquard_identity::JacquardResolver;
462/// use jacquard_identity::resolver::ResolverOptions;
463/// use jacquard_common::types::string::Did;
464///
465/// async fn handler(
466/// Extension(auth): Extension<jacquard_axum::service_auth::VerifiedServiceAuth<'static>>,
467/// ) -> String {
468/// format!("Authenticated as {}", auth.did())
469/// }
470///
471/// #[tokio::main]
472/// async fn main() {
473/// let resolver = JacquardResolver::new(
474/// reqwest::Client::new(),
475/// ResolverOptions::default(),
476/// );
477/// let config = ServiceAuthConfig::new(
478/// Did::new_static("did:web:feedgen.example.com").unwrap(),
479/// resolver,
480/// );
481///
482/// let app = Router::new()
483/// .route("/xrpc/app.bsky.feed.getFeedSkeleton", get(handler))
484/// .layer(middleware::from_fn_with_state(
485/// config.clone(),
486/// service_auth_middleware::<ServiceAuthConfig<JacquardResolver>>,
487/// ))
488/// .with_state(config);
489///
490/// let listener = tokio::net::TcpListener::bind("0.0.0.0:3000")
491/// .await
492/// .unwrap();
493/// axum::serve(listener, app).await.unwrap();
494/// }
495/// ```
496pub async fn service_auth_middleware<S>(
497 state: axum::extract::State<S>,
498 mut req: axum::extract::Request,
499 next: Next,
500) -> Result<Response, ServiceAuthError>
501where
502 S: ServiceAuth + Send + Sync + Clone,
503 S::Resolver: Send + Sync,
504{
505 // Extract auth from request parts
506 let (mut parts, body) = req.into_parts();
507 let ExtractServiceAuth(auth) =
508 ExtractServiceAuth::from_request_parts(&mut parts, &state.0).await?;
509
510 // Add auth to extensions
511 parts.extensions.insert(auth);
512
513 // Reconstruct request and continue
514 req = axum::extract::Request::from_parts(parts, body);
515 Ok(next.run(req).await)
516}