An atproto PDS written in Go
1package server 2 3import ( 4 "bytes" 5 "context" 6 "crypto/ecdsa" 7 "embed" 8 "errors" 9 "fmt" 10 "io" 11 "log/slog" 12 "net/http" 13 "net/smtp" 14 "os" 15 "path/filepath" 16 "sync" 17 "text/template" 18 "time" 19 20 "github.com/aws/aws-sdk-go/aws" 21 "github.com/aws/aws-sdk-go/aws/credentials" 22 "github.com/aws/aws-sdk-go/aws/session" 23 "github.com/aws/aws-sdk-go/service/s3" 24 "github.com/bluesky-social/indigo/api/atproto" 25 "github.com/bluesky-social/indigo/atproto/syntax" 26 "github.com/bluesky-social/indigo/events" 27 "github.com/bluesky-social/indigo/util" 28 "github.com/bluesky-social/indigo/xrpc" 29 "github.com/domodwyer/mailyak/v3" 30 "github.com/go-playground/validator" 31 "github.com/gorilla/sessions" 32 "github.com/haileyok/cocoon/identity" 33 "github.com/haileyok/cocoon/internal/db" 34 "github.com/haileyok/cocoon/internal/helpers" 35 "github.com/haileyok/cocoon/models" 36 "github.com/haileyok/cocoon/oauth/client" 37 "github.com/haileyok/cocoon/oauth/constants" 38 "github.com/haileyok/cocoon/oauth/dpop" 39 "github.com/haileyok/cocoon/oauth/provider" 40 "github.com/haileyok/cocoon/plc" 41 "github.com/ipfs/go-cid" 42 echo_session "github.com/labstack/echo-contrib/session" 43 "github.com/labstack/echo/v4" 44 "github.com/labstack/echo/v4/middleware" 45 slogecho "github.com/samber/slog-echo" 46 "gorm.io/driver/postgres" 47 "gorm.io/driver/sqlite" 48 "gorm.io/gorm" 49) 50 51const ( 52 AccountSessionMaxAge = 30 * 24 * time.Hour // one week 53) 54 55type S3Config struct { 56 BackupsEnabled bool 57 BlobstoreEnabled bool 58 Endpoint string 59 Region string 60 Bucket string 61 AccessKey string 62 SecretKey string 63} 64 65type Server struct { 66 http *http.Client 67 httpd *http.Server 68 mail *mailyak.MailYak 69 mailLk *sync.Mutex 70 echo *echo.Echo 71 db *db.DB 72 plcClient *plc.Client 73 logger *slog.Logger 74 config *config 75 privateKey *ecdsa.PrivateKey 76 repoman *RepoMan 77 oauthProvider *provider.Provider 78 evtman *events.EventManager 79 passport *identity.Passport 80 fallbackProxy string 81 82 lastRequestCrawl time.Time 83 requestCrawlMu sync.Mutex 84 85 dbName string 86 dbType string 87 s3Config *S3Config 88} 89 90type Args struct { 91 Addr string 92 DbName string 93 DbType string 94 DatabaseURL string 95 Logger *slog.Logger 96 Version string 97 Did string 98 Hostname string 99 RotationKeyPath string 100 JwkPath string 101 ContactEmail string 102 Relays []string 103 AdminPassword string 104 105 SmtpUser string 106 SmtpPass string 107 SmtpHost string 108 SmtpPort string 109 SmtpEmail string 110 SmtpName string 111 112 S3Config *S3Config 113 114 SessionSecret string 115 116 BlockstoreVariant BlockstoreVariant 117 FallbackProxy string 118} 119 120type config struct { 121 Version string 122 Did string 123 Hostname string 124 ContactEmail string 125 EnforcePeering bool 126 Relays []string 127 AdminPassword string 128 SmtpEmail string 129 SmtpName string 130 BlockstoreVariant BlockstoreVariant 131 FallbackProxy string 132} 133 134type CustomValidator struct { 135 validator *validator.Validate 136} 137 138type ValidationError struct { 139 error 140 Field string 141 Tag string 142} 143 144func (cv *CustomValidator) Validate(i any) error { 145 if err := cv.validator.Struct(i); err != nil { 146 var validateErrors validator.ValidationErrors 147 if errors.As(err, &validateErrors) && len(validateErrors) > 0 { 148 first := validateErrors[0] 149 return ValidationError{ 150 error: err, 151 Field: first.Field(), 152 Tag: first.Tag(), 153 } 154 } 155 156 return err 157 } 158 159 return nil 160} 161 162//go:embed templates/* 163var templateFS embed.FS 164 165//go:embed static/* 166var staticFS embed.FS 167 168type TemplateRenderer struct { 169 templates *template.Template 170 isDev bool 171 templatePath string 172} 173 174func (s *Server) loadTemplates() { 175 absPath, _ := filepath.Abs("server/templates/*.html") 176 if s.config.Version == "dev" { 177 tmpl := template.Must(template.ParseGlob(absPath)) 178 s.echo.Renderer = &TemplateRenderer{ 179 templates: tmpl, 180 isDev: true, 181 templatePath: absPath, 182 } 183 } else { 184 tmpl := template.Must(template.ParseFS(templateFS, "templates/*.html")) 185 s.echo.Renderer = &TemplateRenderer{ 186 templates: tmpl, 187 isDev: false, 188 } 189 } 190} 191 192func (t *TemplateRenderer) Render(w io.Writer, name string, data any, c echo.Context) error { 193 if t.isDev { 194 tmpl, err := template.ParseGlob(t.templatePath) 195 if err != nil { 196 return err 197 } 198 t.templates = tmpl 199 } 200 201 if viewContext, isMap := data.(map[string]any); isMap { 202 viewContext["reverse"] = c.Echo().Reverse 203 } 204 205 return t.templates.ExecuteTemplate(w, name, data) 206} 207 208func New(args *Args) (*Server, error) { 209 if args.Addr == "" { 210 return nil, fmt.Errorf("addr must be set") 211 } 212 213 if args.DbName == "" { 214 return nil, fmt.Errorf("db name must be set") 215 } 216 217 if args.Did == "" { 218 return nil, fmt.Errorf("cocoon did must be set") 219 } 220 221 if args.ContactEmail == "" { 222 return nil, fmt.Errorf("cocoon contact email is required") 223 } 224 225 if _, err := syntax.ParseDID(args.Did); err != nil { 226 return nil, fmt.Errorf("error parsing cocoon did: %w", err) 227 } 228 229 if args.Hostname == "" { 230 return nil, fmt.Errorf("cocoon hostname must be set") 231 } 232 233 if args.AdminPassword == "" { 234 return nil, fmt.Errorf("admin password must be set") 235 } 236 237 if args.Logger == nil { 238 args.Logger = slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{})) 239 } 240 241 if args.SessionSecret == "" { 242 panic("SESSION SECRET WAS NOT SET. THIS IS REQUIRED. ") 243 } 244 245 e := echo.New() 246 247 e.Pre(middleware.RemoveTrailingSlash()) 248 e.Pre(slogecho.New(args.Logger)) 249 e.Use(echo_session.Middleware(sessions.NewCookieStore([]byte(args.SessionSecret)))) 250 e.Use(middleware.CORSWithConfig(middleware.CORSConfig{ 251 AllowOrigins: []string{"*"}, 252 AllowHeaders: []string{"*"}, 253 AllowMethods: []string{"*"}, 254 AllowCredentials: true, 255 MaxAge: 100_000_000, 256 })) 257 258 vdtor := validator.New() 259 vdtor.RegisterValidation("atproto-handle", func(fl validator.FieldLevel) bool { 260 if _, err := syntax.ParseHandle(fl.Field().String()); err != nil { 261 return false 262 } 263 return true 264 }) 265 vdtor.RegisterValidation("atproto-did", func(fl validator.FieldLevel) bool { 266 if _, err := syntax.ParseDID(fl.Field().String()); err != nil { 267 return false 268 } 269 return true 270 }) 271 vdtor.RegisterValidation("atproto-rkey", func(fl validator.FieldLevel) bool { 272 if _, err := syntax.ParseRecordKey(fl.Field().String()); err != nil { 273 return false 274 } 275 return true 276 }) 277 vdtor.RegisterValidation("atproto-nsid", func(fl validator.FieldLevel) bool { 278 if _, err := syntax.ParseNSID(fl.Field().String()); err != nil { 279 return false 280 } 281 return true 282 }) 283 284 e.Validator = &CustomValidator{validator: vdtor} 285 286 httpd := &http.Server{ 287 Addr: args.Addr, 288 Handler: e, 289 // shitty defaults but okay for now, needed for import repo 290 ReadTimeout: 5 * time.Minute, 291 WriteTimeout: 5 * time.Minute, 292 IdleTimeout: 5 * time.Minute, 293 } 294 295 dbType := args.DbType 296 if dbType == "" { 297 dbType = "sqlite" 298 } 299 300 var gdb *gorm.DB 301 var err error 302 switch dbType { 303 case "postgres": 304 if args.DatabaseURL == "" { 305 return nil, fmt.Errorf("database-url must be set when using postgres") 306 } 307 gdb, err = gorm.Open(postgres.Open(args.DatabaseURL), &gorm.Config{}) 308 if err != nil { 309 return nil, fmt.Errorf("failed to connect to postgres: %w", err) 310 } 311 args.Logger.Info("connected to PostgreSQL database") 312 default: 313 gdb, err = gorm.Open(sqlite.Open(args.DbName), &gorm.Config{}) 314 if err != nil { 315 return nil, fmt.Errorf("failed to open sqlite database: %w", err) 316 } 317 args.Logger.Info("connected to SQLite database", "path", args.DbName) 318 } 319 dbw := db.NewDB(gdb) 320 321 rkbytes, err := os.ReadFile(args.RotationKeyPath) 322 if err != nil { 323 return nil, err 324 } 325 326 h := util.RobustHTTPClient() 327 328 plcClient, err := plc.NewClient(&plc.ClientArgs{ 329 H: h, 330 Service: "https://plc.directory", 331 PdsHostname: args.Hostname, 332 RotationKey: rkbytes, 333 }) 334 if err != nil { 335 return nil, err 336 } 337 338 jwkbytes, err := os.ReadFile(args.JwkPath) 339 if err != nil { 340 return nil, err 341 } 342 343 key, err := helpers.ParseJWKFromBytes(jwkbytes) 344 if err != nil { 345 return nil, err 346 } 347 348 var pkey ecdsa.PrivateKey 349 if err := key.Raw(&pkey); err != nil { 350 return nil, err 351 } 352 353 oauthCli := &http.Client{ 354 Timeout: 10 * time.Second, 355 } 356 357 var nonceSecret []byte 358 maybeSecret, err := os.ReadFile("nonce.secret") 359 if err != nil && !os.IsNotExist(err) { 360 args.Logger.Error("error attempting to read nonce secret", "error", err) 361 } else { 362 nonceSecret = maybeSecret 363 } 364 365 s := &Server{ 366 http: h, 367 httpd: httpd, 368 echo: e, 369 logger: args.Logger, 370 db: dbw, 371 plcClient: plcClient, 372 privateKey: &pkey, 373 config: &config{ 374 Version: args.Version, 375 Did: args.Did, 376 Hostname: args.Hostname, 377 ContactEmail: args.ContactEmail, 378 EnforcePeering: false, 379 Relays: args.Relays, 380 AdminPassword: args.AdminPassword, 381 SmtpName: args.SmtpName, 382 SmtpEmail: args.SmtpEmail, 383 BlockstoreVariant: args.BlockstoreVariant, 384 FallbackProxy: args.FallbackProxy, 385 }, 386 evtman: events.NewEventManager(events.NewMemPersister()), 387 passport: identity.NewPassport(h, identity.NewMemCache(10_000)), 388 389 dbName: args.DbName, 390 dbType: dbType, 391 s3Config: args.S3Config, 392 393 oauthProvider: provider.NewProvider(provider.Args{ 394 Hostname: args.Hostname, 395 ClientManagerArgs: client.ManagerArgs{ 396 Cli: oauthCli, 397 Logger: args.Logger, 398 }, 399 DpopManagerArgs: dpop.ManagerArgs{ 400 NonceSecret: nonceSecret, 401 NonceRotationInterval: constants.NonceMaxRotationInterval / 3, 402 OnNonceSecretCreated: func(newNonce []byte) { 403 if err := os.WriteFile("nonce.secret", newNonce, 0644); err != nil { 404 args.Logger.Error("error writing new nonce secret", "error", err) 405 } 406 }, 407 Logger: args.Logger, 408 Hostname: args.Hostname, 409 }, 410 }), 411 } 412 413 s.loadTemplates() 414 415 s.repoman = NewRepoMan(s) // TODO: this is way too lazy, stop it 416 417 // TODO: should validate these args 418 if args.SmtpUser == "" || args.SmtpPass == "" || args.SmtpHost == "" || args.SmtpPort == "" || args.SmtpEmail == "" || args.SmtpName == "" { 419 args.Logger.Warn("not enough smtp args were provided. mailing will not work for your server.") 420 } else { 421 mail := mailyak.New(args.SmtpHost+":"+args.SmtpPort, smtp.PlainAuth("", args.SmtpUser, args.SmtpPass, args.SmtpHost)) 422 mail.From(s.config.SmtpEmail) 423 mail.FromName(s.config.SmtpName) 424 425 s.mail = mail 426 s.mailLk = &sync.Mutex{} 427 } 428 429 return s, nil 430} 431 432func (s *Server) addRoutes() { 433 // static 434 if s.config.Version == "dev" { 435 s.echo.Static("/static", "server/static") 436 } else { 437 s.echo.GET("/static/*", echo.WrapHandler(http.FileServer(http.FS(staticFS)))) 438 } 439 440 // random stuff 441 s.echo.GET("/", s.handleRoot) 442 s.echo.GET("/xrpc/_health", s.handleHealth) 443 s.echo.GET("/.well-known/did.json", s.handleWellKnown) 444 s.echo.GET("/.well-known/oauth-protected-resource", s.handleOauthProtectedResource) 445 s.echo.GET("/.well-known/oauth-authorization-server", s.handleOauthAuthorizationServer) 446 s.echo.GET("/robots.txt", s.handleRobots) 447 448 // public 449 s.echo.GET("/xrpc/com.atproto.identity.resolveHandle", s.handleResolveHandle) 450 s.echo.POST("/xrpc/com.atproto.server.createAccount", s.handleCreateAccount) 451 s.echo.POST("/xrpc/com.atproto.server.createSession", s.handleCreateSession) 452 s.echo.GET("/xrpc/com.atproto.server.describeServer", s.handleDescribeServer) 453 454 s.echo.GET("/xrpc/com.atproto.repo.describeRepo", s.handleDescribeRepo) 455 s.echo.GET("/xrpc/com.atproto.sync.listRepos", s.handleListRepos) 456 s.echo.GET("/xrpc/com.atproto.repo.listRecords", s.handleListRecords) 457 s.echo.GET("/xrpc/com.atproto.repo.listMissingBlobs", s.handleListMissingBlobs) 458 s.echo.GET("/xrpc/com.atproto.repo.getRecord", s.handleRepoGetRecord) 459 s.echo.GET("/xrpc/com.atproto.sync.getRecord", s.handleSyncGetRecord) 460 s.echo.GET("/xrpc/com.atproto.sync.getBlocks", s.handleGetBlocks) 461 s.echo.GET("/xrpc/com.atproto.sync.getLatestCommit", s.handleSyncGetLatestCommit) 462 s.echo.GET("/xrpc/com.atproto.sync.getRepoStatus", s.handleSyncGetRepoStatus) 463 s.echo.GET("/xrpc/com.atproto.sync.getRepo", s.handleSyncGetRepo) 464 s.echo.GET("/xrpc/com.atproto.sync.subscribeRepos", s.handleSyncSubscribeRepos) 465 s.echo.GET("/xrpc/com.atproto.sync.listBlobs", s.handleSyncListBlobs) 466 s.echo.GET("/xrpc/com.atproto.sync.getBlob", s.handleSyncGetBlob) 467 468 // account 469 s.echo.GET("/account", s.handleAccount) 470 s.echo.POST("/account/revoke", s.handleAccountRevoke) 471 s.echo.GET("/account/signin", s.handleAccountSigninGet) 472 s.echo.POST("/account/signin", s.handleAccountSigninPost) 473 s.echo.GET("/account/signout", s.handleAccountSignout) 474 475 // oauth account 476 s.echo.GET("/oauth/jwks", s.handleOauthJwks) 477 s.echo.GET("/oauth/authorize", s.handleOauthAuthorizeGet) 478 s.echo.POST("/oauth/authorize", s.handleOauthAuthorizePost) 479 480 // oauth authorization 481 s.echo.POST("/oauth/par", s.handleOauthPar, s.oauthProvider.BaseMiddleware) 482 s.echo.POST("/oauth/token", s.handleOauthToken, s.oauthProvider.BaseMiddleware) 483 484 // authed 485 s.echo.GET("/xrpc/com.atproto.server.getSession", s.handleGetSession, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 486 s.echo.POST("/xrpc/com.atproto.server.refreshSession", s.handleRefreshSession, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 487 s.echo.POST("/xrpc/com.atproto.server.deleteSession", s.handleDeleteSession, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 488 s.echo.GET("/xrpc/com.atproto.identity.getRecommendedDidCredentials", s.handleGetRecommendedDidCredentials, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 489 s.echo.POST("/xrpc/com.atproto.identity.updateHandle", s.handleIdentityUpdateHandle, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 490 s.echo.POST("/xrpc/com.atproto.identity.requestPlcOperationSignature", s.handleIdentityRequestPlcOperationSignature, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 491 s.echo.POST("/xrpc/com.atproto.identity.signPlcOperation", s.handleSignPlcOperation, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 492 s.echo.POST("/xrpc/com.atproto.identity.submitPlcOperation", s.handleSubmitPlcOperation, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 493 s.echo.POST("/xrpc/com.atproto.server.confirmEmail", s.handleServerConfirmEmail, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 494 s.echo.POST("/xrpc/com.atproto.server.requestEmailConfirmation", s.handleServerRequestEmailConfirmation, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 495 s.echo.POST("/xrpc/com.atproto.server.requestPasswordReset", s.handleServerRequestPasswordReset) // AUTH NOT REQUIRED FOR THIS ONE 496 s.echo.POST("/xrpc/com.atproto.server.requestEmailUpdate", s.handleServerRequestEmailUpdate, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 497 s.echo.POST("/xrpc/com.atproto.server.resetPassword", s.handleServerResetPassword, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 498 s.echo.POST("/xrpc/com.atproto.server.updateEmail", s.handleServerUpdateEmail, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 499 s.echo.GET("/xrpc/com.atproto.server.getServiceAuth", s.handleServerGetServiceAuth, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 500 s.echo.GET("/xrpc/com.atproto.server.checkAccountStatus", s.handleServerCheckAccountStatus, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 501 s.echo.POST("/xrpc/com.atproto.server.deactivateAccount", s.handleServerDeactivateAccount, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 502 s.echo.POST("/xrpc/com.atproto.server.activateAccount", s.handleServerActivateAccount, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 503 504 // repo 505 s.echo.POST("/xrpc/com.atproto.repo.createRecord", s.handleCreateRecord, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 506 s.echo.POST("/xrpc/com.atproto.repo.putRecord", s.handlePutRecord, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 507 s.echo.POST("/xrpc/com.atproto.repo.deleteRecord", s.handleDeleteRecord, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 508 s.echo.POST("/xrpc/com.atproto.repo.applyWrites", s.handleApplyWrites, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 509 s.echo.POST("/xrpc/com.atproto.repo.uploadBlob", s.handleRepoUploadBlob, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 510 s.echo.POST("/xrpc/com.atproto.repo.importRepo", s.handleRepoImportRepo, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 511 512 // stupid silly endpoints 513 s.echo.GET("/xrpc/app.bsky.actor.getPreferences", s.handleActorGetPreferences, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 514 s.echo.POST("/xrpc/app.bsky.actor.putPreferences", s.handleActorPutPreferences, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 515 s.echo.GET("/xrpc/app.bsky.feed.getFeed", s.handleProxyBskyFeedGetFeed, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 516 517 // admin routes 518 s.echo.POST("/xrpc/com.atproto.server.createInviteCode", s.handleCreateInviteCode, s.handleAdminMiddleware) 519 s.echo.POST("/xrpc/com.atproto.server.createInviteCodes", s.handleCreateInviteCodes, s.handleAdminMiddleware) 520 521 // are there any routes that we should be allowing without auth? i dont think so but idk 522 s.echo.GET("/xrpc/*", s.handleProxy, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 523 s.echo.POST("/xrpc/*", s.handleProxy, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 524} 525 526func (s *Server) Serve(ctx context.Context) error { 527 s.addRoutes() 528 529 s.logger.Info("migrating...") 530 531 s.db.AutoMigrate( 532 &models.Actor{}, 533 &models.Repo{}, 534 &models.InviteCode{}, 535 &models.Token{}, 536 &models.RefreshToken{}, 537 &models.Block{}, 538 &models.Record{}, 539 &models.Blob{}, 540 &models.BlobPart{}, 541 &provider.OauthToken{}, 542 &provider.OauthAuthorizationRequest{}, 543 ) 544 545 s.logger.Info("starting cocoon") 546 547 go func() { 548 if err := s.httpd.ListenAndServe(); err != nil { 549 panic(err) 550 } 551 }() 552 553 go s.backupRoutine() 554 555 go func() { 556 if err := s.requestCrawl(ctx); err != nil { 557 s.logger.Error("error requesting crawls", "err", err) 558 } 559 }() 560 561 <-ctx.Done() 562 563 fmt.Println("shut down") 564 565 return nil 566} 567 568func (s *Server) requestCrawl(ctx context.Context) error { 569 logger := s.logger.With("component", "request-crawl") 570 s.requestCrawlMu.Lock() 571 defer s.requestCrawlMu.Unlock() 572 573 logger.Info("requesting crawl with configured relays") 574 575 if time.Now().Sub(s.lastRequestCrawl) <= 1*time.Minute { 576 return fmt.Errorf("a crawl request has already been made within the last minute") 577 } 578 579 for _, relay := range s.config.Relays { 580 logger := logger.With("relay", relay) 581 logger.Info("requesting crawl from relay") 582 cli := xrpc.Client{Host: relay} 583 if err := atproto.SyncRequestCrawl(ctx, &cli, &atproto.SyncRequestCrawl_Input{ 584 Hostname: s.config.Hostname, 585 }); err != nil { 586 logger.Error("error requesting crawl", "err", err) 587 } else { 588 logger.Info("crawl requested successfully") 589 } 590 } 591 592 s.lastRequestCrawl = time.Now() 593 594 return nil 595} 596 597func (s *Server) doBackup() { 598 if s.dbType == "postgres" { 599 s.logger.Info("skipping S3 backup - PostgreSQL backups should be handled externally (pg_dump, managed database backups, etc.)") 600 return 601 } 602 603 start := time.Now() 604 605 s.logger.Info("beginning backup to s3...") 606 607 var buf bytes.Buffer 608 if err := func() error { 609 s.logger.Info("reading database bytes...") 610 s.db.Lock() 611 defer s.db.Unlock() 612 613 sf, err := os.Open(s.dbName) 614 if err != nil { 615 return fmt.Errorf("error opening database for backup: %w", err) 616 } 617 defer sf.Close() 618 619 if _, err := io.Copy(&buf, sf); err != nil { 620 return fmt.Errorf("error reading bytes of backup db: %w", err) 621 } 622 623 return nil 624 }(); err != nil { 625 s.logger.Error("error backing up database", "error", err) 626 return 627 } 628 629 if err := func() error { 630 s.logger.Info("sending to s3...") 631 632 currTime := time.Now().Format("2006-01-02_15-04-05") 633 key := "cocoon-backup-" + currTime + ".db" 634 635 config := &aws.Config{ 636 Region: aws.String(s.s3Config.Region), 637 Credentials: credentials.NewStaticCredentials(s.s3Config.AccessKey, s.s3Config.SecretKey, ""), 638 } 639 640 if s.s3Config.Endpoint != "" { 641 config.Endpoint = aws.String(s.s3Config.Endpoint) 642 config.S3ForcePathStyle = aws.Bool(true) 643 } 644 645 sess, err := session.NewSession(config) 646 if err != nil { 647 return err 648 } 649 650 svc := s3.New(sess) 651 652 if _, err := svc.PutObject(&s3.PutObjectInput{ 653 Bucket: aws.String(s.s3Config.Bucket), 654 Key: aws.String(key), 655 Body: bytes.NewReader(buf.Bytes()), 656 }); err != nil { 657 return fmt.Errorf("error uploading file to s3: %w", err) 658 } 659 660 s.logger.Info("finished uploading backup to s3", "key", key, "duration", time.Now().Sub(start).Seconds()) 661 662 return nil 663 }(); err != nil { 664 s.logger.Error("error uploading database backup", "error", err) 665 return 666 } 667 668 os.WriteFile("last-backup.txt", []byte(time.Now().String()), 0644) 669} 670 671func (s *Server) backupRoutine() { 672 if s.s3Config == nil || !s.s3Config.BackupsEnabled { 673 return 674 } 675 676 if s.s3Config.Region == "" { 677 s.logger.Warn("no s3 region configured but backups are enabled. backups will not run.") 678 return 679 } 680 681 if s.s3Config.Bucket == "" { 682 s.logger.Warn("no s3 bucket configured but backups are enabled. backups will not run.") 683 return 684 } 685 686 if s.s3Config.AccessKey == "" { 687 s.logger.Warn("no s3 access key configured but backups are enabled. backups will not run.") 688 return 689 } 690 691 if s.s3Config.SecretKey == "" { 692 s.logger.Warn("no s3 secret key configured but backups are enabled. backups will not run.") 693 return 694 } 695 696 shouldBackupNow := false 697 lastBackupStr, err := os.ReadFile("last-backup.txt") 698 if err != nil { 699 shouldBackupNow = true 700 } else { 701 lastBackup, err := time.Parse("2006-01-02 15:04:05.999999999 -0700 MST", string(lastBackupStr)) 702 if err != nil { 703 shouldBackupNow = true 704 } else if time.Now().Sub(lastBackup).Seconds() > 3600 { 705 shouldBackupNow = true 706 } 707 } 708 709 if shouldBackupNow { 710 go s.doBackup() 711 } 712 713 ticker := time.NewTicker(time.Hour) 714 for range ticker.C { 715 go s.doBackup() 716 } 717} 718 719func (s *Server) UpdateRepo(ctx context.Context, did string, root cid.Cid, rev string) error { 720 if err := s.db.Exec("UPDATE repos SET root = ?, rev = ? WHERE did = ?", nil, root.Bytes(), rev, did).Error; err != nil { 721 return err 722 } 723 724 return nil 725}