An atproto PDS written in Go
1package server 2 3import ( 4 "context" 5 "crypto/ecdsa" 6 "errors" 7 "fmt" 8 "log/slog" 9 "net/http" 10 "net/smtp" 11 "os" 12 "strings" 13 "sync" 14 "time" 15 16 "github.com/Azure/go-autorest/autorest/to" 17 "github.com/bluesky-social/indigo/api/atproto" 18 "github.com/bluesky-social/indigo/atproto/syntax" 19 "github.com/bluesky-social/indigo/events" 20 "github.com/bluesky-social/indigo/util" 21 "github.com/bluesky-social/indigo/xrpc" 22 "github.com/domodwyer/mailyak/v3" 23 "github.com/go-playground/validator" 24 "github.com/golang-jwt/jwt/v4" 25 "github.com/haileyok/cocoon/identity" 26 "github.com/haileyok/cocoon/internal/helpers" 27 "github.com/haileyok/cocoon/models" 28 "github.com/haileyok/cocoon/plc" 29 "github.com/labstack/echo/v4" 30 "github.com/labstack/echo/v4/middleware" 31 "github.com/lestrrat-go/jwx/v2/jwk" 32 slogecho "github.com/samber/slog-echo" 33 "gorm.io/driver/sqlite" 34 "gorm.io/gorm" 35) 36 37type Server struct { 38 http *http.Client 39 httpd *http.Server 40 mail *mailyak.MailYak 41 mailLk *sync.Mutex 42 echo *echo.Echo 43 db *gorm.DB 44 plcClient *plc.Client 45 logger *slog.Logger 46 config *config 47 privateKey *ecdsa.PrivateKey 48 repoman *RepoMan 49 evtman *events.EventManager 50 passport *identity.Passport 51} 52 53type Args struct { 54 Addr string 55 DbName string 56 Logger *slog.Logger 57 Version string 58 Did string 59 Hostname string 60 RotationKeyPath string 61 JwkPath string 62 ContactEmail string 63 Relays []string 64 AdminPassword string 65 66 SmtpUser string 67 SmtpPass string 68 SmtpHost string 69 SmtpPort string 70 SmtpEmail string 71 SmtpName string 72} 73 74type config struct { 75 Version string 76 Did string 77 Hostname string 78 ContactEmail string 79 EnforcePeering bool 80 Relays []string 81 AdminPassword string 82 SmtpEmail string 83 SmtpName string 84} 85 86type CustomValidator struct { 87 validator *validator.Validate 88} 89 90type ValidationError struct { 91 error 92 Field string 93 Tag string 94} 95 96func (cv *CustomValidator) Validate(i any) error { 97 if err := cv.validator.Struct(i); err != nil { 98 var validateErrors validator.ValidationErrors 99 if errors.As(err, &validateErrors) && len(validateErrors) > 0 { 100 first := validateErrors[0] 101 return ValidationError{ 102 error: err, 103 Field: first.Field(), 104 Tag: first.Tag(), 105 } 106 } 107 108 return err 109 } 110 111 return nil 112} 113 114func (s *Server) handleAdminMiddleware(next echo.HandlerFunc) echo.HandlerFunc { 115 return func(e echo.Context) error { 116 username, password, ok := e.Request().BasicAuth() 117 if !ok || username != "admin" || password != s.config.AdminPassword { 118 return helpers.InputError(e, to.StringPtr("Unauthorized")) 119 } 120 121 if err := next(e); err != nil { 122 e.Error(err) 123 } 124 125 return nil 126 } 127} 128 129func (s *Server) handleSessionMiddleware(next echo.HandlerFunc) echo.HandlerFunc { 130 return func(e echo.Context) error { 131 authheader := e.Request().Header.Get("authorization") 132 if authheader == "" { 133 return e.JSON(401, map[string]string{"error": "Unauthorized"}) 134 } 135 136 pts := strings.Split(authheader, " ") 137 if len(pts) != 2 { 138 return helpers.ServerError(e, nil) 139 } 140 141 tokenstr := pts[1] 142 143 token, err := new(jwt.Parser).Parse(tokenstr, func(t *jwt.Token) (any, error) { 144 if _, ok := t.Method.(*jwt.SigningMethodECDSA); !ok { 145 return nil, fmt.Errorf("unsupported signing method: %v", t.Header["alg"]) 146 } 147 148 return s.privateKey.Public(), nil 149 }) 150 if err != nil { 151 s.logger.Error("error parsing jwt", "error", err) 152 // NOTE: https://github.com/bluesky-social/atproto/discussions/3319 153 return e.JSON(400, map[string]string{"error": "ExpiredToken", "message": "token has expired"}) 154 } 155 156 claims, ok := token.Claims.(jwt.MapClaims) 157 if !ok || !token.Valid { 158 return helpers.InputError(e, to.StringPtr("InvalidToken")) 159 } 160 161 isRefresh := e.Request().URL.Path == "/xrpc/com.atproto.server.refreshSession" 162 scope := claims["scope"].(string) 163 164 if isRefresh && scope != "com.atproto.refresh" { 165 return helpers.InputError(e, to.StringPtr("InvalidToken")) 166 } else if !isRefresh && scope != "com.atproto.access" { 167 return helpers.InputError(e, to.StringPtr("InvalidToken")) 168 } 169 170 table := "tokens" 171 if isRefresh { 172 table = "refresh_tokens" 173 } 174 175 type Result struct { 176 Found bool 177 } 178 var result Result 179 if err := s.db.Raw("SELECT EXISTS(SELECT 1 FROM "+table+" WHERE token = ?) AS found", tokenstr).Scan(&result).Error; err != nil { 180 if err == gorm.ErrRecordNotFound { 181 return helpers.InputError(e, to.StringPtr("InvalidToken")) 182 } 183 184 s.logger.Error("error getting token from db", "error", err) 185 return helpers.ServerError(e, nil) 186 } 187 188 if !result.Found { 189 return helpers.InputError(e, to.StringPtr("InvalidToken")) 190 } 191 192 exp, ok := claims["exp"].(float64) 193 if !ok { 194 s.logger.Error("error getting iat from token") 195 return helpers.ServerError(e, nil) 196 } 197 198 if exp < float64(time.Now().UTC().Unix()) { 199 return helpers.InputError(e, to.StringPtr("ExpiredToken")) 200 } 201 202 repo, err := s.getRepoActorByDid(claims["sub"].(string)) 203 if err != nil { 204 s.logger.Error("error fetching repo", "error", err) 205 return helpers.ServerError(e, nil) 206 } 207 208 e.Set("repo", repo) 209 e.Set("did", claims["sub"]) 210 e.Set("token", tokenstr) 211 212 if err := next(e); err != nil { 213 e.Error(err) 214 } 215 216 return nil 217 } 218} 219 220func New(args *Args) (*Server, error) { 221 if args.Addr == "" { 222 return nil, fmt.Errorf("addr must be set") 223 } 224 225 if args.DbName == "" { 226 return nil, fmt.Errorf("db name must be set") 227 } 228 229 if args.Did == "" { 230 return nil, fmt.Errorf("cocoon did must be set") 231 } 232 233 if args.ContactEmail == "" { 234 return nil, fmt.Errorf("cocoon contact email is required") 235 } 236 237 if _, err := syntax.ParseDID(args.Did); err != nil { 238 return nil, fmt.Errorf("error parsing cocoon did: %w", err) 239 } 240 241 if args.Hostname == "" { 242 return nil, fmt.Errorf("cocoon hostname must be set") 243 } 244 245 if args.AdminPassword == "" { 246 return nil, fmt.Errorf("admin password must be set") 247 } 248 249 if args.Logger == nil { 250 args.Logger = slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{})) 251 } 252 253 e := echo.New() 254 255 e.Pre(middleware.RemoveTrailingSlash()) 256 e.Pre(slogecho.New(args.Logger)) 257 e.Use(middleware.CORSWithConfig(middleware.CORSConfig{ 258 AllowOrigins: []string{"*"}, 259 AllowHeaders: []string{"*"}, 260 AllowMethods: []string{"*"}, 261 AllowCredentials: true, 262 MaxAge: 100_000_000, 263 })) 264 265 vdtor := validator.New() 266 vdtor.RegisterValidation("atproto-handle", func(fl validator.FieldLevel) bool { 267 if _, err := syntax.ParseHandle(fl.Field().String()); err != nil { 268 return false 269 } 270 return true 271 }) 272 vdtor.RegisterValidation("atproto-did", func(fl validator.FieldLevel) bool { 273 if _, err := syntax.ParseDID(fl.Field().String()); err != nil { 274 return false 275 } 276 return true 277 }) 278 vdtor.RegisterValidation("atproto-rkey", func(fl validator.FieldLevel) bool { 279 if _, err := syntax.ParseRecordKey(fl.Field().String()); err != nil { 280 return false 281 } 282 return true 283 }) 284 vdtor.RegisterValidation("atproto-nsid", func(fl validator.FieldLevel) bool { 285 if _, err := syntax.ParseNSID(fl.Field().String()); err != nil { 286 return false 287 } 288 return true 289 }) 290 291 e.Validator = &CustomValidator{validator: vdtor} 292 293 httpd := &http.Server{ 294 Addr: args.Addr, 295 Handler: e, 296 // shitty defaults but okay for now, needed for import repo 297 ReadTimeout: 5 * time.Minute, 298 WriteTimeout: 5 * time.Minute, 299 IdleTimeout: 5 * time.Minute, 300 } 301 302 db, err := gorm.Open(sqlite.Open("cocoon.db"), &gorm.Config{}) 303 if err != nil { 304 return nil, err 305 } 306 307 rkbytes, err := os.ReadFile(args.RotationKeyPath) 308 if err != nil { 309 return nil, err 310 } 311 312 h := util.RobustHTTPClient() 313 314 plcClient, err := plc.NewClient(&plc.ClientArgs{ 315 H: h, 316 Service: "https://plc.directory", 317 PdsHostname: args.Hostname, 318 RotationKey: rkbytes, 319 }) 320 if err != nil { 321 return nil, err 322 } 323 324 jwkbytes, err := os.ReadFile(args.JwkPath) 325 if err != nil { 326 return nil, err 327 } 328 329 key, err := jwk.ParseKey(jwkbytes) 330 if err != nil { 331 return nil, err 332 } 333 334 var pkey ecdsa.PrivateKey 335 if err := key.Raw(&pkey); err != nil { 336 return nil, err 337 } 338 339 s := &Server{ 340 http: h, 341 httpd: httpd, 342 echo: e, 343 logger: args.Logger, 344 db: db, 345 plcClient: plcClient, 346 privateKey: &pkey, 347 config: &config{ 348 Version: args.Version, 349 Did: args.Did, 350 Hostname: args.Hostname, 351 ContactEmail: args.ContactEmail, 352 EnforcePeering: false, 353 Relays: args.Relays, 354 AdminPassword: args.AdminPassword, 355 SmtpName: args.SmtpName, 356 SmtpEmail: args.SmtpEmail, 357 }, 358 evtman: events.NewEventManager(events.NewMemPersister()), 359 passport: identity.NewPassport(h, identity.NewMemCache(10_000)), 360 } 361 362 s.repoman = NewRepoMan(s) // TODO: this is way too lazy, stop it 363 364 // TODO: should validate these args 365 if args.SmtpUser == "" || args.SmtpPass == "" || args.SmtpHost == "" || args.SmtpPort == "" || args.SmtpEmail == "" || args.SmtpName == "" { 366 args.Logger.Warn("not enough smpt args were provided. mailing will not work for your server.") 367 } else { 368 mail := mailyak.New(args.SmtpHost+":"+args.SmtpPort, smtp.PlainAuth("", args.SmtpUser, args.SmtpPass, args.SmtpHost)) 369 mail.From(s.config.SmtpEmail) 370 mail.FromName(s.config.SmtpName) 371 372 s.mail = mail 373 s.mailLk = &sync.Mutex{} 374 } 375 376 return s, nil 377} 378 379func (s *Server) addRoutes() { 380 // random stuff 381 s.echo.GET("/", s.handleRoot) 382 s.echo.GET("/xrpc/_health", s.handleHealth) 383 s.echo.GET("/.well-known/did.json", s.handleWellKnown) 384 s.echo.GET("/robots.txt", s.handleRobots) 385 386 // public 387 s.echo.GET("/xrpc/com.atproto.identity.resolveHandle", s.handleResolveHandle) 388 s.echo.POST("/xrpc/com.atproto.server.createAccount", s.handleCreateAccount) 389 s.echo.POST("/xrpc/com.atproto.server.createAccount", s.handleCreateAccount) 390 s.echo.POST("/xrpc/com.atproto.server.createSession", s.handleCreateSession) 391 s.echo.GET("/xrpc/com.atproto.server.describeServer", s.handleDescribeServer) 392 393 s.echo.GET("/xrpc/com.atproto.repo.describeRepo", s.handleDescribeRepo) 394 s.echo.GET("/xrpc/com.atproto.sync.listRepos", s.handleListRepos) 395 s.echo.GET("/xrpc/com.atproto.repo.listRecords", s.handleListRecords) 396 s.echo.GET("/xrpc/com.atproto.repo.getRecord", s.handleRepoGetRecord) 397 s.echo.GET("/xrpc/com.atproto.sync.getRecord", s.handleSyncGetRecord) 398 s.echo.GET("/xrpc/com.atproto.sync.getBlocks", s.handleGetBlocks) 399 s.echo.GET("/xrpc/com.atproto.sync.getLatestCommit", s.handleSyncGetLatestCommit) 400 s.echo.GET("/xrpc/com.atproto.sync.getRepoStatus", s.handleSyncGetRepoStatus) 401 s.echo.GET("/xrpc/com.atproto.sync.getRepo", s.handleSyncGetRepo) 402 s.echo.GET("/xrpc/com.atproto.sync.subscribeRepos", s.handleSyncSubscribeRepos) 403 s.echo.GET("/xrpc/com.atproto.sync.listBlobs", s.handleSyncListBlobs) 404 s.echo.GET("/xrpc/com.atproto.sync.getBlob", s.handleSyncGetBlob) 405 406 // authed 407 s.echo.GET("/xrpc/com.atproto.server.getSession", s.handleGetSession, s.handleSessionMiddleware) 408 s.echo.POST("/xrpc/com.atproto.server.refreshSession", s.handleRefreshSession, s.handleSessionMiddleware) 409 s.echo.POST("/xrpc/com.atproto.server.deleteSession", s.handleDeleteSession, s.handleSessionMiddleware) 410 s.echo.POST("/xrpc/com.atproto.identity.updateHandle", s.handleIdentityUpdateHandle, s.handleSessionMiddleware) 411 s.echo.POST("/xrpc/com.atproto.server.confirmEmail", s.handleServerConfirmEmail, s.handleSessionMiddleware) 412 s.echo.POST("/xrpc/com.atproto.server.requestEmailConfirmation", s.handleServerRequestEmailConfirmation, s.handleSessionMiddleware) 413 s.echo.POST("/xrpc/com.atproto.server.requestPasswordReset", s.handleServerRequestPasswordReset) // AUTH NOT REQUIRED FOR THIS ONE 414 s.echo.POST("/xrpc/com.atproto.server.requestEmailUpdate", s.handleServerRequestEmailUpdate, s.handleSessionMiddleware) 415 s.echo.POST("/xrpc/com.atproto.server.resetPassword", s.handleServerResetPassword, s.handleSessionMiddleware) 416 s.echo.POST("/xrpc/com.atproto.server.updateEmail", s.handleServerUpdateEmail, s.handleSessionMiddleware) 417 418 // repo 419 s.echo.POST("/xrpc/com.atproto.repo.createRecord", s.handleCreateRecord, s.handleSessionMiddleware) 420 s.echo.POST("/xrpc/com.atproto.repo.putRecord", s.handlePutRecord, s.handleSessionMiddleware) 421 s.echo.POST("/xrpc/com.atproto.repo.deleteRecord", s.handleDeleteRecord, s.handleSessionMiddleware) 422 s.echo.POST("/xrpc/com.atproto.repo.applyWrites", s.handleApplyWrites, s.handleSessionMiddleware) 423 s.echo.POST("/xrpc/com.atproto.repo.uploadBlob", s.handleRepoUploadBlob, s.handleSessionMiddleware) 424 s.echo.POST("/xrpc/com.atproto.repo.importRepo", s.handleRepoImportRepo, s.handleSessionMiddleware) 425 426 // stupid silly endpoints 427 s.echo.GET("/xrpc/app.bsky.actor.getPreferences", s.handleActorGetPreferences, s.handleSessionMiddleware) 428 s.echo.POST("/xrpc/app.bsky.actor.putPreferences", s.handleActorPutPreferences, s.handleSessionMiddleware) 429 430 // are there any routes that we should be allowing without auth? i dont think so but idk 431 s.echo.GET("/xrpc/*", s.handleProxy, s.handleSessionMiddleware) 432 s.echo.POST("/xrpc/*", s.handleProxy, s.handleSessionMiddleware) 433 434 // admin routes 435 s.echo.POST("/xrpc/com.atproto.server.createInviteCode", s.handleCreateInviteCode, s.handleAdminMiddleware) 436 s.echo.POST("/xrpc/com.atproto.server.createInviteCodes", s.handleCreateInviteCodes, s.handleAdminMiddleware) 437} 438 439func (s *Server) Serve(ctx context.Context) error { 440 s.addRoutes() 441 442 s.logger.Info("migrating...") 443 444 s.db.AutoMigrate( 445 &models.Actor{}, 446 &models.Repo{}, 447 &models.InviteCode{}, 448 &models.Token{}, 449 &models.RefreshToken{}, 450 &models.Block{}, 451 &models.Record{}, 452 &models.Blob{}, 453 &models.BlobPart{}, 454 ) 455 456 s.logger.Info("starting cocoon") 457 458 go func() { 459 if err := s.httpd.ListenAndServe(); err != nil { 460 panic(err) 461 } 462 }() 463 464 for _, relay := range s.config.Relays { 465 cli := xrpc.Client{Host: relay} 466 atproto.SyncRequestCrawl(ctx, &cli, &atproto.SyncRequestCrawl_Input{ 467 Hostname: s.config.Hostname, 468 }) 469 } 470 471 <-ctx.Done() 472 473 fmt.Println("shut down") 474 475 return nil 476}