···
···
48
-
// TestRequireAuth_ValidToken tests that valid tokens are accepted (Phase 1)
50
+
// TestRequireAuth_ValidToken tests that valid tokens are accepted with DPoP scheme (Phase 1)
func TestRequireAuth_ValidToken(t *testing.T) {
fetcher := &mockJWKSFetcher{}
middleware := NewAtProtoAuthMiddleware(fetcher, true) // skipVerify=true
···
token := createTestToken("did:plc:test123")
req := httptest.NewRequest("GET", "/test", nil)
78
-
req.Header.Set("Authorization", "Bearer "+token)
80
+
req.Header.Set("Authorization", "DPoP "+token)
w := httptest.NewRecorder()
handler.ServeHTTP(w, req)
···
112
-
// TestRequireAuth_InvalidAuthHeaderFormat tests that non-Bearer tokens are rejected
114
+
// TestRequireAuth_InvalidAuthHeaderFormat tests that non-DPoP tokens are rejected (including Bearer)
func TestRequireAuth_InvalidAuthHeaderFormat(t *testing.T) {
fetcher := &mockJWKSFetcher{}
middleware := NewAtProtoAuthMiddleware(fetcher, true)
119
+
tests := []struct {
123
+
{"Basic auth", "Basic dGVzdDp0ZXN0"},
124
+
{"Bearer scheme", "Bearer some-token"},
125
+
{"Invalid format", "InvalidFormat"},
128
+
for _, tt := range tests {
129
+
t.Run(tt.name, func(t *testing.T) {
130
+
handler := middleware.RequireAuth(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
131
+
t.Error("handler should not be called")
134
+
req := httptest.NewRequest("GET", "/test", nil)
135
+
req.Header.Set("Authorization", tt.header)
136
+
w := httptest.NewRecorder()
138
+
handler.ServeHTTP(w, req)
140
+
if w.Code != http.StatusUnauthorized {
141
+
t.Errorf("expected status 401, got %d", w.Code)
147
+
// TestRequireAuth_BearerRejectionErrorMessage verifies that Bearer tokens are rejected
148
+
// with a helpful error message guiding users to use DPoP scheme
149
+
func TestRequireAuth_BearerRejectionErrorMessage(t *testing.T) {
150
+
fetcher := &mockJWKSFetcher{}
151
+
middleware := NewAtProtoAuthMiddleware(fetcher, true)
handler := middleware.RequireAuth(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
t.Error("handler should not be called")
req := httptest.NewRequest("GET", "/test", nil)
122
-
req.Header.Set("Authorization", "Basic dGVzdDp0ZXN0") // Wrong format
158
+
req.Header.Set("Authorization", "Bearer some-token")
w := httptest.NewRecorder()
handler.ServeHTTP(w, req)
if w.Code != http.StatusUnauthorized {
t.Errorf("expected status 401, got %d", w.Code)
167
+
// Verify error message guides user to use DPoP
168
+
body := w.Body.String()
169
+
if !strings.Contains(body, "Expected: DPoP") {
170
+
t.Errorf("error message should guide user to use DPoP, got: %s", body)
174
+
// TestRequireAuth_CaseInsensitiveScheme verifies that DPoP scheme matching is case-insensitive
175
+
// per RFC 7235 which states HTTP auth schemes are case-insensitive
176
+
func TestRequireAuth_CaseInsensitiveScheme(t *testing.T) {
177
+
fetcher := &mockJWKSFetcher{}
178
+
middleware := NewAtProtoAuthMiddleware(fetcher, true)
180
+
// Create a valid JWT for testing
181
+
validToken := createValidJWT(t, "did:plc:test123", time.Hour)
183
+
testCases := []struct {
187
+
{"lowercase", "dpop"},
188
+
{"uppercase", "DPOP"},
189
+
{"mixed_case", "DpOp"},
190
+
{"standard", "DPoP"},
193
+
for _, tc := range testCases {
194
+
t.Run(tc.name, func(t *testing.T) {
195
+
handlerCalled := false
196
+
handler := middleware.RequireAuth(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
197
+
handlerCalled = true
198
+
w.WriteHeader(http.StatusOK)
201
+
req := httptest.NewRequest("GET", "/test", nil)
202
+
req.Header.Set("Authorization", tc.scheme+" "+validToken)
203
+
w := httptest.NewRecorder()
205
+
handler.ServeHTTP(w, req)
207
+
if !handlerCalled {
208
+
t.Errorf("scheme %q should be accepted (case-insensitive per RFC 7235), got status %d: %s",
209
+
tc.scheme, w.Code, w.Body.String())
···
req := httptest.NewRequest("GET", "/test", nil)
142
-
req.Header.Set("Authorization", "Bearer not-a-valid-jwt")
225
+
req.Header.Set("Authorization", "DPoP not-a-valid-jwt")
w := httptest.NewRecorder()
handler.ServeHTTP(w, req)
···
tokenString, _ := token.SignedString(jwt.UnsafeAllowNoneSignatureType)
req := httptest.NewRequest("GET", "/test", nil)
174
-
req.Header.Set("Authorization", "Bearer "+tokenString)
257
+
req.Header.Set("Authorization", "DPoP "+tokenString)
w := httptest.NewRecorder()
handler.ServeHTTP(w, req)
···
tokenString, _ := token.SignedString(jwt.UnsafeAllowNoneSignatureType)
req := httptest.NewRequest("GET", "/test", nil)
206
-
req.Header.Set("Authorization", "Bearer "+tokenString)
289
+
req.Header.Set("Authorization", "DPoP "+tokenString)
w := httptest.NewRecorder()
handler.ServeHTTP(w, req)
···
216
-
// TestOptionalAuth_WithToken tests that OptionalAuth accepts valid tokens
299
+
// TestOptionalAuth_WithToken tests that OptionalAuth accepts valid DPoP tokens
func TestOptionalAuth_WithToken(t *testing.T) {
fetcher := &mockJWKSFetcher{}
middleware := NewAtProtoAuthMiddleware(fetcher, true)
···
token := createTestToken("did:plc:test123")
req := httptest.NewRequest("GET", "/test", nil)
236
-
req.Header.Set("Authorization", "Bearer "+token)
319
+
req.Header.Set("Authorization", "DPoP "+token)
w := httptest.NewRecorder()
handler.ServeHTTP(w, req)
···
req := httptest.NewRequest("GET", "/test", nil)
302
-
req.Header.Set("Authorization", "Bearer not-a-valid-jwt")
385
+
req.Header.Set("Authorization", "DPoP not-a-valid-jwt")
w := httptest.NewRecorder()
handler.ServeHTTP(w, req)
···
req := httptest.NewRequest("GET", "https://test.local/api/endpoint", nil)
396
-
req.Header.Set("Authorization", "Bearer "+tokenString)
479
+
req.Header.Set("Authorization", "DPoP "+tokenString)
req.Header.Set("DPoP", dpopProof)
w := httptest.NewRecorder()
···
req := httptest.NewRequest("GET", "https://test.local/api/endpoint", nil)
440
-
req.Header.Set("Authorization", "Bearer "+tokenString)
523
+
req.Header.Set("Authorization", "DPoP "+tokenString)
w := httptest.NewRecorder()
···
req := httptest.NewRequest("POST", "https://api.example.com/protected", nil)
491
-
req.Header.Set("Authorization", "Bearer "+tokenString)
574
+
req.Header.Set("Authorization", "DPoP "+tokenString)
req.Header.Set("DPoP", dpopProof)
w := httptest.NewRecorder()
···
req.Host = "api.example.com"
req.Header.Set("X-Forwarded-Proto", "https")
540
-
proof, err := middleware.verifyDPoPBinding(req, claims, dpopProof)
623
+
// Pass a fake access token - ath verification will pass since we don't include ath in the DPoP proof
624
+
fakeAccessToken := "fake-access-token-for-testing"
625
+
proof, err := middleware.verifyDPoPBinding(req, claims, dpopProof, fakeAccessToken)
t.Fatalf("expected DPoP verification to succeed with forwarded proto, got %v", err)
···
635
+
// TestVerifyDPoPBinding_UsesForwardedHost ensures we honor X-Forwarded-Host header
636
+
// when behind a TLS-terminating proxy that rewrites the Host header.
637
+
func TestVerifyDPoPBinding_UsesForwardedHost(t *testing.T) {
638
+
privateKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
640
+
t.Fatalf("failed to generate key: %v", err)
643
+
jwk := ecdsaPublicKeyToJWK(&privateKey.PublicKey)
644
+
thumbprint, err := auth.CalculateJWKThumbprint(jwk)
646
+
t.Fatalf("failed to calculate thumbprint: %v", err)
649
+
claims := &auth.Claims{
650
+
RegisteredClaims: jwt.RegisteredClaims{
651
+
Subject: "did:plc:test123",
652
+
Issuer: "https://test.pds.local",
653
+
ExpiresAt: jwt.NewNumericDate(time.Now().Add(1 * time.Hour)),
654
+
IssuedAt: jwt.NewNumericDate(time.Now()),
657
+
Confirmation: map[string]interface{}{
662
+
middleware := NewAtProtoAuthMiddleware(&mockJWKSFetcher{}, false)
663
+
defer middleware.Stop()
665
+
// External URI that the client uses
666
+
externalURI := "https://api.example.com/protected/resource"
667
+
dpopProof := createDPoPProof(t, privateKey, "GET", externalURI)
669
+
// Request hits internal service with internal hostname, but X-Forwarded-Host has public hostname
670
+
req := httptest.NewRequest("GET", "http://internal-service:8080/protected/resource", nil)
671
+
req.Host = "internal-service:8080" // Internal host after proxy
672
+
req.Header.Set("X-Forwarded-Proto", "https")
673
+
req.Header.Set("X-Forwarded-Host", "api.example.com") // Original public host
675
+
fakeAccessToken := "fake-access-token-for-testing"
676
+
proof, err := middleware.verifyDPoPBinding(req, claims, dpopProof, fakeAccessToken)
678
+
t.Fatalf("expected DPoP verification to succeed with X-Forwarded-Host, got %v", err)
681
+
if proof == nil || proof.Claims == nil {
682
+
t.Fatal("expected DPoP proof to be returned")
686
+
// TestVerifyDPoPBinding_UsesStandardForwardedHeader tests RFC 7239 Forwarded header parsing
687
+
func TestVerifyDPoPBinding_UsesStandardForwardedHeader(t *testing.T) {
688
+
privateKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
690
+
t.Fatalf("failed to generate key: %v", err)
693
+
jwk := ecdsaPublicKeyToJWK(&privateKey.PublicKey)
694
+
thumbprint, err := auth.CalculateJWKThumbprint(jwk)
696
+
t.Fatalf("failed to calculate thumbprint: %v", err)
699
+
claims := &auth.Claims{
700
+
RegisteredClaims: jwt.RegisteredClaims{
701
+
Subject: "did:plc:test123",
702
+
Issuer: "https://test.pds.local",
703
+
ExpiresAt: jwt.NewNumericDate(time.Now().Add(1 * time.Hour)),
704
+
IssuedAt: jwt.NewNumericDate(time.Now()),
707
+
Confirmation: map[string]interface{}{
712
+
middleware := NewAtProtoAuthMiddleware(&mockJWKSFetcher{}, false)
713
+
defer middleware.Stop()
716
+
externalURI := "https://api.example.com/protected/resource"
717
+
dpopProof := createDPoPProof(t, privateKey, "GET", externalURI)
719
+
// Request with standard Forwarded header (RFC 7239)
720
+
req := httptest.NewRequest("GET", "http://internal-service/protected/resource", nil)
721
+
req.Host = "internal-service"
722
+
req.Header.Set("Forwarded", "for=192.0.2.60;proto=https;host=api.example.com")
724
+
fakeAccessToken := "fake-access-token-for-testing"
725
+
proof, err := middleware.verifyDPoPBinding(req, claims, dpopProof, fakeAccessToken)
727
+
t.Fatalf("expected DPoP verification to succeed with Forwarded header, got %v", err)
731
+
t.Fatal("expected DPoP proof to be returned")
735
+
// TestVerifyDPoPBinding_ForwardedMixedCaseAndQuotes tests RFC 7239 edge cases:
736
+
// mixed-case keys (Proto vs proto) and quoted values (host="example.com")
737
+
func TestVerifyDPoPBinding_ForwardedMixedCaseAndQuotes(t *testing.T) {
738
+
privateKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
740
+
t.Fatalf("failed to generate key: %v", err)
743
+
jwk := ecdsaPublicKeyToJWK(&privateKey.PublicKey)
744
+
thumbprint, err := auth.CalculateJWKThumbprint(jwk)
746
+
t.Fatalf("failed to calculate thumbprint: %v", err)
749
+
claims := &auth.Claims{
750
+
RegisteredClaims: jwt.RegisteredClaims{
751
+
Subject: "did:plc:test123",
752
+
Issuer: "https://test.pds.local",
753
+
ExpiresAt: jwt.NewNumericDate(time.Now().Add(1 * time.Hour)),
754
+
IssuedAt: jwt.NewNumericDate(time.Now()),
757
+
Confirmation: map[string]interface{}{
762
+
middleware := NewAtProtoAuthMiddleware(&mockJWKSFetcher{}, false)
763
+
defer middleware.Stop()
765
+
// External URI that the client uses
766
+
externalURI := "https://api.example.com/protected/resource"
767
+
dpopProof := createDPoPProof(t, privateKey, "GET", externalURI)
769
+
// Request with RFC 7239 Forwarded header using:
770
+
// - Mixed-case keys: "Proto" instead of "proto", "Host" instead of "host"
771
+
// - Quoted value: Host="api.example.com" (legal per RFC 7239 section 4)
772
+
req := httptest.NewRequest("GET", "http://internal-service/protected/resource", nil)
773
+
req.Host = "internal-service"
774
+
req.Header.Set("Forwarded", `for=192.0.2.60;Proto=https;Host="api.example.com"`)
776
+
fakeAccessToken := "fake-access-token-for-testing"
777
+
proof, err := middleware.verifyDPoPBinding(req, claims, dpopProof, fakeAccessToken)
779
+
t.Fatalf("expected DPoP verification to succeed with mixed-case/quoted Forwarded header, got %v", err)
783
+
t.Fatal("expected DPoP proof to be returned")
787
+
// TestVerifyDPoPBinding_AthValidation tests access token hash (ath) claim validation
788
+
func TestVerifyDPoPBinding_AthValidation(t *testing.T) {
789
+
privateKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
791
+
t.Fatalf("failed to generate key: %v", err)
794
+
jwk := ecdsaPublicKeyToJWK(&privateKey.PublicKey)
795
+
thumbprint, err := auth.CalculateJWKThumbprint(jwk)
797
+
t.Fatalf("failed to calculate thumbprint: %v", err)
800
+
claims := &auth.Claims{
801
+
RegisteredClaims: jwt.RegisteredClaims{
802
+
Subject: "did:plc:test123",
803
+
Issuer: "https://test.pds.local",
804
+
ExpiresAt: jwt.NewNumericDate(time.Now().Add(1 * time.Hour)),
805
+
IssuedAt: jwt.NewNumericDate(time.Now()),
808
+
Confirmation: map[string]interface{}{
813
+
middleware := NewAtProtoAuthMiddleware(&mockJWKSFetcher{}, false)
814
+
defer middleware.Stop()
816
+
accessToken := "real-access-token-12345"
818
+
t.Run("ath_matches_access_token", func(t *testing.T) {
819
+
// Create DPoP proof with ath claim matching the access token
820
+
dpopProof := createDPoPProofWithAth(t, privateKey, "GET", "https://api.example.com/resource", accessToken)
822
+
req := httptest.NewRequest("GET", "https://api.example.com/resource", nil)
823
+
req.Host = "api.example.com"
825
+
proof, err := middleware.verifyDPoPBinding(req, claims, dpopProof, accessToken)
827
+
t.Fatalf("expected verification to succeed with matching ath, got %v", err)
830
+
t.Fatal("expected proof to be returned")
834
+
t.Run("ath_mismatch_rejected", func(t *testing.T) {
835
+
// Create DPoP proof with ath for a DIFFERENT token
836
+
differentToken := "different-token-67890"
837
+
dpopProof := createDPoPProofWithAth(t, privateKey, "POST", "https://api.example.com/resource", differentToken)
839
+
req := httptest.NewRequest("POST", "https://api.example.com/resource", nil)
840
+
req.Host = "api.example.com"
842
+
// Try to use with the original access token - should fail
843
+
_, err := middleware.verifyDPoPBinding(req, claims, dpopProof, accessToken)
845
+
t.Fatal("SECURITY: expected verification to fail when ath doesn't match access token")
847
+
if !strings.Contains(err.Error(), "ath") {
848
+
t.Errorf("error should mention ath mismatch, got: %v", err)
// TestMiddlewareStop tests that the middleware can be stopped properly
func TestMiddlewareStop(t *testing.T) {
fetcher := &mockJWKSFetcher{}
···
req := httptest.NewRequest("GET", "/test", nil)
610
-
req.Header.Set("Authorization", "Bearer "+tokenString)
913
+
req.Header.Set("Authorization", "DPoP "+tokenString)
// Deliberately NOT setting DPoP header
w := httptest.NewRecorder()
···
req := httptest.NewRequest("GET", "/test", nil)
642
-
req.Header.Set("Authorization", "Bearer "+tokenString)
945
+
req.Header.Set("Authorization", "DPoP "+tokenString)
w := httptest.NewRecorder()
···
1015
+
// Helper: createDPoPProofWithAth creates a DPoP proof JWT with ath (access token hash) claim
1016
+
func createDPoPProofWithAth(t *testing.T, privateKey *ecdsa.PrivateKey, method, uri, accessToken string) string {
1017
+
// Create JWK from public key
1018
+
jwk := ecdsaPublicKeyToJWK(&privateKey.PublicKey)
1020
+
// Calculate ath: base64url(SHA-256(access_token))
1021
+
hash := sha256.Sum256([]byte(accessToken))
1022
+
ath := base64.RawURLEncoding.EncodeToString(hash[:])
1024
+
// Create DPoP claims with ath
1025
+
claims := auth.DPoPClaims{
1026
+
RegisteredClaims: jwt.RegisteredClaims{
1027
+
IssuedAt: jwt.NewNumericDate(time.Now()),
1028
+
ID: uuid.New().String(),
1030
+
HTTPMethod: method,
1032
+
AccessTokenHash: ath,
1035
+
// Create token with custom header
1036
+
token := jwt.NewWithClaims(jwt.SigningMethodES256, claims)
1037
+
token.Header["typ"] = "dpop+jwt"
1038
+
token.Header["jwk"] = jwk
1040
+
// Sign with private key
1041
+
signedToken, err := token.SignedString(privateKey)
1043
+
t.Fatalf("failed to sign DPoP proof: %v", err)
1046
+
return signedToken
// Helper: ecdsaPublicKeyToJWK converts an ECDSA public key to JWK map
func ecdsaPublicKeyToJWK(pubKey *ecdsa.PublicKey) map[string]interface{} {
···
"y": base64.RawURLEncoding.EncodeToString(yPadded),
1083
+
// Helper: createValidJWT creates a valid unsigned JWT token for testing
1084
+
// This is used with skipVerify=true middleware where signature verification is skipped
1085
+
func createValidJWT(t *testing.T, subject string, expiry time.Duration) string {
1088
+
claims := auth.Claims{
1089
+
RegisteredClaims: jwt.RegisteredClaims{
1091
+
Issuer: "https://test.pds.local",
1092
+
ExpiresAt: jwt.NewNumericDate(time.Now().Add(expiry)),
1093
+
IssuedAt: jwt.NewNumericDate(time.Now()),
1098
+
// Create unsigned token (for skipVerify=true tests)
1099
+
token := jwt.NewWithClaims(jwt.SigningMethodNone, claims)
1100
+
signedToken, err := token.SignedString(jwt.UnsafeAllowNoneSignatureType)
1102
+
t.Fatalf("failed to create test JWT: %v", err)
1105
+
return signedToken