A community based topic aggregation platform built on atproto
1package auth
2
3import (
4 "context"
5 "crypto/ecdsa"
6 "crypto/elliptic"
7 "crypto/rsa"
8 "encoding/base64"
9 "encoding/json"
10 "fmt"
11 "math/big"
12 "net/url"
13 "os"
14 "strings"
15 "time"
16
17 "github.com/golang-jwt/jwt/v5"
18)
19
20// Claims represents the standard JWT claims we care about
21type Claims struct {
22 jwt.RegisteredClaims
23 Scope string `json:"scope,omitempty"`
24}
25
26// ParseJWT parses a JWT token without verification (Phase 1)
27// Returns the claims if the token is valid JSON and has required fields
28func ParseJWT(tokenString string) (*Claims, error) {
29 // Remove "Bearer " prefix if present
30 tokenString = strings.TrimPrefix(tokenString, "Bearer ")
31 tokenString = strings.TrimSpace(tokenString)
32
33 // Parse without verification first to extract claims
34 parser := jwt.NewParser(jwt.WithoutClaimsValidation())
35 token, _, err := parser.ParseUnverified(tokenString, &Claims{})
36 if err != nil {
37 return nil, fmt.Errorf("failed to parse JWT: %w", err)
38 }
39
40 claims, ok := token.Claims.(*Claims)
41 if !ok {
42 return nil, fmt.Errorf("invalid claims type")
43 }
44
45 // Validate required fields
46 if claims.Subject == "" {
47 return nil, fmt.Errorf("missing 'sub' claim (user DID)")
48 }
49
50 // atProto PDSes may use 'aud' instead of 'iss' for the authorization server
51 // If 'iss' is missing, use 'aud' as the authorization server identifier
52 if claims.Issuer == "" {
53 if len(claims.Audience) > 0 {
54 claims.Issuer = claims.Audience[0]
55 } else {
56 return nil, fmt.Errorf("missing both 'iss' and 'aud' claims (authorization server)")
57 }
58 }
59
60 // Validate claims (even in Phase 1, we need basic validation like expiry)
61 if err := validateClaims(claims); err != nil {
62 return nil, err
63 }
64
65 return claims, nil
66}
67
68// VerifyJWT verifies a JWT token's signature and claims (Phase 2)
69// Fetches the public key from the issuer's JWKS endpoint and validates the signature
70func VerifyJWT(ctx context.Context, tokenString string, keyFetcher JWKSFetcher) (*Claims, error) {
71 // First parse to get the issuer
72 claims, err := ParseJWT(tokenString)
73 if err != nil {
74 return nil, err
75 }
76
77 // Fetch the public key from the issuer
78 publicKey, err := keyFetcher.FetchPublicKey(ctx, claims.Issuer, tokenString)
79 if err != nil {
80 return nil, fmt.Errorf("failed to fetch public key: %w", err)
81 }
82
83 // Now parse and verify with the public key
84 tokenString = strings.TrimPrefix(tokenString, "Bearer ")
85 token, err := jwt.ParseWithClaims(tokenString, &Claims{}, func(token *jwt.Token) (interface{}, error) {
86 // Validate signing method - support both RSA and ECDSA (atProto uses ES256 primarily)
87 switch token.Method.(type) {
88 case *jwt.SigningMethodRSA, *jwt.SigningMethodECDSA:
89 // Valid signing methods for atProto
90 default:
91 return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
92 }
93 return publicKey, nil
94 })
95 if err != nil {
96 return nil, fmt.Errorf("failed to verify JWT: %w", err)
97 }
98
99 if !token.Valid {
100 return nil, fmt.Errorf("token is invalid")
101 }
102
103 verifiedClaims, ok := token.Claims.(*Claims)
104 if !ok {
105 return nil, fmt.Errorf("invalid claims type after verification")
106 }
107
108 // Additional validation
109 if err := validateClaims(verifiedClaims); err != nil {
110 return nil, err
111 }
112
113 return verifiedClaims, nil
114}
115
116// validateClaims performs additional validation on JWT claims
117func validateClaims(claims *Claims) error {
118 now := time.Now()
119
120 // Check expiration
121 if claims.ExpiresAt != nil && claims.ExpiresAt.Before(now) {
122 return fmt.Errorf("token has expired")
123 }
124
125 // Check not before
126 if claims.NotBefore != nil && claims.NotBefore.After(now) {
127 return fmt.Errorf("token not yet valid")
128 }
129
130 // Validate DID format in sub claim
131 if !strings.HasPrefix(claims.Subject, "did:") {
132 return fmt.Errorf("invalid DID format in 'sub' claim: %s", claims.Subject)
133 }
134
135 // Validate issuer is either an HTTPS URL or a DID
136 // atProto uses DIDs (did:web:, did:plc:) or HTTPS URLs as issuer identifiers
137 // In dev mode (IS_DEV_ENV=true), allow HTTP for local PDS testing
138 isHTTP := strings.HasPrefix(claims.Issuer, "http://")
139 isHTTPS := strings.HasPrefix(claims.Issuer, "https://")
140 isDID := strings.HasPrefix(claims.Issuer, "did:")
141
142 if !isHTTPS && !isDID && !isHTTP {
143 return fmt.Errorf("issuer must be HTTPS URL, HTTP URL (dev only), or DID, got: %s", claims.Issuer)
144 }
145
146 // In production, reject HTTP issuers (only for non-dev environments)
147 // Check IS_DEV_ENV environment variable
148 if isHTTP && os.Getenv("IS_DEV_ENV") != "true" {
149 return fmt.Errorf("HTTP issuer not allowed in production, got: %s", claims.Issuer)
150 }
151
152 // Parse to ensure it's a valid URL
153 if _, err := url.Parse(claims.Issuer); err != nil {
154 return fmt.Errorf("invalid issuer URL: %w", err)
155 }
156
157 // Validate scope if present (lenient: allow empty, but reject wrong scopes)
158 if claims.Scope != "" && !strings.Contains(claims.Scope, "atproto") {
159 return fmt.Errorf("token missing required 'atproto' scope, got: %s", claims.Scope)
160 }
161
162 return nil
163}
164
165// JWKSFetcher defines the interface for fetching public keys from JWKS endpoints
166// Returns interface{} to support both RSA and ECDSA keys
167type JWKSFetcher interface {
168 FetchPublicKey(ctx context.Context, issuer, token string) (interface{}, error)
169}
170
171// JWK represents a JSON Web Key from a JWKS endpoint
172// Supports both RSA and EC (ECDSA) keys
173type JWK struct {
174 Kid string `json:"kid"` // Key ID
175 Kty string `json:"kty"` // Key type ("RSA" or "EC")
176 Alg string `json:"alg"` // Algorithm (e.g., "RS256", "ES256")
177 Use string `json:"use"` // Public key use (should be "sig" for signatures)
178 // RSA fields
179 N string `json:"n,omitempty"` // RSA modulus
180 E string `json:"e,omitempty"` // RSA exponent
181 // EC fields
182 Crv string `json:"crv,omitempty"` // EC curve (e.g., "P-256")
183 X string `json:"x,omitempty"` // EC x coordinate
184 Y string `json:"y,omitempty"` // EC y coordinate
185}
186
187// ToPublicKey converts a JWK to a public key (RSA or ECDSA)
188func (j *JWK) ToPublicKey() (interface{}, error) {
189 switch j.Kty {
190 case "RSA":
191 return j.toRSAPublicKey()
192 case "EC":
193 return j.toECPublicKey()
194 default:
195 return nil, fmt.Errorf("unsupported key type: %s", j.Kty)
196 }
197}
198
199// toRSAPublicKey converts a JWK to an RSA public key
200func (j *JWK) toRSAPublicKey() (*rsa.PublicKey, error) {
201 // Decode modulus
202 nBytes, err := base64.RawURLEncoding.DecodeString(j.N)
203 if err != nil {
204 return nil, fmt.Errorf("failed to decode RSA modulus: %w", err)
205 }
206
207 // Decode exponent
208 eBytes, err := base64.RawURLEncoding.DecodeString(j.E)
209 if err != nil {
210 return nil, fmt.Errorf("failed to decode RSA exponent: %w", err)
211 }
212
213 // Convert exponent to int
214 var eInt int
215 for _, b := range eBytes {
216 eInt = eInt*256 + int(b)
217 }
218
219 return &rsa.PublicKey{
220 N: new(big.Int).SetBytes(nBytes),
221 E: eInt,
222 }, nil
223}
224
225// toECPublicKey converts a JWK to an ECDSA public key
226func (j *JWK) toECPublicKey() (*ecdsa.PublicKey, error) {
227 // Determine curve
228 var curve elliptic.Curve
229 switch j.Crv {
230 case "P-256":
231 curve = elliptic.P256()
232 case "P-384":
233 curve = elliptic.P384()
234 case "P-521":
235 curve = elliptic.P521()
236 default:
237 return nil, fmt.Errorf("unsupported EC curve: %s", j.Crv)
238 }
239
240 // Decode X coordinate
241 xBytes, err := base64.RawURLEncoding.DecodeString(j.X)
242 if err != nil {
243 return nil, fmt.Errorf("failed to decode EC x coordinate: %w", err)
244 }
245
246 // Decode Y coordinate
247 yBytes, err := base64.RawURLEncoding.DecodeString(j.Y)
248 if err != nil {
249 return nil, fmt.Errorf("failed to decode EC y coordinate: %w", err)
250 }
251
252 return &ecdsa.PublicKey{
253 Curve: curve,
254 X: new(big.Int).SetBytes(xBytes),
255 Y: new(big.Int).SetBytes(yBytes),
256 }, nil
257}
258
259// JWKS represents a JSON Web Key Set
260type JWKS struct {
261 Keys []JWK `json:"keys"`
262}
263
264// FindKeyByID finds a key in the JWKS by its key ID
265func (j *JWKS) FindKeyByID(kid string) (*JWK, error) {
266 for _, key := range j.Keys {
267 if key.Kid == kid {
268 return &key, nil
269 }
270 }
271 return nil, fmt.Errorf("key with kid %s not found", kid)
272}
273
274// ExtractKeyID extracts the key ID from a JWT token header
275func ExtractKeyID(tokenString string) (string, error) {
276 tokenString = strings.TrimPrefix(tokenString, "Bearer ")
277 parts := strings.Split(tokenString, ".")
278 if len(parts) != 3 {
279 return "", fmt.Errorf("invalid JWT format")
280 }
281
282 // Decode header
283 headerBytes, err := base64.RawURLEncoding.DecodeString(parts[0])
284 if err != nil {
285 return "", fmt.Errorf("failed to decode header: %w", err)
286 }
287
288 var header struct {
289 Kid string `json:"kid"`
290 }
291 if err := json.Unmarshal(headerBytes, &header); err != nil {
292 return "", fmt.Errorf("failed to unmarshal header: %w", err)
293 }
294
295 if header.Kid == "" {
296 return "", fmt.Errorf("missing kid in token header")
297 }
298
299 return header.Kid, nil
300}