this repo has no description
1package oauth
2
3import (
4 "bytes"
5 "context"
6 "encoding/json"
7 "fmt"
8 "io"
9 "net/http"
10 "net/url"
11 "strconv"
12 "time"
13
14 "github.com/bluesky-social/indigo/util"
15 "github.com/bluesky-social/indigo/xrpc"
16 "github.com/carlmjohnson/versioninfo"
17 "github.com/golang-jwt/jwt/v5"
18 "github.com/google/uuid"
19 "github.com/lestrrat-go/jwx/v2/jwk"
20)
21
22// This xrpc client is copied from the indigo xrpc client, with some tweaks:
23// - There is no `AuthInfo` on the client. Instead, you pass auth _with the request_ in the `Do()` function
24// - There is an `XrpcAuthedRequestArgs` struct that contains all the info you need to complete an authed request
25// - There is a `OnDpopPdsNonceChanged` callback that will run when the dpop nonce receives an update. You can
26// use this to update a database, for example.
27// - Requests are retried whenever the dpop nonce changes
28
29type XrpcClient struct {
30 // Client is an HTTP client to use. If not set, defaults to http.RobustHTTPClient().
31 Client *http.Client
32 UserAgent *string
33 Headers map[string]string
34 OnDpopPdsNonceChanged func(did, newNonce string)
35}
36
37type XrpcAuthedRequestArgs struct {
38 Did string
39 PdsUrl string
40 Issuer string
41 AccessToken string
42 DpopPdsNonce string
43 DpopPrivateJwk jwk.Key
44}
45
46func (c *XrpcClient) getClient() *http.Client {
47 if c.Client == nil {
48 return util.RobustHTTPClient()
49 }
50 return c.Client
51}
52
53func errorFromHTTPResponse(resp *http.Response, err error) error {
54 r := &xrpc.Error{
55 StatusCode: resp.StatusCode,
56 Wrapped: err,
57 }
58 if resp.Header.Get("ratelimit-limit") != "" {
59 r.Ratelimit = &xrpc.RatelimitInfo{
60 Policy: resp.Header.Get("ratelimit-policy"),
61 }
62 if n, err := strconv.ParseInt(resp.Header.Get("ratelimit-reset"), 10, 64); err == nil {
63 r.Ratelimit.Reset = time.Unix(n, 0)
64 }
65 if n, err := strconv.ParseInt(resp.Header.Get("ratelimit-limit"), 10, 64); err == nil {
66 r.Ratelimit.Limit = int(n)
67 }
68 if n, err := strconv.ParseInt(resp.Header.Get("ratelimit-remaining"), 10, 64); err == nil {
69 r.Ratelimit.Remaining = int(n)
70 }
71 }
72 return r
73}
74
75// makeParams converts a map of string keys and any values into a URL-encoded string.
76// If a value is a slice of strings, it will be joined with commas.
77// Generally the values will be strings, numbers, booleans, or slices of strings
78func makeParams(p map[string]any) string {
79 params := url.Values{}
80 for k, v := range p {
81 if s, ok := v.([]string); ok {
82 for _, v := range s {
83 params.Add(k, v)
84 }
85 } else {
86 params.Add(k, fmt.Sprint(v))
87 }
88 }
89
90 return params.Encode()
91}
92
93func PdsDpopJwt(method, url, iss, accessToken, nonce string, privateJwk jwk.Key) (string, error) {
94 pubJwk, err := privateJwk.PublicKey()
95 if err != nil {
96 return "", err
97 }
98
99 b, err := json.Marshal(pubJwk)
100 if err != nil {
101 return "", err
102 }
103
104 var pubMap map[string]any
105 if err := json.Unmarshal(b, &pubMap); err != nil {
106 return "", err
107 }
108
109 now := time.Now().Unix()
110
111 claims := jwt.MapClaims{
112 "iss": iss,
113 "iat": now,
114 "exp": now + 30,
115 "jti": uuid.NewString(),
116 "htm": method,
117 "htu": url,
118 "ath": generateCodeChallenge(accessToken),
119 }
120
121 if nonce != "" {
122 claims["nonce"] = nonce
123 }
124
125 token := jwt.NewWithClaims(jwt.SigningMethodES256, claims)
126 token.Header["typ"] = "dpop+jwt"
127 token.Header["alg"] = "ES256"
128 token.Header["jwk"] = pubMap
129
130 var rawKey any
131 if err := privateJwk.Raw(&rawKey); err != nil {
132 return "", err
133 }
134
135 tokenString, err := token.SignedString(rawKey)
136 if err != nil {
137 return "", fmt.Errorf("failed to sign token: %w", err)
138 }
139
140 return tokenString, nil
141}
142
143func (c *XrpcClient) Do(ctx context.Context, authedArgs *XrpcAuthedRequestArgs, kind xrpc.XRPCRequestType, inpenc, method string, params map[string]any, bodyobj any, out any) error {
144 // we might have to retry the request if we get a new nonce from the server
145 for range 2 {
146 var body io.Reader
147 if bodyobj != nil {
148 if rr, ok := bodyobj.(io.Reader); ok {
149 body = rr
150 } else {
151 b, err := json.Marshal(bodyobj)
152 if err != nil {
153 return err
154 }
155
156 body = bytes.NewReader(b)
157 }
158 }
159
160 var m string
161 switch kind {
162 case xrpc.Query:
163 m = "GET"
164 case xrpc.Procedure:
165 m = "POST"
166 default:
167 return fmt.Errorf("unsupported request kind: %d", kind)
168 }
169
170 var paramStr string
171 if len(params) > 0 {
172 paramStr = "?" + makeParams(params)
173 }
174
175 ustr := authedArgs.PdsUrl + "/xrpc/" + method + paramStr
176 req, err := http.NewRequest(m, ustr, body)
177 if err != nil {
178 return err
179 }
180
181 if bodyobj != nil && inpenc != "" {
182 req.Header.Set("Content-Type", inpenc)
183 }
184 if c.UserAgent != nil {
185 req.Header.Set("User-Agent", *c.UserAgent)
186 } else {
187 req.Header.Set("User-Agent", "atproto-oauth/"+versioninfo.Short())
188 }
189
190 if c.Headers != nil {
191 for k, v := range c.Headers {
192 req.Header.Set(k, v)
193 }
194 }
195
196 if authedArgs != nil {
197 dpopJwt, err := PdsDpopJwt(m, ustr, authedArgs.Issuer, authedArgs.AccessToken, authedArgs.DpopPdsNonce, authedArgs.DpopPrivateJwk)
198 if err != nil {
199 return err
200 }
201
202 req.Header.Set("DPoP", dpopJwt)
203 req.Header.Set("Authorization", "DPoP "+authedArgs.AccessToken)
204 }
205
206 resp, err := c.getClient().Do(req.WithContext(ctx))
207 if err != nil {
208 return fmt.Errorf("request failed: %w", err)
209 }
210
211 defer resp.Body.Close()
212
213 if resp.StatusCode != 200 {
214 var xe xrpc.XRPCError
215 if err := json.NewDecoder(resp.Body).Decode(&xe); err != nil {
216 return errorFromHTTPResponse(resp, fmt.Errorf("failed to decode xrpc error message: %w", err))
217 }
218
219 // if we get a new nonce, update the nonce and make the request again
220 if (resp.StatusCode == 400 || resp.StatusCode == 401) && xe.ErrStr == "use_dpop_nonce" {
221 authedArgs.DpopPdsNonce = resp.Header.Get("DPoP-Nonce")
222 c.OnDpopPdsNonceChanged(authedArgs.Did, authedArgs.DpopPdsNonce)
223 continue
224 }
225
226 return errorFromHTTPResponse(resp, &xe)
227 }
228
229 if out != nil {
230 if buf, ok := out.(*bytes.Buffer); ok {
231 if resp.ContentLength < 0 {
232 _, err := io.Copy(buf, resp.Body)
233 if err != nil {
234 return fmt.Errorf("reading response body: %w", err)
235 }
236 } else {
237 n, err := io.CopyN(buf, resp.Body, resp.ContentLength)
238 if err != nil {
239 return fmt.Errorf("reading length delimited response body (%d < %d): %w", n, resp.ContentLength, err)
240 }
241 }
242 } else {
243 if err := json.NewDecoder(resp.Body).Decode(out); err != nil {
244 return fmt.Errorf("decoding xrpc response: %w", err)
245 }
246 }
247 }
248
249 return nil
250 }
251
252 return nil
253}