An atproto PDS written in Go
1package server 2 3import ( 4 "crypto/rand" 5 "crypto/sha256" 6 "encoding/base64" 7 "encoding/json" 8 "fmt" 9 "net/http" 10 "strings" 11 "time" 12 13 "github.com/google/uuid" 14 "github.com/haileyok/cocoon/internal/helpers" 15 "github.com/haileyok/cocoon/models" 16 "github.com/labstack/echo/v4" 17 secp256k1secec "gitlab.com/yawning/secp256k1-voi/secec" 18) 19 20func (s *Server) getAtprotoProxyEndpointFromRequest(e echo.Context) (string, string, error) { 21 svc := e.Request().Header.Get("atproto-proxy") 22 23 svcPts := strings.Split(svc, "#") 24 if len(svcPts) != 2 { 25 return "", "", fmt.Errorf("invalid service header") 26 } 27 28 svcDid := svcPts[0] 29 svcId := "#" + svcPts[1] 30 31 doc, err := s.passport.FetchDoc(e.Request().Context(), svcDid) 32 if err != nil { 33 return "", "", err 34 } 35 36 var endpoint string 37 for _, s := range doc.Service { 38 if s.Id == svcId { 39 endpoint = s.ServiceEndpoint 40 } 41 } 42 43 return endpoint, svcDid, nil 44} 45 46func (s *Server) handleProxy(e echo.Context) error { 47 lgr := s.logger.With("handler", "handleProxy") 48 49 repo, isAuthed := e.Get("repo").(*models.RepoActor) 50 51 pts := strings.Split(e.Request().URL.Path, "/") 52 if len(pts) != 3 { 53 return fmt.Errorf("incorrect number of parts") 54 } 55 56 endpoint, svcDid, err := s.getAtprotoProxyEndpointFromRequest(e) 57 if err != nil { 58 lgr.Error("could not get atproto proxy", "error", err) 59 return helpers.ServerError(e, nil) 60 } 61 62 requrl := e.Request().URL 63 requrl.Host = strings.TrimPrefix(endpoint, "https://") 64 requrl.Scheme = "https" 65 66 body := e.Request().Body 67 if e.Request().Method == "GET" { 68 body = nil 69 } 70 71 req, err := http.NewRequest(e.Request().Method, requrl.String(), body) 72 if err != nil { 73 return err 74 } 75 76 req.Header = e.Request().Header.Clone() 77 78 if isAuthed { 79 // this is a little dumb. i should probably figure out a better way to do this, and use 80 // a single way of creating/signing jwts throughout the pds. kinda limited here because 81 // im using the atproto crypto lib for this though. will come back to it 82 83 header := map[string]string{ 84 "alg": "ES256K", 85 "crv": "secp256k1", 86 "typ": "JWT", 87 } 88 hj, err := json.Marshal(header) 89 if err != nil { 90 lgr.Error("error marshaling header", "error", err) 91 return helpers.ServerError(e, nil) 92 } 93 94 encheader := strings.TrimRight(base64.RawURLEncoding.EncodeToString(hj), "=") 95 96 payload := map[string]any{ 97 "iss": repo.Repo.Did, 98 "aud": svcDid, 99 "lxm": pts[2], 100 "jti": uuid.NewString(), 101 "exp": time.Now().Add(1 * time.Minute).UTC().Unix(), 102 } 103 pj, err := json.Marshal(payload) 104 if err != nil { 105 lgr.Error("error marashaling payload", "error", err) 106 return helpers.ServerError(e, nil) 107 } 108 109 encpayload := strings.TrimRight(base64.RawURLEncoding.EncodeToString(pj), "=") 110 111 input := fmt.Sprintf("%s.%s", encheader, encpayload) 112 hash := sha256.Sum256([]byte(input)) 113 114 sk, err := secp256k1secec.NewPrivateKey(repo.SigningKey) 115 if err != nil { 116 lgr.Error("can't load private key", "error", err) 117 return err 118 } 119 120 R, S, _, err := sk.SignRaw(rand.Reader, hash[:]) 121 if err != nil { 122 lgr.Error("error signing", "error", err) 123 } 124 125 rBytes := R.Bytes() 126 sBytes := S.Bytes() 127 128 rPadded := make([]byte, 32) 129 sPadded := make([]byte, 32) 130 copy(rPadded[32-len(rBytes):], rBytes) 131 copy(sPadded[32-len(sBytes):], sBytes) 132 133 rawsig := append(rPadded, sPadded...) 134 encsig := strings.TrimRight(base64.RawURLEncoding.EncodeToString(rawsig), "=") 135 token := fmt.Sprintf("%s.%s", input, encsig) 136 137 req.Header.Set("authorization", "Bearer "+token) 138 } else { 139 req.Header.Del("authorization") 140 } 141 142 resp, err := http.DefaultClient.Do(req) 143 if err != nil { 144 return err 145 } 146 defer resp.Body.Close() 147 148 for k, v := range resp.Header { 149 e.Response().Header().Set(k, strings.Join(v, ",")) 150 } 151 152 return e.Stream(resp.StatusCode, e.Response().Header().Get("content-type"), resp.Body) 153}