1package server
2
3import (
4 "bytes"
5 "context"
6 "crypto/ecdsa"
7 "errors"
8 "fmt"
9 "io"
10 "log/slog"
11 "net/http"
12 "net/smtp"
13 "os"
14 "strings"
15 "sync"
16 "time"
17
18 "github.com/Azure/go-autorest/autorest/to"
19 "github.com/aws/aws-sdk-go/aws"
20 "github.com/aws/aws-sdk-go/aws/credentials"
21 "github.com/aws/aws-sdk-go/aws/session"
22 "github.com/aws/aws-sdk-go/service/s3"
23 "github.com/bluesky-social/indigo/api/atproto"
24 "github.com/bluesky-social/indigo/atproto/syntax"
25 "github.com/bluesky-social/indigo/events"
26 "github.com/bluesky-social/indigo/util"
27 "github.com/bluesky-social/indigo/xrpc"
28 "github.com/domodwyer/mailyak/v3"
29 "github.com/go-playground/validator"
30 "github.com/golang-jwt/jwt/v4"
31 "github.com/haileyok/cocoon/identity"
32 "github.com/haileyok/cocoon/internal/db"
33 "github.com/haileyok/cocoon/internal/helpers"
34 "github.com/haileyok/cocoon/models"
35 "github.com/haileyok/cocoon/plc"
36 "github.com/labstack/echo/v4"
37 "github.com/labstack/echo/v4/middleware"
38 "github.com/lestrrat-go/jwx/v2/jwk"
39 slogecho "github.com/samber/slog-echo"
40 "gorm.io/driver/sqlite"
41 "gorm.io/gorm"
42)
43
44type S3Config struct {
45 BackupsEnabled bool
46 Endpoint string
47 Region string
48 Bucket string
49 AccessKey string
50 SecretKey string
51}
52
53type Server struct {
54 http *http.Client
55 httpd *http.Server
56 mail *mailyak.MailYak
57 mailLk *sync.Mutex
58 echo *echo.Echo
59 db *db.DB
60 plcClient *plc.Client
61 logger *slog.Logger
62 config *config
63 privateKey *ecdsa.PrivateKey
64 repoman *RepoMan
65 evtman *events.EventManager
66 passport *identity.Passport
67
68 dbName string
69 s3Config *S3Config
70}
71
72type Args struct {
73 Addr string
74 DbName string
75 Logger *slog.Logger
76 Version string
77 Did string
78 Hostname string
79 RotationKeyPath string
80 JwkPath string
81 ContactEmail string
82 Relays []string
83 AdminPassword string
84
85 SmtpUser string
86 SmtpPass string
87 SmtpHost string
88 SmtpPort string
89 SmtpEmail string
90 SmtpName string
91
92 S3Config *S3Config
93}
94
95type config struct {
96 Version string
97 Did string
98 Hostname string
99 ContactEmail string
100 EnforcePeering bool
101 Relays []string
102 AdminPassword string
103 SmtpEmail string
104 SmtpName string
105}
106
107type CustomValidator struct {
108 validator *validator.Validate
109}
110
111type ValidationError struct {
112 error
113 Field string
114 Tag string
115}
116
117func (cv *CustomValidator) Validate(i any) error {
118 if err := cv.validator.Struct(i); err != nil {
119 var validateErrors validator.ValidationErrors
120 if errors.As(err, &validateErrors) && len(validateErrors) > 0 {
121 first := validateErrors[0]
122 return ValidationError{
123 error: err,
124 Field: first.Field(),
125 Tag: first.Tag(),
126 }
127 }
128
129 return err
130 }
131
132 return nil
133}
134
135func (s *Server) handleAdminMiddleware(next echo.HandlerFunc) echo.HandlerFunc {
136 return func(e echo.Context) error {
137 username, password, ok := e.Request().BasicAuth()
138 if !ok || username != "admin" || password != s.config.AdminPassword {
139 return helpers.InputError(e, to.StringPtr("Unauthorized"))
140 }
141
142 if err := next(e); err != nil {
143 e.Error(err)
144 }
145
146 return nil
147 }
148}
149
150func (s *Server) handleSessionMiddleware(next echo.HandlerFunc) echo.HandlerFunc {
151 return func(e echo.Context) error {
152 authheader := e.Request().Header.Get("authorization")
153 if authheader == "" {
154 return e.JSON(401, map[string]string{"error": "Unauthorized"})
155 }
156
157 pts := strings.Split(authheader, " ")
158 if len(pts) != 2 {
159 return helpers.ServerError(e, nil)
160 }
161
162 tokenstr := pts[1]
163
164 token, err := new(jwt.Parser).Parse(tokenstr, func(t *jwt.Token) (any, error) {
165 if _, ok := t.Method.(*jwt.SigningMethodECDSA); !ok {
166 return nil, fmt.Errorf("unsupported signing method: %v", t.Header["alg"])
167 }
168
169 return s.privateKey.Public(), nil
170 })
171 if err != nil {
172 s.logger.Error("error parsing jwt", "error", err)
173 // NOTE: https://github.com/bluesky-social/atproto/discussions/3319
174 return e.JSON(400, map[string]string{"error": "ExpiredToken", "message": "token has expired"})
175 }
176
177 claims, ok := token.Claims.(jwt.MapClaims)
178 if !ok || !token.Valid {
179 return helpers.InputError(e, to.StringPtr("InvalidToken"))
180 }
181
182 isRefresh := e.Request().URL.Path == "/xrpc/com.atproto.server.refreshSession"
183 scope := claims["scope"].(string)
184
185 if isRefresh && scope != "com.atproto.refresh" {
186 return helpers.InputError(e, to.StringPtr("InvalidToken"))
187 } else if !isRefresh && scope != "com.atproto.access" {
188 return helpers.InputError(e, to.StringPtr("InvalidToken"))
189 }
190
191 table := "tokens"
192 if isRefresh {
193 table = "refresh_tokens"
194 }
195
196 type Result struct {
197 Found bool
198 }
199 var result Result
200 if err := s.db.Raw("SELECT EXISTS(SELECT 1 FROM "+table+" WHERE token = ?) AS found", nil, tokenstr).Scan(&result).Error; err != nil {
201 if err == gorm.ErrRecordNotFound {
202 return helpers.InputError(e, to.StringPtr("InvalidToken"))
203 }
204
205 s.logger.Error("error getting token from db", "error", err)
206 return helpers.ServerError(e, nil)
207 }
208
209 if !result.Found {
210 return helpers.InputError(e, to.StringPtr("InvalidToken"))
211 }
212
213 exp, ok := claims["exp"].(float64)
214 if !ok {
215 s.logger.Error("error getting iat from token")
216 return helpers.ServerError(e, nil)
217 }
218
219 if exp < float64(time.Now().UTC().Unix()) {
220 return helpers.InputError(e, to.StringPtr("ExpiredToken"))
221 }
222
223 repo, err := s.getRepoActorByDid(claims["sub"].(string))
224 if err != nil {
225 s.logger.Error("error fetching repo", "error", err)
226 return helpers.ServerError(e, nil)
227 }
228
229 e.Set("repo", repo)
230 e.Set("did", claims["sub"])
231 e.Set("token", tokenstr)
232
233 if err := next(e); err != nil {
234 e.Error(err)
235 }
236
237 return nil
238 }
239}
240
241func New(args *Args) (*Server, error) {
242 if args.Addr == "" {
243 return nil, fmt.Errorf("addr must be set")
244 }
245
246 if args.DbName == "" {
247 return nil, fmt.Errorf("db name must be set")
248 }
249
250 if args.Did == "" {
251 return nil, fmt.Errorf("cocoon did must be set")
252 }
253
254 if args.ContactEmail == "" {
255 return nil, fmt.Errorf("cocoon contact email is required")
256 }
257
258 if _, err := syntax.ParseDID(args.Did); err != nil {
259 return nil, fmt.Errorf("error parsing cocoon did: %w", err)
260 }
261
262 if args.Hostname == "" {
263 return nil, fmt.Errorf("cocoon hostname must be set")
264 }
265
266 if args.AdminPassword == "" {
267 return nil, fmt.Errorf("admin password must be set")
268 }
269
270 if args.Logger == nil {
271 args.Logger = slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{}))
272 }
273
274 e := echo.New()
275
276 e.Pre(middleware.RemoveTrailingSlash())
277 e.Pre(slogecho.New(args.Logger))
278 e.Use(middleware.CORSWithConfig(middleware.CORSConfig{
279 AllowOrigins: []string{"*"},
280 AllowHeaders: []string{"*"},
281 AllowMethods: []string{"*"},
282 AllowCredentials: true,
283 MaxAge: 100_000_000,
284 }))
285
286 vdtor := validator.New()
287 vdtor.RegisterValidation("atproto-handle", func(fl validator.FieldLevel) bool {
288 if _, err := syntax.ParseHandle(fl.Field().String()); err != nil {
289 return false
290 }
291 return true
292 })
293 vdtor.RegisterValidation("atproto-did", func(fl validator.FieldLevel) bool {
294 if _, err := syntax.ParseDID(fl.Field().String()); err != nil {
295 return false
296 }
297 return true
298 })
299 vdtor.RegisterValidation("atproto-rkey", func(fl validator.FieldLevel) bool {
300 if _, err := syntax.ParseRecordKey(fl.Field().String()); err != nil {
301 return false
302 }
303 return true
304 })
305 vdtor.RegisterValidation("atproto-nsid", func(fl validator.FieldLevel) bool {
306 if _, err := syntax.ParseNSID(fl.Field().String()); err != nil {
307 return false
308 }
309 return true
310 })
311
312 e.Validator = &CustomValidator{validator: vdtor}
313
314 httpd := &http.Server{
315 Addr: args.Addr,
316 Handler: e,
317 // shitty defaults but okay for now, needed for import repo
318 ReadTimeout: 5 * time.Minute,
319 WriteTimeout: 5 * time.Minute,
320 IdleTimeout: 5 * time.Minute,
321 }
322
323 gdb, err := gorm.Open(sqlite.Open("cocoon.db"), &gorm.Config{})
324 if err != nil {
325 return nil, err
326 }
327 dbw := db.NewDB(gdb)
328
329 rkbytes, err := os.ReadFile(args.RotationKeyPath)
330 if err != nil {
331 return nil, err
332 }
333
334 h := util.RobustHTTPClient()
335
336 plcClient, err := plc.NewClient(&plc.ClientArgs{
337 H: h,
338 Service: "https://plc.directory",
339 PdsHostname: args.Hostname,
340 RotationKey: rkbytes,
341 })
342 if err != nil {
343 return nil, err
344 }
345
346 jwkbytes, err := os.ReadFile(args.JwkPath)
347 if err != nil {
348 return nil, err
349 }
350
351 key, err := jwk.ParseKey(jwkbytes)
352 if err != nil {
353 return nil, err
354 }
355
356 var pkey ecdsa.PrivateKey
357 if err := key.Raw(&pkey); err != nil {
358 return nil, err
359 }
360
361 s := &Server{
362 http: h,
363 httpd: httpd,
364 echo: e,
365 logger: args.Logger,
366 db: dbw,
367 plcClient: plcClient,
368 privateKey: &pkey,
369 config: &config{
370 Version: args.Version,
371 Did: args.Did,
372 Hostname: args.Hostname,
373 ContactEmail: args.ContactEmail,
374 EnforcePeering: false,
375 Relays: args.Relays,
376 AdminPassword: args.AdminPassword,
377 SmtpName: args.SmtpName,
378 SmtpEmail: args.SmtpEmail,
379 },
380 evtman: events.NewEventManager(events.NewMemPersister()),
381 passport: identity.NewPassport(h, identity.NewMemCache(10_000)),
382
383 dbName: args.DbName,
384 s3Config: args.S3Config,
385 }
386
387 s.repoman = NewRepoMan(s) // TODO: this is way too lazy, stop it
388
389 // TODO: should validate these args
390 if args.SmtpUser == "" || args.SmtpPass == "" || args.SmtpHost == "" || args.SmtpPort == "" || args.SmtpEmail == "" || args.SmtpName == "" {
391 args.Logger.Warn("not enough smpt args were provided. mailing will not work for your server.")
392 } else {
393 mail := mailyak.New(args.SmtpHost+":"+args.SmtpPort, smtp.PlainAuth("", args.SmtpUser, args.SmtpPass, args.SmtpHost))
394 mail.From(s.config.SmtpEmail)
395 mail.FromName(s.config.SmtpName)
396
397 s.mail = mail
398 s.mailLk = &sync.Mutex{}
399 }
400
401 return s, nil
402}
403
404func (s *Server) addRoutes() {
405 // random stuff
406 s.echo.GET("/", s.handleRoot)
407 s.echo.GET("/xrpc/_health", s.handleHealth)
408 s.echo.GET("/.well-known/did.json", s.handleWellKnown)
409 s.echo.GET("/robots.txt", s.handleRobots)
410
411 // public
412 s.echo.GET("/xrpc/com.atproto.identity.resolveHandle", s.handleResolveHandle)
413 s.echo.POST("/xrpc/com.atproto.server.createAccount", s.handleCreateAccount)
414 s.echo.POST("/xrpc/com.atproto.server.createAccount", s.handleCreateAccount)
415 s.echo.POST("/xrpc/com.atproto.server.createSession", s.handleCreateSession)
416 s.echo.GET("/xrpc/com.atproto.server.describeServer", s.handleDescribeServer)
417
418 s.echo.GET("/xrpc/com.atproto.repo.describeRepo", s.handleDescribeRepo)
419 s.echo.GET("/xrpc/com.atproto.sync.listRepos", s.handleListRepos)
420 s.echo.GET("/xrpc/com.atproto.repo.listRecords", s.handleListRecords)
421 s.echo.GET("/xrpc/com.atproto.repo.getRecord", s.handleRepoGetRecord)
422 s.echo.GET("/xrpc/com.atproto.sync.getRecord", s.handleSyncGetRecord)
423 s.echo.GET("/xrpc/com.atproto.sync.getBlocks", s.handleGetBlocks)
424 s.echo.GET("/xrpc/com.atproto.sync.getLatestCommit", s.handleSyncGetLatestCommit)
425 s.echo.GET("/xrpc/com.atproto.sync.getRepoStatus", s.handleSyncGetRepoStatus)
426 s.echo.GET("/xrpc/com.atproto.sync.getRepo", s.handleSyncGetRepo)
427 s.echo.GET("/xrpc/com.atproto.sync.subscribeRepos", s.handleSyncSubscribeRepos)
428 s.echo.GET("/xrpc/com.atproto.sync.listBlobs", s.handleSyncListBlobs)
429 s.echo.GET("/xrpc/com.atproto.sync.getBlob", s.handleSyncGetBlob)
430
431 // authed
432 s.echo.GET("/xrpc/com.atproto.server.getSession", s.handleGetSession, s.handleSessionMiddleware)
433 s.echo.POST("/xrpc/com.atproto.server.refreshSession", s.handleRefreshSession, s.handleSessionMiddleware)
434 s.echo.POST("/xrpc/com.atproto.server.deleteSession", s.handleDeleteSession, s.handleSessionMiddleware)
435 s.echo.POST("/xrpc/com.atproto.identity.updateHandle", s.handleIdentityUpdateHandle, s.handleSessionMiddleware)
436 s.echo.POST("/xrpc/com.atproto.server.confirmEmail", s.handleServerConfirmEmail, s.handleSessionMiddleware)
437 s.echo.POST("/xrpc/com.atproto.server.requestEmailConfirmation", s.handleServerRequestEmailConfirmation, s.handleSessionMiddleware)
438 s.echo.POST("/xrpc/com.atproto.server.requestPasswordReset", s.handleServerRequestPasswordReset) // AUTH NOT REQUIRED FOR THIS ONE
439 s.echo.POST("/xrpc/com.atproto.server.requestEmailUpdate", s.handleServerRequestEmailUpdate, s.handleSessionMiddleware)
440 s.echo.POST("/xrpc/com.atproto.server.resetPassword", s.handleServerResetPassword, s.handleSessionMiddleware)
441 s.echo.POST("/xrpc/com.atproto.server.updateEmail", s.handleServerUpdateEmail, s.handleSessionMiddleware)
442
443 // repo
444 s.echo.POST("/xrpc/com.atproto.repo.createRecord", s.handleCreateRecord, s.handleSessionMiddleware)
445 s.echo.POST("/xrpc/com.atproto.repo.putRecord", s.handlePutRecord, s.handleSessionMiddleware)
446 s.echo.POST("/xrpc/com.atproto.repo.deleteRecord", s.handleDeleteRecord, s.handleSessionMiddleware)
447 s.echo.POST("/xrpc/com.atproto.repo.applyWrites", s.handleApplyWrites, s.handleSessionMiddleware)
448 s.echo.POST("/xrpc/com.atproto.repo.uploadBlob", s.handleRepoUploadBlob, s.handleSessionMiddleware)
449 s.echo.POST("/xrpc/com.atproto.repo.importRepo", s.handleRepoImportRepo, s.handleSessionMiddleware)
450
451 // stupid silly endpoints
452 s.echo.GET("/xrpc/app.bsky.actor.getPreferences", s.handleActorGetPreferences, s.handleSessionMiddleware)
453 s.echo.POST("/xrpc/app.bsky.actor.putPreferences", s.handleActorPutPreferences, s.handleSessionMiddleware)
454
455 // are there any routes that we should be allowing without auth? i dont think so but idk
456 s.echo.GET("/xrpc/*", s.handleProxy, s.handleSessionMiddleware)
457 s.echo.POST("/xrpc/*", s.handleProxy, s.handleSessionMiddleware)
458
459 // admin routes
460 s.echo.POST("/xrpc/com.atproto.server.createInviteCode", s.handleCreateInviteCode, s.handleAdminMiddleware)
461 s.echo.POST("/xrpc/com.atproto.server.createInviteCodes", s.handleCreateInviteCodes, s.handleAdminMiddleware)
462}
463
464func (s *Server) Serve(ctx context.Context) error {
465 s.addRoutes()
466
467 s.logger.Info("migrating...")
468
469 s.db.AutoMigrate(
470 &models.Actor{},
471 &models.Repo{},
472 &models.InviteCode{},
473 &models.Token{},
474 &models.RefreshToken{},
475 &models.Block{},
476 &models.Record{},
477 &models.Blob{},
478 &models.BlobPart{},
479 )
480
481 s.logger.Info("starting cocoon")
482
483 go func() {
484 if err := s.httpd.ListenAndServe(); err != nil {
485 panic(err)
486 }
487 }()
488
489 go s.backupRoutine()
490
491 for _, relay := range s.config.Relays {
492 cli := xrpc.Client{Host: relay}
493 atproto.SyncRequestCrawl(ctx, &cli, &atproto.SyncRequestCrawl_Input{
494 Hostname: s.config.Hostname,
495 })
496 }
497
498 <-ctx.Done()
499
500 fmt.Println("shut down")
501
502 return nil
503}
504
505func (s *Server) doBackup() {
506 start := time.Now()
507
508 s.logger.Info("beginning backup to s3...")
509
510 var buf bytes.Buffer
511 if err := func() error {
512 s.logger.Info("reading database bytes...")
513 s.db.Lock()
514 defer s.db.Unlock()
515
516 sf, err := os.Open(s.dbName)
517 if err != nil {
518 return fmt.Errorf("error opening database for backup: %w", err)
519 }
520 defer sf.Close()
521
522 if _, err := io.Copy(&buf, sf); err != nil {
523 return fmt.Errorf("error reading bytes of backup db: %w", err)
524 }
525
526 return nil
527 }(); err != nil {
528 s.logger.Error("error backing up database", "error", err)
529 return
530 }
531
532 if err := func() error {
533 s.logger.Info("sending to s3...")
534
535 currTime := time.Now().Format("2006-01-02_15-04-05")
536 key := "cocoon-backup-" + currTime + ".db"
537
538 config := &aws.Config{
539 Region: aws.String(s.s3Config.Region),
540 Credentials: credentials.NewStaticCredentials(s.s3Config.AccessKey, s.s3Config.SecretKey, ""),
541 }
542
543 if s.s3Config.Endpoint != "" {
544 config.Endpoint = aws.String(s.s3Config.Endpoint)
545 config.S3ForcePathStyle = aws.Bool(true)
546 }
547
548 sess, err := session.NewSession(config)
549 if err != nil {
550 return err
551 }
552
553 svc := s3.New(sess)
554
555 if _, err := svc.PutObject(&s3.PutObjectInput{
556 Bucket: aws.String(s.s3Config.Bucket),
557 Key: aws.String(key),
558 Body: bytes.NewReader(buf.Bytes()),
559 }); err != nil {
560 return fmt.Errorf("error uploading file to s3: %w", err)
561 }
562
563 s.logger.Info("finished uploading backup to s3", "key", key, "duration", time.Now().Sub(start).Seconds())
564
565 return nil
566 }(); err != nil {
567 s.logger.Error("error uploading database backup", "error", err)
568 return
569 }
570
571 os.WriteFile("last-backup.txt", []byte(time.Now().String()), 0644)
572}
573
574func (s *Server) backupRoutine() {
575 if s.s3Config == nil || !s.s3Config.BackupsEnabled {
576 return
577 }
578
579 if s.s3Config.Region == "" {
580 s.logger.Warn("no s3 region configured but backups are enabled. backups will not run.")
581 return
582 }
583
584 if s.s3Config.Bucket == "" {
585 s.logger.Warn("no s3 bucket configured but backups are enabled. backups will not run.")
586 return
587 }
588
589 if s.s3Config.AccessKey == "" {
590 s.logger.Warn("no s3 access key configured but backups are enabled. backups will not run.")
591 return
592 }
593
594 if s.s3Config.SecretKey == "" {
595 s.logger.Warn("no s3 secret key configured but backups are enabled. backups will not run.")
596 return
597 }
598
599 shouldBackupNow := false
600 lastBackupStr, err := os.ReadFile("last-backup.txt")
601 if err != nil {
602 shouldBackupNow = true
603 } else {
604 lastBackup, err := time.Parse("2006-01-02 15:04:05.999999999 -0700 MST", string(lastBackupStr))
605 if err != nil {
606 shouldBackupNow = true
607 } else if time.Now().Sub(lastBackup).Seconds() > 3600 {
608 shouldBackupNow = true
609 }
610 }
611
612 if shouldBackupNow {
613 go s.doBackup()
614 }
615
616 ticker := time.NewTicker(time.Hour)
617 for range ticker.C {
618 go s.doBackup()
619 }
620}