A community based topic aggregation platform built on atproto
1package auth 2 3import ( 4 "context" 5 "crypto/ecdsa" 6 "crypto/elliptic" 7 "crypto/rsa" 8 "encoding/base64" 9 "encoding/json" 10 "fmt" 11 "math/big" 12 "net/url" 13 "os" 14 "strings" 15 "time" 16 17 "github.com/golang-jwt/jwt/v5" 18) 19 20// Claims represents the standard JWT claims we care about 21type Claims struct { 22 jwt.RegisteredClaims 23 Scope string `json:"scope,omitempty"` 24} 25 26// ParseJWT parses a JWT token without verification (Phase 1) 27// Returns the claims if the token is valid JSON and has required fields 28func ParseJWT(tokenString string) (*Claims, error) { 29 // Remove "Bearer " prefix if present 30 tokenString = strings.TrimPrefix(tokenString, "Bearer ") 31 tokenString = strings.TrimSpace(tokenString) 32 33 // Parse without verification first to extract claims 34 parser := jwt.NewParser(jwt.WithoutClaimsValidation()) 35 token, _, err := parser.ParseUnverified(tokenString, &Claims{}) 36 if err != nil { 37 return nil, fmt.Errorf("failed to parse JWT: %w", err) 38 } 39 40 claims, ok := token.Claims.(*Claims) 41 if !ok { 42 return nil, fmt.Errorf("invalid claims type") 43 } 44 45 // Validate required fields 46 if claims.Subject == "" { 47 return nil, fmt.Errorf("missing 'sub' claim (user DID)") 48 } 49 50 // atProto PDSes may use 'aud' instead of 'iss' for the authorization server 51 // If 'iss' is missing, use 'aud' as the authorization server identifier 52 if claims.Issuer == "" { 53 if len(claims.Audience) > 0 { 54 claims.Issuer = claims.Audience[0] 55 } else { 56 return nil, fmt.Errorf("missing both 'iss' and 'aud' claims (authorization server)") 57 } 58 } 59 60 // Validate claims (even in Phase 1, we need basic validation like expiry) 61 if err := validateClaims(claims); err != nil { 62 return nil, err 63 } 64 65 return claims, nil 66} 67 68// VerifyJWT verifies a JWT token's signature and claims (Phase 2) 69// Fetches the public key from the issuer's JWKS endpoint and validates the signature 70func VerifyJWT(ctx context.Context, tokenString string, keyFetcher JWKSFetcher) (*Claims, error) { 71 // First parse to get the issuer 72 claims, err := ParseJWT(tokenString) 73 if err != nil { 74 return nil, err 75 } 76 77 // Fetch the public key from the issuer 78 publicKey, err := keyFetcher.FetchPublicKey(ctx, claims.Issuer, tokenString) 79 if err != nil { 80 return nil, fmt.Errorf("failed to fetch public key: %w", err) 81 } 82 83 // Now parse and verify with the public key 84 tokenString = strings.TrimPrefix(tokenString, "Bearer ") 85 token, err := jwt.ParseWithClaims(tokenString, &Claims{}, func(token *jwt.Token) (interface{}, error) { 86 // Validate signing method - support both RSA and ECDSA (atProto uses ES256 primarily) 87 switch token.Method.(type) { 88 case *jwt.SigningMethodRSA, *jwt.SigningMethodECDSA: 89 // Valid signing methods for atProto 90 default: 91 return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"]) 92 } 93 return publicKey, nil 94 }) 95 if err != nil { 96 return nil, fmt.Errorf("failed to verify JWT: %w", err) 97 } 98 99 if !token.Valid { 100 return nil, fmt.Errorf("token is invalid") 101 } 102 103 verifiedClaims, ok := token.Claims.(*Claims) 104 if !ok { 105 return nil, fmt.Errorf("invalid claims type after verification") 106 } 107 108 // Additional validation 109 if err := validateClaims(verifiedClaims); err != nil { 110 return nil, err 111 } 112 113 return verifiedClaims, nil 114} 115 116// validateClaims performs additional validation on JWT claims 117func validateClaims(claims *Claims) error { 118 now := time.Now() 119 120 // Check expiration 121 if claims.ExpiresAt != nil && claims.ExpiresAt.Before(now) { 122 return fmt.Errorf("token has expired") 123 } 124 125 // Check not before 126 if claims.NotBefore != nil && claims.NotBefore.After(now) { 127 return fmt.Errorf("token not yet valid") 128 } 129 130 // Validate DID format in sub claim 131 if !strings.HasPrefix(claims.Subject, "did:") { 132 return fmt.Errorf("invalid DID format in 'sub' claim: %s", claims.Subject) 133 } 134 135 // Validate issuer is either an HTTPS URL or a DID 136 // atProto uses DIDs (did:web:, did:plc:) or HTTPS URLs as issuer identifiers 137 // In dev mode (IS_DEV_ENV=true), allow HTTP for local PDS testing 138 isHTTP := strings.HasPrefix(claims.Issuer, "http://") 139 isHTTPS := strings.HasPrefix(claims.Issuer, "https://") 140 isDID := strings.HasPrefix(claims.Issuer, "did:") 141 142 if !isHTTPS && !isDID && !isHTTP { 143 return fmt.Errorf("issuer must be HTTPS URL, HTTP URL (dev only), or DID, got: %s", claims.Issuer) 144 } 145 146 // In production, reject HTTP issuers (only for non-dev environments) 147 // Check IS_DEV_ENV environment variable 148 if isHTTP && os.Getenv("IS_DEV_ENV") != "true" { 149 return fmt.Errorf("HTTP issuer not allowed in production, got: %s", claims.Issuer) 150 } 151 152 // Parse to ensure it's a valid URL 153 if _, err := url.Parse(claims.Issuer); err != nil { 154 return fmt.Errorf("invalid issuer URL: %w", err) 155 } 156 157 // Validate scope if present (lenient: allow empty, but reject wrong scopes) 158 if claims.Scope != "" && !strings.Contains(claims.Scope, "atproto") { 159 return fmt.Errorf("token missing required 'atproto' scope, got: %s", claims.Scope) 160 } 161 162 return nil 163} 164 165// JWKSFetcher defines the interface for fetching public keys from JWKS endpoints 166// Returns interface{} to support both RSA and ECDSA keys 167type JWKSFetcher interface { 168 FetchPublicKey(ctx context.Context, issuer, token string) (interface{}, error) 169} 170 171// JWK represents a JSON Web Key from a JWKS endpoint 172// Supports both RSA and EC (ECDSA) keys 173type JWK struct { 174 Kid string `json:"kid"` // Key ID 175 Kty string `json:"kty"` // Key type ("RSA" or "EC") 176 Alg string `json:"alg"` // Algorithm (e.g., "RS256", "ES256") 177 Use string `json:"use"` // Public key use (should be "sig" for signatures) 178 // RSA fields 179 N string `json:"n,omitempty"` // RSA modulus 180 E string `json:"e,omitempty"` // RSA exponent 181 // EC fields 182 Crv string `json:"crv,omitempty"` // EC curve (e.g., "P-256") 183 X string `json:"x,omitempty"` // EC x coordinate 184 Y string `json:"y,omitempty"` // EC y coordinate 185} 186 187// ToPublicKey converts a JWK to a public key (RSA or ECDSA) 188func (j *JWK) ToPublicKey() (interface{}, error) { 189 switch j.Kty { 190 case "RSA": 191 return j.toRSAPublicKey() 192 case "EC": 193 return j.toECPublicKey() 194 default: 195 return nil, fmt.Errorf("unsupported key type: %s", j.Kty) 196 } 197} 198 199// toRSAPublicKey converts a JWK to an RSA public key 200func (j *JWK) toRSAPublicKey() (*rsa.PublicKey, error) { 201 // Decode modulus 202 nBytes, err := base64.RawURLEncoding.DecodeString(j.N) 203 if err != nil { 204 return nil, fmt.Errorf("failed to decode RSA modulus: %w", err) 205 } 206 207 // Decode exponent 208 eBytes, err := base64.RawURLEncoding.DecodeString(j.E) 209 if err != nil { 210 return nil, fmt.Errorf("failed to decode RSA exponent: %w", err) 211 } 212 213 // Convert exponent to int 214 var eInt int 215 for _, b := range eBytes { 216 eInt = eInt*256 + int(b) 217 } 218 219 return &rsa.PublicKey{ 220 N: new(big.Int).SetBytes(nBytes), 221 E: eInt, 222 }, nil 223} 224 225// toECPublicKey converts a JWK to an ECDSA public key 226func (j *JWK) toECPublicKey() (*ecdsa.PublicKey, error) { 227 // Determine curve 228 var curve elliptic.Curve 229 switch j.Crv { 230 case "P-256": 231 curve = elliptic.P256() 232 case "P-384": 233 curve = elliptic.P384() 234 case "P-521": 235 curve = elliptic.P521() 236 default: 237 return nil, fmt.Errorf("unsupported EC curve: %s", j.Crv) 238 } 239 240 // Decode X coordinate 241 xBytes, err := base64.RawURLEncoding.DecodeString(j.X) 242 if err != nil { 243 return nil, fmt.Errorf("failed to decode EC x coordinate: %w", err) 244 } 245 246 // Decode Y coordinate 247 yBytes, err := base64.RawURLEncoding.DecodeString(j.Y) 248 if err != nil { 249 return nil, fmt.Errorf("failed to decode EC y coordinate: %w", err) 250 } 251 252 return &ecdsa.PublicKey{ 253 Curve: curve, 254 X: new(big.Int).SetBytes(xBytes), 255 Y: new(big.Int).SetBytes(yBytes), 256 }, nil 257} 258 259// JWKS represents a JSON Web Key Set 260type JWKS struct { 261 Keys []JWK `json:"keys"` 262} 263 264// FindKeyByID finds a key in the JWKS by its key ID 265func (j *JWKS) FindKeyByID(kid string) (*JWK, error) { 266 for _, key := range j.Keys { 267 if key.Kid == kid { 268 return &key, nil 269 } 270 } 271 return nil, fmt.Errorf("key with kid %s not found", kid) 272} 273 274// ExtractKeyID extracts the key ID from a JWT token header 275func ExtractKeyID(tokenString string) (string, error) { 276 tokenString = strings.TrimPrefix(tokenString, "Bearer ") 277 parts := strings.Split(tokenString, ".") 278 if len(parts) != 3 { 279 return "", fmt.Errorf("invalid JWT format") 280 } 281 282 // Decode header 283 headerBytes, err := base64.RawURLEncoding.DecodeString(parts[0]) 284 if err != nil { 285 return "", fmt.Errorf("failed to decode header: %w", err) 286 } 287 288 var header struct { 289 Kid string `json:"kid"` 290 } 291 if err := json.Unmarshal(headerBytes, &header); err != nil { 292 return "", fmt.Errorf("failed to unmarshal header: %w", err) 293 } 294 295 if header.Kid == "" { 296 return "", fmt.Errorf("missing kid in token header") 297 } 298 299 return header.Kid, nil 300}