forked from tangled.org/core
Monorepo for Tangled — https://tangled.org
1package state 2 3import ( 4 "context" 5 "database/sql" 6 "errors" 7 "fmt" 8 "log" 9 "log/slog" 10 "net/http" 11 "strings" 12 "time" 13 14 comatproto "github.com/bluesky-social/indigo/api/atproto" 15 "github.com/bluesky-social/indigo/atproto/syntax" 16 lexutil "github.com/bluesky-social/indigo/lex/util" 17 securejoin "github.com/cyphar/filepath-securejoin" 18 "github.com/go-chi/chi/v5" 19 "github.com/posthog/posthog-go" 20 "tangled.org/core/api/tangled" 21 "tangled.org/core/appview" 22 "tangled.org/core/appview/cache" 23 "tangled.org/core/appview/cache/session" 24 "tangled.org/core/appview/config" 25 "tangled.org/core/appview/db" 26 "tangled.org/core/appview/models" 27 "tangled.org/core/appview/notify" 28 "tangled.org/core/appview/oauth" 29 "tangled.org/core/appview/pages" 30 posthogService "tangled.org/core/appview/posthog" 31 "tangled.org/core/appview/reporesolver" 32 "tangled.org/core/appview/validator" 33 xrpcclient "tangled.org/core/appview/xrpcclient" 34 "tangled.org/core/eventconsumer" 35 "tangled.org/core/idresolver" 36 "tangled.org/core/jetstream" 37 tlog "tangled.org/core/log" 38 "tangled.org/core/rbac" 39 "tangled.org/core/tid" 40) 41 42type State struct { 43 db *db.DB 44 notifier notify.Notifier 45 oauth *oauth.OAuth 46 enforcer *rbac.Enforcer 47 pages *pages.Pages 48 sess *session.SessionStore 49 idResolver *idresolver.Resolver 50 posthog posthog.Client 51 jc *jetstream.JetstreamClient 52 config *config.Config 53 repoResolver *reporesolver.RepoResolver 54 knotstream *eventconsumer.Consumer 55 spindlestream *eventconsumer.Consumer 56 logger *slog.Logger 57 validator *validator.Validator 58} 59 60func Make(ctx context.Context, config *config.Config) (*State, error) { 61 d, err := db.Make(config.Core.DbPath) 62 if err != nil { 63 return nil, fmt.Errorf("failed to create db: %w", err) 64 } 65 66 enforcer, err := rbac.NewEnforcer(config.Core.DbPath) 67 if err != nil { 68 return nil, fmt.Errorf("failed to create enforcer: %w", err) 69 } 70 71 res, err := idresolver.RedisResolver(config.Redis.ToURL()) 72 if err != nil { 73 log.Printf("failed to create redis resolver: %v", err) 74 res = idresolver.DefaultResolver() 75 } 76 77 pgs := pages.NewPages(config, res) 78 cache := cache.New(config.Redis.Addr) 79 sess := session.New(cache) 80 oauth := oauth.NewOAuth(config, sess) 81 validator := validator.New(d, res) 82 83 posthog, err := posthog.NewWithConfig(config.Posthog.ApiKey, posthog.Config{Endpoint: config.Posthog.Endpoint}) 84 if err != nil { 85 return nil, fmt.Errorf("failed to create posthog client: %w", err) 86 } 87 88 repoResolver := reporesolver.New(config, enforcer, res, d) 89 90 wrapper := db.DbWrapper{Execer: d} 91 jc, err := jetstream.NewJetstreamClient( 92 config.Jetstream.Endpoint, 93 "appview", 94 []string{ 95 tangled.GraphFollowNSID, 96 tangled.FeedStarNSID, 97 tangled.PublicKeyNSID, 98 tangled.RepoArtifactNSID, 99 tangled.ActorProfileNSID, 100 tangled.SpindleMemberNSID, 101 tangled.SpindleNSID, 102 tangled.StringNSID, 103 tangled.RepoIssueNSID, 104 tangled.RepoIssueCommentNSID, 105 tangled.LabelDefinitionNSID, 106 tangled.LabelOpNSID, 107 }, 108 nil, 109 slog.Default(), 110 wrapper, 111 false, 112 113 // in-memory filter is inapplicalble to appview so 114 // we'll never log dids anyway. 115 false, 116 ) 117 if err != nil { 118 return nil, fmt.Errorf("failed to create jetstream client: %w", err) 119 } 120 121 if err := db.BackfillDefaultDefs(d, res); err != nil { 122 return nil, fmt.Errorf("failed to backfill default label defs: %w", err) 123 } 124 125 ingester := appview.Ingester{ 126 Db: wrapper, 127 Enforcer: enforcer, 128 IdResolver: res, 129 Config: config, 130 Logger: tlog.New("ingester"), 131 Validator: validator, 132 } 133 err = jc.StartJetstream(ctx, ingester.Ingest()) 134 if err != nil { 135 return nil, fmt.Errorf("failed to start jetstream watcher: %w", err) 136 } 137 138 knotstream, err := Knotstream(ctx, config, d, enforcer, posthog) 139 if err != nil { 140 return nil, fmt.Errorf("failed to start knotstream consumer: %w", err) 141 } 142 knotstream.Start(ctx) 143 144 spindlestream, err := Spindlestream(ctx, config, d, enforcer) 145 if err != nil { 146 return nil, fmt.Errorf("failed to start spindlestream consumer: %w", err) 147 } 148 spindlestream.Start(ctx) 149 150 var notifiers []notify.Notifier 151 if !config.Core.Dev { 152 notifiers = append(notifiers, posthogService.NewPosthogNotifier(posthog)) 153 } 154 notifier := notify.NewMergedNotifier(notifiers...) 155 156 state := &State{ 157 d, 158 notifier, 159 oauth, 160 enforcer, 161 pgs, 162 sess, 163 res, 164 posthog, 165 jc, 166 config, 167 repoResolver, 168 knotstream, 169 spindlestream, 170 slog.Default(), 171 validator, 172 } 173 174 return state, nil 175} 176 177func (s *State) Close() error { 178 // other close up logic goes here 179 return s.db.Close() 180} 181 182func (s *State) Favicon(w http.ResponseWriter, r *http.Request) { 183 w.Header().Set("Content-Type", "image/svg+xml") 184 w.Header().Set("Cache-Control", "public, max-age=31536000") // one year 185 w.Header().Set("ETag", `"favicon-svg-v1"`) 186 187 if match := r.Header.Get("If-None-Match"); match == `"favicon-svg-v1"` { 188 w.WriteHeader(http.StatusNotModified) 189 return 190 } 191 192 s.pages.Favicon(w) 193} 194 195func (s *State) TermsOfService(w http.ResponseWriter, r *http.Request) { 196 user := s.oauth.GetUser(r) 197 s.pages.TermsOfService(w, pages.TermsOfServiceParams{ 198 LoggedInUser: user, 199 }) 200} 201 202func (s *State) PrivacyPolicy(w http.ResponseWriter, r *http.Request) { 203 user := s.oauth.GetUser(r) 204 s.pages.PrivacyPolicy(w, pages.PrivacyPolicyParams{ 205 LoggedInUser: user, 206 }) 207} 208 209func (s *State) HomeOrTimeline(w http.ResponseWriter, r *http.Request) { 210 if s.oauth.GetUser(r) != nil { 211 s.Timeline(w, r) 212 return 213 } 214 s.Home(w, r) 215} 216 217func (s *State) Timeline(w http.ResponseWriter, r *http.Request) { 218 user := s.oauth.GetUser(r) 219 220 var userDid string 221 if user != nil { 222 userDid = user.Did 223 } 224 timeline, err := db.MakeTimeline(s.db, 50, userDid) 225 if err != nil { 226 log.Println(err) 227 s.pages.Notice(w, "timeline", "Uh oh! Failed to load timeline.") 228 } 229 230 repos, err := db.GetTopStarredReposLastWeek(s.db) 231 if err != nil { 232 log.Println(err) 233 s.pages.Notice(w, "topstarredrepos", "Unable to load.") 234 return 235 } 236 237 s.pages.Timeline(w, pages.TimelineParams{ 238 LoggedInUser: user, 239 Timeline: timeline, 240 Repos: repos, 241 }) 242} 243 244func (s *State) UpgradeBanner(w http.ResponseWriter, r *http.Request) { 245 user := s.oauth.GetUser(r) 246 l := s.logger.With("handler", "UpgradeBanner") 247 l = l.With("did", user.Did) 248 l = l.With("handle", user.Handle) 249 250 regs, err := db.GetRegistrations( 251 s.db, 252 db.FilterEq("did", user.Did), 253 db.FilterEq("needs_upgrade", 1), 254 ) 255 if err != nil { 256 l.Error("non-fatal: failed to get registrations", "err", err) 257 } 258 259 spindles, err := db.GetSpindles( 260 s.db, 261 db.FilterEq("owner", user.Did), 262 db.FilterEq("needs_upgrade", 1), 263 ) 264 if err != nil { 265 l.Error("non-fatal: failed to get spindles", "err", err) 266 } 267 268 if regs == nil && spindles == nil { 269 return 270 } 271 272 s.pages.UpgradeBanner(w, pages.UpgradeBannerParams{ 273 Registrations: regs, 274 Spindles: spindles, 275 }) 276} 277 278func (s *State) Home(w http.ResponseWriter, r *http.Request) { 279 timeline, err := db.MakeTimeline(s.db, 5, "") 280 if err != nil { 281 log.Println(err) 282 s.pages.Notice(w, "timeline", "Uh oh! Failed to load timeline.") 283 return 284 } 285 286 repos, err := db.GetTopStarredReposLastWeek(s.db) 287 if err != nil { 288 log.Println(err) 289 s.pages.Notice(w, "topstarredrepos", "Unable to load.") 290 return 291 } 292 293 s.pages.Home(w, pages.TimelineParams{ 294 LoggedInUser: nil, 295 Timeline: timeline, 296 Repos: repos, 297 }) 298} 299 300func (s *State) Keys(w http.ResponseWriter, r *http.Request) { 301 user := chi.URLParam(r, "user") 302 user = strings.TrimPrefix(user, "@") 303 304 if user == "" { 305 w.WriteHeader(http.StatusBadRequest) 306 return 307 } 308 309 id, err := s.idResolver.ResolveIdent(r.Context(), user) 310 if err != nil { 311 w.WriteHeader(http.StatusInternalServerError) 312 return 313 } 314 315 pubKeys, err := db.GetPublicKeysForDid(s.db, id.DID.String()) 316 if err != nil { 317 w.WriteHeader(http.StatusNotFound) 318 return 319 } 320 321 if len(pubKeys) == 0 { 322 w.WriteHeader(http.StatusNotFound) 323 return 324 } 325 326 for _, k := range pubKeys { 327 key := strings.TrimRight(k.Key, "\n") 328 fmt.Fprintln(w, key) 329 } 330} 331 332func validateRepoName(name string) error { 333 // check for path traversal attempts 334 if name == "." || name == ".." || 335 strings.Contains(name, "/") || strings.Contains(name, "\\") { 336 return fmt.Errorf("Repository name contains invalid path characters") 337 } 338 339 // check for sequences that could be used for traversal when normalized 340 if strings.Contains(name, "./") || strings.Contains(name, "../") || 341 strings.HasPrefix(name, ".") || strings.HasSuffix(name, ".") { 342 return fmt.Errorf("Repository name contains invalid path sequence") 343 } 344 345 // then continue with character validation 346 for _, char := range name { 347 if !((char >= 'a' && char <= 'z') || 348 (char >= 'A' && char <= 'Z') || 349 (char >= '0' && char <= '9') || 350 char == '-' || char == '_' || char == '.') { 351 return fmt.Errorf("Repository name can only contain alphanumeric characters, periods, hyphens, and underscores") 352 } 353 } 354 355 // additional check to prevent multiple sequential dots 356 if strings.Contains(name, "..") { 357 return fmt.Errorf("Repository name cannot contain sequential dots") 358 } 359 360 // if all checks pass 361 return nil 362} 363 364func stripGitExt(name string) string { 365 return strings.TrimSuffix(name, ".git") 366} 367 368func (s *State) NewRepo(w http.ResponseWriter, r *http.Request) { 369 switch r.Method { 370 case http.MethodGet: 371 user := s.oauth.GetUser(r) 372 knots, err := s.enforcer.GetKnotsForUser(user.Did) 373 if err != nil { 374 s.pages.Notice(w, "repo", "Invalid user account.") 375 return 376 } 377 378 s.pages.NewRepo(w, pages.NewRepoParams{ 379 LoggedInUser: user, 380 Knots: knots, 381 }) 382 383 case http.MethodPost: 384 l := s.logger.With("handler", "NewRepo") 385 386 user := s.oauth.GetUser(r) 387 l = l.With("did", user.Did) 388 l = l.With("handle", user.Handle) 389 390 // form validation 391 domain := r.FormValue("domain") 392 if domain == "" { 393 s.pages.Notice(w, "repo", "Invalid form submission&mdash;missing knot domain.") 394 return 395 } 396 l = l.With("knot", domain) 397 398 repoName := r.FormValue("name") 399 if repoName == "" { 400 s.pages.Notice(w, "repo", "Repository name cannot be empty.") 401 return 402 } 403 404 if err := validateRepoName(repoName); err != nil { 405 s.pages.Notice(w, "repo", err.Error()) 406 return 407 } 408 repoName = stripGitExt(repoName) 409 l = l.With("repoName", repoName) 410 411 defaultBranch := r.FormValue("branch") 412 if defaultBranch == "" { 413 defaultBranch = "main" 414 } 415 l = l.With("defaultBranch", defaultBranch) 416 417 description := r.FormValue("description") 418 419 // ACL validation 420 ok, err := s.enforcer.E.Enforce(user.Did, domain, domain, "repo:create") 421 if err != nil || !ok { 422 l.Info("unauthorized") 423 s.pages.Notice(w, "repo", "You do not have permission to create a repo in this knot.") 424 return 425 } 426 427 // Check for existing repos 428 existingRepo, err := db.GetRepo( 429 s.db, 430 db.FilterEq("did", user.Did), 431 db.FilterEq("name", repoName), 432 ) 433 if err == nil && existingRepo != nil { 434 l.Info("repo exists") 435 s.pages.Notice(w, "repo", fmt.Sprintf("You already have a repository by this name on %s", existingRepo.Knot)) 436 return 437 } 438 439 // create atproto record for this repo 440 rkey := tid.TID() 441 repo := &models.Repo{ 442 Did: user.Did, 443 Name: repoName, 444 Knot: domain, 445 Rkey: rkey, 446 Description: description, 447 Created: time.Now(), 448 Labels: models.DefaultLabelDefs(), 449 } 450 record := repo.AsRecord() 451 452 xrpcClient, err := s.oauth.AuthorizedClient(r) 453 if err != nil { 454 l.Info("PDS write failed", "err", err) 455 s.pages.Notice(w, "repo", "Failed to write record to PDS.") 456 return 457 } 458 459 atresp, err := xrpcClient.RepoPutRecord(r.Context(), &comatproto.RepoPutRecord_Input{ 460 Collection: tangled.RepoNSID, 461 Repo: user.Did, 462 Rkey: rkey, 463 Record: &lexutil.LexiconTypeDecoder{ 464 Val: &record, 465 }, 466 }) 467 if err != nil { 468 l.Info("PDS write failed", "err", err) 469 s.pages.Notice(w, "repo", "Failed to announce repository creation.") 470 return 471 } 472 473 aturi := atresp.Uri 474 l = l.With("aturi", aturi) 475 l.Info("wrote to PDS") 476 477 tx, err := s.db.BeginTx(r.Context(), nil) 478 if err != nil { 479 l.Info("txn failed", "err", err) 480 s.pages.Notice(w, "repo", "Failed to save repository information.") 481 return 482 } 483 484 // The rollback function reverts a few things on failure: 485 // - the pending txn 486 // - the ACLs 487 // - the atproto record created 488 rollback := func() { 489 err1 := tx.Rollback() 490 err2 := s.enforcer.E.LoadPolicy() 491 err3 := rollbackRecord(context.Background(), aturi, xrpcClient) 492 493 // ignore txn complete errors, this is okay 494 if errors.Is(err1, sql.ErrTxDone) { 495 err1 = nil 496 } 497 498 if errs := errors.Join(err1, err2, err3); errs != nil { 499 l.Error("failed to rollback changes", "errs", errs) 500 return 501 } 502 } 503 defer rollback() 504 505 client, err := s.oauth.ServiceClient( 506 r, 507 oauth.WithService(domain), 508 oauth.WithLxm(tangled.RepoCreateNSID), 509 oauth.WithDev(s.config.Core.Dev), 510 ) 511 if err != nil { 512 l.Error("service auth failed", "err", err) 513 s.pages.Notice(w, "repo", "Failed to reach PDS.") 514 return 515 } 516 517 xe := tangled.RepoCreate( 518 r.Context(), 519 client, 520 &tangled.RepoCreate_Input{ 521 Rkey: rkey, 522 }, 523 ) 524 if err := xrpcclient.HandleXrpcErr(xe); err != nil { 525 l.Error("xrpc error", "xe", xe) 526 s.pages.Notice(w, "repo", err.Error()) 527 return 528 } 529 530 err = db.AddRepo(tx, repo) 531 if err != nil { 532 l.Error("db write failed", "err", err) 533 s.pages.Notice(w, "repo", "Failed to save repository information.") 534 return 535 } 536 537 // acls 538 p, _ := securejoin.SecureJoin(user.Did, repoName) 539 err = s.enforcer.AddRepo(user.Did, domain, p) 540 if err != nil { 541 l.Error("acl setup failed", "err", err) 542 s.pages.Notice(w, "repo", "Failed to set up repository permissions.") 543 return 544 } 545 546 err = tx.Commit() 547 if err != nil { 548 l.Error("txn commit failed", "err", err) 549 http.Error(w, err.Error(), http.StatusInternalServerError) 550 return 551 } 552 553 err = s.enforcer.E.SavePolicy() 554 if err != nil { 555 l.Error("acl save failed", "err", err) 556 http.Error(w, err.Error(), http.StatusInternalServerError) 557 return 558 } 559 560 // reset the ATURI because the transaction completed successfully 561 aturi = "" 562 563 s.notifier.NewRepo(r.Context(), repo) 564 s.pages.HxLocation(w, fmt.Sprintf("/@%s/%s", user.Handle, repoName)) 565 } 566} 567 568// this is used to rollback changes made to the PDS 569// 570// it is a no-op if the provided ATURI is empty 571func rollbackRecord(ctx context.Context, aturi string, xrpcc *xrpcclient.Client) error { 572 if aturi == "" { 573 return nil 574 } 575 576 parsed := syntax.ATURI(aturi) 577 578 collection := parsed.Collection().String() 579 repo := parsed.Authority().String() 580 rkey := parsed.RecordKey().String() 581 582 _, err := xrpcc.RepoDeleteRecord(ctx, &comatproto.RepoDeleteRecord_Input{ 583 Collection: collection, 584 Repo: repo, 585 Rkey: rkey, 586 }) 587 return err 588} 589 590func BackfillDefaultDefs(e db.Execer, r *idresolver.Resolver) error { 591 defaults := models.DefaultLabelDefs() 592 defaultLabels, err := db.GetLabelDefinitions(e, db.FilterIn("at_uri", defaults)) 593 if err != nil { 594 return err 595 } 596 // already present 597 if len(defaultLabels) == len(defaults) { 598 return nil 599 } 600 601 labelDefs, err := models.FetchDefaultDefs(r) 602 if err != nil { 603 return err 604 } 605 606 // Insert each label definition to the database 607 for _, labelDef := range labelDefs { 608 _, err = db.AddLabelDefinition(e, &labelDef) 609 if err != nil { 610 return fmt.Errorf("failed to add label definition %s: %v", labelDef.Name, err) 611 } 612 } 613 614 return nil 615}