A community based topic aggregation platform built on atproto
1package auth
2
3import (
4 "crypto/ecdsa"
5 "crypto/elliptic"
6 "crypto/sha256"
7 "encoding/base64"
8 "encoding/json"
9 "fmt"
10 "math/big"
11 "strings"
12 "sync"
13 "time"
14
15 indigoCrypto "github.com/bluesky-social/indigo/atproto/atcrypto"
16 "github.com/golang-jwt/jwt/v5"
17)
18
19// NonceCache provides replay protection for DPoP proofs by tracking seen jti values.
20// This prevents an attacker from reusing a captured DPoP proof within the validity window.
21// Per RFC 9449 Section 11.1, servers SHOULD prevent replay attacks.
22type NonceCache struct {
23 seen map[string]time.Time // jti -> expiration time
24 stopCh chan struct{}
25 maxAge time.Duration // How long to keep entries
26 cleanup time.Duration // How often to clean up expired entries
27 mu sync.RWMutex
28}
29
30// NewNonceCache creates a new nonce cache for DPoP replay protection.
31// maxAge should match or exceed DPoPVerifier.MaxProofAge.
32func NewNonceCache(maxAge time.Duration) *NonceCache {
33 nc := &NonceCache{
34 seen: make(map[string]time.Time),
35 maxAge: maxAge,
36 cleanup: maxAge / 2, // Clean up at half the max age
37 stopCh: make(chan struct{}),
38 }
39
40 // Start background cleanup goroutine
41 go nc.cleanupLoop()
42
43 return nc
44}
45
46// CheckAndStore checks if a jti has been seen before and stores it if not.
47// Returns true if the jti is fresh (not a replay), false if it's a replay.
48func (nc *NonceCache) CheckAndStore(jti string) bool {
49 nc.mu.Lock()
50 defer nc.mu.Unlock()
51
52 now := time.Now()
53 expiry := now.Add(nc.maxAge)
54
55 // Check if already seen
56 if existingExpiry, seen := nc.seen[jti]; seen {
57 // Still valid (not expired) - this is a replay
58 if existingExpiry.After(now) {
59 return false
60 }
61 // Expired entry - allow reuse and update expiry
62 }
63
64 // Store the new jti
65 nc.seen[jti] = expiry
66 return true
67}
68
69// cleanupLoop periodically removes expired entries from the cache
70func (nc *NonceCache) cleanupLoop() {
71 ticker := time.NewTicker(nc.cleanup)
72 defer ticker.Stop()
73
74 for {
75 select {
76 case <-ticker.C:
77 nc.cleanupExpired()
78 case <-nc.stopCh:
79 return
80 }
81 }
82}
83
84// cleanupExpired removes expired entries from the cache
85func (nc *NonceCache) cleanupExpired() {
86 nc.mu.Lock()
87 defer nc.mu.Unlock()
88
89 now := time.Now()
90 for jti, expiry := range nc.seen {
91 if expiry.Before(now) {
92 delete(nc.seen, jti)
93 }
94 }
95}
96
97// Stop stops the cleanup goroutine. Call this when done with the cache.
98func (nc *NonceCache) Stop() {
99 close(nc.stopCh)
100}
101
102// Size returns the number of entries in the cache (for testing/monitoring)
103func (nc *NonceCache) Size() int {
104 nc.mu.RLock()
105 defer nc.mu.RUnlock()
106 return len(nc.seen)
107}
108
109// DPoPClaims represents the claims in a DPoP proof JWT (RFC 9449)
110type DPoPClaims struct {
111 jwt.RegisteredClaims
112
113 // HTTP method of the request (e.g., "GET", "POST")
114 HTTPMethod string `json:"htm"`
115
116 // HTTP URI of the request (without query and fragment parts)
117 HTTPURI string `json:"htu"`
118
119 // Access token hash (optional, for token binding)
120 AccessTokenHash string `json:"ath,omitempty"`
121}
122
123// DPoPProof represents a parsed and verified DPoP proof
124type DPoPProof struct {
125 RawPublicJWK map[string]interface{}
126 Claims *DPoPClaims
127 PublicKey interface{} // *ecdsa.PublicKey or similar
128 Thumbprint string // JWK thumbprint (base64url)
129}
130
131// DPoPVerifier verifies DPoP proofs for OAuth token binding
132type DPoPVerifier struct {
133 // Optional: custom nonce validation function (for server-issued nonces)
134 ValidateNonce func(nonce string) bool
135
136 // NonceCache for replay protection (optional but recommended)
137 // If nil, jti replay protection is disabled
138 NonceCache *NonceCache
139
140 // Maximum allowed clock skew for timestamp validation
141 MaxClockSkew time.Duration
142
143 // Maximum age of DPoP proof (prevents replay with old proofs)
144 MaxProofAge time.Duration
145}
146
147// NewDPoPVerifier creates a DPoP verifier with sensible defaults including replay protection
148func NewDPoPVerifier() *DPoPVerifier {
149 maxProofAge := 5 * time.Minute
150 return &DPoPVerifier{
151 MaxClockSkew: 30 * time.Second,
152 MaxProofAge: maxProofAge,
153 NonceCache: NewNonceCache(maxProofAge),
154 }
155}
156
157// NewDPoPVerifierWithoutReplayProtection creates a DPoP verifier without replay protection.
158// This should only be used in testing or when replay protection is handled externally.
159func NewDPoPVerifierWithoutReplayProtection() *DPoPVerifier {
160 return &DPoPVerifier{
161 MaxClockSkew: 30 * time.Second,
162 MaxProofAge: 5 * time.Minute,
163 NonceCache: nil, // No replay protection
164 }
165}
166
167// Stop stops background goroutines. Call this when shutting down.
168func (v *DPoPVerifier) Stop() {
169 if v.NonceCache != nil {
170 v.NonceCache.Stop()
171 }
172}
173
174// VerifyDPoPProof verifies a DPoP proof JWT and returns the parsed proof
175func (v *DPoPVerifier) VerifyDPoPProof(dpopProof, httpMethod, httpURI string) (*DPoPProof, error) {
176 // Parse the DPoP JWT without verification first to extract the header
177 parser := jwt.NewParser(jwt.WithoutClaimsValidation())
178 token, _, err := parser.ParseUnverified(dpopProof, &DPoPClaims{})
179 if err != nil {
180 return nil, fmt.Errorf("failed to parse DPoP proof: %w", err)
181 }
182
183 // Extract and validate the header
184 header, ok := token.Header["typ"].(string)
185 if !ok || header != "dpop+jwt" {
186 return nil, fmt.Errorf("invalid DPoP proof: typ must be 'dpop+jwt', got '%s'", header)
187 }
188
189 alg, ok := token.Header["alg"].(string)
190 if !ok {
191 return nil, fmt.Errorf("invalid DPoP proof: missing alg header")
192 }
193
194 // Extract the JWK from the header
195 jwkRaw, ok := token.Header["jwk"]
196 if !ok {
197 return nil, fmt.Errorf("invalid DPoP proof: missing jwk header")
198 }
199
200 jwkMap, ok := jwkRaw.(map[string]interface{})
201 if !ok {
202 return nil, fmt.Errorf("invalid DPoP proof: jwk must be an object")
203 }
204
205 // Parse the public key from JWK
206 publicKey, err := parseJWKToPublicKey(jwkMap)
207 if err != nil {
208 return nil, fmt.Errorf("invalid DPoP proof JWK: %w", err)
209 }
210
211 // Calculate the JWK thumbprint
212 thumbprint, err := CalculateJWKThumbprint(jwkMap)
213 if err != nil {
214 return nil, fmt.Errorf("failed to calculate JWK thumbprint: %w", err)
215 }
216
217 // Now verify the signature
218 verifiedToken, err := jwt.ParseWithClaims(dpopProof, &DPoPClaims{}, func(token *jwt.Token) (interface{}, error) {
219 // Verify the signing method matches what we expect
220 switch alg {
221 case "ES256":
222 if _, ok := token.Method.(*jwt.SigningMethodECDSA); !ok {
223 return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
224 }
225 case "ES384", "ES512":
226 if _, ok := token.Method.(*jwt.SigningMethodECDSA); !ok {
227 return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
228 }
229 case "RS256", "RS384", "RS512", "PS256", "PS384", "PS512":
230 // RSA methods - we primarily support ES256 for atproto
231 return nil, fmt.Errorf("RSA algorithms not yet supported for DPoP: %s", alg)
232 default:
233 return nil, fmt.Errorf("unsupported DPoP algorithm: %s", alg)
234 }
235 return publicKey, nil
236 })
237 if err != nil {
238 return nil, fmt.Errorf("DPoP proof signature verification failed: %w", err)
239 }
240
241 claims, ok := verifiedToken.Claims.(*DPoPClaims)
242 if !ok {
243 return nil, fmt.Errorf("invalid DPoP claims type")
244 }
245
246 // Validate the claims
247 if err := v.validateDPoPClaims(claims, httpMethod, httpURI); err != nil {
248 return nil, err
249 }
250
251 return &DPoPProof{
252 Claims: claims,
253 PublicKey: publicKey,
254 Thumbprint: thumbprint,
255 RawPublicJWK: jwkMap,
256 }, nil
257}
258
259// validateDPoPClaims validates the DPoP proof claims
260func (v *DPoPVerifier) validateDPoPClaims(claims *DPoPClaims, expectedMethod, expectedURI string) error {
261 // Validate jti (unique identifier) is present
262 if claims.ID == "" {
263 return fmt.Errorf("DPoP proof missing jti claim")
264 }
265
266 // Validate htm (HTTP method)
267 if !strings.EqualFold(claims.HTTPMethod, expectedMethod) {
268 return fmt.Errorf("DPoP proof htm mismatch: expected %s, got %s", expectedMethod, claims.HTTPMethod)
269 }
270
271 // Validate htu (HTTP URI) - compare without query/fragment
272 expectedURIBase := stripQueryFragment(expectedURI)
273 claimURIBase := stripQueryFragment(claims.HTTPURI)
274 if expectedURIBase != claimURIBase {
275 return fmt.Errorf("DPoP proof htu mismatch: expected %s, got %s", expectedURIBase, claimURIBase)
276 }
277
278 // Validate iat (issued at) is present and recent
279 if claims.IssuedAt == nil {
280 return fmt.Errorf("DPoP proof missing iat claim")
281 }
282
283 now := time.Now()
284 iat := claims.IssuedAt.Time
285
286 // Check clock skew (not too far in the future)
287 if iat.After(now.Add(v.MaxClockSkew)) {
288 return fmt.Errorf("DPoP proof iat is in the future")
289 }
290
291 // Check proof age (not too old)
292 if now.Sub(iat) > v.MaxProofAge {
293 return fmt.Errorf("DPoP proof is too old (issued %v ago, max %v)", now.Sub(iat), v.MaxProofAge)
294 }
295
296 // SECURITY: Check for replay attack using jti
297 // Per RFC 9449 Section 11.1, servers SHOULD prevent replay attacks
298 if v.NonceCache != nil {
299 if !v.NonceCache.CheckAndStore(claims.ID) {
300 return fmt.Errorf("DPoP proof replay detected: jti %s already used", claims.ID)
301 }
302 }
303
304 return nil
305}
306
307// VerifyTokenBinding verifies that the DPoP proof binds to the access token
308// by comparing the proof's thumbprint to the token's cnf.jkt claim
309func (v *DPoPVerifier) VerifyTokenBinding(proof *DPoPProof, expectedThumbprint string) error {
310 if proof.Thumbprint != expectedThumbprint {
311 return fmt.Errorf("DPoP proof thumbprint mismatch: token expects %s, proof has %s",
312 expectedThumbprint, proof.Thumbprint)
313 }
314 return nil
315}
316
317// CalculateJWKThumbprint calculates the JWK thumbprint per RFC 7638
318// The thumbprint is the base64url-encoded SHA-256 hash of the canonical JWK representation
319func CalculateJWKThumbprint(jwk map[string]interface{}) (string, error) {
320 kty, ok := jwk["kty"].(string)
321 if !ok {
322 return "", fmt.Errorf("JWK missing kty")
323 }
324
325 // Build the canonical JWK representation based on key type
326 // Per RFC 7638, only specific members are included, in lexicographic order
327 var canonical map[string]string
328
329 switch kty {
330 case "EC":
331 crv, ok := jwk["crv"].(string)
332 if !ok {
333 return "", fmt.Errorf("EC JWK missing crv")
334 }
335 x, ok := jwk["x"].(string)
336 if !ok {
337 return "", fmt.Errorf("EC JWK missing x")
338 }
339 y, ok := jwk["y"].(string)
340 if !ok {
341 return "", fmt.Errorf("EC JWK missing y")
342 }
343 // Lexicographic order: crv, kty, x, y
344 canonical = map[string]string{
345 "crv": crv,
346 "kty": kty,
347 "x": x,
348 "y": y,
349 }
350 case "RSA":
351 e, ok := jwk["e"].(string)
352 if !ok {
353 return "", fmt.Errorf("RSA JWK missing e")
354 }
355 n, ok := jwk["n"].(string)
356 if !ok {
357 return "", fmt.Errorf("RSA JWK missing n")
358 }
359 // Lexicographic order: e, kty, n
360 canonical = map[string]string{
361 "e": e,
362 "kty": kty,
363 "n": n,
364 }
365 case "OKP":
366 crv, ok := jwk["crv"].(string)
367 if !ok {
368 return "", fmt.Errorf("OKP JWK missing crv")
369 }
370 x, ok := jwk["x"].(string)
371 if !ok {
372 return "", fmt.Errorf("OKP JWK missing x")
373 }
374 // Lexicographic order: crv, kty, x
375 canonical = map[string]string{
376 "crv": crv,
377 "kty": kty,
378 "x": x,
379 }
380 default:
381 return "", fmt.Errorf("unsupported JWK key type: %s", kty)
382 }
383
384 // Serialize to JSON (Go's json.Marshal produces lexicographically ordered keys for map[string]string)
385 canonicalJSON, err := json.Marshal(canonical)
386 if err != nil {
387 return "", fmt.Errorf("failed to serialize canonical JWK: %w", err)
388 }
389
390 // SHA-256 hash
391 hash := sha256.Sum256(canonicalJSON)
392
393 // Base64url encode (no padding)
394 thumbprint := base64.RawURLEncoding.EncodeToString(hash[:])
395
396 return thumbprint, nil
397}
398
399// parseJWKToPublicKey parses a JWK map to a Go public key
400func parseJWKToPublicKey(jwkMap map[string]interface{}) (interface{}, error) {
401 // Convert map to JSON bytes for indigo's parser
402 jwkBytes, err := json.Marshal(jwkMap)
403 if err != nil {
404 return nil, fmt.Errorf("failed to serialize JWK: %w", err)
405 }
406
407 // Try to parse with indigo's crypto package
408 pubKey, err := indigoCrypto.ParsePublicJWKBytes(jwkBytes)
409 if err != nil {
410 return nil, fmt.Errorf("failed to parse JWK: %w", err)
411 }
412
413 // Convert indigo's PublicKey to Go's ecdsa.PublicKey
414 jwk, err := pubKey.JWK()
415 if err != nil {
416 return nil, fmt.Errorf("failed to get JWK from public key: %w", err)
417 }
418
419 // Use our existing conversion function
420 return atcryptoJWKToECDSAFromIndigoJWK(jwk)
421}
422
423// atcryptoJWKToECDSAFromIndigoJWK converts an indigo JWK to Go ecdsa.PublicKey
424func atcryptoJWKToECDSAFromIndigoJWK(jwk *indigoCrypto.JWK) (*ecdsa.PublicKey, error) {
425 if jwk.KeyType != "EC" {
426 return nil, fmt.Errorf("unsupported JWK key type: %s (expected EC)", jwk.KeyType)
427 }
428
429 xBytes, err := base64.RawURLEncoding.DecodeString(jwk.X)
430 if err != nil {
431 return nil, fmt.Errorf("invalid JWK X coordinate: %w", err)
432 }
433 yBytes, err := base64.RawURLEncoding.DecodeString(jwk.Y)
434 if err != nil {
435 return nil, fmt.Errorf("invalid JWK Y coordinate: %w", err)
436 }
437
438 var curve ecdsa.PublicKey
439 switch jwk.Curve {
440 case "P-256":
441 curve.Curve = ecdsaP256Curve()
442 case "P-384":
443 curve.Curve = ecdsaP384Curve()
444 case "P-521":
445 curve.Curve = ecdsaP521Curve()
446 default:
447 return nil, fmt.Errorf("unsupported curve: %s", jwk.Curve)
448 }
449
450 curve.X = new(big.Int).SetBytes(xBytes)
451 curve.Y = new(big.Int).SetBytes(yBytes)
452
453 return &curve, nil
454}
455
456// Helper functions for elliptic curves
457func ecdsaP256Curve() elliptic.Curve { return elliptic.P256() }
458func ecdsaP384Curve() elliptic.Curve { return elliptic.P384() }
459func ecdsaP521Curve() elliptic.Curve { return elliptic.P521() }
460
461// stripQueryFragment removes query and fragment from a URI
462func stripQueryFragment(uri string) string {
463 if idx := strings.Index(uri, "?"); idx != -1 {
464 uri = uri[:idx]
465 }
466 if idx := strings.Index(uri, "#"); idx != -1 {
467 uri = uri[:idx]
468 }
469 return uri
470}
471
472// ExtractCnfJkt extracts the cnf.jkt (confirmation key thumbprint) from JWT claims
473func ExtractCnfJkt(claims *Claims) (string, error) {
474 if claims.Confirmation == nil {
475 return "", fmt.Errorf("token missing cnf claim (no DPoP binding)")
476 }
477
478 jkt, ok := claims.Confirmation["jkt"].(string)
479 if !ok || jkt == "" {
480 return "", fmt.Errorf("token cnf claim missing jkt (DPoP key thumbprint)")
481 }
482
483 return jkt, nil
484}