1package server
2
3import (
4 "crypto/rand"
5 "crypto/sha256"
6 "encoding/base64"
7 "encoding/json"
8 "fmt"
9 "net/http"
10 "strings"
11 "time"
12
13 "github.com/google/uuid"
14 "github.com/haileyok/cocoon/internal/helpers"
15 "github.com/haileyok/cocoon/models"
16 "github.com/labstack/echo/v4"
17 secp256k1secec "gitlab.com/yawning/secp256k1-voi/secec"
18)
19
20func (s *Server) getAtprotoProxyEndpointFromRequest(e echo.Context) (string, string, error) {
21 svc := e.Request().Header.Get("atproto-proxy")
22 if svc == "" && s.config.FallbackProxy != "" {
23 svc = s.config.FallbackProxy
24 }
25
26 svcPts := strings.Split(svc, "#")
27 if len(svcPts) != 2 {
28 return "", "", fmt.Errorf("invalid service header")
29 }
30
31 svcDid := svcPts[0]
32 svcId := "#" + svcPts[1]
33
34 doc, err := s.passport.FetchDoc(e.Request().Context(), svcDid)
35 if err != nil {
36 return "", "", err
37 }
38
39 var endpoint string
40 for _, s := range doc.Service {
41 if s.Id == svcId {
42 endpoint = s.ServiceEndpoint
43 }
44 }
45
46 return endpoint, svcDid, nil
47}
48
49func (s *Server) handleProxy(e echo.Context) error {
50 lgr := s.logger.With("handler", "handleProxy")
51
52 repo, isAuthed := e.Get("repo").(*models.RepoActor)
53
54 pts := strings.Split(e.Request().URL.Path, "/")
55 if len(pts) != 3 {
56 return fmt.Errorf("incorrect number of parts")
57 }
58
59 endpoint, svcDid, err := s.getAtprotoProxyEndpointFromRequest(e)
60 if err != nil {
61 lgr.Error("could not get atproto proxy", "error", err)
62 return helpers.ServerError(e, nil)
63 }
64
65 requrl := e.Request().URL
66 requrl.Host = strings.TrimPrefix(endpoint, "https://")
67 requrl.Scheme = "https"
68
69 body := e.Request().Body
70 if e.Request().Method == "GET" {
71 body = nil
72 }
73
74 req, err := http.NewRequest(e.Request().Method, requrl.String(), body)
75 if err != nil {
76 return err
77 }
78
79 req.Header = e.Request().Header.Clone()
80
81 if isAuthed {
82 // this is a little dumb. i should probably figure out a better way to do this, and use
83 // a single way of creating/signing jwts throughout the pds. kinda limited here because
84 // im using the atproto crypto lib for this though. will come back to it
85
86 header := map[string]string{
87 "alg": "ES256K",
88 "crv": "secp256k1",
89 "typ": "JWT",
90 }
91 hj, err := json.Marshal(header)
92 if err != nil {
93 lgr.Error("error marshaling header", "error", err)
94 return helpers.ServerError(e, nil)
95 }
96
97 encheader := strings.TrimRight(base64.RawURLEncoding.EncodeToString(hj), "=")
98
99 // When proxying app.bsky.feed.getFeed the token is actually issued for the
100 // underlying feed generator and the app view passes it on. This allows the
101 // getFeed implementation to pass in the desired lxm and aud for the token
102 // and then just delegate to the general proxying logic
103 lxm, proxyTokenLxmExists := e.Get("proxyTokenLxm").(string)
104 if !proxyTokenLxmExists || lxm == "" {
105 lxm = pts[2]
106 }
107 aud, proxyTokenAudExists := e.Get("proxyTokenAud").(string)
108 if !proxyTokenAudExists || aud == "" {
109 aud = svcDid
110 }
111
112 payload := map[string]any{
113 "iss": repo.Repo.Did,
114 "aud": aud,
115 "lxm": lxm,
116 "jti": uuid.NewString(),
117 "exp": time.Now().Add(1 * time.Minute).UTC().Unix(),
118 }
119 pj, err := json.Marshal(payload)
120 if err != nil {
121 lgr.Error("error marashaling payload", "error", err)
122 return helpers.ServerError(e, nil)
123 }
124
125 encpayload := strings.TrimRight(base64.RawURLEncoding.EncodeToString(pj), "=")
126
127 input := fmt.Sprintf("%s.%s", encheader, encpayload)
128 hash := sha256.Sum256([]byte(input))
129
130 sk, err := secp256k1secec.NewPrivateKey(repo.SigningKey)
131 if err != nil {
132 lgr.Error("can't load private key", "error", err)
133 return err
134 }
135
136 R, S, _, err := sk.SignRaw(rand.Reader, hash[:])
137 if err != nil {
138 lgr.Error("error signing", "error", err)
139 }
140
141 rBytes := R.Bytes()
142 sBytes := S.Bytes()
143
144 rPadded := make([]byte, 32)
145 sPadded := make([]byte, 32)
146 copy(rPadded[32-len(rBytes):], rBytes)
147 copy(sPadded[32-len(sBytes):], sBytes)
148
149 rawsig := append(rPadded, sPadded...)
150 encsig := strings.TrimRight(base64.RawURLEncoding.EncodeToString(rawsig), "=")
151 token := fmt.Sprintf("%s.%s", input, encsig)
152
153 req.Header.Set("authorization", "Bearer "+token)
154 } else {
155 req.Header.Del("authorization")
156 }
157
158 resp, err := http.DefaultClient.Do(req)
159 if err != nil {
160 return err
161 }
162 defer resp.Body.Close()
163
164 for k, v := range resp.Header {
165 e.Response().Header().Set(k, strings.Join(v, ","))
166 }
167
168 return e.Stream(resp.StatusCode, e.Response().Header().Get("content-type"), resp.Body)
169}