A community based topic aggregation platform built on atproto
1package middleware
2
3import (
4 "Coves/internal/atproto/auth"
5 "context"
6 "crypto/ecdsa"
7 "crypto/elliptic"
8 "crypto/rand"
9 "encoding/base64"
10 "fmt"
11 "net/http"
12 "net/http/httptest"
13 "testing"
14 "time"
15
16 "github.com/golang-jwt/jwt/v5"
17 "github.com/google/uuid"
18)
19
20// mockJWKSFetcher is a test double for JWKSFetcher
21type mockJWKSFetcher struct {
22 shouldFail bool
23}
24
25func (m *mockJWKSFetcher) FetchPublicKey(ctx context.Context, issuer, token string) (interface{}, error) {
26 if m.shouldFail {
27 return nil, fmt.Errorf("mock fetch failure")
28 }
29 // Return nil - we won't actually verify signatures in Phase 1 tests
30 return nil, nil
31}
32
33// createTestToken creates a test JWT with the given DID
34func createTestToken(did string) string {
35 claims := jwt.MapClaims{
36 "sub": did,
37 "iss": "https://test.pds.local",
38 "scope": "atproto",
39 "exp": time.Now().Add(1 * time.Hour).Unix(),
40 "iat": time.Now().Unix(),
41 }
42
43 token := jwt.NewWithClaims(jwt.SigningMethodNone, claims)
44 tokenString, _ := token.SignedString(jwt.UnsafeAllowNoneSignatureType)
45 return tokenString
46}
47
48// TestRequireAuth_ValidToken tests that valid tokens are accepted (Phase 1)
49func TestRequireAuth_ValidToken(t *testing.T) {
50 fetcher := &mockJWKSFetcher{}
51 middleware := NewAtProtoAuthMiddleware(fetcher, true) // skipVerify=true
52
53 handlerCalled := false
54 handler := middleware.RequireAuth(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
55 handlerCalled = true
56
57 // Verify DID was extracted and injected into context
58 did := GetUserDID(r)
59 if did != "did:plc:test123" {
60 t.Errorf("expected DID 'did:plc:test123', got %s", did)
61 }
62
63 // Verify claims were injected
64 claims := GetJWTClaims(r)
65 if claims == nil {
66 t.Error("expected claims to be non-nil")
67 return
68 }
69 if claims.Subject != "did:plc:test123" {
70 t.Errorf("expected claims.Subject 'did:plc:test123', got %s", claims.Subject)
71 }
72
73 w.WriteHeader(http.StatusOK)
74 }))
75
76 token := createTestToken("did:plc:test123")
77 req := httptest.NewRequest("GET", "/test", nil)
78 req.Header.Set("Authorization", "Bearer "+token)
79 w := httptest.NewRecorder()
80
81 handler.ServeHTTP(w, req)
82
83 if !handlerCalled {
84 t.Error("handler was not called")
85 }
86
87 if w.Code != http.StatusOK {
88 t.Errorf("expected status 200, got %d: %s", w.Code, w.Body.String())
89 }
90}
91
92// TestRequireAuth_MissingAuthHeader tests that missing Authorization header is rejected
93func TestRequireAuth_MissingAuthHeader(t *testing.T) {
94 fetcher := &mockJWKSFetcher{}
95 middleware := NewAtProtoAuthMiddleware(fetcher, true)
96
97 handler := middleware.RequireAuth(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
98 t.Error("handler should not be called")
99 }))
100
101 req := httptest.NewRequest("GET", "/test", nil)
102 // No Authorization header
103 w := httptest.NewRecorder()
104
105 handler.ServeHTTP(w, req)
106
107 if w.Code != http.StatusUnauthorized {
108 t.Errorf("expected status 401, got %d", w.Code)
109 }
110}
111
112// TestRequireAuth_InvalidAuthHeaderFormat tests that non-Bearer tokens are rejected
113func TestRequireAuth_InvalidAuthHeaderFormat(t *testing.T) {
114 fetcher := &mockJWKSFetcher{}
115 middleware := NewAtProtoAuthMiddleware(fetcher, true)
116
117 handler := middleware.RequireAuth(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
118 t.Error("handler should not be called")
119 }))
120
121 req := httptest.NewRequest("GET", "/test", nil)
122 req.Header.Set("Authorization", "Basic dGVzdDp0ZXN0") // Wrong format
123 w := httptest.NewRecorder()
124
125 handler.ServeHTTP(w, req)
126
127 if w.Code != http.StatusUnauthorized {
128 t.Errorf("expected status 401, got %d", w.Code)
129 }
130}
131
132// TestRequireAuth_MalformedToken tests that malformed JWTs are rejected
133func TestRequireAuth_MalformedToken(t *testing.T) {
134 fetcher := &mockJWKSFetcher{}
135 middleware := NewAtProtoAuthMiddleware(fetcher, true)
136
137 handler := middleware.RequireAuth(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
138 t.Error("handler should not be called")
139 }))
140
141 req := httptest.NewRequest("GET", "/test", nil)
142 req.Header.Set("Authorization", "Bearer not-a-valid-jwt")
143 w := httptest.NewRecorder()
144
145 handler.ServeHTTP(w, req)
146
147 if w.Code != http.StatusUnauthorized {
148 t.Errorf("expected status 401, got %d", w.Code)
149 }
150}
151
152// TestRequireAuth_ExpiredToken tests that expired tokens are rejected
153func TestRequireAuth_ExpiredToken(t *testing.T) {
154 fetcher := &mockJWKSFetcher{}
155 middleware := NewAtProtoAuthMiddleware(fetcher, true)
156
157 handler := middleware.RequireAuth(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
158 t.Error("handler should not be called for expired token")
159 }))
160
161 // Create expired token
162 claims := jwt.MapClaims{
163 "sub": "did:plc:test123",
164 "iss": "https://test.pds.local",
165 "scope": "atproto",
166 "exp": time.Now().Add(-1 * time.Hour).Unix(), // Expired 1 hour ago
167 "iat": time.Now().Add(-2 * time.Hour).Unix(),
168 }
169
170 token := jwt.NewWithClaims(jwt.SigningMethodNone, claims)
171 tokenString, _ := token.SignedString(jwt.UnsafeAllowNoneSignatureType)
172
173 req := httptest.NewRequest("GET", "/test", nil)
174 req.Header.Set("Authorization", "Bearer "+tokenString)
175 w := httptest.NewRecorder()
176
177 handler.ServeHTTP(w, req)
178
179 if w.Code != http.StatusUnauthorized {
180 t.Errorf("expected status 401, got %d", w.Code)
181 }
182}
183
184// TestRequireAuth_MissingDID tests that tokens without DID are rejected
185func TestRequireAuth_MissingDID(t *testing.T) {
186 fetcher := &mockJWKSFetcher{}
187 middleware := NewAtProtoAuthMiddleware(fetcher, true)
188
189 handler := middleware.RequireAuth(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
190 t.Error("handler should not be called")
191 }))
192
193 // Create token without sub claim
194 claims := jwt.MapClaims{
195 // "sub" missing
196 "iss": "https://test.pds.local",
197 "scope": "atproto",
198 "exp": time.Now().Add(1 * time.Hour).Unix(),
199 "iat": time.Now().Unix(),
200 }
201
202 token := jwt.NewWithClaims(jwt.SigningMethodNone, claims)
203 tokenString, _ := token.SignedString(jwt.UnsafeAllowNoneSignatureType)
204
205 req := httptest.NewRequest("GET", "/test", nil)
206 req.Header.Set("Authorization", "Bearer "+tokenString)
207 w := httptest.NewRecorder()
208
209 handler.ServeHTTP(w, req)
210
211 if w.Code != http.StatusUnauthorized {
212 t.Errorf("expected status 401, got %d", w.Code)
213 }
214}
215
216// TestOptionalAuth_WithToken tests that OptionalAuth accepts valid tokens
217func TestOptionalAuth_WithToken(t *testing.T) {
218 fetcher := &mockJWKSFetcher{}
219 middleware := NewAtProtoAuthMiddleware(fetcher, true)
220
221 handlerCalled := false
222 handler := middleware.OptionalAuth(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
223 handlerCalled = true
224
225 // Verify DID was extracted
226 did := GetUserDID(r)
227 if did != "did:plc:test123" {
228 t.Errorf("expected DID 'did:plc:test123', got %s", did)
229 }
230
231 w.WriteHeader(http.StatusOK)
232 }))
233
234 token := createTestToken("did:plc:test123")
235 req := httptest.NewRequest("GET", "/test", nil)
236 req.Header.Set("Authorization", "Bearer "+token)
237 w := httptest.NewRecorder()
238
239 handler.ServeHTTP(w, req)
240
241 if !handlerCalled {
242 t.Error("handler was not called")
243 }
244
245 if w.Code != http.StatusOK {
246 t.Errorf("expected status 200, got %d", w.Code)
247 }
248}
249
250// TestOptionalAuth_WithoutToken tests that OptionalAuth allows requests without tokens
251func TestOptionalAuth_WithoutToken(t *testing.T) {
252 fetcher := &mockJWKSFetcher{}
253 middleware := NewAtProtoAuthMiddleware(fetcher, true)
254
255 handlerCalled := false
256 handler := middleware.OptionalAuth(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
257 handlerCalled = true
258
259 // Verify no DID is set
260 did := GetUserDID(r)
261 if did != "" {
262 t.Errorf("expected empty DID, got %s", did)
263 }
264
265 w.WriteHeader(http.StatusOK)
266 }))
267
268 req := httptest.NewRequest("GET", "/test", nil)
269 // No Authorization header
270 w := httptest.NewRecorder()
271
272 handler.ServeHTTP(w, req)
273
274 if !handlerCalled {
275 t.Error("handler was not called")
276 }
277
278 if w.Code != http.StatusOK {
279 t.Errorf("expected status 200, got %d", w.Code)
280 }
281}
282
283// TestOptionalAuth_InvalidToken tests that OptionalAuth continues without auth on invalid token
284func TestOptionalAuth_InvalidToken(t *testing.T) {
285 fetcher := &mockJWKSFetcher{}
286 middleware := NewAtProtoAuthMiddleware(fetcher, true)
287
288 handlerCalled := false
289 handler := middleware.OptionalAuth(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
290 handlerCalled = true
291
292 // Verify no DID is set (invalid token ignored)
293 did := GetUserDID(r)
294 if did != "" {
295 t.Errorf("expected empty DID for invalid token, got %s", did)
296 }
297
298 w.WriteHeader(http.StatusOK)
299 }))
300
301 req := httptest.NewRequest("GET", "/test", nil)
302 req.Header.Set("Authorization", "Bearer not-a-valid-jwt")
303 w := httptest.NewRecorder()
304
305 handler.ServeHTTP(w, req)
306
307 if !handlerCalled {
308 t.Error("handler was not called")
309 }
310
311 if w.Code != http.StatusOK {
312 t.Errorf("expected status 200, got %d", w.Code)
313 }
314}
315
316// TestGetUserDID_NotAuthenticated tests that GetUserDID returns empty string when not authenticated
317func TestGetUserDID_NotAuthenticated(t *testing.T) {
318 req := httptest.NewRequest("GET", "/test", nil)
319 did := GetUserDID(req)
320
321 if did != "" {
322 t.Errorf("expected empty string, got %s", did)
323 }
324}
325
326// TestGetJWTClaims_NotAuthenticated tests that GetJWTClaims returns nil when not authenticated
327func TestGetJWTClaims_NotAuthenticated(t *testing.T) {
328 req := httptest.NewRequest("GET", "/test", nil)
329 claims := GetJWTClaims(req)
330
331 if claims != nil {
332 t.Errorf("expected nil claims, got %+v", claims)
333 }
334}
335
336// TestGetDPoPProof_NotAuthenticated tests that GetDPoPProof returns nil when no DPoP was verified
337func TestGetDPoPProof_NotAuthenticated(t *testing.T) {
338 req := httptest.NewRequest("GET", "/test", nil)
339 proof := GetDPoPProof(req)
340
341 if proof != nil {
342 t.Errorf("expected nil proof, got %+v", proof)
343 }
344}
345
346// TestRequireAuth_WithDPoP_SecurityModel tests the correct DPoP security model:
347// Token MUST be verified first, then DPoP is checked as an additional layer.
348// DPoP is NOT a fallback for failed token verification.
349func TestRequireAuth_WithDPoP_SecurityModel(t *testing.T) {
350 // Generate an ECDSA key pair for DPoP
351 privateKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
352 if err != nil {
353 t.Fatalf("failed to generate key: %v", err)
354 }
355
356 // Calculate JWK thumbprint for cnf.jkt
357 jwk := ecdsaPublicKeyToJWK(&privateKey.PublicKey)
358 thumbprint, err := auth.CalculateJWKThumbprint(jwk)
359 if err != nil {
360 t.Fatalf("failed to calculate thumbprint: %v", err)
361 }
362
363 t.Run("DPoP_is_NOT_fallback_for_failed_verification", func(t *testing.T) {
364 // SECURITY TEST: When token verification fails, DPoP should NOT be used as fallback
365 // This prevents an attacker from forging a token with their own cnf.jkt
366
367 // Create a DPoP-bound access token (unsigned - will fail verification)
368 claims := auth.Claims{
369 RegisteredClaims: jwt.RegisteredClaims{
370 Subject: "did:plc:attacker",
371 Issuer: "https://external.pds.local",
372 ExpiresAt: jwt.NewNumericDate(time.Now().Add(1 * time.Hour)),
373 IssuedAt: jwt.NewNumericDate(time.Now()),
374 },
375 Scope: "atproto",
376 Confirmation: map[string]interface{}{
377 "jkt": thumbprint,
378 },
379 }
380
381 token := jwt.NewWithClaims(jwt.SigningMethodNone, claims)
382 tokenString, _ := token.SignedString(jwt.UnsafeAllowNoneSignatureType)
383
384 // Create valid DPoP proof (attacker has the private key)
385 dpopProof := createDPoPProof(t, privateKey, "GET", "https://test.local/api/endpoint")
386
387 // Mock fetcher that fails (simulating external PDS without JWKS)
388 fetcher := &mockJWKSFetcher{shouldFail: true}
389 middleware := NewAtProtoAuthMiddleware(fetcher, false) // skipVerify=false
390
391 handler := middleware.RequireAuth(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
392 t.Error("SECURITY VULNERABILITY: handler was called despite token verification failure")
393 }))
394
395 req := httptest.NewRequest("GET", "https://test.local/api/endpoint", nil)
396 req.Header.Set("Authorization", "Bearer "+tokenString)
397 req.Header.Set("DPoP", dpopProof)
398 w := httptest.NewRecorder()
399
400 handler.ServeHTTP(w, req)
401
402 // MUST reject - token verification failed, DPoP cannot substitute for signature verification
403 if w.Code != http.StatusUnauthorized {
404 t.Errorf("SECURITY: expected 401 for unverified token, got %d", w.Code)
405 }
406 })
407
408 t.Run("DPoP_required_when_cnf_jkt_present_in_verified_token", func(t *testing.T) {
409 // When token has cnf.jkt, DPoP header MUST be present
410 // This test uses skipVerify=true to simulate a verified token
411
412 claims := auth.Claims{
413 RegisteredClaims: jwt.RegisteredClaims{
414 Subject: "did:plc:test123",
415 Issuer: "https://test.pds.local",
416 ExpiresAt: jwt.NewNumericDate(time.Now().Add(1 * time.Hour)),
417 IssuedAt: jwt.NewNumericDate(time.Now()),
418 },
419 Scope: "atproto",
420 Confirmation: map[string]interface{}{
421 "jkt": thumbprint,
422 },
423 }
424
425 token := jwt.NewWithClaims(jwt.SigningMethodNone, claims)
426 tokenString, _ := token.SignedString(jwt.UnsafeAllowNoneSignatureType)
427
428 // NO DPoP header - should fail when skipVerify is false
429 // Note: with skipVerify=true, DPoP is not checked
430 fetcher := &mockJWKSFetcher{}
431 middleware := NewAtProtoAuthMiddleware(fetcher, true) // skipVerify=true for parsing
432
433 handlerCalled := false
434 handler := middleware.RequireAuth(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
435 handlerCalled = true
436 w.WriteHeader(http.StatusOK)
437 }))
438
439 req := httptest.NewRequest("GET", "https://test.local/api/endpoint", nil)
440 req.Header.Set("Authorization", "Bearer "+tokenString)
441 // No DPoP header
442 w := httptest.NewRecorder()
443
444 handler.ServeHTTP(w, req)
445
446 // With skipVerify=true, DPoP is not checked, so this should succeed
447 if !handlerCalled {
448 t.Error("handler should be called when skipVerify=true")
449 }
450 })
451}
452
453// TestRequireAuth_TokenVerificationFails_DPoPNotUsedAsFallback is the key security test.
454// It ensures that DPoP cannot be used as a fallback when token signature verification fails.
455func TestRequireAuth_TokenVerificationFails_DPoPNotUsedAsFallback(t *testing.T) {
456 // Generate a key pair (attacker's key)
457 attackerKey, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
458 jwk := ecdsaPublicKeyToJWK(&attackerKey.PublicKey)
459 thumbprint, _ := auth.CalculateJWKThumbprint(jwk)
460
461 // Create a FORGED token claiming to be the victim
462 claims := auth.Claims{
463 RegisteredClaims: jwt.RegisteredClaims{
464 Subject: "did:plc:victim_user", // Attacker claims to be victim
465 Issuer: "https://untrusted.pds",
466 ExpiresAt: jwt.NewNumericDate(time.Now().Add(1 * time.Hour)),
467 IssuedAt: jwt.NewNumericDate(time.Now()),
468 },
469 Scope: "atproto",
470 Confirmation: map[string]interface{}{
471 "jkt": thumbprint, // Attacker uses their own key
472 },
473 }
474
475 token := jwt.NewWithClaims(jwt.SigningMethodNone, claims)
476 tokenString, _ := token.SignedString(jwt.UnsafeAllowNoneSignatureType)
477
478 // Attacker creates a valid DPoP proof with their key
479 dpopProof := createDPoPProof(t, attackerKey, "POST", "https://api.example.com/protected")
480
481 // Fetcher fails (external PDS without JWKS)
482 fetcher := &mockJWKSFetcher{shouldFail: true}
483 middleware := NewAtProtoAuthMiddleware(fetcher, false) // skipVerify=false - REAL verification
484
485 handler := middleware.RequireAuth(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
486 t.Fatalf("CRITICAL SECURITY FAILURE: Request authenticated as %s despite forged token!",
487 GetUserDID(r))
488 }))
489
490 req := httptest.NewRequest("POST", "https://api.example.com/protected", nil)
491 req.Header.Set("Authorization", "Bearer "+tokenString)
492 req.Header.Set("DPoP", dpopProof)
493 w := httptest.NewRecorder()
494
495 handler.ServeHTTP(w, req)
496
497 // MUST reject - the token signature was never verified
498 if w.Code != http.StatusUnauthorized {
499 t.Errorf("SECURITY VULNERABILITY: Expected 401, got %d. Token was not properly verified!", w.Code)
500 }
501}
502
503// TestVerifyDPoPBinding_UsesForwardedProto ensures we honor the external HTTPS
504// scheme when TLS is terminated upstream and X-Forwarded-Proto is present.
505func TestVerifyDPoPBinding_UsesForwardedProto(t *testing.T) {
506 privateKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
507 if err != nil {
508 t.Fatalf("failed to generate key: %v", err)
509 }
510
511 jwk := ecdsaPublicKeyToJWK(&privateKey.PublicKey)
512 thumbprint, err := auth.CalculateJWKThumbprint(jwk)
513 if err != nil {
514 t.Fatalf("failed to calculate thumbprint: %v", err)
515 }
516
517 claims := &auth.Claims{
518 RegisteredClaims: jwt.RegisteredClaims{
519 Subject: "did:plc:test123",
520 Issuer: "https://test.pds.local",
521 ExpiresAt: jwt.NewNumericDate(time.Now().Add(1 * time.Hour)),
522 IssuedAt: jwt.NewNumericDate(time.Now()),
523 },
524 Scope: "atproto",
525 Confirmation: map[string]interface{}{
526 "jkt": thumbprint,
527 },
528 }
529
530 middleware := NewAtProtoAuthMiddleware(&mockJWKSFetcher{}, false)
531 defer middleware.Stop()
532
533 externalURI := "https://api.example.com/protected/resource"
534 dpopProof := createDPoPProof(t, privateKey, "GET", externalURI)
535
536 req := httptest.NewRequest("GET", "http://internal-service/protected/resource", nil)
537 req.Host = "api.example.com"
538 req.Header.Set("X-Forwarded-Proto", "https")
539
540 proof, err := middleware.verifyDPoPBinding(req, claims, dpopProof)
541 if err != nil {
542 t.Fatalf("expected DPoP verification to succeed with forwarded proto, got %v", err)
543 }
544
545 if proof == nil || proof.Claims == nil {
546 t.Fatal("expected DPoP proof to be returned")
547 }
548}
549
550// TestMiddlewareStop tests that the middleware can be stopped properly
551func TestMiddlewareStop(t *testing.T) {
552 fetcher := &mockJWKSFetcher{}
553 middleware := NewAtProtoAuthMiddleware(fetcher, false)
554
555 // Stop should not panic and should clean up resources
556 middleware.Stop()
557
558 // Calling Stop again should also be safe (idempotent-ish)
559 // Note: The underlying DPoPVerifier.Stop() closes a channel, so this might panic
560 // if not handled properly. We test that at least one Stop works.
561}
562
563// TestOptionalAuth_DPoPBoundToken_NoDPoPHeader tests that OptionalAuth treats
564// tokens with cnf.jkt but no DPoP header as unauthenticated (potential token theft)
565func TestOptionalAuth_DPoPBoundToken_NoDPoPHeader(t *testing.T) {
566 // Generate a key pair for DPoP binding
567 privateKey, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
568 jwk := ecdsaPublicKeyToJWK(&privateKey.PublicKey)
569 thumbprint, _ := auth.CalculateJWKThumbprint(jwk)
570
571 // Create a DPoP-bound token (has cnf.jkt)
572 claims := auth.Claims{
573 RegisteredClaims: jwt.RegisteredClaims{
574 Subject: "did:plc:user123",
575 Issuer: "https://test.pds.local",
576 ExpiresAt: jwt.NewNumericDate(time.Now().Add(1 * time.Hour)),
577 IssuedAt: jwt.NewNumericDate(time.Now()),
578 },
579 Scope: "atproto",
580 Confirmation: map[string]interface{}{
581 "jkt": thumbprint,
582 },
583 }
584
585 token := jwt.NewWithClaims(jwt.SigningMethodNone, claims)
586 tokenString, _ := token.SignedString(jwt.UnsafeAllowNoneSignatureType)
587
588 // Use skipVerify=true to simulate a verified token
589 // (In production, skipVerify would be false and VerifyJWT would be called)
590 // However, for this test we need skipVerify=false to trigger DPoP checking
591 // But the fetcher will fail, so let's use skipVerify=true and verify the logic
592 // Actually, the DPoP check only happens when skipVerify=false
593
594 t.Run("with_skipVerify_false", func(t *testing.T) {
595 // This will fail at JWT verification level, but that's expected
596 // The important thing is the code path for DPoP checking
597 fetcher := &mockJWKSFetcher{shouldFail: true}
598 middleware := NewAtProtoAuthMiddleware(fetcher, false)
599 defer middleware.Stop()
600
601 handlerCalled := false
602 var capturedDID string
603 handler := middleware.OptionalAuth(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
604 handlerCalled = true
605 capturedDID = GetUserDID(r)
606 w.WriteHeader(http.StatusOK)
607 }))
608
609 req := httptest.NewRequest("GET", "/test", nil)
610 req.Header.Set("Authorization", "Bearer "+tokenString)
611 // Deliberately NOT setting DPoP header
612 w := httptest.NewRecorder()
613
614 handler.ServeHTTP(w, req)
615
616 // Handler should be called (optional auth doesn't block)
617 if !handlerCalled {
618 t.Error("handler should be called")
619 }
620
621 // But since JWT verification fails, user should not be authenticated
622 if capturedDID != "" {
623 t.Errorf("expected empty DID when verification fails, got %s", capturedDID)
624 }
625 })
626
627 t.Run("with_skipVerify_true_dpop_not_checked", func(t *testing.T) {
628 // When skipVerify=true, DPoP is not checked (Phase 1 mode)
629 fetcher := &mockJWKSFetcher{}
630 middleware := NewAtProtoAuthMiddleware(fetcher, true)
631 defer middleware.Stop()
632
633 handlerCalled := false
634 var capturedDID string
635 handler := middleware.OptionalAuth(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
636 handlerCalled = true
637 capturedDID = GetUserDID(r)
638 w.WriteHeader(http.StatusOK)
639 }))
640
641 req := httptest.NewRequest("GET", "/test", nil)
642 req.Header.Set("Authorization", "Bearer "+tokenString)
643 // No DPoP header
644 w := httptest.NewRecorder()
645
646 handler.ServeHTTP(w, req)
647
648 if !handlerCalled {
649 t.Error("handler should be called")
650 }
651
652 // With skipVerify=true, DPoP check is bypassed - token is trusted
653 if capturedDID != "did:plc:user123" {
654 t.Errorf("expected DID when skipVerify=true, got %s", capturedDID)
655 }
656 })
657}
658
659// TestDPoPReplayProtection tests that the same DPoP proof cannot be used twice
660func TestDPoPReplayProtection(t *testing.T) {
661 // This tests the NonceCache functionality
662 cache := auth.NewNonceCache(5 * time.Minute)
663 defer cache.Stop()
664
665 jti := "unique-proof-id-123"
666
667 // First use should succeed
668 if !cache.CheckAndStore(jti) {
669 t.Error("First use of jti should succeed")
670 }
671
672 // Second use should fail (replay detected)
673 if cache.CheckAndStore(jti) {
674 t.Error("SECURITY: Replay attack not detected - same jti accepted twice")
675 }
676
677 // Different jti should succeed
678 if !cache.CheckAndStore("different-jti-456") {
679 t.Error("Different jti should succeed")
680 }
681}
682
683// Helper: createDPoPProof creates a DPoP proof JWT for testing
684func createDPoPProof(t *testing.T, privateKey *ecdsa.PrivateKey, method, uri string) string {
685 // Create JWK from public key
686 jwk := ecdsaPublicKeyToJWK(&privateKey.PublicKey)
687
688 // Create DPoP claims with UUID for jti to ensure uniqueness across tests
689 claims := auth.DPoPClaims{
690 RegisteredClaims: jwt.RegisteredClaims{
691 IssuedAt: jwt.NewNumericDate(time.Now()),
692 ID: uuid.New().String(),
693 },
694 HTTPMethod: method,
695 HTTPURI: uri,
696 }
697
698 // Create token with custom header
699 token := jwt.NewWithClaims(jwt.SigningMethodES256, claims)
700 token.Header["typ"] = "dpop+jwt"
701 token.Header["jwk"] = jwk
702
703 // Sign with private key
704 signedToken, err := token.SignedString(privateKey)
705 if err != nil {
706 t.Fatalf("failed to sign DPoP proof: %v", err)
707 }
708
709 return signedToken
710}
711
712// Helper: ecdsaPublicKeyToJWK converts an ECDSA public key to JWK map
713func ecdsaPublicKeyToJWK(pubKey *ecdsa.PublicKey) map[string]interface{} {
714 // Get curve name
715 var crv string
716 switch pubKey.Curve {
717 case elliptic.P256():
718 crv = "P-256"
719 case elliptic.P384():
720 crv = "P-384"
721 case elliptic.P521():
722 crv = "P-521"
723 default:
724 panic("unsupported curve")
725 }
726
727 // Encode coordinates
728 xBytes := pubKey.X.Bytes()
729 yBytes := pubKey.Y.Bytes()
730
731 // Ensure proper byte length (pad if needed)
732 keySize := (pubKey.Curve.Params().BitSize + 7) / 8
733 xPadded := make([]byte, keySize)
734 yPadded := make([]byte, keySize)
735 copy(xPadded[keySize-len(xBytes):], xBytes)
736 copy(yPadded[keySize-len(yBytes):], yBytes)
737
738 return map[string]interface{}{
739 "kty": "EC",
740 "crv": crv,
741 "x": base64.RawURLEncoding.EncodeToString(xPadded),
742 "y": base64.RawURLEncoding.EncodeToString(yPadded),
743 }
744}