A community based topic aggregation platform built on atproto
1package auth
2
3import (
4 "context"
5 "encoding/json"
6 "fmt"
7 "net/http"
8 "strings"
9 "sync"
10 "time"
11)
12
13// CachedJWKSFetcher fetches and caches JWKS from authorization servers
14type CachedJWKSFetcher struct {
15 cache map[string]*cachedJWKS
16 httpClient *http.Client
17 cacheMutex sync.RWMutex
18 cacheTTL time.Duration
19}
20
21type cachedJWKS struct {
22 jwks *JWKS
23 expiresAt time.Time
24}
25
26// NewCachedJWKSFetcher creates a new JWKS fetcher with caching
27func NewCachedJWKSFetcher(cacheTTL time.Duration) *CachedJWKSFetcher {
28 return &CachedJWKSFetcher{
29 cache: make(map[string]*cachedJWKS),
30 httpClient: &http.Client{
31 Timeout: 10 * time.Second,
32 },
33 cacheTTL: cacheTTL,
34 }
35}
36
37// FetchPublicKey fetches the public key for verifying a JWT from the issuer
38// Implements JWKSFetcher interface
39// Returns interface{} to support both RSA and ECDSA keys
40func (f *CachedJWKSFetcher) FetchPublicKey(ctx context.Context, issuer, token string) (interface{}, error) {
41 // Extract key ID from token
42 kid, err := ExtractKeyID(token)
43 if err != nil {
44 return nil, fmt.Errorf("failed to extract key ID: %w", err)
45 }
46
47 // Get JWKS from cache or fetch
48 jwks, err := f.getJWKS(ctx, issuer)
49 if err != nil {
50 return nil, err
51 }
52
53 // Find the key by ID
54 jwk, err := jwks.FindKeyByID(kid)
55 if err != nil {
56 // Key not found in cache - try refreshing
57 jwks, err = f.fetchJWKS(ctx, issuer)
58 if err != nil {
59 return nil, fmt.Errorf("failed to refresh JWKS: %w", err)
60 }
61 f.cacheJWKS(issuer, jwks)
62
63 // Try again with fresh JWKS
64 jwk, err = jwks.FindKeyByID(kid)
65 if err != nil {
66 return nil, err
67 }
68 }
69
70 // Convert JWK to public key (RSA or ECDSA)
71 return jwk.ToPublicKey()
72}
73
74// getJWKS gets JWKS from cache or fetches if not cached/expired
75func (f *CachedJWKSFetcher) getJWKS(ctx context.Context, issuer string) (*JWKS, error) {
76 // Check cache first
77 f.cacheMutex.RLock()
78 cached, exists := f.cache[issuer]
79 f.cacheMutex.RUnlock()
80
81 if exists && time.Now().Before(cached.expiresAt) {
82 return cached.jwks, nil
83 }
84
85 // Not in cache or expired - fetch from issuer
86 jwks, err := f.fetchJWKS(ctx, issuer)
87 if err != nil {
88 return nil, err
89 }
90
91 // Cache it
92 f.cacheJWKS(issuer, jwks)
93
94 return jwks, nil
95}
96
97// fetchJWKS fetches JWKS from the authorization server
98func (f *CachedJWKSFetcher) fetchJWKS(ctx context.Context, issuer string) (*JWKS, error) {
99 // Step 1: Fetch OAuth server metadata to get JWKS URI
100 metadataURL := strings.TrimSuffix(issuer, "/") + "/.well-known/oauth-authorization-server"
101
102 req, err := http.NewRequestWithContext(ctx, "GET", metadataURL, nil)
103 if err != nil {
104 return nil, fmt.Errorf("failed to create metadata request: %w", err)
105 }
106
107 resp, err := f.httpClient.Do(req)
108 if err != nil {
109 return nil, fmt.Errorf("failed to fetch metadata: %w", err)
110 }
111 defer func() {
112 _ = resp.Body.Close()
113 }()
114
115 if resp.StatusCode != http.StatusOK {
116 return nil, fmt.Errorf("metadata endpoint returned status %d", resp.StatusCode)
117 }
118
119 var metadata struct {
120 JWKSURI string `json:"jwks_uri"`
121 }
122 if err := json.NewDecoder(resp.Body).Decode(&metadata); err != nil {
123 return nil, fmt.Errorf("failed to decode metadata: %w", err)
124 }
125
126 if metadata.JWKSURI == "" {
127 return nil, fmt.Errorf("jwks_uri not found in metadata")
128 }
129
130 // Step 2: Fetch JWKS from the JWKS URI
131 jwksReq, err := http.NewRequestWithContext(ctx, "GET", metadata.JWKSURI, nil)
132 if err != nil {
133 return nil, fmt.Errorf("failed to create JWKS request: %w", err)
134 }
135
136 jwksResp, err := f.httpClient.Do(jwksReq)
137 if err != nil {
138 return nil, fmt.Errorf("failed to fetch JWKS: %w", err)
139 }
140 defer func() {
141 _ = jwksResp.Body.Close()
142 }()
143
144 if jwksResp.StatusCode != http.StatusOK {
145 return nil, fmt.Errorf("JWKS endpoint returned status %d", jwksResp.StatusCode)
146 }
147
148 var jwks JWKS
149 if err := json.NewDecoder(jwksResp.Body).Decode(&jwks); err != nil {
150 return nil, fmt.Errorf("failed to decode JWKS: %w", err)
151 }
152
153 if len(jwks.Keys) == 0 {
154 return nil, fmt.Errorf("no keys found in JWKS")
155 }
156
157 return &jwks, nil
158}
159
160// cacheJWKS stores JWKS in the cache
161func (f *CachedJWKSFetcher) cacheJWKS(issuer string, jwks *JWKS) {
162 f.cacheMutex.Lock()
163 defer f.cacheMutex.Unlock()
164
165 f.cache[issuer] = &cachedJWKS{
166 jwks: jwks,
167 expiresAt: time.Now().Add(f.cacheTTL),
168 }
169}
170
171// ClearCache clears the entire JWKS cache
172func (f *CachedJWKSFetcher) ClearCache() {
173 f.cacheMutex.Lock()
174 defer f.cacheMutex.Unlock()
175 f.cache = make(map[string]*cachedJWKS)
176}
177
178// CleanupExpiredCache removes expired entries from the cache
179func (f *CachedJWKSFetcher) CleanupExpiredCache() {
180 f.cacheMutex.Lock()
181 defer f.cacheMutex.Unlock()
182
183 now := time.Now()
184 for issuer, cached := range f.cache {
185 if now.After(cached.expiresAt) {
186 delete(f.cache, issuer)
187 }
188 }
189}