this repo has no description
1package oauth 2 3import ( 4 "context" 5 "crypto/ecdsa" 6 "crypto/rand" 7 "crypto/sha256" 8 "encoding/base64" 9 "encoding/hex" 10 "encoding/json" 11 "fmt" 12 "io" 13 "net/http" 14 "net/url" 15 "strings" 16 "time" 17 18 "github.com/golang-jwt/jwt/v5" 19 "github.com/google/uuid" 20 "github.com/lestrrat-go/jwx/v2/jwk" 21) 22 23type Client struct { 24 h *http.Client 25 clientPrivateKey *ecdsa.PrivateKey 26 clientKid string 27 clientId string 28 redirectUri string 29} 30 31type ClientArgs struct { 32 H *http.Client 33 ClientJwk jwk.Key 34 ClientId string 35 RedirectUri string 36} 37 38func NewClient(args ClientArgs) (*Client, error) { 39 if args.ClientId == "" { 40 return nil, fmt.Errorf("no client id provided") 41 } 42 43 if args.RedirectUri == "" { 44 return nil, fmt.Errorf("no redirect uri provided") 45 } 46 47 if args.H == nil { 48 args.H = &http.Client{ 49 Timeout: 5 * time.Second, 50 } 51 } 52 53 clientPkey, err := getPrivateKey(args.ClientJwk) 54 if err != nil { 55 return nil, fmt.Errorf("could not load private key from provided client jwk: %w", err) 56 } 57 58 kid := args.ClientJwk.KeyID() 59 60 return &Client{ 61 h: args.H, 62 clientKid: kid, 63 clientPrivateKey: clientPkey, 64 clientId: args.ClientId, 65 redirectUri: args.RedirectUri, 66 }, nil 67} 68 69func (c *Client) ResolvePDSAuthServer(ctx context.Context, ustr string) (string, error) { 70 u, err := isSafeAndParsed(ustr) 71 if err != nil { 72 return "", err 73 } 74 75 u.Path = "/.well-known/oauth-protected-resource" 76 77 req, err := http.NewRequestWithContext(ctx, "GET", u.String(), nil) 78 if err != nil { 79 return "", fmt.Errorf("error creating request for oauth protected resource: %w", err) 80 } 81 82 resp, err := c.h.Do(req) 83 if err != nil { 84 return "", fmt.Errorf("could not get response from server: %w", err) 85 } 86 defer resp.Body.Close() 87 88 if resp.StatusCode != http.StatusOK { 89 io.Copy(io.Discard, resp.Body) 90 return "", fmt.Errorf("received non-200 response from pds. code was %d", resp.StatusCode) 91 } 92 93 b, err := io.ReadAll(resp.Body) 94 if err != nil { 95 return "", fmt.Errorf("could not read body: %w", err) 96 } 97 98 var resource OauthProtectedResource 99 if err := resource.UnmarshalJSON(b); err != nil { 100 return "", fmt.Errorf("could not unmarshal json: %w", err) 101 } 102 103 if len(resource.AuthorizationServers) == 0 { 104 return "", fmt.Errorf("oauth protected resource contained no authorization servers") 105 } 106 107 return resource.AuthorizationServers[0], nil 108} 109 110func (c *Client) FetchAuthServerMetadata(ctx context.Context, ustr string) (*OauthAuthorizationMetadata, error) { 111 u, err := isSafeAndParsed(ustr) 112 if err != nil { 113 return nil, err 114 } 115 116 u.Path = "/.well-known/oauth-authorization-server" 117 118 req, err := http.NewRequestWithContext(ctx, "GET", u.String(), nil) 119 if err != nil { 120 return nil, fmt.Errorf("error creating request to fetch auth metadata: %w", err) 121 } 122 123 resp, err := c.h.Do(req) 124 if err != nil { 125 return nil, fmt.Errorf("error getting response for auth metadata: %w", err) 126 } 127 defer resp.Body.Close() 128 129 if resp.StatusCode != http.StatusOK { 130 io.Copy(io.Discard, resp.Body) 131 return nil, fmt.Errorf( 132 "received non-200 response from pds. status code was %d", 133 resp.StatusCode, 134 ) 135 } 136 137 b, err := io.ReadAll(resp.Body) 138 if err != nil { 139 return nil, fmt.Errorf("could not read body for metadata response: %w", err) 140 } 141 142 var metadata OauthAuthorizationMetadata 143 if err := metadata.UnmarshalJSON(b); err != nil { 144 return nil, fmt.Errorf("could not unmarshal metadata: %w", err) 145 } 146 147 if err := metadata.Validate(u); err != nil { 148 return nil, fmt.Errorf("could not validate metadata: %w", err) 149 } 150 151 return &metadata, nil 152} 153 154func (c *Client) ClientAssertionJwt(authServerUrl string) (string, error) { 155 claims := jwt.MapClaims{ 156 "iss": c.clientId, 157 "sub": c.clientId, 158 "aud": authServerUrl, 159 "jti": uuid.NewString(), 160 "iat": time.Now().Unix(), 161 } 162 163 token := jwt.NewWithClaims(jwt.SigningMethodES256, claims) 164 token.Header["kid"] = c.clientKid 165 166 tokenString, err := token.SignedString(c.clientPrivateKey) 167 if err != nil { 168 return "", err 169 } 170 171 return tokenString, nil 172} 173 174func (c *Client) AuthServerDpopJwt(method, url, nonce string, privateJwk jwk.Key) (string, error) { 175 pubJwk, err := privateJwk.PublicKey() 176 if err != nil { 177 return "", err 178 } 179 180 b, err := json.Marshal(pubJwk) 181 if err != nil { 182 return "", err 183 } 184 185 var pubMap map[string]any 186 if err := json.Unmarshal(b, &pubMap); err != nil { 187 return "", err 188 } 189 190 now := time.Now().Unix() 191 192 claims := jwt.MapClaims{ 193 "jti": uuid.NewString(), 194 "htm": method, 195 "htu": url, 196 "iat": now, 197 "exp": now + 30, 198 } 199 200 if nonce != "" { 201 claims["nonce"] = nonce 202 } 203 204 token := jwt.NewWithClaims(jwt.SigningMethodES256, claims) 205 token.Header["typ"] = "dpop+jwt" 206 token.Header["alg"] = "ES256" 207 token.Header["jwk"] = pubMap 208 209 var rawKey any 210 if err := privateJwk.Raw(&rawKey); err != nil { 211 return "", err 212 } 213 214 tokenString, err := token.SignedString(rawKey) 215 if err != nil { 216 return "", fmt.Errorf("failed to sign token: %w", err) 217 } 218 219 return tokenString, nil 220} 221 222func (c *Client) SendParAuthRequest(ctx context.Context, authServerUrl string, authServerMeta *OauthAuthorizationMetadata, loginHint, scope string, dpopPrivateKey jwk.Key) (*SendParAuthResponse, error) { 223 if authServerMeta == nil { 224 return nil, fmt.Errorf("nil metadata provided") 225 } 226 227 parUrl := authServerMeta.PushedAuthorizationRequestEndpoint 228 229 state, err := generateToken(10) 230 if err != nil { 231 return nil, fmt.Errorf("could not generate state token: %w", err) 232 } 233 234 pkceVerifier, err := generateToken(48) 235 if err != nil { 236 return nil, fmt.Errorf("could not generate pkce verifier: %w", err) 237 } 238 239 codeChallenge := generateCodeChallenge(pkceVerifier) 240 codeChallengeMethod := "S256" 241 242 clientAssertion, err := c.ClientAssertionJwt(authServerUrl) 243 if err != nil { 244 return nil, err 245 } 246 247 dpopAuthserverNonce := "" 248 dpopProof, err := c.AuthServerDpopJwt("POST", parUrl, dpopAuthserverNonce, dpopPrivateKey) 249 if err != nil { 250 return nil, fmt.Errorf("error getting dpop proof: %w", err) 251 } 252 253 params := url.Values{ 254 "response_type": {"code"}, 255 "code_challenge": {codeChallenge}, 256 "code_challenge_method": {codeChallengeMethod}, 257 "client_id": {c.clientId}, 258 "state": {state}, 259 "redirect_uri": {c.redirectUri}, 260 "scope": {scope}, 261 "client_assertion_type": {"urn:ietf:params:oauth:client-assertion-type:jwt-bearer"}, 262 "client_assertion": {clientAssertion}, 263 } 264 265 if loginHint != "" { 266 params.Set("login_hint", loginHint) 267 } 268 269 _, err = isSafeAndParsed(parUrl) 270 if err != nil { 271 return nil, err 272 } 273 274 req, err := http.NewRequestWithContext(ctx, "POST", parUrl, strings.NewReader(params.Encode())) 275 if err != nil { 276 return nil, err 277 } 278 279 req.Header.Set("Content-Type", "application/x-www-form-urlencoded") 280 req.Header.Set("DPoP", dpopProof) 281 282 resp, err := c.h.Do(req) 283 if err != nil { 284 return nil, err 285 } 286 defer resp.Body.Close() 287 288 var rmap map[string]any 289 if err := json.NewDecoder(resp.Body).Decode(&rmap); err != nil { 290 return nil, err 291 } 292 293 if resp.StatusCode == 400 && rmap["error"] == "use_dpop_nonce" { 294 dpopAuthserverNonce = resp.Header.Get("DPoP-Nonce") 295 dpopProof, err := c.AuthServerDpopJwt("POST", parUrl, dpopAuthserverNonce, dpopPrivateKey) 296 if err != nil { 297 return nil, err 298 } 299 300 req2, err := http.NewRequestWithContext( 301 ctx, 302 "POST", 303 parUrl, 304 strings.NewReader(params.Encode()), 305 ) 306 if err != nil { 307 return nil, err 308 } 309 310 req2.Header.Set("Content-Type", "application/x-www-form-urlencoded") 311 req2.Header.Set("DPoP", dpopProof) 312 313 resp2, err := c.h.Do(req2) 314 if err != nil { 315 return nil, err 316 } 317 defer resp2.Body.Close() 318 319 rmap = map[string]any{} 320 if err := json.NewDecoder(resp2.Body).Decode(&rmap); err != nil { 321 return nil, err 322 } 323 } 324 325 return &SendParAuthResponse{ 326 PkceVerifier: pkceVerifier, 327 State: state, 328 DpopAuthserverNonce: dpopAuthserverNonce, 329 Resp: rmap, 330 }, nil 331} 332 333func (c *Client) InitialTokenRequest( 334 ctx context.Context, 335 code, 336 appUrl, 337 authserverIss, 338 pkceVerifier, 339 dpopAuthserverNonce string, 340 dpopPrivateJwk jwk.Key, 341) (*TokenResponse, error) { 342 authserverMeta, err := c.FetchAuthServerMetadata(ctx, authserverIss) 343 if err != nil { 344 return nil, err 345 } 346 347 clientAssertion, err := c.ClientAssertionJwt(authserverIss) 348 if err != nil { 349 return nil, err 350 } 351 352 params := url.Values{ 353 "client_id": {c.clientId}, 354 "redirect_uri": {c.redirectUri}, 355 "grant_type": {"authorization_code"}, 356 "code": {code}, 357 "code_verifier": {pkceVerifier}, 358 "client_assertion_type": {"urn:ietf:params:oauth:client-assertion-type:jwt-bearer"}, 359 "client_assertion": {clientAssertion}, 360 } 361 362 dpopProof, err := c.AuthServerDpopJwt( 363 "POST", 364 authserverMeta.TokenEndpoint, 365 dpopAuthserverNonce, 366 dpopPrivateJwk, 367 ) 368 if err != nil { 369 return nil, err 370 } 371 372 req, err := http.NewRequestWithContext( 373 ctx, 374 "POST", 375 authserverMeta.TokenEndpoint, 376 strings.NewReader(params.Encode()), 377 ) 378 if err != nil { 379 return nil, err 380 } 381 382 req.Header.Set("Content-Type", "application/x-www-form-urlencoded") 383 req.Header.Set("DPoP", dpopProof) 384 385 resp, err := c.h.Do(req) 386 if err != nil { 387 return nil, err 388 } 389 defer resp.Body.Close() 390 391 var rmap map[string]any 392 if err := json.NewDecoder(resp.Body).Decode(&rmap); err != nil { 393 return nil, err 394 } 395 396 return &TokenResponse{ 397 DpopAuthserverNonce: dpopAuthserverNonce, 398 Resp: rmap, 399 }, nil 400} 401 402func (c *Client) RefreshTokenRequest(ctx context.Context, authserverUrl, refreshToken, dpopAuthserverNonce string, dpopPrivateJwk jwk.Key) (*TokenResponse, error) { 403 authserverMeta, err := c.FetchAuthServerMetadata(ctx, authserverUrl) 404 if err != nil { 405 return nil, err 406 } 407 408 clientAssertion, err := c.ClientAssertionJwt(authserverUrl) 409 if err != nil { 410 return nil, err 411 } 412 413 params := url.Values{ 414 "client_id": {c.clientId}, 415 "grant_type": {"refresh_token"}, 416 "refresh_token": {refreshToken}, 417 "client_assertion_type": {"urn:ietf:params:oauth:client-assertion-type:jwt-bearer"}, 418 "client_assertion": {clientAssertion}, 419 } 420 421 dpopProof, err := c.AuthServerDpopJwt( 422 "POST", 423 authserverMeta.TokenEndpoint, 424 dpopAuthserverNonce, 425 dpopPrivateJwk, 426 ) 427 if err != nil { 428 return nil, err 429 } 430 431 req, err := http.NewRequestWithContext( 432 ctx, 433 "POST", 434 authserverMeta.TokenEndpoint, 435 strings.NewReader(params.Encode()), 436 ) 437 if err != nil { 438 return nil, err 439 } 440 441 req.Header.Set("Content-Type", "application/x-www-form-urlencoded") 442 req.Header.Set("DPoP", dpopProof) 443 444 resp, err := c.h.Do(req) 445 if err != nil { 446 return nil, err 447 } 448 defer resp.Body.Close() 449 450 // TODO: handle same thing as above... 451 452 if resp.StatusCode != 200 && resp.StatusCode != 201 { 453 b, _ := io.ReadAll(resp.Body) 454 return nil, fmt.Errorf("token refresh error: %s", string(b)) 455 } 456 457 var rmap map[string]any 458 if err := json.NewDecoder(resp.Body).Decode(&rmap); err != nil { 459 return nil, err 460 } 461 462 return &TokenResponse{ 463 DpopAuthserverNonce: dpopAuthserverNonce, 464 Resp: rmap, 465 }, nil 466} 467 468func generateToken(len int) (string, error) { 469 b := make([]byte, len) 470 if _, err := rand.Read(b); err != nil { 471 return "", err 472 } 473 474 return hex.EncodeToString(b), nil 475} 476 477func generateCodeChallenge(pkceVerifier string) string { 478 h := sha256.New() 479 h.Write([]byte(pkceVerifier)) 480 hash := h.Sum(nil) 481 return base64.RawURLEncoding.EncodeToString(hash) 482} 483 484func parsePrivateJwkFromString(str string) (jwk.Key, error) { 485 return jwk.ParseKey([]byte(str)) 486}