1package db
2
3import (
4 "database/sql"
5 "fmt"
6 "log"
7 "slices"
8 "sort"
9 "strings"
10 "time"
11
12 "github.com/bluesky-social/indigo/atproto/syntax"
13 "tangled.sh/tangled.sh/core/api/tangled"
14 "tangled.sh/tangled.sh/core/patchutil"
15 "tangled.sh/tangled.sh/core/types"
16)
17
18type PullState int
19
20const (
21 PullClosed PullState = iota
22 PullOpen
23 PullMerged
24 PullDeleted
25)
26
27func (p PullState) String() string {
28 switch p {
29 case PullOpen:
30 return "open"
31 case PullMerged:
32 return "merged"
33 case PullClosed:
34 return "closed"
35 case PullDeleted:
36 return "deleted"
37 default:
38 return "closed"
39 }
40}
41
42func (p PullState) IsOpen() bool {
43 return p == PullOpen
44}
45func (p PullState) IsMerged() bool {
46 return p == PullMerged
47}
48func (p PullState) IsClosed() bool {
49 return p == PullClosed
50}
51func (p PullState) IsDeleted() bool {
52 return p == PullDeleted
53}
54
55type Pull struct {
56 // ids
57 ID int
58 PullId int
59
60 // at ids
61 RepoAt syntax.ATURI
62 OwnerDid string
63 Rkey string
64
65 // content
66 Title string
67 Body string
68 TargetBranch string
69 State PullState
70 Submissions []*PullSubmission
71
72 // stacking
73 StackId string // nullable string
74 ChangeId string // nullable string
75 ParentChangeId string // nullable string
76
77 // meta
78 Created time.Time
79 PullSource *PullSource
80
81 // optionally, populate this when querying for reverse mappings
82 Repo *Repo
83}
84
85func (p Pull) AsRecord() tangled.RepoPull {
86 var source *tangled.RepoPull_Source
87 if p.PullSource != nil {
88 s := p.PullSource.AsRecord()
89 source = &s
90 }
91
92 record := tangled.RepoPull{
93 Title: p.Title,
94 Body: &p.Body,
95 CreatedAt: p.Created.Format(time.RFC3339),
96 PullId: int64(p.PullId),
97 TargetRepo: p.RepoAt.String(),
98 TargetBranch: p.TargetBranch,
99 Patch: p.LatestPatch(),
100 Source: source,
101 }
102 return record
103}
104
105type PullSource struct {
106 Branch string
107 RepoAt *syntax.ATURI
108
109 // optionally populate this for reverse mappings
110 Repo *Repo
111}
112
113func (p PullSource) AsRecord() tangled.RepoPull_Source {
114 var repoAt *string
115 if p.RepoAt != nil {
116 s := p.RepoAt.String()
117 repoAt = &s
118 }
119 record := tangled.RepoPull_Source{
120 Branch: p.Branch,
121 Repo: repoAt,
122 }
123 return record
124}
125
126type PullSubmission struct {
127 // ids
128 ID int
129 PullId int
130
131 // at ids
132 RepoAt syntax.ATURI
133
134 // content
135 RoundNumber int
136 Patch string
137 Comments []PullComment
138 SourceRev string // include the rev that was used to create this submission: only for branch/fork PRs
139
140 // meta
141 Created time.Time
142}
143
144type PullComment struct {
145 // ids
146 ID int
147 PullId int
148 SubmissionId int
149
150 // at ids
151 RepoAt string
152 OwnerDid string
153 CommentAt string
154
155 // content
156 Body string
157
158 // meta
159 Created time.Time
160}
161
162func (p *Pull) LatestPatch() string {
163 latestSubmission := p.Submissions[p.LastRoundNumber()]
164 return latestSubmission.Patch
165}
166
167func (p *Pull) PullAt() syntax.ATURI {
168 return syntax.ATURI(fmt.Sprintf("at://%s/%s/%s", p.OwnerDid, tangled.RepoPullNSID, p.Rkey))
169}
170
171func (p *Pull) LastRoundNumber() int {
172 return len(p.Submissions) - 1
173}
174
175func (p *Pull) IsPatchBased() bool {
176 return p.PullSource == nil
177}
178
179func (p *Pull) IsBranchBased() bool {
180 if p.PullSource != nil {
181 if p.PullSource.RepoAt != nil {
182 return p.PullSource.RepoAt == &p.RepoAt
183 } else {
184 // no repo specified
185 return true
186 }
187 }
188 return false
189}
190
191func (p *Pull) IsForkBased() bool {
192 if p.PullSource != nil {
193 if p.PullSource.RepoAt != nil {
194 // make sure repos are different
195 return p.PullSource.RepoAt != &p.RepoAt
196 }
197 }
198 return false
199}
200
201func (p *Pull) IsStacked() bool {
202 return p.StackId != ""
203}
204
205func (s PullSubmission) IsFormatPatch() bool {
206 return patchutil.IsFormatPatch(s.Patch)
207}
208
209func (s PullSubmission) AsFormatPatch() []types.FormatPatch {
210 patches, err := patchutil.ExtractPatches(s.Patch)
211 if err != nil {
212 log.Println("error extracting patches from submission:", err)
213 return []types.FormatPatch{}
214 }
215
216 return patches
217}
218
219func NewPull(tx *sql.Tx, pull *Pull) error {
220 _, err := tx.Exec(`
221 insert or ignore into repo_pull_seqs (repo_at, next_pull_id)
222 values (?, 1)
223 `, pull.RepoAt)
224 if err != nil {
225 return err
226 }
227
228 var nextId int
229 err = tx.QueryRow(`
230 update repo_pull_seqs
231 set next_pull_id = next_pull_id + 1
232 where repo_at = ?
233 returning next_pull_id - 1
234 `, pull.RepoAt).Scan(&nextId)
235 if err != nil {
236 return err
237 }
238
239 pull.PullId = nextId
240 pull.State = PullOpen
241
242 var sourceBranch, sourceRepoAt *string
243 if pull.PullSource != nil {
244 sourceBranch = &pull.PullSource.Branch
245 if pull.PullSource.RepoAt != nil {
246 x := pull.PullSource.RepoAt.String()
247 sourceRepoAt = &x
248 }
249 }
250
251 var stackId, changeId, parentChangeId *string
252 if pull.StackId != "" {
253 stackId = &pull.StackId
254 }
255 if pull.ChangeId != "" {
256 changeId = &pull.ChangeId
257 }
258 if pull.ParentChangeId != "" {
259 parentChangeId = &pull.ParentChangeId
260 }
261
262 _, err = tx.Exec(
263 `
264 insert into pulls (
265 repo_at, owner_did, pull_id, title, target_branch, body, rkey, state, source_branch, source_repo_at, stack_id, change_id, parent_change_id
266 )
267 values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`,
268 pull.RepoAt,
269 pull.OwnerDid,
270 pull.PullId,
271 pull.Title,
272 pull.TargetBranch,
273 pull.Body,
274 pull.Rkey,
275 pull.State,
276 sourceBranch,
277 sourceRepoAt,
278 stackId,
279 changeId,
280 parentChangeId,
281 )
282 if err != nil {
283 return err
284 }
285
286 _, err = tx.Exec(`
287 insert into pull_submissions (pull_id, repo_at, round_number, patch, source_rev)
288 values (?, ?, ?, ?, ?)
289 `, pull.PullId, pull.RepoAt, 0, pull.Submissions[0].Patch, pull.Submissions[0].SourceRev)
290 return err
291}
292
293func GetPullAt(e Execer, repoAt syntax.ATURI, pullId int) (syntax.ATURI, error) {
294 pull, err := GetPull(e, repoAt, pullId)
295 if err != nil {
296 return "", err
297 }
298 return pull.PullAt(), err
299}
300
301func NextPullId(e Execer, repoAt syntax.ATURI) (int, error) {
302 var pullId int
303 err := e.QueryRow(`select next_pull_id from repo_pull_seqs where repo_at = ?`, repoAt).Scan(&pullId)
304 return pullId - 1, err
305}
306
307func GetPulls(e Execer, filters ...filter) ([]*Pull, error) {
308 pulls := make(map[int]*Pull)
309
310 var conditions []string
311 var args []any
312 for _, filter := range filters {
313 conditions = append(conditions, filter.Condition())
314 args = append(args, filter.arg)
315 }
316
317 whereClause := ""
318 if conditions != nil {
319 whereClause = " where " + strings.Join(conditions, " and ")
320 }
321
322 query := fmt.Sprintf(`
323 select
324 owner_did,
325 repo_at,
326 pull_id,
327 created,
328 title,
329 state,
330 target_branch,
331 body,
332 rkey,
333 source_branch,
334 source_repo_at,
335 stack_id,
336 change_id,
337 parent_change_id
338 from
339 pulls
340 %s
341 `, whereClause)
342
343 rows, err := e.Query(query, args...)
344 if err != nil {
345 return nil, err
346 }
347 defer rows.Close()
348
349 for rows.Next() {
350 var pull Pull
351 var createdAt string
352 var sourceBranch, sourceRepoAt, stackId, changeId, parentChangeId sql.NullString
353 err := rows.Scan(
354 &pull.OwnerDid,
355 &pull.RepoAt,
356 &pull.PullId,
357 &createdAt,
358 &pull.Title,
359 &pull.State,
360 &pull.TargetBranch,
361 &pull.Body,
362 &pull.Rkey,
363 &sourceBranch,
364 &sourceRepoAt,
365 &stackId,
366 &changeId,
367 &parentChangeId,
368 )
369 if err != nil {
370 return nil, err
371 }
372
373 createdTime, err := time.Parse(time.RFC3339, createdAt)
374 if err != nil {
375 return nil, err
376 }
377 pull.Created = createdTime
378
379 if sourceBranch.Valid {
380 pull.PullSource = &PullSource{
381 Branch: sourceBranch.String,
382 }
383 if sourceRepoAt.Valid {
384 sourceRepoAtParsed, err := syntax.ParseATURI(sourceRepoAt.String)
385 if err != nil {
386 return nil, err
387 }
388 pull.PullSource.RepoAt = &sourceRepoAtParsed
389 }
390 }
391
392 if stackId.Valid {
393 pull.StackId = stackId.String
394 }
395 if changeId.Valid {
396 pull.ChangeId = changeId.String
397 }
398 if parentChangeId.Valid {
399 pull.ParentChangeId = parentChangeId.String
400 }
401
402 pulls[pull.PullId] = &pull
403 }
404
405 // get latest round no. for each pull
406 inClause := strings.TrimSuffix(strings.Repeat("?, ", len(pulls)), ", ")
407 submissionsQuery := fmt.Sprintf(`
408 select
409 id, pull_id, round_number, patch, source_rev
410 from
411 pull_submissions
412 where
413 repo_at in (%s) and pull_id in (%s)
414 `, inClause, inClause)
415
416 args = make([]any, len(pulls)*2)
417 idx := 0
418 for _, p := range pulls {
419 args[idx] = p.RepoAt
420 idx += 1
421 }
422 for _, p := range pulls {
423 args[idx] = p.PullId
424 idx += 1
425 }
426 submissionsRows, err := e.Query(submissionsQuery, args...)
427 if err != nil {
428 return nil, err
429 }
430 defer submissionsRows.Close()
431
432 for submissionsRows.Next() {
433 var s PullSubmission
434 var sourceRev sql.NullString
435 err := submissionsRows.Scan(
436 &s.ID,
437 &s.PullId,
438 &s.RoundNumber,
439 &s.Patch,
440 &sourceRev,
441 )
442 if err != nil {
443 return nil, err
444 }
445
446 if sourceRev.Valid {
447 s.SourceRev = sourceRev.String
448 }
449
450 if p, ok := pulls[s.PullId]; ok {
451 p.Submissions = make([]*PullSubmission, s.RoundNumber+1)
452 p.Submissions[s.RoundNumber] = &s
453 }
454 }
455 if err := rows.Err(); err != nil {
456 return nil, err
457 }
458
459 // get comment count on latest submission on each pull
460 inClause = strings.TrimSuffix(strings.Repeat("?, ", len(pulls)), ", ")
461 commentsQuery := fmt.Sprintf(`
462 select
463 count(id), pull_id
464 from
465 pull_comments
466 where
467 submission_id in (%s)
468 group by
469 submission_id
470 `, inClause)
471
472 args = []any{}
473 for _, p := range pulls {
474 args = append(args, p.Submissions[p.LastRoundNumber()].ID)
475 }
476 commentsRows, err := e.Query(commentsQuery, args...)
477 if err != nil {
478 return nil, err
479 }
480 defer commentsRows.Close()
481
482 for commentsRows.Next() {
483 var commentCount, pullId int
484 err := commentsRows.Scan(
485 &commentCount,
486 &pullId,
487 )
488 if err != nil {
489 return nil, err
490 }
491 if p, ok := pulls[pullId]; ok {
492 p.Submissions[p.LastRoundNumber()].Comments = make([]PullComment, commentCount)
493 }
494 }
495 if err := rows.Err(); err != nil {
496 return nil, err
497 }
498
499 orderedByPullId := []*Pull{}
500 for _, p := range pulls {
501 orderedByPullId = append(orderedByPullId, p)
502 }
503 sort.Slice(orderedByPullId, func(i, j int) bool {
504 return orderedByPullId[i].PullId > orderedByPullId[j].PullId
505 })
506
507 return orderedByPullId, nil
508}
509
510func GetPull(e Execer, repoAt syntax.ATURI, pullId int) (*Pull, error) {
511 query := `
512 select
513 owner_did,
514 pull_id,
515 created,
516 title,
517 state,
518 target_branch,
519 repo_at,
520 body,
521 rkey,
522 source_branch,
523 source_repo_at,
524 stack_id,
525 change_id,
526 parent_change_id
527 from
528 pulls
529 where
530 repo_at = ? and pull_id = ?
531 `
532 row := e.QueryRow(query, repoAt, pullId)
533
534 var pull Pull
535 var createdAt string
536 var sourceBranch, sourceRepoAt, stackId, changeId, parentChangeId sql.NullString
537 err := row.Scan(
538 &pull.OwnerDid,
539 &pull.PullId,
540 &createdAt,
541 &pull.Title,
542 &pull.State,
543 &pull.TargetBranch,
544 &pull.RepoAt,
545 &pull.Body,
546 &pull.Rkey,
547 &sourceBranch,
548 &sourceRepoAt,
549 &stackId,
550 &changeId,
551 &parentChangeId,
552 )
553 if err != nil {
554 return nil, err
555 }
556
557 createdTime, err := time.Parse(time.RFC3339, createdAt)
558 if err != nil {
559 return nil, err
560 }
561 pull.Created = createdTime
562
563 // populate source
564 if sourceBranch.Valid {
565 pull.PullSource = &PullSource{
566 Branch: sourceBranch.String,
567 }
568 if sourceRepoAt.Valid {
569 sourceRepoAtParsed, err := syntax.ParseATURI(sourceRepoAt.String)
570 if err != nil {
571 return nil, err
572 }
573 pull.PullSource.RepoAt = &sourceRepoAtParsed
574 }
575 }
576
577 if stackId.Valid {
578 pull.StackId = stackId.String
579 }
580 if changeId.Valid {
581 pull.ChangeId = changeId.String
582 }
583 if parentChangeId.Valid {
584 pull.ParentChangeId = parentChangeId.String
585 }
586
587 submissionsQuery := `
588 select
589 id, pull_id, repo_at, round_number, patch, created, source_rev
590 from
591 pull_submissions
592 where
593 repo_at = ? and pull_id = ?
594 `
595 submissionsRows, err := e.Query(submissionsQuery, repoAt, pullId)
596 if err != nil {
597 return nil, err
598 }
599 defer submissionsRows.Close()
600
601 submissionsMap := make(map[int]*PullSubmission)
602
603 for submissionsRows.Next() {
604 var submission PullSubmission
605 var submissionCreatedStr string
606 var submissionSourceRev sql.NullString
607 err := submissionsRows.Scan(
608 &submission.ID,
609 &submission.PullId,
610 &submission.RepoAt,
611 &submission.RoundNumber,
612 &submission.Patch,
613 &submissionCreatedStr,
614 &submissionSourceRev,
615 )
616 if err != nil {
617 return nil, err
618 }
619
620 submissionCreatedTime, err := time.Parse(time.RFC3339, submissionCreatedStr)
621 if err != nil {
622 return nil, err
623 }
624 submission.Created = submissionCreatedTime
625
626 if submissionSourceRev.Valid {
627 submission.SourceRev = submissionSourceRev.String
628 }
629
630 submissionsMap[submission.ID] = &submission
631 }
632 if err = submissionsRows.Close(); err != nil {
633 return nil, err
634 }
635 if len(submissionsMap) == 0 {
636 return &pull, nil
637 }
638
639 var args []any
640 for k := range submissionsMap {
641 args = append(args, k)
642 }
643 inClause := strings.TrimSuffix(strings.Repeat("?, ", len(submissionsMap)), ", ")
644 commentsQuery := fmt.Sprintf(`
645 select
646 id,
647 pull_id,
648 submission_id,
649 repo_at,
650 owner_did,
651 comment_at,
652 body,
653 created
654 from
655 pull_comments
656 where
657 submission_id IN (%s)
658 order by
659 created asc
660 `, inClause)
661 commentsRows, err := e.Query(commentsQuery, args...)
662 if err != nil {
663 return nil, err
664 }
665 defer commentsRows.Close()
666
667 for commentsRows.Next() {
668 var comment PullComment
669 var commentCreatedStr string
670 err := commentsRows.Scan(
671 &comment.ID,
672 &comment.PullId,
673 &comment.SubmissionId,
674 &comment.RepoAt,
675 &comment.OwnerDid,
676 &comment.CommentAt,
677 &comment.Body,
678 &commentCreatedStr,
679 )
680 if err != nil {
681 return nil, err
682 }
683
684 commentCreatedTime, err := time.Parse(time.RFC3339, commentCreatedStr)
685 if err != nil {
686 return nil, err
687 }
688 comment.Created = commentCreatedTime
689
690 // Add the comment to its submission
691 if submission, ok := submissionsMap[comment.SubmissionId]; ok {
692 submission.Comments = append(submission.Comments, comment)
693 }
694
695 }
696 if err = commentsRows.Err(); err != nil {
697 return nil, err
698 }
699
700 var pullSourceRepo *Repo
701 if pull.PullSource != nil {
702 if pull.PullSource.RepoAt != nil {
703 pullSourceRepo, err = GetRepoByAtUri(e, pull.PullSource.RepoAt.String())
704 if err != nil {
705 log.Printf("failed to get repo by at uri: %v", err)
706 } else {
707 pull.PullSource.Repo = pullSourceRepo
708 }
709 }
710 }
711
712 pull.Submissions = make([]*PullSubmission, len(submissionsMap))
713 for _, submission := range submissionsMap {
714 pull.Submissions[submission.RoundNumber] = submission
715 }
716
717 return &pull, nil
718}
719
720// timeframe here is directly passed into the sql query filter, and any
721// timeframe in the past should be negative; e.g.: "-3 months"
722func GetPullsByOwnerDid(e Execer, did, timeframe string) ([]Pull, error) {
723 var pulls []Pull
724
725 rows, err := e.Query(`
726 select
727 p.owner_did,
728 p.repo_at,
729 p.pull_id,
730 p.created,
731 p.title,
732 p.state,
733 r.did,
734 r.name,
735 r.knot,
736 r.rkey,
737 r.created
738 from
739 pulls p
740 join
741 repos r on p.repo_at = r.at_uri
742 where
743 p.owner_did = ? and p.created >= date ('now', ?)
744 order by
745 p.created desc`, did, timeframe)
746 if err != nil {
747 return nil, err
748 }
749 defer rows.Close()
750
751 for rows.Next() {
752 var pull Pull
753 var repo Repo
754 var pullCreatedAt, repoCreatedAt string
755 err := rows.Scan(
756 &pull.OwnerDid,
757 &pull.RepoAt,
758 &pull.PullId,
759 &pullCreatedAt,
760 &pull.Title,
761 &pull.State,
762 &repo.Did,
763 &repo.Name,
764 &repo.Knot,
765 &repo.Rkey,
766 &repoCreatedAt,
767 )
768 if err != nil {
769 return nil, err
770 }
771
772 pullCreatedTime, err := time.Parse(time.RFC3339, pullCreatedAt)
773 if err != nil {
774 return nil, err
775 }
776 pull.Created = pullCreatedTime
777
778 repoCreatedTime, err := time.Parse(time.RFC3339, repoCreatedAt)
779 if err != nil {
780 return nil, err
781 }
782 repo.Created = repoCreatedTime
783
784 pull.Repo = &repo
785
786 pulls = append(pulls, pull)
787 }
788
789 if err := rows.Err(); err != nil {
790 return nil, err
791 }
792
793 return pulls, nil
794}
795
796func NewPullComment(e Execer, comment *PullComment) (int64, error) {
797 query := `insert into pull_comments (owner_did, repo_at, submission_id, comment_at, pull_id, body) values (?, ?, ?, ?, ?, ?)`
798 res, err := e.Exec(
799 query,
800 comment.OwnerDid,
801 comment.RepoAt,
802 comment.SubmissionId,
803 comment.CommentAt,
804 comment.PullId,
805 comment.Body,
806 )
807 if err != nil {
808 return 0, err
809 }
810
811 i, err := res.LastInsertId()
812 if err != nil {
813 return 0, err
814 }
815
816 return i, nil
817}
818
819func SetPullState(e Execer, repoAt syntax.ATURI, pullId int, pullState PullState) error {
820 _, err := e.Exec(
821 `update pulls set state = ? where repo_at = ? and pull_id = ? and (state <> ? or state <> ?)`,
822 pullState,
823 repoAt,
824 pullId,
825 PullDeleted, // only update state of non-deleted pulls
826 PullMerged, // only update state of non-merged pulls
827 )
828 return err
829}
830
831func ClosePull(e Execer, repoAt syntax.ATURI, pullId int) error {
832 err := SetPullState(e, repoAt, pullId, PullClosed)
833 return err
834}
835
836func ReopenPull(e Execer, repoAt syntax.ATURI, pullId int) error {
837 err := SetPullState(e, repoAt, pullId, PullOpen)
838 return err
839}
840
841func MergePull(e Execer, repoAt syntax.ATURI, pullId int) error {
842 err := SetPullState(e, repoAt, pullId, PullMerged)
843 return err
844}
845
846func DeletePull(e Execer, repoAt syntax.ATURI, pullId int) error {
847 err := SetPullState(e, repoAt, pullId, PullDeleted)
848 return err
849}
850
851func ResubmitPull(e Execer, pull *Pull, newPatch, sourceRev string) error {
852 newRoundNumber := len(pull.Submissions)
853 _, err := e.Exec(`
854 insert into pull_submissions (pull_id, repo_at, round_number, patch, source_rev)
855 values (?, ?, ?, ?, ?)
856 `, pull.PullId, pull.RepoAt, newRoundNumber, newPatch, sourceRev)
857
858 return err
859}
860
861func SetPullParentChangeId(e Execer, parentChangeId string, filters ...filter) error {
862 var conditions []string
863 var args []any
864
865 args = append(args, parentChangeId)
866
867 for _, filter := range filters {
868 conditions = append(conditions, filter.Condition())
869 args = append(args, filter.arg)
870 }
871
872 whereClause := ""
873 if conditions != nil {
874 whereClause = " where " + strings.Join(conditions, " and ")
875 }
876
877 query := fmt.Sprintf("update pulls set parent_change_id = ? %s", whereClause)
878 _, err := e.Exec(query, args...)
879
880 return err
881}
882
883// Only used when stacking to update contents in the event of a rebase (the interdiff should be empty).
884// otherwise submissions are immutable
885func UpdatePull(e Execer, newPatch, sourceRev string, filters ...filter) error {
886 var conditions []string
887 var args []any
888
889 args = append(args, sourceRev)
890 args = append(args, newPatch)
891
892 for _, filter := range filters {
893 conditions = append(conditions, filter.Condition())
894 args = append(args, filter.arg)
895 }
896
897 whereClause := ""
898 if conditions != nil {
899 whereClause = " where " + strings.Join(conditions, " and ")
900 }
901
902 query := fmt.Sprintf("update pull_submissions set source_rev = ?, patch = ? %s", whereClause)
903 _, err := e.Exec(query, args...)
904
905 return err
906}
907
908type PullCount struct {
909 Open int
910 Merged int
911 Closed int
912 Deleted int
913}
914
915func GetPullCount(e Execer, repoAt syntax.ATURI) (PullCount, error) {
916 row := e.QueryRow(`
917 select
918 count(case when state = ? then 1 end) as open_count,
919 count(case when state = ? then 1 end) as merged_count,
920 count(case when state = ? then 1 end) as closed_count,
921 count(case when state = ? then 1 end) as deleted_count
922 from pulls
923 where repo_at = ?`,
924 PullOpen,
925 PullMerged,
926 PullClosed,
927 PullDeleted,
928 repoAt,
929 )
930
931 var count PullCount
932 if err := row.Scan(&count.Open, &count.Merged, &count.Closed, &count.Deleted); err != nil {
933 return PullCount{0, 0, 0, 0}, err
934 }
935
936 return count, nil
937}
938
939type Stack []*Pull
940
941// change-id parent-change-id
942//
943// 4 w ,-------- z (TOP)
944// 3 z <----',------- y
945// 2 y <-----',------ x
946// 1 x <------' nil (BOT)
947//
948// `w` is parent of none, so it is the top of the stack
949func GetStack(e Execer, stackId string) (Stack, error) {
950 unorderedPulls, err := GetPulls(
951 e,
952 FilterEq("stack_id", stackId),
953 FilterNotEq("state", PullDeleted),
954 )
955 if err != nil {
956 return nil, err
957 }
958 // map of parent-change-id to pull
959 changeIdMap := make(map[string]*Pull, len(unorderedPulls))
960 parentMap := make(map[string]*Pull, len(unorderedPulls))
961 for _, p := range unorderedPulls {
962 changeIdMap[p.ChangeId] = p
963 if p.ParentChangeId != "" {
964 parentMap[p.ParentChangeId] = p
965 }
966 }
967
968 // the top of the stack is the pull that is not a parent of any pull
969 var topPull *Pull
970 for _, maybeTop := range unorderedPulls {
971 if _, ok := parentMap[maybeTop.ChangeId]; !ok {
972 topPull = maybeTop
973 break
974 }
975 }
976
977 pulls := []*Pull{}
978 for {
979 pulls = append(pulls, topPull)
980 if topPull.ParentChangeId != "" {
981 if next, ok := changeIdMap[topPull.ParentChangeId]; ok {
982 topPull = next
983 } else {
984 return nil, fmt.Errorf("failed to find parent pull request, stack is malformed")
985 }
986 } else {
987 break
988 }
989 }
990
991 return pulls, nil
992}
993
994func GetAbandonedPulls(e Execer, stackId string) ([]*Pull, error) {
995 pulls, err := GetPulls(
996 e,
997 FilterEq("stack_id", stackId),
998 FilterEq("state", PullDeleted),
999 )
1000 if err != nil {
1001 return nil, err
1002 }
1003
1004 return pulls, nil
1005}
1006
1007// position of this pull in the stack
1008func (stack Stack) Position(pull *Pull) int {
1009 return slices.IndexFunc(stack, func(p *Pull) bool {
1010 return p.ChangeId == pull.ChangeId
1011 })
1012}
1013
1014// all pulls below this pull (including self) in this stack
1015//
1016// nil if this pull does not belong to this stack
1017func (stack Stack) Below(pull *Pull) Stack {
1018 position := stack.Position(pull)
1019
1020 if position < 0 {
1021 return nil
1022 }
1023
1024 return stack[position:]
1025}
1026
1027// all pulls below this pull (excluding self) in this stack
1028func (stack Stack) StrictlyBelow(pull *Pull) Stack {
1029 below := stack.Below(pull)
1030
1031 if len(below) > 0 {
1032 return below[1:]
1033 }
1034
1035 return nil
1036}
1037
1038// all pulls above this pull (including self) in this stack
1039func (stack Stack) Above(pull *Pull) Stack {
1040 position := stack.Position(pull)
1041
1042 if position < 0 {
1043 return nil
1044 }
1045
1046 return stack[:position+1]
1047}
1048
1049// all pulls below this pull (excluding self) in this stack
1050func (stack Stack) StrictlyAbove(pull *Pull) Stack {
1051 above := stack.Above(pull)
1052
1053 if len(above) > 0 {
1054 return above[:len(above)-1]
1055 }
1056
1057 return nil
1058}
1059
1060// the combined format-patches of all the newest submissions in this stack
1061func (stack Stack) CombinedPatch() string {
1062 // go in reverse order because the bottom of the stack is the last element in the slice
1063 var combined strings.Builder
1064 for idx := range stack {
1065 pull := stack[len(stack)-1-idx]
1066 combined.WriteString(pull.LatestPatch())
1067 combined.WriteString("\n")
1068 }
1069 return combined.String()
1070}
1071
1072// filter out PRs that are "active"
1073//
1074// PRs that are still open are active
1075func (stack Stack) Mergeable() Stack {
1076 var mergeable Stack
1077
1078 for _, p := range stack {
1079 // stop at the first merged PR
1080 if p.State == PullMerged || p.State == PullClosed {
1081 break
1082 }
1083
1084 // skip over deleted PRs
1085 if p.State != PullDeleted {
1086 mergeable = append(mergeable, p)
1087 }
1088 }
1089
1090 return mergeable
1091}