An atproto PDS written in Go
1package server 2 3import ( 4 "bytes" 5 "context" 6 "crypto/ecdsa" 7 "embed" 8 "errors" 9 "fmt" 10 "io" 11 "log/slog" 12 "net/http" 13 "net/smtp" 14 "os" 15 "path/filepath" 16 "sync" 17 "text/template" 18 "time" 19 20 "github.com/aws/aws-sdk-go/aws" 21 "github.com/aws/aws-sdk-go/aws/credentials" 22 "github.com/aws/aws-sdk-go/aws/session" 23 "github.com/aws/aws-sdk-go/service/s3" 24 "github.com/bluesky-social/indigo/api/atproto" 25 "github.com/bluesky-social/indigo/atproto/syntax" 26 "github.com/bluesky-social/indigo/events" 27 "github.com/bluesky-social/indigo/util" 28 "github.com/bluesky-social/indigo/xrpc" 29 "github.com/domodwyer/mailyak/v3" 30 "github.com/go-playground/validator" 31 "github.com/gorilla/sessions" 32 "github.com/haileyok/cocoon/identity" 33 "github.com/haileyok/cocoon/internal/db" 34 "github.com/haileyok/cocoon/internal/helpers" 35 "github.com/haileyok/cocoon/models" 36 "github.com/haileyok/cocoon/oauth/client" 37 "github.com/haileyok/cocoon/oauth/constants" 38 "github.com/haileyok/cocoon/oauth/dpop" 39 "github.com/haileyok/cocoon/oauth/provider" 40 "github.com/haileyok/cocoon/plc" 41 echo_session "github.com/labstack/echo-contrib/session" 42 "github.com/labstack/echo/v4" 43 "github.com/labstack/echo/v4/middleware" 44 slogecho "github.com/samber/slog-echo" 45 "gorm.io/driver/sqlite" 46 "gorm.io/gorm" 47) 48 49const ( 50 AccountSessionMaxAge = 30 * 24 * time.Hour // one week 51) 52 53type S3Config struct { 54 BackupsEnabled bool 55 Endpoint string 56 Region string 57 Bucket string 58 AccessKey string 59 SecretKey string 60} 61 62type Server struct { 63 http *http.Client 64 httpd *http.Server 65 mail *mailyak.MailYak 66 mailLk *sync.Mutex 67 echo *echo.Echo 68 db *db.DB 69 plcClient *plc.Client 70 logger *slog.Logger 71 config *config 72 privateKey *ecdsa.PrivateKey 73 repoman *RepoMan 74 oauthProvider *provider.Provider 75 evtman *events.EventManager 76 passport *identity.Passport 77 78 dbName string 79 s3Config *S3Config 80} 81 82type Args struct { 83 Addr string 84 DbName string 85 Logger *slog.Logger 86 Version string 87 Did string 88 Hostname string 89 RotationKeyPath string 90 JwkPath string 91 ContactEmail string 92 Relays []string 93 AdminPassword string 94 95 SmtpUser string 96 SmtpPass string 97 SmtpHost string 98 SmtpPort string 99 SmtpEmail string 100 SmtpName string 101 102 S3Config *S3Config 103 104 SessionSecret string 105} 106 107type config struct { 108 Version string 109 Did string 110 Hostname string 111 ContactEmail string 112 EnforcePeering bool 113 Relays []string 114 AdminPassword string 115 SmtpEmail string 116 SmtpName string 117} 118 119type CustomValidator struct { 120 validator *validator.Validate 121} 122 123type ValidationError struct { 124 error 125 Field string 126 Tag string 127} 128 129func (cv *CustomValidator) Validate(i any) error { 130 if err := cv.validator.Struct(i); err != nil { 131 var validateErrors validator.ValidationErrors 132 if errors.As(err, &validateErrors) && len(validateErrors) > 0 { 133 first := validateErrors[0] 134 return ValidationError{ 135 error: err, 136 Field: first.Field(), 137 Tag: first.Tag(), 138 } 139 } 140 141 return err 142 } 143 144 return nil 145} 146 147//go:embed templates/* 148var templateFS embed.FS 149 150//go:embed static/* 151var staticFS embed.FS 152 153type TemplateRenderer struct { 154 templates *template.Template 155 isDev bool 156 templatePath string 157} 158 159func (s *Server) loadTemplates() { 160 absPath, _ := filepath.Abs("server/templates/*.html") 161 if s.config.Version == "dev" { 162 tmpl := template.Must(template.ParseGlob(absPath)) 163 s.echo.Renderer = &TemplateRenderer{ 164 templates: tmpl, 165 isDev: true, 166 templatePath: absPath, 167 } 168 } else { 169 tmpl := template.Must(template.ParseFS(templateFS, "templates/*.html")) 170 s.echo.Renderer = &TemplateRenderer{ 171 templates: tmpl, 172 isDev: false, 173 } 174 } 175} 176 177func (t *TemplateRenderer) Render(w io.Writer, name string, data any, c echo.Context) error { 178 if t.isDev { 179 tmpl, err := template.ParseGlob(t.templatePath) 180 if err != nil { 181 return err 182 } 183 t.templates = tmpl 184 } 185 186 if viewContext, isMap := data.(map[string]any); isMap { 187 viewContext["reverse"] = c.Echo().Reverse 188 } 189 190 return t.templates.ExecuteTemplate(w, name, data) 191} 192 193func New(args *Args) (*Server, error) { 194 if args.Addr == "" { 195 return nil, fmt.Errorf("addr must be set") 196 } 197 198 if args.DbName == "" { 199 return nil, fmt.Errorf("db name must be set") 200 } 201 202 if args.Did == "" { 203 return nil, fmt.Errorf("cocoon did must be set") 204 } 205 206 if args.ContactEmail == "" { 207 return nil, fmt.Errorf("cocoon contact email is required") 208 } 209 210 if _, err := syntax.ParseDID(args.Did); err != nil { 211 return nil, fmt.Errorf("error parsing cocoon did: %w", err) 212 } 213 214 if args.Hostname == "" { 215 return nil, fmt.Errorf("cocoon hostname must be set") 216 } 217 218 if args.AdminPassword == "" { 219 return nil, fmt.Errorf("admin password must be set") 220 } 221 222 if args.Logger == nil { 223 args.Logger = slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{})) 224 } 225 226 if args.SessionSecret == "" { 227 panic("SESSION SECRET WAS NOT SET. THIS IS REQUIRED. ") 228 } 229 230 e := echo.New() 231 232 e.Pre(middleware.RemoveTrailingSlash()) 233 e.Pre(slogecho.New(args.Logger)) 234 e.Use(echo_session.Middleware(sessions.NewCookieStore([]byte(args.SessionSecret)))) 235 e.Use(middleware.CORSWithConfig(middleware.CORSConfig{ 236 AllowOrigins: []string{"*"}, 237 AllowHeaders: []string{"*"}, 238 AllowMethods: []string{"*"}, 239 AllowCredentials: true, 240 MaxAge: 100_000_000, 241 })) 242 243 vdtor := validator.New() 244 vdtor.RegisterValidation("atproto-handle", func(fl validator.FieldLevel) bool { 245 if _, err := syntax.ParseHandle(fl.Field().String()); err != nil { 246 return false 247 } 248 return true 249 }) 250 vdtor.RegisterValidation("atproto-did", func(fl validator.FieldLevel) bool { 251 if _, err := syntax.ParseDID(fl.Field().String()); err != nil { 252 return false 253 } 254 return true 255 }) 256 vdtor.RegisterValidation("atproto-rkey", func(fl validator.FieldLevel) bool { 257 if _, err := syntax.ParseRecordKey(fl.Field().String()); err != nil { 258 return false 259 } 260 return true 261 }) 262 vdtor.RegisterValidation("atproto-nsid", func(fl validator.FieldLevel) bool { 263 if _, err := syntax.ParseNSID(fl.Field().String()); err != nil { 264 return false 265 } 266 return true 267 }) 268 269 e.Validator = &CustomValidator{validator: vdtor} 270 271 httpd := &http.Server{ 272 Addr: args.Addr, 273 Handler: e, 274 // shitty defaults but okay for now, needed for import repo 275 ReadTimeout: 5 * time.Minute, 276 WriteTimeout: 5 * time.Minute, 277 IdleTimeout: 5 * time.Minute, 278 } 279 280 gdb, err := gorm.Open(sqlite.Open("cocoon.db"), &gorm.Config{}) 281 if err != nil { 282 return nil, err 283 } 284 dbw := db.NewDB(gdb) 285 286 rkbytes, err := os.ReadFile(args.RotationKeyPath) 287 if err != nil { 288 return nil, err 289 } 290 291 h := util.RobustHTTPClient() 292 293 plcClient, err := plc.NewClient(&plc.ClientArgs{ 294 H: h, 295 Service: "https://plc.directory", 296 PdsHostname: args.Hostname, 297 RotationKey: rkbytes, 298 }) 299 if err != nil { 300 return nil, err 301 } 302 303 jwkbytes, err := os.ReadFile(args.JwkPath) 304 if err != nil { 305 return nil, err 306 } 307 308 key, err := helpers.ParseJWKFromBytes(jwkbytes) 309 if err != nil { 310 return nil, err 311 } 312 313 var pkey ecdsa.PrivateKey 314 if err := key.Raw(&pkey); err != nil { 315 return nil, err 316 } 317 318 oauthCli := &http.Client{ 319 Timeout: 10 * time.Second, 320 } 321 322 var nonceSecret []byte 323 maybeSecret, err := os.ReadFile("nonce.secret") 324 if err != nil && !os.IsNotExist(err) { 325 args.Logger.Error("error attempting to read nonce secret", "error", err) 326 } else { 327 nonceSecret = maybeSecret 328 } 329 330 s := &Server{ 331 http: h, 332 httpd: httpd, 333 echo: e, 334 logger: args.Logger, 335 db: dbw, 336 plcClient: plcClient, 337 privateKey: &pkey, 338 config: &config{ 339 Version: args.Version, 340 Did: args.Did, 341 Hostname: args.Hostname, 342 ContactEmail: args.ContactEmail, 343 EnforcePeering: false, 344 Relays: args.Relays, 345 AdminPassword: args.AdminPassword, 346 SmtpName: args.SmtpName, 347 SmtpEmail: args.SmtpEmail, 348 }, 349 evtman: events.NewEventManager(events.NewMemPersister()), 350 passport: identity.NewPassport(h, identity.NewMemCache(10_000)), 351 352 dbName: args.DbName, 353 s3Config: args.S3Config, 354 355 oauthProvider: provider.NewProvider(provider.Args{ 356 Hostname: args.Hostname, 357 ClientManagerArgs: client.ManagerArgs{ 358 Cli: oauthCli, 359 Logger: args.Logger, 360 }, 361 DpopManagerArgs: dpop.ManagerArgs{ 362 NonceSecret: nonceSecret, 363 NonceRotationInterval: constants.NonceMaxRotationInterval / 3, 364 OnNonceSecretCreated: func(newNonce []byte) { 365 if err := os.WriteFile("nonce.secret", newNonce, 0644); err != nil { 366 args.Logger.Error("error writing new nonce secret", "error", err) 367 } 368 }, 369 Logger: args.Logger, 370 Hostname: args.Hostname, 371 }, 372 }), 373 } 374 375 s.loadTemplates() 376 377 s.repoman = NewRepoMan(s) // TODO: this is way too lazy, stop it 378 379 // TODO: should validate these args 380 if args.SmtpUser == "" || args.SmtpPass == "" || args.SmtpHost == "" || args.SmtpPort == "" || args.SmtpEmail == "" || args.SmtpName == "" { 381 args.Logger.Warn("not enough smpt args were provided. mailing will not work for your server.") 382 } else { 383 mail := mailyak.New(args.SmtpHost+":"+args.SmtpPort, smtp.PlainAuth("", args.SmtpUser, args.SmtpPass, args.SmtpHost)) 384 mail.From(s.config.SmtpEmail) 385 mail.FromName(s.config.SmtpName) 386 387 s.mail = mail 388 s.mailLk = &sync.Mutex{} 389 } 390 391 return s, nil 392} 393 394func (s *Server) addRoutes() { 395 // static 396 if s.config.Version == "dev" { 397 s.echo.Static("/static", "server/static") 398 } else { 399 s.echo.GET("/static/*", echo.WrapHandler(http.FileServer(http.FS(staticFS)))) 400 } 401 402 // random stuff 403 s.echo.GET("/", s.handleRoot) 404 s.echo.GET("/xrpc/_health", s.handleHealth) 405 s.echo.GET("/.well-known/did.json", s.handleWellKnown) 406 s.echo.GET("/.well-known/oauth-protected-resource", s.handleOauthProtectedResource) 407 s.echo.GET("/.well-known/oauth-authorization-server", s.handleOauthAuthorizationServer) 408 s.echo.GET("/robots.txt", s.handleRobots) 409 410 // public 411 s.echo.GET("/xrpc/com.atproto.identity.resolveHandle", s.handleResolveHandle) 412 s.echo.POST("/xrpc/com.atproto.server.createAccount", s.handleCreateAccount) 413 s.echo.POST("/xrpc/com.atproto.server.createAccount", s.handleCreateAccount) 414 s.echo.POST("/xrpc/com.atproto.server.createSession", s.handleCreateSession) 415 s.echo.GET("/xrpc/com.atproto.server.describeServer", s.handleDescribeServer) 416 417 s.echo.GET("/xrpc/com.atproto.repo.describeRepo", s.handleDescribeRepo) 418 s.echo.GET("/xrpc/com.atproto.sync.listRepos", s.handleListRepos) 419 s.echo.GET("/xrpc/com.atproto.repo.listRecords", s.handleListRecords) 420 s.echo.GET("/xrpc/com.atproto.repo.getRecord", s.handleRepoGetRecord) 421 s.echo.GET("/xrpc/com.atproto.sync.getRecord", s.handleSyncGetRecord) 422 s.echo.GET("/xrpc/com.atproto.sync.getBlocks", s.handleGetBlocks) 423 s.echo.GET("/xrpc/com.atproto.sync.getLatestCommit", s.handleSyncGetLatestCommit) 424 s.echo.GET("/xrpc/com.atproto.sync.getRepoStatus", s.handleSyncGetRepoStatus) 425 s.echo.GET("/xrpc/com.atproto.sync.getRepo", s.handleSyncGetRepo) 426 s.echo.GET("/xrpc/com.atproto.sync.subscribeRepos", s.handleSyncSubscribeRepos) 427 s.echo.GET("/xrpc/com.atproto.sync.listBlobs", s.handleSyncListBlobs) 428 s.echo.GET("/xrpc/com.atproto.sync.getBlob", s.handleSyncGetBlob) 429 430 // account 431 s.echo.GET("/account", s.handleAccount) 432 s.echo.POST("/account/revoke", s.handleAccountRevoke) 433 s.echo.GET("/account/signin", s.handleAccountSigninGet) 434 s.echo.POST("/account/signin", s.handleAccountSigninPost) 435 s.echo.GET("/account/signout", s.handleAccountSignout) 436 437 // oauth account 438 s.echo.GET("/oauth/jwks", s.handleOauthJwks) 439 s.echo.GET("/oauth/authorize", s.handleOauthAuthorizeGet) 440 s.echo.POST("/oauth/authorize", s.handleOauthAuthorizePost) 441 442 // oauth authorization 443 s.echo.POST("/oauth/par", s.handleOauthPar, s.oauthProvider.BaseMiddleware) 444 s.echo.POST("/oauth/token", s.handleOauthToken, s.oauthProvider.BaseMiddleware) 445 446 // authed 447 s.echo.GET("/xrpc/com.atproto.server.getSession", s.handleGetSession, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 448 s.echo.POST("/xrpc/com.atproto.server.refreshSession", s.handleRefreshSession, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 449 s.echo.POST("/xrpc/com.atproto.server.deleteSession", s.handleDeleteSession, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 450 s.echo.POST("/xrpc/com.atproto.identity.updateHandle", s.handleIdentityUpdateHandle, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 451 s.echo.POST("/xrpc/com.atproto.server.confirmEmail", s.handleServerConfirmEmail, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 452 s.echo.POST("/xrpc/com.atproto.server.requestEmailConfirmation", s.handleServerRequestEmailConfirmation, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 453 s.echo.POST("/xrpc/com.atproto.server.requestPasswordReset", s.handleServerRequestPasswordReset) // AUTH NOT REQUIRED FOR THIS ONE 454 s.echo.POST("/xrpc/com.atproto.server.requestEmailUpdate", s.handleServerRequestEmailUpdate, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 455 s.echo.POST("/xrpc/com.atproto.server.resetPassword", s.handleServerResetPassword, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 456 s.echo.POST("/xrpc/com.atproto.server.updateEmail", s.handleServerUpdateEmail, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 457 s.echo.GET("/xrpc/com.atproto.server.getServiceAuth", s.handleServerGetServiceAuth, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 458 s.echo.GET("/xrpc/com.atproto.server.checkAccountStatus", s.handleServerCheckAccountStatus, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 459 460 // repo 461 s.echo.POST("/xrpc/com.atproto.repo.createRecord", s.handleCreateRecord, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 462 s.echo.POST("/xrpc/com.atproto.repo.putRecord", s.handlePutRecord, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 463 s.echo.POST("/xrpc/com.atproto.repo.deleteRecord", s.handleDeleteRecord, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 464 s.echo.POST("/xrpc/com.atproto.repo.applyWrites", s.handleApplyWrites, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 465 s.echo.POST("/xrpc/com.atproto.repo.uploadBlob", s.handleRepoUploadBlob, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 466 s.echo.POST("/xrpc/com.atproto.repo.importRepo", s.handleRepoImportRepo, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 467 468 // stupid silly endpoints 469 s.echo.GET("/xrpc/app.bsky.actor.getPreferences", s.handleActorGetPreferences, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 470 s.echo.POST("/xrpc/app.bsky.actor.putPreferences", s.handleActorPutPreferences, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 471 472 // admin routes 473 s.echo.POST("/xrpc/com.atproto.server.createInviteCode", s.handleCreateInviteCode, s.handleAdminMiddleware) 474 s.echo.POST("/xrpc/com.atproto.server.createInviteCodes", s.handleCreateInviteCodes, s.handleAdminMiddleware) 475 476 // are there any routes that we should be allowing without auth? i dont think so but idk 477 s.echo.GET("/xrpc/*", s.handleProxy, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 478 s.echo.POST("/xrpc/*", s.handleProxy, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 479} 480 481func (s *Server) Serve(ctx context.Context) error { 482 s.addRoutes() 483 484 s.logger.Info("migrating...") 485 486 s.db.AutoMigrate( 487 &models.Actor{}, 488 &models.Repo{}, 489 &models.InviteCode{}, 490 &models.Token{}, 491 &models.RefreshToken{}, 492 &models.Block{}, 493 &models.Record{}, 494 &models.Blob{}, 495 &models.BlobPart{}, 496 &provider.OauthToken{}, 497 &provider.OauthAuthorizationRequest{}, 498 ) 499 500 s.logger.Info("starting cocoon") 501 502 go func() { 503 if err := s.httpd.ListenAndServe(); err != nil { 504 panic(err) 505 } 506 }() 507 508 go s.backupRoutine() 509 510 for _, relay := range s.config.Relays { 511 cli := xrpc.Client{Host: relay} 512 atproto.SyncRequestCrawl(ctx, &cli, &atproto.SyncRequestCrawl_Input{ 513 Hostname: s.config.Hostname, 514 }) 515 } 516 517 <-ctx.Done() 518 519 fmt.Println("shut down") 520 521 return nil 522} 523 524func (s *Server) doBackup() { 525 start := time.Now() 526 527 s.logger.Info("beginning backup to s3...") 528 529 var buf bytes.Buffer 530 if err := func() error { 531 s.logger.Info("reading database bytes...") 532 s.db.Lock() 533 defer s.db.Unlock() 534 535 sf, err := os.Open(s.dbName) 536 if err != nil { 537 return fmt.Errorf("error opening database for backup: %w", err) 538 } 539 defer sf.Close() 540 541 if _, err := io.Copy(&buf, sf); err != nil { 542 return fmt.Errorf("error reading bytes of backup db: %w", err) 543 } 544 545 return nil 546 }(); err != nil { 547 s.logger.Error("error backing up database", "error", err) 548 return 549 } 550 551 if err := func() error { 552 s.logger.Info("sending to s3...") 553 554 currTime := time.Now().Format("2006-01-02_15-04-05") 555 key := "cocoon-backup-" + currTime + ".db" 556 557 config := &aws.Config{ 558 Region: aws.String(s.s3Config.Region), 559 Credentials: credentials.NewStaticCredentials(s.s3Config.AccessKey, s.s3Config.SecretKey, ""), 560 } 561 562 if s.s3Config.Endpoint != "" { 563 config.Endpoint = aws.String(s.s3Config.Endpoint) 564 config.S3ForcePathStyle = aws.Bool(true) 565 } 566 567 sess, err := session.NewSession(config) 568 if err != nil { 569 return err 570 } 571 572 svc := s3.New(sess) 573 574 if _, err := svc.PutObject(&s3.PutObjectInput{ 575 Bucket: aws.String(s.s3Config.Bucket), 576 Key: aws.String(key), 577 Body: bytes.NewReader(buf.Bytes()), 578 }); err != nil { 579 return fmt.Errorf("error uploading file to s3: %w", err) 580 } 581 582 s.logger.Info("finished uploading backup to s3", "key", key, "duration", time.Now().Sub(start).Seconds()) 583 584 return nil 585 }(); err != nil { 586 s.logger.Error("error uploading database backup", "error", err) 587 return 588 } 589 590 os.WriteFile("last-backup.txt", []byte(time.Now().String()), 0644) 591} 592 593func (s *Server) backupRoutine() { 594 if s.s3Config == nil || !s.s3Config.BackupsEnabled { 595 return 596 } 597 598 if s.s3Config.Region == "" { 599 s.logger.Warn("no s3 region configured but backups are enabled. backups will not run.") 600 return 601 } 602 603 if s.s3Config.Bucket == "" { 604 s.logger.Warn("no s3 bucket configured but backups are enabled. backups will not run.") 605 return 606 } 607 608 if s.s3Config.AccessKey == "" { 609 s.logger.Warn("no s3 access key configured but backups are enabled. backups will not run.") 610 return 611 } 612 613 if s.s3Config.SecretKey == "" { 614 s.logger.Warn("no s3 secret key configured but backups are enabled. backups will not run.") 615 return 616 } 617 618 shouldBackupNow := false 619 lastBackupStr, err := os.ReadFile("last-backup.txt") 620 if err != nil { 621 shouldBackupNow = true 622 } else { 623 lastBackup, err := time.Parse("2006-01-02 15:04:05.999999999 -0700 MST", string(lastBackupStr)) 624 if err != nil { 625 shouldBackupNow = true 626 } else if time.Now().Sub(lastBackup).Seconds() > 3600 { 627 shouldBackupNow = true 628 } 629 } 630 631 if shouldBackupNow { 632 go s.doBackup() 633 } 634 635 ticker := time.NewTicker(time.Hour) 636 for range ticker.C { 637 go s.doBackup() 638 } 639}