this repo has no description
1package oauth
2
3import (
4 "context"
5 "crypto/ecdsa"
6 "crypto/rand"
7 "crypto/sha256"
8 "encoding/base64"
9 "encoding/hex"
10 "encoding/json"
11 "fmt"
12 "io"
13 "net/http"
14 "net/url"
15 "strings"
16 "time"
17
18 "github.com/golang-jwt/jwt/v5"
19 "github.com/google/uuid"
20 "github.com/lestrrat-go/jwx/v2/jwk"
21)
22
23type OauthClient struct {
24 h *http.Client
25 clientPrivateKey *ecdsa.PrivateKey
26 clientKid string
27 clientId string
28 redirectUri string
29}
30
31type OauthClientArgs struct {
32 H *http.Client
33 ClientJwk []byte
34 ClientId string
35 RedirectUri string
36}
37
38func NewOauthClient(args OauthClientArgs) (*OauthClient, error) {
39 if args.ClientId == "" {
40 return nil, fmt.Errorf("no client id provided")
41 }
42
43 if args.RedirectUri == "" {
44 return nil, fmt.Errorf("no redirect uri provided")
45 }
46
47 if args.H == nil {
48 args.H = &http.Client{
49 Timeout: 5 * time.Second,
50 }
51 }
52
53 clientJwk, err := jwk.ParseKey(args.ClientJwk)
54 if err != nil {
55 return nil, err
56 }
57
58 clientPkey, err := getPrivateKey(clientJwk)
59 if err != nil {
60 return nil, fmt.Errorf("could not load private key from provided client jwk: %w", err)
61 }
62
63 kid := clientJwk.KeyID()
64
65 return &OauthClient{
66 h: args.H,
67 clientKid: kid,
68 clientPrivateKey: clientPkey,
69 clientId: args.ClientId,
70 redirectUri: args.RedirectUri,
71 }, nil
72}
73
74func (c *OauthClient) ResolvePDSAuthServer(ctx context.Context, ustr string) (string, error) {
75 u, err := isSafeAndParsed(ustr)
76 if err != nil {
77 return "", err
78 }
79
80 u.Path = "/.well-known/oauth-protected-resource"
81
82 req, err := http.NewRequestWithContext(ctx, "GET", u.String(), nil)
83 if err != nil {
84 return "", fmt.Errorf("error creating request for oauth protected resource: %w", err)
85 }
86
87 resp, err := c.h.Do(req)
88 if err != nil {
89 return "", fmt.Errorf("could not get response from server: %w", err)
90 }
91 defer resp.Body.Close()
92
93 if resp.StatusCode != http.StatusOK {
94 io.Copy(io.Discard, resp.Body)
95 return "", fmt.Errorf("received non-200 response from pds. code was %d", resp.StatusCode)
96 }
97
98 b, err := io.ReadAll(resp.Body)
99 if err != nil {
100 return "", fmt.Errorf("could not read body: %w", err)
101 }
102
103 var resource OauthProtectedResource
104 if err := resource.UnmarshalJSON(b); err != nil {
105 return "", fmt.Errorf("could not unmarshal json: %w", err)
106 }
107
108 if len(resource.AuthorizationServers) == 0 {
109 return "", fmt.Errorf("oauth protected resource contained no authorization servers")
110 }
111
112 return resource.AuthorizationServers[0], nil
113}
114
115func (c *OauthClient) FetchAuthServerMetadata(ctx context.Context, ustr string) (*OauthAuthorizationMetadata, error) {
116 u, err := isSafeAndParsed(ustr)
117 if err != nil {
118 return nil, err
119 }
120
121 u.Path = "/.well-known/oauth-authorization-server"
122
123 req, err := http.NewRequestWithContext(ctx, "GET", u.String(), nil)
124 if err != nil {
125 return nil, fmt.Errorf("error creating request to fetch auth metadata: %w", err)
126 }
127
128 resp, err := c.h.Do(req)
129 if err != nil {
130 return nil, fmt.Errorf("error getting response for auth metadata: %w", err)
131 }
132 defer resp.Body.Close()
133
134 if resp.StatusCode != http.StatusOK {
135 io.Copy(io.Discard, resp.Body)
136 return nil, fmt.Errorf("received non-200 response from pds. status code was %d", resp.StatusCode)
137 }
138
139 b, err := io.ReadAll(resp.Body)
140 if err != nil {
141 return nil, fmt.Errorf("could not read body for metadata response: %w", err)
142 }
143
144 var metadata OauthAuthorizationMetadata
145 if err := metadata.UnmarshalJSON(b); err != nil {
146 return nil, fmt.Errorf("could not unmarshal metadata: %w", err)
147 }
148
149 if err := metadata.Validate(u); err != nil {
150 return nil, fmt.Errorf("could not validate metadata: %w", err)
151 }
152
153 return &metadata, nil
154}
155
156func (c *OauthClient) ClientAssertionJwt(authServerUrl string) (string, error) {
157 claims := jwt.MapClaims{
158 "iss": c.clientId,
159 "sub": c.clientId,
160 "aud": authServerUrl,
161 "jti": uuid.NewString(),
162 "iat": time.Now().Unix(),
163 }
164
165 token := jwt.NewWithClaims(jwt.SigningMethodES256, claims)
166 token.Header["kid"] = c.clientKid
167
168 tokenString, err := token.SignedString(c.clientPrivateKey)
169 if err != nil {
170 return "", err
171 }
172
173 return tokenString, nil
174}
175
176func (c *OauthClient) AuthServerDpopJwt(method, url, nonce string, privateJwk jwk.Key) (string, error) {
177 raw, err := jwk.PublicKeyOf(privateJwk)
178 if err != nil {
179 return "", err
180 }
181
182 pubJwk, err := jwk.FromRaw(raw)
183 if err != nil {
184 return "", err
185 }
186
187 b, err := json.Marshal(pubJwk)
188 if err != nil {
189 return "", err
190 }
191
192 var pubMap map[string]interface{}
193 if err := json.Unmarshal(b, &pubMap); err != nil {
194 return "", err
195 }
196
197 now := time.Now().Unix()
198
199 claims := jwt.MapClaims{
200 "jti": uuid.NewString(),
201 "htm": method,
202 "htu": url,
203 "iat": now,
204 "exp": now + 30,
205 }
206
207 if nonce != "" {
208 claims["nonce"] = nonce
209 }
210
211 token := jwt.NewWithClaims(jwt.SigningMethodES256, claims)
212 token.Header["typ"] = "dpop+jwt"
213 token.Header["alg"] = "ES256"
214 token.Header["jwk"] = pubMap
215
216 var rawKey interface{}
217 if err := privateJwk.Raw(&rawKey); err != nil {
218 return "", err
219 }
220
221 tokenString, err := token.SignedString(rawKey)
222 if err != nil {
223 return "", fmt.Errorf("failed to sign token: %w", err)
224 }
225
226 return tokenString, nil
227}
228
229type SendParAuthResponse struct {
230 PkceVerifier string
231 State string
232 DpopAuthserverNonce string
233 Resp map[string]string
234}
235
236func (c *OauthClient) SendParAuthRequest(ctx context.Context, authServerUrl string, authServerMeta *OauthAuthorizationMetadata, loginHint, scope string, dpopPrivateKey jwk.Key) (*SendParAuthResponse, error) {
237 if authServerMeta == nil {
238 return nil, fmt.Errorf("nil metadata provided")
239 }
240
241 parUrl := authServerMeta.PushedAuthorizationRequestEndpoint
242
243 state, err := generateToken(10)
244 if err != nil {
245 return nil, fmt.Errorf("could not generate state token: %w", err)
246 }
247
248 pkceVerifier, err := generateToken(48)
249 if err != nil {
250 return nil, fmt.Errorf("could not generate pkce verifier: %w", err)
251 }
252
253 codeChallenge := generateCodeChallenge(pkceVerifier)
254 codeChallengeMethod := "S256"
255
256 clientAssertion, err := c.ClientAssertionJwt(authServerUrl)
257 if err != nil {
258 return nil, err
259 }
260
261 // TODO: ??
262 nonce := ""
263 dpopProof, err := c.AuthServerDpopJwt("POST", parUrl, nonce, dpopPrivateKey)
264 if err != nil {
265 return nil, err
266 }
267
268 params := url.Values{
269 "response_type": {"code"},
270 "code_challenge": {codeChallenge},
271 "code_challenge_method": {codeChallengeMethod},
272 "client_id": {c.clientId},
273 "state": {state},
274 "redirect_uri": {c.redirectUri},
275 "scope": {scope},
276 "client_assertion_type": {"urn:ietf:params:oauth:client-assertion-type:jwt-bearer"},
277 "client_assertion": {clientAssertion},
278 }
279
280 if loginHint != "" {
281 params.Set("login_hint", loginHint)
282 }
283
284 _, err = isSafeAndParsed(parUrl)
285 if err != nil {
286 return nil, err
287 }
288
289 req, err := http.NewRequestWithContext(ctx, "POST", parUrl, strings.NewReader(params.Encode()))
290 if err != nil {
291 return nil, err
292 }
293
294 req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
295 req.Header.Set("DPoP", dpopProof)
296
297 resp, err := c.h.Do(req)
298 if err != nil {
299 return nil, err
300 }
301 defer resp.Body.Close()
302
303 var rmap map[string]string
304 if err := json.NewDecoder(resp.Body).Decode(&rmap); err != nil {
305 return nil, err
306 }
307
308 // TODO: there's some logic in the flask example where we retry if the server
309 // asks us to use a dpop nonce. we should add that here eventually, but for now
310 // we'll skip that
311
312 return &SendParAuthResponse{
313 PkceVerifier: pkceVerifier,
314 State: state,
315 DpopAuthserverNonce: "", // add here later
316 Resp: rmap,
317 }, nil
318}
319
320type TokenResponse struct {
321 DpopAuthserverNonce string
322 Resp map[string]string
323}
324
325func (c *OauthClient) InitialTokenRequest(ctx context.Context, authRequest map[string]string, code, appUrl string) (*TokenResponse, error) {
326 authserverUrl := authRequest["authserver_iss"]
327 authserverMeta, err := c.FetchAuthServerMetadata(ctx, authserverUrl)
328 if err != nil {
329 return nil, err
330 }
331
332 clientAssertion, err := c.ClientAssertionJwt(authserverUrl)
333 if err != nil {
334 return nil, err
335 }
336
337 params := url.Values{
338 "client_id": {c.clientId},
339 "redirect_uri": {c.redirectUri},
340 "grant_type": {"authorization_code"},
341 "code": {code},
342 "code_verifier": {authRequest["pkce_verifier"]},
343 "client_assertion_type": {"urn:ietf:params:oauth:client-assertion-type:jwt-bearer"},
344 "client_assertion": {clientAssertion},
345 }
346
347 dpopPrivateJwk, err := parsePrivateJwkFromString(authRequest["dpop_private_jwk"])
348 if err != nil {
349 return nil, err
350 }
351
352 dpopProof, err := c.AuthServerDpopJwt("POST", authserverMeta.TokenEndpoint, authRequest["dpop_authserver_nonce"], dpopPrivateJwk)
353 if err != nil {
354 return nil, err
355 }
356
357 dpopAuthserverNonce := authRequest["dpop_authserver_nonce"]
358
359 req, err := http.NewRequestWithContext(ctx, "POST", authserverMeta.TokenEndpoint, strings.NewReader(params.Encode()))
360 if err != nil {
361 return nil, err
362 }
363
364 req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
365 req.Header.Set("DPoP", dpopProof)
366
367 resp, err := c.h.Do(req)
368 if err != nil {
369 return nil, err
370 }
371 defer resp.Body.Close()
372
373 // TODO: use nonce if needed, same as in par
374
375 var rmap map[string]string
376 if err := json.NewDecoder(resp.Body).Decode(&rmap); err != nil {
377 return nil, err
378 }
379
380 return &TokenResponse{
381 DpopAuthserverNonce: dpopAuthserverNonce,
382 Resp: rmap,
383 }, nil
384}
385
386type RefreshTokenArgs struct {
387 AuthserverUrl string
388 RefreshToken string
389 DpopPrivateJwk string
390 DpopAuthserverNonce string
391}
392
393func (c *OauthClient) RefreshTokenRequest(ctx context.Context, args RefreshTokenArgs, appUrl string) (any, error) {
394 authserverMeta, err := c.FetchAuthServerMetadata(ctx, args.AuthserverUrl)
395 if err != nil {
396 return nil, err
397 }
398
399 clientAssertion, err := c.ClientAssertionJwt(args.AuthserverUrl)
400 if err != nil {
401 return nil, err
402 }
403
404 params := url.Values{
405 "client_id": {c.clientId},
406 "grant_type": {"refresh_token"},
407 "refresh_token": {args.RefreshToken},
408 "client_assertion_type": {"urn:ietf:params:oauth:client-assertion-type:jwt-bearer"},
409 "client_assertion": {clientAssertion},
410 }
411
412 dpopPrivateJwk, err := parsePrivateJwkFromString(args.DpopPrivateJwk)
413 if err != nil {
414 return nil, err
415 }
416
417 dpopProof, err := c.AuthServerDpopJwt("POST", authserverMeta.TokenEndpoint, args.DpopAuthserverNonce, dpopPrivateJwk)
418 if err != nil {
419 return nil, err
420 }
421
422 req, err := http.NewRequestWithContext(ctx, "POST", authserverMeta.TokenEndpoint, strings.NewReader(params.Encode()))
423 if err != nil {
424 return nil, err
425 }
426
427 req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
428 req.Header.Set("DPoP", dpopProof)
429
430 resp, err := c.h.Do(req)
431 if err != nil {
432 return nil, err
433 }
434 defer resp.Body.Close()
435
436 // TODO: handle same thing as above...
437
438 if resp.StatusCode != 200 && resp.StatusCode != 201 {
439 b, _ := io.ReadAll(resp.Body)
440 return nil, fmt.Errorf("token refresh error: %s", string(b))
441 }
442
443 var rmap map[string]string
444 if err := json.NewDecoder(resp.Body).Decode(&rmap); err != nil {
445 return nil, err
446 }
447
448 return &TokenResponse{
449 DpopAuthserverNonce: args.DpopAuthserverNonce,
450 Resp: rmap,
451 }, nil
452}
453
454func generateToken(len int) (string, error) {
455 b := make([]byte, len)
456 if _, err := rand.Read(b); err != nil {
457 return "", err
458 }
459
460 return hex.EncodeToString(b), nil
461}
462
463func generateCodeChallenge(pkceVerifier string) string {
464 h := sha256.New()
465 h.Write([]byte(pkceVerifier))
466 hash := h.Sum(nil)
467 return base64.RawURLEncoding.EncodeToString(hash)
468}
469
470func parsePrivateJwkFromString(str string) (jwk.Key, error) {
471 return jwk.ParseKey([]byte(str))
472}