1package server
2
3import (
4 "bytes"
5 "context"
6 "crypto/ecdsa"
7 "crypto/sha256"
8 "embed"
9 "encoding/base64"
10 "errors"
11 "fmt"
12 "io"
13 "log/slog"
14 "net/http"
15 "net/smtp"
16 "os"
17 "path/filepath"
18 "strings"
19 "sync"
20 "text/template"
21 "time"
22
23 "github.com/Azure/go-autorest/autorest/to"
24 "github.com/aws/aws-sdk-go/aws"
25 "github.com/aws/aws-sdk-go/aws/credentials"
26 "github.com/aws/aws-sdk-go/aws/session"
27 "github.com/aws/aws-sdk-go/service/s3"
28 "github.com/bluesky-social/indigo/api/atproto"
29 "github.com/bluesky-social/indigo/atproto/syntax"
30 "github.com/bluesky-social/indigo/events"
31 "github.com/bluesky-social/indigo/util"
32 "github.com/bluesky-social/indigo/xrpc"
33 "github.com/domodwyer/mailyak/v3"
34 "github.com/go-playground/validator"
35 "github.com/golang-jwt/jwt/v4"
36 "github.com/gorilla/sessions"
37 "github.com/haileyok/cocoon/identity"
38 "github.com/haileyok/cocoon/internal/db"
39 "github.com/haileyok/cocoon/internal/helpers"
40 "github.com/haileyok/cocoon/models"
41 "github.com/haileyok/cocoon/oauth/client_manager"
42 "github.com/haileyok/cocoon/oauth/constants"
43 "github.com/haileyok/cocoon/oauth/dpop/dpop_manager"
44 "github.com/haileyok/cocoon/oauth/provider"
45 "github.com/haileyok/cocoon/plc"
46 echo_session "github.com/labstack/echo-contrib/session"
47 "github.com/labstack/echo/v4"
48 "github.com/labstack/echo/v4/middleware"
49 slogecho "github.com/samber/slog-echo"
50 "gitlab.com/yawning/secp256k1-voi"
51 secp256k1secec "gitlab.com/yawning/secp256k1-voi/secec"
52 "gorm.io/driver/sqlite"
53 "gorm.io/gorm"
54)
55
56const (
57 AccountSessionMaxAge = 30 * 24 * time.Hour // one week
58)
59
60type S3Config struct {
61 BackupsEnabled bool
62 Endpoint string
63 Region string
64 Bucket string
65 AccessKey string
66 SecretKey string
67}
68
69type Server struct {
70 http *http.Client
71 httpd *http.Server
72 mail *mailyak.MailYak
73 mailLk *sync.Mutex
74 echo *echo.Echo
75 db *db.DB
76 plcClient *plc.Client
77 logger *slog.Logger
78 config *config
79 privateKey *ecdsa.PrivateKey
80 repoman *RepoMan
81 oauthProvider *provider.Provider
82 evtman *events.EventManager
83 passport *identity.Passport
84
85 dbName string
86 s3Config *S3Config
87}
88
89type Args struct {
90 Addr string
91 DbName string
92 Logger *slog.Logger
93 Version string
94 Did string
95 Hostname string
96 RotationKeyPath string
97 JwkPath string
98 ContactEmail string
99 Relays []string
100 AdminPassword string
101
102 SmtpUser string
103 SmtpPass string
104 SmtpHost string
105 SmtpPort string
106 SmtpEmail string
107 SmtpName string
108
109 S3Config *S3Config
110
111 SessionSecret string
112}
113
114type config struct {
115 Version string
116 Did string
117 Hostname string
118 ContactEmail string
119 EnforcePeering bool
120 Relays []string
121 AdminPassword string
122 SmtpEmail string
123 SmtpName string
124}
125
126type CustomValidator struct {
127 validator *validator.Validate
128}
129
130type ValidationError struct {
131 error
132 Field string
133 Tag string
134}
135
136func (cv *CustomValidator) Validate(i any) error {
137 if err := cv.validator.Struct(i); err != nil {
138 var validateErrors validator.ValidationErrors
139 if errors.As(err, &validateErrors) && len(validateErrors) > 0 {
140 first := validateErrors[0]
141 return ValidationError{
142 error: err,
143 Field: first.Field(),
144 Tag: first.Tag(),
145 }
146 }
147
148 return err
149 }
150
151 return nil
152}
153
154//go:embed templates/*
155var templateFS embed.FS
156
157//go:embed static/*
158var staticFS embed.FS
159
160type TemplateRenderer struct {
161 templates *template.Template
162 isDev bool
163 templatePath string
164}
165
166func (s *Server) loadTemplates() {
167 absPath, _ := filepath.Abs("server/templates/*.html")
168 if s.config.Version == "dev" {
169 tmpl := template.Must(template.ParseGlob(absPath))
170 s.echo.Renderer = &TemplateRenderer{
171 templates: tmpl,
172 isDev: true,
173 templatePath: absPath,
174 }
175 } else {
176 tmpl := template.Must(template.ParseFS(templateFS, "templates/*.html"))
177 s.echo.Renderer = &TemplateRenderer{
178 templates: tmpl,
179 isDev: false,
180 }
181 }
182}
183
184func (t *TemplateRenderer) Render(w io.Writer, name string, data any, c echo.Context) error {
185 if t.isDev {
186 tmpl, err := template.ParseGlob(t.templatePath)
187 if err != nil {
188 return err
189 }
190 t.templates = tmpl
191 }
192
193 if viewContext, isMap := data.(map[string]any); isMap {
194 viewContext["reverse"] = c.Echo().Reverse
195 }
196
197 return t.templates.ExecuteTemplate(w, name, data)
198}
199
200func (s *Server) handleAdminMiddleware(next echo.HandlerFunc) echo.HandlerFunc {
201 return func(e echo.Context) error {
202 username, password, ok := e.Request().BasicAuth()
203 if !ok || username != "admin" || password != s.config.AdminPassword {
204 return helpers.InputError(e, to.StringPtr("Unauthorized"))
205 }
206
207 if err := next(e); err != nil {
208 e.Error(err)
209 }
210
211 return nil
212 }
213}
214
215func (s *Server) handleLegacySessionMiddleware(next echo.HandlerFunc) echo.HandlerFunc {
216 return func(e echo.Context) error {
217 authheader := e.Request().Header.Get("authorization")
218 if authheader == "" {
219 return e.JSON(401, map[string]string{"error": "Unauthorized"})
220 }
221
222 pts := strings.Split(authheader, " ")
223 if len(pts) != 2 {
224 return helpers.ServerError(e, nil)
225 }
226
227 // move on to oauth session middleware if this is a dpop token
228 if pts[0] == "DPoP" {
229 return next(e)
230 }
231
232 tokenstr := pts[1]
233 token, _, err := new(jwt.Parser).ParseUnverified(tokenstr, jwt.MapClaims{})
234 claims, ok := token.Claims.(jwt.MapClaims)
235 if !ok {
236 return helpers.InputError(e, to.StringPtr("InvalidToken"))
237 }
238
239 var did string
240 var repo *models.RepoActor
241
242 // service auth tokens
243 lxm, hasLxm := claims["lxm"]
244 if hasLxm {
245 pts := strings.Split(e.Request().URL.String(), "/")
246 if lxm != pts[len(pts)-1] {
247 s.logger.Error("service auth lxm incorrect", "lxm", lxm, "expected", pts[len(pts)-1], "error", err)
248 return helpers.InputError(e, nil)
249 }
250
251 maybeDid, ok := claims["iss"].(string)
252 if !ok {
253 s.logger.Error("no iss in service auth token", "error", err)
254 return helpers.InputError(e, nil)
255 }
256 did = maybeDid
257
258 maybeRepo, err := s.getRepoActorByDid(did)
259 if err != nil {
260 s.logger.Error("error fetching repo", "error", err)
261 return helpers.ServerError(e, nil)
262 }
263 repo = maybeRepo
264 }
265
266 if token.Header["alg"] != "ES256K" {
267 token, err = new(jwt.Parser).Parse(tokenstr, func(t *jwt.Token) (any, error) {
268 if _, ok := t.Method.(*jwt.SigningMethodECDSA); !ok {
269 return nil, fmt.Errorf("unsupported signing method: %v", t.Header["alg"])
270 }
271 return s.privateKey.Public(), nil
272 })
273 if err != nil {
274 s.logger.Error("error parsing jwt", "error", err)
275 // NOTE: https://github.com/bluesky-social/atproto/discussions/3319
276 return e.JSON(400, map[string]string{"error": "ExpiredToken", "message": "token has expired"})
277 }
278
279 if !token.Valid {
280 return helpers.InputError(e, to.StringPtr("InvalidToken"))
281 }
282 } else {
283 kpts := strings.Split(tokenstr, ".")
284 signingInput := kpts[0] + "." + kpts[1]
285 hash := sha256.Sum256([]byte(signingInput))
286 sigBytes, err := base64.RawURLEncoding.DecodeString(kpts[2])
287 if err != nil {
288 s.logger.Error("error decoding signature bytes", "error", err)
289 return helpers.ServerError(e, nil)
290 }
291
292 if len(sigBytes) != 64 {
293 s.logger.Error("incorrect sigbytes length", "length", len(sigBytes))
294 return helpers.ServerError(e, nil)
295 }
296
297 rBytes := sigBytes[:32]
298 sBytes := sigBytes[32:]
299 rr, _ := secp256k1.NewScalarFromBytes((*[32]byte)(rBytes))
300 ss, _ := secp256k1.NewScalarFromBytes((*[32]byte)(sBytes))
301
302 sk, err := secp256k1secec.NewPrivateKey(repo.SigningKey)
303 if err != nil {
304 s.logger.Error("can't load private key", "error", err)
305 return err
306 }
307
308 pubKey, ok := sk.Public().(*secp256k1secec.PublicKey)
309 if !ok {
310 s.logger.Error("error getting public key from sk")
311 return helpers.ServerError(e, nil)
312 }
313
314 verified := pubKey.VerifyRaw(hash[:], rr, ss)
315 if !verified {
316 s.logger.Error("error verifying", "error", err)
317 return helpers.ServerError(e, nil)
318 }
319 }
320
321 isRefresh := e.Request().URL.Path == "/xrpc/com.atproto.server.refreshSession"
322 scope, _ := claims["scope"].(string)
323
324 if isRefresh && scope != "com.atproto.refresh" {
325 return helpers.InputError(e, to.StringPtr("InvalidToken"))
326 } else if !hasLxm && !isRefresh && scope != "com.atproto.access" {
327 return helpers.InputError(e, to.StringPtr("InvalidToken"))
328 }
329
330 table := "tokens"
331 if isRefresh {
332 table = "refresh_tokens"
333 }
334
335 if isRefresh {
336 type Result struct {
337 Found bool
338 }
339 var result Result
340 if err := s.db.Raw("SELECT EXISTS(SELECT 1 FROM "+table+" WHERE token = ?) AS found", nil, tokenstr).Scan(&result).Error; err != nil {
341 if err == gorm.ErrRecordNotFound {
342 return helpers.InputError(e, to.StringPtr("InvalidToken"))
343 }
344
345 s.logger.Error("error getting token from db", "error", err)
346 return helpers.ServerError(e, nil)
347 }
348
349 if !result.Found {
350 return helpers.InputError(e, to.StringPtr("InvalidToken"))
351 }
352 }
353
354 exp, ok := claims["exp"].(float64)
355 if !ok {
356 s.logger.Error("error getting iat from token")
357 return helpers.ServerError(e, nil)
358 }
359
360 if exp < float64(time.Now().UTC().Unix()) {
361 return helpers.InputError(e, to.StringPtr("ExpiredToken"))
362 }
363
364 if repo == nil {
365 maybeRepo, err := s.getRepoActorByDid(claims["sub"].(string))
366 if err != nil {
367 s.logger.Error("error fetching repo", "error", err)
368 return helpers.ServerError(e, nil)
369 }
370 repo = maybeRepo
371 did = repo.Repo.Did
372 }
373
374 e.Set("repo", repo)
375 e.Set("did", did)
376 e.Set("token", tokenstr)
377
378 if err := next(e); err != nil {
379 e.Error(err)
380 }
381
382 return nil
383 }
384}
385
386func (s *Server) handleOauthSessionMiddleware(next echo.HandlerFunc) echo.HandlerFunc {
387 return func(e echo.Context) error {
388 authheader := e.Request().Header.Get("authorization")
389 if authheader == "" {
390 return e.JSON(401, map[string]string{"error": "Unauthorized"})
391 }
392
393 pts := strings.Split(authheader, " ")
394 if len(pts) != 2 {
395 return helpers.ServerError(e, nil)
396 }
397
398 if pts[0] != "DPoP" {
399 return next(e)
400 }
401
402 accessToken := pts[1]
403
404 proof, err := s.oauthProvider.DpopManager.CheckProof(e.Request().Method, "https://"+s.config.Hostname+e.Request().URL.String(), e.Request().Header, to.StringPtr(accessToken))
405 if err != nil {
406 s.logger.Error("invalid dpop proof", "error", err)
407 return helpers.InputError(e, to.StringPtr(err.Error()))
408 }
409
410 var oauthToken provider.OauthToken
411 if err := s.db.Raw("SELECT * FROM oauth_tokens WHERE token = ?", nil, accessToken).Scan(&oauthToken).Error; err != nil {
412 s.logger.Error("error finding access token in db", "error", err)
413 return helpers.InputError(e, nil)
414 }
415
416 if oauthToken.Token == "" {
417 return helpers.InputError(e, to.StringPtr("InvalidToken"))
418 }
419
420 if *oauthToken.Parameters.DpopJkt != proof.JKT {
421 s.logger.Error("jkt mismatch", "token", oauthToken.Parameters.DpopJkt, "proof", proof.JKT)
422 return helpers.InputError(e, to.StringPtr("dpop jkt mismatch"))
423 }
424
425 if time.Now().After(oauthToken.ExpiresAt) {
426 return e.JSON(400, map[string]string{"error": "ExpiredToken", "message": "token has expired"})
427 }
428
429 repo, err := s.getRepoActorByDid(oauthToken.Sub)
430 if err != nil {
431 s.logger.Error("could not find actor in db", "error", err)
432 return helpers.ServerError(e, nil)
433 }
434
435 nonce := s.oauthProvider.NextNonce()
436 if nonce != "" {
437 e.Response().Header().Set("DPoP-Nonce", nonce)
438 e.Response().Header().Add("access-control-expose-headers", "DPoP-Nonce")
439 }
440
441 e.Set("repo", repo)
442 e.Set("did", repo.Repo.Did)
443 e.Set("token", accessToken)
444 e.Set("scopes", strings.Split(oauthToken.Parameters.Scope, " "))
445
446 return next(e)
447 }
448}
449
450func New(args *Args) (*Server, error) {
451 if args.Addr == "" {
452 return nil, fmt.Errorf("addr must be set")
453 }
454
455 if args.DbName == "" {
456 return nil, fmt.Errorf("db name must be set")
457 }
458
459 if args.Did == "" {
460 return nil, fmt.Errorf("cocoon did must be set")
461 }
462
463 if args.ContactEmail == "" {
464 return nil, fmt.Errorf("cocoon contact email is required")
465 }
466
467 if _, err := syntax.ParseDID(args.Did); err != nil {
468 return nil, fmt.Errorf("error parsing cocoon did: %w", err)
469 }
470
471 if args.Hostname == "" {
472 return nil, fmt.Errorf("cocoon hostname must be set")
473 }
474
475 if args.AdminPassword == "" {
476 return nil, fmt.Errorf("admin password must be set")
477 }
478
479 if args.Logger == nil {
480 args.Logger = slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{}))
481 }
482
483 if args.SessionSecret == "" {
484 panic("SESSION SECRET WAS NOT SET. THIS IS REQUIRED. ")
485 }
486
487 e := echo.New()
488
489 e.Pre(middleware.RemoveTrailingSlash())
490 e.Pre(slogecho.New(args.Logger))
491 e.Use(echo_session.Middleware(sessions.NewCookieStore([]byte(args.SessionSecret))))
492 e.Use(middleware.CORSWithConfig(middleware.CORSConfig{
493 AllowOrigins: []string{"*"},
494 AllowHeaders: []string{"*"},
495 AllowMethods: []string{"*"},
496 AllowCredentials: true,
497 MaxAge: 100_000_000,
498 }))
499
500 vdtor := validator.New()
501 vdtor.RegisterValidation("atproto-handle", func(fl validator.FieldLevel) bool {
502 if _, err := syntax.ParseHandle(fl.Field().String()); err != nil {
503 return false
504 }
505 return true
506 })
507 vdtor.RegisterValidation("atproto-did", func(fl validator.FieldLevel) bool {
508 if _, err := syntax.ParseDID(fl.Field().String()); err != nil {
509 return false
510 }
511 return true
512 })
513 vdtor.RegisterValidation("atproto-rkey", func(fl validator.FieldLevel) bool {
514 if _, err := syntax.ParseRecordKey(fl.Field().String()); err != nil {
515 return false
516 }
517 return true
518 })
519 vdtor.RegisterValidation("atproto-nsid", func(fl validator.FieldLevel) bool {
520 if _, err := syntax.ParseNSID(fl.Field().String()); err != nil {
521 return false
522 }
523 return true
524 })
525
526 e.Validator = &CustomValidator{validator: vdtor}
527
528 httpd := &http.Server{
529 Addr: args.Addr,
530 Handler: e,
531 // shitty defaults but okay for now, needed for import repo
532 ReadTimeout: 5 * time.Minute,
533 WriteTimeout: 5 * time.Minute,
534 IdleTimeout: 5 * time.Minute,
535 }
536
537 gdb, err := gorm.Open(sqlite.Open("cocoon.db"), &gorm.Config{})
538 if err != nil {
539 return nil, err
540 }
541 dbw := db.NewDB(gdb)
542
543 rkbytes, err := os.ReadFile(args.RotationKeyPath)
544 if err != nil {
545 return nil, err
546 }
547
548 h := util.RobustHTTPClient()
549
550 plcClient, err := plc.NewClient(&plc.ClientArgs{
551 H: h,
552 Service: "https://plc.directory",
553 PdsHostname: args.Hostname,
554 RotationKey: rkbytes,
555 })
556 if err != nil {
557 return nil, err
558 }
559
560 jwkbytes, err := os.ReadFile(args.JwkPath)
561 if err != nil {
562 return nil, err
563 }
564
565 key, err := helpers.ParseJWKFromBytes(jwkbytes)
566 if err != nil {
567 return nil, err
568 }
569
570 var pkey ecdsa.PrivateKey
571 if err := key.Raw(&pkey); err != nil {
572 return nil, err
573 }
574
575 oauthCli := &http.Client{
576 Timeout: 10 * time.Second,
577 }
578
579 var nonceSecret []byte
580 maybeSecret, err := os.ReadFile("nonce.secret")
581 if err != nil && !os.IsNotExist(err) {
582 args.Logger.Error("error attempting to read nonce secret", "error", err)
583 } else {
584 nonceSecret = maybeSecret
585 }
586
587 s := &Server{
588 http: h,
589 httpd: httpd,
590 echo: e,
591 logger: args.Logger,
592 db: dbw,
593 plcClient: plcClient,
594 privateKey: &pkey,
595 config: &config{
596 Version: args.Version,
597 Did: args.Did,
598 Hostname: args.Hostname,
599 ContactEmail: args.ContactEmail,
600 EnforcePeering: false,
601 Relays: args.Relays,
602 AdminPassword: args.AdminPassword,
603 SmtpName: args.SmtpName,
604 SmtpEmail: args.SmtpEmail,
605 },
606 evtman: events.NewEventManager(events.NewMemPersister()),
607 passport: identity.NewPassport(h, identity.NewMemCache(10_000)),
608
609 dbName: args.DbName,
610 s3Config: args.S3Config,
611
612 oauthProvider: provider.NewProvider(provider.Args{
613 Hostname: args.Hostname,
614 ClientManagerArgs: client_manager.Args{
615 Cli: oauthCli,
616 Logger: args.Logger,
617 },
618 DpopManagerArgs: dpop_manager.Args{
619 NonceSecret: nonceSecret,
620 NonceRotationInterval: constants.NonceMaxRotationInterval / 3,
621 OnNonceSecretCreated: func(newNonce []byte) {
622 if err := os.WriteFile("nonce.secret", newNonce, 0644); err != nil {
623 args.Logger.Error("error writing new nonce secret", "error", err)
624 }
625 },
626 Logger: args.Logger,
627 Hostname: args.Hostname,
628 },
629 }),
630 }
631
632 s.loadTemplates()
633
634 s.repoman = NewRepoMan(s) // TODO: this is way too lazy, stop it
635
636 // TODO: should validate these args
637 if args.SmtpUser == "" || args.SmtpPass == "" || args.SmtpHost == "" || args.SmtpPort == "" || args.SmtpEmail == "" || args.SmtpName == "" {
638 args.Logger.Warn("not enough smpt args were provided. mailing will not work for your server.")
639 } else {
640 mail := mailyak.New(args.SmtpHost+":"+args.SmtpPort, smtp.PlainAuth("", args.SmtpUser, args.SmtpPass, args.SmtpHost))
641 mail.From(s.config.SmtpEmail)
642 mail.FromName(s.config.SmtpName)
643
644 s.mail = mail
645 s.mailLk = &sync.Mutex{}
646 }
647
648 return s, nil
649}
650
651func (s *Server) addRoutes() {
652 // static
653 if s.config.Version == "dev" {
654 s.echo.Static("/static", "server/static")
655 } else {
656 s.echo.GET("/static/*", echo.WrapHandler(http.FileServer(http.FS(staticFS))))
657 }
658
659 // random stuff
660 s.echo.GET("/", s.handleRoot)
661 s.echo.GET("/xrpc/_health", s.handleHealth)
662 s.echo.GET("/.well-known/did.json", s.handleWellKnown)
663 s.echo.GET("/.well-known/oauth-protected-resource", s.handleOauthProtectedResource)
664 s.echo.GET("/.well-known/oauth-authorization-server", s.handleOauthAuthorizationServer)
665 s.echo.GET("/robots.txt", s.handleRobots)
666
667 // public
668 s.echo.GET("/xrpc/com.atproto.identity.resolveHandle", s.handleResolveHandle)
669 s.echo.POST("/xrpc/com.atproto.server.createAccount", s.handleCreateAccount)
670 s.echo.POST("/xrpc/com.atproto.server.createAccount", s.handleCreateAccount)
671 s.echo.POST("/xrpc/com.atproto.server.createSession", s.handleCreateSession)
672 s.echo.GET("/xrpc/com.atproto.server.describeServer", s.handleDescribeServer)
673
674 s.echo.GET("/xrpc/com.atproto.repo.describeRepo", s.handleDescribeRepo)
675 s.echo.GET("/xrpc/com.atproto.sync.listRepos", s.handleListRepos)
676 s.echo.GET("/xrpc/com.atproto.repo.listRecords", s.handleListRecords)
677 s.echo.GET("/xrpc/com.atproto.repo.getRecord", s.handleRepoGetRecord)
678 s.echo.GET("/xrpc/com.atproto.sync.getRecord", s.handleSyncGetRecord)
679 s.echo.GET("/xrpc/com.atproto.sync.getBlocks", s.handleGetBlocks)
680 s.echo.GET("/xrpc/com.atproto.sync.getLatestCommit", s.handleSyncGetLatestCommit)
681 s.echo.GET("/xrpc/com.atproto.sync.getRepoStatus", s.handleSyncGetRepoStatus)
682 s.echo.GET("/xrpc/com.atproto.sync.getRepo", s.handleSyncGetRepo)
683 s.echo.GET("/xrpc/com.atproto.sync.subscribeRepos", s.handleSyncSubscribeRepos)
684 s.echo.GET("/xrpc/com.atproto.sync.listBlobs", s.handleSyncListBlobs)
685 s.echo.GET("/xrpc/com.atproto.sync.getBlob", s.handleSyncGetBlob)
686
687 // account
688 s.echo.GET("/account", s.handleAccount)
689 s.echo.POST("/account/revoke", s.handleAccountRevoke)
690 s.echo.GET("/account/signin", s.handleAccountSigninGet)
691 s.echo.POST("/account/signin", s.handleAccountSigninPost)
692 s.echo.GET("/account/signout", s.handleAccountSignout)
693
694 // oauth account
695 s.echo.GET("/oauth/jwks", s.handleOauthJwks)
696 s.echo.GET("/oauth/authorize", s.handleOauthAuthorizeGet)
697 s.echo.POST("/oauth/authorize", s.handleOauthAuthorizePost)
698
699 // oauth authorization
700 s.echo.POST("/oauth/par", s.handleOauthPar, s.oauthProvider.BaseMiddleware)
701 s.echo.POST("/oauth/token", s.handleOauthToken, s.oauthProvider.BaseMiddleware)
702
703 // authed
704 s.echo.GET("/xrpc/com.atproto.server.getSession", s.handleGetSession, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
705 s.echo.POST("/xrpc/com.atproto.server.refreshSession", s.handleRefreshSession, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
706 s.echo.POST("/xrpc/com.atproto.server.deleteSession", s.handleDeleteSession, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
707 s.echo.POST("/xrpc/com.atproto.identity.updateHandle", s.handleIdentityUpdateHandle, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
708 s.echo.POST("/xrpc/com.atproto.server.confirmEmail", s.handleServerConfirmEmail, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
709 s.echo.POST("/xrpc/com.atproto.server.requestEmailConfirmation", s.handleServerRequestEmailConfirmation, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
710 s.echo.POST("/xrpc/com.atproto.server.requestPasswordReset", s.handleServerRequestPasswordReset) // AUTH NOT REQUIRED FOR THIS ONE
711 s.echo.POST("/xrpc/com.atproto.server.requestEmailUpdate", s.handleServerRequestEmailUpdate, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
712 s.echo.POST("/xrpc/com.atproto.server.resetPassword", s.handleServerResetPassword, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
713 s.echo.POST("/xrpc/com.atproto.server.updateEmail", s.handleServerUpdateEmail, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
714 s.echo.GET("/xrpc/com.atproto.server.getServiceAuth", s.handleServerGetServiceAuth, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
715
716 // repo
717 s.echo.POST("/xrpc/com.atproto.repo.createRecord", s.handleCreateRecord, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
718 s.echo.POST("/xrpc/com.atproto.repo.putRecord", s.handlePutRecord, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
719 s.echo.POST("/xrpc/com.atproto.repo.deleteRecord", s.handleDeleteRecord, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
720 s.echo.POST("/xrpc/com.atproto.repo.applyWrites", s.handleApplyWrites, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
721 s.echo.POST("/xrpc/com.atproto.repo.uploadBlob", s.handleRepoUploadBlob, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
722 s.echo.POST("/xrpc/com.atproto.repo.importRepo", s.handleRepoImportRepo, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
723
724 // stupid silly endpoints
725 s.echo.GET("/xrpc/app.bsky.actor.getPreferences", s.handleActorGetPreferences, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
726 s.echo.POST("/xrpc/app.bsky.actor.putPreferences", s.handleActorPutPreferences, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
727
728 // are there any routes that we should be allowing without auth? i dont think so but idk
729 s.echo.GET("/xrpc/*", s.handleProxy, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
730 s.echo.POST("/xrpc/*", s.handleProxy, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
731
732 // admin routes
733 s.echo.POST("/xrpc/com.atproto.server.createInviteCode", s.handleCreateInviteCode, s.handleAdminMiddleware)
734 s.echo.POST("/xrpc/com.atproto.server.createInviteCodes", s.handleCreateInviteCodes, s.handleAdminMiddleware)
735}
736
737func (s *Server) Serve(ctx context.Context) error {
738 s.addRoutes()
739
740 s.logger.Info("migrating...")
741
742 s.db.AutoMigrate(
743 &models.Actor{},
744 &models.Repo{},
745 &models.InviteCode{},
746 &models.Token{},
747 &models.RefreshToken{},
748 &models.Block{},
749 &models.Record{},
750 &models.Blob{},
751 &models.BlobPart{},
752 &provider.OauthToken{},
753 &provider.OauthAuthorizationRequest{},
754 )
755
756 s.logger.Info("starting cocoon")
757
758 go func() {
759 if err := s.httpd.ListenAndServe(); err != nil {
760 panic(err)
761 }
762 }()
763
764 go s.backupRoutine()
765
766 for _, relay := range s.config.Relays {
767 cli := xrpc.Client{Host: relay}
768 atproto.SyncRequestCrawl(ctx, &cli, &atproto.SyncRequestCrawl_Input{
769 Hostname: s.config.Hostname,
770 })
771 }
772
773 <-ctx.Done()
774
775 fmt.Println("shut down")
776
777 return nil
778}
779
780func (s *Server) doBackup() {
781 start := time.Now()
782
783 s.logger.Info("beginning backup to s3...")
784
785 var buf bytes.Buffer
786 if err := func() error {
787 s.logger.Info("reading database bytes...")
788 s.db.Lock()
789 defer s.db.Unlock()
790
791 sf, err := os.Open(s.dbName)
792 if err != nil {
793 return fmt.Errorf("error opening database for backup: %w", err)
794 }
795 defer sf.Close()
796
797 if _, err := io.Copy(&buf, sf); err != nil {
798 return fmt.Errorf("error reading bytes of backup db: %w", err)
799 }
800
801 return nil
802 }(); err != nil {
803 s.logger.Error("error backing up database", "error", err)
804 return
805 }
806
807 if err := func() error {
808 s.logger.Info("sending to s3...")
809
810 currTime := time.Now().Format("2006-01-02_15-04-05")
811 key := "cocoon-backup-" + currTime + ".db"
812
813 config := &aws.Config{
814 Region: aws.String(s.s3Config.Region),
815 Credentials: credentials.NewStaticCredentials(s.s3Config.AccessKey, s.s3Config.SecretKey, ""),
816 }
817
818 if s.s3Config.Endpoint != "" {
819 config.Endpoint = aws.String(s.s3Config.Endpoint)
820 config.S3ForcePathStyle = aws.Bool(true)
821 }
822
823 sess, err := session.NewSession(config)
824 if err != nil {
825 return err
826 }
827
828 svc := s3.New(sess)
829
830 if _, err := svc.PutObject(&s3.PutObjectInput{
831 Bucket: aws.String(s.s3Config.Bucket),
832 Key: aws.String(key),
833 Body: bytes.NewReader(buf.Bytes()),
834 }); err != nil {
835 return fmt.Errorf("error uploading file to s3: %w", err)
836 }
837
838 s.logger.Info("finished uploading backup to s3", "key", key, "duration", time.Now().Sub(start).Seconds())
839
840 return nil
841 }(); err != nil {
842 s.logger.Error("error uploading database backup", "error", err)
843 return
844 }
845
846 os.WriteFile("last-backup.txt", []byte(time.Now().String()), 0644)
847}
848
849func (s *Server) backupRoutine() {
850 if s.s3Config == nil || !s.s3Config.BackupsEnabled {
851 return
852 }
853
854 if s.s3Config.Region == "" {
855 s.logger.Warn("no s3 region configured but backups are enabled. backups will not run.")
856 return
857 }
858
859 if s.s3Config.Bucket == "" {
860 s.logger.Warn("no s3 bucket configured but backups are enabled. backups will not run.")
861 return
862 }
863
864 if s.s3Config.AccessKey == "" {
865 s.logger.Warn("no s3 access key configured but backups are enabled. backups will not run.")
866 return
867 }
868
869 if s.s3Config.SecretKey == "" {
870 s.logger.Warn("no s3 secret key configured but backups are enabled. backups will not run.")
871 return
872 }
873
874 shouldBackupNow := false
875 lastBackupStr, err := os.ReadFile("last-backup.txt")
876 if err != nil {
877 shouldBackupNow = true
878 } else {
879 lastBackup, err := time.Parse("2006-01-02 15:04:05.999999999 -0700 MST", string(lastBackupStr))
880 if err != nil {
881 shouldBackupNow = true
882 } else if time.Now().Sub(lastBackup).Seconds() > 3600 {
883 shouldBackupNow = true
884 }
885 }
886
887 if shouldBackupNow {
888 go s.doBackup()
889 }
890
891 ticker := time.NewTicker(time.Hour)
892 for range ticker.C {
893 go s.doBackup()
894 }
895}