An atproto PDS written in Go
1package server 2 3import ( 4 "bytes" 5 "context" 6 "crypto/ecdsa" 7 "errors" 8 "fmt" 9 "io" 10 "log/slog" 11 "net/http" 12 "net/smtp" 13 "os" 14 "strings" 15 "sync" 16 "time" 17 18 "github.com/Azure/go-autorest/autorest/to" 19 "github.com/aws/aws-sdk-go/aws" 20 "github.com/aws/aws-sdk-go/aws/credentials" 21 "github.com/aws/aws-sdk-go/aws/session" 22 "github.com/aws/aws-sdk-go/service/s3" 23 "github.com/bluesky-social/indigo/api/atproto" 24 "github.com/bluesky-social/indigo/atproto/syntax" 25 "github.com/bluesky-social/indigo/events" 26 "github.com/bluesky-social/indigo/util" 27 "github.com/bluesky-social/indigo/xrpc" 28 "github.com/domodwyer/mailyak/v3" 29 "github.com/go-playground/validator" 30 "github.com/golang-jwt/jwt/v4" 31 "github.com/haileyok/cocoon/identity" 32 "github.com/haileyok/cocoon/internal/db" 33 "github.com/haileyok/cocoon/internal/helpers" 34 "github.com/haileyok/cocoon/models" 35 "github.com/haileyok/cocoon/plc" 36 "github.com/labstack/echo/v4" 37 "github.com/labstack/echo/v4/middleware" 38 "github.com/lestrrat-go/jwx/v2/jwk" 39 slogecho "github.com/samber/slog-echo" 40 "gorm.io/driver/sqlite" 41 "gorm.io/gorm" 42) 43 44type S3Config struct { 45 BackupsEnabled bool 46 Endpoint string 47 Region string 48 Bucket string 49 AccessKey string 50 SecretKey string 51} 52 53type Server struct { 54 http *http.Client 55 httpd *http.Server 56 mail *mailyak.MailYak 57 mailLk *sync.Mutex 58 echo *echo.Echo 59 db *db.DB 60 plcClient *plc.Client 61 logger *slog.Logger 62 config *config 63 privateKey *ecdsa.PrivateKey 64 repoman *RepoMan 65 evtman *events.EventManager 66 passport *identity.Passport 67 68 dbName string 69 s3Config *S3Config 70} 71 72type Args struct { 73 Addr string 74 DbName string 75 Logger *slog.Logger 76 Version string 77 Did string 78 Hostname string 79 RotationKeyPath string 80 JwkPath string 81 ContactEmail string 82 Relays []string 83 AdminPassword string 84 85 SmtpUser string 86 SmtpPass string 87 SmtpHost string 88 SmtpPort string 89 SmtpEmail string 90 SmtpName string 91 92 S3Config *S3Config 93} 94 95type config struct { 96 Version string 97 Did string 98 Hostname string 99 ContactEmail string 100 EnforcePeering bool 101 Relays []string 102 AdminPassword string 103 SmtpEmail string 104 SmtpName string 105} 106 107type CustomValidator struct { 108 validator *validator.Validate 109} 110 111type ValidationError struct { 112 error 113 Field string 114 Tag string 115} 116 117func (cv *CustomValidator) Validate(i any) error { 118 if err := cv.validator.Struct(i); err != nil { 119 var validateErrors validator.ValidationErrors 120 if errors.As(err, &validateErrors) && len(validateErrors) > 0 { 121 first := validateErrors[0] 122 return ValidationError{ 123 error: err, 124 Field: first.Field(), 125 Tag: first.Tag(), 126 } 127 } 128 129 return err 130 } 131 132 return nil 133} 134 135func (s *Server) handleAdminMiddleware(next echo.HandlerFunc) echo.HandlerFunc { 136 return func(e echo.Context) error { 137 username, password, ok := e.Request().BasicAuth() 138 if !ok || username != "admin" || password != s.config.AdminPassword { 139 return helpers.InputError(e, to.StringPtr("Unauthorized")) 140 } 141 142 if err := next(e); err != nil { 143 e.Error(err) 144 } 145 146 return nil 147 } 148} 149 150func (s *Server) handleSessionMiddleware(next echo.HandlerFunc) echo.HandlerFunc { 151 return func(e echo.Context) error { 152 authheader := e.Request().Header.Get("authorization") 153 if authheader == "" { 154 return e.JSON(401, map[string]string{"error": "Unauthorized"}) 155 } 156 157 pts := strings.Split(authheader, " ") 158 if len(pts) != 2 { 159 return helpers.ServerError(e, nil) 160 } 161 162 tokenstr := pts[1] 163 164 token, err := new(jwt.Parser).Parse(tokenstr, func(t *jwt.Token) (any, error) { 165 if _, ok := t.Method.(*jwt.SigningMethodECDSA); !ok { 166 return nil, fmt.Errorf("unsupported signing method: %v", t.Header["alg"]) 167 } 168 169 return s.privateKey.Public(), nil 170 }) 171 if err != nil { 172 s.logger.Error("error parsing jwt", "error", err) 173 // NOTE: https://github.com/bluesky-social/atproto/discussions/3319 174 return e.JSON(400, map[string]string{"error": "ExpiredToken", "message": "token has expired"}) 175 } 176 177 claims, ok := token.Claims.(jwt.MapClaims) 178 if !ok || !token.Valid { 179 return helpers.InputError(e, to.StringPtr("InvalidToken")) 180 } 181 182 isRefresh := e.Request().URL.Path == "/xrpc/com.atproto.server.refreshSession" 183 scope := claims["scope"].(string) 184 185 if isRefresh && scope != "com.atproto.refresh" { 186 return helpers.InputError(e, to.StringPtr("InvalidToken")) 187 } else if !isRefresh && scope != "com.atproto.access" { 188 return helpers.InputError(e, to.StringPtr("InvalidToken")) 189 } 190 191 table := "tokens" 192 if isRefresh { 193 table = "refresh_tokens" 194 } 195 196 type Result struct { 197 Found bool 198 } 199 var result Result 200 if err := s.db.Raw("SELECT EXISTS(SELECT 1 FROM "+table+" WHERE token = ?) AS found", nil, tokenstr).Scan(&result).Error; err != nil { 201 if err == gorm.ErrRecordNotFound { 202 return helpers.InputError(e, to.StringPtr("InvalidToken")) 203 } 204 205 s.logger.Error("error getting token from db", "error", err) 206 return helpers.ServerError(e, nil) 207 } 208 209 if !result.Found { 210 return helpers.InputError(e, to.StringPtr("InvalidToken")) 211 } 212 213 exp, ok := claims["exp"].(float64) 214 if !ok { 215 s.logger.Error("error getting iat from token") 216 return helpers.ServerError(e, nil) 217 } 218 219 if exp < float64(time.Now().UTC().Unix()) { 220 return helpers.InputError(e, to.StringPtr("ExpiredToken")) 221 } 222 223 repo, err := s.getRepoActorByDid(claims["sub"].(string)) 224 if err != nil { 225 s.logger.Error("error fetching repo", "error", err) 226 return helpers.ServerError(e, nil) 227 } 228 229 e.Set("repo", repo) 230 e.Set("did", claims["sub"]) 231 e.Set("token", tokenstr) 232 233 if err := next(e); err != nil { 234 e.Error(err) 235 } 236 237 return nil 238 } 239} 240 241func New(args *Args) (*Server, error) { 242 if args.Addr == "" { 243 return nil, fmt.Errorf("addr must be set") 244 } 245 246 if args.DbName == "" { 247 return nil, fmt.Errorf("db name must be set") 248 } 249 250 if args.Did == "" { 251 return nil, fmt.Errorf("cocoon did must be set") 252 } 253 254 if args.ContactEmail == "" { 255 return nil, fmt.Errorf("cocoon contact email is required") 256 } 257 258 if _, err := syntax.ParseDID(args.Did); err != nil { 259 return nil, fmt.Errorf("error parsing cocoon did: %w", err) 260 } 261 262 if args.Hostname == "" { 263 return nil, fmt.Errorf("cocoon hostname must be set") 264 } 265 266 if args.AdminPassword == "" { 267 return nil, fmt.Errorf("admin password must be set") 268 } 269 270 if args.Logger == nil { 271 args.Logger = slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{})) 272 } 273 274 e := echo.New() 275 276 e.Pre(middleware.RemoveTrailingSlash()) 277 e.Pre(slogecho.New(args.Logger)) 278 e.Use(middleware.CORSWithConfig(middleware.CORSConfig{ 279 AllowOrigins: []string{"*"}, 280 AllowHeaders: []string{"*"}, 281 AllowMethods: []string{"*"}, 282 AllowCredentials: true, 283 MaxAge: 100_000_000, 284 })) 285 286 vdtor := validator.New() 287 vdtor.RegisterValidation("atproto-handle", func(fl validator.FieldLevel) bool { 288 if _, err := syntax.ParseHandle(fl.Field().String()); err != nil { 289 return false 290 } 291 return true 292 }) 293 vdtor.RegisterValidation("atproto-did", func(fl validator.FieldLevel) bool { 294 if _, err := syntax.ParseDID(fl.Field().String()); err != nil { 295 return false 296 } 297 return true 298 }) 299 vdtor.RegisterValidation("atproto-rkey", func(fl validator.FieldLevel) bool { 300 if _, err := syntax.ParseRecordKey(fl.Field().String()); err != nil { 301 return false 302 } 303 return true 304 }) 305 vdtor.RegisterValidation("atproto-nsid", func(fl validator.FieldLevel) bool { 306 if _, err := syntax.ParseNSID(fl.Field().String()); err != nil { 307 return false 308 } 309 return true 310 }) 311 312 e.Validator = &CustomValidator{validator: vdtor} 313 314 httpd := &http.Server{ 315 Addr: args.Addr, 316 Handler: e, 317 // shitty defaults but okay for now, needed for import repo 318 ReadTimeout: 5 * time.Minute, 319 WriteTimeout: 5 * time.Minute, 320 IdleTimeout: 5 * time.Minute, 321 } 322 323 gdb, err := gorm.Open(sqlite.Open("cocoon.db"), &gorm.Config{}) 324 if err != nil { 325 return nil, err 326 } 327 dbw := db.NewDB(gdb) 328 329 rkbytes, err := os.ReadFile(args.RotationKeyPath) 330 if err != nil { 331 return nil, err 332 } 333 334 h := util.RobustHTTPClient() 335 336 plcClient, err := plc.NewClient(&plc.ClientArgs{ 337 H: h, 338 Service: "https://plc.directory", 339 PdsHostname: args.Hostname, 340 RotationKey: rkbytes, 341 }) 342 if err != nil { 343 return nil, err 344 } 345 346 jwkbytes, err := os.ReadFile(args.JwkPath) 347 if err != nil { 348 return nil, err 349 } 350 351 key, err := jwk.ParseKey(jwkbytes) 352 if err != nil { 353 return nil, err 354 } 355 356 var pkey ecdsa.PrivateKey 357 if err := key.Raw(&pkey); err != nil { 358 return nil, err 359 } 360 361 s := &Server{ 362 http: h, 363 httpd: httpd, 364 echo: e, 365 logger: args.Logger, 366 db: dbw, 367 plcClient: plcClient, 368 privateKey: &pkey, 369 config: &config{ 370 Version: args.Version, 371 Did: args.Did, 372 Hostname: args.Hostname, 373 ContactEmail: args.ContactEmail, 374 EnforcePeering: false, 375 Relays: args.Relays, 376 AdminPassword: args.AdminPassword, 377 SmtpName: args.SmtpName, 378 SmtpEmail: args.SmtpEmail, 379 }, 380 evtman: events.NewEventManager(events.NewMemPersister()), 381 passport: identity.NewPassport(h, identity.NewMemCache(10_000)), 382 383 dbName: args.DbName, 384 s3Config: args.S3Config, 385 } 386 387 s.repoman = NewRepoMan(s) // TODO: this is way too lazy, stop it 388 389 // TODO: should validate these args 390 if args.SmtpUser == "" || args.SmtpPass == "" || args.SmtpHost == "" || args.SmtpPort == "" || args.SmtpEmail == "" || args.SmtpName == "" { 391 args.Logger.Warn("not enough smpt args were provided. mailing will not work for your server.") 392 } else { 393 mail := mailyak.New(args.SmtpHost+":"+args.SmtpPort, smtp.PlainAuth("", args.SmtpUser, args.SmtpPass, args.SmtpHost)) 394 mail.From(s.config.SmtpEmail) 395 mail.FromName(s.config.SmtpName) 396 397 s.mail = mail 398 s.mailLk = &sync.Mutex{} 399 } 400 401 return s, nil 402} 403 404func (s *Server) addRoutes() { 405 // random stuff 406 s.echo.GET("/", s.handleRoot) 407 s.echo.GET("/xrpc/_health", s.handleHealth) 408 s.echo.GET("/.well-known/did.json", s.handleWellKnown) 409 s.echo.GET("/robots.txt", s.handleRobots) 410 411 // public 412 s.echo.GET("/xrpc/com.atproto.identity.resolveHandle", s.handleResolveHandle) 413 s.echo.POST("/xrpc/com.atproto.server.createAccount", s.handleCreateAccount) 414 s.echo.POST("/xrpc/com.atproto.server.createAccount", s.handleCreateAccount) 415 s.echo.POST("/xrpc/com.atproto.server.createSession", s.handleCreateSession) 416 s.echo.GET("/xrpc/com.atproto.server.describeServer", s.handleDescribeServer) 417 418 s.echo.GET("/xrpc/com.atproto.repo.describeRepo", s.handleDescribeRepo) 419 s.echo.GET("/xrpc/com.atproto.sync.listRepos", s.handleListRepos) 420 s.echo.GET("/xrpc/com.atproto.repo.listRecords", s.handleListRecords) 421 s.echo.GET("/xrpc/com.atproto.repo.getRecord", s.handleRepoGetRecord) 422 s.echo.GET("/xrpc/com.atproto.sync.getRecord", s.handleSyncGetRecord) 423 s.echo.GET("/xrpc/com.atproto.sync.getBlocks", s.handleGetBlocks) 424 s.echo.GET("/xrpc/com.atproto.sync.getLatestCommit", s.handleSyncGetLatestCommit) 425 s.echo.GET("/xrpc/com.atproto.sync.getRepoStatus", s.handleSyncGetRepoStatus) 426 s.echo.GET("/xrpc/com.atproto.sync.getRepo", s.handleSyncGetRepo) 427 s.echo.GET("/xrpc/com.atproto.sync.subscribeRepos", s.handleSyncSubscribeRepos) 428 s.echo.GET("/xrpc/com.atproto.sync.listBlobs", s.handleSyncListBlobs) 429 s.echo.GET("/xrpc/com.atproto.sync.getBlob", s.handleSyncGetBlob) 430 431 // authed 432 s.echo.GET("/xrpc/com.atproto.server.getSession", s.handleGetSession, s.handleSessionMiddleware) 433 s.echo.POST("/xrpc/com.atproto.server.refreshSession", s.handleRefreshSession, s.handleSessionMiddleware) 434 s.echo.POST("/xrpc/com.atproto.server.deleteSession", s.handleDeleteSession, s.handleSessionMiddleware) 435 s.echo.POST("/xrpc/com.atproto.identity.updateHandle", s.handleIdentityUpdateHandle, s.handleSessionMiddleware) 436 s.echo.POST("/xrpc/com.atproto.server.confirmEmail", s.handleServerConfirmEmail, s.handleSessionMiddleware) 437 s.echo.POST("/xrpc/com.atproto.server.requestEmailConfirmation", s.handleServerRequestEmailConfirmation, s.handleSessionMiddleware) 438 s.echo.POST("/xrpc/com.atproto.server.requestPasswordReset", s.handleServerRequestPasswordReset) // AUTH NOT REQUIRED FOR THIS ONE 439 s.echo.POST("/xrpc/com.atproto.server.requestEmailUpdate", s.handleServerRequestEmailUpdate, s.handleSessionMiddleware) 440 s.echo.POST("/xrpc/com.atproto.server.resetPassword", s.handleServerResetPassword, s.handleSessionMiddleware) 441 s.echo.POST("/xrpc/com.atproto.server.updateEmail", s.handleServerUpdateEmail, s.handleSessionMiddleware) 442 443 // repo 444 s.echo.POST("/xrpc/com.atproto.repo.createRecord", s.handleCreateRecord, s.handleSessionMiddleware) 445 s.echo.POST("/xrpc/com.atproto.repo.putRecord", s.handlePutRecord, s.handleSessionMiddleware) 446 s.echo.POST("/xrpc/com.atproto.repo.deleteRecord", s.handleDeleteRecord, s.handleSessionMiddleware) 447 s.echo.POST("/xrpc/com.atproto.repo.applyWrites", s.handleApplyWrites, s.handleSessionMiddleware) 448 s.echo.POST("/xrpc/com.atproto.repo.uploadBlob", s.handleRepoUploadBlob, s.handleSessionMiddleware) 449 s.echo.POST("/xrpc/com.atproto.repo.importRepo", s.handleRepoImportRepo, s.handleSessionMiddleware) 450 451 // stupid silly endpoints 452 s.echo.GET("/xrpc/app.bsky.actor.getPreferences", s.handleActorGetPreferences, s.handleSessionMiddleware) 453 s.echo.POST("/xrpc/app.bsky.actor.putPreferences", s.handleActorPutPreferences, s.handleSessionMiddleware) 454 455 // are there any routes that we should be allowing without auth? i dont think so but idk 456 s.echo.GET("/xrpc/*", s.handleProxy, s.handleSessionMiddleware) 457 s.echo.POST("/xrpc/*", s.handleProxy, s.handleSessionMiddleware) 458 459 // admin routes 460 s.echo.POST("/xrpc/com.atproto.server.createInviteCode", s.handleCreateInviteCode, s.handleAdminMiddleware) 461 s.echo.POST("/xrpc/com.atproto.server.createInviteCodes", s.handleCreateInviteCodes, s.handleAdminMiddleware) 462} 463 464func (s *Server) Serve(ctx context.Context) error { 465 s.addRoutes() 466 467 s.logger.Info("migrating...") 468 469 s.db.AutoMigrate( 470 &models.Actor{}, 471 &models.Repo{}, 472 &models.InviteCode{}, 473 &models.Token{}, 474 &models.RefreshToken{}, 475 &models.Block{}, 476 &models.Record{}, 477 &models.Blob{}, 478 &models.BlobPart{}, 479 ) 480 481 s.logger.Info("starting cocoon") 482 483 go func() { 484 if err := s.httpd.ListenAndServe(); err != nil { 485 panic(err) 486 } 487 }() 488 489 go s.backupRoutine() 490 491 for _, relay := range s.config.Relays { 492 cli := xrpc.Client{Host: relay} 493 atproto.SyncRequestCrawl(ctx, &cli, &atproto.SyncRequestCrawl_Input{ 494 Hostname: s.config.Hostname, 495 }) 496 } 497 498 <-ctx.Done() 499 500 fmt.Println("shut down") 501 502 return nil 503} 504 505func (s *Server) doBackup() { 506 start := time.Now() 507 508 s.logger.Info("beginning backup to s3...") 509 510 var buf bytes.Buffer 511 if err := func() error { 512 s.logger.Info("reading database bytes...") 513 s.db.Lock() 514 defer s.db.Unlock() 515 516 sf, err := os.Open(s.dbName) 517 if err != nil { 518 return fmt.Errorf("error opening database for backup: %w", err) 519 } 520 defer sf.Close() 521 522 if _, err := io.Copy(&buf, sf); err != nil { 523 return fmt.Errorf("error reading bytes of backup db: %w", err) 524 } 525 526 return nil 527 }(); err != nil { 528 s.logger.Error("error backing up database", "error", err) 529 return 530 } 531 532 if err := func() error { 533 s.logger.Info("sending to s3...") 534 535 currTime := time.Now().Format("2006-01-02_15-04-05") 536 key := "cocoon-backup-" + currTime + ".db" 537 538 config := &aws.Config{ 539 Region: aws.String(s.s3Config.Region), 540 Credentials: credentials.NewStaticCredentials(s.s3Config.AccessKey, s.s3Config.SecretKey, ""), 541 } 542 543 if s.s3Config.Endpoint != "" { 544 config.Endpoint = aws.String(s.s3Config.Endpoint) 545 config.S3ForcePathStyle = aws.Bool(true) 546 } 547 548 sess, err := session.NewSession(config) 549 if err != nil { 550 return err 551 } 552 553 svc := s3.New(sess) 554 555 if _, err := svc.PutObject(&s3.PutObjectInput{ 556 Bucket: aws.String(s.s3Config.Bucket), 557 Key: aws.String(key), 558 Body: bytes.NewReader(buf.Bytes()), 559 }); err != nil { 560 return fmt.Errorf("error uploading file to s3: %w", err) 561 } 562 563 s.logger.Info("finished uploading backup to s3", "key", key, "duration", time.Now().Sub(start).Seconds()) 564 565 return nil 566 }(); err != nil { 567 s.logger.Error("error uploading database backup", "error", err) 568 return 569 } 570 571 os.WriteFile("last-backup.txt", []byte(time.Now().String()), 0644) 572} 573 574func (s *Server) backupRoutine() { 575 if s.s3Config == nil || !s.s3Config.BackupsEnabled { 576 return 577 } 578 579 if s.s3Config.Region == "" { 580 s.logger.Warn("no s3 region configured but backups are enabled. backups will not run.") 581 return 582 } 583 584 if s.s3Config.Bucket == "" { 585 s.logger.Warn("no s3 bucket configured but backups are enabled. backups will not run.") 586 return 587 } 588 589 if s.s3Config.AccessKey == "" { 590 s.logger.Warn("no s3 access key configured but backups are enabled. backups will not run.") 591 return 592 } 593 594 if s.s3Config.SecretKey == "" { 595 s.logger.Warn("no s3 secret key configured but backups are enabled. backups will not run.") 596 return 597 } 598 599 shouldBackupNow := false 600 lastBackupStr, err := os.ReadFile("last-backup.txt") 601 if err != nil { 602 shouldBackupNow = true 603 } else { 604 lastBackup, err := time.Parse("2006-01-02 15:04:05.999999999 -0700 MST", string(lastBackupStr)) 605 if err != nil { 606 shouldBackupNow = true 607 } else if time.Now().Sub(lastBackup).Seconds() > 3600 { 608 shouldBackupNow = true 609 } 610 } 611 612 if shouldBackupNow { 613 go s.doBackup() 614 } 615 616 ticker := time.NewTicker(time.Hour) 617 for range ticker.C { 618 go s.doBackup() 619 } 620}