A community based topic aggregation platform built on atproto
1package auth
2
3import (
4 "crypto/ecdsa"
5 "crypto/elliptic"
6 "crypto/sha256"
7 "encoding/base64"
8 "encoding/json"
9 "fmt"
10 "math/big"
11 "strings"
12 "sync"
13 "time"
14
15 indigoCrypto "github.com/bluesky-social/indigo/atproto/atcrypto"
16 "github.com/golang-jwt/jwt/v5"
17)
18
19// NonceCache provides replay protection for DPoP proofs by tracking seen jti values.
20// This prevents an attacker from reusing a captured DPoP proof within the validity window.
21// Per RFC 9449 Section 11.1, servers SHOULD prevent replay attacks.
22type NonceCache struct {
23 seen map[string]time.Time // jti -> expiration time
24 stopCh chan struct{}
25 maxAge time.Duration // How long to keep entries
26 cleanup time.Duration // How often to clean up expired entries
27 mu sync.RWMutex
28}
29
30// NewNonceCache creates a new nonce cache for DPoP replay protection.
31// maxAge should match or exceed DPoPVerifier.MaxProofAge.
32func NewNonceCache(maxAge time.Duration) *NonceCache {
33 nc := &NonceCache{
34 seen: make(map[string]time.Time),
35 maxAge: maxAge,
36 cleanup: maxAge / 2, // Clean up at half the max age
37 stopCh: make(chan struct{}),
38 }
39
40 // Start background cleanup goroutine
41 go nc.cleanupLoop()
42
43 return nc
44}
45
46// CheckAndStore checks if a jti has been seen before and stores it if not.
47// Returns true if the jti is fresh (not a replay), false if it's a replay.
48func (nc *NonceCache) CheckAndStore(jti string) bool {
49 nc.mu.Lock()
50 defer nc.mu.Unlock()
51
52 now := time.Now()
53 expiry := now.Add(nc.maxAge)
54
55 // Check if already seen
56 if existingExpiry, seen := nc.seen[jti]; seen {
57 // Still valid (not expired) - this is a replay
58 if existingExpiry.After(now) {
59 return false
60 }
61 // Expired entry - allow reuse and update expiry
62 }
63
64 // Store the new jti
65 nc.seen[jti] = expiry
66 return true
67}
68
69// cleanupLoop periodically removes expired entries from the cache
70func (nc *NonceCache) cleanupLoop() {
71 ticker := time.NewTicker(nc.cleanup)
72 defer ticker.Stop()
73
74 for {
75 select {
76 case <-ticker.C:
77 nc.cleanupExpired()
78 case <-nc.stopCh:
79 return
80 }
81 }
82}
83
84// cleanupExpired removes expired entries from the cache
85func (nc *NonceCache) cleanupExpired() {
86 nc.mu.Lock()
87 defer nc.mu.Unlock()
88
89 now := time.Now()
90 for jti, expiry := range nc.seen {
91 if expiry.Before(now) {
92 delete(nc.seen, jti)
93 }
94 }
95}
96
97// Stop stops the cleanup goroutine. Call this when done with the cache.
98func (nc *NonceCache) Stop() {
99 close(nc.stopCh)
100}
101
102// Size returns the number of entries in the cache (for testing/monitoring)
103func (nc *NonceCache) Size() int {
104 nc.mu.RLock()
105 defer nc.mu.RUnlock()
106 return len(nc.seen)
107}
108
109// DPoPClaims represents the claims in a DPoP proof JWT (RFC 9449)
110type DPoPClaims struct {
111 jwt.RegisteredClaims
112
113 // HTTP method of the request (e.g., "GET", "POST")
114 HTTPMethod string `json:"htm"`
115
116 // HTTP URI of the request (without query and fragment parts)
117 HTTPURI string `json:"htu"`
118
119 // Access token hash (optional, for token binding)
120 AccessTokenHash string `json:"ath,omitempty"`
121}
122
123// DPoPProof represents a parsed and verified DPoP proof
124type DPoPProof struct {
125 RawPublicJWK map[string]interface{}
126 Claims *DPoPClaims
127 PublicKey interface{} // *ecdsa.PublicKey or similar
128 Thumbprint string // JWK thumbprint (base64url)
129}
130
131// DPoPVerifier verifies DPoP proofs for OAuth token binding
132type DPoPVerifier struct {
133 // Optional: custom nonce validation function (for server-issued nonces)
134 ValidateNonce func(nonce string) bool
135
136 // NonceCache for replay protection (optional but recommended)
137 // If nil, jti replay protection is disabled
138 NonceCache *NonceCache
139
140 // Maximum allowed clock skew for timestamp validation
141 MaxClockSkew time.Duration
142
143 // Maximum age of DPoP proof (prevents replay with old proofs)
144 MaxProofAge time.Duration
145}
146
147// NewDPoPVerifier creates a DPoP verifier with sensible defaults including replay protection
148func NewDPoPVerifier() *DPoPVerifier {
149 maxProofAge := 5 * time.Minute
150 return &DPoPVerifier{
151 MaxClockSkew: 30 * time.Second,
152 MaxProofAge: maxProofAge,
153 NonceCache: NewNonceCache(maxProofAge),
154 }
155}
156
157// NewDPoPVerifierWithoutReplayProtection creates a DPoP verifier without replay protection.
158// This should only be used in testing or when replay protection is handled externally.
159func NewDPoPVerifierWithoutReplayProtection() *DPoPVerifier {
160 return &DPoPVerifier{
161 MaxClockSkew: 30 * time.Second,
162 MaxProofAge: 5 * time.Minute,
163 NonceCache: nil, // No replay protection
164 }
165}
166
167// Stop stops background goroutines. Call this when shutting down.
168func (v *DPoPVerifier) Stop() {
169 if v.NonceCache != nil {
170 v.NonceCache.Stop()
171 }
172}
173
174// VerifyDPoPProof verifies a DPoP proof JWT and returns the parsed proof
175func (v *DPoPVerifier) VerifyDPoPProof(dpopProof, httpMethod, httpURI string) (*DPoPProof, error) {
176 // Parse the DPoP JWT without verification first to extract the header
177 parser := jwt.NewParser(jwt.WithoutClaimsValidation())
178 token, _, err := parser.ParseUnverified(dpopProof, &DPoPClaims{})
179 if err != nil {
180 return nil, fmt.Errorf("failed to parse DPoP proof: %w", err)
181 }
182
183 // Extract and validate the header
184 header, ok := token.Header["typ"].(string)
185 if !ok || header != "dpop+jwt" {
186 return nil, fmt.Errorf("invalid DPoP proof: typ must be 'dpop+jwt', got '%s'", header)
187 }
188
189 alg, ok := token.Header["alg"].(string)
190 if !ok {
191 return nil, fmt.Errorf("invalid DPoP proof: missing alg header")
192 }
193
194 // Extract the JWK from the header
195 jwkRaw, ok := token.Header["jwk"]
196 if !ok {
197 return nil, fmt.Errorf("invalid DPoP proof: missing jwk header")
198 }
199
200 jwkMap, ok := jwkRaw.(map[string]interface{})
201 if !ok {
202 return nil, fmt.Errorf("invalid DPoP proof: jwk must be an object")
203 }
204
205 // Parse the public key from JWK
206 publicKey, err := parseJWKToPublicKey(jwkMap)
207 if err != nil {
208 return nil, fmt.Errorf("invalid DPoP proof JWK: %w", err)
209 }
210
211 // Calculate the JWK thumbprint
212 thumbprint, err := CalculateJWKThumbprint(jwkMap)
213 if err != nil {
214 return nil, fmt.Errorf("failed to calculate JWK thumbprint: %w", err)
215 }
216
217 // Now verify the signature
218 verifiedToken, err := jwt.ParseWithClaims(dpopProof, &DPoPClaims{}, func(token *jwt.Token) (interface{}, error) {
219 // Verify the signing method matches what we expect
220 switch alg {
221 case "ES256":
222 if _, ok := token.Method.(*jwt.SigningMethodECDSA); !ok {
223 return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
224 }
225 case "ES384", "ES512":
226 if _, ok := token.Method.(*jwt.SigningMethodECDSA); !ok {
227 return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
228 }
229 case "RS256", "RS384", "RS512", "PS256", "PS384", "PS512":
230 // RSA methods - we primarily support ES256 for atproto
231 return nil, fmt.Errorf("RSA algorithms not yet supported for DPoP: %s", alg)
232 default:
233 return nil, fmt.Errorf("unsupported DPoP algorithm: %s", alg)
234 }
235 return publicKey, nil
236 })
237 if err != nil {
238 return nil, fmt.Errorf("DPoP proof signature verification failed: %w", err)
239 }
240
241 claims, ok := verifiedToken.Claims.(*DPoPClaims)
242 if !ok {
243 return nil, fmt.Errorf("invalid DPoP claims type")
244 }
245
246 // Validate the claims
247 if err := v.validateDPoPClaims(claims, httpMethod, httpURI); err != nil {
248 return nil, err
249 }
250
251 return &DPoPProof{
252 Claims: claims,
253 PublicKey: publicKey,
254 Thumbprint: thumbprint,
255 RawPublicJWK: jwkMap,
256 }, nil
257}
258
259// validateDPoPClaims validates the DPoP proof claims
260func (v *DPoPVerifier) validateDPoPClaims(claims *DPoPClaims, expectedMethod, expectedURI string) error {
261 // Validate jti (unique identifier) is present
262 if claims.ID == "" {
263 return fmt.Errorf("DPoP proof missing jti claim")
264 }
265
266 // Validate htm (HTTP method)
267 if !strings.EqualFold(claims.HTTPMethod, expectedMethod) {
268 return fmt.Errorf("DPoP proof htm mismatch: expected %s, got %s", expectedMethod, claims.HTTPMethod)
269 }
270
271 // Validate htu (HTTP URI) - compare without query/fragment
272 expectedURIBase := stripQueryFragment(expectedURI)
273 claimURIBase := stripQueryFragment(claims.HTTPURI)
274 if expectedURIBase != claimURIBase {
275 return fmt.Errorf("DPoP proof htu mismatch: expected %s, got %s", expectedURIBase, claimURIBase)
276 }
277
278 // Validate iat (issued at) is present and recent
279 if claims.IssuedAt == nil {
280 return fmt.Errorf("DPoP proof missing iat claim")
281 }
282
283 now := time.Now()
284 iat := claims.IssuedAt.Time
285
286 // Check clock skew (not too far in the future)
287 if iat.After(now.Add(v.MaxClockSkew)) {
288 return fmt.Errorf("DPoP proof iat is in the future")
289 }
290
291 // Check proof age (not too old)
292 if now.Sub(iat) > v.MaxProofAge {
293 return fmt.Errorf("DPoP proof is too old (issued %v ago, max %v)", now.Sub(iat), v.MaxProofAge)
294 }
295
296 // SECURITY: Check for replay attack using jti
297 // Per RFC 9449 Section 11.1, servers SHOULD prevent replay attacks
298 if v.NonceCache != nil {
299 if !v.NonceCache.CheckAndStore(claims.ID) {
300 return fmt.Errorf("DPoP proof replay detected: jti %s already used", claims.ID)
301 }
302 }
303
304 return nil
305}
306
307// VerifyTokenBinding verifies that the DPoP proof binds to the access token
308// by comparing the proof's thumbprint to the token's cnf.jkt claim
309func (v *DPoPVerifier) VerifyTokenBinding(proof *DPoPProof, expectedThumbprint string) error {
310 if proof.Thumbprint != expectedThumbprint {
311 return fmt.Errorf("DPoP proof thumbprint mismatch: token expects %s, proof has %s",
312 expectedThumbprint, proof.Thumbprint)
313 }
314 return nil
315}
316
317// VerifyAccessTokenHash verifies the DPoP proof's ath (access token hash) claim
318// matches the SHA-256 hash of the presented access token.
319// Per RFC 9449 section 4.2, if ath is present, the RS MUST verify it.
320func (v *DPoPVerifier) VerifyAccessTokenHash(proof *DPoPProof, accessToken string) error {
321 // If ath claim is not present, that's acceptable per RFC 9449
322 // (ath is only required when the RS mandates it)
323 if proof.Claims.AccessTokenHash == "" {
324 return nil
325 }
326
327 // Calculate the expected ath: base64url(SHA-256(access_token))
328 hash := sha256.Sum256([]byte(accessToken))
329 expectedAth := base64.RawURLEncoding.EncodeToString(hash[:])
330
331 if proof.Claims.AccessTokenHash != expectedAth {
332 return fmt.Errorf("DPoP proof ath mismatch: proof bound to different access token")
333 }
334
335 return nil
336}
337
338// CalculateJWKThumbprint calculates the JWK thumbprint per RFC 7638
339// The thumbprint is the base64url-encoded SHA-256 hash of the canonical JWK representation
340func CalculateJWKThumbprint(jwk map[string]interface{}) (string, error) {
341 kty, ok := jwk["kty"].(string)
342 if !ok {
343 return "", fmt.Errorf("JWK missing kty")
344 }
345
346 // Build the canonical JWK representation based on key type
347 // Per RFC 7638, only specific members are included, in lexicographic order
348 var canonical map[string]string
349
350 switch kty {
351 case "EC":
352 crv, ok := jwk["crv"].(string)
353 if !ok {
354 return "", fmt.Errorf("EC JWK missing crv")
355 }
356 x, ok := jwk["x"].(string)
357 if !ok {
358 return "", fmt.Errorf("EC JWK missing x")
359 }
360 y, ok := jwk["y"].(string)
361 if !ok {
362 return "", fmt.Errorf("EC JWK missing y")
363 }
364 // Lexicographic order: crv, kty, x, y
365 canonical = map[string]string{
366 "crv": crv,
367 "kty": kty,
368 "x": x,
369 "y": y,
370 }
371 case "RSA":
372 e, ok := jwk["e"].(string)
373 if !ok {
374 return "", fmt.Errorf("RSA JWK missing e")
375 }
376 n, ok := jwk["n"].(string)
377 if !ok {
378 return "", fmt.Errorf("RSA JWK missing n")
379 }
380 // Lexicographic order: e, kty, n
381 canonical = map[string]string{
382 "e": e,
383 "kty": kty,
384 "n": n,
385 }
386 case "OKP":
387 crv, ok := jwk["crv"].(string)
388 if !ok {
389 return "", fmt.Errorf("OKP JWK missing crv")
390 }
391 x, ok := jwk["x"].(string)
392 if !ok {
393 return "", fmt.Errorf("OKP JWK missing x")
394 }
395 // Lexicographic order: crv, kty, x
396 canonical = map[string]string{
397 "crv": crv,
398 "kty": kty,
399 "x": x,
400 }
401 default:
402 return "", fmt.Errorf("unsupported JWK key type: %s", kty)
403 }
404
405 // Serialize to JSON (Go's json.Marshal produces lexicographically ordered keys for map[string]string)
406 canonicalJSON, err := json.Marshal(canonical)
407 if err != nil {
408 return "", fmt.Errorf("failed to serialize canonical JWK: %w", err)
409 }
410
411 // SHA-256 hash
412 hash := sha256.Sum256(canonicalJSON)
413
414 // Base64url encode (no padding)
415 thumbprint := base64.RawURLEncoding.EncodeToString(hash[:])
416
417 return thumbprint, nil
418}
419
420// parseJWKToPublicKey parses a JWK map to a Go public key
421func parseJWKToPublicKey(jwkMap map[string]interface{}) (interface{}, error) {
422 // Convert map to JSON bytes for indigo's parser
423 jwkBytes, err := json.Marshal(jwkMap)
424 if err != nil {
425 return nil, fmt.Errorf("failed to serialize JWK: %w", err)
426 }
427
428 // Try to parse with indigo's crypto package
429 pubKey, err := indigoCrypto.ParsePublicJWKBytes(jwkBytes)
430 if err != nil {
431 return nil, fmt.Errorf("failed to parse JWK: %w", err)
432 }
433
434 // Convert indigo's PublicKey to Go's ecdsa.PublicKey
435 jwk, err := pubKey.JWK()
436 if err != nil {
437 return nil, fmt.Errorf("failed to get JWK from public key: %w", err)
438 }
439
440 // Use our existing conversion function
441 return atcryptoJWKToECDSAFromIndigoJWK(jwk)
442}
443
444// atcryptoJWKToECDSAFromIndigoJWK converts an indigo JWK to Go ecdsa.PublicKey
445func atcryptoJWKToECDSAFromIndigoJWK(jwk *indigoCrypto.JWK) (*ecdsa.PublicKey, error) {
446 if jwk.KeyType != "EC" {
447 return nil, fmt.Errorf("unsupported JWK key type: %s (expected EC)", jwk.KeyType)
448 }
449
450 xBytes, err := base64.RawURLEncoding.DecodeString(jwk.X)
451 if err != nil {
452 return nil, fmt.Errorf("invalid JWK X coordinate: %w", err)
453 }
454 yBytes, err := base64.RawURLEncoding.DecodeString(jwk.Y)
455 if err != nil {
456 return nil, fmt.Errorf("invalid JWK Y coordinate: %w", err)
457 }
458
459 var curve ecdsa.PublicKey
460 switch jwk.Curve {
461 case "P-256":
462 curve.Curve = ecdsaP256Curve()
463 case "P-384":
464 curve.Curve = ecdsaP384Curve()
465 case "P-521":
466 curve.Curve = ecdsaP521Curve()
467 default:
468 return nil, fmt.Errorf("unsupported curve: %s", jwk.Curve)
469 }
470
471 curve.X = new(big.Int).SetBytes(xBytes)
472 curve.Y = new(big.Int).SetBytes(yBytes)
473
474 return &curve, nil
475}
476
477// Helper functions for elliptic curves
478func ecdsaP256Curve() elliptic.Curve { return elliptic.P256() }
479func ecdsaP384Curve() elliptic.Curve { return elliptic.P384() }
480func ecdsaP521Curve() elliptic.Curve { return elliptic.P521() }
481
482// stripQueryFragment removes query and fragment from a URI
483func stripQueryFragment(uri string) string {
484 if idx := strings.Index(uri, "?"); idx != -1 {
485 uri = uri[:idx]
486 }
487 if idx := strings.Index(uri, "#"); idx != -1 {
488 uri = uri[:idx]
489 }
490 return uri
491}
492
493// ExtractCnfJkt extracts the cnf.jkt (confirmation key thumbprint) from JWT claims
494func ExtractCnfJkt(claims *Claims) (string, error) {
495 if claims.Confirmation == nil {
496 return "", fmt.Errorf("token missing cnf claim (no DPoP binding)")
497 }
498
499 jkt, ok := claims.Confirmation["jkt"].(string)
500 if !ok || jkt == "" {
501 return "", fmt.Errorf("token cnf claim missing jkt (DPoP key thumbprint)")
502 }
503
504 return jkt, nil
505}