An atproto PDS written in Go
1package client 2 3import ( 4 "context" 5 "encoding/json" 6 "errors" 7 "fmt" 8 "io" 9 "log/slog" 10 "net/http" 11 "net/url" 12 "slices" 13 "strings" 14 "time" 15 16 cache "github.com/go-pkgz/expirable-cache/v3" 17 "github.com/haileyok/cocoon/internal/helpers" 18 "github.com/lestrrat-go/jwx/v2/jwk" 19) 20 21type Manager struct { 22 cli *http.Client 23 logger *slog.Logger 24 jwksCache cache.Cache[string, jwk.Key] 25 metadataCache cache.Cache[string, Metadata] 26} 27 28type ManagerArgs struct { 29 Cli *http.Client 30 Logger *slog.Logger 31} 32 33func NewManager(args ManagerArgs) *Manager { 34 if args.Logger == nil { 35 args.Logger = slog.Default() 36 } 37 38 if args.Cli == nil { 39 args.Cli = http.DefaultClient 40 } 41 42 jwksCache := cache.NewCache[string, jwk.Key]().WithLRU().WithMaxKeys(500).WithTTL(5 * time.Minute) 43 metadataCache := cache.NewCache[string, Metadata]().WithLRU().WithMaxKeys(500).WithTTL(5 * time.Minute) 44 45 return &Manager{ 46 cli: args.Cli, 47 logger: args.Logger, 48 jwksCache: jwksCache, 49 metadataCache: metadataCache, 50 } 51} 52 53func (cm *Manager) GetClient(ctx context.Context, clientId string) (*Client, error) { 54 metadata, err := cm.getClientMetadata(ctx, clientId) 55 if err != nil { 56 return nil, err 57 } 58 59 var jwks jwk.Key 60 if metadata.JWKS != nil && len(metadata.JWKS.Keys) > 0 { 61 // TODO: this is kinda bad but whatever for now. there could obviously be more than one jwk, and we need to 62 // make sure we use the right one 63 b, err := json.Marshal(metadata.JWKS.Keys[0]) 64 if err != nil { 65 return nil, err 66 } 67 68 k, err := helpers.ParseJWKFromBytes(b) 69 if err != nil { 70 return nil, err 71 } 72 73 jwks = k 74 } else if metadata.JWKSURI != nil { 75 maybeJwks, err := cm.getClientJwks(ctx, clientId, *metadata.JWKSURI) 76 if err != nil { 77 return nil, err 78 } 79 80 jwks = maybeJwks 81 } else { 82 return nil, fmt.Errorf("no valid jwks found in oauth client metadata") 83 } 84 85 return &Client{ 86 Metadata: metadata, 87 JWKS: jwks, 88 }, nil 89} 90 91func (cm *Manager) getClientMetadata(ctx context.Context, clientId string) (*Metadata, error) { 92 metadataCached, ok := cm.metadataCache.Get(clientId) 93 if !ok { 94 req, err := http.NewRequestWithContext(ctx, "GET", clientId, nil) 95 if err != nil { 96 return nil, err 97 } 98 99 resp, err := cm.cli.Do(req) 100 if err != nil { 101 return nil, err 102 } 103 defer resp.Body.Close() 104 105 if resp.StatusCode != http.StatusOK { 106 io.Copy(io.Discard, resp.Body) 107 return nil, fmt.Errorf("fetching client metadata returned response code %d", resp.StatusCode) 108 } 109 110 b, err := io.ReadAll(resp.Body) 111 if err != nil { 112 return nil, fmt.Errorf("error reading bytes from client response: %w", err) 113 } 114 115 validated, err := validateAndParseMetadata(clientId, b) 116 if err != nil { 117 return nil, err 118 } 119 120 return validated, nil 121 } else { 122 return &metadataCached, nil 123 } 124} 125 126func (cm *Manager) getClientJwks(ctx context.Context, clientId, jwksUri string) (jwk.Key, error) { 127 jwks, ok := cm.jwksCache.Get(clientId) 128 if !ok { 129 req, err := http.NewRequestWithContext(ctx, "GET", jwksUri, nil) 130 if err != nil { 131 return nil, err 132 } 133 134 resp, err := cm.cli.Do(req) 135 if err != nil { 136 return nil, err 137 } 138 defer resp.Body.Close() 139 140 if resp.StatusCode != http.StatusOK { 141 io.Copy(io.Discard, resp.Body) 142 return nil, fmt.Errorf("fetching client jwks returned response code %d", resp.StatusCode) 143 } 144 145 type Keys struct { 146 Keys []map[string]any `json:"keys"` 147 } 148 149 var keys Keys 150 if err := json.NewDecoder(resp.Body).Decode(&keys); err != nil { 151 return nil, fmt.Errorf("error unmarshaling keys response: %w", err) 152 } 153 154 if len(keys.Keys) == 0 { 155 return nil, errors.New("no keys in jwks response") 156 } 157 158 // TODO: this is again bad, we should be figuring out which one we need to use... 159 b, err := json.Marshal(keys.Keys[0]) 160 if err != nil { 161 return nil, fmt.Errorf("could not marshal key: %w", err) 162 } 163 164 k, err := helpers.ParseJWKFromBytes(b) 165 if err != nil { 166 return nil, err 167 } 168 169 jwks = k 170 } 171 172 return jwks, nil 173} 174 175func validateAndParseMetadata(clientId string, b []byte) (*Metadata, error) { 176 var metadataMap map[string]any 177 if err := json.Unmarshal(b, &metadataMap); err != nil { 178 return nil, fmt.Errorf("error unmarshaling metadata: %w", err) 179 } 180 181 _, jwksOk := metadataMap["jwks"].(string) 182 _, jwksUriOk := metadataMap["jwks_uri"].(string) 183 if jwksOk && jwksUriOk { 184 return nil, errors.New("jwks_uri and jwks are mutually exclusive") 185 } 186 187 for _, k := range []string{ 188 "default_max_age", 189 "userinfo_signed_response_alg", 190 "id_token_signed_response_alg", 191 "userinfo_encryhpted_response_alg", 192 "authorization_encrypted_response_enc", 193 "authorization_encrypted_response_alg", 194 "tls_client_certificate_bound_access_tokens", 195 } { 196 _, kOk := metadataMap[k] 197 if kOk { 198 return nil, fmt.Errorf("unsupported `%s` parameter", k) 199 } 200 } 201 202 var metadata Metadata 203 if err := json.Unmarshal(b, &metadata); err != nil { 204 return nil, fmt.Errorf("error unmarshaling metadata: %w", err) 205 } 206 207 u, err := url.Parse(metadata.ClientURI) 208 if err != nil { 209 return nil, fmt.Errorf("unable to parse client uri: %w", err) 210 } 211 212 if isLocalHostname(u.Hostname()) { 213 return nil, errors.New("`client_uri` hostname is invalid") 214 } 215 216 if metadata.Scope == "" { 217 return nil, errors.New("missing `scopes` scope") 218 } 219 220 scopes := strings.Split(metadata.Scope, " ") 221 if !slices.Contains(scopes, "atproto") { 222 return nil, errors.New("missing `atproto` scope") 223 } 224 225 scopesMap := map[string]bool{} 226 for _, scope := range scopes { 227 if scopesMap[scope] { 228 return nil, fmt.Errorf("duplicate scope `%s`", scope) 229 } 230 231 // TODO: check for unsupported scopes 232 233 scopesMap[scope] = true 234 } 235 236 grantTypesMap := map[string]bool{} 237 for _, gt := range metadata.GrantTypes { 238 if grantTypesMap[gt] { 239 return nil, fmt.Errorf("duplicate grant type `%s`", gt) 240 } 241 242 switch gt { 243 case "implicit": 244 return nil, errors.New("grantg type `implicit` is not allowed") 245 case "authorization_code", "refresh_token": 246 // TODO check if this grant type is supported 247 default: 248 return nil, fmt.Errorf("grant tyhpe `%s` is not supported", gt) 249 } 250 251 grantTypesMap[gt] = true 252 } 253 254 if metadata.ClientID != clientId { 255 return nil, errors.New("`client_id` does not match") 256 } 257 258 subjectType, subjectTypeOk := metadataMap["subject_type"].(string) 259 if subjectTypeOk && subjectType != "public" { 260 return nil, errors.New("only public `subject_type` is supported") 261 } 262 263 switch metadata.TokenEndpointAuthMethod { 264 case "none": 265 if metadata.TokenEndpointAuthSigningAlg != "" { 266 return nil, errors.New("token_endpoint_auth_method `none` must not have token_endpoint_auth_signing_alg") 267 } 268 case "private_key_jwt": 269 if metadata.JWKS == nil && metadata.JWKSURI == nil { 270 return nil, errors.New("private_key_jwt auth method requires jwks or jwks_uri") 271 } 272 273 if metadata.JWKS != nil && len(metadata.JWKS.Keys) == 0 { 274 return nil, errors.New("private_key_jwt auth method requires atleast one key in jwks") 275 } 276 277 if metadata.TokenEndpointAuthSigningAlg == "" { 278 return nil, errors.New("missing token_endpoint_auth_signing_alg in client metadata") 279 } 280 default: 281 return nil, fmt.Errorf("unsupported client authentication method `%s`", metadata.TokenEndpointAuthMethod) 282 } 283 284 if !metadata.DpopBoundAccessTokens { 285 return nil, errors.New("dpop_bound_access_tokens must be true") 286 } 287 288 if !slices.Contains(metadata.ResponseTypes, "code") { 289 return nil, errors.New("response_types must inclue `code`") 290 } 291 292 if !slices.Contains(metadata.GrantTypes, "authorization_code") { 293 return nil, errors.New("the `code` response type requires that `grant_types` contains `authorization_code`") 294 } 295 296 if len(metadata.RedirectURIs) == 0 { 297 return nil, errors.New("at least one `redirect_uri` is required") 298 } 299 300 if metadata.ApplicationType == "native" && metadata.TokenEndpointAuthMethod != "none" { 301 return nil, errors.New("native clients must authenticate using `none` method") 302 } 303 304 if metadata.ApplicationType == "web" && slices.Contains(metadata.GrantTypes, "implicit") { 305 for _, ruri := range metadata.RedirectURIs { 306 u, err := url.Parse(ruri) 307 if err != nil { 308 return nil, fmt.Errorf("error parsing redirect uri: %w", err) 309 } 310 311 if u.Scheme != "https" { 312 return nil, errors.New("web clients must use https redirect uris") 313 } 314 315 if u.Hostname() == "localhost" { 316 return nil, errors.New("web clients must not use localhost as the hostname") 317 } 318 } 319 } 320 321 for _, ruri := range metadata.RedirectURIs { 322 u, err := url.Parse(ruri) 323 if err != nil { 324 return nil, fmt.Errorf("error parsing redirect uri: %w", err) 325 } 326 327 if u.User != nil { 328 if u.User.Username() != "" { 329 return nil, fmt.Errorf("redirect uri %s must not contain credentials", ruri) 330 } 331 332 if _, hasPass := u.User.Password(); hasPass { 333 return nil, fmt.Errorf("redirect uri %s must not contain credentials", ruri) 334 } 335 } 336 337 switch true { 338 case u.Hostname() == "localhost": 339 return nil, errors.New("loopback redirect uri is not allowed (use explicit ips instead)") 340 case u.Hostname() == "127.0.0.1", u.Hostname() == "[::1]": 341 if metadata.ApplicationType != "native" { 342 return nil, errors.New("loopback redirect uris are only allowed for native apps") 343 } 344 345 if u.Port() != "" { 346 // reference impl doesn't do anything with this? 347 } 348 349 if u.Scheme != "http" { 350 return nil, fmt.Errorf("loopback redirect uri %s must use http", ruri) 351 } 352 353 break 354 case u.Scheme == "http": 355 return nil, errors.New("only loopbvack redirect uris are allowed to use the `http` scheme") 356 case u.Scheme == "https": 357 if isLocalHostname(u.Hostname()) { 358 return nil, fmt.Errorf("redirect uri %s's domain must not be a local hostname", ruri) 359 } 360 break 361 case strings.Contains(u.Scheme, "."): 362 if metadata.ApplicationType != "native" { 363 return nil, errors.New("private-use uri scheme redirect uris are only allowed for native apps") 364 } 365 366 revdomain := reverseDomain(u.Scheme) 367 368 if isLocalHostname(revdomain) { 369 return nil, errors.New("private use uri scheme redirect uris must not be local hostnames") 370 } 371 372 if strings.HasPrefix(u.String(), fmt.Sprintf("%s://", u.Scheme)) || u.Hostname() != "" || u.Port() != "" { 373 return nil, fmt.Errorf("private use uri scheme must be in the form ") 374 } 375 default: 376 return nil, fmt.Errorf("invalid redirect uri scheme `%s`", u.Scheme) 377 } 378 } 379 380 return &metadata, nil 381} 382 383func isLocalHostname(hostname string) bool { 384 pts := strings.Split(hostname, ".") 385 if len(pts) < 2 { 386 return true 387 } 388 389 tld := strings.ToLower(pts[len(pts)-1]) 390 return tld == "test" || tld == "local" || tld == "localhost" || tld == "invalid" || tld == "example" 391} 392 393func reverseDomain(domain string) string { 394 pts := strings.Split(domain, ".") 395 slices.Reverse(pts) 396 return strings.Join(pts, ".") 397}