1package server
2
3import (
4 "bytes"
5 "context"
6 "crypto/ecdsa"
7 "embed"
8 "errors"
9 "fmt"
10 "io"
11 "log/slog"
12 "net/http"
13 "net/smtp"
14 "os"
15 "path/filepath"
16 "sync"
17 "text/template"
18 "time"
19
20 "github.com/aws/aws-sdk-go/aws"
21 "github.com/aws/aws-sdk-go/aws/credentials"
22 "github.com/aws/aws-sdk-go/aws/session"
23 "github.com/aws/aws-sdk-go/service/s3"
24 "github.com/bluesky-social/indigo/api/atproto"
25 "github.com/bluesky-social/indigo/atproto/syntax"
26 "github.com/bluesky-social/indigo/events"
27 "github.com/bluesky-social/indigo/util"
28 "github.com/bluesky-social/indigo/xrpc"
29 "github.com/domodwyer/mailyak/v3"
30 "github.com/go-playground/validator"
31 "github.com/gorilla/sessions"
32 "github.com/haileyok/cocoon/identity"
33 "github.com/haileyok/cocoon/internal/db"
34 "github.com/haileyok/cocoon/internal/helpers"
35 "github.com/haileyok/cocoon/models"
36 "github.com/haileyok/cocoon/oauth/client"
37 "github.com/haileyok/cocoon/oauth/constants"
38 "github.com/haileyok/cocoon/oauth/dpop"
39 "github.com/haileyok/cocoon/oauth/provider"
40 "github.com/haileyok/cocoon/plc"
41 echo_session "github.com/labstack/echo-contrib/session"
42 "github.com/labstack/echo/v4"
43 "github.com/labstack/echo/v4/middleware"
44 slogecho "github.com/samber/slog-echo"
45 "gorm.io/driver/sqlite"
46 "gorm.io/gorm"
47)
48
49const (
50 AccountSessionMaxAge = 30 * 24 * time.Hour // one week
51)
52
53type S3Config struct {
54 BackupsEnabled bool
55 Endpoint string
56 Region string
57 Bucket string
58 AccessKey string
59 SecretKey string
60}
61
62type Server struct {
63 http *http.Client
64 httpd *http.Server
65 mail *mailyak.MailYak
66 mailLk *sync.Mutex
67 echo *echo.Echo
68 db *db.DB
69 plcClient *plc.Client
70 logger *slog.Logger
71 config *config
72 privateKey *ecdsa.PrivateKey
73 repoman *RepoMan
74 oauthProvider *provider.Provider
75 evtman *events.EventManager
76 passport *identity.Passport
77
78 dbName string
79 s3Config *S3Config
80}
81
82type Args struct {
83 Addr string
84 DbName string
85 Logger *slog.Logger
86 Version string
87 Did string
88 Hostname string
89 RotationKeyPath string
90 JwkPath string
91 ContactEmail string
92 Relays []string
93 AdminPassword string
94
95 SmtpUser string
96 SmtpPass string
97 SmtpHost string
98 SmtpPort string
99 SmtpEmail string
100 SmtpName string
101
102 S3Config *S3Config
103
104 SessionSecret string
105
106 DefaultAtprotoProxy string
107}
108
109type config struct {
110 Version string
111 Did string
112 Hostname string
113 ContactEmail string
114 EnforcePeering bool
115 Relays []string
116 AdminPassword string
117 SmtpEmail string
118 SmtpName string
119 DefaultAtprotoProxy string
120}
121
122type CustomValidator struct {
123 validator *validator.Validate
124}
125
126type ValidationError struct {
127 error
128 Field string
129 Tag string
130}
131
132func (cv *CustomValidator) Validate(i any) error {
133 if err := cv.validator.Struct(i); err != nil {
134 var validateErrors validator.ValidationErrors
135 if errors.As(err, &validateErrors) && len(validateErrors) > 0 {
136 first := validateErrors[0]
137 return ValidationError{
138 error: err,
139 Field: first.Field(),
140 Tag: first.Tag(),
141 }
142 }
143
144 return err
145 }
146
147 return nil
148}
149
150//go:embed templates/*
151var templateFS embed.FS
152
153//go:embed static/*
154var staticFS embed.FS
155
156type TemplateRenderer struct {
157 templates *template.Template
158 isDev bool
159 templatePath string
160}
161
162func (s *Server) loadTemplates() {
163 absPath, _ := filepath.Abs("server/templates/*.html")
164 if s.config.Version == "dev" {
165 tmpl := template.Must(template.ParseGlob(absPath))
166 s.echo.Renderer = &TemplateRenderer{
167 templates: tmpl,
168 isDev: true,
169 templatePath: absPath,
170 }
171 } else {
172 tmpl := template.Must(template.ParseFS(templateFS, "templates/*.html"))
173 s.echo.Renderer = &TemplateRenderer{
174 templates: tmpl,
175 isDev: false,
176 }
177 }
178}
179
180func (t *TemplateRenderer) Render(w io.Writer, name string, data any, c echo.Context) error {
181 if t.isDev {
182 tmpl, err := template.ParseGlob(t.templatePath)
183 if err != nil {
184 return err
185 }
186 t.templates = tmpl
187 }
188
189 if viewContext, isMap := data.(map[string]any); isMap {
190 viewContext["reverse"] = c.Echo().Reverse
191 }
192
193 return t.templates.ExecuteTemplate(w, name, data)
194}
195
196func New(args *Args) (*Server, error) {
197 if args.Addr == "" {
198 return nil, fmt.Errorf("addr must be set")
199 }
200
201 if args.DbName == "" {
202 return nil, fmt.Errorf("db name must be set")
203 }
204
205 if args.Did == "" {
206 return nil, fmt.Errorf("cocoon did must be set")
207 }
208
209 if args.ContactEmail == "" {
210 return nil, fmt.Errorf("cocoon contact email is required")
211 }
212
213 if _, err := syntax.ParseDID(args.Did); err != nil {
214 return nil, fmt.Errorf("error parsing cocoon did: %w", err)
215 }
216
217 if args.Hostname == "" {
218 return nil, fmt.Errorf("cocoon hostname must be set")
219 }
220
221 if args.AdminPassword == "" {
222 return nil, fmt.Errorf("admin password must be set")
223 }
224
225 if args.Logger == nil {
226 args.Logger = slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{}))
227 }
228
229 if args.SessionSecret == "" {
230 panic("SESSION SECRET WAS NOT SET. THIS IS REQUIRED. ")
231 }
232
233 e := echo.New()
234
235 e.Pre(middleware.RemoveTrailingSlash())
236 e.Pre(slogecho.New(args.Logger))
237 e.Use(echo_session.Middleware(sessions.NewCookieStore([]byte(args.SessionSecret))))
238 e.Use(middleware.CORSWithConfig(middleware.CORSConfig{
239 AllowOrigins: []string{"*"},
240 AllowHeaders: []string{"*"},
241 AllowMethods: []string{"*"},
242 AllowCredentials: true,
243 MaxAge: 100_000_000,
244 }))
245
246 vdtor := validator.New()
247 vdtor.RegisterValidation("atproto-handle", func(fl validator.FieldLevel) bool {
248 if _, err := syntax.ParseHandle(fl.Field().String()); err != nil {
249 return false
250 }
251 return true
252 })
253 vdtor.RegisterValidation("atproto-did", func(fl validator.FieldLevel) bool {
254 if _, err := syntax.ParseDID(fl.Field().String()); err != nil {
255 return false
256 }
257 return true
258 })
259 vdtor.RegisterValidation("atproto-rkey", func(fl validator.FieldLevel) bool {
260 if _, err := syntax.ParseRecordKey(fl.Field().String()); err != nil {
261 return false
262 }
263 return true
264 })
265 vdtor.RegisterValidation("atproto-nsid", func(fl validator.FieldLevel) bool {
266 if _, err := syntax.ParseNSID(fl.Field().String()); err != nil {
267 return false
268 }
269 return true
270 })
271
272 e.Validator = &CustomValidator{validator: vdtor}
273
274 httpd := &http.Server{
275 Addr: args.Addr,
276 Handler: e,
277 // shitty defaults but okay for now, needed for import repo
278 ReadTimeout: 5 * time.Minute,
279 WriteTimeout: 5 * time.Minute,
280 IdleTimeout: 5 * time.Minute,
281 }
282
283 gdb, err := gorm.Open(sqlite.Open("cocoon.db"), &gorm.Config{})
284 if err != nil {
285 return nil, err
286 }
287 dbw := db.NewDB(gdb)
288
289 rkbytes, err := os.ReadFile(args.RotationKeyPath)
290 if err != nil {
291 return nil, err
292 }
293
294 h := util.RobustHTTPClient()
295
296 plcClient, err := plc.NewClient(&plc.ClientArgs{
297 H: h,
298 Service: "https://plc.directory",
299 PdsHostname: args.Hostname,
300 RotationKey: rkbytes,
301 })
302 if err != nil {
303 return nil, err
304 }
305
306 jwkbytes, err := os.ReadFile(args.JwkPath)
307 if err != nil {
308 return nil, err
309 }
310
311 key, err := helpers.ParseJWKFromBytes(jwkbytes)
312 if err != nil {
313 return nil, err
314 }
315
316 var pkey ecdsa.PrivateKey
317 if err := key.Raw(&pkey); err != nil {
318 return nil, err
319 }
320
321 oauthCli := &http.Client{
322 Timeout: 10 * time.Second,
323 }
324
325 var nonceSecret []byte
326 maybeSecret, err := os.ReadFile("nonce.secret")
327 if err != nil && !os.IsNotExist(err) {
328 args.Logger.Error("error attempting to read nonce secret", "error", err)
329 } else {
330 nonceSecret = maybeSecret
331 }
332
333 s := &Server{
334 http: h,
335 httpd: httpd,
336 echo: e,
337 logger: args.Logger,
338 db: dbw,
339 plcClient: plcClient,
340 privateKey: &pkey,
341 config: &config{
342 Version: args.Version,
343 Did: args.Did,
344 Hostname: args.Hostname,
345 ContactEmail: args.ContactEmail,
346 EnforcePeering: false,
347 Relays: args.Relays,
348 AdminPassword: args.AdminPassword,
349 SmtpName: args.SmtpName,
350 SmtpEmail: args.SmtpEmail,
351 DefaultAtprotoProxy: args.DefaultAtprotoProxy,
352 },
353 evtman: events.NewEventManager(events.NewMemPersister()),
354 passport: identity.NewPassport(h, identity.NewMemCache(10_000)),
355
356 dbName: args.DbName,
357 s3Config: args.S3Config,
358
359 oauthProvider: provider.NewProvider(provider.Args{
360 Hostname: args.Hostname,
361 ClientManagerArgs: client.ManagerArgs{
362 Cli: oauthCli,
363 Logger: args.Logger,
364 },
365 DpopManagerArgs: dpop.ManagerArgs{
366 NonceSecret: nonceSecret,
367 NonceRotationInterval: constants.NonceMaxRotationInterval / 3,
368 OnNonceSecretCreated: func(newNonce []byte) {
369 if err := os.WriteFile("nonce.secret", newNonce, 0644); err != nil {
370 args.Logger.Error("error writing new nonce secret", "error", err)
371 }
372 },
373 Logger: args.Logger,
374 Hostname: args.Hostname,
375 },
376 }),
377 }
378
379 s.loadTemplates()
380
381 s.repoman = NewRepoMan(s) // TODO: this is way too lazy, stop it
382
383 // TODO: should validate these args
384 if args.SmtpUser == "" || args.SmtpPass == "" || args.SmtpHost == "" || args.SmtpPort == "" || args.SmtpEmail == "" || args.SmtpName == "" {
385 args.Logger.Warn("not enough smpt args were provided. mailing will not work for your server.")
386 } else {
387 mail := mailyak.New(args.SmtpHost+":"+args.SmtpPort, smtp.PlainAuth("", args.SmtpUser, args.SmtpPass, args.SmtpHost))
388 mail.From(s.config.SmtpEmail)
389 mail.FromName(s.config.SmtpName)
390
391 s.mail = mail
392 s.mailLk = &sync.Mutex{}
393 }
394
395 return s, nil
396}
397
398func (s *Server) addRoutes() {
399 // static
400 if s.config.Version == "dev" {
401 s.echo.Static("/static", "server/static")
402 } else {
403 s.echo.GET("/static/*", echo.WrapHandler(http.FileServer(http.FS(staticFS))))
404 }
405
406 // random stuff
407 s.echo.GET("/", s.handleRoot)
408 s.echo.GET("/xrpc/_health", s.handleHealth)
409 s.echo.GET("/.well-known/did.json", s.handleWellKnown)
410 s.echo.GET("/.well-known/oauth-protected-resource", s.handleOauthProtectedResource)
411 s.echo.GET("/.well-known/oauth-authorization-server", s.handleOauthAuthorizationServer)
412 s.echo.GET("/robots.txt", s.handleRobots)
413
414 // public
415 s.echo.GET("/xrpc/com.atproto.identity.resolveHandle", s.handleResolveHandle)
416 s.echo.POST("/xrpc/com.atproto.server.createAccount", s.handleCreateAccount)
417 s.echo.POST("/xrpc/com.atproto.server.createAccount", s.handleCreateAccount)
418 s.echo.POST("/xrpc/com.atproto.server.createSession", s.handleCreateSession)
419 s.echo.GET("/xrpc/com.atproto.server.describeServer", s.handleDescribeServer)
420
421 s.echo.GET("/xrpc/com.atproto.repo.describeRepo", s.handleDescribeRepo)
422 s.echo.GET("/xrpc/com.atproto.sync.listRepos", s.handleListRepos)
423 s.echo.GET("/xrpc/com.atproto.repo.listRecords", s.handleListRecords)
424 s.echo.GET("/xrpc/com.atproto.repo.getRecord", s.handleRepoGetRecord)
425 s.echo.GET("/xrpc/com.atproto.sync.getRecord", s.handleSyncGetRecord)
426 s.echo.GET("/xrpc/com.atproto.sync.getBlocks", s.handleGetBlocks)
427 s.echo.GET("/xrpc/com.atproto.sync.getLatestCommit", s.handleSyncGetLatestCommit)
428 s.echo.GET("/xrpc/com.atproto.sync.getRepoStatus", s.handleSyncGetRepoStatus)
429 s.echo.GET("/xrpc/com.atproto.sync.getRepo", s.handleSyncGetRepo)
430 s.echo.GET("/xrpc/com.atproto.sync.subscribeRepos", s.handleSyncSubscribeRepos)
431 s.echo.GET("/xrpc/com.atproto.sync.listBlobs", s.handleSyncListBlobs)
432 s.echo.GET("/xrpc/com.atproto.sync.getBlob", s.handleSyncGetBlob)
433
434 // account
435 s.echo.GET("/account", s.handleAccount)
436 s.echo.POST("/account/revoke", s.handleAccountRevoke)
437 s.echo.GET("/account/signin", s.handleAccountSigninGet)
438 s.echo.POST("/account/signin", s.handleAccountSigninPost)
439 s.echo.GET("/account/signout", s.handleAccountSignout)
440
441 // oauth account
442 s.echo.GET("/oauth/jwks", s.handleOauthJwks)
443 s.echo.GET("/oauth/authorize", s.handleOauthAuthorizeGet)
444 s.echo.POST("/oauth/authorize", s.handleOauthAuthorizePost)
445
446 // oauth authorization
447 s.echo.POST("/oauth/par", s.handleOauthPar, s.oauthProvider.BaseMiddleware)
448 s.echo.POST("/oauth/token", s.handleOauthToken, s.oauthProvider.BaseMiddleware)
449
450 // authed
451 s.echo.GET("/xrpc/com.atproto.server.getSession", s.handleGetSession, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
452 s.echo.POST("/xrpc/com.atproto.server.refreshSession", s.handleRefreshSession, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
453 s.echo.POST("/xrpc/com.atproto.server.deleteSession", s.handleDeleteSession, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
454 s.echo.POST("/xrpc/com.atproto.identity.updateHandle", s.handleIdentityUpdateHandle, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
455 s.echo.POST("/xrpc/com.atproto.server.confirmEmail", s.handleServerConfirmEmail, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
456 s.echo.POST("/xrpc/com.atproto.server.requestEmailConfirmation", s.handleServerRequestEmailConfirmation, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
457 s.echo.POST("/xrpc/com.atproto.server.requestPasswordReset", s.handleServerRequestPasswordReset) // AUTH NOT REQUIRED FOR THIS ONE
458 s.echo.POST("/xrpc/com.atproto.server.requestEmailUpdate", s.handleServerRequestEmailUpdate, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
459 s.echo.POST("/xrpc/com.atproto.server.resetPassword", s.handleServerResetPassword, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
460 s.echo.POST("/xrpc/com.atproto.server.updateEmail", s.handleServerUpdateEmail, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
461 s.echo.GET("/xrpc/com.atproto.server.getServiceAuth", s.handleServerGetServiceAuth, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
462 s.echo.GET("/xrpc/com.atproto.server.checkAccountStatus", s.handleServerCheckAccountStatus, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
463
464 // repo
465 s.echo.POST("/xrpc/com.atproto.repo.createRecord", s.handleCreateRecord, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
466 s.echo.POST("/xrpc/com.atproto.repo.putRecord", s.handlePutRecord, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
467 s.echo.POST("/xrpc/com.atproto.repo.deleteRecord", s.handleDeleteRecord, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
468 s.echo.POST("/xrpc/com.atproto.repo.applyWrites", s.handleApplyWrites, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
469 s.echo.POST("/xrpc/com.atproto.repo.uploadBlob", s.handleRepoUploadBlob, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
470 s.echo.POST("/xrpc/com.atproto.repo.importRepo", s.handleRepoImportRepo, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
471
472 // stupid silly endpoints
473 s.echo.GET("/xrpc/app.bsky.actor.getPreferences", s.handleActorGetPreferences, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
474 s.echo.POST("/xrpc/app.bsky.actor.putPreferences", s.handleActorPutPreferences, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
475
476 // admin routes
477 s.echo.POST("/xrpc/com.atproto.server.createInviteCode", s.handleCreateInviteCode, s.handleAdminMiddleware)
478 s.echo.POST("/xrpc/com.atproto.server.createInviteCodes", s.handleCreateInviteCodes, s.handleAdminMiddleware)
479
480 // are there any routes that we should be allowing without auth? i dont think so but idk
481 s.echo.GET("/xrpc/*", s.handleProxy, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
482 s.echo.POST("/xrpc/*", s.handleProxy, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
483}
484
485func (s *Server) Serve(ctx context.Context) error {
486 s.addRoutes()
487
488 s.logger.Info("migrating...")
489
490 s.db.AutoMigrate(
491 &models.Actor{},
492 &models.Repo{},
493 &models.InviteCode{},
494 &models.Token{},
495 &models.RefreshToken{},
496 &models.Block{},
497 &models.Record{},
498 &models.Blob{},
499 &models.BlobPart{},
500 &provider.OauthToken{},
501 &provider.OauthAuthorizationRequest{},
502 )
503
504 s.logger.Info("starting cocoon")
505
506 go func() {
507 if err := s.httpd.ListenAndServe(); err != nil {
508 panic(err)
509 }
510 }()
511
512 go s.backupRoutine()
513
514 for _, relay := range s.config.Relays {
515 cli := xrpc.Client{Host: relay}
516 atproto.SyncRequestCrawl(ctx, &cli, &atproto.SyncRequestCrawl_Input{
517 Hostname: s.config.Hostname,
518 })
519 }
520
521 <-ctx.Done()
522
523 fmt.Println("shut down")
524
525 return nil
526}
527
528func (s *Server) doBackup() {
529 start := time.Now()
530
531 s.logger.Info("beginning backup to s3...")
532
533 var buf bytes.Buffer
534 if err := func() error {
535 s.logger.Info("reading database bytes...")
536 s.db.Lock()
537 defer s.db.Unlock()
538
539 sf, err := os.Open(s.dbName)
540 if err != nil {
541 return fmt.Errorf("error opening database for backup: %w", err)
542 }
543 defer sf.Close()
544
545 if _, err := io.Copy(&buf, sf); err != nil {
546 return fmt.Errorf("error reading bytes of backup db: %w", err)
547 }
548
549 return nil
550 }(); err != nil {
551 s.logger.Error("error backing up database", "error", err)
552 return
553 }
554
555 if err := func() error {
556 s.logger.Info("sending to s3...")
557
558 currTime := time.Now().Format("2006-01-02_15-04-05")
559 key := "cocoon-backup-" + currTime + ".db"
560
561 config := &aws.Config{
562 Region: aws.String(s.s3Config.Region),
563 Credentials: credentials.NewStaticCredentials(s.s3Config.AccessKey, s.s3Config.SecretKey, ""),
564 }
565
566 if s.s3Config.Endpoint != "" {
567 config.Endpoint = aws.String(s.s3Config.Endpoint)
568 config.S3ForcePathStyle = aws.Bool(true)
569 }
570
571 sess, err := session.NewSession(config)
572 if err != nil {
573 return err
574 }
575
576 svc := s3.New(sess)
577
578 if _, err := svc.PutObject(&s3.PutObjectInput{
579 Bucket: aws.String(s.s3Config.Bucket),
580 Key: aws.String(key),
581 Body: bytes.NewReader(buf.Bytes()),
582 }); err != nil {
583 return fmt.Errorf("error uploading file to s3: %w", err)
584 }
585
586 s.logger.Info("finished uploading backup to s3", "key", key, "duration", time.Now().Sub(start).Seconds())
587
588 return nil
589 }(); err != nil {
590 s.logger.Error("error uploading database backup", "error", err)
591 return
592 }
593
594 os.WriteFile("last-backup.txt", []byte(time.Now().String()), 0644)
595}
596
597func (s *Server) backupRoutine() {
598 if s.s3Config == nil || !s.s3Config.BackupsEnabled {
599 return
600 }
601
602 if s.s3Config.Region == "" {
603 s.logger.Warn("no s3 region configured but backups are enabled. backups will not run.")
604 return
605 }
606
607 if s.s3Config.Bucket == "" {
608 s.logger.Warn("no s3 bucket configured but backups are enabled. backups will not run.")
609 return
610 }
611
612 if s.s3Config.AccessKey == "" {
613 s.logger.Warn("no s3 access key configured but backups are enabled. backups will not run.")
614 return
615 }
616
617 if s.s3Config.SecretKey == "" {
618 s.logger.Warn("no s3 secret key configured but backups are enabled. backups will not run.")
619 return
620 }
621
622 shouldBackupNow := false
623 lastBackupStr, err := os.ReadFile("last-backup.txt")
624 if err != nil {
625 shouldBackupNow = true
626 } else {
627 lastBackup, err := time.Parse("2006-01-02 15:04:05.999999999 -0700 MST", string(lastBackupStr))
628 if err != nil {
629 shouldBackupNow = true
630 } else if time.Now().Sub(lastBackup).Seconds() > 3600 {
631 shouldBackupNow = true
632 }
633 }
634
635 if shouldBackupNow {
636 go s.doBackup()
637 }
638
639 ticker := time.NewTicker(time.Hour)
640 for range ticker.C {
641 go s.doBackup()
642 }
643}