An atproto PDS written in Go
1package server 2 3import ( 4 "crypto/rand" 5 "crypto/sha256" 6 "encoding/base64" 7 "encoding/json" 8 "fmt" 9 "net/http" 10 "strings" 11 "time" 12 13 "github.com/google/uuid" 14 "github.com/haileyok/cocoon/internal/helpers" 15 "github.com/haileyok/cocoon/models" 16 "github.com/labstack/echo/v4" 17 secp256k1secec "gitlab.com/yawning/secp256k1-voi/secec" 18) 19 20func (s *Server) getAtprotoProxyEndpointFromRequest(e echo.Context) (string, string, error) { 21 svc := e.Request().Header.Get("atproto-proxy") 22 if svc == "" { 23 svc = s.config.DefaultAtprotoProxy 24 } 25 26 svcPts := strings.Split(svc, "#") 27 if len(svcPts) != 2 { 28 return "", "", fmt.Errorf("invalid service header") 29 } 30 31 svcDid := svcPts[0] 32 svcId := "#" + svcPts[1] 33 34 doc, err := s.passport.FetchDoc(e.Request().Context(), svcDid) 35 if err != nil { 36 return "", "", err 37 } 38 39 var endpoint string 40 for _, s := range doc.Service { 41 if s.Id == svcId { 42 endpoint = s.ServiceEndpoint 43 } 44 } 45 46 return endpoint, "", nil 47} 48 49func (s *Server) handleProxy(e echo.Context) error { 50 lgr := s.logger.With("handler", "handleProxy") 51 52 repo, isAuthed := e.Get("repo").(*models.RepoActor) 53 54 pts := strings.Split(e.Request().URL.Path, "/") 55 if len(pts) != 3 { 56 return fmt.Errorf("incorrect number of parts") 57 } 58 59 endpoint, svcDid, err := s.getAtprotoProxyEndpointFromRequest(e) 60 if err != nil { 61 lgr.Error("could not get atproto proxy", "error", err) 62 return helpers.ServerError(e, nil) 63 } 64 65 requrl := e.Request().URL 66 requrl.Host = strings.TrimPrefix(endpoint, "https://") 67 requrl.Scheme = "https" 68 69 body := e.Request().Body 70 if e.Request().Method == "GET" { 71 body = nil 72 } 73 74 req, err := http.NewRequest(e.Request().Method, requrl.String(), body) 75 if err != nil { 76 return err 77 } 78 79 req.Header = e.Request().Header.Clone() 80 81 if isAuthed { 82 // this is a little dumb. i should probably figure out a better way to do this, and use 83 // a single way of creating/signing jwts throughout the pds. kinda limited here because 84 // im using the atproto crypto lib for this though. will come back to it 85 86 header := map[string]string{ 87 "alg": "ES256K", 88 "crv": "secp256k1", 89 "typ": "JWT", 90 } 91 hj, err := json.Marshal(header) 92 if err != nil { 93 lgr.Error("error marshaling header", "error", err) 94 return helpers.ServerError(e, nil) 95 } 96 97 encheader := strings.TrimRight(base64.RawURLEncoding.EncodeToString(hj), "=") 98 99 payload := map[string]any{ 100 "iss": repo.Repo.Did, 101 "aud": svcDid, 102 "lxm": pts[2], 103 "jti": uuid.NewString(), 104 "exp": time.Now().Add(1 * time.Minute).UTC().Unix(), 105 } 106 pj, err := json.Marshal(payload) 107 if err != nil { 108 lgr.Error("error marashaling payload", "error", err) 109 return helpers.ServerError(e, nil) 110 } 111 112 encpayload := strings.TrimRight(base64.RawURLEncoding.EncodeToString(pj), "=") 113 114 input := fmt.Sprintf("%s.%s", encheader, encpayload) 115 hash := sha256.Sum256([]byte(input)) 116 117 sk, err := secp256k1secec.NewPrivateKey(repo.SigningKey) 118 if err != nil { 119 lgr.Error("can't load private key", "error", err) 120 return err 121 } 122 123 R, S, _, err := sk.SignRaw(rand.Reader, hash[:]) 124 if err != nil { 125 lgr.Error("error signing", "error", err) 126 } 127 128 rBytes := R.Bytes() 129 sBytes := S.Bytes() 130 131 rPadded := make([]byte, 32) 132 sPadded := make([]byte, 32) 133 copy(rPadded[32-len(rBytes):], rBytes) 134 copy(sPadded[32-len(sBytes):], sBytes) 135 136 rawsig := append(rPadded, sPadded...) 137 encsig := strings.TrimRight(base64.RawURLEncoding.EncodeToString(rawsig), "=") 138 token := fmt.Sprintf("%s.%s", input, encsig) 139 140 req.Header.Set("authorization", "Bearer "+token) 141 } else { 142 req.Header.Del("authorization") 143 } 144 145 resp, err := http.DefaultClient.Do(req) 146 if err != nil { 147 return err 148 } 149 defer resp.Body.Close() 150 151 for k, v := range resp.Header { 152 e.Response().Header().Set(k, strings.Join(v, ",")) 153 } 154 155 return e.Stream(resp.StatusCode, e.Response().Header().Get("content-type"), resp.Body) 156}