forked from tangled.org/core
this repo has no description
1package db 2 3import ( 4 "database/sql" 5 "fmt" 6 "log" 7 "slices" 8 "sort" 9 "strings" 10 "time" 11 12 "github.com/bluekeyes/go-gitdiff/gitdiff" 13 "github.com/bluesky-social/indigo/atproto/syntax" 14 "tangled.sh/tangled.sh/core/api/tangled" 15 "tangled.sh/tangled.sh/core/patchutil" 16 "tangled.sh/tangled.sh/core/types" 17) 18 19type PullState int 20 21const ( 22 PullClosed PullState = iota 23 PullOpen 24 PullMerged 25 PullDeleted 26) 27 28func (p PullState) String() string { 29 switch p { 30 case PullOpen: 31 return "open" 32 case PullMerged: 33 return "merged" 34 case PullClosed: 35 return "closed" 36 case PullDeleted: 37 return "deleted" 38 default: 39 return "closed" 40 } 41} 42 43func (p PullState) IsOpen() bool { 44 return p == PullOpen 45} 46func (p PullState) IsMerged() bool { 47 return p == PullMerged 48} 49func (p PullState) IsClosed() bool { 50 return p == PullClosed 51} 52func (p PullState) IsDeleted() bool { 53 return p == PullDeleted 54} 55 56type Pull struct { 57 // ids 58 ID int 59 PullId int 60 61 // at ids 62 RepoAt syntax.ATURI 63 OwnerDid string 64 Rkey string 65 66 // content 67 Title string 68 Body string 69 TargetBranch string 70 State PullState 71 Submissions []*PullSubmission 72 73 // stacking 74 StackId string // nullable string 75 ChangeId string // nullable string 76 ParentChangeId string // nullable string 77 78 // meta 79 Created time.Time 80 PullSource *PullSource 81 82 // optionally, populate this when querying for reverse mappings 83 Repo *Repo 84} 85 86func (p Pull) AsRecord() tangled.RepoPull { 87 var source *tangled.RepoPull_Source 88 if p.PullSource != nil { 89 s := p.PullSource.AsRecord() 90 source = &s 91 } 92 93 record := tangled.RepoPull{ 94 Title: p.Title, 95 Body: &p.Body, 96 CreatedAt: p.Created.Format(time.RFC3339), 97 PullId: int64(p.PullId), 98 TargetRepo: p.RepoAt.String(), 99 TargetBranch: p.TargetBranch, 100 Patch: p.LatestPatch(), 101 Source: source, 102 } 103 return record 104} 105 106type PullSource struct { 107 Branch string 108 RepoAt *syntax.ATURI 109 110 // optionally populate this for reverse mappings 111 Repo *Repo 112} 113 114func (p PullSource) AsRecord() tangled.RepoPull_Source { 115 var repoAt *string 116 if p.RepoAt != nil { 117 s := p.RepoAt.String() 118 repoAt = &s 119 } 120 record := tangled.RepoPull_Source{ 121 Branch: p.Branch, 122 Repo: repoAt, 123 } 124 return record 125} 126 127type PullSubmission struct { 128 // ids 129 ID int 130 PullId int 131 132 // at ids 133 RepoAt syntax.ATURI 134 135 // content 136 RoundNumber int 137 Patch string 138 Comments []PullComment 139 SourceRev string // include the rev that was used to create this submission: only for branch/fork PRs 140 141 // meta 142 Created time.Time 143} 144 145type PullComment struct { 146 // ids 147 ID int 148 PullId int 149 SubmissionId int 150 151 // at ids 152 RepoAt string 153 OwnerDid string 154 CommentAt string 155 156 // content 157 Body string 158 159 // meta 160 Created time.Time 161} 162 163func (p *Pull) LatestPatch() string { 164 latestSubmission := p.Submissions[p.LastRoundNumber()] 165 return latestSubmission.Patch 166} 167 168func (p *Pull) PullAt() syntax.ATURI { 169 return syntax.ATURI(fmt.Sprintf("at://%s/%s/%s", p.OwnerDid, tangled.RepoPullNSID, p.Rkey)) 170} 171 172func (p *Pull) LastRoundNumber() int { 173 return len(p.Submissions) - 1 174} 175 176func (p *Pull) IsPatchBased() bool { 177 return p.PullSource == nil 178} 179 180func (p *Pull) IsBranchBased() bool { 181 if p.PullSource != nil { 182 if p.PullSource.RepoAt != nil { 183 return p.PullSource.RepoAt == &p.RepoAt 184 } else { 185 // no repo specified 186 return true 187 } 188 } 189 return false 190} 191 192func (p *Pull) IsForkBased() bool { 193 if p.PullSource != nil { 194 if p.PullSource.RepoAt != nil { 195 // make sure repos are different 196 return p.PullSource.RepoAt != &p.RepoAt 197 } 198 } 199 return false 200} 201 202func (p *Pull) IsStacked() bool { 203 return p.StackId != "" 204} 205 206func (s PullSubmission) AsDiff(targetBranch string) ([]*gitdiff.File, error) { 207 patch := s.Patch 208 209 // if format-patch; then extract each patch 210 var diffs []*gitdiff.File 211 if patchutil.IsFormatPatch(patch) { 212 patches, err := patchutil.ExtractPatches(patch) 213 if err != nil { 214 return nil, err 215 } 216 var ps [][]*gitdiff.File 217 for _, p := range patches { 218 ps = append(ps, p.Files) 219 } 220 221 diffs = patchutil.CombineDiff(ps...) 222 } else { 223 d, _, err := gitdiff.Parse(strings.NewReader(patch)) 224 if err != nil { 225 return nil, err 226 } 227 diffs = d 228 } 229 230 return diffs, nil 231} 232 233func (s PullSubmission) AsNiceDiff(targetBranch string) types.NiceDiff { 234 diffs, err := s.AsDiff(targetBranch) 235 if err != nil { 236 log.Println(err) 237 } 238 239 nd := types.NiceDiff{} 240 nd.Commit.Parent = targetBranch 241 242 for _, d := range diffs { 243 ndiff := types.Diff{} 244 ndiff.Name.New = d.NewName 245 ndiff.Name.Old = d.OldName 246 ndiff.IsBinary = d.IsBinary 247 ndiff.IsNew = d.IsNew 248 ndiff.IsDelete = d.IsDelete 249 ndiff.IsCopy = d.IsCopy 250 ndiff.IsRename = d.IsRename 251 252 for _, tf := range d.TextFragments { 253 ndiff.TextFragments = append(ndiff.TextFragments, *tf) 254 for _, l := range tf.Lines { 255 switch l.Op { 256 case gitdiff.OpAdd: 257 nd.Stat.Insertions += 1 258 case gitdiff.OpDelete: 259 nd.Stat.Deletions += 1 260 } 261 } 262 } 263 264 nd.Diff = append(nd.Diff, ndiff) 265 } 266 267 nd.Stat.FilesChanged = len(diffs) 268 269 return nd 270} 271 272func (s PullSubmission) IsFormatPatch() bool { 273 return patchutil.IsFormatPatch(s.Patch) 274} 275 276func (s PullSubmission) AsFormatPatch() []patchutil.FormatPatch { 277 patches, err := patchutil.ExtractPatches(s.Patch) 278 if err != nil { 279 log.Println("error extracting patches from submission:", err) 280 return []patchutil.FormatPatch{} 281 } 282 283 return patches 284} 285 286func NewPull(tx *sql.Tx, pull *Pull) error { 287 _, err := tx.Exec(` 288 insert or ignore into repo_pull_seqs (repo_at, next_pull_id) 289 values (?, 1) 290 `, pull.RepoAt) 291 if err != nil { 292 return err 293 } 294 295 var nextId int 296 err = tx.QueryRow(` 297 update repo_pull_seqs 298 set next_pull_id = next_pull_id + 1 299 where repo_at = ? 300 returning next_pull_id - 1 301 `, pull.RepoAt).Scan(&nextId) 302 if err != nil { 303 return err 304 } 305 306 pull.PullId = nextId 307 pull.State = PullOpen 308 309 var sourceBranch, sourceRepoAt *string 310 if pull.PullSource != nil { 311 sourceBranch = &pull.PullSource.Branch 312 if pull.PullSource.RepoAt != nil { 313 x := pull.PullSource.RepoAt.String() 314 sourceRepoAt = &x 315 } 316 } 317 318 var stackId, changeId, parentChangeId *string 319 if pull.StackId != "" { 320 stackId = &pull.StackId 321 } 322 if pull.ChangeId != "" { 323 changeId = &pull.ChangeId 324 } 325 if pull.ParentChangeId != "" { 326 parentChangeId = &pull.ParentChangeId 327 } 328 329 _, err = tx.Exec( 330 ` 331 insert into pulls ( 332 repo_at, owner_did, pull_id, title, target_branch, body, rkey, state, source_branch, source_repo_at, stack_id, change_id, parent_change_id 333 ) 334 values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`, 335 pull.RepoAt, 336 pull.OwnerDid, 337 pull.PullId, 338 pull.Title, 339 pull.TargetBranch, 340 pull.Body, 341 pull.Rkey, 342 pull.State, 343 sourceBranch, 344 sourceRepoAt, 345 stackId, 346 changeId, 347 parentChangeId, 348 ) 349 if err != nil { 350 return err 351 } 352 353 _, err = tx.Exec(` 354 insert into pull_submissions (pull_id, repo_at, round_number, patch, source_rev) 355 values (?, ?, ?, ?, ?) 356 `, pull.PullId, pull.RepoAt, 0, pull.Submissions[0].Patch, pull.Submissions[0].SourceRev) 357 return err 358} 359 360func GetPullAt(e Execer, repoAt syntax.ATURI, pullId int) (syntax.ATURI, error) { 361 pull, err := GetPull(e, repoAt, pullId) 362 if err != nil { 363 return "", err 364 } 365 return pull.PullAt(), err 366} 367 368func NextPullId(e Execer, repoAt syntax.ATURI) (int, error) { 369 var pullId int 370 err := e.QueryRow(`select next_pull_id from repo_pull_seqs where repo_at = ?`, repoAt).Scan(&pullId) 371 return pullId - 1, err 372} 373 374func GetPulls(e Execer, filters ...filter) ([]*Pull, error) { 375 pulls := make(map[int]*Pull) 376 377 var conditions []string 378 var args []any 379 for _, filter := range filters { 380 conditions = append(conditions, filter.Condition()) 381 args = append(args, filter.arg) 382 } 383 384 whereClause := "" 385 if conditions != nil { 386 whereClause = " where " + strings.Join(conditions, " and ") 387 } 388 389 query := fmt.Sprintf(` 390 select 391 owner_did, 392 repo_at, 393 pull_id, 394 created, 395 title, 396 state, 397 target_branch, 398 body, 399 rkey, 400 source_branch, 401 source_repo_at, 402 stack_id, 403 change_id, 404 parent_change_id 405 from 406 pulls 407 %s 408 `, whereClause) 409 410 rows, err := e.Query(query, args...) 411 if err != nil { 412 return nil, err 413 } 414 defer rows.Close() 415 416 for rows.Next() { 417 var pull Pull 418 var createdAt string 419 var sourceBranch, sourceRepoAt, stackId, changeId, parentChangeId sql.NullString 420 err := rows.Scan( 421 &pull.OwnerDid, 422 &pull.RepoAt, 423 &pull.PullId, 424 &createdAt, 425 &pull.Title, 426 &pull.State, 427 &pull.TargetBranch, 428 &pull.Body, 429 &pull.Rkey, 430 &sourceBranch, 431 &sourceRepoAt, 432 &stackId, 433 &changeId, 434 &parentChangeId, 435 ) 436 if err != nil { 437 return nil, err 438 } 439 440 createdTime, err := time.Parse(time.RFC3339, createdAt) 441 if err != nil { 442 return nil, err 443 } 444 pull.Created = createdTime 445 446 if sourceBranch.Valid { 447 pull.PullSource = &PullSource{ 448 Branch: sourceBranch.String, 449 } 450 if sourceRepoAt.Valid { 451 sourceRepoAtParsed, err := syntax.ParseATURI(sourceRepoAt.String) 452 if err != nil { 453 return nil, err 454 } 455 pull.PullSource.RepoAt = &sourceRepoAtParsed 456 } 457 } 458 459 if stackId.Valid { 460 pull.StackId = stackId.String 461 } 462 if changeId.Valid { 463 pull.ChangeId = changeId.String 464 } 465 if parentChangeId.Valid { 466 pull.ParentChangeId = parentChangeId.String 467 } 468 469 pulls[pull.PullId] = &pull 470 } 471 472 // get latest round no. for each pull 473 inClause := strings.TrimSuffix(strings.Repeat("?, ", len(pulls)), ", ") 474 submissionsQuery := fmt.Sprintf(` 475 select 476 id, pull_id, round_number, patch, source_rev 477 from 478 pull_submissions 479 where 480 repo_at in (%s) and pull_id in (%s) 481 `, inClause, inClause) 482 483 args = make([]any, len(pulls)*2) 484 idx := 0 485 for _, p := range pulls { 486 args[idx] = p.RepoAt 487 idx += 1 488 } 489 for _, p := range pulls { 490 args[idx] = p.PullId 491 idx += 1 492 } 493 submissionsRows, err := e.Query(submissionsQuery, args...) 494 if err != nil { 495 return nil, err 496 } 497 defer submissionsRows.Close() 498 499 for submissionsRows.Next() { 500 var s PullSubmission 501 var sourceRev sql.NullString 502 err := submissionsRows.Scan( 503 &s.ID, 504 &s.PullId, 505 &s.RoundNumber, 506 &s.Patch, 507 &sourceRev, 508 ) 509 if err != nil { 510 return nil, err 511 } 512 513 if sourceRev.Valid { 514 s.SourceRev = sourceRev.String 515 } 516 517 if p, ok := pulls[s.PullId]; ok { 518 p.Submissions = make([]*PullSubmission, s.RoundNumber+1) 519 p.Submissions[s.RoundNumber] = &s 520 } 521 } 522 if err := rows.Err(); err != nil { 523 return nil, err 524 } 525 526 // get comment count on latest submission on each pull 527 inClause = strings.TrimSuffix(strings.Repeat("?, ", len(pulls)), ", ") 528 commentsQuery := fmt.Sprintf(` 529 select 530 count(id), pull_id 531 from 532 pull_comments 533 where 534 submission_id in (%s) 535 group by 536 submission_id 537 `, inClause) 538 539 args = []any{} 540 for _, p := range pulls { 541 args = append(args, p.Submissions[p.LastRoundNumber()].ID) 542 } 543 commentsRows, err := e.Query(commentsQuery, args...) 544 if err != nil { 545 return nil, err 546 } 547 defer commentsRows.Close() 548 549 for commentsRows.Next() { 550 var commentCount, pullId int 551 err := commentsRows.Scan( 552 &commentCount, 553 &pullId, 554 ) 555 if err != nil { 556 return nil, err 557 } 558 if p, ok := pulls[pullId]; ok { 559 p.Submissions[p.LastRoundNumber()].Comments = make([]PullComment, commentCount) 560 } 561 } 562 if err := rows.Err(); err != nil { 563 return nil, err 564 } 565 566 orderedByPullId := []*Pull{} 567 for _, p := range pulls { 568 orderedByPullId = append(orderedByPullId, p) 569 } 570 sort.Slice(orderedByPullId, func(i, j int) bool { 571 return orderedByPullId[i].PullId > orderedByPullId[j].PullId 572 }) 573 574 return orderedByPullId, nil 575} 576 577func GetPull(e Execer, repoAt syntax.ATURI, pullId int) (*Pull, error) { 578 query := ` 579 select 580 owner_did, 581 pull_id, 582 created, 583 title, 584 state, 585 target_branch, 586 repo_at, 587 body, 588 rkey, 589 source_branch, 590 source_repo_at, 591 stack_id, 592 change_id, 593 parent_change_id 594 from 595 pulls 596 where 597 repo_at = ? and pull_id = ? 598 ` 599 row := e.QueryRow(query, repoAt, pullId) 600 601 var pull Pull 602 var createdAt string 603 var sourceBranch, sourceRepoAt, stackId, changeId, parentChangeId sql.NullString 604 err := row.Scan( 605 &pull.OwnerDid, 606 &pull.PullId, 607 &createdAt, 608 &pull.Title, 609 &pull.State, 610 &pull.TargetBranch, 611 &pull.RepoAt, 612 &pull.Body, 613 &pull.Rkey, 614 &sourceBranch, 615 &sourceRepoAt, 616 &stackId, 617 &changeId, 618 &parentChangeId, 619 ) 620 if err != nil { 621 return nil, err 622 } 623 624 createdTime, err := time.Parse(time.RFC3339, createdAt) 625 if err != nil { 626 return nil, err 627 } 628 pull.Created = createdTime 629 630 // populate source 631 if sourceBranch.Valid { 632 pull.PullSource = &PullSource{ 633 Branch: sourceBranch.String, 634 } 635 if sourceRepoAt.Valid { 636 sourceRepoAtParsed, err := syntax.ParseATURI(sourceRepoAt.String) 637 if err != nil { 638 return nil, err 639 } 640 pull.PullSource.RepoAt = &sourceRepoAtParsed 641 } 642 } 643 644 if stackId.Valid { 645 pull.StackId = stackId.String 646 } 647 if changeId.Valid { 648 pull.ChangeId = changeId.String 649 } 650 if parentChangeId.Valid { 651 pull.ParentChangeId = parentChangeId.String 652 } 653 654 submissionsQuery := ` 655 select 656 id, pull_id, repo_at, round_number, patch, created, source_rev 657 from 658 pull_submissions 659 where 660 repo_at = ? and pull_id = ? 661 ` 662 submissionsRows, err := e.Query(submissionsQuery, repoAt, pullId) 663 if err != nil { 664 return nil, err 665 } 666 defer submissionsRows.Close() 667 668 submissionsMap := make(map[int]*PullSubmission) 669 670 for submissionsRows.Next() { 671 var submission PullSubmission 672 var submissionCreatedStr string 673 var submissionSourceRev sql.NullString 674 err := submissionsRows.Scan( 675 &submission.ID, 676 &submission.PullId, 677 &submission.RepoAt, 678 &submission.RoundNumber, 679 &submission.Patch, 680 &submissionCreatedStr, 681 &submissionSourceRev, 682 ) 683 if err != nil { 684 return nil, err 685 } 686 687 submissionCreatedTime, err := time.Parse(time.RFC3339, submissionCreatedStr) 688 if err != nil { 689 return nil, err 690 } 691 submission.Created = submissionCreatedTime 692 693 if submissionSourceRev.Valid { 694 submission.SourceRev = submissionSourceRev.String 695 } 696 697 submissionsMap[submission.ID] = &submission 698 } 699 if err = submissionsRows.Close(); err != nil { 700 return nil, err 701 } 702 if len(submissionsMap) == 0 { 703 return &pull, nil 704 } 705 706 var args []any 707 for k := range submissionsMap { 708 args = append(args, k) 709 } 710 inClause := strings.TrimSuffix(strings.Repeat("?, ", len(submissionsMap)), ", ") 711 commentsQuery := fmt.Sprintf(` 712 select 713 id, 714 pull_id, 715 submission_id, 716 repo_at, 717 owner_did, 718 comment_at, 719 body, 720 created 721 from 722 pull_comments 723 where 724 submission_id IN (%s) 725 order by 726 created asc 727 `, inClause) 728 commentsRows, err := e.Query(commentsQuery, args...) 729 if err != nil { 730 return nil, err 731 } 732 defer commentsRows.Close() 733 734 for commentsRows.Next() { 735 var comment PullComment 736 var commentCreatedStr string 737 err := commentsRows.Scan( 738 &comment.ID, 739 &comment.PullId, 740 &comment.SubmissionId, 741 &comment.RepoAt, 742 &comment.OwnerDid, 743 &comment.CommentAt, 744 &comment.Body, 745 &commentCreatedStr, 746 ) 747 if err != nil { 748 return nil, err 749 } 750 751 commentCreatedTime, err := time.Parse(time.RFC3339, commentCreatedStr) 752 if err != nil { 753 return nil, err 754 } 755 comment.Created = commentCreatedTime 756 757 // Add the comment to its submission 758 if submission, ok := submissionsMap[comment.SubmissionId]; ok { 759 submission.Comments = append(submission.Comments, comment) 760 } 761 762 } 763 if err = commentsRows.Err(); err != nil { 764 return nil, err 765 } 766 767 var pullSourceRepo *Repo 768 if pull.PullSource != nil { 769 if pull.PullSource.RepoAt != nil { 770 pullSourceRepo, err = GetRepoByAtUri(e, pull.PullSource.RepoAt.String()) 771 if err != nil { 772 log.Printf("failed to get repo by at uri: %v", err) 773 } else { 774 pull.PullSource.Repo = pullSourceRepo 775 } 776 } 777 } 778 779 pull.Submissions = make([]*PullSubmission, len(submissionsMap)) 780 for _, submission := range submissionsMap { 781 pull.Submissions[submission.RoundNumber] = submission 782 } 783 784 return &pull, nil 785} 786 787// timeframe here is directly passed into the sql query filter, and any 788// timeframe in the past should be negative; e.g.: "-3 months" 789func GetPullsByOwnerDid(e Execer, did, timeframe string) ([]Pull, error) { 790 var pulls []Pull 791 792 rows, err := e.Query(` 793 select 794 p.owner_did, 795 p.repo_at, 796 p.pull_id, 797 p.created, 798 p.title, 799 p.state, 800 r.did, 801 r.name, 802 r.knot, 803 r.rkey, 804 r.created 805 from 806 pulls p 807 join 808 repos r on p.repo_at = r.at_uri 809 where 810 p.owner_did = ? and p.created >= date ('now', ?) 811 order by 812 p.created desc`, did, timeframe) 813 if err != nil { 814 return nil, err 815 } 816 defer rows.Close() 817 818 for rows.Next() { 819 var pull Pull 820 var repo Repo 821 var pullCreatedAt, repoCreatedAt string 822 err := rows.Scan( 823 &pull.OwnerDid, 824 &pull.RepoAt, 825 &pull.PullId, 826 &pullCreatedAt, 827 &pull.Title, 828 &pull.State, 829 &repo.Did, 830 &repo.Name, 831 &repo.Knot, 832 &repo.Rkey, 833 &repoCreatedAt, 834 ) 835 if err != nil { 836 return nil, err 837 } 838 839 pullCreatedTime, err := time.Parse(time.RFC3339, pullCreatedAt) 840 if err != nil { 841 return nil, err 842 } 843 pull.Created = pullCreatedTime 844 845 repoCreatedTime, err := time.Parse(time.RFC3339, repoCreatedAt) 846 if err != nil { 847 return nil, err 848 } 849 repo.Created = repoCreatedTime 850 851 pull.Repo = &repo 852 853 pulls = append(pulls, pull) 854 } 855 856 if err := rows.Err(); err != nil { 857 return nil, err 858 } 859 860 return pulls, nil 861} 862 863func NewPullComment(e Execer, comment *PullComment) (int64, error) { 864 query := `insert into pull_comments (owner_did, repo_at, submission_id, comment_at, pull_id, body) values (?, ?, ?, ?, ?, ?)` 865 res, err := e.Exec( 866 query, 867 comment.OwnerDid, 868 comment.RepoAt, 869 comment.SubmissionId, 870 comment.CommentAt, 871 comment.PullId, 872 comment.Body, 873 ) 874 if err != nil { 875 return 0, err 876 } 877 878 i, err := res.LastInsertId() 879 if err != nil { 880 return 0, err 881 } 882 883 return i, nil 884} 885 886func SetPullState(e Execer, repoAt syntax.ATURI, pullId int, pullState PullState) error { 887 _, err := e.Exec( 888 `update pulls set state = ? where repo_at = ? and pull_id = ? and (state <> ? or state <> ?)`, 889 pullState, 890 repoAt, 891 pullId, 892 PullDeleted, // only update state of non-deleted pulls 893 PullMerged, // only update state of non-merged pulls 894 ) 895 return err 896} 897 898func ClosePull(e Execer, repoAt syntax.ATURI, pullId int) error { 899 err := SetPullState(e, repoAt, pullId, PullClosed) 900 return err 901} 902 903func ReopenPull(e Execer, repoAt syntax.ATURI, pullId int) error { 904 err := SetPullState(e, repoAt, pullId, PullOpen) 905 return err 906} 907 908func MergePull(e Execer, repoAt syntax.ATURI, pullId int) error { 909 err := SetPullState(e, repoAt, pullId, PullMerged) 910 return err 911} 912 913func DeletePull(e Execer, repoAt syntax.ATURI, pullId int) error { 914 err := SetPullState(e, repoAt, pullId, PullDeleted) 915 return err 916} 917 918func ResubmitPull(e Execer, pull *Pull, newPatch, sourceRev string) error { 919 newRoundNumber := len(pull.Submissions) 920 _, err := e.Exec(` 921 insert into pull_submissions (pull_id, repo_at, round_number, patch, source_rev) 922 values (?, ?, ?, ?, ?) 923 `, pull.PullId, pull.RepoAt, newRoundNumber, newPatch, sourceRev) 924 925 return err 926} 927 928func SetPullParentChangeId(e Execer, parentChangeId string, filters ...filter) error { 929 var conditions []string 930 var args []any 931 932 args = append(args, parentChangeId) 933 934 for _, filter := range filters { 935 conditions = append(conditions, filter.Condition()) 936 args = append(args, filter.arg) 937 } 938 939 whereClause := "" 940 if conditions != nil { 941 whereClause = " where " + strings.Join(conditions, " and ") 942 } 943 944 query := fmt.Sprintf("update pulls set parent_change_id = ? %s", whereClause) 945 _, err := e.Exec(query, args...) 946 947 return err 948} 949 950type PullCount struct { 951 Open int 952 Merged int 953 Closed int 954 Deleted int 955} 956 957func GetPullCount(e Execer, repoAt syntax.ATURI) (PullCount, error) { 958 row := e.QueryRow(` 959 select 960 count(case when state = ? then 1 end) as open_count, 961 count(case when state = ? then 1 end) as merged_count, 962 count(case when state = ? then 1 end) as closed_count, 963 count(case when state = ? then 1 end) as deleted_count 964 from pulls 965 where repo_at = ?`, 966 PullOpen, 967 PullMerged, 968 PullClosed, 969 PullDeleted, 970 repoAt, 971 ) 972 973 var count PullCount 974 if err := row.Scan(&count.Open, &count.Merged, &count.Closed, &count.Deleted); err != nil { 975 return PullCount{0, 0, 0, 0}, err 976 } 977 978 return count, nil 979} 980 981type Stack []*Pull 982 983// change-id parent-change-id 984// 985// 4 w ,-------- z (TOP) 986// 3 z <----',------- y 987// 2 y <-----',------ x 988// 1 x <------' nil (BOT) 989// 990// `w` is parent of none, so it is the top of the stack 991func GetStack(e Execer, stackId string) (Stack, error) { 992 unorderedPulls, err := GetPulls( 993 e, 994 FilterEq("stack_id", stackId), 995 FilterNotEq("state", PullDeleted), 996 ) 997 if err != nil { 998 return nil, err 999 } 1000 // map of parent-change-id to pull 1001 changeIdMap := make(map[string]*Pull, len(unorderedPulls)) 1002 parentMap := make(map[string]*Pull, len(unorderedPulls)) 1003 for _, p := range unorderedPulls { 1004 changeIdMap[p.ChangeId] = p 1005 if p.ParentChangeId != "" { 1006 parentMap[p.ParentChangeId] = p 1007 } 1008 } 1009 1010 // the top of the stack is the pull that is not a parent of any pull 1011 var topPull *Pull 1012 for _, maybeTop := range unorderedPulls { 1013 if _, ok := parentMap[maybeTop.ChangeId]; !ok { 1014 topPull = maybeTop 1015 break 1016 } 1017 } 1018 1019 pulls := []*Pull{} 1020 for { 1021 pulls = append(pulls, topPull) 1022 if topPull.ParentChangeId != "" { 1023 if next, ok := changeIdMap[topPull.ParentChangeId]; ok { 1024 topPull = next 1025 } else { 1026 return nil, fmt.Errorf("failed to find parent pull request, stack is malformed") 1027 } 1028 } else { 1029 break 1030 } 1031 } 1032 1033 return pulls, nil 1034} 1035 1036func GetAbandonedPulls(e Execer, stackId string) ([]*Pull, error) { 1037 pulls, err := GetPulls( 1038 e, 1039 FilterEq("stack_id", stackId), 1040 FilterEq("state", PullDeleted), 1041 ) 1042 if err != nil { 1043 return nil, err 1044 } 1045 1046 return pulls, nil 1047} 1048 1049// position of this pull in the stack 1050func (stack Stack) Position(pull *Pull) int { 1051 return slices.IndexFunc(stack, func(p *Pull) bool { 1052 return p.ChangeId == pull.ChangeId 1053 }) 1054} 1055 1056// all pulls below this pull (including self) in this stack 1057// 1058// nil if this pull does not belong to this stack 1059func (stack Stack) Below(pull *Pull) Stack { 1060 position := stack.Position(pull) 1061 1062 if position < 0 { 1063 return nil 1064 } 1065 1066 return stack[position:] 1067} 1068 1069// all pulls below this pull (excluding self) in this stack 1070func (stack Stack) StrictlyBelow(pull *Pull) Stack { 1071 below := stack.Below(pull) 1072 1073 if len(below) > 0 { 1074 return below[1:] 1075 } 1076 1077 return nil 1078} 1079 1080// all pulls above this pull (including self) in this stack 1081func (stack Stack) Above(pull *Pull) Stack { 1082 position := stack.Position(pull) 1083 1084 if position < 0 { 1085 return nil 1086 } 1087 1088 return stack[:position+1] 1089} 1090 1091// all pulls below this pull (excluding self) in this stack 1092func (stack Stack) StrictlyAbove(pull *Pull) Stack { 1093 above := stack.Above(pull) 1094 1095 if len(above) > 0 { 1096 return above[:len(above)-1] 1097 } 1098 1099 return nil 1100} 1101 1102// the combined format-patches of all the newest submissions in this stack 1103func (stack Stack) CombinedPatch() string { 1104 // go in reverse order because the bottom of the stack is the last element in the slice 1105 var combined strings.Builder 1106 for idx := range stack { 1107 pull := stack[len(stack)-1-idx] 1108 combined.WriteString(pull.LatestPatch()) 1109 combined.WriteString("\n") 1110 } 1111 return combined.String() 1112} 1113 1114// filter out PRs that are "active" 1115// 1116// PRs that are still open are active 1117func (stack Stack) Mergeable() Stack { 1118 var mergeable Stack 1119 1120 for _, p := range stack { 1121 // stop at the first merged PR 1122 if p.State == PullMerged || p.State == PullClosed { 1123 break 1124 } 1125 1126 // skip over deleted PRs 1127 if p.State != PullDeleted { 1128 mergeable = append(mergeable, p) 1129 } 1130 } 1131 1132 return mergeable 1133}