A community based topic aggregation platform built on atproto
1package middleware
2
3import (
4 "Coves/internal/atproto/auth"
5 "context"
6 "crypto/ecdsa"
7 "crypto/elliptic"
8 "crypto/rand"
9 "crypto/sha256"
10 "encoding/base64"
11 "fmt"
12 "net/http"
13 "net/http/httptest"
14 "strings"
15 "testing"
16 "time"
17
18 "github.com/golang-jwt/jwt/v5"
19 "github.com/google/uuid"
20)
21
22// mockJWKSFetcher is a test double for JWKSFetcher
23type mockJWKSFetcher struct {
24 shouldFail bool
25}
26
27func (m *mockJWKSFetcher) FetchPublicKey(ctx context.Context, issuer, token string) (interface{}, error) {
28 if m.shouldFail {
29 return nil, fmt.Errorf("mock fetch failure")
30 }
31 // Return nil - we won't actually verify signatures in Phase 1 tests
32 return nil, nil
33}
34
35// createTestToken creates a test JWT with the given DID
36func createTestToken(did string) string {
37 claims := jwt.MapClaims{
38 "sub": did,
39 "iss": "https://test.pds.local",
40 "scope": "atproto",
41 "exp": time.Now().Add(1 * time.Hour).Unix(),
42 "iat": time.Now().Unix(),
43 }
44
45 token := jwt.NewWithClaims(jwt.SigningMethodNone, claims)
46 tokenString, _ := token.SignedString(jwt.UnsafeAllowNoneSignatureType)
47 return tokenString
48}
49
50// TestRequireAuth_ValidToken tests that valid tokens are accepted with DPoP scheme (Phase 1)
51func TestRequireAuth_ValidToken(t *testing.T) {
52 fetcher := &mockJWKSFetcher{}
53 middleware := NewAtProtoAuthMiddleware(fetcher, true) // skipVerify=true
54
55 handlerCalled := false
56 handler := middleware.RequireAuth(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
57 handlerCalled = true
58
59 // Verify DID was extracted and injected into context
60 did := GetUserDID(r)
61 if did != "did:plc:test123" {
62 t.Errorf("expected DID 'did:plc:test123', got %s", did)
63 }
64
65 // Verify claims were injected
66 claims := GetJWTClaims(r)
67 if claims == nil {
68 t.Error("expected claims to be non-nil")
69 return
70 }
71 if claims.Subject != "did:plc:test123" {
72 t.Errorf("expected claims.Subject 'did:plc:test123', got %s", claims.Subject)
73 }
74
75 w.WriteHeader(http.StatusOK)
76 }))
77
78 token := createTestToken("did:plc:test123")
79 req := httptest.NewRequest("GET", "/test", nil)
80 req.Header.Set("Authorization", "DPoP "+token)
81 w := httptest.NewRecorder()
82
83 handler.ServeHTTP(w, req)
84
85 if !handlerCalled {
86 t.Error("handler was not called")
87 }
88
89 if w.Code != http.StatusOK {
90 t.Errorf("expected status 200, got %d: %s", w.Code, w.Body.String())
91 }
92}
93
94// TestRequireAuth_MissingAuthHeader tests that missing Authorization header is rejected
95func TestRequireAuth_MissingAuthHeader(t *testing.T) {
96 fetcher := &mockJWKSFetcher{}
97 middleware := NewAtProtoAuthMiddleware(fetcher, true)
98
99 handler := middleware.RequireAuth(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
100 t.Error("handler should not be called")
101 }))
102
103 req := httptest.NewRequest("GET", "/test", nil)
104 // No Authorization header
105 w := httptest.NewRecorder()
106
107 handler.ServeHTTP(w, req)
108
109 if w.Code != http.StatusUnauthorized {
110 t.Errorf("expected status 401, got %d", w.Code)
111 }
112}
113
114// TestRequireAuth_InvalidAuthHeaderFormat tests that non-DPoP tokens are rejected (including Bearer)
115func TestRequireAuth_InvalidAuthHeaderFormat(t *testing.T) {
116 fetcher := &mockJWKSFetcher{}
117 middleware := NewAtProtoAuthMiddleware(fetcher, true)
118
119 tests := []struct {
120 name string
121 header string
122 }{
123 {"Basic auth", "Basic dGVzdDp0ZXN0"},
124 {"Bearer scheme", "Bearer some-token"},
125 {"Invalid format", "InvalidFormat"},
126 }
127
128 for _, tt := range tests {
129 t.Run(tt.name, func(t *testing.T) {
130 handler := middleware.RequireAuth(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
131 t.Error("handler should not be called")
132 }))
133
134 req := httptest.NewRequest("GET", "/test", nil)
135 req.Header.Set("Authorization", tt.header)
136 w := httptest.NewRecorder()
137
138 handler.ServeHTTP(w, req)
139
140 if w.Code != http.StatusUnauthorized {
141 t.Errorf("expected status 401, got %d", w.Code)
142 }
143 })
144 }
145}
146
147// TestRequireAuth_BearerRejectionErrorMessage verifies that Bearer tokens are rejected
148// with a helpful error message guiding users to use DPoP scheme
149func TestRequireAuth_BearerRejectionErrorMessage(t *testing.T) {
150 fetcher := &mockJWKSFetcher{}
151 middleware := NewAtProtoAuthMiddleware(fetcher, true)
152
153 handler := middleware.RequireAuth(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
154 t.Error("handler should not be called")
155 }))
156
157 req := httptest.NewRequest("GET", "/test", nil)
158 req.Header.Set("Authorization", "Bearer some-token")
159 w := httptest.NewRecorder()
160
161 handler.ServeHTTP(w, req)
162
163 if w.Code != http.StatusUnauthorized {
164 t.Errorf("expected status 401, got %d", w.Code)
165 }
166
167 // Verify error message guides user to use DPoP
168 body := w.Body.String()
169 if !strings.Contains(body, "Expected: DPoP") {
170 t.Errorf("error message should guide user to use DPoP, got: %s", body)
171 }
172}
173
174// TestRequireAuth_CaseInsensitiveScheme verifies that DPoP scheme matching is case-insensitive
175// per RFC 7235 which states HTTP auth schemes are case-insensitive
176func TestRequireAuth_CaseInsensitiveScheme(t *testing.T) {
177 fetcher := &mockJWKSFetcher{}
178 middleware := NewAtProtoAuthMiddleware(fetcher, true)
179
180 // Create a valid JWT for testing
181 validToken := createValidJWT(t, "did:plc:test123", time.Hour)
182
183 testCases := []struct {
184 name string
185 scheme string
186 }{
187 {"lowercase", "dpop"},
188 {"uppercase", "DPOP"},
189 {"mixed_case", "DpOp"},
190 {"standard", "DPoP"},
191 }
192
193 for _, tc := range testCases {
194 t.Run(tc.name, func(t *testing.T) {
195 handlerCalled := false
196 handler := middleware.RequireAuth(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
197 handlerCalled = true
198 w.WriteHeader(http.StatusOK)
199 }))
200
201 req := httptest.NewRequest("GET", "/test", nil)
202 req.Header.Set("Authorization", tc.scheme+" "+validToken)
203 w := httptest.NewRecorder()
204
205 handler.ServeHTTP(w, req)
206
207 if !handlerCalled {
208 t.Errorf("scheme %q should be accepted (case-insensitive per RFC 7235), got status %d: %s",
209 tc.scheme, w.Code, w.Body.String())
210 }
211 })
212 }
213}
214
215// TestRequireAuth_MalformedToken tests that malformed JWTs are rejected
216func TestRequireAuth_MalformedToken(t *testing.T) {
217 fetcher := &mockJWKSFetcher{}
218 middleware := NewAtProtoAuthMiddleware(fetcher, true)
219
220 handler := middleware.RequireAuth(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
221 t.Error("handler should not be called")
222 }))
223
224 req := httptest.NewRequest("GET", "/test", nil)
225 req.Header.Set("Authorization", "DPoP not-a-valid-jwt")
226 w := httptest.NewRecorder()
227
228 handler.ServeHTTP(w, req)
229
230 if w.Code != http.StatusUnauthorized {
231 t.Errorf("expected status 401, got %d", w.Code)
232 }
233}
234
235// TestRequireAuth_ExpiredToken tests that expired tokens are rejected
236func TestRequireAuth_ExpiredToken(t *testing.T) {
237 fetcher := &mockJWKSFetcher{}
238 middleware := NewAtProtoAuthMiddleware(fetcher, true)
239
240 handler := middleware.RequireAuth(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
241 t.Error("handler should not be called for expired token")
242 }))
243
244 // Create expired token
245 claims := jwt.MapClaims{
246 "sub": "did:plc:test123",
247 "iss": "https://test.pds.local",
248 "scope": "atproto",
249 "exp": time.Now().Add(-1 * time.Hour).Unix(), // Expired 1 hour ago
250 "iat": time.Now().Add(-2 * time.Hour).Unix(),
251 }
252
253 token := jwt.NewWithClaims(jwt.SigningMethodNone, claims)
254 tokenString, _ := token.SignedString(jwt.UnsafeAllowNoneSignatureType)
255
256 req := httptest.NewRequest("GET", "/test", nil)
257 req.Header.Set("Authorization", "DPoP "+tokenString)
258 w := httptest.NewRecorder()
259
260 handler.ServeHTTP(w, req)
261
262 if w.Code != http.StatusUnauthorized {
263 t.Errorf("expected status 401, got %d", w.Code)
264 }
265}
266
267// TestRequireAuth_MissingDID tests that tokens without DID are rejected
268func TestRequireAuth_MissingDID(t *testing.T) {
269 fetcher := &mockJWKSFetcher{}
270 middleware := NewAtProtoAuthMiddleware(fetcher, true)
271
272 handler := middleware.RequireAuth(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
273 t.Error("handler should not be called")
274 }))
275
276 // Create token without sub claim
277 claims := jwt.MapClaims{
278 // "sub" missing
279 "iss": "https://test.pds.local",
280 "scope": "atproto",
281 "exp": time.Now().Add(1 * time.Hour).Unix(),
282 "iat": time.Now().Unix(),
283 }
284
285 token := jwt.NewWithClaims(jwt.SigningMethodNone, claims)
286 tokenString, _ := token.SignedString(jwt.UnsafeAllowNoneSignatureType)
287
288 req := httptest.NewRequest("GET", "/test", nil)
289 req.Header.Set("Authorization", "DPoP "+tokenString)
290 w := httptest.NewRecorder()
291
292 handler.ServeHTTP(w, req)
293
294 if w.Code != http.StatusUnauthorized {
295 t.Errorf("expected status 401, got %d", w.Code)
296 }
297}
298
299// TestOptionalAuth_WithToken tests that OptionalAuth accepts valid DPoP tokens
300func TestOptionalAuth_WithToken(t *testing.T) {
301 fetcher := &mockJWKSFetcher{}
302 middleware := NewAtProtoAuthMiddleware(fetcher, true)
303
304 handlerCalled := false
305 handler := middleware.OptionalAuth(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
306 handlerCalled = true
307
308 // Verify DID was extracted
309 did := GetUserDID(r)
310 if did != "did:plc:test123" {
311 t.Errorf("expected DID 'did:plc:test123', got %s", did)
312 }
313
314 w.WriteHeader(http.StatusOK)
315 }))
316
317 token := createTestToken("did:plc:test123")
318 req := httptest.NewRequest("GET", "/test", nil)
319 req.Header.Set("Authorization", "DPoP "+token)
320 w := httptest.NewRecorder()
321
322 handler.ServeHTTP(w, req)
323
324 if !handlerCalled {
325 t.Error("handler was not called")
326 }
327
328 if w.Code != http.StatusOK {
329 t.Errorf("expected status 200, got %d", w.Code)
330 }
331}
332
333// TestOptionalAuth_WithoutToken tests that OptionalAuth allows requests without tokens
334func TestOptionalAuth_WithoutToken(t *testing.T) {
335 fetcher := &mockJWKSFetcher{}
336 middleware := NewAtProtoAuthMiddleware(fetcher, true)
337
338 handlerCalled := false
339 handler := middleware.OptionalAuth(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
340 handlerCalled = true
341
342 // Verify no DID is set
343 did := GetUserDID(r)
344 if did != "" {
345 t.Errorf("expected empty DID, got %s", did)
346 }
347
348 w.WriteHeader(http.StatusOK)
349 }))
350
351 req := httptest.NewRequest("GET", "/test", nil)
352 // No Authorization header
353 w := httptest.NewRecorder()
354
355 handler.ServeHTTP(w, req)
356
357 if !handlerCalled {
358 t.Error("handler was not called")
359 }
360
361 if w.Code != http.StatusOK {
362 t.Errorf("expected status 200, got %d", w.Code)
363 }
364}
365
366// TestOptionalAuth_InvalidToken tests that OptionalAuth continues without auth on invalid token
367func TestOptionalAuth_InvalidToken(t *testing.T) {
368 fetcher := &mockJWKSFetcher{}
369 middleware := NewAtProtoAuthMiddleware(fetcher, true)
370
371 handlerCalled := false
372 handler := middleware.OptionalAuth(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
373 handlerCalled = true
374
375 // Verify no DID is set (invalid token ignored)
376 did := GetUserDID(r)
377 if did != "" {
378 t.Errorf("expected empty DID for invalid token, got %s", did)
379 }
380
381 w.WriteHeader(http.StatusOK)
382 }))
383
384 req := httptest.NewRequest("GET", "/test", nil)
385 req.Header.Set("Authorization", "DPoP not-a-valid-jwt")
386 w := httptest.NewRecorder()
387
388 handler.ServeHTTP(w, req)
389
390 if !handlerCalled {
391 t.Error("handler was not called")
392 }
393
394 if w.Code != http.StatusOK {
395 t.Errorf("expected status 200, got %d", w.Code)
396 }
397}
398
399// TestGetUserDID_NotAuthenticated tests that GetUserDID returns empty string when not authenticated
400func TestGetUserDID_NotAuthenticated(t *testing.T) {
401 req := httptest.NewRequest("GET", "/test", nil)
402 did := GetUserDID(req)
403
404 if did != "" {
405 t.Errorf("expected empty string, got %s", did)
406 }
407}
408
409// TestGetJWTClaims_NotAuthenticated tests that GetJWTClaims returns nil when not authenticated
410func TestGetJWTClaims_NotAuthenticated(t *testing.T) {
411 req := httptest.NewRequest("GET", "/test", nil)
412 claims := GetJWTClaims(req)
413
414 if claims != nil {
415 t.Errorf("expected nil claims, got %+v", claims)
416 }
417}
418
419// TestGetDPoPProof_NotAuthenticated tests that GetDPoPProof returns nil when no DPoP was verified
420func TestGetDPoPProof_NotAuthenticated(t *testing.T) {
421 req := httptest.NewRequest("GET", "/test", nil)
422 proof := GetDPoPProof(req)
423
424 if proof != nil {
425 t.Errorf("expected nil proof, got %+v", proof)
426 }
427}
428
429// TestRequireAuth_WithDPoP_SecurityModel tests the correct DPoP security model:
430// Token MUST be verified first, then DPoP is checked as an additional layer.
431// DPoP is NOT a fallback for failed token verification.
432func TestRequireAuth_WithDPoP_SecurityModel(t *testing.T) {
433 // Generate an ECDSA key pair for DPoP
434 privateKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
435 if err != nil {
436 t.Fatalf("failed to generate key: %v", err)
437 }
438
439 // Calculate JWK thumbprint for cnf.jkt
440 jwk := ecdsaPublicKeyToJWK(&privateKey.PublicKey)
441 thumbprint, err := auth.CalculateJWKThumbprint(jwk)
442 if err != nil {
443 t.Fatalf("failed to calculate thumbprint: %v", err)
444 }
445
446 t.Run("DPoP_is_NOT_fallback_for_failed_verification", func(t *testing.T) {
447 // SECURITY TEST: When token verification fails, DPoP should NOT be used as fallback
448 // This prevents an attacker from forging a token with their own cnf.jkt
449
450 // Create a DPoP-bound access token (unsigned - will fail verification)
451 claims := auth.Claims{
452 RegisteredClaims: jwt.RegisteredClaims{
453 Subject: "did:plc:attacker",
454 Issuer: "https://external.pds.local",
455 ExpiresAt: jwt.NewNumericDate(time.Now().Add(1 * time.Hour)),
456 IssuedAt: jwt.NewNumericDate(time.Now()),
457 },
458 Scope: "atproto",
459 Confirmation: map[string]interface{}{
460 "jkt": thumbprint,
461 },
462 }
463
464 token := jwt.NewWithClaims(jwt.SigningMethodNone, claims)
465 tokenString, _ := token.SignedString(jwt.UnsafeAllowNoneSignatureType)
466
467 // Create valid DPoP proof (attacker has the private key)
468 dpopProof := createDPoPProof(t, privateKey, "GET", "https://test.local/api/endpoint")
469
470 // Mock fetcher that fails (simulating external PDS without JWKS)
471 fetcher := &mockJWKSFetcher{shouldFail: true}
472 middleware := NewAtProtoAuthMiddleware(fetcher, false) // skipVerify=false
473
474 handler := middleware.RequireAuth(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
475 t.Error("SECURITY VULNERABILITY: handler was called despite token verification failure")
476 }))
477
478 req := httptest.NewRequest("GET", "https://test.local/api/endpoint", nil)
479 req.Header.Set("Authorization", "DPoP "+tokenString)
480 req.Header.Set("DPoP", dpopProof)
481 w := httptest.NewRecorder()
482
483 handler.ServeHTTP(w, req)
484
485 // MUST reject - token verification failed, DPoP cannot substitute for signature verification
486 if w.Code != http.StatusUnauthorized {
487 t.Errorf("SECURITY: expected 401 for unverified token, got %d", w.Code)
488 }
489 })
490
491 t.Run("DPoP_required_when_cnf_jkt_present_in_verified_token", func(t *testing.T) {
492 // When token has cnf.jkt, DPoP header MUST be present
493 // This test uses skipVerify=true to simulate a verified token
494
495 claims := auth.Claims{
496 RegisteredClaims: jwt.RegisteredClaims{
497 Subject: "did:plc:test123",
498 Issuer: "https://test.pds.local",
499 ExpiresAt: jwt.NewNumericDate(time.Now().Add(1 * time.Hour)),
500 IssuedAt: jwt.NewNumericDate(time.Now()),
501 },
502 Scope: "atproto",
503 Confirmation: map[string]interface{}{
504 "jkt": thumbprint,
505 },
506 }
507
508 token := jwt.NewWithClaims(jwt.SigningMethodNone, claims)
509 tokenString, _ := token.SignedString(jwt.UnsafeAllowNoneSignatureType)
510
511 // NO DPoP header - should fail when skipVerify is false
512 // Note: with skipVerify=true, DPoP is not checked
513 fetcher := &mockJWKSFetcher{}
514 middleware := NewAtProtoAuthMiddleware(fetcher, true) // skipVerify=true for parsing
515
516 handlerCalled := false
517 handler := middleware.RequireAuth(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
518 handlerCalled = true
519 w.WriteHeader(http.StatusOK)
520 }))
521
522 req := httptest.NewRequest("GET", "https://test.local/api/endpoint", nil)
523 req.Header.Set("Authorization", "DPoP "+tokenString)
524 // No DPoP header
525 w := httptest.NewRecorder()
526
527 handler.ServeHTTP(w, req)
528
529 // With skipVerify=true, DPoP is not checked, so this should succeed
530 if !handlerCalled {
531 t.Error("handler should be called when skipVerify=true")
532 }
533 })
534}
535
536// TestRequireAuth_TokenVerificationFails_DPoPNotUsedAsFallback is the key security test.
537// It ensures that DPoP cannot be used as a fallback when token signature verification fails.
538func TestRequireAuth_TokenVerificationFails_DPoPNotUsedAsFallback(t *testing.T) {
539 // Generate a key pair (attacker's key)
540 attackerKey, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
541 jwk := ecdsaPublicKeyToJWK(&attackerKey.PublicKey)
542 thumbprint, _ := auth.CalculateJWKThumbprint(jwk)
543
544 // Create a FORGED token claiming to be the victim
545 claims := auth.Claims{
546 RegisteredClaims: jwt.RegisteredClaims{
547 Subject: "did:plc:victim_user", // Attacker claims to be victim
548 Issuer: "https://untrusted.pds",
549 ExpiresAt: jwt.NewNumericDate(time.Now().Add(1 * time.Hour)),
550 IssuedAt: jwt.NewNumericDate(time.Now()),
551 },
552 Scope: "atproto",
553 Confirmation: map[string]interface{}{
554 "jkt": thumbprint, // Attacker uses their own key
555 },
556 }
557
558 token := jwt.NewWithClaims(jwt.SigningMethodNone, claims)
559 tokenString, _ := token.SignedString(jwt.UnsafeAllowNoneSignatureType)
560
561 // Attacker creates a valid DPoP proof with their key
562 dpopProof := createDPoPProof(t, attackerKey, "POST", "https://api.example.com/protected")
563
564 // Fetcher fails (external PDS without JWKS)
565 fetcher := &mockJWKSFetcher{shouldFail: true}
566 middleware := NewAtProtoAuthMiddleware(fetcher, false) // skipVerify=false - REAL verification
567
568 handler := middleware.RequireAuth(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
569 t.Fatalf("CRITICAL SECURITY FAILURE: Request authenticated as %s despite forged token!",
570 GetUserDID(r))
571 }))
572
573 req := httptest.NewRequest("POST", "https://api.example.com/protected", nil)
574 req.Header.Set("Authorization", "DPoP "+tokenString)
575 req.Header.Set("DPoP", dpopProof)
576 w := httptest.NewRecorder()
577
578 handler.ServeHTTP(w, req)
579
580 // MUST reject - the token signature was never verified
581 if w.Code != http.StatusUnauthorized {
582 t.Errorf("SECURITY VULNERABILITY: Expected 401, got %d. Token was not properly verified!", w.Code)
583 }
584}
585
586// TestVerifyDPoPBinding_UsesForwardedProto ensures we honor the external HTTPS
587// scheme when TLS is terminated upstream and X-Forwarded-Proto is present.
588func TestVerifyDPoPBinding_UsesForwardedProto(t *testing.T) {
589 privateKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
590 if err != nil {
591 t.Fatalf("failed to generate key: %v", err)
592 }
593
594 jwk := ecdsaPublicKeyToJWK(&privateKey.PublicKey)
595 thumbprint, err := auth.CalculateJWKThumbprint(jwk)
596 if err != nil {
597 t.Fatalf("failed to calculate thumbprint: %v", err)
598 }
599
600 claims := &auth.Claims{
601 RegisteredClaims: jwt.RegisteredClaims{
602 Subject: "did:plc:test123",
603 Issuer: "https://test.pds.local",
604 ExpiresAt: jwt.NewNumericDate(time.Now().Add(1 * time.Hour)),
605 IssuedAt: jwt.NewNumericDate(time.Now()),
606 },
607 Scope: "atproto",
608 Confirmation: map[string]interface{}{
609 "jkt": thumbprint,
610 },
611 }
612
613 middleware := NewAtProtoAuthMiddleware(&mockJWKSFetcher{}, false)
614 defer middleware.Stop()
615
616 externalURI := "https://api.example.com/protected/resource"
617 dpopProof := createDPoPProof(t, privateKey, "GET", externalURI)
618
619 req := httptest.NewRequest("GET", "http://internal-service/protected/resource", nil)
620 req.Host = "api.example.com"
621 req.Header.Set("X-Forwarded-Proto", "https")
622
623 // Pass a fake access token - ath verification will pass since we don't include ath in the DPoP proof
624 fakeAccessToken := "fake-access-token-for-testing"
625 proof, err := middleware.verifyDPoPBinding(req, claims, dpopProof, fakeAccessToken)
626 if err != nil {
627 t.Fatalf("expected DPoP verification to succeed with forwarded proto, got %v", err)
628 }
629
630 if proof == nil || proof.Claims == nil {
631 t.Fatal("expected DPoP proof to be returned")
632 }
633}
634
635// TestVerifyDPoPBinding_UsesForwardedHost ensures we honor X-Forwarded-Host header
636// when behind a TLS-terminating proxy that rewrites the Host header.
637func TestVerifyDPoPBinding_UsesForwardedHost(t *testing.T) {
638 privateKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
639 if err != nil {
640 t.Fatalf("failed to generate key: %v", err)
641 }
642
643 jwk := ecdsaPublicKeyToJWK(&privateKey.PublicKey)
644 thumbprint, err := auth.CalculateJWKThumbprint(jwk)
645 if err != nil {
646 t.Fatalf("failed to calculate thumbprint: %v", err)
647 }
648
649 claims := &auth.Claims{
650 RegisteredClaims: jwt.RegisteredClaims{
651 Subject: "did:plc:test123",
652 Issuer: "https://test.pds.local",
653 ExpiresAt: jwt.NewNumericDate(time.Now().Add(1 * time.Hour)),
654 IssuedAt: jwt.NewNumericDate(time.Now()),
655 },
656 Scope: "atproto",
657 Confirmation: map[string]interface{}{
658 "jkt": thumbprint,
659 },
660 }
661
662 middleware := NewAtProtoAuthMiddleware(&mockJWKSFetcher{}, false)
663 defer middleware.Stop()
664
665 // External URI that the client uses
666 externalURI := "https://api.example.com/protected/resource"
667 dpopProof := createDPoPProof(t, privateKey, "GET", externalURI)
668
669 // Request hits internal service with internal hostname, but X-Forwarded-Host has public hostname
670 req := httptest.NewRequest("GET", "http://internal-service:8080/protected/resource", nil)
671 req.Host = "internal-service:8080" // Internal host after proxy
672 req.Header.Set("X-Forwarded-Proto", "https")
673 req.Header.Set("X-Forwarded-Host", "api.example.com") // Original public host
674
675 fakeAccessToken := "fake-access-token-for-testing"
676 proof, err := middleware.verifyDPoPBinding(req, claims, dpopProof, fakeAccessToken)
677 if err != nil {
678 t.Fatalf("expected DPoP verification to succeed with X-Forwarded-Host, got %v", err)
679 }
680
681 if proof == nil || proof.Claims == nil {
682 t.Fatal("expected DPoP proof to be returned")
683 }
684}
685
686// TestVerifyDPoPBinding_UsesStandardForwardedHeader tests RFC 7239 Forwarded header parsing
687func TestVerifyDPoPBinding_UsesStandardForwardedHeader(t *testing.T) {
688 privateKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
689 if err != nil {
690 t.Fatalf("failed to generate key: %v", err)
691 }
692
693 jwk := ecdsaPublicKeyToJWK(&privateKey.PublicKey)
694 thumbprint, err := auth.CalculateJWKThumbprint(jwk)
695 if err != nil {
696 t.Fatalf("failed to calculate thumbprint: %v", err)
697 }
698
699 claims := &auth.Claims{
700 RegisteredClaims: jwt.RegisteredClaims{
701 Subject: "did:plc:test123",
702 Issuer: "https://test.pds.local",
703 ExpiresAt: jwt.NewNumericDate(time.Now().Add(1 * time.Hour)),
704 IssuedAt: jwt.NewNumericDate(time.Now()),
705 },
706 Scope: "atproto",
707 Confirmation: map[string]interface{}{
708 "jkt": thumbprint,
709 },
710 }
711
712 middleware := NewAtProtoAuthMiddleware(&mockJWKSFetcher{}, false)
713 defer middleware.Stop()
714
715 // External URI
716 externalURI := "https://api.example.com/protected/resource"
717 dpopProof := createDPoPProof(t, privateKey, "GET", externalURI)
718
719 // Request with standard Forwarded header (RFC 7239)
720 req := httptest.NewRequest("GET", "http://internal-service/protected/resource", nil)
721 req.Host = "internal-service"
722 req.Header.Set("Forwarded", "for=192.0.2.60;proto=https;host=api.example.com")
723
724 fakeAccessToken := "fake-access-token-for-testing"
725 proof, err := middleware.verifyDPoPBinding(req, claims, dpopProof, fakeAccessToken)
726 if err != nil {
727 t.Fatalf("expected DPoP verification to succeed with Forwarded header, got %v", err)
728 }
729
730 if proof == nil {
731 t.Fatal("expected DPoP proof to be returned")
732 }
733}
734
735// TestVerifyDPoPBinding_ForwardedMixedCaseAndQuotes tests RFC 7239 edge cases:
736// mixed-case keys (Proto vs proto) and quoted values (host="example.com")
737func TestVerifyDPoPBinding_ForwardedMixedCaseAndQuotes(t *testing.T) {
738 privateKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
739 if err != nil {
740 t.Fatalf("failed to generate key: %v", err)
741 }
742
743 jwk := ecdsaPublicKeyToJWK(&privateKey.PublicKey)
744 thumbprint, err := auth.CalculateJWKThumbprint(jwk)
745 if err != nil {
746 t.Fatalf("failed to calculate thumbprint: %v", err)
747 }
748
749 claims := &auth.Claims{
750 RegisteredClaims: jwt.RegisteredClaims{
751 Subject: "did:plc:test123",
752 Issuer: "https://test.pds.local",
753 ExpiresAt: jwt.NewNumericDate(time.Now().Add(1 * time.Hour)),
754 IssuedAt: jwt.NewNumericDate(time.Now()),
755 },
756 Scope: "atproto",
757 Confirmation: map[string]interface{}{
758 "jkt": thumbprint,
759 },
760 }
761
762 middleware := NewAtProtoAuthMiddleware(&mockJWKSFetcher{}, false)
763 defer middleware.Stop()
764
765 // External URI that the client uses
766 externalURI := "https://api.example.com/protected/resource"
767 dpopProof := createDPoPProof(t, privateKey, "GET", externalURI)
768
769 // Request with RFC 7239 Forwarded header using:
770 // - Mixed-case keys: "Proto" instead of "proto", "Host" instead of "host"
771 // - Quoted value: Host="api.example.com" (legal per RFC 7239 section 4)
772 req := httptest.NewRequest("GET", "http://internal-service/protected/resource", nil)
773 req.Host = "internal-service"
774 req.Header.Set("Forwarded", `for=192.0.2.60;Proto=https;Host="api.example.com"`)
775
776 fakeAccessToken := "fake-access-token-for-testing"
777 proof, err := middleware.verifyDPoPBinding(req, claims, dpopProof, fakeAccessToken)
778 if err != nil {
779 t.Fatalf("expected DPoP verification to succeed with mixed-case/quoted Forwarded header, got %v", err)
780 }
781
782 if proof == nil {
783 t.Fatal("expected DPoP proof to be returned")
784 }
785}
786
787// TestVerifyDPoPBinding_AthValidation tests access token hash (ath) claim validation
788func TestVerifyDPoPBinding_AthValidation(t *testing.T) {
789 privateKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
790 if err != nil {
791 t.Fatalf("failed to generate key: %v", err)
792 }
793
794 jwk := ecdsaPublicKeyToJWK(&privateKey.PublicKey)
795 thumbprint, err := auth.CalculateJWKThumbprint(jwk)
796 if err != nil {
797 t.Fatalf("failed to calculate thumbprint: %v", err)
798 }
799
800 claims := &auth.Claims{
801 RegisteredClaims: jwt.RegisteredClaims{
802 Subject: "did:plc:test123",
803 Issuer: "https://test.pds.local",
804 ExpiresAt: jwt.NewNumericDate(time.Now().Add(1 * time.Hour)),
805 IssuedAt: jwt.NewNumericDate(time.Now()),
806 },
807 Scope: "atproto",
808 Confirmation: map[string]interface{}{
809 "jkt": thumbprint,
810 },
811 }
812
813 middleware := NewAtProtoAuthMiddleware(&mockJWKSFetcher{}, false)
814 defer middleware.Stop()
815
816 accessToken := "real-access-token-12345"
817
818 t.Run("ath_matches_access_token", func(t *testing.T) {
819 // Create DPoP proof with ath claim matching the access token
820 dpopProof := createDPoPProofWithAth(t, privateKey, "GET", "https://api.example.com/resource", accessToken)
821
822 req := httptest.NewRequest("GET", "https://api.example.com/resource", nil)
823 req.Host = "api.example.com"
824
825 proof, err := middleware.verifyDPoPBinding(req, claims, dpopProof, accessToken)
826 if err != nil {
827 t.Fatalf("expected verification to succeed with matching ath, got %v", err)
828 }
829 if proof == nil {
830 t.Fatal("expected proof to be returned")
831 }
832 })
833
834 t.Run("ath_mismatch_rejected", func(t *testing.T) {
835 // Create DPoP proof with ath for a DIFFERENT token
836 differentToken := "different-token-67890"
837 dpopProof := createDPoPProofWithAth(t, privateKey, "POST", "https://api.example.com/resource", differentToken)
838
839 req := httptest.NewRequest("POST", "https://api.example.com/resource", nil)
840 req.Host = "api.example.com"
841
842 // Try to use with the original access token - should fail
843 _, err := middleware.verifyDPoPBinding(req, claims, dpopProof, accessToken)
844 if err == nil {
845 t.Fatal("SECURITY: expected verification to fail when ath doesn't match access token")
846 }
847 if !strings.Contains(err.Error(), "ath") {
848 t.Errorf("error should mention ath mismatch, got: %v", err)
849 }
850 })
851}
852
853// TestMiddlewareStop tests that the middleware can be stopped properly
854func TestMiddlewareStop(t *testing.T) {
855 fetcher := &mockJWKSFetcher{}
856 middleware := NewAtProtoAuthMiddleware(fetcher, false)
857
858 // Stop should not panic and should clean up resources
859 middleware.Stop()
860
861 // Calling Stop again should also be safe (idempotent-ish)
862 // Note: The underlying DPoPVerifier.Stop() closes a channel, so this might panic
863 // if not handled properly. We test that at least one Stop works.
864}
865
866// TestOptionalAuth_DPoPBoundToken_NoDPoPHeader tests that OptionalAuth treats
867// tokens with cnf.jkt but no DPoP header as unauthenticated (potential token theft)
868func TestOptionalAuth_DPoPBoundToken_NoDPoPHeader(t *testing.T) {
869 // Generate a key pair for DPoP binding
870 privateKey, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
871 jwk := ecdsaPublicKeyToJWK(&privateKey.PublicKey)
872 thumbprint, _ := auth.CalculateJWKThumbprint(jwk)
873
874 // Create a DPoP-bound token (has cnf.jkt)
875 claims := auth.Claims{
876 RegisteredClaims: jwt.RegisteredClaims{
877 Subject: "did:plc:user123",
878 Issuer: "https://test.pds.local",
879 ExpiresAt: jwt.NewNumericDate(time.Now().Add(1 * time.Hour)),
880 IssuedAt: jwt.NewNumericDate(time.Now()),
881 },
882 Scope: "atproto",
883 Confirmation: map[string]interface{}{
884 "jkt": thumbprint,
885 },
886 }
887
888 token := jwt.NewWithClaims(jwt.SigningMethodNone, claims)
889 tokenString, _ := token.SignedString(jwt.UnsafeAllowNoneSignatureType)
890
891 // Use skipVerify=true to simulate a verified token
892 // (In production, skipVerify would be false and VerifyJWT would be called)
893 // However, for this test we need skipVerify=false to trigger DPoP checking
894 // But the fetcher will fail, so let's use skipVerify=true and verify the logic
895 // Actually, the DPoP check only happens when skipVerify=false
896
897 t.Run("with_skipVerify_false", func(t *testing.T) {
898 // This will fail at JWT verification level, but that's expected
899 // The important thing is the code path for DPoP checking
900 fetcher := &mockJWKSFetcher{shouldFail: true}
901 middleware := NewAtProtoAuthMiddleware(fetcher, false)
902 defer middleware.Stop()
903
904 handlerCalled := false
905 var capturedDID string
906 handler := middleware.OptionalAuth(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
907 handlerCalled = true
908 capturedDID = GetUserDID(r)
909 w.WriteHeader(http.StatusOK)
910 }))
911
912 req := httptest.NewRequest("GET", "/test", nil)
913 req.Header.Set("Authorization", "DPoP "+tokenString)
914 // Deliberately NOT setting DPoP header
915 w := httptest.NewRecorder()
916
917 handler.ServeHTTP(w, req)
918
919 // Handler should be called (optional auth doesn't block)
920 if !handlerCalled {
921 t.Error("handler should be called")
922 }
923
924 // But since JWT verification fails, user should not be authenticated
925 if capturedDID != "" {
926 t.Errorf("expected empty DID when verification fails, got %s", capturedDID)
927 }
928 })
929
930 t.Run("with_skipVerify_true_dpop_not_checked", func(t *testing.T) {
931 // When skipVerify=true, DPoP is not checked (Phase 1 mode)
932 fetcher := &mockJWKSFetcher{}
933 middleware := NewAtProtoAuthMiddleware(fetcher, true)
934 defer middleware.Stop()
935
936 handlerCalled := false
937 var capturedDID string
938 handler := middleware.OptionalAuth(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
939 handlerCalled = true
940 capturedDID = GetUserDID(r)
941 w.WriteHeader(http.StatusOK)
942 }))
943
944 req := httptest.NewRequest("GET", "/test", nil)
945 req.Header.Set("Authorization", "DPoP "+tokenString)
946 // No DPoP header
947 w := httptest.NewRecorder()
948
949 handler.ServeHTTP(w, req)
950
951 if !handlerCalled {
952 t.Error("handler should be called")
953 }
954
955 // With skipVerify=true, DPoP check is bypassed - token is trusted
956 if capturedDID != "did:plc:user123" {
957 t.Errorf("expected DID when skipVerify=true, got %s", capturedDID)
958 }
959 })
960}
961
962// TestDPoPReplayProtection tests that the same DPoP proof cannot be used twice
963func TestDPoPReplayProtection(t *testing.T) {
964 // This tests the NonceCache functionality
965 cache := auth.NewNonceCache(5 * time.Minute)
966 defer cache.Stop()
967
968 jti := "unique-proof-id-123"
969
970 // First use should succeed
971 if !cache.CheckAndStore(jti) {
972 t.Error("First use of jti should succeed")
973 }
974
975 // Second use should fail (replay detected)
976 if cache.CheckAndStore(jti) {
977 t.Error("SECURITY: Replay attack not detected - same jti accepted twice")
978 }
979
980 // Different jti should succeed
981 if !cache.CheckAndStore("different-jti-456") {
982 t.Error("Different jti should succeed")
983 }
984}
985
986// Helper: createDPoPProof creates a DPoP proof JWT for testing
987func createDPoPProof(t *testing.T, privateKey *ecdsa.PrivateKey, method, uri string) string {
988 // Create JWK from public key
989 jwk := ecdsaPublicKeyToJWK(&privateKey.PublicKey)
990
991 // Create DPoP claims with UUID for jti to ensure uniqueness across tests
992 claims := auth.DPoPClaims{
993 RegisteredClaims: jwt.RegisteredClaims{
994 IssuedAt: jwt.NewNumericDate(time.Now()),
995 ID: uuid.New().String(),
996 },
997 HTTPMethod: method,
998 HTTPURI: uri,
999 }
1000
1001 // Create token with custom header
1002 token := jwt.NewWithClaims(jwt.SigningMethodES256, claims)
1003 token.Header["typ"] = "dpop+jwt"
1004 token.Header["jwk"] = jwk
1005
1006 // Sign with private key
1007 signedToken, err := token.SignedString(privateKey)
1008 if err != nil {
1009 t.Fatalf("failed to sign DPoP proof: %v", err)
1010 }
1011
1012 return signedToken
1013}
1014
1015// Helper: createDPoPProofWithAth creates a DPoP proof JWT with ath (access token hash) claim
1016func createDPoPProofWithAth(t *testing.T, privateKey *ecdsa.PrivateKey, method, uri, accessToken string) string {
1017 // Create JWK from public key
1018 jwk := ecdsaPublicKeyToJWK(&privateKey.PublicKey)
1019
1020 // Calculate ath: base64url(SHA-256(access_token))
1021 hash := sha256.Sum256([]byte(accessToken))
1022 ath := base64.RawURLEncoding.EncodeToString(hash[:])
1023
1024 // Create DPoP claims with ath
1025 claims := auth.DPoPClaims{
1026 RegisteredClaims: jwt.RegisteredClaims{
1027 IssuedAt: jwt.NewNumericDate(time.Now()),
1028 ID: uuid.New().String(),
1029 },
1030 HTTPMethod: method,
1031 HTTPURI: uri,
1032 AccessTokenHash: ath,
1033 }
1034
1035 // Create token with custom header
1036 token := jwt.NewWithClaims(jwt.SigningMethodES256, claims)
1037 token.Header["typ"] = "dpop+jwt"
1038 token.Header["jwk"] = jwk
1039
1040 // Sign with private key
1041 signedToken, err := token.SignedString(privateKey)
1042 if err != nil {
1043 t.Fatalf("failed to sign DPoP proof: %v", err)
1044 }
1045
1046 return signedToken
1047}
1048
1049// Helper: ecdsaPublicKeyToJWK converts an ECDSA public key to JWK map
1050func ecdsaPublicKeyToJWK(pubKey *ecdsa.PublicKey) map[string]interface{} {
1051 // Get curve name
1052 var crv string
1053 switch pubKey.Curve {
1054 case elliptic.P256():
1055 crv = "P-256"
1056 case elliptic.P384():
1057 crv = "P-384"
1058 case elliptic.P521():
1059 crv = "P-521"
1060 default:
1061 panic("unsupported curve")
1062 }
1063
1064 // Encode coordinates
1065 xBytes := pubKey.X.Bytes()
1066 yBytes := pubKey.Y.Bytes()
1067
1068 // Ensure proper byte length (pad if needed)
1069 keySize := (pubKey.Curve.Params().BitSize + 7) / 8
1070 xPadded := make([]byte, keySize)
1071 yPadded := make([]byte, keySize)
1072 copy(xPadded[keySize-len(xBytes):], xBytes)
1073 copy(yPadded[keySize-len(yBytes):], yBytes)
1074
1075 return map[string]interface{}{
1076 "kty": "EC",
1077 "crv": crv,
1078 "x": base64.RawURLEncoding.EncodeToString(xPadded),
1079 "y": base64.RawURLEncoding.EncodeToString(yPadded),
1080 }
1081}
1082
1083// Helper: createValidJWT creates a valid unsigned JWT token for testing
1084// This is used with skipVerify=true middleware where signature verification is skipped
1085func createValidJWT(t *testing.T, subject string, expiry time.Duration) string {
1086 t.Helper()
1087
1088 claims := auth.Claims{
1089 RegisteredClaims: jwt.RegisteredClaims{
1090 Subject: subject,
1091 Issuer: "https://test.pds.local",
1092 ExpiresAt: jwt.NewNumericDate(time.Now().Add(expiry)),
1093 IssuedAt: jwt.NewNumericDate(time.Now()),
1094 },
1095 Scope: "atproto",
1096 }
1097
1098 // Create unsigned token (for skipVerify=true tests)
1099 token := jwt.NewWithClaims(jwt.SigningMethodNone, claims)
1100 signedToken, err := token.SignedString(jwt.UnsafeAllowNoneSignatureType)
1101 if err != nil {
1102 t.Fatalf("failed to create test JWT: %v", err)
1103 }
1104
1105 return signedToken
1106}