An atproto PDS written in Go
1package client_manager 2 3import ( 4 "context" 5 "encoding/json" 6 "errors" 7 "fmt" 8 "io" 9 "log/slog" 10 "net/http" 11 "net/url" 12 "slices" 13 "strings" 14 "time" 15 16 cache "github.com/go-pkgz/expirable-cache/v3" 17 "github.com/haileyok/cocoon/internal/helpers" 18 "github.com/haileyok/cocoon/oauth" 19 "github.com/lestrrat-go/jwx/v2/jwk" 20) 21 22type ClientManager struct { 23 cli *http.Client 24 logger *slog.Logger 25 jwksCache cache.Cache[string, jwk.Key] 26 metadataCache cache.Cache[string, oauth.ClientMetadata] 27} 28 29type Args struct { 30 Cli *http.Client 31 Logger *slog.Logger 32} 33 34func New(args Args) *ClientManager { 35 if args.Logger == nil { 36 args.Logger = slog.Default() 37 } 38 39 if args.Cli == nil { 40 args.Cli = http.DefaultClient 41 } 42 43 jwksCache := cache.NewCache[string, jwk.Key]().WithLRU().WithMaxKeys(500).WithTTL(5 * time.Minute) 44 metadataCache := cache.NewCache[string, oauth.ClientMetadata]().WithLRU().WithMaxKeys(500).WithTTL(5 * time.Minute) 45 46 return &ClientManager{ 47 cli: args.Cli, 48 logger: args.Logger, 49 jwksCache: jwksCache, 50 metadataCache: metadataCache, 51 } 52} 53 54func (cm *ClientManager) GetClient(ctx context.Context, clientId string) (*oauth.Client, error) { 55 metadata, err := cm.getClientMetadata(ctx, clientId) 56 if err != nil { 57 return nil, err 58 } 59 60 var jwks jwk.Key 61 if metadata.JWKS != nil { 62 // TODO: this is kinda bad but whatever for now. there could obviously be more than one jwk, and we need to 63 // make sure we use the right one 64 k, err := helpers.ParseJWKFromBytes((*metadata.JWKS)[0]) 65 if err != nil { 66 return nil, err 67 } 68 jwks = k 69 } else if metadata.JWKSURI != nil { 70 maybeJwks, err := cm.getClientJwks(ctx, clientId, *metadata.JWKSURI) 71 if err != nil { 72 return nil, err 73 } 74 75 jwks = maybeJwks 76 } 77 78 return &oauth.Client{ 79 Metadata: metadata, 80 JWKS: jwks, 81 }, nil 82} 83 84func (cm *ClientManager) getClientMetadata(ctx context.Context, clientId string) (*oauth.ClientMetadata, error) { 85 metadataCached, ok := cm.metadataCache.Get(clientId) 86 if !ok { 87 req, err := http.NewRequestWithContext(ctx, "GET", clientId, nil) 88 if err != nil { 89 return nil, err 90 } 91 92 resp, err := cm.cli.Do(req) 93 if err != nil { 94 return nil, err 95 } 96 defer resp.Body.Close() 97 98 if resp.StatusCode != http.StatusOK { 99 io.Copy(io.Discard, resp.Body) 100 return nil, fmt.Errorf("fetching client metadata returned response code %d", resp.StatusCode) 101 } 102 103 b, err := io.ReadAll(resp.Body) 104 if err != nil { 105 return nil, fmt.Errorf("error reading bytes from client response: %w", err) 106 } 107 108 validated, err := validateAndParseMetadata(clientId, b) 109 if err != nil { 110 return nil, err 111 } 112 113 return validated, nil 114 } else { 115 return &metadataCached, nil 116 } 117} 118 119func (cm *ClientManager) getClientJwks(ctx context.Context, clientId, jwksUri string) (jwk.Key, error) { 120 jwks, ok := cm.jwksCache.Get(clientId) 121 if !ok { 122 req, err := http.NewRequestWithContext(ctx, "GET", jwksUri, nil) 123 if err != nil { 124 return nil, err 125 } 126 127 resp, err := cm.cli.Do(req) 128 if err != nil { 129 return nil, err 130 } 131 defer resp.Body.Close() 132 133 if resp.StatusCode != http.StatusOK { 134 io.Copy(io.Discard, resp.Body) 135 return nil, fmt.Errorf("fetching client jwks returned response code %d", resp.StatusCode) 136 } 137 138 type Keys struct { 139 Keys []map[string]any `json:"keys"` 140 } 141 142 var keys Keys 143 if err := json.NewDecoder(resp.Body).Decode(&keys); err != nil { 144 return nil, fmt.Errorf("error unmarshaling keys response: %w", err) 145 } 146 147 if len(keys.Keys) == 0 { 148 return nil, errors.New("no keys in jwks response") 149 } 150 151 // TODO: this is again bad, we should be figuring out which one we need to use... 152 b, err := json.Marshal(keys.Keys[0]) 153 if err != nil { 154 return nil, fmt.Errorf("could not marshal key: %w", err) 155 } 156 157 k, err := helpers.ParseJWKFromBytes(b) 158 if err != nil { 159 return nil, err 160 } 161 162 jwks = k 163 } 164 165 return jwks, nil 166} 167 168func validateAndParseMetadata(clientId string, b []byte) (*oauth.ClientMetadata, error) { 169 var metadataMap map[string]any 170 if err := json.Unmarshal(b, &metadataMap); err != nil { 171 return nil, fmt.Errorf("error unmarshaling metadata: %w", err) 172 } 173 174 _, jwksOk := metadataMap["jwks"].(string) 175 _, jwksUriOk := metadataMap["jwks_uri"].(string) 176 if jwksOk && jwksUriOk { 177 return nil, errors.New("jwks_uri and jwks are mutually exclusive") 178 } 179 180 for _, k := range []string{ 181 "default_max_age", 182 "userinfo_signed_response_alg", 183 "id_token_signed_response_alg", 184 "userinfo_encryhpted_response_alg", 185 "authorization_encrypted_response_enc", 186 "authorization_encrypted_response_alg", 187 "tls_client_certificate_bound_access_tokens", 188 } { 189 _, kOk := metadataMap[k] 190 if kOk { 191 return nil, fmt.Errorf("unsupported `%s` parameter", k) 192 } 193 } 194 195 var metadata oauth.ClientMetadata 196 if err := json.Unmarshal(b, &metadata); err != nil { 197 return nil, fmt.Errorf("error unmarshaling metadata: %w", err) 198 } 199 200 u, err := url.Parse(metadata.ClientURI) 201 if err != nil { 202 return nil, fmt.Errorf("unable to parse client uri: %w", err) 203 } 204 205 if isLocalHostname(u.Hostname()) { 206 return nil, errors.New("`client_uri` hostname is invalid") 207 } 208 209 if metadata.Scope == "" { 210 return nil, errors.New("missing `scopes` scope") 211 } 212 213 scopes := strings.Split(metadata.Scope, " ") 214 if !slices.Contains(scopes, "atproto") { 215 return nil, errors.New("missing `atproto` scope") 216 } 217 218 scopesMap := map[string]bool{} 219 for _, scope := range scopes { 220 if scopesMap[scope] { 221 return nil, fmt.Errorf("duplicate scope `%s`", scope) 222 } 223 224 // TODO: check for unsupported scopes 225 226 scopesMap[scope] = true 227 } 228 229 grantTypesMap := map[string]bool{} 230 for _, gt := range metadata.GrantTypes { 231 if grantTypesMap[gt] { 232 return nil, fmt.Errorf("duplicate grant type `%s`", gt) 233 } 234 235 switch gt { 236 case "implicit": 237 return nil, errors.New("grantg type `implicit` is not allowed") 238 case "authorization_code", "refresh_token": 239 // TODO check if this grant type is supported 240 default: 241 return nil, fmt.Errorf("grant tyhpe `%s` is not supported", gt) 242 } 243 244 grantTypesMap[gt] = true 245 } 246 247 if metadata.ClientID != clientId { 248 return nil, errors.New("`client_id` does not match") 249 } 250 251 subjectType, subjectTypeOk := metadataMap["subject_type"].(string) 252 if subjectTypeOk && subjectType != "public" { 253 return nil, errors.New("only public `subject_type` is supported") 254 } 255 256 switch metadata.TokenEndpointAuthMethod { 257 case "none": 258 if metadata.TokenEndpointAuthSigningAlg != "" { 259 return nil, errors.New("token_endpoint_auth_method `none` must not have token_endpoint_auth_signing_alg") 260 } 261 case "private_key_jwt": 262 if metadata.JWKS == nil && metadata.JWKSURI == nil { 263 return nil, errors.New("private_key_jwt auth method requires jwks or jwks_uri") 264 } 265 266 if metadata.JWKS != nil && len(*metadata.JWKS) == 0 { 267 return nil, errors.New("private_key_jwt auth method requires atleast one key in jwks") 268 } 269 270 if metadata.TokenEndpointAuthSigningAlg == "" { 271 return nil, errors.New("missing token_endpoint_auth_signing_alg in client metadata") 272 } 273 default: 274 return nil, fmt.Errorf("unsupported client authentication method `%s`", metadata.TokenEndpointAuthMethod) 275 } 276 277 if !metadata.DpopBoundAccessTokens { 278 return nil, errors.New("dpop_bound_access_tokens must be true") 279 } 280 281 if !slices.Contains(metadata.ResponseTypes, "code") { 282 return nil, errors.New("response_types must inclue `code`") 283 } 284 285 if !slices.Contains(metadata.GrantTypes, "authorization_code") { 286 return nil, errors.New("the `code` response type requires that `grant_types` contains `authorization_code`") 287 } 288 289 if len(metadata.RedirectURIs) == 0 { 290 return nil, errors.New("at least one `redirect_uri` is required") 291 } 292 293 if metadata.ApplicationType == "native" && metadata.TokenEndpointAuthMethod == "none" { 294 return nil, errors.New("native clients must authenticate using `none` method") 295 } 296 297 if metadata.ApplicationType == "web" && slices.Contains(metadata.GrantTypes, "implicit") { 298 for _, ruri := range metadata.RedirectURIs { 299 u, err := url.Parse(ruri) 300 if err != nil { 301 return nil, fmt.Errorf("error parsing redirect uri: %w", err) 302 } 303 304 if u.Scheme != "https" { 305 return nil, errors.New("web clients must use https redirect uris") 306 } 307 308 if u.Hostname() == "localhost" { 309 return nil, errors.New("web clients must not use localhost as the hostname") 310 } 311 } 312 } 313 314 for _, ruri := range metadata.RedirectURIs { 315 u, err := url.Parse(ruri) 316 if err != nil { 317 return nil, fmt.Errorf("error parsing redirect uri: %w", err) 318 } 319 320 if u.User != nil { 321 if u.User.Username() != "" { 322 return nil, fmt.Errorf("redirect uri %s must not contain credentials", ruri) 323 } 324 325 if _, hasPass := u.User.Password(); hasPass { 326 return nil, fmt.Errorf("redirect uri %s must not contain credentials", ruri) 327 } 328 } 329 330 switch true { 331 case u.Hostname() == "localhost": 332 return nil, errors.New("loopback redirect uri is not allowed (use explicit ips instead)") 333 case u.Hostname() == "127.0.0.1", u.Hostname() == "[::1]": 334 if metadata.ApplicationType != "native" { 335 return nil, errors.New("loopback redirect uris are only allowed for native apps") 336 } 337 338 if u.Port() != "" { 339 // reference impl doesn't do anything with this? 340 } 341 342 if u.Scheme != "http" { 343 return nil, fmt.Errorf("loopback redirect uri %s must use http", ruri) 344 } 345 346 break 347 case u.Scheme == "http": 348 return nil, errors.New("only loopbvack redirect uris are allowed to use the `http` scheme") 349 case u.Scheme == "https": 350 if isLocalHostname(u.Hostname()) { 351 return nil, fmt.Errorf("redirect uri %s's domain must not be a local hostname", ruri) 352 } 353 break 354 case strings.Contains(u.Scheme, "."): 355 if metadata.ApplicationType != "native" { 356 return nil, errors.New("private-use uri scheme redirect uris are only allowed for native apps") 357 } 358 359 revdomain := reverseDomain(u.Scheme) 360 361 if isLocalHostname(revdomain) { 362 return nil, errors.New("private use uri scheme redirect uris must not be local hostnames") 363 } 364 365 if strings.HasPrefix(u.String(), fmt.Sprintf("%s://", u.Scheme)) || u.Hostname() != "" || u.Port() != "" { 366 return nil, fmt.Errorf("private use uri scheme must be in the form ") 367 } 368 default: 369 return nil, fmt.Errorf("invalid redirect uri scheme `%s`", u.Scheme) 370 } 371 } 372 373 return &metadata, nil 374} 375 376func isLocalHostname(hostname string) bool { 377 pts := strings.Split(hostname, ".") 378 if len(pts) < 2 { 379 return true 380 } 381 382 tld := strings.ToLower(pts[len(pts)-1]) 383 return tld == "test" || tld == "local" || tld == "localhost" || tld == "invalid" || tld == "example" 384} 385 386func reverseDomain(domain string) string { 387 pts := strings.Split(domain, ".") 388 slices.Reverse(pts) 389 return strings.Join(pts, ".") 390}