A community based topic aggregation platform built on atproto
1package auth 2 3import ( 4 "context" 5 "crypto/ecdsa" 6 "crypto/elliptic" 7 "crypto/rsa" 8 "encoding/base64" 9 "encoding/json" 10 "fmt" 11 "math/big" 12 "net/url" 13 "os" 14 "strings" 15 "sync" 16 "time" 17 18 "github.com/golang-jwt/jwt/v5" 19) 20 21// jwtConfig holds cached JWT configuration to avoid reading env vars on every request 22type jwtConfig struct { 23 hs256Issuers map[string]struct{} // Set of whitelisted HS256 issuers 24 pdsJWTSecret []byte // Cached PDS_JWT_SECRET 25 isDevEnv bool // Cached IS_DEV_ENV 26} 27 28var ( 29 cachedConfig *jwtConfig 30 configOnce sync.Once 31) 32 33// InitJWTConfig initializes the JWT configuration from environment variables. 34// This should be called once at startup. If not called explicitly, it will be 35// initialized lazily on first use. 36func InitJWTConfig() { 37 configOnce.Do(func() { 38 cachedConfig = &jwtConfig{ 39 hs256Issuers: make(map[string]struct{}), 40 isDevEnv: os.Getenv("IS_DEV_ENV") == "true", 41 } 42 43 // Parse HS256_ISSUERS into a set for O(1) lookup 44 if issuers := os.Getenv("HS256_ISSUERS"); issuers != "" { 45 for _, issuer := range strings.Split(issuers, ",") { 46 issuer = strings.TrimSpace(issuer) 47 if issuer != "" { 48 cachedConfig.hs256Issuers[issuer] = struct{}{} 49 } 50 } 51 } 52 53 // Cache PDS_JWT_SECRET 54 if secret := os.Getenv("PDS_JWT_SECRET"); secret != "" { 55 cachedConfig.pdsJWTSecret = []byte(secret) 56 } 57 }) 58} 59 60// getConfig returns the cached config, initializing if needed 61func getConfig() *jwtConfig { 62 InitJWTConfig() 63 return cachedConfig 64} 65 66// ResetJWTConfigForTesting resets the cached config to allow re-initialization. 67// This should ONLY be used in tests. 68func ResetJWTConfigForTesting() { 69 cachedConfig = nil 70 configOnce = sync.Once{} 71} 72 73// Algorithm constants for JWT signing methods 74const ( 75 AlgorithmHS256 = "HS256" 76 AlgorithmRS256 = "RS256" 77 AlgorithmES256 = "ES256" 78) 79 80// JWTHeader represents the parsed JWT header 81type JWTHeader struct { 82 Alg string `json:"alg"` 83 Kid string `json:"kid"` 84 Typ string `json:"typ,omitempty"` 85} 86 87// Claims represents the standard JWT claims we care about 88type Claims struct { 89 jwt.RegisteredClaims 90 Scope string `json:"scope,omitempty"` 91} 92 93// stripBearerPrefix removes the "Bearer " prefix from a token string 94func stripBearerPrefix(tokenString string) string { 95 tokenString = strings.TrimPrefix(tokenString, "Bearer ") 96 return strings.TrimSpace(tokenString) 97} 98 99// ParseJWTHeader extracts and parses the JWT header from a token string 100// This is a reusable function for getting algorithm and key ID information 101func ParseJWTHeader(tokenString string) (*JWTHeader, error) { 102 tokenString = stripBearerPrefix(tokenString) 103 104 parts := strings.Split(tokenString, ".") 105 if len(parts) != 3 { 106 return nil, fmt.Errorf("invalid JWT format: expected 3 parts, got %d", len(parts)) 107 } 108 109 headerBytes, err := base64.RawURLEncoding.DecodeString(parts[0]) 110 if err != nil { 111 return nil, fmt.Errorf("failed to decode JWT header: %w", err) 112 } 113 114 var header JWTHeader 115 if err := json.Unmarshal(headerBytes, &header); err != nil { 116 return nil, fmt.Errorf("failed to parse JWT header: %w", err) 117 } 118 119 return &header, nil 120} 121 122// shouldUseHS256 determines if a token should use HS256 verification 123// This prevents algorithm confusion attacks by using multiple signals: 124// 1. If the token has a `kid` (key ID), it MUST use asymmetric verification 125// 2. If no `kid`, only allow HS256 from whitelisted issuers (your own PDS) 126// 127// This approach supports open federation because: 128// - External PDSes publish keys via JWKS and include `kid` in their tokens 129// - Only your own PDS (which shares PDS_JWT_SECRET) uses HS256 without `kid` 130func shouldUseHS256(header *JWTHeader, issuer string) bool { 131 // If token has a key ID, it MUST use asymmetric verification 132 // This is the primary defense against algorithm confusion attacks 133 if header.Kid != "" { 134 return false 135 } 136 137 // No kid - check if issuer is whitelisted for HS256 138 // This should only include your own PDS URL(s) 139 return isHS256IssuerWhitelisted(issuer) 140} 141 142// isHS256IssuerWhitelisted checks if the issuer is in the HS256 whitelist 143// Only your own PDS should be in this list - external PDSes should use JWKS 144func isHS256IssuerWhitelisted(issuer string) bool { 145 cfg := getConfig() 146 _, whitelisted := cfg.hs256Issuers[issuer] 147 return whitelisted 148} 149 150// ParseJWT parses a JWT token without verification (Phase 1) 151// Returns the claims if the token is valid JSON and has required fields 152func ParseJWT(tokenString string) (*Claims, error) { 153 // Remove "Bearer " prefix if present 154 tokenString = stripBearerPrefix(tokenString) 155 156 // Parse without verification first to extract claims 157 parser := jwt.NewParser(jwt.WithoutClaimsValidation()) 158 token, _, err := parser.ParseUnverified(tokenString, &Claims{}) 159 if err != nil { 160 return nil, fmt.Errorf("failed to parse JWT: %w", err) 161 } 162 163 claims, ok := token.Claims.(*Claims) 164 if !ok { 165 return nil, fmt.Errorf("invalid claims type") 166 } 167 168 // Validate required fields 169 if claims.Subject == "" { 170 return nil, fmt.Errorf("missing 'sub' claim (user DID)") 171 } 172 173 // atProto PDSes may use 'aud' instead of 'iss' for the authorization server 174 // If 'iss' is missing, use 'aud' as the authorization server identifier 175 if claims.Issuer == "" { 176 if len(claims.Audience) > 0 { 177 claims.Issuer = claims.Audience[0] 178 } else { 179 return nil, fmt.Errorf("missing both 'iss' and 'aud' claims (authorization server)") 180 } 181 } 182 183 // Validate claims (even in Phase 1, we need basic validation like expiry) 184 if err := validateClaims(claims); err != nil { 185 return nil, err 186 } 187 188 return claims, nil 189} 190 191// VerifyJWT verifies a JWT token's signature and claims (Phase 2) 192// Fetches the public key from the issuer's JWKS endpoint and validates the signature 193// For HS256 tokens from whitelisted issuers, uses the shared PDS_JWT_SECRET 194// 195// SECURITY: Algorithm is determined by the issuer whitelist, NOT the token header, 196// to prevent algorithm confusion attacks where an attacker could re-sign a token 197// with HS256 using a public key as the secret. 198func VerifyJWT(ctx context.Context, tokenString string, keyFetcher JWKSFetcher) (*Claims, error) { 199 // Strip Bearer prefix once at the start 200 tokenString = stripBearerPrefix(tokenString) 201 202 // First parse to get the issuer (needed to determine expected algorithm) 203 claims, err := ParseJWT(tokenString) 204 if err != nil { 205 return nil, err 206 } 207 208 // Parse header to get the claimed algorithm (for validation) 209 header, err := ParseJWTHeader(tokenString) 210 if err != nil { 211 return nil, err 212 } 213 214 // SECURITY: Determine verification method based on token characteristics 215 // 1. Tokens with `kid` MUST use asymmetric verification (supports federation) 216 // 2. Tokens without `kid` can use HS256 only from whitelisted issuers (your own PDS) 217 useHS256 := shouldUseHS256(header, claims.Issuer) 218 219 if useHS256 { 220 // Verify token actually claims to use HS256 221 if header.Alg != AlgorithmHS256 { 222 return nil, fmt.Errorf("expected HS256 for issuer %s but token uses %s", claims.Issuer, header.Alg) 223 } 224 return verifyHS256Token(tokenString) 225 } 226 227 // Token must use asymmetric verification 228 // Reject HS256 tokens that don't meet the criteria above 229 if header.Alg == AlgorithmHS256 { 230 if header.Kid != "" { 231 return nil, fmt.Errorf("HS256 tokens with kid must use asymmetric verification") 232 } 233 return nil, fmt.Errorf("HS256 not allowed for issuer %s (not in HS256_ISSUERS whitelist)", claims.Issuer) 234 } 235 236 // For RSA/ECDSA, fetch public key from JWKS and verify 237 return verifyAsymmetricToken(ctx, tokenString, claims.Issuer, keyFetcher) 238} 239 240// verifyHS256Token verifies a JWT using HMAC-SHA256 with the shared secret 241func verifyHS256Token(tokenString string) (*Claims, error) { 242 cfg := getConfig() 243 if len(cfg.pdsJWTSecret) == 0 { 244 return nil, fmt.Errorf("HS256 verification failed: PDS_JWT_SECRET not configured") 245 } 246 247 token, err := jwt.ParseWithClaims(tokenString, &Claims{}, func(token *jwt.Token) (interface{}, error) { 248 if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok { 249 return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"]) 250 } 251 return cfg.pdsJWTSecret, nil 252 }) 253 if err != nil { 254 return nil, fmt.Errorf("HS256 verification failed: %w", err) 255 } 256 257 if !token.Valid { 258 return nil, fmt.Errorf("HS256 verification failed: token signature invalid") 259 } 260 261 verifiedClaims, ok := token.Claims.(*Claims) 262 if !ok { 263 return nil, fmt.Errorf("HS256 verification failed: invalid claims type") 264 } 265 266 if err := validateClaims(verifiedClaims); err != nil { 267 return nil, err 268 } 269 270 return verifiedClaims, nil 271} 272 273// verifyAsymmetricToken verifies a JWT using RSA or ECDSA with a public key from JWKS 274func verifyAsymmetricToken(ctx context.Context, tokenString, issuer string, keyFetcher JWKSFetcher) (*Claims, error) { 275 publicKey, err := keyFetcher.FetchPublicKey(ctx, issuer, tokenString) 276 if err != nil { 277 return nil, fmt.Errorf("failed to fetch public key: %w", err) 278 } 279 280 token, err := jwt.ParseWithClaims(tokenString, &Claims{}, func(token *jwt.Token) (interface{}, error) { 281 // Validate signing method - support both RSA and ECDSA (atProto uses ES256 primarily) 282 switch token.Method.(type) { 283 case *jwt.SigningMethodRSA, *jwt.SigningMethodECDSA: 284 // Valid signing methods for atProto 285 default: 286 return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"]) 287 } 288 return publicKey, nil 289 }) 290 if err != nil { 291 return nil, fmt.Errorf("asymmetric verification failed: %w", err) 292 } 293 294 if !token.Valid { 295 return nil, fmt.Errorf("asymmetric verification failed: token signature invalid") 296 } 297 298 verifiedClaims, ok := token.Claims.(*Claims) 299 if !ok { 300 return nil, fmt.Errorf("asymmetric verification failed: invalid claims type") 301 } 302 303 if err := validateClaims(verifiedClaims); err != nil { 304 return nil, err 305 } 306 307 return verifiedClaims, nil 308} 309 310// validateClaims performs additional validation on JWT claims 311func validateClaims(claims *Claims) error { 312 now := time.Now() 313 314 // Check expiration 315 if claims.ExpiresAt != nil && claims.ExpiresAt.Before(now) { 316 return fmt.Errorf("token has expired") 317 } 318 319 // Check not before 320 if claims.NotBefore != nil && claims.NotBefore.After(now) { 321 return fmt.Errorf("token not yet valid") 322 } 323 324 // Validate DID format in sub claim 325 if !strings.HasPrefix(claims.Subject, "did:") { 326 return fmt.Errorf("invalid DID format in 'sub' claim: %s", claims.Subject) 327 } 328 329 // Validate issuer is either an HTTPS URL or a DID 330 // atProto uses DIDs (did:web:, did:plc:) or HTTPS URLs as issuer identifiers 331 // In dev mode (IS_DEV_ENV=true), allow HTTP for local PDS testing 332 isHTTP := strings.HasPrefix(claims.Issuer, "http://") 333 isHTTPS := strings.HasPrefix(claims.Issuer, "https://") 334 isDID := strings.HasPrefix(claims.Issuer, "did:") 335 336 if !isHTTPS && !isDID && !isHTTP { 337 return fmt.Errorf("issuer must be HTTPS URL, HTTP URL (dev only), or DID, got: %s", claims.Issuer) 338 } 339 340 // In production, reject HTTP issuers (only for non-dev environments) 341 cfg := getConfig() 342 if isHTTP && !cfg.isDevEnv { 343 return fmt.Errorf("HTTP issuer not allowed in production, got: %s", claims.Issuer) 344 } 345 346 // Parse to ensure it's a valid URL 347 if _, err := url.Parse(claims.Issuer); err != nil { 348 return fmt.Errorf("invalid issuer URL: %w", err) 349 } 350 351 // Validate scope if present (lenient: allow empty, but reject wrong scopes) 352 if claims.Scope != "" && !strings.Contains(claims.Scope, "atproto") { 353 return fmt.Errorf("token missing required 'atproto' scope, got: %s", claims.Scope) 354 } 355 356 return nil 357} 358 359// JWKSFetcher defines the interface for fetching public keys from JWKS endpoints 360// Returns interface{} to support both RSA and ECDSA keys 361type JWKSFetcher interface { 362 FetchPublicKey(ctx context.Context, issuer, token string) (interface{}, error) 363} 364 365// JWK represents a JSON Web Key from a JWKS endpoint 366// Supports both RSA and EC (ECDSA) keys 367type JWK struct { 368 Kid string `json:"kid"` // Key ID 369 Kty string `json:"kty"` // Key type ("RSA" or "EC") 370 Alg string `json:"alg"` // Algorithm (e.g., "RS256", "ES256") 371 Use string `json:"use"` // Public key use (should be "sig" for signatures) 372 // RSA fields 373 N string `json:"n,omitempty"` // RSA modulus 374 E string `json:"e,omitempty"` // RSA exponent 375 // EC fields 376 Crv string `json:"crv,omitempty"` // EC curve (e.g., "P-256") 377 X string `json:"x,omitempty"` // EC x coordinate 378 Y string `json:"y,omitempty"` // EC y coordinate 379} 380 381// ToPublicKey converts a JWK to a public key (RSA or ECDSA) 382func (j *JWK) ToPublicKey() (interface{}, error) { 383 switch j.Kty { 384 case "RSA": 385 return j.toRSAPublicKey() 386 case "EC": 387 return j.toECPublicKey() 388 default: 389 return nil, fmt.Errorf("unsupported key type: %s", j.Kty) 390 } 391} 392 393// toRSAPublicKey converts a JWK to an RSA public key 394func (j *JWK) toRSAPublicKey() (*rsa.PublicKey, error) { 395 // Decode modulus 396 nBytes, err := base64.RawURLEncoding.DecodeString(j.N) 397 if err != nil { 398 return nil, fmt.Errorf("failed to decode RSA modulus: %w", err) 399 } 400 401 // Decode exponent 402 eBytes, err := base64.RawURLEncoding.DecodeString(j.E) 403 if err != nil { 404 return nil, fmt.Errorf("failed to decode RSA exponent: %w", err) 405 } 406 407 // Convert exponent to int 408 var eInt int 409 for _, b := range eBytes { 410 eInt = eInt*256 + int(b) 411 } 412 413 return &rsa.PublicKey{ 414 N: new(big.Int).SetBytes(nBytes), 415 E: eInt, 416 }, nil 417} 418 419// toECPublicKey converts a JWK to an ECDSA public key 420func (j *JWK) toECPublicKey() (*ecdsa.PublicKey, error) { 421 // Determine curve 422 var curve elliptic.Curve 423 switch j.Crv { 424 case "P-256": 425 curve = elliptic.P256() 426 case "P-384": 427 curve = elliptic.P384() 428 case "P-521": 429 curve = elliptic.P521() 430 default: 431 return nil, fmt.Errorf("unsupported EC curve: %s", j.Crv) 432 } 433 434 // Decode X coordinate 435 xBytes, err := base64.RawURLEncoding.DecodeString(j.X) 436 if err != nil { 437 return nil, fmt.Errorf("failed to decode EC x coordinate: %w", err) 438 } 439 440 // Decode Y coordinate 441 yBytes, err := base64.RawURLEncoding.DecodeString(j.Y) 442 if err != nil { 443 return nil, fmt.Errorf("failed to decode EC y coordinate: %w", err) 444 } 445 446 return &ecdsa.PublicKey{ 447 Curve: curve, 448 X: new(big.Int).SetBytes(xBytes), 449 Y: new(big.Int).SetBytes(yBytes), 450 }, nil 451} 452 453// JWKS represents a JSON Web Key Set 454type JWKS struct { 455 Keys []JWK `json:"keys"` 456} 457 458// FindKeyByID finds a key in the JWKS by its key ID 459func (j *JWKS) FindKeyByID(kid string) (*JWK, error) { 460 for _, key := range j.Keys { 461 if key.Kid == kid { 462 return &key, nil 463 } 464 } 465 return nil, fmt.Errorf("key with kid %s not found", kid) 466} 467 468// ExtractKeyID extracts the key ID from a JWT token header 469func ExtractKeyID(tokenString string) (string, error) { 470 header, err := ParseJWTHeader(tokenString) 471 if err != nil { 472 return "", err 473 } 474 475 if header.Kid == "" { 476 return "", fmt.Errorf("missing kid in token header") 477 } 478 479 return header.Kid, nil 480}