A community based topic aggregation platform built on atproto
1package auth
2
3import (
4 "crypto/sha256"
5 "encoding/base64"
6 "encoding/json"
7 "fmt"
8 "strings"
9 "sync"
10 "time"
11
12 indigoCrypto "github.com/bluesky-social/indigo/atproto/atcrypto"
13 "github.com/golang-jwt/jwt/v5"
14)
15
16// NonceCache provides replay protection for DPoP proofs by tracking seen jti values.
17// This prevents an attacker from reusing a captured DPoP proof within the validity window.
18// Per RFC 9449 Section 11.1, servers SHOULD prevent replay attacks.
19type NonceCache struct {
20 seen map[string]time.Time // jti -> expiration time
21 stopCh chan struct{}
22 maxAge time.Duration // How long to keep entries
23 cleanup time.Duration // How often to clean up expired entries
24 mu sync.RWMutex
25}
26
27// NewNonceCache creates a new nonce cache for DPoP replay protection.
28// maxAge should match or exceed DPoPVerifier.MaxProofAge.
29func NewNonceCache(maxAge time.Duration) *NonceCache {
30 nc := &NonceCache{
31 seen: make(map[string]time.Time),
32 maxAge: maxAge,
33 cleanup: maxAge / 2, // Clean up at half the max age
34 stopCh: make(chan struct{}),
35 }
36
37 // Start background cleanup goroutine
38 go nc.cleanupLoop()
39
40 return nc
41}
42
43// CheckAndStore checks if a jti has been seen before and stores it if not.
44// Returns true if the jti is fresh (not a replay), false if it's a replay.
45func (nc *NonceCache) CheckAndStore(jti string) bool {
46 nc.mu.Lock()
47 defer nc.mu.Unlock()
48
49 now := time.Now()
50 expiry := now.Add(nc.maxAge)
51
52 // Check if already seen
53 if existingExpiry, seen := nc.seen[jti]; seen {
54 // Still valid (not expired) - this is a replay
55 if existingExpiry.After(now) {
56 return false
57 }
58 // Expired entry - allow reuse and update expiry
59 }
60
61 // Store the new jti
62 nc.seen[jti] = expiry
63 return true
64}
65
66// cleanupLoop periodically removes expired entries from the cache
67func (nc *NonceCache) cleanupLoop() {
68 ticker := time.NewTicker(nc.cleanup)
69 defer ticker.Stop()
70
71 for {
72 select {
73 case <-ticker.C:
74 nc.cleanupExpired()
75 case <-nc.stopCh:
76 return
77 }
78 }
79}
80
81// cleanupExpired removes expired entries from the cache
82func (nc *NonceCache) cleanupExpired() {
83 nc.mu.Lock()
84 defer nc.mu.Unlock()
85
86 now := time.Now()
87 for jti, expiry := range nc.seen {
88 if expiry.Before(now) {
89 delete(nc.seen, jti)
90 }
91 }
92}
93
94// Stop stops the cleanup goroutine. Call this when done with the cache.
95func (nc *NonceCache) Stop() {
96 close(nc.stopCh)
97}
98
99// Size returns the number of entries in the cache (for testing/monitoring)
100func (nc *NonceCache) Size() int {
101 nc.mu.RLock()
102 defer nc.mu.RUnlock()
103 return len(nc.seen)
104}
105
106// DPoPClaims represents the claims in a DPoP proof JWT (RFC 9449)
107type DPoPClaims struct {
108 jwt.RegisteredClaims
109
110 // HTTP method of the request (e.g., "GET", "POST")
111 HTTPMethod string `json:"htm"`
112
113 // HTTP URI of the request (without query and fragment parts)
114 HTTPURI string `json:"htu"`
115
116 // Access token hash (optional, for token binding)
117 AccessTokenHash string `json:"ath,omitempty"`
118}
119
120// DPoPProof represents a parsed and verified DPoP proof
121type DPoPProof struct {
122 RawPublicJWK map[string]interface{}
123 Claims *DPoPClaims
124 PublicKey interface{} // *ecdsa.PublicKey or similar
125 Thumbprint string // JWK thumbprint (base64url)
126}
127
128// DPoPVerifier verifies DPoP proofs for OAuth token binding
129type DPoPVerifier struct {
130 // Optional: custom nonce validation function (for server-issued nonces)
131 ValidateNonce func(nonce string) bool
132
133 // NonceCache for replay protection (optional but recommended)
134 // If nil, jti replay protection is disabled
135 NonceCache *NonceCache
136
137 // Maximum allowed clock skew for timestamp validation
138 MaxClockSkew time.Duration
139
140 // Maximum age of DPoP proof (prevents replay with old proofs)
141 MaxProofAge time.Duration
142}
143
144// NewDPoPVerifier creates a DPoP verifier with sensible defaults including replay protection
145func NewDPoPVerifier() *DPoPVerifier {
146 maxProofAge := 5 * time.Minute
147 return &DPoPVerifier{
148 MaxClockSkew: 30 * time.Second,
149 MaxProofAge: maxProofAge,
150 NonceCache: NewNonceCache(maxProofAge),
151 }
152}
153
154// NewDPoPVerifierWithoutReplayProtection creates a DPoP verifier without replay protection.
155// This should only be used in testing or when replay protection is handled externally.
156func NewDPoPVerifierWithoutReplayProtection() *DPoPVerifier {
157 return &DPoPVerifier{
158 MaxClockSkew: 30 * time.Second,
159 MaxProofAge: 5 * time.Minute,
160 NonceCache: nil, // No replay protection
161 }
162}
163
164// Stop stops background goroutines. Call this when shutting down.
165func (v *DPoPVerifier) Stop() {
166 if v.NonceCache != nil {
167 v.NonceCache.Stop()
168 }
169}
170
171// VerifyDPoPProof verifies a DPoP proof JWT and returns the parsed proof.
172// This supports all atProto-compatible ECDSA algorithms including ES256K (secp256k1).
173func (v *DPoPVerifier) VerifyDPoPProof(dpopProof, httpMethod, httpURI string) (*DPoPProof, error) {
174 // Manually parse the JWT to support ES256K (which golang-jwt doesn't recognize)
175 header, claims, err := parseJWTHeaderAndClaims(dpopProof)
176 if err != nil {
177 return nil, fmt.Errorf("failed to parse DPoP proof: %w", err)
178 }
179
180 // Extract and validate the typ header
181 typ, ok := header["typ"].(string)
182 if !ok || typ != "dpop+jwt" {
183 return nil, fmt.Errorf("invalid DPoP proof: typ must be 'dpop+jwt', got '%s'", typ)
184 }
185
186 alg, ok := header["alg"].(string)
187 if !ok {
188 return nil, fmt.Errorf("invalid DPoP proof: missing alg header")
189 }
190
191 // Extract the JWK from the header first (needed for algorithm-curve validation)
192 jwkRaw, ok := header["jwk"]
193 if !ok {
194 return nil, fmt.Errorf("invalid DPoP proof: missing jwk header")
195 }
196
197 jwkMap, ok := jwkRaw.(map[string]interface{})
198 if !ok {
199 return nil, fmt.Errorf("invalid DPoP proof: jwk must be an object")
200 }
201
202 // Validate the algorithm is supported and matches the JWK curve
203 // This is critical for security - prevents algorithm confusion attacks
204 if err := validateAlgorithmCurveBinding(alg, jwkMap); err != nil {
205 return nil, fmt.Errorf("invalid DPoP proof: %w", err)
206 }
207
208 // Parse the public key using indigo's crypto package
209 // This supports all atProto curves including secp256k1 (ES256K)
210 publicKey, err := parseJWKToIndigoPublicKey(jwkMap)
211 if err != nil {
212 return nil, fmt.Errorf("invalid DPoP proof JWK: %w", err)
213 }
214
215 // Calculate the JWK thumbprint
216 thumbprint, err := CalculateJWKThumbprint(jwkMap)
217 if err != nil {
218 return nil, fmt.Errorf("failed to calculate JWK thumbprint: %w", err)
219 }
220
221 // Verify the signature using indigo's crypto package
222 // This works for all ECDSA algorithms including ES256K
223 if err := verifyJWTSignatureWithIndigo(dpopProof, publicKey); err != nil {
224 return nil, fmt.Errorf("DPoP proof signature verification failed: %w", err)
225 }
226
227 // Validate the claims
228 if err := v.validateDPoPClaims(claims, httpMethod, httpURI); err != nil {
229 return nil, err
230 }
231
232 return &DPoPProof{
233 Claims: claims,
234 PublicKey: publicKey,
235 Thumbprint: thumbprint,
236 RawPublicJWK: jwkMap,
237 }, nil
238}
239
240// validateDPoPClaims validates the DPoP proof claims
241func (v *DPoPVerifier) validateDPoPClaims(claims *DPoPClaims, expectedMethod, expectedURI string) error {
242 // Validate jti (unique identifier) is present
243 if claims.ID == "" {
244 return fmt.Errorf("DPoP proof missing jti claim")
245 }
246
247 // Validate htm (HTTP method)
248 if !strings.EqualFold(claims.HTTPMethod, expectedMethod) {
249 return fmt.Errorf("DPoP proof htm mismatch: expected %s, got %s", expectedMethod, claims.HTTPMethod)
250 }
251
252 // Validate htu (HTTP URI) - compare without query/fragment
253 expectedURIBase := stripQueryFragment(expectedURI)
254 claimURIBase := stripQueryFragment(claims.HTTPURI)
255 if expectedURIBase != claimURIBase {
256 return fmt.Errorf("DPoP proof htu mismatch: expected %s, got %s", expectedURIBase, claimURIBase)
257 }
258
259 // Validate iat (issued at) is present and recent
260 if claims.IssuedAt == nil {
261 return fmt.Errorf("DPoP proof missing iat claim")
262 }
263
264 now := time.Now()
265 iat := claims.IssuedAt.Time
266
267 // Check clock skew (not too far in the future)
268 if iat.After(now.Add(v.MaxClockSkew)) {
269 return fmt.Errorf("DPoP proof iat is in the future")
270 }
271
272 // Check proof age (not too old)
273 if now.Sub(iat) > v.MaxProofAge {
274 return fmt.Errorf("DPoP proof is too old (issued %v ago, max %v)", now.Sub(iat), v.MaxProofAge)
275 }
276
277 // SECURITY: Validate exp claim if present (RFC standard JWT validation)
278 // While DPoP proofs typically use iat + MaxProofAge, if exp is included it must be honored
279 if claims.ExpiresAt != nil {
280 expWithSkew := claims.ExpiresAt.Time.Add(v.MaxClockSkew)
281 if now.After(expWithSkew) {
282 return fmt.Errorf("DPoP proof expired at %v", claims.ExpiresAt.Time)
283 }
284 }
285
286 // SECURITY: Validate nbf claim if present (RFC standard JWT validation)
287 if claims.NotBefore != nil {
288 nbfWithSkew := claims.NotBefore.Time.Add(-v.MaxClockSkew)
289 if now.Before(nbfWithSkew) {
290 return fmt.Errorf("DPoP proof not valid before %v", claims.NotBefore.Time)
291 }
292 }
293
294 // SECURITY: Check for replay attack using jti
295 // Per RFC 9449 Section 11.1, servers SHOULD prevent replay attacks
296 if v.NonceCache != nil {
297 if !v.NonceCache.CheckAndStore(claims.ID) {
298 return fmt.Errorf("DPoP proof replay detected: jti %s already used", claims.ID)
299 }
300 }
301
302 return nil
303}
304
305// VerifyTokenBinding verifies that the DPoP proof binds to the access token
306// by comparing the proof's thumbprint to the token's cnf.jkt claim
307func (v *DPoPVerifier) VerifyTokenBinding(proof *DPoPProof, expectedThumbprint string) error {
308 if proof.Thumbprint != expectedThumbprint {
309 return fmt.Errorf("DPoP proof thumbprint mismatch: token expects %s, proof has %s",
310 expectedThumbprint, proof.Thumbprint)
311 }
312 return nil
313}
314
315// VerifyAccessTokenHash verifies the DPoP proof's ath (access token hash) claim
316// matches the SHA-256 hash of the presented access token.
317// Per RFC 9449 section 4.2, if ath is present, the RS MUST verify it.
318func (v *DPoPVerifier) VerifyAccessTokenHash(proof *DPoPProof, accessToken string) error {
319 // If ath claim is not present, that's acceptable per RFC 9449
320 // (ath is only required when the RS mandates it)
321 if proof.Claims.AccessTokenHash == "" {
322 return nil
323 }
324
325 // Calculate the expected ath: base64url(SHA-256(access_token))
326 hash := sha256.Sum256([]byte(accessToken))
327 expectedAth := base64.RawURLEncoding.EncodeToString(hash[:])
328
329 if proof.Claims.AccessTokenHash != expectedAth {
330 return fmt.Errorf("DPoP proof ath mismatch: proof bound to different access token")
331 }
332
333 return nil
334}
335
336// CalculateJWKThumbprint calculates the JWK thumbprint per RFC 7638
337// The thumbprint is the base64url-encoded SHA-256 hash of the canonical JWK representation
338func CalculateJWKThumbprint(jwk map[string]interface{}) (string, error) {
339 kty, ok := jwk["kty"].(string)
340 if !ok {
341 return "", fmt.Errorf("JWK missing kty")
342 }
343
344 // Build the canonical JWK representation based on key type
345 // Per RFC 7638, only specific members are included, in lexicographic order
346 var canonical map[string]string
347
348 switch kty {
349 case "EC":
350 crv, ok := jwk["crv"].(string)
351 if !ok {
352 return "", fmt.Errorf("EC JWK missing crv")
353 }
354 x, ok := jwk["x"].(string)
355 if !ok {
356 return "", fmt.Errorf("EC JWK missing x")
357 }
358 y, ok := jwk["y"].(string)
359 if !ok {
360 return "", fmt.Errorf("EC JWK missing y")
361 }
362 // Lexicographic order: crv, kty, x, y
363 canonical = map[string]string{
364 "crv": crv,
365 "kty": kty,
366 "x": x,
367 "y": y,
368 }
369 case "RSA":
370 e, ok := jwk["e"].(string)
371 if !ok {
372 return "", fmt.Errorf("RSA JWK missing e")
373 }
374 n, ok := jwk["n"].(string)
375 if !ok {
376 return "", fmt.Errorf("RSA JWK missing n")
377 }
378 // Lexicographic order: e, kty, n
379 canonical = map[string]string{
380 "e": e,
381 "kty": kty,
382 "n": n,
383 }
384 case "OKP":
385 crv, ok := jwk["crv"].(string)
386 if !ok {
387 return "", fmt.Errorf("OKP JWK missing crv")
388 }
389 x, ok := jwk["x"].(string)
390 if !ok {
391 return "", fmt.Errorf("OKP JWK missing x")
392 }
393 // Lexicographic order: crv, kty, x
394 canonical = map[string]string{
395 "crv": crv,
396 "kty": kty,
397 "x": x,
398 }
399 default:
400 return "", fmt.Errorf("unsupported JWK key type: %s", kty)
401 }
402
403 // Serialize to JSON (Go's json.Marshal produces lexicographically ordered keys for map[string]string)
404 canonicalJSON, err := json.Marshal(canonical)
405 if err != nil {
406 return "", fmt.Errorf("failed to serialize canonical JWK: %w", err)
407 }
408
409 // SHA-256 hash
410 hash := sha256.Sum256(canonicalJSON)
411
412 // Base64url encode (no padding)
413 thumbprint := base64.RawURLEncoding.EncodeToString(hash[:])
414
415 return thumbprint, nil
416}
417
418// validateAlgorithmCurveBinding validates that the JWT algorithm matches the JWK curve.
419// This is critical for security - an attacker could claim alg: "ES256K" but provide
420// a P-256 key, potentially bypassing algorithm binding requirements.
421func validateAlgorithmCurveBinding(alg string, jwkMap map[string]interface{}) error {
422 kty, ok := jwkMap["kty"].(string)
423 if !ok {
424 return fmt.Errorf("JWK missing kty")
425 }
426
427 // ECDSA algorithms require EC key type
428 switch alg {
429 case "ES256K", "ES256", "ES384", "ES512":
430 if kty != "EC" {
431 return fmt.Errorf("algorithm %s requires EC key type, got %s", alg, kty)
432 }
433 case "RS256", "RS384", "RS512", "PS256", "PS384", "PS512":
434 return fmt.Errorf("RSA algorithms not yet supported for DPoP: %s", alg)
435 default:
436 return fmt.Errorf("unsupported DPoP algorithm: %s", alg)
437 }
438
439 // Validate curve matches algorithm
440 crv, ok := jwkMap["crv"].(string)
441 if !ok {
442 return fmt.Errorf("EC JWK missing crv")
443 }
444
445 var expectedCurve string
446 switch alg {
447 case "ES256K":
448 expectedCurve = "secp256k1"
449 case "ES256":
450 expectedCurve = "P-256"
451 case "ES384":
452 expectedCurve = "P-384"
453 case "ES512":
454 expectedCurve = "P-521"
455 }
456
457 if crv != expectedCurve {
458 return fmt.Errorf("algorithm %s requires curve %s, got %s", alg, expectedCurve, crv)
459 }
460
461 return nil
462}
463
464// parseJWKToIndigoPublicKey parses a JWK map to an indigo PublicKey.
465// This returns indigo's PublicKey interface which supports all atProto curves
466// including secp256k1 (ES256K), P-256 (ES256), P-384 (ES384), and P-521 (ES512).
467func parseJWKToIndigoPublicKey(jwkMap map[string]interface{}) (indigoCrypto.PublicKey, error) {
468 // Convert map to JSON bytes for indigo's parser
469 jwkBytes, err := json.Marshal(jwkMap)
470 if err != nil {
471 return nil, fmt.Errorf("failed to serialize JWK: %w", err)
472 }
473
474 // Parse with indigo's crypto package - this supports all atProto curves
475 // including secp256k1 (ES256K) which Go's crypto/elliptic doesn't support
476 pubKey, err := indigoCrypto.ParsePublicJWKBytes(jwkBytes)
477 if err != nil {
478 return nil, fmt.Errorf("failed to parse JWK: %w", err)
479 }
480
481 return pubKey, nil
482}
483
484// parseJWTHeaderAndClaims manually parses a JWT's header and claims without using golang-jwt.
485// This is necessary to support ES256K (secp256k1) which golang-jwt doesn't recognize.
486func parseJWTHeaderAndClaims(tokenString string) (map[string]interface{}, *DPoPClaims, error) {
487 parts := strings.Split(tokenString, ".")
488 if len(parts) != 3 {
489 return nil, nil, fmt.Errorf("invalid JWT format: expected 3 parts, got %d", len(parts))
490 }
491
492 // Decode header
493 headerBytes, err := base64.RawURLEncoding.DecodeString(parts[0])
494 if err != nil {
495 return nil, nil, fmt.Errorf("failed to decode JWT header: %w", err)
496 }
497
498 var header map[string]interface{}
499 if err := json.Unmarshal(headerBytes, &header); err != nil {
500 return nil, nil, fmt.Errorf("failed to parse JWT header: %w", err)
501 }
502
503 // Decode claims
504 claimsBytes, err := base64.RawURLEncoding.DecodeString(parts[1])
505 if err != nil {
506 return nil, nil, fmt.Errorf("failed to decode JWT claims: %w", err)
507 }
508
509 // Parse into raw map first to extract standard claims
510 var rawClaims map[string]interface{}
511 if err := json.Unmarshal(claimsBytes, &rawClaims); err != nil {
512 return nil, nil, fmt.Errorf("failed to parse JWT claims: %w", err)
513 }
514
515 // Build DPoPClaims struct
516 claims := &DPoPClaims{}
517
518 // Extract jti
519 if jti, ok := rawClaims["jti"].(string); ok {
520 claims.ID = jti
521 }
522
523 // Extract iat (issued at)
524 if iat, ok := rawClaims["iat"].(float64); ok {
525 t := time.Unix(int64(iat), 0)
526 claims.IssuedAt = jwt.NewNumericDate(t)
527 }
528
529 // Extract exp (expiration) if present
530 if exp, ok := rawClaims["exp"].(float64); ok {
531 t := time.Unix(int64(exp), 0)
532 claims.ExpiresAt = jwt.NewNumericDate(t)
533 }
534
535 // Extract nbf (not before) if present
536 if nbf, ok := rawClaims["nbf"].(float64); ok {
537 t := time.Unix(int64(nbf), 0)
538 claims.NotBefore = jwt.NewNumericDate(t)
539 }
540
541 // Extract htm (HTTP method)
542 if htm, ok := rawClaims["htm"].(string); ok {
543 claims.HTTPMethod = htm
544 }
545
546 // Extract htu (HTTP URI)
547 if htu, ok := rawClaims["htu"].(string); ok {
548 claims.HTTPURI = htu
549 }
550
551 // Extract ath (access token hash) if present
552 if ath, ok := rawClaims["ath"].(string); ok {
553 claims.AccessTokenHash = ath
554 }
555
556 return header, claims, nil
557}
558
559// verifyJWTSignatureWithIndigo verifies a JWT signature using indigo's crypto package.
560// This is used instead of golang-jwt for algorithms not supported by golang-jwt (like ES256K).
561// It parses the JWT, extracts the signing input and signature, and uses indigo's
562// PublicKey.HashAndVerifyLenient() for verification.
563//
564// JWT format: header.payload.signature (all base64url-encoded)
565// Signature is verified over the raw bytes of "header.payload"
566// (indigo's HashAndVerifyLenient handles SHA-256 hashing internally)
567func verifyJWTSignatureWithIndigo(tokenString string, pubKey indigoCrypto.PublicKey) error {
568 // Split the JWT into parts
569 parts := strings.Split(tokenString, ".")
570 if len(parts) != 3 {
571 return fmt.Errorf("invalid JWT format: expected 3 parts, got %d", len(parts))
572 }
573
574 // The signing input is "header.payload" (without decoding)
575 signingInput := parts[0] + "." + parts[1]
576
577 // Decode the signature from base64url
578 signature, err := base64.RawURLEncoding.DecodeString(parts[2])
579 if err != nil {
580 return fmt.Errorf("failed to decode JWT signature: %w", err)
581 }
582
583 // Use indigo's verification - HashAndVerifyLenient handles hashing internally
584 // and accepts both low-S and high-S signatures for maximum compatibility
585 err = pubKey.HashAndVerifyLenient([]byte(signingInput), signature)
586 if err != nil {
587 return fmt.Errorf("signature verification failed: %w", err)
588 }
589
590 return nil
591}
592
593// stripQueryFragment removes query and fragment from a URI
594func stripQueryFragment(uri string) string {
595 if idx := strings.Index(uri, "?"); idx != -1 {
596 uri = uri[:idx]
597 }
598 if idx := strings.Index(uri, "#"); idx != -1 {
599 uri = uri[:idx]
600 }
601 return uri
602}
603
604// ExtractCnfJkt extracts the cnf.jkt (confirmation key thumbprint) from JWT claims
605func ExtractCnfJkt(claims *Claims) (string, error) {
606 if claims.Confirmation == nil {
607 return "", fmt.Errorf("token missing cnf claim (no DPoP binding)")
608 }
609
610 jkt, ok := claims.Confirmation["jkt"].(string)
611 if !ok || jkt == "" {
612 return "", fmt.Errorf("token cnf claim missing jkt (DPoP key thumbprint)")
613 }
614
615 return jkt, nil
616}