A community based topic aggregation platform built on atproto
1package auth 2 3import ( 4 "crypto/sha256" 5 "encoding/base64" 6 "encoding/json" 7 "fmt" 8 "strings" 9 "sync" 10 "time" 11 12 indigoCrypto "github.com/bluesky-social/indigo/atproto/atcrypto" 13 "github.com/golang-jwt/jwt/v5" 14) 15 16// NonceCache provides replay protection for DPoP proofs by tracking seen jti values. 17// This prevents an attacker from reusing a captured DPoP proof within the validity window. 18// Per RFC 9449 Section 11.1, servers SHOULD prevent replay attacks. 19type NonceCache struct { 20 seen map[string]time.Time // jti -> expiration time 21 stopCh chan struct{} 22 maxAge time.Duration // How long to keep entries 23 cleanup time.Duration // How often to clean up expired entries 24 mu sync.RWMutex 25} 26 27// NewNonceCache creates a new nonce cache for DPoP replay protection. 28// maxAge should match or exceed DPoPVerifier.MaxProofAge. 29func NewNonceCache(maxAge time.Duration) *NonceCache { 30 nc := &NonceCache{ 31 seen: make(map[string]time.Time), 32 maxAge: maxAge, 33 cleanup: maxAge / 2, // Clean up at half the max age 34 stopCh: make(chan struct{}), 35 } 36 37 // Start background cleanup goroutine 38 go nc.cleanupLoop() 39 40 return nc 41} 42 43// CheckAndStore checks if a jti has been seen before and stores it if not. 44// Returns true if the jti is fresh (not a replay), false if it's a replay. 45func (nc *NonceCache) CheckAndStore(jti string) bool { 46 nc.mu.Lock() 47 defer nc.mu.Unlock() 48 49 now := time.Now() 50 expiry := now.Add(nc.maxAge) 51 52 // Check if already seen 53 if existingExpiry, seen := nc.seen[jti]; seen { 54 // Still valid (not expired) - this is a replay 55 if existingExpiry.After(now) { 56 return false 57 } 58 // Expired entry - allow reuse and update expiry 59 } 60 61 // Store the new jti 62 nc.seen[jti] = expiry 63 return true 64} 65 66// cleanupLoop periodically removes expired entries from the cache 67func (nc *NonceCache) cleanupLoop() { 68 ticker := time.NewTicker(nc.cleanup) 69 defer ticker.Stop() 70 71 for { 72 select { 73 case <-ticker.C: 74 nc.cleanupExpired() 75 case <-nc.stopCh: 76 return 77 } 78 } 79} 80 81// cleanupExpired removes expired entries from the cache 82func (nc *NonceCache) cleanupExpired() { 83 nc.mu.Lock() 84 defer nc.mu.Unlock() 85 86 now := time.Now() 87 for jti, expiry := range nc.seen { 88 if expiry.Before(now) { 89 delete(nc.seen, jti) 90 } 91 } 92} 93 94// Stop stops the cleanup goroutine. Call this when done with the cache. 95func (nc *NonceCache) Stop() { 96 close(nc.stopCh) 97} 98 99// Size returns the number of entries in the cache (for testing/monitoring) 100func (nc *NonceCache) Size() int { 101 nc.mu.RLock() 102 defer nc.mu.RUnlock() 103 return len(nc.seen) 104} 105 106// DPoPClaims represents the claims in a DPoP proof JWT (RFC 9449) 107type DPoPClaims struct { 108 jwt.RegisteredClaims 109 110 // HTTP method of the request (e.g., "GET", "POST") 111 HTTPMethod string `json:"htm"` 112 113 // HTTP URI of the request (without query and fragment parts) 114 HTTPURI string `json:"htu"` 115 116 // Access token hash (optional, for token binding) 117 AccessTokenHash string `json:"ath,omitempty"` 118} 119 120// DPoPProof represents a parsed and verified DPoP proof 121type DPoPProof struct { 122 RawPublicJWK map[string]interface{} 123 Claims *DPoPClaims 124 PublicKey interface{} // *ecdsa.PublicKey or similar 125 Thumbprint string // JWK thumbprint (base64url) 126} 127 128// DPoPVerifier verifies DPoP proofs for OAuth token binding 129type DPoPVerifier struct { 130 // Optional: custom nonce validation function (for server-issued nonces) 131 ValidateNonce func(nonce string) bool 132 133 // NonceCache for replay protection (optional but recommended) 134 // If nil, jti replay protection is disabled 135 NonceCache *NonceCache 136 137 // Maximum allowed clock skew for timestamp validation 138 MaxClockSkew time.Duration 139 140 // Maximum age of DPoP proof (prevents replay with old proofs) 141 MaxProofAge time.Duration 142} 143 144// NewDPoPVerifier creates a DPoP verifier with sensible defaults including replay protection 145func NewDPoPVerifier() *DPoPVerifier { 146 maxProofAge := 5 * time.Minute 147 return &DPoPVerifier{ 148 MaxClockSkew: 30 * time.Second, 149 MaxProofAge: maxProofAge, 150 NonceCache: NewNonceCache(maxProofAge), 151 } 152} 153 154// NewDPoPVerifierWithoutReplayProtection creates a DPoP verifier without replay protection. 155// This should only be used in testing or when replay protection is handled externally. 156func NewDPoPVerifierWithoutReplayProtection() *DPoPVerifier { 157 return &DPoPVerifier{ 158 MaxClockSkew: 30 * time.Second, 159 MaxProofAge: 5 * time.Minute, 160 NonceCache: nil, // No replay protection 161 } 162} 163 164// Stop stops background goroutines. Call this when shutting down. 165func (v *DPoPVerifier) Stop() { 166 if v.NonceCache != nil { 167 v.NonceCache.Stop() 168 } 169} 170 171// VerifyDPoPProof verifies a DPoP proof JWT and returns the parsed proof. 172// This supports all atProto-compatible ECDSA algorithms including ES256K (secp256k1). 173func (v *DPoPVerifier) VerifyDPoPProof(dpopProof, httpMethod, httpURI string) (*DPoPProof, error) { 174 // Manually parse the JWT to support ES256K (which golang-jwt doesn't recognize) 175 header, claims, err := parseJWTHeaderAndClaims(dpopProof) 176 if err != nil { 177 return nil, fmt.Errorf("failed to parse DPoP proof: %w", err) 178 } 179 180 // Extract and validate the typ header 181 typ, ok := header["typ"].(string) 182 if !ok || typ != "dpop+jwt" { 183 return nil, fmt.Errorf("invalid DPoP proof: typ must be 'dpop+jwt', got '%s'", typ) 184 } 185 186 alg, ok := header["alg"].(string) 187 if !ok { 188 return nil, fmt.Errorf("invalid DPoP proof: missing alg header") 189 } 190 191 // Extract the JWK from the header first (needed for algorithm-curve validation) 192 jwkRaw, ok := header["jwk"] 193 if !ok { 194 return nil, fmt.Errorf("invalid DPoP proof: missing jwk header") 195 } 196 197 jwkMap, ok := jwkRaw.(map[string]interface{}) 198 if !ok { 199 return nil, fmt.Errorf("invalid DPoP proof: jwk must be an object") 200 } 201 202 // Validate the algorithm is supported and matches the JWK curve 203 // This is critical for security - prevents algorithm confusion attacks 204 if err := validateAlgorithmCurveBinding(alg, jwkMap); err != nil { 205 return nil, fmt.Errorf("invalid DPoP proof: %w", err) 206 } 207 208 // Parse the public key using indigo's crypto package 209 // This supports all atProto curves including secp256k1 (ES256K) 210 publicKey, err := parseJWKToIndigoPublicKey(jwkMap) 211 if err != nil { 212 return nil, fmt.Errorf("invalid DPoP proof JWK: %w", err) 213 } 214 215 // Calculate the JWK thumbprint 216 thumbprint, err := CalculateJWKThumbprint(jwkMap) 217 if err != nil { 218 return nil, fmt.Errorf("failed to calculate JWK thumbprint: %w", err) 219 } 220 221 // Verify the signature using indigo's crypto package 222 // This works for all ECDSA algorithms including ES256K 223 if err := verifyJWTSignatureWithIndigo(dpopProof, publicKey); err != nil { 224 return nil, fmt.Errorf("DPoP proof signature verification failed: %w", err) 225 } 226 227 // Validate the claims 228 if err := v.validateDPoPClaims(claims, httpMethod, httpURI); err != nil { 229 return nil, err 230 } 231 232 return &DPoPProof{ 233 Claims: claims, 234 PublicKey: publicKey, 235 Thumbprint: thumbprint, 236 RawPublicJWK: jwkMap, 237 }, nil 238} 239 240// validateDPoPClaims validates the DPoP proof claims 241func (v *DPoPVerifier) validateDPoPClaims(claims *DPoPClaims, expectedMethod, expectedURI string) error { 242 // Validate jti (unique identifier) is present 243 if claims.ID == "" { 244 return fmt.Errorf("DPoP proof missing jti claim") 245 } 246 247 // Validate htm (HTTP method) 248 if !strings.EqualFold(claims.HTTPMethod, expectedMethod) { 249 return fmt.Errorf("DPoP proof htm mismatch: expected %s, got %s", expectedMethod, claims.HTTPMethod) 250 } 251 252 // Validate htu (HTTP URI) - compare without query/fragment 253 expectedURIBase := stripQueryFragment(expectedURI) 254 claimURIBase := stripQueryFragment(claims.HTTPURI) 255 if expectedURIBase != claimURIBase { 256 return fmt.Errorf("DPoP proof htu mismatch: expected %s, got %s", expectedURIBase, claimURIBase) 257 } 258 259 // Validate iat (issued at) is present and recent 260 if claims.IssuedAt == nil { 261 return fmt.Errorf("DPoP proof missing iat claim") 262 } 263 264 now := time.Now() 265 iat := claims.IssuedAt.Time 266 267 // Check clock skew (not too far in the future) 268 if iat.After(now.Add(v.MaxClockSkew)) { 269 return fmt.Errorf("DPoP proof iat is in the future") 270 } 271 272 // Check proof age (not too old) 273 if now.Sub(iat) > v.MaxProofAge { 274 return fmt.Errorf("DPoP proof is too old (issued %v ago, max %v)", now.Sub(iat), v.MaxProofAge) 275 } 276 277 // SECURITY: Validate exp claim if present (RFC standard JWT validation) 278 // While DPoP proofs typically use iat + MaxProofAge, if exp is included it must be honored 279 if claims.ExpiresAt != nil { 280 expWithSkew := claims.ExpiresAt.Time.Add(v.MaxClockSkew) 281 if now.After(expWithSkew) { 282 return fmt.Errorf("DPoP proof expired at %v", claims.ExpiresAt.Time) 283 } 284 } 285 286 // SECURITY: Validate nbf claim if present (RFC standard JWT validation) 287 if claims.NotBefore != nil { 288 nbfWithSkew := claims.NotBefore.Time.Add(-v.MaxClockSkew) 289 if now.Before(nbfWithSkew) { 290 return fmt.Errorf("DPoP proof not valid before %v", claims.NotBefore.Time) 291 } 292 } 293 294 // SECURITY: Check for replay attack using jti 295 // Per RFC 9449 Section 11.1, servers SHOULD prevent replay attacks 296 if v.NonceCache != nil { 297 if !v.NonceCache.CheckAndStore(claims.ID) { 298 return fmt.Errorf("DPoP proof replay detected: jti %s already used", claims.ID) 299 } 300 } 301 302 return nil 303} 304 305// VerifyTokenBinding verifies that the DPoP proof binds to the access token 306// by comparing the proof's thumbprint to the token's cnf.jkt claim 307func (v *DPoPVerifier) VerifyTokenBinding(proof *DPoPProof, expectedThumbprint string) error { 308 if proof.Thumbprint != expectedThumbprint { 309 return fmt.Errorf("DPoP proof thumbprint mismatch: token expects %s, proof has %s", 310 expectedThumbprint, proof.Thumbprint) 311 } 312 return nil 313} 314 315// VerifyAccessTokenHash verifies the DPoP proof's ath (access token hash) claim 316// matches the SHA-256 hash of the presented access token. 317// Per RFC 9449 section 4.2, if ath is present, the RS MUST verify it. 318func (v *DPoPVerifier) VerifyAccessTokenHash(proof *DPoPProof, accessToken string) error { 319 // If ath claim is not present, that's acceptable per RFC 9449 320 // (ath is only required when the RS mandates it) 321 if proof.Claims.AccessTokenHash == "" { 322 return nil 323 } 324 325 // Calculate the expected ath: base64url(SHA-256(access_token)) 326 hash := sha256.Sum256([]byte(accessToken)) 327 expectedAth := base64.RawURLEncoding.EncodeToString(hash[:]) 328 329 if proof.Claims.AccessTokenHash != expectedAth { 330 return fmt.Errorf("DPoP proof ath mismatch: proof bound to different access token") 331 } 332 333 return nil 334} 335 336// CalculateJWKThumbprint calculates the JWK thumbprint per RFC 7638 337// The thumbprint is the base64url-encoded SHA-256 hash of the canonical JWK representation 338func CalculateJWKThumbprint(jwk map[string]interface{}) (string, error) { 339 kty, ok := jwk["kty"].(string) 340 if !ok { 341 return "", fmt.Errorf("JWK missing kty") 342 } 343 344 // Build the canonical JWK representation based on key type 345 // Per RFC 7638, only specific members are included, in lexicographic order 346 var canonical map[string]string 347 348 switch kty { 349 case "EC": 350 crv, ok := jwk["crv"].(string) 351 if !ok { 352 return "", fmt.Errorf("EC JWK missing crv") 353 } 354 x, ok := jwk["x"].(string) 355 if !ok { 356 return "", fmt.Errorf("EC JWK missing x") 357 } 358 y, ok := jwk["y"].(string) 359 if !ok { 360 return "", fmt.Errorf("EC JWK missing y") 361 } 362 // Lexicographic order: crv, kty, x, y 363 canonical = map[string]string{ 364 "crv": crv, 365 "kty": kty, 366 "x": x, 367 "y": y, 368 } 369 case "RSA": 370 e, ok := jwk["e"].(string) 371 if !ok { 372 return "", fmt.Errorf("RSA JWK missing e") 373 } 374 n, ok := jwk["n"].(string) 375 if !ok { 376 return "", fmt.Errorf("RSA JWK missing n") 377 } 378 // Lexicographic order: e, kty, n 379 canonical = map[string]string{ 380 "e": e, 381 "kty": kty, 382 "n": n, 383 } 384 case "OKP": 385 crv, ok := jwk["crv"].(string) 386 if !ok { 387 return "", fmt.Errorf("OKP JWK missing crv") 388 } 389 x, ok := jwk["x"].(string) 390 if !ok { 391 return "", fmt.Errorf("OKP JWK missing x") 392 } 393 // Lexicographic order: crv, kty, x 394 canonical = map[string]string{ 395 "crv": crv, 396 "kty": kty, 397 "x": x, 398 } 399 default: 400 return "", fmt.Errorf("unsupported JWK key type: %s", kty) 401 } 402 403 // Serialize to JSON (Go's json.Marshal produces lexicographically ordered keys for map[string]string) 404 canonicalJSON, err := json.Marshal(canonical) 405 if err != nil { 406 return "", fmt.Errorf("failed to serialize canonical JWK: %w", err) 407 } 408 409 // SHA-256 hash 410 hash := sha256.Sum256(canonicalJSON) 411 412 // Base64url encode (no padding) 413 thumbprint := base64.RawURLEncoding.EncodeToString(hash[:]) 414 415 return thumbprint, nil 416} 417 418// validateAlgorithmCurveBinding validates that the JWT algorithm matches the JWK curve. 419// This is critical for security - an attacker could claim alg: "ES256K" but provide 420// a P-256 key, potentially bypassing algorithm binding requirements. 421func validateAlgorithmCurveBinding(alg string, jwkMap map[string]interface{}) error { 422 kty, ok := jwkMap["kty"].(string) 423 if !ok { 424 return fmt.Errorf("JWK missing kty") 425 } 426 427 // ECDSA algorithms require EC key type 428 switch alg { 429 case "ES256K", "ES256", "ES384", "ES512": 430 if kty != "EC" { 431 return fmt.Errorf("algorithm %s requires EC key type, got %s", alg, kty) 432 } 433 case "RS256", "RS384", "RS512", "PS256", "PS384", "PS512": 434 return fmt.Errorf("RSA algorithms not yet supported for DPoP: %s", alg) 435 default: 436 return fmt.Errorf("unsupported DPoP algorithm: %s", alg) 437 } 438 439 // Validate curve matches algorithm 440 crv, ok := jwkMap["crv"].(string) 441 if !ok { 442 return fmt.Errorf("EC JWK missing crv") 443 } 444 445 var expectedCurve string 446 switch alg { 447 case "ES256K": 448 expectedCurve = "secp256k1" 449 case "ES256": 450 expectedCurve = "P-256" 451 case "ES384": 452 expectedCurve = "P-384" 453 case "ES512": 454 expectedCurve = "P-521" 455 } 456 457 if crv != expectedCurve { 458 return fmt.Errorf("algorithm %s requires curve %s, got %s", alg, expectedCurve, crv) 459 } 460 461 return nil 462} 463 464// parseJWKToIndigoPublicKey parses a JWK map to an indigo PublicKey. 465// This returns indigo's PublicKey interface which supports all atProto curves 466// including secp256k1 (ES256K), P-256 (ES256), P-384 (ES384), and P-521 (ES512). 467func parseJWKToIndigoPublicKey(jwkMap map[string]interface{}) (indigoCrypto.PublicKey, error) { 468 // Convert map to JSON bytes for indigo's parser 469 jwkBytes, err := json.Marshal(jwkMap) 470 if err != nil { 471 return nil, fmt.Errorf("failed to serialize JWK: %w", err) 472 } 473 474 // Parse with indigo's crypto package - this supports all atProto curves 475 // including secp256k1 (ES256K) which Go's crypto/elliptic doesn't support 476 pubKey, err := indigoCrypto.ParsePublicJWKBytes(jwkBytes) 477 if err != nil { 478 return nil, fmt.Errorf("failed to parse JWK: %w", err) 479 } 480 481 return pubKey, nil 482} 483 484// parseJWTHeaderAndClaims manually parses a JWT's header and claims without using golang-jwt. 485// This is necessary to support ES256K (secp256k1) which golang-jwt doesn't recognize. 486func parseJWTHeaderAndClaims(tokenString string) (map[string]interface{}, *DPoPClaims, error) { 487 parts := strings.Split(tokenString, ".") 488 if len(parts) != 3 { 489 return nil, nil, fmt.Errorf("invalid JWT format: expected 3 parts, got %d", len(parts)) 490 } 491 492 // Decode header 493 headerBytes, err := base64.RawURLEncoding.DecodeString(parts[0]) 494 if err != nil { 495 return nil, nil, fmt.Errorf("failed to decode JWT header: %w", err) 496 } 497 498 var header map[string]interface{} 499 if err := json.Unmarshal(headerBytes, &header); err != nil { 500 return nil, nil, fmt.Errorf("failed to parse JWT header: %w", err) 501 } 502 503 // Decode claims 504 claimsBytes, err := base64.RawURLEncoding.DecodeString(parts[1]) 505 if err != nil { 506 return nil, nil, fmt.Errorf("failed to decode JWT claims: %w", err) 507 } 508 509 // Parse into raw map first to extract standard claims 510 var rawClaims map[string]interface{} 511 if err := json.Unmarshal(claimsBytes, &rawClaims); err != nil { 512 return nil, nil, fmt.Errorf("failed to parse JWT claims: %w", err) 513 } 514 515 // Build DPoPClaims struct 516 claims := &DPoPClaims{} 517 518 // Extract jti 519 if jti, ok := rawClaims["jti"].(string); ok { 520 claims.ID = jti 521 } 522 523 // Extract iat (issued at) 524 if iat, ok := rawClaims["iat"].(float64); ok { 525 t := time.Unix(int64(iat), 0) 526 claims.IssuedAt = jwt.NewNumericDate(t) 527 } 528 529 // Extract exp (expiration) if present 530 if exp, ok := rawClaims["exp"].(float64); ok { 531 t := time.Unix(int64(exp), 0) 532 claims.ExpiresAt = jwt.NewNumericDate(t) 533 } 534 535 // Extract nbf (not before) if present 536 if nbf, ok := rawClaims["nbf"].(float64); ok { 537 t := time.Unix(int64(nbf), 0) 538 claims.NotBefore = jwt.NewNumericDate(t) 539 } 540 541 // Extract htm (HTTP method) 542 if htm, ok := rawClaims["htm"].(string); ok { 543 claims.HTTPMethod = htm 544 } 545 546 // Extract htu (HTTP URI) 547 if htu, ok := rawClaims["htu"].(string); ok { 548 claims.HTTPURI = htu 549 } 550 551 // Extract ath (access token hash) if present 552 if ath, ok := rawClaims["ath"].(string); ok { 553 claims.AccessTokenHash = ath 554 } 555 556 return header, claims, nil 557} 558 559// verifyJWTSignatureWithIndigo verifies a JWT signature using indigo's crypto package. 560// This is used instead of golang-jwt for algorithms not supported by golang-jwt (like ES256K). 561// It parses the JWT, extracts the signing input and signature, and uses indigo's 562// PublicKey.HashAndVerifyLenient() for verification. 563// 564// JWT format: header.payload.signature (all base64url-encoded) 565// Signature is verified over the raw bytes of "header.payload" 566// (indigo's HashAndVerifyLenient handles SHA-256 hashing internally) 567func verifyJWTSignatureWithIndigo(tokenString string, pubKey indigoCrypto.PublicKey) error { 568 // Split the JWT into parts 569 parts := strings.Split(tokenString, ".") 570 if len(parts) != 3 { 571 return fmt.Errorf("invalid JWT format: expected 3 parts, got %d", len(parts)) 572 } 573 574 // The signing input is "header.payload" (without decoding) 575 signingInput := parts[0] + "." + parts[1] 576 577 // Decode the signature from base64url 578 signature, err := base64.RawURLEncoding.DecodeString(parts[2]) 579 if err != nil { 580 return fmt.Errorf("failed to decode JWT signature: %w", err) 581 } 582 583 // Use indigo's verification - HashAndVerifyLenient handles hashing internally 584 // and accepts both low-S and high-S signatures for maximum compatibility 585 err = pubKey.HashAndVerifyLenient([]byte(signingInput), signature) 586 if err != nil { 587 return fmt.Errorf("signature verification failed: %w", err) 588 } 589 590 return nil 591} 592 593// stripQueryFragment removes query and fragment from a URI 594func stripQueryFragment(uri string) string { 595 if idx := strings.Index(uri, "?"); idx != -1 { 596 uri = uri[:idx] 597 } 598 if idx := strings.Index(uri, "#"); idx != -1 { 599 uri = uri[:idx] 600 } 601 return uri 602} 603 604// ExtractCnfJkt extracts the cnf.jkt (confirmation key thumbprint) from JWT claims 605func ExtractCnfJkt(claims *Claims) (string, error) { 606 if claims.Confirmation == nil { 607 return "", fmt.Errorf("token missing cnf claim (no DPoP binding)") 608 } 609 610 jkt, ok := claims.Confirmation["jkt"].(string) 611 if !ok || jkt == "" { 612 return "", fmt.Errorf("token cnf claim missing jkt (DPoP key thumbprint)") 613 } 614 615 return jkt, nil 616}