forked from
tangled.org/core
Monorepo for Tangled — https://tangled.org
1package db
2
3import (
4 "database/sql"
5 "fmt"
6 "log"
7 "slices"
8 "sort"
9 "strings"
10 "time"
11
12 "github.com/bluesky-social/indigo/atproto/syntax"
13 "tangled.sh/tangled.sh/core/api/tangled"
14 "tangled.sh/tangled.sh/core/patchutil"
15 "tangled.sh/tangled.sh/core/types"
16)
17
18type PullState int
19
20const (
21 PullClosed PullState = iota
22 PullOpen
23 PullMerged
24 PullDeleted
25)
26
27func (p PullState) String() string {
28 switch p {
29 case PullOpen:
30 return "open"
31 case PullMerged:
32 return "merged"
33 case PullClosed:
34 return "closed"
35 case PullDeleted:
36 return "deleted"
37 default:
38 return "closed"
39 }
40}
41
42func (p PullState) IsOpen() bool {
43 return p == PullOpen
44}
45func (p PullState) IsMerged() bool {
46 return p == PullMerged
47}
48func (p PullState) IsClosed() bool {
49 return p == PullClosed
50}
51func (p PullState) IsDeleted() bool {
52 return p == PullDeleted
53}
54
55type Pull struct {
56 // ids
57 ID int
58 PullId int
59
60 // at ids
61 RepoAt syntax.ATURI
62 OwnerDid string
63 Rkey string
64
65 // content
66 Title string
67 Body string
68 TargetBranch string
69 State PullState
70 Submissions []*PullSubmission
71
72 // stacking
73 StackId string // nullable string
74 ChangeId string // nullable string
75 ParentChangeId string // nullable string
76
77 // meta
78 Created time.Time
79 PullSource *PullSource
80
81 // optionally, populate this when querying for reverse mappings
82 Repo *Repo
83}
84
85func (p Pull) AsRecord() tangled.RepoPull {
86 var source *tangled.RepoPull_Source
87 if p.PullSource != nil {
88 s := p.PullSource.AsRecord()
89 source = &s
90 source.Sha = p.LatestSha()
91 }
92
93 record := tangled.RepoPull{
94 Title: p.Title,
95 Body: &p.Body,
96 CreatedAt: p.Created.Format(time.RFC3339),
97 PullId: int64(p.PullId),
98 TargetRepo: p.RepoAt.String(),
99 TargetBranch: p.TargetBranch,
100 Patch: p.LatestPatch(),
101 Source: source,
102 }
103 return record
104}
105
106type PullSource struct {
107 Branch string
108 RepoAt *syntax.ATURI
109
110 // optionally populate this for reverse mappings
111 Repo *Repo
112}
113
114func (p PullSource) AsRecord() tangled.RepoPull_Source {
115 var repoAt *string
116 if p.RepoAt != nil {
117 s := p.RepoAt.String()
118 repoAt = &s
119 }
120 record := tangled.RepoPull_Source{
121 Branch: p.Branch,
122 Repo: repoAt,
123 }
124 return record
125}
126
127type PullSubmission struct {
128 // ids
129 ID int
130 PullId int
131
132 // at ids
133 RepoAt syntax.ATURI
134
135 // content
136 RoundNumber int
137 Patch string
138 Comments []PullComment
139 SourceRev string // include the rev that was used to create this submission: only for branch/fork PRs
140
141 // meta
142 Created time.Time
143}
144
145type PullComment struct {
146 // ids
147 ID int
148 PullId int
149 SubmissionId int
150
151 // at ids
152 RepoAt string
153 OwnerDid string
154 CommentAt string
155
156 // content
157 Body string
158
159 // meta
160 Created time.Time
161}
162
163func (p *Pull) LatestPatch() string {
164 latestSubmission := p.Submissions[p.LastRoundNumber()]
165 return latestSubmission.Patch
166}
167
168func (p *Pull) LatestSha() string {
169 latestSubmission := p.Submissions[p.LastRoundNumber()]
170 return latestSubmission.SourceRev
171}
172
173func (p *Pull) PullAt() syntax.ATURI {
174 return syntax.ATURI(fmt.Sprintf("at://%s/%s/%s", p.OwnerDid, tangled.RepoPullNSID, p.Rkey))
175}
176
177func (p *Pull) LastRoundNumber() int {
178 return len(p.Submissions) - 1
179}
180
181func (p *Pull) IsPatchBased() bool {
182 return p.PullSource == nil
183}
184
185func (p *Pull) IsBranchBased() bool {
186 if p.PullSource != nil {
187 if p.PullSource.RepoAt != nil {
188 return p.PullSource.RepoAt == &p.RepoAt
189 } else {
190 // no repo specified
191 return true
192 }
193 }
194 return false
195}
196
197func (p *Pull) IsForkBased() bool {
198 if p.PullSource != nil {
199 if p.PullSource.RepoAt != nil {
200 // make sure repos are different
201 return p.PullSource.RepoAt != &p.RepoAt
202 }
203 }
204 return false
205}
206
207func (p *Pull) IsStacked() bool {
208 return p.StackId != ""
209}
210
211func (s PullSubmission) IsFormatPatch() bool {
212 return patchutil.IsFormatPatch(s.Patch)
213}
214
215func (s PullSubmission) AsFormatPatch() []types.FormatPatch {
216 patches, err := patchutil.ExtractPatches(s.Patch)
217 if err != nil {
218 log.Println("error extracting patches from submission:", err)
219 return []types.FormatPatch{}
220 }
221
222 return patches
223}
224
225func NewPull(tx *sql.Tx, pull *Pull) error {
226 _, err := tx.Exec(`
227 insert or ignore into repo_pull_seqs (repo_at, next_pull_id)
228 values (?, 1)
229 `, pull.RepoAt)
230 if err != nil {
231 return err
232 }
233
234 var nextId int
235 err = tx.QueryRow(`
236 update repo_pull_seqs
237 set next_pull_id = next_pull_id + 1
238 where repo_at = ?
239 returning next_pull_id - 1
240 `, pull.RepoAt).Scan(&nextId)
241 if err != nil {
242 return err
243 }
244
245 pull.PullId = nextId
246 pull.State = PullOpen
247
248 var sourceBranch, sourceRepoAt *string
249 if pull.PullSource != nil {
250 sourceBranch = &pull.PullSource.Branch
251 if pull.PullSource.RepoAt != nil {
252 x := pull.PullSource.RepoAt.String()
253 sourceRepoAt = &x
254 }
255 }
256
257 var stackId, changeId, parentChangeId *string
258 if pull.StackId != "" {
259 stackId = &pull.StackId
260 }
261 if pull.ChangeId != "" {
262 changeId = &pull.ChangeId
263 }
264 if pull.ParentChangeId != "" {
265 parentChangeId = &pull.ParentChangeId
266 }
267
268 _, err = tx.Exec(
269 `
270 insert into pulls (
271 repo_at, owner_did, pull_id, title, target_branch, body, rkey, state, source_branch, source_repo_at, stack_id, change_id, parent_change_id
272 )
273 values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`,
274 pull.RepoAt,
275 pull.OwnerDid,
276 pull.PullId,
277 pull.Title,
278 pull.TargetBranch,
279 pull.Body,
280 pull.Rkey,
281 pull.State,
282 sourceBranch,
283 sourceRepoAt,
284 stackId,
285 changeId,
286 parentChangeId,
287 )
288 if err != nil {
289 return err
290 }
291
292 _, err = tx.Exec(`
293 insert into pull_submissions (pull_id, repo_at, round_number, patch, source_rev)
294 values (?, ?, ?, ?, ?)
295 `, pull.PullId, pull.RepoAt, 0, pull.Submissions[0].Patch, pull.Submissions[0].SourceRev)
296 return err
297}
298
299func GetPullAt(e Execer, repoAt syntax.ATURI, pullId int) (syntax.ATURI, error) {
300 pull, err := GetPull(e, repoAt, pullId)
301 if err != nil {
302 return "", err
303 }
304 return pull.PullAt(), err
305}
306
307func NextPullId(e Execer, repoAt syntax.ATURI) (int, error) {
308 var pullId int
309 err := e.QueryRow(`select next_pull_id from repo_pull_seqs where repo_at = ?`, repoAt).Scan(&pullId)
310 return pullId - 1, err
311}
312
313func GetPullsWithLimit(e Execer, limit int, filters ...filter) ([]*Pull, error) {
314 pulls := make(map[int]*Pull)
315
316 var conditions []string
317 var args []any
318 for _, filter := range filters {
319 conditions = append(conditions, filter.Condition())
320 args = append(args, filter.Arg()...)
321 }
322
323 whereClause := ""
324 if conditions != nil {
325 whereClause = " where " + strings.Join(conditions, " and ")
326 }
327 limitClause := ""
328 if limit != 0 {
329 limitClause = fmt.Sprintf(" limit %d ", limit)
330 }
331
332 query := fmt.Sprintf(`
333 select
334 owner_did,
335 repo_at,
336 pull_id,
337 created,
338 title,
339 state,
340 target_branch,
341 body,
342 rkey,
343 source_branch,
344 source_repo_at,
345 stack_id,
346 change_id,
347 parent_change_id
348 from
349 pulls
350 %s
351 order by
352 created desc
353 %s
354 `, whereClause, limitClause)
355
356 rows, err := e.Query(query, args...)
357 if err != nil {
358 return nil, err
359 }
360 defer rows.Close()
361
362 for rows.Next() {
363 var pull Pull
364 var createdAt string
365 var sourceBranch, sourceRepoAt, stackId, changeId, parentChangeId sql.NullString
366 err := rows.Scan(
367 &pull.OwnerDid,
368 &pull.RepoAt,
369 &pull.PullId,
370 &createdAt,
371 &pull.Title,
372 &pull.State,
373 &pull.TargetBranch,
374 &pull.Body,
375 &pull.Rkey,
376 &sourceBranch,
377 &sourceRepoAt,
378 &stackId,
379 &changeId,
380 &parentChangeId,
381 )
382 if err != nil {
383 return nil, err
384 }
385
386 createdTime, err := time.Parse(time.RFC3339, createdAt)
387 if err != nil {
388 return nil, err
389 }
390 pull.Created = createdTime
391
392 if sourceBranch.Valid {
393 pull.PullSource = &PullSource{
394 Branch: sourceBranch.String,
395 }
396 if sourceRepoAt.Valid {
397 sourceRepoAtParsed, err := syntax.ParseATURI(sourceRepoAt.String)
398 if err != nil {
399 return nil, err
400 }
401 pull.PullSource.RepoAt = &sourceRepoAtParsed
402 }
403 }
404
405 if stackId.Valid {
406 pull.StackId = stackId.String
407 }
408 if changeId.Valid {
409 pull.ChangeId = changeId.String
410 }
411 if parentChangeId.Valid {
412 pull.ParentChangeId = parentChangeId.String
413 }
414
415 pulls[pull.PullId] = &pull
416 }
417
418 // get latest round no. for each pull
419 inClause := strings.TrimSuffix(strings.Repeat("?, ", len(pulls)), ", ")
420 submissionsQuery := fmt.Sprintf(`
421 select
422 id, pull_id, round_number, patch, created, source_rev
423 from
424 pull_submissions
425 where
426 repo_at in (%s) and pull_id in (%s)
427 `, inClause, inClause)
428
429 args = make([]any, len(pulls)*2)
430 idx := 0
431 for _, p := range pulls {
432 args[idx] = p.RepoAt
433 idx += 1
434 }
435 for _, p := range pulls {
436 args[idx] = p.PullId
437 idx += 1
438 }
439 submissionsRows, err := e.Query(submissionsQuery, args...)
440 if err != nil {
441 return nil, err
442 }
443 defer submissionsRows.Close()
444
445 for submissionsRows.Next() {
446 var s PullSubmission
447 var sourceRev sql.NullString
448 var createdAt string
449 err := submissionsRows.Scan(
450 &s.ID,
451 &s.PullId,
452 &s.RoundNumber,
453 &s.Patch,
454 &createdAt,
455 &sourceRev,
456 )
457 if err != nil {
458 return nil, err
459 }
460
461 createdTime, err := time.Parse(time.RFC3339, createdAt)
462 if err != nil {
463 return nil, err
464 }
465 s.Created = createdTime
466
467 if sourceRev.Valid {
468 s.SourceRev = sourceRev.String
469 }
470
471 if p, ok := pulls[s.PullId]; ok {
472 p.Submissions = make([]*PullSubmission, s.RoundNumber+1)
473 p.Submissions[s.RoundNumber] = &s
474 }
475 }
476 if err := rows.Err(); err != nil {
477 return nil, err
478 }
479
480 // get comment count on latest submission on each pull
481 inClause = strings.TrimSuffix(strings.Repeat("?, ", len(pulls)), ", ")
482 commentsQuery := fmt.Sprintf(`
483 select
484 count(id), pull_id
485 from
486 pull_comments
487 where
488 submission_id in (%s)
489 group by
490 submission_id
491 `, inClause)
492
493 args = []any{}
494 for _, p := range pulls {
495 args = append(args, p.Submissions[p.LastRoundNumber()].ID)
496 }
497 commentsRows, err := e.Query(commentsQuery, args...)
498 if err != nil {
499 return nil, err
500 }
501 defer commentsRows.Close()
502
503 for commentsRows.Next() {
504 var commentCount, pullId int
505 err := commentsRows.Scan(
506 &commentCount,
507 &pullId,
508 )
509 if err != nil {
510 return nil, err
511 }
512 if p, ok := pulls[pullId]; ok {
513 p.Submissions[p.LastRoundNumber()].Comments = make([]PullComment, commentCount)
514 }
515 }
516 if err := rows.Err(); err != nil {
517 return nil, err
518 }
519
520 orderedByPullId := []*Pull{}
521 for _, p := range pulls {
522 orderedByPullId = append(orderedByPullId, p)
523 }
524 sort.Slice(orderedByPullId, func(i, j int) bool {
525 return orderedByPullId[i].PullId > orderedByPullId[j].PullId
526 })
527
528 return orderedByPullId, nil
529}
530
531func GetPulls(e Execer, filters ...filter) ([]*Pull, error) {
532 return GetPullsWithLimit(e, 0, filters...)
533}
534
535func GetPull(e Execer, repoAt syntax.ATURI, pullId int) (*Pull, error) {
536 query := `
537 select
538 owner_did,
539 pull_id,
540 created,
541 title,
542 state,
543 target_branch,
544 repo_at,
545 body,
546 rkey,
547 source_branch,
548 source_repo_at,
549 stack_id,
550 change_id,
551 parent_change_id
552 from
553 pulls
554 where
555 repo_at = ? and pull_id = ?
556 `
557 row := e.QueryRow(query, repoAt, pullId)
558
559 var pull Pull
560 var createdAt string
561 var sourceBranch, sourceRepoAt, stackId, changeId, parentChangeId sql.NullString
562 err := row.Scan(
563 &pull.OwnerDid,
564 &pull.PullId,
565 &createdAt,
566 &pull.Title,
567 &pull.State,
568 &pull.TargetBranch,
569 &pull.RepoAt,
570 &pull.Body,
571 &pull.Rkey,
572 &sourceBranch,
573 &sourceRepoAt,
574 &stackId,
575 &changeId,
576 &parentChangeId,
577 )
578 if err != nil {
579 return nil, err
580 }
581
582 createdTime, err := time.Parse(time.RFC3339, createdAt)
583 if err != nil {
584 return nil, err
585 }
586 pull.Created = createdTime
587
588 // populate source
589 if sourceBranch.Valid {
590 pull.PullSource = &PullSource{
591 Branch: sourceBranch.String,
592 }
593 if sourceRepoAt.Valid {
594 sourceRepoAtParsed, err := syntax.ParseATURI(sourceRepoAt.String)
595 if err != nil {
596 return nil, err
597 }
598 pull.PullSource.RepoAt = &sourceRepoAtParsed
599 }
600 }
601
602 if stackId.Valid {
603 pull.StackId = stackId.String
604 }
605 if changeId.Valid {
606 pull.ChangeId = changeId.String
607 }
608 if parentChangeId.Valid {
609 pull.ParentChangeId = parentChangeId.String
610 }
611
612 submissionsQuery := `
613 select
614 id, pull_id, repo_at, round_number, patch, created, source_rev
615 from
616 pull_submissions
617 where
618 repo_at = ? and pull_id = ?
619 `
620 submissionsRows, err := e.Query(submissionsQuery, repoAt, pullId)
621 if err != nil {
622 return nil, err
623 }
624 defer submissionsRows.Close()
625
626 submissionsMap := make(map[int]*PullSubmission)
627
628 for submissionsRows.Next() {
629 var submission PullSubmission
630 var submissionCreatedStr string
631 var submissionSourceRev sql.NullString
632 err := submissionsRows.Scan(
633 &submission.ID,
634 &submission.PullId,
635 &submission.RepoAt,
636 &submission.RoundNumber,
637 &submission.Patch,
638 &submissionCreatedStr,
639 &submissionSourceRev,
640 )
641 if err != nil {
642 return nil, err
643 }
644
645 submissionCreatedTime, err := time.Parse(time.RFC3339, submissionCreatedStr)
646 if err != nil {
647 return nil, err
648 }
649 submission.Created = submissionCreatedTime
650
651 if submissionSourceRev.Valid {
652 submission.SourceRev = submissionSourceRev.String
653 }
654
655 submissionsMap[submission.ID] = &submission
656 }
657 if err = submissionsRows.Close(); err != nil {
658 return nil, err
659 }
660 if len(submissionsMap) == 0 {
661 return &pull, nil
662 }
663
664 var args []any
665 for k := range submissionsMap {
666 args = append(args, k)
667 }
668 inClause := strings.TrimSuffix(strings.Repeat("?, ", len(submissionsMap)), ", ")
669 commentsQuery := fmt.Sprintf(`
670 select
671 id,
672 pull_id,
673 submission_id,
674 repo_at,
675 owner_did,
676 comment_at,
677 body,
678 created
679 from
680 pull_comments
681 where
682 submission_id IN (%s)
683 order by
684 created asc
685 `, inClause)
686 commentsRows, err := e.Query(commentsQuery, args...)
687 if err != nil {
688 return nil, err
689 }
690 defer commentsRows.Close()
691
692 for commentsRows.Next() {
693 var comment PullComment
694 var commentCreatedStr string
695 err := commentsRows.Scan(
696 &comment.ID,
697 &comment.PullId,
698 &comment.SubmissionId,
699 &comment.RepoAt,
700 &comment.OwnerDid,
701 &comment.CommentAt,
702 &comment.Body,
703 &commentCreatedStr,
704 )
705 if err != nil {
706 return nil, err
707 }
708
709 commentCreatedTime, err := time.Parse(time.RFC3339, commentCreatedStr)
710 if err != nil {
711 return nil, err
712 }
713 comment.Created = commentCreatedTime
714
715 // Add the comment to its submission
716 if submission, ok := submissionsMap[comment.SubmissionId]; ok {
717 submission.Comments = append(submission.Comments, comment)
718 }
719
720 }
721 if err = commentsRows.Err(); err != nil {
722 return nil, err
723 }
724
725 var pullSourceRepo *Repo
726 if pull.PullSource != nil {
727 if pull.PullSource.RepoAt != nil {
728 pullSourceRepo, err = GetRepoByAtUri(e, pull.PullSource.RepoAt.String())
729 if err != nil {
730 log.Printf("failed to get repo by at uri: %v", err)
731 } else {
732 pull.PullSource.Repo = pullSourceRepo
733 }
734 }
735 }
736
737 pull.Submissions = make([]*PullSubmission, len(submissionsMap))
738 for _, submission := range submissionsMap {
739 pull.Submissions[submission.RoundNumber] = submission
740 }
741
742 return &pull, nil
743}
744
745// timeframe here is directly passed into the sql query filter, and any
746// timeframe in the past should be negative; e.g.: "-3 months"
747func GetPullsByOwnerDid(e Execer, did, timeframe string) ([]Pull, error) {
748 var pulls []Pull
749
750 rows, err := e.Query(`
751 select
752 p.owner_did,
753 p.repo_at,
754 p.pull_id,
755 p.created,
756 p.title,
757 p.state,
758 r.did,
759 r.name,
760 r.knot,
761 r.rkey,
762 r.created
763 from
764 pulls p
765 join
766 repos r on p.repo_at = r.at_uri
767 where
768 p.owner_did = ? and p.created >= date ('now', ?)
769 order by
770 p.created desc`, did, timeframe)
771 if err != nil {
772 return nil, err
773 }
774 defer rows.Close()
775
776 for rows.Next() {
777 var pull Pull
778 var repo Repo
779 var pullCreatedAt, repoCreatedAt string
780 err := rows.Scan(
781 &pull.OwnerDid,
782 &pull.RepoAt,
783 &pull.PullId,
784 &pullCreatedAt,
785 &pull.Title,
786 &pull.State,
787 &repo.Did,
788 &repo.Name,
789 &repo.Knot,
790 &repo.Rkey,
791 &repoCreatedAt,
792 )
793 if err != nil {
794 return nil, err
795 }
796
797 pullCreatedTime, err := time.Parse(time.RFC3339, pullCreatedAt)
798 if err != nil {
799 return nil, err
800 }
801 pull.Created = pullCreatedTime
802
803 repoCreatedTime, err := time.Parse(time.RFC3339, repoCreatedAt)
804 if err != nil {
805 return nil, err
806 }
807 repo.Created = repoCreatedTime
808
809 pull.Repo = &repo
810
811 pulls = append(pulls, pull)
812 }
813
814 if err := rows.Err(); err != nil {
815 return nil, err
816 }
817
818 return pulls, nil
819}
820
821func NewPullComment(e Execer, comment *PullComment) (int64, error) {
822 query := `insert into pull_comments (owner_did, repo_at, submission_id, comment_at, pull_id, body) values (?, ?, ?, ?, ?, ?)`
823 res, err := e.Exec(
824 query,
825 comment.OwnerDid,
826 comment.RepoAt,
827 comment.SubmissionId,
828 comment.CommentAt,
829 comment.PullId,
830 comment.Body,
831 )
832 if err != nil {
833 return 0, err
834 }
835
836 i, err := res.LastInsertId()
837 if err != nil {
838 return 0, err
839 }
840
841 return i, nil
842}
843
844func SetPullState(e Execer, repoAt syntax.ATURI, pullId int, pullState PullState) error {
845 _, err := e.Exec(
846 `update pulls set state = ? where repo_at = ? and pull_id = ? and (state <> ? or state <> ?)`,
847 pullState,
848 repoAt,
849 pullId,
850 PullDeleted, // only update state of non-deleted pulls
851 PullMerged, // only update state of non-merged pulls
852 )
853 return err
854}
855
856func ClosePull(e Execer, repoAt syntax.ATURI, pullId int) error {
857 err := SetPullState(e, repoAt, pullId, PullClosed)
858 return err
859}
860
861func ReopenPull(e Execer, repoAt syntax.ATURI, pullId int) error {
862 err := SetPullState(e, repoAt, pullId, PullOpen)
863 return err
864}
865
866func MergePull(e Execer, repoAt syntax.ATURI, pullId int) error {
867 err := SetPullState(e, repoAt, pullId, PullMerged)
868 return err
869}
870
871func DeletePull(e Execer, repoAt syntax.ATURI, pullId int) error {
872 err := SetPullState(e, repoAt, pullId, PullDeleted)
873 return err
874}
875
876func ResubmitPull(e Execer, pull *Pull, newPatch, sourceRev string) error {
877 newRoundNumber := len(pull.Submissions)
878 _, err := e.Exec(`
879 insert into pull_submissions (pull_id, repo_at, round_number, patch, source_rev)
880 values (?, ?, ?, ?, ?)
881 `, pull.PullId, pull.RepoAt, newRoundNumber, newPatch, sourceRev)
882
883 return err
884}
885
886func SetPullParentChangeId(e Execer, parentChangeId string, filters ...filter) error {
887 var conditions []string
888 var args []any
889
890 args = append(args, parentChangeId)
891
892 for _, filter := range filters {
893 conditions = append(conditions, filter.Condition())
894 args = append(args, filter.Arg()...)
895 }
896
897 whereClause := ""
898 if conditions != nil {
899 whereClause = " where " + strings.Join(conditions, " and ")
900 }
901
902 query := fmt.Sprintf("update pulls set parent_change_id = ? %s", whereClause)
903 _, err := e.Exec(query, args...)
904
905 return err
906}
907
908// Only used when stacking to update contents in the event of a rebase (the interdiff should be empty).
909// otherwise submissions are immutable
910func UpdatePull(e Execer, newPatch, sourceRev string, filters ...filter) error {
911 var conditions []string
912 var args []any
913
914 args = append(args, sourceRev)
915 args = append(args, newPatch)
916
917 for _, filter := range filters {
918 conditions = append(conditions, filter.Condition())
919 args = append(args, filter.Arg()...)
920 }
921
922 whereClause := ""
923 if conditions != nil {
924 whereClause = " where " + strings.Join(conditions, " and ")
925 }
926
927 query := fmt.Sprintf("update pull_submissions set source_rev = ?, patch = ? %s", whereClause)
928 _, err := e.Exec(query, args...)
929
930 return err
931}
932
933type PullCount struct {
934 Open int
935 Merged int
936 Closed int
937 Deleted int
938}
939
940func GetPullCount(e Execer, repoAt syntax.ATURI) (PullCount, error) {
941 row := e.QueryRow(`
942 select
943 count(case when state = ? then 1 end) as open_count,
944 count(case when state = ? then 1 end) as merged_count,
945 count(case when state = ? then 1 end) as closed_count,
946 count(case when state = ? then 1 end) as deleted_count
947 from pulls
948 where repo_at = ?`,
949 PullOpen,
950 PullMerged,
951 PullClosed,
952 PullDeleted,
953 repoAt,
954 )
955
956 var count PullCount
957 if err := row.Scan(&count.Open, &count.Merged, &count.Closed, &count.Deleted); err != nil {
958 return PullCount{0, 0, 0, 0}, err
959 }
960
961 return count, nil
962}
963
964type Stack []*Pull
965
966// change-id parent-change-id
967//
968// 4 w ,-------- z (TOP)
969// 3 z <----',------- y
970// 2 y <-----',------ x
971// 1 x <------' nil (BOT)
972//
973// `w` is parent of none, so it is the top of the stack
974func GetStack(e Execer, stackId string) (Stack, error) {
975 unorderedPulls, err := GetPulls(
976 e,
977 FilterEq("stack_id", stackId),
978 FilterNotEq("state", PullDeleted),
979 )
980 if err != nil {
981 return nil, err
982 }
983 // map of parent-change-id to pull
984 changeIdMap := make(map[string]*Pull, len(unorderedPulls))
985 parentMap := make(map[string]*Pull, len(unorderedPulls))
986 for _, p := range unorderedPulls {
987 changeIdMap[p.ChangeId] = p
988 if p.ParentChangeId != "" {
989 parentMap[p.ParentChangeId] = p
990 }
991 }
992
993 // the top of the stack is the pull that is not a parent of any pull
994 var topPull *Pull
995 for _, maybeTop := range unorderedPulls {
996 if _, ok := parentMap[maybeTop.ChangeId]; !ok {
997 topPull = maybeTop
998 break
999 }
1000 }
1001
1002 pulls := []*Pull{}
1003 for {
1004 pulls = append(pulls, topPull)
1005 if topPull.ParentChangeId != "" {
1006 if next, ok := changeIdMap[topPull.ParentChangeId]; ok {
1007 topPull = next
1008 } else {
1009 return nil, fmt.Errorf("failed to find parent pull request, stack is malformed")
1010 }
1011 } else {
1012 break
1013 }
1014 }
1015
1016 return pulls, nil
1017}
1018
1019func GetAbandonedPulls(e Execer, stackId string) ([]*Pull, error) {
1020 pulls, err := GetPulls(
1021 e,
1022 FilterEq("stack_id", stackId),
1023 FilterEq("state", PullDeleted),
1024 )
1025 if err != nil {
1026 return nil, err
1027 }
1028
1029 return pulls, nil
1030}
1031
1032// position of this pull in the stack
1033func (stack Stack) Position(pull *Pull) int {
1034 return slices.IndexFunc(stack, func(p *Pull) bool {
1035 return p.ChangeId == pull.ChangeId
1036 })
1037}
1038
1039// all pulls below this pull (including self) in this stack
1040//
1041// nil if this pull does not belong to this stack
1042func (stack Stack) Below(pull *Pull) Stack {
1043 position := stack.Position(pull)
1044
1045 if position < 0 {
1046 return nil
1047 }
1048
1049 return stack[position:]
1050}
1051
1052// all pulls below this pull (excluding self) in this stack
1053func (stack Stack) StrictlyBelow(pull *Pull) Stack {
1054 below := stack.Below(pull)
1055
1056 if len(below) > 0 {
1057 return below[1:]
1058 }
1059
1060 return nil
1061}
1062
1063// all pulls above this pull (including self) in this stack
1064func (stack Stack) Above(pull *Pull) Stack {
1065 position := stack.Position(pull)
1066
1067 if position < 0 {
1068 return nil
1069 }
1070
1071 return stack[:position+1]
1072}
1073
1074// all pulls below this pull (excluding self) in this stack
1075func (stack Stack) StrictlyAbove(pull *Pull) Stack {
1076 above := stack.Above(pull)
1077
1078 if len(above) > 0 {
1079 return above[:len(above)-1]
1080 }
1081
1082 return nil
1083}
1084
1085// the combined format-patches of all the newest submissions in this stack
1086func (stack Stack) CombinedPatch() string {
1087 // go in reverse order because the bottom of the stack is the last element in the slice
1088 var combined strings.Builder
1089 for idx := range stack {
1090 pull := stack[len(stack)-1-idx]
1091 combined.WriteString(pull.LatestPatch())
1092 combined.WriteString("\n")
1093 }
1094 return combined.String()
1095}
1096
1097// filter out PRs that are "active"
1098//
1099// PRs that are still open are active
1100func (stack Stack) Mergeable() Stack {
1101 var mergeable Stack
1102
1103 for _, p := range stack {
1104 // stop at the first merged PR
1105 if p.State == PullMerged || p.State == PullClosed {
1106 break
1107 }
1108
1109 // skip over deleted PRs
1110 if p.State != PullDeleted {
1111 mergeable = append(mergeable, p)
1112 }
1113 }
1114
1115 return mergeable
1116}