forked from tangled.org/core
this repo has no description
1package db 2 3import ( 4 "database/sql" 5 "fmt" 6 "log" 7 "slices" 8 "sort" 9 "strings" 10 "time" 11 12 "github.com/bluesky-social/indigo/atproto/syntax" 13 "tangled.sh/tangled.sh/core/api/tangled" 14 "tangled.sh/tangled.sh/core/patchutil" 15 "tangled.sh/tangled.sh/core/types" 16) 17 18type PullState int 19 20const ( 21 PullClosed PullState = iota 22 PullOpen 23 PullMerged 24 PullDeleted 25) 26 27func (p PullState) String() string { 28 switch p { 29 case PullOpen: 30 return "open" 31 case PullMerged: 32 return "merged" 33 case PullClosed: 34 return "closed" 35 case PullDeleted: 36 return "deleted" 37 default: 38 return "closed" 39 } 40} 41 42func (p PullState) IsOpen() bool { 43 return p == PullOpen 44} 45func (p PullState) IsMerged() bool { 46 return p == PullMerged 47} 48func (p PullState) IsClosed() bool { 49 return p == PullClosed 50} 51func (p PullState) IsDeleted() bool { 52 return p == PullDeleted 53} 54 55type Pull struct { 56 // ids 57 ID int 58 PullId int 59 60 // at ids 61 RepoAt syntax.ATURI 62 OwnerDid string 63 Rkey string 64 65 // content 66 Title string 67 Body string 68 TargetBranch string 69 State PullState 70 Submissions []*PullSubmission 71 72 // stacking 73 StackId string // nullable string 74 ChangeId string // nullable string 75 ParentChangeId string // nullable string 76 77 // meta 78 Created time.Time 79 PullSource *PullSource 80 81 // optionally, populate this when querying for reverse mappings 82 Repo *Repo 83} 84 85func (p Pull) AsRecord() tangled.RepoPull { 86 var source *tangled.RepoPull_Source 87 if p.PullSource != nil { 88 s := p.PullSource.AsRecord() 89 source = &s 90 source.Sha = p.LatestSha() 91 } 92 93 record := tangled.RepoPull{ 94 Title: p.Title, 95 Body: &p.Body, 96 CreatedAt: p.Created.Format(time.RFC3339), 97 Target: &tangled.RepoPull_Target{ 98 Repo: p.RepoAt.String(), 99 Branch: p.TargetBranch, 100 }, 101 Patch: p.LatestPatch(), 102 Source: source, 103 } 104 return record 105} 106 107type PullSource struct { 108 Branch string 109 RepoAt *syntax.ATURI 110 111 // optionally populate this for reverse mappings 112 Repo *Repo 113} 114 115func (p PullSource) AsRecord() tangled.RepoPull_Source { 116 var repoAt *string 117 if p.RepoAt != nil { 118 s := p.RepoAt.String() 119 repoAt = &s 120 } 121 record := tangled.RepoPull_Source{ 122 Branch: p.Branch, 123 Repo: repoAt, 124 } 125 return record 126} 127 128type PullSubmission struct { 129 // ids 130 ID int 131 PullId int 132 133 // at ids 134 RepoAt syntax.ATURI 135 136 // content 137 RoundNumber int 138 Patch string 139 Comments []PullComment 140 SourceRev string // include the rev that was used to create this submission: only for branch/fork PRs 141 142 // meta 143 Created time.Time 144} 145 146type PullComment struct { 147 // ids 148 ID int 149 PullId int 150 SubmissionId int 151 152 // at ids 153 RepoAt string 154 OwnerDid string 155 CommentAt string 156 157 // content 158 Body string 159 160 // meta 161 Created time.Time 162} 163 164func (p *Pull) LatestPatch() string { 165 latestSubmission := p.Submissions[p.LastRoundNumber()] 166 return latestSubmission.Patch 167} 168 169func (p *Pull) LatestSha() string { 170 latestSubmission := p.Submissions[p.LastRoundNumber()] 171 return latestSubmission.SourceRev 172} 173 174func (p *Pull) PullAt() syntax.ATURI { 175 return syntax.ATURI(fmt.Sprintf("at://%s/%s/%s", p.OwnerDid, tangled.RepoPullNSID, p.Rkey)) 176} 177 178func (p *Pull) LastRoundNumber() int { 179 return len(p.Submissions) - 1 180} 181 182func (p *Pull) IsPatchBased() bool { 183 return p.PullSource == nil 184} 185 186func (p *Pull) IsBranchBased() bool { 187 if p.PullSource != nil { 188 if p.PullSource.RepoAt != nil { 189 return p.PullSource.RepoAt == &p.RepoAt 190 } else { 191 // no repo specified 192 return true 193 } 194 } 195 return false 196} 197 198func (p *Pull) IsForkBased() bool { 199 if p.PullSource != nil { 200 if p.PullSource.RepoAt != nil { 201 // make sure repos are different 202 return p.PullSource.RepoAt != &p.RepoAt 203 } 204 } 205 return false 206} 207 208func (p *Pull) IsStacked() bool { 209 return p.StackId != "" 210} 211 212func (s PullSubmission) IsFormatPatch() bool { 213 return patchutil.IsFormatPatch(s.Patch) 214} 215 216func (s PullSubmission) AsFormatPatch() []types.FormatPatch { 217 patches, err := patchutil.ExtractPatches(s.Patch) 218 if err != nil { 219 log.Println("error extracting patches from submission:", err) 220 return []types.FormatPatch{} 221 } 222 223 return patches 224} 225 226func NewPull(tx *sql.Tx, pull *Pull) error { 227 _, err := tx.Exec(` 228 insert or ignore into repo_pull_seqs (repo_at, next_pull_id) 229 values (?, 1) 230 `, pull.RepoAt) 231 if err != nil { 232 return err 233 } 234 235 var nextId int 236 err = tx.QueryRow(` 237 update repo_pull_seqs 238 set next_pull_id = next_pull_id + 1 239 where repo_at = ? 240 returning next_pull_id - 1 241 `, pull.RepoAt).Scan(&nextId) 242 if err != nil { 243 return err 244 } 245 246 pull.PullId = nextId 247 pull.State = PullOpen 248 249 var sourceBranch, sourceRepoAt *string 250 if pull.PullSource != nil { 251 sourceBranch = &pull.PullSource.Branch 252 if pull.PullSource.RepoAt != nil { 253 x := pull.PullSource.RepoAt.String() 254 sourceRepoAt = &x 255 } 256 } 257 258 var stackId, changeId, parentChangeId *string 259 if pull.StackId != "" { 260 stackId = &pull.StackId 261 } 262 if pull.ChangeId != "" { 263 changeId = &pull.ChangeId 264 } 265 if pull.ParentChangeId != "" { 266 parentChangeId = &pull.ParentChangeId 267 } 268 269 _, err = tx.Exec( 270 ` 271 insert into pulls ( 272 repo_at, owner_did, pull_id, title, target_branch, body, rkey, state, source_branch, source_repo_at, stack_id, change_id, parent_change_id 273 ) 274 values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`, 275 pull.RepoAt, 276 pull.OwnerDid, 277 pull.PullId, 278 pull.Title, 279 pull.TargetBranch, 280 pull.Body, 281 pull.Rkey, 282 pull.State, 283 sourceBranch, 284 sourceRepoAt, 285 stackId, 286 changeId, 287 parentChangeId, 288 ) 289 if err != nil { 290 return err 291 } 292 293 _, err = tx.Exec(` 294 insert into pull_submissions (pull_id, repo_at, round_number, patch, source_rev) 295 values (?, ?, ?, ?, ?) 296 `, pull.PullId, pull.RepoAt, 0, pull.Submissions[0].Patch, pull.Submissions[0].SourceRev) 297 return err 298} 299 300func GetPullAt(e Execer, repoAt syntax.ATURI, pullId int) (syntax.ATURI, error) { 301 pull, err := GetPull(e, repoAt, pullId) 302 if err != nil { 303 return "", err 304 } 305 return pull.PullAt(), err 306} 307 308func NextPullId(e Execer, repoAt syntax.ATURI) (int, error) { 309 var pullId int 310 err := e.QueryRow(`select next_pull_id from repo_pull_seqs where repo_at = ?`, repoAt).Scan(&pullId) 311 return pullId - 1, err 312} 313 314func GetPullsWithLimit(e Execer, limit int, filters ...filter) ([]*Pull, error) { 315 pulls := make(map[int]*Pull) 316 317 var conditions []string 318 var args []any 319 for _, filter := range filters { 320 conditions = append(conditions, filter.Condition()) 321 args = append(args, filter.Arg()...) 322 } 323 324 whereClause := "" 325 if conditions != nil { 326 whereClause = " where " + strings.Join(conditions, " and ") 327 } 328 limitClause := "" 329 if limit != 0 { 330 limitClause = fmt.Sprintf(" limit %d ", limit) 331 } 332 333 query := fmt.Sprintf(` 334 select 335 owner_did, 336 repo_at, 337 pull_id, 338 created, 339 title, 340 state, 341 target_branch, 342 body, 343 rkey, 344 source_branch, 345 source_repo_at, 346 stack_id, 347 change_id, 348 parent_change_id 349 from 350 pulls 351 %s 352 order by 353 created desc 354 %s 355 `, whereClause, limitClause) 356 357 rows, err := e.Query(query, args...) 358 if err != nil { 359 return nil, err 360 } 361 defer rows.Close() 362 363 for rows.Next() { 364 var pull Pull 365 var createdAt string 366 var sourceBranch, sourceRepoAt, stackId, changeId, parentChangeId sql.NullString 367 err := rows.Scan( 368 &pull.OwnerDid, 369 &pull.RepoAt, 370 &pull.PullId, 371 &createdAt, 372 &pull.Title, 373 &pull.State, 374 &pull.TargetBranch, 375 &pull.Body, 376 &pull.Rkey, 377 &sourceBranch, 378 &sourceRepoAt, 379 &stackId, 380 &changeId, 381 &parentChangeId, 382 ) 383 if err != nil { 384 return nil, err 385 } 386 387 createdTime, err := time.Parse(time.RFC3339, createdAt) 388 if err != nil { 389 return nil, err 390 } 391 pull.Created = createdTime 392 393 if sourceBranch.Valid { 394 pull.PullSource = &PullSource{ 395 Branch: sourceBranch.String, 396 } 397 if sourceRepoAt.Valid { 398 sourceRepoAtParsed, err := syntax.ParseATURI(sourceRepoAt.String) 399 if err != nil { 400 return nil, err 401 } 402 pull.PullSource.RepoAt = &sourceRepoAtParsed 403 } 404 } 405 406 if stackId.Valid { 407 pull.StackId = stackId.String 408 } 409 if changeId.Valid { 410 pull.ChangeId = changeId.String 411 } 412 if parentChangeId.Valid { 413 pull.ParentChangeId = parentChangeId.String 414 } 415 416 pulls[pull.PullId] = &pull 417 } 418 419 // get latest round no. for each pull 420 inClause := strings.TrimSuffix(strings.Repeat("?, ", len(pulls)), ", ") 421 submissionsQuery := fmt.Sprintf(` 422 select 423 id, pull_id, round_number, patch, created, source_rev 424 from 425 pull_submissions 426 where 427 repo_at in (%s) and pull_id in (%s) 428 `, inClause, inClause) 429 430 args = make([]any, len(pulls)*2) 431 idx := 0 432 for _, p := range pulls { 433 args[idx] = p.RepoAt 434 idx += 1 435 } 436 for _, p := range pulls { 437 args[idx] = p.PullId 438 idx += 1 439 } 440 submissionsRows, err := e.Query(submissionsQuery, args...) 441 if err != nil { 442 return nil, err 443 } 444 defer submissionsRows.Close() 445 446 for submissionsRows.Next() { 447 var s PullSubmission 448 var sourceRev sql.NullString 449 var createdAt string 450 err := submissionsRows.Scan( 451 &s.ID, 452 &s.PullId, 453 &s.RoundNumber, 454 &s.Patch, 455 &createdAt, 456 &sourceRev, 457 ) 458 if err != nil { 459 return nil, err 460 } 461 462 createdTime, err := time.Parse(time.RFC3339, createdAt) 463 if err != nil { 464 return nil, err 465 } 466 s.Created = createdTime 467 468 if sourceRev.Valid { 469 s.SourceRev = sourceRev.String 470 } 471 472 if p, ok := pulls[s.PullId]; ok { 473 p.Submissions = make([]*PullSubmission, s.RoundNumber+1) 474 p.Submissions[s.RoundNumber] = &s 475 } 476 } 477 if err := rows.Err(); err != nil { 478 return nil, err 479 } 480 481 // get comment count on latest submission on each pull 482 inClause = strings.TrimSuffix(strings.Repeat("?, ", len(pulls)), ", ") 483 commentsQuery := fmt.Sprintf(` 484 select 485 count(id), pull_id 486 from 487 pull_comments 488 where 489 submission_id in (%s) 490 group by 491 submission_id 492 `, inClause) 493 494 args = []any{} 495 for _, p := range pulls { 496 args = append(args, p.Submissions[p.LastRoundNumber()].ID) 497 } 498 commentsRows, err := e.Query(commentsQuery, args...) 499 if err != nil { 500 return nil, err 501 } 502 defer commentsRows.Close() 503 504 for commentsRows.Next() { 505 var commentCount, pullId int 506 err := commentsRows.Scan( 507 &commentCount, 508 &pullId, 509 ) 510 if err != nil { 511 return nil, err 512 } 513 if p, ok := pulls[pullId]; ok { 514 p.Submissions[p.LastRoundNumber()].Comments = make([]PullComment, commentCount) 515 } 516 } 517 if err := rows.Err(); err != nil { 518 return nil, err 519 } 520 521 orderedByPullId := []*Pull{} 522 for _, p := range pulls { 523 orderedByPullId = append(orderedByPullId, p) 524 } 525 sort.Slice(orderedByPullId, func(i, j int) bool { 526 return orderedByPullId[i].PullId > orderedByPullId[j].PullId 527 }) 528 529 return orderedByPullId, nil 530} 531 532func GetPulls(e Execer, filters ...filter) ([]*Pull, error) { 533 return GetPullsWithLimit(e, 0, filters...) 534} 535 536func GetPull(e Execer, repoAt syntax.ATURI, pullId int) (*Pull, error) { 537 query := ` 538 select 539 owner_did, 540 pull_id, 541 created, 542 title, 543 state, 544 target_branch, 545 repo_at, 546 body, 547 rkey, 548 source_branch, 549 source_repo_at, 550 stack_id, 551 change_id, 552 parent_change_id 553 from 554 pulls 555 where 556 repo_at = ? and pull_id = ? 557 ` 558 row := e.QueryRow(query, repoAt, pullId) 559 560 var pull Pull 561 var createdAt string 562 var sourceBranch, sourceRepoAt, stackId, changeId, parentChangeId sql.NullString 563 err := row.Scan( 564 &pull.OwnerDid, 565 &pull.PullId, 566 &createdAt, 567 &pull.Title, 568 &pull.State, 569 &pull.TargetBranch, 570 &pull.RepoAt, 571 &pull.Body, 572 &pull.Rkey, 573 &sourceBranch, 574 &sourceRepoAt, 575 &stackId, 576 &changeId, 577 &parentChangeId, 578 ) 579 if err != nil { 580 return nil, err 581 } 582 583 createdTime, err := time.Parse(time.RFC3339, createdAt) 584 if err != nil { 585 return nil, err 586 } 587 pull.Created = createdTime 588 589 // populate source 590 if sourceBranch.Valid { 591 pull.PullSource = &PullSource{ 592 Branch: sourceBranch.String, 593 } 594 if sourceRepoAt.Valid { 595 sourceRepoAtParsed, err := syntax.ParseATURI(sourceRepoAt.String) 596 if err != nil { 597 return nil, err 598 } 599 pull.PullSource.RepoAt = &sourceRepoAtParsed 600 } 601 } 602 603 if stackId.Valid { 604 pull.StackId = stackId.String 605 } 606 if changeId.Valid { 607 pull.ChangeId = changeId.String 608 } 609 if parentChangeId.Valid { 610 pull.ParentChangeId = parentChangeId.String 611 } 612 613 submissionsQuery := ` 614 select 615 id, pull_id, repo_at, round_number, patch, created, source_rev 616 from 617 pull_submissions 618 where 619 repo_at = ? and pull_id = ? 620 ` 621 submissionsRows, err := e.Query(submissionsQuery, repoAt, pullId) 622 if err != nil { 623 return nil, err 624 } 625 defer submissionsRows.Close() 626 627 submissionsMap := make(map[int]*PullSubmission) 628 629 for submissionsRows.Next() { 630 var submission PullSubmission 631 var submissionCreatedStr string 632 var submissionSourceRev sql.NullString 633 err := submissionsRows.Scan( 634 &submission.ID, 635 &submission.PullId, 636 &submission.RepoAt, 637 &submission.RoundNumber, 638 &submission.Patch, 639 &submissionCreatedStr, 640 &submissionSourceRev, 641 ) 642 if err != nil { 643 return nil, err 644 } 645 646 submissionCreatedTime, err := time.Parse(time.RFC3339, submissionCreatedStr) 647 if err != nil { 648 return nil, err 649 } 650 submission.Created = submissionCreatedTime 651 652 if submissionSourceRev.Valid { 653 submission.SourceRev = submissionSourceRev.String 654 } 655 656 submissionsMap[submission.ID] = &submission 657 } 658 if err = submissionsRows.Close(); err != nil { 659 return nil, err 660 } 661 if len(submissionsMap) == 0 { 662 return &pull, nil 663 } 664 665 var args []any 666 for k := range submissionsMap { 667 args = append(args, k) 668 } 669 inClause := strings.TrimSuffix(strings.Repeat("?, ", len(submissionsMap)), ", ") 670 commentsQuery := fmt.Sprintf(` 671 select 672 id, 673 pull_id, 674 submission_id, 675 repo_at, 676 owner_did, 677 comment_at, 678 body, 679 created 680 from 681 pull_comments 682 where 683 submission_id IN (%s) 684 order by 685 created asc 686 `, inClause) 687 commentsRows, err := e.Query(commentsQuery, args...) 688 if err != nil { 689 return nil, err 690 } 691 defer commentsRows.Close() 692 693 for commentsRows.Next() { 694 var comment PullComment 695 var commentCreatedStr string 696 err := commentsRows.Scan( 697 &comment.ID, 698 &comment.PullId, 699 &comment.SubmissionId, 700 &comment.RepoAt, 701 &comment.OwnerDid, 702 &comment.CommentAt, 703 &comment.Body, 704 &commentCreatedStr, 705 ) 706 if err != nil { 707 return nil, err 708 } 709 710 commentCreatedTime, err := time.Parse(time.RFC3339, commentCreatedStr) 711 if err != nil { 712 return nil, err 713 } 714 comment.Created = commentCreatedTime 715 716 // Add the comment to its submission 717 if submission, ok := submissionsMap[comment.SubmissionId]; ok { 718 submission.Comments = append(submission.Comments, comment) 719 } 720 721 } 722 if err = commentsRows.Err(); err != nil { 723 return nil, err 724 } 725 726 var pullSourceRepo *Repo 727 if pull.PullSource != nil { 728 if pull.PullSource.RepoAt != nil { 729 pullSourceRepo, err = GetRepoByAtUri(e, pull.PullSource.RepoAt.String()) 730 if err != nil { 731 log.Printf("failed to get repo by at uri: %v", err) 732 } else { 733 pull.PullSource.Repo = pullSourceRepo 734 } 735 } 736 } 737 738 pull.Submissions = make([]*PullSubmission, len(submissionsMap)) 739 for _, submission := range submissionsMap { 740 pull.Submissions[submission.RoundNumber] = submission 741 } 742 743 return &pull, nil 744} 745 746// timeframe here is directly passed into the sql query filter, and any 747// timeframe in the past should be negative; e.g.: "-3 months" 748func GetPullsByOwnerDid(e Execer, did, timeframe string) ([]Pull, error) { 749 var pulls []Pull 750 751 rows, err := e.Query(` 752 select 753 p.owner_did, 754 p.repo_at, 755 p.pull_id, 756 p.created, 757 p.title, 758 p.state, 759 r.did, 760 r.name, 761 r.knot, 762 r.rkey, 763 r.created 764 from 765 pulls p 766 join 767 repos r on p.repo_at = r.at_uri 768 where 769 p.owner_did = ? and p.created >= date ('now', ?) 770 order by 771 p.created desc`, did, timeframe) 772 if err != nil { 773 return nil, err 774 } 775 defer rows.Close() 776 777 for rows.Next() { 778 var pull Pull 779 var repo Repo 780 var pullCreatedAt, repoCreatedAt string 781 err := rows.Scan( 782 &pull.OwnerDid, 783 &pull.RepoAt, 784 &pull.PullId, 785 &pullCreatedAt, 786 &pull.Title, 787 &pull.State, 788 &repo.Did, 789 &repo.Name, 790 &repo.Knot, 791 &repo.Rkey, 792 &repoCreatedAt, 793 ) 794 if err != nil { 795 return nil, err 796 } 797 798 pullCreatedTime, err := time.Parse(time.RFC3339, pullCreatedAt) 799 if err != nil { 800 return nil, err 801 } 802 pull.Created = pullCreatedTime 803 804 repoCreatedTime, err := time.Parse(time.RFC3339, repoCreatedAt) 805 if err != nil { 806 return nil, err 807 } 808 repo.Created = repoCreatedTime 809 810 pull.Repo = &repo 811 812 pulls = append(pulls, pull) 813 } 814 815 if err := rows.Err(); err != nil { 816 return nil, err 817 } 818 819 return pulls, nil 820} 821 822func NewPullComment(e Execer, comment *PullComment) (int64, error) { 823 query := `insert into pull_comments (owner_did, repo_at, submission_id, comment_at, pull_id, body) values (?, ?, ?, ?, ?, ?)` 824 res, err := e.Exec( 825 query, 826 comment.OwnerDid, 827 comment.RepoAt, 828 comment.SubmissionId, 829 comment.CommentAt, 830 comment.PullId, 831 comment.Body, 832 ) 833 if err != nil { 834 return 0, err 835 } 836 837 i, err := res.LastInsertId() 838 if err != nil { 839 return 0, err 840 } 841 842 return i, nil 843} 844 845func SetPullState(e Execer, repoAt syntax.ATURI, pullId int, pullState PullState) error { 846 _, err := e.Exec( 847 `update pulls set state = ? where repo_at = ? and pull_id = ? and (state <> ? or state <> ?)`, 848 pullState, 849 repoAt, 850 pullId, 851 PullDeleted, // only update state of non-deleted pulls 852 PullMerged, // only update state of non-merged pulls 853 ) 854 return err 855} 856 857func ClosePull(e Execer, repoAt syntax.ATURI, pullId int) error { 858 err := SetPullState(e, repoAt, pullId, PullClosed) 859 return err 860} 861 862func ReopenPull(e Execer, repoAt syntax.ATURI, pullId int) error { 863 err := SetPullState(e, repoAt, pullId, PullOpen) 864 return err 865} 866 867func MergePull(e Execer, repoAt syntax.ATURI, pullId int) error { 868 err := SetPullState(e, repoAt, pullId, PullMerged) 869 return err 870} 871 872func DeletePull(e Execer, repoAt syntax.ATURI, pullId int) error { 873 err := SetPullState(e, repoAt, pullId, PullDeleted) 874 return err 875} 876 877func ResubmitPull(e Execer, pull *Pull, newPatch, sourceRev string) error { 878 newRoundNumber := len(pull.Submissions) 879 _, err := e.Exec(` 880 insert into pull_submissions (pull_id, repo_at, round_number, patch, source_rev) 881 values (?, ?, ?, ?, ?) 882 `, pull.PullId, pull.RepoAt, newRoundNumber, newPatch, sourceRev) 883 884 return err 885} 886 887func SetPullParentChangeId(e Execer, parentChangeId string, filters ...filter) error { 888 var conditions []string 889 var args []any 890 891 args = append(args, parentChangeId) 892 893 for _, filter := range filters { 894 conditions = append(conditions, filter.Condition()) 895 args = append(args, filter.Arg()...) 896 } 897 898 whereClause := "" 899 if conditions != nil { 900 whereClause = " where " + strings.Join(conditions, " and ") 901 } 902 903 query := fmt.Sprintf("update pulls set parent_change_id = ? %s", whereClause) 904 _, err := e.Exec(query, args...) 905 906 return err 907} 908 909// Only used when stacking to update contents in the event of a rebase (the interdiff should be empty). 910// otherwise submissions are immutable 911func UpdatePull(e Execer, newPatch, sourceRev string, filters ...filter) error { 912 var conditions []string 913 var args []any 914 915 args = append(args, sourceRev) 916 args = append(args, newPatch) 917 918 for _, filter := range filters { 919 conditions = append(conditions, filter.Condition()) 920 args = append(args, filter.Arg()...) 921 } 922 923 whereClause := "" 924 if conditions != nil { 925 whereClause = " where " + strings.Join(conditions, " and ") 926 } 927 928 query := fmt.Sprintf("update pull_submissions set source_rev = ?, patch = ? %s", whereClause) 929 _, err := e.Exec(query, args...) 930 931 return err 932} 933 934type PullCount struct { 935 Open int 936 Merged int 937 Closed int 938 Deleted int 939} 940 941func GetPullCount(e Execer, repoAt syntax.ATURI) (PullCount, error) { 942 row := e.QueryRow(` 943 select 944 count(case when state = ? then 1 end) as open_count, 945 count(case when state = ? then 1 end) as merged_count, 946 count(case when state = ? then 1 end) as closed_count, 947 count(case when state = ? then 1 end) as deleted_count 948 from pulls 949 where repo_at = ?`, 950 PullOpen, 951 PullMerged, 952 PullClosed, 953 PullDeleted, 954 repoAt, 955 ) 956 957 var count PullCount 958 if err := row.Scan(&count.Open, &count.Merged, &count.Closed, &count.Deleted); err != nil { 959 return PullCount{0, 0, 0, 0}, err 960 } 961 962 return count, nil 963} 964 965type Stack []*Pull 966 967// change-id parent-change-id 968// 969// 4 w ,-------- z (TOP) 970// 3 z <----',------- y 971// 2 y <-----',------ x 972// 1 x <------' nil (BOT) 973// 974// `w` is parent of none, so it is the top of the stack 975func GetStack(e Execer, stackId string) (Stack, error) { 976 unorderedPulls, err := GetPulls( 977 e, 978 FilterEq("stack_id", stackId), 979 FilterNotEq("state", PullDeleted), 980 ) 981 if err != nil { 982 return nil, err 983 } 984 // map of parent-change-id to pull 985 changeIdMap := make(map[string]*Pull, len(unorderedPulls)) 986 parentMap := make(map[string]*Pull, len(unorderedPulls)) 987 for _, p := range unorderedPulls { 988 changeIdMap[p.ChangeId] = p 989 if p.ParentChangeId != "" { 990 parentMap[p.ParentChangeId] = p 991 } 992 } 993 994 // the top of the stack is the pull that is not a parent of any pull 995 var topPull *Pull 996 for _, maybeTop := range unorderedPulls { 997 if _, ok := parentMap[maybeTop.ChangeId]; !ok { 998 topPull = maybeTop 999 break 1000 } 1001 } 1002 1003 pulls := []*Pull{} 1004 for { 1005 pulls = append(pulls, topPull) 1006 if topPull.ParentChangeId != "" { 1007 if next, ok := changeIdMap[topPull.ParentChangeId]; ok { 1008 topPull = next 1009 } else { 1010 return nil, fmt.Errorf("failed to find parent pull request, stack is malformed") 1011 } 1012 } else { 1013 break 1014 } 1015 } 1016 1017 return pulls, nil 1018} 1019 1020func GetAbandonedPulls(e Execer, stackId string) ([]*Pull, error) { 1021 pulls, err := GetPulls( 1022 e, 1023 FilterEq("stack_id", stackId), 1024 FilterEq("state", PullDeleted), 1025 ) 1026 if err != nil { 1027 return nil, err 1028 } 1029 1030 return pulls, nil 1031} 1032 1033// position of this pull in the stack 1034func (stack Stack) Position(pull *Pull) int { 1035 return slices.IndexFunc(stack, func(p *Pull) bool { 1036 return p.ChangeId == pull.ChangeId 1037 }) 1038} 1039 1040// all pulls below this pull (including self) in this stack 1041// 1042// nil if this pull does not belong to this stack 1043func (stack Stack) Below(pull *Pull) Stack { 1044 position := stack.Position(pull) 1045 1046 if position < 0 { 1047 return nil 1048 } 1049 1050 return stack[position:] 1051} 1052 1053// all pulls below this pull (excluding self) in this stack 1054func (stack Stack) StrictlyBelow(pull *Pull) Stack { 1055 below := stack.Below(pull) 1056 1057 if len(below) > 0 { 1058 return below[1:] 1059 } 1060 1061 return nil 1062} 1063 1064// all pulls above this pull (including self) in this stack 1065func (stack Stack) Above(pull *Pull) Stack { 1066 position := stack.Position(pull) 1067 1068 if position < 0 { 1069 return nil 1070 } 1071 1072 return stack[:position+1] 1073} 1074 1075// all pulls below this pull (excluding self) in this stack 1076func (stack Stack) StrictlyAbove(pull *Pull) Stack { 1077 above := stack.Above(pull) 1078 1079 if len(above) > 0 { 1080 return above[:len(above)-1] 1081 } 1082 1083 return nil 1084} 1085 1086// the combined format-patches of all the newest submissions in this stack 1087func (stack Stack) CombinedPatch() string { 1088 // go in reverse order because the bottom of the stack is the last element in the slice 1089 var combined strings.Builder 1090 for idx := range stack { 1091 pull := stack[len(stack)-1-idx] 1092 combined.WriteString(pull.LatestPatch()) 1093 combined.WriteString("\n") 1094 } 1095 return combined.String() 1096} 1097 1098// filter out PRs that are "active" 1099// 1100// PRs that are still open are active 1101func (stack Stack) Mergeable() Stack { 1102 var mergeable Stack 1103 1104 for _, p := range stack { 1105 // stop at the first merged PR 1106 if p.State == PullMerged || p.State == PullClosed { 1107 break 1108 } 1109 1110 // skip over deleted PRs 1111 if p.State != PullDeleted { 1112 mergeable = append(mergeable, p) 1113 } 1114 } 1115 1116 return mergeable 1117}