A community based topic aggregation platform built on atproto
1package auth
2
3import (
4 "context"
5 "crypto/ecdsa"
6 "crypto/elliptic"
7 "crypto/rsa"
8 "encoding/base64"
9 "encoding/json"
10 "fmt"
11 "math/big"
12 "net/url"
13 "strings"
14 "time"
15
16 "github.com/golang-jwt/jwt/v5"
17)
18
19// Claims represents the standard JWT claims we care about
20type Claims struct {
21 jwt.RegisteredClaims
22 Scope string `json:"scope,omitempty"`
23}
24
25// ParseJWT parses a JWT token without verification (Phase 1)
26// Returns the claims if the token is valid JSON and has required fields
27func ParseJWT(tokenString string) (*Claims, error) {
28 // Remove "Bearer " prefix if present
29 tokenString = strings.TrimPrefix(tokenString, "Bearer ")
30 tokenString = strings.TrimSpace(tokenString)
31
32 // Parse without verification first to extract claims
33 parser := jwt.NewParser(jwt.WithoutClaimsValidation())
34 token, _, err := parser.ParseUnverified(tokenString, &Claims{})
35 if err != nil {
36 return nil, fmt.Errorf("failed to parse JWT: %w", err)
37 }
38
39 claims, ok := token.Claims.(*Claims)
40 if !ok {
41 return nil, fmt.Errorf("invalid claims type")
42 }
43
44 // Validate required fields
45 if claims.Subject == "" {
46 return nil, fmt.Errorf("missing 'sub' claim (user DID)")
47 }
48
49 // atProto PDSes may use 'aud' instead of 'iss' for the authorization server
50 // If 'iss' is missing, use 'aud' as the authorization server identifier
51 if claims.Issuer == "" {
52 if len(claims.Audience) > 0 {
53 claims.Issuer = claims.Audience[0]
54 } else {
55 return nil, fmt.Errorf("missing both 'iss' and 'aud' claims (authorization server)")
56 }
57 }
58
59 // Validate claims (even in Phase 1, we need basic validation like expiry)
60 if err := validateClaims(claims); err != nil {
61 return nil, err
62 }
63
64 return claims, nil
65}
66
67// VerifyJWT verifies a JWT token's signature and claims (Phase 2)
68// Fetches the public key from the issuer's JWKS endpoint and validates the signature
69func VerifyJWT(ctx context.Context, tokenString string, keyFetcher JWKSFetcher) (*Claims, error) {
70 // First parse to get the issuer
71 claims, err := ParseJWT(tokenString)
72 if err != nil {
73 return nil, err
74 }
75
76 // Fetch the public key from the issuer
77 publicKey, err := keyFetcher.FetchPublicKey(ctx, claims.Issuer, tokenString)
78 if err != nil {
79 return nil, fmt.Errorf("failed to fetch public key: %w", err)
80 }
81
82 // Now parse and verify with the public key
83 tokenString = strings.TrimPrefix(tokenString, "Bearer ")
84 token, err := jwt.ParseWithClaims(tokenString, &Claims{}, func(token *jwt.Token) (interface{}, error) {
85 // Validate signing method - support both RSA and ECDSA (atProto uses ES256 primarily)
86 switch token.Method.(type) {
87 case *jwt.SigningMethodRSA, *jwt.SigningMethodECDSA:
88 // Valid signing methods for atProto
89 default:
90 return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
91 }
92 return publicKey, nil
93 })
94 if err != nil {
95 return nil, fmt.Errorf("failed to verify JWT: %w", err)
96 }
97
98 if !token.Valid {
99 return nil, fmt.Errorf("token is invalid")
100 }
101
102 verifiedClaims, ok := token.Claims.(*Claims)
103 if !ok {
104 return nil, fmt.Errorf("invalid claims type after verification")
105 }
106
107 // Additional validation
108 if err := validateClaims(verifiedClaims); err != nil {
109 return nil, err
110 }
111
112 return verifiedClaims, nil
113}
114
115// validateClaims performs additional validation on JWT claims
116func validateClaims(claims *Claims) error {
117 now := time.Now()
118
119 // Check expiration
120 if claims.ExpiresAt != nil && claims.ExpiresAt.Before(now) {
121 return fmt.Errorf("token has expired")
122 }
123
124 // Check not before
125 if claims.NotBefore != nil && claims.NotBefore.After(now) {
126 return fmt.Errorf("token not yet valid")
127 }
128
129 // Validate DID format in sub claim
130 if !strings.HasPrefix(claims.Subject, "did:") {
131 return fmt.Errorf("invalid DID format in 'sub' claim: %s", claims.Subject)
132 }
133
134 // Validate issuer is either an HTTPS URL or a DID
135 // atProto uses DIDs (did:web:, did:plc:) or HTTPS URLs as issuer identifiers
136 if !strings.HasPrefix(claims.Issuer, "https://") && !strings.HasPrefix(claims.Issuer, "did:") {
137 return fmt.Errorf("issuer must be HTTPS URL or DID, got: %s", claims.Issuer)
138 }
139
140 // Parse to ensure it's a valid URL
141 if _, err := url.Parse(claims.Issuer); err != nil {
142 return fmt.Errorf("invalid issuer URL: %w", err)
143 }
144
145 // Validate scope if present (lenient: allow empty, but reject wrong scopes)
146 if claims.Scope != "" && !strings.Contains(claims.Scope, "atproto") {
147 return fmt.Errorf("token missing required 'atproto' scope, got: %s", claims.Scope)
148 }
149
150 return nil
151}
152
153// JWKSFetcher defines the interface for fetching public keys from JWKS endpoints
154// Returns interface{} to support both RSA and ECDSA keys
155type JWKSFetcher interface {
156 FetchPublicKey(ctx context.Context, issuer, token string) (interface{}, error)
157}
158
159// JWK represents a JSON Web Key from a JWKS endpoint
160// Supports both RSA and EC (ECDSA) keys
161type JWK struct {
162 Kid string `json:"kid"` // Key ID
163 Kty string `json:"kty"` // Key type ("RSA" or "EC")
164 Alg string `json:"alg"` // Algorithm (e.g., "RS256", "ES256")
165 Use string `json:"use"` // Public key use (should be "sig" for signatures)
166 // RSA fields
167 N string `json:"n,omitempty"` // RSA modulus
168 E string `json:"e,omitempty"` // RSA exponent
169 // EC fields
170 Crv string `json:"crv,omitempty"` // EC curve (e.g., "P-256")
171 X string `json:"x,omitempty"` // EC x coordinate
172 Y string `json:"y,omitempty"` // EC y coordinate
173}
174
175// ToPublicKey converts a JWK to a public key (RSA or ECDSA)
176func (j *JWK) ToPublicKey() (interface{}, error) {
177 switch j.Kty {
178 case "RSA":
179 return j.toRSAPublicKey()
180 case "EC":
181 return j.toECPublicKey()
182 default:
183 return nil, fmt.Errorf("unsupported key type: %s", j.Kty)
184 }
185}
186
187// toRSAPublicKey converts a JWK to an RSA public key
188func (j *JWK) toRSAPublicKey() (*rsa.PublicKey, error) {
189 // Decode modulus
190 nBytes, err := base64.RawURLEncoding.DecodeString(j.N)
191 if err != nil {
192 return nil, fmt.Errorf("failed to decode RSA modulus: %w", err)
193 }
194
195 // Decode exponent
196 eBytes, err := base64.RawURLEncoding.DecodeString(j.E)
197 if err != nil {
198 return nil, fmt.Errorf("failed to decode RSA exponent: %w", err)
199 }
200
201 // Convert exponent to int
202 var eInt int
203 for _, b := range eBytes {
204 eInt = eInt*256 + int(b)
205 }
206
207 return &rsa.PublicKey{
208 N: new(big.Int).SetBytes(nBytes),
209 E: eInt,
210 }, nil
211}
212
213// toECPublicKey converts a JWK to an ECDSA public key
214func (j *JWK) toECPublicKey() (*ecdsa.PublicKey, error) {
215 // Determine curve
216 var curve elliptic.Curve
217 switch j.Crv {
218 case "P-256":
219 curve = elliptic.P256()
220 case "P-384":
221 curve = elliptic.P384()
222 case "P-521":
223 curve = elliptic.P521()
224 default:
225 return nil, fmt.Errorf("unsupported EC curve: %s", j.Crv)
226 }
227
228 // Decode X coordinate
229 xBytes, err := base64.RawURLEncoding.DecodeString(j.X)
230 if err != nil {
231 return nil, fmt.Errorf("failed to decode EC x coordinate: %w", err)
232 }
233
234 // Decode Y coordinate
235 yBytes, err := base64.RawURLEncoding.DecodeString(j.Y)
236 if err != nil {
237 return nil, fmt.Errorf("failed to decode EC y coordinate: %w", err)
238 }
239
240 return &ecdsa.PublicKey{
241 Curve: curve,
242 X: new(big.Int).SetBytes(xBytes),
243 Y: new(big.Int).SetBytes(yBytes),
244 }, nil
245}
246
247// JWKS represents a JSON Web Key Set
248type JWKS struct {
249 Keys []JWK `json:"keys"`
250}
251
252// FindKeyByID finds a key in the JWKS by its key ID
253func (j *JWKS) FindKeyByID(kid string) (*JWK, error) {
254 for _, key := range j.Keys {
255 if key.Kid == kid {
256 return &key, nil
257 }
258 }
259 return nil, fmt.Errorf("key with kid %s not found", kid)
260}
261
262// ExtractKeyID extracts the key ID from a JWT token header
263func ExtractKeyID(tokenString string) (string, error) {
264 tokenString = strings.TrimPrefix(tokenString, "Bearer ")
265 parts := strings.Split(tokenString, ".")
266 if len(parts) != 3 {
267 return "", fmt.Errorf("invalid JWT format")
268 }
269
270 // Decode header
271 headerBytes, err := base64.RawURLEncoding.DecodeString(parts[0])
272 if err != nil {
273 return "", fmt.Errorf("failed to decode header: %w", err)
274 }
275
276 var header struct {
277 Kid string `json:"kid"`
278 }
279 if err := json.Unmarshal(headerBytes, &header); err != nil {
280 return "", fmt.Errorf("failed to unmarshal header: %w", err)
281 }
282
283 if header.Kid == "" {
284 return "", fmt.Errorf("missing kid in token header")
285 }
286
287 return header.Kid, nil
288}