A community based topic aggregation platform built on atproto
1package auth 2 3import ( 4 "context" 5 "encoding/json" 6 "fmt" 7 "net/http" 8 "strings" 9 "sync" 10 "time" 11) 12 13// CachedJWKSFetcher fetches and caches JWKS from authorization servers 14type CachedJWKSFetcher struct { 15 cache map[string]*cachedJWKS 16 httpClient *http.Client 17 cacheMutex sync.RWMutex 18 cacheTTL time.Duration 19} 20 21type cachedJWKS struct { 22 jwks *JWKS 23 expiresAt time.Time 24} 25 26// NewCachedJWKSFetcher creates a new JWKS fetcher with caching 27func NewCachedJWKSFetcher(cacheTTL time.Duration) *CachedJWKSFetcher { 28 return &CachedJWKSFetcher{ 29 cache: make(map[string]*cachedJWKS), 30 httpClient: &http.Client{ 31 Timeout: 10 * time.Second, 32 }, 33 cacheTTL: cacheTTL, 34 } 35} 36 37// FetchPublicKey fetches the public key for verifying a JWT from the issuer 38// Implements JWKSFetcher interface 39// Returns interface{} to support both RSA and ECDSA keys 40func (f *CachedJWKSFetcher) FetchPublicKey(ctx context.Context, issuer, token string) (interface{}, error) { 41 // Extract key ID from token 42 kid, err := ExtractKeyID(token) 43 if err != nil { 44 return nil, fmt.Errorf("failed to extract key ID: %w", err) 45 } 46 47 // Get JWKS from cache or fetch 48 jwks, err := f.getJWKS(ctx, issuer) 49 if err != nil { 50 return nil, err 51 } 52 53 // Find the key by ID 54 jwk, err := jwks.FindKeyByID(kid) 55 if err != nil { 56 // Key not found in cache - try refreshing 57 jwks, err = f.fetchJWKS(ctx, issuer) 58 if err != nil { 59 return nil, fmt.Errorf("failed to refresh JWKS: %w", err) 60 } 61 f.cacheJWKS(issuer, jwks) 62 63 // Try again with fresh JWKS 64 jwk, err = jwks.FindKeyByID(kid) 65 if err != nil { 66 return nil, err 67 } 68 } 69 70 // Convert JWK to public key (RSA or ECDSA) 71 return jwk.ToPublicKey() 72} 73 74// getJWKS gets JWKS from cache or fetches if not cached/expired 75func (f *CachedJWKSFetcher) getJWKS(ctx context.Context, issuer string) (*JWKS, error) { 76 // Check cache first 77 f.cacheMutex.RLock() 78 cached, exists := f.cache[issuer] 79 f.cacheMutex.RUnlock() 80 81 if exists && time.Now().Before(cached.expiresAt) { 82 return cached.jwks, nil 83 } 84 85 // Not in cache or expired - fetch from issuer 86 jwks, err := f.fetchJWKS(ctx, issuer) 87 if err != nil { 88 return nil, err 89 } 90 91 // Cache it 92 f.cacheJWKS(issuer, jwks) 93 94 return jwks, nil 95} 96 97// fetchJWKS fetches JWKS from the authorization server 98func (f *CachedJWKSFetcher) fetchJWKS(ctx context.Context, issuer string) (*JWKS, error) { 99 // Step 1: Fetch OAuth server metadata to get JWKS URI 100 metadataURL := strings.TrimSuffix(issuer, "/") + "/.well-known/oauth-authorization-server" 101 102 req, err := http.NewRequestWithContext(ctx, "GET", metadataURL, nil) 103 if err != nil { 104 return nil, fmt.Errorf("failed to create metadata request: %w", err) 105 } 106 107 resp, err := f.httpClient.Do(req) 108 if err != nil { 109 return nil, fmt.Errorf("failed to fetch metadata: %w", err) 110 } 111 defer func() { 112 _ = resp.Body.Close() 113 }() 114 115 if resp.StatusCode != http.StatusOK { 116 return nil, fmt.Errorf("metadata endpoint returned status %d", resp.StatusCode) 117 } 118 119 var metadata struct { 120 JWKSURI string `json:"jwks_uri"` 121 } 122 if err := json.NewDecoder(resp.Body).Decode(&metadata); err != nil { 123 return nil, fmt.Errorf("failed to decode metadata: %w", err) 124 } 125 126 if metadata.JWKSURI == "" { 127 return nil, fmt.Errorf("jwks_uri not found in metadata") 128 } 129 130 // Step 2: Fetch JWKS from the JWKS URI 131 jwksReq, err := http.NewRequestWithContext(ctx, "GET", metadata.JWKSURI, nil) 132 if err != nil { 133 return nil, fmt.Errorf("failed to create JWKS request: %w", err) 134 } 135 136 jwksResp, err := f.httpClient.Do(jwksReq) 137 if err != nil { 138 return nil, fmt.Errorf("failed to fetch JWKS: %w", err) 139 } 140 defer func() { 141 _ = jwksResp.Body.Close() 142 }() 143 144 if jwksResp.StatusCode != http.StatusOK { 145 return nil, fmt.Errorf("JWKS endpoint returned status %d", jwksResp.StatusCode) 146 } 147 148 var jwks JWKS 149 if err := json.NewDecoder(jwksResp.Body).Decode(&jwks); err != nil { 150 return nil, fmt.Errorf("failed to decode JWKS: %w", err) 151 } 152 153 if len(jwks.Keys) == 0 { 154 return nil, fmt.Errorf("no keys found in JWKS") 155 } 156 157 return &jwks, nil 158} 159 160// cacheJWKS stores JWKS in the cache 161func (f *CachedJWKSFetcher) cacheJWKS(issuer string, jwks *JWKS) { 162 f.cacheMutex.Lock() 163 defer f.cacheMutex.Unlock() 164 165 f.cache[issuer] = &cachedJWKS{ 166 jwks: jwks, 167 expiresAt: time.Now().Add(f.cacheTTL), 168 } 169} 170 171// ClearCache clears the entire JWKS cache 172func (f *CachedJWKSFetcher) ClearCache() { 173 f.cacheMutex.Lock() 174 defer f.cacheMutex.Unlock() 175 f.cache = make(map[string]*cachedJWKS) 176} 177 178// CleanupExpiredCache removes expired entries from the cache 179func (f *CachedJWKSFetcher) CleanupExpiredCache() { 180 f.cacheMutex.Lock() 181 defer f.cacheMutex.Unlock() 182 183 now := time.Now() 184 for issuer, cached := range f.cache { 185 if now.After(cached.expiresAt) { 186 delete(f.cache, issuer) 187 } 188 } 189}