An atproto PDS written in Go

oauth: fix `none` token method auth (#36)

* fix none method auth

* error log

* fix empty client uri

* fix missing client name

Changed files
+44 -28
oauth
client
+44 -28
oauth/client/manager.go
···
cli *http.Client
logger *slog.Logger
jwksCache cache.Cache[string, jwk.Key]
-
metadataCache cache.Cache[string, Metadata]
+
metadataCache cache.Cache[string, *Metadata]
}
type ManagerArgs struct {
···
}
jwksCache := cache.NewCache[string, jwk.Key]().WithLRU().WithMaxKeys(500).WithTTL(5 * time.Minute)
-
metadataCache := cache.NewCache[string, Metadata]().WithLRU().WithMaxKeys(500).WithTTL(5 * time.Minute)
+
metadataCache := cache.NewCache[string, *Metadata]().WithLRU().WithMaxKeys(500).WithTTL(5 * time.Minute)
return &Manager{
cli: args.Cli,
···
}
var jwks jwk.Key
-
if metadata.JWKS != nil && len(metadata.JWKS.Keys) > 0 {
-
// TODO: this is kinda bad but whatever for now. there could obviously be more than one jwk, and we need to
-
// make sure we use the right one
-
b, err := json.Marshal(metadata.JWKS.Keys[0])
-
if err != nil {
-
return nil, err
-
}
+
if metadata.TokenEndpointAuthMethod == "private_key_jwt" {
+
if metadata.JWKS != nil && len(metadata.JWKS.Keys) > 0 {
+
// TODO: this is kinda bad but whatever for now. there could obviously be more than one jwk, and we need to
+
// make sure we use the right one
+
b, err := json.Marshal(metadata.JWKS.Keys[0])
+
if err != nil {
+
return nil, err
+
}
-
k, err := helpers.ParseJWKFromBytes(b)
-
if err != nil {
-
return nil, err
-
}
+
k, err := helpers.ParseJWKFromBytes(b)
+
if err != nil {
+
return nil, err
+
}
-
jwks = k
-
} else if metadata.JWKSURI != nil {
-
maybeJwks, err := cm.getClientJwks(ctx, clientId, *metadata.JWKSURI)
-
if err != nil {
-
return nil, err
-
}
+
jwks = k
+
} else if metadata.JWKS != nil {
+
} else if metadata.JWKSURI != nil {
+
maybeJwks, err := cm.getClientJwks(ctx, clientId, *metadata.JWKSURI)
+
if err != nil {
+
return nil, err
+
}
-
jwks = maybeJwks
-
} else {
-
return nil, fmt.Errorf("no valid jwks found in oauth client metadata")
+
jwks = maybeJwks
+
} else {
+
return nil, fmt.Errorf("no valid jwks found in oauth client metadata")
+
}
}
return &Client{
···
}
func (cm *Manager) getClientMetadata(ctx context.Context, clientId string) (*Metadata, error) {
-
metadataCached, ok := cm.metadataCache.Get(clientId)
+
cached, ok := cm.metadataCache.Get(clientId)
if !ok {
req, err := http.NewRequestWithContext(ctx, "GET", clientId, nil)
if err != nil {
···
return nil, err
}
+
cm.metadataCache.Set(clientId, validated, 10*time.Minute)
+
return validated, nil
} else {
-
return &metadataCached, nil
+
return cached, nil
}
}
···
return nil, fmt.Errorf("error unmarshaling metadata: %w", err)
}
+
if metadata.ClientURI == "" {
+
u, err := url.Parse(metadata.ClientID)
+
if err != nil {
+
return nil, fmt.Errorf("unable to parse client id: %w", err)
+
}
+
u.RawPath = ""
+
u.RawQuery = ""
+
metadata.ClientURI = u.String()
+
}
+
u, err := url.Parse(metadata.ClientURI)
if err != nil {
return nil, fmt.Errorf("unable to parse client uri: %w", err)
}
+
if metadata.ClientName == "" {
+
metadata.ClientName = metadata.ClientURI
+
}
+
if isLocalHostname(u.Hostname()) {
-
return nil, errors.New("`client_uri` hostname is invalid")
+
return nil, fmt.Errorf("`client_uri` hostname is invalid: %s", u.Hostname())
}
if metadata.Scope == "" {
···
if u.Scheme != "http" {
return nil, fmt.Errorf("loopback redirect uri %s must use http", ruri)
}
-
-
break
case u.Scheme == "http":
return nil, errors.New("only loopbvack redirect uris are allowed to use the `http` scheme")
case u.Scheme == "https":
if isLocalHostname(u.Hostname()) {
return nil, fmt.Errorf("redirect uri %s's domain must not be a local hostname", ruri)
}
-
break
case strings.Contains(u.Scheme, "."):
if metadata.ApplicationType != "native" {
return nil, errors.New("private-use uri scheme redirect uris are only allowed for native apps")