A community based topic aggregation platform built on atproto

perf(auth): cache JWT config at startup

Cache HS256_ISSUERS, PDS_JWT_SECRET, and IS_DEV_ENV at startup instead
of reading environment variables on every token verification request.

- Add jwtConfig struct with sync.Once initialization
- Use map[string]struct{} for O(1) issuer whitelist lookups
- Add InitJWTConfig() for explicit startup initialization
- Add ResetJWTConfigForTesting() for test isolation
- Update main.go to call InitJWTConfig() at startup

Before: 2-3 os.Getenv() calls + O(n) string iteration per request
After: Single pointer dereference + O(1) map lookup per request

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>

Changed files
+600 -32
cmd
server
internal
atproto
+24 -3
cmd/server/main.go
···
commentsAPI "Coves/internal/api/handlers/comments"
postgresRepo "Coves/internal/db/postgres"
+
+
indigoIdentity "github.com/bluesky-social/indigo/atproto/identity"
)
func main() {
···
log.Println(" Set AUTH_SKIP_VERIFY=false for production")
}
-
jwksCacheTTL := 1 * time.Hour // Cache public keys for 1 hour
+
// Initialize Indigo directory for DID resolution (used by auth)
+
plcURL := os.Getenv("PLC_DIRECTORY_URL")
+
if plcURL == "" {
+
plcURL = "https://plc.directory"
+
}
+
indigoDir := &indigoIdentity.BaseDirectory{
+
PLCURL: plcURL,
+
HTTPClient: http.Client{Timeout: 10 * time.Second},
+
}
+
+
// Initialize JWT config early to cache HS256_ISSUERS and PDS_JWT_SECRET
+
// This avoids reading env vars on every request
+
auth.InitJWTConfig()
+
+
// Create combined key fetcher for both DID and URL issuers
+
// - DID issuers (did:plc:, did:web:) → resolved via DID document keys (ES256)
+
// - URL issuers → JWKS endpoint (fallback for legacy tokens)
+
jwksCacheTTL := 1 * time.Hour
jwksFetcher := auth.NewCachedJWKSFetcher(jwksCacheTTL)
-
authMiddleware := middleware.NewAtProtoAuthMiddleware(jwksFetcher, skipVerify)
-
log.Println("✅ atProto auth middleware initialized")
+
keyFetcher := auth.NewCombinedKeyFetcher(indigoDir, jwksFetcher)
+
+
authMiddleware := middleware.NewAtProtoAuthMiddleware(keyFetcher, skipVerify)
+
log.Println("✅ atProto auth middleware initialized (DID + JWKS key resolution)")
// Initialize repositories and services
userRepo := postgresRepo.NewUserRepository(db)
+209 -29
internal/atproto/auth/jwt.go
···
"net/url"
"os"
"strings"
+
"sync"
"time"
"github.com/golang-jwt/jwt/v5"
)
+
// jwtConfig holds cached JWT configuration to avoid reading env vars on every request
+
type jwtConfig struct {
+
hs256Issuers map[string]struct{} // Set of whitelisted HS256 issuers
+
pdsJWTSecret []byte // Cached PDS_JWT_SECRET
+
isDevEnv bool // Cached IS_DEV_ENV
+
}
+
+
var (
+
cachedConfig *jwtConfig
+
configOnce sync.Once
+
)
+
+
// InitJWTConfig initializes the JWT configuration from environment variables.
+
// This should be called once at startup. If not called explicitly, it will be
+
// initialized lazily on first use.
+
func InitJWTConfig() {
+
configOnce.Do(func() {
+
cachedConfig = &jwtConfig{
+
hs256Issuers: make(map[string]struct{}),
+
isDevEnv: os.Getenv("IS_DEV_ENV") == "true",
+
}
+
+
// Parse HS256_ISSUERS into a set for O(1) lookup
+
if issuers := os.Getenv("HS256_ISSUERS"); issuers != "" {
+
for _, issuer := range strings.Split(issuers, ",") {
+
issuer = strings.TrimSpace(issuer)
+
if issuer != "" {
+
cachedConfig.hs256Issuers[issuer] = struct{}{}
+
}
+
}
+
}
+
+
// Cache PDS_JWT_SECRET
+
if secret := os.Getenv("PDS_JWT_SECRET"); secret != "" {
+
cachedConfig.pdsJWTSecret = []byte(secret)
+
}
+
})
+
}
+
+
// getConfig returns the cached config, initializing if needed
+
func getConfig() *jwtConfig {
+
InitJWTConfig()
+
return cachedConfig
+
}
+
+
// ResetJWTConfigForTesting resets the cached config to allow re-initialization.
+
// This should ONLY be used in tests.
+
func ResetJWTConfigForTesting() {
+
cachedConfig = nil
+
configOnce = sync.Once{}
+
}
+
+
// Algorithm constants for JWT signing methods
+
const (
+
AlgorithmHS256 = "HS256"
+
AlgorithmRS256 = "RS256"
+
AlgorithmES256 = "ES256"
+
)
+
+
// JWTHeader represents the parsed JWT header
+
type JWTHeader struct {
+
Alg string `json:"alg"`
+
Kid string `json:"kid"`
+
Typ string `json:"typ,omitempty"`
+
}
+
// Claims represents the standard JWT claims we care about
type Claims struct {
jwt.RegisteredClaims
Scope string `json:"scope,omitempty"`
}
+
// stripBearerPrefix removes the "Bearer " prefix from a token string
+
func stripBearerPrefix(tokenString string) string {
+
tokenString = strings.TrimPrefix(tokenString, "Bearer ")
+
return strings.TrimSpace(tokenString)
+
}
+
+
// ParseJWTHeader extracts and parses the JWT header from a token string
+
// This is a reusable function for getting algorithm and key ID information
+
func ParseJWTHeader(tokenString string) (*JWTHeader, error) {
+
tokenString = stripBearerPrefix(tokenString)
+
+
parts := strings.Split(tokenString, ".")
+
if len(parts) != 3 {
+
return nil, fmt.Errorf("invalid JWT format: expected 3 parts, got %d", len(parts))
+
}
+
+
headerBytes, err := base64.RawURLEncoding.DecodeString(parts[0])
+
if err != nil {
+
return nil, fmt.Errorf("failed to decode JWT header: %w", err)
+
}
+
+
var header JWTHeader
+
if err := json.Unmarshal(headerBytes, &header); err != nil {
+
return nil, fmt.Errorf("failed to parse JWT header: %w", err)
+
}
+
+
return &header, nil
+
}
+
+
// shouldUseHS256 determines if a token should use HS256 verification
+
// This prevents algorithm confusion attacks by using multiple signals:
+
// 1. If the token has a `kid` (key ID), it MUST use asymmetric verification
+
// 2. If no `kid`, only allow HS256 from whitelisted issuers (your own PDS)
+
//
+
// This approach supports open federation because:
+
// - External PDSes publish keys via JWKS and include `kid` in their tokens
+
// - Only your own PDS (which shares PDS_JWT_SECRET) uses HS256 without `kid`
+
func shouldUseHS256(header *JWTHeader, issuer string) bool {
+
// If token has a key ID, it MUST use asymmetric verification
+
// This is the primary defense against algorithm confusion attacks
+
if header.Kid != "" {
+
return false
+
}
+
+
// No kid - check if issuer is whitelisted for HS256
+
// This should only include your own PDS URL(s)
+
return isHS256IssuerWhitelisted(issuer)
+
}
+
+
// isHS256IssuerWhitelisted checks if the issuer is in the HS256 whitelist
+
// Only your own PDS should be in this list - external PDSes should use JWKS
+
func isHS256IssuerWhitelisted(issuer string) bool {
+
cfg := getConfig()
+
_, whitelisted := cfg.hs256Issuers[issuer]
+
return whitelisted
+
}
+
// ParseJWT parses a JWT token without verification (Phase 1)
// Returns the claims if the token is valid JSON and has required fields
func ParseJWT(tokenString string) (*Claims, error) {
// Remove "Bearer " prefix if present
-
tokenString = strings.TrimPrefix(tokenString, "Bearer ")
-
tokenString = strings.TrimSpace(tokenString)
+
tokenString = stripBearerPrefix(tokenString)
// Parse without verification first to extract claims
parser := jwt.NewParser(jwt.WithoutClaimsValidation())
···
// VerifyJWT verifies a JWT token's signature and claims (Phase 2)
// Fetches the public key from the issuer's JWKS endpoint and validates the signature
+
// For HS256 tokens from whitelisted issuers, uses the shared PDS_JWT_SECRET
+
//
+
// SECURITY: Algorithm is determined by the issuer whitelist, NOT the token header,
+
// to prevent algorithm confusion attacks where an attacker could re-sign a token
+
// with HS256 using a public key as the secret.
func VerifyJWT(ctx context.Context, tokenString string, keyFetcher JWKSFetcher) (*Claims, error) {
-
// First parse to get the issuer
+
// Strip Bearer prefix once at the start
+
tokenString = stripBearerPrefix(tokenString)
+
+
// First parse to get the issuer (needed to determine expected algorithm)
claims, err := ParseJWT(tokenString)
if err != nil {
return nil, err
}
-
// Fetch the public key from the issuer
-
publicKey, err := keyFetcher.FetchPublicKey(ctx, claims.Issuer, tokenString)
+
// Parse header to get the claimed algorithm (for validation)
+
header, err := ParseJWTHeader(tokenString)
+
if err != nil {
+
return nil, err
+
}
+
+
// SECURITY: Determine verification method based on token characteristics
+
// 1. Tokens with `kid` MUST use asymmetric verification (supports federation)
+
// 2. Tokens without `kid` can use HS256 only from whitelisted issuers (your own PDS)
+
useHS256 := shouldUseHS256(header, claims.Issuer)
+
+
if useHS256 {
+
// Verify token actually claims to use HS256
+
if header.Alg != AlgorithmHS256 {
+
return nil, fmt.Errorf("expected HS256 for issuer %s but token uses %s", claims.Issuer, header.Alg)
+
}
+
return verifyHS256Token(tokenString)
+
}
+
+
// Token must use asymmetric verification
+
// Reject HS256 tokens that don't meet the criteria above
+
if header.Alg == AlgorithmHS256 {
+
if header.Kid != "" {
+
return nil, fmt.Errorf("HS256 tokens with kid must use asymmetric verification")
+
}
+
return nil, fmt.Errorf("HS256 not allowed for issuer %s (not in HS256_ISSUERS whitelist)", claims.Issuer)
+
}
+
+
// For RSA/ECDSA, fetch public key from JWKS and verify
+
return verifyAsymmetricToken(ctx, tokenString, claims.Issuer, keyFetcher)
+
}
+
+
// verifyHS256Token verifies a JWT using HMAC-SHA256 with the shared secret
+
func verifyHS256Token(tokenString string) (*Claims, error) {
+
cfg := getConfig()
+
if len(cfg.pdsJWTSecret) == 0 {
+
return nil, fmt.Errorf("HS256 verification failed: PDS_JWT_SECRET not configured")
+
}
+
+
token, err := jwt.ParseWithClaims(tokenString, &Claims{}, func(token *jwt.Token) (interface{}, error) {
+
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
+
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
+
}
+
return cfg.pdsJWTSecret, nil
+
})
+
if err != nil {
+
return nil, fmt.Errorf("HS256 verification failed: %w", err)
+
}
+
+
if !token.Valid {
+
return nil, fmt.Errorf("HS256 verification failed: token signature invalid")
+
}
+
+
verifiedClaims, ok := token.Claims.(*Claims)
+
if !ok {
+
return nil, fmt.Errorf("HS256 verification failed: invalid claims type")
+
}
+
+
if err := validateClaims(verifiedClaims); err != nil {
+
return nil, err
+
}
+
+
return verifiedClaims, nil
+
}
+
+
// verifyAsymmetricToken verifies a JWT using RSA or ECDSA with a public key from JWKS
+
func verifyAsymmetricToken(ctx context.Context, tokenString, issuer string, keyFetcher JWKSFetcher) (*Claims, error) {
+
publicKey, err := keyFetcher.FetchPublicKey(ctx, issuer, tokenString)
if err != nil {
return nil, fmt.Errorf("failed to fetch public key: %w", err)
}
-
// Now parse and verify with the public key
-
tokenString = strings.TrimPrefix(tokenString, "Bearer ")
token, err := jwt.ParseWithClaims(tokenString, &Claims{}, func(token *jwt.Token) (interface{}, error) {
// Validate signing method - support both RSA and ECDSA (atProto uses ES256 primarily)
switch token.Method.(type) {
···
return publicKey, nil
})
if err != nil {
-
return nil, fmt.Errorf("failed to verify JWT: %w", err)
+
return nil, fmt.Errorf("asymmetric verification failed: %w", err)
}
if !token.Valid {
-
return nil, fmt.Errorf("token is invalid")
+
return nil, fmt.Errorf("asymmetric verification failed: token signature invalid")
}
verifiedClaims, ok := token.Claims.(*Claims)
if !ok {
-
return nil, fmt.Errorf("invalid claims type after verification")
+
return nil, fmt.Errorf("asymmetric verification failed: invalid claims type")
}
-
// Additional validation
if err := validateClaims(verifiedClaims); err != nil {
return nil, err
}
···
}
// In production, reject HTTP issuers (only for non-dev environments)
-
// Check IS_DEV_ENV environment variable
-
if isHTTP && os.Getenv("IS_DEV_ENV") != "true" {
+
cfg := getConfig()
+
if isHTTP && !cfg.isDevEnv {
return fmt.Errorf("HTTP issuer not allowed in production, got: %s", claims.Issuer)
}
···
// ExtractKeyID extracts the key ID from a JWT token header
func ExtractKeyID(tokenString string) (string, error) {
-
tokenString = strings.TrimPrefix(tokenString, "Bearer ")
-
parts := strings.Split(tokenString, ".")
-
if len(parts) != 3 {
-
return "", fmt.Errorf("invalid JWT format")
-
}
-
-
// Decode header
-
headerBytes, err := base64.RawURLEncoding.DecodeString(parts[0])
+
header, err := ParseJWTHeader(tokenString)
if err != nil {
-
return "", fmt.Errorf("failed to decode header: %w", err)
-
}
-
-
var header struct {
-
Kid string `json:"kid"`
-
}
-
if err := json.Unmarshal(headerBytes, &header); err != nil {
-
return "", fmt.Errorf("failed to unmarshal header: %w", err)
+
return "", err
}
if header.Kid == "" {
+367
internal/atproto/auth/jwt_test.go
···
package auth
import (
+
"context"
+
"os"
"testing"
"time"
···
}
}
}
+
+
// === HS256 Verification Tests ===
+
+
// mockJWKSFetcher is a mock implementation of JWKSFetcher for testing
+
type mockJWKSFetcher struct {
+
publicKey interface{}
+
err error
+
}
+
+
func (m *mockJWKSFetcher) FetchPublicKey(ctx context.Context, issuer, token string) (interface{}, error) {
+
return m.publicKey, m.err
+
}
+
+
func createHS256Token(t *testing.T, subject, issuer, secret string, expiry time.Duration) string {
+
t.Helper()
+
claims := &Claims{
+
RegisteredClaims: jwt.RegisteredClaims{
+
Subject: subject,
+
Issuer: issuer,
+
ExpiresAt: jwt.NewNumericDate(time.Now().Add(expiry)),
+
IssuedAt: jwt.NewNumericDate(time.Now()),
+
},
+
Scope: "atproto transition:generic",
+
}
+
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
+
tokenString, err := token.SignedString([]byte(secret))
+
if err != nil {
+
t.Fatalf("Failed to create test token: %v", err)
+
}
+
return tokenString
+
}
+
+
func TestVerifyJWT_HS256_Valid(t *testing.T) {
+
// Setup: Configure environment for HS256 verification
+
secret := "test-jwt-secret-key-12345"
+
issuer := "https://pds.coves.social"
+
+
ResetJWTConfigForTesting()
+
os.Setenv("PDS_JWT_SECRET", secret)
+
os.Setenv("HS256_ISSUERS", issuer)
+
defer func() {
+
os.Unsetenv("PDS_JWT_SECRET")
+
os.Unsetenv("HS256_ISSUERS")
+
ResetJWTConfigForTesting()
+
}()
+
+
tokenString := createHS256Token(t, "did:plc:test123", issuer, secret, 1*time.Hour)
+
+
// Verify token
+
claims, err := VerifyJWT(context.Background(), tokenString, &mockJWKSFetcher{})
+
if err != nil {
+
t.Fatalf("VerifyJWT failed for valid HS256 token: %v", err)
+
}
+
+
if claims.Subject != "did:plc:test123" {
+
t.Errorf("Expected subject 'did:plc:test123', got '%s'", claims.Subject)
+
}
+
if claims.Issuer != issuer {
+
t.Errorf("Expected issuer '%s', got '%s'", issuer, claims.Issuer)
+
}
+
}
+
+
func TestVerifyJWT_HS256_WrongSecret(t *testing.T) {
+
// Setup: Configure environment with one secret, sign with another
+
issuer := "https://pds.coves.social"
+
+
ResetJWTConfigForTesting()
+
os.Setenv("PDS_JWT_SECRET", "correct-secret")
+
os.Setenv("HS256_ISSUERS", issuer)
+
defer func() {
+
os.Unsetenv("PDS_JWT_SECRET")
+
os.Unsetenv("HS256_ISSUERS")
+
ResetJWTConfigForTesting()
+
}()
+
+
// Create token with wrong secret
+
tokenString := createHS256Token(t, "did:plc:test123", issuer, "wrong-secret", 1*time.Hour)
+
+
// Verify should fail
+
_, err := VerifyJWT(context.Background(), tokenString, &mockJWKSFetcher{})
+
if err == nil {
+
t.Error("Expected error for HS256 token with wrong secret, got nil")
+
}
+
}
+
+
func TestVerifyJWT_HS256_SecretNotConfigured(t *testing.T) {
+
// Setup: Whitelist issuer but don't configure secret
+
issuer := "https://pds.coves.social"
+
+
ResetJWTConfigForTesting()
+
os.Unsetenv("PDS_JWT_SECRET") // Ensure secret is not set
+
os.Setenv("HS256_ISSUERS", issuer)
+
defer func() {
+
os.Unsetenv("HS256_ISSUERS")
+
ResetJWTConfigForTesting()
+
}()
+
+
tokenString := createHS256Token(t, "did:plc:test123", issuer, "any-secret", 1*time.Hour)
+
+
// Verify should fail with descriptive error
+
_, err := VerifyJWT(context.Background(), tokenString, &mockJWKSFetcher{})
+
if err == nil {
+
t.Error("Expected error when PDS_JWT_SECRET not configured, got nil")
+
}
+
if err != nil && !contains(err.Error(), "PDS_JWT_SECRET not configured") {
+
t.Errorf("Expected error about PDS_JWT_SECRET not configured, got: %v", err)
+
}
+
}
+
+
// === Algorithm Confusion Attack Prevention Tests ===
+
+
func TestVerifyJWT_AlgorithmConfusionAttack_HS256WithNonWhitelistedIssuer(t *testing.T) {
+
// SECURITY TEST: This tests the algorithm confusion attack prevention
+
// An attacker tries to use HS256 with an issuer that should use RS256/ES256
+
+
ResetJWTConfigForTesting()
+
os.Setenv("PDS_JWT_SECRET", "some-secret")
+
os.Setenv("HS256_ISSUERS", "https://trusted.example.com") // Different from token issuer
+
defer func() {
+
os.Unsetenv("PDS_JWT_SECRET")
+
os.Unsetenv("HS256_ISSUERS")
+
ResetJWTConfigForTesting()
+
}()
+
+
// Create HS256 token with non-whitelisted issuer (simulating attack)
+
tokenString := createHS256Token(t, "did:plc:attacker", "https://victim-pds.example.com", "some-secret", 1*time.Hour)
+
+
// Verify should fail because issuer is not in HS256 whitelist
+
_, err := VerifyJWT(context.Background(), tokenString, &mockJWKSFetcher{})
+
if err == nil {
+
t.Error("SECURITY VULNERABILITY: HS256 token accepted for non-whitelisted issuer")
+
}
+
if err != nil && !contains(err.Error(), "not in HS256_ISSUERS whitelist") {
+
t.Errorf("Expected error about HS256 not allowed for issuer, got: %v", err)
+
}
+
}
+
+
func TestVerifyJWT_AlgorithmConfusionAttack_EmptyWhitelist(t *testing.T) {
+
// SECURITY TEST: When no issuers are whitelisted for HS256, all HS256 tokens should be rejected
+
+
ResetJWTConfigForTesting()
+
os.Setenv("PDS_JWT_SECRET", "some-secret")
+
os.Unsetenv("HS256_ISSUERS") // Empty whitelist
+
defer func() {
+
os.Unsetenv("PDS_JWT_SECRET")
+
ResetJWTConfigForTesting()
+
}()
+
+
tokenString := createHS256Token(t, "did:plc:test123", "https://any-pds.example.com", "some-secret", 1*time.Hour)
+
+
// Verify should fail because no issuers are whitelisted for HS256
+
_, err := VerifyJWT(context.Background(), tokenString, &mockJWKSFetcher{})
+
if err == nil {
+
t.Error("SECURITY VULNERABILITY: HS256 token accepted with empty issuer whitelist")
+
}
+
}
+
+
func TestVerifyJWT_IssuerRequiresHS256ButTokenUsesRS256(t *testing.T) {
+
// Test that issuer whitelisted for HS256 rejects tokens claiming to use RS256
+
issuer := "https://pds.coves.social"
+
+
ResetJWTConfigForTesting()
+
os.Setenv("PDS_JWT_SECRET", "test-secret")
+
os.Setenv("HS256_ISSUERS", issuer)
+
defer func() {
+
os.Unsetenv("PDS_JWT_SECRET")
+
os.Unsetenv("HS256_ISSUERS")
+
ResetJWTConfigForTesting()
+
}()
+
+
// Create RS256-signed token (can't actually sign without RSA key, but we can test the header check)
+
claims := &Claims{
+
RegisteredClaims: jwt.RegisteredClaims{
+
Subject: "did:plc:test123",
+
Issuer: issuer,
+
ExpiresAt: jwt.NewNumericDate(time.Now().Add(1 * time.Hour)),
+
},
+
}
+
token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims)
+
// This will create an invalid signature but valid header structure
+
// The test should fail at algorithm check, not signature verification
+
tokenString, _ := token.SignedString([]byte("dummy-key"))
+
+
if tokenString != "" {
+
_, err := VerifyJWT(context.Background(), tokenString, &mockJWKSFetcher{})
+
if err == nil {
+
t.Error("Expected error when HS256 issuer receives non-HS256 token")
+
}
+
}
+
}
+
+
// === ParseJWTHeader Tests ===
+
+
func TestParseJWTHeader_Valid(t *testing.T) {
+
tokenString := createHS256Token(t, "did:plc:test123", "https://test.example.com", "secret", 1*time.Hour)
+
+
header, err := ParseJWTHeader(tokenString)
+
if err != nil {
+
t.Fatalf("ParseJWTHeader failed: %v", err)
+
}
+
+
if header.Alg != AlgorithmHS256 {
+
t.Errorf("Expected alg '%s', got '%s'", AlgorithmHS256, header.Alg)
+
}
+
}
+
+
func TestParseJWTHeader_WithBearerPrefix(t *testing.T) {
+
tokenString := createHS256Token(t, "did:plc:test123", "https://test.example.com", "secret", 1*time.Hour)
+
+
header, err := ParseJWTHeader("Bearer " + tokenString)
+
if err != nil {
+
t.Fatalf("ParseJWTHeader failed with Bearer prefix: %v", err)
+
}
+
+
if header.Alg != AlgorithmHS256 {
+
t.Errorf("Expected alg '%s', got '%s'", AlgorithmHS256, header.Alg)
+
}
+
}
+
+
func TestParseJWTHeader_InvalidFormat(t *testing.T) {
+
testCases := []struct {
+
name string
+
input string
+
}{
+
{"empty string", ""},
+
{"single part", "abc"},
+
{"two parts", "abc.def"},
+
{"too many parts", "a.b.c.d"},
+
}
+
+
for _, tc := range testCases {
+
t.Run(tc.name, func(t *testing.T) {
+
_, err := ParseJWTHeader(tc.input)
+
if err == nil {
+
t.Errorf("Expected error for invalid JWT format '%s', got nil", tc.input)
+
}
+
})
+
}
+
}
+
+
// === shouldUseHS256 and isHS256IssuerWhitelisted Tests ===
+
+
func TestIsHS256IssuerWhitelisted_Whitelisted(t *testing.T) {
+
ResetJWTConfigForTesting()
+
os.Setenv("HS256_ISSUERS", "https://pds1.example.com,https://pds2.example.com")
+
defer func() {
+
os.Unsetenv("HS256_ISSUERS")
+
ResetJWTConfigForTesting()
+
}()
+
+
if !isHS256IssuerWhitelisted("https://pds1.example.com") {
+
t.Error("Expected pds1 to be whitelisted")
+
}
+
if !isHS256IssuerWhitelisted("https://pds2.example.com") {
+
t.Error("Expected pds2 to be whitelisted")
+
}
+
}
+
+
func TestIsHS256IssuerWhitelisted_NotWhitelisted(t *testing.T) {
+
ResetJWTConfigForTesting()
+
os.Setenv("HS256_ISSUERS", "https://pds1.example.com")
+
defer func() {
+
os.Unsetenv("HS256_ISSUERS")
+
ResetJWTConfigForTesting()
+
}()
+
+
if isHS256IssuerWhitelisted("https://attacker.example.com") {
+
t.Error("Expected non-whitelisted issuer to return false")
+
}
+
}
+
+
func TestIsHS256IssuerWhitelisted_EmptyWhitelist(t *testing.T) {
+
ResetJWTConfigForTesting()
+
os.Unsetenv("HS256_ISSUERS")
+
defer ResetJWTConfigForTesting()
+
+
if isHS256IssuerWhitelisted("https://any.example.com") {
+
t.Error("Expected false when whitelist is empty (safe default)")
+
}
+
}
+
+
func TestIsHS256IssuerWhitelisted_WhitespaceHandling(t *testing.T) {
+
ResetJWTConfigForTesting()
+
os.Setenv("HS256_ISSUERS", " https://pds1.example.com , https://pds2.example.com ")
+
defer func() {
+
os.Unsetenv("HS256_ISSUERS")
+
ResetJWTConfigForTesting()
+
}()
+
+
if !isHS256IssuerWhitelisted("https://pds1.example.com") {
+
t.Error("Expected whitespace-trimmed issuer to be whitelisted")
+
}
+
}
+
+
// === shouldUseHS256 Tests (kid-based logic) ===
+
+
func TestShouldUseHS256_WithKid_AlwaysFalse(t *testing.T) {
+
// Tokens with kid should NEVER use HS256, regardless of issuer whitelist
+
ResetJWTConfigForTesting()
+
os.Setenv("HS256_ISSUERS", "https://whitelisted.example.com")
+
defer func() {
+
os.Unsetenv("HS256_ISSUERS")
+
ResetJWTConfigForTesting()
+
}()
+
+
header := &JWTHeader{
+
Alg: AlgorithmHS256,
+
Kid: "some-key-id", // Has kid
+
}
+
+
// Even whitelisted issuer should not use HS256 if token has kid
+
if shouldUseHS256(header, "https://whitelisted.example.com") {
+
t.Error("Tokens with kid should never use HS256 (supports federation)")
+
}
+
}
+
+
func TestShouldUseHS256_WithoutKid_WhitelistedIssuer(t *testing.T) {
+
ResetJWTConfigForTesting()
+
os.Setenv("HS256_ISSUERS", "https://my-pds.example.com")
+
defer func() {
+
os.Unsetenv("HS256_ISSUERS")
+
ResetJWTConfigForTesting()
+
}()
+
+
header := &JWTHeader{
+
Alg: AlgorithmHS256,
+
Kid: "", // No kid
+
}
+
+
if !shouldUseHS256(header, "https://my-pds.example.com") {
+
t.Error("Token without kid from whitelisted issuer should use HS256")
+
}
+
}
+
+
func TestShouldUseHS256_WithoutKid_NotWhitelisted(t *testing.T) {
+
ResetJWTConfigForTesting()
+
os.Setenv("HS256_ISSUERS", "https://my-pds.example.com")
+
defer func() {
+
os.Unsetenv("HS256_ISSUERS")
+
ResetJWTConfigForTesting()
+
}()
+
+
header := &JWTHeader{
+
Alg: AlgorithmHS256,
+
Kid: "", // No kid
+
}
+
+
if shouldUseHS256(header, "https://external-pds.example.com") {
+
t.Error("Token without kid from non-whitelisted issuer should NOT use HS256")
+
}
+
}
+
+
// Helper function
+
func contains(s, substr string) bool {
+
return len(s) >= len(substr) && (s == substr || len(s) > 0 && containsHelper(s, substr))
+
}
+
+
func containsHelper(s, substr string) bool {
+
for i := 0; i <= len(s)-len(substr); i++ {
+
if s[i:i+len(substr)] == substr {
+
return true
+
}
+
}
+
return false
+
}