forked from
tangled.org/core
Monorepo for Tangled — https://tangled.org
1package db
2
3import (
4 "database/sql"
5 "fmt"
6 "log"
7 "slices"
8 "sort"
9 "strings"
10 "time"
11
12 "github.com/bluesky-social/indigo/atproto/syntax"
13 "tangled.sh/tangled.sh/core/api/tangled"
14 "tangled.sh/tangled.sh/core/patchutil"
15 "tangled.sh/tangled.sh/core/types"
16)
17
18type PullState int
19
20const (
21 PullClosed PullState = iota
22 PullOpen
23 PullMerged
24 PullDeleted
25)
26
27func (p PullState) String() string {
28 switch p {
29 case PullOpen:
30 return "open"
31 case PullMerged:
32 return "merged"
33 case PullClosed:
34 return "closed"
35 case PullDeleted:
36 return "deleted"
37 default:
38 return "closed"
39 }
40}
41
42func (p PullState) IsOpen() bool {
43 return p == PullOpen
44}
45func (p PullState) IsMerged() bool {
46 return p == PullMerged
47}
48func (p PullState) IsClosed() bool {
49 return p == PullClosed
50}
51func (p PullState) IsDeleted() bool {
52 return p == PullDeleted
53}
54
55type Pull struct {
56 // ids
57 ID int
58 PullId int
59
60 // at ids
61 RepoAt syntax.ATURI
62 OwnerDid string
63 Rkey string
64
65 // content
66 Title string
67 Body string
68 TargetBranch string
69 State PullState
70 Submissions []*PullSubmission
71
72 // stacking
73 StackId string // nullable string
74 ChangeId string // nullable string
75 ParentChangeId string // nullable string
76
77 // meta
78 Created time.Time
79 PullSource *PullSource
80
81 // optionally, populate this when querying for reverse mappings
82 Repo *Repo
83}
84
85func (p Pull) AsRecord() tangled.RepoPull {
86 var source *tangled.RepoPull_Source
87 if p.PullSource != nil {
88 s := p.PullSource.AsRecord()
89 source = &s
90 source.Sha = p.LatestSha()
91 }
92
93 record := tangled.RepoPull{
94 Title: p.Title,
95 Body: &p.Body,
96 CreatedAt: p.Created.Format(time.RFC3339),
97 PullId: int64(p.PullId),
98 TargetRepo: p.RepoAt.String(),
99 TargetBranch: p.TargetBranch,
100 Patch: p.LatestPatch(),
101 Source: source,
102 }
103 return record
104}
105
106type PullSource struct {
107 Branch string
108 RepoAt *syntax.ATURI
109
110 // optionally populate this for reverse mappings
111 Repo *Repo
112}
113
114func (p PullSource) AsRecord() tangled.RepoPull_Source {
115 var repoAt *string
116 if p.RepoAt != nil {
117 s := p.RepoAt.String()
118 repoAt = &s
119 }
120 record := tangled.RepoPull_Source{
121 Branch: p.Branch,
122 Repo: repoAt,
123 }
124 return record
125}
126
127type PullSubmission struct {
128 // ids
129 ID int
130 PullId int
131
132 // at ids
133 RepoAt syntax.ATURI
134
135 // content
136 RoundNumber int
137 Patch string
138 Comments []PullComment
139 SourceRev string // include the rev that was used to create this submission: only for branch/fork PRs
140
141 // meta
142 Created time.Time
143}
144
145type PullComment struct {
146 // ids
147 ID int
148 PullId int
149 SubmissionId int
150
151 // at ids
152 RepoAt string
153 OwnerDid string
154 CommentAt string
155
156 // content
157 Body string
158
159 // meta
160 Created time.Time
161}
162
163func (p *Pull) LatestPatch() string {
164 latestSubmission := p.Submissions[p.LastRoundNumber()]
165 return latestSubmission.Patch
166}
167
168func (p *Pull) LatestSha() string {
169 latestSubmission := p.Submissions[p.LastRoundNumber()]
170 return latestSubmission.SourceRev
171}
172
173func (p *Pull) PullAt() syntax.ATURI {
174 return syntax.ATURI(fmt.Sprintf("at://%s/%s/%s", p.OwnerDid, tangled.RepoPullNSID, p.Rkey))
175}
176
177func (p *Pull) LastRoundNumber() int {
178 return len(p.Submissions) - 1
179}
180
181func (p *Pull) IsPatchBased() bool {
182 return p.PullSource == nil
183}
184
185func (p *Pull) IsBranchBased() bool {
186 if p.PullSource != nil {
187 if p.PullSource.RepoAt != nil {
188 return p.PullSource.RepoAt == &p.RepoAt
189 } else {
190 // no repo specified
191 return true
192 }
193 }
194 return false
195}
196
197func (p *Pull) IsForkBased() bool {
198 if p.PullSource != nil {
199 if p.PullSource.RepoAt != nil {
200 // make sure repos are different
201 return p.PullSource.RepoAt != &p.RepoAt
202 }
203 }
204 return false
205}
206
207func (p *Pull) IsStacked() bool {
208 return p.StackId != ""
209}
210
211func (s PullSubmission) IsFormatPatch() bool {
212 return patchutil.IsFormatPatch(s.Patch)
213}
214
215func (s PullSubmission) AsFormatPatch() []types.FormatPatch {
216 patches, err := patchutil.ExtractPatches(s.Patch)
217 if err != nil {
218 log.Println("error extracting patches from submission:", err)
219 return []types.FormatPatch{}
220 }
221
222 return patches
223}
224
225func NewPull(tx *sql.Tx, pull *Pull) error {
226 _, err := tx.Exec(`
227 insert or ignore into repo_pull_seqs (repo_at, next_pull_id)
228 values (?, 1)
229 `, pull.RepoAt)
230 if err != nil {
231 return err
232 }
233
234 var nextId int
235 err = tx.QueryRow(`
236 update repo_pull_seqs
237 set next_pull_id = next_pull_id + 1
238 where repo_at = ?
239 returning next_pull_id - 1
240 `, pull.RepoAt).Scan(&nextId)
241 if err != nil {
242 return err
243 }
244
245 pull.PullId = nextId
246 pull.State = PullOpen
247
248 var sourceBranch, sourceRepoAt *string
249 if pull.PullSource != nil {
250 sourceBranch = &pull.PullSource.Branch
251 if pull.PullSource.RepoAt != nil {
252 x := pull.PullSource.RepoAt.String()
253 sourceRepoAt = &x
254 }
255 }
256
257 var stackId, changeId, parentChangeId *string
258 if pull.StackId != "" {
259 stackId = &pull.StackId
260 }
261 if pull.ChangeId != "" {
262 changeId = &pull.ChangeId
263 }
264 if pull.ParentChangeId != "" {
265 parentChangeId = &pull.ParentChangeId
266 }
267
268 _, err = tx.Exec(
269 `
270 insert into pulls (
271 repo_at, owner_did, pull_id, title, target_branch, body, rkey, state, source_branch, source_repo_at, stack_id, change_id, parent_change_id
272 )
273 values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`,
274 pull.RepoAt,
275 pull.OwnerDid,
276 pull.PullId,
277 pull.Title,
278 pull.TargetBranch,
279 pull.Body,
280 pull.Rkey,
281 pull.State,
282 sourceBranch,
283 sourceRepoAt,
284 stackId,
285 changeId,
286 parentChangeId,
287 )
288 if err != nil {
289 return err
290 }
291
292 _, err = tx.Exec(`
293 insert into pull_submissions (pull_id, repo_at, round_number, patch, source_rev)
294 values (?, ?, ?, ?, ?)
295 `, pull.PullId, pull.RepoAt, 0, pull.Submissions[0].Patch, pull.Submissions[0].SourceRev)
296 return err
297}
298
299func GetPullAt(e Execer, repoAt syntax.ATURI, pullId int) (syntax.ATURI, error) {
300 pull, err := GetPull(e, repoAt, pullId)
301 if err != nil {
302 return "", err
303 }
304 return pull.PullAt(), err
305}
306
307func NextPullId(e Execer, repoAt syntax.ATURI) (int, error) {
308 var pullId int
309 err := e.QueryRow(`select next_pull_id from repo_pull_seqs where repo_at = ?`, repoAt).Scan(&pullId)
310 return pullId - 1, err
311}
312
313func GetPulls(e Execer, filters ...filter) ([]*Pull, error) {
314 pulls := make(map[int]*Pull)
315
316 var conditions []string
317 var args []any
318 for _, filter := range filters {
319 conditions = append(conditions, filter.Condition())
320 args = append(args, filter.Arg()...)
321 }
322
323 whereClause := ""
324 if conditions != nil {
325 whereClause = " where " + strings.Join(conditions, " and ")
326 }
327
328 query := fmt.Sprintf(`
329 select
330 owner_did,
331 repo_at,
332 pull_id,
333 created,
334 title,
335 state,
336 target_branch,
337 body,
338 rkey,
339 source_branch,
340 source_repo_at,
341 stack_id,
342 change_id,
343 parent_change_id
344 from
345 pulls
346 %s
347 `, whereClause)
348
349 rows, err := e.Query(query, args...)
350 if err != nil {
351 return nil, err
352 }
353 defer rows.Close()
354
355 for rows.Next() {
356 var pull Pull
357 var createdAt string
358 var sourceBranch, sourceRepoAt, stackId, changeId, parentChangeId sql.NullString
359 err := rows.Scan(
360 &pull.OwnerDid,
361 &pull.RepoAt,
362 &pull.PullId,
363 &createdAt,
364 &pull.Title,
365 &pull.State,
366 &pull.TargetBranch,
367 &pull.Body,
368 &pull.Rkey,
369 &sourceBranch,
370 &sourceRepoAt,
371 &stackId,
372 &changeId,
373 &parentChangeId,
374 )
375 if err != nil {
376 return nil, err
377 }
378
379 createdTime, err := time.Parse(time.RFC3339, createdAt)
380 if err != nil {
381 return nil, err
382 }
383 pull.Created = createdTime
384
385 if sourceBranch.Valid {
386 pull.PullSource = &PullSource{
387 Branch: sourceBranch.String,
388 }
389 if sourceRepoAt.Valid {
390 sourceRepoAtParsed, err := syntax.ParseATURI(sourceRepoAt.String)
391 if err != nil {
392 return nil, err
393 }
394 pull.PullSource.RepoAt = &sourceRepoAtParsed
395 }
396 }
397
398 if stackId.Valid {
399 pull.StackId = stackId.String
400 }
401 if changeId.Valid {
402 pull.ChangeId = changeId.String
403 }
404 if parentChangeId.Valid {
405 pull.ParentChangeId = parentChangeId.String
406 }
407
408 pulls[pull.PullId] = &pull
409 }
410
411 // get latest round no. for each pull
412 inClause := strings.TrimSuffix(strings.Repeat("?, ", len(pulls)), ", ")
413 submissionsQuery := fmt.Sprintf(`
414 select
415 id, pull_id, round_number, patch, source_rev
416 from
417 pull_submissions
418 where
419 repo_at in (%s) and pull_id in (%s)
420 `, inClause, inClause)
421
422 args = make([]any, len(pulls)*2)
423 idx := 0
424 for _, p := range pulls {
425 args[idx] = p.RepoAt
426 idx += 1
427 }
428 for _, p := range pulls {
429 args[idx] = p.PullId
430 idx += 1
431 }
432 submissionsRows, err := e.Query(submissionsQuery, args...)
433 if err != nil {
434 return nil, err
435 }
436 defer submissionsRows.Close()
437
438 for submissionsRows.Next() {
439 var s PullSubmission
440 var sourceRev sql.NullString
441 err := submissionsRows.Scan(
442 &s.ID,
443 &s.PullId,
444 &s.RoundNumber,
445 &s.Patch,
446 &sourceRev,
447 )
448 if err != nil {
449 return nil, err
450 }
451
452 if sourceRev.Valid {
453 s.SourceRev = sourceRev.String
454 }
455
456 if p, ok := pulls[s.PullId]; ok {
457 p.Submissions = make([]*PullSubmission, s.RoundNumber+1)
458 p.Submissions[s.RoundNumber] = &s
459 }
460 }
461 if err := rows.Err(); err != nil {
462 return nil, err
463 }
464
465 // get comment count on latest submission on each pull
466 inClause = strings.TrimSuffix(strings.Repeat("?, ", len(pulls)), ", ")
467 commentsQuery := fmt.Sprintf(`
468 select
469 count(id), pull_id
470 from
471 pull_comments
472 where
473 submission_id in (%s)
474 group by
475 submission_id
476 `, inClause)
477
478 args = []any{}
479 for _, p := range pulls {
480 args = append(args, p.Submissions[p.LastRoundNumber()].ID)
481 }
482 commentsRows, err := e.Query(commentsQuery, args...)
483 if err != nil {
484 return nil, err
485 }
486 defer commentsRows.Close()
487
488 for commentsRows.Next() {
489 var commentCount, pullId int
490 err := commentsRows.Scan(
491 &commentCount,
492 &pullId,
493 )
494 if err != nil {
495 return nil, err
496 }
497 if p, ok := pulls[pullId]; ok {
498 p.Submissions[p.LastRoundNumber()].Comments = make([]PullComment, commentCount)
499 }
500 }
501 if err := rows.Err(); err != nil {
502 return nil, err
503 }
504
505 orderedByPullId := []*Pull{}
506 for _, p := range pulls {
507 orderedByPullId = append(orderedByPullId, p)
508 }
509 sort.Slice(orderedByPullId, func(i, j int) bool {
510 return orderedByPullId[i].PullId > orderedByPullId[j].PullId
511 })
512
513 return orderedByPullId, nil
514}
515
516func GetPull(e Execer, repoAt syntax.ATURI, pullId int) (*Pull, error) {
517 query := `
518 select
519 owner_did,
520 pull_id,
521 created,
522 title,
523 state,
524 target_branch,
525 repo_at,
526 body,
527 rkey,
528 source_branch,
529 source_repo_at,
530 stack_id,
531 change_id,
532 parent_change_id
533 from
534 pulls
535 where
536 repo_at = ? and pull_id = ?
537 `
538 row := e.QueryRow(query, repoAt, pullId)
539
540 var pull Pull
541 var createdAt string
542 var sourceBranch, sourceRepoAt, stackId, changeId, parentChangeId sql.NullString
543 err := row.Scan(
544 &pull.OwnerDid,
545 &pull.PullId,
546 &createdAt,
547 &pull.Title,
548 &pull.State,
549 &pull.TargetBranch,
550 &pull.RepoAt,
551 &pull.Body,
552 &pull.Rkey,
553 &sourceBranch,
554 &sourceRepoAt,
555 &stackId,
556 &changeId,
557 &parentChangeId,
558 )
559 if err != nil {
560 return nil, err
561 }
562
563 createdTime, err := time.Parse(time.RFC3339, createdAt)
564 if err != nil {
565 return nil, err
566 }
567 pull.Created = createdTime
568
569 // populate source
570 if sourceBranch.Valid {
571 pull.PullSource = &PullSource{
572 Branch: sourceBranch.String,
573 }
574 if sourceRepoAt.Valid {
575 sourceRepoAtParsed, err := syntax.ParseATURI(sourceRepoAt.String)
576 if err != nil {
577 return nil, err
578 }
579 pull.PullSource.RepoAt = &sourceRepoAtParsed
580 }
581 }
582
583 if stackId.Valid {
584 pull.StackId = stackId.String
585 }
586 if changeId.Valid {
587 pull.ChangeId = changeId.String
588 }
589 if parentChangeId.Valid {
590 pull.ParentChangeId = parentChangeId.String
591 }
592
593 submissionsQuery := `
594 select
595 id, pull_id, repo_at, round_number, patch, created, source_rev
596 from
597 pull_submissions
598 where
599 repo_at = ? and pull_id = ?
600 `
601 submissionsRows, err := e.Query(submissionsQuery, repoAt, pullId)
602 if err != nil {
603 return nil, err
604 }
605 defer submissionsRows.Close()
606
607 submissionsMap := make(map[int]*PullSubmission)
608
609 for submissionsRows.Next() {
610 var submission PullSubmission
611 var submissionCreatedStr string
612 var submissionSourceRev sql.NullString
613 err := submissionsRows.Scan(
614 &submission.ID,
615 &submission.PullId,
616 &submission.RepoAt,
617 &submission.RoundNumber,
618 &submission.Patch,
619 &submissionCreatedStr,
620 &submissionSourceRev,
621 )
622 if err != nil {
623 return nil, err
624 }
625
626 submissionCreatedTime, err := time.Parse(time.RFC3339, submissionCreatedStr)
627 if err != nil {
628 return nil, err
629 }
630 submission.Created = submissionCreatedTime
631
632 if submissionSourceRev.Valid {
633 submission.SourceRev = submissionSourceRev.String
634 }
635
636 submissionsMap[submission.ID] = &submission
637 }
638 if err = submissionsRows.Close(); err != nil {
639 return nil, err
640 }
641 if len(submissionsMap) == 0 {
642 return &pull, nil
643 }
644
645 var args []any
646 for k := range submissionsMap {
647 args = append(args, k)
648 }
649 inClause := strings.TrimSuffix(strings.Repeat("?, ", len(submissionsMap)), ", ")
650 commentsQuery := fmt.Sprintf(`
651 select
652 id,
653 pull_id,
654 submission_id,
655 repo_at,
656 owner_did,
657 comment_at,
658 body,
659 created
660 from
661 pull_comments
662 where
663 submission_id IN (%s)
664 order by
665 created asc
666 `, inClause)
667 commentsRows, err := e.Query(commentsQuery, args...)
668 if err != nil {
669 return nil, err
670 }
671 defer commentsRows.Close()
672
673 for commentsRows.Next() {
674 var comment PullComment
675 var commentCreatedStr string
676 err := commentsRows.Scan(
677 &comment.ID,
678 &comment.PullId,
679 &comment.SubmissionId,
680 &comment.RepoAt,
681 &comment.OwnerDid,
682 &comment.CommentAt,
683 &comment.Body,
684 &commentCreatedStr,
685 )
686 if err != nil {
687 return nil, err
688 }
689
690 commentCreatedTime, err := time.Parse(time.RFC3339, commentCreatedStr)
691 if err != nil {
692 return nil, err
693 }
694 comment.Created = commentCreatedTime
695
696 // Add the comment to its submission
697 if submission, ok := submissionsMap[comment.SubmissionId]; ok {
698 submission.Comments = append(submission.Comments, comment)
699 }
700
701 }
702 if err = commentsRows.Err(); err != nil {
703 return nil, err
704 }
705
706 var pullSourceRepo *Repo
707 if pull.PullSource != nil {
708 if pull.PullSource.RepoAt != nil {
709 pullSourceRepo, err = GetRepoByAtUri(e, pull.PullSource.RepoAt.String())
710 if err != nil {
711 log.Printf("failed to get repo by at uri: %v", err)
712 } else {
713 pull.PullSource.Repo = pullSourceRepo
714 }
715 }
716 }
717
718 pull.Submissions = make([]*PullSubmission, len(submissionsMap))
719 for _, submission := range submissionsMap {
720 pull.Submissions[submission.RoundNumber] = submission
721 }
722
723 return &pull, nil
724}
725
726// timeframe here is directly passed into the sql query filter, and any
727// timeframe in the past should be negative; e.g.: "-3 months"
728func GetPullsByOwnerDid(e Execer, did, timeframe string) ([]Pull, error) {
729 var pulls []Pull
730
731 rows, err := e.Query(`
732 select
733 p.owner_did,
734 p.repo_at,
735 p.pull_id,
736 p.created,
737 p.title,
738 p.state,
739 r.did,
740 r.name,
741 r.knot,
742 r.rkey,
743 r.created
744 from
745 pulls p
746 join
747 repos r on p.repo_at = r.at_uri
748 where
749 p.owner_did = ? and p.created >= date ('now', ?)
750 order by
751 p.created desc`, did, timeframe)
752 if err != nil {
753 return nil, err
754 }
755 defer rows.Close()
756
757 for rows.Next() {
758 var pull Pull
759 var repo Repo
760 var pullCreatedAt, repoCreatedAt string
761 err := rows.Scan(
762 &pull.OwnerDid,
763 &pull.RepoAt,
764 &pull.PullId,
765 &pullCreatedAt,
766 &pull.Title,
767 &pull.State,
768 &repo.Did,
769 &repo.Name,
770 &repo.Knot,
771 &repo.Rkey,
772 &repoCreatedAt,
773 )
774 if err != nil {
775 return nil, err
776 }
777
778 pullCreatedTime, err := time.Parse(time.RFC3339, pullCreatedAt)
779 if err != nil {
780 return nil, err
781 }
782 pull.Created = pullCreatedTime
783
784 repoCreatedTime, err := time.Parse(time.RFC3339, repoCreatedAt)
785 if err != nil {
786 return nil, err
787 }
788 repo.Created = repoCreatedTime
789
790 pull.Repo = &repo
791
792 pulls = append(pulls, pull)
793 }
794
795 if err := rows.Err(); err != nil {
796 return nil, err
797 }
798
799 return pulls, nil
800}
801
802func NewPullComment(e Execer, comment *PullComment) (int64, error) {
803 query := `insert into pull_comments (owner_did, repo_at, submission_id, comment_at, pull_id, body) values (?, ?, ?, ?, ?, ?)`
804 res, err := e.Exec(
805 query,
806 comment.OwnerDid,
807 comment.RepoAt,
808 comment.SubmissionId,
809 comment.CommentAt,
810 comment.PullId,
811 comment.Body,
812 )
813 if err != nil {
814 return 0, err
815 }
816
817 i, err := res.LastInsertId()
818 if err != nil {
819 return 0, err
820 }
821
822 return i, nil
823}
824
825func SetPullState(e Execer, repoAt syntax.ATURI, pullId int, pullState PullState) error {
826 _, err := e.Exec(
827 `update pulls set state = ? where repo_at = ? and pull_id = ? and (state <> ? or state <> ?)`,
828 pullState,
829 repoAt,
830 pullId,
831 PullDeleted, // only update state of non-deleted pulls
832 PullMerged, // only update state of non-merged pulls
833 )
834 return err
835}
836
837func ClosePull(e Execer, repoAt syntax.ATURI, pullId int) error {
838 err := SetPullState(e, repoAt, pullId, PullClosed)
839 return err
840}
841
842func ReopenPull(e Execer, repoAt syntax.ATURI, pullId int) error {
843 err := SetPullState(e, repoAt, pullId, PullOpen)
844 return err
845}
846
847func MergePull(e Execer, repoAt syntax.ATURI, pullId int) error {
848 err := SetPullState(e, repoAt, pullId, PullMerged)
849 return err
850}
851
852func DeletePull(e Execer, repoAt syntax.ATURI, pullId int) error {
853 err := SetPullState(e, repoAt, pullId, PullDeleted)
854 return err
855}
856
857func ResubmitPull(e Execer, pull *Pull, newPatch, sourceRev string) error {
858 newRoundNumber := len(pull.Submissions)
859 _, err := e.Exec(`
860 insert into pull_submissions (pull_id, repo_at, round_number, patch, source_rev)
861 values (?, ?, ?, ?, ?)
862 `, pull.PullId, pull.RepoAt, newRoundNumber, newPatch, sourceRev)
863
864 return err
865}
866
867func SetPullParentChangeId(e Execer, parentChangeId string, filters ...filter) error {
868 var conditions []string
869 var args []any
870
871 args = append(args, parentChangeId)
872
873 for _, filter := range filters {
874 conditions = append(conditions, filter.Condition())
875 args = append(args, filter.Arg()...)
876 }
877
878 whereClause := ""
879 if conditions != nil {
880 whereClause = " where " + strings.Join(conditions, " and ")
881 }
882
883 query := fmt.Sprintf("update pulls set parent_change_id = ? %s", whereClause)
884 _, err := e.Exec(query, args...)
885
886 return err
887}
888
889// Only used when stacking to update contents in the event of a rebase (the interdiff should be empty).
890// otherwise submissions are immutable
891func UpdatePull(e Execer, newPatch, sourceRev string, filters ...filter) error {
892 var conditions []string
893 var args []any
894
895 args = append(args, sourceRev)
896 args = append(args, newPatch)
897
898 for _, filter := range filters {
899 conditions = append(conditions, filter.Condition())
900 args = append(args, filter.Arg()...)
901 }
902
903 whereClause := ""
904 if conditions != nil {
905 whereClause = " where " + strings.Join(conditions, " and ")
906 }
907
908 query := fmt.Sprintf("update pull_submissions set source_rev = ?, patch = ? %s", whereClause)
909 _, err := e.Exec(query, args...)
910
911 return err
912}
913
914type PullCount struct {
915 Open int
916 Merged int
917 Closed int
918 Deleted int
919}
920
921func GetPullCount(e Execer, repoAt syntax.ATURI) (PullCount, error) {
922 row := e.QueryRow(`
923 select
924 count(case when state = ? then 1 end) as open_count,
925 count(case when state = ? then 1 end) as merged_count,
926 count(case when state = ? then 1 end) as closed_count,
927 count(case when state = ? then 1 end) as deleted_count
928 from pulls
929 where repo_at = ?`,
930 PullOpen,
931 PullMerged,
932 PullClosed,
933 PullDeleted,
934 repoAt,
935 )
936
937 var count PullCount
938 if err := row.Scan(&count.Open, &count.Merged, &count.Closed, &count.Deleted); err != nil {
939 return PullCount{0, 0, 0, 0}, err
940 }
941
942 return count, nil
943}
944
945type Stack []*Pull
946
947// change-id parent-change-id
948//
949// 4 w ,-------- z (TOP)
950// 3 z <----',------- y
951// 2 y <-----',------ x
952// 1 x <------' nil (BOT)
953//
954// `w` is parent of none, so it is the top of the stack
955func GetStack(e Execer, stackId string) (Stack, error) {
956 unorderedPulls, err := GetPulls(
957 e,
958 FilterEq("stack_id", stackId),
959 FilterNotEq("state", PullDeleted),
960 )
961 if err != nil {
962 return nil, err
963 }
964 // map of parent-change-id to pull
965 changeIdMap := make(map[string]*Pull, len(unorderedPulls))
966 parentMap := make(map[string]*Pull, len(unorderedPulls))
967 for _, p := range unorderedPulls {
968 changeIdMap[p.ChangeId] = p
969 if p.ParentChangeId != "" {
970 parentMap[p.ParentChangeId] = p
971 }
972 }
973
974 // the top of the stack is the pull that is not a parent of any pull
975 var topPull *Pull
976 for _, maybeTop := range unorderedPulls {
977 if _, ok := parentMap[maybeTop.ChangeId]; !ok {
978 topPull = maybeTop
979 break
980 }
981 }
982
983 pulls := []*Pull{}
984 for {
985 pulls = append(pulls, topPull)
986 if topPull.ParentChangeId != "" {
987 if next, ok := changeIdMap[topPull.ParentChangeId]; ok {
988 topPull = next
989 } else {
990 return nil, fmt.Errorf("failed to find parent pull request, stack is malformed")
991 }
992 } else {
993 break
994 }
995 }
996
997 return pulls, nil
998}
999
1000func GetAbandonedPulls(e Execer, stackId string) ([]*Pull, error) {
1001 pulls, err := GetPulls(
1002 e,
1003 FilterEq("stack_id", stackId),
1004 FilterEq("state", PullDeleted),
1005 )
1006 if err != nil {
1007 return nil, err
1008 }
1009
1010 return pulls, nil
1011}
1012
1013// position of this pull in the stack
1014func (stack Stack) Position(pull *Pull) int {
1015 return slices.IndexFunc(stack, func(p *Pull) bool {
1016 return p.ChangeId == pull.ChangeId
1017 })
1018}
1019
1020// all pulls below this pull (including self) in this stack
1021//
1022// nil if this pull does not belong to this stack
1023func (stack Stack) Below(pull *Pull) Stack {
1024 position := stack.Position(pull)
1025
1026 if position < 0 {
1027 return nil
1028 }
1029
1030 return stack[position:]
1031}
1032
1033// all pulls below this pull (excluding self) in this stack
1034func (stack Stack) StrictlyBelow(pull *Pull) Stack {
1035 below := stack.Below(pull)
1036
1037 if len(below) > 0 {
1038 return below[1:]
1039 }
1040
1041 return nil
1042}
1043
1044// all pulls above this pull (including self) in this stack
1045func (stack Stack) Above(pull *Pull) Stack {
1046 position := stack.Position(pull)
1047
1048 if position < 0 {
1049 return nil
1050 }
1051
1052 return stack[:position+1]
1053}
1054
1055// all pulls below this pull (excluding self) in this stack
1056func (stack Stack) StrictlyAbove(pull *Pull) Stack {
1057 above := stack.Above(pull)
1058
1059 if len(above) > 0 {
1060 return above[:len(above)-1]
1061 }
1062
1063 return nil
1064}
1065
1066// the combined format-patches of all the newest submissions in this stack
1067func (stack Stack) CombinedPatch() string {
1068 // go in reverse order because the bottom of the stack is the last element in the slice
1069 var combined strings.Builder
1070 for idx := range stack {
1071 pull := stack[len(stack)-1-idx]
1072 combined.WriteString(pull.LatestPatch())
1073 combined.WriteString("\n")
1074 }
1075 return combined.String()
1076}
1077
1078// filter out PRs that are "active"
1079//
1080// PRs that are still open are active
1081func (stack Stack) Mergeable() Stack {
1082 var mergeable Stack
1083
1084 for _, p := range stack {
1085 // stop at the first merged PR
1086 if p.State == PullMerged || p.State == PullClosed {
1087 break
1088 }
1089
1090 // skip over deleted PRs
1091 if p.State != PullDeleted {
1092 mergeable = append(mergeable, p)
1093 }
1094 }
1095
1096 return mergeable
1097}