forked from tangled.org/core
Monorepo for Tangled — https://tangled.org
1package state 2 3import ( 4 "context" 5 "database/sql" 6 "errors" 7 "fmt" 8 "log" 9 "log/slog" 10 "net/http" 11 "strings" 12 "time" 13 14 comatproto "github.com/bluesky-social/indigo/api/atproto" 15 "github.com/bluesky-social/indigo/atproto/syntax" 16 lexutil "github.com/bluesky-social/indigo/lex/util" 17 securejoin "github.com/cyphar/filepath-securejoin" 18 "github.com/go-chi/chi/v5" 19 "github.com/posthog/posthog-go" 20 "tangled.org/core/api/tangled" 21 "tangled.org/core/appview" 22 "tangled.org/core/appview/cache" 23 "tangled.org/core/appview/cache/session" 24 "tangled.org/core/appview/config" 25 "tangled.org/core/appview/db" 26 "tangled.org/core/appview/models" 27 "tangled.org/core/appview/notify" 28 dbnotify "tangled.org/core/appview/notify/db" 29 phnotify "tangled.org/core/appview/notify/posthog" 30 "tangled.org/core/appview/oauth" 31 "tangled.org/core/appview/pages" 32 "tangled.org/core/appview/reporesolver" 33 "tangled.org/core/appview/validator" 34 xrpcclient "tangled.org/core/appview/xrpcclient" 35 "tangled.org/core/eventconsumer" 36 "tangled.org/core/idresolver" 37 "tangled.org/core/jetstream" 38 tlog "tangled.org/core/log" 39 "tangled.org/core/rbac" 40 "tangled.org/core/tid" 41) 42 43type State struct { 44 db *db.DB 45 notifier notify.Notifier 46 oauth *oauth.OAuth 47 enforcer *rbac.Enforcer 48 pages *pages.Pages 49 sess *session.SessionStore 50 idResolver *idresolver.Resolver 51 posthog posthog.Client 52 jc *jetstream.JetstreamClient 53 config *config.Config 54 repoResolver *reporesolver.RepoResolver 55 knotstream *eventconsumer.Consumer 56 spindlestream *eventconsumer.Consumer 57 logger *slog.Logger 58 validator *validator.Validator 59} 60 61func Make(ctx context.Context, config *config.Config) (*State, error) { 62 d, err := db.Make(config.Core.DbPath) 63 if err != nil { 64 return nil, fmt.Errorf("failed to create db: %w", err) 65 } 66 67 enforcer, err := rbac.NewEnforcer(config.Core.DbPath) 68 if err != nil { 69 return nil, fmt.Errorf("failed to create enforcer: %w", err) 70 } 71 72 res, err := idresolver.RedisResolver(config.Redis.ToURL()) 73 if err != nil { 74 log.Printf("failed to create redis resolver: %v", err) 75 res = idresolver.DefaultResolver() 76 } 77 78 pgs := pages.NewPages(config, res) 79 cache := cache.New(config.Redis.Addr) 80 sess := session.New(cache) 81 oauth := oauth.NewOAuth(config, sess) 82 validator := validator.New(d, res, enforcer) 83 84 posthog, err := posthog.NewWithConfig(config.Posthog.ApiKey, posthog.Config{Endpoint: config.Posthog.Endpoint}) 85 if err != nil { 86 return nil, fmt.Errorf("failed to create posthog client: %w", err) 87 } 88 89 repoResolver := reporesolver.New(config, enforcer, res, d) 90 91 wrapper := db.DbWrapper{Execer: d} 92 jc, err := jetstream.NewJetstreamClient( 93 config.Jetstream.Endpoint, 94 "appview", 95 []string{ 96 tangled.GraphFollowNSID, 97 tangled.FeedStarNSID, 98 tangled.PublicKeyNSID, 99 tangled.RepoArtifactNSID, 100 tangled.ActorProfileNSID, 101 tangled.SpindleMemberNSID, 102 tangled.SpindleNSID, 103 tangled.StringNSID, 104 tangled.RepoIssueNSID, 105 tangled.RepoIssueCommentNSID, 106 tangled.LabelDefinitionNSID, 107 tangled.LabelOpNSID, 108 }, 109 nil, 110 slog.Default(), 111 wrapper, 112 false, 113 114 // in-memory filter is inapplicalble to appview so 115 // we'll never log dids anyway. 116 false, 117 ) 118 if err != nil { 119 return nil, fmt.Errorf("failed to create jetstream client: %w", err) 120 } 121 122 if err := BackfillDefaultDefs(d, res); err != nil { 123 return nil, fmt.Errorf("failed to backfill default label defs: %w", err) 124 } 125 126 ingester := appview.Ingester{ 127 Db: wrapper, 128 Enforcer: enforcer, 129 IdResolver: res, 130 Config: config, 131 Logger: tlog.New("ingester"), 132 Validator: validator, 133 } 134 err = jc.StartJetstream(ctx, ingester.Ingest()) 135 if err != nil { 136 return nil, fmt.Errorf("failed to start jetstream watcher: %w", err) 137 } 138 139 knotstream, err := Knotstream(ctx, config, d, enforcer, posthog) 140 if err != nil { 141 return nil, fmt.Errorf("failed to start knotstream consumer: %w", err) 142 } 143 knotstream.Start(ctx) 144 145 spindlestream, err := Spindlestream(ctx, config, d, enforcer) 146 if err != nil { 147 return nil, fmt.Errorf("failed to start spindlestream consumer: %w", err) 148 } 149 spindlestream.Start(ctx) 150 151 var notifiers []notify.Notifier 152 153 // Always add the database notifier 154 notifiers = append(notifiers, dbnotify.NewDatabaseNotifier(d, res)) 155 156 // Add other notifiers in production only 157 if !config.Core.Dev { 158 notifiers = append(notifiers, phnotify.NewPosthogNotifier(posthog)) 159 } 160 notifier := notify.NewMergedNotifier(notifiers...) 161 162 state := &State{ 163 d, 164 notifier, 165 oauth, 166 enforcer, 167 pgs, 168 sess, 169 res, 170 posthog, 171 jc, 172 config, 173 repoResolver, 174 knotstream, 175 spindlestream, 176 slog.Default(), 177 validator, 178 } 179 180 return state, nil 181} 182 183func (s *State) Close() error { 184 // other close up logic goes here 185 return s.db.Close() 186} 187 188func (s *State) Favicon(w http.ResponseWriter, r *http.Request) { 189 w.Header().Set("Content-Type", "image/svg+xml") 190 w.Header().Set("Cache-Control", "public, max-age=31536000") // one year 191 w.Header().Set("ETag", `"favicon-svg-v1"`) 192 193 if match := r.Header.Get("If-None-Match"); match == `"favicon-svg-v1"` { 194 w.WriteHeader(http.StatusNotModified) 195 return 196 } 197 198 s.pages.Favicon(w) 199} 200 201func (s *State) TermsOfService(w http.ResponseWriter, r *http.Request) { 202 user := s.oauth.GetUser(r) 203 s.pages.TermsOfService(w, pages.TermsOfServiceParams{ 204 LoggedInUser: user, 205 }) 206} 207 208func (s *State) PrivacyPolicy(w http.ResponseWriter, r *http.Request) { 209 user := s.oauth.GetUser(r) 210 s.pages.PrivacyPolicy(w, pages.PrivacyPolicyParams{ 211 LoggedInUser: user, 212 }) 213} 214 215func (s *State) HomeOrTimeline(w http.ResponseWriter, r *http.Request) { 216 if s.oauth.GetUser(r) != nil { 217 s.Timeline(w, r) 218 return 219 } 220 s.Home(w, r) 221} 222 223func (s *State) Timeline(w http.ResponseWriter, r *http.Request) { 224 user := s.oauth.GetUser(r) 225 226 var userDid string 227 if user != nil { 228 userDid = user.Did 229 } 230 timeline, err := db.MakeTimeline(s.db, 50, userDid) 231 if err != nil { 232 log.Println(err) 233 s.pages.Notice(w, "timeline", "Uh oh! Failed to load timeline.") 234 } 235 236 repos, err := db.GetTopStarredReposLastWeek(s.db) 237 if err != nil { 238 log.Println(err) 239 s.pages.Notice(w, "topstarredrepos", "Unable to load.") 240 return 241 } 242 243 s.pages.Timeline(w, pages.TimelineParams{ 244 LoggedInUser: user, 245 Timeline: timeline, 246 Repos: repos, 247 }) 248} 249 250func (s *State) UpgradeBanner(w http.ResponseWriter, r *http.Request) { 251 user := s.oauth.GetUser(r) 252 if user == nil { 253 return 254 } 255 256 l := s.logger.With("handler", "UpgradeBanner") 257 l = l.With("did", user.Did) 258 l = l.With("handle", user.Handle) 259 260 regs, err := db.GetRegistrations( 261 s.db, 262 db.FilterEq("did", user.Did), 263 db.FilterEq("needs_upgrade", 1), 264 ) 265 if err != nil { 266 l.Error("non-fatal: failed to get registrations", "err", err) 267 } 268 269 spindles, err := db.GetSpindles( 270 s.db, 271 db.FilterEq("owner", user.Did), 272 db.FilterEq("needs_upgrade", 1), 273 ) 274 if err != nil { 275 l.Error("non-fatal: failed to get spindles", "err", err) 276 } 277 278 if regs == nil && spindles == nil { 279 return 280 } 281 282 s.pages.UpgradeBanner(w, pages.UpgradeBannerParams{ 283 Registrations: regs, 284 Spindles: spindles, 285 }) 286} 287 288func (s *State) Home(w http.ResponseWriter, r *http.Request) { 289 timeline, err := db.MakeTimeline(s.db, 5, "") 290 if err != nil { 291 log.Println(err) 292 s.pages.Notice(w, "timeline", "Uh oh! Failed to load timeline.") 293 return 294 } 295 296 repos, err := db.GetTopStarredReposLastWeek(s.db) 297 if err != nil { 298 log.Println(err) 299 s.pages.Notice(w, "topstarredrepos", "Unable to load.") 300 return 301 } 302 303 s.pages.Home(w, pages.TimelineParams{ 304 LoggedInUser: nil, 305 Timeline: timeline, 306 Repos: repos, 307 }) 308} 309 310func (s *State) Keys(w http.ResponseWriter, r *http.Request) { 311 user := chi.URLParam(r, "user") 312 user = strings.TrimPrefix(user, "@") 313 314 if user == "" { 315 w.WriteHeader(http.StatusBadRequest) 316 return 317 } 318 319 id, err := s.idResolver.ResolveIdent(r.Context(), user) 320 if err != nil { 321 w.WriteHeader(http.StatusInternalServerError) 322 return 323 } 324 325 pubKeys, err := db.GetPublicKeysForDid(s.db, id.DID.String()) 326 if err != nil { 327 w.WriteHeader(http.StatusNotFound) 328 return 329 } 330 331 if len(pubKeys) == 0 { 332 w.WriteHeader(http.StatusNotFound) 333 return 334 } 335 336 for _, k := range pubKeys { 337 key := strings.TrimRight(k.Key, "\n") 338 fmt.Fprintln(w, key) 339 } 340} 341 342func validateRepoName(name string) error { 343 // check for path traversal attempts 344 if name == "." || name == ".." || 345 strings.Contains(name, "/") || strings.Contains(name, "\\") { 346 return fmt.Errorf("Repository name contains invalid path characters") 347 } 348 349 // check for sequences that could be used for traversal when normalized 350 if strings.Contains(name, "./") || strings.Contains(name, "../") || 351 strings.HasPrefix(name, ".") || strings.HasSuffix(name, ".") { 352 return fmt.Errorf("Repository name contains invalid path sequence") 353 } 354 355 // then continue with character validation 356 for _, char := range name { 357 if !((char >= 'a' && char <= 'z') || 358 (char >= 'A' && char <= 'Z') || 359 (char >= '0' && char <= '9') || 360 char == '-' || char == '_' || char == '.') { 361 return fmt.Errorf("Repository name can only contain alphanumeric characters, periods, hyphens, and underscores") 362 } 363 } 364 365 // additional check to prevent multiple sequential dots 366 if strings.Contains(name, "..") { 367 return fmt.Errorf("Repository name cannot contain sequential dots") 368 } 369 370 // if all checks pass 371 return nil 372} 373 374func stripGitExt(name string) string { 375 return strings.TrimSuffix(name, ".git") 376} 377 378func (s *State) NewRepo(w http.ResponseWriter, r *http.Request) { 379 switch r.Method { 380 case http.MethodGet: 381 user := s.oauth.GetUser(r) 382 knots, err := s.enforcer.GetKnotsForUser(user.Did) 383 if err != nil { 384 s.pages.Notice(w, "repo", "Invalid user account.") 385 return 386 } 387 388 s.pages.NewRepo(w, pages.NewRepoParams{ 389 LoggedInUser: user, 390 Knots: knots, 391 }) 392 393 case http.MethodPost: 394 l := s.logger.With("handler", "NewRepo") 395 396 user := s.oauth.GetUser(r) 397 l = l.With("did", user.Did) 398 l = l.With("handle", user.Handle) 399 400 // form validation 401 domain := r.FormValue("domain") 402 if domain == "" { 403 s.pages.Notice(w, "repo", "Invalid form submission&mdash;missing knot domain.") 404 return 405 } 406 l = l.With("knot", domain) 407 408 repoName := r.FormValue("name") 409 if repoName == "" { 410 s.pages.Notice(w, "repo", "Repository name cannot be empty.") 411 return 412 } 413 414 if err := validateRepoName(repoName); err != nil { 415 s.pages.Notice(w, "repo", err.Error()) 416 return 417 } 418 repoName = stripGitExt(repoName) 419 l = l.With("repoName", repoName) 420 421 defaultBranch := r.FormValue("branch") 422 if defaultBranch == "" { 423 defaultBranch = "main" 424 } 425 l = l.With("defaultBranch", defaultBranch) 426 427 description := r.FormValue("description") 428 429 // ACL validation 430 ok, err := s.enforcer.E.Enforce(user.Did, domain, domain, "repo:create") 431 if err != nil || !ok { 432 l.Info("unauthorized") 433 s.pages.Notice(w, "repo", "You do not have permission to create a repo in this knot.") 434 return 435 } 436 437 // Check for existing repos 438 existingRepo, err := db.GetRepo( 439 s.db, 440 db.FilterEq("did", user.Did), 441 db.FilterEq("name", repoName), 442 ) 443 if err == nil && existingRepo != nil { 444 l.Info("repo exists") 445 s.pages.Notice(w, "repo", fmt.Sprintf("You already have a repository by this name on %s", existingRepo.Knot)) 446 return 447 } 448 449 // create atproto record for this repo 450 rkey := tid.TID() 451 repo := &models.Repo{ 452 Did: user.Did, 453 Name: repoName, 454 Knot: domain, 455 Rkey: rkey, 456 Description: description, 457 Created: time.Now(), 458 Labels: models.DefaultLabelDefs(), 459 } 460 record := repo.AsRecord() 461 462 xrpcClient, err := s.oauth.AuthorizedClient(r) 463 if err != nil { 464 l.Info("PDS write failed", "err", err) 465 s.pages.Notice(w, "repo", "Failed to write record to PDS.") 466 return 467 } 468 469 atresp, err := xrpcClient.RepoPutRecord(r.Context(), &comatproto.RepoPutRecord_Input{ 470 Collection: tangled.RepoNSID, 471 Repo: user.Did, 472 Rkey: rkey, 473 Record: &lexutil.LexiconTypeDecoder{ 474 Val: &record, 475 }, 476 }) 477 if err != nil { 478 l.Info("PDS write failed", "err", err) 479 s.pages.Notice(w, "repo", "Failed to announce repository creation.") 480 return 481 } 482 483 aturi := atresp.Uri 484 l = l.With("aturi", aturi) 485 l.Info("wrote to PDS") 486 487 tx, err := s.db.BeginTx(r.Context(), nil) 488 if err != nil { 489 l.Info("txn failed", "err", err) 490 s.pages.Notice(w, "repo", "Failed to save repository information.") 491 return 492 } 493 494 // The rollback function reverts a few things on failure: 495 // - the pending txn 496 // - the ACLs 497 // - the atproto record created 498 rollback := func() { 499 err1 := tx.Rollback() 500 err2 := s.enforcer.E.LoadPolicy() 501 err3 := rollbackRecord(context.Background(), aturi, xrpcClient) 502 503 // ignore txn complete errors, this is okay 504 if errors.Is(err1, sql.ErrTxDone) { 505 err1 = nil 506 } 507 508 if errs := errors.Join(err1, err2, err3); errs != nil { 509 l.Error("failed to rollback changes", "errs", errs) 510 return 511 } 512 } 513 defer rollback() 514 515 client, err := s.oauth.ServiceClient( 516 r, 517 oauth.WithService(domain), 518 oauth.WithLxm(tangled.RepoCreateNSID), 519 oauth.WithDev(s.config.Core.Dev), 520 ) 521 if err != nil { 522 l.Error("service auth failed", "err", err) 523 s.pages.Notice(w, "repo", "Failed to reach PDS.") 524 return 525 } 526 527 xe := tangled.RepoCreate( 528 r.Context(), 529 client, 530 &tangled.RepoCreate_Input{ 531 Rkey: rkey, 532 }, 533 ) 534 if err := xrpcclient.HandleXrpcErr(xe); err != nil { 535 l.Error("xrpc error", "xe", xe) 536 s.pages.Notice(w, "repo", err.Error()) 537 return 538 } 539 540 err = db.AddRepo(tx, repo) 541 if err != nil { 542 l.Error("db write failed", "err", err) 543 s.pages.Notice(w, "repo", "Failed to save repository information.") 544 return 545 } 546 547 // acls 548 p, _ := securejoin.SecureJoin(user.Did, repoName) 549 err = s.enforcer.AddRepo(user.Did, domain, p) 550 if err != nil { 551 l.Error("acl setup failed", "err", err) 552 s.pages.Notice(w, "repo", "Failed to set up repository permissions.") 553 return 554 } 555 556 err = tx.Commit() 557 if err != nil { 558 l.Error("txn commit failed", "err", err) 559 http.Error(w, err.Error(), http.StatusInternalServerError) 560 return 561 } 562 563 err = s.enforcer.E.SavePolicy() 564 if err != nil { 565 l.Error("acl save failed", "err", err) 566 http.Error(w, err.Error(), http.StatusInternalServerError) 567 return 568 } 569 570 // reset the ATURI because the transaction completed successfully 571 aturi = "" 572 573 s.notifier.NewRepo(r.Context(), repo) 574 s.pages.HxLocation(w, fmt.Sprintf("/@%s/%s", user.Handle, repoName)) 575 } 576} 577 578// this is used to rollback changes made to the PDS 579// 580// it is a no-op if the provided ATURI is empty 581func rollbackRecord(ctx context.Context, aturi string, xrpcc *xrpcclient.Client) error { 582 if aturi == "" { 583 return nil 584 } 585 586 parsed := syntax.ATURI(aturi) 587 588 collection := parsed.Collection().String() 589 repo := parsed.Authority().String() 590 rkey := parsed.RecordKey().String() 591 592 _, err := xrpcc.RepoDeleteRecord(ctx, &comatproto.RepoDeleteRecord_Input{ 593 Collection: collection, 594 Repo: repo, 595 Rkey: rkey, 596 }) 597 return err 598} 599 600func BackfillDefaultDefs(e db.Execer, r *idresolver.Resolver) error { 601 defaults := models.DefaultLabelDefs() 602 defaultLabels, err := db.GetLabelDefinitions(e, db.FilterIn("at_uri", defaults)) 603 if err != nil { 604 return err 605 } 606 // already present 607 if len(defaultLabels) == len(defaults) { 608 return nil 609 } 610 611 labelDefs, err := models.FetchDefaultDefs(r) 612 if err != nil { 613 return err 614 } 615 616 // Insert each label definition to the database 617 for _, labelDef := range labelDefs { 618 _, err = db.AddLabelDefinition(e, &labelDef) 619 if err != nil { 620 return fmt.Errorf("failed to add label definition %s: %v", labelDef.Name, err) 621 } 622 } 623 624 return nil 625}