1package server
2
3import (
4 "crypto/rand"
5 "crypto/sha256"
6 "encoding/base64"
7 "encoding/json"
8 "fmt"
9 "net/http"
10 "strings"
11 "time"
12
13 "github.com/google/uuid"
14 "github.com/haileyok/cocoon/internal/helpers"
15 "github.com/haileyok/cocoon/models"
16 "github.com/labstack/echo/v4"
17 secp256k1secec "gitlab.com/yawning/secp256k1-voi/secec"
18)
19
20func (s *Server) getAtprotoProxyEndpointFromRequest(e echo.Context) (string, string, error) {
21 svc := e.Request().Header.Get("atproto-proxy")
22
23 svcPts := strings.Split(svc, "#")
24 if len(svcPts) != 2 {
25 return "", "", fmt.Errorf("invalid service header")
26 }
27
28 svcDid := svcPts[0]
29 svcId := "#" + svcPts[1]
30
31 doc, err := s.passport.FetchDoc(e.Request().Context(), svcDid)
32 if err != nil {
33 return "", "", err
34 }
35
36 var endpoint string
37 for _, s := range doc.Service {
38 if s.Id == svcId {
39 endpoint = s.ServiceEndpoint
40 }
41 }
42
43 return endpoint, svcDid, nil
44}
45
46func (s *Server) handleProxy(e echo.Context) error {
47 lgr := s.logger.With("handler", "handleProxy")
48
49 repo, isAuthed := e.Get("repo").(*models.RepoActor)
50
51 pts := strings.Split(e.Request().URL.Path, "/")
52 if len(pts) != 3 {
53 return fmt.Errorf("incorrect number of parts")
54 }
55
56 endpoint, svcDid, err := s.getAtprotoProxyEndpointFromRequest(e)
57 if err != nil {
58 lgr.Error("could not get atproto proxy", "error", err)
59 return helpers.ServerError(e, nil)
60 }
61
62 requrl := e.Request().URL
63 requrl.Host = strings.TrimPrefix(endpoint, "https://")
64 requrl.Scheme = "https"
65
66 body := e.Request().Body
67 if e.Request().Method == "GET" {
68 body = nil
69 }
70
71 req, err := http.NewRequest(e.Request().Method, requrl.String(), body)
72 if err != nil {
73 return err
74 }
75
76 req.Header = e.Request().Header.Clone()
77
78 if isAuthed {
79 // this is a little dumb. i should probably figure out a better way to do this, and use
80 // a single way of creating/signing jwts throughout the pds. kinda limited here because
81 // im using the atproto crypto lib for this though. will come back to it
82
83 header := map[string]string{
84 "alg": "ES256K",
85 "crv": "secp256k1",
86 "typ": "JWT",
87 }
88 hj, err := json.Marshal(header)
89 if err != nil {
90 lgr.Error("error marshaling header", "error", err)
91 return helpers.ServerError(e, nil)
92 }
93
94 encheader := strings.TrimRight(base64.RawURLEncoding.EncodeToString(hj), "=")
95
96 payload := map[string]any{
97 "iss": repo.Repo.Did,
98 "aud": svcDid,
99 "lxm": pts[2],
100 "jti": uuid.NewString(),
101 "exp": time.Now().Add(1 * time.Minute).UTC().Unix(),
102 }
103 pj, err := json.Marshal(payload)
104 if err != nil {
105 lgr.Error("error marashaling payload", "error", err)
106 return helpers.ServerError(e, nil)
107 }
108
109 encpayload := strings.TrimRight(base64.RawURLEncoding.EncodeToString(pj), "=")
110
111 input := fmt.Sprintf("%s.%s", encheader, encpayload)
112 hash := sha256.Sum256([]byte(input))
113
114 sk, err := secp256k1secec.NewPrivateKey(repo.SigningKey)
115 if err != nil {
116 lgr.Error("can't load private key", "error", err)
117 return err
118 }
119
120 R, S, _, err := sk.SignRaw(rand.Reader, hash[:])
121 if err != nil {
122 lgr.Error("error signing", "error", err)
123 }
124
125 rBytes := R.Bytes()
126 sBytes := S.Bytes()
127
128 rPadded := make([]byte, 32)
129 sPadded := make([]byte, 32)
130 copy(rPadded[32-len(rBytes):], rBytes)
131 copy(sPadded[32-len(sBytes):], sBytes)
132
133 rawsig := append(rPadded, sPadded...)
134 encsig := strings.TrimRight(base64.RawURLEncoding.EncodeToString(rawsig), "=")
135 token := fmt.Sprintf("%s.%s", input, encsig)
136
137 req.Header.Set("authorization", "Bearer "+token)
138 } else {
139 req.Header.Del("authorization")
140 }
141
142 resp, err := http.DefaultClient.Do(req)
143 if err != nil {
144 return err
145 }
146 defer resp.Body.Close()
147
148 for k, v := range resp.Header {
149 e.Response().Header().Set(k, strings.Join(v, ","))
150 }
151
152 return e.Stream(resp.StatusCode, e.Response().Header().Get("content-type"), resp.Body)
153}