this repo has no description
1package oauth
2
3import (
4 "bytes"
5 "context"
6 "encoding/json"
7 "fmt"
8 "io"
9 "net/http"
10 "net/url"
11 "strconv"
12 "time"
13
14 "github.com/bluesky-social/indigo/util"
15 "github.com/bluesky-social/indigo/xrpc"
16 "github.com/carlmjohnson/versioninfo"
17 "github.com/golang-jwt/jwt/v5"
18 "github.com/google/uuid"
19 "github.com/haileyok/atproto-oauth-golang/internal/helpers"
20 "github.com/lestrrat-go/jwx/v2/jwk"
21)
22
23// This xrpc client is copied from the indigo xrpc client, with some tweaks:
24// - There is no `AuthInfo` on the client. Instead, you pass auth _with the request_ in the `Do()` function
25// - There is an `XrpcAuthedRequestArgs` struct that contains all the info you need to complete an authed request
26// - There is a `OnDpopPdsNonceChanged` callback that will run when the dpop nonce receives an update. You can
27// use this to update a database, for example.
28// - Requests are retried whenever the dpop nonce changes
29
30type XrpcClient struct {
31 // Client is an HTTP client to use. If not set, defaults to http.RobustHTTPClient().
32 Client *http.Client
33 UserAgent *string
34 Headers map[string]string
35 OnDpopPdsNonceChanged func(did, newNonce string)
36}
37
38type XrpcAuthedRequestArgs struct {
39 Did string
40 PdsUrl string
41 Issuer string
42 AccessToken string
43 DpopPdsNonce string
44 DpopPrivateJwk jwk.Key
45}
46
47func (c *XrpcClient) getClient() *http.Client {
48 if c.Client == nil {
49 return util.RobustHTTPClient()
50 }
51 return c.Client
52}
53
54func errorFromHTTPResponse(resp *http.Response, err error) error {
55 r := &xrpc.Error{
56 StatusCode: resp.StatusCode,
57 Wrapped: err,
58 }
59 if resp.Header.Get("ratelimit-limit") != "" {
60 r.Ratelimit = &xrpc.RatelimitInfo{
61 Policy: resp.Header.Get("ratelimit-policy"),
62 }
63 if n, err := strconv.ParseInt(resp.Header.Get("ratelimit-reset"), 10, 64); err == nil {
64 r.Ratelimit.Reset = time.Unix(n, 0)
65 }
66 if n, err := strconv.ParseInt(resp.Header.Get("ratelimit-limit"), 10, 64); err == nil {
67 r.Ratelimit.Limit = int(n)
68 }
69 if n, err := strconv.ParseInt(resp.Header.Get("ratelimit-remaining"), 10, 64); err == nil {
70 r.Ratelimit.Remaining = int(n)
71 }
72 }
73 return r
74}
75
76// makeParams converts a map of string keys and any values into a URL-encoded string.
77// If a value is a slice of strings, it will be joined with commas.
78// Generally the values will be strings, numbers, booleans, or slices of strings
79func makeParams(p map[string]any) string {
80 params := url.Values{}
81 for k, v := range p {
82 if s, ok := v.([]string); ok {
83 for _, v := range s {
84 params.Add(k, v)
85 }
86 } else {
87 params.Add(k, fmt.Sprint(v))
88 }
89 }
90
91 return params.Encode()
92}
93
94func PdsDpopJwt(method, url, iss, accessToken, nonce string, privateJwk jwk.Key) (string, error) {
95 pubJwk, err := privateJwk.PublicKey()
96 if err != nil {
97 return "", err
98 }
99
100 b, err := json.Marshal(pubJwk)
101 if err != nil {
102 return "", err
103 }
104
105 var pubMap map[string]any
106 if err := json.Unmarshal(b, &pubMap); err != nil {
107 return "", err
108 }
109
110 now := time.Now().Unix()
111
112 claims := jwt.MapClaims{
113 "iss": iss,
114 "iat": now,
115 "exp": now + 30,
116 "jti": uuid.NewString(),
117 "htm": method,
118 "htu": url,
119 "ath": helpers.GenerateCodeChallenge(accessToken),
120 }
121
122 if nonce != "" {
123 claims["nonce"] = nonce
124 }
125
126 token := jwt.NewWithClaims(jwt.SigningMethodES256, claims)
127 token.Header["typ"] = "dpop+jwt"
128 token.Header["alg"] = "ES256"
129 token.Header["jwk"] = pubMap
130
131 var rawKey any
132 if err := privateJwk.Raw(&rawKey); err != nil {
133 return "", err
134 }
135
136 tokenString, err := token.SignedString(rawKey)
137 if err != nil {
138 return "", fmt.Errorf("failed to sign token: %w", err)
139 }
140
141 return tokenString, nil
142}
143
144func (c *XrpcClient) Do(ctx context.Context, authedArgs *XrpcAuthedRequestArgs, kind xrpc.XRPCRequestType, inpenc, method string, params map[string]any, bodyobj any, out any) error {
145 // we might have to retry the request if we get a new nonce from the server
146 for range 2 {
147 var body io.Reader
148 if bodyobj != nil {
149 if rr, ok := bodyobj.(io.Reader); ok {
150 body = rr
151 } else {
152 b, err := json.Marshal(bodyobj)
153 if err != nil {
154 return err
155 }
156
157 body = bytes.NewReader(b)
158 }
159 }
160
161 var m string
162 switch kind {
163 case xrpc.Query:
164 m = "GET"
165 case xrpc.Procedure:
166 m = "POST"
167 default:
168 return fmt.Errorf("unsupported request kind: %d", kind)
169 }
170
171 var paramStr string
172 if len(params) > 0 {
173 paramStr = "?" + makeParams(params)
174 }
175
176 ustr := authedArgs.PdsUrl + "/xrpc/" + method + paramStr
177 req, err := http.NewRequest(m, ustr, body)
178 if err != nil {
179 return err
180 }
181
182 if bodyobj != nil && inpenc != "" {
183 req.Header.Set("Content-Type", inpenc)
184 }
185 if c.UserAgent != nil {
186 req.Header.Set("User-Agent", *c.UserAgent)
187 } else {
188 req.Header.Set("User-Agent", "atproto-oauth/"+versioninfo.Short())
189 }
190
191 if c.Headers != nil {
192 for k, v := range c.Headers {
193 req.Header.Set(k, v)
194 }
195 }
196
197 if authedArgs != nil {
198 dpopJwt, err := PdsDpopJwt(m, ustr, authedArgs.Issuer, authedArgs.AccessToken, authedArgs.DpopPdsNonce, authedArgs.DpopPrivateJwk)
199 if err != nil {
200 return err
201 }
202
203 req.Header.Set("DPoP", dpopJwt)
204 req.Header.Set("Authorization", "DPoP "+authedArgs.AccessToken)
205 }
206
207 resp, err := c.getClient().Do(req.WithContext(ctx))
208 if err != nil {
209 return fmt.Errorf("request failed: %w", err)
210 }
211
212 defer resp.Body.Close()
213
214 if resp.StatusCode != 200 {
215 var xe xrpc.XRPCError
216 if err := json.NewDecoder(resp.Body).Decode(&xe); err != nil {
217 return errorFromHTTPResponse(resp, fmt.Errorf("failed to decode xrpc error message: %w", err))
218 }
219
220 // if we get a new nonce, update the nonce and make the request again
221 if (resp.StatusCode == 400 || resp.StatusCode == 401) && xe.ErrStr == "use_dpop_nonce" {
222 authedArgs.DpopPdsNonce = resp.Header.Get("DPoP-Nonce")
223 c.OnDpopPdsNonceChanged(authedArgs.Did, authedArgs.DpopPdsNonce)
224 continue
225 }
226
227 return errorFromHTTPResponse(resp, &xe)
228 }
229
230 if out != nil {
231 if buf, ok := out.(*bytes.Buffer); ok {
232 if resp.ContentLength < 0 {
233 _, err := io.Copy(buf, resp.Body)
234 if err != nil {
235 return fmt.Errorf("reading response body: %w", err)
236 }
237 } else {
238 n, err := io.CopyN(buf, resp.Body, resp.ContentLength)
239 if err != nil {
240 return fmt.Errorf("reading length delimited response body (%d < %d): %w", n, resp.ContentLength, err)
241 }
242 }
243 } else {
244 if err := json.NewDecoder(resp.Body).Decode(out); err != nil {
245 return fmt.Errorf("decoding xrpc response: %w", err)
246 }
247 }
248 }
249
250 return nil
251 }
252
253 return nil
254}