An atproto PDS written in Go
1package server 2 3import ( 4 "context" 5 "crypto/ecdsa" 6 "errors" 7 "fmt" 8 "log/slog" 9 "net/http" 10 "os" 11 "strings" 12 "time" 13 14 "github.com/Azure/go-autorest/autorest/to" 15 "github.com/bluesky-social/indigo/api/atproto" 16 "github.com/bluesky-social/indigo/atproto/syntax" 17 "github.com/bluesky-social/indigo/events" 18 "github.com/bluesky-social/indigo/xrpc" 19 "github.com/go-playground/validator" 20 "github.com/golang-jwt/jwt/v4" 21 "github.com/haileyok/cocoon/identity" 22 "github.com/haileyok/cocoon/internal/helpers" 23 "github.com/haileyok/cocoon/models" 24 "github.com/haileyok/cocoon/plc" 25 "github.com/labstack/echo/v4" 26 "github.com/labstack/echo/v4/middleware" 27 "github.com/lestrrat-go/jwx/v2/jwk" 28 slogecho "github.com/samber/slog-echo" 29 "gorm.io/driver/sqlite" 30 "gorm.io/gorm" 31) 32 33type Server struct { 34 httpd *http.Server 35 echo *echo.Echo 36 db *gorm.DB 37 plcClient *plc.Client 38 logger *slog.Logger 39 config *config 40 privateKey *ecdsa.PrivateKey 41 repoman *RepoMan 42 evtman *events.EventManager 43 passport *identity.Passport 44} 45 46type Args struct { 47 Addr string 48 DbName string 49 Logger *slog.Logger 50 Version string 51 Did string 52 Hostname string 53 RotationKeyPath string 54 JwkPath string 55 ContactEmail string 56 Relays []string 57} 58 59type config struct { 60 Version string 61 Did string 62 Hostname string 63 ContactEmail string 64 EnforcePeering bool 65 Relays []string 66} 67 68type CustomValidator struct { 69 validator *validator.Validate 70} 71 72type ValidationError struct { 73 error 74 Field string 75 Tag string 76} 77 78func (cv *CustomValidator) Validate(i any) error { 79 if err := cv.validator.Struct(i); err != nil { 80 var validateErrors validator.ValidationErrors 81 if errors.As(err, &validateErrors) && len(validateErrors) > 0 { 82 first := validateErrors[0] 83 return ValidationError{ 84 error: err, 85 Field: first.Field(), 86 Tag: first.Tag(), 87 } 88 } 89 90 return err 91 } 92 93 return nil 94} 95 96func (s *Server) handleSessionMiddleware(next echo.HandlerFunc) echo.HandlerFunc { 97 return func(e echo.Context) error { 98 authheader := e.Request().Header.Get("authorization") 99 if authheader == "" { 100 return e.JSON(401, map[string]string{"error": "Unauthorized"}) 101 } 102 103 pts := strings.Split(authheader, " ") 104 if len(pts) != 2 { 105 return helpers.ServerError(e, nil) 106 } 107 108 tokenstr := pts[1] 109 110 token, err := new(jwt.Parser).Parse(tokenstr, func(t *jwt.Token) (any, error) { 111 if _, ok := t.Method.(*jwt.SigningMethodECDSA); !ok { 112 return nil, fmt.Errorf("unsupported signing method: %v", t.Header["alg"]) 113 } 114 115 return s.privateKey.Public(), nil 116 }) 117 if err != nil { 118 s.logger.Error("error parsing jwt", "error", err) 119 return helpers.InputError(e, to.StringPtr("InvalidToken")) 120 } 121 122 claims, ok := token.Claims.(jwt.MapClaims) 123 if !ok || !token.Valid { 124 return helpers.InputError(e, to.StringPtr("InvalidToken")) 125 } 126 127 isRefresh := e.Request().URL.Path == "/xrpc/com.atproto.server.refreshSession" 128 scope := claims["scope"].(string) 129 130 if isRefresh && scope != "com.atproto.refresh" { 131 return helpers.InputError(e, to.StringPtr("InvalidToken")) 132 } else if !isRefresh && scope != "com.atproto.access" { 133 return helpers.InputError(e, to.StringPtr("InvalidToken")) 134 } 135 136 table := "tokens" 137 if isRefresh { 138 table = "refresh_tokens" 139 } 140 141 type Result struct { 142 Found bool 143 } 144 var result Result 145 if err := s.db.Raw("SELECT EXISTS(SELECT 1 FROM "+table+" WHERE token = ?) AS found", tokenstr).Scan(&result).Error; err != nil { 146 if err == gorm.ErrRecordNotFound { 147 return helpers.InputError(e, to.StringPtr("InvalidToken")) 148 } 149 150 s.logger.Error("error getting token from db", "error", err) 151 return helpers.ServerError(e, nil) 152 } 153 154 if !result.Found { 155 return helpers.InputError(e, to.StringPtr("InvalidToken")) 156 } 157 158 exp, ok := claims["exp"].(float64) 159 if !ok { 160 s.logger.Error("error getting iat from token") 161 return helpers.ServerError(e, nil) 162 } 163 164 if exp < float64(time.Now().UTC().Unix()) { 165 return helpers.InputError(e, to.StringPtr("ExpiredToken")) 166 } 167 168 e.Set("did", claims["sub"]) 169 170 repo, err := s.getRepoActorByDid(claims["sub"].(string)) 171 if err != nil { 172 s.logger.Error("error fetching repo", "error", err) 173 return helpers.ServerError(e, nil) 174 } 175 e.Set("repo", repo) 176 177 e.Set("token", tokenstr) 178 179 if err := next(e); err != nil { 180 e.Error(err) 181 } 182 183 return nil 184 } 185} 186 187func New(args *Args) (*Server, error) { 188 if args.Addr == "" { 189 return nil, fmt.Errorf("addr must be set") 190 } 191 192 if args.DbName == "" { 193 return nil, fmt.Errorf("db name must be set") 194 } 195 196 if args.Did == "" { 197 return nil, fmt.Errorf("cocoon did must be set") 198 } 199 200 if args.ContactEmail == "" { 201 return nil, fmt.Errorf("cocoon contact email is required") 202 } 203 204 if _, err := syntax.ParseDID(args.Did); err != nil { 205 return nil, fmt.Errorf("error parsing cocoon did: %w", err) 206 } 207 208 if args.Hostname == "" { 209 return nil, fmt.Errorf("cocoon hostname must be set") 210 } 211 212 if args.Logger == nil { 213 args.Logger = slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{})) 214 } 215 216 e := echo.New() 217 218 e.Pre(middleware.RemoveTrailingSlash()) 219 e.Pre(slogecho.New(args.Logger)) 220 e.Use(middleware.CORSWithConfig(middleware.CORSConfig{ 221 AllowOrigins: []string{"*"}, 222 AllowHeaders: []string{"*"}, 223 AllowMethods: []string{"*"}, 224 AllowCredentials: true, 225 MaxAge: 100_000_000, 226 })) 227 228 vdtor := validator.New() 229 vdtor.RegisterValidation("atproto-handle", func(fl validator.FieldLevel) bool { 230 if _, err := syntax.ParseHandle(fl.Field().String()); err != nil { 231 return false 232 } 233 return true 234 }) 235 vdtor.RegisterValidation("atproto-did", func(fl validator.FieldLevel) bool { 236 if _, err := syntax.ParseDID(fl.Field().String()); err != nil { 237 return false 238 } 239 return true 240 }) 241 vdtor.RegisterValidation("atproto-rkey", func(fl validator.FieldLevel) bool { 242 if _, err := syntax.ParseRecordKey(fl.Field().String()); err != nil { 243 return false 244 } 245 return true 246 }) 247 vdtor.RegisterValidation("atproto-nsid", func(fl validator.FieldLevel) bool { 248 if _, err := syntax.ParseNSID(fl.Field().String()); err != nil { 249 return false 250 } 251 return true 252 }) 253 254 e.Validator = &CustomValidator{validator: vdtor} 255 256 httpd := &http.Server{ 257 Addr: args.Addr, 258 Handler: e, 259 } 260 261 db, err := gorm.Open(sqlite.Open("cocoon.db"), &gorm.Config{}) 262 if err != nil { 263 return nil, err 264 } 265 266 rkbytes, err := os.ReadFile(args.RotationKeyPath) 267 if err != nil { 268 return nil, err 269 } 270 271 plcClient, err := plc.NewClient(&plc.ClientArgs{ 272 Service: "https://plc.directory", 273 PdsHostname: args.Hostname, 274 RotationKey: rkbytes, 275 }) 276 if err != nil { 277 return nil, err 278 } 279 280 jwkbytes, err := os.ReadFile(args.JwkPath) 281 if err != nil { 282 return nil, err 283 } 284 285 key, err := jwk.ParseKey(jwkbytes) 286 if err != nil { 287 return nil, err 288 } 289 290 var pkey ecdsa.PrivateKey 291 if err := key.Raw(&pkey); err != nil { 292 return nil, err 293 } 294 295 s := &Server{ 296 httpd: httpd, 297 echo: e, 298 logger: args.Logger, 299 db: db, 300 plcClient: plcClient, 301 privateKey: &pkey, 302 config: &config{ 303 Version: args.Version, 304 Did: args.Did, 305 Hostname: args.Hostname, 306 ContactEmail: args.ContactEmail, 307 EnforcePeering: false, 308 Relays: args.Relays, 309 }, 310 evtman: events.NewEventManager(events.NewMemPersister()), 311 passport: identity.NewPassport(identity.NewMemCache(10_000)), 312 } 313 314 s.repoman = NewRepoMan(s) // TODO: this is way too lazy, stop it 315 316 return s, nil 317} 318 319func (s *Server) addRoutes() { 320 s.echo.GET("/", s.handleRoot) 321 s.echo.GET("/xrpc/_health", s.handleHealth) 322 s.echo.GET("/.well-known/did.json", s.handleWellKnown) 323 s.echo.GET("/robots.txt", s.handleRobots) 324 325 // public 326 s.echo.GET("/xrpc/com.atproto.identity.resolveHandle", s.handleResolveHandle) 327 s.echo.POST("/xrpc/com.atproto.server.createAccount", s.handleCreateAccount) 328 s.echo.POST("/xrpc/com.atproto.server.createAccount", s.handleCreateAccount) 329 s.echo.POST("/xrpc/com.atproto.server.createSession", s.handleCreateSession) 330 s.echo.GET("/xrpc/com.atproto.server.describeServer", s.handleDescribeServer) 331 332 s.echo.GET("/xrpc/com.atproto.repo.describeRepo", s.handleDescribeRepo) 333 s.echo.GET("/xrpc/com.atproto.sync.listRepos", s.handleListRepos) 334 s.echo.GET("/xrpc/com.atproto.repo.listRecords", s.handleListRecords) 335 s.echo.GET("/xrpc/com.atproto.repo.getRecord", s.handleRepoGetRecord) 336 s.echo.GET("/xrpc/com.atproto.sync.getRecord", s.handleSyncGetRecord) 337 s.echo.GET("/xrpc/com.atproto.sync.getBlocks", s.handleGetBlocks) 338 s.echo.GET("/xrpc/com.atproto.sync.getLatestCommit", s.handleSyncGetLatestCommit) 339 s.echo.GET("/xrpc/com.atproto.sync.getRepoStatus", s.handleSyncGetRepoStatus) 340 s.echo.GET("/xrpc/com.atproto.sync.getRepo", s.handleSyncGetRepo) 341 s.echo.GET("/xrpc/com.atproto.sync.subscribeRepos", s.handleSyncSubscribeRepos) 342 s.echo.GET("/xrpc/com.atproto.sync.listBlobs", s.handleSyncListBlobs) 343 s.echo.GET("/xrpc/com.atproto.sync.getBlob", s.handleSyncGetBlob) 344 345 // authed 346 s.echo.GET("/xrpc/com.atproto.server.getSession", s.handleGetSession, s.handleSessionMiddleware) 347 s.echo.POST("/xrpc/com.atproto.server.refreshSession", s.handleRefreshSession, s.handleSessionMiddleware) 348 s.echo.POST("/xrpc/com.atproto.server.deleteSession", s.handleDeleteSession, s.handleSessionMiddleware) 349 s.echo.POST("/xrpc/com.atproto.identity.updateHandle", s.handleIdentityUpdateHandle, s.handleSessionMiddleware) 350 351 // repo 352 s.echo.POST("/xrpc/com.atproto.repo.createRecord", s.handleCreateRecord, s.handleSessionMiddleware) 353 s.echo.POST("/xrpc/com.atproto.repo.putRecord", s.handlePutRecord, s.handleSessionMiddleware) 354 s.echo.POST("/xrpc/com.atproto.repo.applyWrites", s.handleApplyWrites, s.handleSessionMiddleware) 355 s.echo.POST("/xrpc/com.atproto.repo.uploadBlob", s.handleRepoUploadBlob, s.handleSessionMiddleware) 356 357 // stupid silly endpoints 358 s.echo.GET("/xrpc/app.bsky.actor.getPreferences", s.handleActorGetPreferences, s.handleSessionMiddleware) 359 s.echo.POST("/xrpc/app.bsky.actor.putPreferences", s.handleActorPutPreferences, s.handleSessionMiddleware) 360 361 s.echo.GET("/xrpc/*", s.handleProxy, s.handleSessionMiddleware) 362 s.echo.POST("/xrpc/*", s.handleProxy, s.handleSessionMiddleware) 363} 364 365func (s *Server) Serve(ctx context.Context) error { 366 s.addRoutes() 367 368 s.logger.Info("migrating...") 369 370 s.db.AutoMigrate( 371 &models.Actor{}, 372 &models.Repo{}, 373 &models.InviteCode{}, 374 &models.Token{}, 375 &models.RefreshToken{}, 376 &models.Block{}, 377 &models.Record{}, 378 &models.Blob{}, 379 &models.BlobPart{}, 380 ) 381 382 s.logger.Info("starting cocoon") 383 384 go func() { 385 if err := s.httpd.ListenAndServe(); err != nil { 386 panic(err) 387 } 388 }() 389 390 for _, relay := range s.config.Relays { 391 cli := xrpc.Client{Host: relay} 392 atproto.SyncRequestCrawl(context.TODO(), &cli, &atproto.SyncRequestCrawl_Input{ 393 Hostname: s.config.Hostname, 394 }) 395 } 396 397 <-ctx.Done() 398 399 fmt.Println("shut down") 400 401 return nil 402}