A community based topic aggregation platform built on atproto
1package auth 2 3import ( 4 "context" 5 "testing" 6 "time" 7 8 "github.com/golang-jwt/jwt/v5" 9) 10 11func TestParseJWT(t *testing.T) { 12 // Create a test JWT token 13 claims := &Claims{ 14 RegisteredClaims: jwt.RegisteredClaims{ 15 Subject: "did:plc:test123", 16 Issuer: "https://test-pds.example.com", 17 ExpiresAt: jwt.NewNumericDate(time.Now().Add(1 * time.Hour)), 18 IssuedAt: jwt.NewNumericDate(time.Now()), 19 }, 20 Scope: "atproto transition:generic", 21 } 22 23 token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) 24 tokenString, err := token.SignedString([]byte("test-secret")) 25 if err != nil { 26 t.Fatalf("Failed to create test token: %v", err) 27 } 28 29 // Test parsing 30 parsedClaims, err := ParseJWT(tokenString) 31 if err != nil { 32 t.Fatalf("ParseJWT failed: %v", err) 33 } 34 35 if parsedClaims.Subject != "did:plc:test123" { 36 t.Errorf("Expected subject 'did:plc:test123', got '%s'", parsedClaims.Subject) 37 } 38 39 if parsedClaims.Issuer != "https://test-pds.example.com" { 40 t.Errorf("Expected issuer 'https://test-pds.example.com', got '%s'", parsedClaims.Issuer) 41 } 42 43 if parsedClaims.Scope != "atproto transition:generic" { 44 t.Errorf("Expected scope 'atproto transition:generic', got '%s'", parsedClaims.Scope) 45 } 46} 47 48func TestParseJWT_MissingSubject(t *testing.T) { 49 // Create a token without subject 50 claims := &Claims{ 51 RegisteredClaims: jwt.RegisteredClaims{ 52 Issuer: "https://test-pds.example.com", 53 ExpiresAt: jwt.NewNumericDate(time.Now().Add(1 * time.Hour)), 54 }, 55 } 56 57 token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) 58 tokenString, err := token.SignedString([]byte("test-secret")) 59 if err != nil { 60 t.Fatalf("Failed to create test token: %v", err) 61 } 62 63 // Test parsing - should fail 64 _, err = ParseJWT(tokenString) 65 if err == nil { 66 t.Error("Expected error for missing subject, got nil") 67 } 68} 69 70func TestParseJWT_MissingIssuer(t *testing.T) { 71 // Create a token without issuer 72 claims := &Claims{ 73 RegisteredClaims: jwt.RegisteredClaims{ 74 Subject: "did:plc:test123", 75 ExpiresAt: jwt.NewNumericDate(time.Now().Add(1 * time.Hour)), 76 }, 77 } 78 79 token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) 80 tokenString, err := token.SignedString([]byte("test-secret")) 81 if err != nil { 82 t.Fatalf("Failed to create test token: %v", err) 83 } 84 85 // Test parsing - should fail 86 _, err = ParseJWT(tokenString) 87 if err == nil { 88 t.Error("Expected error for missing issuer, got nil") 89 } 90} 91 92func TestParseJWT_WithBearerPrefix(t *testing.T) { 93 // Create a test JWT token 94 claims := &Claims{ 95 RegisteredClaims: jwt.RegisteredClaims{ 96 Subject: "did:plc:test123", 97 Issuer: "https://test-pds.example.com", 98 ExpiresAt: jwt.NewNumericDate(time.Now().Add(1 * time.Hour)), 99 }, 100 } 101 102 token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) 103 tokenString, err := token.SignedString([]byte("test-secret")) 104 if err != nil { 105 t.Fatalf("Failed to create test token: %v", err) 106 } 107 108 // Test parsing with Bearer prefix 109 parsedClaims, err := ParseJWT("Bearer " + tokenString) 110 if err != nil { 111 t.Fatalf("ParseJWT failed with Bearer prefix: %v", err) 112 } 113 114 if parsedClaims.Subject != "did:plc:test123" { 115 t.Errorf("Expected subject 'did:plc:test123', got '%s'", parsedClaims.Subject) 116 } 117} 118 119func TestValidateClaims_Expired(t *testing.T) { 120 claims := &Claims{ 121 RegisteredClaims: jwt.RegisteredClaims{ 122 Subject: "did:plc:test123", 123 Issuer: "https://test-pds.example.com", 124 ExpiresAt: jwt.NewNumericDate(time.Now().Add(-1 * time.Hour)), // Expired 125 }, 126 } 127 128 err := validateClaims(claims) 129 if err == nil { 130 t.Error("Expected error for expired token, got nil") 131 } 132} 133 134func TestValidateClaims_InvalidDID(t *testing.T) { 135 claims := &Claims{ 136 RegisteredClaims: jwt.RegisteredClaims{ 137 Subject: "invalid-did-format", 138 Issuer: "https://test-pds.example.com", 139 ExpiresAt: jwt.NewNumericDate(time.Now().Add(1 * time.Hour)), 140 }, 141 } 142 143 err := validateClaims(claims) 144 if err == nil { 145 t.Error("Expected error for invalid DID format, got nil") 146 } 147} 148 149func TestExtractKeyID(t *testing.T) { 150 // Create a test JWT token with kid in header 151 token := jwt.New(jwt.SigningMethodRS256) 152 token.Header["kid"] = "test-key-id" 153 token.Claims = &Claims{ 154 RegisteredClaims: jwt.RegisteredClaims{ 155 Subject: "did:plc:test123", 156 Issuer: "https://test-pds.example.com", 157 }, 158 } 159 160 // Sign with a dummy RSA key (we just need a valid token structure) 161 tokenString, err := token.SignedString([]byte("dummy")) 162 if err == nil { 163 // If it succeeds (shouldn't with wrong key type, but let's handle it) 164 kid, err := ExtractKeyID(tokenString) 165 if err != nil { 166 t.Logf("ExtractKeyID failed (expected if signing fails): %v", err) 167 } else if kid != "test-key-id" { 168 t.Errorf("Expected kid 'test-key-id', got '%s'", kid) 169 } 170 } 171} 172 173// === HS256 Verification Tests === 174 175// mockJWKSFetcher is a mock implementation of JWKSFetcher for testing 176type mockJWKSFetcher struct { 177 publicKey interface{} 178 err error 179} 180 181func (m *mockJWKSFetcher) FetchPublicKey(ctx context.Context, issuer, token string) (interface{}, error) { 182 return m.publicKey, m.err 183} 184 185func createHS256Token(t *testing.T, subject, issuer, secret string, expiry time.Duration) string { 186 t.Helper() 187 claims := &Claims{ 188 RegisteredClaims: jwt.RegisteredClaims{ 189 Subject: subject, 190 Issuer: issuer, 191 ExpiresAt: jwt.NewNumericDate(time.Now().Add(expiry)), 192 IssuedAt: jwt.NewNumericDate(time.Now()), 193 }, 194 Scope: "atproto transition:generic", 195 } 196 token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) 197 tokenString, err := token.SignedString([]byte(secret)) 198 if err != nil { 199 t.Fatalf("Failed to create test token: %v", err) 200 } 201 return tokenString 202} 203 204func TestVerifyJWT_HS256_Valid(t *testing.T) { 205 // Setup: Configure environment for HS256 verification 206 secret := "test-jwt-secret-key-12345" 207 issuer := "https://pds.coves.social" 208 209 ResetJWTConfigForTesting() 210 t.Setenv("PDS_JWT_SECRET", secret) 211 t.Setenv("HS256_ISSUERS", issuer) 212 t.Cleanup(ResetJWTConfigForTesting) 213 214 tokenString := createHS256Token(t, "did:plc:test123", issuer, secret, 1*time.Hour) 215 216 // Verify token 217 claims, err := VerifyJWT(context.Background(), tokenString, &mockJWKSFetcher{}) 218 if err != nil { 219 t.Fatalf("VerifyJWT failed for valid HS256 token: %v", err) 220 } 221 222 if claims.Subject != "did:plc:test123" { 223 t.Errorf("Expected subject 'did:plc:test123', got '%s'", claims.Subject) 224 } 225 if claims.Issuer != issuer { 226 t.Errorf("Expected issuer '%s', got '%s'", issuer, claims.Issuer) 227 } 228} 229 230func TestVerifyJWT_HS256_WrongSecret(t *testing.T) { 231 // Setup: Configure environment with one secret, sign with another 232 issuer := "https://pds.coves.social" 233 234 ResetJWTConfigForTesting() 235 t.Setenv("PDS_JWT_SECRET", "correct-secret") 236 t.Setenv("HS256_ISSUERS", issuer) 237 t.Cleanup(ResetJWTConfigForTesting) 238 239 // Create token with wrong secret 240 tokenString := createHS256Token(t, "did:plc:test123", issuer, "wrong-secret", 1*time.Hour) 241 242 // Verify should fail 243 _, err := VerifyJWT(context.Background(), tokenString, &mockJWKSFetcher{}) 244 if err == nil { 245 t.Error("Expected error for HS256 token with wrong secret, got nil") 246 } 247} 248 249func TestVerifyJWT_HS256_SecretNotConfigured(t *testing.T) { 250 // Setup: Whitelist issuer but don't configure secret 251 issuer := "https://pds.coves.social" 252 253 ResetJWTConfigForTesting() 254 t.Setenv("PDS_JWT_SECRET", "") // Ensure secret is not set (empty = not configured) 255 t.Setenv("HS256_ISSUERS", issuer) 256 t.Cleanup(ResetJWTConfigForTesting) 257 258 tokenString := createHS256Token(t, "did:plc:test123", issuer, "any-secret", 1*time.Hour) 259 260 // Verify should fail with descriptive error 261 _, err := VerifyJWT(context.Background(), tokenString, &mockJWKSFetcher{}) 262 if err == nil { 263 t.Error("Expected error when PDS_JWT_SECRET not configured, got nil") 264 } 265 if err != nil && !contains(err.Error(), "PDS_JWT_SECRET not configured") { 266 t.Errorf("Expected error about PDS_JWT_SECRET not configured, got: %v", err) 267 } 268} 269 270// === Algorithm Confusion Attack Prevention Tests === 271 272func TestVerifyJWT_AlgorithmConfusionAttack_HS256WithNonWhitelistedIssuer(t *testing.T) { 273 // SECURITY TEST: This tests the algorithm confusion attack prevention 274 // An attacker tries to use HS256 with an issuer that should use RS256/ES256 275 276 ResetJWTConfigForTesting() 277 t.Setenv("PDS_JWT_SECRET", "some-secret") 278 t.Setenv("HS256_ISSUERS", "https://trusted.example.com") // Different from token issuer 279 t.Cleanup(ResetJWTConfigForTesting) 280 281 // Create HS256 token with non-whitelisted issuer (simulating attack) 282 tokenString := createHS256Token(t, "did:plc:attacker", "https://victim-pds.example.com", "some-secret", 1*time.Hour) 283 284 // Verify should fail because issuer is not in HS256 whitelist 285 _, err := VerifyJWT(context.Background(), tokenString, &mockJWKSFetcher{}) 286 if err == nil { 287 t.Error("SECURITY VULNERABILITY: HS256 token accepted for non-whitelisted issuer") 288 } 289 if err != nil && !contains(err.Error(), "not in HS256_ISSUERS whitelist") { 290 t.Errorf("Expected error about HS256 not allowed for issuer, got: %v", err) 291 } 292} 293 294func TestVerifyJWT_AlgorithmConfusionAttack_EmptyWhitelist(t *testing.T) { 295 // SECURITY TEST: When no issuers are whitelisted for HS256, all HS256 tokens should be rejected 296 297 ResetJWTConfigForTesting() 298 t.Setenv("PDS_JWT_SECRET", "some-secret") 299 t.Setenv("HS256_ISSUERS", "") // Empty whitelist 300 t.Cleanup(ResetJWTConfigForTesting) 301 302 tokenString := createHS256Token(t, "did:plc:test123", "https://any-pds.example.com", "some-secret", 1*time.Hour) 303 304 // Verify should fail because no issuers are whitelisted for HS256 305 _, err := VerifyJWT(context.Background(), tokenString, &mockJWKSFetcher{}) 306 if err == nil { 307 t.Error("SECURITY VULNERABILITY: HS256 token accepted with empty issuer whitelist") 308 } 309} 310 311func TestVerifyJWT_IssuerRequiresHS256ButTokenUsesRS256(t *testing.T) { 312 // Test that issuer whitelisted for HS256 rejects tokens claiming to use RS256 313 issuer := "https://pds.coves.social" 314 315 ResetJWTConfigForTesting() 316 t.Setenv("PDS_JWT_SECRET", "test-secret") 317 t.Setenv("HS256_ISSUERS", issuer) 318 t.Cleanup(ResetJWTConfigForTesting) 319 320 // Create RS256-signed token (can't actually sign without RSA key, but we can test the header check) 321 claims := &Claims{ 322 RegisteredClaims: jwt.RegisteredClaims{ 323 Subject: "did:plc:test123", 324 Issuer: issuer, 325 ExpiresAt: jwt.NewNumericDate(time.Now().Add(1 * time.Hour)), 326 }, 327 } 328 token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims) 329 // This will create an invalid signature but valid header structure 330 // The test should fail at algorithm check, not signature verification 331 tokenString, _ := token.SignedString([]byte("dummy-key")) 332 333 if tokenString != "" { 334 _, err := VerifyJWT(context.Background(), tokenString, &mockJWKSFetcher{}) 335 if err == nil { 336 t.Error("Expected error when HS256 issuer receives non-HS256 token") 337 } 338 } 339} 340 341// === ParseJWTHeader Tests === 342 343func TestParseJWTHeader_Valid(t *testing.T) { 344 tokenString := createHS256Token(t, "did:plc:test123", "https://test.example.com", "secret", 1*time.Hour) 345 346 header, err := ParseJWTHeader(tokenString) 347 if err != nil { 348 t.Fatalf("ParseJWTHeader failed: %v", err) 349 } 350 351 if header.Alg != AlgorithmHS256 { 352 t.Errorf("Expected alg '%s', got '%s'", AlgorithmHS256, header.Alg) 353 } 354} 355 356func TestParseJWTHeader_WithBearerPrefix(t *testing.T) { 357 tokenString := createHS256Token(t, "did:plc:test123", "https://test.example.com", "secret", 1*time.Hour) 358 359 header, err := ParseJWTHeader("Bearer " + tokenString) 360 if err != nil { 361 t.Fatalf("ParseJWTHeader failed with Bearer prefix: %v", err) 362 } 363 364 if header.Alg != AlgorithmHS256 { 365 t.Errorf("Expected alg '%s', got '%s'", AlgorithmHS256, header.Alg) 366 } 367} 368 369func TestParseJWTHeader_InvalidFormat(t *testing.T) { 370 testCases := []struct { 371 name string 372 input string 373 }{ 374 {"empty string", ""}, 375 {"single part", "abc"}, 376 {"two parts", "abc.def"}, 377 {"too many parts", "a.b.c.d"}, 378 } 379 380 for _, tc := range testCases { 381 t.Run(tc.name, func(t *testing.T) { 382 _, err := ParseJWTHeader(tc.input) 383 if err == nil { 384 t.Errorf("Expected error for invalid JWT format '%s', got nil", tc.input) 385 } 386 }) 387 } 388} 389 390// === shouldUseHS256 and isHS256IssuerWhitelisted Tests === 391 392func TestIsHS256IssuerWhitelisted_Whitelisted(t *testing.T) { 393 ResetJWTConfigForTesting() 394 t.Setenv("HS256_ISSUERS", "https://pds1.example.com,https://pds2.example.com") 395 t.Cleanup(ResetJWTConfigForTesting) 396 397 if !isHS256IssuerWhitelisted("https://pds1.example.com") { 398 t.Error("Expected pds1 to be whitelisted") 399 } 400 if !isHS256IssuerWhitelisted("https://pds2.example.com") { 401 t.Error("Expected pds2 to be whitelisted") 402 } 403} 404 405func TestIsHS256IssuerWhitelisted_NotWhitelisted(t *testing.T) { 406 ResetJWTConfigForTesting() 407 t.Setenv("HS256_ISSUERS", "https://pds1.example.com") 408 t.Cleanup(ResetJWTConfigForTesting) 409 410 if isHS256IssuerWhitelisted("https://attacker.example.com") { 411 t.Error("Expected non-whitelisted issuer to return false") 412 } 413} 414 415func TestIsHS256IssuerWhitelisted_EmptyWhitelist(t *testing.T) { 416 ResetJWTConfigForTesting() 417 t.Setenv("HS256_ISSUERS", "") // Empty whitelist 418 t.Cleanup(ResetJWTConfigForTesting) 419 420 if isHS256IssuerWhitelisted("https://any.example.com") { 421 t.Error("Expected false when whitelist is empty (safe default)") 422 } 423} 424 425func TestIsHS256IssuerWhitelisted_WhitespaceHandling(t *testing.T) { 426 ResetJWTConfigForTesting() 427 t.Setenv("HS256_ISSUERS", " https://pds1.example.com , https://pds2.example.com ") 428 t.Cleanup(ResetJWTConfigForTesting) 429 430 if !isHS256IssuerWhitelisted("https://pds1.example.com") { 431 t.Error("Expected whitespace-trimmed issuer to be whitelisted") 432 } 433} 434 435// === shouldUseHS256 Tests (kid-based logic) === 436 437func TestShouldUseHS256_WithKid_AlwaysFalse(t *testing.T) { 438 // Tokens with kid should NEVER use HS256, regardless of issuer whitelist 439 ResetJWTConfigForTesting() 440 t.Setenv("HS256_ISSUERS", "https://whitelisted.example.com") 441 t.Cleanup(ResetJWTConfigForTesting) 442 443 header := &JWTHeader{ 444 Alg: AlgorithmHS256, 445 Kid: "some-key-id", // Has kid 446 } 447 448 // Even whitelisted issuer should not use HS256 if token has kid 449 if shouldUseHS256(header, "https://whitelisted.example.com") { 450 t.Error("Tokens with kid should never use HS256 (supports federation)") 451 } 452} 453 454func TestShouldUseHS256_WithoutKid_WhitelistedIssuer(t *testing.T) { 455 ResetJWTConfigForTesting() 456 t.Setenv("HS256_ISSUERS", "https://my-pds.example.com") 457 t.Cleanup(ResetJWTConfigForTesting) 458 459 header := &JWTHeader{ 460 Alg: AlgorithmHS256, 461 Kid: "", // No kid 462 } 463 464 if !shouldUseHS256(header, "https://my-pds.example.com") { 465 t.Error("Token without kid from whitelisted issuer should use HS256") 466 } 467} 468 469func TestShouldUseHS256_WithoutKid_NotWhitelisted(t *testing.T) { 470 ResetJWTConfigForTesting() 471 t.Setenv("HS256_ISSUERS", "https://my-pds.example.com") 472 t.Cleanup(ResetJWTConfigForTesting) 473 474 header := &JWTHeader{ 475 Alg: AlgorithmHS256, 476 Kid: "", // No kid 477 } 478 479 if shouldUseHS256(header, "https://external-pds.example.com") { 480 t.Error("Token without kid from non-whitelisted issuer should NOT use HS256") 481 } 482} 483 484// Helper function 485func contains(s, substr string) bool { 486 return len(s) >= len(substr) && (s == substr || len(s) > 0 && containsHelper(s, substr)) 487} 488 489func containsHelper(s, substr string) bool { 490 for i := 0; i <= len(s)-len(substr); i++ { 491 if s[i:i+len(substr)] == substr { 492 return true 493 } 494 } 495 return false 496}