1package server
2
3import (
4 "context"
5 "crypto/ecdsa"
6 "errors"
7 "fmt"
8 "log/slog"
9 "net/http"
10 "net/smtp"
11 "os"
12 "strings"
13 "sync"
14 "time"
15
16 "github.com/Azure/go-autorest/autorest/to"
17 "github.com/bluesky-social/indigo/api/atproto"
18 "github.com/bluesky-social/indigo/atproto/syntax"
19 "github.com/bluesky-social/indigo/events"
20 "github.com/bluesky-social/indigo/util"
21 "github.com/bluesky-social/indigo/xrpc"
22 "github.com/domodwyer/mailyak/v3"
23 "github.com/go-playground/validator"
24 "github.com/golang-jwt/jwt/v4"
25 "github.com/haileyok/cocoon/identity"
26 "github.com/haileyok/cocoon/internal/helpers"
27 "github.com/haileyok/cocoon/models"
28 "github.com/haileyok/cocoon/plc"
29 "github.com/labstack/echo/v4"
30 "github.com/labstack/echo/v4/middleware"
31 "github.com/lestrrat-go/jwx/v2/jwk"
32 slogecho "github.com/samber/slog-echo"
33 "gorm.io/driver/sqlite"
34 "gorm.io/gorm"
35)
36
37type Server struct {
38 http *http.Client
39 httpd *http.Server
40 mail *mailyak.MailYak
41 mailLk *sync.Mutex
42 echo *echo.Echo
43 db *gorm.DB
44 plcClient *plc.Client
45 logger *slog.Logger
46 config *config
47 privateKey *ecdsa.PrivateKey
48 repoman *RepoMan
49 evtman *events.EventManager
50 passport *identity.Passport
51}
52
53type Args struct {
54 Addr string
55 DbName string
56 Logger *slog.Logger
57 Version string
58 Did string
59 Hostname string
60 RotationKeyPath string
61 JwkPath string
62 ContactEmail string
63 Relays []string
64 AdminPassword string
65
66 SmtpUser string
67 SmtpPass string
68 SmtpHost string
69 SmtpPort string
70 SmtpEmail string
71 SmtpName string
72}
73
74type config struct {
75 Version string
76 Did string
77 Hostname string
78 ContactEmail string
79 EnforcePeering bool
80 Relays []string
81 AdminPassword string
82 SmtpEmail string
83 SmtpName string
84}
85
86type CustomValidator struct {
87 validator *validator.Validate
88}
89
90type ValidationError struct {
91 error
92 Field string
93 Tag string
94}
95
96func (cv *CustomValidator) Validate(i any) error {
97 if err := cv.validator.Struct(i); err != nil {
98 var validateErrors validator.ValidationErrors
99 if errors.As(err, &validateErrors) && len(validateErrors) > 0 {
100 first := validateErrors[0]
101 return ValidationError{
102 error: err,
103 Field: first.Field(),
104 Tag: first.Tag(),
105 }
106 }
107
108 return err
109 }
110
111 return nil
112}
113
114func (s *Server) handleAdminMiddleware(next echo.HandlerFunc) echo.HandlerFunc {
115 return func(e echo.Context) error {
116 username, password, ok := e.Request().BasicAuth()
117 if !ok || username != "admin" || password != s.config.AdminPassword {
118 return helpers.InputError(e, to.StringPtr("Unauthorized"))
119 }
120
121 if err := next(e); err != nil {
122 e.Error(err)
123 }
124
125 return nil
126 }
127}
128
129func (s *Server) handleSessionMiddleware(next echo.HandlerFunc) echo.HandlerFunc {
130 return func(e echo.Context) error {
131 authheader := e.Request().Header.Get("authorization")
132 if authheader == "" {
133 return e.JSON(401, map[string]string{"error": "Unauthorized"})
134 }
135
136 pts := strings.Split(authheader, " ")
137 if len(pts) != 2 {
138 return helpers.ServerError(e, nil)
139 }
140
141 tokenstr := pts[1]
142
143 token, err := new(jwt.Parser).Parse(tokenstr, func(t *jwt.Token) (any, error) {
144 if _, ok := t.Method.(*jwt.SigningMethodECDSA); !ok {
145 return nil, fmt.Errorf("unsupported signing method: %v", t.Header["alg"])
146 }
147
148 return s.privateKey.Public(), nil
149 })
150 if err != nil {
151 s.logger.Error("error parsing jwt", "error", err)
152 // NOTE: https://github.com/bluesky-social/atproto/discussions/3319
153 return e.JSON(400, map[string]string{"error": "ExpiredToken", "message": "token has expired"})
154 }
155
156 claims, ok := token.Claims.(jwt.MapClaims)
157 if !ok || !token.Valid {
158 return helpers.InputError(e, to.StringPtr("InvalidToken"))
159 }
160
161 isRefresh := e.Request().URL.Path == "/xrpc/com.atproto.server.refreshSession"
162 scope := claims["scope"].(string)
163
164 if isRefresh && scope != "com.atproto.refresh" {
165 return helpers.InputError(e, to.StringPtr("InvalidToken"))
166 } else if !isRefresh && scope != "com.atproto.access" {
167 return helpers.InputError(e, to.StringPtr("InvalidToken"))
168 }
169
170 table := "tokens"
171 if isRefresh {
172 table = "refresh_tokens"
173 }
174
175 type Result struct {
176 Found bool
177 }
178 var result Result
179 if err := s.db.Raw("SELECT EXISTS(SELECT 1 FROM "+table+" WHERE token = ?) AS found", tokenstr).Scan(&result).Error; err != nil {
180 if err == gorm.ErrRecordNotFound {
181 return helpers.InputError(e, to.StringPtr("InvalidToken"))
182 }
183
184 s.logger.Error("error getting token from db", "error", err)
185 return helpers.ServerError(e, nil)
186 }
187
188 if !result.Found {
189 return helpers.InputError(e, to.StringPtr("InvalidToken"))
190 }
191
192 exp, ok := claims["exp"].(float64)
193 if !ok {
194 s.logger.Error("error getting iat from token")
195 return helpers.ServerError(e, nil)
196 }
197
198 if exp < float64(time.Now().UTC().Unix()) {
199 return helpers.InputError(e, to.StringPtr("ExpiredToken"))
200 }
201
202 repo, err := s.getRepoActorByDid(claims["sub"].(string))
203 if err != nil {
204 s.logger.Error("error fetching repo", "error", err)
205 return helpers.ServerError(e, nil)
206 }
207
208 e.Set("repo", repo)
209 e.Set("did", claims["sub"])
210 e.Set("token", tokenstr)
211
212 if err := next(e); err != nil {
213 e.Error(err)
214 }
215
216 return nil
217 }
218}
219
220func New(args *Args) (*Server, error) {
221 if args.Addr == "" {
222 return nil, fmt.Errorf("addr must be set")
223 }
224
225 if args.DbName == "" {
226 return nil, fmt.Errorf("db name must be set")
227 }
228
229 if args.Did == "" {
230 return nil, fmt.Errorf("cocoon did must be set")
231 }
232
233 if args.ContactEmail == "" {
234 return nil, fmt.Errorf("cocoon contact email is required")
235 }
236
237 if _, err := syntax.ParseDID(args.Did); err != nil {
238 return nil, fmt.Errorf("error parsing cocoon did: %w", err)
239 }
240
241 if args.Hostname == "" {
242 return nil, fmt.Errorf("cocoon hostname must be set")
243 }
244
245 if args.AdminPassword == "" {
246 return nil, fmt.Errorf("admin password must be set")
247 }
248
249 if args.Logger == nil {
250 args.Logger = slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{}))
251 }
252
253 e := echo.New()
254
255 e.Pre(middleware.RemoveTrailingSlash())
256 e.Pre(slogecho.New(args.Logger))
257 e.Use(middleware.CORSWithConfig(middleware.CORSConfig{
258 AllowOrigins: []string{"*"},
259 AllowHeaders: []string{"*"},
260 AllowMethods: []string{"*"},
261 AllowCredentials: true,
262 MaxAge: 100_000_000,
263 }))
264
265 vdtor := validator.New()
266 vdtor.RegisterValidation("atproto-handle", func(fl validator.FieldLevel) bool {
267 if _, err := syntax.ParseHandle(fl.Field().String()); err != nil {
268 return false
269 }
270 return true
271 })
272 vdtor.RegisterValidation("atproto-did", func(fl validator.FieldLevel) bool {
273 if _, err := syntax.ParseDID(fl.Field().String()); err != nil {
274 return false
275 }
276 return true
277 })
278 vdtor.RegisterValidation("atproto-rkey", func(fl validator.FieldLevel) bool {
279 if _, err := syntax.ParseRecordKey(fl.Field().String()); err != nil {
280 return false
281 }
282 return true
283 })
284 vdtor.RegisterValidation("atproto-nsid", func(fl validator.FieldLevel) bool {
285 if _, err := syntax.ParseNSID(fl.Field().String()); err != nil {
286 return false
287 }
288 return true
289 })
290
291 e.Validator = &CustomValidator{validator: vdtor}
292
293 httpd := &http.Server{
294 Addr: args.Addr,
295 Handler: e,
296 }
297
298 db, err := gorm.Open(sqlite.Open("cocoon.db"), &gorm.Config{})
299 if err != nil {
300 return nil, err
301 }
302
303 rkbytes, err := os.ReadFile(args.RotationKeyPath)
304 if err != nil {
305 return nil, err
306 }
307
308 h := util.RobustHTTPClient()
309
310 plcClient, err := plc.NewClient(&plc.ClientArgs{
311 H: h,
312 Service: "https://plc.directory",
313 PdsHostname: args.Hostname,
314 RotationKey: rkbytes,
315 })
316 if err != nil {
317 return nil, err
318 }
319
320 jwkbytes, err := os.ReadFile(args.JwkPath)
321 if err != nil {
322 return nil, err
323 }
324
325 key, err := jwk.ParseKey(jwkbytes)
326 if err != nil {
327 return nil, err
328 }
329
330 var pkey ecdsa.PrivateKey
331 if err := key.Raw(&pkey); err != nil {
332 return nil, err
333 }
334
335 s := &Server{
336 http: h,
337 httpd: httpd,
338 echo: e,
339 logger: args.Logger,
340 db: db,
341 plcClient: plcClient,
342 privateKey: &pkey,
343 config: &config{
344 Version: args.Version,
345 Did: args.Did,
346 Hostname: args.Hostname,
347 ContactEmail: args.ContactEmail,
348 EnforcePeering: false,
349 Relays: args.Relays,
350 AdminPassword: args.AdminPassword,
351 SmtpName: args.SmtpName,
352 SmtpEmail: args.SmtpEmail,
353 },
354 evtman: events.NewEventManager(events.NewMemPersister()),
355 passport: identity.NewPassport(h, identity.NewMemCache(10_000)),
356 }
357
358 s.repoman = NewRepoMan(s) // TODO: this is way too lazy, stop it
359
360 // TODO: should validate these args
361 if args.SmtpUser == "" || args.SmtpPass == "" || args.SmtpHost == "" || args.SmtpPort == "" || args.SmtpEmail == "" || args.SmtpName == "" {
362 args.Logger.Warn("not enough smpt args were provided. mailing will not work for your server.")
363 } else {
364 mail := mailyak.New(args.SmtpHost+":"+args.SmtpPort, smtp.PlainAuth("", args.SmtpUser, args.SmtpPass, args.SmtpHost))
365 mail.From(s.config.SmtpEmail)
366 mail.FromName(s.config.SmtpName)
367
368 s.mail = mail
369 s.mailLk = &sync.Mutex{}
370 }
371
372 return s, nil
373}
374
375func (s *Server) addRoutes() {
376 // random stuff
377 s.echo.GET("/", s.handleRoot)
378 s.echo.GET("/xrpc/_health", s.handleHealth)
379 s.echo.GET("/.well-known/did.json", s.handleWellKnown)
380 s.echo.GET("/robots.txt", s.handleRobots)
381
382 // public
383 s.echo.GET("/xrpc/com.atproto.identity.resolveHandle", s.handleResolveHandle)
384 s.echo.POST("/xrpc/com.atproto.server.createAccount", s.handleCreateAccount)
385 s.echo.POST("/xrpc/com.atproto.server.createAccount", s.handleCreateAccount)
386 s.echo.POST("/xrpc/com.atproto.server.createSession", s.handleCreateSession)
387 s.echo.GET("/xrpc/com.atproto.server.describeServer", s.handleDescribeServer)
388
389 s.echo.GET("/xrpc/com.atproto.repo.describeRepo", s.handleDescribeRepo)
390 s.echo.GET("/xrpc/com.atproto.sync.listRepos", s.handleListRepos)
391 s.echo.GET("/xrpc/com.atproto.repo.listRecords", s.handleListRecords)
392 s.echo.GET("/xrpc/com.atproto.repo.getRecord", s.handleRepoGetRecord)
393 s.echo.GET("/xrpc/com.atproto.sync.getRecord", s.handleSyncGetRecord)
394 s.echo.GET("/xrpc/com.atproto.sync.getBlocks", s.handleGetBlocks)
395 s.echo.GET("/xrpc/com.atproto.sync.getLatestCommit", s.handleSyncGetLatestCommit)
396 s.echo.GET("/xrpc/com.atproto.sync.getRepoStatus", s.handleSyncGetRepoStatus)
397 s.echo.GET("/xrpc/com.atproto.sync.getRepo", s.handleSyncGetRepo)
398 s.echo.GET("/xrpc/com.atproto.sync.subscribeRepos", s.handleSyncSubscribeRepos)
399 s.echo.GET("/xrpc/com.atproto.sync.listBlobs", s.handleSyncListBlobs)
400 s.echo.GET("/xrpc/com.atproto.sync.getBlob", s.handleSyncGetBlob)
401
402 // authed
403 s.echo.GET("/xrpc/com.atproto.server.getSession", s.handleGetSession, s.handleSessionMiddleware)
404 s.echo.POST("/xrpc/com.atproto.server.refreshSession", s.handleRefreshSession, s.handleSessionMiddleware)
405 s.echo.POST("/xrpc/com.atproto.server.deleteSession", s.handleDeleteSession, s.handleSessionMiddleware)
406 s.echo.POST("/xrpc/com.atproto.identity.updateHandle", s.handleIdentityUpdateHandle, s.handleSessionMiddleware)
407 s.echo.POST("/xrpc/com.atproto.server.confirmEmail", s.handleServerConfirmEmail, s.handleSessionMiddleware)
408 s.echo.POST("/xrpc/com.atproto.server.requestEmailConfirmation", s.handleServerRequestEmailConfirmation, s.handleSessionMiddleware)
409 s.echo.POST("/xrpc/com.atproto.server.requestPasswordReset", s.handleServerRequestPasswordReset) // AUTH NOT REQUIRED FOR THIS ONE
410 s.echo.POST("/xrpc/com.atproto.server.requestEmailUpdate", s.handleServerRequestEmailUpdate, s.handleSessionMiddleware)
411 s.echo.POST("/xrpc/com.atproto.server.resetPassword", s.handleServerResetPassword, s.handleSessionMiddleware)
412 s.echo.POST("/xrpc/com.atproto.server.updateEmail", s.handleServerUpdateEmail, s.handleSessionMiddleware)
413
414 // repo
415 s.echo.POST("/xrpc/com.atproto.repo.createRecord", s.handleCreateRecord, s.handleSessionMiddleware)
416 s.echo.POST("/xrpc/com.atproto.repo.putRecord", s.handlePutRecord, s.handleSessionMiddleware)
417 s.echo.POST("/xrpc/com.atproto.repo.deleteRecord", s.handleDeleteRecord, s.handleSessionMiddleware)
418 s.echo.POST("/xrpc/com.atproto.repo.applyWrites", s.handleApplyWrites, s.handleSessionMiddleware)
419 s.echo.POST("/xrpc/com.atproto.repo.uploadBlob", s.handleRepoUploadBlob, s.handleSessionMiddleware)
420
421 // stupid silly endpoints
422 s.echo.GET("/xrpc/app.bsky.actor.getPreferences", s.handleActorGetPreferences, s.handleSessionMiddleware)
423 s.echo.POST("/xrpc/app.bsky.actor.putPreferences", s.handleActorPutPreferences, s.handleSessionMiddleware)
424
425 // are there any routes that we should be allowing without auth? i dont think so but idk
426 s.echo.GET("/xrpc/*", s.handleProxy, s.handleSessionMiddleware)
427 s.echo.POST("/xrpc/*", s.handleProxy, s.handleSessionMiddleware)
428
429 // admin routes
430 s.echo.POST("/xrpc/com.atproto.server.createInviteCode", s.handleCreateInviteCode, s.handleAdminMiddleware)
431 s.echo.POST("/xrpc/com.atproto.server.createInviteCodes", s.handleCreateInviteCodes, s.handleAdminMiddleware)
432}
433
434func (s *Server) Serve(ctx context.Context) error {
435 s.addRoutes()
436
437 s.logger.Info("migrating...")
438
439 s.db.AutoMigrate(
440 &models.Actor{},
441 &models.Repo{},
442 &models.InviteCode{},
443 &models.Token{},
444 &models.RefreshToken{},
445 &models.Block{},
446 &models.Record{},
447 &models.Blob{},
448 &models.BlobPart{},
449 )
450
451 s.logger.Info("starting cocoon")
452
453 go func() {
454 if err := s.httpd.ListenAndServe(); err != nil {
455 panic(err)
456 }
457 }()
458
459 for _, relay := range s.config.Relays {
460 cli := xrpc.Client{Host: relay}
461 atproto.SyncRequestCrawl(ctx, &cli, &atproto.SyncRequestCrawl_Input{
462 Hostname: s.config.Hostname,
463 })
464 }
465
466 <-ctx.Done()
467
468 fmt.Println("shut down")
469
470 return nil
471}