1package server
2
3import (
4 "context"
5 "crypto/ecdsa"
6 "errors"
7 "fmt"
8 "log/slog"
9 "net/http"
10 "os"
11 "strings"
12 "time"
13
14 "github.com/Azure/go-autorest/autorest/to"
15 "github.com/bluesky-social/indigo/api/atproto"
16 "github.com/bluesky-social/indigo/atproto/syntax"
17 "github.com/bluesky-social/indigo/events"
18 "github.com/bluesky-social/indigo/xrpc"
19 "github.com/go-playground/validator"
20 "github.com/golang-jwt/jwt/v4"
21 "github.com/haileyok/cocoon/identity"
22 "github.com/haileyok/cocoon/internal/helpers"
23 "github.com/haileyok/cocoon/models"
24 "github.com/haileyok/cocoon/plc"
25 "github.com/labstack/echo/v4"
26 "github.com/labstack/echo/v4/middleware"
27 "github.com/lestrrat-go/jwx/v2/jwk"
28 slogecho "github.com/samber/slog-echo"
29 "gorm.io/driver/sqlite"
30 "gorm.io/gorm"
31)
32
33type Server struct {
34 httpd *http.Server
35 echo *echo.Echo
36 db *gorm.DB
37 plcClient *plc.Client
38 logger *slog.Logger
39 config *config
40 privateKey *ecdsa.PrivateKey
41 repoman *RepoMan
42 evtman *events.EventManager
43 passport *identity.Passport
44}
45
46type Args struct {
47 Addr string
48 DbName string
49 Logger *slog.Logger
50 Version string
51 Did string
52 Hostname string
53 RotationKeyPath string
54 JwkPath string
55 ContactEmail string
56 Relays []string
57}
58
59type config struct {
60 Version string
61 Did string
62 Hostname string
63 ContactEmail string
64 EnforcePeering bool
65 Relays []string
66}
67
68type CustomValidator struct {
69 validator *validator.Validate
70}
71
72type ValidationError struct {
73 error
74 Field string
75 Tag string
76}
77
78func (cv *CustomValidator) Validate(i any) error {
79 if err := cv.validator.Struct(i); err != nil {
80 var validateErrors validator.ValidationErrors
81 if errors.As(err, &validateErrors) && len(validateErrors) > 0 {
82 first := validateErrors[0]
83 return ValidationError{
84 error: err,
85 Field: first.Field(),
86 Tag: first.Tag(),
87 }
88 }
89
90 return err
91 }
92
93 return nil
94}
95
96func (s *Server) handleSessionMiddleware(next echo.HandlerFunc) echo.HandlerFunc {
97 return func(e echo.Context) error {
98 authheader := e.Request().Header.Get("authorization")
99 if authheader == "" {
100 return e.JSON(401, map[string]string{"error": "Unauthorized"})
101 }
102
103 pts := strings.Split(authheader, " ")
104 if len(pts) != 2 {
105 return helpers.ServerError(e, nil)
106 }
107
108 tokenstr := pts[1]
109
110 token, err := new(jwt.Parser).Parse(tokenstr, func(t *jwt.Token) (any, error) {
111 if _, ok := t.Method.(*jwt.SigningMethodECDSA); !ok {
112 return nil, fmt.Errorf("unsupported signing method: %v", t.Header["alg"])
113 }
114
115 return s.privateKey.Public(), nil
116 })
117 if err != nil {
118 s.logger.Error("error parsing jwt", "error", err)
119 return helpers.InputError(e, to.StringPtr("InvalidToken"))
120 }
121
122 claims, ok := token.Claims.(jwt.MapClaims)
123 if !ok || !token.Valid {
124 return helpers.InputError(e, to.StringPtr("InvalidToken"))
125 }
126
127 isRefresh := e.Request().URL.Path == "/xrpc/com.atproto.server.refreshSession"
128 scope := claims["scope"].(string)
129
130 if isRefresh && scope != "com.atproto.refresh" {
131 return helpers.InputError(e, to.StringPtr("InvalidToken"))
132 } else if !isRefresh && scope != "com.atproto.access" {
133 return helpers.InputError(e, to.StringPtr("InvalidToken"))
134 }
135
136 table := "tokens"
137 if isRefresh {
138 table = "refresh_tokens"
139 }
140
141 type Result struct {
142 Found bool
143 }
144 var result Result
145 if err := s.db.Raw("SELECT EXISTS(SELECT 1 FROM "+table+" WHERE token = ?) AS found", tokenstr).Scan(&result).Error; err != nil {
146 if err == gorm.ErrRecordNotFound {
147 return helpers.InputError(e, to.StringPtr("InvalidToken"))
148 }
149
150 s.logger.Error("error getting token from db", "error", err)
151 return helpers.ServerError(e, nil)
152 }
153
154 if !result.Found {
155 return helpers.InputError(e, to.StringPtr("InvalidToken"))
156 }
157
158 exp, ok := claims["exp"].(float64)
159 if !ok {
160 s.logger.Error("error getting iat from token")
161 return helpers.ServerError(e, nil)
162 }
163
164 if exp < float64(time.Now().UTC().Unix()) {
165 return helpers.InputError(e, to.StringPtr("ExpiredToken"))
166 }
167
168 e.Set("did", claims["sub"])
169
170 repo, err := s.getRepoActorByDid(claims["sub"].(string))
171 if err != nil {
172 s.logger.Error("error fetching repo", "error", err)
173 return helpers.ServerError(e, nil)
174 }
175 e.Set("repo", repo)
176
177 e.Set("token", tokenstr)
178
179 if err := next(e); err != nil {
180 e.Error(err)
181 }
182
183 return nil
184 }
185}
186
187func New(args *Args) (*Server, error) {
188 if args.Addr == "" {
189 return nil, fmt.Errorf("addr must be set")
190 }
191
192 if args.DbName == "" {
193 return nil, fmt.Errorf("db name must be set")
194 }
195
196 if args.Did == "" {
197 return nil, fmt.Errorf("cocoon did must be set")
198 }
199
200 if args.ContactEmail == "" {
201 return nil, fmt.Errorf("cocoon contact email is required")
202 }
203
204 if _, err := syntax.ParseDID(args.Did); err != nil {
205 return nil, fmt.Errorf("error parsing cocoon did: %w", err)
206 }
207
208 if args.Hostname == "" {
209 return nil, fmt.Errorf("cocoon hostname must be set")
210 }
211
212 if args.Logger == nil {
213 args.Logger = slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{}))
214 }
215
216 e := echo.New()
217
218 e.Pre(middleware.RemoveTrailingSlash())
219 e.Pre(slogecho.New(args.Logger))
220 e.Use(middleware.CORSWithConfig(middleware.CORSConfig{
221 AllowOrigins: []string{"*"},
222 AllowHeaders: []string{"*"},
223 AllowMethods: []string{"*"},
224 AllowCredentials: true,
225 MaxAge: 100_000_000,
226 }))
227
228 vdtor := validator.New()
229 vdtor.RegisterValidation("atproto-handle", func(fl validator.FieldLevel) bool {
230 if _, err := syntax.ParseHandle(fl.Field().String()); err != nil {
231 return false
232 }
233 return true
234 })
235 vdtor.RegisterValidation("atproto-did", func(fl validator.FieldLevel) bool {
236 if _, err := syntax.ParseDID(fl.Field().String()); err != nil {
237 return false
238 }
239 return true
240 })
241 vdtor.RegisterValidation("atproto-rkey", func(fl validator.FieldLevel) bool {
242 if _, err := syntax.ParseRecordKey(fl.Field().String()); err != nil {
243 return false
244 }
245 return true
246 })
247 vdtor.RegisterValidation("atproto-nsid", func(fl validator.FieldLevel) bool {
248 if _, err := syntax.ParseNSID(fl.Field().String()); err != nil {
249 return false
250 }
251 return true
252 })
253
254 e.Validator = &CustomValidator{validator: vdtor}
255
256 httpd := &http.Server{
257 Addr: args.Addr,
258 Handler: e,
259 }
260
261 db, err := gorm.Open(sqlite.Open("cocoon.db"), &gorm.Config{})
262 if err != nil {
263 return nil, err
264 }
265
266 rkbytes, err := os.ReadFile(args.RotationKeyPath)
267 if err != nil {
268 return nil, err
269 }
270
271 plcClient, err := plc.NewClient(&plc.ClientArgs{
272 Service: "https://plc.directory",
273 PdsHostname: args.Hostname,
274 RotationKey: rkbytes,
275 })
276 if err != nil {
277 return nil, err
278 }
279
280 jwkbytes, err := os.ReadFile(args.JwkPath)
281 if err != nil {
282 return nil, err
283 }
284
285 key, err := jwk.ParseKey(jwkbytes)
286 if err != nil {
287 return nil, err
288 }
289
290 var pkey ecdsa.PrivateKey
291 if err := key.Raw(&pkey); err != nil {
292 return nil, err
293 }
294
295 s := &Server{
296 httpd: httpd,
297 echo: e,
298 logger: args.Logger,
299 db: db,
300 plcClient: plcClient,
301 privateKey: &pkey,
302 config: &config{
303 Version: args.Version,
304 Did: args.Did,
305 Hostname: args.Hostname,
306 ContactEmail: args.ContactEmail,
307 EnforcePeering: false,
308 Relays: args.Relays,
309 },
310 evtman: events.NewEventManager(events.NewMemPersister()),
311 passport: identity.NewPassport(identity.NewMemCache(10_000)),
312 }
313
314 s.repoman = NewRepoMan(s) // TODO: this is way too lazy, stop it
315
316 return s, nil
317}
318
319func (s *Server) addRoutes() {
320 s.echo.GET("/", s.handleRoot)
321 s.echo.GET("/xrpc/_health", s.handleHealth)
322 s.echo.GET("/.well-known/did.json", s.handleWellKnown)
323 s.echo.GET("/robots.txt", s.handleRobots)
324
325 // public
326 s.echo.GET("/xrpc/com.atproto.identity.resolveHandle", s.handleResolveHandle)
327 s.echo.POST("/xrpc/com.atproto.server.createAccount", s.handleCreateAccount)
328 s.echo.POST("/xrpc/com.atproto.server.createAccount", s.handleCreateAccount)
329 s.echo.POST("/xrpc/com.atproto.server.createSession", s.handleCreateSession)
330 s.echo.GET("/xrpc/com.atproto.server.describeServer", s.handleDescribeServer)
331
332 s.echo.GET("/xrpc/com.atproto.repo.describeRepo", s.handleDescribeRepo)
333 s.echo.GET("/xrpc/com.atproto.sync.listRepos", s.handleListRepos)
334 s.echo.GET("/xrpc/com.atproto.repo.listRecords", s.handleListRecords)
335 s.echo.GET("/xrpc/com.atproto.repo.getRecord", s.handleRepoGetRecord)
336 s.echo.GET("/xrpc/com.atproto.sync.getRecord", s.handleSyncGetRecord)
337 s.echo.GET("/xrpc/com.atproto.sync.getBlocks", s.handleGetBlocks)
338 s.echo.GET("/xrpc/com.atproto.sync.getLatestCommit", s.handleSyncGetLatestCommit)
339 s.echo.GET("/xrpc/com.atproto.sync.getRepoStatus", s.handleSyncGetRepoStatus)
340 s.echo.GET("/xrpc/com.atproto.sync.getRepo", s.handleSyncGetRepo)
341 s.echo.GET("/xrpc/com.atproto.sync.subscribeRepos", s.handleSyncSubscribeRepos)
342 s.echo.GET("/xrpc/com.atproto.sync.listBlobs", s.handleSyncListBlobs)
343 s.echo.GET("/xrpc/com.atproto.sync.getBlob", s.handleSyncGetBlob)
344
345 // authed
346 s.echo.GET("/xrpc/com.atproto.server.getSession", s.handleGetSession, s.handleSessionMiddleware)
347 s.echo.POST("/xrpc/com.atproto.server.refreshSession", s.handleRefreshSession, s.handleSessionMiddleware)
348 s.echo.POST("/xrpc/com.atproto.server.deleteSession", s.handleDeleteSession, s.handleSessionMiddleware)
349 s.echo.POST("/xrpc/com.atproto.identity.updateHandle", s.handleIdentityUpdateHandle, s.handleSessionMiddleware)
350
351 // repo
352 s.echo.POST("/xrpc/com.atproto.repo.createRecord", s.handleCreateRecord, s.handleSessionMiddleware)
353 s.echo.POST("/xrpc/com.atproto.repo.putRecord", s.handlePutRecord, s.handleSessionMiddleware)
354 s.echo.POST("/xrpc/com.atproto.repo.applyWrites", s.handleApplyWrites, s.handleSessionMiddleware)
355 s.echo.POST("/xrpc/com.atproto.repo.uploadBlob", s.handleRepoUploadBlob, s.handleSessionMiddleware)
356
357 // stupid silly endpoints
358 s.echo.GET("/xrpc/app.bsky.actor.getPreferences", s.handleActorGetPreferences, s.handleSessionMiddleware)
359 s.echo.POST("/xrpc/app.bsky.actor.putPreferences", s.handleActorPutPreferences, s.handleSessionMiddleware)
360
361 s.echo.GET("/xrpc/*", s.handleProxy, s.handleSessionMiddleware)
362 s.echo.POST("/xrpc/*", s.handleProxy, s.handleSessionMiddleware)
363}
364
365func (s *Server) Serve(ctx context.Context) error {
366 s.addRoutes()
367
368 s.logger.Info("migrating...")
369
370 s.db.AutoMigrate(
371 &models.Actor{},
372 &models.Repo{},
373 &models.InviteCode{},
374 &models.Token{},
375 &models.RefreshToken{},
376 &models.Block{},
377 &models.Record{},
378 &models.Blob{},
379 &models.BlobPart{},
380 )
381
382 s.logger.Info("starting cocoon")
383
384 go func() {
385 if err := s.httpd.ListenAndServe(); err != nil {
386 panic(err)
387 }
388 }()
389
390 for _, relay := range s.config.Relays {
391 cli := xrpc.Client{Host: relay}
392 atproto.SyncRequestCrawl(context.TODO(), &cli, &atproto.SyncRequestCrawl_Input{
393 Hostname: s.config.Hostname,
394 })
395 }
396
397 <-ctx.Done()
398
399 fmt.Println("shut down")
400
401 return nil
402}