1package state
2
3import (
4 "context"
5 "database/sql"
6 "errors"
7 "fmt"
8 "log"
9 "log/slog"
10 "net/http"
11 "strings"
12 "time"
13
14 comatproto "github.com/bluesky-social/indigo/api/atproto"
15 "github.com/bluesky-social/indigo/atproto/syntax"
16 lexutil "github.com/bluesky-social/indigo/lex/util"
17 securejoin "github.com/cyphar/filepath-securejoin"
18 "github.com/go-chi/chi/v5"
19 "github.com/posthog/posthog-go"
20 "tangled.org/core/api/tangled"
21 "tangled.org/core/appview"
22 "tangled.org/core/appview/cache"
23 "tangled.org/core/appview/cache/session"
24 "tangled.org/core/appview/config"
25 "tangled.org/core/appview/db"
26 "tangled.org/core/appview/models"
27 "tangled.org/core/appview/notify"
28 dbnotify "tangled.org/core/appview/notify/db"
29 phnotify "tangled.org/core/appview/notify/posthog"
30 "tangled.org/core/appview/oauth"
31 "tangled.org/core/appview/pages"
32 "tangled.org/core/appview/reporesolver"
33 "tangled.org/core/appview/validator"
34 xrpcclient "tangled.org/core/appview/xrpcclient"
35 "tangled.org/core/eventconsumer"
36 "tangled.org/core/idresolver"
37 "tangled.org/core/jetstream"
38 tlog "tangled.org/core/log"
39 "tangled.org/core/rbac"
40 "tangled.org/core/tid"
41)
42
43type State struct {
44 db *db.DB
45 notifier notify.Notifier
46 oauth *oauth.OAuth
47 enforcer *rbac.Enforcer
48 pages *pages.Pages
49 sess *session.SessionStore
50 idResolver *idresolver.Resolver
51 posthog posthog.Client
52 jc *jetstream.JetstreamClient
53 config *config.Config
54 repoResolver *reporesolver.RepoResolver
55 knotstream *eventconsumer.Consumer
56 spindlestream *eventconsumer.Consumer
57 logger *slog.Logger
58 validator *validator.Validator
59}
60
61func Make(ctx context.Context, config *config.Config) (*State, error) {
62 d, err := db.Make(config.Core.DbPath)
63 if err != nil {
64 return nil, fmt.Errorf("failed to create db: %w", err)
65 }
66
67 enforcer, err := rbac.NewEnforcer(config.Core.DbPath)
68 if err != nil {
69 return nil, fmt.Errorf("failed to create enforcer: %w", err)
70 }
71
72 res, err := idresolver.RedisResolver(config.Redis.ToURL())
73 if err != nil {
74 log.Printf("failed to create redis resolver: %v", err)
75 res = idresolver.DefaultResolver()
76 }
77
78 pgs := pages.NewPages(config, res)
79 cache := cache.New(config.Redis.Addr)
80 sess := session.New(cache)
81 oauth := oauth.NewOAuth(config, sess)
82 validator := validator.New(d, res, enforcer)
83
84 posthog, err := posthog.NewWithConfig(config.Posthog.ApiKey, posthog.Config{Endpoint: config.Posthog.Endpoint})
85 if err != nil {
86 return nil, fmt.Errorf("failed to create posthog client: %w", err)
87 }
88
89 repoResolver := reporesolver.New(config, enforcer, res, d)
90
91 wrapper := db.DbWrapper{Execer: d}
92 jc, err := jetstream.NewJetstreamClient(
93 config.Jetstream.Endpoint,
94 "appview",
95 []string{
96 tangled.GraphFollowNSID,
97 tangled.FeedStarNSID,
98 tangled.PublicKeyNSID,
99 tangled.RepoArtifactNSID,
100 tangled.ActorProfileNSID,
101 tangled.SpindleMemberNSID,
102 tangled.SpindleNSID,
103 tangled.StringNSID,
104 tangled.RepoIssueNSID,
105 tangled.RepoIssueCommentNSID,
106 tangled.LabelDefinitionNSID,
107 tangled.LabelOpNSID,
108 },
109 nil,
110 slog.Default(),
111 wrapper,
112 false,
113
114 // in-memory filter is inapplicalble to appview so
115 // we'll never log dids anyway.
116 false,
117 )
118 if err != nil {
119 return nil, fmt.Errorf("failed to create jetstream client: %w", err)
120 }
121
122 if err := BackfillDefaultDefs(d, res); err != nil {
123 return nil, fmt.Errorf("failed to backfill default label defs: %w", err)
124 }
125
126 ingester := appview.Ingester{
127 Db: wrapper,
128 Enforcer: enforcer,
129 IdResolver: res,
130 Config: config,
131 Logger: tlog.New("ingester"),
132 Validator: validator,
133 }
134 err = jc.StartJetstream(ctx, ingester.Ingest())
135 if err != nil {
136 return nil, fmt.Errorf("failed to start jetstream watcher: %w", err)
137 }
138
139 knotstream, err := Knotstream(ctx, config, d, enforcer, posthog)
140 if err != nil {
141 return nil, fmt.Errorf("failed to start knotstream consumer: %w", err)
142 }
143 knotstream.Start(ctx)
144
145 spindlestream, err := Spindlestream(ctx, config, d, enforcer)
146 if err != nil {
147 return nil, fmt.Errorf("failed to start spindlestream consumer: %w", err)
148 }
149 spindlestream.Start(ctx)
150
151 var notifiers []notify.Notifier
152
153 // Always add the database notifier
154 notifiers = append(notifiers, dbnotify.NewDatabaseNotifier(d, res))
155
156 // Add other notifiers in production only
157 if !config.Core.Dev {
158 notifiers = append(notifiers, phnotify.NewPosthogNotifier(posthog))
159 }
160 notifier := notify.NewMergedNotifier(notifiers...)
161
162 state := &State{
163 d,
164 notifier,
165 oauth,
166 enforcer,
167 pgs,
168 sess,
169 res,
170 posthog,
171 jc,
172 config,
173 repoResolver,
174 knotstream,
175 spindlestream,
176 slog.Default(),
177 validator,
178 }
179
180 return state, nil
181}
182
183func (s *State) Close() error {
184 // other close up logic goes here
185 return s.db.Close()
186}
187
188func (s *State) Favicon(w http.ResponseWriter, r *http.Request) {
189 w.Header().Set("Content-Type", "image/svg+xml")
190 w.Header().Set("Cache-Control", "public, max-age=31536000") // one year
191 w.Header().Set("ETag", `"favicon-svg-v1"`)
192
193 if match := r.Header.Get("If-None-Match"); match == `"favicon-svg-v1"` {
194 w.WriteHeader(http.StatusNotModified)
195 return
196 }
197
198 s.pages.Favicon(w)
199}
200
201func (s *State) TermsOfService(w http.ResponseWriter, r *http.Request) {
202 user := s.oauth.GetUser(r)
203 s.pages.TermsOfService(w, pages.TermsOfServiceParams{
204 LoggedInUser: user,
205 })
206}
207
208func (s *State) PrivacyPolicy(w http.ResponseWriter, r *http.Request) {
209 user := s.oauth.GetUser(r)
210 s.pages.PrivacyPolicy(w, pages.PrivacyPolicyParams{
211 LoggedInUser: user,
212 })
213}
214
215func (s *State) HomeOrTimeline(w http.ResponseWriter, r *http.Request) {
216 if s.oauth.GetUser(r) != nil {
217 s.Timeline(w, r)
218 return
219 }
220 s.Home(w, r)
221}
222
223func (s *State) Timeline(w http.ResponseWriter, r *http.Request) {
224 user := s.oauth.GetUser(r)
225
226 var userDid string
227 if user != nil {
228 userDid = user.Did
229 }
230 timeline, err := db.MakeTimeline(s.db, 50, userDid)
231 if err != nil {
232 log.Println(err)
233 s.pages.Notice(w, "timeline", "Uh oh! Failed to load timeline.")
234 }
235
236 repos, err := db.GetTopStarredReposLastWeek(s.db)
237 if err != nil {
238 log.Println(err)
239 s.pages.Notice(w, "topstarredrepos", "Unable to load.")
240 return
241 }
242
243 s.pages.Timeline(w, pages.TimelineParams{
244 LoggedInUser: user,
245 Timeline: timeline,
246 Repos: repos,
247 })
248}
249
250func (s *State) UpgradeBanner(w http.ResponseWriter, r *http.Request) {
251 user := s.oauth.GetUser(r)
252 if user == nil {
253 return
254 }
255
256 l := s.logger.With("handler", "UpgradeBanner")
257 l = l.With("did", user.Did)
258 l = l.With("handle", user.Handle)
259
260 regs, err := db.GetRegistrations(
261 s.db,
262 db.FilterEq("did", user.Did),
263 db.FilterEq("needs_upgrade", 1),
264 )
265 if err != nil {
266 l.Error("non-fatal: failed to get registrations", "err", err)
267 }
268
269 spindles, err := db.GetSpindles(
270 s.db,
271 db.FilterEq("owner", user.Did),
272 db.FilterEq("needs_upgrade", 1),
273 )
274 if err != nil {
275 l.Error("non-fatal: failed to get spindles", "err", err)
276 }
277
278 if regs == nil && spindles == nil {
279 return
280 }
281
282 s.pages.UpgradeBanner(w, pages.UpgradeBannerParams{
283 Registrations: regs,
284 Spindles: spindles,
285 })
286}
287
288func (s *State) Home(w http.ResponseWriter, r *http.Request) {
289 timeline, err := db.MakeTimeline(s.db, 5, "")
290 if err != nil {
291 log.Println(err)
292 s.pages.Notice(w, "timeline", "Uh oh! Failed to load timeline.")
293 return
294 }
295
296 repos, err := db.GetTopStarredReposLastWeek(s.db)
297 if err != nil {
298 log.Println(err)
299 s.pages.Notice(w, "topstarredrepos", "Unable to load.")
300 return
301 }
302
303 s.pages.Home(w, pages.TimelineParams{
304 LoggedInUser: nil,
305 Timeline: timeline,
306 Repos: repos,
307 })
308}
309
310func (s *State) Keys(w http.ResponseWriter, r *http.Request) {
311 user := chi.URLParam(r, "user")
312 user = strings.TrimPrefix(user, "@")
313
314 if user == "" {
315 w.WriteHeader(http.StatusBadRequest)
316 return
317 }
318
319 id, err := s.idResolver.ResolveIdent(r.Context(), user)
320 if err != nil {
321 w.WriteHeader(http.StatusInternalServerError)
322 return
323 }
324
325 pubKeys, err := db.GetPublicKeysForDid(s.db, id.DID.String())
326 if err != nil {
327 w.WriteHeader(http.StatusNotFound)
328 return
329 }
330
331 if len(pubKeys) == 0 {
332 w.WriteHeader(http.StatusNotFound)
333 return
334 }
335
336 for _, k := range pubKeys {
337 key := strings.TrimRight(k.Key, "\n")
338 fmt.Fprintln(w, key)
339 }
340}
341
342func validateRepoName(name string) error {
343 // check for path traversal attempts
344 if name == "." || name == ".." ||
345 strings.Contains(name, "/") || strings.Contains(name, "\\") {
346 return fmt.Errorf("Repository name contains invalid path characters")
347 }
348
349 // check for sequences that could be used for traversal when normalized
350 if strings.Contains(name, "./") || strings.Contains(name, "../") ||
351 strings.HasPrefix(name, ".") || strings.HasSuffix(name, ".") {
352 return fmt.Errorf("Repository name contains invalid path sequence")
353 }
354
355 // then continue with character validation
356 for _, char := range name {
357 if !((char >= 'a' && char <= 'z') ||
358 (char >= 'A' && char <= 'Z') ||
359 (char >= '0' && char <= '9') ||
360 char == '-' || char == '_' || char == '.') {
361 return fmt.Errorf("Repository name can only contain alphanumeric characters, periods, hyphens, and underscores")
362 }
363 }
364
365 // additional check to prevent multiple sequential dots
366 if strings.Contains(name, "..") {
367 return fmt.Errorf("Repository name cannot contain sequential dots")
368 }
369
370 // if all checks pass
371 return nil
372}
373
374func stripGitExt(name string) string {
375 return strings.TrimSuffix(name, ".git")
376}
377
378func (s *State) NewRepo(w http.ResponseWriter, r *http.Request) {
379 switch r.Method {
380 case http.MethodGet:
381 user := s.oauth.GetUser(r)
382 knots, err := s.enforcer.GetKnotsForUser(user.Did)
383 if err != nil {
384 s.pages.Notice(w, "repo", "Invalid user account.")
385 return
386 }
387
388 s.pages.NewRepo(w, pages.NewRepoParams{
389 LoggedInUser: user,
390 Knots: knots,
391 })
392
393 case http.MethodPost:
394 l := s.logger.With("handler", "NewRepo")
395
396 user := s.oauth.GetUser(r)
397 l = l.With("did", user.Did)
398 l = l.With("handle", user.Handle)
399
400 // form validation
401 domain := r.FormValue("domain")
402 if domain == "" {
403 s.pages.Notice(w, "repo", "Invalid form submission—missing knot domain.")
404 return
405 }
406 l = l.With("knot", domain)
407
408 repoName := r.FormValue("name")
409 if repoName == "" {
410 s.pages.Notice(w, "repo", "Repository name cannot be empty.")
411 return
412 }
413
414 if err := validateRepoName(repoName); err != nil {
415 s.pages.Notice(w, "repo", err.Error())
416 return
417 }
418 repoName = stripGitExt(repoName)
419 l = l.With("repoName", repoName)
420
421 defaultBranch := r.FormValue("branch")
422 if defaultBranch == "" {
423 defaultBranch = "main"
424 }
425 l = l.With("defaultBranch", defaultBranch)
426
427 description := r.FormValue("description")
428
429 // ACL validation
430 ok, err := s.enforcer.E.Enforce(user.Did, domain, domain, "repo:create")
431 if err != nil || !ok {
432 l.Info("unauthorized")
433 s.pages.Notice(w, "repo", "You do not have permission to create a repo in this knot.")
434 return
435 }
436
437 // Check for existing repos
438 existingRepo, err := db.GetRepo(
439 s.db,
440 db.FilterEq("did", user.Did),
441 db.FilterEq("name", repoName),
442 )
443 if err == nil && existingRepo != nil {
444 l.Info("repo exists")
445 s.pages.Notice(w, "repo", fmt.Sprintf("You already have a repository by this name on %s", existingRepo.Knot))
446 return
447 }
448
449 // create atproto record for this repo
450 rkey := tid.TID()
451 repo := &models.Repo{
452 Did: user.Did,
453 Name: repoName,
454 Knot: domain,
455 Rkey: rkey,
456 Description: description,
457 Created: time.Now(),
458 Labels: models.DefaultLabelDefs(),
459 }
460 record := repo.AsRecord()
461
462 xrpcClient, err := s.oauth.AuthorizedClient(r)
463 if err != nil {
464 l.Info("PDS write failed", "err", err)
465 s.pages.Notice(w, "repo", "Failed to write record to PDS.")
466 return
467 }
468
469 atresp, err := xrpcClient.RepoPutRecord(r.Context(), &comatproto.RepoPutRecord_Input{
470 Collection: tangled.RepoNSID,
471 Repo: user.Did,
472 Rkey: rkey,
473 Record: &lexutil.LexiconTypeDecoder{
474 Val: &record,
475 },
476 })
477 if err != nil {
478 l.Info("PDS write failed", "err", err)
479 s.pages.Notice(w, "repo", "Failed to announce repository creation.")
480 return
481 }
482
483 aturi := atresp.Uri
484 l = l.With("aturi", aturi)
485 l.Info("wrote to PDS")
486
487 tx, err := s.db.BeginTx(r.Context(), nil)
488 if err != nil {
489 l.Info("txn failed", "err", err)
490 s.pages.Notice(w, "repo", "Failed to save repository information.")
491 return
492 }
493
494 // The rollback function reverts a few things on failure:
495 // - the pending txn
496 // - the ACLs
497 // - the atproto record created
498 rollback := func() {
499 err1 := tx.Rollback()
500 err2 := s.enforcer.E.LoadPolicy()
501 err3 := rollbackRecord(context.Background(), aturi, xrpcClient)
502
503 // ignore txn complete errors, this is okay
504 if errors.Is(err1, sql.ErrTxDone) {
505 err1 = nil
506 }
507
508 if errs := errors.Join(err1, err2, err3); errs != nil {
509 l.Error("failed to rollback changes", "errs", errs)
510 return
511 }
512 }
513 defer rollback()
514
515 client, err := s.oauth.ServiceClient(
516 r,
517 oauth.WithService(domain),
518 oauth.WithLxm(tangled.RepoCreateNSID),
519 oauth.WithDev(s.config.Core.Dev),
520 )
521 if err != nil {
522 l.Error("service auth failed", "err", err)
523 s.pages.Notice(w, "repo", "Failed to reach PDS.")
524 return
525 }
526
527 xe := tangled.RepoCreate(
528 r.Context(),
529 client,
530 &tangled.RepoCreate_Input{
531 Rkey: rkey,
532 },
533 )
534 if err := xrpcclient.HandleXrpcErr(xe); err != nil {
535 l.Error("xrpc error", "xe", xe)
536 s.pages.Notice(w, "repo", err.Error())
537 return
538 }
539
540 err = db.AddRepo(tx, repo)
541 if err != nil {
542 l.Error("db write failed", "err", err)
543 s.pages.Notice(w, "repo", "Failed to save repository information.")
544 return
545 }
546
547 // acls
548 p, _ := securejoin.SecureJoin(user.Did, repoName)
549 err = s.enforcer.AddRepo(user.Did, domain, p)
550 if err != nil {
551 l.Error("acl setup failed", "err", err)
552 s.pages.Notice(w, "repo", "Failed to set up repository permissions.")
553 return
554 }
555
556 err = tx.Commit()
557 if err != nil {
558 l.Error("txn commit failed", "err", err)
559 http.Error(w, err.Error(), http.StatusInternalServerError)
560 return
561 }
562
563 err = s.enforcer.E.SavePolicy()
564 if err != nil {
565 l.Error("acl save failed", "err", err)
566 http.Error(w, err.Error(), http.StatusInternalServerError)
567 return
568 }
569
570 // reset the ATURI because the transaction completed successfully
571 aturi = ""
572
573 s.notifier.NewRepo(r.Context(), repo)
574 s.pages.HxLocation(w, fmt.Sprintf("/@%s/%s", user.Handle, repoName))
575 }
576}
577
578// this is used to rollback changes made to the PDS
579//
580// it is a no-op if the provided ATURI is empty
581func rollbackRecord(ctx context.Context, aturi string, xrpcc *xrpcclient.Client) error {
582 if aturi == "" {
583 return nil
584 }
585
586 parsed := syntax.ATURI(aturi)
587
588 collection := parsed.Collection().String()
589 repo := parsed.Authority().String()
590 rkey := parsed.RecordKey().String()
591
592 _, err := xrpcc.RepoDeleteRecord(ctx, &comatproto.RepoDeleteRecord_Input{
593 Collection: collection,
594 Repo: repo,
595 Rkey: rkey,
596 })
597 return err
598}
599
600func BackfillDefaultDefs(e db.Execer, r *idresolver.Resolver) error {
601 defaults := models.DefaultLabelDefs()
602 defaultLabels, err := db.GetLabelDefinitions(e, db.FilterIn("at_uri", defaults))
603 if err != nil {
604 return err
605 }
606 // already present
607 if len(defaultLabels) == len(defaults) {
608 return nil
609 }
610
611 labelDefs, err := models.FetchDefaultDefs(r)
612 if err != nil {
613 return err
614 }
615
616 // Insert each label definition to the database
617 for _, labelDef := range labelDefs {
618 _, err = db.AddLabelDefinition(e, &labelDef)
619 if err != nil {
620 return fmt.Errorf("failed to add label definition %s: %v", labelDef.Name, err)
621 }
622 }
623
624 return nil
625}