1package client_manager
2
3import (
4 "context"
5 "encoding/json"
6 "errors"
7 "fmt"
8 "io"
9 "log/slog"
10 "net/http"
11 "net/url"
12 "slices"
13 "strings"
14 "time"
15
16 cache "github.com/go-pkgz/expirable-cache/v3"
17 "github.com/haileyok/cocoon/internal/helpers"
18 "github.com/haileyok/cocoon/oauth"
19 "github.com/lestrrat-go/jwx/v2/jwk"
20)
21
22type ClientManager struct {
23 cli *http.Client
24 logger *slog.Logger
25 jwksCache cache.Cache[string, jwk.Key]
26 metadataCache cache.Cache[string, oauth.ClientMetadata]
27}
28
29type Args struct {
30 Cli *http.Client
31 Logger *slog.Logger
32}
33
34func New(args Args) *ClientManager {
35 if args.Logger == nil {
36 args.Logger = slog.Default()
37 }
38
39 if args.Cli == nil {
40 args.Cli = http.DefaultClient
41 }
42
43 jwksCache := cache.NewCache[string, jwk.Key]().WithLRU().WithMaxKeys(500).WithTTL(5 * time.Minute)
44 metadataCache := cache.NewCache[string, oauth.ClientMetadata]().WithLRU().WithMaxKeys(500).WithTTL(5 * time.Minute)
45
46 return &ClientManager{
47 cli: args.Cli,
48 logger: args.Logger,
49 jwksCache: jwksCache,
50 metadataCache: metadataCache,
51 }
52}
53
54func (cm *ClientManager) GetClient(ctx context.Context, clientId string) (*oauth.Client, error) {
55 metadata, err := cm.getClientMetadata(ctx, clientId)
56 if err != nil {
57 return nil, err
58 }
59
60 var jwks jwk.Key
61 if metadata.JWKS != nil {
62 // TODO: this is kinda bad but whatever for now. there could obviously be more than one jwk, and we need to
63 // make sure we use the right one
64 k, err := helpers.ParseJWKFromBytes((*metadata.JWKS)[0])
65 if err != nil {
66 return nil, err
67 }
68 jwks = k
69 } else if metadata.JWKSURI != nil {
70 maybeJwks, err := cm.getClientJwks(ctx, clientId, *metadata.JWKSURI)
71 if err != nil {
72 return nil, err
73 }
74
75 jwks = maybeJwks
76 }
77
78 return &oauth.Client{
79 Metadata: metadata,
80 JWKS: jwks,
81 }, nil
82}
83
84func (cm *ClientManager) getClientMetadata(ctx context.Context, clientId string) (*oauth.ClientMetadata, error) {
85 metadataCached, ok := cm.metadataCache.Get(clientId)
86 if !ok {
87 req, err := http.NewRequestWithContext(ctx, "GET", clientId, nil)
88 if err != nil {
89 return nil, err
90 }
91
92 resp, err := cm.cli.Do(req)
93 if err != nil {
94 return nil, err
95 }
96 defer resp.Body.Close()
97
98 if resp.StatusCode != http.StatusOK {
99 io.Copy(io.Discard, resp.Body)
100 return nil, fmt.Errorf("fetching client metadata returned response code %d", resp.StatusCode)
101 }
102
103 b, err := io.ReadAll(resp.Body)
104 if err != nil {
105 return nil, fmt.Errorf("error reading bytes from client response: %w", err)
106 }
107
108 validated, err := validateAndParseMetadata(clientId, b)
109 if err != nil {
110 return nil, err
111 }
112
113 return validated, nil
114 } else {
115 return &metadataCached, nil
116 }
117}
118
119func (cm *ClientManager) getClientJwks(ctx context.Context, clientId, jwksUri string) (jwk.Key, error) {
120 jwks, ok := cm.jwksCache.Get(clientId)
121 if !ok {
122 req, err := http.NewRequestWithContext(ctx, "GET", jwksUri, nil)
123 if err != nil {
124 return nil, err
125 }
126
127 resp, err := cm.cli.Do(req)
128 if err != nil {
129 return nil, err
130 }
131 defer resp.Body.Close()
132
133 if resp.StatusCode != http.StatusOK {
134 io.Copy(io.Discard, resp.Body)
135 return nil, fmt.Errorf("fetching client jwks returned response code %d", resp.StatusCode)
136 }
137
138 type Keys struct {
139 Keys []map[string]any `json:"keys"`
140 }
141
142 var keys Keys
143 if err := json.NewDecoder(resp.Body).Decode(&keys); err != nil {
144 return nil, fmt.Errorf("error unmarshaling keys response: %w", err)
145 }
146
147 if len(keys.Keys) == 0 {
148 return nil, errors.New("no keys in jwks response")
149 }
150
151 // TODO: this is again bad, we should be figuring out which one we need to use...
152 b, err := json.Marshal(keys.Keys[0])
153 if err != nil {
154 return nil, fmt.Errorf("could not marshal key: %w", err)
155 }
156
157 k, err := helpers.ParseJWKFromBytes(b)
158 if err != nil {
159 return nil, err
160 }
161
162 jwks = k
163 }
164
165 return jwks, nil
166}
167
168func validateAndParseMetadata(clientId string, b []byte) (*oauth.ClientMetadata, error) {
169 var metadataMap map[string]any
170 if err := json.Unmarshal(b, &metadataMap); err != nil {
171 return nil, fmt.Errorf("error unmarshaling metadata: %w", err)
172 }
173
174 _, jwksOk := metadataMap["jwks"].(string)
175 _, jwksUriOk := metadataMap["jwks_uri"].(string)
176 if jwksOk && jwksUriOk {
177 return nil, errors.New("jwks_uri and jwks are mutually exclusive")
178 }
179
180 for _, k := range []string{
181 "default_max_age",
182 "userinfo_signed_response_alg",
183 "id_token_signed_response_alg",
184 "userinfo_encryhpted_response_alg",
185 "authorization_encrypted_response_enc",
186 "authorization_encrypted_response_alg",
187 "tls_client_certificate_bound_access_tokens",
188 } {
189 _, kOk := metadataMap[k]
190 if kOk {
191 return nil, fmt.Errorf("unsupported `%s` parameter", k)
192 }
193 }
194
195 var metadata oauth.ClientMetadata
196 if err := json.Unmarshal(b, &metadata); err != nil {
197 return nil, fmt.Errorf("error unmarshaling metadata: %w", err)
198 }
199
200 u, err := url.Parse(metadata.ClientURI)
201 if err != nil {
202 return nil, fmt.Errorf("unable to parse client uri: %w", err)
203 }
204
205 if isLocalHostname(u.Hostname()) {
206 return nil, errors.New("`client_uri` hostname is invalid")
207 }
208
209 if metadata.Scope == "" {
210 return nil, errors.New("missing `scopes` scope")
211 }
212
213 scopes := strings.Split(metadata.Scope, " ")
214 if !slices.Contains(scopes, "atproto") {
215 return nil, errors.New("missing `atproto` scope")
216 }
217
218 scopesMap := map[string]bool{}
219 for _, scope := range scopes {
220 if scopesMap[scope] {
221 return nil, fmt.Errorf("duplicate scope `%s`", scope)
222 }
223
224 // TODO: check for unsupported scopes
225
226 scopesMap[scope] = true
227 }
228
229 grantTypesMap := map[string]bool{}
230 for _, gt := range metadata.GrantTypes {
231 if grantTypesMap[gt] {
232 return nil, fmt.Errorf("duplicate grant type `%s`", gt)
233 }
234
235 switch gt {
236 case "implicit":
237 return nil, errors.New("grantg type `implicit` is not allowed")
238 case "authorization_code", "refresh_token":
239 // TODO check if this grant type is supported
240 default:
241 return nil, fmt.Errorf("grant tyhpe `%s` is not supported", gt)
242 }
243
244 grantTypesMap[gt] = true
245 }
246
247 if metadata.ClientID != clientId {
248 return nil, errors.New("`client_id` does not match")
249 }
250
251 subjectType, subjectTypeOk := metadataMap["subject_type"].(string)
252 if subjectTypeOk && subjectType != "public" {
253 return nil, errors.New("only public `subject_type` is supported")
254 }
255
256 switch metadata.TokenEndpointAuthMethod {
257 case "none":
258 if metadata.TokenEndpointAuthSigningAlg != "" {
259 return nil, errors.New("token_endpoint_auth_method `none` must not have token_endpoint_auth_signing_alg")
260 }
261 case "private_key_jwt":
262 if metadata.JWKS == nil && metadata.JWKSURI == nil {
263 return nil, errors.New("private_key_jwt auth method requires jwks or jwks_uri")
264 }
265
266 if metadata.JWKS != nil && len(*metadata.JWKS) == 0 {
267 return nil, errors.New("private_key_jwt auth method requires atleast one key in jwks")
268 }
269
270 if metadata.TokenEndpointAuthSigningAlg == "" {
271 return nil, errors.New("missing token_endpoint_auth_signing_alg in client metadata")
272 }
273 default:
274 return nil, fmt.Errorf("unsupported client authentication method `%s`", metadata.TokenEndpointAuthMethod)
275 }
276
277 if !metadata.DpopBoundAccessTokens {
278 return nil, errors.New("dpop_bound_access_tokens must be true")
279 }
280
281 if !slices.Contains(metadata.ResponseTypes, "code") {
282 return nil, errors.New("response_types must inclue `code`")
283 }
284
285 if !slices.Contains(metadata.GrantTypes, "authorization_code") {
286 return nil, errors.New("the `code` response type requires that `grant_types` contains `authorization_code`")
287 }
288
289 if len(metadata.RedirectURIs) == 0 {
290 return nil, errors.New("at least one `redirect_uri` is required")
291 }
292
293 if metadata.ApplicationType == "native" && metadata.TokenEndpointAuthMethod == "none" {
294 return nil, errors.New("native clients must authenticate using `none` method")
295 }
296
297 if metadata.ApplicationType == "web" && slices.Contains(metadata.GrantTypes, "implicit") {
298 for _, ruri := range metadata.RedirectURIs {
299 u, err := url.Parse(ruri)
300 if err != nil {
301 return nil, fmt.Errorf("error parsing redirect uri: %w", err)
302 }
303
304 if u.Scheme != "https" {
305 return nil, errors.New("web clients must use https redirect uris")
306 }
307
308 if u.Hostname() == "localhost" {
309 return nil, errors.New("web clients must not use localhost as the hostname")
310 }
311 }
312 }
313
314 for _, ruri := range metadata.RedirectURIs {
315 u, err := url.Parse(ruri)
316 if err != nil {
317 return nil, fmt.Errorf("error parsing redirect uri: %w", err)
318 }
319
320 if u.User != nil {
321 if u.User.Username() != "" {
322 return nil, fmt.Errorf("redirect uri %s must not contain credentials", ruri)
323 }
324
325 if _, hasPass := u.User.Password(); hasPass {
326 return nil, fmt.Errorf("redirect uri %s must not contain credentials", ruri)
327 }
328 }
329
330 switch true {
331 case u.Hostname() == "localhost":
332 return nil, errors.New("loopback redirect uri is not allowed (use explicit ips instead)")
333 case u.Hostname() == "127.0.0.1", u.Hostname() == "[::1]":
334 if metadata.ApplicationType != "native" {
335 return nil, errors.New("loopback redirect uris are only allowed for native apps")
336 }
337
338 if u.Port() != "" {
339 // reference impl doesn't do anything with this?
340 }
341
342 if u.Scheme != "http" {
343 return nil, fmt.Errorf("loopback redirect uri %s must use http", ruri)
344 }
345
346 break
347 case u.Scheme == "http":
348 return nil, errors.New("only loopbvack redirect uris are allowed to use the `http` scheme")
349 case u.Scheme == "https":
350 if isLocalHostname(u.Hostname()) {
351 return nil, fmt.Errorf("redirect uri %s's domain must not be a local hostname", ruri)
352 }
353 break
354 case strings.Contains(u.Scheme, "."):
355 if metadata.ApplicationType != "native" {
356 return nil, errors.New("private-use uri scheme redirect uris are only allowed for native apps")
357 }
358
359 revdomain := reverseDomain(u.Scheme)
360
361 if isLocalHostname(revdomain) {
362 return nil, errors.New("private use uri scheme redirect uris must not be local hostnames")
363 }
364
365 if strings.HasPrefix(u.String(), fmt.Sprintf("%s://", u.Scheme)) || u.Hostname() != "" || u.Port() != "" {
366 return nil, fmt.Errorf("private use uri scheme must be in the form ")
367 }
368 default:
369 return nil, fmt.Errorf("invalid redirect uri scheme `%s`", u.Scheme)
370 }
371 }
372
373 return &metadata, nil
374}
375
376func isLocalHostname(hostname string) bool {
377 pts := strings.Split(hostname, ".")
378 if len(pts) < 2 {
379 return true
380 }
381
382 tld := strings.ToLower(pts[len(pts)-1])
383 return tld == "test" || tld == "local" || tld == "localhost" || tld == "invalid" || tld == "example"
384}
385
386func reverseDomain(domain string) string {
387 pts := strings.Split(domain, ".")
388 slices.Reverse(pts)
389 return strings.Join(pts, ".")
390}