···
16
+
"github.com/golang-jwt/jwt/v5"
19
+
// Claims represents the standard JWT claims we care about
20
+
type Claims struct {
21
+
jwt.RegisteredClaims
22
+
Scope string `json:"scope,omitempty"`
25
+
// ParseJWT parses a JWT token without verification (Phase 1)
26
+
// Returns the claims if the token is valid JSON and has required fields
27
+
func ParseJWT(tokenString string) (*Claims, error) {
28
+
// Remove "Bearer " prefix if present
29
+
tokenString = strings.TrimPrefix(tokenString, "Bearer ")
30
+
tokenString = strings.TrimSpace(tokenString)
32
+
// Parse without verification first to extract claims
33
+
parser := jwt.NewParser(jwt.WithoutClaimsValidation())
34
+
token, _, err := parser.ParseUnverified(tokenString, &Claims{})
36
+
return nil, fmt.Errorf("failed to parse JWT: %w", err)
39
+
claims, ok := token.Claims.(*Claims)
41
+
return nil, fmt.Errorf("invalid claims type")
44
+
// Validate required fields
45
+
if claims.Subject == "" {
46
+
return nil, fmt.Errorf("missing 'sub' claim (user DID)")
49
+
// atProto PDSes may use 'aud' instead of 'iss' for the authorization server
50
+
// If 'iss' is missing, use 'aud' as the authorization server identifier
51
+
if claims.Issuer == "" {
52
+
if len(claims.Audience) > 0 {
53
+
claims.Issuer = claims.Audience[0]
55
+
return nil, fmt.Errorf("missing both 'iss' and 'aud' claims (authorization server)")
59
+
// Validate claims (even in Phase 1, we need basic validation like expiry)
60
+
if err := validateClaims(claims); err != nil {
67
+
// VerifyJWT verifies a JWT token's signature and claims (Phase 2)
68
+
// Fetches the public key from the issuer's JWKS endpoint and validates the signature
69
+
func VerifyJWT(ctx context.Context, tokenString string, keyFetcher JWKSFetcher) (*Claims, error) {
70
+
// First parse to get the issuer
71
+
claims, err := ParseJWT(tokenString)
76
+
// Fetch the public key from the issuer
77
+
publicKey, err := keyFetcher.FetchPublicKey(ctx, claims.Issuer, tokenString)
79
+
return nil, fmt.Errorf("failed to fetch public key: %w", err)
82
+
// Now parse and verify with the public key
83
+
tokenString = strings.TrimPrefix(tokenString, "Bearer ")
84
+
token, err := jwt.ParseWithClaims(tokenString, &Claims{}, func(token *jwt.Token) (interface{}, error) {
85
+
// Validate signing method - support both RSA and ECDSA (atProto uses ES256 primarily)
86
+
switch token.Method.(type) {
87
+
case *jwt.SigningMethodRSA, *jwt.SigningMethodECDSA:
88
+
// Valid signing methods for atProto
90
+
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
92
+
return publicKey, nil
95
+
return nil, fmt.Errorf("failed to verify JWT: %w", err)
99
+
return nil, fmt.Errorf("token is invalid")
102
+
verifiedClaims, ok := token.Claims.(*Claims)
104
+
return nil, fmt.Errorf("invalid claims type after verification")
107
+
// Additional validation
108
+
if err := validateClaims(verifiedClaims); err != nil {
112
+
return verifiedClaims, nil
115
+
// validateClaims performs additional validation on JWT claims
116
+
func validateClaims(claims *Claims) error {
119
+
// Check expiration
120
+
if claims.ExpiresAt != nil && claims.ExpiresAt.Before(now) {
121
+
return fmt.Errorf("token has expired")
124
+
// Check not before
125
+
if claims.NotBefore != nil && claims.NotBefore.After(now) {
126
+
return fmt.Errorf("token not yet valid")
129
+
// Validate DID format in sub claim
130
+
if !strings.HasPrefix(claims.Subject, "did:") {
131
+
return fmt.Errorf("invalid DID format in 'sub' claim: %s", claims.Subject)
134
+
// Validate issuer is either an HTTPS URL or a DID
135
+
// atProto uses DIDs (did:web:, did:plc:) or HTTPS URLs as issuer identifiers
136
+
if !strings.HasPrefix(claims.Issuer, "https://") && !strings.HasPrefix(claims.Issuer, "did:") {
137
+
return fmt.Errorf("issuer must be HTTPS URL or DID, got: %s", claims.Issuer)
140
+
// Parse to ensure it's a valid URL
141
+
if _, err := url.Parse(claims.Issuer); err != nil {
142
+
return fmt.Errorf("invalid issuer URL: %w", err)
145
+
// Validate scope if present (lenient: allow empty, but reject wrong scopes)
146
+
if claims.Scope != "" && !strings.Contains(claims.Scope, "atproto") {
147
+
return fmt.Errorf("token missing required 'atproto' scope, got: %s", claims.Scope)
153
+
// JWKSFetcher defines the interface for fetching public keys from JWKS endpoints
154
+
// Returns interface{} to support both RSA and ECDSA keys
155
+
type JWKSFetcher interface {
156
+
FetchPublicKey(ctx context.Context, issuer, token string) (interface{}, error)
159
+
// JWK represents a JSON Web Key from a JWKS endpoint
160
+
// Supports both RSA and EC (ECDSA) keys
162
+
Kid string `json:"kid"` // Key ID
163
+
Kty string `json:"kty"` // Key type ("RSA" or "EC")
164
+
Alg string `json:"alg"` // Algorithm (e.g., "RS256", "ES256")
165
+
Use string `json:"use"` // Public key use (should be "sig" for signatures)
167
+
N string `json:"n,omitempty"` // RSA modulus
168
+
E string `json:"e,omitempty"` // RSA exponent
170
+
Crv string `json:"crv,omitempty"` // EC curve (e.g., "P-256")
171
+
X string `json:"x,omitempty"` // EC x coordinate
172
+
Y string `json:"y,omitempty"` // EC y coordinate
175
+
// ToPublicKey converts a JWK to a public key (RSA or ECDSA)
176
+
func (j *JWK) ToPublicKey() (interface{}, error) {
179
+
return j.toRSAPublicKey()
181
+
return j.toECPublicKey()
183
+
return nil, fmt.Errorf("unsupported key type: %s", j.Kty)
187
+
// toRSAPublicKey converts a JWK to an RSA public key
188
+
func (j *JWK) toRSAPublicKey() (*rsa.PublicKey, error) {
190
+
nBytes, err := base64.RawURLEncoding.DecodeString(j.N)
192
+
return nil, fmt.Errorf("failed to decode RSA modulus: %w", err)
196
+
eBytes, err := base64.RawURLEncoding.DecodeString(j.E)
198
+
return nil, fmt.Errorf("failed to decode RSA exponent: %w", err)
201
+
// Convert exponent to int
203
+
for _, b := range eBytes {
204
+
eInt = eInt*256 + int(b)
207
+
return &rsa.PublicKey{
208
+
N: new(big.Int).SetBytes(nBytes),
213
+
// toECPublicKey converts a JWK to an ECDSA public key
214
+
func (j *JWK) toECPublicKey() (*ecdsa.PublicKey, error) {
216
+
var curve elliptic.Curve
219
+
curve = elliptic.P256()
221
+
curve = elliptic.P384()
223
+
curve = elliptic.P521()
225
+
return nil, fmt.Errorf("unsupported EC curve: %s", j.Crv)
228
+
// Decode X coordinate
229
+
xBytes, err := base64.RawURLEncoding.DecodeString(j.X)
231
+
return nil, fmt.Errorf("failed to decode EC x coordinate: %w", err)
234
+
// Decode Y coordinate
235
+
yBytes, err := base64.RawURLEncoding.DecodeString(j.Y)
237
+
return nil, fmt.Errorf("failed to decode EC y coordinate: %w", err)
240
+
return &ecdsa.PublicKey{
242
+
X: new(big.Int).SetBytes(xBytes),
243
+
Y: new(big.Int).SetBytes(yBytes),
247
+
// JWKS represents a JSON Web Key Set
249
+
Keys []JWK `json:"keys"`
252
+
// FindKeyByID finds a key in the JWKS by its key ID
253
+
func (j *JWKS) FindKeyByID(kid string) (*JWK, error) {
254
+
for _, key := range j.Keys {
255
+
if key.Kid == kid {
259
+
return nil, fmt.Errorf("key with kid %s not found", kid)
262
+
// ExtractKeyID extracts the key ID from a JWT token header
263
+
func ExtractKeyID(tokenString string) (string, error) {
264
+
tokenString = strings.TrimPrefix(tokenString, "Bearer ")
265
+
parts := strings.Split(tokenString, ".")
266
+
if len(parts) != 3 {
267
+
return "", fmt.Errorf("invalid JWT format")
271
+
headerBytes, err := base64.RawURLEncoding.DecodeString(parts[0])
273
+
return "", fmt.Errorf("failed to decode header: %w", err)
276
+
var header struct {
277
+
Kid string `json:"kid"`
279
+
if err := json.Unmarshal(headerBytes, &header); err != nil {
280
+
return "", fmt.Errorf("failed to unmarshal header: %w", err)
283
+
if header.Kid == "" {
284
+
return "", fmt.Errorf("missing kid in token header")
287
+
return header.Kid, nil