forked from tangled.org/core
this repo has no description
1package db 2 3import ( 4 "database/sql" 5 "fmt" 6 "log" 7 "slices" 8 "sort" 9 "strings" 10 "time" 11 12 "github.com/bluesky-social/indigo/atproto/syntax" 13 "tangled.sh/tangled.sh/core/api/tangled" 14 "tangled.sh/tangled.sh/core/patchutil" 15 "tangled.sh/tangled.sh/core/types" 16) 17 18type PullState int 19 20const ( 21 PullClosed PullState = iota 22 PullOpen 23 PullMerged 24 PullDeleted 25) 26 27func (p PullState) String() string { 28 switch p { 29 case PullOpen: 30 return "open" 31 case PullMerged: 32 return "merged" 33 case PullClosed: 34 return "closed" 35 case PullDeleted: 36 return "deleted" 37 default: 38 return "closed" 39 } 40} 41 42func (p PullState) IsOpen() bool { 43 return p == PullOpen 44} 45func (p PullState) IsMerged() bool { 46 return p == PullMerged 47} 48func (p PullState) IsClosed() bool { 49 return p == PullClosed 50} 51func (p PullState) IsDeleted() bool { 52 return p == PullDeleted 53} 54 55type Pull struct { 56 // ids 57 ID int 58 PullId int 59 60 // at ids 61 RepoAt syntax.ATURI 62 OwnerDid string 63 Rkey string 64 65 // content 66 Title string 67 Body string 68 TargetBranch string 69 State PullState 70 Submissions []*PullSubmission 71 72 // stacking 73 StackId string // nullable string 74 ChangeId string // nullable string 75 ParentChangeId string // nullable string 76 77 // meta 78 Created time.Time 79 PullSource *PullSource 80 81 // optionally, populate this when querying for reverse mappings 82 Repo *Repo 83} 84 85func (p Pull) AsRecord() tangled.RepoPull { 86 var source *tangled.RepoPull_Source 87 if p.PullSource != nil { 88 s := p.PullSource.AsRecord() 89 source = &s 90 } 91 92 record := tangled.RepoPull{ 93 Title: p.Title, 94 Body: &p.Body, 95 CreatedAt: p.Created.Format(time.RFC3339), 96 PullId: int64(p.PullId), 97 TargetRepo: p.RepoAt.String(), 98 TargetBranch: p.TargetBranch, 99 Patch: p.LatestPatch(), 100 Source: source, 101 } 102 return record 103} 104 105type PullSource struct { 106 Branch string 107 RepoAt *syntax.ATURI 108 109 // optionally populate this for reverse mappings 110 Repo *Repo 111} 112 113func (p PullSource) AsRecord() tangled.RepoPull_Source { 114 var repoAt *string 115 if p.RepoAt != nil { 116 s := p.RepoAt.String() 117 repoAt = &s 118 } 119 record := tangled.RepoPull_Source{ 120 Branch: p.Branch, 121 Repo: repoAt, 122 } 123 return record 124} 125 126type PullSubmission struct { 127 // ids 128 ID int 129 PullId int 130 131 // at ids 132 RepoAt syntax.ATURI 133 134 // content 135 RoundNumber int 136 Patch string 137 Comments []PullComment 138 SourceRev string // include the rev that was used to create this submission: only for branch/fork PRs 139 140 // meta 141 Created time.Time 142} 143 144type PullComment struct { 145 // ids 146 ID int 147 PullId int 148 SubmissionId int 149 150 // at ids 151 RepoAt string 152 OwnerDid string 153 CommentAt string 154 155 // content 156 Body string 157 158 // meta 159 Created time.Time 160} 161 162func (p *Pull) LatestPatch() string { 163 latestSubmission := p.Submissions[p.LastRoundNumber()] 164 return latestSubmission.Patch 165} 166 167func (p *Pull) PullAt() syntax.ATURI { 168 return syntax.ATURI(fmt.Sprintf("at://%s/%s/%s", p.OwnerDid, tangled.RepoPullNSID, p.Rkey)) 169} 170 171func (p *Pull) LastRoundNumber() int { 172 return len(p.Submissions) - 1 173} 174 175func (p *Pull) IsPatchBased() bool { 176 return p.PullSource == nil 177} 178 179func (p *Pull) IsBranchBased() bool { 180 if p.PullSource != nil { 181 if p.PullSource.RepoAt != nil { 182 return p.PullSource.RepoAt == &p.RepoAt 183 } else { 184 // no repo specified 185 return true 186 } 187 } 188 return false 189} 190 191func (p *Pull) IsForkBased() bool { 192 if p.PullSource != nil { 193 if p.PullSource.RepoAt != nil { 194 // make sure repos are different 195 return p.PullSource.RepoAt != &p.RepoAt 196 } 197 } 198 return false 199} 200 201func (p *Pull) IsStacked() bool { 202 return p.StackId != "" 203} 204 205func (s PullSubmission) IsFormatPatch() bool { 206 return patchutil.IsFormatPatch(s.Patch) 207} 208 209func (s PullSubmission) AsFormatPatch() []types.FormatPatch { 210 patches, err := patchutil.ExtractPatches(s.Patch) 211 if err != nil { 212 log.Println("error extracting patches from submission:", err) 213 return []types.FormatPatch{} 214 } 215 216 return patches 217} 218 219func NewPull(tx *sql.Tx, pull *Pull) error { 220 _, err := tx.Exec(` 221 insert or ignore into repo_pull_seqs (repo_at, next_pull_id) 222 values (?, 1) 223 `, pull.RepoAt) 224 if err != nil { 225 return err 226 } 227 228 var nextId int 229 err = tx.QueryRow(` 230 update repo_pull_seqs 231 set next_pull_id = next_pull_id + 1 232 where repo_at = ? 233 returning next_pull_id - 1 234 `, pull.RepoAt).Scan(&nextId) 235 if err != nil { 236 return err 237 } 238 239 pull.PullId = nextId 240 pull.State = PullOpen 241 242 var sourceBranch, sourceRepoAt *string 243 if pull.PullSource != nil { 244 sourceBranch = &pull.PullSource.Branch 245 if pull.PullSource.RepoAt != nil { 246 x := pull.PullSource.RepoAt.String() 247 sourceRepoAt = &x 248 } 249 } 250 251 var stackId, changeId, parentChangeId *string 252 if pull.StackId != "" { 253 stackId = &pull.StackId 254 } 255 if pull.ChangeId != "" { 256 changeId = &pull.ChangeId 257 } 258 if pull.ParentChangeId != "" { 259 parentChangeId = &pull.ParentChangeId 260 } 261 262 _, err = tx.Exec( 263 ` 264 insert into pulls ( 265 repo_at, owner_did, pull_id, title, target_branch, body, rkey, state, source_branch, source_repo_at, stack_id, change_id, parent_change_id 266 ) 267 values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`, 268 pull.RepoAt, 269 pull.OwnerDid, 270 pull.PullId, 271 pull.Title, 272 pull.TargetBranch, 273 pull.Body, 274 pull.Rkey, 275 pull.State, 276 sourceBranch, 277 sourceRepoAt, 278 stackId, 279 changeId, 280 parentChangeId, 281 ) 282 if err != nil { 283 return err 284 } 285 286 _, err = tx.Exec(` 287 insert into pull_submissions (pull_id, repo_at, round_number, patch, source_rev) 288 values (?, ?, ?, ?, ?) 289 `, pull.PullId, pull.RepoAt, 0, pull.Submissions[0].Patch, pull.Submissions[0].SourceRev) 290 return err 291} 292 293func GetPullAt(e Execer, repoAt syntax.ATURI, pullId int) (syntax.ATURI, error) { 294 pull, err := GetPull(e, repoAt, pullId) 295 if err != nil { 296 return "", err 297 } 298 return pull.PullAt(), err 299} 300 301func NextPullId(e Execer, repoAt syntax.ATURI) (int, error) { 302 var pullId int 303 err := e.QueryRow(`select next_pull_id from repo_pull_seqs where repo_at = ?`, repoAt).Scan(&pullId) 304 return pullId - 1, err 305} 306 307func GetPulls(e Execer, filters ...filter) ([]*Pull, error) { 308 pulls := make(map[int]*Pull) 309 310 var conditions []string 311 var args []any 312 for _, filter := range filters { 313 conditions = append(conditions, filter.Condition()) 314 args = append(args, filter.arg) 315 } 316 317 whereClause := "" 318 if conditions != nil { 319 whereClause = " where " + strings.Join(conditions, " and ") 320 } 321 322 query := fmt.Sprintf(` 323 select 324 owner_did, 325 repo_at, 326 pull_id, 327 created, 328 title, 329 state, 330 target_branch, 331 body, 332 rkey, 333 source_branch, 334 source_repo_at, 335 stack_id, 336 change_id, 337 parent_change_id 338 from 339 pulls 340 %s 341 `, whereClause) 342 343 rows, err := e.Query(query, args...) 344 if err != nil { 345 return nil, err 346 } 347 defer rows.Close() 348 349 for rows.Next() { 350 var pull Pull 351 var createdAt string 352 var sourceBranch, sourceRepoAt, stackId, changeId, parentChangeId sql.NullString 353 err := rows.Scan( 354 &pull.OwnerDid, 355 &pull.RepoAt, 356 &pull.PullId, 357 &createdAt, 358 &pull.Title, 359 &pull.State, 360 &pull.TargetBranch, 361 &pull.Body, 362 &pull.Rkey, 363 &sourceBranch, 364 &sourceRepoAt, 365 &stackId, 366 &changeId, 367 &parentChangeId, 368 ) 369 if err != nil { 370 return nil, err 371 } 372 373 createdTime, err := time.Parse(time.RFC3339, createdAt) 374 if err != nil { 375 return nil, err 376 } 377 pull.Created = createdTime 378 379 if sourceBranch.Valid { 380 pull.PullSource = &PullSource{ 381 Branch: sourceBranch.String, 382 } 383 if sourceRepoAt.Valid { 384 sourceRepoAtParsed, err := syntax.ParseATURI(sourceRepoAt.String) 385 if err != nil { 386 return nil, err 387 } 388 pull.PullSource.RepoAt = &sourceRepoAtParsed 389 } 390 } 391 392 if stackId.Valid { 393 pull.StackId = stackId.String 394 } 395 if changeId.Valid { 396 pull.ChangeId = changeId.String 397 } 398 if parentChangeId.Valid { 399 pull.ParentChangeId = parentChangeId.String 400 } 401 402 pulls[pull.PullId] = &pull 403 } 404 405 // get latest round no. for each pull 406 inClause := strings.TrimSuffix(strings.Repeat("?, ", len(pulls)), ", ") 407 submissionsQuery := fmt.Sprintf(` 408 select 409 id, pull_id, round_number, patch, source_rev 410 from 411 pull_submissions 412 where 413 repo_at in (%s) and pull_id in (%s) 414 `, inClause, inClause) 415 416 args = make([]any, len(pulls)*2) 417 idx := 0 418 for _, p := range pulls { 419 args[idx] = p.RepoAt 420 idx += 1 421 } 422 for _, p := range pulls { 423 args[idx] = p.PullId 424 idx += 1 425 } 426 submissionsRows, err := e.Query(submissionsQuery, args...) 427 if err != nil { 428 return nil, err 429 } 430 defer submissionsRows.Close() 431 432 for submissionsRows.Next() { 433 var s PullSubmission 434 var sourceRev sql.NullString 435 err := submissionsRows.Scan( 436 &s.ID, 437 &s.PullId, 438 &s.RoundNumber, 439 &s.Patch, 440 &sourceRev, 441 ) 442 if err != nil { 443 return nil, err 444 } 445 446 if sourceRev.Valid { 447 s.SourceRev = sourceRev.String 448 } 449 450 if p, ok := pulls[s.PullId]; ok { 451 p.Submissions = make([]*PullSubmission, s.RoundNumber+1) 452 p.Submissions[s.RoundNumber] = &s 453 } 454 } 455 if err := rows.Err(); err != nil { 456 return nil, err 457 } 458 459 // get comment count on latest submission on each pull 460 inClause = strings.TrimSuffix(strings.Repeat("?, ", len(pulls)), ", ") 461 commentsQuery := fmt.Sprintf(` 462 select 463 count(id), pull_id 464 from 465 pull_comments 466 where 467 submission_id in (%s) 468 group by 469 submission_id 470 `, inClause) 471 472 args = []any{} 473 for _, p := range pulls { 474 args = append(args, p.Submissions[p.LastRoundNumber()].ID) 475 } 476 commentsRows, err := e.Query(commentsQuery, args...) 477 if err != nil { 478 return nil, err 479 } 480 defer commentsRows.Close() 481 482 for commentsRows.Next() { 483 var commentCount, pullId int 484 err := commentsRows.Scan( 485 &commentCount, 486 &pullId, 487 ) 488 if err != nil { 489 return nil, err 490 } 491 if p, ok := pulls[pullId]; ok { 492 p.Submissions[p.LastRoundNumber()].Comments = make([]PullComment, commentCount) 493 } 494 } 495 if err := rows.Err(); err != nil { 496 return nil, err 497 } 498 499 orderedByPullId := []*Pull{} 500 for _, p := range pulls { 501 orderedByPullId = append(orderedByPullId, p) 502 } 503 sort.Slice(orderedByPullId, func(i, j int) bool { 504 return orderedByPullId[i].PullId > orderedByPullId[j].PullId 505 }) 506 507 return orderedByPullId, nil 508} 509 510func GetPull(e Execer, repoAt syntax.ATURI, pullId int) (*Pull, error) { 511 query := ` 512 select 513 owner_did, 514 pull_id, 515 created, 516 title, 517 state, 518 target_branch, 519 repo_at, 520 body, 521 rkey, 522 source_branch, 523 source_repo_at, 524 stack_id, 525 change_id, 526 parent_change_id 527 from 528 pulls 529 where 530 repo_at = ? and pull_id = ? 531 ` 532 row := e.QueryRow(query, repoAt, pullId) 533 534 var pull Pull 535 var createdAt string 536 var sourceBranch, sourceRepoAt, stackId, changeId, parentChangeId sql.NullString 537 err := row.Scan( 538 &pull.OwnerDid, 539 &pull.PullId, 540 &createdAt, 541 &pull.Title, 542 &pull.State, 543 &pull.TargetBranch, 544 &pull.RepoAt, 545 &pull.Body, 546 &pull.Rkey, 547 &sourceBranch, 548 &sourceRepoAt, 549 &stackId, 550 &changeId, 551 &parentChangeId, 552 ) 553 if err != nil { 554 return nil, err 555 } 556 557 createdTime, err := time.Parse(time.RFC3339, createdAt) 558 if err != nil { 559 return nil, err 560 } 561 pull.Created = createdTime 562 563 // populate source 564 if sourceBranch.Valid { 565 pull.PullSource = &PullSource{ 566 Branch: sourceBranch.String, 567 } 568 if sourceRepoAt.Valid { 569 sourceRepoAtParsed, err := syntax.ParseATURI(sourceRepoAt.String) 570 if err != nil { 571 return nil, err 572 } 573 pull.PullSource.RepoAt = &sourceRepoAtParsed 574 } 575 } 576 577 if stackId.Valid { 578 pull.StackId = stackId.String 579 } 580 if changeId.Valid { 581 pull.ChangeId = changeId.String 582 } 583 if parentChangeId.Valid { 584 pull.ParentChangeId = parentChangeId.String 585 } 586 587 submissionsQuery := ` 588 select 589 id, pull_id, repo_at, round_number, patch, created, source_rev 590 from 591 pull_submissions 592 where 593 repo_at = ? and pull_id = ? 594 ` 595 submissionsRows, err := e.Query(submissionsQuery, repoAt, pullId) 596 if err != nil { 597 return nil, err 598 } 599 defer submissionsRows.Close() 600 601 submissionsMap := make(map[int]*PullSubmission) 602 603 for submissionsRows.Next() { 604 var submission PullSubmission 605 var submissionCreatedStr string 606 var submissionSourceRev sql.NullString 607 err := submissionsRows.Scan( 608 &submission.ID, 609 &submission.PullId, 610 &submission.RepoAt, 611 &submission.RoundNumber, 612 &submission.Patch, 613 &submissionCreatedStr, 614 &submissionSourceRev, 615 ) 616 if err != nil { 617 return nil, err 618 } 619 620 submissionCreatedTime, err := time.Parse(time.RFC3339, submissionCreatedStr) 621 if err != nil { 622 return nil, err 623 } 624 submission.Created = submissionCreatedTime 625 626 if submissionSourceRev.Valid { 627 submission.SourceRev = submissionSourceRev.String 628 } 629 630 submissionsMap[submission.ID] = &submission 631 } 632 if err = submissionsRows.Close(); err != nil { 633 return nil, err 634 } 635 if len(submissionsMap) == 0 { 636 return &pull, nil 637 } 638 639 var args []any 640 for k := range submissionsMap { 641 args = append(args, k) 642 } 643 inClause := strings.TrimSuffix(strings.Repeat("?, ", len(submissionsMap)), ", ") 644 commentsQuery := fmt.Sprintf(` 645 select 646 id, 647 pull_id, 648 submission_id, 649 repo_at, 650 owner_did, 651 comment_at, 652 body, 653 created 654 from 655 pull_comments 656 where 657 submission_id IN (%s) 658 order by 659 created asc 660 `, inClause) 661 commentsRows, err := e.Query(commentsQuery, args...) 662 if err != nil { 663 return nil, err 664 } 665 defer commentsRows.Close() 666 667 for commentsRows.Next() { 668 var comment PullComment 669 var commentCreatedStr string 670 err := commentsRows.Scan( 671 &comment.ID, 672 &comment.PullId, 673 &comment.SubmissionId, 674 &comment.RepoAt, 675 &comment.OwnerDid, 676 &comment.CommentAt, 677 &comment.Body, 678 &commentCreatedStr, 679 ) 680 if err != nil { 681 return nil, err 682 } 683 684 commentCreatedTime, err := time.Parse(time.RFC3339, commentCreatedStr) 685 if err != nil { 686 return nil, err 687 } 688 comment.Created = commentCreatedTime 689 690 // Add the comment to its submission 691 if submission, ok := submissionsMap[comment.SubmissionId]; ok { 692 submission.Comments = append(submission.Comments, comment) 693 } 694 695 } 696 if err = commentsRows.Err(); err != nil { 697 return nil, err 698 } 699 700 var pullSourceRepo *Repo 701 if pull.PullSource != nil { 702 if pull.PullSource.RepoAt != nil { 703 pullSourceRepo, err = GetRepoByAtUri(e, pull.PullSource.RepoAt.String()) 704 if err != nil { 705 log.Printf("failed to get repo by at uri: %v", err) 706 } else { 707 pull.PullSource.Repo = pullSourceRepo 708 } 709 } 710 } 711 712 pull.Submissions = make([]*PullSubmission, len(submissionsMap)) 713 for _, submission := range submissionsMap { 714 pull.Submissions[submission.RoundNumber] = submission 715 } 716 717 return &pull, nil 718} 719 720// timeframe here is directly passed into the sql query filter, and any 721// timeframe in the past should be negative; e.g.: "-3 months" 722func GetPullsByOwnerDid(e Execer, did, timeframe string) ([]Pull, error) { 723 var pulls []Pull 724 725 rows, err := e.Query(` 726 select 727 p.owner_did, 728 p.repo_at, 729 p.pull_id, 730 p.created, 731 p.title, 732 p.state, 733 r.did, 734 r.name, 735 r.knot, 736 r.rkey, 737 r.created 738 from 739 pulls p 740 join 741 repos r on p.repo_at = r.at_uri 742 where 743 p.owner_did = ? and p.created >= date ('now', ?) 744 order by 745 p.created desc`, did, timeframe) 746 if err != nil { 747 return nil, err 748 } 749 defer rows.Close() 750 751 for rows.Next() { 752 var pull Pull 753 var repo Repo 754 var pullCreatedAt, repoCreatedAt string 755 err := rows.Scan( 756 &pull.OwnerDid, 757 &pull.RepoAt, 758 &pull.PullId, 759 &pullCreatedAt, 760 &pull.Title, 761 &pull.State, 762 &repo.Did, 763 &repo.Name, 764 &repo.Knot, 765 &repo.Rkey, 766 &repoCreatedAt, 767 ) 768 if err != nil { 769 return nil, err 770 } 771 772 pullCreatedTime, err := time.Parse(time.RFC3339, pullCreatedAt) 773 if err != nil { 774 return nil, err 775 } 776 pull.Created = pullCreatedTime 777 778 repoCreatedTime, err := time.Parse(time.RFC3339, repoCreatedAt) 779 if err != nil { 780 return nil, err 781 } 782 repo.Created = repoCreatedTime 783 784 pull.Repo = &repo 785 786 pulls = append(pulls, pull) 787 } 788 789 if err := rows.Err(); err != nil { 790 return nil, err 791 } 792 793 return pulls, nil 794} 795 796func NewPullComment(e Execer, comment *PullComment) (int64, error) { 797 query := `insert into pull_comments (owner_did, repo_at, submission_id, comment_at, pull_id, body) values (?, ?, ?, ?, ?, ?)` 798 res, err := e.Exec( 799 query, 800 comment.OwnerDid, 801 comment.RepoAt, 802 comment.SubmissionId, 803 comment.CommentAt, 804 comment.PullId, 805 comment.Body, 806 ) 807 if err != nil { 808 return 0, err 809 } 810 811 i, err := res.LastInsertId() 812 if err != nil { 813 return 0, err 814 } 815 816 return i, nil 817} 818 819func SetPullState(e Execer, repoAt syntax.ATURI, pullId int, pullState PullState) error { 820 _, err := e.Exec( 821 `update pulls set state = ? where repo_at = ? and pull_id = ? and (state <> ? or state <> ?)`, 822 pullState, 823 repoAt, 824 pullId, 825 PullDeleted, // only update state of non-deleted pulls 826 PullMerged, // only update state of non-merged pulls 827 ) 828 return err 829} 830 831func ClosePull(e Execer, repoAt syntax.ATURI, pullId int) error { 832 err := SetPullState(e, repoAt, pullId, PullClosed) 833 return err 834} 835 836func ReopenPull(e Execer, repoAt syntax.ATURI, pullId int) error { 837 err := SetPullState(e, repoAt, pullId, PullOpen) 838 return err 839} 840 841func MergePull(e Execer, repoAt syntax.ATURI, pullId int) error { 842 err := SetPullState(e, repoAt, pullId, PullMerged) 843 return err 844} 845 846func DeletePull(e Execer, repoAt syntax.ATURI, pullId int) error { 847 err := SetPullState(e, repoAt, pullId, PullDeleted) 848 return err 849} 850 851func ResubmitPull(e Execer, pull *Pull, newPatch, sourceRev string) error { 852 newRoundNumber := len(pull.Submissions) 853 _, err := e.Exec(` 854 insert into pull_submissions (pull_id, repo_at, round_number, patch, source_rev) 855 values (?, ?, ?, ?, ?) 856 `, pull.PullId, pull.RepoAt, newRoundNumber, newPatch, sourceRev) 857 858 return err 859} 860 861func SetPullParentChangeId(e Execer, parentChangeId string, filters ...filter) error { 862 var conditions []string 863 var args []any 864 865 args = append(args, parentChangeId) 866 867 for _, filter := range filters { 868 conditions = append(conditions, filter.Condition()) 869 args = append(args, filter.arg) 870 } 871 872 whereClause := "" 873 if conditions != nil { 874 whereClause = " where " + strings.Join(conditions, " and ") 875 } 876 877 query := fmt.Sprintf("update pulls set parent_change_id = ? %s", whereClause) 878 _, err := e.Exec(query, args...) 879 880 return err 881} 882 883// Only used when stacking to update contents in the event of a rebase (the interdiff should be empty). 884// otherwise submissions are immutable 885func UpdatePull(e Execer, newPatch, sourceRev string, filters ...filter) error { 886 var conditions []string 887 var args []any 888 889 args = append(args, sourceRev) 890 args = append(args, newPatch) 891 892 for _, filter := range filters { 893 conditions = append(conditions, filter.Condition()) 894 args = append(args, filter.arg) 895 } 896 897 whereClause := "" 898 if conditions != nil { 899 whereClause = " where " + strings.Join(conditions, " and ") 900 } 901 902 query := fmt.Sprintf("update pull_submissions set source_rev = ?, patch = ? %s", whereClause) 903 _, err := e.Exec(query, args...) 904 905 return err 906} 907 908type PullCount struct { 909 Open int 910 Merged int 911 Closed int 912 Deleted int 913} 914 915func GetPullCount(e Execer, repoAt syntax.ATURI) (PullCount, error) { 916 row := e.QueryRow(` 917 select 918 count(case when state = ? then 1 end) as open_count, 919 count(case when state = ? then 1 end) as merged_count, 920 count(case when state = ? then 1 end) as closed_count, 921 count(case when state = ? then 1 end) as deleted_count 922 from pulls 923 where repo_at = ?`, 924 PullOpen, 925 PullMerged, 926 PullClosed, 927 PullDeleted, 928 repoAt, 929 ) 930 931 var count PullCount 932 if err := row.Scan(&count.Open, &count.Merged, &count.Closed, &count.Deleted); err != nil { 933 return PullCount{0, 0, 0, 0}, err 934 } 935 936 return count, nil 937} 938 939type Stack []*Pull 940 941// change-id parent-change-id 942// 943// 4 w ,-------- z (TOP) 944// 3 z <----',------- y 945// 2 y <-----',------ x 946// 1 x <------' nil (BOT) 947// 948// `w` is parent of none, so it is the top of the stack 949func GetStack(e Execer, stackId string) (Stack, error) { 950 unorderedPulls, err := GetPulls( 951 e, 952 FilterEq("stack_id", stackId), 953 FilterNotEq("state", PullDeleted), 954 ) 955 if err != nil { 956 return nil, err 957 } 958 // map of parent-change-id to pull 959 changeIdMap := make(map[string]*Pull, len(unorderedPulls)) 960 parentMap := make(map[string]*Pull, len(unorderedPulls)) 961 for _, p := range unorderedPulls { 962 changeIdMap[p.ChangeId] = p 963 if p.ParentChangeId != "" { 964 parentMap[p.ParentChangeId] = p 965 } 966 } 967 968 // the top of the stack is the pull that is not a parent of any pull 969 var topPull *Pull 970 for _, maybeTop := range unorderedPulls { 971 if _, ok := parentMap[maybeTop.ChangeId]; !ok { 972 topPull = maybeTop 973 break 974 } 975 } 976 977 pulls := []*Pull{} 978 for { 979 pulls = append(pulls, topPull) 980 if topPull.ParentChangeId != "" { 981 if next, ok := changeIdMap[topPull.ParentChangeId]; ok { 982 topPull = next 983 } else { 984 return nil, fmt.Errorf("failed to find parent pull request, stack is malformed") 985 } 986 } else { 987 break 988 } 989 } 990 991 return pulls, nil 992} 993 994func GetAbandonedPulls(e Execer, stackId string) ([]*Pull, error) { 995 pulls, err := GetPulls( 996 e, 997 FilterEq("stack_id", stackId), 998 FilterEq("state", PullDeleted), 999 ) 1000 if err != nil { 1001 return nil, err 1002 } 1003 1004 return pulls, nil 1005} 1006 1007// position of this pull in the stack 1008func (stack Stack) Position(pull *Pull) int { 1009 return slices.IndexFunc(stack, func(p *Pull) bool { 1010 return p.ChangeId == pull.ChangeId 1011 }) 1012} 1013 1014// all pulls below this pull (including self) in this stack 1015// 1016// nil if this pull does not belong to this stack 1017func (stack Stack) Below(pull *Pull) Stack { 1018 position := stack.Position(pull) 1019 1020 if position < 0 { 1021 return nil 1022 } 1023 1024 return stack[position:] 1025} 1026 1027// all pulls below this pull (excluding self) in this stack 1028func (stack Stack) StrictlyBelow(pull *Pull) Stack { 1029 below := stack.Below(pull) 1030 1031 if len(below) > 0 { 1032 return below[1:] 1033 } 1034 1035 return nil 1036} 1037 1038// all pulls above this pull (including self) in this stack 1039func (stack Stack) Above(pull *Pull) Stack { 1040 position := stack.Position(pull) 1041 1042 if position < 0 { 1043 return nil 1044 } 1045 1046 return stack[:position+1] 1047} 1048 1049// all pulls below this pull (excluding self) in this stack 1050func (stack Stack) StrictlyAbove(pull *Pull) Stack { 1051 above := stack.Above(pull) 1052 1053 if len(above) > 0 { 1054 return above[:len(above)-1] 1055 } 1056 1057 return nil 1058} 1059 1060// the combined format-patches of all the newest submissions in this stack 1061func (stack Stack) CombinedPatch() string { 1062 // go in reverse order because the bottom of the stack is the last element in the slice 1063 var combined strings.Builder 1064 for idx := range stack { 1065 pull := stack[len(stack)-1-idx] 1066 combined.WriteString(pull.LatestPatch()) 1067 combined.WriteString("\n") 1068 } 1069 return combined.String() 1070} 1071 1072// filter out PRs that are "active" 1073// 1074// PRs that are still open are active 1075func (stack Stack) Mergeable() Stack { 1076 var mergeable Stack 1077 1078 for _, p := range stack { 1079 // stop at the first merged PR 1080 if p.State == PullMerged || p.State == PullClosed { 1081 break 1082 } 1083 1084 // skip over deleted PRs 1085 if p.State != PullDeleted { 1086 mergeable = append(mergeable, p) 1087 } 1088 } 1089 1090 return mergeable 1091}