this repo has no description
1package oauth
2
3import (
4 "bytes"
5 "context"
6 "crypto/ecdsa"
7 "crypto/rand"
8 "crypto/sha256"
9 "encoding/base64"
10 "encoding/hex"
11 "encoding/json"
12 "fmt"
13 "io"
14 "net/http"
15 "time"
16
17 "github.com/golang-jwt/jwt/v5"
18 "github.com/google/uuid"
19 "github.com/lestrrat-go/jwx/v2/jwk"
20)
21
22type OauthClient struct {
23 h *http.Client
24 clientPrivateKey *ecdsa.PrivateKey
25 clientKid string
26 clientId string
27 redirectUri string
28}
29
30type OauthClientArgs struct {
31 H *http.Client
32 ClientJwk []byte
33 ClientId string
34 RedirectUri string
35}
36
37func NewOauthClient(args OauthClientArgs) (*OauthClient, error) {
38 if args.ClientId == "" {
39 return nil, fmt.Errorf("no client id provided")
40 }
41
42 if args.RedirectUri == "" {
43 return nil, fmt.Errorf("no redirect uri provided")
44 }
45
46 if args.H == nil {
47 args.H = &http.Client{
48 Timeout: 5 * time.Second,
49 }
50 }
51
52 clientJwk, err := jwk.ParseKey(args.ClientJwk)
53 if err != nil {
54 return nil, err
55 }
56
57 clientPkey, err := getPrivateKey(clientJwk)
58 if err != nil {
59 return nil, fmt.Errorf("could not load private key from provided client jwk: %w", err)
60 }
61
62 kid := clientJwk.KeyID()
63
64 return &OauthClient{
65 h: args.H,
66 clientKid: kid,
67 clientPrivateKey: clientPkey,
68 clientId: args.ClientId,
69 redirectUri: args.RedirectUri,
70 }, nil
71}
72
73func (o *OauthClient) ResolvePDSAuthServer(ctx context.Context, ustr string) (string, error) {
74 u, err := isSafeAndParsed(ustr)
75 if err != nil {
76 return "", err
77 }
78
79 u.Path = "/.well-known/oauth-protected-resource"
80
81 req, err := http.NewRequestWithContext(ctx, "GET", u.String(), nil)
82 if err != nil {
83 return "", fmt.Errorf("error creating request for oauth protected resource: %w", err)
84 }
85
86 resp, err := o.h.Do(req)
87 if err != nil {
88 return "", fmt.Errorf("could not get response from server: %w", err)
89 }
90 defer resp.Body.Close()
91
92 if resp.StatusCode != http.StatusOK {
93 io.Copy(io.Discard, resp.Body)
94 return "", fmt.Errorf("received non-200 response from pds. code was %d", resp.StatusCode)
95 }
96
97 b, err := io.ReadAll(resp.Body)
98 if err != nil {
99 return "", fmt.Errorf("could not read body: %w", err)
100 }
101
102 var resource OauthProtectedResource
103 if err := resource.UnmarshalJSON(b); err != nil {
104 return "", fmt.Errorf("could not unmarshal json: %w", err)
105 }
106
107 if len(resource.AuthorizationServers) == 0 {
108 return "", fmt.Errorf("oauth protected resource contained no authorization servers")
109 }
110
111 return resource.AuthorizationServers[0], nil
112}
113
114func (o *OauthClient) FetchAuthServerMetadata(ctx context.Context, ustr string) (any, error) {
115 u, err := isSafeAndParsed(ustr)
116 if err != nil {
117 return nil, err
118 }
119
120 u.Path = "/.well-known/oauth-authorization-server"
121
122 req, err := http.NewRequestWithContext(ctx, "GET", u.String(), nil)
123 if err != nil {
124 return nil, fmt.Errorf("error creating request to fetch auth metadata: %w", err)
125 }
126
127 resp, err := o.h.Do(req)
128 if err != nil {
129 return nil, fmt.Errorf("error getting response for auth metadata: %w", err)
130 }
131 defer resp.Body.Close()
132
133 if resp.StatusCode != http.StatusOK {
134 io.Copy(io.Discard, resp.Body)
135 return nil, fmt.Errorf("received non-200 response from pds. status code was %d", resp.StatusCode)
136 }
137
138 b, err := io.ReadAll(resp.Body)
139 if err != nil {
140 return nil, fmt.Errorf("could not read body for metadata response: %w", err)
141 }
142
143 var metadata OauthAuthorizationMetadata
144 if err := metadata.UnmarshalJSON(b); err != nil {
145 return nil, fmt.Errorf("could not unmarshal metadata: %w", err)
146 }
147
148 if err := metadata.Validate(u); err != nil {
149 return nil, fmt.Errorf("could not validate metadata: %w", err)
150 }
151
152 return metadata, nil
153}
154
155func (o *OauthClient) ClientAssertionJwt(authServerUrl string) (string, error) {
156 claims := jwt.MapClaims{
157 "iss": o.clientId,
158 "sub": o.clientId,
159 "aud": authServerUrl,
160 "jti": uuid.NewString(),
161 "iat": time.Now().Unix(),
162 }
163
164 token := jwt.NewWithClaims(jwt.SigningMethodES256, claims)
165 token.Header["kid"] = o.clientKid
166
167 tokenString, err := token.SignedString(o.clientPrivateKey)
168 if err != nil {
169 return "", err
170 }
171
172 return tokenString, nil
173}
174
175func (o *OauthClient) AuthServerDpopJwt(method, url, nonce string, privateJwk jwk.Key) (string, error) {
176 raw, err := jwk.PublicKeyOf(privateJwk)
177 if err != nil {
178 return "", err
179 }
180
181 pubJwk, err := jwk.FromRaw(raw)
182 if err != nil {
183 return "", err
184 }
185
186 b, err := json.Marshal(pubJwk)
187 if err != nil {
188 return "", err
189 }
190
191 var pubMap map[string]interface{}
192 if err := json.Unmarshal(b, &pubMap); err != nil {
193 return "", err
194 }
195
196 now := time.Now().Unix()
197
198 claims := jwt.MapClaims{
199 "jti": uuid.NewString(),
200 "htm": method,
201 "htu": url,
202 "iat": now,
203 "exp": now + 30,
204 }
205
206 if nonce != "" {
207 claims["nonce"] = nonce
208 }
209
210 token := jwt.NewWithClaims(jwt.SigningMethodES256, claims)
211 token.Header["typ"] = "dpop+jwt"
212 token.Header["alg"] = "ES256"
213 token.Header["jwk"] = pubMap
214
215 var rawKey interface{}
216 if err := privateJwk.Raw(&rawKey); err != nil {
217 return "", err
218 }
219
220 tokenString, err := token.SignedString(rawKey)
221 if err != nil {
222 return "", fmt.Errorf("failed to sign token: %w", err)
223 }
224
225 return tokenString, nil
226}
227
228func (o *OauthClient) SendParAuthRequest(ctx context.Context, authServerUrl string, authServerMeta *OauthAuthorizationMetadata, loginHint, scope string, dpopPrivateKey jwk.Key) (any, error) {
229 if authServerMeta == nil {
230 return nil, fmt.Errorf("nil metadata provided")
231 }
232
233 parUrl := authServerMeta.PushedAuthorizationRequestEndpoint
234
235 state, err := generateToken(10)
236 if err != nil {
237 return nil, fmt.Errorf("could not generate state token: %w", err)
238 }
239
240 pkceVerifier, err := generateToken(48)
241 if err != nil {
242 return nil, fmt.Errorf("could not generate pkce verifier: %w", err)
243 }
244
245 codeChallenge := generateCodeChallenge(pkceVerifier)
246 codeChallengeMethod := "S256"
247
248 clientAssertion, err := o.ClientAssertionJwt(authServerUrl)
249 if err != nil {
250 return nil, err
251 }
252
253 // TODO: ??
254 nonce := ""
255 dpopProof, err := o.AuthServerDpopJwt("POST", parUrl, nonce, dpopPrivateKey)
256 if err != nil {
257 return nil, err
258 }
259
260 parBody := map[string]string{
261 "response_type": "code",
262 "code_challenge": codeChallenge,
263 "code_challenge_method": codeChallengeMethod,
264 "client_id": o.clientId,
265 "state": state,
266 "redirect_uri": o.redirectUri,
267 "scope": scope,
268 "client_assertion_type": "urn:ietf:params:oauth:client-assertion-type:jwt-bearer",
269 "client_assertion": clientAssertion,
270 }
271
272 if loginHint != "" {
273 parBody["login_hint"] = loginHint
274 }
275
276 _, err = isSafeAndParsed(parUrl)
277 if err != nil {
278 return nil, err
279 }
280
281 b, err := json.Marshal(parBody)
282 if err != nil {
283 return nil, err
284 }
285
286 req, err := http.NewRequestWithContext(ctx, "POST", parUrl, bytes.NewReader(b))
287 if err != nil {
288 return nil, err
289 }
290
291 req.Header.Add("Content-Type", "application/x-www-form-urlencoded")
292 req.Header.Add("DPoP", dpopProof)
293
294 return nil, nil
295}
296
297func generateToken(len int) (string, error) {
298 b := make([]byte, len)
299 if _, err := rand.Read(b); err != nil {
300 return "", err
301 }
302
303 return hex.EncodeToString(b), nil
304}
305
306func generateCodeChallenge(pkceVerifier string) string {
307 h := sha256.New()
308 h.Write([]byte(pkceVerifier))
309 hash := h.Sum(nil)
310 return base64.RawURLEncoding.EncodeToString(hash)
311}