A community based topic aggregation platform built on atproto
1package auth 2 3import ( 4 "context" 5 "crypto/ecdsa" 6 "crypto/elliptic" 7 "crypto/rsa" 8 "encoding/base64" 9 "encoding/json" 10 "fmt" 11 "math/big" 12 "net/url" 13 "os" 14 "strings" 15 "sync" 16 "time" 17 18 "github.com/golang-jwt/jwt/v5" 19) 20 21// jwtConfig holds cached JWT configuration to avoid reading env vars on every request 22type jwtConfig struct { 23 hs256Issuers map[string]struct{} // Set of whitelisted HS256 issuers 24 pdsJWTSecret []byte // Cached PDS_JWT_SECRET 25 isDevEnv bool // Cached IS_DEV_ENV 26} 27 28var ( 29 cachedConfig *jwtConfig 30 configOnce sync.Once 31) 32 33// InitJWTConfig initializes the JWT configuration from environment variables. 34// This should be called once at startup. If not called explicitly, it will be 35// initialized lazily on first use. 36func InitJWTConfig() { 37 configOnce.Do(func() { 38 cachedConfig = &jwtConfig{ 39 hs256Issuers: make(map[string]struct{}), 40 isDevEnv: os.Getenv("IS_DEV_ENV") == "true", 41 } 42 43 // Parse HS256_ISSUERS into a set for O(1) lookup 44 if issuers := os.Getenv("HS256_ISSUERS"); issuers != "" { 45 for _, issuer := range strings.Split(issuers, ",") { 46 issuer = strings.TrimSpace(issuer) 47 if issuer != "" { 48 cachedConfig.hs256Issuers[issuer] = struct{}{} 49 } 50 } 51 } 52 53 // Cache PDS_JWT_SECRET 54 if secret := os.Getenv("PDS_JWT_SECRET"); secret != "" { 55 cachedConfig.pdsJWTSecret = []byte(secret) 56 } 57 }) 58} 59 60// getConfig returns the cached config, initializing if needed 61func getConfig() *jwtConfig { 62 InitJWTConfig() 63 return cachedConfig 64} 65 66// ResetJWTConfigForTesting resets the cached config to allow re-initialization. 67// This should ONLY be used in tests. 68func ResetJWTConfigForTesting() { 69 cachedConfig = nil 70 configOnce = sync.Once{} 71} 72 73// Algorithm constants for JWT signing methods 74const ( 75 AlgorithmHS256 = "HS256" 76 AlgorithmRS256 = "RS256" 77 AlgorithmES256 = "ES256" 78) 79 80// JWTHeader represents the parsed JWT header 81type JWTHeader struct { 82 Alg string `json:"alg"` 83 Kid string `json:"kid"` 84 Typ string `json:"typ,omitempty"` 85} 86 87// Claims represents the standard JWT claims we care about 88type Claims struct { 89 jwt.RegisteredClaims 90 // Confirmation claim for DPoP token binding (RFC 9449) 91 // Contains "jkt" (JWK thumbprint) when token is bound to a DPoP key 92 Confirmation map[string]interface{} `json:"cnf,omitempty"` 93 Scope string `json:"scope,omitempty"` 94} 95 96// stripBearerPrefix removes the "Bearer " prefix from a token string 97func stripBearerPrefix(tokenString string) string { 98 tokenString = strings.TrimPrefix(tokenString, "Bearer ") 99 return strings.TrimSpace(tokenString) 100} 101 102// ParseJWTHeader extracts and parses the JWT header from a token string 103// This is a reusable function for getting algorithm and key ID information 104func ParseJWTHeader(tokenString string) (*JWTHeader, error) { 105 tokenString = stripBearerPrefix(tokenString) 106 107 parts := strings.Split(tokenString, ".") 108 if len(parts) != 3 { 109 return nil, fmt.Errorf("invalid JWT format: expected 3 parts, got %d", len(parts)) 110 } 111 112 headerBytes, err := base64.RawURLEncoding.DecodeString(parts[0]) 113 if err != nil { 114 return nil, fmt.Errorf("failed to decode JWT header: %w", err) 115 } 116 117 var header JWTHeader 118 if err := json.Unmarshal(headerBytes, &header); err != nil { 119 return nil, fmt.Errorf("failed to parse JWT header: %w", err) 120 } 121 122 return &header, nil 123} 124 125// shouldUseHS256 determines if a token should use HS256 verification 126// This prevents algorithm confusion attacks by using multiple signals: 127// 1. If the token has a `kid` (key ID), it MUST use asymmetric verification 128// 2. If no `kid`, only allow HS256 from whitelisted issuers (your own PDS) 129// 130// This approach supports open federation because: 131// - External PDSes publish keys via JWKS and include `kid` in their tokens 132// - Only your own PDS (which shares PDS_JWT_SECRET) uses HS256 without `kid` 133func shouldUseHS256(header *JWTHeader, issuer string) bool { 134 // If token has a key ID, it MUST use asymmetric verification 135 // This is the primary defense against algorithm confusion attacks 136 if header.Kid != "" { 137 return false 138 } 139 140 // No kid - check if issuer is whitelisted for HS256 141 // This should only include your own PDS URL(s) 142 return isHS256IssuerWhitelisted(issuer) 143} 144 145// isHS256IssuerWhitelisted checks if the issuer is in the HS256 whitelist 146// Only your own PDS should be in this list - external PDSes should use JWKS 147func isHS256IssuerWhitelisted(issuer string) bool { 148 cfg := getConfig() 149 _, whitelisted := cfg.hs256Issuers[issuer] 150 return whitelisted 151} 152 153// ParseJWT parses a JWT token without verification (Phase 1) 154// Returns the claims if the token is valid JSON and has required fields 155func ParseJWT(tokenString string) (*Claims, error) { 156 // Remove "Bearer " prefix if present 157 tokenString = stripBearerPrefix(tokenString) 158 159 // Parse without verification first to extract claims 160 parser := jwt.NewParser(jwt.WithoutClaimsValidation()) 161 token, _, err := parser.ParseUnverified(tokenString, &Claims{}) 162 if err != nil { 163 return nil, fmt.Errorf("failed to parse JWT: %w", err) 164 } 165 166 claims, ok := token.Claims.(*Claims) 167 if !ok { 168 return nil, fmt.Errorf("invalid claims type") 169 } 170 171 // Validate required fields 172 if claims.Subject == "" { 173 return nil, fmt.Errorf("missing 'sub' claim (user DID)") 174 } 175 176 // atProto PDSes may use 'aud' instead of 'iss' for the authorization server 177 // If 'iss' is missing, use 'aud' as the authorization server identifier 178 if claims.Issuer == "" { 179 if len(claims.Audience) > 0 { 180 claims.Issuer = claims.Audience[0] 181 } else { 182 return nil, fmt.Errorf("missing both 'iss' and 'aud' claims (authorization server)") 183 } 184 } 185 186 // Validate claims (even in Phase 1, we need basic validation like expiry) 187 if err := validateClaims(claims); err != nil { 188 return nil, err 189 } 190 191 return claims, nil 192} 193 194// VerifyJWT verifies a JWT token's signature and claims (Phase 2) 195// Fetches the public key from the issuer's JWKS endpoint and validates the signature 196// For HS256 tokens from whitelisted issuers, uses the shared PDS_JWT_SECRET 197// 198// SECURITY: Algorithm is determined by the issuer whitelist, NOT the token header, 199// to prevent algorithm confusion attacks where an attacker could re-sign a token 200// with HS256 using a public key as the secret. 201func VerifyJWT(ctx context.Context, tokenString string, keyFetcher JWKSFetcher) (*Claims, error) { 202 // Strip Bearer prefix once at the start 203 tokenString = stripBearerPrefix(tokenString) 204 205 // First parse to get the issuer (needed to determine expected algorithm) 206 claims, err := ParseJWT(tokenString) 207 if err != nil { 208 return nil, err 209 } 210 211 // Parse header to get the claimed algorithm (for validation) 212 header, err := ParseJWTHeader(tokenString) 213 if err != nil { 214 return nil, err 215 } 216 217 // SECURITY: Determine verification method based on token characteristics 218 // 1. Tokens with `kid` MUST use asymmetric verification (supports federation) 219 // 2. Tokens without `kid` can use HS256 only from whitelisted issuers (your own PDS) 220 useHS256 := shouldUseHS256(header, claims.Issuer) 221 222 if useHS256 { 223 // Verify token actually claims to use HS256 224 if header.Alg != AlgorithmHS256 { 225 return nil, fmt.Errorf("expected HS256 for issuer %s but token uses %s", claims.Issuer, header.Alg) 226 } 227 return verifyHS256Token(tokenString) 228 } 229 230 // Token must use asymmetric verification 231 // Reject HS256 tokens that don't meet the criteria above 232 if header.Alg == AlgorithmHS256 { 233 if header.Kid != "" { 234 return nil, fmt.Errorf("HS256 tokens with kid must use asymmetric verification") 235 } 236 return nil, fmt.Errorf("HS256 not allowed for issuer %s (not in HS256_ISSUERS whitelist)", claims.Issuer) 237 } 238 239 // For RSA/ECDSA, fetch public key from JWKS and verify 240 return verifyAsymmetricToken(ctx, tokenString, claims.Issuer, keyFetcher) 241} 242 243// verifyHS256Token verifies a JWT using HMAC-SHA256 with the shared secret 244func verifyHS256Token(tokenString string) (*Claims, error) { 245 cfg := getConfig() 246 if len(cfg.pdsJWTSecret) == 0 { 247 return nil, fmt.Errorf("HS256 verification failed: PDS_JWT_SECRET not configured") 248 } 249 250 token, err := jwt.ParseWithClaims(tokenString, &Claims{}, func(token *jwt.Token) (interface{}, error) { 251 if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok { 252 return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"]) 253 } 254 return cfg.pdsJWTSecret, nil 255 }) 256 if err != nil { 257 return nil, fmt.Errorf("HS256 verification failed: %w", err) 258 } 259 260 if !token.Valid { 261 return nil, fmt.Errorf("HS256 verification failed: token signature invalid") 262 } 263 264 verifiedClaims, ok := token.Claims.(*Claims) 265 if !ok { 266 return nil, fmt.Errorf("HS256 verification failed: invalid claims type") 267 } 268 269 if err := validateClaims(verifiedClaims); err != nil { 270 return nil, err 271 } 272 273 return verifiedClaims, nil 274} 275 276// verifyAsymmetricToken verifies a JWT using RSA or ECDSA with a public key from JWKS 277func verifyAsymmetricToken(ctx context.Context, tokenString, issuer string, keyFetcher JWKSFetcher) (*Claims, error) { 278 publicKey, err := keyFetcher.FetchPublicKey(ctx, issuer, tokenString) 279 if err != nil { 280 return nil, fmt.Errorf("failed to fetch public key: %w", err) 281 } 282 283 token, err := jwt.ParseWithClaims(tokenString, &Claims{}, func(token *jwt.Token) (interface{}, error) { 284 // Validate signing method - support both RSA and ECDSA (atProto uses ES256 primarily) 285 switch token.Method.(type) { 286 case *jwt.SigningMethodRSA, *jwt.SigningMethodECDSA: 287 // Valid signing methods for atProto 288 default: 289 return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"]) 290 } 291 return publicKey, nil 292 }) 293 if err != nil { 294 return nil, fmt.Errorf("asymmetric verification failed: %w", err) 295 } 296 297 if !token.Valid { 298 return nil, fmt.Errorf("asymmetric verification failed: token signature invalid") 299 } 300 301 verifiedClaims, ok := token.Claims.(*Claims) 302 if !ok { 303 return nil, fmt.Errorf("asymmetric verification failed: invalid claims type") 304 } 305 306 if err := validateClaims(verifiedClaims); err != nil { 307 return nil, err 308 } 309 310 return verifiedClaims, nil 311} 312 313// validateClaims performs additional validation on JWT claims 314func validateClaims(claims *Claims) error { 315 now := time.Now() 316 317 // Check expiration 318 if claims.ExpiresAt != nil && claims.ExpiresAt.Before(now) { 319 return fmt.Errorf("token has expired") 320 } 321 322 // Check not before 323 if claims.NotBefore != nil && claims.NotBefore.After(now) { 324 return fmt.Errorf("token not yet valid") 325 } 326 327 // Validate DID format in sub claim 328 if !strings.HasPrefix(claims.Subject, "did:") { 329 return fmt.Errorf("invalid DID format in 'sub' claim: %s", claims.Subject) 330 } 331 332 // Validate issuer is either an HTTPS URL or a DID 333 // atProto uses DIDs (did:web:, did:plc:) or HTTPS URLs as issuer identifiers 334 // In dev mode (IS_DEV_ENV=true), allow HTTP for local PDS testing 335 isHTTP := strings.HasPrefix(claims.Issuer, "http://") 336 isHTTPS := strings.HasPrefix(claims.Issuer, "https://") 337 isDID := strings.HasPrefix(claims.Issuer, "did:") 338 339 if !isHTTPS && !isDID && !isHTTP { 340 return fmt.Errorf("issuer must be HTTPS URL, HTTP URL (dev only), or DID, got: %s", claims.Issuer) 341 } 342 343 // In production, reject HTTP issuers (only for non-dev environments) 344 cfg := getConfig() 345 if isHTTP && !cfg.isDevEnv { 346 return fmt.Errorf("HTTP issuer not allowed in production, got: %s", claims.Issuer) 347 } 348 349 // Parse to ensure it's a valid URL 350 if _, err := url.Parse(claims.Issuer); err != nil { 351 return fmt.Errorf("invalid issuer URL: %w", err) 352 } 353 354 // Validate scope if present (lenient: allow empty, but reject wrong scopes) 355 if claims.Scope != "" && !strings.Contains(claims.Scope, "atproto") { 356 return fmt.Errorf("token missing required 'atproto' scope, got: %s", claims.Scope) 357 } 358 359 return nil 360} 361 362// JWKSFetcher defines the interface for fetching public keys from JWKS endpoints 363// Returns interface{} to support both RSA and ECDSA keys 364type JWKSFetcher interface { 365 FetchPublicKey(ctx context.Context, issuer, token string) (interface{}, error) 366} 367 368// JWK represents a JSON Web Key from a JWKS endpoint 369// Supports both RSA and EC (ECDSA) keys 370type JWK struct { 371 Kid string `json:"kid"` // Key ID 372 Kty string `json:"kty"` // Key type ("RSA" or "EC") 373 Alg string `json:"alg"` // Algorithm (e.g., "RS256", "ES256") 374 Use string `json:"use"` // Public key use (should be "sig" for signatures) 375 // RSA fields 376 N string `json:"n,omitempty"` // RSA modulus 377 E string `json:"e,omitempty"` // RSA exponent 378 // EC fields 379 Crv string `json:"crv,omitempty"` // EC curve (e.g., "P-256") 380 X string `json:"x,omitempty"` // EC x coordinate 381 Y string `json:"y,omitempty"` // EC y coordinate 382} 383 384// ToPublicKey converts a JWK to a public key (RSA or ECDSA) 385func (j *JWK) ToPublicKey() (interface{}, error) { 386 switch j.Kty { 387 case "RSA": 388 return j.toRSAPublicKey() 389 case "EC": 390 return j.toECPublicKey() 391 default: 392 return nil, fmt.Errorf("unsupported key type: %s", j.Kty) 393 } 394} 395 396// toRSAPublicKey converts a JWK to an RSA public key 397func (j *JWK) toRSAPublicKey() (*rsa.PublicKey, error) { 398 // Decode modulus 399 nBytes, err := base64.RawURLEncoding.DecodeString(j.N) 400 if err != nil { 401 return nil, fmt.Errorf("failed to decode RSA modulus: %w", err) 402 } 403 404 // Decode exponent 405 eBytes, err := base64.RawURLEncoding.DecodeString(j.E) 406 if err != nil { 407 return nil, fmt.Errorf("failed to decode RSA exponent: %w", err) 408 } 409 410 // Convert exponent to int 411 var eInt int 412 for _, b := range eBytes { 413 eInt = eInt*256 + int(b) 414 } 415 416 return &rsa.PublicKey{ 417 N: new(big.Int).SetBytes(nBytes), 418 E: eInt, 419 }, nil 420} 421 422// toECPublicKey converts a JWK to an ECDSA public key 423func (j *JWK) toECPublicKey() (*ecdsa.PublicKey, error) { 424 // Determine curve 425 var curve elliptic.Curve 426 switch j.Crv { 427 case "P-256": 428 curve = elliptic.P256() 429 case "P-384": 430 curve = elliptic.P384() 431 case "P-521": 432 curve = elliptic.P521() 433 default: 434 return nil, fmt.Errorf("unsupported EC curve: %s", j.Crv) 435 } 436 437 // Decode X coordinate 438 xBytes, err := base64.RawURLEncoding.DecodeString(j.X) 439 if err != nil { 440 return nil, fmt.Errorf("failed to decode EC x coordinate: %w", err) 441 } 442 443 // Decode Y coordinate 444 yBytes, err := base64.RawURLEncoding.DecodeString(j.Y) 445 if err != nil { 446 return nil, fmt.Errorf("failed to decode EC y coordinate: %w", err) 447 } 448 449 return &ecdsa.PublicKey{ 450 Curve: curve, 451 X: new(big.Int).SetBytes(xBytes), 452 Y: new(big.Int).SetBytes(yBytes), 453 }, nil 454} 455 456// JWKS represents a JSON Web Key Set 457type JWKS struct { 458 Keys []JWK `json:"keys"` 459} 460 461// FindKeyByID finds a key in the JWKS by its key ID 462func (j *JWKS) FindKeyByID(kid string) (*JWK, error) { 463 for _, key := range j.Keys { 464 if key.Kid == kid { 465 return &key, nil 466 } 467 } 468 return nil, fmt.Errorf("key with kid %s not found", kid) 469} 470 471// ExtractKeyID extracts the key ID from a JWT token header 472func ExtractKeyID(tokenString string) (string, error) { 473 header, err := ParseJWTHeader(tokenString) 474 if err != nil { 475 return "", err 476 } 477 478 if header.Kid == "" { 479 return "", fmt.Errorf("missing kid in token header") 480 } 481 482 return header.Kid, nil 483}