this repo has no description
1package oauth 2 3import ( 4 "bytes" 5 "context" 6 "encoding/json" 7 "fmt" 8 "io" 9 "net/http" 10 "net/url" 11 "strconv" 12 "time" 13 14 "github.com/bluesky-social/indigo/util" 15 "github.com/bluesky-social/indigo/xrpc" 16 "github.com/carlmjohnson/versioninfo" 17 "github.com/golang-jwt/jwt/v5" 18 "github.com/google/uuid" 19 "github.com/lestrrat-go/jwx/v2/jwk" 20) 21 22type XrpcClient struct { 23 // Client is an HTTP client to use. If not set, defaults to http.RobustHTTPClient(). 24 Client *http.Client 25 UserAgent *string 26 Headers map[string]string 27 OnDPoPNonceChanged func(did, newNonce string) 28} 29 30func (c *XrpcClient) getClient() *http.Client { 31 if c.Client == nil { 32 return util.RobustHTTPClient() 33 } 34 return c.Client 35} 36 37func errorFromHTTPResponse(resp *http.Response, err error) error { 38 r := &xrpc.Error{ 39 StatusCode: resp.StatusCode, 40 Wrapped: err, 41 } 42 if resp.Header.Get("ratelimit-limit") != "" { 43 r.Ratelimit = &xrpc.RatelimitInfo{ 44 Policy: resp.Header.Get("ratelimit-policy"), 45 } 46 if n, err := strconv.ParseInt(resp.Header.Get("ratelimit-reset"), 10, 64); err == nil { 47 r.Ratelimit.Reset = time.Unix(n, 0) 48 } 49 if n, err := strconv.ParseInt(resp.Header.Get("ratelimit-limit"), 10, 64); err == nil { 50 r.Ratelimit.Limit = int(n) 51 } 52 if n, err := strconv.ParseInt(resp.Header.Get("ratelimit-remaining"), 10, 64); err == nil { 53 r.Ratelimit.Remaining = int(n) 54 } 55 } 56 return r 57} 58 59// makeParams converts a map of string keys and any values into a URL-encoded string. 60// If a value is a slice of strings, it will be joined with commas. 61// Generally the values will be strings, numbers, booleans, or slices of strings 62func makeParams(p map[string]any) string { 63 params := url.Values{} 64 for k, v := range p { 65 if s, ok := v.([]string); ok { 66 for _, v := range s { 67 params.Add(k, v) 68 } 69 } else { 70 params.Add(k, fmt.Sprint(v)) 71 } 72 } 73 74 return params.Encode() 75} 76 77func PdsDpopJwt(method, url, iss, accessToken, nonce string, privateJwk jwk.Key) (string, error) { 78 pubJwk, err := privateJwk.PublicKey() 79 if err != nil { 80 return "", err 81 } 82 83 b, err := json.Marshal(pubJwk) 84 if err != nil { 85 return "", err 86 } 87 88 var pubMap map[string]any 89 if err := json.Unmarshal(b, &pubMap); err != nil { 90 return "", err 91 } 92 93 now := time.Now().Unix() 94 95 claims := jwt.MapClaims{ 96 "iss": iss, 97 "iat": now, 98 "exp": now + 30, 99 "jti": uuid.NewString(), 100 "htm": method, 101 "htu": url, 102 "ath": generateCodeChallenge(accessToken), 103 } 104 105 if nonce != "" { 106 claims["nonce"] = nonce 107 } 108 109 token := jwt.NewWithClaims(jwt.SigningMethodES256, claims) 110 token.Header["typ"] = "dpop+jwt" 111 token.Header["alg"] = "ES256" 112 token.Header["jwk"] = pubMap 113 114 var rawKey any 115 if err := privateJwk.Raw(&rawKey); err != nil { 116 return "", err 117 } 118 119 tokenString, err := token.SignedString(rawKey) 120 if err != nil { 121 return "", fmt.Errorf("failed to sign token: %w", err) 122 } 123 124 return tokenString, nil 125} 126 127type XrpcAuthedRequestArgs struct { 128 Did string 129 PdsUrl string 130 Issuer string 131 AccessToken string 132 DpopPdsNonce string 133 DpopPrivateJwk jwk.Key 134} 135 136func (c *XrpcClient) Do(ctx context.Context, authedArgs *XrpcAuthedRequestArgs, kind xrpc.XRPCRequestType, inpenc, method string, params map[string]any, bodyobj any, out any) error { 137 // we might have to retry the request if we get a new nonce from the server 138 for range 2 { 139 var body io.Reader 140 if bodyobj != nil { 141 if rr, ok := bodyobj.(io.Reader); ok { 142 body = rr 143 } else { 144 b, err := json.Marshal(bodyobj) 145 if err != nil { 146 return err 147 } 148 149 body = bytes.NewReader(b) 150 } 151 } 152 153 var m string 154 switch kind { 155 case xrpc.Query: 156 m = "GET" 157 case xrpc.Procedure: 158 m = "POST" 159 default: 160 return fmt.Errorf("unsupported request kind: %d", kind) 161 } 162 163 var paramStr string 164 if len(params) > 0 { 165 paramStr = "?" + makeParams(params) 166 } 167 168 ustr := authedArgs.PdsUrl + "/xrpc/" + method + paramStr 169 req, err := http.NewRequest(m, ustr, body) 170 if err != nil { 171 return err 172 } 173 174 if bodyobj != nil && inpenc != "" { 175 req.Header.Set("Content-Type", inpenc) 176 } 177 if c.UserAgent != nil { 178 req.Header.Set("User-Agent", *c.UserAgent) 179 } else { 180 req.Header.Set("User-Agent", "atproto-oauth/"+versioninfo.Short()) 181 } 182 183 if c.Headers != nil { 184 for k, v := range c.Headers { 185 req.Header.Set(k, v) 186 } 187 } 188 189 if authedArgs != nil { 190 dpopJwt, err := PdsDpopJwt(m, ustr, authedArgs.Issuer, authedArgs.AccessToken, authedArgs.DpopPdsNonce, authedArgs.DpopPrivateJwk) 191 if err != nil { 192 return err 193 } 194 195 req.Header.Set("DPoP", dpopJwt) 196 req.Header.Set("Authorization", "DPoP "+authedArgs.AccessToken) 197 } 198 199 resp, err := c.getClient().Do(req.WithContext(ctx)) 200 if err != nil { 201 return fmt.Errorf("request failed: %w", err) 202 } 203 204 defer resp.Body.Close() 205 206 if resp.StatusCode != 200 { 207 var xe xrpc.XRPCError 208 if err := json.NewDecoder(resp.Body).Decode(&xe); err != nil { 209 return errorFromHTTPResponse(resp, fmt.Errorf("failed to decode xrpc error message: %w", err)) 210 } 211 212 // if we get a new nonce, update the nonce and make the request again 213 if (resp.StatusCode == 400 || resp.StatusCode == 401) && xe.ErrStr == "use_dpop_nonce" { 214 newNonce := resp.Header.Get("DPoP-Nonce") 215 c.OnDPoPNonceChanged(authedArgs.Did, newNonce) 216 authedArgs.DpopPdsNonce = newNonce 217 continue 218 } 219 220 return errorFromHTTPResponse(resp, &xe) 221 } 222 223 if out != nil { 224 if buf, ok := out.(*bytes.Buffer); ok { 225 if resp.ContentLength < 0 { 226 _, err := io.Copy(buf, resp.Body) 227 if err != nil { 228 return fmt.Errorf("reading response body: %w", err) 229 } 230 } else { 231 n, err := io.CopyN(buf, resp.Body, resp.ContentLength) 232 if err != nil { 233 return fmt.Errorf("reading length delimited response body (%d < %d): %w", n, resp.ContentLength, err) 234 } 235 } 236 } else { 237 if err := json.NewDecoder(resp.Body).Decode(out); err != nil { 238 return fmt.Errorf("decoding xrpc response: %w", err) 239 } 240 } 241 } 242 243 break 244 } 245 246 return nil 247}