1package db
2
3import (
4 "database/sql"
5 "fmt"
6 "log"
7 "slices"
8 "sort"
9 "strings"
10 "time"
11
12 "github.com/bluekeyes/go-gitdiff/gitdiff"
13 "github.com/bluesky-social/indigo/atproto/syntax"
14 "tangled.sh/tangled.sh/core/api/tangled"
15 "tangled.sh/tangled.sh/core/patchutil"
16 "tangled.sh/tangled.sh/core/types"
17)
18
19type PullState int
20
21const (
22 PullClosed PullState = iota
23 PullOpen
24 PullMerged
25 PullDeleted
26)
27
28func (p PullState) String() string {
29 switch p {
30 case PullOpen:
31 return "open"
32 case PullMerged:
33 return "merged"
34 case PullClosed:
35 return "closed"
36 case PullDeleted:
37 return "deleted"
38 default:
39 return "closed"
40 }
41}
42
43func (p PullState) IsOpen() bool {
44 return p == PullOpen
45}
46func (p PullState) IsMerged() bool {
47 return p == PullMerged
48}
49func (p PullState) IsClosed() bool {
50 return p == PullClosed
51}
52func (p PullState) IsDeleted() bool {
53 return p == PullDeleted
54}
55
56type Pull struct {
57 // ids
58 ID int
59 PullId int
60
61 // at ids
62 RepoAt syntax.ATURI
63 OwnerDid string
64 Rkey string
65
66 // content
67 Title string
68 Body string
69 TargetBranch string
70 State PullState
71 Submissions []*PullSubmission
72
73 // stacking
74 StackId string // nullable string
75 ChangeId string // nullable string
76 ParentChangeId string // nullable string
77
78 // meta
79 Created time.Time
80 PullSource *PullSource
81
82 // optionally, populate this when querying for reverse mappings
83 Repo *Repo
84}
85
86func (p Pull) AsRecord() tangled.RepoPull {
87 var source *tangled.RepoPull_Source
88 if p.PullSource != nil {
89 s := p.PullSource.AsRecord()
90 source = &s
91 }
92
93 record := tangled.RepoPull{
94 Title: p.Title,
95 Body: &p.Body,
96 CreatedAt: p.Created.Format(time.RFC3339),
97 PullId: int64(p.PullId),
98 TargetRepo: p.RepoAt.String(),
99 TargetBranch: p.TargetBranch,
100 Patch: p.LatestPatch(),
101 Source: source,
102 }
103 return record
104}
105
106type PullSource struct {
107 Branch string
108 RepoAt *syntax.ATURI
109
110 // optionally populate this for reverse mappings
111 Repo *Repo
112}
113
114func (p PullSource) AsRecord() tangled.RepoPull_Source {
115 var repoAt *string
116 if p.RepoAt != nil {
117 s := p.RepoAt.String()
118 repoAt = &s
119 }
120 record := tangled.RepoPull_Source{
121 Branch: p.Branch,
122 Repo: repoAt,
123 }
124 return record
125}
126
127type PullSubmission struct {
128 // ids
129 ID int
130 PullId int
131
132 // at ids
133 RepoAt syntax.ATURI
134
135 // content
136 RoundNumber int
137 Patch string
138 Comments []PullComment
139 SourceRev string // include the rev that was used to create this submission: only for branch/fork PRs
140
141 // meta
142 Created time.Time
143}
144
145type PullComment struct {
146 // ids
147 ID int
148 PullId int
149 SubmissionId int
150
151 // at ids
152 RepoAt string
153 OwnerDid string
154 CommentAt string
155
156 // content
157 Body string
158
159 // meta
160 Created time.Time
161}
162
163func (p *Pull) LatestPatch() string {
164 latestSubmission := p.Submissions[p.LastRoundNumber()]
165 return latestSubmission.Patch
166}
167
168func (p *Pull) PullAt() syntax.ATURI {
169 return syntax.ATURI(fmt.Sprintf("at://%s/%s/%s", p.OwnerDid, tangled.RepoPullNSID, p.Rkey))
170}
171
172func (p *Pull) LastRoundNumber() int {
173 return len(p.Submissions) - 1
174}
175
176func (p *Pull) IsPatchBased() bool {
177 return p.PullSource == nil
178}
179
180func (p *Pull) IsBranchBased() bool {
181 if p.PullSource != nil {
182 if p.PullSource.RepoAt != nil {
183 return p.PullSource.RepoAt == &p.RepoAt
184 } else {
185 // no repo specified
186 return true
187 }
188 }
189 return false
190}
191
192func (p *Pull) IsForkBased() bool {
193 if p.PullSource != nil {
194 if p.PullSource.RepoAt != nil {
195 // make sure repos are different
196 return p.PullSource.RepoAt != &p.RepoAt
197 }
198 }
199 return false
200}
201
202func (p *Pull) IsStacked() bool {
203 return p.StackId != ""
204}
205
206func (s PullSubmission) AsDiff(targetBranch string) ([]*gitdiff.File, error) {
207 patch := s.Patch
208
209 // if format-patch; then extract each patch
210 var diffs []*gitdiff.File
211 if patchutil.IsFormatPatch(patch) {
212 patches, err := patchutil.ExtractPatches(patch)
213 if err != nil {
214 return nil, err
215 }
216 var ps [][]*gitdiff.File
217 for _, p := range patches {
218 ps = append(ps, p.Files)
219 }
220
221 diffs = patchutil.CombineDiff(ps...)
222 } else {
223 d, _, err := gitdiff.Parse(strings.NewReader(patch))
224 if err != nil {
225 return nil, err
226 }
227 diffs = d
228 }
229
230 return diffs, nil
231}
232
233func (s PullSubmission) AsNiceDiff(targetBranch string) types.NiceDiff {
234 diffs, err := s.AsDiff(targetBranch)
235 if err != nil {
236 log.Println(err)
237 }
238
239 nd := types.NiceDiff{}
240 nd.Commit.Parent = targetBranch
241
242 for _, d := range diffs {
243 ndiff := types.Diff{}
244 ndiff.Name.New = d.NewName
245 ndiff.Name.Old = d.OldName
246 ndiff.IsBinary = d.IsBinary
247 ndiff.IsNew = d.IsNew
248 ndiff.IsDelete = d.IsDelete
249 ndiff.IsCopy = d.IsCopy
250 ndiff.IsRename = d.IsRename
251
252 for _, tf := range d.TextFragments {
253 ndiff.TextFragments = append(ndiff.TextFragments, *tf)
254 for _, l := range tf.Lines {
255 switch l.Op {
256 case gitdiff.OpAdd:
257 nd.Stat.Insertions += 1
258 case gitdiff.OpDelete:
259 nd.Stat.Deletions += 1
260 }
261 }
262 }
263
264 nd.Diff = append(nd.Diff, ndiff)
265 }
266
267 nd.Stat.FilesChanged = len(diffs)
268
269 return nd
270}
271
272func (s PullSubmission) IsFormatPatch() bool {
273 return patchutil.IsFormatPatch(s.Patch)
274}
275
276func (s PullSubmission) AsFormatPatch() []patchutil.FormatPatch {
277 patches, err := patchutil.ExtractPatches(s.Patch)
278 if err != nil {
279 log.Println("error extracting patches from submission:", err)
280 return []patchutil.FormatPatch{}
281 }
282
283 return patches
284}
285
286func NewPull(tx *sql.Tx, pull *Pull) error {
287 _, err := tx.Exec(`
288 insert or ignore into repo_pull_seqs (repo_at, next_pull_id)
289 values (?, 1)
290 `, pull.RepoAt)
291 if err != nil {
292 return err
293 }
294
295 var nextId int
296 err = tx.QueryRow(`
297 update repo_pull_seqs
298 set next_pull_id = next_pull_id + 1
299 where repo_at = ?
300 returning next_pull_id - 1
301 `, pull.RepoAt).Scan(&nextId)
302 if err != nil {
303 return err
304 }
305
306 pull.PullId = nextId
307 pull.State = PullOpen
308
309 var sourceBranch, sourceRepoAt *string
310 if pull.PullSource != nil {
311 sourceBranch = &pull.PullSource.Branch
312 if pull.PullSource.RepoAt != nil {
313 x := pull.PullSource.RepoAt.String()
314 sourceRepoAt = &x
315 }
316 }
317
318 var stackId, changeId, parentChangeId *string
319 if pull.StackId != "" {
320 stackId = &pull.StackId
321 }
322 if pull.ChangeId != "" {
323 changeId = &pull.ChangeId
324 }
325 if pull.ParentChangeId != "" {
326 parentChangeId = &pull.ParentChangeId
327 }
328
329 _, err = tx.Exec(
330 `
331 insert into pulls (
332 repo_at, owner_did, pull_id, title, target_branch, body, rkey, state, source_branch, source_repo_at, stack_id, change_id, parent_change_id
333 )
334 values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`,
335 pull.RepoAt,
336 pull.OwnerDid,
337 pull.PullId,
338 pull.Title,
339 pull.TargetBranch,
340 pull.Body,
341 pull.Rkey,
342 pull.State,
343 sourceBranch,
344 sourceRepoAt,
345 stackId,
346 changeId,
347 parentChangeId,
348 )
349 if err != nil {
350 return err
351 }
352
353 _, err = tx.Exec(`
354 insert into pull_submissions (pull_id, repo_at, round_number, patch, source_rev)
355 values (?, ?, ?, ?, ?)
356 `, pull.PullId, pull.RepoAt, 0, pull.Submissions[0].Patch, pull.Submissions[0].SourceRev)
357 return err
358}
359
360func GetPullAt(e Execer, repoAt syntax.ATURI, pullId int) (syntax.ATURI, error) {
361 pull, err := GetPull(e, repoAt, pullId)
362 if err != nil {
363 return "", err
364 }
365 return pull.PullAt(), err
366}
367
368func NextPullId(e Execer, repoAt syntax.ATURI) (int, error) {
369 var pullId int
370 err := e.QueryRow(`select next_pull_id from repo_pull_seqs where repo_at = ?`, repoAt).Scan(&pullId)
371 return pullId - 1, err
372}
373
374func GetPulls(e Execer, filters ...filter) ([]*Pull, error) {
375 pulls := make(map[int]*Pull)
376
377 var conditions []string
378 var args []any
379 for _, filter := range filters {
380 conditions = append(conditions, filter.Condition())
381 args = append(args, filter.arg)
382 }
383
384 whereClause := ""
385 if conditions != nil {
386 whereClause = " where " + strings.Join(conditions, " and ")
387 }
388
389 query := fmt.Sprintf(`
390 select
391 owner_did,
392 repo_at,
393 pull_id,
394 created,
395 title,
396 state,
397 target_branch,
398 body,
399 rkey,
400 source_branch,
401 source_repo_at,
402 stack_id,
403 change_id,
404 parent_change_id
405 from
406 pulls
407 %s
408 `, whereClause)
409
410 rows, err := e.Query(query, args...)
411 if err != nil {
412 return nil, err
413 }
414 defer rows.Close()
415
416 for rows.Next() {
417 var pull Pull
418 var createdAt string
419 var sourceBranch, sourceRepoAt, stackId, changeId, parentChangeId sql.NullString
420 err := rows.Scan(
421 &pull.OwnerDid,
422 &pull.RepoAt,
423 &pull.PullId,
424 &createdAt,
425 &pull.Title,
426 &pull.State,
427 &pull.TargetBranch,
428 &pull.Body,
429 &pull.Rkey,
430 &sourceBranch,
431 &sourceRepoAt,
432 &stackId,
433 &changeId,
434 &parentChangeId,
435 )
436 if err != nil {
437 return nil, err
438 }
439
440 createdTime, err := time.Parse(time.RFC3339, createdAt)
441 if err != nil {
442 return nil, err
443 }
444 pull.Created = createdTime
445
446 if sourceBranch.Valid {
447 pull.PullSource = &PullSource{
448 Branch: sourceBranch.String,
449 }
450 if sourceRepoAt.Valid {
451 sourceRepoAtParsed, err := syntax.ParseATURI(sourceRepoAt.String)
452 if err != nil {
453 return nil, err
454 }
455 pull.PullSource.RepoAt = &sourceRepoAtParsed
456 }
457 }
458
459 if stackId.Valid {
460 pull.StackId = stackId.String
461 }
462 if changeId.Valid {
463 pull.ChangeId = changeId.String
464 }
465 if parentChangeId.Valid {
466 pull.ParentChangeId = parentChangeId.String
467 }
468
469 pulls[pull.PullId] = &pull
470 }
471
472 // get latest round no. for each pull
473 inClause := strings.TrimSuffix(strings.Repeat("?, ", len(pulls)), ", ")
474 submissionsQuery := fmt.Sprintf(`
475 select
476 id, pull_id, round_number, patch, source_rev
477 from
478 pull_submissions
479 where
480 repo_at in (%s) and pull_id in (%s)
481 `, inClause, inClause)
482
483 args = make([]any, len(pulls)*2)
484 idx := 0
485 for _, p := range pulls {
486 args[idx] = p.RepoAt
487 idx += 1
488 }
489 for _, p := range pulls {
490 args[idx] = p.PullId
491 idx += 1
492 }
493 submissionsRows, err := e.Query(submissionsQuery, args...)
494 if err != nil {
495 return nil, err
496 }
497 defer submissionsRows.Close()
498
499 for submissionsRows.Next() {
500 var s PullSubmission
501 var sourceRev sql.NullString
502 err := submissionsRows.Scan(
503 &s.ID,
504 &s.PullId,
505 &s.RoundNumber,
506 &s.Patch,
507 &sourceRev,
508 )
509 if err != nil {
510 return nil, err
511 }
512
513 if sourceRev.Valid {
514 s.SourceRev = sourceRev.String
515 }
516
517 if p, ok := pulls[s.PullId]; ok {
518 p.Submissions = make([]*PullSubmission, s.RoundNumber+1)
519 p.Submissions[s.RoundNumber] = &s
520 }
521 }
522 if err := rows.Err(); err != nil {
523 return nil, err
524 }
525
526 // get comment count on latest submission on each pull
527 inClause = strings.TrimSuffix(strings.Repeat("?, ", len(pulls)), ", ")
528 commentsQuery := fmt.Sprintf(`
529 select
530 count(id), pull_id
531 from
532 pull_comments
533 where
534 submission_id in (%s)
535 group by
536 submission_id
537 `, inClause)
538
539 args = []any{}
540 for _, p := range pulls {
541 args = append(args, p.Submissions[p.LastRoundNumber()].ID)
542 }
543 commentsRows, err := e.Query(commentsQuery, args...)
544 if err != nil {
545 return nil, err
546 }
547 defer commentsRows.Close()
548
549 for commentsRows.Next() {
550 var commentCount, pullId int
551 err := commentsRows.Scan(
552 &commentCount,
553 &pullId,
554 )
555 if err != nil {
556 return nil, err
557 }
558 if p, ok := pulls[pullId]; ok {
559 p.Submissions[p.LastRoundNumber()].Comments = make([]PullComment, commentCount)
560 }
561 }
562 if err := rows.Err(); err != nil {
563 return nil, err
564 }
565
566 orderedByPullId := []*Pull{}
567 for _, p := range pulls {
568 orderedByPullId = append(orderedByPullId, p)
569 }
570 sort.Slice(orderedByPullId, func(i, j int) bool {
571 return orderedByPullId[i].PullId > orderedByPullId[j].PullId
572 })
573
574 return orderedByPullId, nil
575}
576
577func GetPull(e Execer, repoAt syntax.ATURI, pullId int) (*Pull, error) {
578 query := `
579 select
580 owner_did,
581 pull_id,
582 created,
583 title,
584 state,
585 target_branch,
586 repo_at,
587 body,
588 rkey,
589 source_branch,
590 source_repo_at,
591 stack_id,
592 change_id,
593 parent_change_id
594 from
595 pulls
596 where
597 repo_at = ? and pull_id = ?
598 `
599 row := e.QueryRow(query, repoAt, pullId)
600
601 var pull Pull
602 var createdAt string
603 var sourceBranch, sourceRepoAt, stackId, changeId, parentChangeId sql.NullString
604 err := row.Scan(
605 &pull.OwnerDid,
606 &pull.PullId,
607 &createdAt,
608 &pull.Title,
609 &pull.State,
610 &pull.TargetBranch,
611 &pull.RepoAt,
612 &pull.Body,
613 &pull.Rkey,
614 &sourceBranch,
615 &sourceRepoAt,
616 &stackId,
617 &changeId,
618 &parentChangeId,
619 )
620 if err != nil {
621 return nil, err
622 }
623
624 createdTime, err := time.Parse(time.RFC3339, createdAt)
625 if err != nil {
626 return nil, err
627 }
628 pull.Created = createdTime
629
630 // populate source
631 if sourceBranch.Valid {
632 pull.PullSource = &PullSource{
633 Branch: sourceBranch.String,
634 }
635 if sourceRepoAt.Valid {
636 sourceRepoAtParsed, err := syntax.ParseATURI(sourceRepoAt.String)
637 if err != nil {
638 return nil, err
639 }
640 pull.PullSource.RepoAt = &sourceRepoAtParsed
641 }
642 }
643
644 if stackId.Valid {
645 pull.StackId = stackId.String
646 }
647 if changeId.Valid {
648 pull.ChangeId = changeId.String
649 }
650 if parentChangeId.Valid {
651 pull.ParentChangeId = parentChangeId.String
652 }
653
654 submissionsQuery := `
655 select
656 id, pull_id, repo_at, round_number, patch, created, source_rev
657 from
658 pull_submissions
659 where
660 repo_at = ? and pull_id = ?
661 `
662 submissionsRows, err := e.Query(submissionsQuery, repoAt, pullId)
663 if err != nil {
664 return nil, err
665 }
666 defer submissionsRows.Close()
667
668 submissionsMap := make(map[int]*PullSubmission)
669
670 for submissionsRows.Next() {
671 var submission PullSubmission
672 var submissionCreatedStr string
673 var submissionSourceRev sql.NullString
674 err := submissionsRows.Scan(
675 &submission.ID,
676 &submission.PullId,
677 &submission.RepoAt,
678 &submission.RoundNumber,
679 &submission.Patch,
680 &submissionCreatedStr,
681 &submissionSourceRev,
682 )
683 if err != nil {
684 return nil, err
685 }
686
687 submissionCreatedTime, err := time.Parse(time.RFC3339, submissionCreatedStr)
688 if err != nil {
689 return nil, err
690 }
691 submission.Created = submissionCreatedTime
692
693 if submissionSourceRev.Valid {
694 submission.SourceRev = submissionSourceRev.String
695 }
696
697 submissionsMap[submission.ID] = &submission
698 }
699 if err = submissionsRows.Close(); err != nil {
700 return nil, err
701 }
702 if len(submissionsMap) == 0 {
703 return &pull, nil
704 }
705
706 var args []any
707 for k := range submissionsMap {
708 args = append(args, k)
709 }
710 inClause := strings.TrimSuffix(strings.Repeat("?, ", len(submissionsMap)), ", ")
711 commentsQuery := fmt.Sprintf(`
712 select
713 id,
714 pull_id,
715 submission_id,
716 repo_at,
717 owner_did,
718 comment_at,
719 body,
720 created
721 from
722 pull_comments
723 where
724 submission_id IN (%s)
725 order by
726 created asc
727 `, inClause)
728 commentsRows, err := e.Query(commentsQuery, args...)
729 if err != nil {
730 return nil, err
731 }
732 defer commentsRows.Close()
733
734 for commentsRows.Next() {
735 var comment PullComment
736 var commentCreatedStr string
737 err := commentsRows.Scan(
738 &comment.ID,
739 &comment.PullId,
740 &comment.SubmissionId,
741 &comment.RepoAt,
742 &comment.OwnerDid,
743 &comment.CommentAt,
744 &comment.Body,
745 &commentCreatedStr,
746 )
747 if err != nil {
748 return nil, err
749 }
750
751 commentCreatedTime, err := time.Parse(time.RFC3339, commentCreatedStr)
752 if err != nil {
753 return nil, err
754 }
755 comment.Created = commentCreatedTime
756
757 // Add the comment to its submission
758 if submission, ok := submissionsMap[comment.SubmissionId]; ok {
759 submission.Comments = append(submission.Comments, comment)
760 }
761
762 }
763 if err = commentsRows.Err(); err != nil {
764 return nil, err
765 }
766
767 var pullSourceRepo *Repo
768 if pull.PullSource != nil {
769 if pull.PullSource.RepoAt != nil {
770 pullSourceRepo, err = GetRepoByAtUri(e, pull.PullSource.RepoAt.String())
771 if err != nil {
772 log.Printf("failed to get repo by at uri: %v", err)
773 } else {
774 pull.PullSource.Repo = pullSourceRepo
775 }
776 }
777 }
778
779 pull.Submissions = make([]*PullSubmission, len(submissionsMap))
780 for _, submission := range submissionsMap {
781 pull.Submissions[submission.RoundNumber] = submission
782 }
783
784 return &pull, nil
785}
786
787// timeframe here is directly passed into the sql query filter, and any
788// timeframe in the past should be negative; e.g.: "-3 months"
789func GetPullsByOwnerDid(e Execer, did, timeframe string) ([]Pull, error) {
790 var pulls []Pull
791
792 rows, err := e.Query(`
793 select
794 p.owner_did,
795 p.repo_at,
796 p.pull_id,
797 p.created,
798 p.title,
799 p.state,
800 r.did,
801 r.name,
802 r.knot,
803 r.rkey,
804 r.created
805 from
806 pulls p
807 join
808 repos r on p.repo_at = r.at_uri
809 where
810 p.owner_did = ? and p.created >= date ('now', ?)
811 order by
812 p.created desc`, did, timeframe)
813 if err != nil {
814 return nil, err
815 }
816 defer rows.Close()
817
818 for rows.Next() {
819 var pull Pull
820 var repo Repo
821 var pullCreatedAt, repoCreatedAt string
822 err := rows.Scan(
823 &pull.OwnerDid,
824 &pull.RepoAt,
825 &pull.PullId,
826 &pullCreatedAt,
827 &pull.Title,
828 &pull.State,
829 &repo.Did,
830 &repo.Name,
831 &repo.Knot,
832 &repo.Rkey,
833 &repoCreatedAt,
834 )
835 if err != nil {
836 return nil, err
837 }
838
839 pullCreatedTime, err := time.Parse(time.RFC3339, pullCreatedAt)
840 if err != nil {
841 return nil, err
842 }
843 pull.Created = pullCreatedTime
844
845 repoCreatedTime, err := time.Parse(time.RFC3339, repoCreatedAt)
846 if err != nil {
847 return nil, err
848 }
849 repo.Created = repoCreatedTime
850
851 pull.Repo = &repo
852
853 pulls = append(pulls, pull)
854 }
855
856 if err := rows.Err(); err != nil {
857 return nil, err
858 }
859
860 return pulls, nil
861}
862
863func NewPullComment(e Execer, comment *PullComment) (int64, error) {
864 query := `insert into pull_comments (owner_did, repo_at, submission_id, comment_at, pull_id, body) values (?, ?, ?, ?, ?, ?)`
865 res, err := e.Exec(
866 query,
867 comment.OwnerDid,
868 comment.RepoAt,
869 comment.SubmissionId,
870 comment.CommentAt,
871 comment.PullId,
872 comment.Body,
873 )
874 if err != nil {
875 return 0, err
876 }
877
878 i, err := res.LastInsertId()
879 if err != nil {
880 return 0, err
881 }
882
883 return i, nil
884}
885
886func SetPullState(e Execer, repoAt syntax.ATURI, pullId int, pullState PullState) error {
887 _, err := e.Exec(
888 `update pulls set state = ? where repo_at = ? and pull_id = ? and (state <> ? or state <> ?)`,
889 pullState,
890 repoAt,
891 pullId,
892 PullDeleted, // only update state of non-deleted pulls
893 PullMerged, // only update state of non-merged pulls
894 )
895 return err
896}
897
898func ClosePull(e Execer, repoAt syntax.ATURI, pullId int) error {
899 err := SetPullState(e, repoAt, pullId, PullClosed)
900 return err
901}
902
903func ReopenPull(e Execer, repoAt syntax.ATURI, pullId int) error {
904 err := SetPullState(e, repoAt, pullId, PullOpen)
905 return err
906}
907
908func MergePull(e Execer, repoAt syntax.ATURI, pullId int) error {
909 err := SetPullState(e, repoAt, pullId, PullMerged)
910 return err
911}
912
913func DeletePull(e Execer, repoAt syntax.ATURI, pullId int) error {
914 err := SetPullState(e, repoAt, pullId, PullDeleted)
915 return err
916}
917
918func ResubmitPull(e Execer, pull *Pull, newPatch, sourceRev string) error {
919 newRoundNumber := len(pull.Submissions)
920 _, err := e.Exec(`
921 insert into pull_submissions (pull_id, repo_at, round_number, patch, source_rev)
922 values (?, ?, ?, ?, ?)
923 `, pull.PullId, pull.RepoAt, newRoundNumber, newPatch, sourceRev)
924
925 return err
926}
927
928func SetPullParentChangeId(e Execer, parentChangeId string, filters ...filter) error {
929 var conditions []string
930 var args []any
931
932 args = append(args, parentChangeId)
933
934 for _, filter := range filters {
935 conditions = append(conditions, filter.Condition())
936 args = append(args, filter.arg)
937 }
938
939 whereClause := ""
940 if conditions != nil {
941 whereClause = " where " + strings.Join(conditions, " and ")
942 }
943
944 query := fmt.Sprintf("update pulls set parent_change_id = ? %s", whereClause)
945 _, err := e.Exec(query, args...)
946
947 return err
948}
949
950type PullCount struct {
951 Open int
952 Merged int
953 Closed int
954 Deleted int
955}
956
957func GetPullCount(e Execer, repoAt syntax.ATURI) (PullCount, error) {
958 row := e.QueryRow(`
959 select
960 count(case when state = ? then 1 end) as open_count,
961 count(case when state = ? then 1 end) as merged_count,
962 count(case when state = ? then 1 end) as closed_count,
963 count(case when state = ? then 1 end) as deleted_count
964 from pulls
965 where repo_at = ?`,
966 PullOpen,
967 PullMerged,
968 PullClosed,
969 PullDeleted,
970 repoAt,
971 )
972
973 var count PullCount
974 if err := row.Scan(&count.Open, &count.Merged, &count.Closed, &count.Deleted); err != nil {
975 return PullCount{0, 0, 0, 0}, err
976 }
977
978 return count, nil
979}
980
981type Stack []*Pull
982
983// change-id parent-change-id
984//
985// 4 w ,-------- z (TOP)
986// 3 z <----',------- y
987// 2 y <-----',------ x
988// 1 x <------' nil (BOT)
989//
990// `w` is parent of none, so it is the top of the stack
991func GetStack(e Execer, stackId string) (Stack, error) {
992 unorderedPulls, err := GetPulls(
993 e,
994 FilterEq("stack_id", stackId),
995 FilterNotEq("state", PullDeleted),
996 )
997 if err != nil {
998 return nil, err
999 }
1000 // map of parent-change-id to pull
1001 changeIdMap := make(map[string]*Pull, len(unorderedPulls))
1002 parentMap := make(map[string]*Pull, len(unorderedPulls))
1003 for _, p := range unorderedPulls {
1004 changeIdMap[p.ChangeId] = p
1005 if p.ParentChangeId != "" {
1006 parentMap[p.ParentChangeId] = p
1007 }
1008 }
1009
1010 // the top of the stack is the pull that is not a parent of any pull
1011 var topPull *Pull
1012 for _, maybeTop := range unorderedPulls {
1013 if _, ok := parentMap[maybeTop.ChangeId]; !ok {
1014 topPull = maybeTop
1015 break
1016 }
1017 }
1018
1019 pulls := []*Pull{}
1020 for {
1021 pulls = append(pulls, topPull)
1022 if topPull.ParentChangeId != "" {
1023 if next, ok := changeIdMap[topPull.ParentChangeId]; ok {
1024 topPull = next
1025 } else {
1026 return nil, fmt.Errorf("failed to find parent pull request, stack is malformed")
1027 }
1028 } else {
1029 break
1030 }
1031 }
1032
1033 return pulls, nil
1034}
1035
1036func GetAbandonedPulls(e Execer, stackId string) ([]*Pull, error) {
1037 pulls, err := GetPulls(
1038 e,
1039 FilterEq("stack_id", stackId),
1040 FilterEq("state", PullDeleted),
1041 )
1042 if err != nil {
1043 return nil, err
1044 }
1045
1046 return pulls, nil
1047}
1048
1049// position of this pull in the stack
1050func (stack Stack) Position(pull *Pull) int {
1051 return slices.IndexFunc(stack, func(p *Pull) bool {
1052 return p.ChangeId == pull.ChangeId
1053 })
1054}
1055
1056// all pulls below this pull (including self) in this stack
1057//
1058// nil if this pull does not belong to this stack
1059func (stack Stack) Below(pull *Pull) Stack {
1060 position := stack.Position(pull)
1061
1062 if position < 0 {
1063 return nil
1064 }
1065
1066 return stack[position:]
1067}
1068
1069// all pulls below this pull (excluding self) in this stack
1070func (stack Stack) StrictlyBelow(pull *Pull) Stack {
1071 below := stack.Below(pull)
1072
1073 if len(below) > 0 {
1074 return below[1:]
1075 }
1076
1077 return nil
1078}
1079
1080// all pulls above this pull (including self) in this stack
1081func (stack Stack) Above(pull *Pull) Stack {
1082 position := stack.Position(pull)
1083
1084 if position < 0 {
1085 return nil
1086 }
1087
1088 return stack[:position+1]
1089}
1090
1091// all pulls below this pull (excluding self) in this stack
1092func (stack Stack) StrictlyAbove(pull *Pull) Stack {
1093 above := stack.Above(pull)
1094
1095 if len(above) > 0 {
1096 return above[:len(above)-1]
1097 }
1098
1099 return nil
1100}
1101
1102// the combined format-patches of all the newest submissions in this stack
1103func (stack Stack) CombinedPatch() string {
1104 // go in reverse order because the bottom of the stack is the last element in the slice
1105 var combined strings.Builder
1106 for idx := range stack {
1107 pull := stack[len(stack)-1-idx]
1108 combined.WriteString(pull.LatestPatch())
1109 combined.WriteString("\n")
1110 }
1111 return combined.String()
1112}
1113
1114// filter out PRs that are "active"
1115//
1116// PRs that are still open are active
1117func (stack Stack) Mergeable() Stack {
1118 var mergeable Stack
1119
1120 for _, p := range stack {
1121 // stop at the first merged PR
1122 if p.State == PullMerged || p.State == PullClosed {
1123 break
1124 }
1125
1126 // skip over deleted PRs
1127 if p.State != PullDeleted {
1128 mergeable = append(mergeable, p)
1129 }
1130 }
1131
1132 return mergeable
1133}