1package client
2
3import (
4 "context"
5 "encoding/json"
6 "errors"
7 "fmt"
8 "io"
9 "log/slog"
10 "net/http"
11 "net/url"
12 "slices"
13 "strings"
14 "time"
15
16 cache "github.com/go-pkgz/expirable-cache/v3"
17 "github.com/haileyok/cocoon/internal/helpers"
18 "github.com/lestrrat-go/jwx/v2/jwk"
19)
20
21type Manager struct {
22 cli *http.Client
23 logger *slog.Logger
24 jwksCache cache.Cache[string, jwk.Key]
25 metadataCache cache.Cache[string, Metadata]
26}
27
28type ManagerArgs struct {
29 Cli *http.Client
30 Logger *slog.Logger
31}
32
33func NewManager(args ManagerArgs) *Manager {
34 if args.Logger == nil {
35 args.Logger = slog.Default()
36 }
37
38 if args.Cli == nil {
39 args.Cli = http.DefaultClient
40 }
41
42 jwksCache := cache.NewCache[string, jwk.Key]().WithLRU().WithMaxKeys(500).WithTTL(5 * time.Minute)
43 metadataCache := cache.NewCache[string, Metadata]().WithLRU().WithMaxKeys(500).WithTTL(5 * time.Minute)
44
45 return &Manager{
46 cli: args.Cli,
47 logger: args.Logger,
48 jwksCache: jwksCache,
49 metadataCache: metadataCache,
50 }
51}
52
53func (cm *Manager) GetClient(ctx context.Context, clientId string) (*Client, error) {
54 metadata, err := cm.getClientMetadata(ctx, clientId)
55 if err != nil {
56 return nil, err
57 }
58
59 var jwks jwk.Key
60 if metadata.JWKS != nil && len(metadata.JWKS.Keys) > 0 {
61 // TODO: this is kinda bad but whatever for now. there could obviously be more than one jwk, and we need to
62 // make sure we use the right one
63 b, err := json.Marshal(metadata.JWKS.Keys[0])
64 if err != nil {
65 return nil, err
66 }
67
68 k, err := helpers.ParseJWKFromBytes(b)
69 if err != nil {
70 return nil, err
71 }
72
73 jwks = k
74 } else if metadata.JWKSURI != nil {
75 maybeJwks, err := cm.getClientJwks(ctx, clientId, *metadata.JWKSURI)
76 if err != nil {
77 return nil, err
78 }
79
80 jwks = maybeJwks
81 } else {
82 return nil, fmt.Errorf("no valid jwks found in oauth client metadata")
83 }
84
85 return &Client{
86 Metadata: metadata,
87 JWKS: jwks,
88 }, nil
89}
90
91func (cm *Manager) getClientMetadata(ctx context.Context, clientId string) (*Metadata, error) {
92 metadataCached, ok := cm.metadataCache.Get(clientId)
93 if !ok {
94 req, err := http.NewRequestWithContext(ctx, "GET", clientId, nil)
95 if err != nil {
96 return nil, err
97 }
98
99 resp, err := cm.cli.Do(req)
100 if err != nil {
101 return nil, err
102 }
103 defer resp.Body.Close()
104
105 if resp.StatusCode != http.StatusOK {
106 io.Copy(io.Discard, resp.Body)
107 return nil, fmt.Errorf("fetching client metadata returned response code %d", resp.StatusCode)
108 }
109
110 b, err := io.ReadAll(resp.Body)
111 if err != nil {
112 return nil, fmt.Errorf("error reading bytes from client response: %w", err)
113 }
114
115 validated, err := validateAndParseMetadata(clientId, b)
116 if err != nil {
117 return nil, err
118 }
119
120 return validated, nil
121 } else {
122 return &metadataCached, nil
123 }
124}
125
126func (cm *Manager) getClientJwks(ctx context.Context, clientId, jwksUri string) (jwk.Key, error) {
127 jwks, ok := cm.jwksCache.Get(clientId)
128 if !ok {
129 req, err := http.NewRequestWithContext(ctx, "GET", jwksUri, nil)
130 if err != nil {
131 return nil, err
132 }
133
134 resp, err := cm.cli.Do(req)
135 if err != nil {
136 return nil, err
137 }
138 defer resp.Body.Close()
139
140 if resp.StatusCode != http.StatusOK {
141 io.Copy(io.Discard, resp.Body)
142 return nil, fmt.Errorf("fetching client jwks returned response code %d", resp.StatusCode)
143 }
144
145 type Keys struct {
146 Keys []map[string]any `json:"keys"`
147 }
148
149 var keys Keys
150 if err := json.NewDecoder(resp.Body).Decode(&keys); err != nil {
151 return nil, fmt.Errorf("error unmarshaling keys response: %w", err)
152 }
153
154 if len(keys.Keys) == 0 {
155 return nil, errors.New("no keys in jwks response")
156 }
157
158 // TODO: this is again bad, we should be figuring out which one we need to use...
159 b, err := json.Marshal(keys.Keys[0])
160 if err != nil {
161 return nil, fmt.Errorf("could not marshal key: %w", err)
162 }
163
164 k, err := helpers.ParseJWKFromBytes(b)
165 if err != nil {
166 return nil, err
167 }
168
169 jwks = k
170 }
171
172 return jwks, nil
173}
174
175func validateAndParseMetadata(clientId string, b []byte) (*Metadata, error) {
176 var metadataMap map[string]any
177 if err := json.Unmarshal(b, &metadataMap); err != nil {
178 return nil, fmt.Errorf("error unmarshaling metadata: %w", err)
179 }
180
181 _, jwksOk := metadataMap["jwks"].(string)
182 _, jwksUriOk := metadataMap["jwks_uri"].(string)
183 if jwksOk && jwksUriOk {
184 return nil, errors.New("jwks_uri and jwks are mutually exclusive")
185 }
186
187 for _, k := range []string{
188 "default_max_age",
189 "userinfo_signed_response_alg",
190 "id_token_signed_response_alg",
191 "userinfo_encryhpted_response_alg",
192 "authorization_encrypted_response_enc",
193 "authorization_encrypted_response_alg",
194 "tls_client_certificate_bound_access_tokens",
195 } {
196 _, kOk := metadataMap[k]
197 if kOk {
198 return nil, fmt.Errorf("unsupported `%s` parameter", k)
199 }
200 }
201
202 var metadata Metadata
203 if err := json.Unmarshal(b, &metadata); err != nil {
204 return nil, fmt.Errorf("error unmarshaling metadata: %w", err)
205 }
206
207 u, err := url.Parse(metadata.ClientURI)
208 if err != nil {
209 return nil, fmt.Errorf("unable to parse client uri: %w", err)
210 }
211
212 if isLocalHostname(u.Hostname()) {
213 return nil, errors.New("`client_uri` hostname is invalid")
214 }
215
216 if metadata.Scope == "" {
217 return nil, errors.New("missing `scopes` scope")
218 }
219
220 scopes := strings.Split(metadata.Scope, " ")
221 if !slices.Contains(scopes, "atproto") {
222 return nil, errors.New("missing `atproto` scope")
223 }
224
225 scopesMap := map[string]bool{}
226 for _, scope := range scopes {
227 if scopesMap[scope] {
228 return nil, fmt.Errorf("duplicate scope `%s`", scope)
229 }
230
231 // TODO: check for unsupported scopes
232
233 scopesMap[scope] = true
234 }
235
236 grantTypesMap := map[string]bool{}
237 for _, gt := range metadata.GrantTypes {
238 if grantTypesMap[gt] {
239 return nil, fmt.Errorf("duplicate grant type `%s`", gt)
240 }
241
242 switch gt {
243 case "implicit":
244 return nil, errors.New("grantg type `implicit` is not allowed")
245 case "authorization_code", "refresh_token":
246 // TODO check if this grant type is supported
247 default:
248 return nil, fmt.Errorf("grant tyhpe `%s` is not supported", gt)
249 }
250
251 grantTypesMap[gt] = true
252 }
253
254 if metadata.ClientID != clientId {
255 return nil, errors.New("`client_id` does not match")
256 }
257
258 subjectType, subjectTypeOk := metadataMap["subject_type"].(string)
259 if subjectTypeOk && subjectType != "public" {
260 return nil, errors.New("only public `subject_type` is supported")
261 }
262
263 switch metadata.TokenEndpointAuthMethod {
264 case "none":
265 if metadata.TokenEndpointAuthSigningAlg != "" {
266 return nil, errors.New("token_endpoint_auth_method `none` must not have token_endpoint_auth_signing_alg")
267 }
268 case "private_key_jwt":
269 if metadata.JWKS == nil && metadata.JWKSURI == nil {
270 return nil, errors.New("private_key_jwt auth method requires jwks or jwks_uri")
271 }
272
273 if metadata.JWKS != nil && len(metadata.JWKS.Keys) == 0 {
274 return nil, errors.New("private_key_jwt auth method requires atleast one key in jwks")
275 }
276
277 if metadata.TokenEndpointAuthSigningAlg == "" {
278 return nil, errors.New("missing token_endpoint_auth_signing_alg in client metadata")
279 }
280 default:
281 return nil, fmt.Errorf("unsupported client authentication method `%s`", metadata.TokenEndpointAuthMethod)
282 }
283
284 if !metadata.DpopBoundAccessTokens {
285 return nil, errors.New("dpop_bound_access_tokens must be true")
286 }
287
288 if !slices.Contains(metadata.ResponseTypes, "code") {
289 return nil, errors.New("response_types must inclue `code`")
290 }
291
292 if !slices.Contains(metadata.GrantTypes, "authorization_code") {
293 return nil, errors.New("the `code` response type requires that `grant_types` contains `authorization_code`")
294 }
295
296 if len(metadata.RedirectURIs) == 0 {
297 return nil, errors.New("at least one `redirect_uri` is required")
298 }
299
300 if metadata.ApplicationType == "native" && metadata.TokenEndpointAuthMethod != "none" {
301 return nil, errors.New("native clients must authenticate using `none` method")
302 }
303
304 if metadata.ApplicationType == "web" && slices.Contains(metadata.GrantTypes, "implicit") {
305 for _, ruri := range metadata.RedirectURIs {
306 u, err := url.Parse(ruri)
307 if err != nil {
308 return nil, fmt.Errorf("error parsing redirect uri: %w", err)
309 }
310
311 if u.Scheme != "https" {
312 return nil, errors.New("web clients must use https redirect uris")
313 }
314
315 if u.Hostname() == "localhost" {
316 return nil, errors.New("web clients must not use localhost as the hostname")
317 }
318 }
319 }
320
321 for _, ruri := range metadata.RedirectURIs {
322 u, err := url.Parse(ruri)
323 if err != nil {
324 return nil, fmt.Errorf("error parsing redirect uri: %w", err)
325 }
326
327 if u.User != nil {
328 if u.User.Username() != "" {
329 return nil, fmt.Errorf("redirect uri %s must not contain credentials", ruri)
330 }
331
332 if _, hasPass := u.User.Password(); hasPass {
333 return nil, fmt.Errorf("redirect uri %s must not contain credentials", ruri)
334 }
335 }
336
337 switch true {
338 case u.Hostname() == "localhost":
339 return nil, errors.New("loopback redirect uri is not allowed (use explicit ips instead)")
340 case u.Hostname() == "127.0.0.1", u.Hostname() == "[::1]":
341 if metadata.ApplicationType != "native" {
342 return nil, errors.New("loopback redirect uris are only allowed for native apps")
343 }
344
345 if u.Port() != "" {
346 // reference impl doesn't do anything with this?
347 }
348
349 if u.Scheme != "http" {
350 return nil, fmt.Errorf("loopback redirect uri %s must use http", ruri)
351 }
352
353 break
354 case u.Scheme == "http":
355 return nil, errors.New("only loopbvack redirect uris are allowed to use the `http` scheme")
356 case u.Scheme == "https":
357 if isLocalHostname(u.Hostname()) {
358 return nil, fmt.Errorf("redirect uri %s's domain must not be a local hostname", ruri)
359 }
360 break
361 case strings.Contains(u.Scheme, "."):
362 if metadata.ApplicationType != "native" {
363 return nil, errors.New("private-use uri scheme redirect uris are only allowed for native apps")
364 }
365
366 revdomain := reverseDomain(u.Scheme)
367
368 if isLocalHostname(revdomain) {
369 return nil, errors.New("private use uri scheme redirect uris must not be local hostnames")
370 }
371
372 if strings.HasPrefix(u.String(), fmt.Sprintf("%s://", u.Scheme)) || u.Hostname() != "" || u.Port() != "" {
373 return nil, fmt.Errorf("private use uri scheme must be in the form ")
374 }
375 default:
376 return nil, fmt.Errorf("invalid redirect uri scheme `%s`", u.Scheme)
377 }
378 }
379
380 return &metadata, nil
381}
382
383func isLocalHostname(hostname string) bool {
384 pts := strings.Split(hostname, ".")
385 if len(pts) < 2 {
386 return true
387 }
388
389 tld := strings.ToLower(pts[len(pts)-1])
390 return tld == "test" || tld == "local" || tld == "localhost" || tld == "invalid" || tld == "example"
391}
392
393func reverseDomain(domain string) string {
394 pts := strings.Split(domain, ".")
395 slices.Reverse(pts)
396 return strings.Join(pts, ".")
397}