A community based topic aggregation platform built on atproto
1package middleware 2 3import ( 4 "Coves/internal/atproto/oauth" 5 "context" 6 "encoding/base64" 7 "encoding/json" 8 "fmt" 9 "net/http" 10 "net/http/httptest" 11 "strings" 12 "testing" 13 "time" 14 15 oauthlib "github.com/bluesky-social/indigo/atproto/auth/oauth" 16 "github.com/bluesky-social/indigo/atproto/syntax" 17) 18 19// mockOAuthClient is a test double for OAuthClient 20type mockOAuthClient struct { 21 sealSecret []byte 22 shouldFailSeal bool 23} 24 25func newMockOAuthClient() *mockOAuthClient { 26 // Create a 32-byte seal secret for testing 27 secret := []byte("test-secret-key-32-bytes-long!!") 28 return &mockOAuthClient{ 29 sealSecret: secret, 30 } 31} 32 33func (m *mockOAuthClient) UnsealSession(token string) (*oauth.SealedSession, error) { 34 if m.shouldFailSeal { 35 return nil, fmt.Errorf("mock unseal failure") 36 } 37 38 // For testing, we'll decode a simple format: base64(did|sessionID|expiresAt) 39 // In production this would be AES-GCM encrypted 40 // Using pipe separator to avoid conflicts with colon in DIDs 41 decoded, err := base64.RawURLEncoding.DecodeString(token) 42 if err != nil { 43 return nil, fmt.Errorf("invalid token encoding: %w", err) 44 } 45 46 parts := strings.Split(string(decoded), "|") 47 if len(parts) != 3 { 48 return nil, fmt.Errorf("invalid token format") 49 } 50 51 var expiresAt int64 52 _, _ = fmt.Sscanf(parts[2], "%d", &expiresAt) 53 54 // Check expiration 55 if expiresAt <= time.Now().Unix() { 56 return nil, fmt.Errorf("token expired") 57 } 58 59 return &oauth.SealedSession{ 60 DID: parts[0], 61 SessionID: parts[1], 62 ExpiresAt: expiresAt, 63 }, nil 64} 65 66// Helper to create a test sealed token 67func (m *mockOAuthClient) createTestToken(did, sessionID string, ttl time.Duration) string { 68 expiresAt := time.Now().Add(ttl).Unix() 69 payload := fmt.Sprintf("%s|%s|%d", did, sessionID, expiresAt) 70 return base64.RawURLEncoding.EncodeToString([]byte(payload)) 71} 72 73// mockOAuthStore is a test double for ClientAuthStore 74type mockOAuthStore struct { 75 sessions map[string]*oauthlib.ClientSessionData 76} 77 78func newMockOAuthStore() *mockOAuthStore { 79 return &mockOAuthStore{ 80 sessions: make(map[string]*oauthlib.ClientSessionData), 81 } 82} 83 84func (m *mockOAuthStore) GetSession(ctx context.Context, did syntax.DID, sessionID string) (*oauthlib.ClientSessionData, error) { 85 key := did.String() + ":" + sessionID 86 session, ok := m.sessions[key] 87 if !ok { 88 return nil, fmt.Errorf("session not found") 89 } 90 return session, nil 91} 92 93func (m *mockOAuthStore) SaveSession(ctx context.Context, session oauthlib.ClientSessionData) error { 94 key := session.AccountDID.String() + ":" + session.SessionID 95 m.sessions[key] = &session 96 return nil 97} 98 99func (m *mockOAuthStore) DeleteSession(ctx context.Context, did syntax.DID, sessionID string) error { 100 key := did.String() + ":" + sessionID 101 delete(m.sessions, key) 102 return nil 103} 104 105func (m *mockOAuthStore) GetAuthRequestInfo(ctx context.Context, state string) (*oauthlib.AuthRequestData, error) { 106 return nil, fmt.Errorf("not implemented") 107} 108 109func (m *mockOAuthStore) SaveAuthRequestInfo(ctx context.Context, info oauthlib.AuthRequestData) error { 110 return fmt.Errorf("not implemented") 111} 112 113func (m *mockOAuthStore) DeleteAuthRequestInfo(ctx context.Context, state string) error { 114 return fmt.Errorf("not implemented") 115} 116 117// TestRequireAuth_ValidToken tests that valid sealed tokens are accepted 118func TestRequireAuth_ValidToken(t *testing.T) { 119 client := newMockOAuthClient() 120 store := newMockOAuthStore() 121 122 // Create a test session 123 did := syntax.DID("did:plc:test123") 124 sessionID := "session123" 125 session := &oauthlib.ClientSessionData{ 126 AccountDID: did, 127 SessionID: sessionID, 128 AccessToken: "test_access_token", 129 HostURL: "https://pds.example.com", 130 } 131 _ = store.SaveSession(context.Background(), *session) 132 133 middleware := NewOAuthAuthMiddleware(client, store) 134 135 handlerCalled := false 136 handler := middleware.RequireAuth(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 137 handlerCalled = true 138 139 // Verify DID was extracted and injected into context 140 extractedDID := GetUserDID(r) 141 if extractedDID != "did:plc:test123" { 142 t.Errorf("expected DID 'did:plc:test123', got %s", extractedDID) 143 } 144 145 // Verify OAuth session was injected 146 oauthSession := GetOAuthSession(r) 147 if oauthSession == nil { 148 t.Error("expected OAuth session to be non-nil") 149 return 150 } 151 if oauthSession.SessionID != sessionID { 152 t.Errorf("expected session ID '%s', got %s", sessionID, oauthSession.SessionID) 153 } 154 155 // Verify access token is available 156 accessToken := GetUserAccessToken(r) 157 if accessToken != "test_access_token" { 158 t.Errorf("expected access token 'test_access_token', got %s", accessToken) 159 } 160 161 w.WriteHeader(http.StatusOK) 162 })) 163 164 token := client.createTestToken("did:plc:test123", sessionID, time.Hour) 165 req := httptest.NewRequest("GET", "/test", nil) 166 req.Header.Set("Authorization", "Bearer "+token) 167 w := httptest.NewRecorder() 168 169 handler.ServeHTTP(w, req) 170 171 if !handlerCalled { 172 t.Error("handler was not called") 173 } 174 175 if w.Code != http.StatusOK { 176 t.Errorf("expected status 200, got %d: %s", w.Code, w.Body.String()) 177 } 178} 179 180// TestRequireAuth_MissingAuthHeader tests that missing Authorization header is rejected 181func TestRequireAuth_MissingAuthHeader(t *testing.T) { 182 client := newMockOAuthClient() 183 store := newMockOAuthStore() 184 middleware := NewOAuthAuthMiddleware(client, store) 185 186 handler := middleware.RequireAuth(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 187 t.Error("handler should not be called") 188 })) 189 190 req := httptest.NewRequest("GET", "/test", nil) 191 // No Authorization header 192 w := httptest.NewRecorder() 193 194 handler.ServeHTTP(w, req) 195 196 if w.Code != http.StatusUnauthorized { 197 t.Errorf("expected status 401, got %d", w.Code) 198 } 199} 200 201// TestRequireAuth_InvalidAuthHeaderFormat tests that non-Bearer tokens are rejected 202func TestRequireAuth_InvalidAuthHeaderFormat(t *testing.T) { 203 client := newMockOAuthClient() 204 store := newMockOAuthStore() 205 middleware := NewOAuthAuthMiddleware(client, store) 206 207 tests := []struct { 208 name string 209 header string 210 }{ 211 {"Basic auth", "Basic dGVzdDp0ZXN0"}, 212 {"DPoP scheme", "DPoP some-token"}, 213 {"Invalid format", "InvalidFormat"}, 214 } 215 216 for _, tt := range tests { 217 t.Run(tt.name, func(t *testing.T) { 218 handler := middleware.RequireAuth(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 219 t.Error("handler should not be called") 220 })) 221 222 req := httptest.NewRequest("GET", "/test", nil) 223 req.Header.Set("Authorization", tt.header) 224 w := httptest.NewRecorder() 225 226 handler.ServeHTTP(w, req) 227 228 if w.Code != http.StatusUnauthorized { 229 t.Errorf("expected status 401, got %d", w.Code) 230 } 231 }) 232 } 233} 234 235// TestRequireAuth_CaseInsensitiveScheme verifies that Bearer scheme matching is case-insensitive 236func TestRequireAuth_CaseInsensitiveScheme(t *testing.T) { 237 client := newMockOAuthClient() 238 store := newMockOAuthStore() 239 240 // Create a test session 241 did := syntax.DID("did:plc:test123") 242 sessionID := "session123" 243 session := &oauthlib.ClientSessionData{ 244 AccountDID: did, 245 SessionID: sessionID, 246 AccessToken: "test_access_token", 247 } 248 _ = store.SaveSession(context.Background(), *session) 249 250 middleware := NewOAuthAuthMiddleware(client, store) 251 token := client.createTestToken("did:plc:test123", sessionID, time.Hour) 252 253 testCases := []struct { 254 name string 255 scheme string 256 }{ 257 {"lowercase", "bearer"}, 258 {"uppercase", "BEARER"}, 259 {"mixed_case", "BeArEr"}, 260 {"standard", "Bearer"}, 261 } 262 263 for _, tc := range testCases { 264 t.Run(tc.name, func(t *testing.T) { 265 handlerCalled := false 266 handler := middleware.RequireAuth(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 267 handlerCalled = true 268 w.WriteHeader(http.StatusOK) 269 })) 270 271 req := httptest.NewRequest("GET", "/test", nil) 272 req.Header.Set("Authorization", tc.scheme+" "+token) 273 w := httptest.NewRecorder() 274 275 handler.ServeHTTP(w, req) 276 277 if !handlerCalled { 278 t.Errorf("scheme %q should be accepted (case-insensitive per RFC 7235), got status %d: %s", 279 tc.scheme, w.Code, w.Body.String()) 280 } 281 }) 282 } 283} 284 285// TestRequireAuth_InvalidToken tests that malformed sealed tokens are rejected 286func TestRequireAuth_InvalidToken(t *testing.T) { 287 client := newMockOAuthClient() 288 store := newMockOAuthStore() 289 middleware := NewOAuthAuthMiddleware(client, store) 290 291 handler := middleware.RequireAuth(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 292 t.Error("handler should not be called") 293 })) 294 295 req := httptest.NewRequest("GET", "/test", nil) 296 req.Header.Set("Authorization", "Bearer not-a-valid-sealed-token") 297 w := httptest.NewRecorder() 298 299 handler.ServeHTTP(w, req) 300 301 if w.Code != http.StatusUnauthorized { 302 t.Errorf("expected status 401, got %d", w.Code) 303 } 304} 305 306// TestRequireAuth_ExpiredToken tests that expired sealed tokens are rejected 307func TestRequireAuth_ExpiredToken(t *testing.T) { 308 client := newMockOAuthClient() 309 store := newMockOAuthStore() 310 311 // Create a test session 312 did := syntax.DID("did:plc:test123") 313 sessionID := "session123" 314 session := &oauthlib.ClientSessionData{ 315 AccountDID: did, 316 SessionID: sessionID, 317 AccessToken: "test_access_token", 318 } 319 _ = store.SaveSession(context.Background(), *session) 320 321 middleware := NewOAuthAuthMiddleware(client, store) 322 323 handler := middleware.RequireAuth(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 324 t.Error("handler should not be called for expired token") 325 })) 326 327 // Create expired token (expired 1 hour ago) 328 token := client.createTestToken("did:plc:test123", sessionID, -time.Hour) 329 330 req := httptest.NewRequest("GET", "/test", nil) 331 req.Header.Set("Authorization", "Bearer "+token) 332 w := httptest.NewRecorder() 333 334 handler.ServeHTTP(w, req) 335 336 if w.Code != http.StatusUnauthorized { 337 t.Errorf("expected status 401, got %d", w.Code) 338 } 339} 340 341// TestRequireAuth_SessionNotFound tests that tokens with non-existent sessions are rejected 342func TestRequireAuth_SessionNotFound(t *testing.T) { 343 client := newMockOAuthClient() 344 store := newMockOAuthStore() 345 middleware := NewOAuthAuthMiddleware(client, store) 346 347 handler := middleware.RequireAuth(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 348 t.Error("handler should not be called") 349 })) 350 351 // Create token for session that doesn't exist in store 352 token := client.createTestToken("did:plc:nonexistent", "session999", time.Hour) 353 354 req := httptest.NewRequest("GET", "/test", nil) 355 req.Header.Set("Authorization", "Bearer "+token) 356 w := httptest.NewRecorder() 357 358 handler.ServeHTTP(w, req) 359 360 if w.Code != http.StatusUnauthorized { 361 t.Errorf("expected status 401, got %d", w.Code) 362 } 363} 364 365// TestRequireAuth_DIDMismatch tests that session DID must match token DID 366func TestRequireAuth_DIDMismatch(t *testing.T) { 367 client := newMockOAuthClient() 368 store := newMockOAuthStore() 369 370 // Create a session with different DID than token 371 did := syntax.DID("did:plc:different") 372 sessionID := "session123" 373 session := &oauthlib.ClientSessionData{ 374 AccountDID: did, 375 SessionID: sessionID, 376 AccessToken: "test_access_token", 377 } 378 // Store with key that matches token DID 379 key := "did:plc:test123:" + sessionID 380 store.sessions[key] = session 381 382 middleware := NewOAuthAuthMiddleware(client, store) 383 384 handler := middleware.RequireAuth(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 385 t.Error("handler should not be called when DID mismatches") 386 })) 387 388 token := client.createTestToken("did:plc:test123", sessionID, time.Hour) 389 390 req := httptest.NewRequest("GET", "/test", nil) 391 req.Header.Set("Authorization", "Bearer "+token) 392 w := httptest.NewRecorder() 393 394 handler.ServeHTTP(w, req) 395 396 if w.Code != http.StatusUnauthorized { 397 t.Errorf("expected status 401, got %d", w.Code) 398 } 399} 400 401// TestOptionalAuth_WithToken tests that OptionalAuth accepts valid Bearer tokens 402func TestOptionalAuth_WithToken(t *testing.T) { 403 client := newMockOAuthClient() 404 store := newMockOAuthStore() 405 406 // Create a test session 407 did := syntax.DID("did:plc:test123") 408 sessionID := "session123" 409 session := &oauthlib.ClientSessionData{ 410 AccountDID: did, 411 SessionID: sessionID, 412 AccessToken: "test_access_token", 413 } 414 _ = store.SaveSession(context.Background(), *session) 415 416 middleware := NewOAuthAuthMiddleware(client, store) 417 418 handlerCalled := false 419 handler := middleware.OptionalAuth(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 420 handlerCalled = true 421 422 // Verify DID was extracted 423 extractedDID := GetUserDID(r) 424 if extractedDID != "did:plc:test123" { 425 t.Errorf("expected DID 'did:plc:test123', got %s", extractedDID) 426 } 427 428 w.WriteHeader(http.StatusOK) 429 })) 430 431 token := client.createTestToken("did:plc:test123", sessionID, time.Hour) 432 req := httptest.NewRequest("GET", "/test", nil) 433 req.Header.Set("Authorization", "Bearer "+token) 434 w := httptest.NewRecorder() 435 436 handler.ServeHTTP(w, req) 437 438 if !handlerCalled { 439 t.Error("handler was not called") 440 } 441 442 if w.Code != http.StatusOK { 443 t.Errorf("expected status 200, got %d", w.Code) 444 } 445} 446 447// TestOptionalAuth_WithoutToken tests that OptionalAuth allows requests without tokens 448func TestOptionalAuth_WithoutToken(t *testing.T) { 449 client := newMockOAuthClient() 450 store := newMockOAuthStore() 451 middleware := NewOAuthAuthMiddleware(client, store) 452 453 handlerCalled := false 454 handler := middleware.OptionalAuth(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 455 handlerCalled = true 456 457 // Verify no DID is set 458 did := GetUserDID(r) 459 if did != "" { 460 t.Errorf("expected empty DID, got %s", did) 461 } 462 463 w.WriteHeader(http.StatusOK) 464 })) 465 466 req := httptest.NewRequest("GET", "/test", nil) 467 // No Authorization header 468 w := httptest.NewRecorder() 469 470 handler.ServeHTTP(w, req) 471 472 if !handlerCalled { 473 t.Error("handler was not called") 474 } 475 476 if w.Code != http.StatusOK { 477 t.Errorf("expected status 200, got %d", w.Code) 478 } 479} 480 481// TestOptionalAuth_InvalidToken tests that OptionalAuth continues without auth on invalid token 482func TestOptionalAuth_InvalidToken(t *testing.T) { 483 client := newMockOAuthClient() 484 store := newMockOAuthStore() 485 middleware := NewOAuthAuthMiddleware(client, store) 486 487 handlerCalled := false 488 handler := middleware.OptionalAuth(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 489 handlerCalled = true 490 491 // Verify no DID is set (invalid token ignored) 492 did := GetUserDID(r) 493 if did != "" { 494 t.Errorf("expected empty DID for invalid token, got %s", did) 495 } 496 497 w.WriteHeader(http.StatusOK) 498 })) 499 500 req := httptest.NewRequest("GET", "/test", nil) 501 req.Header.Set("Authorization", "Bearer not-a-valid-sealed-token") 502 w := httptest.NewRecorder() 503 504 handler.ServeHTTP(w, req) 505 506 if !handlerCalled { 507 t.Error("handler was not called") 508 } 509 510 if w.Code != http.StatusOK { 511 t.Errorf("expected status 200, got %d", w.Code) 512 } 513} 514 515// TestGetUserDID_NotAuthenticated tests that GetUserDID returns empty string when not authenticated 516func TestGetUserDID_NotAuthenticated(t *testing.T) { 517 req := httptest.NewRequest("GET", "/test", nil) 518 did := GetUserDID(req) 519 520 if did != "" { 521 t.Errorf("expected empty string, got %s", did) 522 } 523} 524 525// TestGetOAuthSession_NotAuthenticated tests that GetOAuthSession returns nil when not authenticated 526func TestGetOAuthSession_NotAuthenticated(t *testing.T) { 527 req := httptest.NewRequest("GET", "/test", nil) 528 session := GetOAuthSession(req) 529 530 if session != nil { 531 t.Errorf("expected nil session, got %+v", session) 532 } 533} 534 535// TestGetUserAccessToken_NotAuthenticated tests that GetUserAccessToken returns empty when not authenticated 536func TestGetUserAccessToken_NotAuthenticated(t *testing.T) { 537 req := httptest.NewRequest("GET", "/test", nil) 538 token := GetUserAccessToken(req) 539 540 if token != "" { 541 t.Errorf("expected empty token, got %s", token) 542 } 543} 544 545// TestSetTestUserDID tests the testing helper function 546func TestSetTestUserDID(t *testing.T) { 547 ctx := context.Background() 548 ctx = SetTestUserDID(ctx, "did:plc:testuser") 549 550 did, ok := ctx.Value(UserDIDKey).(string) 551 if !ok { 552 t.Error("DID not found in context") 553 } 554 if did != "did:plc:testuser" { 555 t.Errorf("expected 'did:plc:testuser', got %s", did) 556 } 557} 558 559// TestExtractBearerToken tests the Bearer token extraction logic 560func TestExtractBearerToken(t *testing.T) { 561 tests := []struct { 562 name string 563 authHeader string 564 expectToken string 565 expectOK bool 566 }{ 567 {"valid bearer", "Bearer token123", "token123", true}, 568 {"lowercase bearer", "bearer token123", "token123", true}, 569 {"uppercase bearer", "BEARER token123", "token123", true}, 570 {"mixed case", "BeArEr token123", "token123", true}, 571 {"empty header", "", "", false}, 572 {"wrong scheme", "DPoP token123", "", false}, 573 {"no token", "Bearer", "", false}, 574 {"no space", "Bearertoken123", "", false}, 575 {"extra spaces", "Bearer token123 ", "token123", true}, 576 } 577 578 for _, tt := range tests { 579 t.Run(tt.name, func(t *testing.T) { 580 token, ok := extractBearerToken(tt.authHeader) 581 if ok != tt.expectOK { 582 t.Errorf("expected ok=%v, got %v", tt.expectOK, ok) 583 } 584 if token != tt.expectToken { 585 t.Errorf("expected token '%s', got '%s'", tt.expectToken, token) 586 } 587 }) 588 } 589} 590 591// TestRequireAuth_ValidCookie tests that valid session cookies are accepted 592func TestRequireAuth_ValidCookie(t *testing.T) { 593 client := newMockOAuthClient() 594 store := newMockOAuthStore() 595 596 // Create a test session 597 did := syntax.DID("did:plc:test123") 598 sessionID := "session123" 599 session := &oauthlib.ClientSessionData{ 600 AccountDID: did, 601 SessionID: sessionID, 602 AccessToken: "test_access_token", 603 HostURL: "https://pds.example.com", 604 } 605 _ = store.SaveSession(context.Background(), *session) 606 607 middleware := NewOAuthAuthMiddleware(client, store) 608 609 handlerCalled := false 610 handler := middleware.RequireAuth(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 611 handlerCalled = true 612 613 // Verify DID was extracted and injected into context 614 extractedDID := GetUserDID(r) 615 if extractedDID != "did:plc:test123" { 616 t.Errorf("expected DID 'did:plc:test123', got %s", extractedDID) 617 } 618 619 // Verify OAuth session was injected 620 oauthSession := GetOAuthSession(r) 621 if oauthSession == nil { 622 t.Error("expected OAuth session to be non-nil") 623 return 624 } 625 if oauthSession.SessionID != sessionID { 626 t.Errorf("expected session ID '%s', got %s", sessionID, oauthSession.SessionID) 627 } 628 629 w.WriteHeader(http.StatusOK) 630 })) 631 632 token := client.createTestToken("did:plc:test123", sessionID, time.Hour) 633 req := httptest.NewRequest("GET", "/test", nil) 634 req.AddCookie(&http.Cookie{ 635 Name: "coves_session", 636 Value: token, 637 }) 638 w := httptest.NewRecorder() 639 640 handler.ServeHTTP(w, req) 641 642 if !handlerCalled { 643 t.Error("handler was not called") 644 } 645 646 if w.Code != http.StatusOK { 647 t.Errorf("expected status 200, got %d: %s", w.Code, w.Body.String()) 648 } 649} 650 651// TestRequireAuth_HeaderPrecedenceOverCookie tests that Authorization header takes precedence over cookie 652func TestRequireAuth_HeaderPrecedenceOverCookie(t *testing.T) { 653 client := newMockOAuthClient() 654 store := newMockOAuthStore() 655 656 // Create two test sessions 657 did1 := syntax.DID("did:plc:header") 658 sessionID1 := "session_header" 659 session1 := &oauthlib.ClientSessionData{ 660 AccountDID: did1, 661 SessionID: sessionID1, 662 AccessToken: "header_token", 663 HostURL: "https://pds.example.com", 664 } 665 _ = store.SaveSession(context.Background(), *session1) 666 667 did2 := syntax.DID("did:plc:cookie") 668 sessionID2 := "session_cookie" 669 session2 := &oauthlib.ClientSessionData{ 670 AccountDID: did2, 671 SessionID: sessionID2, 672 AccessToken: "cookie_token", 673 HostURL: "https://pds.example.com", 674 } 675 _ = store.SaveSession(context.Background(), *session2) 676 677 middleware := NewOAuthAuthMiddleware(client, store) 678 679 handlerCalled := false 680 handler := middleware.RequireAuth(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 681 handlerCalled = true 682 683 // Should get header DID, not cookie DID 684 extractedDID := GetUserDID(r) 685 if extractedDID != "did:plc:header" { 686 t.Errorf("expected header DID 'did:plc:header', got %s", extractedDID) 687 } 688 689 w.WriteHeader(http.StatusOK) 690 })) 691 692 headerToken := client.createTestToken("did:plc:header", sessionID1, time.Hour) 693 cookieToken := client.createTestToken("did:plc:cookie", sessionID2, time.Hour) 694 695 req := httptest.NewRequest("GET", "/test", nil) 696 req.Header.Set("Authorization", "Bearer "+headerToken) 697 req.AddCookie(&http.Cookie{ 698 Name: "coves_session", 699 Value: cookieToken, 700 }) 701 w := httptest.NewRecorder() 702 703 handler.ServeHTTP(w, req) 704 705 if !handlerCalled { 706 t.Error("handler was not called") 707 } 708 709 if w.Code != http.StatusOK { 710 t.Errorf("expected status 200, got %d", w.Code) 711 } 712} 713 714// TestRequireAuth_MissingBothHeaderAndCookie tests that missing both auth methods is rejected 715func TestRequireAuth_MissingBothHeaderAndCookie(t *testing.T) { 716 client := newMockOAuthClient() 717 store := newMockOAuthStore() 718 middleware := NewOAuthAuthMiddleware(client, store) 719 720 handler := middleware.RequireAuth(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 721 t.Error("handler should not be called") 722 })) 723 724 req := httptest.NewRequest("GET", "/test", nil) 725 // No Authorization header and no cookie 726 w := httptest.NewRecorder() 727 728 handler.ServeHTTP(w, req) 729 730 if w.Code != http.StatusUnauthorized { 731 t.Errorf("expected status 401, got %d", w.Code) 732 } 733} 734 735// TestRequireAuth_InvalidCookie tests that malformed cookie tokens are rejected 736func TestRequireAuth_InvalidCookie(t *testing.T) { 737 client := newMockOAuthClient() 738 store := newMockOAuthStore() 739 middleware := NewOAuthAuthMiddleware(client, store) 740 741 handler := middleware.RequireAuth(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 742 t.Error("handler should not be called") 743 })) 744 745 req := httptest.NewRequest("GET", "/test", nil) 746 req.AddCookie(&http.Cookie{ 747 Name: "coves_session", 748 Value: "not-a-valid-sealed-token", 749 }) 750 w := httptest.NewRecorder() 751 752 handler.ServeHTTP(w, req) 753 754 if w.Code != http.StatusUnauthorized { 755 t.Errorf("expected status 401, got %d", w.Code) 756 } 757} 758 759// TestOptionalAuth_WithCookie tests that OptionalAuth accepts valid session cookies 760func TestOptionalAuth_WithCookie(t *testing.T) { 761 client := newMockOAuthClient() 762 store := newMockOAuthStore() 763 764 // Create a test session 765 did := syntax.DID("did:plc:test123") 766 sessionID := "session123" 767 session := &oauthlib.ClientSessionData{ 768 AccountDID: did, 769 SessionID: sessionID, 770 AccessToken: "test_access_token", 771 } 772 _ = store.SaveSession(context.Background(), *session) 773 774 middleware := NewOAuthAuthMiddleware(client, store) 775 776 handlerCalled := false 777 handler := middleware.OptionalAuth(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 778 handlerCalled = true 779 780 // Verify DID was extracted 781 extractedDID := GetUserDID(r) 782 if extractedDID != "did:plc:test123" { 783 t.Errorf("expected DID 'did:plc:test123', got %s", extractedDID) 784 } 785 786 w.WriteHeader(http.StatusOK) 787 })) 788 789 token := client.createTestToken("did:plc:test123", sessionID, time.Hour) 790 req := httptest.NewRequest("GET", "/test", nil) 791 req.AddCookie(&http.Cookie{ 792 Name: "coves_session", 793 Value: token, 794 }) 795 w := httptest.NewRecorder() 796 797 handler.ServeHTTP(w, req) 798 799 if !handlerCalled { 800 t.Error("handler was not called") 801 } 802 803 if w.Code != http.StatusOK { 804 t.Errorf("expected status 200, got %d", w.Code) 805 } 806} 807 808// TestOptionalAuth_InvalidCookie tests that OptionalAuth continues without auth on invalid cookie 809func TestOptionalAuth_InvalidCookie(t *testing.T) { 810 client := newMockOAuthClient() 811 store := newMockOAuthStore() 812 middleware := NewOAuthAuthMiddleware(client, store) 813 814 handlerCalled := false 815 handler := middleware.OptionalAuth(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 816 handlerCalled = true 817 818 // Verify no DID is set (invalid cookie ignored) 819 did := GetUserDID(r) 820 if did != "" { 821 t.Errorf("expected empty DID for invalid cookie, got %s", did) 822 } 823 824 w.WriteHeader(http.StatusOK) 825 })) 826 827 req := httptest.NewRequest("GET", "/test", nil) 828 req.AddCookie(&http.Cookie{ 829 Name: "coves_session", 830 Value: "not-a-valid-sealed-token", 831 }) 832 w := httptest.NewRecorder() 833 834 handler.ServeHTTP(w, req) 835 836 if !handlerCalled { 837 t.Error("handler was not called") 838 } 839 840 if w.Code != http.StatusOK { 841 t.Errorf("expected status 200, got %d", w.Code) 842 } 843} 844 845// TestWriteAuthError_JSONEscaping tests that writeAuthError properly escapes messages 846func TestWriteAuthError_JSONEscaping(t *testing.T) { 847 tests := []struct { 848 name string 849 message string 850 }{ 851 {"simple message", "Missing authentication"}, 852 {"message with quotes", `Invalid "token" format`}, 853 {"message with newlines", "Invalid\ntoken\nformat"}, 854 {"message with backslashes", `Invalid \ token`}, 855 {"message with special chars", `Invalid <script>alert("xss")</script> token`}, 856 {"message with unicode", "Invalid token: \u2028\u2029"}, 857 } 858 859 for _, tt := range tests { 860 t.Run(tt.name, func(t *testing.T) { 861 w := httptest.NewRecorder() 862 writeAuthError(w, tt.message) 863 864 // Verify status code 865 if w.Code != http.StatusUnauthorized { 866 t.Errorf("expected status 401, got %d", w.Code) 867 } 868 869 // Verify content type 870 if ct := w.Header().Get("Content-Type"); ct != "application/json" { 871 t.Errorf("expected Content-Type 'application/json', got %s", ct) 872 } 873 874 // Verify response is valid JSON 875 var response map[string]string 876 if err := json.Unmarshal(w.Body.Bytes(), &response); err != nil { 877 t.Fatalf("response is not valid JSON: %v\nBody: %s", err, w.Body.String()) 878 } 879 880 // Verify fields 881 if response["error"] != "AuthenticationRequired" { 882 t.Errorf("expected error 'AuthenticationRequired', got %s", response["error"]) 883 } 884 if response["message"] != tt.message { 885 t.Errorf("expected message %q, got %q", tt.message, response["message"]) 886 } 887 }) 888 } 889}