this repo has no description
1package oauth 2 3import ( 4 "bytes" 5 "context" 6 "encoding/json" 7 "fmt" 8 "io" 9 "net/http" 10 "net/url" 11 "strconv" 12 "time" 13 14 "github.com/bluesky-social/indigo/util" 15 "github.com/bluesky-social/indigo/xrpc" 16 "github.com/carlmjohnson/versioninfo" 17 "github.com/golang-jwt/jwt/v5" 18 "github.com/google/uuid" 19 "github.com/lestrrat-go/jwx/v2/jwk" 20) 21 22type XrpcClient struct { 23 // Client is an HTTP client to use. If not set, defaults to http.RobustHTTPClient(). 24 Client *http.Client 25 UserAgent *string 26 Headers map[string]string 27 OnDPoPNonceChanged func(did, newNonce string) 28} 29 30type XrpcAuthedRequestArgs struct { 31 Did string 32 PdsUrl string 33 Issuer string 34 AccessToken string 35 DpopPdsNonce string 36 DpopPrivateJwk jwk.Key 37} 38 39func (c *XrpcClient) getClient() *http.Client { 40 if c.Client == nil { 41 return util.RobustHTTPClient() 42 } 43 return c.Client 44} 45 46func errorFromHTTPResponse(resp *http.Response, err error) error { 47 r := &xrpc.Error{ 48 StatusCode: resp.StatusCode, 49 Wrapped: err, 50 } 51 if resp.Header.Get("ratelimit-limit") != "" { 52 r.Ratelimit = &xrpc.RatelimitInfo{ 53 Policy: resp.Header.Get("ratelimit-policy"), 54 } 55 if n, err := strconv.ParseInt(resp.Header.Get("ratelimit-reset"), 10, 64); err == nil { 56 r.Ratelimit.Reset = time.Unix(n, 0) 57 } 58 if n, err := strconv.ParseInt(resp.Header.Get("ratelimit-limit"), 10, 64); err == nil { 59 r.Ratelimit.Limit = int(n) 60 } 61 if n, err := strconv.ParseInt(resp.Header.Get("ratelimit-remaining"), 10, 64); err == nil { 62 r.Ratelimit.Remaining = int(n) 63 } 64 } 65 return r 66} 67 68// makeParams converts a map of string keys and any values into a URL-encoded string. 69// If a value is a slice of strings, it will be joined with commas. 70// Generally the values will be strings, numbers, booleans, or slices of strings 71func makeParams(p map[string]any) string { 72 params := url.Values{} 73 for k, v := range p { 74 if s, ok := v.([]string); ok { 75 for _, v := range s { 76 params.Add(k, v) 77 } 78 } else { 79 params.Add(k, fmt.Sprint(v)) 80 } 81 } 82 83 return params.Encode() 84} 85 86func PdsDpopJwt(method, url, iss, accessToken, nonce string, privateJwk jwk.Key) (string, error) { 87 pubJwk, err := privateJwk.PublicKey() 88 if err != nil { 89 return "", err 90 } 91 92 b, err := json.Marshal(pubJwk) 93 if err != nil { 94 return "", err 95 } 96 97 var pubMap map[string]any 98 if err := json.Unmarshal(b, &pubMap); err != nil { 99 return "", err 100 } 101 102 now := time.Now().Unix() 103 104 claims := jwt.MapClaims{ 105 "iss": iss, 106 "iat": now, 107 "exp": now + 30, 108 "jti": uuid.NewString(), 109 "htm": method, 110 "htu": url, 111 "ath": generateCodeChallenge(accessToken), 112 } 113 114 if nonce != "" { 115 claims["nonce"] = nonce 116 } 117 118 token := jwt.NewWithClaims(jwt.SigningMethodES256, claims) 119 token.Header["typ"] = "dpop+jwt" 120 token.Header["alg"] = "ES256" 121 token.Header["jwk"] = pubMap 122 123 var rawKey any 124 if err := privateJwk.Raw(&rawKey); err != nil { 125 return "", err 126 } 127 128 tokenString, err := token.SignedString(rawKey) 129 if err != nil { 130 return "", fmt.Errorf("failed to sign token: %w", err) 131 } 132 133 return tokenString, nil 134} 135 136func (c *XrpcClient) Do(ctx context.Context, authedArgs *XrpcAuthedRequestArgs, kind xrpc.XRPCRequestType, inpenc, method string, params map[string]any, bodyobj any, out any) error { 137 // we might have to retry the request if we get a new nonce from the server 138 for range 2 { 139 var body io.Reader 140 if bodyobj != nil { 141 if rr, ok := bodyobj.(io.Reader); ok { 142 body = rr 143 } else { 144 b, err := json.Marshal(bodyobj) 145 if err != nil { 146 return err 147 } 148 149 body = bytes.NewReader(b) 150 } 151 } 152 153 var m string 154 switch kind { 155 case xrpc.Query: 156 m = "GET" 157 case xrpc.Procedure: 158 m = "POST" 159 default: 160 return fmt.Errorf("unsupported request kind: %d", kind) 161 } 162 163 var paramStr string 164 if len(params) > 0 { 165 paramStr = "?" + makeParams(params) 166 } 167 168 ustr := authedArgs.PdsUrl + "/xrpc/" + method + paramStr 169 req, err := http.NewRequest(m, ustr, body) 170 if err != nil { 171 return err 172 } 173 174 if bodyobj != nil && inpenc != "" { 175 req.Header.Set("Content-Type", inpenc) 176 } 177 if c.UserAgent != nil { 178 req.Header.Set("User-Agent", *c.UserAgent) 179 } else { 180 req.Header.Set("User-Agent", "atproto-oauth/"+versioninfo.Short()) 181 } 182 183 if c.Headers != nil { 184 for k, v := range c.Headers { 185 req.Header.Set(k, v) 186 } 187 } 188 189 if authedArgs != nil { 190 dpopJwt, err := PdsDpopJwt(m, ustr, authedArgs.Issuer, authedArgs.AccessToken, authedArgs.DpopPdsNonce, authedArgs.DpopPrivateJwk) 191 if err != nil { 192 return err 193 } 194 195 req.Header.Set("DPoP", dpopJwt) 196 req.Header.Set("Authorization", "DPoP "+authedArgs.AccessToken) 197 } 198 199 resp, err := c.getClient().Do(req.WithContext(ctx)) 200 if err != nil { 201 return fmt.Errorf("request failed: %w", err) 202 } 203 204 defer resp.Body.Close() 205 206 if resp.StatusCode != 200 { 207 var xe xrpc.XRPCError 208 if err := json.NewDecoder(resp.Body).Decode(&xe); err != nil { 209 return errorFromHTTPResponse(resp, fmt.Errorf("failed to decode xrpc error message: %w", err)) 210 } 211 212 // if we get a new nonce, update the nonce and make the request again 213 if (resp.StatusCode == 400 || resp.StatusCode == 401) && xe.ErrStr == "use_dpop_nonce" { 214 newNonce := resp.Header.Get("DPoP-Nonce") 215 c.OnDPoPNonceChanged(authedArgs.Did, newNonce) 216 authedArgs.DpopPdsNonce = newNonce 217 continue 218 } 219 220 return errorFromHTTPResponse(resp, &xe) 221 } 222 223 if out != nil { 224 if buf, ok := out.(*bytes.Buffer); ok { 225 if resp.ContentLength < 0 { 226 _, err := io.Copy(buf, resp.Body) 227 if err != nil { 228 return fmt.Errorf("reading response body: %w", err) 229 } 230 } else { 231 n, err := io.CopyN(buf, resp.Body, resp.ContentLength) 232 if err != nil { 233 return fmt.Errorf("reading length delimited response body (%d < %d): %w", n, resp.ContentLength, err) 234 } 235 } 236 } else { 237 if err := json.NewDecoder(resp.Body).Decode(out); err != nil { 238 return fmt.Errorf("decoding xrpc response: %w", err) 239 } 240 } 241 } 242 243 break 244 } 245 246 return nil 247}