A community based topic aggregation platform built on atproto
1package auth
2
3import (
4 "context"
5 "crypto/ecdsa"
6 "crypto/elliptic"
7 "crypto/rsa"
8 "encoding/base64"
9 "encoding/json"
10 "fmt"
11 "math/big"
12 "net/url"
13 "os"
14 "strings"
15 "sync"
16 "time"
17
18 indigoCrypto "github.com/bluesky-social/indigo/atproto/atcrypto"
19 "github.com/golang-jwt/jwt/v5"
20)
21
22// jwtConfig holds cached JWT configuration to avoid reading env vars on every request
23type jwtConfig struct {
24 hs256Issuers map[string]struct{} // Set of whitelisted HS256 issuers
25 pdsJWTSecret []byte // Cached PDS_JWT_SECRET
26 isDevEnv bool // Cached IS_DEV_ENV
27}
28
29var (
30 cachedConfig *jwtConfig
31 configOnce sync.Once
32)
33
34// InitJWTConfig initializes the JWT configuration from environment variables.
35// This should be called once at startup. If not called explicitly, it will be
36// initialized lazily on first use.
37func InitJWTConfig() {
38 configOnce.Do(func() {
39 cachedConfig = &jwtConfig{
40 hs256Issuers: make(map[string]struct{}),
41 isDevEnv: os.Getenv("IS_DEV_ENV") == "true",
42 }
43
44 // Parse HS256_ISSUERS into a set for O(1) lookup
45 if issuers := os.Getenv("HS256_ISSUERS"); issuers != "" {
46 for _, issuer := range strings.Split(issuers, ",") {
47 issuer = strings.TrimSpace(issuer)
48 if issuer != "" {
49 cachedConfig.hs256Issuers[issuer] = struct{}{}
50 }
51 }
52 }
53
54 // Cache PDS_JWT_SECRET
55 if secret := os.Getenv("PDS_JWT_SECRET"); secret != "" {
56 cachedConfig.pdsJWTSecret = []byte(secret)
57 }
58 })
59}
60
61// getConfig returns the cached config, initializing if needed
62func getConfig() *jwtConfig {
63 InitJWTConfig()
64 return cachedConfig
65}
66
67// ResetJWTConfigForTesting resets the cached config to allow re-initialization.
68// This should ONLY be used in tests.
69func ResetJWTConfigForTesting() {
70 cachedConfig = nil
71 configOnce = sync.Once{}
72}
73
74// Algorithm constants for JWT signing methods
75const (
76 AlgorithmHS256 = "HS256"
77 AlgorithmRS256 = "RS256"
78 AlgorithmES256 = "ES256"
79)
80
81// JWTHeader represents the parsed JWT header
82type JWTHeader struct {
83 Alg string `json:"alg"`
84 Kid string `json:"kid"`
85 Typ string `json:"typ,omitempty"`
86}
87
88// Claims represents the standard JWT claims we care about
89type Claims struct {
90 jwt.RegisteredClaims
91 // Confirmation claim for DPoP token binding (RFC 9449)
92 // Contains "jkt" (JWK thumbprint) when token is bound to a DPoP key
93 Confirmation map[string]interface{} `json:"cnf,omitempty"`
94 Scope string `json:"scope,omitempty"`
95}
96
97// stripBearerPrefix removes the "Bearer " prefix from a token string
98func stripBearerPrefix(tokenString string) string {
99 tokenString = strings.TrimPrefix(tokenString, "Bearer ")
100 return strings.TrimSpace(tokenString)
101}
102
103// ParseJWTHeader extracts and parses the JWT header from a token string
104// This is a reusable function for getting algorithm and key ID information
105func ParseJWTHeader(tokenString string) (*JWTHeader, error) {
106 tokenString = stripBearerPrefix(tokenString)
107
108 parts := strings.Split(tokenString, ".")
109 if len(parts) != 3 {
110 return nil, fmt.Errorf("invalid JWT format: expected 3 parts, got %d", len(parts))
111 }
112
113 headerBytes, err := base64.RawURLEncoding.DecodeString(parts[0])
114 if err != nil {
115 return nil, fmt.Errorf("failed to decode JWT header: %w", err)
116 }
117
118 var header JWTHeader
119 if err := json.Unmarshal(headerBytes, &header); err != nil {
120 return nil, fmt.Errorf("failed to parse JWT header: %w", err)
121 }
122
123 return &header, nil
124}
125
126// shouldUseHS256 determines if a token should use HS256 verification
127// This prevents algorithm confusion attacks by using multiple signals:
128// 1. If the token has a `kid` (key ID), it MUST use asymmetric verification
129// 2. If no `kid`, only allow HS256 from whitelisted issuers (your own PDS)
130//
131// This approach supports open federation because:
132// - External PDSes publish keys via JWKS and include `kid` in their tokens
133// - Only your own PDS (which shares PDS_JWT_SECRET) uses HS256 without `kid`
134func shouldUseHS256(header *JWTHeader, issuer string) bool {
135 // If token has a key ID, it MUST use asymmetric verification
136 // This is the primary defense against algorithm confusion attacks
137 if header.Kid != "" {
138 return false
139 }
140
141 // No kid - check if issuer is whitelisted for HS256
142 // This should only include your own PDS URL(s)
143 return isHS256IssuerWhitelisted(issuer)
144}
145
146// isHS256IssuerWhitelisted checks if the issuer is in the HS256 whitelist
147// Only your own PDS should be in this list - external PDSes should use JWKS
148func isHS256IssuerWhitelisted(issuer string) bool {
149 cfg := getConfig()
150 _, whitelisted := cfg.hs256Issuers[issuer]
151 return whitelisted
152}
153
154// ParseJWT parses a JWT token without verification (Phase 1)
155// Returns the claims if the token is valid JSON and has required fields
156func ParseJWT(tokenString string) (*Claims, error) {
157 // Remove "Bearer " prefix if present
158 tokenString = stripBearerPrefix(tokenString)
159
160 // Parse without verification first to extract claims
161 parser := jwt.NewParser(jwt.WithoutClaimsValidation())
162 token, _, err := parser.ParseUnverified(tokenString, &Claims{})
163 if err != nil {
164 return nil, fmt.Errorf("failed to parse JWT: %w", err)
165 }
166
167 claims, ok := token.Claims.(*Claims)
168 if !ok {
169 return nil, fmt.Errorf("invalid claims type")
170 }
171
172 // Validate required fields
173 if claims.Subject == "" {
174 return nil, fmt.Errorf("missing 'sub' claim (user DID)")
175 }
176
177 // atProto PDSes may use 'aud' instead of 'iss' for the authorization server
178 // If 'iss' is missing, use 'aud' as the authorization server identifier
179 if claims.Issuer == "" {
180 if len(claims.Audience) > 0 {
181 claims.Issuer = claims.Audience[0]
182 } else {
183 return nil, fmt.Errorf("missing both 'iss' and 'aud' claims (authorization server)")
184 }
185 }
186
187 // Validate claims (even in Phase 1, we need basic validation like expiry)
188 if err := validateClaims(claims); err != nil {
189 return nil, err
190 }
191
192 return claims, nil
193}
194
195// VerifyJWT verifies a JWT token's signature and claims (Phase 2)
196// Fetches the public key from the issuer's JWKS endpoint and validates the signature
197// For HS256 tokens from whitelisted issuers, uses the shared PDS_JWT_SECRET
198//
199// SECURITY: Algorithm is determined by the issuer whitelist, NOT the token header,
200// to prevent algorithm confusion attacks where an attacker could re-sign a token
201// with HS256 using a public key as the secret.
202func VerifyJWT(ctx context.Context, tokenString string, keyFetcher JWKSFetcher) (*Claims, error) {
203 // Strip Bearer prefix once at the start
204 tokenString = stripBearerPrefix(tokenString)
205
206 // First parse to get the issuer (needed to determine expected algorithm)
207 claims, err := ParseJWT(tokenString)
208 if err != nil {
209 return nil, err
210 }
211
212 // Parse header to get the claimed algorithm (for validation)
213 header, err := ParseJWTHeader(tokenString)
214 if err != nil {
215 return nil, err
216 }
217
218 // SECURITY: Determine verification method based on token characteristics
219 // 1. Tokens with `kid` MUST use asymmetric verification (supports federation)
220 // 2. Tokens without `kid` can use HS256 only from whitelisted issuers (your own PDS)
221 useHS256 := shouldUseHS256(header, claims.Issuer)
222
223 if useHS256 {
224 // Verify token actually claims to use HS256
225 if header.Alg != AlgorithmHS256 {
226 return nil, fmt.Errorf("expected HS256 for issuer %s but token uses %s", claims.Issuer, header.Alg)
227 }
228 return verifyHS256Token(tokenString)
229 }
230
231 // Token must use asymmetric verification
232 // Reject HS256 tokens that don't meet the criteria above
233 if header.Alg == AlgorithmHS256 {
234 if header.Kid != "" {
235 return nil, fmt.Errorf("HS256 tokens with kid must use asymmetric verification")
236 }
237 return nil, fmt.Errorf("HS256 not allowed for issuer %s (not in HS256_ISSUERS whitelist)", claims.Issuer)
238 }
239
240 // For RSA/ECDSA, fetch public key from JWKS and verify
241 return verifyAsymmetricToken(ctx, tokenString, claims.Issuer, keyFetcher)
242}
243
244// verifyHS256Token verifies a JWT using HMAC-SHA256 with the shared secret
245func verifyHS256Token(tokenString string) (*Claims, error) {
246 cfg := getConfig()
247 if len(cfg.pdsJWTSecret) == 0 {
248 return nil, fmt.Errorf("HS256 verification failed: PDS_JWT_SECRET not configured")
249 }
250
251 token, err := jwt.ParseWithClaims(tokenString, &Claims{}, func(token *jwt.Token) (interface{}, error) {
252 if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
253 return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
254 }
255 return cfg.pdsJWTSecret, nil
256 })
257 if err != nil {
258 return nil, fmt.Errorf("HS256 verification failed: %w", err)
259 }
260
261 if !token.Valid {
262 return nil, fmt.Errorf("HS256 verification failed: token signature invalid")
263 }
264
265 verifiedClaims, ok := token.Claims.(*Claims)
266 if !ok {
267 return nil, fmt.Errorf("HS256 verification failed: invalid claims type")
268 }
269
270 if err := validateClaims(verifiedClaims); err != nil {
271 return nil, err
272 }
273
274 return verifiedClaims, nil
275}
276
277// verifyAsymmetricToken verifies a JWT using RSA or ECDSA with a public key from JWKS.
278// For ES256K (secp256k1), uses indigo's crypto package since golang-jwt doesn't support it.
279func verifyAsymmetricToken(ctx context.Context, tokenString, issuer string, keyFetcher JWKSFetcher) (*Claims, error) {
280 // Parse header to check algorithm
281 header, err := ParseJWTHeader(tokenString)
282 if err != nil {
283 return nil, fmt.Errorf("failed to parse JWT header: %w", err)
284 }
285
286 // ES256K (secp256k1) requires special handling via indigo's crypto package
287 // golang-jwt doesn't recognize ES256K as a valid signing method
288 if header.Alg == "ES256K" {
289 return verifyES256KToken(ctx, tokenString, issuer, keyFetcher)
290 }
291
292 // For standard algorithms (ES256, ES384, ES512, RS256, etc.), use golang-jwt
293 publicKey, err := keyFetcher.FetchPublicKey(ctx, issuer, tokenString)
294 if err != nil {
295 return nil, fmt.Errorf("failed to fetch public key: %w", err)
296 }
297
298 token, err := jwt.ParseWithClaims(tokenString, &Claims{}, func(token *jwt.Token) (interface{}, error) {
299 // Validate signing method - support both RSA and ECDSA (atProto uses ES256 primarily)
300 switch token.Method.(type) {
301 case *jwt.SigningMethodRSA, *jwt.SigningMethodECDSA:
302 // Valid signing methods for atProto
303 default:
304 return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
305 }
306 return publicKey, nil
307 })
308 if err != nil {
309 return nil, fmt.Errorf("asymmetric verification failed: %w", err)
310 }
311
312 if !token.Valid {
313 return nil, fmt.Errorf("asymmetric verification failed: token signature invalid")
314 }
315
316 verifiedClaims, ok := token.Claims.(*Claims)
317 if !ok {
318 return nil, fmt.Errorf("asymmetric verification failed: invalid claims type")
319 }
320
321 if err := validateClaims(verifiedClaims); err != nil {
322 return nil, err
323 }
324
325 return verifiedClaims, nil
326}
327
328// verifyES256KToken verifies a JWT signed with ES256K (secp256k1) using indigo's crypto package.
329// This is necessary because golang-jwt doesn't support ES256K as a signing method.
330func verifyES256KToken(ctx context.Context, tokenString, issuer string, keyFetcher JWKSFetcher) (*Claims, error) {
331 // Fetch the public key - for ES256K, the fetcher returns a JWK map or indigo PublicKey
332 keyData, err := keyFetcher.FetchPublicKey(ctx, issuer, tokenString)
333 if err != nil {
334 return nil, fmt.Errorf("failed to fetch public key for ES256K: %w", err)
335 }
336
337 // Convert to indigo PublicKey based on what the fetcher returned
338 var pubKey indigoCrypto.PublicKey
339 switch k := keyData.(type) {
340 case indigoCrypto.PublicKey:
341 // Already an indigo PublicKey (from DIDKeyFetcher or updated JWKSFetcher)
342 pubKey = k
343 case map[string]interface{}:
344 // Raw JWK map - parse with indigo
345 pubKey, err = parseJWKMapToIndigoPublicKey(k)
346 if err != nil {
347 return nil, fmt.Errorf("failed to parse ES256K JWK: %w", err)
348 }
349 default:
350 return nil, fmt.Errorf("ES256K verification requires indigo PublicKey or JWK map, got %T", keyData)
351 }
352
353 // Verify signature using indigo
354 if err := verifyJWTSignatureWithIndigoKey(tokenString, pubKey); err != nil {
355 return nil, fmt.Errorf("ES256K signature verification failed: %w", err)
356 }
357
358 // Parse claims (signature already verified)
359 claims, err := parseJWTClaimsManually(tokenString)
360 if err != nil {
361 return nil, fmt.Errorf("failed to parse ES256K JWT claims: %w", err)
362 }
363
364 if err := validateClaims(claims); err != nil {
365 return nil, err
366 }
367
368 return claims, nil
369}
370
371// parseJWKMapToIndigoPublicKey converts a JWK map to an indigo PublicKey.
372// This uses indigo's crypto package which supports all atProto curves including secp256k1.
373func parseJWKMapToIndigoPublicKey(jwkMap map[string]interface{}) (indigoCrypto.PublicKey, error) {
374 // Convert map to JSON bytes for indigo's parser
375 jwkBytes, err := json.Marshal(jwkMap)
376 if err != nil {
377 return nil, fmt.Errorf("failed to serialize JWK: %w", err)
378 }
379
380 // Parse with indigo's crypto package - supports all atProto curves
381 pubKey, err := indigoCrypto.ParsePublicJWKBytes(jwkBytes)
382 if err != nil {
383 return nil, fmt.Errorf("failed to parse JWK with indigo: %w", err)
384 }
385
386 return pubKey, nil
387}
388
389// verifyJWTSignatureWithIndigoKey verifies a JWT signature using indigo's crypto package.
390// This works for all ECDSA algorithms including ES256K (secp256k1).
391func verifyJWTSignatureWithIndigoKey(tokenString string, pubKey indigoCrypto.PublicKey) error {
392 parts := strings.Split(tokenString, ".")
393 if len(parts) != 3 {
394 return fmt.Errorf("invalid JWT format: expected 3 parts, got %d", len(parts))
395 }
396
397 // The signing input is "header.payload" (without decoding)
398 signingInput := parts[0] + "." + parts[1]
399
400 // Decode the signature from base64url
401 signature, err := base64.RawURLEncoding.DecodeString(parts[2])
402 if err != nil {
403 return fmt.Errorf("failed to decode JWT signature: %w", err)
404 }
405
406 // Use indigo's verification - HashAndVerifyLenient handles hashing internally
407 // and accepts both low-S and high-S signatures for maximum compatibility
408 if err := pubKey.HashAndVerifyLenient([]byte(signingInput), signature); err != nil {
409 return fmt.Errorf("signature verification failed: %w", err)
410 }
411
412 return nil
413}
414
415// parseJWTClaimsManually parses JWT claims without using golang-jwt.
416// This is used for ES256K tokens where golang-jwt would reject the algorithm.
417func parseJWTClaimsManually(tokenString string) (*Claims, error) {
418 parts := strings.Split(tokenString, ".")
419 if len(parts) != 3 {
420 return nil, fmt.Errorf("invalid JWT format: expected 3 parts, got %d", len(parts))
421 }
422
423 // Decode claims
424 claimsBytes, err := base64.RawURLEncoding.DecodeString(parts[1])
425 if err != nil {
426 return nil, fmt.Errorf("failed to decode JWT claims: %w", err)
427 }
428
429 // Parse into raw map first
430 var rawClaims map[string]interface{}
431 if err := json.Unmarshal(claimsBytes, &rawClaims); err != nil {
432 return nil, fmt.Errorf("failed to parse JWT claims: %w", err)
433 }
434
435 // Build Claims struct
436 claims := &Claims{}
437
438 // Extract sub (subject/DID)
439 if sub, ok := rawClaims["sub"].(string); ok {
440 claims.Subject = sub
441 }
442
443 // Extract iss (issuer)
444 if iss, ok := rawClaims["iss"].(string); ok {
445 claims.Issuer = iss
446 }
447
448 // Extract aud (audience) - can be string or array
449 switch aud := rawClaims["aud"].(type) {
450 case string:
451 claims.Audience = jwt.ClaimStrings{aud}
452 case []interface{}:
453 for _, a := range aud {
454 if s, ok := a.(string); ok {
455 claims.Audience = append(claims.Audience, s)
456 }
457 }
458 }
459
460 // Extract exp (expiration)
461 if exp, ok := rawClaims["exp"].(float64); ok {
462 t := time.Unix(int64(exp), 0)
463 claims.ExpiresAt = jwt.NewNumericDate(t)
464 }
465
466 // Extract iat (issued at)
467 if iat, ok := rawClaims["iat"].(float64); ok {
468 t := time.Unix(int64(iat), 0)
469 claims.IssuedAt = jwt.NewNumericDate(t)
470 }
471
472 // Extract nbf (not before)
473 if nbf, ok := rawClaims["nbf"].(float64); ok {
474 t := time.Unix(int64(nbf), 0)
475 claims.NotBefore = jwt.NewNumericDate(t)
476 }
477
478 // Extract jti (JWT ID)
479 if jti, ok := rawClaims["jti"].(string); ok {
480 claims.ID = jti
481 }
482
483 // Extract scope
484 if scope, ok := rawClaims["scope"].(string); ok {
485 claims.Scope = scope
486 }
487
488 // Extract cnf (confirmation) for DPoP binding
489 if cnf, ok := rawClaims["cnf"].(map[string]interface{}); ok {
490 claims.Confirmation = cnf
491 }
492
493 return claims, nil
494}
495
496// validateClaims performs additional validation on JWT claims
497func validateClaims(claims *Claims) error {
498 now := time.Now()
499
500 // Check expiration
501 if claims.ExpiresAt != nil && claims.ExpiresAt.Before(now) {
502 return fmt.Errorf("token has expired")
503 }
504
505 // Check not before
506 if claims.NotBefore != nil && claims.NotBefore.After(now) {
507 return fmt.Errorf("token not yet valid")
508 }
509
510 // Validate DID format in sub claim
511 if !strings.HasPrefix(claims.Subject, "did:") {
512 return fmt.Errorf("invalid DID format in 'sub' claim: %s", claims.Subject)
513 }
514
515 // Validate issuer is either an HTTPS URL or a DID
516 // atProto uses DIDs (did:web:, did:plc:) or HTTPS URLs as issuer identifiers
517 // In dev mode (IS_DEV_ENV=true), allow HTTP for local PDS testing
518 isHTTP := strings.HasPrefix(claims.Issuer, "http://")
519 isHTTPS := strings.HasPrefix(claims.Issuer, "https://")
520 isDID := strings.HasPrefix(claims.Issuer, "did:")
521
522 if !isHTTPS && !isDID && !isHTTP {
523 return fmt.Errorf("issuer must be HTTPS URL, HTTP URL (dev only), or DID, got: %s", claims.Issuer)
524 }
525
526 // In production, reject HTTP issuers (only for non-dev environments)
527 cfg := getConfig()
528 if isHTTP && !cfg.isDevEnv {
529 return fmt.Errorf("HTTP issuer not allowed in production, got: %s", claims.Issuer)
530 }
531
532 // Parse to ensure it's a valid URL
533 if _, err := url.Parse(claims.Issuer); err != nil {
534 return fmt.Errorf("invalid issuer URL: %w", err)
535 }
536
537 // Validate scope if present (lenient: allow empty, but reject wrong scopes)
538 if claims.Scope != "" && !strings.Contains(claims.Scope, "atproto") {
539 return fmt.Errorf("token missing required 'atproto' scope, got: %s", claims.Scope)
540 }
541
542 return nil
543}
544
545// JWKSFetcher defines the interface for fetching public keys from JWKS endpoints
546// Returns interface{} to support both RSA and ECDSA keys
547type JWKSFetcher interface {
548 FetchPublicKey(ctx context.Context, issuer, token string) (interface{}, error)
549}
550
551// JWK represents a JSON Web Key from a JWKS endpoint
552// Supports both RSA and EC (ECDSA) keys
553type JWK struct {
554 Kid string `json:"kid"` // Key ID
555 Kty string `json:"kty"` // Key type ("RSA" or "EC")
556 Alg string `json:"alg"` // Algorithm (e.g., "RS256", "ES256")
557 Use string `json:"use"` // Public key use (should be "sig" for signatures)
558 // RSA fields
559 N string `json:"n,omitempty"` // RSA modulus
560 E string `json:"e,omitempty"` // RSA exponent
561 // EC fields
562 Crv string `json:"crv,omitempty"` // EC curve (e.g., "P-256")
563 X string `json:"x,omitempty"` // EC x coordinate
564 Y string `json:"y,omitempty"` // EC y coordinate
565}
566
567// ToPublicKey converts a JWK to a public key (RSA, ECDSA, or indigo for secp256k1).
568//
569// Returns:
570// - *rsa.PublicKey for RSA keys
571// - *ecdsa.PublicKey for NIST EC curves (P-256, P-384, P-521)
572// - map[string]interface{} for secp256k1 (ES256K) - parsed by indigo
573func (j *JWK) ToPublicKey() (interface{}, error) {
574 switch j.Kty {
575 case "RSA":
576 return j.toRSAPublicKey()
577 case "EC":
578 // For secp256k1, return raw JWK map for indigo to parse
579 if j.Crv == "secp256k1" {
580 return j.toJWKMap(), nil
581 }
582 return j.toECPublicKey()
583 default:
584 return nil, fmt.Errorf("unsupported key type: %s", j.Kty)
585 }
586}
587
588// toJWKMap converts the JWK struct to a map for indigo parsing
589func (j *JWK) toJWKMap() map[string]interface{} {
590 m := map[string]interface{}{
591 "kty": j.Kty,
592 }
593 if j.Kid != "" {
594 m["kid"] = j.Kid
595 }
596 if j.Alg != "" {
597 m["alg"] = j.Alg
598 }
599 if j.Use != "" {
600 m["use"] = j.Use
601 }
602 // RSA fields
603 if j.N != "" {
604 m["n"] = j.N
605 }
606 if j.E != "" {
607 m["e"] = j.E
608 }
609 // EC fields
610 if j.Crv != "" {
611 m["crv"] = j.Crv
612 }
613 if j.X != "" {
614 m["x"] = j.X
615 }
616 if j.Y != "" {
617 m["y"] = j.Y
618 }
619 return m
620}
621
622// toRSAPublicKey converts a JWK to an RSA public key
623func (j *JWK) toRSAPublicKey() (*rsa.PublicKey, error) {
624 // Decode modulus
625 nBytes, err := base64.RawURLEncoding.DecodeString(j.N)
626 if err != nil {
627 return nil, fmt.Errorf("failed to decode RSA modulus: %w", err)
628 }
629
630 // Decode exponent
631 eBytes, err := base64.RawURLEncoding.DecodeString(j.E)
632 if err != nil {
633 return nil, fmt.Errorf("failed to decode RSA exponent: %w", err)
634 }
635
636 // Convert exponent to int
637 var eInt int
638 for _, b := range eBytes {
639 eInt = eInt*256 + int(b)
640 }
641
642 return &rsa.PublicKey{
643 N: new(big.Int).SetBytes(nBytes),
644 E: eInt,
645 }, nil
646}
647
648// toECPublicKey converts a JWK to an ECDSA public key
649func (j *JWK) toECPublicKey() (*ecdsa.PublicKey, error) {
650 // Determine curve
651 var curve elliptic.Curve
652 switch j.Crv {
653 case "P-256":
654 curve = elliptic.P256()
655 case "P-384":
656 curve = elliptic.P384()
657 case "P-521":
658 curve = elliptic.P521()
659 default:
660 return nil, fmt.Errorf("unsupported EC curve: %s", j.Crv)
661 }
662
663 // Decode X coordinate
664 xBytes, err := base64.RawURLEncoding.DecodeString(j.X)
665 if err != nil {
666 return nil, fmt.Errorf("failed to decode EC x coordinate: %w", err)
667 }
668
669 // Decode Y coordinate
670 yBytes, err := base64.RawURLEncoding.DecodeString(j.Y)
671 if err != nil {
672 return nil, fmt.Errorf("failed to decode EC y coordinate: %w", err)
673 }
674
675 return &ecdsa.PublicKey{
676 Curve: curve,
677 X: new(big.Int).SetBytes(xBytes),
678 Y: new(big.Int).SetBytes(yBytes),
679 }, nil
680}
681
682// JWKS represents a JSON Web Key Set
683type JWKS struct {
684 Keys []JWK `json:"keys"`
685}
686
687// FindKeyByID finds a key in the JWKS by its key ID
688func (j *JWKS) FindKeyByID(kid string) (*JWK, error) {
689 for _, key := range j.Keys {
690 if key.Kid == kid {
691 return &key, nil
692 }
693 }
694 return nil, fmt.Errorf("key with kid %s not found", kid)
695}
696
697// ExtractKeyID extracts the key ID from a JWT token header
698func ExtractKeyID(tokenString string) (string, error) {
699 header, err := ParseJWTHeader(tokenString)
700 if err != nil {
701 return "", err
702 }
703
704 if header.Kid == "" {
705 return "", fmt.Errorf("missing kid in token header")
706 }
707
708 return header.Kid, nil
709}