1//! Service authentication JWT parsing and verification for AT Protocol.
2//!
3//! Service auth is atproto's inter-service authentication mechanism. When a backend
4//! service (feed generator, labeler, etc.) receives requests, the PDS signs a
5//! short-lived JWT with the user's signing key and includes it as a Bearer token.
6//!
7//! # JWT Structure
8//!
9//! - Header: `alg` (ES256K for k256, ES256 for p256), `typ` ("JWT")
10//! - Payload:
11//! - `iss`: user's DID (issuer)
12//! - `aud`: target service DID (audience)
13//! - `exp`: expiration unix timestamp
14//! - `iat`: issued at unix timestamp
15//! - `jti`: random nonce (128-bit hex) for replay protection
16//! - `lxm`: lexicon method NSID (method binding)
17//! - Signature: signed with user's signing key from DID doc (ES256 or ES256K)
18
19use crate::CowStr;
20use crate::IntoStatic;
21use crate::types::string::{Did, Nsid};
22use base64::Engine;
23use base64::engine::general_purpose::URL_SAFE_NO_PAD;
24use ouroboros::self_referencing;
25use serde::{Deserialize, Serialize};
26use signature::Verifier;
27use smol_str::SmolStr;
28use smol_str::format_smolstr;
29use thiserror::Error;
30
31#[cfg(feature = "crypto-p256")]
32use p256::ecdsa::{Signature as P256Signature, VerifyingKey as P256VerifyingKey};
33
34#[cfg(feature = "crypto-k256")]
35use k256::ecdsa::{Signature as K256Signature, VerifyingKey as K256VerifyingKey};
36
37/// Errors that can occur during JWT parsing and verification.
38#[derive(Debug, Error, miette::Diagnostic)]
39pub enum ServiceAuthError {
40 /// JWT format is invalid (not three base64-encoded parts separated by dots)
41 #[error("malformed JWT: {0}")]
42 MalformedToken(CowStr<'static>),
43
44 /// Base64 decoding failed
45 #[error("base64 decode error: {0}")]
46 Base64Decode(#[from] base64::DecodeError),
47
48 /// JSON parsing failed
49 #[error("JSON parsing error: {0}")]
50 JsonParse(#[from] serde_json::Error),
51
52 /// Signature verification failed
53 #[error("invalid signature")]
54 InvalidSignature,
55
56 /// Unsupported algorithm
57 #[error("unsupported algorithm: {alg}")]
58 UnsupportedAlgorithm {
59 /// Algorithm name from JWT header
60 alg: SmolStr,
61 },
62
63 /// Token has expired
64 #[error("token expired at {exp} (current time: {now})")]
65 Expired {
66 /// Expiration timestamp from token
67 exp: i64,
68 /// Current timestamp
69 now: i64,
70 },
71
72 /// Audience mismatch
73 #[error("audience mismatch: expected {expected}, got {actual}")]
74 AudienceMismatch {
75 /// Expected audience DID
76 expected: Did<'static>,
77 /// Actual audience DID in token
78 actual: Did<'static>,
79 },
80
81 /// Method mismatch (lxm field)
82 #[error("method mismatch: expected {expected}, got {actual:?}")]
83 MethodMismatch {
84 /// Expected method NSID
85 expected: Nsid<'static>,
86 /// Actual method NSID in token (if any)
87 actual: Option<Nsid<'static>>,
88 },
89
90 /// Missing required field
91 #[error("missing required field: {0}")]
92 MissingField(&'static str),
93
94 /// Crypto error
95 #[error("crypto error: {0}")]
96 Crypto(CowStr<'static>),
97}
98
99/// JWT header for service auth tokens.
100#[derive(Debug, Clone, Serialize, Deserialize)]
101pub struct JwtHeader<'a> {
102 /// Algorithm used for signing
103 #[serde(borrow)]
104 pub alg: CowStr<'a>,
105 /// Type (always "JWT")
106 #[serde(borrow)]
107 pub typ: CowStr<'a>,
108}
109
110impl IntoStatic for JwtHeader<'_> {
111 type Output = JwtHeader<'static>;
112
113 fn into_static(self) -> Self::Output {
114 JwtHeader {
115 alg: self.alg.into_static(),
116 typ: self.typ.into_static(),
117 }
118 }
119}
120
121/// Service authentication claims.
122///
123/// These are the payload fields in a service auth JWT.
124#[derive(Debug, Clone, Serialize, Deserialize)]
125pub struct ServiceAuthClaims<'a> {
126 /// Issuer (user's DID)
127 #[serde(borrow)]
128 pub iss: Did<'a>,
129
130 /// Audience (target service DID)
131 #[serde(borrow)]
132 pub aud: Did<'a>,
133
134 /// Expiration time (unix timestamp)
135 pub exp: i64,
136
137 /// Issued at (unix timestamp)
138 pub iat: i64,
139
140 /// JWT ID (nonce for replay protection)
141 #[serde(borrow, skip_serializing_if = "Option::is_none")]
142 pub jti: Option<CowStr<'a>>,
143
144 /// Lexicon method NSID (method binding)
145 #[serde(borrow, skip_serializing_if = "Option::is_none")]
146 pub lxm: Option<Nsid<'a>>,
147}
148
149impl<'a> IntoStatic for ServiceAuthClaims<'a> {
150 type Output = ServiceAuthClaims<'static>;
151
152 fn into_static(self) -> Self::Output {
153 ServiceAuthClaims {
154 iss: self.iss.into_static(),
155 aud: self.aud.into_static(),
156 exp: self.exp,
157 iat: self.iat,
158 jti: self.jti.map(|j| j.into_static()),
159 lxm: self.lxm.map(|l| l.into_static()),
160 }
161 }
162}
163
164impl<'a> ServiceAuthClaims<'a> {
165 /// Validate the claims against expected values.
166 ///
167 /// Checks:
168 /// - Audience matches expected DID
169 /// - Token is not expired
170 pub fn validate(&self, expected_aud: &Did) -> Result<(), ServiceAuthError> {
171 // Check audience
172 if self.aud.as_str() != expected_aud.as_str() {
173 return Err(ServiceAuthError::AudienceMismatch {
174 expected: expected_aud.clone().into_static(),
175 actual: self.aud.clone().into_static(),
176 });
177 }
178
179 // Check expiration
180 if self.is_expired() {
181 let now = chrono::Utc::now().timestamp();
182 return Err(ServiceAuthError::Expired { exp: self.exp, now });
183 }
184
185 Ok(())
186 }
187
188 /// Check if the token has expired.
189 pub fn is_expired(&self) -> bool {
190 let now = chrono::Utc::now().timestamp();
191 self.exp <= now
192 }
193
194 /// Check if the method (lxm) matches the expected NSID.
195 pub fn check_method(&self, nsid: &Nsid) -> bool {
196 self.lxm
197 .as_ref()
198 .map(|lxm| lxm.as_str() == nsid.as_str())
199 .unwrap_or(false)
200 }
201
202 /// Require that the method (lxm) matches the expected NSID.
203 pub fn require_method(&self, nsid: &Nsid) -> Result<(), ServiceAuthError> {
204 if !self.check_method(nsid) {
205 return Err(ServiceAuthError::MethodMismatch {
206 expected: nsid.clone().into_static(),
207 actual: self.lxm.as_ref().map(|l| l.clone().into_static()),
208 });
209 }
210 Ok(())
211 }
212}
213
214/// Parsed JWT components.
215///
216/// This struct owns the decoded buffers and parsed components using ouroboros
217/// self-referencing. The header and claims borrow from their respective buffers.
218#[self_referencing]
219pub struct ParsedJwt {
220 /// Decoded header buffer (owned)
221 header_buf: Vec<u8>,
222 /// Decoded payload buffer (owned)
223 payload_buf: Vec<u8>,
224 /// Original token string for signing_input
225 token: String,
226 /// Signature bytes
227 signature: Vec<u8>,
228 /// Parsed header borrowing from header_buf
229 #[borrows(header_buf)]
230 #[covariant]
231 header: JwtHeader<'this>,
232 /// Parsed claims borrowing from payload_buf
233 #[borrows(payload_buf)]
234 #[covariant]
235 claims: ServiceAuthClaims<'this>,
236}
237
238impl ParsedJwt {
239 /// Get the signing input (header.payload) for signature verification.
240 pub fn signing_input(&self) -> &[u8] {
241 self.with_token(|token| {
242 let dot_pos = token.find('.').unwrap();
243 let second_dot_pos = token[dot_pos + 1..].find('.').unwrap() + dot_pos + 1;
244 token[..second_dot_pos].as_bytes()
245 })
246 }
247
248 /// Get a reference to the header.
249 pub fn header(&self) -> &JwtHeader<'_> {
250 self.borrow_header()
251 }
252
253 /// Get a reference to the claims.
254 pub fn claims(&self) -> &ServiceAuthClaims<'_> {
255 self.borrow_claims()
256 }
257
258 /// Get a reference to the signature.
259 pub fn signature(&self) -> &[u8] {
260 self.borrow_signature()
261 }
262
263 /// Get owned header with 'static lifetime.
264 pub fn into_header(self) -> JwtHeader<'static> {
265 self.with_header(|header| header.clone().into_static())
266 }
267
268 /// Get owned claims with 'static lifetime.
269 pub fn into_claims(self) -> ServiceAuthClaims<'static> {
270 self.with_claims(|claims| claims.clone().into_static())
271 }
272}
273
274/// Parse a JWT token into its components without verifying the signature.
275///
276/// This extracts and decodes all JWT components. The header and claims are parsed
277/// and borrow from their respective owned buffers using ouroboros self-referencing.
278pub fn parse_jwt(token: &str) -> Result<ParsedJwt, ServiceAuthError> {
279 let parts: Vec<&str> = token.split('.').collect();
280 if parts.len() != 3 {
281 return Err(ServiceAuthError::MalformedToken(CowStr::new_static(
282 "JWT must have exactly 3 parts separated by dots",
283 )));
284 }
285
286 let header_b64 = parts[0];
287 let payload_b64 = parts[1];
288 let signature_b64 = parts[2];
289
290 // Decode all components
291 let header_buf = URL_SAFE_NO_PAD.decode(header_b64)?;
292 let payload_buf = URL_SAFE_NO_PAD.decode(payload_b64)?;
293 let signature = URL_SAFE_NO_PAD.decode(signature_b64)?;
294
295 // Validate that buffers contain valid JSON for their types
296 // We parse once here to validate, then again in the builder (unavoidable with ouroboros)
297 let _header: JwtHeader = serde_json::from_slice(&header_buf)?;
298 let _claims: ServiceAuthClaims = serde_json::from_slice(&payload_buf)?;
299
300 Ok(ParsedJwtBuilder {
301 header_buf,
302 payload_buf,
303 token: token.to_string(),
304 signature,
305 header_builder: |buf| {
306 // Safe: we validated this succeeds above
307 serde_json::from_slice(buf).expect("header was validated")
308 },
309 claims_builder: |buf| {
310 // Safe: we validated this succeeds above
311 serde_json::from_slice(buf).expect("claims were validated")
312 },
313 }
314 .build())
315}
316
317/// Public key types for signature verification.
318#[derive(Debug, Clone)]
319pub enum PublicKey {
320 /// P-256 (ES256) public key
321 #[cfg(feature = "crypto-p256")]
322 P256(P256VerifyingKey),
323
324 /// secp256k1 (ES256K) public key
325 #[cfg(feature = "crypto-k256")]
326 K256(K256VerifyingKey),
327}
328
329impl PublicKey {
330 /// Create a P-256 public key from compressed or uncompressed bytes.
331 #[cfg(feature = "crypto-p256")]
332 pub fn from_p256_bytes(bytes: &[u8]) -> Result<Self, ServiceAuthError> {
333 let key = P256VerifyingKey::from_sec1_bytes(bytes).map_err(|e| {
334 ServiceAuthError::Crypto(CowStr::Owned(format_smolstr!("invalid P-256 key: {}", e)))
335 })?;
336 Ok(PublicKey::P256(key))
337 }
338
339 /// Create a secp256k1 public key from compressed or uncompressed bytes.
340 #[cfg(feature = "crypto-k256")]
341 pub fn from_k256_bytes(bytes: &[u8]) -> Result<Self, ServiceAuthError> {
342 let key = K256VerifyingKey::from_sec1_bytes(bytes).map_err(|e| {
343 ServiceAuthError::Crypto(CowStr::Owned(format_smolstr!("invalid K-256 key: {}", e)))
344 })?;
345 Ok(PublicKey::K256(key))
346 }
347}
348
349/// Verify a JWT signature using the provided public key.
350///
351/// The algorithm is determined by the JWT header and must match the public key type.
352pub fn verify_signature(
353 parsed: &ParsedJwt,
354 public_key: &PublicKey,
355) -> Result<(), ServiceAuthError> {
356 let alg = parsed.header().alg.as_str();
357 let signing_input = parsed.signing_input();
358 let signature = parsed.signature();
359
360 match (alg, public_key) {
361 #[cfg(feature = "crypto-p256")]
362 ("ES256", PublicKey::P256(key)) => {
363 let sig = P256Signature::from_slice(signature).map_err(|e| {
364 ServiceAuthError::Crypto(CowStr::Owned(format_smolstr!(
365 "invalid ES256 signature: {}",
366 e
367 )))
368 })?;
369 key.verify(signing_input, &sig)
370 .map_err(|_| ServiceAuthError::InvalidSignature)?;
371 Ok(())
372 }
373
374 #[cfg(feature = "crypto-k256")]
375 ("ES256K", PublicKey::K256(key)) => {
376 let sig = K256Signature::from_slice(signature).map_err(|e| {
377 ServiceAuthError::Crypto(CowStr::Owned(format_smolstr!(
378 "invalid ES256K signature: {}",
379 e
380 )))
381 })?;
382 key.verify(signing_input, &sig)
383 .map_err(|_| ServiceAuthError::InvalidSignature)?;
384 Ok(())
385 }
386
387 _ => Err(ServiceAuthError::UnsupportedAlgorithm {
388 alg: SmolStr::new(alg),
389 }),
390 }
391}
392
393/// Parse and verify a service auth JWT in one step, returning owned claims.
394///
395/// This is a convenience function that combines parsing and signature verification.
396pub fn verify_service_jwt(
397 token: &str,
398 public_key: &PublicKey,
399) -> Result<ServiceAuthClaims<'static>, ServiceAuthError> {
400 let parsed = parse_jwt(token)?;
401 verify_signature(&parsed, public_key)?;
402 Ok(parsed.into_claims())
403}
404
405#[cfg(test)]
406mod tests {
407 use super::*;
408
409 #[test]
410 fn test_parse_jwt_invalid_format() {
411 let result = parse_jwt("not.a.valid.jwt.with.too.many.parts");
412 assert!(matches!(result, Err(ServiceAuthError::MalformedToken(_))));
413 }
414
415 #[test]
416 fn test_claims_expiration() {
417 let now = chrono::Utc::now().timestamp();
418 let expired_claims = ServiceAuthClaims {
419 iss: Did::new("did:plc:test").unwrap(),
420 aud: Did::new("did:web:example.com").unwrap(),
421 exp: now - 100,
422 iat: now - 200,
423 jti: None,
424 lxm: None,
425 };
426
427 assert!(expired_claims.is_expired());
428
429 let valid_claims = ServiceAuthClaims {
430 iss: Did::new("did:plc:test").unwrap(),
431 aud: Did::new("did:web:example.com").unwrap(),
432 exp: now + 100,
433 iat: now,
434 jti: None,
435 lxm: None,
436 };
437
438 assert!(!valid_claims.is_expired());
439 }
440
441 #[test]
442 fn test_audience_validation() {
443 let now = chrono::Utc::now().timestamp();
444 let claims = ServiceAuthClaims {
445 iss: Did::new("did:plc:test").unwrap(),
446 aud: Did::new("did:web:example.com").unwrap(),
447 exp: now + 100,
448 iat: now,
449 jti: None,
450 lxm: None,
451 };
452
453 let expected_aud = Did::new("did:web:example.com").unwrap();
454 assert!(claims.validate(&expected_aud).is_ok());
455
456 let wrong_aud = Did::new("did:web:wrong.com").unwrap();
457 assert!(matches!(
458 claims.validate(&wrong_aud),
459 Err(ServiceAuthError::AudienceMismatch { .. })
460 ));
461 }
462
463 #[test]
464 fn test_method_check() {
465 let claims = ServiceAuthClaims {
466 iss: Did::new("did:plc:test").unwrap(),
467 aud: Did::new("did:web:example.com").unwrap(),
468 exp: chrono::Utc::now().timestamp() + 100,
469 iat: chrono::Utc::now().timestamp(),
470 jti: None,
471 lxm: Some(Nsid::new("app.bsky.feed.getFeedSkeleton").unwrap()),
472 };
473
474 let expected = Nsid::new("app.bsky.feed.getFeedSkeleton").unwrap();
475 assert!(claims.check_method(&expected));
476
477 let wrong = Nsid::new("app.bsky.feed.getTimeline").unwrap();
478 assert!(!claims.check_method(&wrong));
479 }
480}