1package db
2
3import (
4 "database/sql"
5 "fmt"
6 "log"
7 "slices"
8 "sort"
9 "strings"
10 "time"
11
12 "github.com/bluesky-social/indigo/atproto/syntax"
13 "tangled.sh/tangled.sh/core/api/tangled"
14 "tangled.sh/tangled.sh/core/patchutil"
15 "tangled.sh/tangled.sh/core/types"
16)
17
18type PullState int
19
20const (
21 PullClosed PullState = iota
22 PullOpen
23 PullMerged
24 PullDeleted
25)
26
27func (p PullState) String() string {
28 switch p {
29 case PullOpen:
30 return "open"
31 case PullMerged:
32 return "merged"
33 case PullClosed:
34 return "closed"
35 case PullDeleted:
36 return "deleted"
37 default:
38 return "closed"
39 }
40}
41
42func (p PullState) IsOpen() bool {
43 return p == PullOpen
44}
45func (p PullState) IsMerged() bool {
46 return p == PullMerged
47}
48func (p PullState) IsClosed() bool {
49 return p == PullClosed
50}
51func (p PullState) IsDeleted() bool {
52 return p == PullDeleted
53}
54
55type Pull struct {
56 // ids
57 ID int
58 PullId int
59
60 // at ids
61 RepoAt syntax.ATURI
62 OwnerDid string
63 Rkey string
64
65 // content
66 Title string
67 Body string
68 TargetBranch string
69 State PullState
70 Submissions []*PullSubmission
71
72 // stacking
73 StackId string // nullable string
74 ChangeId string // nullable string
75 ParentChangeId string // nullable string
76
77 // meta
78 Created time.Time
79 PullSource *PullSource
80
81 // optionally, populate this when querying for reverse mappings
82 Repo *Repo
83}
84
85func (p Pull) AsRecord() tangled.RepoPull {
86 var source *tangled.RepoPull_Source
87 if p.PullSource != nil {
88 s := p.PullSource.AsRecord()
89 source = &s
90 source.Sha = p.LatestSha()
91 }
92
93 record := tangled.RepoPull{
94 Title: p.Title,
95 Body: &p.Body,
96 CreatedAt: p.Created.Format(time.RFC3339),
97 Target: &tangled.RepoPull_Target{
98 Repo: p.RepoAt.String(),
99 Branch: p.TargetBranch,
100 },
101 Patch: p.LatestPatch(),
102 Source: source,
103 }
104 return record
105}
106
107type PullSource struct {
108 Branch string
109 RepoAt *syntax.ATURI
110
111 // optionally populate this for reverse mappings
112 Repo *Repo
113}
114
115func (p PullSource) AsRecord() tangled.RepoPull_Source {
116 var repoAt *string
117 if p.RepoAt != nil {
118 s := p.RepoAt.String()
119 repoAt = &s
120 }
121 record := tangled.RepoPull_Source{
122 Branch: p.Branch,
123 Repo: repoAt,
124 }
125 return record
126}
127
128type PullSubmission struct {
129 // ids
130 ID int
131 PullId int
132
133 // at ids
134 RepoAt syntax.ATURI
135
136 // content
137 RoundNumber int
138 Patch string
139 Comments []PullComment
140 SourceRev string // include the rev that was used to create this submission: only for branch/fork PRs
141
142 // meta
143 Created time.Time
144}
145
146type PullComment struct {
147 // ids
148 ID int
149 PullId int
150 SubmissionId int
151
152 // at ids
153 RepoAt string
154 OwnerDid string
155 CommentAt string
156
157 // content
158 Body string
159
160 // meta
161 Created time.Time
162}
163
164func (p *Pull) LatestPatch() string {
165 latestSubmission := p.Submissions[p.LastRoundNumber()]
166 return latestSubmission.Patch
167}
168
169func (p *Pull) LatestSha() string {
170 latestSubmission := p.Submissions[p.LastRoundNumber()]
171 return latestSubmission.SourceRev
172}
173
174func (p *Pull) PullAt() syntax.ATURI {
175 return syntax.ATURI(fmt.Sprintf("at://%s/%s/%s", p.OwnerDid, tangled.RepoPullNSID, p.Rkey))
176}
177
178func (p *Pull) LastRoundNumber() int {
179 return len(p.Submissions) - 1
180}
181
182func (p *Pull) IsPatchBased() bool {
183 return p.PullSource == nil
184}
185
186func (p *Pull) IsBranchBased() bool {
187 if p.PullSource != nil {
188 if p.PullSource.RepoAt != nil {
189 return p.PullSource.RepoAt == &p.RepoAt
190 } else {
191 // no repo specified
192 return true
193 }
194 }
195 return false
196}
197
198func (p *Pull) IsForkBased() bool {
199 if p.PullSource != nil {
200 if p.PullSource.RepoAt != nil {
201 // make sure repos are different
202 return p.PullSource.RepoAt != &p.RepoAt
203 }
204 }
205 return false
206}
207
208func (p *Pull) IsStacked() bool {
209 return p.StackId != ""
210}
211
212func (s PullSubmission) IsFormatPatch() bool {
213 return patchutil.IsFormatPatch(s.Patch)
214}
215
216func (s PullSubmission) AsFormatPatch() []types.FormatPatch {
217 patches, err := patchutil.ExtractPatches(s.Patch)
218 if err != nil {
219 log.Println("error extracting patches from submission:", err)
220 return []types.FormatPatch{}
221 }
222
223 return patches
224}
225
226func NewPull(tx *sql.Tx, pull *Pull) error {
227 _, err := tx.Exec(`
228 insert or ignore into repo_pull_seqs (repo_at, next_pull_id)
229 values (?, 1)
230 `, pull.RepoAt)
231 if err != nil {
232 return err
233 }
234
235 var nextId int
236 err = tx.QueryRow(`
237 update repo_pull_seqs
238 set next_pull_id = next_pull_id + 1
239 where repo_at = ?
240 returning next_pull_id - 1
241 `, pull.RepoAt).Scan(&nextId)
242 if err != nil {
243 return err
244 }
245
246 pull.PullId = nextId
247 pull.State = PullOpen
248
249 var sourceBranch, sourceRepoAt *string
250 if pull.PullSource != nil {
251 sourceBranch = &pull.PullSource.Branch
252 if pull.PullSource.RepoAt != nil {
253 x := pull.PullSource.RepoAt.String()
254 sourceRepoAt = &x
255 }
256 }
257
258 var stackId, changeId, parentChangeId *string
259 if pull.StackId != "" {
260 stackId = &pull.StackId
261 }
262 if pull.ChangeId != "" {
263 changeId = &pull.ChangeId
264 }
265 if pull.ParentChangeId != "" {
266 parentChangeId = &pull.ParentChangeId
267 }
268
269 _, err = tx.Exec(
270 `
271 insert into pulls (
272 repo_at, owner_did, pull_id, title, target_branch, body, rkey, state, source_branch, source_repo_at, stack_id, change_id, parent_change_id
273 )
274 values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`,
275 pull.RepoAt,
276 pull.OwnerDid,
277 pull.PullId,
278 pull.Title,
279 pull.TargetBranch,
280 pull.Body,
281 pull.Rkey,
282 pull.State,
283 sourceBranch,
284 sourceRepoAt,
285 stackId,
286 changeId,
287 parentChangeId,
288 )
289 if err != nil {
290 return err
291 }
292
293 _, err = tx.Exec(`
294 insert into pull_submissions (pull_id, repo_at, round_number, patch, source_rev)
295 values (?, ?, ?, ?, ?)
296 `, pull.PullId, pull.RepoAt, 0, pull.Submissions[0].Patch, pull.Submissions[0].SourceRev)
297 return err
298}
299
300func GetPullAt(e Execer, repoAt syntax.ATURI, pullId int) (syntax.ATURI, error) {
301 pull, err := GetPull(e, repoAt, pullId)
302 if err != nil {
303 return "", err
304 }
305 return pull.PullAt(), err
306}
307
308func NextPullId(e Execer, repoAt syntax.ATURI) (int, error) {
309 var pullId int
310 err := e.QueryRow(`select next_pull_id from repo_pull_seqs where repo_at = ?`, repoAt).Scan(&pullId)
311 return pullId - 1, err
312}
313
314func GetPullsWithLimit(e Execer, limit int, filters ...filter) ([]*Pull, error) {
315 pulls := make(map[int]*Pull)
316
317 var conditions []string
318 var args []any
319 for _, filter := range filters {
320 conditions = append(conditions, filter.Condition())
321 args = append(args, filter.Arg()...)
322 }
323
324 whereClause := ""
325 if conditions != nil {
326 whereClause = " where " + strings.Join(conditions, " and ")
327 }
328 limitClause := ""
329 if limit != 0 {
330 limitClause = fmt.Sprintf(" limit %d ", limit)
331 }
332
333 query := fmt.Sprintf(`
334 select
335 owner_did,
336 repo_at,
337 pull_id,
338 created,
339 title,
340 state,
341 target_branch,
342 body,
343 rkey,
344 source_branch,
345 source_repo_at,
346 stack_id,
347 change_id,
348 parent_change_id
349 from
350 pulls
351 %s
352 order by
353 created desc
354 %s
355 `, whereClause, limitClause)
356
357 rows, err := e.Query(query, args...)
358 if err != nil {
359 return nil, err
360 }
361 defer rows.Close()
362
363 for rows.Next() {
364 var pull Pull
365 var createdAt string
366 var sourceBranch, sourceRepoAt, stackId, changeId, parentChangeId sql.NullString
367 err := rows.Scan(
368 &pull.OwnerDid,
369 &pull.RepoAt,
370 &pull.PullId,
371 &createdAt,
372 &pull.Title,
373 &pull.State,
374 &pull.TargetBranch,
375 &pull.Body,
376 &pull.Rkey,
377 &sourceBranch,
378 &sourceRepoAt,
379 &stackId,
380 &changeId,
381 &parentChangeId,
382 )
383 if err != nil {
384 return nil, err
385 }
386
387 createdTime, err := time.Parse(time.RFC3339, createdAt)
388 if err != nil {
389 return nil, err
390 }
391 pull.Created = createdTime
392
393 if sourceBranch.Valid {
394 pull.PullSource = &PullSource{
395 Branch: sourceBranch.String,
396 }
397 if sourceRepoAt.Valid {
398 sourceRepoAtParsed, err := syntax.ParseATURI(sourceRepoAt.String)
399 if err != nil {
400 return nil, err
401 }
402 pull.PullSource.RepoAt = &sourceRepoAtParsed
403 }
404 }
405
406 if stackId.Valid {
407 pull.StackId = stackId.String
408 }
409 if changeId.Valid {
410 pull.ChangeId = changeId.String
411 }
412 if parentChangeId.Valid {
413 pull.ParentChangeId = parentChangeId.String
414 }
415
416 pulls[pull.PullId] = &pull
417 }
418
419 // get latest round no. for each pull
420 inClause := strings.TrimSuffix(strings.Repeat("?, ", len(pulls)), ", ")
421 submissionsQuery := fmt.Sprintf(`
422 select
423 id, pull_id, round_number, patch, created, source_rev
424 from
425 pull_submissions
426 where
427 repo_at in (%s) and pull_id in (%s)
428 `, inClause, inClause)
429
430 args = make([]any, len(pulls)*2)
431 idx := 0
432 for _, p := range pulls {
433 args[idx] = p.RepoAt
434 idx += 1
435 }
436 for _, p := range pulls {
437 args[idx] = p.PullId
438 idx += 1
439 }
440 submissionsRows, err := e.Query(submissionsQuery, args...)
441 if err != nil {
442 return nil, err
443 }
444 defer submissionsRows.Close()
445
446 for submissionsRows.Next() {
447 var s PullSubmission
448 var sourceRev sql.NullString
449 var createdAt string
450 err := submissionsRows.Scan(
451 &s.ID,
452 &s.PullId,
453 &s.RoundNumber,
454 &s.Patch,
455 &createdAt,
456 &sourceRev,
457 )
458 if err != nil {
459 return nil, err
460 }
461
462 createdTime, err := time.Parse(time.RFC3339, createdAt)
463 if err != nil {
464 return nil, err
465 }
466 s.Created = createdTime
467
468 if sourceRev.Valid {
469 s.SourceRev = sourceRev.String
470 }
471
472 if p, ok := pulls[s.PullId]; ok {
473 p.Submissions = make([]*PullSubmission, s.RoundNumber+1)
474 p.Submissions[s.RoundNumber] = &s
475 }
476 }
477 if err := rows.Err(); err != nil {
478 return nil, err
479 }
480
481 // get comment count on latest submission on each pull
482 inClause = strings.TrimSuffix(strings.Repeat("?, ", len(pulls)), ", ")
483 commentsQuery := fmt.Sprintf(`
484 select
485 count(id), pull_id
486 from
487 pull_comments
488 where
489 submission_id in (%s)
490 group by
491 submission_id
492 `, inClause)
493
494 args = []any{}
495 for _, p := range pulls {
496 args = append(args, p.Submissions[p.LastRoundNumber()].ID)
497 }
498 commentsRows, err := e.Query(commentsQuery, args...)
499 if err != nil {
500 return nil, err
501 }
502 defer commentsRows.Close()
503
504 for commentsRows.Next() {
505 var commentCount, pullId int
506 err := commentsRows.Scan(
507 &commentCount,
508 &pullId,
509 )
510 if err != nil {
511 return nil, err
512 }
513 if p, ok := pulls[pullId]; ok {
514 p.Submissions[p.LastRoundNumber()].Comments = make([]PullComment, commentCount)
515 }
516 }
517 if err := rows.Err(); err != nil {
518 return nil, err
519 }
520
521 orderedByPullId := []*Pull{}
522 for _, p := range pulls {
523 orderedByPullId = append(orderedByPullId, p)
524 }
525 sort.Slice(orderedByPullId, func(i, j int) bool {
526 return orderedByPullId[i].PullId > orderedByPullId[j].PullId
527 })
528
529 return orderedByPullId, nil
530}
531
532func GetPulls(e Execer, filters ...filter) ([]*Pull, error) {
533 return GetPullsWithLimit(e, 0, filters...)
534}
535
536func GetPull(e Execer, repoAt syntax.ATURI, pullId int) (*Pull, error) {
537 query := `
538 select
539 owner_did,
540 pull_id,
541 created,
542 title,
543 state,
544 target_branch,
545 repo_at,
546 body,
547 rkey,
548 source_branch,
549 source_repo_at,
550 stack_id,
551 change_id,
552 parent_change_id
553 from
554 pulls
555 where
556 repo_at = ? and pull_id = ?
557 `
558 row := e.QueryRow(query, repoAt, pullId)
559
560 var pull Pull
561 var createdAt string
562 var sourceBranch, sourceRepoAt, stackId, changeId, parentChangeId sql.NullString
563 err := row.Scan(
564 &pull.OwnerDid,
565 &pull.PullId,
566 &createdAt,
567 &pull.Title,
568 &pull.State,
569 &pull.TargetBranch,
570 &pull.RepoAt,
571 &pull.Body,
572 &pull.Rkey,
573 &sourceBranch,
574 &sourceRepoAt,
575 &stackId,
576 &changeId,
577 &parentChangeId,
578 )
579 if err != nil {
580 return nil, err
581 }
582
583 createdTime, err := time.Parse(time.RFC3339, createdAt)
584 if err != nil {
585 return nil, err
586 }
587 pull.Created = createdTime
588
589 // populate source
590 if sourceBranch.Valid {
591 pull.PullSource = &PullSource{
592 Branch: sourceBranch.String,
593 }
594 if sourceRepoAt.Valid {
595 sourceRepoAtParsed, err := syntax.ParseATURI(sourceRepoAt.String)
596 if err != nil {
597 return nil, err
598 }
599 pull.PullSource.RepoAt = &sourceRepoAtParsed
600 }
601 }
602
603 if stackId.Valid {
604 pull.StackId = stackId.String
605 }
606 if changeId.Valid {
607 pull.ChangeId = changeId.String
608 }
609 if parentChangeId.Valid {
610 pull.ParentChangeId = parentChangeId.String
611 }
612
613 submissionsQuery := `
614 select
615 id, pull_id, repo_at, round_number, patch, created, source_rev
616 from
617 pull_submissions
618 where
619 repo_at = ? and pull_id = ?
620 `
621 submissionsRows, err := e.Query(submissionsQuery, repoAt, pullId)
622 if err != nil {
623 return nil, err
624 }
625 defer submissionsRows.Close()
626
627 submissionsMap := make(map[int]*PullSubmission)
628
629 for submissionsRows.Next() {
630 var submission PullSubmission
631 var submissionCreatedStr string
632 var submissionSourceRev sql.NullString
633 err := submissionsRows.Scan(
634 &submission.ID,
635 &submission.PullId,
636 &submission.RepoAt,
637 &submission.RoundNumber,
638 &submission.Patch,
639 &submissionCreatedStr,
640 &submissionSourceRev,
641 )
642 if err != nil {
643 return nil, err
644 }
645
646 submissionCreatedTime, err := time.Parse(time.RFC3339, submissionCreatedStr)
647 if err != nil {
648 return nil, err
649 }
650 submission.Created = submissionCreatedTime
651
652 if submissionSourceRev.Valid {
653 submission.SourceRev = submissionSourceRev.String
654 }
655
656 submissionsMap[submission.ID] = &submission
657 }
658 if err = submissionsRows.Close(); err != nil {
659 return nil, err
660 }
661 if len(submissionsMap) == 0 {
662 return &pull, nil
663 }
664
665 var args []any
666 for k := range submissionsMap {
667 args = append(args, k)
668 }
669 inClause := strings.TrimSuffix(strings.Repeat("?, ", len(submissionsMap)), ", ")
670 commentsQuery := fmt.Sprintf(`
671 select
672 id,
673 pull_id,
674 submission_id,
675 repo_at,
676 owner_did,
677 comment_at,
678 body,
679 created
680 from
681 pull_comments
682 where
683 submission_id IN (%s)
684 order by
685 created asc
686 `, inClause)
687 commentsRows, err := e.Query(commentsQuery, args...)
688 if err != nil {
689 return nil, err
690 }
691 defer commentsRows.Close()
692
693 for commentsRows.Next() {
694 var comment PullComment
695 var commentCreatedStr string
696 err := commentsRows.Scan(
697 &comment.ID,
698 &comment.PullId,
699 &comment.SubmissionId,
700 &comment.RepoAt,
701 &comment.OwnerDid,
702 &comment.CommentAt,
703 &comment.Body,
704 &commentCreatedStr,
705 )
706 if err != nil {
707 return nil, err
708 }
709
710 commentCreatedTime, err := time.Parse(time.RFC3339, commentCreatedStr)
711 if err != nil {
712 return nil, err
713 }
714 comment.Created = commentCreatedTime
715
716 // Add the comment to its submission
717 if submission, ok := submissionsMap[comment.SubmissionId]; ok {
718 submission.Comments = append(submission.Comments, comment)
719 }
720
721 }
722 if err = commentsRows.Err(); err != nil {
723 return nil, err
724 }
725
726 var pullSourceRepo *Repo
727 if pull.PullSource != nil {
728 if pull.PullSource.RepoAt != nil {
729 pullSourceRepo, err = GetRepoByAtUri(e, pull.PullSource.RepoAt.String())
730 if err != nil {
731 log.Printf("failed to get repo by at uri: %v", err)
732 } else {
733 pull.PullSource.Repo = pullSourceRepo
734 }
735 }
736 }
737
738 pull.Submissions = make([]*PullSubmission, len(submissionsMap))
739 for _, submission := range submissionsMap {
740 pull.Submissions[submission.RoundNumber] = submission
741 }
742
743 return &pull, nil
744}
745
746// timeframe here is directly passed into the sql query filter, and any
747// timeframe in the past should be negative; e.g.: "-3 months"
748func GetPullsByOwnerDid(e Execer, did, timeframe string) ([]Pull, error) {
749 var pulls []Pull
750
751 rows, err := e.Query(`
752 select
753 p.owner_did,
754 p.repo_at,
755 p.pull_id,
756 p.created,
757 p.title,
758 p.state,
759 r.did,
760 r.name,
761 r.knot,
762 r.rkey,
763 r.created
764 from
765 pulls p
766 join
767 repos r on p.repo_at = r.at_uri
768 where
769 p.owner_did = ? and p.created >= date ('now', ?)
770 order by
771 p.created desc`, did, timeframe)
772 if err != nil {
773 return nil, err
774 }
775 defer rows.Close()
776
777 for rows.Next() {
778 var pull Pull
779 var repo Repo
780 var pullCreatedAt, repoCreatedAt string
781 err := rows.Scan(
782 &pull.OwnerDid,
783 &pull.RepoAt,
784 &pull.PullId,
785 &pullCreatedAt,
786 &pull.Title,
787 &pull.State,
788 &repo.Did,
789 &repo.Name,
790 &repo.Knot,
791 &repo.Rkey,
792 &repoCreatedAt,
793 )
794 if err != nil {
795 return nil, err
796 }
797
798 pullCreatedTime, err := time.Parse(time.RFC3339, pullCreatedAt)
799 if err != nil {
800 return nil, err
801 }
802 pull.Created = pullCreatedTime
803
804 repoCreatedTime, err := time.Parse(time.RFC3339, repoCreatedAt)
805 if err != nil {
806 return nil, err
807 }
808 repo.Created = repoCreatedTime
809
810 pull.Repo = &repo
811
812 pulls = append(pulls, pull)
813 }
814
815 if err := rows.Err(); err != nil {
816 return nil, err
817 }
818
819 return pulls, nil
820}
821
822func NewPullComment(e Execer, comment *PullComment) (int64, error) {
823 query := `insert into pull_comments (owner_did, repo_at, submission_id, comment_at, pull_id, body) values (?, ?, ?, ?, ?, ?)`
824 res, err := e.Exec(
825 query,
826 comment.OwnerDid,
827 comment.RepoAt,
828 comment.SubmissionId,
829 comment.CommentAt,
830 comment.PullId,
831 comment.Body,
832 )
833 if err != nil {
834 return 0, err
835 }
836
837 i, err := res.LastInsertId()
838 if err != nil {
839 return 0, err
840 }
841
842 return i, nil
843}
844
845func SetPullState(e Execer, repoAt syntax.ATURI, pullId int, pullState PullState) error {
846 _, err := e.Exec(
847 `update pulls set state = ? where repo_at = ? and pull_id = ? and (state <> ? or state <> ?)`,
848 pullState,
849 repoAt,
850 pullId,
851 PullDeleted, // only update state of non-deleted pulls
852 PullMerged, // only update state of non-merged pulls
853 )
854 return err
855}
856
857func ClosePull(e Execer, repoAt syntax.ATURI, pullId int) error {
858 err := SetPullState(e, repoAt, pullId, PullClosed)
859 return err
860}
861
862func ReopenPull(e Execer, repoAt syntax.ATURI, pullId int) error {
863 err := SetPullState(e, repoAt, pullId, PullOpen)
864 return err
865}
866
867func MergePull(e Execer, repoAt syntax.ATURI, pullId int) error {
868 err := SetPullState(e, repoAt, pullId, PullMerged)
869 return err
870}
871
872func DeletePull(e Execer, repoAt syntax.ATURI, pullId int) error {
873 err := SetPullState(e, repoAt, pullId, PullDeleted)
874 return err
875}
876
877func ResubmitPull(e Execer, pull *Pull, newPatch, sourceRev string) error {
878 newRoundNumber := len(pull.Submissions)
879 _, err := e.Exec(`
880 insert into pull_submissions (pull_id, repo_at, round_number, patch, source_rev)
881 values (?, ?, ?, ?, ?)
882 `, pull.PullId, pull.RepoAt, newRoundNumber, newPatch, sourceRev)
883
884 return err
885}
886
887func SetPullParentChangeId(e Execer, parentChangeId string, filters ...filter) error {
888 var conditions []string
889 var args []any
890
891 args = append(args, parentChangeId)
892
893 for _, filter := range filters {
894 conditions = append(conditions, filter.Condition())
895 args = append(args, filter.Arg()...)
896 }
897
898 whereClause := ""
899 if conditions != nil {
900 whereClause = " where " + strings.Join(conditions, " and ")
901 }
902
903 query := fmt.Sprintf("update pulls set parent_change_id = ? %s", whereClause)
904 _, err := e.Exec(query, args...)
905
906 return err
907}
908
909// Only used when stacking to update contents in the event of a rebase (the interdiff should be empty).
910// otherwise submissions are immutable
911func UpdatePull(e Execer, newPatch, sourceRev string, filters ...filter) error {
912 var conditions []string
913 var args []any
914
915 args = append(args, sourceRev)
916 args = append(args, newPatch)
917
918 for _, filter := range filters {
919 conditions = append(conditions, filter.Condition())
920 args = append(args, filter.Arg()...)
921 }
922
923 whereClause := ""
924 if conditions != nil {
925 whereClause = " where " + strings.Join(conditions, " and ")
926 }
927
928 query := fmt.Sprintf("update pull_submissions set source_rev = ?, patch = ? %s", whereClause)
929 _, err := e.Exec(query, args...)
930
931 return err
932}
933
934type PullCount struct {
935 Open int
936 Merged int
937 Closed int
938 Deleted int
939}
940
941func GetPullCount(e Execer, repoAt syntax.ATURI) (PullCount, error) {
942 row := e.QueryRow(`
943 select
944 count(case when state = ? then 1 end) as open_count,
945 count(case when state = ? then 1 end) as merged_count,
946 count(case when state = ? then 1 end) as closed_count,
947 count(case when state = ? then 1 end) as deleted_count
948 from pulls
949 where repo_at = ?`,
950 PullOpen,
951 PullMerged,
952 PullClosed,
953 PullDeleted,
954 repoAt,
955 )
956
957 var count PullCount
958 if err := row.Scan(&count.Open, &count.Merged, &count.Closed, &count.Deleted); err != nil {
959 return PullCount{0, 0, 0, 0}, err
960 }
961
962 return count, nil
963}
964
965type Stack []*Pull
966
967// change-id parent-change-id
968//
969// 4 w ,-------- z (TOP)
970// 3 z <----',------- y
971// 2 y <-----',------ x
972// 1 x <------' nil (BOT)
973//
974// `w` is parent of none, so it is the top of the stack
975func GetStack(e Execer, stackId string) (Stack, error) {
976 unorderedPulls, err := GetPulls(
977 e,
978 FilterEq("stack_id", stackId),
979 FilterNotEq("state", PullDeleted),
980 )
981 if err != nil {
982 return nil, err
983 }
984 // map of parent-change-id to pull
985 changeIdMap := make(map[string]*Pull, len(unorderedPulls))
986 parentMap := make(map[string]*Pull, len(unorderedPulls))
987 for _, p := range unorderedPulls {
988 changeIdMap[p.ChangeId] = p
989 if p.ParentChangeId != "" {
990 parentMap[p.ParentChangeId] = p
991 }
992 }
993
994 // the top of the stack is the pull that is not a parent of any pull
995 var topPull *Pull
996 for _, maybeTop := range unorderedPulls {
997 if _, ok := parentMap[maybeTop.ChangeId]; !ok {
998 topPull = maybeTop
999 break
1000 }
1001 }
1002
1003 pulls := []*Pull{}
1004 for {
1005 pulls = append(pulls, topPull)
1006 if topPull.ParentChangeId != "" {
1007 if next, ok := changeIdMap[topPull.ParentChangeId]; ok {
1008 topPull = next
1009 } else {
1010 return nil, fmt.Errorf("failed to find parent pull request, stack is malformed")
1011 }
1012 } else {
1013 break
1014 }
1015 }
1016
1017 return pulls, nil
1018}
1019
1020func GetAbandonedPulls(e Execer, stackId string) ([]*Pull, error) {
1021 pulls, err := GetPulls(
1022 e,
1023 FilterEq("stack_id", stackId),
1024 FilterEq("state", PullDeleted),
1025 )
1026 if err != nil {
1027 return nil, err
1028 }
1029
1030 return pulls, nil
1031}
1032
1033// position of this pull in the stack
1034func (stack Stack) Position(pull *Pull) int {
1035 return slices.IndexFunc(stack, func(p *Pull) bool {
1036 return p.ChangeId == pull.ChangeId
1037 })
1038}
1039
1040// all pulls below this pull (including self) in this stack
1041//
1042// nil if this pull does not belong to this stack
1043func (stack Stack) Below(pull *Pull) Stack {
1044 position := stack.Position(pull)
1045
1046 if position < 0 {
1047 return nil
1048 }
1049
1050 return stack[position:]
1051}
1052
1053// all pulls below this pull (excluding self) in this stack
1054func (stack Stack) StrictlyBelow(pull *Pull) Stack {
1055 below := stack.Below(pull)
1056
1057 if len(below) > 0 {
1058 return below[1:]
1059 }
1060
1061 return nil
1062}
1063
1064// all pulls above this pull (including self) in this stack
1065func (stack Stack) Above(pull *Pull) Stack {
1066 position := stack.Position(pull)
1067
1068 if position < 0 {
1069 return nil
1070 }
1071
1072 return stack[:position+1]
1073}
1074
1075// all pulls below this pull (excluding self) in this stack
1076func (stack Stack) StrictlyAbove(pull *Pull) Stack {
1077 above := stack.Above(pull)
1078
1079 if len(above) > 0 {
1080 return above[:len(above)-1]
1081 }
1082
1083 return nil
1084}
1085
1086// the combined format-patches of all the newest submissions in this stack
1087func (stack Stack) CombinedPatch() string {
1088 // go in reverse order because the bottom of the stack is the last element in the slice
1089 var combined strings.Builder
1090 for idx := range stack {
1091 pull := stack[len(stack)-1-idx]
1092 combined.WriteString(pull.LatestPatch())
1093 combined.WriteString("\n")
1094 }
1095 return combined.String()
1096}
1097
1098// filter out PRs that are "active"
1099//
1100// PRs that are still open are active
1101func (stack Stack) Mergeable() Stack {
1102 var mergeable Stack
1103
1104 for _, p := range stack {
1105 // stop at the first merged PR
1106 if p.State == PullMerged || p.State == PullClosed {
1107 break
1108 }
1109
1110 // skip over deleted PRs
1111 if p.State != PullDeleted {
1112 mergeable = append(mergeable, p)
1113 }
1114 }
1115
1116 return mergeable
1117}