A community based topic aggregation platform built on atproto
1package middleware
2
3import (
4 "context"
5 "fmt"
6 "net/http"
7 "net/http/httptest"
8 "testing"
9 "time"
10
11 "github.com/golang-jwt/jwt/v5"
12)
13
14// mockJWKSFetcher is a test double for JWKSFetcher
15type mockJWKSFetcher struct {
16 shouldFail bool
17}
18
19func (m *mockJWKSFetcher) FetchPublicKey(ctx context.Context, issuer, token string) (interface{}, error) {
20 if m.shouldFail {
21 return nil, fmt.Errorf("mock fetch failure")
22 }
23 // Return nil - we won't actually verify signatures in Phase 1 tests
24 return nil, nil
25}
26
27// createTestToken creates a test JWT with the given DID
28func createTestToken(did string) string {
29 claims := jwt.MapClaims{
30 "sub": did,
31 "iss": "https://test.pds.local",
32 "scope": "atproto",
33 "exp": time.Now().Add(1 * time.Hour).Unix(),
34 "iat": time.Now().Unix(),
35 }
36
37 token := jwt.NewWithClaims(jwt.SigningMethodNone, claims)
38 tokenString, _ := token.SignedString(jwt.UnsafeAllowNoneSignatureType)
39 return tokenString
40}
41
42// TestRequireAuth_ValidToken tests that valid tokens are accepted (Phase 1)
43func TestRequireAuth_ValidToken(t *testing.T) {
44 fetcher := &mockJWKSFetcher{}
45 middleware := NewAtProtoAuthMiddleware(fetcher, true) // skipVerify=true
46
47 handlerCalled := false
48 handler := middleware.RequireAuth(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
49 handlerCalled = true
50
51 // Verify DID was extracted and injected into context
52 did := GetUserDID(r)
53 if did != "did:plc:test123" {
54 t.Errorf("expected DID 'did:plc:test123', got %s", did)
55 }
56
57 // Verify claims were injected
58 claims := GetJWTClaims(r)
59 if claims == nil {
60 t.Error("expected claims to be non-nil")
61 return
62 }
63 if claims.Subject != "did:plc:test123" {
64 t.Errorf("expected claims.Subject 'did:plc:test123', got %s", claims.Subject)
65 }
66
67 w.WriteHeader(http.StatusOK)
68 }))
69
70 token := createTestToken("did:plc:test123")
71 req := httptest.NewRequest("GET", "/test", nil)
72 req.Header.Set("Authorization", "Bearer "+token)
73 w := httptest.NewRecorder()
74
75 handler.ServeHTTP(w, req)
76
77 if !handlerCalled {
78 t.Error("handler was not called")
79 }
80
81 if w.Code != http.StatusOK {
82 t.Errorf("expected status 200, got %d: %s", w.Code, w.Body.String())
83 }
84}
85
86// TestRequireAuth_MissingAuthHeader tests that missing Authorization header is rejected
87func TestRequireAuth_MissingAuthHeader(t *testing.T) {
88 fetcher := &mockJWKSFetcher{}
89 middleware := NewAtProtoAuthMiddleware(fetcher, true)
90
91 handler := middleware.RequireAuth(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
92 t.Error("handler should not be called")
93 }))
94
95 req := httptest.NewRequest("GET", "/test", nil)
96 // No Authorization header
97 w := httptest.NewRecorder()
98
99 handler.ServeHTTP(w, req)
100
101 if w.Code != http.StatusUnauthorized {
102 t.Errorf("expected status 401, got %d", w.Code)
103 }
104}
105
106// TestRequireAuth_InvalidAuthHeaderFormat tests that non-Bearer tokens are rejected
107func TestRequireAuth_InvalidAuthHeaderFormat(t *testing.T) {
108 fetcher := &mockJWKSFetcher{}
109 middleware := NewAtProtoAuthMiddleware(fetcher, true)
110
111 handler := middleware.RequireAuth(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
112 t.Error("handler should not be called")
113 }))
114
115 req := httptest.NewRequest("GET", "/test", nil)
116 req.Header.Set("Authorization", "Basic dGVzdDp0ZXN0") // Wrong format
117 w := httptest.NewRecorder()
118
119 handler.ServeHTTP(w, req)
120
121 if w.Code != http.StatusUnauthorized {
122 t.Errorf("expected status 401, got %d", w.Code)
123 }
124}
125
126// TestRequireAuth_MalformedToken tests that malformed JWTs are rejected
127func TestRequireAuth_MalformedToken(t *testing.T) {
128 fetcher := &mockJWKSFetcher{}
129 middleware := NewAtProtoAuthMiddleware(fetcher, true)
130
131 handler := middleware.RequireAuth(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
132 t.Error("handler should not be called")
133 }))
134
135 req := httptest.NewRequest("GET", "/test", nil)
136 req.Header.Set("Authorization", "Bearer not-a-valid-jwt")
137 w := httptest.NewRecorder()
138
139 handler.ServeHTTP(w, req)
140
141 if w.Code != http.StatusUnauthorized {
142 t.Errorf("expected status 401, got %d", w.Code)
143 }
144}
145
146// TestRequireAuth_ExpiredToken tests that expired tokens are rejected
147func TestRequireAuth_ExpiredToken(t *testing.T) {
148 fetcher := &mockJWKSFetcher{}
149 middleware := NewAtProtoAuthMiddleware(fetcher, true)
150
151 handler := middleware.RequireAuth(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
152 t.Error("handler should not be called for expired token")
153 }))
154
155 // Create expired token
156 claims := jwt.MapClaims{
157 "sub": "did:plc:test123",
158 "iss": "https://test.pds.local",
159 "scope": "atproto",
160 "exp": time.Now().Add(-1 * time.Hour).Unix(), // Expired 1 hour ago
161 "iat": time.Now().Add(-2 * time.Hour).Unix(),
162 }
163
164 token := jwt.NewWithClaims(jwt.SigningMethodNone, claims)
165 tokenString, _ := token.SignedString(jwt.UnsafeAllowNoneSignatureType)
166
167 req := httptest.NewRequest("GET", "/test", nil)
168 req.Header.Set("Authorization", "Bearer "+tokenString)
169 w := httptest.NewRecorder()
170
171 handler.ServeHTTP(w, req)
172
173 if w.Code != http.StatusUnauthorized {
174 t.Errorf("expected status 401, got %d", w.Code)
175 }
176}
177
178// TestRequireAuth_MissingDID tests that tokens without DID are rejected
179func TestRequireAuth_MissingDID(t *testing.T) {
180 fetcher := &mockJWKSFetcher{}
181 middleware := NewAtProtoAuthMiddleware(fetcher, true)
182
183 handler := middleware.RequireAuth(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
184 t.Error("handler should not be called")
185 }))
186
187 // Create token without sub claim
188 claims := jwt.MapClaims{
189 // "sub" missing
190 "iss": "https://test.pds.local",
191 "scope": "atproto",
192 "exp": time.Now().Add(1 * time.Hour).Unix(),
193 "iat": time.Now().Unix(),
194 }
195
196 token := jwt.NewWithClaims(jwt.SigningMethodNone, claims)
197 tokenString, _ := token.SignedString(jwt.UnsafeAllowNoneSignatureType)
198
199 req := httptest.NewRequest("GET", "/test", nil)
200 req.Header.Set("Authorization", "Bearer "+tokenString)
201 w := httptest.NewRecorder()
202
203 handler.ServeHTTP(w, req)
204
205 if w.Code != http.StatusUnauthorized {
206 t.Errorf("expected status 401, got %d", w.Code)
207 }
208}
209
210// TestOptionalAuth_WithToken tests that OptionalAuth accepts valid tokens
211func TestOptionalAuth_WithToken(t *testing.T) {
212 fetcher := &mockJWKSFetcher{}
213 middleware := NewAtProtoAuthMiddleware(fetcher, true)
214
215 handlerCalled := false
216 handler := middleware.OptionalAuth(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
217 handlerCalled = true
218
219 // Verify DID was extracted
220 did := GetUserDID(r)
221 if did != "did:plc:test123" {
222 t.Errorf("expected DID 'did:plc:test123', got %s", did)
223 }
224
225 w.WriteHeader(http.StatusOK)
226 }))
227
228 token := createTestToken("did:plc:test123")
229 req := httptest.NewRequest("GET", "/test", nil)
230 req.Header.Set("Authorization", "Bearer "+token)
231 w := httptest.NewRecorder()
232
233 handler.ServeHTTP(w, req)
234
235 if !handlerCalled {
236 t.Error("handler was not called")
237 }
238
239 if w.Code != http.StatusOK {
240 t.Errorf("expected status 200, got %d", w.Code)
241 }
242}
243
244// TestOptionalAuth_WithoutToken tests that OptionalAuth allows requests without tokens
245func TestOptionalAuth_WithoutToken(t *testing.T) {
246 fetcher := &mockJWKSFetcher{}
247 middleware := NewAtProtoAuthMiddleware(fetcher, true)
248
249 handlerCalled := false
250 handler := middleware.OptionalAuth(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
251 handlerCalled = true
252
253 // Verify no DID is set
254 did := GetUserDID(r)
255 if did != "" {
256 t.Errorf("expected empty DID, got %s", did)
257 }
258
259 w.WriteHeader(http.StatusOK)
260 }))
261
262 req := httptest.NewRequest("GET", "/test", nil)
263 // No Authorization header
264 w := httptest.NewRecorder()
265
266 handler.ServeHTTP(w, req)
267
268 if !handlerCalled {
269 t.Error("handler was not called")
270 }
271
272 if w.Code != http.StatusOK {
273 t.Errorf("expected status 200, got %d", w.Code)
274 }
275}
276
277// TestOptionalAuth_InvalidToken tests that OptionalAuth continues without auth on invalid token
278func TestOptionalAuth_InvalidToken(t *testing.T) {
279 fetcher := &mockJWKSFetcher{}
280 middleware := NewAtProtoAuthMiddleware(fetcher, true)
281
282 handlerCalled := false
283 handler := middleware.OptionalAuth(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
284 handlerCalled = true
285
286 // Verify no DID is set (invalid token ignored)
287 did := GetUserDID(r)
288 if did != "" {
289 t.Errorf("expected empty DID for invalid token, got %s", did)
290 }
291
292 w.WriteHeader(http.StatusOK)
293 }))
294
295 req := httptest.NewRequest("GET", "/test", nil)
296 req.Header.Set("Authorization", "Bearer not-a-valid-jwt")
297 w := httptest.NewRecorder()
298
299 handler.ServeHTTP(w, req)
300
301 if !handlerCalled {
302 t.Error("handler was not called")
303 }
304
305 if w.Code != http.StatusOK {
306 t.Errorf("expected status 200, got %d", w.Code)
307 }
308}
309
310// TestGetUserDID_NotAuthenticated tests that GetUserDID returns empty string when not authenticated
311func TestGetUserDID_NotAuthenticated(t *testing.T) {
312 req := httptest.NewRequest("GET", "/test", nil)
313 did := GetUserDID(req)
314
315 if did != "" {
316 t.Errorf("expected empty string, got %s", did)
317 }
318}
319
320// TestGetJWTClaims_NotAuthenticated tests that GetJWTClaims returns nil when not authenticated
321func TestGetJWTClaims_NotAuthenticated(t *testing.T) {
322 req := httptest.NewRequest("GET", "/test", nil)
323 claims := GetJWTClaims(req)
324
325 if claims != nil {
326 t.Errorf("expected nil claims, got %+v", claims)
327 }
328}