A community based topic aggregation platform built on atproto
1package auth
2
3import (
4 "context"
5 "crypto/ecdsa"
6 "crypto/elliptic"
7 "crypto/rsa"
8 "encoding/base64"
9 "encoding/json"
10 "fmt"
11 "math/big"
12 "net/url"
13 "os"
14 "strings"
15 "sync"
16 "time"
17
18 "github.com/golang-jwt/jwt/v5"
19)
20
21// jwtConfig holds cached JWT configuration to avoid reading env vars on every request
22type jwtConfig struct {
23 hs256Issuers map[string]struct{} // Set of whitelisted HS256 issuers
24 pdsJWTSecret []byte // Cached PDS_JWT_SECRET
25 isDevEnv bool // Cached IS_DEV_ENV
26}
27
28var (
29 cachedConfig *jwtConfig
30 configOnce sync.Once
31)
32
33// InitJWTConfig initializes the JWT configuration from environment variables.
34// This should be called once at startup. If not called explicitly, it will be
35// initialized lazily on first use.
36func InitJWTConfig() {
37 configOnce.Do(func() {
38 cachedConfig = &jwtConfig{
39 hs256Issuers: make(map[string]struct{}),
40 isDevEnv: os.Getenv("IS_DEV_ENV") == "true",
41 }
42
43 // Parse HS256_ISSUERS into a set for O(1) lookup
44 if issuers := os.Getenv("HS256_ISSUERS"); issuers != "" {
45 for _, issuer := range strings.Split(issuers, ",") {
46 issuer = strings.TrimSpace(issuer)
47 if issuer != "" {
48 cachedConfig.hs256Issuers[issuer] = struct{}{}
49 }
50 }
51 }
52
53 // Cache PDS_JWT_SECRET
54 if secret := os.Getenv("PDS_JWT_SECRET"); secret != "" {
55 cachedConfig.pdsJWTSecret = []byte(secret)
56 }
57 })
58}
59
60// getConfig returns the cached config, initializing if needed
61func getConfig() *jwtConfig {
62 InitJWTConfig()
63 return cachedConfig
64}
65
66// ResetJWTConfigForTesting resets the cached config to allow re-initialization.
67// This should ONLY be used in tests.
68func ResetJWTConfigForTesting() {
69 cachedConfig = nil
70 configOnce = sync.Once{}
71}
72
73// Algorithm constants for JWT signing methods
74const (
75 AlgorithmHS256 = "HS256"
76 AlgorithmRS256 = "RS256"
77 AlgorithmES256 = "ES256"
78)
79
80// JWTHeader represents the parsed JWT header
81type JWTHeader struct {
82 Alg string `json:"alg"`
83 Kid string `json:"kid"`
84 Typ string `json:"typ,omitempty"`
85}
86
87// Claims represents the standard JWT claims we care about
88type Claims struct {
89 jwt.RegisteredClaims
90 // Confirmation claim for DPoP token binding (RFC 9449)
91 // Contains "jkt" (JWK thumbprint) when token is bound to a DPoP key
92 Confirmation map[string]interface{} `json:"cnf,omitempty"`
93 Scope string `json:"scope,omitempty"`
94}
95
96// stripBearerPrefix removes the "Bearer " prefix from a token string
97func stripBearerPrefix(tokenString string) string {
98 tokenString = strings.TrimPrefix(tokenString, "Bearer ")
99 return strings.TrimSpace(tokenString)
100}
101
102// ParseJWTHeader extracts and parses the JWT header from a token string
103// This is a reusable function for getting algorithm and key ID information
104func ParseJWTHeader(tokenString string) (*JWTHeader, error) {
105 tokenString = stripBearerPrefix(tokenString)
106
107 parts := strings.Split(tokenString, ".")
108 if len(parts) != 3 {
109 return nil, fmt.Errorf("invalid JWT format: expected 3 parts, got %d", len(parts))
110 }
111
112 headerBytes, err := base64.RawURLEncoding.DecodeString(parts[0])
113 if err != nil {
114 return nil, fmt.Errorf("failed to decode JWT header: %w", err)
115 }
116
117 var header JWTHeader
118 if err := json.Unmarshal(headerBytes, &header); err != nil {
119 return nil, fmt.Errorf("failed to parse JWT header: %w", err)
120 }
121
122 return &header, nil
123}
124
125// shouldUseHS256 determines if a token should use HS256 verification
126// This prevents algorithm confusion attacks by using multiple signals:
127// 1. If the token has a `kid` (key ID), it MUST use asymmetric verification
128// 2. If no `kid`, only allow HS256 from whitelisted issuers (your own PDS)
129//
130// This approach supports open federation because:
131// - External PDSes publish keys via JWKS and include `kid` in their tokens
132// - Only your own PDS (which shares PDS_JWT_SECRET) uses HS256 without `kid`
133func shouldUseHS256(header *JWTHeader, issuer string) bool {
134 // If token has a key ID, it MUST use asymmetric verification
135 // This is the primary defense against algorithm confusion attacks
136 if header.Kid != "" {
137 return false
138 }
139
140 // No kid - check if issuer is whitelisted for HS256
141 // This should only include your own PDS URL(s)
142 return isHS256IssuerWhitelisted(issuer)
143}
144
145// isHS256IssuerWhitelisted checks if the issuer is in the HS256 whitelist
146// Only your own PDS should be in this list - external PDSes should use JWKS
147func isHS256IssuerWhitelisted(issuer string) bool {
148 cfg := getConfig()
149 _, whitelisted := cfg.hs256Issuers[issuer]
150 return whitelisted
151}
152
153// ParseJWT parses a JWT token without verification (Phase 1)
154// Returns the claims if the token is valid JSON and has required fields
155func ParseJWT(tokenString string) (*Claims, error) {
156 // Remove "Bearer " prefix if present
157 tokenString = stripBearerPrefix(tokenString)
158
159 // Parse without verification first to extract claims
160 parser := jwt.NewParser(jwt.WithoutClaimsValidation())
161 token, _, err := parser.ParseUnverified(tokenString, &Claims{})
162 if err != nil {
163 return nil, fmt.Errorf("failed to parse JWT: %w", err)
164 }
165
166 claims, ok := token.Claims.(*Claims)
167 if !ok {
168 return nil, fmt.Errorf("invalid claims type")
169 }
170
171 // Validate required fields
172 if claims.Subject == "" {
173 return nil, fmt.Errorf("missing 'sub' claim (user DID)")
174 }
175
176 // atProto PDSes may use 'aud' instead of 'iss' for the authorization server
177 // If 'iss' is missing, use 'aud' as the authorization server identifier
178 if claims.Issuer == "" {
179 if len(claims.Audience) > 0 {
180 claims.Issuer = claims.Audience[0]
181 } else {
182 return nil, fmt.Errorf("missing both 'iss' and 'aud' claims (authorization server)")
183 }
184 }
185
186 // Validate claims (even in Phase 1, we need basic validation like expiry)
187 if err := validateClaims(claims); err != nil {
188 return nil, err
189 }
190
191 return claims, nil
192}
193
194// VerifyJWT verifies a JWT token's signature and claims (Phase 2)
195// Fetches the public key from the issuer's JWKS endpoint and validates the signature
196// For HS256 tokens from whitelisted issuers, uses the shared PDS_JWT_SECRET
197//
198// SECURITY: Algorithm is determined by the issuer whitelist, NOT the token header,
199// to prevent algorithm confusion attacks where an attacker could re-sign a token
200// with HS256 using a public key as the secret.
201func VerifyJWT(ctx context.Context, tokenString string, keyFetcher JWKSFetcher) (*Claims, error) {
202 // Strip Bearer prefix once at the start
203 tokenString = stripBearerPrefix(tokenString)
204
205 // First parse to get the issuer (needed to determine expected algorithm)
206 claims, err := ParseJWT(tokenString)
207 if err != nil {
208 return nil, err
209 }
210
211 // Parse header to get the claimed algorithm (for validation)
212 header, err := ParseJWTHeader(tokenString)
213 if err != nil {
214 return nil, err
215 }
216
217 // SECURITY: Determine verification method based on token characteristics
218 // 1. Tokens with `kid` MUST use asymmetric verification (supports federation)
219 // 2. Tokens without `kid` can use HS256 only from whitelisted issuers (your own PDS)
220 useHS256 := shouldUseHS256(header, claims.Issuer)
221
222 if useHS256 {
223 // Verify token actually claims to use HS256
224 if header.Alg != AlgorithmHS256 {
225 return nil, fmt.Errorf("expected HS256 for issuer %s but token uses %s", claims.Issuer, header.Alg)
226 }
227 return verifyHS256Token(tokenString)
228 }
229
230 // Token must use asymmetric verification
231 // Reject HS256 tokens that don't meet the criteria above
232 if header.Alg == AlgorithmHS256 {
233 if header.Kid != "" {
234 return nil, fmt.Errorf("HS256 tokens with kid must use asymmetric verification")
235 }
236 return nil, fmt.Errorf("HS256 not allowed for issuer %s (not in HS256_ISSUERS whitelist)", claims.Issuer)
237 }
238
239 // For RSA/ECDSA, fetch public key from JWKS and verify
240 return verifyAsymmetricToken(ctx, tokenString, claims.Issuer, keyFetcher)
241}
242
243// verifyHS256Token verifies a JWT using HMAC-SHA256 with the shared secret
244func verifyHS256Token(tokenString string) (*Claims, error) {
245 cfg := getConfig()
246 if len(cfg.pdsJWTSecret) == 0 {
247 return nil, fmt.Errorf("HS256 verification failed: PDS_JWT_SECRET not configured")
248 }
249
250 token, err := jwt.ParseWithClaims(tokenString, &Claims{}, func(token *jwt.Token) (interface{}, error) {
251 if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
252 return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
253 }
254 return cfg.pdsJWTSecret, nil
255 })
256 if err != nil {
257 return nil, fmt.Errorf("HS256 verification failed: %w", err)
258 }
259
260 if !token.Valid {
261 return nil, fmt.Errorf("HS256 verification failed: token signature invalid")
262 }
263
264 verifiedClaims, ok := token.Claims.(*Claims)
265 if !ok {
266 return nil, fmt.Errorf("HS256 verification failed: invalid claims type")
267 }
268
269 if err := validateClaims(verifiedClaims); err != nil {
270 return nil, err
271 }
272
273 return verifiedClaims, nil
274}
275
276// verifyAsymmetricToken verifies a JWT using RSA or ECDSA with a public key from JWKS
277func verifyAsymmetricToken(ctx context.Context, tokenString, issuer string, keyFetcher JWKSFetcher) (*Claims, error) {
278 publicKey, err := keyFetcher.FetchPublicKey(ctx, issuer, tokenString)
279 if err != nil {
280 return nil, fmt.Errorf("failed to fetch public key: %w", err)
281 }
282
283 token, err := jwt.ParseWithClaims(tokenString, &Claims{}, func(token *jwt.Token) (interface{}, error) {
284 // Validate signing method - support both RSA and ECDSA (atProto uses ES256 primarily)
285 switch token.Method.(type) {
286 case *jwt.SigningMethodRSA, *jwt.SigningMethodECDSA:
287 // Valid signing methods for atProto
288 default:
289 return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
290 }
291 return publicKey, nil
292 })
293 if err != nil {
294 return nil, fmt.Errorf("asymmetric verification failed: %w", err)
295 }
296
297 if !token.Valid {
298 return nil, fmt.Errorf("asymmetric verification failed: token signature invalid")
299 }
300
301 verifiedClaims, ok := token.Claims.(*Claims)
302 if !ok {
303 return nil, fmt.Errorf("asymmetric verification failed: invalid claims type")
304 }
305
306 if err := validateClaims(verifiedClaims); err != nil {
307 return nil, err
308 }
309
310 return verifiedClaims, nil
311}
312
313// validateClaims performs additional validation on JWT claims
314func validateClaims(claims *Claims) error {
315 now := time.Now()
316
317 // Check expiration
318 if claims.ExpiresAt != nil && claims.ExpiresAt.Before(now) {
319 return fmt.Errorf("token has expired")
320 }
321
322 // Check not before
323 if claims.NotBefore != nil && claims.NotBefore.After(now) {
324 return fmt.Errorf("token not yet valid")
325 }
326
327 // Validate DID format in sub claim
328 if !strings.HasPrefix(claims.Subject, "did:") {
329 return fmt.Errorf("invalid DID format in 'sub' claim: %s", claims.Subject)
330 }
331
332 // Validate issuer is either an HTTPS URL or a DID
333 // atProto uses DIDs (did:web:, did:plc:) or HTTPS URLs as issuer identifiers
334 // In dev mode (IS_DEV_ENV=true), allow HTTP for local PDS testing
335 isHTTP := strings.HasPrefix(claims.Issuer, "http://")
336 isHTTPS := strings.HasPrefix(claims.Issuer, "https://")
337 isDID := strings.HasPrefix(claims.Issuer, "did:")
338
339 if !isHTTPS && !isDID && !isHTTP {
340 return fmt.Errorf("issuer must be HTTPS URL, HTTP URL (dev only), or DID, got: %s", claims.Issuer)
341 }
342
343 // In production, reject HTTP issuers (only for non-dev environments)
344 cfg := getConfig()
345 if isHTTP && !cfg.isDevEnv {
346 return fmt.Errorf("HTTP issuer not allowed in production, got: %s", claims.Issuer)
347 }
348
349 // Parse to ensure it's a valid URL
350 if _, err := url.Parse(claims.Issuer); err != nil {
351 return fmt.Errorf("invalid issuer URL: %w", err)
352 }
353
354 // Validate scope if present (lenient: allow empty, but reject wrong scopes)
355 if claims.Scope != "" && !strings.Contains(claims.Scope, "atproto") {
356 return fmt.Errorf("token missing required 'atproto' scope, got: %s", claims.Scope)
357 }
358
359 return nil
360}
361
362// JWKSFetcher defines the interface for fetching public keys from JWKS endpoints
363// Returns interface{} to support both RSA and ECDSA keys
364type JWKSFetcher interface {
365 FetchPublicKey(ctx context.Context, issuer, token string) (interface{}, error)
366}
367
368// JWK represents a JSON Web Key from a JWKS endpoint
369// Supports both RSA and EC (ECDSA) keys
370type JWK struct {
371 Kid string `json:"kid"` // Key ID
372 Kty string `json:"kty"` // Key type ("RSA" or "EC")
373 Alg string `json:"alg"` // Algorithm (e.g., "RS256", "ES256")
374 Use string `json:"use"` // Public key use (should be "sig" for signatures)
375 // RSA fields
376 N string `json:"n,omitempty"` // RSA modulus
377 E string `json:"e,omitempty"` // RSA exponent
378 // EC fields
379 Crv string `json:"crv,omitempty"` // EC curve (e.g., "P-256")
380 X string `json:"x,omitempty"` // EC x coordinate
381 Y string `json:"y,omitempty"` // EC y coordinate
382}
383
384// ToPublicKey converts a JWK to a public key (RSA or ECDSA)
385func (j *JWK) ToPublicKey() (interface{}, error) {
386 switch j.Kty {
387 case "RSA":
388 return j.toRSAPublicKey()
389 case "EC":
390 return j.toECPublicKey()
391 default:
392 return nil, fmt.Errorf("unsupported key type: %s", j.Kty)
393 }
394}
395
396// toRSAPublicKey converts a JWK to an RSA public key
397func (j *JWK) toRSAPublicKey() (*rsa.PublicKey, error) {
398 // Decode modulus
399 nBytes, err := base64.RawURLEncoding.DecodeString(j.N)
400 if err != nil {
401 return nil, fmt.Errorf("failed to decode RSA modulus: %w", err)
402 }
403
404 // Decode exponent
405 eBytes, err := base64.RawURLEncoding.DecodeString(j.E)
406 if err != nil {
407 return nil, fmt.Errorf("failed to decode RSA exponent: %w", err)
408 }
409
410 // Convert exponent to int
411 var eInt int
412 for _, b := range eBytes {
413 eInt = eInt*256 + int(b)
414 }
415
416 return &rsa.PublicKey{
417 N: new(big.Int).SetBytes(nBytes),
418 E: eInt,
419 }, nil
420}
421
422// toECPublicKey converts a JWK to an ECDSA public key
423func (j *JWK) toECPublicKey() (*ecdsa.PublicKey, error) {
424 // Determine curve
425 var curve elliptic.Curve
426 switch j.Crv {
427 case "P-256":
428 curve = elliptic.P256()
429 case "P-384":
430 curve = elliptic.P384()
431 case "P-521":
432 curve = elliptic.P521()
433 default:
434 return nil, fmt.Errorf("unsupported EC curve: %s", j.Crv)
435 }
436
437 // Decode X coordinate
438 xBytes, err := base64.RawURLEncoding.DecodeString(j.X)
439 if err != nil {
440 return nil, fmt.Errorf("failed to decode EC x coordinate: %w", err)
441 }
442
443 // Decode Y coordinate
444 yBytes, err := base64.RawURLEncoding.DecodeString(j.Y)
445 if err != nil {
446 return nil, fmt.Errorf("failed to decode EC y coordinate: %w", err)
447 }
448
449 return &ecdsa.PublicKey{
450 Curve: curve,
451 X: new(big.Int).SetBytes(xBytes),
452 Y: new(big.Int).SetBytes(yBytes),
453 }, nil
454}
455
456// JWKS represents a JSON Web Key Set
457type JWKS struct {
458 Keys []JWK `json:"keys"`
459}
460
461// FindKeyByID finds a key in the JWKS by its key ID
462func (j *JWKS) FindKeyByID(kid string) (*JWK, error) {
463 for _, key := range j.Keys {
464 if key.Kid == kid {
465 return &key, nil
466 }
467 }
468 return nil, fmt.Errorf("key with kid %s not found", kid)
469}
470
471// ExtractKeyID extracts the key ID from a JWT token header
472func ExtractKeyID(tokenString string) (string, error) {
473 header, err := ParseJWTHeader(tokenString)
474 if err != nil {
475 return "", err
476 }
477
478 if header.Kid == "" {
479 return "", fmt.Errorf("missing kid in token header")
480 }
481
482 return header.Kid, nil
483}