this repo has no description
1package oauth
2
3import (
4 "context"
5 "crypto/ecdsa"
6 "crypto/rand"
7 "crypto/sha256"
8 "encoding/base64"
9 "encoding/hex"
10 "encoding/json"
11 "fmt"
12 "io"
13 "net/http"
14 "net/url"
15 "strings"
16 "time"
17
18 "github.com/golang-jwt/jwt/v5"
19 "github.com/google/uuid"
20 "github.com/lestrrat-go/jwx/v2/jwk"
21)
22
23type Client struct {
24 h *http.Client
25 clientPrivateKey *ecdsa.PrivateKey
26 clientKid string
27 clientId string
28 redirectUri string
29}
30
31type ClientArgs struct {
32 H *http.Client
33 ClientJwk jwk.Key
34 ClientId string
35 RedirectUri string
36}
37
38func NewClient(args ClientArgs) (*Client, error) {
39 if args.ClientId == "" {
40 return nil, fmt.Errorf("no client id provided")
41 }
42
43 if args.RedirectUri == "" {
44 return nil, fmt.Errorf("no redirect uri provided")
45 }
46
47 if args.H == nil {
48 args.H = &http.Client{
49 Timeout: 5 * time.Second,
50 }
51 }
52
53 clientPkey, err := getPrivateKey(args.ClientJwk)
54 if err != nil {
55 return nil, fmt.Errorf("could not load private key from provided client jwk: %w", err)
56 }
57
58 kid := args.ClientJwk.KeyID()
59
60 return &Client{
61 h: args.H,
62 clientKid: kid,
63 clientPrivateKey: clientPkey,
64 clientId: args.ClientId,
65 redirectUri: args.RedirectUri,
66 }, nil
67}
68
69func (c *Client) ResolvePDSAuthServer(ctx context.Context, ustr string) (string, error) {
70 u, err := isSafeAndParsed(ustr)
71 if err != nil {
72 return "", err
73 }
74
75 u.Path = "/.well-known/oauth-protected-resource"
76
77 req, err := http.NewRequestWithContext(ctx, "GET", u.String(), nil)
78 if err != nil {
79 return "", fmt.Errorf("error creating request for oauth protected resource: %w", err)
80 }
81
82 resp, err := c.h.Do(req)
83 if err != nil {
84 return "", fmt.Errorf("could not get response from server: %w", err)
85 }
86 defer resp.Body.Close()
87
88 if resp.StatusCode != http.StatusOK {
89 io.Copy(io.Discard, resp.Body)
90 return "", fmt.Errorf("received non-200 response from pds. code was %d", resp.StatusCode)
91 }
92
93 b, err := io.ReadAll(resp.Body)
94 if err != nil {
95 return "", fmt.Errorf("could not read body: %w", err)
96 }
97
98 var resource OauthProtectedResource
99 if err := resource.UnmarshalJSON(b); err != nil {
100 return "", fmt.Errorf("could not unmarshal json: %w", err)
101 }
102
103 if len(resource.AuthorizationServers) == 0 {
104 return "", fmt.Errorf("oauth protected resource contained no authorization servers")
105 }
106
107 return resource.AuthorizationServers[0], nil
108}
109
110func (c *Client) FetchAuthServerMetadata(ctx context.Context, ustr string) (*OauthAuthorizationMetadata, error) {
111 u, err := isSafeAndParsed(ustr)
112 if err != nil {
113 return nil, err
114 }
115
116 u.Path = "/.well-known/oauth-authorization-server"
117
118 req, err := http.NewRequestWithContext(ctx, "GET", u.String(), nil)
119 if err != nil {
120 return nil, fmt.Errorf("error creating request to fetch auth metadata: %w", err)
121 }
122
123 resp, err := c.h.Do(req)
124 if err != nil {
125 return nil, fmt.Errorf("error getting response for auth metadata: %w", err)
126 }
127 defer resp.Body.Close()
128
129 if resp.StatusCode != http.StatusOK {
130 io.Copy(io.Discard, resp.Body)
131 return nil, fmt.Errorf(
132 "received non-200 response from pds. status code was %d",
133 resp.StatusCode,
134 )
135 }
136
137 b, err := io.ReadAll(resp.Body)
138 if err != nil {
139 return nil, fmt.Errorf("could not read body for metadata response: %w", err)
140 }
141
142 var metadata OauthAuthorizationMetadata
143 if err := metadata.UnmarshalJSON(b); err != nil {
144 return nil, fmt.Errorf("could not unmarshal metadata: %w", err)
145 }
146
147 if err := metadata.Validate(u); err != nil {
148 return nil, fmt.Errorf("could not validate metadata: %w", err)
149 }
150
151 return &metadata, nil
152}
153
154func (c *Client) ClientAssertionJwt(authServerUrl string) (string, error) {
155 claims := jwt.MapClaims{
156 "iss": c.clientId,
157 "sub": c.clientId,
158 "aud": authServerUrl,
159 "jti": uuid.NewString(),
160 "iat": time.Now().Unix(),
161 }
162
163 token := jwt.NewWithClaims(jwt.SigningMethodES256, claims)
164 token.Header["kid"] = c.clientKid
165
166 tokenString, err := token.SignedString(c.clientPrivateKey)
167 if err != nil {
168 return "", err
169 }
170
171 return tokenString, nil
172}
173
174func (c *Client) AuthServerDpopJwt(method, url, nonce string, privateJwk jwk.Key) (string, error) {
175 pubJwk, err := privateJwk.PublicKey()
176 if err != nil {
177 return "", err
178 }
179
180 b, err := json.Marshal(pubJwk)
181 if err != nil {
182 return "", err
183 }
184
185 var pubMap map[string]any
186 if err := json.Unmarshal(b, &pubMap); err != nil {
187 return "", err
188 }
189
190 now := time.Now().Unix()
191
192 claims := jwt.MapClaims{
193 "jti": uuid.NewString(),
194 "htm": method,
195 "htu": url,
196 "iat": now,
197 "exp": now + 30,
198 }
199
200 if nonce != "" {
201 claims["nonce"] = nonce
202 }
203
204 token := jwt.NewWithClaims(jwt.SigningMethodES256, claims)
205 token.Header["typ"] = "dpop+jwt"
206 token.Header["alg"] = "ES256"
207 token.Header["jwk"] = pubMap
208
209 var rawKey any
210 if err := privateJwk.Raw(&rawKey); err != nil {
211 return "", err
212 }
213
214 tokenString, err := token.SignedString(rawKey)
215 if err != nil {
216 return "", fmt.Errorf("failed to sign token: %w", err)
217 }
218
219 return tokenString, nil
220}
221
222func (c *Client) SendParAuthRequest(ctx context.Context, authServerUrl string, authServerMeta *OauthAuthorizationMetadata, loginHint, scope string, dpopPrivateKey jwk.Key) (*SendParAuthResponse, error) {
223 if authServerMeta == nil {
224 return nil, fmt.Errorf("nil metadata provided")
225 }
226
227 parUrl := authServerMeta.PushedAuthorizationRequestEndpoint
228
229 state, err := generateToken(10)
230 if err != nil {
231 return nil, fmt.Errorf("could not generate state token: %w", err)
232 }
233
234 pkceVerifier, err := generateToken(48)
235 if err != nil {
236 return nil, fmt.Errorf("could not generate pkce verifier: %w", err)
237 }
238
239 codeChallenge := generateCodeChallenge(pkceVerifier)
240 codeChallengeMethod := "S256"
241
242 clientAssertion, err := c.ClientAssertionJwt(authServerUrl)
243 if err != nil {
244 return nil, err
245 }
246
247 dpopAuthserverNonce := ""
248 dpopProof, err := c.AuthServerDpopJwt("POST", parUrl, dpopAuthserverNonce, dpopPrivateKey)
249 if err != nil {
250 return nil, fmt.Errorf("error getting dpop proof: %w", err)
251 }
252
253 params := url.Values{
254 "response_type": {"code"},
255 "code_challenge": {codeChallenge},
256 "code_challenge_method": {codeChallengeMethod},
257 "client_id": {c.clientId},
258 "state": {state},
259 "redirect_uri": {c.redirectUri},
260 "scope": {scope},
261 "client_assertion_type": {"urn:ietf:params:oauth:client-assertion-type:jwt-bearer"},
262 "client_assertion": {clientAssertion},
263 }
264
265 if loginHint != "" {
266 params.Set("login_hint", loginHint)
267 }
268
269 _, err = isSafeAndParsed(parUrl)
270 if err != nil {
271 return nil, err
272 }
273
274 req, err := http.NewRequestWithContext(ctx, "POST", parUrl, strings.NewReader(params.Encode()))
275 if err != nil {
276 return nil, err
277 }
278
279 req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
280 req.Header.Set("DPoP", dpopProof)
281
282 resp, err := c.h.Do(req)
283 if err != nil {
284 return nil, err
285 }
286 defer resp.Body.Close()
287
288 var rmap map[string]any
289 if err := json.NewDecoder(resp.Body).Decode(&rmap); err != nil {
290 return nil, err
291 }
292
293 if resp.StatusCode == 400 && rmap["error"] == "use_dpop_nonce" {
294 dpopAuthserverNonce = resp.Header.Get("DPoP-Nonce")
295 dpopProof, err := c.AuthServerDpopJwt("POST", parUrl, dpopAuthserverNonce, dpopPrivateKey)
296 if err != nil {
297 return nil, err
298 }
299
300 req2, err := http.NewRequestWithContext(
301 ctx,
302 "POST",
303 parUrl,
304 strings.NewReader(params.Encode()),
305 )
306 if err != nil {
307 return nil, err
308 }
309
310 req2.Header.Set("Content-Type", "application/x-www-form-urlencoded")
311 req2.Header.Set("DPoP", dpopProof)
312
313 resp2, err := c.h.Do(req2)
314 if err != nil {
315 return nil, err
316 }
317 defer resp2.Body.Close()
318
319 rmap = map[string]any{}
320 if err := json.NewDecoder(resp2.Body).Decode(&rmap); err != nil {
321 return nil, err
322 }
323 }
324
325 return &SendParAuthResponse{
326 PkceVerifier: pkceVerifier,
327 State: state,
328 DpopAuthserverNonce: dpopAuthserverNonce,
329 Resp: rmap,
330 }, nil
331}
332
333func (c *Client) InitialTokenRequest(
334 ctx context.Context,
335 code,
336 authserverIss,
337 pkceVerifier,
338 dpopAuthserverNonce string,
339 dpopPrivateJwk jwk.Key,
340) (*TokenResponse, error) {
341 authserverMeta, err := c.FetchAuthServerMetadata(ctx, authserverIss)
342 if err != nil {
343 return nil, err
344 }
345
346 clientAssertion, err := c.ClientAssertionJwt(authserverIss)
347 if err != nil {
348 return nil, err
349 }
350
351 params := url.Values{
352 "client_id": {c.clientId},
353 "redirect_uri": {c.redirectUri},
354 "grant_type": {"authorization_code"},
355 "code": {code},
356 "code_verifier": {pkceVerifier},
357 "client_assertion_type": {"urn:ietf:params:oauth:client-assertion-type:jwt-bearer"},
358 "client_assertion": {clientAssertion},
359 }
360
361 dpopProof, err := c.AuthServerDpopJwt("POST", authserverMeta.TokenEndpoint, dpopAuthserverNonce, dpopPrivateJwk)
362 if err != nil {
363 return nil, err
364 }
365
366 req, err := http.NewRequestWithContext(ctx, "POST", authserverMeta.TokenEndpoint, strings.NewReader(params.Encode()))
367 if err != nil {
368 return nil, err
369 }
370
371 req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
372 req.Header.Set("DPoP", dpopProof)
373
374 resp, err := c.h.Do(req)
375 if err != nil {
376 return nil, err
377 }
378 defer resp.Body.Close()
379
380 var tokenResponse TokenResponse
381 if err := json.NewDecoder(resp.Body).Decode(&tokenResponse); err != nil {
382 return nil, err
383 }
384
385 tokenResponse.DpopAuthserverNonce = dpopAuthserverNonce
386
387 return &tokenResponse, nil
388}
389
390func (c *Client) RefreshTokenRequest(
391 ctx context.Context,
392 refreshToken,
393 authserverIss,
394 dpopAuthserverNonce string,
395 dpopPrivateJwk jwk.Key,
396) (*TokenResponse, error) {
397 // we may need to update the dpop nonce
398 for range 2 {
399 authserverMeta, err := c.FetchAuthServerMetadata(ctx, authserverIss)
400 if err != nil {
401 return nil, err
402 }
403
404 clientAssertion, err := c.ClientAssertionJwt(authserverIss)
405 if err != nil {
406 return nil, err
407 }
408
409 params := url.Values{
410 "client_id": {c.clientId},
411 "grant_type": {"refresh_token"},
412 "refresh_token": {refreshToken},
413 "client_assertion_type": {"urn:ietf:params:oauth:client-assertion-type:jwt-bearer"},
414 "client_assertion": {clientAssertion},
415 }
416
417 dpopProof, err := c.AuthServerDpopJwt("POST", authserverMeta.TokenEndpoint, dpopAuthserverNonce, dpopPrivateJwk)
418 if err != nil {
419 return nil, err
420 }
421
422 req, err := http.NewRequestWithContext(ctx, "POST", authserverMeta.TokenEndpoint, strings.NewReader(params.Encode()))
423 if err != nil {
424 return nil, err
425 }
426
427 req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
428 req.Header.Set("DPoP", dpopProof)
429
430 resp, err := c.h.Do(req)
431 if err != nil {
432 return nil, err
433 }
434 defer resp.Body.Close()
435
436 if resp.StatusCode != 200 && resp.StatusCode != 201 {
437 var respMap map[string]string
438 if err := json.NewDecoder(resp.Body).Decode(&respMap); err != nil {
439 return nil, err
440 }
441
442 if resp.StatusCode == 400 && respMap["error"] == "use_dpop_nonce" {
443 dpopAuthserverNonce = resp.Header.Get("DPoP-Nonce")
444 continue
445 }
446
447 return nil, fmt.Errorf("token refresh error: %s", respMap["error"])
448 }
449
450 var tokenResponse TokenResponse
451 if err := json.NewDecoder(resp.Body).Decode(&tokenResponse); err != nil {
452 return nil, err
453 }
454
455 // set the nonce so that updates are reflected in response
456 tokenResponse.DpopAuthserverNonce = dpopAuthserverNonce
457
458 return &tokenResponse, nil
459 }
460
461 return nil, nil
462}
463
464func generateToken(len int) (string, error) {
465 b := make([]byte, len)
466 if _, err := rand.Read(b); err != nil {
467 return "", err
468 }
469
470 return hex.EncodeToString(b), nil
471}
472
473func generateCodeChallenge(pkceVerifier string) string {
474 h := sha256.New()
475 h.Write([]byte(pkceVerifier))
476 hash := h.Sum(nil)
477 return base64.RawURLEncoding.EncodeToString(hash)
478}
479
480func parsePrivateJwkFromString(str string) (jwk.Key, error) {
481 return jwk.ParseKey([]byte(str))
482}