A community based topic aggregation platform built on atproto
1package middleware
2
3import (
4 "Coves/internal/atproto/oauth"
5 "context"
6 "encoding/base64"
7 "encoding/json"
8 "fmt"
9 "net/http"
10 "net/http/httptest"
11 "strings"
12 "testing"
13 "time"
14
15 oauthlib "github.com/bluesky-social/indigo/atproto/auth/oauth"
16 "github.com/bluesky-social/indigo/atproto/syntax"
17)
18
19// mockOAuthClient is a test double for OAuthClient
20type mockOAuthClient struct {
21 sealSecret []byte
22 shouldFailSeal bool
23}
24
25func newMockOAuthClient() *mockOAuthClient {
26 // Create a 32-byte seal secret for testing
27 secret := []byte("test-secret-key-32-bytes-long!!")
28 return &mockOAuthClient{
29 sealSecret: secret,
30 }
31}
32
33func (m *mockOAuthClient) UnsealSession(token string) (*oauth.SealedSession, error) {
34 if m.shouldFailSeal {
35 return nil, fmt.Errorf("mock unseal failure")
36 }
37
38 // For testing, we'll decode a simple format: base64(did|sessionID|expiresAt)
39 // In production this would be AES-GCM encrypted
40 // Using pipe separator to avoid conflicts with colon in DIDs
41 decoded, err := base64.RawURLEncoding.DecodeString(token)
42 if err != nil {
43 return nil, fmt.Errorf("invalid token encoding: %w", err)
44 }
45
46 parts := strings.Split(string(decoded), "|")
47 if len(parts) != 3 {
48 return nil, fmt.Errorf("invalid token format")
49 }
50
51 var expiresAt int64
52 _, _ = fmt.Sscanf(parts[2], "%d", &expiresAt)
53
54 // Check expiration
55 if expiresAt <= time.Now().Unix() {
56 return nil, fmt.Errorf("token expired")
57 }
58
59 return &oauth.SealedSession{
60 DID: parts[0],
61 SessionID: parts[1],
62 ExpiresAt: expiresAt,
63 }, nil
64}
65
66// Helper to create a test sealed token
67func (m *mockOAuthClient) createTestToken(did, sessionID string, ttl time.Duration) string {
68 expiresAt := time.Now().Add(ttl).Unix()
69 payload := fmt.Sprintf("%s|%s|%d", did, sessionID, expiresAt)
70 return base64.RawURLEncoding.EncodeToString([]byte(payload))
71}
72
73// mockOAuthStore is a test double for ClientAuthStore
74type mockOAuthStore struct {
75 sessions map[string]*oauthlib.ClientSessionData
76}
77
78func newMockOAuthStore() *mockOAuthStore {
79 return &mockOAuthStore{
80 sessions: make(map[string]*oauthlib.ClientSessionData),
81 }
82}
83
84func (m *mockOAuthStore) GetSession(ctx context.Context, did syntax.DID, sessionID string) (*oauthlib.ClientSessionData, error) {
85 key := did.String() + ":" + sessionID
86 session, ok := m.sessions[key]
87 if !ok {
88 return nil, fmt.Errorf("session not found")
89 }
90 return session, nil
91}
92
93func (m *mockOAuthStore) SaveSession(ctx context.Context, session oauthlib.ClientSessionData) error {
94 key := session.AccountDID.String() + ":" + session.SessionID
95 m.sessions[key] = &session
96 return nil
97}
98
99func (m *mockOAuthStore) DeleteSession(ctx context.Context, did syntax.DID, sessionID string) error {
100 key := did.String() + ":" + sessionID
101 delete(m.sessions, key)
102 return nil
103}
104
105func (m *mockOAuthStore) GetAuthRequestInfo(ctx context.Context, state string) (*oauthlib.AuthRequestData, error) {
106 return nil, fmt.Errorf("not implemented")
107}
108
109func (m *mockOAuthStore) SaveAuthRequestInfo(ctx context.Context, info oauthlib.AuthRequestData) error {
110 return fmt.Errorf("not implemented")
111}
112
113func (m *mockOAuthStore) DeleteAuthRequestInfo(ctx context.Context, state string) error {
114 return fmt.Errorf("not implemented")
115}
116
117// TestRequireAuth_ValidToken tests that valid sealed tokens are accepted
118func TestRequireAuth_ValidToken(t *testing.T) {
119 client := newMockOAuthClient()
120 store := newMockOAuthStore()
121
122 // Create a test session
123 did := syntax.DID("did:plc:test123")
124 sessionID := "session123"
125 session := &oauthlib.ClientSessionData{
126 AccountDID: did,
127 SessionID: sessionID,
128 AccessToken: "test_access_token",
129 HostURL: "https://pds.example.com",
130 }
131 _ = store.SaveSession(context.Background(), *session)
132
133 middleware := NewOAuthAuthMiddleware(client, store)
134
135 handlerCalled := false
136 handler := middleware.RequireAuth(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
137 handlerCalled = true
138
139 // Verify DID was extracted and injected into context
140 extractedDID := GetUserDID(r)
141 if extractedDID != "did:plc:test123" {
142 t.Errorf("expected DID 'did:plc:test123', got %s", extractedDID)
143 }
144
145 // Verify OAuth session was injected
146 oauthSession := GetOAuthSession(r)
147 if oauthSession == nil {
148 t.Error("expected OAuth session to be non-nil")
149 return
150 }
151 if oauthSession.SessionID != sessionID {
152 t.Errorf("expected session ID '%s', got %s", sessionID, oauthSession.SessionID)
153 }
154
155 // Verify access token is available
156 accessToken := GetUserAccessToken(r)
157 if accessToken != "test_access_token" {
158 t.Errorf("expected access token 'test_access_token', got %s", accessToken)
159 }
160
161 w.WriteHeader(http.StatusOK)
162 }))
163
164 token := client.createTestToken("did:plc:test123", sessionID, time.Hour)
165 req := httptest.NewRequest("GET", "/test", nil)
166 req.Header.Set("Authorization", "Bearer "+token)
167 w := httptest.NewRecorder()
168
169 handler.ServeHTTP(w, req)
170
171 if !handlerCalled {
172 t.Error("handler was not called")
173 }
174
175 if w.Code != http.StatusOK {
176 t.Errorf("expected status 200, got %d: %s", w.Code, w.Body.String())
177 }
178}
179
180// TestRequireAuth_MissingAuthHeader tests that missing Authorization header is rejected
181func TestRequireAuth_MissingAuthHeader(t *testing.T) {
182 client := newMockOAuthClient()
183 store := newMockOAuthStore()
184 middleware := NewOAuthAuthMiddleware(client, store)
185
186 handler := middleware.RequireAuth(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
187 t.Error("handler should not be called")
188 }))
189
190 req := httptest.NewRequest("GET", "/test", nil)
191 // No Authorization header
192 w := httptest.NewRecorder()
193
194 handler.ServeHTTP(w, req)
195
196 if w.Code != http.StatusUnauthorized {
197 t.Errorf("expected status 401, got %d", w.Code)
198 }
199}
200
201// TestRequireAuth_InvalidAuthHeaderFormat tests that non-Bearer tokens are rejected
202func TestRequireAuth_InvalidAuthHeaderFormat(t *testing.T) {
203 client := newMockOAuthClient()
204 store := newMockOAuthStore()
205 middleware := NewOAuthAuthMiddleware(client, store)
206
207 tests := []struct {
208 name string
209 header string
210 }{
211 {"Basic auth", "Basic dGVzdDp0ZXN0"},
212 {"DPoP scheme", "DPoP some-token"},
213 {"Invalid format", "InvalidFormat"},
214 }
215
216 for _, tt := range tests {
217 t.Run(tt.name, func(t *testing.T) {
218 handler := middleware.RequireAuth(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
219 t.Error("handler should not be called")
220 }))
221
222 req := httptest.NewRequest("GET", "/test", nil)
223 req.Header.Set("Authorization", tt.header)
224 w := httptest.NewRecorder()
225
226 handler.ServeHTTP(w, req)
227
228 if w.Code != http.StatusUnauthorized {
229 t.Errorf("expected status 401, got %d", w.Code)
230 }
231 })
232 }
233}
234
235// TestRequireAuth_CaseInsensitiveScheme verifies that Bearer scheme matching is case-insensitive
236func TestRequireAuth_CaseInsensitiveScheme(t *testing.T) {
237 client := newMockOAuthClient()
238 store := newMockOAuthStore()
239
240 // Create a test session
241 did := syntax.DID("did:plc:test123")
242 sessionID := "session123"
243 session := &oauthlib.ClientSessionData{
244 AccountDID: did,
245 SessionID: sessionID,
246 AccessToken: "test_access_token",
247 }
248 _ = store.SaveSession(context.Background(), *session)
249
250 middleware := NewOAuthAuthMiddleware(client, store)
251 token := client.createTestToken("did:plc:test123", sessionID, time.Hour)
252
253 testCases := []struct {
254 name string
255 scheme string
256 }{
257 {"lowercase", "bearer"},
258 {"uppercase", "BEARER"},
259 {"mixed_case", "BeArEr"},
260 {"standard", "Bearer"},
261 }
262
263 for _, tc := range testCases {
264 t.Run(tc.name, func(t *testing.T) {
265 handlerCalled := false
266 handler := middleware.RequireAuth(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
267 handlerCalled = true
268 w.WriteHeader(http.StatusOK)
269 }))
270
271 req := httptest.NewRequest("GET", "/test", nil)
272 req.Header.Set("Authorization", tc.scheme+" "+token)
273 w := httptest.NewRecorder()
274
275 handler.ServeHTTP(w, req)
276
277 if !handlerCalled {
278 t.Errorf("scheme %q should be accepted (case-insensitive per RFC 7235), got status %d: %s",
279 tc.scheme, w.Code, w.Body.String())
280 }
281 })
282 }
283}
284
285// TestRequireAuth_InvalidToken tests that malformed sealed tokens are rejected
286func TestRequireAuth_InvalidToken(t *testing.T) {
287 client := newMockOAuthClient()
288 store := newMockOAuthStore()
289 middleware := NewOAuthAuthMiddleware(client, store)
290
291 handler := middleware.RequireAuth(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
292 t.Error("handler should not be called")
293 }))
294
295 req := httptest.NewRequest("GET", "/test", nil)
296 req.Header.Set("Authorization", "Bearer not-a-valid-sealed-token")
297 w := httptest.NewRecorder()
298
299 handler.ServeHTTP(w, req)
300
301 if w.Code != http.StatusUnauthorized {
302 t.Errorf("expected status 401, got %d", w.Code)
303 }
304}
305
306// TestRequireAuth_ExpiredToken tests that expired sealed tokens are rejected
307func TestRequireAuth_ExpiredToken(t *testing.T) {
308 client := newMockOAuthClient()
309 store := newMockOAuthStore()
310
311 // Create a test session
312 did := syntax.DID("did:plc:test123")
313 sessionID := "session123"
314 session := &oauthlib.ClientSessionData{
315 AccountDID: did,
316 SessionID: sessionID,
317 AccessToken: "test_access_token",
318 }
319 _ = store.SaveSession(context.Background(), *session)
320
321 middleware := NewOAuthAuthMiddleware(client, store)
322
323 handler := middleware.RequireAuth(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
324 t.Error("handler should not be called for expired token")
325 }))
326
327 // Create expired token (expired 1 hour ago)
328 token := client.createTestToken("did:plc:test123", sessionID, -time.Hour)
329
330 req := httptest.NewRequest("GET", "/test", nil)
331 req.Header.Set("Authorization", "Bearer "+token)
332 w := httptest.NewRecorder()
333
334 handler.ServeHTTP(w, req)
335
336 if w.Code != http.StatusUnauthorized {
337 t.Errorf("expected status 401, got %d", w.Code)
338 }
339}
340
341// TestRequireAuth_SessionNotFound tests that tokens with non-existent sessions are rejected
342func TestRequireAuth_SessionNotFound(t *testing.T) {
343 client := newMockOAuthClient()
344 store := newMockOAuthStore()
345 middleware := NewOAuthAuthMiddleware(client, store)
346
347 handler := middleware.RequireAuth(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
348 t.Error("handler should not be called")
349 }))
350
351 // Create token for session that doesn't exist in store
352 token := client.createTestToken("did:plc:nonexistent", "session999", time.Hour)
353
354 req := httptest.NewRequest("GET", "/test", nil)
355 req.Header.Set("Authorization", "Bearer "+token)
356 w := httptest.NewRecorder()
357
358 handler.ServeHTTP(w, req)
359
360 if w.Code != http.StatusUnauthorized {
361 t.Errorf("expected status 401, got %d", w.Code)
362 }
363}
364
365// TestRequireAuth_DIDMismatch tests that session DID must match token DID
366func TestRequireAuth_DIDMismatch(t *testing.T) {
367 client := newMockOAuthClient()
368 store := newMockOAuthStore()
369
370 // Create a session with different DID than token
371 did := syntax.DID("did:plc:different")
372 sessionID := "session123"
373 session := &oauthlib.ClientSessionData{
374 AccountDID: did,
375 SessionID: sessionID,
376 AccessToken: "test_access_token",
377 }
378 // Store with key that matches token DID
379 key := "did:plc:test123:" + sessionID
380 store.sessions[key] = session
381
382 middleware := NewOAuthAuthMiddleware(client, store)
383
384 handler := middleware.RequireAuth(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
385 t.Error("handler should not be called when DID mismatches")
386 }))
387
388 token := client.createTestToken("did:plc:test123", sessionID, time.Hour)
389
390 req := httptest.NewRequest("GET", "/test", nil)
391 req.Header.Set("Authorization", "Bearer "+token)
392 w := httptest.NewRecorder()
393
394 handler.ServeHTTP(w, req)
395
396 if w.Code != http.StatusUnauthorized {
397 t.Errorf("expected status 401, got %d", w.Code)
398 }
399}
400
401// TestOptionalAuth_WithToken tests that OptionalAuth accepts valid Bearer tokens
402func TestOptionalAuth_WithToken(t *testing.T) {
403 client := newMockOAuthClient()
404 store := newMockOAuthStore()
405
406 // Create a test session
407 did := syntax.DID("did:plc:test123")
408 sessionID := "session123"
409 session := &oauthlib.ClientSessionData{
410 AccountDID: did,
411 SessionID: sessionID,
412 AccessToken: "test_access_token",
413 }
414 _ = store.SaveSession(context.Background(), *session)
415
416 middleware := NewOAuthAuthMiddleware(client, store)
417
418 handlerCalled := false
419 handler := middleware.OptionalAuth(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
420 handlerCalled = true
421
422 // Verify DID was extracted
423 extractedDID := GetUserDID(r)
424 if extractedDID != "did:plc:test123" {
425 t.Errorf("expected DID 'did:plc:test123', got %s", extractedDID)
426 }
427
428 w.WriteHeader(http.StatusOK)
429 }))
430
431 token := client.createTestToken("did:plc:test123", sessionID, time.Hour)
432 req := httptest.NewRequest("GET", "/test", nil)
433 req.Header.Set("Authorization", "Bearer "+token)
434 w := httptest.NewRecorder()
435
436 handler.ServeHTTP(w, req)
437
438 if !handlerCalled {
439 t.Error("handler was not called")
440 }
441
442 if w.Code != http.StatusOK {
443 t.Errorf("expected status 200, got %d", w.Code)
444 }
445}
446
447// TestOptionalAuth_WithoutToken tests that OptionalAuth allows requests without tokens
448func TestOptionalAuth_WithoutToken(t *testing.T) {
449 client := newMockOAuthClient()
450 store := newMockOAuthStore()
451 middleware := NewOAuthAuthMiddleware(client, store)
452
453 handlerCalled := false
454 handler := middleware.OptionalAuth(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
455 handlerCalled = true
456
457 // Verify no DID is set
458 did := GetUserDID(r)
459 if did != "" {
460 t.Errorf("expected empty DID, got %s", did)
461 }
462
463 w.WriteHeader(http.StatusOK)
464 }))
465
466 req := httptest.NewRequest("GET", "/test", nil)
467 // No Authorization header
468 w := httptest.NewRecorder()
469
470 handler.ServeHTTP(w, req)
471
472 if !handlerCalled {
473 t.Error("handler was not called")
474 }
475
476 if w.Code != http.StatusOK {
477 t.Errorf("expected status 200, got %d", w.Code)
478 }
479}
480
481// TestOptionalAuth_InvalidToken tests that OptionalAuth continues without auth on invalid token
482func TestOptionalAuth_InvalidToken(t *testing.T) {
483 client := newMockOAuthClient()
484 store := newMockOAuthStore()
485 middleware := NewOAuthAuthMiddleware(client, store)
486
487 handlerCalled := false
488 handler := middleware.OptionalAuth(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
489 handlerCalled = true
490
491 // Verify no DID is set (invalid token ignored)
492 did := GetUserDID(r)
493 if did != "" {
494 t.Errorf("expected empty DID for invalid token, got %s", did)
495 }
496
497 w.WriteHeader(http.StatusOK)
498 }))
499
500 req := httptest.NewRequest("GET", "/test", nil)
501 req.Header.Set("Authorization", "Bearer not-a-valid-sealed-token")
502 w := httptest.NewRecorder()
503
504 handler.ServeHTTP(w, req)
505
506 if !handlerCalled {
507 t.Error("handler was not called")
508 }
509
510 if w.Code != http.StatusOK {
511 t.Errorf("expected status 200, got %d", w.Code)
512 }
513}
514
515// TestGetUserDID_NotAuthenticated tests that GetUserDID returns empty string when not authenticated
516func TestGetUserDID_NotAuthenticated(t *testing.T) {
517 req := httptest.NewRequest("GET", "/test", nil)
518 did := GetUserDID(req)
519
520 if did != "" {
521 t.Errorf("expected empty string, got %s", did)
522 }
523}
524
525// TestGetOAuthSession_NotAuthenticated tests that GetOAuthSession returns nil when not authenticated
526func TestGetOAuthSession_NotAuthenticated(t *testing.T) {
527 req := httptest.NewRequest("GET", "/test", nil)
528 session := GetOAuthSession(req)
529
530 if session != nil {
531 t.Errorf("expected nil session, got %+v", session)
532 }
533}
534
535// TestGetUserAccessToken_NotAuthenticated tests that GetUserAccessToken returns empty when not authenticated
536func TestGetUserAccessToken_NotAuthenticated(t *testing.T) {
537 req := httptest.NewRequest("GET", "/test", nil)
538 token := GetUserAccessToken(req)
539
540 if token != "" {
541 t.Errorf("expected empty token, got %s", token)
542 }
543}
544
545// TestSetTestUserDID tests the testing helper function
546func TestSetTestUserDID(t *testing.T) {
547 ctx := context.Background()
548 ctx = SetTestUserDID(ctx, "did:plc:testuser")
549
550 did, ok := ctx.Value(UserDIDKey).(string)
551 if !ok {
552 t.Error("DID not found in context")
553 }
554 if did != "did:plc:testuser" {
555 t.Errorf("expected 'did:plc:testuser', got %s", did)
556 }
557}
558
559// TestExtractBearerToken tests the Bearer token extraction logic
560func TestExtractBearerToken(t *testing.T) {
561 tests := []struct {
562 name string
563 authHeader string
564 expectToken string
565 expectOK bool
566 }{
567 {"valid bearer", "Bearer token123", "token123", true},
568 {"lowercase bearer", "bearer token123", "token123", true},
569 {"uppercase bearer", "BEARER token123", "token123", true},
570 {"mixed case", "BeArEr token123", "token123", true},
571 {"empty header", "", "", false},
572 {"wrong scheme", "DPoP token123", "", false},
573 {"no token", "Bearer", "", false},
574 {"no space", "Bearertoken123", "", false},
575 {"extra spaces", "Bearer token123 ", "token123", true},
576 }
577
578 for _, tt := range tests {
579 t.Run(tt.name, func(t *testing.T) {
580 token, ok := extractBearerToken(tt.authHeader)
581 if ok != tt.expectOK {
582 t.Errorf("expected ok=%v, got %v", tt.expectOK, ok)
583 }
584 if token != tt.expectToken {
585 t.Errorf("expected token '%s', got '%s'", tt.expectToken, token)
586 }
587 })
588 }
589}
590
591// TestRequireAuth_ValidCookie tests that valid session cookies are accepted
592func TestRequireAuth_ValidCookie(t *testing.T) {
593 client := newMockOAuthClient()
594 store := newMockOAuthStore()
595
596 // Create a test session
597 did := syntax.DID("did:plc:test123")
598 sessionID := "session123"
599 session := &oauthlib.ClientSessionData{
600 AccountDID: did,
601 SessionID: sessionID,
602 AccessToken: "test_access_token",
603 HostURL: "https://pds.example.com",
604 }
605 _ = store.SaveSession(context.Background(), *session)
606
607 middleware := NewOAuthAuthMiddleware(client, store)
608
609 handlerCalled := false
610 handler := middleware.RequireAuth(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
611 handlerCalled = true
612
613 // Verify DID was extracted and injected into context
614 extractedDID := GetUserDID(r)
615 if extractedDID != "did:plc:test123" {
616 t.Errorf("expected DID 'did:plc:test123', got %s", extractedDID)
617 }
618
619 // Verify OAuth session was injected
620 oauthSession := GetOAuthSession(r)
621 if oauthSession == nil {
622 t.Error("expected OAuth session to be non-nil")
623 return
624 }
625 if oauthSession.SessionID != sessionID {
626 t.Errorf("expected session ID '%s', got %s", sessionID, oauthSession.SessionID)
627 }
628
629 w.WriteHeader(http.StatusOK)
630 }))
631
632 token := client.createTestToken("did:plc:test123", sessionID, time.Hour)
633 req := httptest.NewRequest("GET", "/test", nil)
634 req.AddCookie(&http.Cookie{
635 Name: "coves_session",
636 Value: token,
637 })
638 w := httptest.NewRecorder()
639
640 handler.ServeHTTP(w, req)
641
642 if !handlerCalled {
643 t.Error("handler was not called")
644 }
645
646 if w.Code != http.StatusOK {
647 t.Errorf("expected status 200, got %d: %s", w.Code, w.Body.String())
648 }
649}
650
651// TestRequireAuth_HeaderPrecedenceOverCookie tests that Authorization header takes precedence over cookie
652func TestRequireAuth_HeaderPrecedenceOverCookie(t *testing.T) {
653 client := newMockOAuthClient()
654 store := newMockOAuthStore()
655
656 // Create two test sessions
657 did1 := syntax.DID("did:plc:header")
658 sessionID1 := "session_header"
659 session1 := &oauthlib.ClientSessionData{
660 AccountDID: did1,
661 SessionID: sessionID1,
662 AccessToken: "header_token",
663 HostURL: "https://pds.example.com",
664 }
665 _ = store.SaveSession(context.Background(), *session1)
666
667 did2 := syntax.DID("did:plc:cookie")
668 sessionID2 := "session_cookie"
669 session2 := &oauthlib.ClientSessionData{
670 AccountDID: did2,
671 SessionID: sessionID2,
672 AccessToken: "cookie_token",
673 HostURL: "https://pds.example.com",
674 }
675 _ = store.SaveSession(context.Background(), *session2)
676
677 middleware := NewOAuthAuthMiddleware(client, store)
678
679 handlerCalled := false
680 handler := middleware.RequireAuth(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
681 handlerCalled = true
682
683 // Should get header DID, not cookie DID
684 extractedDID := GetUserDID(r)
685 if extractedDID != "did:plc:header" {
686 t.Errorf("expected header DID 'did:plc:header', got %s", extractedDID)
687 }
688
689 w.WriteHeader(http.StatusOK)
690 }))
691
692 headerToken := client.createTestToken("did:plc:header", sessionID1, time.Hour)
693 cookieToken := client.createTestToken("did:plc:cookie", sessionID2, time.Hour)
694
695 req := httptest.NewRequest("GET", "/test", nil)
696 req.Header.Set("Authorization", "Bearer "+headerToken)
697 req.AddCookie(&http.Cookie{
698 Name: "coves_session",
699 Value: cookieToken,
700 })
701 w := httptest.NewRecorder()
702
703 handler.ServeHTTP(w, req)
704
705 if !handlerCalled {
706 t.Error("handler was not called")
707 }
708
709 if w.Code != http.StatusOK {
710 t.Errorf("expected status 200, got %d", w.Code)
711 }
712}
713
714// TestRequireAuth_MissingBothHeaderAndCookie tests that missing both auth methods is rejected
715func TestRequireAuth_MissingBothHeaderAndCookie(t *testing.T) {
716 client := newMockOAuthClient()
717 store := newMockOAuthStore()
718 middleware := NewOAuthAuthMiddleware(client, store)
719
720 handler := middleware.RequireAuth(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
721 t.Error("handler should not be called")
722 }))
723
724 req := httptest.NewRequest("GET", "/test", nil)
725 // No Authorization header and no cookie
726 w := httptest.NewRecorder()
727
728 handler.ServeHTTP(w, req)
729
730 if w.Code != http.StatusUnauthorized {
731 t.Errorf("expected status 401, got %d", w.Code)
732 }
733}
734
735// TestRequireAuth_InvalidCookie tests that malformed cookie tokens are rejected
736func TestRequireAuth_InvalidCookie(t *testing.T) {
737 client := newMockOAuthClient()
738 store := newMockOAuthStore()
739 middleware := NewOAuthAuthMiddleware(client, store)
740
741 handler := middleware.RequireAuth(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
742 t.Error("handler should not be called")
743 }))
744
745 req := httptest.NewRequest("GET", "/test", nil)
746 req.AddCookie(&http.Cookie{
747 Name: "coves_session",
748 Value: "not-a-valid-sealed-token",
749 })
750 w := httptest.NewRecorder()
751
752 handler.ServeHTTP(w, req)
753
754 if w.Code != http.StatusUnauthorized {
755 t.Errorf("expected status 401, got %d", w.Code)
756 }
757}
758
759// TestOptionalAuth_WithCookie tests that OptionalAuth accepts valid session cookies
760func TestOptionalAuth_WithCookie(t *testing.T) {
761 client := newMockOAuthClient()
762 store := newMockOAuthStore()
763
764 // Create a test session
765 did := syntax.DID("did:plc:test123")
766 sessionID := "session123"
767 session := &oauthlib.ClientSessionData{
768 AccountDID: did,
769 SessionID: sessionID,
770 AccessToken: "test_access_token",
771 }
772 _ = store.SaveSession(context.Background(), *session)
773
774 middleware := NewOAuthAuthMiddleware(client, store)
775
776 handlerCalled := false
777 handler := middleware.OptionalAuth(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
778 handlerCalled = true
779
780 // Verify DID was extracted
781 extractedDID := GetUserDID(r)
782 if extractedDID != "did:plc:test123" {
783 t.Errorf("expected DID 'did:plc:test123', got %s", extractedDID)
784 }
785
786 w.WriteHeader(http.StatusOK)
787 }))
788
789 token := client.createTestToken("did:plc:test123", sessionID, time.Hour)
790 req := httptest.NewRequest("GET", "/test", nil)
791 req.AddCookie(&http.Cookie{
792 Name: "coves_session",
793 Value: token,
794 })
795 w := httptest.NewRecorder()
796
797 handler.ServeHTTP(w, req)
798
799 if !handlerCalled {
800 t.Error("handler was not called")
801 }
802
803 if w.Code != http.StatusOK {
804 t.Errorf("expected status 200, got %d", w.Code)
805 }
806}
807
808// TestOptionalAuth_InvalidCookie tests that OptionalAuth continues without auth on invalid cookie
809func TestOptionalAuth_InvalidCookie(t *testing.T) {
810 client := newMockOAuthClient()
811 store := newMockOAuthStore()
812 middleware := NewOAuthAuthMiddleware(client, store)
813
814 handlerCalled := false
815 handler := middleware.OptionalAuth(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
816 handlerCalled = true
817
818 // Verify no DID is set (invalid cookie ignored)
819 did := GetUserDID(r)
820 if did != "" {
821 t.Errorf("expected empty DID for invalid cookie, got %s", did)
822 }
823
824 w.WriteHeader(http.StatusOK)
825 }))
826
827 req := httptest.NewRequest("GET", "/test", nil)
828 req.AddCookie(&http.Cookie{
829 Name: "coves_session",
830 Value: "not-a-valid-sealed-token",
831 })
832 w := httptest.NewRecorder()
833
834 handler.ServeHTTP(w, req)
835
836 if !handlerCalled {
837 t.Error("handler was not called")
838 }
839
840 if w.Code != http.StatusOK {
841 t.Errorf("expected status 200, got %d", w.Code)
842 }
843}
844
845// TestWriteAuthError_JSONEscaping tests that writeAuthError properly escapes messages
846func TestWriteAuthError_JSONEscaping(t *testing.T) {
847 tests := []struct {
848 name string
849 message string
850 }{
851 {"simple message", "Missing authentication"},
852 {"message with quotes", `Invalid "token" format`},
853 {"message with newlines", "Invalid\ntoken\nformat"},
854 {"message with backslashes", `Invalid \ token`},
855 {"message with special chars", `Invalid <script>alert("xss")</script> token`},
856 {"message with unicode", "Invalid token: \u2028\u2029"},
857 }
858
859 for _, tt := range tests {
860 t.Run(tt.name, func(t *testing.T) {
861 w := httptest.NewRecorder()
862 writeAuthError(w, tt.message)
863
864 // Verify status code
865 if w.Code != http.StatusUnauthorized {
866 t.Errorf("expected status 401, got %d", w.Code)
867 }
868
869 // Verify content type
870 if ct := w.Header().Get("Content-Type"); ct != "application/json" {
871 t.Errorf("expected Content-Type 'application/json', got %s", ct)
872 }
873
874 // Verify response is valid JSON
875 var response map[string]string
876 if err := json.Unmarshal(w.Body.Bytes(), &response); err != nil {
877 t.Fatalf("response is not valid JSON: %v\nBody: %s", err, w.Body.String())
878 }
879
880 // Verify fields
881 if response["error"] != "AuthenticationRequired" {
882 t.Errorf("expected error 'AuthenticationRequired', got %s", response["error"])
883 }
884 if response["message"] != tt.message {
885 t.Errorf("expected message %q, got %q", tt.message, response["message"])
886 }
887 })
888 }
889}