A community based topic aggregation platform built on atproto
1package auth
2
3import (
4 "context"
5 "crypto/ecdsa"
6 "crypto/elliptic"
7 "crypto/rsa"
8 "encoding/base64"
9 "encoding/json"
10 "fmt"
11 "math/big"
12 "net/url"
13 "os"
14 "strings"
15 "sync"
16 "time"
17
18 "github.com/golang-jwt/jwt/v5"
19)
20
21// jwtConfig holds cached JWT configuration to avoid reading env vars on every request
22type jwtConfig struct {
23 hs256Issuers map[string]struct{} // Set of whitelisted HS256 issuers
24 pdsJWTSecret []byte // Cached PDS_JWT_SECRET
25 isDevEnv bool // Cached IS_DEV_ENV
26}
27
28var (
29 cachedConfig *jwtConfig
30 configOnce sync.Once
31)
32
33// InitJWTConfig initializes the JWT configuration from environment variables.
34// This should be called once at startup. If not called explicitly, it will be
35// initialized lazily on first use.
36func InitJWTConfig() {
37 configOnce.Do(func() {
38 cachedConfig = &jwtConfig{
39 hs256Issuers: make(map[string]struct{}),
40 isDevEnv: os.Getenv("IS_DEV_ENV") == "true",
41 }
42
43 // Parse HS256_ISSUERS into a set for O(1) lookup
44 if issuers := os.Getenv("HS256_ISSUERS"); issuers != "" {
45 for _, issuer := range strings.Split(issuers, ",") {
46 issuer = strings.TrimSpace(issuer)
47 if issuer != "" {
48 cachedConfig.hs256Issuers[issuer] = struct{}{}
49 }
50 }
51 }
52
53 // Cache PDS_JWT_SECRET
54 if secret := os.Getenv("PDS_JWT_SECRET"); secret != "" {
55 cachedConfig.pdsJWTSecret = []byte(secret)
56 }
57 })
58}
59
60// getConfig returns the cached config, initializing if needed
61func getConfig() *jwtConfig {
62 InitJWTConfig()
63 return cachedConfig
64}
65
66// ResetJWTConfigForTesting resets the cached config to allow re-initialization.
67// This should ONLY be used in tests.
68func ResetJWTConfigForTesting() {
69 cachedConfig = nil
70 configOnce = sync.Once{}
71}
72
73// Algorithm constants for JWT signing methods
74const (
75 AlgorithmHS256 = "HS256"
76 AlgorithmRS256 = "RS256"
77 AlgorithmES256 = "ES256"
78)
79
80// JWTHeader represents the parsed JWT header
81type JWTHeader struct {
82 Alg string `json:"alg"`
83 Kid string `json:"kid"`
84 Typ string `json:"typ,omitempty"`
85}
86
87// Claims represents the standard JWT claims we care about
88type Claims struct {
89 jwt.RegisteredClaims
90 Scope string `json:"scope,omitempty"`
91}
92
93// stripBearerPrefix removes the "Bearer " prefix from a token string
94func stripBearerPrefix(tokenString string) string {
95 tokenString = strings.TrimPrefix(tokenString, "Bearer ")
96 return strings.TrimSpace(tokenString)
97}
98
99// ParseJWTHeader extracts and parses the JWT header from a token string
100// This is a reusable function for getting algorithm and key ID information
101func ParseJWTHeader(tokenString string) (*JWTHeader, error) {
102 tokenString = stripBearerPrefix(tokenString)
103
104 parts := strings.Split(tokenString, ".")
105 if len(parts) != 3 {
106 return nil, fmt.Errorf("invalid JWT format: expected 3 parts, got %d", len(parts))
107 }
108
109 headerBytes, err := base64.RawURLEncoding.DecodeString(parts[0])
110 if err != nil {
111 return nil, fmt.Errorf("failed to decode JWT header: %w", err)
112 }
113
114 var header JWTHeader
115 if err := json.Unmarshal(headerBytes, &header); err != nil {
116 return nil, fmt.Errorf("failed to parse JWT header: %w", err)
117 }
118
119 return &header, nil
120}
121
122// shouldUseHS256 determines if a token should use HS256 verification
123// This prevents algorithm confusion attacks by using multiple signals:
124// 1. If the token has a `kid` (key ID), it MUST use asymmetric verification
125// 2. If no `kid`, only allow HS256 from whitelisted issuers (your own PDS)
126//
127// This approach supports open federation because:
128// - External PDSes publish keys via JWKS and include `kid` in their tokens
129// - Only your own PDS (which shares PDS_JWT_SECRET) uses HS256 without `kid`
130func shouldUseHS256(header *JWTHeader, issuer string) bool {
131 // If token has a key ID, it MUST use asymmetric verification
132 // This is the primary defense against algorithm confusion attacks
133 if header.Kid != "" {
134 return false
135 }
136
137 // No kid - check if issuer is whitelisted for HS256
138 // This should only include your own PDS URL(s)
139 return isHS256IssuerWhitelisted(issuer)
140}
141
142// isHS256IssuerWhitelisted checks if the issuer is in the HS256 whitelist
143// Only your own PDS should be in this list - external PDSes should use JWKS
144func isHS256IssuerWhitelisted(issuer string) bool {
145 cfg := getConfig()
146 _, whitelisted := cfg.hs256Issuers[issuer]
147 return whitelisted
148}
149
150// ParseJWT parses a JWT token without verification (Phase 1)
151// Returns the claims if the token is valid JSON and has required fields
152func ParseJWT(tokenString string) (*Claims, error) {
153 // Remove "Bearer " prefix if present
154 tokenString = stripBearerPrefix(tokenString)
155
156 // Parse without verification first to extract claims
157 parser := jwt.NewParser(jwt.WithoutClaimsValidation())
158 token, _, err := parser.ParseUnverified(tokenString, &Claims{})
159 if err != nil {
160 return nil, fmt.Errorf("failed to parse JWT: %w", err)
161 }
162
163 claims, ok := token.Claims.(*Claims)
164 if !ok {
165 return nil, fmt.Errorf("invalid claims type")
166 }
167
168 // Validate required fields
169 if claims.Subject == "" {
170 return nil, fmt.Errorf("missing 'sub' claim (user DID)")
171 }
172
173 // atProto PDSes may use 'aud' instead of 'iss' for the authorization server
174 // If 'iss' is missing, use 'aud' as the authorization server identifier
175 if claims.Issuer == "" {
176 if len(claims.Audience) > 0 {
177 claims.Issuer = claims.Audience[0]
178 } else {
179 return nil, fmt.Errorf("missing both 'iss' and 'aud' claims (authorization server)")
180 }
181 }
182
183 // Validate claims (even in Phase 1, we need basic validation like expiry)
184 if err := validateClaims(claims); err != nil {
185 return nil, err
186 }
187
188 return claims, nil
189}
190
191// VerifyJWT verifies a JWT token's signature and claims (Phase 2)
192// Fetches the public key from the issuer's JWKS endpoint and validates the signature
193// For HS256 tokens from whitelisted issuers, uses the shared PDS_JWT_SECRET
194//
195// SECURITY: Algorithm is determined by the issuer whitelist, NOT the token header,
196// to prevent algorithm confusion attacks where an attacker could re-sign a token
197// with HS256 using a public key as the secret.
198func VerifyJWT(ctx context.Context, tokenString string, keyFetcher JWKSFetcher) (*Claims, error) {
199 // Strip Bearer prefix once at the start
200 tokenString = stripBearerPrefix(tokenString)
201
202 // First parse to get the issuer (needed to determine expected algorithm)
203 claims, err := ParseJWT(tokenString)
204 if err != nil {
205 return nil, err
206 }
207
208 // Parse header to get the claimed algorithm (for validation)
209 header, err := ParseJWTHeader(tokenString)
210 if err != nil {
211 return nil, err
212 }
213
214 // SECURITY: Determine verification method based on token characteristics
215 // 1. Tokens with `kid` MUST use asymmetric verification (supports federation)
216 // 2. Tokens without `kid` can use HS256 only from whitelisted issuers (your own PDS)
217 useHS256 := shouldUseHS256(header, claims.Issuer)
218
219 if useHS256 {
220 // Verify token actually claims to use HS256
221 if header.Alg != AlgorithmHS256 {
222 return nil, fmt.Errorf("expected HS256 for issuer %s but token uses %s", claims.Issuer, header.Alg)
223 }
224 return verifyHS256Token(tokenString)
225 }
226
227 // Token must use asymmetric verification
228 // Reject HS256 tokens that don't meet the criteria above
229 if header.Alg == AlgorithmHS256 {
230 if header.Kid != "" {
231 return nil, fmt.Errorf("HS256 tokens with kid must use asymmetric verification")
232 }
233 return nil, fmt.Errorf("HS256 not allowed for issuer %s (not in HS256_ISSUERS whitelist)", claims.Issuer)
234 }
235
236 // For RSA/ECDSA, fetch public key from JWKS and verify
237 return verifyAsymmetricToken(ctx, tokenString, claims.Issuer, keyFetcher)
238}
239
240// verifyHS256Token verifies a JWT using HMAC-SHA256 with the shared secret
241func verifyHS256Token(tokenString string) (*Claims, error) {
242 cfg := getConfig()
243 if len(cfg.pdsJWTSecret) == 0 {
244 return nil, fmt.Errorf("HS256 verification failed: PDS_JWT_SECRET not configured")
245 }
246
247 token, err := jwt.ParseWithClaims(tokenString, &Claims{}, func(token *jwt.Token) (interface{}, error) {
248 if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
249 return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
250 }
251 return cfg.pdsJWTSecret, nil
252 })
253 if err != nil {
254 return nil, fmt.Errorf("HS256 verification failed: %w", err)
255 }
256
257 if !token.Valid {
258 return nil, fmt.Errorf("HS256 verification failed: token signature invalid")
259 }
260
261 verifiedClaims, ok := token.Claims.(*Claims)
262 if !ok {
263 return nil, fmt.Errorf("HS256 verification failed: invalid claims type")
264 }
265
266 if err := validateClaims(verifiedClaims); err != nil {
267 return nil, err
268 }
269
270 return verifiedClaims, nil
271}
272
273// verifyAsymmetricToken verifies a JWT using RSA or ECDSA with a public key from JWKS
274func verifyAsymmetricToken(ctx context.Context, tokenString, issuer string, keyFetcher JWKSFetcher) (*Claims, error) {
275 publicKey, err := keyFetcher.FetchPublicKey(ctx, issuer, tokenString)
276 if err != nil {
277 return nil, fmt.Errorf("failed to fetch public key: %w", err)
278 }
279
280 token, err := jwt.ParseWithClaims(tokenString, &Claims{}, func(token *jwt.Token) (interface{}, error) {
281 // Validate signing method - support both RSA and ECDSA (atProto uses ES256 primarily)
282 switch token.Method.(type) {
283 case *jwt.SigningMethodRSA, *jwt.SigningMethodECDSA:
284 // Valid signing methods for atProto
285 default:
286 return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
287 }
288 return publicKey, nil
289 })
290 if err != nil {
291 return nil, fmt.Errorf("asymmetric verification failed: %w", err)
292 }
293
294 if !token.Valid {
295 return nil, fmt.Errorf("asymmetric verification failed: token signature invalid")
296 }
297
298 verifiedClaims, ok := token.Claims.(*Claims)
299 if !ok {
300 return nil, fmt.Errorf("asymmetric verification failed: invalid claims type")
301 }
302
303 if err := validateClaims(verifiedClaims); err != nil {
304 return nil, err
305 }
306
307 return verifiedClaims, nil
308}
309
310// validateClaims performs additional validation on JWT claims
311func validateClaims(claims *Claims) error {
312 now := time.Now()
313
314 // Check expiration
315 if claims.ExpiresAt != nil && claims.ExpiresAt.Before(now) {
316 return fmt.Errorf("token has expired")
317 }
318
319 // Check not before
320 if claims.NotBefore != nil && claims.NotBefore.After(now) {
321 return fmt.Errorf("token not yet valid")
322 }
323
324 // Validate DID format in sub claim
325 if !strings.HasPrefix(claims.Subject, "did:") {
326 return fmt.Errorf("invalid DID format in 'sub' claim: %s", claims.Subject)
327 }
328
329 // Validate issuer is either an HTTPS URL or a DID
330 // atProto uses DIDs (did:web:, did:plc:) or HTTPS URLs as issuer identifiers
331 // In dev mode (IS_DEV_ENV=true), allow HTTP for local PDS testing
332 isHTTP := strings.HasPrefix(claims.Issuer, "http://")
333 isHTTPS := strings.HasPrefix(claims.Issuer, "https://")
334 isDID := strings.HasPrefix(claims.Issuer, "did:")
335
336 if !isHTTPS && !isDID && !isHTTP {
337 return fmt.Errorf("issuer must be HTTPS URL, HTTP URL (dev only), or DID, got: %s", claims.Issuer)
338 }
339
340 // In production, reject HTTP issuers (only for non-dev environments)
341 cfg := getConfig()
342 if isHTTP && !cfg.isDevEnv {
343 return fmt.Errorf("HTTP issuer not allowed in production, got: %s", claims.Issuer)
344 }
345
346 // Parse to ensure it's a valid URL
347 if _, err := url.Parse(claims.Issuer); err != nil {
348 return fmt.Errorf("invalid issuer URL: %w", err)
349 }
350
351 // Validate scope if present (lenient: allow empty, but reject wrong scopes)
352 if claims.Scope != "" && !strings.Contains(claims.Scope, "atproto") {
353 return fmt.Errorf("token missing required 'atproto' scope, got: %s", claims.Scope)
354 }
355
356 return nil
357}
358
359// JWKSFetcher defines the interface for fetching public keys from JWKS endpoints
360// Returns interface{} to support both RSA and ECDSA keys
361type JWKSFetcher interface {
362 FetchPublicKey(ctx context.Context, issuer, token string) (interface{}, error)
363}
364
365// JWK represents a JSON Web Key from a JWKS endpoint
366// Supports both RSA and EC (ECDSA) keys
367type JWK struct {
368 Kid string `json:"kid"` // Key ID
369 Kty string `json:"kty"` // Key type ("RSA" or "EC")
370 Alg string `json:"alg"` // Algorithm (e.g., "RS256", "ES256")
371 Use string `json:"use"` // Public key use (should be "sig" for signatures)
372 // RSA fields
373 N string `json:"n,omitempty"` // RSA modulus
374 E string `json:"e,omitempty"` // RSA exponent
375 // EC fields
376 Crv string `json:"crv,omitempty"` // EC curve (e.g., "P-256")
377 X string `json:"x,omitempty"` // EC x coordinate
378 Y string `json:"y,omitempty"` // EC y coordinate
379}
380
381// ToPublicKey converts a JWK to a public key (RSA or ECDSA)
382func (j *JWK) ToPublicKey() (interface{}, error) {
383 switch j.Kty {
384 case "RSA":
385 return j.toRSAPublicKey()
386 case "EC":
387 return j.toECPublicKey()
388 default:
389 return nil, fmt.Errorf("unsupported key type: %s", j.Kty)
390 }
391}
392
393// toRSAPublicKey converts a JWK to an RSA public key
394func (j *JWK) toRSAPublicKey() (*rsa.PublicKey, error) {
395 // Decode modulus
396 nBytes, err := base64.RawURLEncoding.DecodeString(j.N)
397 if err != nil {
398 return nil, fmt.Errorf("failed to decode RSA modulus: %w", err)
399 }
400
401 // Decode exponent
402 eBytes, err := base64.RawURLEncoding.DecodeString(j.E)
403 if err != nil {
404 return nil, fmt.Errorf("failed to decode RSA exponent: %w", err)
405 }
406
407 // Convert exponent to int
408 var eInt int
409 for _, b := range eBytes {
410 eInt = eInt*256 + int(b)
411 }
412
413 return &rsa.PublicKey{
414 N: new(big.Int).SetBytes(nBytes),
415 E: eInt,
416 }, nil
417}
418
419// toECPublicKey converts a JWK to an ECDSA public key
420func (j *JWK) toECPublicKey() (*ecdsa.PublicKey, error) {
421 // Determine curve
422 var curve elliptic.Curve
423 switch j.Crv {
424 case "P-256":
425 curve = elliptic.P256()
426 case "P-384":
427 curve = elliptic.P384()
428 case "P-521":
429 curve = elliptic.P521()
430 default:
431 return nil, fmt.Errorf("unsupported EC curve: %s", j.Crv)
432 }
433
434 // Decode X coordinate
435 xBytes, err := base64.RawURLEncoding.DecodeString(j.X)
436 if err != nil {
437 return nil, fmt.Errorf("failed to decode EC x coordinate: %w", err)
438 }
439
440 // Decode Y coordinate
441 yBytes, err := base64.RawURLEncoding.DecodeString(j.Y)
442 if err != nil {
443 return nil, fmt.Errorf("failed to decode EC y coordinate: %w", err)
444 }
445
446 return &ecdsa.PublicKey{
447 Curve: curve,
448 X: new(big.Int).SetBytes(xBytes),
449 Y: new(big.Int).SetBytes(yBytes),
450 }, nil
451}
452
453// JWKS represents a JSON Web Key Set
454type JWKS struct {
455 Keys []JWK `json:"keys"`
456}
457
458// FindKeyByID finds a key in the JWKS by its key ID
459func (j *JWKS) FindKeyByID(kid string) (*JWK, error) {
460 for _, key := range j.Keys {
461 if key.Kid == kid {
462 return &key, nil
463 }
464 }
465 return nil, fmt.Errorf("key with kid %s not found", kid)
466}
467
468// ExtractKeyID extracts the key ID from a JWT token header
469func ExtractKeyID(tokenString string) (string, error) {
470 header, err := ParseJWTHeader(tokenString)
471 if err != nil {
472 return "", err
473 }
474
475 if header.Kid == "" {
476 return "", fmt.Errorf("missing kid in token header")
477 }
478
479 return header.Kid, nil
480}