1package provider
2
3import (
4 "context"
5 "crypto"
6 "database/sql/driver"
7 "encoding/base64"
8 "encoding/json"
9 "errors"
10 "fmt"
11 "time"
12
13 "github.com/golang-jwt/jwt/v4"
14 "github.com/haileyok/cocoon/oauth"
15 "github.com/haileyok/cocoon/oauth/constants"
16 "github.com/haileyok/cocoon/oauth/dpop"
17)
18
19type ClientAuth struct {
20 Method string
21 Alg string
22 Kid string
23 Jkt string
24 Jti string
25 Exp *float64
26}
27
28func (ca *ClientAuth) Scan(value any) error {
29 b, ok := value.([]byte)
30 if !ok {
31 return fmt.Errorf("failed to unmarshal OauthParRequest value")
32 }
33 return json.Unmarshal(b, ca)
34}
35
36func (ca ClientAuth) Value() (driver.Value, error) {
37 return json.Marshal(ca)
38}
39
40type AuthenticateClientOptions struct {
41 AllowMissingDpopProof bool
42}
43
44type AuthenticateClientRequestBase struct {
45 ClientID string `form:"client_id" json:"client_id" validate:"required"`
46 ClientAssertionType *string `form:"client_assertion_type" json:"client_assertion_type,omitempty"`
47 ClientAssertion *string `form:"client_assertion" json:"client_assertion,omitempty"`
48}
49
50func (p *Provider) AuthenticateClient(ctx context.Context, req AuthenticateClientRequestBase, proof *dpop.Proof, opts *AuthenticateClientOptions) (*oauth.Client, *ClientAuth, error) {
51 client, err := p.ClientManager.GetClient(ctx, req.ClientID)
52 if err != nil {
53 return nil, nil, fmt.Errorf("failed to get client: %w", err)
54 }
55
56 if client.Metadata.DpopBoundAccessTokens && proof == nil && (opts == nil || !opts.AllowMissingDpopProof) {
57 return nil, nil, errors.New("dpop proof required")
58 }
59
60 if proof != nil && !client.Metadata.DpopBoundAccessTokens {
61 return nil, nil, errors.New("dpop proof not allowed for this client")
62 }
63
64 clientAuth, err := p.Authenticate(ctx, req, client)
65 if err != nil {
66 return nil, nil, err
67 }
68
69 return client, clientAuth, nil
70}
71
72func (p *Provider) Authenticate(_ context.Context, req AuthenticateClientRequestBase, client *oauth.Client) (*ClientAuth, error) {
73 metadata := client.Metadata
74
75 if metadata.TokenEndpointAuthMethod == "none" {
76 return &ClientAuth{
77 Method: "none",
78 }, nil
79 }
80
81 if metadata.TokenEndpointAuthMethod == "private_key_jwt" {
82 if req.ClientAssertion == nil {
83 return nil, errors.New(`client authentication method "private_key_jwt" requires a "client_assertion`)
84 }
85
86 if req.ClientAssertionType == nil || *req.ClientAssertionType != constants.ClientAssertionTypeJwtBearer {
87 return nil, fmt.Errorf("unsupported client_assertion_type %s", *req.ClientAssertionType)
88 }
89
90 token, _, err := jwt.NewParser().ParseUnverified(*req.ClientAssertion, jwt.MapClaims{})
91 if err != nil {
92 return nil, fmt.Errorf("error parsing client assertion: %w", err)
93 }
94
95 kid, ok := token.Header["kid"].(string)
96 if !ok || kid == "" {
97 return nil, errors.New(`"kid" required in client_assertion`)
98 }
99
100 var rawKey any
101 if err := client.JWKS.Raw(&rawKey); err != nil {
102 return nil, fmt.Errorf("failed to extract raw key: %w", err)
103 }
104
105 token, err = jwt.Parse(*req.ClientAssertion, func(token *jwt.Token) (any, error) {
106 if token.Method.Alg() != jwt.SigningMethodES256.Alg() {
107 return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
108 }
109
110 return rawKey, nil
111 })
112 if err != nil {
113 return nil, fmt.Errorf(`unable to verify "client_assertion" jwt: %w`, err)
114 }
115
116 if !token.Valid {
117 return nil, errors.New("client_assertion jwt is invalid")
118 }
119
120 claims, ok := token.Claims.(jwt.MapClaims)
121 if !ok {
122 return nil, errors.New("no claims in client_assertion jwt")
123 }
124
125 sub, _ := claims["sub"].(string)
126 if sub != metadata.ClientID {
127 return nil, errors.New("subject must be client_id")
128 }
129
130 aud, _ := claims["aud"].(string)
131 if aud != "" && aud != "https://"+p.hostname {
132 return nil, fmt.Errorf("audience must be %s, got %s", "https://"+p.hostname, aud)
133 }
134
135 iat, iatOk := claims["iat"].(float64)
136 if !iatOk {
137 return nil, errors.New(`invalid client_assertion jwt: "iat" is missing`)
138 }
139
140 iatTime := time.Unix(int64(iat), 0)
141 if time.Since(iatTime) > constants.ClientAssertionMaxAge {
142 return nil, errors.New("client_assertion jwt too old")
143 }
144
145 jti, _ := claims["jti"].(string)
146 if jti == "" {
147 return nil, errors.New(`invalid client_assertion jwt: "jti" is missing`)
148 }
149
150 var exp *float64
151 if maybeExp, ok := claims["exp"].(float64); ok {
152 exp = &maybeExp
153 }
154
155 alg := token.Header["alg"].(string)
156
157 thumbBytes, err := client.JWKS.Thumbprint(crypto.SHA256)
158 if err != nil {
159 return nil, fmt.Errorf("failed to calculate thumbprint: %w", err)
160 }
161
162 thumb := base64.RawURLEncoding.EncodeToString(thumbBytes)
163
164 return &ClientAuth{
165 Method: "private_key_jwt",
166 Jti: jti,
167 Exp: exp,
168 Jkt: thumb,
169 Alg: alg,
170 Kid: kid,
171 }, nil
172 }
173
174 return nil, fmt.Errorf("auth method %s is not implemented in this pds", metadata.TokenEndpointAuthMethod)
175}