forked from tangled.org/core
Monorepo for Tangled — https://tangled.org
1package db 2 3import ( 4 "cmp" 5 "database/sql" 6 "errors" 7 "fmt" 8 "maps" 9 "slices" 10 "sort" 11 "strings" 12 "time" 13 14 "github.com/bluesky-social/indigo/atproto/syntax" 15 "tangled.org/core/appview/models" 16) 17 18func NewPull(tx *sql.Tx, pull *models.Pull) error { 19 _, err := tx.Exec(` 20 insert or ignore into repo_pull_seqs (repo_at, next_pull_id) 21 values (?, 1) 22 `, pull.RepoAt) 23 if err != nil { 24 return err 25 } 26 27 var nextId int 28 err = tx.QueryRow(` 29 update repo_pull_seqs 30 set next_pull_id = next_pull_id + 1 31 where repo_at = ? 32 returning next_pull_id - 1 33 `, pull.RepoAt).Scan(&nextId) 34 if err != nil { 35 return err 36 } 37 38 pull.PullId = nextId 39 pull.State = models.PullOpen 40 41 var sourceBranch, sourceRepoAt *string 42 if pull.PullSource != nil { 43 sourceBranch = &pull.PullSource.Branch 44 if pull.PullSource.RepoAt != nil { 45 x := pull.PullSource.RepoAt.String() 46 sourceRepoAt = &x 47 } 48 } 49 50 var stackId, changeId, parentChangeId *string 51 if pull.StackId != "" { 52 stackId = &pull.StackId 53 } 54 if pull.ChangeId != "" { 55 changeId = &pull.ChangeId 56 } 57 if pull.ParentChangeId != "" { 58 parentChangeId = &pull.ParentChangeId 59 } 60 61 result, err := tx.Exec( 62 ` 63 insert into pulls ( 64 repo_at, owner_did, pull_id, title, target_branch, body, rkey, state, source_branch, source_repo_at, stack_id, change_id, parent_change_id 65 ) 66 values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`, 67 pull.RepoAt, 68 pull.OwnerDid, 69 pull.PullId, 70 pull.Title, 71 pull.TargetBranch, 72 pull.Body, 73 pull.Rkey, 74 pull.State, 75 sourceBranch, 76 sourceRepoAt, 77 stackId, 78 changeId, 79 parentChangeId, 80 ) 81 if err != nil { 82 return err 83 } 84 85 // Set the database primary key ID 86 id, err := result.LastInsertId() 87 if err != nil { 88 return err 89 } 90 pull.ID = int(id) 91 92 _, err = tx.Exec(` 93 insert into pull_submissions (pull_at, round_number, patch, source_rev) 94 values (?, ?, ?, ?) 95 `, pull.PullAt(), 0, pull.Submissions[0].Patch, pull.Submissions[0].SourceRev) 96 return err 97} 98 99func GetPullAt(e Execer, repoAt syntax.ATURI, pullId int) (syntax.ATURI, error) { 100 pull, err := GetPull(e, repoAt, pullId) 101 if err != nil { 102 return "", err 103 } 104 return pull.PullAt(), err 105} 106 107func NextPullId(e Execer, repoAt syntax.ATURI) (int, error) { 108 var pullId int 109 err := e.QueryRow(`select next_pull_id from repo_pull_seqs where repo_at = ?`, repoAt).Scan(&pullId) 110 return pullId - 1, err 111} 112 113func GetPullsWithLimit(e Execer, limit int, filters ...filter) ([]*models.Pull, error) { 114 pulls := make(map[syntax.ATURI]*models.Pull) 115 116 var conditions []string 117 var args []any 118 for _, filter := range filters { 119 conditions = append(conditions, filter.Condition()) 120 args = append(args, filter.Arg()...) 121 } 122 123 whereClause := "" 124 if conditions != nil { 125 whereClause = " where " + strings.Join(conditions, " and ") 126 } 127 limitClause := "" 128 if limit != 0 { 129 limitClause = fmt.Sprintf(" limit %d ", limit) 130 } 131 132 query := fmt.Sprintf(` 133 select 134 id, 135 owner_did, 136 repo_at, 137 pull_id, 138 created, 139 title, 140 state, 141 target_branch, 142 body, 143 rkey, 144 source_branch, 145 source_repo_at, 146 stack_id, 147 change_id, 148 parent_change_id 149 from 150 pulls 151 %s 152 order by 153 created desc 154 %s 155 `, whereClause, limitClause) 156 157 rows, err := e.Query(query, args...) 158 if err != nil { 159 return nil, err 160 } 161 defer rows.Close() 162 163 for rows.Next() { 164 var pull models.Pull 165 var createdAt string 166 var sourceBranch, sourceRepoAt, stackId, changeId, parentChangeId sql.NullString 167 err := rows.Scan( 168 &pull.ID, 169 &pull.OwnerDid, 170 &pull.RepoAt, 171 &pull.PullId, 172 &createdAt, 173 &pull.Title, 174 &pull.State, 175 &pull.TargetBranch, 176 &pull.Body, 177 &pull.Rkey, 178 &sourceBranch, 179 &sourceRepoAt, 180 &stackId, 181 &changeId, 182 &parentChangeId, 183 ) 184 if err != nil { 185 return nil, err 186 } 187 188 createdTime, err := time.Parse(time.RFC3339, createdAt) 189 if err != nil { 190 return nil, err 191 } 192 pull.Created = createdTime 193 194 if sourceBranch.Valid { 195 pull.PullSource = &models.PullSource{ 196 Branch: sourceBranch.String, 197 } 198 if sourceRepoAt.Valid { 199 sourceRepoAtParsed, err := syntax.ParseATURI(sourceRepoAt.String) 200 if err != nil { 201 return nil, err 202 } 203 pull.PullSource.RepoAt = &sourceRepoAtParsed 204 } 205 } 206 207 if stackId.Valid { 208 pull.StackId = stackId.String 209 } 210 if changeId.Valid { 211 pull.ChangeId = changeId.String 212 } 213 if parentChangeId.Valid { 214 pull.ParentChangeId = parentChangeId.String 215 } 216 217 pulls[pull.PullAt()] = &pull 218 } 219 220 var pullAts []syntax.ATURI 221 for _, p := range pulls { 222 pullAts = append(pullAts, p.PullAt()) 223 } 224 submissionsMap, err := GetPullSubmissions(e, FilterIn("pull_at", pullAts)) 225 if err != nil { 226 return nil, fmt.Errorf("failed to get submissions: %w", err) 227 } 228 229 for pullAt, submissions := range submissionsMap { 230 if p, ok := pulls[pullAt]; ok { 231 p.Submissions = submissions 232 } 233 } 234 235 // collect allLabels for each issue 236 allLabels, err := GetLabels(e, FilterIn("subject", pullAts)) 237 if err != nil { 238 return nil, fmt.Errorf("failed to query labels: %w", err) 239 } 240 for pullAt, labels := range allLabels { 241 if p, ok := pulls[pullAt]; ok { 242 p.Labels = labels 243 } 244 } 245 246 // collect pull source for all pulls that need it 247 var sourceAts []syntax.ATURI 248 for _, p := range pulls { 249 if p.PullSource != nil && p.PullSource.RepoAt != nil { 250 sourceAts = append(sourceAts, *p.PullSource.RepoAt) 251 } 252 } 253 sourceRepos, err := GetRepos(e, 0, FilterIn("at_uri", sourceAts)) 254 if err != nil && !errors.Is(err, sql.ErrNoRows) { 255 return nil, fmt.Errorf("failed to get source repos: %w", err) 256 } 257 sourceRepoMap := make(map[syntax.ATURI]*models.Repo) 258 for _, r := range sourceRepos { 259 sourceRepoMap[r.RepoAt()] = &r 260 } 261 for _, p := range pulls { 262 if p.PullSource != nil && p.PullSource.RepoAt != nil { 263 if sourceRepo, ok := sourceRepoMap[*p.PullSource.RepoAt]; ok { 264 p.PullSource.Repo = sourceRepo 265 } 266 } 267 } 268 269 orderedByPullId := []*models.Pull{} 270 for _, p := range pulls { 271 orderedByPullId = append(orderedByPullId, p) 272 } 273 sort.Slice(orderedByPullId, func(i, j int) bool { 274 return orderedByPullId[i].PullId > orderedByPullId[j].PullId 275 }) 276 277 return orderedByPullId, nil 278} 279 280func GetPulls(e Execer, filters ...filter) ([]*models.Pull, error) { 281 return GetPullsWithLimit(e, 0, filters...) 282} 283 284func GetPull(e Execer, repoAt syntax.ATURI, pullId int) (*models.Pull, error) { 285 pulls, err := GetPullsWithLimit(e, 1, FilterEq("repo_at", repoAt), FilterEq("pull_id", pullId)) 286 if err != nil { 287 return nil, err 288 } 289 if pulls == nil { 290 return nil, sql.ErrNoRows 291 } 292 293 return pulls[0], nil 294} 295 296// mapping from pull -> pull submissions 297func GetPullSubmissions(e Execer, filters ...filter) (map[syntax.ATURI][]*models.PullSubmission, error) { 298 var conditions []string 299 var args []any 300 for _, filter := range filters { 301 conditions = append(conditions, filter.Condition()) 302 args = append(args, filter.Arg()...) 303 } 304 305 whereClause := "" 306 if conditions != nil { 307 whereClause = " where " + strings.Join(conditions, " and ") 308 } 309 310 query := fmt.Sprintf(` 311 select 312 id, 313 pull_at, 314 round_number, 315 patch, 316 created, 317 source_rev 318 from 319 pull_submissions 320 %s 321 order by 322 round_number asc 323 `, whereClause) 324 325 rows, err := e.Query(query, args...) 326 if err != nil { 327 return nil, err 328 } 329 defer rows.Close() 330 331 submissionMap := make(map[int]*models.PullSubmission) 332 333 for rows.Next() { 334 var submission models.PullSubmission 335 var createdAt string 336 var sourceRev sql.NullString 337 err := rows.Scan( 338 &submission.ID, 339 &submission.PullAt, 340 &submission.RoundNumber, 341 &submission.Patch, 342 &createdAt, 343 &sourceRev, 344 ) 345 if err != nil { 346 return nil, err 347 } 348 349 createdTime, err := time.Parse(time.RFC3339, createdAt) 350 if err != nil { 351 return nil, err 352 } 353 submission.Created = createdTime 354 355 if sourceRev.Valid { 356 submission.SourceRev = sourceRev.String 357 } 358 359 submissionMap[submission.ID] = &submission 360 } 361 362 if err := rows.Err(); err != nil { 363 return nil, err 364 } 365 366 // Get comments for all submissions using GetPullComments 367 submissionIds := slices.Collect(maps.Keys(submissionMap)) 368 comments, err := GetPullComments(e, FilterIn("submission_id", submissionIds)) 369 if err != nil { 370 return nil, err 371 } 372 for _, comment := range comments { 373 if submission, ok := submissionMap[comment.SubmissionId]; ok { 374 submission.Comments = append(submission.Comments, comment) 375 } 376 } 377 378 // group the submissions by pull_at 379 m := make(map[syntax.ATURI][]*models.PullSubmission) 380 for _, s := range submissionMap { 381 m[s.PullAt] = append(m[s.PullAt], s) 382 } 383 384 // sort each one by round number 385 for _, s := range m { 386 slices.SortFunc(s, func(a, b *models.PullSubmission) int { 387 return cmp.Compare(a.RoundNumber, b.RoundNumber) 388 }) 389 } 390 391 return m, nil 392} 393 394func GetPullComments(e Execer, filters ...filter) ([]models.PullComment, error) { 395 var conditions []string 396 var args []any 397 for _, filter := range filters { 398 conditions = append(conditions, filter.Condition()) 399 args = append(args, filter.Arg()...) 400 } 401 402 whereClause := "" 403 if conditions != nil { 404 whereClause = " where " + strings.Join(conditions, " and ") 405 } 406 407 query := fmt.Sprintf(` 408 select 409 id, 410 pull_id, 411 submission_id, 412 repo_at, 413 owner_did, 414 comment_at, 415 body, 416 created 417 from 418 pull_comments 419 %s 420 order by 421 created asc 422 `, whereClause) 423 424 rows, err := e.Query(query, args...) 425 if err != nil { 426 return nil, err 427 } 428 defer rows.Close() 429 430 var comments []models.PullComment 431 for rows.Next() { 432 var comment models.PullComment 433 var createdAt string 434 err := rows.Scan( 435 &comment.ID, 436 &comment.PullId, 437 &comment.SubmissionId, 438 &comment.RepoAt, 439 &comment.OwnerDid, 440 &comment.CommentAt, 441 &comment.Body, 442 &createdAt, 443 ) 444 if err != nil { 445 return nil, err 446 } 447 448 if t, err := time.Parse(time.RFC3339, createdAt); err == nil { 449 comment.Created = t 450 } 451 452 comments = append(comments, comment) 453 } 454 455 if err := rows.Err(); err != nil { 456 return nil, err 457 } 458 459 return comments, nil 460} 461 462// timeframe here is directly passed into the sql query filter, and any 463// timeframe in the past should be negative; e.g.: "-3 months" 464func GetPullsByOwnerDid(e Execer, did, timeframe string) ([]models.Pull, error) { 465 var pulls []models.Pull 466 467 rows, err := e.Query(` 468 select 469 p.owner_did, 470 p.repo_at, 471 p.pull_id, 472 p.created, 473 p.title, 474 p.state, 475 r.did, 476 r.name, 477 r.knot, 478 r.rkey, 479 r.created 480 from 481 pulls p 482 join 483 repos r on p.repo_at = r.at_uri 484 where 485 p.owner_did = ? and p.created >= date ('now', ?) 486 order by 487 p.created desc`, did, timeframe) 488 if err != nil { 489 return nil, err 490 } 491 defer rows.Close() 492 493 for rows.Next() { 494 var pull models.Pull 495 var repo models.Repo 496 var pullCreatedAt, repoCreatedAt string 497 err := rows.Scan( 498 &pull.OwnerDid, 499 &pull.RepoAt, 500 &pull.PullId, 501 &pullCreatedAt, 502 &pull.Title, 503 &pull.State, 504 &repo.Did, 505 &repo.Name, 506 &repo.Knot, 507 &repo.Rkey, 508 &repoCreatedAt, 509 ) 510 if err != nil { 511 return nil, err 512 } 513 514 pullCreatedTime, err := time.Parse(time.RFC3339, pullCreatedAt) 515 if err != nil { 516 return nil, err 517 } 518 pull.Created = pullCreatedTime 519 520 repoCreatedTime, err := time.Parse(time.RFC3339, repoCreatedAt) 521 if err != nil { 522 return nil, err 523 } 524 repo.Created = repoCreatedTime 525 526 pull.Repo = &repo 527 528 pulls = append(pulls, pull) 529 } 530 531 if err := rows.Err(); err != nil { 532 return nil, err 533 } 534 535 return pulls, nil 536} 537 538func NewPullComment(e Execer, comment *models.PullComment) (int64, error) { 539 query := `insert into pull_comments (owner_did, repo_at, submission_id, comment_at, pull_id, body) values (?, ?, ?, ?, ?, ?)` 540 res, err := e.Exec( 541 query, 542 comment.OwnerDid, 543 comment.RepoAt, 544 comment.SubmissionId, 545 comment.CommentAt, 546 comment.PullId, 547 comment.Body, 548 ) 549 if err != nil { 550 return 0, err 551 } 552 553 i, err := res.LastInsertId() 554 if err != nil { 555 return 0, err 556 } 557 558 return i, nil 559} 560 561func SetPullState(e Execer, repoAt syntax.ATURI, pullId int, pullState models.PullState) error { 562 _, err := e.Exec( 563 `update pulls set state = ? where repo_at = ? and pull_id = ? and (state <> ? or state <> ?)`, 564 pullState, 565 repoAt, 566 pullId, 567 models.PullDeleted, // only update state of non-deleted pulls 568 models.PullMerged, // only update state of non-merged pulls 569 ) 570 return err 571} 572 573func ClosePull(e Execer, repoAt syntax.ATURI, pullId int) error { 574 err := SetPullState(e, repoAt, pullId, models.PullClosed) 575 return err 576} 577 578func ReopenPull(e Execer, repoAt syntax.ATURI, pullId int) error { 579 err := SetPullState(e, repoAt, pullId, models.PullOpen) 580 return err 581} 582 583func MergePull(e Execer, repoAt syntax.ATURI, pullId int) error { 584 err := SetPullState(e, repoAt, pullId, models.PullMerged) 585 return err 586} 587 588func DeletePull(e Execer, repoAt syntax.ATURI, pullId int) error { 589 err := SetPullState(e, repoAt, pullId, models.PullDeleted) 590 return err 591} 592 593func ResubmitPull(e Execer, pull *models.Pull) error { 594 newPatch := pull.LatestPatch() 595 newSourceRev := pull.LatestSha() 596 newRoundNumber := len(pull.Submissions) 597 _, err := e.Exec(` 598 insert into pull_submissions (pull_at, round_number, patch, source_rev) 599 values (?, ?, ?, ?) 600 `, pull.PullAt(), newRoundNumber, newPatch, newSourceRev) 601 602 return err 603} 604 605func SetPullParentChangeId(e Execer, parentChangeId string, filters ...filter) error { 606 var conditions []string 607 var args []any 608 609 args = append(args, parentChangeId) 610 611 for _, filter := range filters { 612 conditions = append(conditions, filter.Condition()) 613 args = append(args, filter.Arg()...) 614 } 615 616 whereClause := "" 617 if conditions != nil { 618 whereClause = " where " + strings.Join(conditions, " and ") 619 } 620 621 query := fmt.Sprintf("update pulls set parent_change_id = ? %s", whereClause) 622 _, err := e.Exec(query, args...) 623 624 return err 625} 626 627// Only used when stacking to update contents in the event of a rebase (the interdiff should be empty). 628// otherwise submissions are immutable 629func UpdatePull(e Execer, newPatch, sourceRev string, filters ...filter) error { 630 var conditions []string 631 var args []any 632 633 args = append(args, sourceRev) 634 args = append(args, newPatch) 635 636 for _, filter := range filters { 637 conditions = append(conditions, filter.Condition()) 638 args = append(args, filter.Arg()...) 639 } 640 641 whereClause := "" 642 if conditions != nil { 643 whereClause = " where " + strings.Join(conditions, " and ") 644 } 645 646 query := fmt.Sprintf("update pull_submissions set source_rev = ?, patch = ? %s", whereClause) 647 _, err := e.Exec(query, args...) 648 649 return err 650} 651 652func GetPullCount(e Execer, repoAt syntax.ATURI) (models.PullCount, error) { 653 row := e.QueryRow(` 654 select 655 count(case when state = ? then 1 end) as open_count, 656 count(case when state = ? then 1 end) as merged_count, 657 count(case when state = ? then 1 end) as closed_count, 658 count(case when state = ? then 1 end) as deleted_count 659 from pulls 660 where repo_at = ?`, 661 models.PullOpen, 662 models.PullMerged, 663 models.PullClosed, 664 models.PullDeleted, 665 repoAt, 666 ) 667 668 var count models.PullCount 669 if err := row.Scan(&count.Open, &count.Merged, &count.Closed, &count.Deleted); err != nil { 670 return models.PullCount{Open: 0, Merged: 0, Closed: 0, Deleted: 0}, err 671 } 672 673 return count, nil 674} 675 676// change-id parent-change-id 677// 678// 4 w ,-------- z (TOP) 679// 3 z <----',------- y 680// 2 y <-----',------ x 681// 1 x <------' nil (BOT) 682// 683// `w` is parent of none, so it is the top of the stack 684func GetStack(e Execer, stackId string) (models.Stack, error) { 685 unorderedPulls, err := GetPulls( 686 e, 687 FilterEq("stack_id", stackId), 688 FilterNotEq("state", models.PullDeleted), 689 ) 690 if err != nil { 691 return nil, err 692 } 693 // map of parent-change-id to pull 694 changeIdMap := make(map[string]*models.Pull, len(unorderedPulls)) 695 parentMap := make(map[string]*models.Pull, len(unorderedPulls)) 696 for _, p := range unorderedPulls { 697 changeIdMap[p.ChangeId] = p 698 if p.ParentChangeId != "" { 699 parentMap[p.ParentChangeId] = p 700 } 701 } 702 703 // the top of the stack is the pull that is not a parent of any pull 704 var topPull *models.Pull 705 for _, maybeTop := range unorderedPulls { 706 if _, ok := parentMap[maybeTop.ChangeId]; !ok { 707 topPull = maybeTop 708 break 709 } 710 } 711 712 pulls := []*models.Pull{} 713 for { 714 pulls = append(pulls, topPull) 715 if topPull.ParentChangeId != "" { 716 if next, ok := changeIdMap[topPull.ParentChangeId]; ok { 717 topPull = next 718 } else { 719 return nil, fmt.Errorf("failed to find parent pull request, stack is malformed") 720 } 721 } else { 722 break 723 } 724 } 725 726 return pulls, nil 727} 728 729func GetAbandonedPulls(e Execer, stackId string) ([]*models.Pull, error) { 730 pulls, err := GetPulls( 731 e, 732 FilterEq("stack_id", stackId), 733 FilterEq("state", models.PullDeleted), 734 ) 735 if err != nil { 736 return nil, err 737 } 738 739 return pulls, nil 740}