this repo has no description
1package oauth
2
3import (
4 "bytes"
5 "context"
6 "encoding/json"
7 "fmt"
8 "io"
9 "net/http"
10 "net/url"
11 "strconv"
12 "time"
13
14 "github.com/bluesky-social/indigo/util"
15 "github.com/bluesky-social/indigo/xrpc"
16 "github.com/carlmjohnson/versioninfo"
17 "github.com/golang-jwt/jwt/v5"
18 "github.com/google/uuid"
19 "github.com/lestrrat-go/jwx/v2/jwk"
20)
21
22type XrpcClient struct {
23 // Client is an HTTP client to use. If not set, defaults to http.RobustHTTPClient().
24 Client *http.Client
25 UserAgent *string
26 Headers map[string]string
27 OnDPoPNonceChanged func(did, newNonce string)
28}
29
30func (c *XrpcClient) getClient() *http.Client {
31 if c.Client == nil {
32 return util.RobustHTTPClient()
33 }
34 return c.Client
35}
36
37func errorFromHTTPResponse(resp *http.Response, err error) error {
38 r := &xrpc.Error{
39 StatusCode: resp.StatusCode,
40 Wrapped: err,
41 }
42 if resp.Header.Get("ratelimit-limit") != "" {
43 r.Ratelimit = &xrpc.RatelimitInfo{
44 Policy: resp.Header.Get("ratelimit-policy"),
45 }
46 if n, err := strconv.ParseInt(resp.Header.Get("ratelimit-reset"), 10, 64); err == nil {
47 r.Ratelimit.Reset = time.Unix(n, 0)
48 }
49 if n, err := strconv.ParseInt(resp.Header.Get("ratelimit-limit"), 10, 64); err == nil {
50 r.Ratelimit.Limit = int(n)
51 }
52 if n, err := strconv.ParseInt(resp.Header.Get("ratelimit-remaining"), 10, 64); err == nil {
53 r.Ratelimit.Remaining = int(n)
54 }
55 }
56 return r
57}
58
59// makeParams converts a map of string keys and any values into a URL-encoded string.
60// If a value is a slice of strings, it will be joined with commas.
61// Generally the values will be strings, numbers, booleans, or slices of strings
62func makeParams(p map[string]any) string {
63 params := url.Values{}
64 for k, v := range p {
65 if s, ok := v.([]string); ok {
66 for _, v := range s {
67 params.Add(k, v)
68 }
69 } else {
70 params.Add(k, fmt.Sprint(v))
71 }
72 }
73
74 return params.Encode()
75}
76
77func PdsDpopJwt(method, url, iss, accessToken, nonce string, privateJwk jwk.Key) (string, error) {
78 pubJwk, err := privateJwk.PublicKey()
79 if err != nil {
80 return "", err
81 }
82
83 b, err := json.Marshal(pubJwk)
84 if err != nil {
85 return "", err
86 }
87
88 var pubMap map[string]any
89 if err := json.Unmarshal(b, &pubMap); err != nil {
90 return "", err
91 }
92
93 now := time.Now().Unix()
94
95 claims := jwt.MapClaims{
96 "iss": iss,
97 "iat": now,
98 "exp": now + 30,
99 "jti": uuid.NewString(),
100 "htm": method,
101 "htu": url,
102 "ath": generateCodeChallenge(accessToken),
103 }
104
105 if nonce != "" {
106 claims["nonce"] = nonce
107 }
108
109 token := jwt.NewWithClaims(jwt.SigningMethodES256, claims)
110 token.Header["typ"] = "dpop+jwt"
111 token.Header["alg"] = "ES256"
112 token.Header["jwk"] = pubMap
113
114 var rawKey any
115 if err := privateJwk.Raw(&rawKey); err != nil {
116 return "", err
117 }
118
119 tokenString, err := token.SignedString(rawKey)
120 if err != nil {
121 return "", fmt.Errorf("failed to sign token: %w", err)
122 }
123
124 return tokenString, nil
125}
126
127type XrpcAuthedRequestArgs struct {
128 Did string
129 PdsUrl string
130 Issuer string
131 AccessToken string
132 DpopPdsNonce string
133 DpopPrivateJwk jwk.Key
134}
135
136func (c *XrpcClient) Do(ctx context.Context, authedArgs *XrpcAuthedRequestArgs, kind xrpc.XRPCRequestType, inpenc, method string, params map[string]any, bodyobj any, out any) error {
137 // we might have to retry the request if we get a new nonce from the server
138 for range 2 {
139 var body io.Reader
140 if bodyobj != nil {
141 if rr, ok := bodyobj.(io.Reader); ok {
142 body = rr
143 } else {
144 b, err := json.Marshal(bodyobj)
145 if err != nil {
146 return err
147 }
148
149 body = bytes.NewReader(b)
150 }
151 }
152
153 var m string
154 switch kind {
155 case xrpc.Query:
156 m = "GET"
157 case xrpc.Procedure:
158 m = "POST"
159 default:
160 return fmt.Errorf("unsupported request kind: %d", kind)
161 }
162
163 var paramStr string
164 if len(params) > 0 {
165 paramStr = "?" + makeParams(params)
166 }
167
168 ustr := authedArgs.PdsUrl + "/xrpc/" + method + paramStr
169 req, err := http.NewRequest(m, ustr, body)
170 if err != nil {
171 return err
172 }
173
174 if bodyobj != nil && inpenc != "" {
175 req.Header.Set("Content-Type", inpenc)
176 }
177 if c.UserAgent != nil {
178 req.Header.Set("User-Agent", *c.UserAgent)
179 } else {
180 req.Header.Set("User-Agent", "atproto-oauth/"+versioninfo.Short())
181 }
182
183 if c.Headers != nil {
184 for k, v := range c.Headers {
185 req.Header.Set(k, v)
186 }
187 }
188
189 if authedArgs != nil {
190 dpopJwt, err := PdsDpopJwt(m, ustr, authedArgs.Issuer, authedArgs.AccessToken, authedArgs.DpopPdsNonce, authedArgs.DpopPrivateJwk)
191 if err != nil {
192 return err
193 }
194
195 req.Header.Set("DPoP", dpopJwt)
196 req.Header.Set("Authorization", "DPoP "+authedArgs.AccessToken)
197 }
198
199 resp, err := c.getClient().Do(req.WithContext(ctx))
200 if err != nil {
201 return fmt.Errorf("request failed: %w", err)
202 }
203
204 defer resp.Body.Close()
205
206 if resp.StatusCode != 200 {
207 var xe xrpc.XRPCError
208 if err := json.NewDecoder(resp.Body).Decode(&xe); err != nil {
209 return errorFromHTTPResponse(resp, fmt.Errorf("failed to decode xrpc error message: %w", err))
210 }
211
212 // if we get a new nonce, update the nonce and make the request again
213 if (resp.StatusCode == 400 || resp.StatusCode == 401) && xe.ErrStr == "use_dpop_nonce" {
214 newNonce := resp.Header.Get("DPoP-Nonce")
215 c.OnDPoPNonceChanged(authedArgs.Did, newNonce)
216 authedArgs.DpopPdsNonce = newNonce
217 continue
218 }
219
220 return errorFromHTTPResponse(resp, &xe)
221 }
222
223 if out != nil {
224 if buf, ok := out.(*bytes.Buffer); ok {
225 if resp.ContentLength < 0 {
226 _, err := io.Copy(buf, resp.Body)
227 if err != nil {
228 return fmt.Errorf("reading response body: %w", err)
229 }
230 } else {
231 n, err := io.CopyN(buf, resp.Body, resp.ContentLength)
232 if err != nil {
233 return fmt.Errorf("reading length delimited response body (%d < %d): %w", n, resp.ContentLength, err)
234 }
235 }
236 } else {
237 if err := json.NewDecoder(resp.Body).Decode(out); err != nil {
238 return fmt.Errorf("decoding xrpc response: %w", err)
239 }
240 }
241 }
242
243 break
244 }
245
246 return nil
247}