1package state
2
3import (
4 "context"
5 "database/sql"
6 "errors"
7 "fmt"
8 "log"
9 "log/slog"
10 "net/http"
11 "strings"
12 "time"
13
14 comatproto "github.com/bluesky-social/indigo/api/atproto"
15 "github.com/bluesky-social/indigo/atproto/syntax"
16 lexutil "github.com/bluesky-social/indigo/lex/util"
17 securejoin "github.com/cyphar/filepath-securejoin"
18 "github.com/go-chi/chi/v5"
19 "github.com/posthog/posthog-go"
20 "tangled.org/core/api/tangled"
21 "tangled.org/core/appview"
22 "tangled.org/core/appview/cache"
23 "tangled.org/core/appview/cache/session"
24 "tangled.org/core/appview/config"
25 "tangled.org/core/appview/db"
26 "tangled.org/core/appview/models"
27 "tangled.org/core/appview/notify"
28 dbnotify "tangled.org/core/appview/notify/db"
29 phnotify "tangled.org/core/appview/notify/posthog"
30 "tangled.org/core/appview/oauth"
31 "tangled.org/core/appview/pages"
32 "tangled.org/core/appview/reporesolver"
33 "tangled.org/core/appview/validator"
34 xrpcclient "tangled.org/core/appview/xrpcclient"
35 "tangled.org/core/eventconsumer"
36 "tangled.org/core/idresolver"
37 "tangled.org/core/jetstream"
38 tlog "tangled.org/core/log"
39 "tangled.org/core/rbac"
40 "tangled.org/core/tid"
41)
42
43type State struct {
44 db *db.DB
45 notifier notify.Notifier
46 oauth *oauth.OAuth
47 enforcer *rbac.Enforcer
48 pages *pages.Pages
49 sess *session.SessionStore
50 idResolver *idresolver.Resolver
51 posthog posthog.Client
52 jc *jetstream.JetstreamClient
53 config *config.Config
54 repoResolver *reporesolver.RepoResolver
55 knotstream *eventconsumer.Consumer
56 spindlestream *eventconsumer.Consumer
57 logger *slog.Logger
58 validator *validator.Validator
59}
60
61func Make(ctx context.Context, config *config.Config) (*State, error) {
62 d, err := db.Make(config.Core.DbPath)
63 if err != nil {
64 return nil, fmt.Errorf("failed to create db: %w", err)
65 }
66
67 enforcer, err := rbac.NewEnforcer(config.Core.DbPath)
68 if err != nil {
69 return nil, fmt.Errorf("failed to create enforcer: %w", err)
70 }
71
72 res, err := idresolver.RedisResolver(config.Redis.ToURL())
73 if err != nil {
74 log.Printf("failed to create redis resolver: %v", err)
75 res = idresolver.DefaultResolver()
76 }
77
78 pgs := pages.NewPages(config, res)
79 cache := cache.New(config.Redis.Addr)
80 sess := session.New(cache)
81 oauth := oauth.NewOAuth(config, sess)
82 validator := validator.New(d, res, enforcer)
83
84 posthog, err := posthog.NewWithConfig(config.Posthog.ApiKey, posthog.Config{Endpoint: config.Posthog.Endpoint})
85 if err != nil {
86 return nil, fmt.Errorf("failed to create posthog client: %w", err)
87 }
88
89 repoResolver := reporesolver.New(config, enforcer, res, d)
90
91 wrapper := db.DbWrapper{Execer: d}
92 jc, err := jetstream.NewJetstreamClient(
93 config.Jetstream.Endpoint,
94 "appview",
95 []string{
96 tangled.GraphFollowNSID,
97 tangled.FeedStarNSID,
98 tangled.PublicKeyNSID,
99 tangled.RepoArtifactNSID,
100 tangled.ActorProfileNSID,
101 tangled.SpindleMemberNSID,
102 tangled.SpindleNSID,
103 tangled.StringNSID,
104 tangled.RepoIssueNSID,
105 tangled.RepoIssueCommentNSID,
106 tangled.LabelDefinitionNSID,
107 tangled.LabelOpNSID,
108 },
109 nil,
110 slog.Default(),
111 wrapper,
112 false,
113
114 // in-memory filter is inapplicalble to appview so
115 // we'll never log dids anyway.
116 false,
117 )
118 if err != nil {
119 return nil, fmt.Errorf("failed to create jetstream client: %w", err)
120 }
121
122 if err := BackfillDefaultDefs(d, res); err != nil {
123 return nil, fmt.Errorf("failed to backfill default label defs: %w", err)
124 }
125
126 ingester := appview.Ingester{
127 Db: wrapper,
128 Enforcer: enforcer,
129 IdResolver: res,
130 Config: config,
131 Logger: tlog.New("ingester"),
132 Validator: validator,
133 }
134 err = jc.StartJetstream(ctx, ingester.Ingest())
135 if err != nil {
136 return nil, fmt.Errorf("failed to start jetstream watcher: %w", err)
137 }
138
139 knotstream, err := Knotstream(ctx, config, d, enforcer, posthog)
140 if err != nil {
141 return nil, fmt.Errorf("failed to start knotstream consumer: %w", err)
142 }
143 knotstream.Start(ctx)
144
145 spindlestream, err := Spindlestream(ctx, config, d, enforcer)
146 if err != nil {
147 return nil, fmt.Errorf("failed to start spindlestream consumer: %w", err)
148 }
149 spindlestream.Start(ctx)
150
151 var notifiers []notify.Notifier
152
153 // Always add the database notifier
154 notifiers = append(notifiers, dbnotify.NewDatabaseNotifier(d, res))
155
156 // Add other notifiers in production only
157 if !config.Core.Dev {
158 notifiers = append(notifiers, phnotify.NewPosthogNotifier(posthog))
159 }
160 notifier := notify.NewMergedNotifier(notifiers...)
161
162 state := &State{
163 d,
164 notifier,
165 oauth,
166 enforcer,
167 pgs,
168 sess,
169 res,
170 posthog,
171 jc,
172 config,
173 repoResolver,
174 knotstream,
175 spindlestream,
176 slog.Default(),
177 validator,
178 }
179
180 return state, nil
181}
182
183func (s *State) Close() error {
184 // other close up logic goes here
185 return s.db.Close()
186}
187
188func (s *State) Favicon(w http.ResponseWriter, r *http.Request) {
189 w.Header().Set("Content-Type", "image/svg+xml")
190 w.Header().Set("Cache-Control", "public, max-age=31536000") // one year
191 w.Header().Set("ETag", `"favicon-svg-v1"`)
192
193 if match := r.Header.Get("If-None-Match"); match == `"favicon-svg-v1"` {
194 w.WriteHeader(http.StatusNotModified)
195 return
196 }
197
198 s.pages.Favicon(w)
199}
200
201func (s *State) TermsOfService(w http.ResponseWriter, r *http.Request) {
202 user := s.oauth.GetUser(r)
203 s.pages.TermsOfService(w, pages.TermsOfServiceParams{
204 LoggedInUser: user,
205 })
206}
207
208func (s *State) PrivacyPolicy(w http.ResponseWriter, r *http.Request) {
209 user := s.oauth.GetUser(r)
210 s.pages.PrivacyPolicy(w, pages.PrivacyPolicyParams{
211 LoggedInUser: user,
212 })
213}
214
215func (s *State) Brand(w http.ResponseWriter, r *http.Request) {
216 user := s.oauth.GetUser(r)
217 s.pages.Brand(w, pages.BrandParams{
218 LoggedInUser: user,
219 })
220}
221
222func (s *State) HomeOrTimeline(w http.ResponseWriter, r *http.Request) {
223 if s.oauth.GetUser(r) != nil {
224 s.Timeline(w, r)
225 return
226 }
227 s.Home(w, r)
228}
229
230func (s *State) Timeline(w http.ResponseWriter, r *http.Request) {
231 user := s.oauth.GetUser(r)
232
233 var userDid string
234 if user != nil {
235 userDid = user.Did
236 }
237 timeline, err := db.MakeTimeline(s.db, 50, userDid)
238 if err != nil {
239 log.Println(err)
240 s.pages.Notice(w, "timeline", "Uh oh! Failed to load timeline.")
241 }
242
243 repos, err := db.GetTopStarredReposLastWeek(s.db)
244 if err != nil {
245 log.Println(err)
246 s.pages.Notice(w, "topstarredrepos", "Unable to load.")
247 return
248 }
249
250 s.pages.Timeline(w, pages.TimelineParams{
251 LoggedInUser: user,
252 Timeline: timeline,
253 Repos: repos,
254 })
255}
256
257func (s *State) UpgradeBanner(w http.ResponseWriter, r *http.Request) {
258 user := s.oauth.GetUser(r)
259 if user == nil {
260 return
261 }
262
263 l := s.logger.With("handler", "UpgradeBanner")
264 l = l.With("did", user.Did)
265 l = l.With("handle", user.Handle)
266
267 regs, err := db.GetRegistrations(
268 s.db,
269 db.FilterEq("did", user.Did),
270 db.FilterEq("needs_upgrade", 1),
271 )
272 if err != nil {
273 l.Error("non-fatal: failed to get registrations", "err", err)
274 }
275
276 spindles, err := db.GetSpindles(
277 s.db,
278 db.FilterEq("owner", user.Did),
279 db.FilterEq("needs_upgrade", 1),
280 )
281 if err != nil {
282 l.Error("non-fatal: failed to get spindles", "err", err)
283 }
284
285 if regs == nil && spindles == nil {
286 return
287 }
288
289 s.pages.UpgradeBanner(w, pages.UpgradeBannerParams{
290 Registrations: regs,
291 Spindles: spindles,
292 })
293}
294
295func (s *State) Home(w http.ResponseWriter, r *http.Request) {
296 timeline, err := db.MakeTimeline(s.db, 5, "")
297 if err != nil {
298 log.Println(err)
299 s.pages.Notice(w, "timeline", "Uh oh! Failed to load timeline.")
300 return
301 }
302
303 repos, err := db.GetTopStarredReposLastWeek(s.db)
304 if err != nil {
305 log.Println(err)
306 s.pages.Notice(w, "topstarredrepos", "Unable to load.")
307 return
308 }
309
310 s.pages.Home(w, pages.TimelineParams{
311 LoggedInUser: nil,
312 Timeline: timeline,
313 Repos: repos,
314 })
315}
316
317func (s *State) Keys(w http.ResponseWriter, r *http.Request) {
318 user := chi.URLParam(r, "user")
319 user = strings.TrimPrefix(user, "@")
320
321 if user == "" {
322 w.WriteHeader(http.StatusBadRequest)
323 return
324 }
325
326 id, err := s.idResolver.ResolveIdent(r.Context(), user)
327 if err != nil {
328 w.WriteHeader(http.StatusInternalServerError)
329 return
330 }
331
332 pubKeys, err := db.GetPublicKeysForDid(s.db, id.DID.String())
333 if err != nil {
334 w.WriteHeader(http.StatusNotFound)
335 return
336 }
337
338 if len(pubKeys) == 0 {
339 w.WriteHeader(http.StatusNotFound)
340 return
341 }
342
343 for _, k := range pubKeys {
344 key := strings.TrimRight(k.Key, "\n")
345 fmt.Fprintln(w, key)
346 }
347}
348
349func validateRepoName(name string) error {
350 // check for path traversal attempts
351 if name == "." || name == ".." ||
352 strings.Contains(name, "/") || strings.Contains(name, "\\") {
353 return fmt.Errorf("Repository name contains invalid path characters")
354 }
355
356 // check for sequences that could be used for traversal when normalized
357 if strings.Contains(name, "./") || strings.Contains(name, "../") ||
358 strings.HasPrefix(name, ".") || strings.HasSuffix(name, ".") {
359 return fmt.Errorf("Repository name contains invalid path sequence")
360 }
361
362 // then continue with character validation
363 for _, char := range name {
364 if !((char >= 'a' && char <= 'z') ||
365 (char >= 'A' && char <= 'Z') ||
366 (char >= '0' && char <= '9') ||
367 char == '-' || char == '_' || char == '.') {
368 return fmt.Errorf("Repository name can only contain alphanumeric characters, periods, hyphens, and underscores")
369 }
370 }
371
372 // additional check to prevent multiple sequential dots
373 if strings.Contains(name, "..") {
374 return fmt.Errorf("Repository name cannot contain sequential dots")
375 }
376
377 // if all checks pass
378 return nil
379}
380
381func stripGitExt(name string) string {
382 return strings.TrimSuffix(name, ".git")
383}
384
385func (s *State) NewRepo(w http.ResponseWriter, r *http.Request) {
386 switch r.Method {
387 case http.MethodGet:
388 user := s.oauth.GetUser(r)
389 knots, err := s.enforcer.GetKnotsForUser(user.Did)
390 if err != nil {
391 s.pages.Notice(w, "repo", "Invalid user account.")
392 return
393 }
394
395 s.pages.NewRepo(w, pages.NewRepoParams{
396 LoggedInUser: user,
397 Knots: knots,
398 })
399
400 case http.MethodPost:
401 l := s.logger.With("handler", "NewRepo")
402
403 user := s.oauth.GetUser(r)
404 l = l.With("did", user.Did)
405 l = l.With("handle", user.Handle)
406
407 // form validation
408 domain := r.FormValue("domain")
409 if domain == "" {
410 s.pages.Notice(w, "repo", "Invalid form submission—missing knot domain.")
411 return
412 }
413 l = l.With("knot", domain)
414
415 repoName := r.FormValue("name")
416 if repoName == "" {
417 s.pages.Notice(w, "repo", "Repository name cannot be empty.")
418 return
419 }
420
421 if err := validateRepoName(repoName); err != nil {
422 s.pages.Notice(w, "repo", err.Error())
423 return
424 }
425 repoName = stripGitExt(repoName)
426 l = l.With("repoName", repoName)
427
428 defaultBranch := r.FormValue("branch")
429 if defaultBranch == "" {
430 defaultBranch = "main"
431 }
432 l = l.With("defaultBranch", defaultBranch)
433
434 description := r.FormValue("description")
435
436 // ACL validation
437 ok, err := s.enforcer.E.Enforce(user.Did, domain, domain, "repo:create")
438 if err != nil || !ok {
439 l.Info("unauthorized")
440 s.pages.Notice(w, "repo", "You do not have permission to create a repo in this knot.")
441 return
442 }
443
444 // Check for existing repos
445 existingRepo, err := db.GetRepo(
446 s.db,
447 db.FilterEq("did", user.Did),
448 db.FilterEq("name", repoName),
449 )
450 if err == nil && existingRepo != nil {
451 l.Info("repo exists")
452 s.pages.Notice(w, "repo", fmt.Sprintf("You already have a repository by this name on %s", existingRepo.Knot))
453 return
454 }
455
456 // create atproto record for this repo
457 rkey := tid.TID()
458 repo := &models.Repo{
459 Did: user.Did,
460 Name: repoName,
461 Knot: domain,
462 Rkey: rkey,
463 Description: description,
464 Created: time.Now(),
465 Labels: models.DefaultLabelDefs(),
466 }
467 record := repo.AsRecord()
468
469 xrpcClient, err := s.oauth.AuthorizedClient(r)
470 if err != nil {
471 l.Info("PDS write failed", "err", err)
472 s.pages.Notice(w, "repo", "Failed to write record to PDS.")
473 return
474 }
475
476 atresp, err := xrpcClient.RepoPutRecord(r.Context(), &comatproto.RepoPutRecord_Input{
477 Collection: tangled.RepoNSID,
478 Repo: user.Did,
479 Rkey: rkey,
480 Record: &lexutil.LexiconTypeDecoder{
481 Val: &record,
482 },
483 })
484 if err != nil {
485 l.Info("PDS write failed", "err", err)
486 s.pages.Notice(w, "repo", "Failed to announce repository creation.")
487 return
488 }
489
490 aturi := atresp.Uri
491 l = l.With("aturi", aturi)
492 l.Info("wrote to PDS")
493
494 tx, err := s.db.BeginTx(r.Context(), nil)
495 if err != nil {
496 l.Info("txn failed", "err", err)
497 s.pages.Notice(w, "repo", "Failed to save repository information.")
498 return
499 }
500
501 // The rollback function reverts a few things on failure:
502 // - the pending txn
503 // - the ACLs
504 // - the atproto record created
505 rollback := func() {
506 err1 := tx.Rollback()
507 err2 := s.enforcer.E.LoadPolicy()
508 err3 := rollbackRecord(context.Background(), aturi, xrpcClient)
509
510 // ignore txn complete errors, this is okay
511 if errors.Is(err1, sql.ErrTxDone) {
512 err1 = nil
513 }
514
515 if errs := errors.Join(err1, err2, err3); errs != nil {
516 l.Error("failed to rollback changes", "errs", errs)
517 return
518 }
519 }
520 defer rollback()
521
522 client, err := s.oauth.ServiceClient(
523 r,
524 oauth.WithService(domain),
525 oauth.WithLxm(tangled.RepoCreateNSID),
526 oauth.WithDev(s.config.Core.Dev),
527 )
528 if err != nil {
529 l.Error("service auth failed", "err", err)
530 s.pages.Notice(w, "repo", "Failed to reach PDS.")
531 return
532 }
533
534 xe := tangled.RepoCreate(
535 r.Context(),
536 client,
537 &tangled.RepoCreate_Input{
538 Rkey: rkey,
539 },
540 )
541 if err := xrpcclient.HandleXrpcErr(xe); err != nil {
542 l.Error("xrpc error", "xe", xe)
543 s.pages.Notice(w, "repo", err.Error())
544 return
545 }
546
547 err = db.AddRepo(tx, repo)
548 if err != nil {
549 l.Error("db write failed", "err", err)
550 s.pages.Notice(w, "repo", "Failed to save repository information.")
551 return
552 }
553
554 // acls
555 p, _ := securejoin.SecureJoin(user.Did, repoName)
556 err = s.enforcer.AddRepo(user.Did, domain, p)
557 if err != nil {
558 l.Error("acl setup failed", "err", err)
559 s.pages.Notice(w, "repo", "Failed to set up repository permissions.")
560 return
561 }
562
563 err = tx.Commit()
564 if err != nil {
565 l.Error("txn commit failed", "err", err)
566 http.Error(w, err.Error(), http.StatusInternalServerError)
567 return
568 }
569
570 err = s.enforcer.E.SavePolicy()
571 if err != nil {
572 l.Error("acl save failed", "err", err)
573 http.Error(w, err.Error(), http.StatusInternalServerError)
574 return
575 }
576
577 // reset the ATURI because the transaction completed successfully
578 aturi = ""
579
580 s.notifier.NewRepo(r.Context(), repo)
581 s.pages.HxLocation(w, fmt.Sprintf("/@%s/%s", user.Handle, repoName))
582 }
583}
584
585// this is used to rollback changes made to the PDS
586//
587// it is a no-op if the provided ATURI is empty
588func rollbackRecord(ctx context.Context, aturi string, xrpcc *xrpcclient.Client) error {
589 if aturi == "" {
590 return nil
591 }
592
593 parsed := syntax.ATURI(aturi)
594
595 collection := parsed.Collection().String()
596 repo := parsed.Authority().String()
597 rkey := parsed.RecordKey().String()
598
599 _, err := xrpcc.RepoDeleteRecord(ctx, &comatproto.RepoDeleteRecord_Input{
600 Collection: collection,
601 Repo: repo,
602 Rkey: rkey,
603 })
604 return err
605}
606
607func BackfillDefaultDefs(e db.Execer, r *idresolver.Resolver) error {
608 defaults := models.DefaultLabelDefs()
609 defaultLabels, err := db.GetLabelDefinitions(e, db.FilterIn("at_uri", defaults))
610 if err != nil {
611 return err
612 }
613 // already present
614 if len(defaultLabels) == len(defaults) {
615 return nil
616 }
617
618 labelDefs, err := models.FetchDefaultDefs(r)
619 if err != nil {
620 return err
621 }
622
623 // Insert each label definition to the database
624 for _, labelDef := range labelDefs {
625 _, err = db.AddLabelDefinition(e, &labelDef)
626 if err != nil {
627 return fmt.Errorf("failed to add label definition %s: %v", labelDef.Name, err)
628 }
629 }
630
631 return nil
632}