1package server
2
3import (
4 "crypto/rand"
5 "crypto/sha256"
6 "encoding/base64"
7 "encoding/json"
8 "fmt"
9 "net/http"
10 "strings"
11 "time"
12
13 "github.com/google/uuid"
14 "github.com/haileyok/cocoon/internal/helpers"
15 "github.com/haileyok/cocoon/models"
16 "github.com/labstack/echo/v4"
17 secp256k1secec "gitlab.com/yawning/secp256k1-voi/secec"
18)
19
20func (s *Server) getAtprotoProxyEndpointFromRequest(e echo.Context) (string, string, error) {
21 svc := e.Request().Header.Get("atproto-proxy")
22 if svc == "" {
23 svc = s.config.DefaultAtprotoProxy
24 }
25
26 svcPts := strings.Split(svc, "#")
27 if len(svcPts) != 2 {
28 return "", "", fmt.Errorf("invalid service header")
29 }
30
31 svcDid := svcPts[0]
32 svcId := "#" + svcPts[1]
33
34 doc, err := s.passport.FetchDoc(e.Request().Context(), svcDid)
35 if err != nil {
36 return "", "", err
37 }
38
39 var endpoint string
40 for _, s := range doc.Service {
41 if s.Id == svcId {
42 endpoint = s.ServiceEndpoint
43 }
44 }
45
46 return endpoint, svcDid, nil
47}
48
49func (s *Server) handleProxy(e echo.Context) error {
50 lgr := s.logger.With("handler", "handleProxy")
51
52 repo, isAuthed := e.Get("repo").(*models.RepoActor)
53
54 pts := strings.Split(e.Request().URL.Path, "/")
55 if len(pts) != 3 {
56 return fmt.Errorf("incorrect number of parts")
57 }
58
59 endpoint, svcDid, err := s.getAtprotoProxyEndpointFromRequest(e)
60 if err != nil {
61 lgr.Error("could not get atproto proxy", "error", err)
62 return helpers.ServerError(e, nil)
63 }
64
65 requrl := e.Request().URL
66 requrl.Host = strings.TrimPrefix(endpoint, "https://")
67 requrl.Scheme = "https"
68
69 body := e.Request().Body
70 if e.Request().Method == "GET" {
71 body = nil
72 }
73
74 req, err := http.NewRequest(e.Request().Method, requrl.String(), body)
75 if err != nil {
76 return err
77 }
78
79 req.Header = e.Request().Header.Clone()
80
81 if isAuthed {
82 // this is a little dumb. i should probably figure out a better way to do this, and use
83 // a single way of creating/signing jwts throughout the pds. kinda limited here because
84 // im using the atproto crypto lib for this though. will come back to it
85
86 header := map[string]string{
87 "alg": "ES256K",
88 "crv": "secp256k1",
89 "typ": "JWT",
90 }
91 hj, err := json.Marshal(header)
92 if err != nil {
93 lgr.Error("error marshaling header", "error", err)
94 return helpers.ServerError(e, nil)
95 }
96
97 encheader := strings.TrimRight(base64.RawURLEncoding.EncodeToString(hj), "=")
98
99 payload := map[string]any{
100 "iss": repo.Repo.Did,
101 "aud": svcDid,
102 "lxm": pts[2],
103 "jti": uuid.NewString(),
104 "exp": time.Now().Add(1 * time.Minute).UTC().Unix(),
105 }
106 pj, err := json.Marshal(payload)
107 if err != nil {
108 lgr.Error("error marashaling payload", "error", err)
109 return helpers.ServerError(e, nil)
110 }
111
112 encpayload := strings.TrimRight(base64.RawURLEncoding.EncodeToString(pj), "=")
113
114 input := fmt.Sprintf("%s.%s", encheader, encpayload)
115 hash := sha256.Sum256([]byte(input))
116
117 sk, err := secp256k1secec.NewPrivateKey(repo.SigningKey)
118 if err != nil {
119 lgr.Error("can't load private key", "error", err)
120 return err
121 }
122
123 R, S, _, err := sk.SignRaw(rand.Reader, hash[:])
124 if err != nil {
125 lgr.Error("error signing", "error", err)
126 }
127
128 rBytes := R.Bytes()
129 sBytes := S.Bytes()
130
131 rPadded := make([]byte, 32)
132 sPadded := make([]byte, 32)
133 copy(rPadded[32-len(rBytes):], rBytes)
134 copy(sPadded[32-len(sBytes):], sBytes)
135
136 rawsig := append(rPadded, sPadded...)
137 encsig := strings.TrimRight(base64.RawURLEncoding.EncodeToString(rawsig), "=")
138 token := fmt.Sprintf("%s.%s", input, encsig)
139
140 req.Header.Set("authorization", "Bearer "+token)
141 } else {
142 req.Header.Del("authorization")
143 }
144
145 resp, err := http.DefaultClient.Do(req)
146 if err != nil {
147 return err
148 }
149 defer resp.Body.Close()
150
151 for k, v := range resp.Header {
152 e.Response().Header().Set(k, strings.Join(v, ","))
153 }
154
155 return e.Stream(resp.StatusCode, e.Response().Header().Get("content-type"), resp.Body)
156}