An atproto PDS written in Go
1package provider 2 3import ( 4 "context" 5 "crypto" 6 "encoding/base64" 7 "errors" 8 "fmt" 9 "time" 10 11 "github.com/golang-jwt/jwt/v4" 12 "github.com/haileyok/cocoon/oauth/client" 13 "github.com/haileyok/cocoon/oauth/constants" 14 "github.com/haileyok/cocoon/oauth/dpop" 15) 16 17type AuthenticateClientOptions struct { 18 AllowMissingDpopProof bool 19} 20 21type AuthenticateClientRequestBase struct { 22 ClientID string `form:"client_id" json:"client_id" validate:"required"` 23 ClientAssertionType *string `form:"client_assertion_type" json:"client_assertion_type,omitempty"` 24 ClientAssertion *string `form:"client_assertion" json:"client_assertion,omitempty"` 25} 26 27func (p *Provider) AuthenticateClient(ctx context.Context, req AuthenticateClientRequestBase, proof *dpop.Proof, opts *AuthenticateClientOptions) (*client.Client, *ClientAuth, error) { 28 client, err := p.ClientManager.GetClient(ctx, req.ClientID) 29 if err != nil { 30 return nil, nil, fmt.Errorf("failed to get client: %w", err) 31 } 32 33 if client.Metadata.DpopBoundAccessTokens && proof == nil && (opts == nil || !opts.AllowMissingDpopProof) { 34 return nil, nil, errors.New("dpop proof required") 35 } 36 37 if proof != nil && !client.Metadata.DpopBoundAccessTokens { 38 return nil, nil, errors.New("dpop proof not allowed for this client") 39 } 40 41 clientAuth, err := p.Authenticate(ctx, req, client) 42 if err != nil { 43 return nil, nil, err 44 } 45 46 return client, clientAuth, nil 47} 48 49func (p *Provider) Authenticate(_ context.Context, req AuthenticateClientRequestBase, client *client.Client) (*ClientAuth, error) { 50 metadata := client.Metadata 51 52 if metadata.TokenEndpointAuthMethod == "none" { 53 return &ClientAuth{ 54 Method: "none", 55 }, nil 56 } 57 58 if metadata.TokenEndpointAuthMethod == "private_key_jwt" { 59 if req.ClientAssertion == nil { 60 return nil, errors.New(`client authentication method "private_key_jwt" requires a "client_assertion`) 61 } 62 63 if req.ClientAssertionType == nil || *req.ClientAssertionType != constants.ClientAssertionTypeJwtBearer { 64 return nil, fmt.Errorf("unsupported client_assertion_type %s", *req.ClientAssertionType) 65 } 66 67 token, _, err := jwt.NewParser().ParseUnverified(*req.ClientAssertion, jwt.MapClaims{}) 68 if err != nil { 69 return nil, fmt.Errorf("error parsing client assertion: %w", err) 70 } 71 72 kid, ok := token.Header["kid"].(string) 73 if !ok || kid == "" { 74 return nil, errors.New(`"kid" required in client_assertion`) 75 } 76 77 var rawKey any 78 if err := client.JWKS.Raw(&rawKey); err != nil { 79 return nil, fmt.Errorf("failed to extract raw key: %w", err) 80 } 81 82 token, err = jwt.Parse(*req.ClientAssertion, func(token *jwt.Token) (any, error) { 83 if token.Method.Alg() != jwt.SigningMethodES256.Alg() { 84 return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"]) 85 } 86 87 return rawKey, nil 88 }) 89 if err != nil { 90 return nil, fmt.Errorf(`unable to verify "client_assertion" jwt: %w`, err) 91 } 92 93 if !token.Valid { 94 return nil, errors.New("client_assertion jwt is invalid") 95 } 96 97 claims, ok := token.Claims.(jwt.MapClaims) 98 if !ok { 99 return nil, errors.New("no claims in client_assertion jwt") 100 } 101 102 sub, _ := claims["sub"].(string) 103 if sub != metadata.ClientID { 104 return nil, errors.New("subject must be client_id") 105 } 106 107 aud, _ := claims["aud"].(string) 108 if aud != "" && aud != "https://"+p.hostname { 109 return nil, fmt.Errorf("audience must be %s, got %s", "https://"+p.hostname, aud) 110 } 111 112 iat, iatOk := claims["iat"].(float64) 113 if !iatOk { 114 return nil, errors.New(`invalid client_assertion jwt: "iat" is missing`) 115 } 116 117 iatTime := time.Unix(int64(iat), 0) 118 if time.Since(iatTime) > constants.ClientAssertionMaxAge { 119 return nil, errors.New("client_assertion jwt too old") 120 } 121 122 jti, _ := claims["jti"].(string) 123 if jti == "" { 124 return nil, errors.New(`invalid client_assertion jwt: "jti" is missing`) 125 } 126 127 var exp *float64 128 if maybeExp, ok := claims["exp"].(float64); ok { 129 exp = &maybeExp 130 } 131 132 alg := token.Header["alg"].(string) 133 134 thumbBytes, err := client.JWKS.Thumbprint(crypto.SHA256) 135 if err != nil { 136 return nil, fmt.Errorf("failed to calculate thumbprint: %w", err) 137 } 138 139 thumb := base64.RawURLEncoding.EncodeToString(thumbBytes) 140 141 return &ClientAuth{ 142 Method: "private_key_jwt", 143 Jti: jti, 144 Exp: exp, 145 Jkt: thumb, 146 Alg: alg, 147 Kid: kid, 148 }, nil 149 } 150 151 return nil, fmt.Errorf("auth method %s is not implemented in this pds", metadata.TokenEndpointAuthMethod) 152}