1package server
2
3import (
4 "context"
5 "crypto/ecdsa"
6 "errors"
7 "fmt"
8 "log/slog"
9 "net/http"
10 "net/smtp"
11 "os"
12 "strings"
13 "sync"
14 "time"
15
16 "github.com/Azure/go-autorest/autorest/to"
17 "github.com/bluesky-social/indigo/api/atproto"
18 "github.com/bluesky-social/indigo/atproto/syntax"
19 "github.com/bluesky-social/indigo/events"
20 "github.com/bluesky-social/indigo/util"
21 "github.com/bluesky-social/indigo/xrpc"
22 "github.com/domodwyer/mailyak/v3"
23 "github.com/go-playground/validator"
24 "github.com/golang-jwt/jwt/v4"
25 "github.com/haileyok/cocoon/identity"
26 "github.com/haileyok/cocoon/internal/helpers"
27 "github.com/haileyok/cocoon/models"
28 "github.com/haileyok/cocoon/plc"
29 "github.com/labstack/echo/v4"
30 "github.com/labstack/echo/v4/middleware"
31 "github.com/lestrrat-go/jwx/v2/jwk"
32 slogecho "github.com/samber/slog-echo"
33 "gorm.io/driver/sqlite"
34 "gorm.io/gorm"
35)
36
37type Server struct {
38 http *http.Client
39 httpd *http.Server
40 mail *mailyak.MailYak
41 mailLk *sync.Mutex
42 echo *echo.Echo
43 db *gorm.DB
44 plcClient *plc.Client
45 logger *slog.Logger
46 config *config
47 privateKey *ecdsa.PrivateKey
48 repoman *RepoMan
49 evtman *events.EventManager
50 passport *identity.Passport
51}
52
53type Args struct {
54 Addr string
55 DbName string
56 Logger *slog.Logger
57 Version string
58 Did string
59 Hostname string
60 RotationKeyPath string
61 JwkPath string
62 ContactEmail string
63 Relays []string
64
65 SmtpUser string
66 SmtpPass string
67 SmtpHost string
68 SmtpPort string
69 SmtpEmail string
70 SmtpName string
71}
72
73type config struct {
74 Version string
75 Did string
76 Hostname string
77 ContactEmail string
78 EnforcePeering bool
79 Relays []string
80 SmtpEmail string
81 SmtpName string
82}
83
84type CustomValidator struct {
85 validator *validator.Validate
86}
87
88type ValidationError struct {
89 error
90 Field string
91 Tag string
92}
93
94func (cv *CustomValidator) Validate(i any) error {
95 if err := cv.validator.Struct(i); err != nil {
96 var validateErrors validator.ValidationErrors
97 if errors.As(err, &validateErrors) && len(validateErrors) > 0 {
98 first := validateErrors[0]
99 return ValidationError{
100 error: err,
101 Field: first.Field(),
102 Tag: first.Tag(),
103 }
104 }
105
106 return err
107 }
108
109 return nil
110}
111
112func (s *Server) handleSessionMiddleware(next echo.HandlerFunc) echo.HandlerFunc {
113 return func(e echo.Context) error {
114 authheader := e.Request().Header.Get("authorization")
115 if authheader == "" {
116 return e.JSON(401, map[string]string{"error": "Unauthorized"})
117 }
118
119 pts := strings.Split(authheader, " ")
120 if len(pts) != 2 {
121 return helpers.ServerError(e, nil)
122 }
123
124 tokenstr := pts[1]
125
126 token, err := new(jwt.Parser).Parse(tokenstr, func(t *jwt.Token) (any, error) {
127 if _, ok := t.Method.(*jwt.SigningMethodECDSA); !ok {
128 return nil, fmt.Errorf("unsupported signing method: %v", t.Header["alg"])
129 }
130
131 return s.privateKey.Public(), nil
132 })
133 if err != nil {
134 s.logger.Error("error parsing jwt", "error", err)
135 return helpers.InputError(e, to.StringPtr("InvalidToken"))
136 }
137
138 claims, ok := token.Claims.(jwt.MapClaims)
139 if !ok || !token.Valid {
140 return helpers.InputError(e, to.StringPtr("InvalidToken"))
141 }
142
143 isRefresh := e.Request().URL.Path == "/xrpc/com.atproto.server.refreshSession"
144 scope := claims["scope"].(string)
145
146 if isRefresh && scope != "com.atproto.refresh" {
147 return helpers.InputError(e, to.StringPtr("InvalidToken"))
148 } else if !isRefresh && scope != "com.atproto.access" {
149 return helpers.InputError(e, to.StringPtr("InvalidToken"))
150 }
151
152 table := "tokens"
153 if isRefresh {
154 table = "refresh_tokens"
155 }
156
157 type Result struct {
158 Found bool
159 }
160 var result Result
161 if err := s.db.Raw("SELECT EXISTS(SELECT 1 FROM "+table+" WHERE token = ?) AS found", tokenstr).Scan(&result).Error; err != nil {
162 if err == gorm.ErrRecordNotFound {
163 return helpers.InputError(e, to.StringPtr("InvalidToken"))
164 }
165
166 s.logger.Error("error getting token from db", "error", err)
167 return helpers.ServerError(e, nil)
168 }
169
170 if !result.Found {
171 return helpers.InputError(e, to.StringPtr("InvalidToken"))
172 }
173
174 exp, ok := claims["exp"].(float64)
175 if !ok {
176 s.logger.Error("error getting iat from token")
177 return helpers.ServerError(e, nil)
178 }
179
180 if exp < float64(time.Now().UTC().Unix()) {
181 return helpers.InputError(e, to.StringPtr("ExpiredToken"))
182 }
183
184 repo, err := s.getRepoActorByDid(claims["sub"].(string))
185 if err != nil {
186 s.logger.Error("error fetching repo", "error", err)
187 return helpers.ServerError(e, nil)
188 }
189
190 e.Set("repo", repo)
191 e.Set("did", claims["sub"])
192 e.Set("token", tokenstr)
193
194 if err := next(e); err != nil {
195 e.Error(err)
196 }
197
198 return nil
199 }
200}
201
202func New(args *Args) (*Server, error) {
203 if args.Addr == "" {
204 return nil, fmt.Errorf("addr must be set")
205 }
206
207 if args.DbName == "" {
208 return nil, fmt.Errorf("db name must be set")
209 }
210
211 if args.Did == "" {
212 return nil, fmt.Errorf("cocoon did must be set")
213 }
214
215 if args.ContactEmail == "" {
216 return nil, fmt.Errorf("cocoon contact email is required")
217 }
218
219 if _, err := syntax.ParseDID(args.Did); err != nil {
220 return nil, fmt.Errorf("error parsing cocoon did: %w", err)
221 }
222
223 if args.Hostname == "" {
224 return nil, fmt.Errorf("cocoon hostname must be set")
225 }
226
227 if args.Logger == nil {
228 args.Logger = slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{}))
229 }
230
231 e := echo.New()
232
233 e.Pre(middleware.RemoveTrailingSlash())
234 e.Pre(slogecho.New(args.Logger))
235 e.Use(middleware.CORSWithConfig(middleware.CORSConfig{
236 AllowOrigins: []string{"*"},
237 AllowHeaders: []string{"*"},
238 AllowMethods: []string{"*"},
239 AllowCredentials: true,
240 MaxAge: 100_000_000,
241 }))
242
243 vdtor := validator.New()
244 vdtor.RegisterValidation("atproto-handle", func(fl validator.FieldLevel) bool {
245 if _, err := syntax.ParseHandle(fl.Field().String()); err != nil {
246 return false
247 }
248 return true
249 })
250 vdtor.RegisterValidation("atproto-did", func(fl validator.FieldLevel) bool {
251 if _, err := syntax.ParseDID(fl.Field().String()); err != nil {
252 return false
253 }
254 return true
255 })
256 vdtor.RegisterValidation("atproto-rkey", func(fl validator.FieldLevel) bool {
257 if _, err := syntax.ParseRecordKey(fl.Field().String()); err != nil {
258 return false
259 }
260 return true
261 })
262 vdtor.RegisterValidation("atproto-nsid", func(fl validator.FieldLevel) bool {
263 if _, err := syntax.ParseNSID(fl.Field().String()); err != nil {
264 return false
265 }
266 return true
267 })
268
269 e.Validator = &CustomValidator{validator: vdtor}
270
271 httpd := &http.Server{
272 Addr: args.Addr,
273 Handler: e,
274 }
275
276 db, err := gorm.Open(sqlite.Open("cocoon.db"), &gorm.Config{})
277 if err != nil {
278 return nil, err
279 }
280
281 rkbytes, err := os.ReadFile(args.RotationKeyPath)
282 if err != nil {
283 return nil, err
284 }
285
286 h := util.RobustHTTPClient()
287
288 plcClient, err := plc.NewClient(&plc.ClientArgs{
289 H: h,
290 Service: "https://plc.directory",
291 PdsHostname: args.Hostname,
292 RotationKey: rkbytes,
293 })
294 if err != nil {
295 return nil, err
296 }
297
298 jwkbytes, err := os.ReadFile(args.JwkPath)
299 if err != nil {
300 return nil, err
301 }
302
303 key, err := jwk.ParseKey(jwkbytes)
304 if err != nil {
305 return nil, err
306 }
307
308 var pkey ecdsa.PrivateKey
309 if err := key.Raw(&pkey); err != nil {
310 return nil, err
311 }
312
313 s := &Server{
314 http: h,
315 httpd: httpd,
316 echo: e,
317 logger: args.Logger,
318 db: db,
319 plcClient: plcClient,
320 privateKey: &pkey,
321 config: &config{
322 Version: args.Version,
323 Did: args.Did,
324 Hostname: args.Hostname,
325 ContactEmail: args.ContactEmail,
326 EnforcePeering: false,
327 Relays: args.Relays,
328 SmtpName: args.SmtpName,
329 SmtpEmail: args.SmtpEmail,
330 },
331 evtman: events.NewEventManager(events.NewMemPersister()),
332 passport: identity.NewPassport(h, identity.NewMemCache(10_000)),
333 }
334
335 s.repoman = NewRepoMan(s) // TODO: this is way too lazy, stop it
336
337 // TODO: should validate these args
338 if args.SmtpUser == "" || args.SmtpPass == "" || args.SmtpHost == "" || args.SmtpPort == "" || args.SmtpEmail == "" || args.SmtpName == "" {
339 args.Logger.Warn("not enough smpt args were provided. mailing will not work for your server.")
340 } else {
341 mail := mailyak.New(args.SmtpHost+":"+args.SmtpPort, smtp.PlainAuth("", args.SmtpUser, args.SmtpPass, args.SmtpHost))
342 mail.From(s.config.SmtpEmail)
343 mail.FromName(s.config.SmtpName)
344
345 s.mail = mail
346 s.mailLk = &sync.Mutex{}
347 }
348
349 return s, nil
350}
351
352func (s *Server) addRoutes() {
353 // random stuff
354 s.echo.GET("/", s.handleRoot)
355 s.echo.GET("/xrpc/_health", s.handleHealth)
356 s.echo.GET("/.well-known/did.json", s.handleWellKnown)
357 s.echo.GET("/robots.txt", s.handleRobots)
358
359 // public
360 s.echo.GET("/xrpc/com.atproto.identity.resolveHandle", s.handleResolveHandle)
361 s.echo.POST("/xrpc/com.atproto.server.createAccount", s.handleCreateAccount)
362 s.echo.POST("/xrpc/com.atproto.server.createAccount", s.handleCreateAccount)
363 s.echo.POST("/xrpc/com.atproto.server.createSession", s.handleCreateSession)
364 s.echo.GET("/xrpc/com.atproto.server.describeServer", s.handleDescribeServer)
365
366 s.echo.GET("/xrpc/com.atproto.repo.describeRepo", s.handleDescribeRepo)
367 s.echo.GET("/xrpc/com.atproto.sync.listRepos", s.handleListRepos)
368 s.echo.GET("/xrpc/com.atproto.repo.listRecords", s.handleListRecords)
369 s.echo.GET("/xrpc/com.atproto.repo.getRecord", s.handleRepoGetRecord)
370 s.echo.GET("/xrpc/com.atproto.sync.getRecord", s.handleSyncGetRecord)
371 s.echo.GET("/xrpc/com.atproto.sync.getBlocks", s.handleGetBlocks)
372 s.echo.GET("/xrpc/com.atproto.sync.getLatestCommit", s.handleSyncGetLatestCommit)
373 s.echo.GET("/xrpc/com.atproto.sync.getRepoStatus", s.handleSyncGetRepoStatus)
374 s.echo.GET("/xrpc/com.atproto.sync.getRepo", s.handleSyncGetRepo)
375 s.echo.GET("/xrpc/com.atproto.sync.subscribeRepos", s.handleSyncSubscribeRepos)
376 s.echo.GET("/xrpc/com.atproto.sync.listBlobs", s.handleSyncListBlobs)
377 s.echo.GET("/xrpc/com.atproto.sync.getBlob", s.handleSyncGetBlob)
378
379 // authed
380 s.echo.GET("/xrpc/com.atproto.server.getSession", s.handleGetSession, s.handleSessionMiddleware)
381 s.echo.POST("/xrpc/com.atproto.server.refreshSession", s.handleRefreshSession, s.handleSessionMiddleware)
382 s.echo.POST("/xrpc/com.atproto.server.deleteSession", s.handleDeleteSession, s.handleSessionMiddleware)
383 s.echo.POST("/xrpc/com.atproto.identity.updateHandle", s.handleIdentityUpdateHandle, s.handleSessionMiddleware)
384 s.echo.POST("/xrpc/com.atproto.server.confirmEmail", s.handleServerConfirmEmail, s.handleSessionMiddleware)
385 s.echo.POST("/xrpc/com.atproto.server.requestEmailConfirmation", s.handleServerRequestEmailConfirmation, s.handleSessionMiddleware)
386 s.echo.POST("/xrpc/com.atproto.server.requestPasswordReset", s.handleServerRequestPasswordReset, s.handleSessionMiddleware)
387 s.echo.POST("/xrpc/com.atproto.server.requestEmailUpdate", s.handleServerRequestEmailUpdate, s.handleSessionMiddleware)
388 s.echo.POST("/xrpc/com.atproto.server.resetPassword", s.handleServerResetPassword, s.handleSessionMiddleware)
389 s.echo.POST("/xrpc/com.atproto.server.updateEmail", s.handleServerUpdateEmail, s.handleSessionMiddleware)
390
391 // repo
392 s.echo.POST("/xrpc/com.atproto.repo.createRecord", s.handleCreateRecord, s.handleSessionMiddleware)
393 s.echo.POST("/xrpc/com.atproto.repo.putRecord", s.handlePutRecord, s.handleSessionMiddleware)
394 s.echo.POST("/xrpc/com.atproto.repo.deleteRecord", s.handleDeleteRecord, s.handleSessionMiddleware)
395 s.echo.POST("/xrpc/com.atproto.repo.applyWrites", s.handleApplyWrites, s.handleSessionMiddleware)
396 s.echo.POST("/xrpc/com.atproto.repo.uploadBlob", s.handleRepoUploadBlob, s.handleSessionMiddleware)
397
398 // stupid silly endpoints
399 s.echo.GET("/xrpc/app.bsky.actor.getPreferences", s.handleActorGetPreferences, s.handleSessionMiddleware)
400 s.echo.POST("/xrpc/app.bsky.actor.putPreferences", s.handleActorPutPreferences, s.handleSessionMiddleware)
401
402 // are there any routes that we should be allowing without auth? i dont think so but idk
403 s.echo.GET("/xrpc/*", s.handleProxy, s.handleSessionMiddleware)
404 s.echo.POST("/xrpc/*", s.handleProxy, s.handleSessionMiddleware)
405}
406
407func (s *Server) Serve(ctx context.Context) error {
408 s.addRoutes()
409
410 s.logger.Info("migrating...")
411
412 s.db.AutoMigrate(
413 &models.Actor{},
414 &models.Repo{},
415 &models.InviteCode{},
416 &models.Token{},
417 &models.RefreshToken{},
418 &models.Block{},
419 &models.Record{},
420 &models.Blob{},
421 &models.BlobPart{},
422 )
423
424 s.logger.Info("starting cocoon")
425
426 go func() {
427 if err := s.httpd.ListenAndServe(); err != nil {
428 panic(err)
429 }
430 }()
431
432 for _, relay := range s.config.Relays {
433 cli := xrpc.Client{Host: relay}
434 atproto.SyncRequestCrawl(ctx, &cli, &atproto.SyncRequestCrawl_Input{
435 Hostname: s.config.Hostname,
436 })
437 }
438
439 <-ctx.Done()
440
441 fmt.Println("shut down")
442
443 return nil
444}