An atproto PDS written in Go
1package server 2 3import ( 4 "bytes" 5 "context" 6 "crypto/ecdsa" 7 "crypto/sha256" 8 "embed" 9 "encoding/base64" 10 "errors" 11 "fmt" 12 "io" 13 "log/slog" 14 "net/http" 15 "net/smtp" 16 "os" 17 "path/filepath" 18 "strings" 19 "sync" 20 "text/template" 21 "time" 22 23 "github.com/Azure/go-autorest/autorest/to" 24 "github.com/aws/aws-sdk-go/aws" 25 "github.com/aws/aws-sdk-go/aws/credentials" 26 "github.com/aws/aws-sdk-go/aws/session" 27 "github.com/aws/aws-sdk-go/service/s3" 28 "github.com/bluesky-social/indigo/api/atproto" 29 "github.com/bluesky-social/indigo/atproto/syntax" 30 "github.com/bluesky-social/indigo/events" 31 "github.com/bluesky-social/indigo/util" 32 "github.com/bluesky-social/indigo/xrpc" 33 "github.com/domodwyer/mailyak/v3" 34 "github.com/go-playground/validator" 35 "github.com/golang-jwt/jwt/v4" 36 "github.com/gorilla/sessions" 37 "github.com/haileyok/cocoon/identity" 38 "github.com/haileyok/cocoon/internal/db" 39 "github.com/haileyok/cocoon/internal/helpers" 40 "github.com/haileyok/cocoon/models" 41 "github.com/haileyok/cocoon/oauth/client" 42 "github.com/haileyok/cocoon/oauth/constants" 43 "github.com/haileyok/cocoon/oauth/dpop" 44 "github.com/haileyok/cocoon/oauth/provider" 45 "github.com/haileyok/cocoon/plc" 46 echo_session "github.com/labstack/echo-contrib/session" 47 "github.com/labstack/echo/v4" 48 "github.com/labstack/echo/v4/middleware" 49 slogecho "github.com/samber/slog-echo" 50 "gitlab.com/yawning/secp256k1-voi" 51 secp256k1secec "gitlab.com/yawning/secp256k1-voi/secec" 52 "gorm.io/driver/sqlite" 53 "gorm.io/gorm" 54) 55 56const ( 57 AccountSessionMaxAge = 30 * 24 * time.Hour // one week 58) 59 60type S3Config struct { 61 BackupsEnabled bool 62 Endpoint string 63 Region string 64 Bucket string 65 AccessKey string 66 SecretKey string 67} 68 69type Server struct { 70 http *http.Client 71 httpd *http.Server 72 mail *mailyak.MailYak 73 mailLk *sync.Mutex 74 echo *echo.Echo 75 db *db.DB 76 plcClient *plc.Client 77 logger *slog.Logger 78 config *config 79 privateKey *ecdsa.PrivateKey 80 repoman *RepoMan 81 oauthProvider *provider.Provider 82 evtman *events.EventManager 83 passport *identity.Passport 84 85 dbName string 86 s3Config *S3Config 87} 88 89type Args struct { 90 Addr string 91 DbName string 92 Logger *slog.Logger 93 Version string 94 Did string 95 Hostname string 96 RotationKeyPath string 97 JwkPath string 98 ContactEmail string 99 Relays []string 100 AdminPassword string 101 102 SmtpUser string 103 SmtpPass string 104 SmtpHost string 105 SmtpPort string 106 SmtpEmail string 107 SmtpName string 108 109 S3Config *S3Config 110 111 SessionSecret string 112} 113 114type config struct { 115 Version string 116 Did string 117 Hostname string 118 ContactEmail string 119 EnforcePeering bool 120 Relays []string 121 AdminPassword string 122 SmtpEmail string 123 SmtpName string 124} 125 126type CustomValidator struct { 127 validator *validator.Validate 128} 129 130type ValidationError struct { 131 error 132 Field string 133 Tag string 134} 135 136func (cv *CustomValidator) Validate(i any) error { 137 if err := cv.validator.Struct(i); err != nil { 138 var validateErrors validator.ValidationErrors 139 if errors.As(err, &validateErrors) && len(validateErrors) > 0 { 140 first := validateErrors[0] 141 return ValidationError{ 142 error: err, 143 Field: first.Field(), 144 Tag: first.Tag(), 145 } 146 } 147 148 return err 149 } 150 151 return nil 152} 153 154//go:embed templates/* 155var templateFS embed.FS 156 157//go:embed static/* 158var staticFS embed.FS 159 160type TemplateRenderer struct { 161 templates *template.Template 162 isDev bool 163 templatePath string 164} 165 166func (s *Server) loadTemplates() { 167 absPath, _ := filepath.Abs("server/templates/*.html") 168 if s.config.Version == "dev" { 169 tmpl := template.Must(template.ParseGlob(absPath)) 170 s.echo.Renderer = &TemplateRenderer{ 171 templates: tmpl, 172 isDev: true, 173 templatePath: absPath, 174 } 175 } else { 176 tmpl := template.Must(template.ParseFS(templateFS, "templates/*.html")) 177 s.echo.Renderer = &TemplateRenderer{ 178 templates: tmpl, 179 isDev: false, 180 } 181 } 182} 183 184func (t *TemplateRenderer) Render(w io.Writer, name string, data any, c echo.Context) error { 185 if t.isDev { 186 tmpl, err := template.ParseGlob(t.templatePath) 187 if err != nil { 188 return err 189 } 190 t.templates = tmpl 191 } 192 193 if viewContext, isMap := data.(map[string]any); isMap { 194 viewContext["reverse"] = c.Echo().Reverse 195 } 196 197 return t.templates.ExecuteTemplate(w, name, data) 198} 199 200func (s *Server) handleAdminMiddleware(next echo.HandlerFunc) echo.HandlerFunc { 201 return func(e echo.Context) error { 202 username, password, ok := e.Request().BasicAuth() 203 if !ok || username != "admin" || password != s.config.AdminPassword { 204 return helpers.InputError(e, to.StringPtr("Unauthorized")) 205 } 206 207 if err := next(e); err != nil { 208 e.Error(err) 209 } 210 211 return nil 212 } 213} 214 215func (s *Server) handleLegacySessionMiddleware(next echo.HandlerFunc) echo.HandlerFunc { 216 return func(e echo.Context) error { 217 authheader := e.Request().Header.Get("authorization") 218 if authheader == "" { 219 return e.JSON(401, map[string]string{"error": "Unauthorized"}) 220 } 221 222 pts := strings.Split(authheader, " ") 223 if len(pts) != 2 { 224 return helpers.ServerError(e, nil) 225 } 226 227 // move on to oauth session middleware if this is a dpop token 228 if pts[0] == "DPoP" { 229 return next(e) 230 } 231 232 tokenstr := pts[1] 233 token, _, err := new(jwt.Parser).ParseUnverified(tokenstr, jwt.MapClaims{}) 234 claims, ok := token.Claims.(jwt.MapClaims) 235 if !ok { 236 return helpers.InputError(e, to.StringPtr("InvalidToken")) 237 } 238 239 var did string 240 var repo *models.RepoActor 241 242 // service auth tokens 243 lxm, hasLxm := claims["lxm"] 244 if hasLxm { 245 pts := strings.Split(e.Request().URL.String(), "/") 246 if lxm != pts[len(pts)-1] { 247 s.logger.Error("service auth lxm incorrect", "lxm", lxm, "expected", pts[len(pts)-1], "error", err) 248 return helpers.InputError(e, nil) 249 } 250 251 maybeDid, ok := claims["iss"].(string) 252 if !ok { 253 s.logger.Error("no iss in service auth token", "error", err) 254 return helpers.InputError(e, nil) 255 } 256 did = maybeDid 257 258 maybeRepo, err := s.getRepoActorByDid(did) 259 if err != nil { 260 s.logger.Error("error fetching repo", "error", err) 261 return helpers.ServerError(e, nil) 262 } 263 repo = maybeRepo 264 } 265 266 if token.Header["alg"] != "ES256K" { 267 token, err = new(jwt.Parser).Parse(tokenstr, func(t *jwt.Token) (any, error) { 268 if _, ok := t.Method.(*jwt.SigningMethodECDSA); !ok { 269 return nil, fmt.Errorf("unsupported signing method: %v", t.Header["alg"]) 270 } 271 return s.privateKey.Public(), nil 272 }) 273 if err != nil { 274 s.logger.Error("error parsing jwt", "error", err) 275 // NOTE: https://github.com/bluesky-social/atproto/discussions/3319 276 return e.JSON(400, map[string]string{"error": "ExpiredToken", "message": "token has expired"}) 277 } 278 279 if !token.Valid { 280 return helpers.InputError(e, to.StringPtr("InvalidToken")) 281 } 282 } else { 283 kpts := strings.Split(tokenstr, ".") 284 signingInput := kpts[0] + "." + kpts[1] 285 hash := sha256.Sum256([]byte(signingInput)) 286 sigBytes, err := base64.RawURLEncoding.DecodeString(kpts[2]) 287 if err != nil { 288 s.logger.Error("error decoding signature bytes", "error", err) 289 return helpers.ServerError(e, nil) 290 } 291 292 if len(sigBytes) != 64 { 293 s.logger.Error("incorrect sigbytes length", "length", len(sigBytes)) 294 return helpers.ServerError(e, nil) 295 } 296 297 rBytes := sigBytes[:32] 298 sBytes := sigBytes[32:] 299 rr, _ := secp256k1.NewScalarFromBytes((*[32]byte)(rBytes)) 300 ss, _ := secp256k1.NewScalarFromBytes((*[32]byte)(sBytes)) 301 302 sk, err := secp256k1secec.NewPrivateKey(repo.SigningKey) 303 if err != nil { 304 s.logger.Error("can't load private key", "error", err) 305 return err 306 } 307 308 pubKey, ok := sk.Public().(*secp256k1secec.PublicKey) 309 if !ok { 310 s.logger.Error("error getting public key from sk") 311 return helpers.ServerError(e, nil) 312 } 313 314 verified := pubKey.VerifyRaw(hash[:], rr, ss) 315 if !verified { 316 s.logger.Error("error verifying", "error", err) 317 return helpers.ServerError(e, nil) 318 } 319 } 320 321 isRefresh := e.Request().URL.Path == "/xrpc/com.atproto.server.refreshSession" 322 scope, _ := claims["scope"].(string) 323 324 if isRefresh && scope != "com.atproto.refresh" { 325 return helpers.InputError(e, to.StringPtr("InvalidToken")) 326 } else if !hasLxm && !isRefresh && scope != "com.atproto.access" { 327 return helpers.InputError(e, to.StringPtr("InvalidToken")) 328 } 329 330 table := "tokens" 331 if isRefresh { 332 table = "refresh_tokens" 333 } 334 335 if isRefresh { 336 type Result struct { 337 Found bool 338 } 339 var result Result 340 if err := s.db.Raw("SELECT EXISTS(SELECT 1 FROM "+table+" WHERE token = ?) AS found", nil, tokenstr).Scan(&result).Error; err != nil { 341 if err == gorm.ErrRecordNotFound { 342 return helpers.InputError(e, to.StringPtr("InvalidToken")) 343 } 344 345 s.logger.Error("error getting token from db", "error", err) 346 return helpers.ServerError(e, nil) 347 } 348 349 if !result.Found { 350 return helpers.InputError(e, to.StringPtr("InvalidToken")) 351 } 352 } 353 354 exp, ok := claims["exp"].(float64) 355 if !ok { 356 s.logger.Error("error getting iat from token") 357 return helpers.ServerError(e, nil) 358 } 359 360 if exp < float64(time.Now().UTC().Unix()) { 361 return helpers.InputError(e, to.StringPtr("ExpiredToken")) 362 } 363 364 if repo == nil { 365 maybeRepo, err := s.getRepoActorByDid(claims["sub"].(string)) 366 if err != nil { 367 s.logger.Error("error fetching repo", "error", err) 368 return helpers.ServerError(e, nil) 369 } 370 repo = maybeRepo 371 did = repo.Repo.Did 372 } 373 374 e.Set("repo", repo) 375 e.Set("did", did) 376 e.Set("token", tokenstr) 377 378 if err := next(e); err != nil { 379 e.Error(err) 380 } 381 382 return nil 383 } 384} 385 386func (s *Server) handleOauthSessionMiddleware(next echo.HandlerFunc) echo.HandlerFunc { 387 return func(e echo.Context) error { 388 authheader := e.Request().Header.Get("authorization") 389 if authheader == "" { 390 return e.JSON(401, map[string]string{"error": "Unauthorized"}) 391 } 392 393 pts := strings.Split(authheader, " ") 394 if len(pts) != 2 { 395 return helpers.ServerError(e, nil) 396 } 397 398 if pts[0] != "DPoP" { 399 return next(e) 400 } 401 402 accessToken := pts[1] 403 404 nonce := s.oauthProvider.NextNonce() 405 if nonce != "" { 406 e.Response().Header().Set("DPoP-Nonce", nonce) 407 e.Response().Header().Add("access-control-expose-headers", "DPoP-Nonce") 408 } 409 410 proof, err := s.oauthProvider.DpopManager.CheckProof(e.Request().Method, "https://"+s.config.Hostname+e.Request().URL.String(), e.Request().Header, to.StringPtr(accessToken)) 411 if err != nil { 412 s.logger.Error("invalid dpop proof", "error", err) 413 return helpers.InputError(e, to.StringPtr(err.Error())) 414 } 415 416 var oauthToken provider.OauthToken 417 if err := s.db.Raw("SELECT * FROM oauth_tokens WHERE token = ?", nil, accessToken).Scan(&oauthToken).Error; err != nil { 418 s.logger.Error("error finding access token in db", "error", err) 419 return helpers.InputError(e, nil) 420 } 421 422 if oauthToken.Token == "" { 423 return helpers.InputError(e, to.StringPtr("InvalidToken")) 424 } 425 426 if *oauthToken.Parameters.DpopJkt != proof.JKT { 427 s.logger.Error("jkt mismatch", "token", oauthToken.Parameters.DpopJkt, "proof", proof.JKT) 428 return helpers.InputError(e, to.StringPtr("dpop jkt mismatch")) 429 } 430 431 if time.Now().After(oauthToken.ExpiresAt) { 432 return e.JSON(400, map[string]string{"error": "ExpiredToken", "message": "token has expired"}) 433 } 434 435 repo, err := s.getRepoActorByDid(oauthToken.Sub) 436 if err != nil { 437 s.logger.Error("could not find actor in db", "error", err) 438 return helpers.ServerError(e, nil) 439 } 440 441 e.Set("repo", repo) 442 e.Set("did", repo.Repo.Did) 443 e.Set("token", accessToken) 444 e.Set("scopes", strings.Split(oauthToken.Parameters.Scope, " ")) 445 446 return next(e) 447 } 448} 449 450func New(args *Args) (*Server, error) { 451 if args.Addr == "" { 452 return nil, fmt.Errorf("addr must be set") 453 } 454 455 if args.DbName == "" { 456 return nil, fmt.Errorf("db name must be set") 457 } 458 459 if args.Did == "" { 460 return nil, fmt.Errorf("cocoon did must be set") 461 } 462 463 if args.ContactEmail == "" { 464 return nil, fmt.Errorf("cocoon contact email is required") 465 } 466 467 if _, err := syntax.ParseDID(args.Did); err != nil { 468 return nil, fmt.Errorf("error parsing cocoon did: %w", err) 469 } 470 471 if args.Hostname == "" { 472 return nil, fmt.Errorf("cocoon hostname must be set") 473 } 474 475 if args.AdminPassword == "" { 476 return nil, fmt.Errorf("admin password must be set") 477 } 478 479 if args.Logger == nil { 480 args.Logger = slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{})) 481 } 482 483 if args.SessionSecret == "" { 484 panic("SESSION SECRET WAS NOT SET. THIS IS REQUIRED. ") 485 } 486 487 e := echo.New() 488 489 e.Pre(middleware.RemoveTrailingSlash()) 490 e.Pre(slogecho.New(args.Logger)) 491 e.Use(echo_session.Middleware(sessions.NewCookieStore([]byte(args.SessionSecret)))) 492 e.Use(middleware.CORSWithConfig(middleware.CORSConfig{ 493 AllowOrigins: []string{"*"}, 494 AllowHeaders: []string{"*"}, 495 AllowMethods: []string{"*"}, 496 AllowCredentials: true, 497 MaxAge: 100_000_000, 498 })) 499 500 vdtor := validator.New() 501 vdtor.RegisterValidation("atproto-handle", func(fl validator.FieldLevel) bool { 502 if _, err := syntax.ParseHandle(fl.Field().String()); err != nil { 503 return false 504 } 505 return true 506 }) 507 vdtor.RegisterValidation("atproto-did", func(fl validator.FieldLevel) bool { 508 if _, err := syntax.ParseDID(fl.Field().String()); err != nil { 509 return false 510 } 511 return true 512 }) 513 vdtor.RegisterValidation("atproto-rkey", func(fl validator.FieldLevel) bool { 514 if _, err := syntax.ParseRecordKey(fl.Field().String()); err != nil { 515 return false 516 } 517 return true 518 }) 519 vdtor.RegisterValidation("atproto-nsid", func(fl validator.FieldLevel) bool { 520 if _, err := syntax.ParseNSID(fl.Field().String()); err != nil { 521 return false 522 } 523 return true 524 }) 525 526 e.Validator = &CustomValidator{validator: vdtor} 527 528 httpd := &http.Server{ 529 Addr: args.Addr, 530 Handler: e, 531 // shitty defaults but okay for now, needed for import repo 532 ReadTimeout: 5 * time.Minute, 533 WriteTimeout: 5 * time.Minute, 534 IdleTimeout: 5 * time.Minute, 535 } 536 537 gdb, err := gorm.Open(sqlite.Open("cocoon.db"), &gorm.Config{}) 538 if err != nil { 539 return nil, err 540 } 541 dbw := db.NewDB(gdb) 542 543 rkbytes, err := os.ReadFile(args.RotationKeyPath) 544 if err != nil { 545 return nil, err 546 } 547 548 h := util.RobustHTTPClient() 549 550 plcClient, err := plc.NewClient(&plc.ClientArgs{ 551 H: h, 552 Service: "https://plc.directory", 553 PdsHostname: args.Hostname, 554 RotationKey: rkbytes, 555 }) 556 if err != nil { 557 return nil, err 558 } 559 560 jwkbytes, err := os.ReadFile(args.JwkPath) 561 if err != nil { 562 return nil, err 563 } 564 565 key, err := helpers.ParseJWKFromBytes(jwkbytes) 566 if err != nil { 567 return nil, err 568 } 569 570 var pkey ecdsa.PrivateKey 571 if err := key.Raw(&pkey); err != nil { 572 return nil, err 573 } 574 575 oauthCli := &http.Client{ 576 Timeout: 10 * time.Second, 577 } 578 579 var nonceSecret []byte 580 maybeSecret, err := os.ReadFile("nonce.secret") 581 if err != nil && !os.IsNotExist(err) { 582 args.Logger.Error("error attempting to read nonce secret", "error", err) 583 } else { 584 nonceSecret = maybeSecret 585 } 586 587 s := &Server{ 588 http: h, 589 httpd: httpd, 590 echo: e, 591 logger: args.Logger, 592 db: dbw, 593 plcClient: plcClient, 594 privateKey: &pkey, 595 config: &config{ 596 Version: args.Version, 597 Did: args.Did, 598 Hostname: args.Hostname, 599 ContactEmail: args.ContactEmail, 600 EnforcePeering: false, 601 Relays: args.Relays, 602 AdminPassword: args.AdminPassword, 603 SmtpName: args.SmtpName, 604 SmtpEmail: args.SmtpEmail, 605 }, 606 evtman: events.NewEventManager(events.NewMemPersister()), 607 passport: identity.NewPassport(h, identity.NewMemCache(10_000)), 608 609 dbName: args.DbName, 610 s3Config: args.S3Config, 611 612 oauthProvider: provider.NewProvider(provider.Args{ 613 Hostname: args.Hostname, 614 ClientManagerArgs: client.ManagerArgs{ 615 Cli: oauthCli, 616 Logger: args.Logger, 617 }, 618 DpopManagerArgs: dpop.ManagerArgs{ 619 NonceSecret: nonceSecret, 620 NonceRotationInterval: constants.NonceMaxRotationInterval / 3, 621 OnNonceSecretCreated: func(newNonce []byte) { 622 if err := os.WriteFile("nonce.secret", newNonce, 0644); err != nil { 623 args.Logger.Error("error writing new nonce secret", "error", err) 624 } 625 }, 626 Logger: args.Logger, 627 Hostname: args.Hostname, 628 }, 629 }), 630 } 631 632 s.loadTemplates() 633 634 s.repoman = NewRepoMan(s) // TODO: this is way too lazy, stop it 635 636 // TODO: should validate these args 637 if args.SmtpUser == "" || args.SmtpPass == "" || args.SmtpHost == "" || args.SmtpPort == "" || args.SmtpEmail == "" || args.SmtpName == "" { 638 args.Logger.Warn("not enough smpt args were provided. mailing will not work for your server.") 639 } else { 640 mail := mailyak.New(args.SmtpHost+":"+args.SmtpPort, smtp.PlainAuth("", args.SmtpUser, args.SmtpPass, args.SmtpHost)) 641 mail.From(s.config.SmtpEmail) 642 mail.FromName(s.config.SmtpName) 643 644 s.mail = mail 645 s.mailLk = &sync.Mutex{} 646 } 647 648 return s, nil 649} 650 651func (s *Server) addRoutes() { 652 // static 653 if s.config.Version == "dev" { 654 s.echo.Static("/static", "server/static") 655 } else { 656 s.echo.GET("/static/*", echo.WrapHandler(http.FileServer(http.FS(staticFS)))) 657 } 658 659 // random stuff 660 s.echo.GET("/", s.handleRoot) 661 s.echo.GET("/xrpc/_health", s.handleHealth) 662 s.echo.GET("/.well-known/did.json", s.handleWellKnown) 663 s.echo.GET("/.well-known/oauth-protected-resource", s.handleOauthProtectedResource) 664 s.echo.GET("/.well-known/oauth-authorization-server", s.handleOauthAuthorizationServer) 665 s.echo.GET("/robots.txt", s.handleRobots) 666 667 // public 668 s.echo.GET("/xrpc/com.atproto.identity.resolveHandle", s.handleResolveHandle) 669 s.echo.POST("/xrpc/com.atproto.server.createAccount", s.handleCreateAccount) 670 s.echo.POST("/xrpc/com.atproto.server.createAccount", s.handleCreateAccount) 671 s.echo.POST("/xrpc/com.atproto.server.createSession", s.handleCreateSession) 672 s.echo.GET("/xrpc/com.atproto.server.describeServer", s.handleDescribeServer) 673 674 s.echo.GET("/xrpc/com.atproto.repo.describeRepo", s.handleDescribeRepo) 675 s.echo.GET("/xrpc/com.atproto.sync.listRepos", s.handleListRepos) 676 s.echo.GET("/xrpc/com.atproto.repo.listRecords", s.handleListRecords) 677 s.echo.GET("/xrpc/com.atproto.repo.getRecord", s.handleRepoGetRecord) 678 s.echo.GET("/xrpc/com.atproto.sync.getRecord", s.handleSyncGetRecord) 679 s.echo.GET("/xrpc/com.atproto.sync.getBlocks", s.handleGetBlocks) 680 s.echo.GET("/xrpc/com.atproto.sync.getLatestCommit", s.handleSyncGetLatestCommit) 681 s.echo.GET("/xrpc/com.atproto.sync.getRepoStatus", s.handleSyncGetRepoStatus) 682 s.echo.GET("/xrpc/com.atproto.sync.getRepo", s.handleSyncGetRepo) 683 s.echo.GET("/xrpc/com.atproto.sync.subscribeRepos", s.handleSyncSubscribeRepos) 684 s.echo.GET("/xrpc/com.atproto.sync.listBlobs", s.handleSyncListBlobs) 685 s.echo.GET("/xrpc/com.atproto.sync.getBlob", s.handleSyncGetBlob) 686 687 // account 688 s.echo.GET("/account", s.handleAccount) 689 s.echo.POST("/account/revoke", s.handleAccountRevoke) 690 s.echo.GET("/account/signin", s.handleAccountSigninGet) 691 s.echo.POST("/account/signin", s.handleAccountSigninPost) 692 s.echo.GET("/account/signout", s.handleAccountSignout) 693 694 // oauth account 695 s.echo.GET("/oauth/jwks", s.handleOauthJwks) 696 s.echo.GET("/oauth/authorize", s.handleOauthAuthorizeGet) 697 s.echo.POST("/oauth/authorize", s.handleOauthAuthorizePost) 698 699 // oauth authorization 700 s.echo.POST("/oauth/par", s.handleOauthPar, s.oauthProvider.BaseMiddleware) 701 s.echo.POST("/oauth/token", s.handleOauthToken, s.oauthProvider.BaseMiddleware) 702 703 // authed 704 s.echo.GET("/xrpc/com.atproto.server.getSession", s.handleGetSession, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 705 s.echo.POST("/xrpc/com.atproto.server.refreshSession", s.handleRefreshSession, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 706 s.echo.POST("/xrpc/com.atproto.server.deleteSession", s.handleDeleteSession, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 707 s.echo.POST("/xrpc/com.atproto.identity.updateHandle", s.handleIdentityUpdateHandle, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 708 s.echo.POST("/xrpc/com.atproto.server.confirmEmail", s.handleServerConfirmEmail, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 709 s.echo.POST("/xrpc/com.atproto.server.requestEmailConfirmation", s.handleServerRequestEmailConfirmation, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 710 s.echo.POST("/xrpc/com.atproto.server.requestPasswordReset", s.handleServerRequestPasswordReset) // AUTH NOT REQUIRED FOR THIS ONE 711 s.echo.POST("/xrpc/com.atproto.server.requestEmailUpdate", s.handleServerRequestEmailUpdate, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 712 s.echo.POST("/xrpc/com.atproto.server.resetPassword", s.handleServerResetPassword, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 713 s.echo.POST("/xrpc/com.atproto.server.updateEmail", s.handleServerUpdateEmail, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 714 s.echo.GET("/xrpc/com.atproto.server.getServiceAuth", s.handleServerGetServiceAuth, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 715 s.echo.GET("/xrpc/com.atproto.server.checkAccountStatus", s.handleServerCheckAccountStatus, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 716 717 // repo 718 s.echo.POST("/xrpc/com.atproto.repo.createRecord", s.handleCreateRecord, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 719 s.echo.POST("/xrpc/com.atproto.repo.putRecord", s.handlePutRecord, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 720 s.echo.POST("/xrpc/com.atproto.repo.deleteRecord", s.handleDeleteRecord, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 721 s.echo.POST("/xrpc/com.atproto.repo.applyWrites", s.handleApplyWrites, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 722 s.echo.POST("/xrpc/com.atproto.repo.uploadBlob", s.handleRepoUploadBlob, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 723 s.echo.POST("/xrpc/com.atproto.repo.importRepo", s.handleRepoImportRepo, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 724 725 // stupid silly endpoints 726 s.echo.GET("/xrpc/app.bsky.actor.getPreferences", s.handleActorGetPreferences, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 727 s.echo.POST("/xrpc/app.bsky.actor.putPreferences", s.handleActorPutPreferences, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 728 729 // are there any routes that we should be allowing without auth? i dont think so but idk 730 s.echo.GET("/xrpc/*", s.handleProxy, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 731 s.echo.POST("/xrpc/*", s.handleProxy, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 732 733 // admin routes 734 s.echo.POST("/xrpc/com.atproto.server.createInviteCode", s.handleCreateInviteCode, s.handleAdminMiddleware) 735 s.echo.POST("/xrpc/com.atproto.server.createInviteCodes", s.handleCreateInviteCodes, s.handleAdminMiddleware) 736} 737 738func (s *Server) Serve(ctx context.Context) error { 739 s.addRoutes() 740 741 s.logger.Info("migrating...") 742 743 s.db.AutoMigrate( 744 &models.Actor{}, 745 &models.Repo{}, 746 &models.InviteCode{}, 747 &models.Token{}, 748 &models.RefreshToken{}, 749 &models.Block{}, 750 &models.Record{}, 751 &models.Blob{}, 752 &models.BlobPart{}, 753 &provider.OauthToken{}, 754 &provider.OauthAuthorizationRequest{}, 755 ) 756 757 s.logger.Info("starting cocoon") 758 759 go func() { 760 if err := s.httpd.ListenAndServe(); err != nil { 761 panic(err) 762 } 763 }() 764 765 go s.backupRoutine() 766 767 for _, relay := range s.config.Relays { 768 cli := xrpc.Client{Host: relay} 769 atproto.SyncRequestCrawl(ctx, &cli, &atproto.SyncRequestCrawl_Input{ 770 Hostname: s.config.Hostname, 771 }) 772 } 773 774 <-ctx.Done() 775 776 fmt.Println("shut down") 777 778 return nil 779} 780 781func (s *Server) doBackup() { 782 start := time.Now() 783 784 s.logger.Info("beginning backup to s3...") 785 786 var buf bytes.Buffer 787 if err := func() error { 788 s.logger.Info("reading database bytes...") 789 s.db.Lock() 790 defer s.db.Unlock() 791 792 sf, err := os.Open(s.dbName) 793 if err != nil { 794 return fmt.Errorf("error opening database for backup: %w", err) 795 } 796 defer sf.Close() 797 798 if _, err := io.Copy(&buf, sf); err != nil { 799 return fmt.Errorf("error reading bytes of backup db: %w", err) 800 } 801 802 return nil 803 }(); err != nil { 804 s.logger.Error("error backing up database", "error", err) 805 return 806 } 807 808 if err := func() error { 809 s.logger.Info("sending to s3...") 810 811 currTime := time.Now().Format("2006-01-02_15-04-05") 812 key := "cocoon-backup-" + currTime + ".db" 813 814 config := &aws.Config{ 815 Region: aws.String(s.s3Config.Region), 816 Credentials: credentials.NewStaticCredentials(s.s3Config.AccessKey, s.s3Config.SecretKey, ""), 817 } 818 819 if s.s3Config.Endpoint != "" { 820 config.Endpoint = aws.String(s.s3Config.Endpoint) 821 config.S3ForcePathStyle = aws.Bool(true) 822 } 823 824 sess, err := session.NewSession(config) 825 if err != nil { 826 return err 827 } 828 829 svc := s3.New(sess) 830 831 if _, err := svc.PutObject(&s3.PutObjectInput{ 832 Bucket: aws.String(s.s3Config.Bucket), 833 Key: aws.String(key), 834 Body: bytes.NewReader(buf.Bytes()), 835 }); err != nil { 836 return fmt.Errorf("error uploading file to s3: %w", err) 837 } 838 839 s.logger.Info("finished uploading backup to s3", "key", key, "duration", time.Now().Sub(start).Seconds()) 840 841 return nil 842 }(); err != nil { 843 s.logger.Error("error uploading database backup", "error", err) 844 return 845 } 846 847 os.WriteFile("last-backup.txt", []byte(time.Now().String()), 0644) 848} 849 850func (s *Server) backupRoutine() { 851 if s.s3Config == nil || !s.s3Config.BackupsEnabled { 852 return 853 } 854 855 if s.s3Config.Region == "" { 856 s.logger.Warn("no s3 region configured but backups are enabled. backups will not run.") 857 return 858 } 859 860 if s.s3Config.Bucket == "" { 861 s.logger.Warn("no s3 bucket configured but backups are enabled. backups will not run.") 862 return 863 } 864 865 if s.s3Config.AccessKey == "" { 866 s.logger.Warn("no s3 access key configured but backups are enabled. backups will not run.") 867 return 868 } 869 870 if s.s3Config.SecretKey == "" { 871 s.logger.Warn("no s3 secret key configured but backups are enabled. backups will not run.") 872 return 873 } 874 875 shouldBackupNow := false 876 lastBackupStr, err := os.ReadFile("last-backup.txt") 877 if err != nil { 878 shouldBackupNow = true 879 } else { 880 lastBackup, err := time.Parse("2006-01-02 15:04:05.999999999 -0700 MST", string(lastBackupStr)) 881 if err != nil { 882 shouldBackupNow = true 883 } else if time.Now().Sub(lastBackup).Seconds() > 3600 { 884 shouldBackupNow = true 885 } 886 } 887 888 if shouldBackupNow { 889 go s.doBackup() 890 } 891 892 ticker := time.NewTicker(time.Hour) 893 for range ticker.C { 894 go s.doBackup() 895 } 896}