A community based topic aggregation platform built on atproto
1package auth
2
3import (
4 "context"
5 "testing"
6 "time"
7
8 "github.com/golang-jwt/jwt/v5"
9)
10
11func TestParseJWT(t *testing.T) {
12 // Create a test JWT token
13 claims := &Claims{
14 RegisteredClaims: jwt.RegisteredClaims{
15 Subject: "did:plc:test123",
16 Issuer: "https://test-pds.example.com",
17 ExpiresAt: jwt.NewNumericDate(time.Now().Add(1 * time.Hour)),
18 IssuedAt: jwt.NewNumericDate(time.Now()),
19 },
20 Scope: "atproto transition:generic",
21 }
22
23 token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
24 tokenString, err := token.SignedString([]byte("test-secret"))
25 if err != nil {
26 t.Fatalf("Failed to create test token: %v", err)
27 }
28
29 // Test parsing
30 parsedClaims, err := ParseJWT(tokenString)
31 if err != nil {
32 t.Fatalf("ParseJWT failed: %v", err)
33 }
34
35 if parsedClaims.Subject != "did:plc:test123" {
36 t.Errorf("Expected subject 'did:plc:test123', got '%s'", parsedClaims.Subject)
37 }
38
39 if parsedClaims.Issuer != "https://test-pds.example.com" {
40 t.Errorf("Expected issuer 'https://test-pds.example.com', got '%s'", parsedClaims.Issuer)
41 }
42
43 if parsedClaims.Scope != "atproto transition:generic" {
44 t.Errorf("Expected scope 'atproto transition:generic', got '%s'", parsedClaims.Scope)
45 }
46}
47
48func TestParseJWT_MissingSubject(t *testing.T) {
49 // Create a token without subject
50 claims := &Claims{
51 RegisteredClaims: jwt.RegisteredClaims{
52 Issuer: "https://test-pds.example.com",
53 ExpiresAt: jwt.NewNumericDate(time.Now().Add(1 * time.Hour)),
54 },
55 }
56
57 token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
58 tokenString, err := token.SignedString([]byte("test-secret"))
59 if err != nil {
60 t.Fatalf("Failed to create test token: %v", err)
61 }
62
63 // Test parsing - should fail
64 _, err = ParseJWT(tokenString)
65 if err == nil {
66 t.Error("Expected error for missing subject, got nil")
67 }
68}
69
70func TestParseJWT_MissingIssuer(t *testing.T) {
71 // Create a token without issuer
72 claims := &Claims{
73 RegisteredClaims: jwt.RegisteredClaims{
74 Subject: "did:plc:test123",
75 ExpiresAt: jwt.NewNumericDate(time.Now().Add(1 * time.Hour)),
76 },
77 }
78
79 token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
80 tokenString, err := token.SignedString([]byte("test-secret"))
81 if err != nil {
82 t.Fatalf("Failed to create test token: %v", err)
83 }
84
85 // Test parsing - should fail
86 _, err = ParseJWT(tokenString)
87 if err == nil {
88 t.Error("Expected error for missing issuer, got nil")
89 }
90}
91
92func TestParseJWT_WithBearerPrefix(t *testing.T) {
93 // Create a test JWT token
94 claims := &Claims{
95 RegisteredClaims: jwt.RegisteredClaims{
96 Subject: "did:plc:test123",
97 Issuer: "https://test-pds.example.com",
98 ExpiresAt: jwt.NewNumericDate(time.Now().Add(1 * time.Hour)),
99 },
100 }
101
102 token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
103 tokenString, err := token.SignedString([]byte("test-secret"))
104 if err != nil {
105 t.Fatalf("Failed to create test token: %v", err)
106 }
107
108 // Test parsing with Bearer prefix
109 parsedClaims, err := ParseJWT("Bearer " + tokenString)
110 if err != nil {
111 t.Fatalf("ParseJWT failed with Bearer prefix: %v", err)
112 }
113
114 if parsedClaims.Subject != "did:plc:test123" {
115 t.Errorf("Expected subject 'did:plc:test123', got '%s'", parsedClaims.Subject)
116 }
117}
118
119func TestValidateClaims_Expired(t *testing.T) {
120 claims := &Claims{
121 RegisteredClaims: jwt.RegisteredClaims{
122 Subject: "did:plc:test123",
123 Issuer: "https://test-pds.example.com",
124 ExpiresAt: jwt.NewNumericDate(time.Now().Add(-1 * time.Hour)), // Expired
125 },
126 }
127
128 err := validateClaims(claims)
129 if err == nil {
130 t.Error("Expected error for expired token, got nil")
131 }
132}
133
134func TestValidateClaims_InvalidDID(t *testing.T) {
135 claims := &Claims{
136 RegisteredClaims: jwt.RegisteredClaims{
137 Subject: "invalid-did-format",
138 Issuer: "https://test-pds.example.com",
139 ExpiresAt: jwt.NewNumericDate(time.Now().Add(1 * time.Hour)),
140 },
141 }
142
143 err := validateClaims(claims)
144 if err == nil {
145 t.Error("Expected error for invalid DID format, got nil")
146 }
147}
148
149func TestExtractKeyID(t *testing.T) {
150 // Create a test JWT token with kid in header
151 token := jwt.New(jwt.SigningMethodRS256)
152 token.Header["kid"] = "test-key-id"
153 token.Claims = &Claims{
154 RegisteredClaims: jwt.RegisteredClaims{
155 Subject: "did:plc:test123",
156 Issuer: "https://test-pds.example.com",
157 },
158 }
159
160 // Sign with a dummy RSA key (we just need a valid token structure)
161 tokenString, err := token.SignedString([]byte("dummy"))
162 if err == nil {
163 // If it succeeds (shouldn't with wrong key type, but let's handle it)
164 kid, err := ExtractKeyID(tokenString)
165 if err != nil {
166 t.Logf("ExtractKeyID failed (expected if signing fails): %v", err)
167 } else if kid != "test-key-id" {
168 t.Errorf("Expected kid 'test-key-id', got '%s'", kid)
169 }
170 }
171}
172
173// === HS256 Verification Tests ===
174
175// mockJWKSFetcher is a mock implementation of JWKSFetcher for testing
176type mockJWKSFetcher struct {
177 publicKey interface{}
178 err error
179}
180
181func (m *mockJWKSFetcher) FetchPublicKey(ctx context.Context, issuer, token string) (interface{}, error) {
182 return m.publicKey, m.err
183}
184
185func createHS256Token(t *testing.T, subject, issuer, secret string, expiry time.Duration) string {
186 t.Helper()
187 claims := &Claims{
188 RegisteredClaims: jwt.RegisteredClaims{
189 Subject: subject,
190 Issuer: issuer,
191 ExpiresAt: jwt.NewNumericDate(time.Now().Add(expiry)),
192 IssuedAt: jwt.NewNumericDate(time.Now()),
193 },
194 Scope: "atproto transition:generic",
195 }
196 token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
197 tokenString, err := token.SignedString([]byte(secret))
198 if err != nil {
199 t.Fatalf("Failed to create test token: %v", err)
200 }
201 return tokenString
202}
203
204func TestVerifyJWT_HS256_Valid(t *testing.T) {
205 // Setup: Configure environment for HS256 verification
206 secret := "test-jwt-secret-key-12345"
207 issuer := "https://pds.coves.social"
208
209 ResetJWTConfigForTesting()
210 t.Setenv("PDS_JWT_SECRET", secret)
211 t.Setenv("HS256_ISSUERS", issuer)
212 t.Cleanup(ResetJWTConfigForTesting)
213
214 tokenString := createHS256Token(t, "did:plc:test123", issuer, secret, 1*time.Hour)
215
216 // Verify token
217 claims, err := VerifyJWT(context.Background(), tokenString, &mockJWKSFetcher{})
218 if err != nil {
219 t.Fatalf("VerifyJWT failed for valid HS256 token: %v", err)
220 }
221
222 if claims.Subject != "did:plc:test123" {
223 t.Errorf("Expected subject 'did:plc:test123', got '%s'", claims.Subject)
224 }
225 if claims.Issuer != issuer {
226 t.Errorf("Expected issuer '%s', got '%s'", issuer, claims.Issuer)
227 }
228}
229
230func TestVerifyJWT_HS256_WrongSecret(t *testing.T) {
231 // Setup: Configure environment with one secret, sign with another
232 issuer := "https://pds.coves.social"
233
234 ResetJWTConfigForTesting()
235 t.Setenv("PDS_JWT_SECRET", "correct-secret")
236 t.Setenv("HS256_ISSUERS", issuer)
237 t.Cleanup(ResetJWTConfigForTesting)
238
239 // Create token with wrong secret
240 tokenString := createHS256Token(t, "did:plc:test123", issuer, "wrong-secret", 1*time.Hour)
241
242 // Verify should fail
243 _, err := VerifyJWT(context.Background(), tokenString, &mockJWKSFetcher{})
244 if err == nil {
245 t.Error("Expected error for HS256 token with wrong secret, got nil")
246 }
247}
248
249func TestVerifyJWT_HS256_SecretNotConfigured(t *testing.T) {
250 // Setup: Whitelist issuer but don't configure secret
251 issuer := "https://pds.coves.social"
252
253 ResetJWTConfigForTesting()
254 t.Setenv("PDS_JWT_SECRET", "") // Ensure secret is not set (empty = not configured)
255 t.Setenv("HS256_ISSUERS", issuer)
256 t.Cleanup(ResetJWTConfigForTesting)
257
258 tokenString := createHS256Token(t, "did:plc:test123", issuer, "any-secret", 1*time.Hour)
259
260 // Verify should fail with descriptive error
261 _, err := VerifyJWT(context.Background(), tokenString, &mockJWKSFetcher{})
262 if err == nil {
263 t.Error("Expected error when PDS_JWT_SECRET not configured, got nil")
264 }
265 if err != nil && !contains(err.Error(), "PDS_JWT_SECRET not configured") {
266 t.Errorf("Expected error about PDS_JWT_SECRET not configured, got: %v", err)
267 }
268}
269
270// === Algorithm Confusion Attack Prevention Tests ===
271
272func TestVerifyJWT_AlgorithmConfusionAttack_HS256WithNonWhitelistedIssuer(t *testing.T) {
273 // SECURITY TEST: This tests the algorithm confusion attack prevention
274 // An attacker tries to use HS256 with an issuer that should use RS256/ES256
275
276 ResetJWTConfigForTesting()
277 t.Setenv("PDS_JWT_SECRET", "some-secret")
278 t.Setenv("HS256_ISSUERS", "https://trusted.example.com") // Different from token issuer
279 t.Cleanup(ResetJWTConfigForTesting)
280
281 // Create HS256 token with non-whitelisted issuer (simulating attack)
282 tokenString := createHS256Token(t, "did:plc:attacker", "https://victim-pds.example.com", "some-secret", 1*time.Hour)
283
284 // Verify should fail because issuer is not in HS256 whitelist
285 _, err := VerifyJWT(context.Background(), tokenString, &mockJWKSFetcher{})
286 if err == nil {
287 t.Error("SECURITY VULNERABILITY: HS256 token accepted for non-whitelisted issuer")
288 }
289 if err != nil && !contains(err.Error(), "not in HS256_ISSUERS whitelist") {
290 t.Errorf("Expected error about HS256 not allowed for issuer, got: %v", err)
291 }
292}
293
294func TestVerifyJWT_AlgorithmConfusionAttack_EmptyWhitelist(t *testing.T) {
295 // SECURITY TEST: When no issuers are whitelisted for HS256, all HS256 tokens should be rejected
296
297 ResetJWTConfigForTesting()
298 t.Setenv("PDS_JWT_SECRET", "some-secret")
299 t.Setenv("HS256_ISSUERS", "") // Empty whitelist
300 t.Cleanup(ResetJWTConfigForTesting)
301
302 tokenString := createHS256Token(t, "did:plc:test123", "https://any-pds.example.com", "some-secret", 1*time.Hour)
303
304 // Verify should fail because no issuers are whitelisted for HS256
305 _, err := VerifyJWT(context.Background(), tokenString, &mockJWKSFetcher{})
306 if err == nil {
307 t.Error("SECURITY VULNERABILITY: HS256 token accepted with empty issuer whitelist")
308 }
309}
310
311func TestVerifyJWT_IssuerRequiresHS256ButTokenUsesRS256(t *testing.T) {
312 // Test that issuer whitelisted for HS256 rejects tokens claiming to use RS256
313 issuer := "https://pds.coves.social"
314
315 ResetJWTConfigForTesting()
316 t.Setenv("PDS_JWT_SECRET", "test-secret")
317 t.Setenv("HS256_ISSUERS", issuer)
318 t.Cleanup(ResetJWTConfigForTesting)
319
320 // Create RS256-signed token (can't actually sign without RSA key, but we can test the header check)
321 claims := &Claims{
322 RegisteredClaims: jwt.RegisteredClaims{
323 Subject: "did:plc:test123",
324 Issuer: issuer,
325 ExpiresAt: jwt.NewNumericDate(time.Now().Add(1 * time.Hour)),
326 },
327 }
328 token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims)
329 // This will create an invalid signature but valid header structure
330 // The test should fail at algorithm check, not signature verification
331 tokenString, _ := token.SignedString([]byte("dummy-key"))
332
333 if tokenString != "" {
334 _, err := VerifyJWT(context.Background(), tokenString, &mockJWKSFetcher{})
335 if err == nil {
336 t.Error("Expected error when HS256 issuer receives non-HS256 token")
337 }
338 }
339}
340
341// === ParseJWTHeader Tests ===
342
343func TestParseJWTHeader_Valid(t *testing.T) {
344 tokenString := createHS256Token(t, "did:plc:test123", "https://test.example.com", "secret", 1*time.Hour)
345
346 header, err := ParseJWTHeader(tokenString)
347 if err != nil {
348 t.Fatalf("ParseJWTHeader failed: %v", err)
349 }
350
351 if header.Alg != AlgorithmHS256 {
352 t.Errorf("Expected alg '%s', got '%s'", AlgorithmHS256, header.Alg)
353 }
354}
355
356func TestParseJWTHeader_WithBearerPrefix(t *testing.T) {
357 tokenString := createHS256Token(t, "did:plc:test123", "https://test.example.com", "secret", 1*time.Hour)
358
359 header, err := ParseJWTHeader("Bearer " + tokenString)
360 if err != nil {
361 t.Fatalf("ParseJWTHeader failed with Bearer prefix: %v", err)
362 }
363
364 if header.Alg != AlgorithmHS256 {
365 t.Errorf("Expected alg '%s', got '%s'", AlgorithmHS256, header.Alg)
366 }
367}
368
369func TestParseJWTHeader_InvalidFormat(t *testing.T) {
370 testCases := []struct {
371 name string
372 input string
373 }{
374 {"empty string", ""},
375 {"single part", "abc"},
376 {"two parts", "abc.def"},
377 {"too many parts", "a.b.c.d"},
378 }
379
380 for _, tc := range testCases {
381 t.Run(tc.name, func(t *testing.T) {
382 _, err := ParseJWTHeader(tc.input)
383 if err == nil {
384 t.Errorf("Expected error for invalid JWT format '%s', got nil", tc.input)
385 }
386 })
387 }
388}
389
390// === shouldUseHS256 and isHS256IssuerWhitelisted Tests ===
391
392func TestIsHS256IssuerWhitelisted_Whitelisted(t *testing.T) {
393 ResetJWTConfigForTesting()
394 t.Setenv("HS256_ISSUERS", "https://pds1.example.com,https://pds2.example.com")
395 t.Cleanup(ResetJWTConfigForTesting)
396
397 if !isHS256IssuerWhitelisted("https://pds1.example.com") {
398 t.Error("Expected pds1 to be whitelisted")
399 }
400 if !isHS256IssuerWhitelisted("https://pds2.example.com") {
401 t.Error("Expected pds2 to be whitelisted")
402 }
403}
404
405func TestIsHS256IssuerWhitelisted_NotWhitelisted(t *testing.T) {
406 ResetJWTConfigForTesting()
407 t.Setenv("HS256_ISSUERS", "https://pds1.example.com")
408 t.Cleanup(ResetJWTConfigForTesting)
409
410 if isHS256IssuerWhitelisted("https://attacker.example.com") {
411 t.Error("Expected non-whitelisted issuer to return false")
412 }
413}
414
415func TestIsHS256IssuerWhitelisted_EmptyWhitelist(t *testing.T) {
416 ResetJWTConfigForTesting()
417 t.Setenv("HS256_ISSUERS", "") // Empty whitelist
418 t.Cleanup(ResetJWTConfigForTesting)
419
420 if isHS256IssuerWhitelisted("https://any.example.com") {
421 t.Error("Expected false when whitelist is empty (safe default)")
422 }
423}
424
425func TestIsHS256IssuerWhitelisted_WhitespaceHandling(t *testing.T) {
426 ResetJWTConfigForTesting()
427 t.Setenv("HS256_ISSUERS", " https://pds1.example.com , https://pds2.example.com ")
428 t.Cleanup(ResetJWTConfigForTesting)
429
430 if !isHS256IssuerWhitelisted("https://pds1.example.com") {
431 t.Error("Expected whitespace-trimmed issuer to be whitelisted")
432 }
433}
434
435// === shouldUseHS256 Tests (kid-based logic) ===
436
437func TestShouldUseHS256_WithKid_AlwaysFalse(t *testing.T) {
438 // Tokens with kid should NEVER use HS256, regardless of issuer whitelist
439 ResetJWTConfigForTesting()
440 t.Setenv("HS256_ISSUERS", "https://whitelisted.example.com")
441 t.Cleanup(ResetJWTConfigForTesting)
442
443 header := &JWTHeader{
444 Alg: AlgorithmHS256,
445 Kid: "some-key-id", // Has kid
446 }
447
448 // Even whitelisted issuer should not use HS256 if token has kid
449 if shouldUseHS256(header, "https://whitelisted.example.com") {
450 t.Error("Tokens with kid should never use HS256 (supports federation)")
451 }
452}
453
454func TestShouldUseHS256_WithoutKid_WhitelistedIssuer(t *testing.T) {
455 ResetJWTConfigForTesting()
456 t.Setenv("HS256_ISSUERS", "https://my-pds.example.com")
457 t.Cleanup(ResetJWTConfigForTesting)
458
459 header := &JWTHeader{
460 Alg: AlgorithmHS256,
461 Kid: "", // No kid
462 }
463
464 if !shouldUseHS256(header, "https://my-pds.example.com") {
465 t.Error("Token without kid from whitelisted issuer should use HS256")
466 }
467}
468
469func TestShouldUseHS256_WithoutKid_NotWhitelisted(t *testing.T) {
470 ResetJWTConfigForTesting()
471 t.Setenv("HS256_ISSUERS", "https://my-pds.example.com")
472 t.Cleanup(ResetJWTConfigForTesting)
473
474 header := &JWTHeader{
475 Alg: AlgorithmHS256,
476 Kid: "", // No kid
477 }
478
479 if shouldUseHS256(header, "https://external-pds.example.com") {
480 t.Error("Token without kid from non-whitelisted issuer should NOT use HS256")
481 }
482}
483
484// Helper function
485func contains(s, substr string) bool {
486 return len(s) >= len(substr) && (s == substr || len(s) > 0 && containsHelper(s, substr))
487}
488
489func containsHelper(s, substr string) bool {
490 for i := 0; i <= len(s)-len(substr); i++ {
491 if s[i:i+len(substr)] == substr {
492 return true
493 }
494 }
495 return false
496}