forked from tangled.org/core
this repo has no description
1package db 2 3import ( 4 "database/sql" 5 "fmt" 6 "log" 7 "slices" 8 "sort" 9 "strings" 10 "time" 11 12 "github.com/bluesky-social/indigo/atproto/syntax" 13 "tangled.sh/tangled.sh/core/api/tangled" 14 "tangled.sh/tangled.sh/core/patchutil" 15 "tangled.sh/tangled.sh/core/types" 16) 17 18type PullState int 19 20const ( 21 PullClosed PullState = iota 22 PullOpen 23 PullMerged 24 PullDeleted 25) 26 27func (p PullState) String() string { 28 switch p { 29 case PullOpen: 30 return "open" 31 case PullMerged: 32 return "merged" 33 case PullClosed: 34 return "closed" 35 case PullDeleted: 36 return "deleted" 37 default: 38 return "closed" 39 } 40} 41 42func (p PullState) IsOpen() bool { 43 return p == PullOpen 44} 45func (p PullState) IsMerged() bool { 46 return p == PullMerged 47} 48func (p PullState) IsClosed() bool { 49 return p == PullClosed 50} 51func (p PullState) IsDeleted() bool { 52 return p == PullDeleted 53} 54 55type Pull struct { 56 // ids 57 ID int 58 PullId int 59 60 // at ids 61 RepoAt syntax.ATURI 62 OwnerDid string 63 Rkey string 64 65 // content 66 Title string 67 Body string 68 TargetBranch string 69 State PullState 70 Submissions []*PullSubmission 71 72 // stacking 73 StackId string // nullable string 74 ChangeId string // nullable string 75 ParentChangeId string // nullable string 76 77 // meta 78 Created time.Time 79 PullSource *PullSource 80 81 // optionally, populate this when querying for reverse mappings 82 Repo *Repo 83} 84 85func (p Pull) AsRecord() tangled.RepoPull { 86 var source *tangled.RepoPull_Source 87 if p.PullSource != nil { 88 s := p.PullSource.AsRecord() 89 source = &s 90 source.Sha = p.LatestSha() 91 } 92 93 record := tangled.RepoPull{ 94 Title: p.Title, 95 Body: &p.Body, 96 CreatedAt: p.Created.Format(time.RFC3339), 97 PullId: int64(p.PullId), 98 TargetRepo: p.RepoAt.String(), 99 TargetBranch: p.TargetBranch, 100 Patch: p.LatestPatch(), 101 Source: source, 102 } 103 return record 104} 105 106type PullSource struct { 107 Branch string 108 RepoAt *syntax.ATURI 109 110 // optionally populate this for reverse mappings 111 Repo *Repo 112} 113 114func (p PullSource) AsRecord() tangled.RepoPull_Source { 115 var repoAt *string 116 if p.RepoAt != nil { 117 s := p.RepoAt.String() 118 repoAt = &s 119 } 120 record := tangled.RepoPull_Source{ 121 Branch: p.Branch, 122 Repo: repoAt, 123 } 124 return record 125} 126 127type PullSubmission struct { 128 // ids 129 ID int 130 PullId int 131 132 // at ids 133 RepoAt syntax.ATURI 134 135 // content 136 RoundNumber int 137 Patch string 138 Comments []PullComment 139 SourceRev string // include the rev that was used to create this submission: only for branch/fork PRs 140 141 // meta 142 Created time.Time 143} 144 145type PullComment struct { 146 // ids 147 ID int 148 PullId int 149 SubmissionId int 150 151 // at ids 152 RepoAt string 153 OwnerDid string 154 CommentAt string 155 156 // content 157 Body string 158 159 // meta 160 Created time.Time 161} 162 163func (p *Pull) LatestPatch() string { 164 latestSubmission := p.Submissions[p.LastRoundNumber()] 165 return latestSubmission.Patch 166} 167 168func (p *Pull) LatestSha() string { 169 latestSubmission := p.Submissions[p.LastRoundNumber()] 170 return latestSubmission.SourceRev 171} 172 173func (p *Pull) PullAt() syntax.ATURI { 174 return syntax.ATURI(fmt.Sprintf("at://%s/%s/%s", p.OwnerDid, tangled.RepoPullNSID, p.Rkey)) 175} 176 177func (p *Pull) LastRoundNumber() int { 178 return len(p.Submissions) - 1 179} 180 181func (p *Pull) IsPatchBased() bool { 182 return p.PullSource == nil 183} 184 185func (p *Pull) IsBranchBased() bool { 186 if p.PullSource != nil { 187 if p.PullSource.RepoAt != nil { 188 return p.PullSource.RepoAt == &p.RepoAt 189 } else { 190 // no repo specified 191 return true 192 } 193 } 194 return false 195} 196 197func (p *Pull) IsForkBased() bool { 198 if p.PullSource != nil { 199 if p.PullSource.RepoAt != nil { 200 // make sure repos are different 201 return p.PullSource.RepoAt != &p.RepoAt 202 } 203 } 204 return false 205} 206 207func (p *Pull) IsStacked() bool { 208 return p.StackId != "" 209} 210 211func (s PullSubmission) IsFormatPatch() bool { 212 return patchutil.IsFormatPatch(s.Patch) 213} 214 215func (s PullSubmission) AsFormatPatch() []types.FormatPatch { 216 patches, err := patchutil.ExtractPatches(s.Patch) 217 if err != nil { 218 log.Println("error extracting patches from submission:", err) 219 return []types.FormatPatch{} 220 } 221 222 return patches 223} 224 225func NewPull(tx *sql.Tx, pull *Pull) error { 226 _, err := tx.Exec(` 227 insert or ignore into repo_pull_seqs (repo_at, next_pull_id) 228 values (?, 1) 229 `, pull.RepoAt) 230 if err != nil { 231 return err 232 } 233 234 var nextId int 235 err = tx.QueryRow(` 236 update repo_pull_seqs 237 set next_pull_id = next_pull_id + 1 238 where repo_at = ? 239 returning next_pull_id - 1 240 `, pull.RepoAt).Scan(&nextId) 241 if err != nil { 242 return err 243 } 244 245 pull.PullId = nextId 246 pull.State = PullOpen 247 248 var sourceBranch, sourceRepoAt *string 249 if pull.PullSource != nil { 250 sourceBranch = &pull.PullSource.Branch 251 if pull.PullSource.RepoAt != nil { 252 x := pull.PullSource.RepoAt.String() 253 sourceRepoAt = &x 254 } 255 } 256 257 var stackId, changeId, parentChangeId *string 258 if pull.StackId != "" { 259 stackId = &pull.StackId 260 } 261 if pull.ChangeId != "" { 262 changeId = &pull.ChangeId 263 } 264 if pull.ParentChangeId != "" { 265 parentChangeId = &pull.ParentChangeId 266 } 267 268 _, err = tx.Exec( 269 ` 270 insert into pulls ( 271 repo_at, owner_did, pull_id, title, target_branch, body, rkey, state, source_branch, source_repo_at, stack_id, change_id, parent_change_id 272 ) 273 values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`, 274 pull.RepoAt, 275 pull.OwnerDid, 276 pull.PullId, 277 pull.Title, 278 pull.TargetBranch, 279 pull.Body, 280 pull.Rkey, 281 pull.State, 282 sourceBranch, 283 sourceRepoAt, 284 stackId, 285 changeId, 286 parentChangeId, 287 ) 288 if err != nil { 289 return err 290 } 291 292 _, err = tx.Exec(` 293 insert into pull_submissions (pull_id, repo_at, round_number, patch, source_rev) 294 values (?, ?, ?, ?, ?) 295 `, pull.PullId, pull.RepoAt, 0, pull.Submissions[0].Patch, pull.Submissions[0].SourceRev) 296 return err 297} 298 299func GetPullAt(e Execer, repoAt syntax.ATURI, pullId int) (syntax.ATURI, error) { 300 pull, err := GetPull(e, repoAt, pullId) 301 if err != nil { 302 return "", err 303 } 304 return pull.PullAt(), err 305} 306 307func NextPullId(e Execer, repoAt syntax.ATURI) (int, error) { 308 var pullId int 309 err := e.QueryRow(`select next_pull_id from repo_pull_seqs where repo_at = ?`, repoAt).Scan(&pullId) 310 return pullId - 1, err 311} 312 313func GetPullsWithLimit(e Execer, limit int, filters ...filter) ([]*Pull, error) { 314 pulls := make(map[int]*Pull) 315 316 var conditions []string 317 var args []any 318 for _, filter := range filters { 319 conditions = append(conditions, filter.Condition()) 320 args = append(args, filter.Arg()...) 321 } 322 323 whereClause := "" 324 if conditions != nil { 325 whereClause = " where " + strings.Join(conditions, " and ") 326 } 327 limitClause := "" 328 if limit != 0 { 329 limitClause = fmt.Sprintf(" limit %d ", limit) 330 } 331 332 query := fmt.Sprintf(` 333 select 334 owner_did, 335 repo_at, 336 pull_id, 337 created, 338 title, 339 state, 340 target_branch, 341 body, 342 rkey, 343 source_branch, 344 source_repo_at, 345 stack_id, 346 change_id, 347 parent_change_id 348 from 349 pulls 350 %s 351 order by 352 created desc 353 %s 354 `, whereClause, limitClause) 355 356 rows, err := e.Query(query, args...) 357 if err != nil { 358 return nil, err 359 } 360 defer rows.Close() 361 362 for rows.Next() { 363 var pull Pull 364 var createdAt string 365 var sourceBranch, sourceRepoAt, stackId, changeId, parentChangeId sql.NullString 366 err := rows.Scan( 367 &pull.OwnerDid, 368 &pull.RepoAt, 369 &pull.PullId, 370 &createdAt, 371 &pull.Title, 372 &pull.State, 373 &pull.TargetBranch, 374 &pull.Body, 375 &pull.Rkey, 376 &sourceBranch, 377 &sourceRepoAt, 378 &stackId, 379 &changeId, 380 &parentChangeId, 381 ) 382 if err != nil { 383 return nil, err 384 } 385 386 createdTime, err := time.Parse(time.RFC3339, createdAt) 387 if err != nil { 388 return nil, err 389 } 390 pull.Created = createdTime 391 392 if sourceBranch.Valid { 393 pull.PullSource = &PullSource{ 394 Branch: sourceBranch.String, 395 } 396 if sourceRepoAt.Valid { 397 sourceRepoAtParsed, err := syntax.ParseATURI(sourceRepoAt.String) 398 if err != nil { 399 return nil, err 400 } 401 pull.PullSource.RepoAt = &sourceRepoAtParsed 402 } 403 } 404 405 if stackId.Valid { 406 pull.StackId = stackId.String 407 } 408 if changeId.Valid { 409 pull.ChangeId = changeId.String 410 } 411 if parentChangeId.Valid { 412 pull.ParentChangeId = parentChangeId.String 413 } 414 415 pulls[pull.PullId] = &pull 416 } 417 418 // get latest round no. for each pull 419 inClause := strings.TrimSuffix(strings.Repeat("?, ", len(pulls)), ", ") 420 submissionsQuery := fmt.Sprintf(` 421 select 422 id, pull_id, round_number, patch, created, source_rev 423 from 424 pull_submissions 425 where 426 repo_at in (%s) and pull_id in (%s) 427 `, inClause, inClause) 428 429 args = make([]any, len(pulls)*2) 430 idx := 0 431 for _, p := range pulls { 432 args[idx] = p.RepoAt 433 idx += 1 434 } 435 for _, p := range pulls { 436 args[idx] = p.PullId 437 idx += 1 438 } 439 submissionsRows, err := e.Query(submissionsQuery, args...) 440 if err != nil { 441 return nil, err 442 } 443 defer submissionsRows.Close() 444 445 for submissionsRows.Next() { 446 var s PullSubmission 447 var sourceRev sql.NullString 448 var createdAt string 449 err := submissionsRows.Scan( 450 &s.ID, 451 &s.PullId, 452 &s.RoundNumber, 453 &s.Patch, 454 &createdAt, 455 &sourceRev, 456 ) 457 if err != nil { 458 return nil, err 459 } 460 461 createdTime, err := time.Parse(time.RFC3339, createdAt) 462 if err != nil { 463 return nil, err 464 } 465 s.Created = createdTime 466 467 if sourceRev.Valid { 468 s.SourceRev = sourceRev.String 469 } 470 471 if p, ok := pulls[s.PullId]; ok { 472 p.Submissions = make([]*PullSubmission, s.RoundNumber+1) 473 p.Submissions[s.RoundNumber] = &s 474 } 475 } 476 if err := rows.Err(); err != nil { 477 return nil, err 478 } 479 480 // get comment count on latest submission on each pull 481 inClause = strings.TrimSuffix(strings.Repeat("?, ", len(pulls)), ", ") 482 commentsQuery := fmt.Sprintf(` 483 select 484 count(id), pull_id 485 from 486 pull_comments 487 where 488 submission_id in (%s) 489 group by 490 submission_id 491 `, inClause) 492 493 args = []any{} 494 for _, p := range pulls { 495 args = append(args, p.Submissions[p.LastRoundNumber()].ID) 496 } 497 commentsRows, err := e.Query(commentsQuery, args...) 498 if err != nil { 499 return nil, err 500 } 501 defer commentsRows.Close() 502 503 for commentsRows.Next() { 504 var commentCount, pullId int 505 err := commentsRows.Scan( 506 &commentCount, 507 &pullId, 508 ) 509 if err != nil { 510 return nil, err 511 } 512 if p, ok := pulls[pullId]; ok { 513 p.Submissions[p.LastRoundNumber()].Comments = make([]PullComment, commentCount) 514 } 515 } 516 if err := rows.Err(); err != nil { 517 return nil, err 518 } 519 520 orderedByPullId := []*Pull{} 521 for _, p := range pulls { 522 orderedByPullId = append(orderedByPullId, p) 523 } 524 sort.Slice(orderedByPullId, func(i, j int) bool { 525 return orderedByPullId[i].PullId > orderedByPullId[j].PullId 526 }) 527 528 return orderedByPullId, nil 529} 530 531func GetPulls(e Execer, filters ...filter) ([]*Pull, error) { 532 return GetPullsWithLimit(e, 0, filters...) 533} 534 535func GetPull(e Execer, repoAt syntax.ATURI, pullId int) (*Pull, error) { 536 query := ` 537 select 538 owner_did, 539 pull_id, 540 created, 541 title, 542 state, 543 target_branch, 544 repo_at, 545 body, 546 rkey, 547 source_branch, 548 source_repo_at, 549 stack_id, 550 change_id, 551 parent_change_id 552 from 553 pulls 554 where 555 repo_at = ? and pull_id = ? 556 ` 557 row := e.QueryRow(query, repoAt, pullId) 558 559 var pull Pull 560 var createdAt string 561 var sourceBranch, sourceRepoAt, stackId, changeId, parentChangeId sql.NullString 562 err := row.Scan( 563 &pull.OwnerDid, 564 &pull.PullId, 565 &createdAt, 566 &pull.Title, 567 &pull.State, 568 &pull.TargetBranch, 569 &pull.RepoAt, 570 &pull.Body, 571 &pull.Rkey, 572 &sourceBranch, 573 &sourceRepoAt, 574 &stackId, 575 &changeId, 576 &parentChangeId, 577 ) 578 if err != nil { 579 return nil, err 580 } 581 582 createdTime, err := time.Parse(time.RFC3339, createdAt) 583 if err != nil { 584 return nil, err 585 } 586 pull.Created = createdTime 587 588 // populate source 589 if sourceBranch.Valid { 590 pull.PullSource = &PullSource{ 591 Branch: sourceBranch.String, 592 } 593 if sourceRepoAt.Valid { 594 sourceRepoAtParsed, err := syntax.ParseATURI(sourceRepoAt.String) 595 if err != nil { 596 return nil, err 597 } 598 pull.PullSource.RepoAt = &sourceRepoAtParsed 599 } 600 } 601 602 if stackId.Valid { 603 pull.StackId = stackId.String 604 } 605 if changeId.Valid { 606 pull.ChangeId = changeId.String 607 } 608 if parentChangeId.Valid { 609 pull.ParentChangeId = parentChangeId.String 610 } 611 612 submissionsQuery := ` 613 select 614 id, pull_id, repo_at, round_number, patch, created, source_rev 615 from 616 pull_submissions 617 where 618 repo_at = ? and pull_id = ? 619 ` 620 submissionsRows, err := e.Query(submissionsQuery, repoAt, pullId) 621 if err != nil { 622 return nil, err 623 } 624 defer submissionsRows.Close() 625 626 submissionsMap := make(map[int]*PullSubmission) 627 628 for submissionsRows.Next() { 629 var submission PullSubmission 630 var submissionCreatedStr string 631 var submissionSourceRev sql.NullString 632 err := submissionsRows.Scan( 633 &submission.ID, 634 &submission.PullId, 635 &submission.RepoAt, 636 &submission.RoundNumber, 637 &submission.Patch, 638 &submissionCreatedStr, 639 &submissionSourceRev, 640 ) 641 if err != nil { 642 return nil, err 643 } 644 645 submissionCreatedTime, err := time.Parse(time.RFC3339, submissionCreatedStr) 646 if err != nil { 647 return nil, err 648 } 649 submission.Created = submissionCreatedTime 650 651 if submissionSourceRev.Valid { 652 submission.SourceRev = submissionSourceRev.String 653 } 654 655 submissionsMap[submission.ID] = &submission 656 } 657 if err = submissionsRows.Close(); err != nil { 658 return nil, err 659 } 660 if len(submissionsMap) == 0 { 661 return &pull, nil 662 } 663 664 var args []any 665 for k := range submissionsMap { 666 args = append(args, k) 667 } 668 inClause := strings.TrimSuffix(strings.Repeat("?, ", len(submissionsMap)), ", ") 669 commentsQuery := fmt.Sprintf(` 670 select 671 id, 672 pull_id, 673 submission_id, 674 repo_at, 675 owner_did, 676 comment_at, 677 body, 678 created 679 from 680 pull_comments 681 where 682 submission_id IN (%s) 683 order by 684 created asc 685 `, inClause) 686 commentsRows, err := e.Query(commentsQuery, args...) 687 if err != nil { 688 return nil, err 689 } 690 defer commentsRows.Close() 691 692 for commentsRows.Next() { 693 var comment PullComment 694 var commentCreatedStr string 695 err := commentsRows.Scan( 696 &comment.ID, 697 &comment.PullId, 698 &comment.SubmissionId, 699 &comment.RepoAt, 700 &comment.OwnerDid, 701 &comment.CommentAt, 702 &comment.Body, 703 &commentCreatedStr, 704 ) 705 if err != nil { 706 return nil, err 707 } 708 709 commentCreatedTime, err := time.Parse(time.RFC3339, commentCreatedStr) 710 if err != nil { 711 return nil, err 712 } 713 comment.Created = commentCreatedTime 714 715 // Add the comment to its submission 716 if submission, ok := submissionsMap[comment.SubmissionId]; ok { 717 submission.Comments = append(submission.Comments, comment) 718 } 719 720 } 721 if err = commentsRows.Err(); err != nil { 722 return nil, err 723 } 724 725 var pullSourceRepo *Repo 726 if pull.PullSource != nil { 727 if pull.PullSource.RepoAt != nil { 728 pullSourceRepo, err = GetRepoByAtUri(e, pull.PullSource.RepoAt.String()) 729 if err != nil { 730 log.Printf("failed to get repo by at uri: %v", err) 731 } else { 732 pull.PullSource.Repo = pullSourceRepo 733 } 734 } 735 } 736 737 pull.Submissions = make([]*PullSubmission, len(submissionsMap)) 738 for _, submission := range submissionsMap { 739 pull.Submissions[submission.RoundNumber] = submission 740 } 741 742 return &pull, nil 743} 744 745// timeframe here is directly passed into the sql query filter, and any 746// timeframe in the past should be negative; e.g.: "-3 months" 747func GetPullsByOwnerDid(e Execer, did, timeframe string) ([]Pull, error) { 748 var pulls []Pull 749 750 rows, err := e.Query(` 751 select 752 p.owner_did, 753 p.repo_at, 754 p.pull_id, 755 p.created, 756 p.title, 757 p.state, 758 r.did, 759 r.name, 760 r.knot, 761 r.rkey, 762 r.created 763 from 764 pulls p 765 join 766 repos r on p.repo_at = r.at_uri 767 where 768 p.owner_did = ? and p.created >= date ('now', ?) 769 order by 770 p.created desc`, did, timeframe) 771 if err != nil { 772 return nil, err 773 } 774 defer rows.Close() 775 776 for rows.Next() { 777 var pull Pull 778 var repo Repo 779 var pullCreatedAt, repoCreatedAt string 780 err := rows.Scan( 781 &pull.OwnerDid, 782 &pull.RepoAt, 783 &pull.PullId, 784 &pullCreatedAt, 785 &pull.Title, 786 &pull.State, 787 &repo.Did, 788 &repo.Name, 789 &repo.Knot, 790 &repo.Rkey, 791 &repoCreatedAt, 792 ) 793 if err != nil { 794 return nil, err 795 } 796 797 pullCreatedTime, err := time.Parse(time.RFC3339, pullCreatedAt) 798 if err != nil { 799 return nil, err 800 } 801 pull.Created = pullCreatedTime 802 803 repoCreatedTime, err := time.Parse(time.RFC3339, repoCreatedAt) 804 if err != nil { 805 return nil, err 806 } 807 repo.Created = repoCreatedTime 808 809 pull.Repo = &repo 810 811 pulls = append(pulls, pull) 812 } 813 814 if err := rows.Err(); err != nil { 815 return nil, err 816 } 817 818 return pulls, nil 819} 820 821func NewPullComment(e Execer, comment *PullComment) (int64, error) { 822 query := `insert into pull_comments (owner_did, repo_at, submission_id, comment_at, pull_id, body) values (?, ?, ?, ?, ?, ?)` 823 res, err := e.Exec( 824 query, 825 comment.OwnerDid, 826 comment.RepoAt, 827 comment.SubmissionId, 828 comment.CommentAt, 829 comment.PullId, 830 comment.Body, 831 ) 832 if err != nil { 833 return 0, err 834 } 835 836 i, err := res.LastInsertId() 837 if err != nil { 838 return 0, err 839 } 840 841 return i, nil 842} 843 844func SetPullState(e Execer, repoAt syntax.ATURI, pullId int, pullState PullState) error { 845 _, err := e.Exec( 846 `update pulls set state = ? where repo_at = ? and pull_id = ? and (state <> ? or state <> ?)`, 847 pullState, 848 repoAt, 849 pullId, 850 PullDeleted, // only update state of non-deleted pulls 851 PullMerged, // only update state of non-merged pulls 852 ) 853 return err 854} 855 856func ClosePull(e Execer, repoAt syntax.ATURI, pullId int) error { 857 err := SetPullState(e, repoAt, pullId, PullClosed) 858 return err 859} 860 861func ReopenPull(e Execer, repoAt syntax.ATURI, pullId int) error { 862 err := SetPullState(e, repoAt, pullId, PullOpen) 863 return err 864} 865 866func MergePull(e Execer, repoAt syntax.ATURI, pullId int) error { 867 err := SetPullState(e, repoAt, pullId, PullMerged) 868 return err 869} 870 871func DeletePull(e Execer, repoAt syntax.ATURI, pullId int) error { 872 err := SetPullState(e, repoAt, pullId, PullDeleted) 873 return err 874} 875 876func ResubmitPull(e Execer, pull *Pull, newPatch, sourceRev string) error { 877 newRoundNumber := len(pull.Submissions) 878 _, err := e.Exec(` 879 insert into pull_submissions (pull_id, repo_at, round_number, patch, source_rev) 880 values (?, ?, ?, ?, ?) 881 `, pull.PullId, pull.RepoAt, newRoundNumber, newPatch, sourceRev) 882 883 return err 884} 885 886func SetPullParentChangeId(e Execer, parentChangeId string, filters ...filter) error { 887 var conditions []string 888 var args []any 889 890 args = append(args, parentChangeId) 891 892 for _, filter := range filters { 893 conditions = append(conditions, filter.Condition()) 894 args = append(args, filter.Arg()...) 895 } 896 897 whereClause := "" 898 if conditions != nil { 899 whereClause = " where " + strings.Join(conditions, " and ") 900 } 901 902 query := fmt.Sprintf("update pulls set parent_change_id = ? %s", whereClause) 903 _, err := e.Exec(query, args...) 904 905 return err 906} 907 908// Only used when stacking to update contents in the event of a rebase (the interdiff should be empty). 909// otherwise submissions are immutable 910func UpdatePull(e Execer, newPatch, sourceRev string, filters ...filter) error { 911 var conditions []string 912 var args []any 913 914 args = append(args, sourceRev) 915 args = append(args, newPatch) 916 917 for _, filter := range filters { 918 conditions = append(conditions, filter.Condition()) 919 args = append(args, filter.Arg()...) 920 } 921 922 whereClause := "" 923 if conditions != nil { 924 whereClause = " where " + strings.Join(conditions, " and ") 925 } 926 927 query := fmt.Sprintf("update pull_submissions set source_rev = ?, patch = ? %s", whereClause) 928 _, err := e.Exec(query, args...) 929 930 return err 931} 932 933type PullCount struct { 934 Open int 935 Merged int 936 Closed int 937 Deleted int 938} 939 940func GetPullCount(e Execer, repoAt syntax.ATURI) (PullCount, error) { 941 row := e.QueryRow(` 942 select 943 count(case when state = ? then 1 end) as open_count, 944 count(case when state = ? then 1 end) as merged_count, 945 count(case when state = ? then 1 end) as closed_count, 946 count(case when state = ? then 1 end) as deleted_count 947 from pulls 948 where repo_at = ?`, 949 PullOpen, 950 PullMerged, 951 PullClosed, 952 PullDeleted, 953 repoAt, 954 ) 955 956 var count PullCount 957 if err := row.Scan(&count.Open, &count.Merged, &count.Closed, &count.Deleted); err != nil { 958 return PullCount{0, 0, 0, 0}, err 959 } 960 961 return count, nil 962} 963 964type Stack []*Pull 965 966// change-id parent-change-id 967// 968// 4 w ,-------- z (TOP) 969// 3 z <----',------- y 970// 2 y <-----',------ x 971// 1 x <------' nil (BOT) 972// 973// `w` is parent of none, so it is the top of the stack 974func GetStack(e Execer, stackId string) (Stack, error) { 975 unorderedPulls, err := GetPulls( 976 e, 977 FilterEq("stack_id", stackId), 978 FilterNotEq("state", PullDeleted), 979 ) 980 if err != nil { 981 return nil, err 982 } 983 // map of parent-change-id to pull 984 changeIdMap := make(map[string]*Pull, len(unorderedPulls)) 985 parentMap := make(map[string]*Pull, len(unorderedPulls)) 986 for _, p := range unorderedPulls { 987 changeIdMap[p.ChangeId] = p 988 if p.ParentChangeId != "" { 989 parentMap[p.ParentChangeId] = p 990 } 991 } 992 993 // the top of the stack is the pull that is not a parent of any pull 994 var topPull *Pull 995 for _, maybeTop := range unorderedPulls { 996 if _, ok := parentMap[maybeTop.ChangeId]; !ok { 997 topPull = maybeTop 998 break 999 } 1000 } 1001 1002 pulls := []*Pull{} 1003 for { 1004 pulls = append(pulls, topPull) 1005 if topPull.ParentChangeId != "" { 1006 if next, ok := changeIdMap[topPull.ParentChangeId]; ok { 1007 topPull = next 1008 } else { 1009 return nil, fmt.Errorf("failed to find parent pull request, stack is malformed") 1010 } 1011 } else { 1012 break 1013 } 1014 } 1015 1016 return pulls, nil 1017} 1018 1019func GetAbandonedPulls(e Execer, stackId string) ([]*Pull, error) { 1020 pulls, err := GetPulls( 1021 e, 1022 FilterEq("stack_id", stackId), 1023 FilterEq("state", PullDeleted), 1024 ) 1025 if err != nil { 1026 return nil, err 1027 } 1028 1029 return pulls, nil 1030} 1031 1032// position of this pull in the stack 1033func (stack Stack) Position(pull *Pull) int { 1034 return slices.IndexFunc(stack, func(p *Pull) bool { 1035 return p.ChangeId == pull.ChangeId 1036 }) 1037} 1038 1039// all pulls below this pull (including self) in this stack 1040// 1041// nil if this pull does not belong to this stack 1042func (stack Stack) Below(pull *Pull) Stack { 1043 position := stack.Position(pull) 1044 1045 if position < 0 { 1046 return nil 1047 } 1048 1049 return stack[position:] 1050} 1051 1052// all pulls below this pull (excluding self) in this stack 1053func (stack Stack) StrictlyBelow(pull *Pull) Stack { 1054 below := stack.Below(pull) 1055 1056 if len(below) > 0 { 1057 return below[1:] 1058 } 1059 1060 return nil 1061} 1062 1063// all pulls above this pull (including self) in this stack 1064func (stack Stack) Above(pull *Pull) Stack { 1065 position := stack.Position(pull) 1066 1067 if position < 0 { 1068 return nil 1069 } 1070 1071 return stack[:position+1] 1072} 1073 1074// all pulls below this pull (excluding self) in this stack 1075func (stack Stack) StrictlyAbove(pull *Pull) Stack { 1076 above := stack.Above(pull) 1077 1078 if len(above) > 0 { 1079 return above[:len(above)-1] 1080 } 1081 1082 return nil 1083} 1084 1085// the combined format-patches of all the newest submissions in this stack 1086func (stack Stack) CombinedPatch() string { 1087 // go in reverse order because the bottom of the stack is the last element in the slice 1088 var combined strings.Builder 1089 for idx := range stack { 1090 pull := stack[len(stack)-1-idx] 1091 combined.WriteString(pull.LatestPatch()) 1092 combined.WriteString("\n") 1093 } 1094 return combined.String() 1095} 1096 1097// filter out PRs that are "active" 1098// 1099// PRs that are still open are active 1100func (stack Stack) Mergeable() Stack { 1101 var mergeable Stack 1102 1103 for _, p := range stack { 1104 // stop at the first merged PR 1105 if p.State == PullMerged || p.State == PullClosed { 1106 break 1107 } 1108 1109 // skip over deleted PRs 1110 if p.State != PullDeleted { 1111 mergeable = append(mergeable, p) 1112 } 1113 } 1114 1115 return mergeable 1116}