A community based topic aggregation platform built on atproto
1package middleware 2 3import ( 4 "context" 5 "fmt" 6 "net/http" 7 "net/http/httptest" 8 "testing" 9 "time" 10 11 "github.com/golang-jwt/jwt/v5" 12) 13 14// mockJWKSFetcher is a test double for JWKSFetcher 15type mockJWKSFetcher struct { 16 shouldFail bool 17} 18 19func (m *mockJWKSFetcher) FetchPublicKey(ctx context.Context, issuer, token string) (interface{}, error) { 20 if m.shouldFail { 21 return nil, fmt.Errorf("mock fetch failure") 22 } 23 // Return nil - we won't actually verify signatures in Phase 1 tests 24 return nil, nil 25} 26 27// createTestToken creates a test JWT with the given DID 28func createTestToken(did string) string { 29 claims := jwt.MapClaims{ 30 "sub": did, 31 "iss": "https://test.pds.local", 32 "scope": "atproto", 33 "exp": time.Now().Add(1 * time.Hour).Unix(), 34 "iat": time.Now().Unix(), 35 } 36 37 token := jwt.NewWithClaims(jwt.SigningMethodNone, claims) 38 tokenString, _ := token.SignedString(jwt.UnsafeAllowNoneSignatureType) 39 return tokenString 40} 41 42// TestRequireAuth_ValidToken tests that valid tokens are accepted (Phase 1) 43func TestRequireAuth_ValidToken(t *testing.T) { 44 fetcher := &mockJWKSFetcher{} 45 middleware := NewAtProtoAuthMiddleware(fetcher, true) // skipVerify=true 46 47 handlerCalled := false 48 handler := middleware.RequireAuth(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 49 handlerCalled = true 50 51 // Verify DID was extracted and injected into context 52 did := GetUserDID(r) 53 if did != "did:plc:test123" { 54 t.Errorf("expected DID 'did:plc:test123', got %s", did) 55 } 56 57 // Verify claims were injected 58 claims := GetJWTClaims(r) 59 if claims == nil { 60 t.Error("expected claims to be non-nil") 61 return 62 } 63 if claims.Subject != "did:plc:test123" { 64 t.Errorf("expected claims.Subject 'did:plc:test123', got %s", claims.Subject) 65 } 66 67 w.WriteHeader(http.StatusOK) 68 })) 69 70 token := createTestToken("did:plc:test123") 71 req := httptest.NewRequest("GET", "/test", nil) 72 req.Header.Set("Authorization", "Bearer "+token) 73 w := httptest.NewRecorder() 74 75 handler.ServeHTTP(w, req) 76 77 if !handlerCalled { 78 t.Error("handler was not called") 79 } 80 81 if w.Code != http.StatusOK { 82 t.Errorf("expected status 200, got %d: %s", w.Code, w.Body.String()) 83 } 84} 85 86// TestRequireAuth_MissingAuthHeader tests that missing Authorization header is rejected 87func TestRequireAuth_MissingAuthHeader(t *testing.T) { 88 fetcher := &mockJWKSFetcher{} 89 middleware := NewAtProtoAuthMiddleware(fetcher, true) 90 91 handler := middleware.RequireAuth(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 92 t.Error("handler should not be called") 93 })) 94 95 req := httptest.NewRequest("GET", "/test", nil) 96 // No Authorization header 97 w := httptest.NewRecorder() 98 99 handler.ServeHTTP(w, req) 100 101 if w.Code != http.StatusUnauthorized { 102 t.Errorf("expected status 401, got %d", w.Code) 103 } 104} 105 106// TestRequireAuth_InvalidAuthHeaderFormat tests that non-Bearer tokens are rejected 107func TestRequireAuth_InvalidAuthHeaderFormat(t *testing.T) { 108 fetcher := &mockJWKSFetcher{} 109 middleware := NewAtProtoAuthMiddleware(fetcher, true) 110 111 handler := middleware.RequireAuth(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 112 t.Error("handler should not be called") 113 })) 114 115 req := httptest.NewRequest("GET", "/test", nil) 116 req.Header.Set("Authorization", "Basic dGVzdDp0ZXN0") // Wrong format 117 w := httptest.NewRecorder() 118 119 handler.ServeHTTP(w, req) 120 121 if w.Code != http.StatusUnauthorized { 122 t.Errorf("expected status 401, got %d", w.Code) 123 } 124} 125 126// TestRequireAuth_MalformedToken tests that malformed JWTs are rejected 127func TestRequireAuth_MalformedToken(t *testing.T) { 128 fetcher := &mockJWKSFetcher{} 129 middleware := NewAtProtoAuthMiddleware(fetcher, true) 130 131 handler := middleware.RequireAuth(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 132 t.Error("handler should not be called") 133 })) 134 135 req := httptest.NewRequest("GET", "/test", nil) 136 req.Header.Set("Authorization", "Bearer not-a-valid-jwt") 137 w := httptest.NewRecorder() 138 139 handler.ServeHTTP(w, req) 140 141 if w.Code != http.StatusUnauthorized { 142 t.Errorf("expected status 401, got %d", w.Code) 143 } 144} 145 146// TestRequireAuth_ExpiredToken tests that expired tokens are rejected 147func TestRequireAuth_ExpiredToken(t *testing.T) { 148 fetcher := &mockJWKSFetcher{} 149 middleware := NewAtProtoAuthMiddleware(fetcher, true) 150 151 handler := middleware.RequireAuth(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 152 t.Error("handler should not be called for expired token") 153 })) 154 155 // Create expired token 156 claims := jwt.MapClaims{ 157 "sub": "did:plc:test123", 158 "iss": "https://test.pds.local", 159 "scope": "atproto", 160 "exp": time.Now().Add(-1 * time.Hour).Unix(), // Expired 1 hour ago 161 "iat": time.Now().Add(-2 * time.Hour).Unix(), 162 } 163 164 token := jwt.NewWithClaims(jwt.SigningMethodNone, claims) 165 tokenString, _ := token.SignedString(jwt.UnsafeAllowNoneSignatureType) 166 167 req := httptest.NewRequest("GET", "/test", nil) 168 req.Header.Set("Authorization", "Bearer "+tokenString) 169 w := httptest.NewRecorder() 170 171 handler.ServeHTTP(w, req) 172 173 if w.Code != http.StatusUnauthorized { 174 t.Errorf("expected status 401, got %d", w.Code) 175 } 176} 177 178// TestRequireAuth_MissingDID tests that tokens without DID are rejected 179func TestRequireAuth_MissingDID(t *testing.T) { 180 fetcher := &mockJWKSFetcher{} 181 middleware := NewAtProtoAuthMiddleware(fetcher, true) 182 183 handler := middleware.RequireAuth(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 184 t.Error("handler should not be called") 185 })) 186 187 // Create token without sub claim 188 claims := jwt.MapClaims{ 189 // "sub" missing 190 "iss": "https://test.pds.local", 191 "scope": "atproto", 192 "exp": time.Now().Add(1 * time.Hour).Unix(), 193 "iat": time.Now().Unix(), 194 } 195 196 token := jwt.NewWithClaims(jwt.SigningMethodNone, claims) 197 tokenString, _ := token.SignedString(jwt.UnsafeAllowNoneSignatureType) 198 199 req := httptest.NewRequest("GET", "/test", nil) 200 req.Header.Set("Authorization", "Bearer "+tokenString) 201 w := httptest.NewRecorder() 202 203 handler.ServeHTTP(w, req) 204 205 if w.Code != http.StatusUnauthorized { 206 t.Errorf("expected status 401, got %d", w.Code) 207 } 208} 209 210// TestOptionalAuth_WithToken tests that OptionalAuth accepts valid tokens 211func TestOptionalAuth_WithToken(t *testing.T) { 212 fetcher := &mockJWKSFetcher{} 213 middleware := NewAtProtoAuthMiddleware(fetcher, true) 214 215 handlerCalled := false 216 handler := middleware.OptionalAuth(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 217 handlerCalled = true 218 219 // Verify DID was extracted 220 did := GetUserDID(r) 221 if did != "did:plc:test123" { 222 t.Errorf("expected DID 'did:plc:test123', got %s", did) 223 } 224 225 w.WriteHeader(http.StatusOK) 226 })) 227 228 token := createTestToken("did:plc:test123") 229 req := httptest.NewRequest("GET", "/test", nil) 230 req.Header.Set("Authorization", "Bearer "+token) 231 w := httptest.NewRecorder() 232 233 handler.ServeHTTP(w, req) 234 235 if !handlerCalled { 236 t.Error("handler was not called") 237 } 238 239 if w.Code != http.StatusOK { 240 t.Errorf("expected status 200, got %d", w.Code) 241 } 242} 243 244// TestOptionalAuth_WithoutToken tests that OptionalAuth allows requests without tokens 245func TestOptionalAuth_WithoutToken(t *testing.T) { 246 fetcher := &mockJWKSFetcher{} 247 middleware := NewAtProtoAuthMiddleware(fetcher, true) 248 249 handlerCalled := false 250 handler := middleware.OptionalAuth(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 251 handlerCalled = true 252 253 // Verify no DID is set 254 did := GetUserDID(r) 255 if did != "" { 256 t.Errorf("expected empty DID, got %s", did) 257 } 258 259 w.WriteHeader(http.StatusOK) 260 })) 261 262 req := httptest.NewRequest("GET", "/test", nil) 263 // No Authorization header 264 w := httptest.NewRecorder() 265 266 handler.ServeHTTP(w, req) 267 268 if !handlerCalled { 269 t.Error("handler was not called") 270 } 271 272 if w.Code != http.StatusOK { 273 t.Errorf("expected status 200, got %d", w.Code) 274 } 275} 276 277// TestOptionalAuth_InvalidToken tests that OptionalAuth continues without auth on invalid token 278func TestOptionalAuth_InvalidToken(t *testing.T) { 279 fetcher := &mockJWKSFetcher{} 280 middleware := NewAtProtoAuthMiddleware(fetcher, true) 281 282 handlerCalled := false 283 handler := middleware.OptionalAuth(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 284 handlerCalled = true 285 286 // Verify no DID is set (invalid token ignored) 287 did := GetUserDID(r) 288 if did != "" { 289 t.Errorf("expected empty DID for invalid token, got %s", did) 290 } 291 292 w.WriteHeader(http.StatusOK) 293 })) 294 295 req := httptest.NewRequest("GET", "/test", nil) 296 req.Header.Set("Authorization", "Bearer not-a-valid-jwt") 297 w := httptest.NewRecorder() 298 299 handler.ServeHTTP(w, req) 300 301 if !handlerCalled { 302 t.Error("handler was not called") 303 } 304 305 if w.Code != http.StatusOK { 306 t.Errorf("expected status 200, got %d", w.Code) 307 } 308} 309 310// TestGetUserDID_NotAuthenticated tests that GetUserDID returns empty string when not authenticated 311func TestGetUserDID_NotAuthenticated(t *testing.T) { 312 req := httptest.NewRequest("GET", "/test", nil) 313 did := GetUserDID(req) 314 315 if did != "" { 316 t.Errorf("expected empty string, got %s", did) 317 } 318} 319 320// TestGetJWTClaims_NotAuthenticated tests that GetJWTClaims returns nil when not authenticated 321func TestGetJWTClaims_NotAuthenticated(t *testing.T) { 322 req := httptest.NewRequest("GET", "/test", nil) 323 claims := GetJWTClaims(req) 324 325 if claims != nil { 326 t.Errorf("expected nil claims, got %+v", claims) 327 } 328}