forked from tangled.org/core
this repo has no description
1package state 2 3import ( 4 "context" 5 "database/sql" 6 "errors" 7 "fmt" 8 "log" 9 "log/slog" 10 "net/http" 11 "strings" 12 "time" 13 14 comatproto "github.com/bluesky-social/indigo/api/atproto" 15 "github.com/bluesky-social/indigo/atproto/syntax" 16 lexutil "github.com/bluesky-social/indigo/lex/util" 17 securejoin "github.com/cyphar/filepath-securejoin" 18 "github.com/go-chi/chi/v5" 19 "github.com/posthog/posthog-go" 20 "tangled.org/core/api/tangled" 21 "tangled.org/core/appview" 22 "tangled.org/core/appview/cache" 23 "tangled.org/core/appview/cache/session" 24 "tangled.org/core/appview/config" 25 "tangled.org/core/appview/db" 26 "tangled.org/core/appview/models" 27 "tangled.org/core/appview/notify" 28 dbnotify "tangled.org/core/appview/notify/db" 29 phnotify "tangled.org/core/appview/notify/posthog" 30 "tangled.org/core/appview/oauth" 31 "tangled.org/core/appview/pages" 32 "tangled.org/core/appview/reporesolver" 33 "tangled.org/core/appview/validator" 34 xrpcclient "tangled.org/core/appview/xrpcclient" 35 "tangled.org/core/eventconsumer" 36 "tangled.org/core/idresolver" 37 "tangled.org/core/jetstream" 38 tlog "tangled.org/core/log" 39 "tangled.org/core/rbac" 40 "tangled.org/core/tid" 41) 42 43type State struct { 44 db *db.DB 45 notifier notify.Notifier 46 oauth *oauth.OAuth 47 enforcer *rbac.Enforcer 48 pages *pages.Pages 49 sess *session.SessionStore 50 idResolver *idresolver.Resolver 51 posthog posthog.Client 52 jc *jetstream.JetstreamClient 53 config *config.Config 54 repoResolver *reporesolver.RepoResolver 55 knotstream *eventconsumer.Consumer 56 spindlestream *eventconsumer.Consumer 57 logger *slog.Logger 58 validator *validator.Validator 59} 60 61func Make(ctx context.Context, config *config.Config) (*State, error) { 62 d, err := db.Make(config.Core.DbPath) 63 if err != nil { 64 return nil, fmt.Errorf("failed to create db: %w", err) 65 } 66 67 enforcer, err := rbac.NewEnforcer(config.Core.DbPath) 68 if err != nil { 69 return nil, fmt.Errorf("failed to create enforcer: %w", err) 70 } 71 72 res, err := idresolver.RedisResolver(config.Redis.ToURL()) 73 if err != nil { 74 log.Printf("failed to create redis resolver: %v", err) 75 res = idresolver.DefaultResolver() 76 } 77 78 pgs := pages.NewPages(config, res) 79 cache := cache.New(config.Redis.Addr) 80 sess := session.New(cache) 81 oauth := oauth.NewOAuth(config, sess) 82 validator := validator.New(d, res, enforcer) 83 84 posthog, err := posthog.NewWithConfig(config.Posthog.ApiKey, posthog.Config{Endpoint: config.Posthog.Endpoint}) 85 if err != nil { 86 return nil, fmt.Errorf("failed to create posthog client: %w", err) 87 } 88 89 repoResolver := reporesolver.New(config, enforcer, res, d) 90 91 wrapper := db.DbWrapper{Execer: d} 92 jc, err := jetstream.NewJetstreamClient( 93 config.Jetstream.Endpoint, 94 "appview", 95 []string{ 96 tangled.GraphFollowNSID, 97 tangled.FeedStarNSID, 98 tangled.PublicKeyNSID, 99 tangled.RepoArtifactNSID, 100 tangled.ActorProfileNSID, 101 tangled.SpindleMemberNSID, 102 tangled.SpindleNSID, 103 tangled.StringNSID, 104 tangled.RepoIssueNSID, 105 tangled.RepoIssueCommentNSID, 106 tangled.LabelDefinitionNSID, 107 tangled.LabelOpNSID, 108 }, 109 nil, 110 slog.Default(), 111 wrapper, 112 false, 113 114 // in-memory filter is inapplicalble to appview so 115 // we'll never log dids anyway. 116 false, 117 ) 118 if err != nil { 119 return nil, fmt.Errorf("failed to create jetstream client: %w", err) 120 } 121 122 if err := BackfillDefaultDefs(d, res); err != nil { 123 return nil, fmt.Errorf("failed to backfill default label defs: %w", err) 124 } 125 126 ingester := appview.Ingester{ 127 Db: wrapper, 128 Enforcer: enforcer, 129 IdResolver: res, 130 Config: config, 131 Logger: tlog.New("ingester"), 132 Validator: validator, 133 } 134 err = jc.StartJetstream(ctx, ingester.Ingest()) 135 if err != nil { 136 return nil, fmt.Errorf("failed to start jetstream watcher: %w", err) 137 } 138 139 knotstream, err := Knotstream(ctx, config, d, enforcer, posthog) 140 if err != nil { 141 return nil, fmt.Errorf("failed to start knotstream consumer: %w", err) 142 } 143 knotstream.Start(ctx) 144 145 spindlestream, err := Spindlestream(ctx, config, d, enforcer) 146 if err != nil { 147 return nil, fmt.Errorf("failed to start spindlestream consumer: %w", err) 148 } 149 spindlestream.Start(ctx) 150 151 var notifiers []notify.Notifier 152 153 // Always add the database notifier 154 notifiers = append(notifiers, dbnotify.NewDatabaseNotifier(d, res)) 155 156 // Add other notifiers in production only 157 if !config.Core.Dev { 158 notifiers = append(notifiers, phnotify.NewPosthogNotifier(posthog)) 159 } 160 notifier := notify.NewMergedNotifier(notifiers...) 161 162 state := &State{ 163 d, 164 notifier, 165 oauth, 166 enforcer, 167 pgs, 168 sess, 169 res, 170 posthog, 171 jc, 172 config, 173 repoResolver, 174 knotstream, 175 spindlestream, 176 slog.Default(), 177 validator, 178 } 179 180 return state, nil 181} 182 183func (s *State) Close() error { 184 // other close up logic goes here 185 return s.db.Close() 186} 187 188func (s *State) Favicon(w http.ResponseWriter, r *http.Request) { 189 w.Header().Set("Content-Type", "image/svg+xml") 190 w.Header().Set("Cache-Control", "public, max-age=31536000") // one year 191 w.Header().Set("ETag", `"favicon-svg-v1"`) 192 193 if match := r.Header.Get("If-None-Match"); match == `"favicon-svg-v1"` { 194 w.WriteHeader(http.StatusNotModified) 195 return 196 } 197 198 s.pages.Favicon(w) 199} 200 201func (s *State) TermsOfService(w http.ResponseWriter, r *http.Request) { 202 user := s.oauth.GetUser(r) 203 s.pages.TermsOfService(w, pages.TermsOfServiceParams{ 204 LoggedInUser: user, 205 }) 206} 207 208func (s *State) PrivacyPolicy(w http.ResponseWriter, r *http.Request) { 209 user := s.oauth.GetUser(r) 210 s.pages.PrivacyPolicy(w, pages.PrivacyPolicyParams{ 211 LoggedInUser: user, 212 }) 213} 214 215func (s *State) Brand(w http.ResponseWriter, r *http.Request) { 216 user := s.oauth.GetUser(r) 217 s.pages.Brand(w, pages.BrandParams{ 218 LoggedInUser: user, 219 }) 220} 221 222func (s *State) HomeOrTimeline(w http.ResponseWriter, r *http.Request) { 223 if s.oauth.GetUser(r) != nil { 224 s.Timeline(w, r) 225 return 226 } 227 s.Home(w, r) 228} 229 230func (s *State) Timeline(w http.ResponseWriter, r *http.Request) { 231 user := s.oauth.GetUser(r) 232 233 var userDid string 234 if user != nil { 235 userDid = user.Did 236 } 237 timeline, err := db.MakeTimeline(s.db, 50, userDid) 238 if err != nil { 239 log.Println(err) 240 s.pages.Notice(w, "timeline", "Uh oh! Failed to load timeline.") 241 } 242 243 repos, err := db.GetTopStarredReposLastWeek(s.db) 244 if err != nil { 245 log.Println(err) 246 s.pages.Notice(w, "topstarredrepos", "Unable to load.") 247 return 248 } 249 250 s.pages.Timeline(w, pages.TimelineParams{ 251 LoggedInUser: user, 252 Timeline: timeline, 253 Repos: repos, 254 }) 255} 256 257func (s *State) UpgradeBanner(w http.ResponseWriter, r *http.Request) { 258 user := s.oauth.GetUser(r) 259 if user == nil { 260 return 261 } 262 263 l := s.logger.With("handler", "UpgradeBanner") 264 l = l.With("did", user.Did) 265 l = l.With("handle", user.Handle) 266 267 regs, err := db.GetRegistrations( 268 s.db, 269 db.FilterEq("did", user.Did), 270 db.FilterEq("needs_upgrade", 1), 271 ) 272 if err != nil { 273 l.Error("non-fatal: failed to get registrations", "err", err) 274 } 275 276 spindles, err := db.GetSpindles( 277 s.db, 278 db.FilterEq("owner", user.Did), 279 db.FilterEq("needs_upgrade", 1), 280 ) 281 if err != nil { 282 l.Error("non-fatal: failed to get spindles", "err", err) 283 } 284 285 if regs == nil && spindles == nil { 286 return 287 } 288 289 s.pages.UpgradeBanner(w, pages.UpgradeBannerParams{ 290 Registrations: regs, 291 Spindles: spindles, 292 }) 293} 294 295func (s *State) Home(w http.ResponseWriter, r *http.Request) { 296 timeline, err := db.MakeTimeline(s.db, 5, "") 297 if err != nil { 298 log.Println(err) 299 s.pages.Notice(w, "timeline", "Uh oh! Failed to load timeline.") 300 return 301 } 302 303 repos, err := db.GetTopStarredReposLastWeek(s.db) 304 if err != nil { 305 log.Println(err) 306 s.pages.Notice(w, "topstarredrepos", "Unable to load.") 307 return 308 } 309 310 s.pages.Home(w, pages.TimelineParams{ 311 LoggedInUser: nil, 312 Timeline: timeline, 313 Repos: repos, 314 }) 315} 316 317func (s *State) Keys(w http.ResponseWriter, r *http.Request) { 318 user := chi.URLParam(r, "user") 319 user = strings.TrimPrefix(user, "@") 320 321 if user == "" { 322 w.WriteHeader(http.StatusBadRequest) 323 return 324 } 325 326 id, err := s.idResolver.ResolveIdent(r.Context(), user) 327 if err != nil { 328 w.WriteHeader(http.StatusInternalServerError) 329 return 330 } 331 332 pubKeys, err := db.GetPublicKeysForDid(s.db, id.DID.String()) 333 if err != nil { 334 w.WriteHeader(http.StatusNotFound) 335 return 336 } 337 338 if len(pubKeys) == 0 { 339 w.WriteHeader(http.StatusNotFound) 340 return 341 } 342 343 for _, k := range pubKeys { 344 key := strings.TrimRight(k.Key, "\n") 345 fmt.Fprintln(w, key) 346 } 347} 348 349func validateRepoName(name string) error { 350 // check for path traversal attempts 351 if name == "." || name == ".." || 352 strings.Contains(name, "/") || strings.Contains(name, "\\") { 353 return fmt.Errorf("Repository name contains invalid path characters") 354 } 355 356 // check for sequences that could be used for traversal when normalized 357 if strings.Contains(name, "./") || strings.Contains(name, "../") || 358 strings.HasPrefix(name, ".") || strings.HasSuffix(name, ".") { 359 return fmt.Errorf("Repository name contains invalid path sequence") 360 } 361 362 // then continue with character validation 363 for _, char := range name { 364 if !((char >= 'a' && char <= 'z') || 365 (char >= 'A' && char <= 'Z') || 366 (char >= '0' && char <= '9') || 367 char == '-' || char == '_' || char == '.') { 368 return fmt.Errorf("Repository name can only contain alphanumeric characters, periods, hyphens, and underscores") 369 } 370 } 371 372 // additional check to prevent multiple sequential dots 373 if strings.Contains(name, "..") { 374 return fmt.Errorf("Repository name cannot contain sequential dots") 375 } 376 377 // if all checks pass 378 return nil 379} 380 381func stripGitExt(name string) string { 382 return strings.TrimSuffix(name, ".git") 383} 384 385func (s *State) NewRepo(w http.ResponseWriter, r *http.Request) { 386 switch r.Method { 387 case http.MethodGet: 388 user := s.oauth.GetUser(r) 389 knots, err := s.enforcer.GetKnotsForUser(user.Did) 390 if err != nil { 391 s.pages.Notice(w, "repo", "Invalid user account.") 392 return 393 } 394 395 s.pages.NewRepo(w, pages.NewRepoParams{ 396 LoggedInUser: user, 397 Knots: knots, 398 }) 399 400 case http.MethodPost: 401 l := s.logger.With("handler", "NewRepo") 402 403 user := s.oauth.GetUser(r) 404 l = l.With("did", user.Did) 405 l = l.With("handle", user.Handle) 406 407 // form validation 408 domain := r.FormValue("domain") 409 if domain == "" { 410 s.pages.Notice(w, "repo", "Invalid form submission&mdash;missing knot domain.") 411 return 412 } 413 l = l.With("knot", domain) 414 415 repoName := r.FormValue("name") 416 if repoName == "" { 417 s.pages.Notice(w, "repo", "Repository name cannot be empty.") 418 return 419 } 420 421 if err := validateRepoName(repoName); err != nil { 422 s.pages.Notice(w, "repo", err.Error()) 423 return 424 } 425 repoName = stripGitExt(repoName) 426 l = l.With("repoName", repoName) 427 428 defaultBranch := r.FormValue("branch") 429 if defaultBranch == "" { 430 defaultBranch = "main" 431 } 432 l = l.With("defaultBranch", defaultBranch) 433 434 description := r.FormValue("description") 435 436 // ACL validation 437 ok, err := s.enforcer.E.Enforce(user.Did, domain, domain, "repo:create") 438 if err != nil || !ok { 439 l.Info("unauthorized") 440 s.pages.Notice(w, "repo", "You do not have permission to create a repo in this knot.") 441 return 442 } 443 444 // Check for existing repos 445 existingRepo, err := db.GetRepo( 446 s.db, 447 db.FilterEq("did", user.Did), 448 db.FilterEq("name", repoName), 449 ) 450 if err == nil && existingRepo != nil { 451 l.Info("repo exists") 452 s.pages.Notice(w, "repo", fmt.Sprintf("You already have a repository by this name on %s", existingRepo.Knot)) 453 return 454 } 455 456 // create atproto record for this repo 457 rkey := tid.TID() 458 repo := &models.Repo{ 459 Did: user.Did, 460 Name: repoName, 461 Knot: domain, 462 Rkey: rkey, 463 Description: description, 464 Created: time.Now(), 465 Labels: models.DefaultLabelDefs(), 466 } 467 record := repo.AsRecord() 468 469 xrpcClient, err := s.oauth.AuthorizedClient(r) 470 if err != nil { 471 l.Info("PDS write failed", "err", err) 472 s.pages.Notice(w, "repo", "Failed to write record to PDS.") 473 return 474 } 475 476 atresp, err := xrpcClient.RepoPutRecord(r.Context(), &comatproto.RepoPutRecord_Input{ 477 Collection: tangled.RepoNSID, 478 Repo: user.Did, 479 Rkey: rkey, 480 Record: &lexutil.LexiconTypeDecoder{ 481 Val: &record, 482 }, 483 }) 484 if err != nil { 485 l.Info("PDS write failed", "err", err) 486 s.pages.Notice(w, "repo", "Failed to announce repository creation.") 487 return 488 } 489 490 aturi := atresp.Uri 491 l = l.With("aturi", aturi) 492 l.Info("wrote to PDS") 493 494 tx, err := s.db.BeginTx(r.Context(), nil) 495 if err != nil { 496 l.Info("txn failed", "err", err) 497 s.pages.Notice(w, "repo", "Failed to save repository information.") 498 return 499 } 500 501 // The rollback function reverts a few things on failure: 502 // - the pending txn 503 // - the ACLs 504 // - the atproto record created 505 rollback := func() { 506 err1 := tx.Rollback() 507 err2 := s.enforcer.E.LoadPolicy() 508 err3 := rollbackRecord(context.Background(), aturi, xrpcClient) 509 510 // ignore txn complete errors, this is okay 511 if errors.Is(err1, sql.ErrTxDone) { 512 err1 = nil 513 } 514 515 if errs := errors.Join(err1, err2, err3); errs != nil { 516 l.Error("failed to rollback changes", "errs", errs) 517 return 518 } 519 } 520 defer rollback() 521 522 client, err := s.oauth.ServiceClient( 523 r, 524 oauth.WithService(domain), 525 oauth.WithLxm(tangled.RepoCreateNSID), 526 oauth.WithDev(s.config.Core.Dev), 527 ) 528 if err != nil { 529 l.Error("service auth failed", "err", err) 530 s.pages.Notice(w, "repo", "Failed to reach PDS.") 531 return 532 } 533 534 xe := tangled.RepoCreate( 535 r.Context(), 536 client, 537 &tangled.RepoCreate_Input{ 538 Rkey: rkey, 539 }, 540 ) 541 if err := xrpcclient.HandleXrpcErr(xe); err != nil { 542 l.Error("xrpc error", "xe", xe) 543 s.pages.Notice(w, "repo", err.Error()) 544 return 545 } 546 547 err = db.AddRepo(tx, repo) 548 if err != nil { 549 l.Error("db write failed", "err", err) 550 s.pages.Notice(w, "repo", "Failed to save repository information.") 551 return 552 } 553 554 // acls 555 p, _ := securejoin.SecureJoin(user.Did, repoName) 556 err = s.enforcer.AddRepo(user.Did, domain, p) 557 if err != nil { 558 l.Error("acl setup failed", "err", err) 559 s.pages.Notice(w, "repo", "Failed to set up repository permissions.") 560 return 561 } 562 563 err = tx.Commit() 564 if err != nil { 565 l.Error("txn commit failed", "err", err) 566 http.Error(w, err.Error(), http.StatusInternalServerError) 567 return 568 } 569 570 err = s.enforcer.E.SavePolicy() 571 if err != nil { 572 l.Error("acl save failed", "err", err) 573 http.Error(w, err.Error(), http.StatusInternalServerError) 574 return 575 } 576 577 // reset the ATURI because the transaction completed successfully 578 aturi = "" 579 580 s.notifier.NewRepo(r.Context(), repo) 581 s.pages.HxLocation(w, fmt.Sprintf("/@%s/%s", user.Handle, repoName)) 582 } 583} 584 585// this is used to rollback changes made to the PDS 586// 587// it is a no-op if the provided ATURI is empty 588func rollbackRecord(ctx context.Context, aturi string, xrpcc *xrpcclient.Client) error { 589 if aturi == "" { 590 return nil 591 } 592 593 parsed := syntax.ATURI(aturi) 594 595 collection := parsed.Collection().String() 596 repo := parsed.Authority().String() 597 rkey := parsed.RecordKey().String() 598 599 _, err := xrpcc.RepoDeleteRecord(ctx, &comatproto.RepoDeleteRecord_Input{ 600 Collection: collection, 601 Repo: repo, 602 Rkey: rkey, 603 }) 604 return err 605} 606 607func BackfillDefaultDefs(e db.Execer, r *idresolver.Resolver) error { 608 defaults := models.DefaultLabelDefs() 609 defaultLabels, err := db.GetLabelDefinitions(e, db.FilterIn("at_uri", defaults)) 610 if err != nil { 611 return err 612 } 613 // already present 614 if len(defaultLabels) == len(defaults) { 615 return nil 616 } 617 618 labelDefs, err := models.FetchDefaultDefs(r) 619 if err != nil { 620 return err 621 } 622 623 // Insert each label definition to the database 624 for _, labelDef := range labelDefs { 625 _, err = db.AddLabelDefinition(e, &labelDef) 626 if err != nil { 627 return fmt.Errorf("failed to add label definition %s: %v", labelDef.Name, err) 628 } 629 } 630 631 return nil 632}