1package server
2
3import (
4 "bytes"
5 "context"
6 "crypto/ecdsa"
7 "embed"
8 "errors"
9 "fmt"
10 "io"
11 "log/slog"
12 "net/http"
13 "net/smtp"
14 "os"
15 "path/filepath"
16 "sync"
17 "text/template"
18 "time"
19
20 "github.com/aws/aws-sdk-go/aws"
21 "github.com/aws/aws-sdk-go/aws/credentials"
22 "github.com/aws/aws-sdk-go/aws/session"
23 "github.com/aws/aws-sdk-go/service/s3"
24 "github.com/bluesky-social/indigo/api/atproto"
25 "github.com/bluesky-social/indigo/atproto/syntax"
26 "github.com/bluesky-social/indigo/events"
27 "github.com/bluesky-social/indigo/util"
28 "github.com/bluesky-social/indigo/xrpc"
29 "github.com/domodwyer/mailyak/v3"
30 "github.com/go-playground/validator"
31 "github.com/gorilla/sessions"
32 "github.com/haileyok/cocoon/identity"
33 "github.com/haileyok/cocoon/internal/db"
34 "github.com/haileyok/cocoon/internal/helpers"
35 "github.com/haileyok/cocoon/models"
36 "github.com/haileyok/cocoon/oauth/client"
37 "github.com/haileyok/cocoon/oauth/constants"
38 "github.com/haileyok/cocoon/oauth/dpop"
39 "github.com/haileyok/cocoon/oauth/provider"
40 "github.com/haileyok/cocoon/plc"
41 "github.com/ipfs/go-cid"
42 echo_session "github.com/labstack/echo-contrib/session"
43 "github.com/labstack/echo/v4"
44 "github.com/labstack/echo/v4/middleware"
45 slogecho "github.com/samber/slog-echo"
46 "gorm.io/driver/sqlite"
47 "gorm.io/gorm"
48)
49
50const (
51 AccountSessionMaxAge = 30 * 24 * time.Hour // one week
52)
53
54type S3Config struct {
55 BackupsEnabled bool
56 BlobstoreEnabled bool
57 Endpoint string
58 Region string
59 Bucket string
60 AccessKey string
61 SecretKey string
62}
63
64type Server struct {
65 http *http.Client
66 httpd *http.Server
67 mail *mailyak.MailYak
68 mailLk *sync.Mutex
69 echo *echo.Echo
70 db *db.DB
71 plcClient *plc.Client
72 logger *slog.Logger
73 config *config
74 privateKey *ecdsa.PrivateKey
75 repoman *RepoMan
76 oauthProvider *provider.Provider
77 evtman *events.EventManager
78 passport *identity.Passport
79 fallbackProxy string
80
81 lastRequestCrawl time.Time
82 requestCrawlMu sync.Mutex
83
84 dbName string
85 s3Config *S3Config
86}
87
88type Args struct {
89 Addr string
90 DbName string
91 Logger *slog.Logger
92 Version string
93 Did string
94 Hostname string
95 RotationKeyPath string
96 JwkPath string
97 ContactEmail string
98 Relays []string
99 AdminPassword string
100
101 SmtpUser string
102 SmtpPass string
103 SmtpHost string
104 SmtpPort string
105 SmtpEmail string
106 SmtpName string
107
108 S3Config *S3Config
109
110 SessionSecret string
111
112 BlockstoreVariant BlockstoreVariant
113 FallbackProxy string
114}
115
116type config struct {
117 Version string
118 Did string
119 Hostname string
120 ContactEmail string
121 EnforcePeering bool
122 Relays []string
123 AdminPassword string
124 SmtpEmail string
125 SmtpName string
126 BlockstoreVariant BlockstoreVariant
127 FallbackProxy string
128}
129
130type CustomValidator struct {
131 validator *validator.Validate
132}
133
134type ValidationError struct {
135 error
136 Field string
137 Tag string
138}
139
140func (cv *CustomValidator) Validate(i any) error {
141 if err := cv.validator.Struct(i); err != nil {
142 var validateErrors validator.ValidationErrors
143 if errors.As(err, &validateErrors) && len(validateErrors) > 0 {
144 first := validateErrors[0]
145 return ValidationError{
146 error: err,
147 Field: first.Field(),
148 Tag: first.Tag(),
149 }
150 }
151
152 return err
153 }
154
155 return nil
156}
157
158//go:embed templates/*
159var templateFS embed.FS
160
161//go:embed static/*
162var staticFS embed.FS
163
164type TemplateRenderer struct {
165 templates *template.Template
166 isDev bool
167 templatePath string
168}
169
170func (s *Server) loadTemplates() {
171 absPath, _ := filepath.Abs("server/templates/*.html")
172 if s.config.Version == "dev" {
173 tmpl := template.Must(template.ParseGlob(absPath))
174 s.echo.Renderer = &TemplateRenderer{
175 templates: tmpl,
176 isDev: true,
177 templatePath: absPath,
178 }
179 } else {
180 tmpl := template.Must(template.ParseFS(templateFS, "templates/*.html"))
181 s.echo.Renderer = &TemplateRenderer{
182 templates: tmpl,
183 isDev: false,
184 }
185 }
186}
187
188func (t *TemplateRenderer) Render(w io.Writer, name string, data any, c echo.Context) error {
189 if t.isDev {
190 tmpl, err := template.ParseGlob(t.templatePath)
191 if err != nil {
192 return err
193 }
194 t.templates = tmpl
195 }
196
197 if viewContext, isMap := data.(map[string]any); isMap {
198 viewContext["reverse"] = c.Echo().Reverse
199 }
200
201 return t.templates.ExecuteTemplate(w, name, data)
202}
203
204func New(args *Args) (*Server, error) {
205 if args.Addr == "" {
206 return nil, fmt.Errorf("addr must be set")
207 }
208
209 if args.DbName == "" {
210 return nil, fmt.Errorf("db name must be set")
211 }
212
213 if args.Did == "" {
214 return nil, fmt.Errorf("cocoon did must be set")
215 }
216
217 if args.ContactEmail == "" {
218 return nil, fmt.Errorf("cocoon contact email is required")
219 }
220
221 if _, err := syntax.ParseDID(args.Did); err != nil {
222 return nil, fmt.Errorf("error parsing cocoon did: %w", err)
223 }
224
225 if args.Hostname == "" {
226 return nil, fmt.Errorf("cocoon hostname must be set")
227 }
228
229 if args.AdminPassword == "" {
230 return nil, fmt.Errorf("admin password must be set")
231 }
232
233 if args.Logger == nil {
234 args.Logger = slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{}))
235 }
236
237 if args.SessionSecret == "" {
238 panic("SESSION SECRET WAS NOT SET. THIS IS REQUIRED. ")
239 }
240
241 e := echo.New()
242
243 e.Pre(middleware.RemoveTrailingSlash())
244 e.Pre(slogecho.New(args.Logger))
245 e.Use(echo_session.Middleware(sessions.NewCookieStore([]byte(args.SessionSecret))))
246 e.Use(middleware.CORSWithConfig(middleware.CORSConfig{
247 AllowOrigins: []string{"*"},
248 AllowHeaders: []string{"*"},
249 AllowMethods: []string{"*"},
250 AllowCredentials: true,
251 MaxAge: 100_000_000,
252 }))
253
254 vdtor := validator.New()
255 vdtor.RegisterValidation("atproto-handle", func(fl validator.FieldLevel) bool {
256 if _, err := syntax.ParseHandle(fl.Field().String()); err != nil {
257 return false
258 }
259 return true
260 })
261 vdtor.RegisterValidation("atproto-did", func(fl validator.FieldLevel) bool {
262 if _, err := syntax.ParseDID(fl.Field().String()); err != nil {
263 return false
264 }
265 return true
266 })
267 vdtor.RegisterValidation("atproto-rkey", func(fl validator.FieldLevel) bool {
268 if _, err := syntax.ParseRecordKey(fl.Field().String()); err != nil {
269 return false
270 }
271 return true
272 })
273 vdtor.RegisterValidation("atproto-nsid", func(fl validator.FieldLevel) bool {
274 if _, err := syntax.ParseNSID(fl.Field().String()); err != nil {
275 return false
276 }
277 return true
278 })
279
280 e.Validator = &CustomValidator{validator: vdtor}
281
282 httpd := &http.Server{
283 Addr: args.Addr,
284 Handler: e,
285 // shitty defaults but okay for now, needed for import repo
286 ReadTimeout: 5 * time.Minute,
287 WriteTimeout: 5 * time.Minute,
288 IdleTimeout: 5 * time.Minute,
289 }
290
291 gdb, err := gorm.Open(sqlite.Open(args.DbName), &gorm.Config{})
292 if err != nil {
293 return nil, err
294 }
295 dbw := db.NewDB(gdb)
296
297 rkbytes, err := os.ReadFile(args.RotationKeyPath)
298 if err != nil {
299 return nil, err
300 }
301
302 h := util.RobustHTTPClient()
303
304 plcClient, err := plc.NewClient(&plc.ClientArgs{
305 H: h,
306 Service: "https://plc.directory",
307 PdsHostname: args.Hostname,
308 RotationKey: rkbytes,
309 })
310 if err != nil {
311 return nil, err
312 }
313
314 jwkbytes, err := os.ReadFile(args.JwkPath)
315 if err != nil {
316 return nil, err
317 }
318
319 key, err := helpers.ParseJWKFromBytes(jwkbytes)
320 if err != nil {
321 return nil, err
322 }
323
324 var pkey ecdsa.PrivateKey
325 if err := key.Raw(&pkey); err != nil {
326 return nil, err
327 }
328
329 oauthCli := &http.Client{
330 Timeout: 10 * time.Second,
331 }
332
333 var nonceSecret []byte
334 maybeSecret, err := os.ReadFile("nonce.secret")
335 if err != nil && !os.IsNotExist(err) {
336 args.Logger.Error("error attempting to read nonce secret", "error", err)
337 } else {
338 nonceSecret = maybeSecret
339 }
340
341 s := &Server{
342 http: h,
343 httpd: httpd,
344 echo: e,
345 logger: args.Logger,
346 db: dbw,
347 plcClient: plcClient,
348 privateKey: &pkey,
349 config: &config{
350 Version: args.Version,
351 Did: args.Did,
352 Hostname: args.Hostname,
353 ContactEmail: args.ContactEmail,
354 EnforcePeering: false,
355 Relays: args.Relays,
356 AdminPassword: args.AdminPassword,
357 SmtpName: args.SmtpName,
358 SmtpEmail: args.SmtpEmail,
359 BlockstoreVariant: args.BlockstoreVariant,
360 FallbackProxy: args.FallbackProxy,
361 },
362 evtman: events.NewEventManager(events.NewMemPersister()),
363 passport: identity.NewPassport(h, identity.NewMemCache(10_000)),
364
365 dbName: args.DbName,
366 s3Config: args.S3Config,
367
368 oauthProvider: provider.NewProvider(provider.Args{
369 Hostname: args.Hostname,
370 ClientManagerArgs: client.ManagerArgs{
371 Cli: oauthCli,
372 Logger: args.Logger,
373 },
374 DpopManagerArgs: dpop.ManagerArgs{
375 NonceSecret: nonceSecret,
376 NonceRotationInterval: constants.NonceMaxRotationInterval / 3,
377 OnNonceSecretCreated: func(newNonce []byte) {
378 if err := os.WriteFile("nonce.secret", newNonce, 0644); err != nil {
379 args.Logger.Error("error writing new nonce secret", "error", err)
380 }
381 },
382 Logger: args.Logger,
383 Hostname: args.Hostname,
384 },
385 }),
386 }
387
388 s.loadTemplates()
389
390 s.repoman = NewRepoMan(s) // TODO: this is way too lazy, stop it
391
392 // TODO: should validate these args
393 if args.SmtpUser == "" || args.SmtpPass == "" || args.SmtpHost == "" || args.SmtpPort == "" || args.SmtpEmail == "" || args.SmtpName == "" {
394 args.Logger.Warn("not enough smtp args were provided. mailing will not work for your server.")
395 } else {
396 mail := mailyak.New(args.SmtpHost+":"+args.SmtpPort, smtp.PlainAuth("", args.SmtpUser, args.SmtpPass, args.SmtpHost))
397 mail.From(s.config.SmtpEmail)
398 mail.FromName(s.config.SmtpName)
399
400 s.mail = mail
401 s.mailLk = &sync.Mutex{}
402 }
403
404 return s, nil
405}
406
407func (s *Server) addRoutes() {
408 // static
409 if s.config.Version == "dev" {
410 s.echo.Static("/static", "server/static")
411 } else {
412 s.echo.GET("/static/*", echo.WrapHandler(http.FileServer(http.FS(staticFS))))
413 }
414
415 // random stuff
416 s.echo.GET("/", s.handleRoot)
417 s.echo.GET("/xrpc/_health", s.handleHealth)
418 s.echo.GET("/.well-known/did.json", s.handleWellKnown)
419 s.echo.GET("/.well-known/oauth-protected-resource", s.handleOauthProtectedResource)
420 s.echo.GET("/.well-known/oauth-authorization-server", s.handleOauthAuthorizationServer)
421 s.echo.GET("/robots.txt", s.handleRobots)
422
423 // public
424 s.echo.GET("/xrpc/com.atproto.identity.resolveHandle", s.handleResolveHandle)
425 s.echo.POST("/xrpc/com.atproto.server.createAccount", s.handleCreateAccount)
426 s.echo.POST("/xrpc/com.atproto.server.createSession", s.handleCreateSession)
427 s.echo.GET("/xrpc/com.atproto.server.describeServer", s.handleDescribeServer)
428
429 s.echo.GET("/xrpc/com.atproto.repo.describeRepo", s.handleDescribeRepo)
430 s.echo.GET("/xrpc/com.atproto.sync.listRepos", s.handleListRepos)
431 s.echo.GET("/xrpc/com.atproto.repo.listRecords", s.handleListRecords)
432 s.echo.GET("/xrpc/com.atproto.repo.getRecord", s.handleRepoGetRecord)
433 s.echo.GET("/xrpc/com.atproto.sync.getRecord", s.handleSyncGetRecord)
434 s.echo.GET("/xrpc/com.atproto.sync.getBlocks", s.handleGetBlocks)
435 s.echo.GET("/xrpc/com.atproto.sync.getLatestCommit", s.handleSyncGetLatestCommit)
436 s.echo.GET("/xrpc/com.atproto.sync.getRepoStatus", s.handleSyncGetRepoStatus)
437 s.echo.GET("/xrpc/com.atproto.sync.getRepo", s.handleSyncGetRepo)
438 s.echo.GET("/xrpc/com.atproto.sync.subscribeRepos", s.handleSyncSubscribeRepos)
439 s.echo.GET("/xrpc/com.atproto.sync.listBlobs", s.handleSyncListBlobs)
440 s.echo.GET("/xrpc/com.atproto.sync.getBlob", s.handleSyncGetBlob)
441
442 // account
443 s.echo.GET("/account", s.handleAccount)
444 s.echo.POST("/account/revoke", s.handleAccountRevoke)
445 s.echo.GET("/account/signin", s.handleAccountSigninGet)
446 s.echo.POST("/account/signin", s.handleAccountSigninPost)
447 s.echo.GET("/account/signout", s.handleAccountSignout)
448
449 // oauth account
450 s.echo.GET("/oauth/jwks", s.handleOauthJwks)
451 s.echo.GET("/oauth/authorize", s.handleOauthAuthorizeGet)
452 s.echo.POST("/oauth/authorize", s.handleOauthAuthorizePost)
453
454 // oauth authorization
455 s.echo.POST("/oauth/par", s.handleOauthPar, s.oauthProvider.BaseMiddleware)
456 s.echo.POST("/oauth/token", s.handleOauthToken, s.oauthProvider.BaseMiddleware)
457
458 // authed
459 s.echo.GET("/xrpc/com.atproto.server.getSession", s.handleGetSession, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
460 s.echo.POST("/xrpc/com.atproto.server.refreshSession", s.handleRefreshSession, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
461 s.echo.POST("/xrpc/com.atproto.server.deleteSession", s.handleDeleteSession, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
462 s.echo.GET("/xrpc/com.atproto.identity.getRecommendedDidCredentials", s.handleGetRecommendedDidCredentials, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
463 s.echo.POST("/xrpc/com.atproto.identity.updateHandle", s.handleIdentityUpdateHandle, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
464 s.echo.POST("/xrpc/com.atproto.identity.requestPlcOperationSignature", s.handleIdentityRequestPlcOperationSignature, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
465 s.echo.POST("/xrpc/com.atproto.identity.signPlcOperation", s.handleSignPlcOperation, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
466 s.echo.POST("/xrpc/com.atproto.identity.submitPlcOperation", s.handleSubmitPlcOperation, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
467 s.echo.POST("/xrpc/com.atproto.server.confirmEmail", s.handleServerConfirmEmail, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
468 s.echo.POST("/xrpc/com.atproto.server.requestEmailConfirmation", s.handleServerRequestEmailConfirmation, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
469 s.echo.POST("/xrpc/com.atproto.server.requestPasswordReset", s.handleServerRequestPasswordReset) // AUTH NOT REQUIRED FOR THIS ONE
470 s.echo.POST("/xrpc/com.atproto.server.requestEmailUpdate", s.handleServerRequestEmailUpdate, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
471 s.echo.POST("/xrpc/com.atproto.server.resetPassword", s.handleServerResetPassword, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
472 s.echo.POST("/xrpc/com.atproto.server.updateEmail", s.handleServerUpdateEmail, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
473 s.echo.GET("/xrpc/com.atproto.server.getServiceAuth", s.handleServerGetServiceAuth, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
474 s.echo.GET("/xrpc/com.atproto.server.checkAccountStatus", s.handleServerCheckAccountStatus, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
475 s.echo.POST("/xrpc/com.atproto.server.deactivateAccount", s.handleServerDeactivateAccount, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
476 s.echo.POST("/xrpc/com.atproto.server.activateAccount", s.handleServerActivateAccount, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
477
478 // repo
479 s.echo.POST("/xrpc/com.atproto.repo.createRecord", s.handleCreateRecord, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
480 s.echo.POST("/xrpc/com.atproto.repo.putRecord", s.handlePutRecord, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
481 s.echo.POST("/xrpc/com.atproto.repo.deleteRecord", s.handleDeleteRecord, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
482 s.echo.POST("/xrpc/com.atproto.repo.applyWrites", s.handleApplyWrites, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
483 s.echo.POST("/xrpc/com.atproto.repo.uploadBlob", s.handleRepoUploadBlob, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
484 s.echo.POST("/xrpc/com.atproto.repo.importRepo", s.handleRepoImportRepo, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
485
486 // stupid silly endpoints
487 s.echo.GET("/xrpc/app.bsky.actor.getPreferences", s.handleActorGetPreferences, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
488 s.echo.POST("/xrpc/app.bsky.actor.putPreferences", s.handleActorPutPreferences, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
489 s.echo.GET("/xrpc/app.bsky.feed.getFeed", s.handleProxyBskyFeedGetFeed, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
490
491 // admin routes
492 s.echo.POST("/xrpc/com.atproto.server.createInviteCode", s.handleCreateInviteCode, s.handleAdminMiddleware)
493 s.echo.POST("/xrpc/com.atproto.server.createInviteCodes", s.handleCreateInviteCodes, s.handleAdminMiddleware)
494
495 // are there any routes that we should be allowing without auth? i dont think so but idk
496 s.echo.GET("/xrpc/*", s.handleProxy, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
497 s.echo.POST("/xrpc/*", s.handleProxy, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
498}
499
500func (s *Server) Serve(ctx context.Context) error {
501 s.addRoutes()
502
503 s.logger.Info("migrating...")
504
505 s.db.AutoMigrate(
506 &models.Actor{},
507 &models.Repo{},
508 &models.InviteCode{},
509 &models.Token{},
510 &models.RefreshToken{},
511 &models.Block{},
512 &models.Record{},
513 &models.Blob{},
514 &models.BlobPart{},
515 &provider.OauthToken{},
516 &provider.OauthAuthorizationRequest{},
517 )
518
519 s.logger.Info("starting cocoon")
520
521 go func() {
522 if err := s.httpd.ListenAndServe(); err != nil {
523 panic(err)
524 }
525 }()
526
527 go s.backupRoutine()
528
529 go func() {
530 if err := s.requestCrawl(ctx); err != nil {
531 s.logger.Error("error requesting crawls", "err", err)
532 }
533 }()
534
535 <-ctx.Done()
536
537 fmt.Println("shut down")
538
539 return nil
540}
541
542func (s *Server) requestCrawl(ctx context.Context) error {
543 logger := s.logger.With("component", "request-crawl")
544 s.requestCrawlMu.Lock()
545 defer s.requestCrawlMu.Unlock()
546
547 logger.Info("requesting crawl with configured relays")
548
549 if time.Now().Sub(s.lastRequestCrawl) <= 1*time.Minute {
550 return fmt.Errorf("a crawl request has already been made within the last minute")
551 }
552
553 for _, relay := range s.config.Relays {
554 logger := logger.With("relay", relay)
555 logger.Info("requesting crawl from relay")
556 cli := xrpc.Client{Host: relay}
557 if err := atproto.SyncRequestCrawl(ctx, &cli, &atproto.SyncRequestCrawl_Input{
558 Hostname: s.config.Hostname,
559 }); err != nil {
560 logger.Error("error requesting crawl", "err", err)
561 } else {
562 logger.Info("crawl requested successfully")
563 }
564 }
565
566 s.lastRequestCrawl = time.Now()
567
568 return nil
569}
570
571func (s *Server) doBackup() {
572 start := time.Now()
573
574 s.logger.Info("beginning backup to s3...")
575
576 var buf bytes.Buffer
577 if err := func() error {
578 s.logger.Info("reading database bytes...")
579 s.db.Lock()
580 defer s.db.Unlock()
581
582 sf, err := os.Open(s.dbName)
583 if err != nil {
584 return fmt.Errorf("error opening database for backup: %w", err)
585 }
586 defer sf.Close()
587
588 if _, err := io.Copy(&buf, sf); err != nil {
589 return fmt.Errorf("error reading bytes of backup db: %w", err)
590 }
591
592 return nil
593 }(); err != nil {
594 s.logger.Error("error backing up database", "error", err)
595 return
596 }
597
598 if err := func() error {
599 s.logger.Info("sending to s3...")
600
601 currTime := time.Now().Format("2006-01-02_15-04-05")
602 key := "cocoon-backup-" + currTime + ".db"
603
604 config := &aws.Config{
605 Region: aws.String(s.s3Config.Region),
606 Credentials: credentials.NewStaticCredentials(s.s3Config.AccessKey, s.s3Config.SecretKey, ""),
607 }
608
609 if s.s3Config.Endpoint != "" {
610 config.Endpoint = aws.String(s.s3Config.Endpoint)
611 config.S3ForcePathStyle = aws.Bool(true)
612 }
613
614 sess, err := session.NewSession(config)
615 if err != nil {
616 return err
617 }
618
619 svc := s3.New(sess)
620
621 if _, err := svc.PutObject(&s3.PutObjectInput{
622 Bucket: aws.String(s.s3Config.Bucket),
623 Key: aws.String(key),
624 Body: bytes.NewReader(buf.Bytes()),
625 }); err != nil {
626 return fmt.Errorf("error uploading file to s3: %w", err)
627 }
628
629 s.logger.Info("finished uploading backup to s3", "key", key, "duration", time.Now().Sub(start).Seconds())
630
631 return nil
632 }(); err != nil {
633 s.logger.Error("error uploading database backup", "error", err)
634 return
635 }
636
637 os.WriteFile("last-backup.txt", []byte(time.Now().String()), 0644)
638}
639
640func (s *Server) backupRoutine() {
641 if s.s3Config == nil || !s.s3Config.BackupsEnabled {
642 return
643 }
644
645 if s.s3Config.Region == "" {
646 s.logger.Warn("no s3 region configured but backups are enabled. backups will not run.")
647 return
648 }
649
650 if s.s3Config.Bucket == "" {
651 s.logger.Warn("no s3 bucket configured but backups are enabled. backups will not run.")
652 return
653 }
654
655 if s.s3Config.AccessKey == "" {
656 s.logger.Warn("no s3 access key configured but backups are enabled. backups will not run.")
657 return
658 }
659
660 if s.s3Config.SecretKey == "" {
661 s.logger.Warn("no s3 secret key configured but backups are enabled. backups will not run.")
662 return
663 }
664
665 shouldBackupNow := false
666 lastBackupStr, err := os.ReadFile("last-backup.txt")
667 if err != nil {
668 shouldBackupNow = true
669 } else {
670 lastBackup, err := time.Parse("2006-01-02 15:04:05.999999999 -0700 MST", string(lastBackupStr))
671 if err != nil {
672 shouldBackupNow = true
673 } else if time.Now().Sub(lastBackup).Seconds() > 3600 {
674 shouldBackupNow = true
675 }
676 }
677
678 if shouldBackupNow {
679 go s.doBackup()
680 }
681
682 ticker := time.NewTicker(time.Hour)
683 for range ticker.C {
684 go s.doBackup()
685 }
686}
687
688func (s *Server) UpdateRepo(ctx context.Context, did string, root cid.Cid, rev string) error {
689 if err := s.db.Exec("UPDATE repos SET root = ?, rev = ? WHERE did = ?", nil, root.Bytes(), rev, did).Error; err != nil {
690 return err
691 }
692
693 return nil
694}