A community based topic aggregation platform built on atproto
1package auth 2 3import ( 4 "crypto/ecdsa" 5 "crypto/elliptic" 6 "crypto/sha256" 7 "encoding/base64" 8 "encoding/json" 9 "fmt" 10 "math/big" 11 "strings" 12 "sync" 13 "time" 14 15 indigoCrypto "github.com/bluesky-social/indigo/atproto/atcrypto" 16 "github.com/golang-jwt/jwt/v5" 17) 18 19// NonceCache provides replay protection for DPoP proofs by tracking seen jti values. 20// This prevents an attacker from reusing a captured DPoP proof within the validity window. 21// Per RFC 9449 Section 11.1, servers SHOULD prevent replay attacks. 22type NonceCache struct { 23 seen map[string]time.Time // jti -> expiration time 24 stopCh chan struct{} 25 maxAge time.Duration // How long to keep entries 26 cleanup time.Duration // How often to clean up expired entries 27 mu sync.RWMutex 28} 29 30// NewNonceCache creates a new nonce cache for DPoP replay protection. 31// maxAge should match or exceed DPoPVerifier.MaxProofAge. 32func NewNonceCache(maxAge time.Duration) *NonceCache { 33 nc := &NonceCache{ 34 seen: make(map[string]time.Time), 35 maxAge: maxAge, 36 cleanup: maxAge / 2, // Clean up at half the max age 37 stopCh: make(chan struct{}), 38 } 39 40 // Start background cleanup goroutine 41 go nc.cleanupLoop() 42 43 return nc 44} 45 46// CheckAndStore checks if a jti has been seen before and stores it if not. 47// Returns true if the jti is fresh (not a replay), false if it's a replay. 48func (nc *NonceCache) CheckAndStore(jti string) bool { 49 nc.mu.Lock() 50 defer nc.mu.Unlock() 51 52 now := time.Now() 53 expiry := now.Add(nc.maxAge) 54 55 // Check if already seen 56 if existingExpiry, seen := nc.seen[jti]; seen { 57 // Still valid (not expired) - this is a replay 58 if existingExpiry.After(now) { 59 return false 60 } 61 // Expired entry - allow reuse and update expiry 62 } 63 64 // Store the new jti 65 nc.seen[jti] = expiry 66 return true 67} 68 69// cleanupLoop periodically removes expired entries from the cache 70func (nc *NonceCache) cleanupLoop() { 71 ticker := time.NewTicker(nc.cleanup) 72 defer ticker.Stop() 73 74 for { 75 select { 76 case <-ticker.C: 77 nc.cleanupExpired() 78 case <-nc.stopCh: 79 return 80 } 81 } 82} 83 84// cleanupExpired removes expired entries from the cache 85func (nc *NonceCache) cleanupExpired() { 86 nc.mu.Lock() 87 defer nc.mu.Unlock() 88 89 now := time.Now() 90 for jti, expiry := range nc.seen { 91 if expiry.Before(now) { 92 delete(nc.seen, jti) 93 } 94 } 95} 96 97// Stop stops the cleanup goroutine. Call this when done with the cache. 98func (nc *NonceCache) Stop() { 99 close(nc.stopCh) 100} 101 102// Size returns the number of entries in the cache (for testing/monitoring) 103func (nc *NonceCache) Size() int { 104 nc.mu.RLock() 105 defer nc.mu.RUnlock() 106 return len(nc.seen) 107} 108 109// DPoPClaims represents the claims in a DPoP proof JWT (RFC 9449) 110type DPoPClaims struct { 111 jwt.RegisteredClaims 112 113 // HTTP method of the request (e.g., "GET", "POST") 114 HTTPMethod string `json:"htm"` 115 116 // HTTP URI of the request (without query and fragment parts) 117 HTTPURI string `json:"htu"` 118 119 // Access token hash (optional, for token binding) 120 AccessTokenHash string `json:"ath,omitempty"` 121} 122 123// DPoPProof represents a parsed and verified DPoP proof 124type DPoPProof struct { 125 RawPublicJWK map[string]interface{} 126 Claims *DPoPClaims 127 PublicKey interface{} // *ecdsa.PublicKey or similar 128 Thumbprint string // JWK thumbprint (base64url) 129} 130 131// DPoPVerifier verifies DPoP proofs for OAuth token binding 132type DPoPVerifier struct { 133 // Optional: custom nonce validation function (for server-issued nonces) 134 ValidateNonce func(nonce string) bool 135 136 // NonceCache for replay protection (optional but recommended) 137 // If nil, jti replay protection is disabled 138 NonceCache *NonceCache 139 140 // Maximum allowed clock skew for timestamp validation 141 MaxClockSkew time.Duration 142 143 // Maximum age of DPoP proof (prevents replay with old proofs) 144 MaxProofAge time.Duration 145} 146 147// NewDPoPVerifier creates a DPoP verifier with sensible defaults including replay protection 148func NewDPoPVerifier() *DPoPVerifier { 149 maxProofAge := 5 * time.Minute 150 return &DPoPVerifier{ 151 MaxClockSkew: 30 * time.Second, 152 MaxProofAge: maxProofAge, 153 NonceCache: NewNonceCache(maxProofAge), 154 } 155} 156 157// NewDPoPVerifierWithoutReplayProtection creates a DPoP verifier without replay protection. 158// This should only be used in testing or when replay protection is handled externally. 159func NewDPoPVerifierWithoutReplayProtection() *DPoPVerifier { 160 return &DPoPVerifier{ 161 MaxClockSkew: 30 * time.Second, 162 MaxProofAge: 5 * time.Minute, 163 NonceCache: nil, // No replay protection 164 } 165} 166 167// Stop stops background goroutines. Call this when shutting down. 168func (v *DPoPVerifier) Stop() { 169 if v.NonceCache != nil { 170 v.NonceCache.Stop() 171 } 172} 173 174// VerifyDPoPProof verifies a DPoP proof JWT and returns the parsed proof 175func (v *DPoPVerifier) VerifyDPoPProof(dpopProof, httpMethod, httpURI string) (*DPoPProof, error) { 176 // Parse the DPoP JWT without verification first to extract the header 177 parser := jwt.NewParser(jwt.WithoutClaimsValidation()) 178 token, _, err := parser.ParseUnverified(dpopProof, &DPoPClaims{}) 179 if err != nil { 180 return nil, fmt.Errorf("failed to parse DPoP proof: %w", err) 181 } 182 183 // Extract and validate the header 184 header, ok := token.Header["typ"].(string) 185 if !ok || header != "dpop+jwt" { 186 return nil, fmt.Errorf("invalid DPoP proof: typ must be 'dpop+jwt', got '%s'", header) 187 } 188 189 alg, ok := token.Header["alg"].(string) 190 if !ok { 191 return nil, fmt.Errorf("invalid DPoP proof: missing alg header") 192 } 193 194 // Extract the JWK from the header 195 jwkRaw, ok := token.Header["jwk"] 196 if !ok { 197 return nil, fmt.Errorf("invalid DPoP proof: missing jwk header") 198 } 199 200 jwkMap, ok := jwkRaw.(map[string]interface{}) 201 if !ok { 202 return nil, fmt.Errorf("invalid DPoP proof: jwk must be an object") 203 } 204 205 // Parse the public key from JWK 206 publicKey, err := parseJWKToPublicKey(jwkMap) 207 if err != nil { 208 return nil, fmt.Errorf("invalid DPoP proof JWK: %w", err) 209 } 210 211 // Calculate the JWK thumbprint 212 thumbprint, err := CalculateJWKThumbprint(jwkMap) 213 if err != nil { 214 return nil, fmt.Errorf("failed to calculate JWK thumbprint: %w", err) 215 } 216 217 // Now verify the signature 218 verifiedToken, err := jwt.ParseWithClaims(dpopProof, &DPoPClaims{}, func(token *jwt.Token) (interface{}, error) { 219 // Verify the signing method matches what we expect 220 switch alg { 221 case "ES256": 222 if _, ok := token.Method.(*jwt.SigningMethodECDSA); !ok { 223 return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"]) 224 } 225 case "ES384", "ES512": 226 if _, ok := token.Method.(*jwt.SigningMethodECDSA); !ok { 227 return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"]) 228 } 229 case "RS256", "RS384", "RS512", "PS256", "PS384", "PS512": 230 // RSA methods - we primarily support ES256 for atproto 231 return nil, fmt.Errorf("RSA algorithms not yet supported for DPoP: %s", alg) 232 default: 233 return nil, fmt.Errorf("unsupported DPoP algorithm: %s", alg) 234 } 235 return publicKey, nil 236 }) 237 if err != nil { 238 return nil, fmt.Errorf("DPoP proof signature verification failed: %w", err) 239 } 240 241 claims, ok := verifiedToken.Claims.(*DPoPClaims) 242 if !ok { 243 return nil, fmt.Errorf("invalid DPoP claims type") 244 } 245 246 // Validate the claims 247 if err := v.validateDPoPClaims(claims, httpMethod, httpURI); err != nil { 248 return nil, err 249 } 250 251 return &DPoPProof{ 252 Claims: claims, 253 PublicKey: publicKey, 254 Thumbprint: thumbprint, 255 RawPublicJWK: jwkMap, 256 }, nil 257} 258 259// validateDPoPClaims validates the DPoP proof claims 260func (v *DPoPVerifier) validateDPoPClaims(claims *DPoPClaims, expectedMethod, expectedURI string) error { 261 // Validate jti (unique identifier) is present 262 if claims.ID == "" { 263 return fmt.Errorf("DPoP proof missing jti claim") 264 } 265 266 // Validate htm (HTTP method) 267 if !strings.EqualFold(claims.HTTPMethod, expectedMethod) { 268 return fmt.Errorf("DPoP proof htm mismatch: expected %s, got %s", expectedMethod, claims.HTTPMethod) 269 } 270 271 // Validate htu (HTTP URI) - compare without query/fragment 272 expectedURIBase := stripQueryFragment(expectedURI) 273 claimURIBase := stripQueryFragment(claims.HTTPURI) 274 if expectedURIBase != claimURIBase { 275 return fmt.Errorf("DPoP proof htu mismatch: expected %s, got %s", expectedURIBase, claimURIBase) 276 } 277 278 // Validate iat (issued at) is present and recent 279 if claims.IssuedAt == nil { 280 return fmt.Errorf("DPoP proof missing iat claim") 281 } 282 283 now := time.Now() 284 iat := claims.IssuedAt.Time 285 286 // Check clock skew (not too far in the future) 287 if iat.After(now.Add(v.MaxClockSkew)) { 288 return fmt.Errorf("DPoP proof iat is in the future") 289 } 290 291 // Check proof age (not too old) 292 if now.Sub(iat) > v.MaxProofAge { 293 return fmt.Errorf("DPoP proof is too old (issued %v ago, max %v)", now.Sub(iat), v.MaxProofAge) 294 } 295 296 // SECURITY: Check for replay attack using jti 297 // Per RFC 9449 Section 11.1, servers SHOULD prevent replay attacks 298 if v.NonceCache != nil { 299 if !v.NonceCache.CheckAndStore(claims.ID) { 300 return fmt.Errorf("DPoP proof replay detected: jti %s already used", claims.ID) 301 } 302 } 303 304 return nil 305} 306 307// VerifyTokenBinding verifies that the DPoP proof binds to the access token 308// by comparing the proof's thumbprint to the token's cnf.jkt claim 309func (v *DPoPVerifier) VerifyTokenBinding(proof *DPoPProof, expectedThumbprint string) error { 310 if proof.Thumbprint != expectedThumbprint { 311 return fmt.Errorf("DPoP proof thumbprint mismatch: token expects %s, proof has %s", 312 expectedThumbprint, proof.Thumbprint) 313 } 314 return nil 315} 316 317// CalculateJWKThumbprint calculates the JWK thumbprint per RFC 7638 318// The thumbprint is the base64url-encoded SHA-256 hash of the canonical JWK representation 319func CalculateJWKThumbprint(jwk map[string]interface{}) (string, error) { 320 kty, ok := jwk["kty"].(string) 321 if !ok { 322 return "", fmt.Errorf("JWK missing kty") 323 } 324 325 // Build the canonical JWK representation based on key type 326 // Per RFC 7638, only specific members are included, in lexicographic order 327 var canonical map[string]string 328 329 switch kty { 330 case "EC": 331 crv, ok := jwk["crv"].(string) 332 if !ok { 333 return "", fmt.Errorf("EC JWK missing crv") 334 } 335 x, ok := jwk["x"].(string) 336 if !ok { 337 return "", fmt.Errorf("EC JWK missing x") 338 } 339 y, ok := jwk["y"].(string) 340 if !ok { 341 return "", fmt.Errorf("EC JWK missing y") 342 } 343 // Lexicographic order: crv, kty, x, y 344 canonical = map[string]string{ 345 "crv": crv, 346 "kty": kty, 347 "x": x, 348 "y": y, 349 } 350 case "RSA": 351 e, ok := jwk["e"].(string) 352 if !ok { 353 return "", fmt.Errorf("RSA JWK missing e") 354 } 355 n, ok := jwk["n"].(string) 356 if !ok { 357 return "", fmt.Errorf("RSA JWK missing n") 358 } 359 // Lexicographic order: e, kty, n 360 canonical = map[string]string{ 361 "e": e, 362 "kty": kty, 363 "n": n, 364 } 365 case "OKP": 366 crv, ok := jwk["crv"].(string) 367 if !ok { 368 return "", fmt.Errorf("OKP JWK missing crv") 369 } 370 x, ok := jwk["x"].(string) 371 if !ok { 372 return "", fmt.Errorf("OKP JWK missing x") 373 } 374 // Lexicographic order: crv, kty, x 375 canonical = map[string]string{ 376 "crv": crv, 377 "kty": kty, 378 "x": x, 379 } 380 default: 381 return "", fmt.Errorf("unsupported JWK key type: %s", kty) 382 } 383 384 // Serialize to JSON (Go's json.Marshal produces lexicographically ordered keys for map[string]string) 385 canonicalJSON, err := json.Marshal(canonical) 386 if err != nil { 387 return "", fmt.Errorf("failed to serialize canonical JWK: %w", err) 388 } 389 390 // SHA-256 hash 391 hash := sha256.Sum256(canonicalJSON) 392 393 // Base64url encode (no padding) 394 thumbprint := base64.RawURLEncoding.EncodeToString(hash[:]) 395 396 return thumbprint, nil 397} 398 399// parseJWKToPublicKey parses a JWK map to a Go public key 400func parseJWKToPublicKey(jwkMap map[string]interface{}) (interface{}, error) { 401 // Convert map to JSON bytes for indigo's parser 402 jwkBytes, err := json.Marshal(jwkMap) 403 if err != nil { 404 return nil, fmt.Errorf("failed to serialize JWK: %w", err) 405 } 406 407 // Try to parse with indigo's crypto package 408 pubKey, err := indigoCrypto.ParsePublicJWKBytes(jwkBytes) 409 if err != nil { 410 return nil, fmt.Errorf("failed to parse JWK: %w", err) 411 } 412 413 // Convert indigo's PublicKey to Go's ecdsa.PublicKey 414 jwk, err := pubKey.JWK() 415 if err != nil { 416 return nil, fmt.Errorf("failed to get JWK from public key: %w", err) 417 } 418 419 // Use our existing conversion function 420 return atcryptoJWKToECDSAFromIndigoJWK(jwk) 421} 422 423// atcryptoJWKToECDSAFromIndigoJWK converts an indigo JWK to Go ecdsa.PublicKey 424func atcryptoJWKToECDSAFromIndigoJWK(jwk *indigoCrypto.JWK) (*ecdsa.PublicKey, error) { 425 if jwk.KeyType != "EC" { 426 return nil, fmt.Errorf("unsupported JWK key type: %s (expected EC)", jwk.KeyType) 427 } 428 429 xBytes, err := base64.RawURLEncoding.DecodeString(jwk.X) 430 if err != nil { 431 return nil, fmt.Errorf("invalid JWK X coordinate: %w", err) 432 } 433 yBytes, err := base64.RawURLEncoding.DecodeString(jwk.Y) 434 if err != nil { 435 return nil, fmt.Errorf("invalid JWK Y coordinate: %w", err) 436 } 437 438 var curve ecdsa.PublicKey 439 switch jwk.Curve { 440 case "P-256": 441 curve.Curve = ecdsaP256Curve() 442 case "P-384": 443 curve.Curve = ecdsaP384Curve() 444 case "P-521": 445 curve.Curve = ecdsaP521Curve() 446 default: 447 return nil, fmt.Errorf("unsupported curve: %s", jwk.Curve) 448 } 449 450 curve.X = new(big.Int).SetBytes(xBytes) 451 curve.Y = new(big.Int).SetBytes(yBytes) 452 453 return &curve, nil 454} 455 456// Helper functions for elliptic curves 457func ecdsaP256Curve() elliptic.Curve { return elliptic.P256() } 458func ecdsaP384Curve() elliptic.Curve { return elliptic.P384() } 459func ecdsaP521Curve() elliptic.Curve { return elliptic.P521() } 460 461// stripQueryFragment removes query and fragment from a URI 462func stripQueryFragment(uri string) string { 463 if idx := strings.Index(uri, "?"); idx != -1 { 464 uri = uri[:idx] 465 } 466 if idx := strings.Index(uri, "#"); idx != -1 { 467 uri = uri[:idx] 468 } 469 return uri 470} 471 472// ExtractCnfJkt extracts the cnf.jkt (confirmation key thumbprint) from JWT claims 473func ExtractCnfJkt(claims *Claims) (string, error) { 474 if claims.Confirmation == nil { 475 return "", fmt.Errorf("token missing cnf claim (no DPoP binding)") 476 } 477 478 jkt, ok := claims.Confirmation["jkt"].(string) 479 if !ok || jkt == "" { 480 return "", fmt.Errorf("token cnf claim missing jkt (DPoP key thumbprint)") 481 } 482 483 return jkt, nil 484}