this repo has no description
1package oauth 2 3import ( 4 "context" 5 "crypto/ecdsa" 6 "crypto/rand" 7 "crypto/sha256" 8 "encoding/base64" 9 "encoding/hex" 10 "encoding/json" 11 "fmt" 12 "io" 13 "net/http" 14 "net/url" 15 "strings" 16 "time" 17 18 "github.com/golang-jwt/jwt/v5" 19 "github.com/google/uuid" 20 "github.com/lestrrat-go/jwx/v2/jwk" 21) 22 23type OauthClient struct { 24 h *http.Client 25 clientPrivateKey *ecdsa.PrivateKey 26 clientKid string 27 clientId string 28 redirectUri string 29} 30 31type OauthClientArgs struct { 32 H *http.Client 33 ClientJwk []byte 34 ClientId string 35 RedirectUri string 36} 37 38func NewOauthClient(args OauthClientArgs) (*OauthClient, error) { 39 if args.ClientId == "" { 40 return nil, fmt.Errorf("no client id provided") 41 } 42 43 if args.RedirectUri == "" { 44 return nil, fmt.Errorf("no redirect uri provided") 45 } 46 47 if args.H == nil { 48 args.H = &http.Client{ 49 Timeout: 5 * time.Second, 50 } 51 } 52 53 clientJwk, err := jwk.ParseKey(args.ClientJwk) 54 if err != nil { 55 return nil, err 56 } 57 58 clientPkey, err := getPrivateKey(clientJwk) 59 if err != nil { 60 return nil, fmt.Errorf("could not load private key from provided client jwk: %w", err) 61 } 62 63 kid := clientJwk.KeyID() 64 65 return &OauthClient{ 66 h: args.H, 67 clientKid: kid, 68 clientPrivateKey: clientPkey, 69 clientId: args.ClientId, 70 redirectUri: args.RedirectUri, 71 }, nil 72} 73 74func (c *OauthClient) ResolvePDSAuthServer(ctx context.Context, ustr string) (string, error) { 75 u, err := isSafeAndParsed(ustr) 76 if err != nil { 77 return "", err 78 } 79 80 u.Path = "/.well-known/oauth-protected-resource" 81 82 req, err := http.NewRequestWithContext(ctx, "GET", u.String(), nil) 83 if err != nil { 84 return "", fmt.Errorf("error creating request for oauth protected resource: %w", err) 85 } 86 87 resp, err := c.h.Do(req) 88 if err != nil { 89 return "", fmt.Errorf("could not get response from server: %w", err) 90 } 91 defer resp.Body.Close() 92 93 if resp.StatusCode != http.StatusOK { 94 io.Copy(io.Discard, resp.Body) 95 return "", fmt.Errorf("received non-200 response from pds. code was %d", resp.StatusCode) 96 } 97 98 b, err := io.ReadAll(resp.Body) 99 if err != nil { 100 return "", fmt.Errorf("could not read body: %w", err) 101 } 102 103 var resource OauthProtectedResource 104 if err := resource.UnmarshalJSON(b); err != nil { 105 return "", fmt.Errorf("could not unmarshal json: %w", err) 106 } 107 108 if len(resource.AuthorizationServers) == 0 { 109 return "", fmt.Errorf("oauth protected resource contained no authorization servers") 110 } 111 112 return resource.AuthorizationServers[0], nil 113} 114 115func (c *OauthClient) FetchAuthServerMetadata(ctx context.Context, ustr string) (*OauthAuthorizationMetadata, error) { 116 u, err := isSafeAndParsed(ustr) 117 if err != nil { 118 return nil, err 119 } 120 121 u.Path = "/.well-known/oauth-authorization-server" 122 123 req, err := http.NewRequestWithContext(ctx, "GET", u.String(), nil) 124 if err != nil { 125 return nil, fmt.Errorf("error creating request to fetch auth metadata: %w", err) 126 } 127 128 resp, err := c.h.Do(req) 129 if err != nil { 130 return nil, fmt.Errorf("error getting response for auth metadata: %w", err) 131 } 132 defer resp.Body.Close() 133 134 if resp.StatusCode != http.StatusOK { 135 io.Copy(io.Discard, resp.Body) 136 return nil, fmt.Errorf("received non-200 response from pds. status code was %d", resp.StatusCode) 137 } 138 139 b, err := io.ReadAll(resp.Body) 140 if err != nil { 141 return nil, fmt.Errorf("could not read body for metadata response: %w", err) 142 } 143 144 var metadata OauthAuthorizationMetadata 145 if err := metadata.UnmarshalJSON(b); err != nil { 146 return nil, fmt.Errorf("could not unmarshal metadata: %w", err) 147 } 148 149 if err := metadata.Validate(u); err != nil { 150 return nil, fmt.Errorf("could not validate metadata: %w", err) 151 } 152 153 return &metadata, nil 154} 155 156func (c *OauthClient) ClientAssertionJwt(authServerUrl string) (string, error) { 157 claims := jwt.MapClaims{ 158 "iss": c.clientId, 159 "sub": c.clientId, 160 "aud": authServerUrl, 161 "jti": uuid.NewString(), 162 "iat": time.Now().Unix(), 163 } 164 165 token := jwt.NewWithClaims(jwt.SigningMethodES256, claims) 166 token.Header["kid"] = c.clientKid 167 168 tokenString, err := token.SignedString(c.clientPrivateKey) 169 if err != nil { 170 return "", err 171 } 172 173 return tokenString, nil 174} 175 176func (c *OauthClient) AuthServerDpopJwt(method, url, nonce string, privateJwk jwk.Key) (string, error) { 177 raw, err := jwk.PublicKeyOf(privateJwk) 178 if err != nil { 179 return "", err 180 } 181 182 pubJwk, err := jwk.FromRaw(raw) 183 if err != nil { 184 return "", err 185 } 186 187 b, err := json.Marshal(pubJwk) 188 if err != nil { 189 return "", err 190 } 191 192 var pubMap map[string]interface{} 193 if err := json.Unmarshal(b, &pubMap); err != nil { 194 return "", err 195 } 196 197 now := time.Now().Unix() 198 199 claims := jwt.MapClaims{ 200 "jti": uuid.NewString(), 201 "htm": method, 202 "htu": url, 203 "iat": now, 204 "exp": now + 30, 205 } 206 207 if nonce != "" { 208 claims["nonce"] = nonce 209 } 210 211 token := jwt.NewWithClaims(jwt.SigningMethodES256, claims) 212 token.Header["typ"] = "dpop+jwt" 213 token.Header["alg"] = "ES256" 214 token.Header["jwk"] = pubMap 215 216 var rawKey interface{} 217 if err := privateJwk.Raw(&rawKey); err != nil { 218 return "", err 219 } 220 221 tokenString, err := token.SignedString(rawKey) 222 if err != nil { 223 return "", fmt.Errorf("failed to sign token: %w", err) 224 } 225 226 return tokenString, nil 227} 228 229type SendParAuthResponse struct { 230 PkceVerifier string 231 State string 232 DpopAuthserverNonce string 233 Resp map[string]string 234} 235 236func (c *OauthClient) SendParAuthRequest(ctx context.Context, authServerUrl string, authServerMeta *OauthAuthorizationMetadata, loginHint, scope string, dpopPrivateKey jwk.Key) (*SendParAuthResponse, error) { 237 if authServerMeta == nil { 238 return nil, fmt.Errorf("nil metadata provided") 239 } 240 241 parUrl := authServerMeta.PushedAuthorizationRequestEndpoint 242 243 state, err := generateToken(10) 244 if err != nil { 245 return nil, fmt.Errorf("could not generate state token: %w", err) 246 } 247 248 pkceVerifier, err := generateToken(48) 249 if err != nil { 250 return nil, fmt.Errorf("could not generate pkce verifier: %w", err) 251 } 252 253 codeChallenge := generateCodeChallenge(pkceVerifier) 254 codeChallengeMethod := "S256" 255 256 clientAssertion, err := c.ClientAssertionJwt(authServerUrl) 257 if err != nil { 258 return nil, err 259 } 260 261 // TODO: ?? 262 nonce := "" 263 dpopProof, err := c.AuthServerDpopJwt("POST", parUrl, nonce, dpopPrivateKey) 264 if err != nil { 265 return nil, err 266 } 267 268 params := url.Values{ 269 "response_type": {"code"}, 270 "code_challenge": {codeChallenge}, 271 "code_challenge_method": {codeChallengeMethod}, 272 "client_id": {c.clientId}, 273 "state": {state}, 274 "redirect_uri": {c.redirectUri}, 275 "scope": {scope}, 276 "client_assertion_type": {"urn:ietf:params:oauth:client-assertion-type:jwt-bearer"}, 277 "client_assertion": {clientAssertion}, 278 } 279 280 if loginHint != "" { 281 params.Set("login_hint", loginHint) 282 } 283 284 _, err = isSafeAndParsed(parUrl) 285 if err != nil { 286 return nil, err 287 } 288 289 req, err := http.NewRequestWithContext(ctx, "POST", parUrl, strings.NewReader(params.Encode())) 290 if err != nil { 291 return nil, err 292 } 293 294 req.Header.Set("Content-Type", "application/x-www-form-urlencoded") 295 req.Header.Set("DPoP", dpopProof) 296 297 resp, err := c.h.Do(req) 298 if err != nil { 299 return nil, err 300 } 301 defer resp.Body.Close() 302 303 var rmap map[string]string 304 if err := json.NewDecoder(resp.Body).Decode(&rmap); err != nil { 305 return nil, err 306 } 307 308 // TODO: there's some logic in the flask example where we retry if the server 309 // asks us to use a dpop nonce. we should add that here eventually, but for now 310 // we'll skip that 311 312 return &SendParAuthResponse{ 313 PkceVerifier: pkceVerifier, 314 State: state, 315 DpopAuthserverNonce: "", // add here later 316 Resp: rmap, 317 }, nil 318} 319 320type TokenResponse struct { 321 DpopAuthserverNonce string 322 Resp map[string]string 323} 324 325func (c *OauthClient) InitialTokenRequest(ctx context.Context, authRequest map[string]string, code, appUrl string) (*TokenResponse, error) { 326 authserverUrl := authRequest["authserver_iss"] 327 authserverMeta, err := c.FetchAuthServerMetadata(ctx, authserverUrl) 328 if err != nil { 329 return nil, err 330 } 331 332 clientAssertion, err := c.ClientAssertionJwt(authserverUrl) 333 if err != nil { 334 return nil, err 335 } 336 337 params := url.Values{ 338 "client_id": {c.clientId}, 339 "redirect_uri": {c.redirectUri}, 340 "grant_type": {"authorization_code"}, 341 "code": {code}, 342 "code_verifier": {authRequest["pkce_verifier"]}, 343 "client_assertion_type": {"urn:ietf:params:oauth:client-assertion-type:jwt-bearer"}, 344 "client_assertion": {clientAssertion}, 345 } 346 347 dpopPrivateJwk, err := parsePrivateJwkFromString(authRequest["dpop_private_jwk"]) 348 if err != nil { 349 return nil, err 350 } 351 352 dpopProof, err := c.AuthServerDpopJwt("POST", authserverMeta.TokenEndpoint, authRequest["dpop_authserver_nonce"], dpopPrivateJwk) 353 if err != nil { 354 return nil, err 355 } 356 357 dpopAuthserverNonce := authRequest["dpop_authserver_nonce"] 358 359 req, err := http.NewRequestWithContext(ctx, "POST", authserverMeta.TokenEndpoint, strings.NewReader(params.Encode())) 360 if err != nil { 361 return nil, err 362 } 363 364 req.Header.Set("Content-Type", "application/x-www-form-urlencoded") 365 req.Header.Set("DPoP", dpopProof) 366 367 resp, err := c.h.Do(req) 368 if err != nil { 369 return nil, err 370 } 371 defer resp.Body.Close() 372 373 // TODO: use nonce if needed, same as in par 374 375 var rmap map[string]string 376 if err := json.NewDecoder(resp.Body).Decode(&rmap); err != nil { 377 return nil, err 378 } 379 380 return &TokenResponse{ 381 DpopAuthserverNonce: dpopAuthserverNonce, 382 Resp: rmap, 383 }, nil 384} 385 386type RefreshTokenArgs struct { 387 AuthserverUrl string 388 RefreshToken string 389 DpopPrivateJwk string 390 DpopAuthserverNonce string 391} 392 393func (c *OauthClient) RefreshTokenRequest(ctx context.Context, args RefreshTokenArgs, appUrl string) (any, error) { 394 authserverMeta, err := c.FetchAuthServerMetadata(ctx, args.AuthserverUrl) 395 if err != nil { 396 return nil, err 397 } 398 399 clientAssertion, err := c.ClientAssertionJwt(args.AuthserverUrl) 400 if err != nil { 401 return nil, err 402 } 403 404 params := url.Values{ 405 "client_id": {c.clientId}, 406 "grant_type": {"refresh_token"}, 407 "refresh_token": {args.RefreshToken}, 408 "client_assertion_type": {"urn:ietf:params:oauth:client-assertion-type:jwt-bearer"}, 409 "client_assertion": {clientAssertion}, 410 } 411 412 dpopPrivateJwk, err := parsePrivateJwkFromString(args.DpopPrivateJwk) 413 if err != nil { 414 return nil, err 415 } 416 417 dpopProof, err := c.AuthServerDpopJwt("POST", authserverMeta.TokenEndpoint, args.DpopAuthserverNonce, dpopPrivateJwk) 418 if err != nil { 419 return nil, err 420 } 421 422 req, err := http.NewRequestWithContext(ctx, "POST", authserverMeta.TokenEndpoint, strings.NewReader(params.Encode())) 423 if err != nil { 424 return nil, err 425 } 426 427 req.Header.Set("Content-Type", "application/x-www-form-urlencoded") 428 req.Header.Set("DPoP", dpopProof) 429 430 resp, err := c.h.Do(req) 431 if err != nil { 432 return nil, err 433 } 434 defer resp.Body.Close() 435 436 // TODO: handle same thing as above... 437 438 if resp.StatusCode != 200 && resp.StatusCode != 201 { 439 b, _ := io.ReadAll(resp.Body) 440 return nil, fmt.Errorf("token refresh error: %s", string(b)) 441 } 442 443 var rmap map[string]string 444 if err := json.NewDecoder(resp.Body).Decode(&rmap); err != nil { 445 return nil, err 446 } 447 448 return &TokenResponse{ 449 DpopAuthserverNonce: args.DpopAuthserverNonce, 450 Resp: rmap, 451 }, nil 452} 453 454func generateToken(len int) (string, error) { 455 b := make([]byte, len) 456 if _, err := rand.Read(b); err != nil { 457 return "", err 458 } 459 460 return hex.EncodeToString(b), nil 461} 462 463func generateCodeChallenge(pkceVerifier string) string { 464 h := sha256.New() 465 h.Write([]byte(pkceVerifier)) 466 hash := h.Sum(nil) 467 return base64.RawURLEncoding.EncodeToString(hash) 468} 469 470func parsePrivateJwkFromString(str string) (jwk.Key, error) { 471 return jwk.ParseKey([]byte(str)) 472}