this repo has no description
1package oauth
2
3import (
4 "bytes"
5 "context"
6 "encoding/json"
7 "fmt"
8 "io"
9 "net/http"
10 "net/url"
11 "strconv"
12 "time"
13
14 "github.com/bluesky-social/indigo/util"
15 "github.com/bluesky-social/indigo/xrpc"
16 "github.com/carlmjohnson/versioninfo"
17 "github.com/golang-jwt/jwt/v5"
18 "github.com/google/uuid"
19 "github.com/lestrrat-go/jwx/v2/jwk"
20)
21
22type XrpcClient struct {
23 // Client is an HTTP client to use. If not set, defaults to http.RobustHTTPClient().
24 Client *http.Client
25 UserAgent *string
26 Headers map[string]string
27 OnDPoPNonceChanged func(did, newNonce string)
28}
29
30type XrpcAuthedRequestArgs struct {
31 Did string
32 PdsUrl string
33 Issuer string
34 AccessToken string
35 DpopPdsNonce string
36 DpopPrivateJwk jwk.Key
37}
38
39func (c *XrpcClient) getClient() *http.Client {
40 if c.Client == nil {
41 return util.RobustHTTPClient()
42 }
43 return c.Client
44}
45
46func errorFromHTTPResponse(resp *http.Response, err error) error {
47 r := &xrpc.Error{
48 StatusCode: resp.StatusCode,
49 Wrapped: err,
50 }
51 if resp.Header.Get("ratelimit-limit") != "" {
52 r.Ratelimit = &xrpc.RatelimitInfo{
53 Policy: resp.Header.Get("ratelimit-policy"),
54 }
55 if n, err := strconv.ParseInt(resp.Header.Get("ratelimit-reset"), 10, 64); err == nil {
56 r.Ratelimit.Reset = time.Unix(n, 0)
57 }
58 if n, err := strconv.ParseInt(resp.Header.Get("ratelimit-limit"), 10, 64); err == nil {
59 r.Ratelimit.Limit = int(n)
60 }
61 if n, err := strconv.ParseInt(resp.Header.Get("ratelimit-remaining"), 10, 64); err == nil {
62 r.Ratelimit.Remaining = int(n)
63 }
64 }
65 return r
66}
67
68// makeParams converts a map of string keys and any values into a URL-encoded string.
69// If a value is a slice of strings, it will be joined with commas.
70// Generally the values will be strings, numbers, booleans, or slices of strings
71func makeParams(p map[string]any) string {
72 params := url.Values{}
73 for k, v := range p {
74 if s, ok := v.([]string); ok {
75 for _, v := range s {
76 params.Add(k, v)
77 }
78 } else {
79 params.Add(k, fmt.Sprint(v))
80 }
81 }
82
83 return params.Encode()
84}
85
86func PdsDpopJwt(method, url, iss, accessToken, nonce string, privateJwk jwk.Key) (string, error) {
87 pubJwk, err := privateJwk.PublicKey()
88 if err != nil {
89 return "", err
90 }
91
92 b, err := json.Marshal(pubJwk)
93 if err != nil {
94 return "", err
95 }
96
97 var pubMap map[string]any
98 if err := json.Unmarshal(b, &pubMap); err != nil {
99 return "", err
100 }
101
102 now := time.Now().Unix()
103
104 claims := jwt.MapClaims{
105 "iss": iss,
106 "iat": now,
107 "exp": now + 30,
108 "jti": uuid.NewString(),
109 "htm": method,
110 "htu": url,
111 "ath": generateCodeChallenge(accessToken),
112 }
113
114 if nonce != "" {
115 claims["nonce"] = nonce
116 }
117
118 token := jwt.NewWithClaims(jwt.SigningMethodES256, claims)
119 token.Header["typ"] = "dpop+jwt"
120 token.Header["alg"] = "ES256"
121 token.Header["jwk"] = pubMap
122
123 var rawKey any
124 if err := privateJwk.Raw(&rawKey); err != nil {
125 return "", err
126 }
127
128 tokenString, err := token.SignedString(rawKey)
129 if err != nil {
130 return "", fmt.Errorf("failed to sign token: %w", err)
131 }
132
133 return tokenString, nil
134}
135
136func (c *XrpcClient) Do(ctx context.Context, authedArgs *XrpcAuthedRequestArgs, kind xrpc.XRPCRequestType, inpenc, method string, params map[string]any, bodyobj any, out any) error {
137 // we might have to retry the request if we get a new nonce from the server
138 for range 2 {
139 var body io.Reader
140 if bodyobj != nil {
141 if rr, ok := bodyobj.(io.Reader); ok {
142 body = rr
143 } else {
144 b, err := json.Marshal(bodyobj)
145 if err != nil {
146 return err
147 }
148
149 body = bytes.NewReader(b)
150 }
151 }
152
153 var m string
154 switch kind {
155 case xrpc.Query:
156 m = "GET"
157 case xrpc.Procedure:
158 m = "POST"
159 default:
160 return fmt.Errorf("unsupported request kind: %d", kind)
161 }
162
163 var paramStr string
164 if len(params) > 0 {
165 paramStr = "?" + makeParams(params)
166 }
167
168 ustr := authedArgs.PdsUrl + "/xrpc/" + method + paramStr
169 req, err := http.NewRequest(m, ustr, body)
170 if err != nil {
171 return err
172 }
173
174 if bodyobj != nil && inpenc != "" {
175 req.Header.Set("Content-Type", inpenc)
176 }
177 if c.UserAgent != nil {
178 req.Header.Set("User-Agent", *c.UserAgent)
179 } else {
180 req.Header.Set("User-Agent", "atproto-oauth/"+versioninfo.Short())
181 }
182
183 if c.Headers != nil {
184 for k, v := range c.Headers {
185 req.Header.Set(k, v)
186 }
187 }
188
189 if authedArgs != nil {
190 dpopJwt, err := PdsDpopJwt(m, ustr, authedArgs.Issuer, authedArgs.AccessToken, authedArgs.DpopPdsNonce, authedArgs.DpopPrivateJwk)
191 if err != nil {
192 return err
193 }
194
195 req.Header.Set("DPoP", dpopJwt)
196 req.Header.Set("Authorization", "DPoP "+authedArgs.AccessToken)
197 }
198
199 resp, err := c.getClient().Do(req.WithContext(ctx))
200 if err != nil {
201 return fmt.Errorf("request failed: %w", err)
202 }
203
204 defer resp.Body.Close()
205
206 if resp.StatusCode != 200 {
207 var xe xrpc.XRPCError
208 if err := json.NewDecoder(resp.Body).Decode(&xe); err != nil {
209 return errorFromHTTPResponse(resp, fmt.Errorf("failed to decode xrpc error message: %w", err))
210 }
211
212 // if we get a new nonce, update the nonce and make the request again
213 if (resp.StatusCode == 400 || resp.StatusCode == 401) && xe.ErrStr == "use_dpop_nonce" {
214 newNonce := resp.Header.Get("DPoP-Nonce")
215 c.OnDPoPNonceChanged(authedArgs.Did, newNonce)
216 authedArgs.DpopPdsNonce = newNonce
217 continue
218 }
219
220 return errorFromHTTPResponse(resp, &xe)
221 }
222
223 if out != nil {
224 if buf, ok := out.(*bytes.Buffer); ok {
225 if resp.ContentLength < 0 {
226 _, err := io.Copy(buf, resp.Body)
227 if err != nil {
228 return fmt.Errorf("reading response body: %w", err)
229 }
230 } else {
231 n, err := io.CopyN(buf, resp.Body, resp.ContentLength)
232 if err != nil {
233 return fmt.Errorf("reading length delimited response body (%d < %d): %w", n, resp.ContentLength, err)
234 }
235 }
236 } else {
237 if err := json.NewDecoder(resp.Body).Decode(out); err != nil {
238 return fmt.Errorf("decoding xrpc response: %w", err)
239 }
240 }
241 }
242
243 break
244 }
245
246 return nil
247}