···
11
+
"github.com/golang-jwt/jwt/v5"
14
+
// mockJWKSFetcher is a test double for JWKSFetcher
15
+
type mockJWKSFetcher struct {
19
+
func (m *mockJWKSFetcher) FetchPublicKey(ctx context.Context, issuer, token string) (interface{}, error) {
21
+
return nil, fmt.Errorf("mock fetch failure")
23
+
// Return nil - we won't actually verify signatures in Phase 1 tests
27
+
// createTestToken creates a test JWT with the given DID
28
+
func createTestToken(did string) string {
29
+
claims := jwt.MapClaims{
31
+
"iss": "https://test.pds.local",
33
+
"exp": time.Now().Add(1 * time.Hour).Unix(),
34
+
"iat": time.Now().Unix(),
37
+
token := jwt.NewWithClaims(jwt.SigningMethodNone, claims)
38
+
tokenString, _ := token.SignedString(jwt.UnsafeAllowNoneSignatureType)
42
+
// TestRequireAuth_ValidToken tests that valid tokens are accepted (Phase 1)
43
+
func TestRequireAuth_ValidToken(t *testing.T) {
44
+
fetcher := &mockJWKSFetcher{}
45
+
middleware := NewAtProtoAuthMiddleware(fetcher, true) // skipVerify=true
47
+
handlerCalled := false
48
+
handler := middleware.RequireAuth(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
49
+
handlerCalled = true
51
+
// Verify DID was extracted and injected into context
52
+
did := GetUserDID(r)
53
+
if did != "did:plc:test123" {
54
+
t.Errorf("expected DID 'did:plc:test123', got %s", did)
57
+
// Verify claims were injected
58
+
claims := GetJWTClaims(r)
60
+
t.Error("expected claims to be non-nil")
63
+
if claims.Subject != "did:plc:test123" {
64
+
t.Errorf("expected claims.Subject 'did:plc:test123', got %s", claims.Subject)
67
+
w.WriteHeader(http.StatusOK)
70
+
token := createTestToken("did:plc:test123")
71
+
req := httptest.NewRequest("GET", "/test", nil)
72
+
req.Header.Set("Authorization", "Bearer "+token)
73
+
w := httptest.NewRecorder()
75
+
handler.ServeHTTP(w, req)
78
+
t.Error("handler was not called")
81
+
if w.Code != http.StatusOK {
82
+
t.Errorf("expected status 200, got %d: %s", w.Code, w.Body.String())
86
+
// TestRequireAuth_MissingAuthHeader tests that missing Authorization header is rejected
87
+
func TestRequireAuth_MissingAuthHeader(t *testing.T) {
88
+
fetcher := &mockJWKSFetcher{}
89
+
middleware := NewAtProtoAuthMiddleware(fetcher, true)
91
+
handler := middleware.RequireAuth(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
92
+
t.Error("handler should not be called")
95
+
req := httptest.NewRequest("GET", "/test", nil)
96
+
// No Authorization header
97
+
w := httptest.NewRecorder()
99
+
handler.ServeHTTP(w, req)
101
+
if w.Code != http.StatusUnauthorized {
102
+
t.Errorf("expected status 401, got %d", w.Code)
106
+
// TestRequireAuth_InvalidAuthHeaderFormat tests that non-Bearer tokens are rejected
107
+
func TestRequireAuth_InvalidAuthHeaderFormat(t *testing.T) {
108
+
fetcher := &mockJWKSFetcher{}
109
+
middleware := NewAtProtoAuthMiddleware(fetcher, true)
111
+
handler := middleware.RequireAuth(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
112
+
t.Error("handler should not be called")
115
+
req := httptest.NewRequest("GET", "/test", nil)
116
+
req.Header.Set("Authorization", "Basic dGVzdDp0ZXN0") // Wrong format
117
+
w := httptest.NewRecorder()
119
+
handler.ServeHTTP(w, req)
121
+
if w.Code != http.StatusUnauthorized {
122
+
t.Errorf("expected status 401, got %d", w.Code)
126
+
// TestRequireAuth_MalformedToken tests that malformed JWTs are rejected
127
+
func TestRequireAuth_MalformedToken(t *testing.T) {
128
+
fetcher := &mockJWKSFetcher{}
129
+
middleware := NewAtProtoAuthMiddleware(fetcher, true)
131
+
handler := middleware.RequireAuth(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
132
+
t.Error("handler should not be called")
135
+
req := httptest.NewRequest("GET", "/test", nil)
136
+
req.Header.Set("Authorization", "Bearer not-a-valid-jwt")
137
+
w := httptest.NewRecorder()
139
+
handler.ServeHTTP(w, req)
141
+
if w.Code != http.StatusUnauthorized {
142
+
t.Errorf("expected status 401, got %d", w.Code)
146
+
// TestRequireAuth_ExpiredToken tests that expired tokens are rejected
147
+
func TestRequireAuth_ExpiredToken(t *testing.T) {
148
+
fetcher := &mockJWKSFetcher{}
149
+
middleware := NewAtProtoAuthMiddleware(fetcher, true)
151
+
handler := middleware.RequireAuth(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
152
+
t.Error("handler should not be called for expired token")
155
+
// Create expired token
156
+
claims := jwt.MapClaims{
157
+
"sub": "did:plc:test123",
158
+
"iss": "https://test.pds.local",
159
+
"scope": "atproto",
160
+
"exp": time.Now().Add(-1 * time.Hour).Unix(), // Expired 1 hour ago
161
+
"iat": time.Now().Add(-2 * time.Hour).Unix(),
164
+
token := jwt.NewWithClaims(jwt.SigningMethodNone, claims)
165
+
tokenString, _ := token.SignedString(jwt.UnsafeAllowNoneSignatureType)
167
+
req := httptest.NewRequest("GET", "/test", nil)
168
+
req.Header.Set("Authorization", "Bearer "+tokenString)
169
+
w := httptest.NewRecorder()
171
+
handler.ServeHTTP(w, req)
173
+
if w.Code != http.StatusUnauthorized {
174
+
t.Errorf("expected status 401, got %d", w.Code)
178
+
// TestRequireAuth_MissingDID tests that tokens without DID are rejected
179
+
func TestRequireAuth_MissingDID(t *testing.T) {
180
+
fetcher := &mockJWKSFetcher{}
181
+
middleware := NewAtProtoAuthMiddleware(fetcher, true)
183
+
handler := middleware.RequireAuth(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
184
+
t.Error("handler should not be called")
187
+
// Create token without sub claim
188
+
claims := jwt.MapClaims{
190
+
"iss": "https://test.pds.local",
191
+
"scope": "atproto",
192
+
"exp": time.Now().Add(1 * time.Hour).Unix(),
193
+
"iat": time.Now().Unix(),
196
+
token := jwt.NewWithClaims(jwt.SigningMethodNone, claims)
197
+
tokenString, _ := token.SignedString(jwt.UnsafeAllowNoneSignatureType)
199
+
req := httptest.NewRequest("GET", "/test", nil)
200
+
req.Header.Set("Authorization", "Bearer "+tokenString)
201
+
w := httptest.NewRecorder()
203
+
handler.ServeHTTP(w, req)
205
+
if w.Code != http.StatusUnauthorized {
206
+
t.Errorf("expected status 401, got %d", w.Code)
210
+
// TestOptionalAuth_WithToken tests that OptionalAuth accepts valid tokens
211
+
func TestOptionalAuth_WithToken(t *testing.T) {
212
+
fetcher := &mockJWKSFetcher{}
213
+
middleware := NewAtProtoAuthMiddleware(fetcher, true)
215
+
handlerCalled := false
216
+
handler := middleware.OptionalAuth(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
217
+
handlerCalled = true
219
+
// Verify DID was extracted
220
+
did := GetUserDID(r)
221
+
if did != "did:plc:test123" {
222
+
t.Errorf("expected DID 'did:plc:test123', got %s", did)
225
+
w.WriteHeader(http.StatusOK)
228
+
token := createTestToken("did:plc:test123")
229
+
req := httptest.NewRequest("GET", "/test", nil)
230
+
req.Header.Set("Authorization", "Bearer "+token)
231
+
w := httptest.NewRecorder()
233
+
handler.ServeHTTP(w, req)
235
+
if !handlerCalled {
236
+
t.Error("handler was not called")
239
+
if w.Code != http.StatusOK {
240
+
t.Errorf("expected status 200, got %d", w.Code)
244
+
// TestOptionalAuth_WithoutToken tests that OptionalAuth allows requests without tokens
245
+
func TestOptionalAuth_WithoutToken(t *testing.T) {
246
+
fetcher := &mockJWKSFetcher{}
247
+
middleware := NewAtProtoAuthMiddleware(fetcher, true)
249
+
handlerCalled := false
250
+
handler := middleware.OptionalAuth(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
251
+
handlerCalled = true
253
+
// Verify no DID is set
254
+
did := GetUserDID(r)
256
+
t.Errorf("expected empty DID, got %s", did)
259
+
w.WriteHeader(http.StatusOK)
262
+
req := httptest.NewRequest("GET", "/test", nil)
263
+
// No Authorization header
264
+
w := httptest.NewRecorder()
266
+
handler.ServeHTTP(w, req)
268
+
if !handlerCalled {
269
+
t.Error("handler was not called")
272
+
if w.Code != http.StatusOK {
273
+
t.Errorf("expected status 200, got %d", w.Code)
277
+
// TestOptionalAuth_InvalidToken tests that OptionalAuth continues without auth on invalid token
278
+
func TestOptionalAuth_InvalidToken(t *testing.T) {
279
+
fetcher := &mockJWKSFetcher{}
280
+
middleware := NewAtProtoAuthMiddleware(fetcher, true)
282
+
handlerCalled := false
283
+
handler := middleware.OptionalAuth(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
284
+
handlerCalled = true
286
+
// Verify no DID is set (invalid token ignored)
287
+
did := GetUserDID(r)
289
+
t.Errorf("expected empty DID for invalid token, got %s", did)
292
+
w.WriteHeader(http.StatusOK)
295
+
req := httptest.NewRequest("GET", "/test", nil)
296
+
req.Header.Set("Authorization", "Bearer not-a-valid-jwt")
297
+
w := httptest.NewRecorder()
299
+
handler.ServeHTTP(w, req)
301
+
if !handlerCalled {
302
+
t.Error("handler was not called")
305
+
if w.Code != http.StatusOK {
306
+
t.Errorf("expected status 200, got %d", w.Code)
310
+
// TestGetUserDID_NotAuthenticated tests that GetUserDID returns empty string when not authenticated
311
+
func TestGetUserDID_NotAuthenticated(t *testing.T) {
312
+
req := httptest.NewRequest("GET", "/test", nil)
313
+
did := GetUserDID(req)
316
+
t.Errorf("expected empty string, got %s", did)
320
+
// TestGetJWTClaims_NotAuthenticated tests that GetJWTClaims returns nil when not authenticated
321
+
func TestGetJWTClaims_NotAuthenticated(t *testing.T) {
322
+
req := httptest.NewRequest("GET", "/test", nil)
323
+
claims := GetJWTClaims(req)
326
+
t.Errorf("expected nil claims, got %+v", claims)