1package server
2
3import (
4 "crypto/rand"
5 "crypto/sha256"
6 "encoding/base64"
7 "encoding/json"
8 "fmt"
9 "net/http"
10 "strings"
11 "time"
12
13 "github.com/google/uuid"
14 "github.com/haileyok/cocoon/internal/helpers"
15 "github.com/haileyok/cocoon/models"
16 "github.com/labstack/echo/v4"
17 secp256k1secec "gitlab.com/yawning/secp256k1-voi/secec"
18)
19
20func (s *Server) handleProxy(e echo.Context) error {
21 repo, isAuthed := e.Get("repo").(*models.RepoActor)
22
23 pts := strings.Split(e.Request().URL.Path, "/")
24 if len(pts) != 3 {
25 return fmt.Errorf("incorrect number of parts")
26 }
27
28 svc := e.Request().Header.Get("atproto-proxy")
29 if svc == "" {
30 svc = "did:web:api.bsky.app#bsky_appview" // TODO: should be a config var probably
31 }
32
33 svcPts := strings.Split(svc, "#")
34 if len(svcPts) != 2 {
35 return fmt.Errorf("invalid service header")
36 }
37
38 svcDid := svcPts[0]
39 svcId := "#" + svcPts[1]
40
41 doc, err := s.passport.FetchDoc(e.Request().Context(), svcDid)
42 if err != nil {
43 return err
44 }
45
46 var endpoint string
47 for _, s := range doc.Service {
48 if s.Id == svcId {
49 endpoint = s.ServiceEndpoint
50 }
51 }
52
53 requrl := e.Request().URL
54 requrl.Host = strings.TrimPrefix(endpoint, "https://")
55 requrl.Scheme = "https"
56
57 body := e.Request().Body
58 if e.Request().Method == "GET" {
59 body = nil
60 }
61
62 req, err := http.NewRequest(e.Request().Method, requrl.String(), body)
63 if err != nil {
64 return err
65 }
66
67 req.Header = e.Request().Header.Clone()
68
69 if isAuthed {
70 // this is a little dumb. i should probably figure out a better way to do this, and use
71 // a single way of creating/signing jwts throughout the pds. kinda limited here because
72 // im using the atproto crypto lib for this though. will come back to it
73
74 header := map[string]string{
75 "alg": "ES256K",
76 "crv": "secp256k1",
77 "typ": "JWT",
78 }
79 hj, err := json.Marshal(header)
80 if err != nil {
81 s.logger.Error("error marshaling header", "error", err)
82 return helpers.ServerError(e, nil)
83 }
84
85 encheader := strings.TrimRight(base64.RawURLEncoding.EncodeToString(hj), "=")
86
87 payload := map[string]any{
88 "iss": repo.Repo.Did,
89 "aud": svcDid,
90 "lxm": pts[2],
91 "jti": uuid.NewString(),
92 "exp": time.Now().Add(1 * time.Minute).UTC().Unix(),
93 }
94 pj, err := json.Marshal(payload)
95 if err != nil {
96 s.logger.Error("error marashaling payload", "error", err)
97 return helpers.ServerError(e, nil)
98 }
99
100 encpayload := strings.TrimRight(base64.RawURLEncoding.EncodeToString(pj), "=")
101
102 input := fmt.Sprintf("%s.%s", encheader, encpayload)
103 hash := sha256.Sum256([]byte(input))
104
105 sk, err := secp256k1secec.NewPrivateKey(repo.SigningKey)
106 if err != nil {
107 s.logger.Error("can't load private key", "error", err)
108 return err
109 }
110
111 R, S, _, err := sk.SignRaw(rand.Reader, hash[:])
112 if err != nil {
113 s.logger.Error("error signing", "error", err)
114 }
115
116 rBytes := R.Bytes()
117 sBytes := S.Bytes()
118
119 rPadded := make([]byte, 32)
120 sPadded := make([]byte, 32)
121 copy(rPadded[32-len(rBytes):], rBytes)
122 copy(sPadded[32-len(sBytes):], sBytes)
123
124 rawsig := append(rPadded, sPadded...)
125 encsig := strings.TrimRight(base64.RawURLEncoding.EncodeToString(rawsig), "=")
126 token := fmt.Sprintf("%s.%s", input, encsig)
127
128 req.Header.Set("authorization", "Bearer "+token)
129 } else {
130 req.Header.Del("authorization")
131 }
132
133 resp, err := http.DefaultClient.Do(req)
134 if err != nil {
135 return err
136 }
137 defer resp.Body.Close()
138
139 for k, v := range resp.Header {
140 e.Response().Header().Set(k, strings.Join(v, ","))
141 }
142
143 return e.Stream(resp.StatusCode, e.Response().Header().Get("content-type"), resp.Body)
144}