A community based topic aggregation platform built on atproto
1package middleware 2 3import ( 4 "Coves/internal/atproto/auth" 5 "context" 6 "crypto/ecdsa" 7 "crypto/elliptic" 8 "crypto/rand" 9 "crypto/sha256" 10 "encoding/base64" 11 "fmt" 12 "net/http" 13 "net/http/httptest" 14 "strings" 15 "testing" 16 "time" 17 18 "github.com/golang-jwt/jwt/v5" 19 "github.com/google/uuid" 20) 21 22// mockJWKSFetcher is a test double for JWKSFetcher 23type mockJWKSFetcher struct { 24 shouldFail bool 25} 26 27func (m *mockJWKSFetcher) FetchPublicKey(ctx context.Context, issuer, token string) (interface{}, error) { 28 if m.shouldFail { 29 return nil, fmt.Errorf("mock fetch failure") 30 } 31 // Return nil - we won't actually verify signatures in Phase 1 tests 32 return nil, nil 33} 34 35// createTestToken creates a test JWT with the given DID 36func createTestToken(did string) string { 37 claims := jwt.MapClaims{ 38 "sub": did, 39 "iss": "https://test.pds.local", 40 "scope": "atproto", 41 "exp": time.Now().Add(1 * time.Hour).Unix(), 42 "iat": time.Now().Unix(), 43 } 44 45 token := jwt.NewWithClaims(jwt.SigningMethodNone, claims) 46 tokenString, _ := token.SignedString(jwt.UnsafeAllowNoneSignatureType) 47 return tokenString 48} 49 50// TestRequireAuth_ValidToken tests that valid tokens are accepted with DPoP scheme (Phase 1) 51func TestRequireAuth_ValidToken(t *testing.T) { 52 fetcher := &mockJWKSFetcher{} 53 middleware := NewAtProtoAuthMiddleware(fetcher, true) // skipVerify=true 54 55 handlerCalled := false 56 handler := middleware.RequireAuth(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 57 handlerCalled = true 58 59 // Verify DID was extracted and injected into context 60 did := GetUserDID(r) 61 if did != "did:plc:test123" { 62 t.Errorf("expected DID 'did:plc:test123', got %s", did) 63 } 64 65 // Verify claims were injected 66 claims := GetJWTClaims(r) 67 if claims == nil { 68 t.Error("expected claims to be non-nil") 69 return 70 } 71 if claims.Subject != "did:plc:test123" { 72 t.Errorf("expected claims.Subject 'did:plc:test123', got %s", claims.Subject) 73 } 74 75 w.WriteHeader(http.StatusOK) 76 })) 77 78 token := createTestToken("did:plc:test123") 79 req := httptest.NewRequest("GET", "/test", nil) 80 req.Header.Set("Authorization", "DPoP "+token) 81 w := httptest.NewRecorder() 82 83 handler.ServeHTTP(w, req) 84 85 if !handlerCalled { 86 t.Error("handler was not called") 87 } 88 89 if w.Code != http.StatusOK { 90 t.Errorf("expected status 200, got %d: %s", w.Code, w.Body.String()) 91 } 92} 93 94// TestRequireAuth_MissingAuthHeader tests that missing Authorization header is rejected 95func TestRequireAuth_MissingAuthHeader(t *testing.T) { 96 fetcher := &mockJWKSFetcher{} 97 middleware := NewAtProtoAuthMiddleware(fetcher, true) 98 99 handler := middleware.RequireAuth(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 100 t.Error("handler should not be called") 101 })) 102 103 req := httptest.NewRequest("GET", "/test", nil) 104 // No Authorization header 105 w := httptest.NewRecorder() 106 107 handler.ServeHTTP(w, req) 108 109 if w.Code != http.StatusUnauthorized { 110 t.Errorf("expected status 401, got %d", w.Code) 111 } 112} 113 114// TestRequireAuth_InvalidAuthHeaderFormat tests that non-DPoP tokens are rejected (including Bearer) 115func TestRequireAuth_InvalidAuthHeaderFormat(t *testing.T) { 116 fetcher := &mockJWKSFetcher{} 117 middleware := NewAtProtoAuthMiddleware(fetcher, true) 118 119 tests := []struct { 120 name string 121 header string 122 }{ 123 {"Basic auth", "Basic dGVzdDp0ZXN0"}, 124 {"Bearer scheme", "Bearer some-token"}, 125 {"Invalid format", "InvalidFormat"}, 126 } 127 128 for _, tt := range tests { 129 t.Run(tt.name, func(t *testing.T) { 130 handler := middleware.RequireAuth(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 131 t.Error("handler should not be called") 132 })) 133 134 req := httptest.NewRequest("GET", "/test", nil) 135 req.Header.Set("Authorization", tt.header) 136 w := httptest.NewRecorder() 137 138 handler.ServeHTTP(w, req) 139 140 if w.Code != http.StatusUnauthorized { 141 t.Errorf("expected status 401, got %d", w.Code) 142 } 143 }) 144 } 145} 146 147// TestRequireAuth_BearerRejectionErrorMessage verifies that Bearer tokens are rejected 148// with a helpful error message guiding users to use DPoP scheme 149func TestRequireAuth_BearerRejectionErrorMessage(t *testing.T) { 150 fetcher := &mockJWKSFetcher{} 151 middleware := NewAtProtoAuthMiddleware(fetcher, true) 152 153 handler := middleware.RequireAuth(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 154 t.Error("handler should not be called") 155 })) 156 157 req := httptest.NewRequest("GET", "/test", nil) 158 req.Header.Set("Authorization", "Bearer some-token") 159 w := httptest.NewRecorder() 160 161 handler.ServeHTTP(w, req) 162 163 if w.Code != http.StatusUnauthorized { 164 t.Errorf("expected status 401, got %d", w.Code) 165 } 166 167 // Verify error message guides user to use DPoP 168 body := w.Body.String() 169 if !strings.Contains(body, "Expected: DPoP") { 170 t.Errorf("error message should guide user to use DPoP, got: %s", body) 171 } 172} 173 174// TestRequireAuth_CaseInsensitiveScheme verifies that DPoP scheme matching is case-insensitive 175// per RFC 7235 which states HTTP auth schemes are case-insensitive 176func TestRequireAuth_CaseInsensitiveScheme(t *testing.T) { 177 fetcher := &mockJWKSFetcher{} 178 middleware := NewAtProtoAuthMiddleware(fetcher, true) 179 180 // Create a valid JWT for testing 181 validToken := createValidJWT(t, "did:plc:test123", time.Hour) 182 183 testCases := []struct { 184 name string 185 scheme string 186 }{ 187 {"lowercase", "dpop"}, 188 {"uppercase", "DPOP"}, 189 {"mixed_case", "DpOp"}, 190 {"standard", "DPoP"}, 191 } 192 193 for _, tc := range testCases { 194 t.Run(tc.name, func(t *testing.T) { 195 handlerCalled := false 196 handler := middleware.RequireAuth(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 197 handlerCalled = true 198 w.WriteHeader(http.StatusOK) 199 })) 200 201 req := httptest.NewRequest("GET", "/test", nil) 202 req.Header.Set("Authorization", tc.scheme+" "+validToken) 203 w := httptest.NewRecorder() 204 205 handler.ServeHTTP(w, req) 206 207 if !handlerCalled { 208 t.Errorf("scheme %q should be accepted (case-insensitive per RFC 7235), got status %d: %s", 209 tc.scheme, w.Code, w.Body.String()) 210 } 211 }) 212 } 213} 214 215// TestRequireAuth_MalformedToken tests that malformed JWTs are rejected 216func TestRequireAuth_MalformedToken(t *testing.T) { 217 fetcher := &mockJWKSFetcher{} 218 middleware := NewAtProtoAuthMiddleware(fetcher, true) 219 220 handler := middleware.RequireAuth(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 221 t.Error("handler should not be called") 222 })) 223 224 req := httptest.NewRequest("GET", "/test", nil) 225 req.Header.Set("Authorization", "DPoP not-a-valid-jwt") 226 w := httptest.NewRecorder() 227 228 handler.ServeHTTP(w, req) 229 230 if w.Code != http.StatusUnauthorized { 231 t.Errorf("expected status 401, got %d", w.Code) 232 } 233} 234 235// TestRequireAuth_ExpiredToken tests that expired tokens are rejected 236func TestRequireAuth_ExpiredToken(t *testing.T) { 237 fetcher := &mockJWKSFetcher{} 238 middleware := NewAtProtoAuthMiddleware(fetcher, true) 239 240 handler := middleware.RequireAuth(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 241 t.Error("handler should not be called for expired token") 242 })) 243 244 // Create expired token 245 claims := jwt.MapClaims{ 246 "sub": "did:plc:test123", 247 "iss": "https://test.pds.local", 248 "scope": "atproto", 249 "exp": time.Now().Add(-1 * time.Hour).Unix(), // Expired 1 hour ago 250 "iat": time.Now().Add(-2 * time.Hour).Unix(), 251 } 252 253 token := jwt.NewWithClaims(jwt.SigningMethodNone, claims) 254 tokenString, _ := token.SignedString(jwt.UnsafeAllowNoneSignatureType) 255 256 req := httptest.NewRequest("GET", "/test", nil) 257 req.Header.Set("Authorization", "DPoP "+tokenString) 258 w := httptest.NewRecorder() 259 260 handler.ServeHTTP(w, req) 261 262 if w.Code != http.StatusUnauthorized { 263 t.Errorf("expected status 401, got %d", w.Code) 264 } 265} 266 267// TestRequireAuth_MissingDID tests that tokens without DID are rejected 268func TestRequireAuth_MissingDID(t *testing.T) { 269 fetcher := &mockJWKSFetcher{} 270 middleware := NewAtProtoAuthMiddleware(fetcher, true) 271 272 handler := middleware.RequireAuth(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 273 t.Error("handler should not be called") 274 })) 275 276 // Create token without sub claim 277 claims := jwt.MapClaims{ 278 // "sub" missing 279 "iss": "https://test.pds.local", 280 "scope": "atproto", 281 "exp": time.Now().Add(1 * time.Hour).Unix(), 282 "iat": time.Now().Unix(), 283 } 284 285 token := jwt.NewWithClaims(jwt.SigningMethodNone, claims) 286 tokenString, _ := token.SignedString(jwt.UnsafeAllowNoneSignatureType) 287 288 req := httptest.NewRequest("GET", "/test", nil) 289 req.Header.Set("Authorization", "DPoP "+tokenString) 290 w := httptest.NewRecorder() 291 292 handler.ServeHTTP(w, req) 293 294 if w.Code != http.StatusUnauthorized { 295 t.Errorf("expected status 401, got %d", w.Code) 296 } 297} 298 299// TestOptionalAuth_WithToken tests that OptionalAuth accepts valid DPoP tokens 300func TestOptionalAuth_WithToken(t *testing.T) { 301 fetcher := &mockJWKSFetcher{} 302 middleware := NewAtProtoAuthMiddleware(fetcher, true) 303 304 handlerCalled := false 305 handler := middleware.OptionalAuth(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 306 handlerCalled = true 307 308 // Verify DID was extracted 309 did := GetUserDID(r) 310 if did != "did:plc:test123" { 311 t.Errorf("expected DID 'did:plc:test123', got %s", did) 312 } 313 314 w.WriteHeader(http.StatusOK) 315 })) 316 317 token := createTestToken("did:plc:test123") 318 req := httptest.NewRequest("GET", "/test", nil) 319 req.Header.Set("Authorization", "DPoP "+token) 320 w := httptest.NewRecorder() 321 322 handler.ServeHTTP(w, req) 323 324 if !handlerCalled { 325 t.Error("handler was not called") 326 } 327 328 if w.Code != http.StatusOK { 329 t.Errorf("expected status 200, got %d", w.Code) 330 } 331} 332 333// TestOptionalAuth_WithoutToken tests that OptionalAuth allows requests without tokens 334func TestOptionalAuth_WithoutToken(t *testing.T) { 335 fetcher := &mockJWKSFetcher{} 336 middleware := NewAtProtoAuthMiddleware(fetcher, true) 337 338 handlerCalled := false 339 handler := middleware.OptionalAuth(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 340 handlerCalled = true 341 342 // Verify no DID is set 343 did := GetUserDID(r) 344 if did != "" { 345 t.Errorf("expected empty DID, got %s", did) 346 } 347 348 w.WriteHeader(http.StatusOK) 349 })) 350 351 req := httptest.NewRequest("GET", "/test", nil) 352 // No Authorization header 353 w := httptest.NewRecorder() 354 355 handler.ServeHTTP(w, req) 356 357 if !handlerCalled { 358 t.Error("handler was not called") 359 } 360 361 if w.Code != http.StatusOK { 362 t.Errorf("expected status 200, got %d", w.Code) 363 } 364} 365 366// TestOptionalAuth_InvalidToken tests that OptionalAuth continues without auth on invalid token 367func TestOptionalAuth_InvalidToken(t *testing.T) { 368 fetcher := &mockJWKSFetcher{} 369 middleware := NewAtProtoAuthMiddleware(fetcher, true) 370 371 handlerCalled := false 372 handler := middleware.OptionalAuth(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 373 handlerCalled = true 374 375 // Verify no DID is set (invalid token ignored) 376 did := GetUserDID(r) 377 if did != "" { 378 t.Errorf("expected empty DID for invalid token, got %s", did) 379 } 380 381 w.WriteHeader(http.StatusOK) 382 })) 383 384 req := httptest.NewRequest("GET", "/test", nil) 385 req.Header.Set("Authorization", "DPoP not-a-valid-jwt") 386 w := httptest.NewRecorder() 387 388 handler.ServeHTTP(w, req) 389 390 if !handlerCalled { 391 t.Error("handler was not called") 392 } 393 394 if w.Code != http.StatusOK { 395 t.Errorf("expected status 200, got %d", w.Code) 396 } 397} 398 399// TestGetUserDID_NotAuthenticated tests that GetUserDID returns empty string when not authenticated 400func TestGetUserDID_NotAuthenticated(t *testing.T) { 401 req := httptest.NewRequest("GET", "/test", nil) 402 did := GetUserDID(req) 403 404 if did != "" { 405 t.Errorf("expected empty string, got %s", did) 406 } 407} 408 409// TestGetJWTClaims_NotAuthenticated tests that GetJWTClaims returns nil when not authenticated 410func TestGetJWTClaims_NotAuthenticated(t *testing.T) { 411 req := httptest.NewRequest("GET", "/test", nil) 412 claims := GetJWTClaims(req) 413 414 if claims != nil { 415 t.Errorf("expected nil claims, got %+v", claims) 416 } 417} 418 419// TestGetDPoPProof_NotAuthenticated tests that GetDPoPProof returns nil when no DPoP was verified 420func TestGetDPoPProof_NotAuthenticated(t *testing.T) { 421 req := httptest.NewRequest("GET", "/test", nil) 422 proof := GetDPoPProof(req) 423 424 if proof != nil { 425 t.Errorf("expected nil proof, got %+v", proof) 426 } 427} 428 429// TestRequireAuth_WithDPoP_SecurityModel tests the correct DPoP security model: 430// Token MUST be verified first, then DPoP is checked as an additional layer. 431// DPoP is NOT a fallback for failed token verification. 432func TestRequireAuth_WithDPoP_SecurityModel(t *testing.T) { 433 // Generate an ECDSA key pair for DPoP 434 privateKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) 435 if err != nil { 436 t.Fatalf("failed to generate key: %v", err) 437 } 438 439 // Calculate JWK thumbprint for cnf.jkt 440 jwk := ecdsaPublicKeyToJWK(&privateKey.PublicKey) 441 thumbprint, err := auth.CalculateJWKThumbprint(jwk) 442 if err != nil { 443 t.Fatalf("failed to calculate thumbprint: %v", err) 444 } 445 446 t.Run("DPoP_is_NOT_fallback_for_failed_verification", func(t *testing.T) { 447 // SECURITY TEST: When token verification fails, DPoP should NOT be used as fallback 448 // This prevents an attacker from forging a token with their own cnf.jkt 449 450 // Create a DPoP-bound access token (unsigned - will fail verification) 451 claims := auth.Claims{ 452 RegisteredClaims: jwt.RegisteredClaims{ 453 Subject: "did:plc:attacker", 454 Issuer: "https://external.pds.local", 455 ExpiresAt: jwt.NewNumericDate(time.Now().Add(1 * time.Hour)), 456 IssuedAt: jwt.NewNumericDate(time.Now()), 457 }, 458 Scope: "atproto", 459 Confirmation: map[string]interface{}{ 460 "jkt": thumbprint, 461 }, 462 } 463 464 token := jwt.NewWithClaims(jwt.SigningMethodNone, claims) 465 tokenString, _ := token.SignedString(jwt.UnsafeAllowNoneSignatureType) 466 467 // Create valid DPoP proof (attacker has the private key) 468 dpopProof := createDPoPProof(t, privateKey, "GET", "https://test.local/api/endpoint") 469 470 // Mock fetcher that fails (simulating external PDS without JWKS) 471 fetcher := &mockJWKSFetcher{shouldFail: true} 472 middleware := NewAtProtoAuthMiddleware(fetcher, false) // skipVerify=false 473 474 handler := middleware.RequireAuth(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 475 t.Error("SECURITY VULNERABILITY: handler was called despite token verification failure") 476 })) 477 478 req := httptest.NewRequest("GET", "https://test.local/api/endpoint", nil) 479 req.Header.Set("Authorization", "DPoP "+tokenString) 480 req.Header.Set("DPoP", dpopProof) 481 w := httptest.NewRecorder() 482 483 handler.ServeHTTP(w, req) 484 485 // MUST reject - token verification failed, DPoP cannot substitute for signature verification 486 if w.Code != http.StatusUnauthorized { 487 t.Errorf("SECURITY: expected 401 for unverified token, got %d", w.Code) 488 } 489 }) 490 491 t.Run("DPoP_required_when_cnf_jkt_present_in_verified_token", func(t *testing.T) { 492 // When token has cnf.jkt, DPoP header MUST be present 493 // This test uses skipVerify=true to simulate a verified token 494 495 claims := auth.Claims{ 496 RegisteredClaims: jwt.RegisteredClaims{ 497 Subject: "did:plc:test123", 498 Issuer: "https://test.pds.local", 499 ExpiresAt: jwt.NewNumericDate(time.Now().Add(1 * time.Hour)), 500 IssuedAt: jwt.NewNumericDate(time.Now()), 501 }, 502 Scope: "atproto", 503 Confirmation: map[string]interface{}{ 504 "jkt": thumbprint, 505 }, 506 } 507 508 token := jwt.NewWithClaims(jwt.SigningMethodNone, claims) 509 tokenString, _ := token.SignedString(jwt.UnsafeAllowNoneSignatureType) 510 511 // NO DPoP header - should fail when skipVerify is false 512 // Note: with skipVerify=true, DPoP is not checked 513 fetcher := &mockJWKSFetcher{} 514 middleware := NewAtProtoAuthMiddleware(fetcher, true) // skipVerify=true for parsing 515 516 handlerCalled := false 517 handler := middleware.RequireAuth(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 518 handlerCalled = true 519 w.WriteHeader(http.StatusOK) 520 })) 521 522 req := httptest.NewRequest("GET", "https://test.local/api/endpoint", nil) 523 req.Header.Set("Authorization", "DPoP "+tokenString) 524 // No DPoP header 525 w := httptest.NewRecorder() 526 527 handler.ServeHTTP(w, req) 528 529 // With skipVerify=true, DPoP is not checked, so this should succeed 530 if !handlerCalled { 531 t.Error("handler should be called when skipVerify=true") 532 } 533 }) 534} 535 536// TestRequireAuth_TokenVerificationFails_DPoPNotUsedAsFallback is the key security test. 537// It ensures that DPoP cannot be used as a fallback when token signature verification fails. 538func TestRequireAuth_TokenVerificationFails_DPoPNotUsedAsFallback(t *testing.T) { 539 // Generate a key pair (attacker's key) 540 attackerKey, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) 541 jwk := ecdsaPublicKeyToJWK(&attackerKey.PublicKey) 542 thumbprint, _ := auth.CalculateJWKThumbprint(jwk) 543 544 // Create a FORGED token claiming to be the victim 545 claims := auth.Claims{ 546 RegisteredClaims: jwt.RegisteredClaims{ 547 Subject: "did:plc:victim_user", // Attacker claims to be victim 548 Issuer: "https://untrusted.pds", 549 ExpiresAt: jwt.NewNumericDate(time.Now().Add(1 * time.Hour)), 550 IssuedAt: jwt.NewNumericDate(time.Now()), 551 }, 552 Scope: "atproto", 553 Confirmation: map[string]interface{}{ 554 "jkt": thumbprint, // Attacker uses their own key 555 }, 556 } 557 558 token := jwt.NewWithClaims(jwt.SigningMethodNone, claims) 559 tokenString, _ := token.SignedString(jwt.UnsafeAllowNoneSignatureType) 560 561 // Attacker creates a valid DPoP proof with their key 562 dpopProof := createDPoPProof(t, attackerKey, "POST", "https://api.example.com/protected") 563 564 // Fetcher fails (external PDS without JWKS) 565 fetcher := &mockJWKSFetcher{shouldFail: true} 566 middleware := NewAtProtoAuthMiddleware(fetcher, false) // skipVerify=false - REAL verification 567 568 handler := middleware.RequireAuth(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 569 t.Fatalf("CRITICAL SECURITY FAILURE: Request authenticated as %s despite forged token!", 570 GetUserDID(r)) 571 })) 572 573 req := httptest.NewRequest("POST", "https://api.example.com/protected", nil) 574 req.Header.Set("Authorization", "DPoP "+tokenString) 575 req.Header.Set("DPoP", dpopProof) 576 w := httptest.NewRecorder() 577 578 handler.ServeHTTP(w, req) 579 580 // MUST reject - the token signature was never verified 581 if w.Code != http.StatusUnauthorized { 582 t.Errorf("SECURITY VULNERABILITY: Expected 401, got %d. Token was not properly verified!", w.Code) 583 } 584} 585 586// TestVerifyDPoPBinding_UsesForwardedProto ensures we honor the external HTTPS 587// scheme when TLS is terminated upstream and X-Forwarded-Proto is present. 588func TestVerifyDPoPBinding_UsesForwardedProto(t *testing.T) { 589 privateKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) 590 if err != nil { 591 t.Fatalf("failed to generate key: %v", err) 592 } 593 594 jwk := ecdsaPublicKeyToJWK(&privateKey.PublicKey) 595 thumbprint, err := auth.CalculateJWKThumbprint(jwk) 596 if err != nil { 597 t.Fatalf("failed to calculate thumbprint: %v", err) 598 } 599 600 claims := &auth.Claims{ 601 RegisteredClaims: jwt.RegisteredClaims{ 602 Subject: "did:plc:test123", 603 Issuer: "https://test.pds.local", 604 ExpiresAt: jwt.NewNumericDate(time.Now().Add(1 * time.Hour)), 605 IssuedAt: jwt.NewNumericDate(time.Now()), 606 }, 607 Scope: "atproto", 608 Confirmation: map[string]interface{}{ 609 "jkt": thumbprint, 610 }, 611 } 612 613 middleware := NewAtProtoAuthMiddleware(&mockJWKSFetcher{}, false) 614 defer middleware.Stop() 615 616 externalURI := "https://api.example.com/protected/resource" 617 dpopProof := createDPoPProof(t, privateKey, "GET", externalURI) 618 619 req := httptest.NewRequest("GET", "http://internal-service/protected/resource", nil) 620 req.Host = "api.example.com" 621 req.Header.Set("X-Forwarded-Proto", "https") 622 623 // Pass a fake access token - ath verification will pass since we don't include ath in the DPoP proof 624 fakeAccessToken := "fake-access-token-for-testing" 625 proof, err := middleware.verifyDPoPBinding(req, claims, dpopProof, fakeAccessToken) 626 if err != nil { 627 t.Fatalf("expected DPoP verification to succeed with forwarded proto, got %v", err) 628 } 629 630 if proof == nil || proof.Claims == nil { 631 t.Fatal("expected DPoP proof to be returned") 632 } 633} 634 635// TestVerifyDPoPBinding_UsesForwardedHost ensures we honor X-Forwarded-Host header 636// when behind a TLS-terminating proxy that rewrites the Host header. 637func TestVerifyDPoPBinding_UsesForwardedHost(t *testing.T) { 638 privateKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) 639 if err != nil { 640 t.Fatalf("failed to generate key: %v", err) 641 } 642 643 jwk := ecdsaPublicKeyToJWK(&privateKey.PublicKey) 644 thumbprint, err := auth.CalculateJWKThumbprint(jwk) 645 if err != nil { 646 t.Fatalf("failed to calculate thumbprint: %v", err) 647 } 648 649 claims := &auth.Claims{ 650 RegisteredClaims: jwt.RegisteredClaims{ 651 Subject: "did:plc:test123", 652 Issuer: "https://test.pds.local", 653 ExpiresAt: jwt.NewNumericDate(time.Now().Add(1 * time.Hour)), 654 IssuedAt: jwt.NewNumericDate(time.Now()), 655 }, 656 Scope: "atproto", 657 Confirmation: map[string]interface{}{ 658 "jkt": thumbprint, 659 }, 660 } 661 662 middleware := NewAtProtoAuthMiddleware(&mockJWKSFetcher{}, false) 663 defer middleware.Stop() 664 665 // External URI that the client uses 666 externalURI := "https://api.example.com/protected/resource" 667 dpopProof := createDPoPProof(t, privateKey, "GET", externalURI) 668 669 // Request hits internal service with internal hostname, but X-Forwarded-Host has public hostname 670 req := httptest.NewRequest("GET", "http://internal-service:8080/protected/resource", nil) 671 req.Host = "internal-service:8080" // Internal host after proxy 672 req.Header.Set("X-Forwarded-Proto", "https") 673 req.Header.Set("X-Forwarded-Host", "api.example.com") // Original public host 674 675 fakeAccessToken := "fake-access-token-for-testing" 676 proof, err := middleware.verifyDPoPBinding(req, claims, dpopProof, fakeAccessToken) 677 if err != nil { 678 t.Fatalf("expected DPoP verification to succeed with X-Forwarded-Host, got %v", err) 679 } 680 681 if proof == nil || proof.Claims == nil { 682 t.Fatal("expected DPoP proof to be returned") 683 } 684} 685 686// TestVerifyDPoPBinding_UsesStandardForwardedHeader tests RFC 7239 Forwarded header parsing 687func TestVerifyDPoPBinding_UsesStandardForwardedHeader(t *testing.T) { 688 privateKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) 689 if err != nil { 690 t.Fatalf("failed to generate key: %v", err) 691 } 692 693 jwk := ecdsaPublicKeyToJWK(&privateKey.PublicKey) 694 thumbprint, err := auth.CalculateJWKThumbprint(jwk) 695 if err != nil { 696 t.Fatalf("failed to calculate thumbprint: %v", err) 697 } 698 699 claims := &auth.Claims{ 700 RegisteredClaims: jwt.RegisteredClaims{ 701 Subject: "did:plc:test123", 702 Issuer: "https://test.pds.local", 703 ExpiresAt: jwt.NewNumericDate(time.Now().Add(1 * time.Hour)), 704 IssuedAt: jwt.NewNumericDate(time.Now()), 705 }, 706 Scope: "atproto", 707 Confirmation: map[string]interface{}{ 708 "jkt": thumbprint, 709 }, 710 } 711 712 middleware := NewAtProtoAuthMiddleware(&mockJWKSFetcher{}, false) 713 defer middleware.Stop() 714 715 // External URI 716 externalURI := "https://api.example.com/protected/resource" 717 dpopProof := createDPoPProof(t, privateKey, "GET", externalURI) 718 719 // Request with standard Forwarded header (RFC 7239) 720 req := httptest.NewRequest("GET", "http://internal-service/protected/resource", nil) 721 req.Host = "internal-service" 722 req.Header.Set("Forwarded", "for=192.0.2.60;proto=https;host=api.example.com") 723 724 fakeAccessToken := "fake-access-token-for-testing" 725 proof, err := middleware.verifyDPoPBinding(req, claims, dpopProof, fakeAccessToken) 726 if err != nil { 727 t.Fatalf("expected DPoP verification to succeed with Forwarded header, got %v", err) 728 } 729 730 if proof == nil { 731 t.Fatal("expected DPoP proof to be returned") 732 } 733} 734 735// TestVerifyDPoPBinding_ForwardedMixedCaseAndQuotes tests RFC 7239 edge cases: 736// mixed-case keys (Proto vs proto) and quoted values (host="example.com") 737func TestVerifyDPoPBinding_ForwardedMixedCaseAndQuotes(t *testing.T) { 738 privateKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) 739 if err != nil { 740 t.Fatalf("failed to generate key: %v", err) 741 } 742 743 jwk := ecdsaPublicKeyToJWK(&privateKey.PublicKey) 744 thumbprint, err := auth.CalculateJWKThumbprint(jwk) 745 if err != nil { 746 t.Fatalf("failed to calculate thumbprint: %v", err) 747 } 748 749 claims := &auth.Claims{ 750 RegisteredClaims: jwt.RegisteredClaims{ 751 Subject: "did:plc:test123", 752 Issuer: "https://test.pds.local", 753 ExpiresAt: jwt.NewNumericDate(time.Now().Add(1 * time.Hour)), 754 IssuedAt: jwt.NewNumericDate(time.Now()), 755 }, 756 Scope: "atproto", 757 Confirmation: map[string]interface{}{ 758 "jkt": thumbprint, 759 }, 760 } 761 762 middleware := NewAtProtoAuthMiddleware(&mockJWKSFetcher{}, false) 763 defer middleware.Stop() 764 765 // External URI that the client uses 766 externalURI := "https://api.example.com/protected/resource" 767 dpopProof := createDPoPProof(t, privateKey, "GET", externalURI) 768 769 // Request with RFC 7239 Forwarded header using: 770 // - Mixed-case keys: "Proto" instead of "proto", "Host" instead of "host" 771 // - Quoted value: Host="api.example.com" (legal per RFC 7239 section 4) 772 req := httptest.NewRequest("GET", "http://internal-service/protected/resource", nil) 773 req.Host = "internal-service" 774 req.Header.Set("Forwarded", `for=192.0.2.60;Proto=https;Host="api.example.com"`) 775 776 fakeAccessToken := "fake-access-token-for-testing" 777 proof, err := middleware.verifyDPoPBinding(req, claims, dpopProof, fakeAccessToken) 778 if err != nil { 779 t.Fatalf("expected DPoP verification to succeed with mixed-case/quoted Forwarded header, got %v", err) 780 } 781 782 if proof == nil { 783 t.Fatal("expected DPoP proof to be returned") 784 } 785} 786 787// TestVerifyDPoPBinding_AthValidation tests access token hash (ath) claim validation 788func TestVerifyDPoPBinding_AthValidation(t *testing.T) { 789 privateKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) 790 if err != nil { 791 t.Fatalf("failed to generate key: %v", err) 792 } 793 794 jwk := ecdsaPublicKeyToJWK(&privateKey.PublicKey) 795 thumbprint, err := auth.CalculateJWKThumbprint(jwk) 796 if err != nil { 797 t.Fatalf("failed to calculate thumbprint: %v", err) 798 } 799 800 claims := &auth.Claims{ 801 RegisteredClaims: jwt.RegisteredClaims{ 802 Subject: "did:plc:test123", 803 Issuer: "https://test.pds.local", 804 ExpiresAt: jwt.NewNumericDate(time.Now().Add(1 * time.Hour)), 805 IssuedAt: jwt.NewNumericDate(time.Now()), 806 }, 807 Scope: "atproto", 808 Confirmation: map[string]interface{}{ 809 "jkt": thumbprint, 810 }, 811 } 812 813 middleware := NewAtProtoAuthMiddleware(&mockJWKSFetcher{}, false) 814 defer middleware.Stop() 815 816 accessToken := "real-access-token-12345" 817 818 t.Run("ath_matches_access_token", func(t *testing.T) { 819 // Create DPoP proof with ath claim matching the access token 820 dpopProof := createDPoPProofWithAth(t, privateKey, "GET", "https://api.example.com/resource", accessToken) 821 822 req := httptest.NewRequest("GET", "https://api.example.com/resource", nil) 823 req.Host = "api.example.com" 824 825 proof, err := middleware.verifyDPoPBinding(req, claims, dpopProof, accessToken) 826 if err != nil { 827 t.Fatalf("expected verification to succeed with matching ath, got %v", err) 828 } 829 if proof == nil { 830 t.Fatal("expected proof to be returned") 831 } 832 }) 833 834 t.Run("ath_mismatch_rejected", func(t *testing.T) { 835 // Create DPoP proof with ath for a DIFFERENT token 836 differentToken := "different-token-67890" 837 dpopProof := createDPoPProofWithAth(t, privateKey, "POST", "https://api.example.com/resource", differentToken) 838 839 req := httptest.NewRequest("POST", "https://api.example.com/resource", nil) 840 req.Host = "api.example.com" 841 842 // Try to use with the original access token - should fail 843 _, err := middleware.verifyDPoPBinding(req, claims, dpopProof, accessToken) 844 if err == nil { 845 t.Fatal("SECURITY: expected verification to fail when ath doesn't match access token") 846 } 847 if !strings.Contains(err.Error(), "ath") { 848 t.Errorf("error should mention ath mismatch, got: %v", err) 849 } 850 }) 851} 852 853// TestMiddlewareStop tests that the middleware can be stopped properly 854func TestMiddlewareStop(t *testing.T) { 855 fetcher := &mockJWKSFetcher{} 856 middleware := NewAtProtoAuthMiddleware(fetcher, false) 857 858 // Stop should not panic and should clean up resources 859 middleware.Stop() 860 861 // Calling Stop again should also be safe (idempotent-ish) 862 // Note: The underlying DPoPVerifier.Stop() closes a channel, so this might panic 863 // if not handled properly. We test that at least one Stop works. 864} 865 866// TestOptionalAuth_DPoPBoundToken_NoDPoPHeader tests that OptionalAuth treats 867// tokens with cnf.jkt but no DPoP header as unauthenticated (potential token theft) 868func TestOptionalAuth_DPoPBoundToken_NoDPoPHeader(t *testing.T) { 869 // Generate a key pair for DPoP binding 870 privateKey, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) 871 jwk := ecdsaPublicKeyToJWK(&privateKey.PublicKey) 872 thumbprint, _ := auth.CalculateJWKThumbprint(jwk) 873 874 // Create a DPoP-bound token (has cnf.jkt) 875 claims := auth.Claims{ 876 RegisteredClaims: jwt.RegisteredClaims{ 877 Subject: "did:plc:user123", 878 Issuer: "https://test.pds.local", 879 ExpiresAt: jwt.NewNumericDate(time.Now().Add(1 * time.Hour)), 880 IssuedAt: jwt.NewNumericDate(time.Now()), 881 }, 882 Scope: "atproto", 883 Confirmation: map[string]interface{}{ 884 "jkt": thumbprint, 885 }, 886 } 887 888 token := jwt.NewWithClaims(jwt.SigningMethodNone, claims) 889 tokenString, _ := token.SignedString(jwt.UnsafeAllowNoneSignatureType) 890 891 // Use skipVerify=true to simulate a verified token 892 // (In production, skipVerify would be false and VerifyJWT would be called) 893 // However, for this test we need skipVerify=false to trigger DPoP checking 894 // But the fetcher will fail, so let's use skipVerify=true and verify the logic 895 // Actually, the DPoP check only happens when skipVerify=false 896 897 t.Run("with_skipVerify_false", func(t *testing.T) { 898 // This will fail at JWT verification level, but that's expected 899 // The important thing is the code path for DPoP checking 900 fetcher := &mockJWKSFetcher{shouldFail: true} 901 middleware := NewAtProtoAuthMiddleware(fetcher, false) 902 defer middleware.Stop() 903 904 handlerCalled := false 905 var capturedDID string 906 handler := middleware.OptionalAuth(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 907 handlerCalled = true 908 capturedDID = GetUserDID(r) 909 w.WriteHeader(http.StatusOK) 910 })) 911 912 req := httptest.NewRequest("GET", "/test", nil) 913 req.Header.Set("Authorization", "DPoP "+tokenString) 914 // Deliberately NOT setting DPoP header 915 w := httptest.NewRecorder() 916 917 handler.ServeHTTP(w, req) 918 919 // Handler should be called (optional auth doesn't block) 920 if !handlerCalled { 921 t.Error("handler should be called") 922 } 923 924 // But since JWT verification fails, user should not be authenticated 925 if capturedDID != "" { 926 t.Errorf("expected empty DID when verification fails, got %s", capturedDID) 927 } 928 }) 929 930 t.Run("with_skipVerify_true_dpop_not_checked", func(t *testing.T) { 931 // When skipVerify=true, DPoP is not checked (Phase 1 mode) 932 fetcher := &mockJWKSFetcher{} 933 middleware := NewAtProtoAuthMiddleware(fetcher, true) 934 defer middleware.Stop() 935 936 handlerCalled := false 937 var capturedDID string 938 handler := middleware.OptionalAuth(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 939 handlerCalled = true 940 capturedDID = GetUserDID(r) 941 w.WriteHeader(http.StatusOK) 942 })) 943 944 req := httptest.NewRequest("GET", "/test", nil) 945 req.Header.Set("Authorization", "DPoP "+tokenString) 946 // No DPoP header 947 w := httptest.NewRecorder() 948 949 handler.ServeHTTP(w, req) 950 951 if !handlerCalled { 952 t.Error("handler should be called") 953 } 954 955 // With skipVerify=true, DPoP check is bypassed - token is trusted 956 if capturedDID != "did:plc:user123" { 957 t.Errorf("expected DID when skipVerify=true, got %s", capturedDID) 958 } 959 }) 960} 961 962// TestDPoPReplayProtection tests that the same DPoP proof cannot be used twice 963func TestDPoPReplayProtection(t *testing.T) { 964 // This tests the NonceCache functionality 965 cache := auth.NewNonceCache(5 * time.Minute) 966 defer cache.Stop() 967 968 jti := "unique-proof-id-123" 969 970 // First use should succeed 971 if !cache.CheckAndStore(jti) { 972 t.Error("First use of jti should succeed") 973 } 974 975 // Second use should fail (replay detected) 976 if cache.CheckAndStore(jti) { 977 t.Error("SECURITY: Replay attack not detected - same jti accepted twice") 978 } 979 980 // Different jti should succeed 981 if !cache.CheckAndStore("different-jti-456") { 982 t.Error("Different jti should succeed") 983 } 984} 985 986// Helper: createDPoPProof creates a DPoP proof JWT for testing 987func createDPoPProof(t *testing.T, privateKey *ecdsa.PrivateKey, method, uri string) string { 988 // Create JWK from public key 989 jwk := ecdsaPublicKeyToJWK(&privateKey.PublicKey) 990 991 // Create DPoP claims with UUID for jti to ensure uniqueness across tests 992 claims := auth.DPoPClaims{ 993 RegisteredClaims: jwt.RegisteredClaims{ 994 IssuedAt: jwt.NewNumericDate(time.Now()), 995 ID: uuid.New().String(), 996 }, 997 HTTPMethod: method, 998 HTTPURI: uri, 999 } 1000 1001 // Create token with custom header 1002 token := jwt.NewWithClaims(jwt.SigningMethodES256, claims) 1003 token.Header["typ"] = "dpop+jwt" 1004 token.Header["jwk"] = jwk 1005 1006 // Sign with private key 1007 signedToken, err := token.SignedString(privateKey) 1008 if err != nil { 1009 t.Fatalf("failed to sign DPoP proof: %v", err) 1010 } 1011 1012 return signedToken 1013} 1014 1015// Helper: createDPoPProofWithAth creates a DPoP proof JWT with ath (access token hash) claim 1016func createDPoPProofWithAth(t *testing.T, privateKey *ecdsa.PrivateKey, method, uri, accessToken string) string { 1017 // Create JWK from public key 1018 jwk := ecdsaPublicKeyToJWK(&privateKey.PublicKey) 1019 1020 // Calculate ath: base64url(SHA-256(access_token)) 1021 hash := sha256.Sum256([]byte(accessToken)) 1022 ath := base64.RawURLEncoding.EncodeToString(hash[:]) 1023 1024 // Create DPoP claims with ath 1025 claims := auth.DPoPClaims{ 1026 RegisteredClaims: jwt.RegisteredClaims{ 1027 IssuedAt: jwt.NewNumericDate(time.Now()), 1028 ID: uuid.New().String(), 1029 }, 1030 HTTPMethod: method, 1031 HTTPURI: uri, 1032 AccessTokenHash: ath, 1033 } 1034 1035 // Create token with custom header 1036 token := jwt.NewWithClaims(jwt.SigningMethodES256, claims) 1037 token.Header["typ"] = "dpop+jwt" 1038 token.Header["jwk"] = jwk 1039 1040 // Sign with private key 1041 signedToken, err := token.SignedString(privateKey) 1042 if err != nil { 1043 t.Fatalf("failed to sign DPoP proof: %v", err) 1044 } 1045 1046 return signedToken 1047} 1048 1049// Helper: ecdsaPublicKeyToJWK converts an ECDSA public key to JWK map 1050func ecdsaPublicKeyToJWK(pubKey *ecdsa.PublicKey) map[string]interface{} { 1051 // Get curve name 1052 var crv string 1053 switch pubKey.Curve { 1054 case elliptic.P256(): 1055 crv = "P-256" 1056 case elliptic.P384(): 1057 crv = "P-384" 1058 case elliptic.P521(): 1059 crv = "P-521" 1060 default: 1061 panic("unsupported curve") 1062 } 1063 1064 // Encode coordinates 1065 xBytes := pubKey.X.Bytes() 1066 yBytes := pubKey.Y.Bytes() 1067 1068 // Ensure proper byte length (pad if needed) 1069 keySize := (pubKey.Curve.Params().BitSize + 7) / 8 1070 xPadded := make([]byte, keySize) 1071 yPadded := make([]byte, keySize) 1072 copy(xPadded[keySize-len(xBytes):], xBytes) 1073 copy(yPadded[keySize-len(yBytes):], yBytes) 1074 1075 return map[string]interface{}{ 1076 "kty": "EC", 1077 "crv": crv, 1078 "x": base64.RawURLEncoding.EncodeToString(xPadded), 1079 "y": base64.RawURLEncoding.EncodeToString(yPadded), 1080 } 1081} 1082 1083// Helper: createValidJWT creates a valid unsigned JWT token for testing 1084// This is used with skipVerify=true middleware where signature verification is skipped 1085func createValidJWT(t *testing.T, subject string, expiry time.Duration) string { 1086 t.Helper() 1087 1088 claims := auth.Claims{ 1089 RegisteredClaims: jwt.RegisteredClaims{ 1090 Subject: subject, 1091 Issuer: "https://test.pds.local", 1092 ExpiresAt: jwt.NewNumericDate(time.Now().Add(expiry)), 1093 IssuedAt: jwt.NewNumericDate(time.Now()), 1094 }, 1095 Scope: "atproto", 1096 } 1097 1098 // Create unsigned token (for skipVerify=true tests) 1099 token := jwt.NewWithClaims(jwt.SigningMethodNone, claims) 1100 signedToken, err := token.SignedString(jwt.UnsafeAllowNoneSignatureType) 1101 if err != nil { 1102 t.Fatalf("failed to create test JWT: %v", err) 1103 } 1104 1105 return signedToken 1106}