1package server
2
3import (
4 "bytes"
5 "context"
6 "crypto/ecdsa"
7 "embed"
8 "errors"
9 "fmt"
10 "io"
11 "log/slog"
12 "net/http"
13 "net/smtp"
14 "os"
15 "path/filepath"
16 "sync"
17 "text/template"
18 "time"
19
20 "github.com/aws/aws-sdk-go/aws"
21 "github.com/aws/aws-sdk-go/aws/credentials"
22 "github.com/aws/aws-sdk-go/aws/session"
23 "github.com/aws/aws-sdk-go/service/s3"
24 "github.com/bluesky-social/indigo/api/atproto"
25 "github.com/bluesky-social/indigo/atproto/syntax"
26 "github.com/bluesky-social/indigo/events"
27 "github.com/bluesky-social/indigo/util"
28 "github.com/bluesky-social/indigo/xrpc"
29 "github.com/domodwyer/mailyak/v3"
30 "github.com/go-playground/validator"
31 "github.com/gorilla/sessions"
32 "github.com/haileyok/cocoon/identity"
33 "github.com/haileyok/cocoon/internal/db"
34 "github.com/haileyok/cocoon/internal/helpers"
35 "github.com/haileyok/cocoon/models"
36 "github.com/haileyok/cocoon/oauth/client"
37 "github.com/haileyok/cocoon/oauth/constants"
38 "github.com/haileyok/cocoon/oauth/dpop"
39 "github.com/haileyok/cocoon/oauth/provider"
40 "github.com/haileyok/cocoon/plc"
41 "github.com/ipfs/go-cid"
42 echo_session "github.com/labstack/echo-contrib/session"
43 "github.com/labstack/echo/v4"
44 "github.com/labstack/echo/v4/middleware"
45 slogecho "github.com/samber/slog-echo"
46 "gorm.io/driver/postgres"
47 "gorm.io/driver/sqlite"
48 "gorm.io/gorm"
49)
50
51const (
52 AccountSessionMaxAge = 30 * 24 * time.Hour // one week
53)
54
55type S3Config struct {
56 BackupsEnabled bool
57 BlobstoreEnabled bool
58 Endpoint string
59 Region string
60 Bucket string
61 AccessKey string
62 SecretKey string
63}
64
65type Server struct {
66 http *http.Client
67 httpd *http.Server
68 mail *mailyak.MailYak
69 mailLk *sync.Mutex
70 echo *echo.Echo
71 db *db.DB
72 plcClient *plc.Client
73 logger *slog.Logger
74 config *config
75 privateKey *ecdsa.PrivateKey
76 repoman *RepoMan
77 oauthProvider *provider.Provider
78 evtman *events.EventManager
79 passport *identity.Passport
80 fallbackProxy string
81
82 lastRequestCrawl time.Time
83 requestCrawlMu sync.Mutex
84
85 dbName string
86 dbType string
87 s3Config *S3Config
88}
89
90type Args struct {
91 Addr string
92 DbName string
93 DbType string
94 DatabaseURL string
95 Logger *slog.Logger
96 Version string
97 Did string
98 Hostname string
99 RotationKeyPath string
100 JwkPath string
101 ContactEmail string
102 Relays []string
103 AdminPassword string
104
105 SmtpUser string
106 SmtpPass string
107 SmtpHost string
108 SmtpPort string
109 SmtpEmail string
110 SmtpName string
111
112 S3Config *S3Config
113
114 SessionSecret string
115
116 BlockstoreVariant BlockstoreVariant
117 FallbackProxy string
118}
119
120type config struct {
121 Version string
122 Did string
123 Hostname string
124 ContactEmail string
125 EnforcePeering bool
126 Relays []string
127 AdminPassword string
128 SmtpEmail string
129 SmtpName string
130 BlockstoreVariant BlockstoreVariant
131 FallbackProxy string
132}
133
134type CustomValidator struct {
135 validator *validator.Validate
136}
137
138type ValidationError struct {
139 error
140 Field string
141 Tag string
142}
143
144func (cv *CustomValidator) Validate(i any) error {
145 if err := cv.validator.Struct(i); err != nil {
146 var validateErrors validator.ValidationErrors
147 if errors.As(err, &validateErrors) && len(validateErrors) > 0 {
148 first := validateErrors[0]
149 return ValidationError{
150 error: err,
151 Field: first.Field(),
152 Tag: first.Tag(),
153 }
154 }
155
156 return err
157 }
158
159 return nil
160}
161
162//go:embed templates/*
163var templateFS embed.FS
164
165//go:embed static/*
166var staticFS embed.FS
167
168type TemplateRenderer struct {
169 templates *template.Template
170 isDev bool
171 templatePath string
172}
173
174func (s *Server) loadTemplates() {
175 absPath, _ := filepath.Abs("server/templates/*.html")
176 if s.config.Version == "dev" {
177 tmpl := template.Must(template.ParseGlob(absPath))
178 s.echo.Renderer = &TemplateRenderer{
179 templates: tmpl,
180 isDev: true,
181 templatePath: absPath,
182 }
183 } else {
184 tmpl := template.Must(template.ParseFS(templateFS, "templates/*.html"))
185 s.echo.Renderer = &TemplateRenderer{
186 templates: tmpl,
187 isDev: false,
188 }
189 }
190}
191
192func (t *TemplateRenderer) Render(w io.Writer, name string, data any, c echo.Context) error {
193 if t.isDev {
194 tmpl, err := template.ParseGlob(t.templatePath)
195 if err != nil {
196 return err
197 }
198 t.templates = tmpl
199 }
200
201 if viewContext, isMap := data.(map[string]any); isMap {
202 viewContext["reverse"] = c.Echo().Reverse
203 }
204
205 return t.templates.ExecuteTemplate(w, name, data)
206}
207
208func New(args *Args) (*Server, error) {
209 if args.Addr == "" {
210 return nil, fmt.Errorf("addr must be set")
211 }
212
213 if args.DbName == "" {
214 return nil, fmt.Errorf("db name must be set")
215 }
216
217 if args.Did == "" {
218 return nil, fmt.Errorf("cocoon did must be set")
219 }
220
221 if args.ContactEmail == "" {
222 return nil, fmt.Errorf("cocoon contact email is required")
223 }
224
225 if _, err := syntax.ParseDID(args.Did); err != nil {
226 return nil, fmt.Errorf("error parsing cocoon did: %w", err)
227 }
228
229 if args.Hostname == "" {
230 return nil, fmt.Errorf("cocoon hostname must be set")
231 }
232
233 if args.AdminPassword == "" {
234 return nil, fmt.Errorf("admin password must be set")
235 }
236
237 if args.Logger == nil {
238 args.Logger = slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{}))
239 }
240
241 if args.SessionSecret == "" {
242 panic("SESSION SECRET WAS NOT SET. THIS IS REQUIRED. ")
243 }
244
245 e := echo.New()
246
247 e.Pre(middleware.RemoveTrailingSlash())
248 e.Pre(slogecho.New(args.Logger))
249 e.Use(echo_session.Middleware(sessions.NewCookieStore([]byte(args.SessionSecret))))
250 e.Use(middleware.CORSWithConfig(middleware.CORSConfig{
251 AllowOrigins: []string{"*"},
252 AllowHeaders: []string{"*"},
253 AllowMethods: []string{"*"},
254 AllowCredentials: true,
255 MaxAge: 100_000_000,
256 }))
257
258 vdtor := validator.New()
259 vdtor.RegisterValidation("atproto-handle", func(fl validator.FieldLevel) bool {
260 if _, err := syntax.ParseHandle(fl.Field().String()); err != nil {
261 return false
262 }
263 return true
264 })
265 vdtor.RegisterValidation("atproto-did", func(fl validator.FieldLevel) bool {
266 if _, err := syntax.ParseDID(fl.Field().String()); err != nil {
267 return false
268 }
269 return true
270 })
271 vdtor.RegisterValidation("atproto-rkey", func(fl validator.FieldLevel) bool {
272 if _, err := syntax.ParseRecordKey(fl.Field().String()); err != nil {
273 return false
274 }
275 return true
276 })
277 vdtor.RegisterValidation("atproto-nsid", func(fl validator.FieldLevel) bool {
278 if _, err := syntax.ParseNSID(fl.Field().String()); err != nil {
279 return false
280 }
281 return true
282 })
283
284 e.Validator = &CustomValidator{validator: vdtor}
285
286 httpd := &http.Server{
287 Addr: args.Addr,
288 Handler: e,
289 // shitty defaults but okay for now, needed for import repo
290 ReadTimeout: 5 * time.Minute,
291 WriteTimeout: 5 * time.Minute,
292 IdleTimeout: 5 * time.Minute,
293 }
294
295 dbType := args.DbType
296 if dbType == "" {
297 dbType = "sqlite"
298 }
299
300 var gdb *gorm.DB
301 var err error
302 switch dbType {
303 case "postgres":
304 if args.DatabaseURL == "" {
305 return nil, fmt.Errorf("database-url must be set when using postgres")
306 }
307 gdb, err = gorm.Open(postgres.Open(args.DatabaseURL), &gorm.Config{})
308 if err != nil {
309 return nil, fmt.Errorf("failed to connect to postgres: %w", err)
310 }
311 args.Logger.Info("connected to PostgreSQL database")
312 default:
313 gdb, err = gorm.Open(sqlite.Open(args.DbName), &gorm.Config{})
314 if err != nil {
315 return nil, fmt.Errorf("failed to open sqlite database: %w", err)
316 }
317 args.Logger.Info("connected to SQLite database", "path", args.DbName)
318 }
319 dbw := db.NewDB(gdb)
320
321 rkbytes, err := os.ReadFile(args.RotationKeyPath)
322 if err != nil {
323 return nil, err
324 }
325
326 h := util.RobustHTTPClient()
327
328 plcClient, err := plc.NewClient(&plc.ClientArgs{
329 H: h,
330 Service: "https://plc.directory",
331 PdsHostname: args.Hostname,
332 RotationKey: rkbytes,
333 })
334 if err != nil {
335 return nil, err
336 }
337
338 jwkbytes, err := os.ReadFile(args.JwkPath)
339 if err != nil {
340 return nil, err
341 }
342
343 key, err := helpers.ParseJWKFromBytes(jwkbytes)
344 if err != nil {
345 return nil, err
346 }
347
348 var pkey ecdsa.PrivateKey
349 if err := key.Raw(&pkey); err != nil {
350 return nil, err
351 }
352
353 oauthCli := &http.Client{
354 Timeout: 10 * time.Second,
355 }
356
357 var nonceSecret []byte
358 maybeSecret, err := os.ReadFile("nonce.secret")
359 if err != nil && !os.IsNotExist(err) {
360 args.Logger.Error("error attempting to read nonce secret", "error", err)
361 } else {
362 nonceSecret = maybeSecret
363 }
364
365 s := &Server{
366 http: h,
367 httpd: httpd,
368 echo: e,
369 logger: args.Logger,
370 db: dbw,
371 plcClient: plcClient,
372 privateKey: &pkey,
373 config: &config{
374 Version: args.Version,
375 Did: args.Did,
376 Hostname: args.Hostname,
377 ContactEmail: args.ContactEmail,
378 EnforcePeering: false,
379 Relays: args.Relays,
380 AdminPassword: args.AdminPassword,
381 SmtpName: args.SmtpName,
382 SmtpEmail: args.SmtpEmail,
383 BlockstoreVariant: args.BlockstoreVariant,
384 FallbackProxy: args.FallbackProxy,
385 },
386 evtman: events.NewEventManager(events.NewMemPersister()),
387 passport: identity.NewPassport(h, identity.NewMemCache(10_000)),
388
389 dbName: args.DbName,
390 dbType: dbType,
391 s3Config: args.S3Config,
392
393 oauthProvider: provider.NewProvider(provider.Args{
394 Hostname: args.Hostname,
395 ClientManagerArgs: client.ManagerArgs{
396 Cli: oauthCli,
397 Logger: args.Logger,
398 },
399 DpopManagerArgs: dpop.ManagerArgs{
400 NonceSecret: nonceSecret,
401 NonceRotationInterval: constants.NonceMaxRotationInterval / 3,
402 OnNonceSecretCreated: func(newNonce []byte) {
403 if err := os.WriteFile("nonce.secret", newNonce, 0644); err != nil {
404 args.Logger.Error("error writing new nonce secret", "error", err)
405 }
406 },
407 Logger: args.Logger,
408 Hostname: args.Hostname,
409 },
410 }),
411 }
412
413 s.loadTemplates()
414
415 s.repoman = NewRepoMan(s) // TODO: this is way too lazy, stop it
416
417 // TODO: should validate these args
418 if args.SmtpUser == "" || args.SmtpPass == "" || args.SmtpHost == "" || args.SmtpPort == "" || args.SmtpEmail == "" || args.SmtpName == "" {
419 args.Logger.Warn("not enough smtp args were provided. mailing will not work for your server.")
420 } else {
421 mail := mailyak.New(args.SmtpHost+":"+args.SmtpPort, smtp.PlainAuth("", args.SmtpUser, args.SmtpPass, args.SmtpHost))
422 mail.From(s.config.SmtpEmail)
423 mail.FromName(s.config.SmtpName)
424
425 s.mail = mail
426 s.mailLk = &sync.Mutex{}
427 }
428
429 return s, nil
430}
431
432func (s *Server) addRoutes() {
433 // static
434 if s.config.Version == "dev" {
435 s.echo.Static("/static", "server/static")
436 } else {
437 s.echo.GET("/static/*", echo.WrapHandler(http.FileServer(http.FS(staticFS))))
438 }
439
440 // random stuff
441 s.echo.GET("/", s.handleRoot)
442 s.echo.GET("/xrpc/_health", s.handleHealth)
443 s.echo.GET("/.well-known/did.json", s.handleWellKnown)
444 s.echo.GET("/.well-known/oauth-protected-resource", s.handleOauthProtectedResource)
445 s.echo.GET("/.well-known/oauth-authorization-server", s.handleOauthAuthorizationServer)
446 s.echo.GET("/robots.txt", s.handleRobots)
447
448 // public
449 s.echo.GET("/xrpc/com.atproto.identity.resolveHandle", s.handleResolveHandle)
450 s.echo.POST("/xrpc/com.atproto.server.createAccount", s.handleCreateAccount)
451 s.echo.POST("/xrpc/com.atproto.server.createSession", s.handleCreateSession)
452 s.echo.GET("/xrpc/com.atproto.server.describeServer", s.handleDescribeServer)
453
454 s.echo.GET("/xrpc/com.atproto.repo.describeRepo", s.handleDescribeRepo)
455 s.echo.GET("/xrpc/com.atproto.sync.listRepos", s.handleListRepos)
456 s.echo.GET("/xrpc/com.atproto.repo.listRecords", s.handleListRecords)
457 s.echo.GET("/xrpc/com.atproto.repo.listMissingBlobs", s.handleListMissingBlobs)
458 s.echo.GET("/xrpc/com.atproto.repo.getRecord", s.handleRepoGetRecord)
459 s.echo.GET("/xrpc/com.atproto.sync.getRecord", s.handleSyncGetRecord)
460 s.echo.GET("/xrpc/com.atproto.sync.getBlocks", s.handleGetBlocks)
461 s.echo.GET("/xrpc/com.atproto.sync.getLatestCommit", s.handleSyncGetLatestCommit)
462 s.echo.GET("/xrpc/com.atproto.sync.getRepoStatus", s.handleSyncGetRepoStatus)
463 s.echo.GET("/xrpc/com.atproto.sync.getRepo", s.handleSyncGetRepo)
464 s.echo.GET("/xrpc/com.atproto.sync.subscribeRepos", s.handleSyncSubscribeRepos)
465 s.echo.GET("/xrpc/com.atproto.sync.listBlobs", s.handleSyncListBlobs)
466 s.echo.GET("/xrpc/com.atproto.sync.getBlob", s.handleSyncGetBlob)
467
468 // account
469 s.echo.GET("/account", s.handleAccount)
470 s.echo.POST("/account/revoke", s.handleAccountRevoke)
471 s.echo.GET("/account/signin", s.handleAccountSigninGet)
472 s.echo.POST("/account/signin", s.handleAccountSigninPost)
473 s.echo.GET("/account/signout", s.handleAccountSignout)
474
475 // oauth account
476 s.echo.GET("/oauth/jwks", s.handleOauthJwks)
477 s.echo.GET("/oauth/authorize", s.handleOauthAuthorizeGet)
478 s.echo.POST("/oauth/authorize", s.handleOauthAuthorizePost)
479
480 // oauth authorization
481 s.echo.POST("/oauth/par", s.handleOauthPar, s.oauthProvider.BaseMiddleware)
482 s.echo.POST("/oauth/token", s.handleOauthToken, s.oauthProvider.BaseMiddleware)
483
484 // authed
485 s.echo.GET("/xrpc/com.atproto.server.getSession", s.handleGetSession, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
486 s.echo.POST("/xrpc/com.atproto.server.refreshSession", s.handleRefreshSession, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
487 s.echo.POST("/xrpc/com.atproto.server.deleteSession", s.handleDeleteSession, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
488 s.echo.GET("/xrpc/com.atproto.identity.getRecommendedDidCredentials", s.handleGetRecommendedDidCredentials, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
489 s.echo.POST("/xrpc/com.atproto.identity.updateHandle", s.handleIdentityUpdateHandle, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
490 s.echo.POST("/xrpc/com.atproto.identity.requestPlcOperationSignature", s.handleIdentityRequestPlcOperationSignature, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
491 s.echo.POST("/xrpc/com.atproto.identity.signPlcOperation", s.handleSignPlcOperation, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
492 s.echo.POST("/xrpc/com.atproto.identity.submitPlcOperation", s.handleSubmitPlcOperation, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
493 s.echo.POST("/xrpc/com.atproto.server.confirmEmail", s.handleServerConfirmEmail, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
494 s.echo.POST("/xrpc/com.atproto.server.requestEmailConfirmation", s.handleServerRequestEmailConfirmation, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
495 s.echo.POST("/xrpc/com.atproto.server.requestPasswordReset", s.handleServerRequestPasswordReset) // AUTH NOT REQUIRED FOR THIS ONE
496 s.echo.POST("/xrpc/com.atproto.server.requestEmailUpdate", s.handleServerRequestEmailUpdate, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
497 s.echo.POST("/xrpc/com.atproto.server.resetPassword", s.handleServerResetPassword, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
498 s.echo.POST("/xrpc/com.atproto.server.updateEmail", s.handleServerUpdateEmail, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
499 s.echo.GET("/xrpc/com.atproto.server.getServiceAuth", s.handleServerGetServiceAuth, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
500 s.echo.GET("/xrpc/com.atproto.server.checkAccountStatus", s.handleServerCheckAccountStatus, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
501 s.echo.POST("/xrpc/com.atproto.server.deactivateAccount", s.handleServerDeactivateAccount, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
502 s.echo.POST("/xrpc/com.atproto.server.activateAccount", s.handleServerActivateAccount, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
503
504 // repo
505 s.echo.POST("/xrpc/com.atproto.repo.createRecord", s.handleCreateRecord, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
506 s.echo.POST("/xrpc/com.atproto.repo.putRecord", s.handlePutRecord, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
507 s.echo.POST("/xrpc/com.atproto.repo.deleteRecord", s.handleDeleteRecord, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
508 s.echo.POST("/xrpc/com.atproto.repo.applyWrites", s.handleApplyWrites, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
509 s.echo.POST("/xrpc/com.atproto.repo.uploadBlob", s.handleRepoUploadBlob, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
510 s.echo.POST("/xrpc/com.atproto.repo.importRepo", s.handleRepoImportRepo, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
511
512 // stupid silly endpoints
513 s.echo.GET("/xrpc/app.bsky.actor.getPreferences", s.handleActorGetPreferences, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
514 s.echo.POST("/xrpc/app.bsky.actor.putPreferences", s.handleActorPutPreferences, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
515 s.echo.GET("/xrpc/app.bsky.feed.getFeed", s.handleProxyBskyFeedGetFeed, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
516
517 // admin routes
518 s.echo.POST("/xrpc/com.atproto.server.createInviteCode", s.handleCreateInviteCode, s.handleAdminMiddleware)
519 s.echo.POST("/xrpc/com.atproto.server.createInviteCodes", s.handleCreateInviteCodes, s.handleAdminMiddleware)
520
521 // are there any routes that we should be allowing without auth? i dont think so but idk
522 s.echo.GET("/xrpc/*", s.handleProxy, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
523 s.echo.POST("/xrpc/*", s.handleProxy, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
524}
525
526func (s *Server) Serve(ctx context.Context) error {
527 s.addRoutes()
528
529 s.logger.Info("migrating...")
530
531 s.db.AutoMigrate(
532 &models.Actor{},
533 &models.Repo{},
534 &models.InviteCode{},
535 &models.Token{},
536 &models.RefreshToken{},
537 &models.Block{},
538 &models.Record{},
539 &models.Blob{},
540 &models.BlobPart{},
541 &provider.OauthToken{},
542 &provider.OauthAuthorizationRequest{},
543 )
544
545 s.logger.Info("starting cocoon")
546
547 go func() {
548 if err := s.httpd.ListenAndServe(); err != nil {
549 panic(err)
550 }
551 }()
552
553 go s.backupRoutine()
554
555 go func() {
556 if err := s.requestCrawl(ctx); err != nil {
557 s.logger.Error("error requesting crawls", "err", err)
558 }
559 }()
560
561 <-ctx.Done()
562
563 fmt.Println("shut down")
564
565 return nil
566}
567
568func (s *Server) requestCrawl(ctx context.Context) error {
569 logger := s.logger.With("component", "request-crawl")
570 s.requestCrawlMu.Lock()
571 defer s.requestCrawlMu.Unlock()
572
573 logger.Info("requesting crawl with configured relays")
574
575 if time.Now().Sub(s.lastRequestCrawl) <= 1*time.Minute {
576 return fmt.Errorf("a crawl request has already been made within the last minute")
577 }
578
579 for _, relay := range s.config.Relays {
580 logger := logger.With("relay", relay)
581 logger.Info("requesting crawl from relay")
582 cli := xrpc.Client{Host: relay}
583 if err := atproto.SyncRequestCrawl(ctx, &cli, &atproto.SyncRequestCrawl_Input{
584 Hostname: s.config.Hostname,
585 }); err != nil {
586 logger.Error("error requesting crawl", "err", err)
587 } else {
588 logger.Info("crawl requested successfully")
589 }
590 }
591
592 s.lastRequestCrawl = time.Now()
593
594 return nil
595}
596
597func (s *Server) doBackup() {
598 if s.dbType == "postgres" {
599 s.logger.Info("skipping S3 backup - PostgreSQL backups should be handled externally (pg_dump, managed database backups, etc.)")
600 return
601 }
602
603 start := time.Now()
604
605 s.logger.Info("beginning backup to s3...")
606
607 var buf bytes.Buffer
608 if err := func() error {
609 s.logger.Info("reading database bytes...")
610 s.db.Lock()
611 defer s.db.Unlock()
612
613 sf, err := os.Open(s.dbName)
614 if err != nil {
615 return fmt.Errorf("error opening database for backup: %w", err)
616 }
617 defer sf.Close()
618
619 if _, err := io.Copy(&buf, sf); err != nil {
620 return fmt.Errorf("error reading bytes of backup db: %w", err)
621 }
622
623 return nil
624 }(); err != nil {
625 s.logger.Error("error backing up database", "error", err)
626 return
627 }
628
629 if err := func() error {
630 s.logger.Info("sending to s3...")
631
632 currTime := time.Now().Format("2006-01-02_15-04-05")
633 key := "cocoon-backup-" + currTime + ".db"
634
635 config := &aws.Config{
636 Region: aws.String(s.s3Config.Region),
637 Credentials: credentials.NewStaticCredentials(s.s3Config.AccessKey, s.s3Config.SecretKey, ""),
638 }
639
640 if s.s3Config.Endpoint != "" {
641 config.Endpoint = aws.String(s.s3Config.Endpoint)
642 config.S3ForcePathStyle = aws.Bool(true)
643 }
644
645 sess, err := session.NewSession(config)
646 if err != nil {
647 return err
648 }
649
650 svc := s3.New(sess)
651
652 if _, err := svc.PutObject(&s3.PutObjectInput{
653 Bucket: aws.String(s.s3Config.Bucket),
654 Key: aws.String(key),
655 Body: bytes.NewReader(buf.Bytes()),
656 }); err != nil {
657 return fmt.Errorf("error uploading file to s3: %w", err)
658 }
659
660 s.logger.Info("finished uploading backup to s3", "key", key, "duration", time.Now().Sub(start).Seconds())
661
662 return nil
663 }(); err != nil {
664 s.logger.Error("error uploading database backup", "error", err)
665 return
666 }
667
668 os.WriteFile("last-backup.txt", []byte(time.Now().String()), 0644)
669}
670
671func (s *Server) backupRoutine() {
672 if s.s3Config == nil || !s.s3Config.BackupsEnabled {
673 return
674 }
675
676 if s.s3Config.Region == "" {
677 s.logger.Warn("no s3 region configured but backups are enabled. backups will not run.")
678 return
679 }
680
681 if s.s3Config.Bucket == "" {
682 s.logger.Warn("no s3 bucket configured but backups are enabled. backups will not run.")
683 return
684 }
685
686 if s.s3Config.AccessKey == "" {
687 s.logger.Warn("no s3 access key configured but backups are enabled. backups will not run.")
688 return
689 }
690
691 if s.s3Config.SecretKey == "" {
692 s.logger.Warn("no s3 secret key configured but backups are enabled. backups will not run.")
693 return
694 }
695
696 shouldBackupNow := false
697 lastBackupStr, err := os.ReadFile("last-backup.txt")
698 if err != nil {
699 shouldBackupNow = true
700 } else {
701 lastBackup, err := time.Parse("2006-01-02 15:04:05.999999999 -0700 MST", string(lastBackupStr))
702 if err != nil {
703 shouldBackupNow = true
704 } else if time.Now().Sub(lastBackup).Seconds() > 3600 {
705 shouldBackupNow = true
706 }
707 }
708
709 if shouldBackupNow {
710 go s.doBackup()
711 }
712
713 ticker := time.NewTicker(time.Hour)
714 for range ticker.C {
715 go s.doBackup()
716 }
717}
718
719func (s *Server) UpdateRepo(ctx context.Context, did string, root cid.Cid, rev string) error {
720 if err := s.db.Exec("UPDATE repos SET root = ?, rev = ? WHERE did = ?", nil, root.Bytes(), rev, did).Error; err != nil {
721 return err
722 }
723
724 return nil
725}