An atproto PDS written in Go
at v0.5.1 4.4 kB view raw
1package server 2 3import ( 4 "crypto/rand" 5 "crypto/sha256" 6 "encoding/base64" 7 "encoding/json" 8 "fmt" 9 "net/http" 10 "strings" 11 "time" 12 13 "github.com/google/uuid" 14 "github.com/haileyok/cocoon/internal/helpers" 15 "github.com/haileyok/cocoon/models" 16 "github.com/labstack/echo/v4" 17 secp256k1secec "gitlab.com/yawning/secp256k1-voi/secec" 18) 19 20func (s *Server) getAtprotoProxyEndpointFromRequest(e echo.Context) (string, string, error) { 21 svc := e.Request().Header.Get("atproto-proxy") 22 if svc == "" && s.config.FallbackProxy != "" { 23 svc = s.config.FallbackProxy 24 } 25 26 svcPts := strings.Split(svc, "#") 27 if len(svcPts) != 2 { 28 return "", "", fmt.Errorf("invalid service header") 29 } 30 31 svcDid := svcPts[0] 32 svcId := "#" + svcPts[1] 33 34 doc, err := s.passport.FetchDoc(e.Request().Context(), svcDid) 35 if err != nil { 36 return "", "", err 37 } 38 39 var endpoint string 40 for _, s := range doc.Service { 41 if s.Id == svcId { 42 endpoint = s.ServiceEndpoint 43 } 44 } 45 46 return endpoint, svcDid, nil 47} 48 49func (s *Server) handleProxy(e echo.Context) error { 50 lgr := s.logger.With("handler", "handleProxy") 51 52 repo, isAuthed := e.Get("repo").(*models.RepoActor) 53 54 pts := strings.Split(e.Request().URL.Path, "/") 55 if len(pts) != 3 { 56 return fmt.Errorf("incorrect number of parts") 57 } 58 59 endpoint, svcDid, err := s.getAtprotoProxyEndpointFromRequest(e) 60 if err != nil { 61 lgr.Error("could not get atproto proxy", "error", err) 62 return helpers.ServerError(e, nil) 63 } 64 65 requrl := e.Request().URL 66 requrl.Host = strings.TrimPrefix(endpoint, "https://") 67 requrl.Scheme = "https" 68 69 body := e.Request().Body 70 if e.Request().Method == "GET" { 71 body = nil 72 } 73 74 req, err := http.NewRequest(e.Request().Method, requrl.String(), body) 75 if err != nil { 76 return err 77 } 78 79 req.Header = e.Request().Header.Clone() 80 81 if isAuthed { 82 // this is a little dumb. i should probably figure out a better way to do this, and use 83 // a single way of creating/signing jwts throughout the pds. kinda limited here because 84 // im using the atproto crypto lib for this though. will come back to it 85 86 header := map[string]string{ 87 "alg": "ES256K", 88 "crv": "secp256k1", 89 "typ": "JWT", 90 } 91 hj, err := json.Marshal(header) 92 if err != nil { 93 lgr.Error("error marshaling header", "error", err) 94 return helpers.ServerError(e, nil) 95 } 96 97 encheader := strings.TrimRight(base64.RawURLEncoding.EncodeToString(hj), "=") 98 99 // When proxying app.bsky.feed.getFeed the token is actually issued for the 100 // underlying feed generator and the app view passes it on. This allows the 101 // getFeed implementation to pass in the desired lxm and aud for the token 102 // and then just delegate to the general proxying logic 103 lxm, proxyTokenLxmExists := e.Get("proxyTokenLxm").(string) 104 if !proxyTokenLxmExists || lxm == "" { 105 lxm = pts[2] 106 } 107 aud, proxyTokenAudExists := e.Get("proxyTokenAud").(string) 108 if !proxyTokenAudExists || aud == "" { 109 aud = svcDid 110 } 111 112 payload := map[string]any{ 113 "iss": repo.Repo.Did, 114 "aud": aud, 115 "lxm": lxm, 116 "jti": uuid.NewString(), 117 "exp": time.Now().Add(1 * time.Minute).UTC().Unix(), 118 } 119 pj, err := json.Marshal(payload) 120 if err != nil { 121 lgr.Error("error marashaling payload", "error", err) 122 return helpers.ServerError(e, nil) 123 } 124 125 encpayload := strings.TrimRight(base64.RawURLEncoding.EncodeToString(pj), "=") 126 127 input := fmt.Sprintf("%s.%s", encheader, encpayload) 128 hash := sha256.Sum256([]byte(input)) 129 130 sk, err := secp256k1secec.NewPrivateKey(repo.SigningKey) 131 if err != nil { 132 lgr.Error("can't load private key", "error", err) 133 return err 134 } 135 136 R, S, _, err := sk.SignRaw(rand.Reader, hash[:]) 137 if err != nil { 138 lgr.Error("error signing", "error", err) 139 } 140 141 rBytes := R.Bytes() 142 sBytes := S.Bytes() 143 144 rPadded := make([]byte, 32) 145 sPadded := make([]byte, 32) 146 copy(rPadded[32-len(rBytes):], rBytes) 147 copy(sPadded[32-len(sBytes):], sBytes) 148 149 rawsig := append(rPadded, sPadded...) 150 encsig := strings.TrimRight(base64.RawURLEncoding.EncodeToString(rawsig), "=") 151 token := fmt.Sprintf("%s.%s", input, encsig) 152 153 req.Header.Set("authorization", "Bearer "+token) 154 } else { 155 req.Header.Del("authorization") 156 } 157 158 resp, err := http.DefaultClient.Do(req) 159 if err != nil { 160 return err 161 } 162 defer resp.Body.Close() 163 164 for k, v := range resp.Header { 165 e.Response().Header().Set(k, strings.Join(v, ",")) 166 } 167 168 return e.Stream(resp.StatusCode, e.Response().Header().Get("content-type"), resp.Body) 169}