A community based topic aggregation platform built on atproto
1package middleware 2 3import ( 4 "Coves/internal/atproto/auth" 5 "context" 6 "crypto/ecdsa" 7 "crypto/elliptic" 8 "crypto/rand" 9 "encoding/base64" 10 "fmt" 11 "net/http" 12 "net/http/httptest" 13 "testing" 14 "time" 15 16 "github.com/golang-jwt/jwt/v5" 17 "github.com/google/uuid" 18) 19 20// mockJWKSFetcher is a test double for JWKSFetcher 21type mockJWKSFetcher struct { 22 shouldFail bool 23} 24 25func (m *mockJWKSFetcher) FetchPublicKey(ctx context.Context, issuer, token string) (interface{}, error) { 26 if m.shouldFail { 27 return nil, fmt.Errorf("mock fetch failure") 28 } 29 // Return nil - we won't actually verify signatures in Phase 1 tests 30 return nil, nil 31} 32 33// createTestToken creates a test JWT with the given DID 34func createTestToken(did string) string { 35 claims := jwt.MapClaims{ 36 "sub": did, 37 "iss": "https://test.pds.local", 38 "scope": "atproto", 39 "exp": time.Now().Add(1 * time.Hour).Unix(), 40 "iat": time.Now().Unix(), 41 } 42 43 token := jwt.NewWithClaims(jwt.SigningMethodNone, claims) 44 tokenString, _ := token.SignedString(jwt.UnsafeAllowNoneSignatureType) 45 return tokenString 46} 47 48// TestRequireAuth_ValidToken tests that valid tokens are accepted (Phase 1) 49func TestRequireAuth_ValidToken(t *testing.T) { 50 fetcher := &mockJWKSFetcher{} 51 middleware := NewAtProtoAuthMiddleware(fetcher, true) // skipVerify=true 52 53 handlerCalled := false 54 handler := middleware.RequireAuth(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 55 handlerCalled = true 56 57 // Verify DID was extracted and injected into context 58 did := GetUserDID(r) 59 if did != "did:plc:test123" { 60 t.Errorf("expected DID 'did:plc:test123', got %s", did) 61 } 62 63 // Verify claims were injected 64 claims := GetJWTClaims(r) 65 if claims == nil { 66 t.Error("expected claims to be non-nil") 67 return 68 } 69 if claims.Subject != "did:plc:test123" { 70 t.Errorf("expected claims.Subject 'did:plc:test123', got %s", claims.Subject) 71 } 72 73 w.WriteHeader(http.StatusOK) 74 })) 75 76 token := createTestToken("did:plc:test123") 77 req := httptest.NewRequest("GET", "/test", nil) 78 req.Header.Set("Authorization", "Bearer "+token) 79 w := httptest.NewRecorder() 80 81 handler.ServeHTTP(w, req) 82 83 if !handlerCalled { 84 t.Error("handler was not called") 85 } 86 87 if w.Code != http.StatusOK { 88 t.Errorf("expected status 200, got %d: %s", w.Code, w.Body.String()) 89 } 90} 91 92// TestRequireAuth_MissingAuthHeader tests that missing Authorization header is rejected 93func TestRequireAuth_MissingAuthHeader(t *testing.T) { 94 fetcher := &mockJWKSFetcher{} 95 middleware := NewAtProtoAuthMiddleware(fetcher, true) 96 97 handler := middleware.RequireAuth(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 98 t.Error("handler should not be called") 99 })) 100 101 req := httptest.NewRequest("GET", "/test", nil) 102 // No Authorization header 103 w := httptest.NewRecorder() 104 105 handler.ServeHTTP(w, req) 106 107 if w.Code != http.StatusUnauthorized { 108 t.Errorf("expected status 401, got %d", w.Code) 109 } 110} 111 112// TestRequireAuth_InvalidAuthHeaderFormat tests that non-Bearer tokens are rejected 113func TestRequireAuth_InvalidAuthHeaderFormat(t *testing.T) { 114 fetcher := &mockJWKSFetcher{} 115 middleware := NewAtProtoAuthMiddleware(fetcher, true) 116 117 handler := middleware.RequireAuth(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 118 t.Error("handler should not be called") 119 })) 120 121 req := httptest.NewRequest("GET", "/test", nil) 122 req.Header.Set("Authorization", "Basic dGVzdDp0ZXN0") // Wrong format 123 w := httptest.NewRecorder() 124 125 handler.ServeHTTP(w, req) 126 127 if w.Code != http.StatusUnauthorized { 128 t.Errorf("expected status 401, got %d", w.Code) 129 } 130} 131 132// TestRequireAuth_MalformedToken tests that malformed JWTs are rejected 133func TestRequireAuth_MalformedToken(t *testing.T) { 134 fetcher := &mockJWKSFetcher{} 135 middleware := NewAtProtoAuthMiddleware(fetcher, true) 136 137 handler := middleware.RequireAuth(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 138 t.Error("handler should not be called") 139 })) 140 141 req := httptest.NewRequest("GET", "/test", nil) 142 req.Header.Set("Authorization", "Bearer not-a-valid-jwt") 143 w := httptest.NewRecorder() 144 145 handler.ServeHTTP(w, req) 146 147 if w.Code != http.StatusUnauthorized { 148 t.Errorf("expected status 401, got %d", w.Code) 149 } 150} 151 152// TestRequireAuth_ExpiredToken tests that expired tokens are rejected 153func TestRequireAuth_ExpiredToken(t *testing.T) { 154 fetcher := &mockJWKSFetcher{} 155 middleware := NewAtProtoAuthMiddleware(fetcher, true) 156 157 handler := middleware.RequireAuth(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 158 t.Error("handler should not be called for expired token") 159 })) 160 161 // Create expired token 162 claims := jwt.MapClaims{ 163 "sub": "did:plc:test123", 164 "iss": "https://test.pds.local", 165 "scope": "atproto", 166 "exp": time.Now().Add(-1 * time.Hour).Unix(), // Expired 1 hour ago 167 "iat": time.Now().Add(-2 * time.Hour).Unix(), 168 } 169 170 token := jwt.NewWithClaims(jwt.SigningMethodNone, claims) 171 tokenString, _ := token.SignedString(jwt.UnsafeAllowNoneSignatureType) 172 173 req := httptest.NewRequest("GET", "/test", nil) 174 req.Header.Set("Authorization", "Bearer "+tokenString) 175 w := httptest.NewRecorder() 176 177 handler.ServeHTTP(w, req) 178 179 if w.Code != http.StatusUnauthorized { 180 t.Errorf("expected status 401, got %d", w.Code) 181 } 182} 183 184// TestRequireAuth_MissingDID tests that tokens without DID are rejected 185func TestRequireAuth_MissingDID(t *testing.T) { 186 fetcher := &mockJWKSFetcher{} 187 middleware := NewAtProtoAuthMiddleware(fetcher, true) 188 189 handler := middleware.RequireAuth(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 190 t.Error("handler should not be called") 191 })) 192 193 // Create token without sub claim 194 claims := jwt.MapClaims{ 195 // "sub" missing 196 "iss": "https://test.pds.local", 197 "scope": "atproto", 198 "exp": time.Now().Add(1 * time.Hour).Unix(), 199 "iat": time.Now().Unix(), 200 } 201 202 token := jwt.NewWithClaims(jwt.SigningMethodNone, claims) 203 tokenString, _ := token.SignedString(jwt.UnsafeAllowNoneSignatureType) 204 205 req := httptest.NewRequest("GET", "/test", nil) 206 req.Header.Set("Authorization", "Bearer "+tokenString) 207 w := httptest.NewRecorder() 208 209 handler.ServeHTTP(w, req) 210 211 if w.Code != http.StatusUnauthorized { 212 t.Errorf("expected status 401, got %d", w.Code) 213 } 214} 215 216// TestOptionalAuth_WithToken tests that OptionalAuth accepts valid tokens 217func TestOptionalAuth_WithToken(t *testing.T) { 218 fetcher := &mockJWKSFetcher{} 219 middleware := NewAtProtoAuthMiddleware(fetcher, true) 220 221 handlerCalled := false 222 handler := middleware.OptionalAuth(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 223 handlerCalled = true 224 225 // Verify DID was extracted 226 did := GetUserDID(r) 227 if did != "did:plc:test123" { 228 t.Errorf("expected DID 'did:plc:test123', got %s", did) 229 } 230 231 w.WriteHeader(http.StatusOK) 232 })) 233 234 token := createTestToken("did:plc:test123") 235 req := httptest.NewRequest("GET", "/test", nil) 236 req.Header.Set("Authorization", "Bearer "+token) 237 w := httptest.NewRecorder() 238 239 handler.ServeHTTP(w, req) 240 241 if !handlerCalled { 242 t.Error("handler was not called") 243 } 244 245 if w.Code != http.StatusOK { 246 t.Errorf("expected status 200, got %d", w.Code) 247 } 248} 249 250// TestOptionalAuth_WithoutToken tests that OptionalAuth allows requests without tokens 251func TestOptionalAuth_WithoutToken(t *testing.T) { 252 fetcher := &mockJWKSFetcher{} 253 middleware := NewAtProtoAuthMiddleware(fetcher, true) 254 255 handlerCalled := false 256 handler := middleware.OptionalAuth(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 257 handlerCalled = true 258 259 // Verify no DID is set 260 did := GetUserDID(r) 261 if did != "" { 262 t.Errorf("expected empty DID, got %s", did) 263 } 264 265 w.WriteHeader(http.StatusOK) 266 })) 267 268 req := httptest.NewRequest("GET", "/test", nil) 269 // No Authorization header 270 w := httptest.NewRecorder() 271 272 handler.ServeHTTP(w, req) 273 274 if !handlerCalled { 275 t.Error("handler was not called") 276 } 277 278 if w.Code != http.StatusOK { 279 t.Errorf("expected status 200, got %d", w.Code) 280 } 281} 282 283// TestOptionalAuth_InvalidToken tests that OptionalAuth continues without auth on invalid token 284func TestOptionalAuth_InvalidToken(t *testing.T) { 285 fetcher := &mockJWKSFetcher{} 286 middleware := NewAtProtoAuthMiddleware(fetcher, true) 287 288 handlerCalled := false 289 handler := middleware.OptionalAuth(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 290 handlerCalled = true 291 292 // Verify no DID is set (invalid token ignored) 293 did := GetUserDID(r) 294 if did != "" { 295 t.Errorf("expected empty DID for invalid token, got %s", did) 296 } 297 298 w.WriteHeader(http.StatusOK) 299 })) 300 301 req := httptest.NewRequest("GET", "/test", nil) 302 req.Header.Set("Authorization", "Bearer not-a-valid-jwt") 303 w := httptest.NewRecorder() 304 305 handler.ServeHTTP(w, req) 306 307 if !handlerCalled { 308 t.Error("handler was not called") 309 } 310 311 if w.Code != http.StatusOK { 312 t.Errorf("expected status 200, got %d", w.Code) 313 } 314} 315 316// TestGetUserDID_NotAuthenticated tests that GetUserDID returns empty string when not authenticated 317func TestGetUserDID_NotAuthenticated(t *testing.T) { 318 req := httptest.NewRequest("GET", "/test", nil) 319 did := GetUserDID(req) 320 321 if did != "" { 322 t.Errorf("expected empty string, got %s", did) 323 } 324} 325 326// TestGetJWTClaims_NotAuthenticated tests that GetJWTClaims returns nil when not authenticated 327func TestGetJWTClaims_NotAuthenticated(t *testing.T) { 328 req := httptest.NewRequest("GET", "/test", nil) 329 claims := GetJWTClaims(req) 330 331 if claims != nil { 332 t.Errorf("expected nil claims, got %+v", claims) 333 } 334} 335 336// TestGetDPoPProof_NotAuthenticated tests that GetDPoPProof returns nil when no DPoP was verified 337func TestGetDPoPProof_NotAuthenticated(t *testing.T) { 338 req := httptest.NewRequest("GET", "/test", nil) 339 proof := GetDPoPProof(req) 340 341 if proof != nil { 342 t.Errorf("expected nil proof, got %+v", proof) 343 } 344} 345 346// TestRequireAuth_WithDPoP_SecurityModel tests the correct DPoP security model: 347// Token MUST be verified first, then DPoP is checked as an additional layer. 348// DPoP is NOT a fallback for failed token verification. 349func TestRequireAuth_WithDPoP_SecurityModel(t *testing.T) { 350 // Generate an ECDSA key pair for DPoP 351 privateKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) 352 if err != nil { 353 t.Fatalf("failed to generate key: %v", err) 354 } 355 356 // Calculate JWK thumbprint for cnf.jkt 357 jwk := ecdsaPublicKeyToJWK(&privateKey.PublicKey) 358 thumbprint, err := auth.CalculateJWKThumbprint(jwk) 359 if err != nil { 360 t.Fatalf("failed to calculate thumbprint: %v", err) 361 } 362 363 t.Run("DPoP_is_NOT_fallback_for_failed_verification", func(t *testing.T) { 364 // SECURITY TEST: When token verification fails, DPoP should NOT be used as fallback 365 // This prevents an attacker from forging a token with their own cnf.jkt 366 367 // Create a DPoP-bound access token (unsigned - will fail verification) 368 claims := auth.Claims{ 369 RegisteredClaims: jwt.RegisteredClaims{ 370 Subject: "did:plc:attacker", 371 Issuer: "https://external.pds.local", 372 ExpiresAt: jwt.NewNumericDate(time.Now().Add(1 * time.Hour)), 373 IssuedAt: jwt.NewNumericDate(time.Now()), 374 }, 375 Scope: "atproto", 376 Confirmation: map[string]interface{}{ 377 "jkt": thumbprint, 378 }, 379 } 380 381 token := jwt.NewWithClaims(jwt.SigningMethodNone, claims) 382 tokenString, _ := token.SignedString(jwt.UnsafeAllowNoneSignatureType) 383 384 // Create valid DPoP proof (attacker has the private key) 385 dpopProof := createDPoPProof(t, privateKey, "GET", "https://test.local/api/endpoint") 386 387 // Mock fetcher that fails (simulating external PDS without JWKS) 388 fetcher := &mockJWKSFetcher{shouldFail: true} 389 middleware := NewAtProtoAuthMiddleware(fetcher, false) // skipVerify=false 390 391 handler := middleware.RequireAuth(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 392 t.Error("SECURITY VULNERABILITY: handler was called despite token verification failure") 393 })) 394 395 req := httptest.NewRequest("GET", "https://test.local/api/endpoint", nil) 396 req.Header.Set("Authorization", "Bearer "+tokenString) 397 req.Header.Set("DPoP", dpopProof) 398 w := httptest.NewRecorder() 399 400 handler.ServeHTTP(w, req) 401 402 // MUST reject - token verification failed, DPoP cannot substitute for signature verification 403 if w.Code != http.StatusUnauthorized { 404 t.Errorf("SECURITY: expected 401 for unverified token, got %d", w.Code) 405 } 406 }) 407 408 t.Run("DPoP_required_when_cnf_jkt_present_in_verified_token", func(t *testing.T) { 409 // When token has cnf.jkt, DPoP header MUST be present 410 // This test uses skipVerify=true to simulate a verified token 411 412 claims := auth.Claims{ 413 RegisteredClaims: jwt.RegisteredClaims{ 414 Subject: "did:plc:test123", 415 Issuer: "https://test.pds.local", 416 ExpiresAt: jwt.NewNumericDate(time.Now().Add(1 * time.Hour)), 417 IssuedAt: jwt.NewNumericDate(time.Now()), 418 }, 419 Scope: "atproto", 420 Confirmation: map[string]interface{}{ 421 "jkt": thumbprint, 422 }, 423 } 424 425 token := jwt.NewWithClaims(jwt.SigningMethodNone, claims) 426 tokenString, _ := token.SignedString(jwt.UnsafeAllowNoneSignatureType) 427 428 // NO DPoP header - should fail when skipVerify is false 429 // Note: with skipVerify=true, DPoP is not checked 430 fetcher := &mockJWKSFetcher{} 431 middleware := NewAtProtoAuthMiddleware(fetcher, true) // skipVerify=true for parsing 432 433 handlerCalled := false 434 handler := middleware.RequireAuth(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 435 handlerCalled = true 436 w.WriteHeader(http.StatusOK) 437 })) 438 439 req := httptest.NewRequest("GET", "https://test.local/api/endpoint", nil) 440 req.Header.Set("Authorization", "Bearer "+tokenString) 441 // No DPoP header 442 w := httptest.NewRecorder() 443 444 handler.ServeHTTP(w, req) 445 446 // With skipVerify=true, DPoP is not checked, so this should succeed 447 if !handlerCalled { 448 t.Error("handler should be called when skipVerify=true") 449 } 450 }) 451} 452 453// TestRequireAuth_TokenVerificationFails_DPoPNotUsedAsFallback is the key security test. 454// It ensures that DPoP cannot be used as a fallback when token signature verification fails. 455func TestRequireAuth_TokenVerificationFails_DPoPNotUsedAsFallback(t *testing.T) { 456 // Generate a key pair (attacker's key) 457 attackerKey, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) 458 jwk := ecdsaPublicKeyToJWK(&attackerKey.PublicKey) 459 thumbprint, _ := auth.CalculateJWKThumbprint(jwk) 460 461 // Create a FORGED token claiming to be the victim 462 claims := auth.Claims{ 463 RegisteredClaims: jwt.RegisteredClaims{ 464 Subject: "did:plc:victim_user", // Attacker claims to be victim 465 Issuer: "https://untrusted.pds", 466 ExpiresAt: jwt.NewNumericDate(time.Now().Add(1 * time.Hour)), 467 IssuedAt: jwt.NewNumericDate(time.Now()), 468 }, 469 Scope: "atproto", 470 Confirmation: map[string]interface{}{ 471 "jkt": thumbprint, // Attacker uses their own key 472 }, 473 } 474 475 token := jwt.NewWithClaims(jwt.SigningMethodNone, claims) 476 tokenString, _ := token.SignedString(jwt.UnsafeAllowNoneSignatureType) 477 478 // Attacker creates a valid DPoP proof with their key 479 dpopProof := createDPoPProof(t, attackerKey, "POST", "https://api.example.com/protected") 480 481 // Fetcher fails (external PDS without JWKS) 482 fetcher := &mockJWKSFetcher{shouldFail: true} 483 middleware := NewAtProtoAuthMiddleware(fetcher, false) // skipVerify=false - REAL verification 484 485 handler := middleware.RequireAuth(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 486 t.Fatalf("CRITICAL SECURITY FAILURE: Request authenticated as %s despite forged token!", 487 GetUserDID(r)) 488 })) 489 490 req := httptest.NewRequest("POST", "https://api.example.com/protected", nil) 491 req.Header.Set("Authorization", "Bearer "+tokenString) 492 req.Header.Set("DPoP", dpopProof) 493 w := httptest.NewRecorder() 494 495 handler.ServeHTTP(w, req) 496 497 // MUST reject - the token signature was never verified 498 if w.Code != http.StatusUnauthorized { 499 t.Errorf("SECURITY VULNERABILITY: Expected 401, got %d. Token was not properly verified!", w.Code) 500 } 501} 502 503// TestVerifyDPoPBinding_UsesForwardedProto ensures we honor the external HTTPS 504// scheme when TLS is terminated upstream and X-Forwarded-Proto is present. 505func TestVerifyDPoPBinding_UsesForwardedProto(t *testing.T) { 506 privateKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) 507 if err != nil { 508 t.Fatalf("failed to generate key: %v", err) 509 } 510 511 jwk := ecdsaPublicKeyToJWK(&privateKey.PublicKey) 512 thumbprint, err := auth.CalculateJWKThumbprint(jwk) 513 if err != nil { 514 t.Fatalf("failed to calculate thumbprint: %v", err) 515 } 516 517 claims := &auth.Claims{ 518 RegisteredClaims: jwt.RegisteredClaims{ 519 Subject: "did:plc:test123", 520 Issuer: "https://test.pds.local", 521 ExpiresAt: jwt.NewNumericDate(time.Now().Add(1 * time.Hour)), 522 IssuedAt: jwt.NewNumericDate(time.Now()), 523 }, 524 Scope: "atproto", 525 Confirmation: map[string]interface{}{ 526 "jkt": thumbprint, 527 }, 528 } 529 530 middleware := NewAtProtoAuthMiddleware(&mockJWKSFetcher{}, false) 531 defer middleware.Stop() 532 533 externalURI := "https://api.example.com/protected/resource" 534 dpopProof := createDPoPProof(t, privateKey, "GET", externalURI) 535 536 req := httptest.NewRequest("GET", "http://internal-service/protected/resource", nil) 537 req.Host = "api.example.com" 538 req.Header.Set("X-Forwarded-Proto", "https") 539 540 proof, err := middleware.verifyDPoPBinding(req, claims, dpopProof) 541 if err != nil { 542 t.Fatalf("expected DPoP verification to succeed with forwarded proto, got %v", err) 543 } 544 545 if proof == nil || proof.Claims == nil { 546 t.Fatal("expected DPoP proof to be returned") 547 } 548} 549 550// TestMiddlewareStop tests that the middleware can be stopped properly 551func TestMiddlewareStop(t *testing.T) { 552 fetcher := &mockJWKSFetcher{} 553 middleware := NewAtProtoAuthMiddleware(fetcher, false) 554 555 // Stop should not panic and should clean up resources 556 middleware.Stop() 557 558 // Calling Stop again should also be safe (idempotent-ish) 559 // Note: The underlying DPoPVerifier.Stop() closes a channel, so this might panic 560 // if not handled properly. We test that at least one Stop works. 561} 562 563// TestOptionalAuth_DPoPBoundToken_NoDPoPHeader tests that OptionalAuth treats 564// tokens with cnf.jkt but no DPoP header as unauthenticated (potential token theft) 565func TestOptionalAuth_DPoPBoundToken_NoDPoPHeader(t *testing.T) { 566 // Generate a key pair for DPoP binding 567 privateKey, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) 568 jwk := ecdsaPublicKeyToJWK(&privateKey.PublicKey) 569 thumbprint, _ := auth.CalculateJWKThumbprint(jwk) 570 571 // Create a DPoP-bound token (has cnf.jkt) 572 claims := auth.Claims{ 573 RegisteredClaims: jwt.RegisteredClaims{ 574 Subject: "did:plc:user123", 575 Issuer: "https://test.pds.local", 576 ExpiresAt: jwt.NewNumericDate(time.Now().Add(1 * time.Hour)), 577 IssuedAt: jwt.NewNumericDate(time.Now()), 578 }, 579 Scope: "atproto", 580 Confirmation: map[string]interface{}{ 581 "jkt": thumbprint, 582 }, 583 } 584 585 token := jwt.NewWithClaims(jwt.SigningMethodNone, claims) 586 tokenString, _ := token.SignedString(jwt.UnsafeAllowNoneSignatureType) 587 588 // Use skipVerify=true to simulate a verified token 589 // (In production, skipVerify would be false and VerifyJWT would be called) 590 // However, for this test we need skipVerify=false to trigger DPoP checking 591 // But the fetcher will fail, so let's use skipVerify=true and verify the logic 592 // Actually, the DPoP check only happens when skipVerify=false 593 594 t.Run("with_skipVerify_false", func(t *testing.T) { 595 // This will fail at JWT verification level, but that's expected 596 // The important thing is the code path for DPoP checking 597 fetcher := &mockJWKSFetcher{shouldFail: true} 598 middleware := NewAtProtoAuthMiddleware(fetcher, false) 599 defer middleware.Stop() 600 601 handlerCalled := false 602 var capturedDID string 603 handler := middleware.OptionalAuth(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 604 handlerCalled = true 605 capturedDID = GetUserDID(r) 606 w.WriteHeader(http.StatusOK) 607 })) 608 609 req := httptest.NewRequest("GET", "/test", nil) 610 req.Header.Set("Authorization", "Bearer "+tokenString) 611 // Deliberately NOT setting DPoP header 612 w := httptest.NewRecorder() 613 614 handler.ServeHTTP(w, req) 615 616 // Handler should be called (optional auth doesn't block) 617 if !handlerCalled { 618 t.Error("handler should be called") 619 } 620 621 // But since JWT verification fails, user should not be authenticated 622 if capturedDID != "" { 623 t.Errorf("expected empty DID when verification fails, got %s", capturedDID) 624 } 625 }) 626 627 t.Run("with_skipVerify_true_dpop_not_checked", func(t *testing.T) { 628 // When skipVerify=true, DPoP is not checked (Phase 1 mode) 629 fetcher := &mockJWKSFetcher{} 630 middleware := NewAtProtoAuthMiddleware(fetcher, true) 631 defer middleware.Stop() 632 633 handlerCalled := false 634 var capturedDID string 635 handler := middleware.OptionalAuth(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 636 handlerCalled = true 637 capturedDID = GetUserDID(r) 638 w.WriteHeader(http.StatusOK) 639 })) 640 641 req := httptest.NewRequest("GET", "/test", nil) 642 req.Header.Set("Authorization", "Bearer "+tokenString) 643 // No DPoP header 644 w := httptest.NewRecorder() 645 646 handler.ServeHTTP(w, req) 647 648 if !handlerCalled { 649 t.Error("handler should be called") 650 } 651 652 // With skipVerify=true, DPoP check is bypassed - token is trusted 653 if capturedDID != "did:plc:user123" { 654 t.Errorf("expected DID when skipVerify=true, got %s", capturedDID) 655 } 656 }) 657} 658 659// TestDPoPReplayProtection tests that the same DPoP proof cannot be used twice 660func TestDPoPReplayProtection(t *testing.T) { 661 // This tests the NonceCache functionality 662 cache := auth.NewNonceCache(5 * time.Minute) 663 defer cache.Stop() 664 665 jti := "unique-proof-id-123" 666 667 // First use should succeed 668 if !cache.CheckAndStore(jti) { 669 t.Error("First use of jti should succeed") 670 } 671 672 // Second use should fail (replay detected) 673 if cache.CheckAndStore(jti) { 674 t.Error("SECURITY: Replay attack not detected - same jti accepted twice") 675 } 676 677 // Different jti should succeed 678 if !cache.CheckAndStore("different-jti-456") { 679 t.Error("Different jti should succeed") 680 } 681} 682 683// Helper: createDPoPProof creates a DPoP proof JWT for testing 684func createDPoPProof(t *testing.T, privateKey *ecdsa.PrivateKey, method, uri string) string { 685 // Create JWK from public key 686 jwk := ecdsaPublicKeyToJWK(&privateKey.PublicKey) 687 688 // Create DPoP claims with UUID for jti to ensure uniqueness across tests 689 claims := auth.DPoPClaims{ 690 RegisteredClaims: jwt.RegisteredClaims{ 691 IssuedAt: jwt.NewNumericDate(time.Now()), 692 ID: uuid.New().String(), 693 }, 694 HTTPMethod: method, 695 HTTPURI: uri, 696 } 697 698 // Create token with custom header 699 token := jwt.NewWithClaims(jwt.SigningMethodES256, claims) 700 token.Header["typ"] = "dpop+jwt" 701 token.Header["jwk"] = jwk 702 703 // Sign with private key 704 signedToken, err := token.SignedString(privateKey) 705 if err != nil { 706 t.Fatalf("failed to sign DPoP proof: %v", err) 707 } 708 709 return signedToken 710} 711 712// Helper: ecdsaPublicKeyToJWK converts an ECDSA public key to JWK map 713func ecdsaPublicKeyToJWK(pubKey *ecdsa.PublicKey) map[string]interface{} { 714 // Get curve name 715 var crv string 716 switch pubKey.Curve { 717 case elliptic.P256(): 718 crv = "P-256" 719 case elliptic.P384(): 720 crv = "P-384" 721 case elliptic.P521(): 722 crv = "P-521" 723 default: 724 panic("unsupported curve") 725 } 726 727 // Encode coordinates 728 xBytes := pubKey.X.Bytes() 729 yBytes := pubKey.Y.Bytes() 730 731 // Ensure proper byte length (pad if needed) 732 keySize := (pubKey.Curve.Params().BitSize + 7) / 8 733 xPadded := make([]byte, keySize) 734 yPadded := make([]byte, keySize) 735 copy(xPadded[keySize-len(xBytes):], xBytes) 736 copy(yPadded[keySize-len(yBytes):], yBytes) 737 738 return map[string]interface{}{ 739 "kty": "EC", 740 "crv": crv, 741 "x": base64.RawURLEncoding.EncodeToString(xPadded), 742 "y": base64.RawURLEncoding.EncodeToString(yPadded), 743 } 744}