A community based topic aggregation platform built on atproto
1package auth 2 3import ( 4 "context" 5 "crypto/ecdsa" 6 "crypto/elliptic" 7 "crypto/rsa" 8 "encoding/base64" 9 "encoding/json" 10 "fmt" 11 "math/big" 12 "net/url" 13 "strings" 14 "time" 15 16 "github.com/golang-jwt/jwt/v5" 17) 18 19// Claims represents the standard JWT claims we care about 20type Claims struct { 21 jwt.RegisteredClaims 22 Scope string `json:"scope,omitempty"` 23} 24 25// ParseJWT parses a JWT token without verification (Phase 1) 26// Returns the claims if the token is valid JSON and has required fields 27func ParseJWT(tokenString string) (*Claims, error) { 28 // Remove "Bearer " prefix if present 29 tokenString = strings.TrimPrefix(tokenString, "Bearer ") 30 tokenString = strings.TrimSpace(tokenString) 31 32 // Parse without verification first to extract claims 33 parser := jwt.NewParser(jwt.WithoutClaimsValidation()) 34 token, _, err := parser.ParseUnverified(tokenString, &Claims{}) 35 if err != nil { 36 return nil, fmt.Errorf("failed to parse JWT: %w", err) 37 } 38 39 claims, ok := token.Claims.(*Claims) 40 if !ok { 41 return nil, fmt.Errorf("invalid claims type") 42 } 43 44 // Validate required fields 45 if claims.Subject == "" { 46 return nil, fmt.Errorf("missing 'sub' claim (user DID)") 47 } 48 49 // atProto PDSes may use 'aud' instead of 'iss' for the authorization server 50 // If 'iss' is missing, use 'aud' as the authorization server identifier 51 if claims.Issuer == "" { 52 if len(claims.Audience) > 0 { 53 claims.Issuer = claims.Audience[0] 54 } else { 55 return nil, fmt.Errorf("missing both 'iss' and 'aud' claims (authorization server)") 56 } 57 } 58 59 // Validate claims (even in Phase 1, we need basic validation like expiry) 60 if err := validateClaims(claims); err != nil { 61 return nil, err 62 } 63 64 return claims, nil 65} 66 67// VerifyJWT verifies a JWT token's signature and claims (Phase 2) 68// Fetches the public key from the issuer's JWKS endpoint and validates the signature 69func VerifyJWT(ctx context.Context, tokenString string, keyFetcher JWKSFetcher) (*Claims, error) { 70 // First parse to get the issuer 71 claims, err := ParseJWT(tokenString) 72 if err != nil { 73 return nil, err 74 } 75 76 // Fetch the public key from the issuer 77 publicKey, err := keyFetcher.FetchPublicKey(ctx, claims.Issuer, tokenString) 78 if err != nil { 79 return nil, fmt.Errorf("failed to fetch public key: %w", err) 80 } 81 82 // Now parse and verify with the public key 83 tokenString = strings.TrimPrefix(tokenString, "Bearer ") 84 token, err := jwt.ParseWithClaims(tokenString, &Claims{}, func(token *jwt.Token) (interface{}, error) { 85 // Validate signing method - support both RSA and ECDSA (atProto uses ES256 primarily) 86 switch token.Method.(type) { 87 case *jwt.SigningMethodRSA, *jwt.SigningMethodECDSA: 88 // Valid signing methods for atProto 89 default: 90 return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"]) 91 } 92 return publicKey, nil 93 }) 94 if err != nil { 95 return nil, fmt.Errorf("failed to verify JWT: %w", err) 96 } 97 98 if !token.Valid { 99 return nil, fmt.Errorf("token is invalid") 100 } 101 102 verifiedClaims, ok := token.Claims.(*Claims) 103 if !ok { 104 return nil, fmt.Errorf("invalid claims type after verification") 105 } 106 107 // Additional validation 108 if err := validateClaims(verifiedClaims); err != nil { 109 return nil, err 110 } 111 112 return verifiedClaims, nil 113} 114 115// validateClaims performs additional validation on JWT claims 116func validateClaims(claims *Claims) error { 117 now := time.Now() 118 119 // Check expiration 120 if claims.ExpiresAt != nil && claims.ExpiresAt.Before(now) { 121 return fmt.Errorf("token has expired") 122 } 123 124 // Check not before 125 if claims.NotBefore != nil && claims.NotBefore.After(now) { 126 return fmt.Errorf("token not yet valid") 127 } 128 129 // Validate DID format in sub claim 130 if !strings.HasPrefix(claims.Subject, "did:") { 131 return fmt.Errorf("invalid DID format in 'sub' claim: %s", claims.Subject) 132 } 133 134 // Validate issuer is either an HTTPS URL or a DID 135 // atProto uses DIDs (did:web:, did:plc:) or HTTPS URLs as issuer identifiers 136 if !strings.HasPrefix(claims.Issuer, "https://") && !strings.HasPrefix(claims.Issuer, "did:") { 137 return fmt.Errorf("issuer must be HTTPS URL or DID, got: %s", claims.Issuer) 138 } 139 140 // Parse to ensure it's a valid URL 141 if _, err := url.Parse(claims.Issuer); err != nil { 142 return fmt.Errorf("invalid issuer URL: %w", err) 143 } 144 145 // Validate scope if present (lenient: allow empty, but reject wrong scopes) 146 if claims.Scope != "" && !strings.Contains(claims.Scope, "atproto") { 147 return fmt.Errorf("token missing required 'atproto' scope, got: %s", claims.Scope) 148 } 149 150 return nil 151} 152 153// JWKSFetcher defines the interface for fetching public keys from JWKS endpoints 154// Returns interface{} to support both RSA and ECDSA keys 155type JWKSFetcher interface { 156 FetchPublicKey(ctx context.Context, issuer, token string) (interface{}, error) 157} 158 159// JWK represents a JSON Web Key from a JWKS endpoint 160// Supports both RSA and EC (ECDSA) keys 161type JWK struct { 162 Kid string `json:"kid"` // Key ID 163 Kty string `json:"kty"` // Key type ("RSA" or "EC") 164 Alg string `json:"alg"` // Algorithm (e.g., "RS256", "ES256") 165 Use string `json:"use"` // Public key use (should be "sig" for signatures) 166 // RSA fields 167 N string `json:"n,omitempty"` // RSA modulus 168 E string `json:"e,omitempty"` // RSA exponent 169 // EC fields 170 Crv string `json:"crv,omitempty"` // EC curve (e.g., "P-256") 171 X string `json:"x,omitempty"` // EC x coordinate 172 Y string `json:"y,omitempty"` // EC y coordinate 173} 174 175// ToPublicKey converts a JWK to a public key (RSA or ECDSA) 176func (j *JWK) ToPublicKey() (interface{}, error) { 177 switch j.Kty { 178 case "RSA": 179 return j.toRSAPublicKey() 180 case "EC": 181 return j.toECPublicKey() 182 default: 183 return nil, fmt.Errorf("unsupported key type: %s", j.Kty) 184 } 185} 186 187// toRSAPublicKey converts a JWK to an RSA public key 188func (j *JWK) toRSAPublicKey() (*rsa.PublicKey, error) { 189 // Decode modulus 190 nBytes, err := base64.RawURLEncoding.DecodeString(j.N) 191 if err != nil { 192 return nil, fmt.Errorf("failed to decode RSA modulus: %w", err) 193 } 194 195 // Decode exponent 196 eBytes, err := base64.RawURLEncoding.DecodeString(j.E) 197 if err != nil { 198 return nil, fmt.Errorf("failed to decode RSA exponent: %w", err) 199 } 200 201 // Convert exponent to int 202 var eInt int 203 for _, b := range eBytes { 204 eInt = eInt*256 + int(b) 205 } 206 207 return &rsa.PublicKey{ 208 N: new(big.Int).SetBytes(nBytes), 209 E: eInt, 210 }, nil 211} 212 213// toECPublicKey converts a JWK to an ECDSA public key 214func (j *JWK) toECPublicKey() (*ecdsa.PublicKey, error) { 215 // Determine curve 216 var curve elliptic.Curve 217 switch j.Crv { 218 case "P-256": 219 curve = elliptic.P256() 220 case "P-384": 221 curve = elliptic.P384() 222 case "P-521": 223 curve = elliptic.P521() 224 default: 225 return nil, fmt.Errorf("unsupported EC curve: %s", j.Crv) 226 } 227 228 // Decode X coordinate 229 xBytes, err := base64.RawURLEncoding.DecodeString(j.X) 230 if err != nil { 231 return nil, fmt.Errorf("failed to decode EC x coordinate: %w", err) 232 } 233 234 // Decode Y coordinate 235 yBytes, err := base64.RawURLEncoding.DecodeString(j.Y) 236 if err != nil { 237 return nil, fmt.Errorf("failed to decode EC y coordinate: %w", err) 238 } 239 240 return &ecdsa.PublicKey{ 241 Curve: curve, 242 X: new(big.Int).SetBytes(xBytes), 243 Y: new(big.Int).SetBytes(yBytes), 244 }, nil 245} 246 247// JWKS represents a JSON Web Key Set 248type JWKS struct { 249 Keys []JWK `json:"keys"` 250} 251 252// FindKeyByID finds a key in the JWKS by its key ID 253func (j *JWKS) FindKeyByID(kid string) (*JWK, error) { 254 for _, key := range j.Keys { 255 if key.Kid == kid { 256 return &key, nil 257 } 258 } 259 return nil, fmt.Errorf("key with kid %s not found", kid) 260} 261 262// ExtractKeyID extracts the key ID from a JWT token header 263func ExtractKeyID(tokenString string) (string, error) { 264 tokenString = strings.TrimPrefix(tokenString, "Bearer ") 265 parts := strings.Split(tokenString, ".") 266 if len(parts) != 3 { 267 return "", fmt.Errorf("invalid JWT format") 268 } 269 270 // Decode header 271 headerBytes, err := base64.RawURLEncoding.DecodeString(parts[0]) 272 if err != nil { 273 return "", fmt.Errorf("failed to decode header: %w", err) 274 } 275 276 var header struct { 277 Kid string `json:"kid"` 278 } 279 if err := json.Unmarshal(headerBytes, &header); err != nil { 280 return "", fmt.Errorf("failed to unmarshal header: %w", err) 281 } 282 283 if header.Kid == "" { 284 return "", fmt.Errorf("missing kid in token header") 285 } 286 287 return header.Kid, nil 288}