An atproto PDS written in Go
at 0.0.2 3.5 kB view raw
1package server 2 3import ( 4 "crypto/rand" 5 "crypto/sha256" 6 "encoding/base64" 7 "encoding/json" 8 "fmt" 9 "net/http" 10 "strings" 11 "time" 12 13 "github.com/google/uuid" 14 "github.com/haileyok/cocoon/internal/helpers" 15 "github.com/haileyok/cocoon/models" 16 "github.com/labstack/echo/v4" 17 secp256k1secec "gitlab.com/yawning/secp256k1-voi/secec" 18) 19 20func (s *Server) handleProxy(e echo.Context) error { 21 repo, isAuthed := e.Get("repo").(*models.RepoActor) 22 23 pts := strings.Split(e.Request().URL.Path, "/") 24 if len(pts) != 3 { 25 return fmt.Errorf("incorrect number of parts") 26 } 27 28 svc := e.Request().Header.Get("atproto-proxy") 29 if svc == "" { 30 svc = "did:web:api.bsky.app#bsky_appview" // TODO: should be a config var probably 31 } 32 33 svcPts := strings.Split(svc, "#") 34 if len(svcPts) != 2 { 35 return fmt.Errorf("invalid service header") 36 } 37 38 svcDid := svcPts[0] 39 svcId := "#" + svcPts[1] 40 41 doc, err := s.passport.FetchDoc(e.Request().Context(), svcDid) 42 if err != nil { 43 return err 44 } 45 46 var endpoint string 47 for _, s := range doc.Service { 48 if s.Id == svcId { 49 endpoint = s.ServiceEndpoint 50 } 51 } 52 53 requrl := e.Request().URL 54 requrl.Host = strings.TrimPrefix(endpoint, "https://") 55 requrl.Scheme = "https" 56 57 body := e.Request().Body 58 if e.Request().Method == "GET" { 59 body = nil 60 } 61 62 req, err := http.NewRequest(e.Request().Method, requrl.String(), body) 63 if err != nil { 64 return err 65 } 66 67 req.Header = e.Request().Header.Clone() 68 69 if isAuthed { 70 // this is a little dumb. i should probably figure out a better way to do this, and use 71 // a single way of creating/signing jwts throughout the pds. kinda limited here because 72 // im using the atproto crypto lib for this though. will come back to it 73 74 header := map[string]string{ 75 "alg": "ES256K", 76 "crv": "secp256k1", 77 "typ": "JWT", 78 } 79 hj, err := json.Marshal(header) 80 if err != nil { 81 s.logger.Error("error marshaling header", "error", err) 82 return helpers.ServerError(e, nil) 83 } 84 85 encheader := strings.TrimRight(base64.RawURLEncoding.EncodeToString(hj), "=") 86 87 payload := map[string]any{ 88 "iss": repo.Repo.Did, 89 "aud": svcDid, 90 "lxm": pts[2], 91 "jti": uuid.NewString(), 92 "exp": time.Now().Add(1 * time.Minute).UTC().Unix(), 93 } 94 pj, err := json.Marshal(payload) 95 if err != nil { 96 s.logger.Error("error marashaling payload", "error", err) 97 return helpers.ServerError(e, nil) 98 } 99 100 encpayload := strings.TrimRight(base64.RawURLEncoding.EncodeToString(pj), "=") 101 102 input := fmt.Sprintf("%s.%s", encheader, encpayload) 103 hash := sha256.Sum256([]byte(input)) 104 105 sk, err := secp256k1secec.NewPrivateKey(repo.SigningKey) 106 if err != nil { 107 s.logger.Error("can't load private key", "error", err) 108 return err 109 } 110 111 R, S, _, err := sk.SignRaw(rand.Reader, hash[:]) 112 if err != nil { 113 s.logger.Error("error signing", "error", err) 114 } 115 116 rBytes := R.Bytes() 117 sBytes := S.Bytes() 118 119 rPadded := make([]byte, 32) 120 sPadded := make([]byte, 32) 121 copy(rPadded[32-len(rBytes):], rBytes) 122 copy(sPadded[32-len(sBytes):], sBytes) 123 124 rawsig := append(rPadded, sPadded...) 125 encsig := strings.TrimRight(base64.RawURLEncoding.EncodeToString(rawsig), "=") 126 token := fmt.Sprintf("%s.%s", input, encsig) 127 128 req.Header.Set("authorization", "Bearer "+token) 129 } else { 130 req.Header.Del("authorization") 131 } 132 133 resp, err := http.DefaultClient.Do(req) 134 if err != nil { 135 return err 136 } 137 defer resp.Body.Close() 138 139 for k, v := range resp.Header { 140 e.Response().Header().Set(k, strings.Join(v, ",")) 141 } 142 143 return e.Stream(resp.StatusCode, e.Response().Header().Get("content-type"), resp.Body) 144}