···
···
+
// TestRequireAuth_ValidToken tests that valid tokens are accepted with DPoP scheme (Phase 1)
func TestRequireAuth_ValidToken(t *testing.T) {
fetcher := &mockJWKSFetcher{}
middleware := NewAtProtoAuthMiddleware(fetcher, true) // skipVerify=true
···
token := createTestToken("did:plc:test123")
req := httptest.NewRequest("GET", "/test", nil)
+
req.Header.Set("Authorization", "DPoP "+token)
w := httptest.NewRecorder()
handler.ServeHTTP(w, req)
···
+
// TestRequireAuth_InvalidAuthHeaderFormat tests that non-DPoP tokens are rejected (including Bearer)
func TestRequireAuth_InvalidAuthHeaderFormat(t *testing.T) {
fetcher := &mockJWKSFetcher{}
middleware := NewAtProtoAuthMiddleware(fetcher, true)
+
{"Basic auth", "Basic dGVzdDp0ZXN0"},
+
{"Bearer scheme", "Bearer some-token"},
+
{"Invalid format", "InvalidFormat"},
+
for _, tt := range tests {
+
t.Run(tt.name, func(t *testing.T) {
+
handler := middleware.RequireAuth(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+
t.Error("handler should not be called")
+
req := httptest.NewRequest("GET", "/test", nil)
+
req.Header.Set("Authorization", tt.header)
+
w := httptest.NewRecorder()
+
handler.ServeHTTP(w, req)
+
if w.Code != http.StatusUnauthorized {
+
t.Errorf("expected status 401, got %d", w.Code)
+
// TestRequireAuth_BearerRejectionErrorMessage verifies that Bearer tokens are rejected
+
// with a helpful error message guiding users to use DPoP scheme
+
func TestRequireAuth_BearerRejectionErrorMessage(t *testing.T) {
+
fetcher := &mockJWKSFetcher{}
+
middleware := NewAtProtoAuthMiddleware(fetcher, true)
handler := middleware.RequireAuth(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
t.Error("handler should not be called")
req := httptest.NewRequest("GET", "/test", nil)
+
req.Header.Set("Authorization", "Bearer some-token")
w := httptest.NewRecorder()
handler.ServeHTTP(w, req)
if w.Code != http.StatusUnauthorized {
t.Errorf("expected status 401, got %d", w.Code)
+
// Verify error message guides user to use DPoP
+
body := w.Body.String()
+
if !strings.Contains(body, "Expected: DPoP") {
+
t.Errorf("error message should guide user to use DPoP, got: %s", body)
+
// TestRequireAuth_CaseInsensitiveScheme verifies that DPoP scheme matching is case-insensitive
+
// per RFC 7235 which states HTTP auth schemes are case-insensitive
+
func TestRequireAuth_CaseInsensitiveScheme(t *testing.T) {
+
fetcher := &mockJWKSFetcher{}
+
middleware := NewAtProtoAuthMiddleware(fetcher, true)
+
// Create a valid JWT for testing
+
validToken := createValidJWT(t, "did:plc:test123", time.Hour)
+
testCases := []struct {
+
{"mixed_case", "DpOp"},
+
for _, tc := range testCases {
+
t.Run(tc.name, func(t *testing.T) {
+
handler := middleware.RequireAuth(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+
w.WriteHeader(http.StatusOK)
+
req := httptest.NewRequest("GET", "/test", nil)
+
req.Header.Set("Authorization", tc.scheme+" "+validToken)
+
w := httptest.NewRecorder()
+
handler.ServeHTTP(w, req)
+
t.Errorf("scheme %q should be accepted (case-insensitive per RFC 7235), got status %d: %s",
+
tc.scheme, w.Code, w.Body.String())
···
req := httptest.NewRequest("GET", "/test", nil)
+
req.Header.Set("Authorization", "DPoP not-a-valid-jwt")
w := httptest.NewRecorder()
handler.ServeHTTP(w, req)
···
tokenString, _ := token.SignedString(jwt.UnsafeAllowNoneSignatureType)
req := httptest.NewRequest("GET", "/test", nil)
+
req.Header.Set("Authorization", "DPoP "+tokenString)
w := httptest.NewRecorder()
handler.ServeHTTP(w, req)
···
tokenString, _ := token.SignedString(jwt.UnsafeAllowNoneSignatureType)
req := httptest.NewRequest("GET", "/test", nil)
+
req.Header.Set("Authorization", "DPoP "+tokenString)
w := httptest.NewRecorder()
handler.ServeHTTP(w, req)
···
+
// TestOptionalAuth_WithToken tests that OptionalAuth accepts valid DPoP tokens
func TestOptionalAuth_WithToken(t *testing.T) {
fetcher := &mockJWKSFetcher{}
middleware := NewAtProtoAuthMiddleware(fetcher, true)
···
token := createTestToken("did:plc:test123")
req := httptest.NewRequest("GET", "/test", nil)
+
req.Header.Set("Authorization", "DPoP "+token)
w := httptest.NewRecorder()
handler.ServeHTTP(w, req)
···
req := httptest.NewRequest("GET", "/test", nil)
+
req.Header.Set("Authorization", "DPoP not-a-valid-jwt")
w := httptest.NewRecorder()
handler.ServeHTTP(w, req)
···
req := httptest.NewRequest("GET", "https://test.local/api/endpoint", nil)
+
req.Header.Set("Authorization", "DPoP "+tokenString)
req.Header.Set("DPoP", dpopProof)
w := httptest.NewRecorder()
···
req := httptest.NewRequest("GET", "https://test.local/api/endpoint", nil)
+
req.Header.Set("Authorization", "DPoP "+tokenString)
w := httptest.NewRecorder()
···
req := httptest.NewRequest("POST", "https://api.example.com/protected", nil)
+
req.Header.Set("Authorization", "DPoP "+tokenString)
req.Header.Set("DPoP", dpopProof)
w := httptest.NewRecorder()
···
req.Host = "api.example.com"
req.Header.Set("X-Forwarded-Proto", "https")
+
// Pass a fake access token - ath verification will pass since we don't include ath in the DPoP proof
+
fakeAccessToken := "fake-access-token-for-testing"
+
proof, err := middleware.verifyDPoPBinding(req, claims, dpopProof, fakeAccessToken)
t.Fatalf("expected DPoP verification to succeed with forwarded proto, got %v", err)
···
+
// TestVerifyDPoPBinding_UsesForwardedHost ensures we honor X-Forwarded-Host header
+
// when behind a TLS-terminating proxy that rewrites the Host header.
+
func TestVerifyDPoPBinding_UsesForwardedHost(t *testing.T) {
+
privateKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
+
t.Fatalf("failed to generate key: %v", err)
+
jwk := ecdsaPublicKeyToJWK(&privateKey.PublicKey)
+
thumbprint, err := auth.CalculateJWKThumbprint(jwk)
+
t.Fatalf("failed to calculate thumbprint: %v", err)
+
claims := &auth.Claims{
+
RegisteredClaims: jwt.RegisteredClaims{
+
Subject: "did:plc:test123",
+
Issuer: "https://test.pds.local",
+
ExpiresAt: jwt.NewNumericDate(time.Now().Add(1 * time.Hour)),
+
IssuedAt: jwt.NewNumericDate(time.Now()),
+
Confirmation: map[string]interface{}{
+
middleware := NewAtProtoAuthMiddleware(&mockJWKSFetcher{}, false)
+
defer middleware.Stop()
+
// External URI that the client uses
+
externalURI := "https://api.example.com/protected/resource"
+
dpopProof := createDPoPProof(t, privateKey, "GET", externalURI)
+
// Request hits internal service with internal hostname, but X-Forwarded-Host has public hostname
+
req := httptest.NewRequest("GET", "http://internal-service:8080/protected/resource", nil)
+
req.Host = "internal-service:8080" // Internal host after proxy
+
req.Header.Set("X-Forwarded-Proto", "https")
+
req.Header.Set("X-Forwarded-Host", "api.example.com") // Original public host
+
fakeAccessToken := "fake-access-token-for-testing"
+
proof, err := middleware.verifyDPoPBinding(req, claims, dpopProof, fakeAccessToken)
+
t.Fatalf("expected DPoP verification to succeed with X-Forwarded-Host, got %v", err)
+
if proof == nil || proof.Claims == nil {
+
t.Fatal("expected DPoP proof to be returned")
+
// TestVerifyDPoPBinding_UsesStandardForwardedHeader tests RFC 7239 Forwarded header parsing
+
func TestVerifyDPoPBinding_UsesStandardForwardedHeader(t *testing.T) {
+
privateKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
+
t.Fatalf("failed to generate key: %v", err)
+
jwk := ecdsaPublicKeyToJWK(&privateKey.PublicKey)
+
thumbprint, err := auth.CalculateJWKThumbprint(jwk)
+
t.Fatalf("failed to calculate thumbprint: %v", err)
+
claims := &auth.Claims{
+
RegisteredClaims: jwt.RegisteredClaims{
+
Subject: "did:plc:test123",
+
Issuer: "https://test.pds.local",
+
ExpiresAt: jwt.NewNumericDate(time.Now().Add(1 * time.Hour)),
+
IssuedAt: jwt.NewNumericDate(time.Now()),
+
Confirmation: map[string]interface{}{
+
middleware := NewAtProtoAuthMiddleware(&mockJWKSFetcher{}, false)
+
defer middleware.Stop()
+
externalURI := "https://api.example.com/protected/resource"
+
dpopProof := createDPoPProof(t, privateKey, "GET", externalURI)
+
// Request with standard Forwarded header (RFC 7239)
+
req := httptest.NewRequest("GET", "http://internal-service/protected/resource", nil)
+
req.Host = "internal-service"
+
req.Header.Set("Forwarded", "for=192.0.2.60;proto=https;host=api.example.com")
+
fakeAccessToken := "fake-access-token-for-testing"
+
proof, err := middleware.verifyDPoPBinding(req, claims, dpopProof, fakeAccessToken)
+
t.Fatalf("expected DPoP verification to succeed with Forwarded header, got %v", err)
+
t.Fatal("expected DPoP proof to be returned")
+
// TestVerifyDPoPBinding_ForwardedMixedCaseAndQuotes tests RFC 7239 edge cases:
+
// mixed-case keys (Proto vs proto) and quoted values (host="example.com")
+
func TestVerifyDPoPBinding_ForwardedMixedCaseAndQuotes(t *testing.T) {
+
privateKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
+
t.Fatalf("failed to generate key: %v", err)
+
jwk := ecdsaPublicKeyToJWK(&privateKey.PublicKey)
+
thumbprint, err := auth.CalculateJWKThumbprint(jwk)
+
t.Fatalf("failed to calculate thumbprint: %v", err)
+
claims := &auth.Claims{
+
RegisteredClaims: jwt.RegisteredClaims{
+
Subject: "did:plc:test123",
+
Issuer: "https://test.pds.local",
+
ExpiresAt: jwt.NewNumericDate(time.Now().Add(1 * time.Hour)),
+
IssuedAt: jwt.NewNumericDate(time.Now()),
+
Confirmation: map[string]interface{}{
+
middleware := NewAtProtoAuthMiddleware(&mockJWKSFetcher{}, false)
+
defer middleware.Stop()
+
// External URI that the client uses
+
externalURI := "https://api.example.com/protected/resource"
+
dpopProof := createDPoPProof(t, privateKey, "GET", externalURI)
+
// Request with RFC 7239 Forwarded header using:
+
// - Mixed-case keys: "Proto" instead of "proto", "Host" instead of "host"
+
// - Quoted value: Host="api.example.com" (legal per RFC 7239 section 4)
+
req := httptest.NewRequest("GET", "http://internal-service/protected/resource", nil)
+
req.Host = "internal-service"
+
req.Header.Set("Forwarded", `for=192.0.2.60;Proto=https;Host="api.example.com"`)
+
fakeAccessToken := "fake-access-token-for-testing"
+
proof, err := middleware.verifyDPoPBinding(req, claims, dpopProof, fakeAccessToken)
+
t.Fatalf("expected DPoP verification to succeed with mixed-case/quoted Forwarded header, got %v", err)
+
t.Fatal("expected DPoP proof to be returned")
+
// TestVerifyDPoPBinding_AthValidation tests access token hash (ath) claim validation
+
func TestVerifyDPoPBinding_AthValidation(t *testing.T) {
+
privateKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
+
t.Fatalf("failed to generate key: %v", err)
+
jwk := ecdsaPublicKeyToJWK(&privateKey.PublicKey)
+
thumbprint, err := auth.CalculateJWKThumbprint(jwk)
+
t.Fatalf("failed to calculate thumbprint: %v", err)
+
claims := &auth.Claims{
+
RegisteredClaims: jwt.RegisteredClaims{
+
Subject: "did:plc:test123",
+
Issuer: "https://test.pds.local",
+
ExpiresAt: jwt.NewNumericDate(time.Now().Add(1 * time.Hour)),
+
IssuedAt: jwt.NewNumericDate(time.Now()),
+
Confirmation: map[string]interface{}{
+
middleware := NewAtProtoAuthMiddleware(&mockJWKSFetcher{}, false)
+
defer middleware.Stop()
+
accessToken := "real-access-token-12345"
+
t.Run("ath_matches_access_token", func(t *testing.T) {
+
// Create DPoP proof with ath claim matching the access token
+
dpopProof := createDPoPProofWithAth(t, privateKey, "GET", "https://api.example.com/resource", accessToken)
+
req := httptest.NewRequest("GET", "https://api.example.com/resource", nil)
+
req.Host = "api.example.com"
+
proof, err := middleware.verifyDPoPBinding(req, claims, dpopProof, accessToken)
+
t.Fatalf("expected verification to succeed with matching ath, got %v", err)
+
t.Fatal("expected proof to be returned")
+
t.Run("ath_mismatch_rejected", func(t *testing.T) {
+
// Create DPoP proof with ath for a DIFFERENT token
+
differentToken := "different-token-67890"
+
dpopProof := createDPoPProofWithAth(t, privateKey, "POST", "https://api.example.com/resource", differentToken)
+
req := httptest.NewRequest("POST", "https://api.example.com/resource", nil)
+
req.Host = "api.example.com"
+
// Try to use with the original access token - should fail
+
_, err := middleware.verifyDPoPBinding(req, claims, dpopProof, accessToken)
+
t.Fatal("SECURITY: expected verification to fail when ath doesn't match access token")
+
if !strings.Contains(err.Error(), "ath") {
+
t.Errorf("error should mention ath mismatch, got: %v", err)
// TestMiddlewareStop tests that the middleware can be stopped properly
func TestMiddlewareStop(t *testing.T) {
fetcher := &mockJWKSFetcher{}
···
req := httptest.NewRequest("GET", "/test", nil)
+
req.Header.Set("Authorization", "DPoP "+tokenString)
// Deliberately NOT setting DPoP header
w := httptest.NewRecorder()
···
req := httptest.NewRequest("GET", "/test", nil)
+
req.Header.Set("Authorization", "DPoP "+tokenString)
w := httptest.NewRecorder()
···
+
// Helper: createDPoPProofWithAth creates a DPoP proof JWT with ath (access token hash) claim
+
func createDPoPProofWithAth(t *testing.T, privateKey *ecdsa.PrivateKey, method, uri, accessToken string) string {
+
// Create JWK from public key
+
jwk := ecdsaPublicKeyToJWK(&privateKey.PublicKey)
+
// Calculate ath: base64url(SHA-256(access_token))
+
hash := sha256.Sum256([]byte(accessToken))
+
ath := base64.RawURLEncoding.EncodeToString(hash[:])
+
// Create DPoP claims with ath
+
claims := auth.DPoPClaims{
+
RegisteredClaims: jwt.RegisteredClaims{
+
IssuedAt: jwt.NewNumericDate(time.Now()),
+
ID: uuid.New().String(),
+
// Create token with custom header
+
token := jwt.NewWithClaims(jwt.SigningMethodES256, claims)
+
token.Header["typ"] = "dpop+jwt"
+
token.Header["jwk"] = jwk
+
// Sign with private key
+
signedToken, err := token.SignedString(privateKey)
+
t.Fatalf("failed to sign DPoP proof: %v", err)
// Helper: ecdsaPublicKeyToJWK converts an ECDSA public key to JWK map
func ecdsaPublicKeyToJWK(pubKey *ecdsa.PublicKey) map[string]interface{} {
···
"y": base64.RawURLEncoding.EncodeToString(yPadded),
+
// Helper: createValidJWT creates a valid unsigned JWT token for testing
+
// This is used with skipVerify=true middleware where signature verification is skipped
+
func createValidJWT(t *testing.T, subject string, expiry time.Duration) string {
+
RegisteredClaims: jwt.RegisteredClaims{
+
Issuer: "https://test.pds.local",
+
ExpiresAt: jwt.NewNumericDate(time.Now().Add(expiry)),
+
IssuedAt: jwt.NewNumericDate(time.Now()),
+
// Create unsigned token (for skipVerify=true tests)
+
token := jwt.NewWithClaims(jwt.SigningMethodNone, claims)
+
signedToken, err := token.SignedString(jwt.UnsafeAllowNoneSignatureType)
+
t.Fatalf("failed to create test JWT: %v", err)