1package server
2
3import (
4 "context"
5 "crypto/ecdsa"
6 "errors"
7 "fmt"
8 "log/slog"
9 "net/http"
10 "net/smtp"
11 "os"
12 "strings"
13 "sync"
14 "time"
15
16 "github.com/Azure/go-autorest/autorest/to"
17 "github.com/bluesky-social/indigo/api/atproto"
18 "github.com/bluesky-social/indigo/atproto/syntax"
19 "github.com/bluesky-social/indigo/events"
20 "github.com/bluesky-social/indigo/util"
21 "github.com/bluesky-social/indigo/xrpc"
22 "github.com/domodwyer/mailyak/v3"
23 "github.com/go-playground/validator"
24 "github.com/golang-jwt/jwt/v4"
25 "github.com/haileyok/cocoon/identity"
26 "github.com/haileyok/cocoon/internal/helpers"
27 "github.com/haileyok/cocoon/models"
28 "github.com/haileyok/cocoon/plc"
29 "github.com/labstack/echo/v4"
30 "github.com/labstack/echo/v4/middleware"
31 "github.com/lestrrat-go/jwx/v2/jwk"
32 slogecho "github.com/samber/slog-echo"
33 "gorm.io/driver/sqlite"
34 "gorm.io/gorm"
35)
36
37type Server struct {
38 http *http.Client
39 httpd *http.Server
40 mail *mailyak.MailYak
41 mailLk *sync.Mutex
42 echo *echo.Echo
43 db *gorm.DB
44 plcClient *plc.Client
45 logger *slog.Logger
46 config *config
47 privateKey *ecdsa.PrivateKey
48 repoman *RepoMan
49 evtman *events.EventManager
50 passport *identity.Passport
51}
52
53type Args struct {
54 Addr string
55 DbName string
56 Logger *slog.Logger
57 Version string
58 Did string
59 Hostname string
60 RotationKeyPath string
61 JwkPath string
62 ContactEmail string
63 Relays []string
64 AdminPassword string
65
66 SmtpUser string
67 SmtpPass string
68 SmtpHost string
69 SmtpPort string
70 SmtpEmail string
71 SmtpName string
72}
73
74type config struct {
75 Version string
76 Did string
77 Hostname string
78 ContactEmail string
79 EnforcePeering bool
80 Relays []string
81 AdminPassword string
82 SmtpEmail string
83 SmtpName string
84}
85
86type CustomValidator struct {
87 validator *validator.Validate
88}
89
90type ValidationError struct {
91 error
92 Field string
93 Tag string
94}
95
96func (cv *CustomValidator) Validate(i any) error {
97 if err := cv.validator.Struct(i); err != nil {
98 var validateErrors validator.ValidationErrors
99 if errors.As(err, &validateErrors) && len(validateErrors) > 0 {
100 first := validateErrors[0]
101 return ValidationError{
102 error: err,
103 Field: first.Field(),
104 Tag: first.Tag(),
105 }
106 }
107
108 return err
109 }
110
111 return nil
112}
113
114func (s *Server) handleAdminMiddleware(next echo.HandlerFunc) echo.HandlerFunc {
115 return func(e echo.Context) error {
116 username, password, ok := e.Request().BasicAuth()
117 if !ok || username != "admin" || password != s.config.AdminPassword {
118 return helpers.InputError(e, to.StringPtr("Unauthorized"))
119 }
120
121 if err := next(e); err != nil {
122 e.Error(err)
123 }
124
125 return nil
126 }
127}
128
129func (s *Server) handleSessionMiddleware(next echo.HandlerFunc) echo.HandlerFunc {
130 return func(e echo.Context) error {
131 authheader := e.Request().Header.Get("authorization")
132 if authheader == "" {
133 return e.JSON(401, map[string]string{"error": "Unauthorized"})
134 }
135
136 pts := strings.Split(authheader, " ")
137 if len(pts) != 2 {
138 return helpers.ServerError(e, nil)
139 }
140
141 tokenstr := pts[1]
142
143 token, err := new(jwt.Parser).Parse(tokenstr, func(t *jwt.Token) (any, error) {
144 if _, ok := t.Method.(*jwt.SigningMethodECDSA); !ok {
145 return nil, fmt.Errorf("unsupported signing method: %v", t.Header["alg"])
146 }
147
148 return s.privateKey.Public(), nil
149 })
150 if err != nil {
151 s.logger.Error("error parsing jwt", "error", err)
152 // NOTE: https://github.com/bluesky-social/atproto/discussions/3319
153 return e.JSON(400, map[string]string{"error": "ExpiredToken", "message": "token has expired"})
154 }
155
156 claims, ok := token.Claims.(jwt.MapClaims)
157 if !ok || !token.Valid {
158 return helpers.InputError(e, to.StringPtr("InvalidToken"))
159 }
160
161 isRefresh := e.Request().URL.Path == "/xrpc/com.atproto.server.refreshSession"
162 scope := claims["scope"].(string)
163
164 if isRefresh && scope != "com.atproto.refresh" {
165 return helpers.InputError(e, to.StringPtr("InvalidToken"))
166 } else if !isRefresh && scope != "com.atproto.access" {
167 return helpers.InputError(e, to.StringPtr("InvalidToken"))
168 }
169
170 table := "tokens"
171 if isRefresh {
172 table = "refresh_tokens"
173 }
174
175 type Result struct {
176 Found bool
177 }
178 var result Result
179 if err := s.db.Raw("SELECT EXISTS(SELECT 1 FROM "+table+" WHERE token = ?) AS found", tokenstr).Scan(&result).Error; err != nil {
180 if err == gorm.ErrRecordNotFound {
181 return helpers.InputError(e, to.StringPtr("InvalidToken"))
182 }
183
184 s.logger.Error("error getting token from db", "error", err)
185 return helpers.ServerError(e, nil)
186 }
187
188 if !result.Found {
189 return helpers.InputError(e, to.StringPtr("InvalidToken"))
190 }
191
192 exp, ok := claims["exp"].(float64)
193 if !ok {
194 s.logger.Error("error getting iat from token")
195 return helpers.ServerError(e, nil)
196 }
197
198 if exp < float64(time.Now().UTC().Unix()) {
199 return helpers.InputError(e, to.StringPtr("ExpiredToken"))
200 }
201
202 repo, err := s.getRepoActorByDid(claims["sub"].(string))
203 if err != nil {
204 s.logger.Error("error fetching repo", "error", err)
205 return helpers.ServerError(e, nil)
206 }
207
208 e.Set("repo", repo)
209 e.Set("did", claims["sub"])
210 e.Set("token", tokenstr)
211
212 if err := next(e); err != nil {
213 e.Error(err)
214 }
215
216 return nil
217 }
218}
219
220func New(args *Args) (*Server, error) {
221 if args.Addr == "" {
222 return nil, fmt.Errorf("addr must be set")
223 }
224
225 if args.DbName == "" {
226 return nil, fmt.Errorf("db name must be set")
227 }
228
229 if args.Did == "" {
230 return nil, fmt.Errorf("cocoon did must be set")
231 }
232
233 if args.ContactEmail == "" {
234 return nil, fmt.Errorf("cocoon contact email is required")
235 }
236
237 if _, err := syntax.ParseDID(args.Did); err != nil {
238 return nil, fmt.Errorf("error parsing cocoon did: %w", err)
239 }
240
241 if args.Hostname == "" {
242 return nil, fmt.Errorf("cocoon hostname must be set")
243 }
244
245 if args.AdminPassword == "" {
246 return nil, fmt.Errorf("admin password must be set")
247 }
248
249 if args.Logger == nil {
250 args.Logger = slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{}))
251 }
252
253 e := echo.New()
254
255 e.Pre(middleware.RemoveTrailingSlash())
256 e.Pre(slogecho.New(args.Logger))
257 e.Use(middleware.CORSWithConfig(middleware.CORSConfig{
258 AllowOrigins: []string{"*"},
259 AllowHeaders: []string{"*"},
260 AllowMethods: []string{"*"},
261 AllowCredentials: true,
262 MaxAge: 100_000_000,
263 }))
264
265 vdtor := validator.New()
266 vdtor.RegisterValidation("atproto-handle", func(fl validator.FieldLevel) bool {
267 if _, err := syntax.ParseHandle(fl.Field().String()); err != nil {
268 return false
269 }
270 return true
271 })
272 vdtor.RegisterValidation("atproto-did", func(fl validator.FieldLevel) bool {
273 if _, err := syntax.ParseDID(fl.Field().String()); err != nil {
274 return false
275 }
276 return true
277 })
278 vdtor.RegisterValidation("atproto-rkey", func(fl validator.FieldLevel) bool {
279 if _, err := syntax.ParseRecordKey(fl.Field().String()); err != nil {
280 return false
281 }
282 return true
283 })
284 vdtor.RegisterValidation("atproto-nsid", func(fl validator.FieldLevel) bool {
285 if _, err := syntax.ParseNSID(fl.Field().String()); err != nil {
286 return false
287 }
288 return true
289 })
290
291 e.Validator = &CustomValidator{validator: vdtor}
292
293 httpd := &http.Server{
294 Addr: args.Addr,
295 Handler: e,
296 // shitty defaults but okay for now, needed for import repo
297 ReadTimeout: 5 * time.Minute,
298 WriteTimeout: 5 * time.Minute,
299 IdleTimeout: 5 * time.Minute,
300 }
301
302 db, err := gorm.Open(sqlite.Open("cocoon.db"), &gorm.Config{})
303 if err != nil {
304 return nil, err
305 }
306
307 rkbytes, err := os.ReadFile(args.RotationKeyPath)
308 if err != nil {
309 return nil, err
310 }
311
312 h := util.RobustHTTPClient()
313
314 plcClient, err := plc.NewClient(&plc.ClientArgs{
315 H: h,
316 Service: "https://plc.directory",
317 PdsHostname: args.Hostname,
318 RotationKey: rkbytes,
319 })
320 if err != nil {
321 return nil, err
322 }
323
324 jwkbytes, err := os.ReadFile(args.JwkPath)
325 if err != nil {
326 return nil, err
327 }
328
329 key, err := jwk.ParseKey(jwkbytes)
330 if err != nil {
331 return nil, err
332 }
333
334 var pkey ecdsa.PrivateKey
335 if err := key.Raw(&pkey); err != nil {
336 return nil, err
337 }
338
339 s := &Server{
340 http: h,
341 httpd: httpd,
342 echo: e,
343 logger: args.Logger,
344 db: db,
345 plcClient: plcClient,
346 privateKey: &pkey,
347 config: &config{
348 Version: args.Version,
349 Did: args.Did,
350 Hostname: args.Hostname,
351 ContactEmail: args.ContactEmail,
352 EnforcePeering: false,
353 Relays: args.Relays,
354 AdminPassword: args.AdminPassword,
355 SmtpName: args.SmtpName,
356 SmtpEmail: args.SmtpEmail,
357 },
358 evtman: events.NewEventManager(events.NewMemPersister()),
359 passport: identity.NewPassport(h, identity.NewMemCache(10_000)),
360 }
361
362 s.repoman = NewRepoMan(s) // TODO: this is way too lazy, stop it
363
364 // TODO: should validate these args
365 if args.SmtpUser == "" || args.SmtpPass == "" || args.SmtpHost == "" || args.SmtpPort == "" || args.SmtpEmail == "" || args.SmtpName == "" {
366 args.Logger.Warn("not enough smpt args were provided. mailing will not work for your server.")
367 } else {
368 mail := mailyak.New(args.SmtpHost+":"+args.SmtpPort, smtp.PlainAuth("", args.SmtpUser, args.SmtpPass, args.SmtpHost))
369 mail.From(s.config.SmtpEmail)
370 mail.FromName(s.config.SmtpName)
371
372 s.mail = mail
373 s.mailLk = &sync.Mutex{}
374 }
375
376 return s, nil
377}
378
379func (s *Server) addRoutes() {
380 // random stuff
381 s.echo.GET("/", s.handleRoot)
382 s.echo.GET("/xrpc/_health", s.handleHealth)
383 s.echo.GET("/.well-known/did.json", s.handleWellKnown)
384 s.echo.GET("/robots.txt", s.handleRobots)
385
386 // public
387 s.echo.GET("/xrpc/com.atproto.identity.resolveHandle", s.handleResolveHandle)
388 s.echo.POST("/xrpc/com.atproto.server.createAccount", s.handleCreateAccount)
389 s.echo.POST("/xrpc/com.atproto.server.createAccount", s.handleCreateAccount)
390 s.echo.POST("/xrpc/com.atproto.server.createSession", s.handleCreateSession)
391 s.echo.GET("/xrpc/com.atproto.server.describeServer", s.handleDescribeServer)
392
393 s.echo.GET("/xrpc/com.atproto.repo.describeRepo", s.handleDescribeRepo)
394 s.echo.GET("/xrpc/com.atproto.sync.listRepos", s.handleListRepos)
395 s.echo.GET("/xrpc/com.atproto.repo.listRecords", s.handleListRecords)
396 s.echo.GET("/xrpc/com.atproto.repo.getRecord", s.handleRepoGetRecord)
397 s.echo.GET("/xrpc/com.atproto.sync.getRecord", s.handleSyncGetRecord)
398 s.echo.GET("/xrpc/com.atproto.sync.getBlocks", s.handleGetBlocks)
399 s.echo.GET("/xrpc/com.atproto.sync.getLatestCommit", s.handleSyncGetLatestCommit)
400 s.echo.GET("/xrpc/com.atproto.sync.getRepoStatus", s.handleSyncGetRepoStatus)
401 s.echo.GET("/xrpc/com.atproto.sync.getRepo", s.handleSyncGetRepo)
402 s.echo.GET("/xrpc/com.atproto.sync.subscribeRepos", s.handleSyncSubscribeRepos)
403 s.echo.GET("/xrpc/com.atproto.sync.listBlobs", s.handleSyncListBlobs)
404 s.echo.GET("/xrpc/com.atproto.sync.getBlob", s.handleSyncGetBlob)
405
406 // authed
407 s.echo.GET("/xrpc/com.atproto.server.getSession", s.handleGetSession, s.handleSessionMiddleware)
408 s.echo.POST("/xrpc/com.atproto.server.refreshSession", s.handleRefreshSession, s.handleSessionMiddleware)
409 s.echo.POST("/xrpc/com.atproto.server.deleteSession", s.handleDeleteSession, s.handleSessionMiddleware)
410 s.echo.POST("/xrpc/com.atproto.identity.updateHandle", s.handleIdentityUpdateHandle, s.handleSessionMiddleware)
411 s.echo.POST("/xrpc/com.atproto.server.confirmEmail", s.handleServerConfirmEmail, s.handleSessionMiddleware)
412 s.echo.POST("/xrpc/com.atproto.server.requestEmailConfirmation", s.handleServerRequestEmailConfirmation, s.handleSessionMiddleware)
413 s.echo.POST("/xrpc/com.atproto.server.requestPasswordReset", s.handleServerRequestPasswordReset) // AUTH NOT REQUIRED FOR THIS ONE
414 s.echo.POST("/xrpc/com.atproto.server.requestEmailUpdate", s.handleServerRequestEmailUpdate, s.handleSessionMiddleware)
415 s.echo.POST("/xrpc/com.atproto.server.resetPassword", s.handleServerResetPassword, s.handleSessionMiddleware)
416 s.echo.POST("/xrpc/com.atproto.server.updateEmail", s.handleServerUpdateEmail, s.handleSessionMiddleware)
417
418 // repo
419 s.echo.POST("/xrpc/com.atproto.repo.createRecord", s.handleCreateRecord, s.handleSessionMiddleware)
420 s.echo.POST("/xrpc/com.atproto.repo.putRecord", s.handlePutRecord, s.handleSessionMiddleware)
421 s.echo.POST("/xrpc/com.atproto.repo.deleteRecord", s.handleDeleteRecord, s.handleSessionMiddleware)
422 s.echo.POST("/xrpc/com.atproto.repo.applyWrites", s.handleApplyWrites, s.handleSessionMiddleware)
423 s.echo.POST("/xrpc/com.atproto.repo.uploadBlob", s.handleRepoUploadBlob, s.handleSessionMiddleware)
424 s.echo.POST("/xrpc/com.atproto.repo.importRepo", s.handleRepoImportRepo, s.handleSessionMiddleware)
425
426 // stupid silly endpoints
427 s.echo.GET("/xrpc/app.bsky.actor.getPreferences", s.handleActorGetPreferences, s.handleSessionMiddleware)
428 s.echo.POST("/xrpc/app.bsky.actor.putPreferences", s.handleActorPutPreferences, s.handleSessionMiddleware)
429
430 // are there any routes that we should be allowing without auth? i dont think so but idk
431 s.echo.GET("/xrpc/*", s.handleProxy, s.handleSessionMiddleware)
432 s.echo.POST("/xrpc/*", s.handleProxy, s.handleSessionMiddleware)
433
434 // admin routes
435 s.echo.POST("/xrpc/com.atproto.server.createInviteCode", s.handleCreateInviteCode, s.handleAdminMiddleware)
436 s.echo.POST("/xrpc/com.atproto.server.createInviteCodes", s.handleCreateInviteCodes, s.handleAdminMiddleware)
437}
438
439func (s *Server) Serve(ctx context.Context) error {
440 s.addRoutes()
441
442 s.logger.Info("migrating...")
443
444 s.db.AutoMigrate(
445 &models.Actor{},
446 &models.Repo{},
447 &models.InviteCode{},
448 &models.Token{},
449 &models.RefreshToken{},
450 &models.Block{},
451 &models.Record{},
452 &models.Blob{},
453 &models.BlobPart{},
454 )
455
456 s.logger.Info("starting cocoon")
457
458 go func() {
459 if err := s.httpd.ListenAndServe(); err != nil {
460 panic(err)
461 }
462 }()
463
464 for _, relay := range s.config.Relays {
465 cli := xrpc.Client{Host: relay}
466 atproto.SyncRequestCrawl(ctx, &cli, &atproto.SyncRequestCrawl_Input{
467 Hostname: s.config.Hostname,
468 })
469 }
470
471 <-ctx.Done()
472
473 fmt.Println("shut down")
474
475 return nil
476}