1package provider
2
3import (
4 "context"
5 "crypto"
6 "encoding/base64"
7 "errors"
8 "fmt"
9 "time"
10
11 "github.com/golang-jwt/jwt/v4"
12 "github.com/haileyok/cocoon/oauth/client"
13 "github.com/haileyok/cocoon/oauth/constants"
14 "github.com/haileyok/cocoon/oauth/dpop"
15)
16
17type AuthenticateClientOptions struct {
18 AllowMissingDpopProof bool
19}
20
21type AuthenticateClientRequestBase struct {
22 ClientID string `form:"client_id" json:"client_id" validate:"required"`
23 ClientAssertionType *string `form:"client_assertion_type" json:"client_assertion_type,omitempty"`
24 ClientAssertion *string `form:"client_assertion" json:"client_assertion,omitempty"`
25}
26
27func (p *Provider) AuthenticateClient(ctx context.Context, req AuthenticateClientRequestBase, proof *dpop.Proof, opts *AuthenticateClientOptions) (*client.Client, *ClientAuth, error) {
28 client, err := p.ClientManager.GetClient(ctx, req.ClientID)
29 if err != nil {
30 return nil, nil, fmt.Errorf("failed to get client: %w", err)
31 }
32
33 if client.Metadata.DpopBoundAccessTokens && proof == nil && (opts == nil || !opts.AllowMissingDpopProof) {
34 return nil, nil, errors.New("dpop proof required")
35 }
36
37 if proof != nil && !client.Metadata.DpopBoundAccessTokens {
38 return nil, nil, errors.New("dpop proof not allowed for this client")
39 }
40
41 clientAuth, err := p.Authenticate(ctx, req, client)
42 if err != nil {
43 return nil, nil, err
44 }
45
46 return client, clientAuth, nil
47}
48
49func (p *Provider) Authenticate(_ context.Context, req AuthenticateClientRequestBase, client *client.Client) (*ClientAuth, error) {
50 metadata := client.Metadata
51
52 if metadata.TokenEndpointAuthMethod == "none" {
53 return &ClientAuth{
54 Method: "none",
55 }, nil
56 }
57
58 if metadata.TokenEndpointAuthMethod == "private_key_jwt" {
59 if req.ClientAssertion == nil {
60 return nil, errors.New(`client authentication method "private_key_jwt" requires a "client_assertion`)
61 }
62
63 if req.ClientAssertionType == nil || *req.ClientAssertionType != constants.ClientAssertionTypeJwtBearer {
64 return nil, fmt.Errorf("unsupported client_assertion_type %s", *req.ClientAssertionType)
65 }
66
67 token, _, err := jwt.NewParser().ParseUnverified(*req.ClientAssertion, jwt.MapClaims{})
68 if err != nil {
69 return nil, fmt.Errorf("error parsing client assertion: %w", err)
70 }
71
72 kid, ok := token.Header["kid"].(string)
73 if !ok || kid == "" {
74 return nil, errors.New(`"kid" required in client_assertion`)
75 }
76
77 var rawKey any
78 if err := client.JWKS.Raw(&rawKey); err != nil {
79 return nil, fmt.Errorf("failed to extract raw key: %w", err)
80 }
81
82 token, err = jwt.Parse(*req.ClientAssertion, func(token *jwt.Token) (any, error) {
83 if token.Method.Alg() != jwt.SigningMethodES256.Alg() {
84 return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
85 }
86
87 return rawKey, nil
88 })
89 if err != nil {
90 return nil, fmt.Errorf(`unable to verify "client_assertion" jwt: %w`, err)
91 }
92
93 if !token.Valid {
94 return nil, errors.New("client_assertion jwt is invalid")
95 }
96
97 claims, ok := token.Claims.(jwt.MapClaims)
98 if !ok {
99 return nil, errors.New("no claims in client_assertion jwt")
100 }
101
102 sub, _ := claims["sub"].(string)
103 if sub != metadata.ClientID {
104 return nil, errors.New("subject must be client_id")
105 }
106
107 aud, _ := claims["aud"].(string)
108 if aud != "" && aud != "https://"+p.hostname {
109 return nil, fmt.Errorf("audience must be %s, got %s", "https://"+p.hostname, aud)
110 }
111
112 iat, iatOk := claims["iat"].(float64)
113 if !iatOk {
114 return nil, errors.New(`invalid client_assertion jwt: "iat" is missing`)
115 }
116
117 iatTime := time.Unix(int64(iat), 0)
118 if time.Since(iatTime) > constants.ClientAssertionMaxAge {
119 return nil, errors.New("client_assertion jwt too old")
120 }
121
122 jti, _ := claims["jti"].(string)
123 if jti == "" {
124 return nil, errors.New(`invalid client_assertion jwt: "jti" is missing`)
125 }
126
127 var exp *float64
128 if maybeExp, ok := claims["exp"].(float64); ok {
129 exp = &maybeExp
130 }
131
132 alg := token.Header["alg"].(string)
133
134 thumbBytes, err := client.JWKS.Thumbprint(crypto.SHA256)
135 if err != nil {
136 return nil, fmt.Errorf("failed to calculate thumbprint: %w", err)
137 }
138
139 thumb := base64.RawURLEncoding.EncodeToString(thumbBytes)
140
141 return &ClientAuth{
142 Method: "private_key_jwt",
143 Jti: jti,
144 Exp: exp,
145 Jkt: thumb,
146 Alg: alg,
147 Kid: kid,
148 }, nil
149 }
150
151 return nil, fmt.Errorf("auth method %s is not implemented in this pds", metadata.TokenEndpointAuthMethod)
152}