1package db
2
3import (
4 "cmp"
5 "database/sql"
6 "errors"
7 "fmt"
8 "maps"
9 "slices"
10 "sort"
11 "strings"
12 "time"
13
14 "github.com/bluesky-social/indigo/atproto/syntax"
15 "tangled.org/core/appview/models"
16)
17
18func NewPull(tx *sql.Tx, pull *models.Pull) error {
19 _, err := tx.Exec(`
20 insert or ignore into repo_pull_seqs (repo_at, next_pull_id)
21 values (?, 1)
22 `, pull.RepoAt)
23 if err != nil {
24 return err
25 }
26
27 var nextId int
28 err = tx.QueryRow(`
29 update repo_pull_seqs
30 set next_pull_id = next_pull_id + 1
31 where repo_at = ?
32 returning next_pull_id - 1
33 `, pull.RepoAt).Scan(&nextId)
34 if err != nil {
35 return err
36 }
37
38 pull.PullId = nextId
39 pull.State = models.PullOpen
40
41 var sourceBranch, sourceRepoAt *string
42 if pull.PullSource != nil {
43 sourceBranch = &pull.PullSource.Branch
44 if pull.PullSource.RepoAt != nil {
45 x := pull.PullSource.RepoAt.String()
46 sourceRepoAt = &x
47 }
48 }
49
50 var stackId, changeId, parentChangeId *string
51 if pull.StackId != "" {
52 stackId = &pull.StackId
53 }
54 if pull.ChangeId != "" {
55 changeId = &pull.ChangeId
56 }
57 if pull.ParentChangeId != "" {
58 parentChangeId = &pull.ParentChangeId
59 }
60
61 result, err := tx.Exec(
62 `
63 insert into pulls (
64 repo_at, owner_did, pull_id, title, target_branch, body, rkey, state, source_branch, source_repo_at, stack_id, change_id, parent_change_id
65 )
66 values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`,
67 pull.RepoAt,
68 pull.OwnerDid,
69 pull.PullId,
70 pull.Title,
71 pull.TargetBranch,
72 pull.Body,
73 pull.Rkey,
74 pull.State,
75 sourceBranch,
76 sourceRepoAt,
77 stackId,
78 changeId,
79 parentChangeId,
80 )
81 if err != nil {
82 return err
83 }
84
85 // Set the database primary key ID
86 id, err := result.LastInsertId()
87 if err != nil {
88 return err
89 }
90 pull.ID = int(id)
91
92 _, err = tx.Exec(`
93 insert into pull_submissions (pull_at, round_number, patch, combined, source_rev)
94 values (?, ?, ?, ?, ?)
95 `, pull.AtUri(), 0, pull.Submissions[0].Patch, pull.Submissions[0].Combined, pull.Submissions[0].SourceRev)
96 if err != nil {
97 return err
98 }
99
100 if err := putReferences(tx, pull.AtUri(), pull.References); err != nil {
101 return fmt.Errorf("put reference_links: %w", err)
102 }
103
104 return nil
105}
106
107func GetPullAt(e Execer, repoAt syntax.ATURI, pullId int) (syntax.ATURI, error) {
108 pull, err := GetPull(e, repoAt, pullId)
109 if err != nil {
110 return "", err
111 }
112 return pull.AtUri(), err
113}
114
115func NextPullId(e Execer, repoAt syntax.ATURI) (int, error) {
116 var pullId int
117 err := e.QueryRow(`select next_pull_id from repo_pull_seqs where repo_at = ?`, repoAt).Scan(&pullId)
118 return pullId - 1, err
119}
120
121func GetPullsWithLimit(e Execer, limit int, filters ...filter) ([]*models.Pull, error) {
122 pulls := make(map[syntax.ATURI]*models.Pull)
123
124 var conditions []string
125 var args []any
126 for _, filter := range filters {
127 conditions = append(conditions, filter.Condition())
128 args = append(args, filter.Arg()...)
129 }
130
131 whereClause := ""
132 if conditions != nil {
133 whereClause = " where " + strings.Join(conditions, " and ")
134 }
135 limitClause := ""
136 if limit != 0 {
137 limitClause = fmt.Sprintf(" limit %d ", limit)
138 }
139
140 query := fmt.Sprintf(`
141 select
142 id,
143 owner_did,
144 repo_at,
145 pull_id,
146 created,
147 title,
148 state,
149 target_branch,
150 body,
151 rkey,
152 source_branch,
153 source_repo_at,
154 stack_id,
155 change_id,
156 parent_change_id
157 from
158 pulls
159 %s
160 order by
161 created desc
162 %s
163 `, whereClause, limitClause)
164
165 rows, err := e.Query(query, args...)
166 if err != nil {
167 return nil, err
168 }
169 defer rows.Close()
170
171 for rows.Next() {
172 var pull models.Pull
173 var createdAt string
174 var sourceBranch, sourceRepoAt, stackId, changeId, parentChangeId sql.NullString
175 err := rows.Scan(
176 &pull.ID,
177 &pull.OwnerDid,
178 &pull.RepoAt,
179 &pull.PullId,
180 &createdAt,
181 &pull.Title,
182 &pull.State,
183 &pull.TargetBranch,
184 &pull.Body,
185 &pull.Rkey,
186 &sourceBranch,
187 &sourceRepoAt,
188 &stackId,
189 &changeId,
190 &parentChangeId,
191 )
192 if err != nil {
193 return nil, err
194 }
195
196 createdTime, err := time.Parse(time.RFC3339, createdAt)
197 if err != nil {
198 return nil, err
199 }
200 pull.Created = createdTime
201
202 if sourceBranch.Valid {
203 pull.PullSource = &models.PullSource{
204 Branch: sourceBranch.String,
205 }
206 if sourceRepoAt.Valid {
207 sourceRepoAtParsed, err := syntax.ParseATURI(sourceRepoAt.String)
208 if err != nil {
209 return nil, err
210 }
211 pull.PullSource.RepoAt = &sourceRepoAtParsed
212 }
213 }
214
215 if stackId.Valid {
216 pull.StackId = stackId.String
217 }
218 if changeId.Valid {
219 pull.ChangeId = changeId.String
220 }
221 if parentChangeId.Valid {
222 pull.ParentChangeId = parentChangeId.String
223 }
224
225 pulls[pull.AtUri()] = &pull
226 }
227
228 var pullAts []syntax.ATURI
229 for _, p := range pulls {
230 pullAts = append(pullAts, p.AtUri())
231 }
232 submissionsMap, err := GetPullSubmissions(e, FilterIn("pull_at", pullAts))
233 if err != nil {
234 return nil, fmt.Errorf("failed to get submissions: %w", err)
235 }
236
237 for pullAt, submissions := range submissionsMap {
238 if p, ok := pulls[pullAt]; ok {
239 p.Submissions = submissions
240 }
241 }
242
243 // collect allLabels for each issue
244 allLabels, err := GetLabels(e, FilterIn("subject", pullAts))
245 if err != nil {
246 return nil, fmt.Errorf("failed to query labels: %w", err)
247 }
248 for pullAt, labels := range allLabels {
249 if p, ok := pulls[pullAt]; ok {
250 p.Labels = labels
251 }
252 }
253
254 // collect pull source for all pulls that need it
255 var sourceAts []syntax.ATURI
256 for _, p := range pulls {
257 if p.PullSource != nil && p.PullSource.RepoAt != nil {
258 sourceAts = append(sourceAts, *p.PullSource.RepoAt)
259 }
260 }
261 sourceRepos, err := GetRepos(e, 0, FilterIn("at_uri", sourceAts))
262 if err != nil && !errors.Is(err, sql.ErrNoRows) {
263 return nil, fmt.Errorf("failed to get source repos: %w", err)
264 }
265 sourceRepoMap := make(map[syntax.ATURI]*models.Repo)
266 for _, r := range sourceRepos {
267 sourceRepoMap[r.RepoAt()] = &r
268 }
269 for _, p := range pulls {
270 if p.PullSource != nil && p.PullSource.RepoAt != nil {
271 if sourceRepo, ok := sourceRepoMap[*p.PullSource.RepoAt]; ok {
272 p.PullSource.Repo = sourceRepo
273 }
274 }
275 }
276
277 allReferences, err := GetReferencesAll(e, FilterIn("from_at", pullAts))
278 if err != nil {
279 return nil, fmt.Errorf("failed to query reference_links: %w", err)
280 }
281 for pullAt, references := range allReferences {
282 if pull, ok := pulls[pullAt]; ok {
283 pull.References = references
284 }
285 }
286
287 orderedByPullId := []*models.Pull{}
288 for _, p := range pulls {
289 orderedByPullId = append(orderedByPullId, p)
290 }
291 sort.Slice(orderedByPullId, func(i, j int) bool {
292 return orderedByPullId[i].PullId > orderedByPullId[j].PullId
293 })
294
295 return orderedByPullId, nil
296}
297
298func GetPulls(e Execer, filters ...filter) ([]*models.Pull, error) {
299 return GetPullsWithLimit(e, 0, filters...)
300}
301
302func GetPullIDs(e Execer, opts models.PullSearchOptions) ([]int64, error) {
303 var ids []int64
304
305 var filters []filter
306 filters = append(filters, FilterEq("state", opts.State))
307 if opts.RepoAt != "" {
308 filters = append(filters, FilterEq("repo_at", opts.RepoAt))
309 }
310
311 var conditions []string
312 var args []any
313
314 for _, filter := range filters {
315 conditions = append(conditions, filter.Condition())
316 args = append(args, filter.Arg()...)
317 }
318
319 whereClause := ""
320 if conditions != nil {
321 whereClause = " where " + strings.Join(conditions, " and ")
322 }
323 pageClause := ""
324 if opts.Page.Limit != 0 {
325 pageClause = fmt.Sprintf(
326 " limit %d offset %d ",
327 opts.Page.Limit,
328 opts.Page.Offset,
329 )
330 }
331
332 query := fmt.Sprintf(
333 `
334 select
335 id
336 from
337 pulls
338 %s
339 %s`,
340 whereClause,
341 pageClause,
342 )
343 args = append(args, opts.Page.Limit, opts.Page.Offset)
344 rows, err := e.Query(query, args...)
345 if err != nil {
346 return nil, err
347 }
348 defer rows.Close()
349
350 for rows.Next() {
351 var id int64
352 err := rows.Scan(&id)
353 if err != nil {
354 return nil, err
355 }
356
357 ids = append(ids, id)
358 }
359
360 return ids, nil
361}
362
363func GetPull(e Execer, repoAt syntax.ATURI, pullId int) (*models.Pull, error) {
364 pulls, err := GetPullsWithLimit(e, 1, FilterEq("repo_at", repoAt), FilterEq("pull_id", pullId))
365 if err != nil {
366 return nil, err
367 }
368 if len(pulls) == 0 {
369 return nil, sql.ErrNoRows
370 }
371
372 return pulls[0], nil
373}
374
375// mapping from pull -> pull submissions
376func GetPullSubmissions(e Execer, filters ...filter) (map[syntax.ATURI][]*models.PullSubmission, error) {
377 var conditions []string
378 var args []any
379 for _, filter := range filters {
380 conditions = append(conditions, filter.Condition())
381 args = append(args, filter.Arg()...)
382 }
383
384 whereClause := ""
385 if conditions != nil {
386 whereClause = " where " + strings.Join(conditions, " and ")
387 }
388
389 query := fmt.Sprintf(`
390 select
391 id,
392 pull_at,
393 round_number,
394 patch,
395 combined,
396 created,
397 source_rev
398 from
399 pull_submissions
400 %s
401 order by
402 round_number asc
403 `, whereClause)
404
405 rows, err := e.Query(query, args...)
406 if err != nil {
407 return nil, err
408 }
409 defer rows.Close()
410
411 submissionMap := make(map[int]*models.PullSubmission)
412
413 for rows.Next() {
414 var submission models.PullSubmission
415 var submissionCreatedStr string
416 var submissionSourceRev, submissionCombined sql.NullString
417 err := rows.Scan(
418 &submission.ID,
419 &submission.PullAt,
420 &submission.RoundNumber,
421 &submission.Patch,
422 &submissionCombined,
423 &submissionCreatedStr,
424 &submissionSourceRev,
425 )
426 if err != nil {
427 return nil, err
428 }
429
430 if t, err := time.Parse(time.RFC3339, submissionCreatedStr); err == nil {
431 submission.Created = t
432 }
433
434 if submissionSourceRev.Valid {
435 submission.SourceRev = submissionSourceRev.String
436 }
437
438 if submissionCombined.Valid {
439 submission.Combined = submissionCombined.String
440 }
441
442 submissionMap[submission.ID] = &submission
443 }
444
445 if err := rows.Err(); err != nil {
446 return nil, err
447 }
448
449 // Get comments for all submissions using GetPullComments
450 submissionIds := slices.Collect(maps.Keys(submissionMap))
451 comments, err := GetPullComments(e, FilterIn("submission_id", submissionIds))
452 if err != nil {
453 return nil, fmt.Errorf("failed to get pull comments: %w", err)
454 }
455 for _, comment := range comments {
456 if submission, ok := submissionMap[comment.SubmissionId]; ok {
457 submission.Comments = append(submission.Comments, comment)
458 }
459 }
460
461 // group the submissions by pull_at
462 m := make(map[syntax.ATURI][]*models.PullSubmission)
463 for _, s := range submissionMap {
464 m[s.PullAt] = append(m[s.PullAt], s)
465 }
466
467 // sort each one by round number
468 for _, s := range m {
469 slices.SortFunc(s, func(a, b *models.PullSubmission) int {
470 return cmp.Compare(a.RoundNumber, b.RoundNumber)
471 })
472 }
473
474 return m, nil
475}
476
477func GetPullComments(e Execer, filters ...filter) ([]models.PullComment, error) {
478 var conditions []string
479 var args []any
480 for _, filter := range filters {
481 conditions = append(conditions, filter.Condition())
482 args = append(args, filter.Arg()...)
483 }
484
485 whereClause := ""
486 if conditions != nil {
487 whereClause = " where " + strings.Join(conditions, " and ")
488 }
489
490 query := fmt.Sprintf(`
491 select
492 id,
493 pull_id,
494 submission_id,
495 repo_at,
496 owner_did,
497 comment_at,
498 body,
499 created
500 from
501 pull_comments
502 %s
503 order by
504 created asc
505 `, whereClause)
506
507 rows, err := e.Query(query, args...)
508 if err != nil {
509 return nil, err
510 }
511 defer rows.Close()
512
513 commentMap := make(map[string]*models.PullComment)
514 for rows.Next() {
515 var comment models.PullComment
516 var createdAt string
517 err := rows.Scan(
518 &comment.ID,
519 &comment.PullId,
520 &comment.SubmissionId,
521 &comment.RepoAt,
522 &comment.OwnerDid,
523 &comment.CommentAt,
524 &comment.Body,
525 &createdAt,
526 )
527 if err != nil {
528 return nil, err
529 }
530
531 if t, err := time.Parse(time.RFC3339, createdAt); err == nil {
532 comment.Created = t
533 }
534
535 atUri := comment.AtUri().String()
536 commentMap[atUri] = &comment
537 }
538
539 if err := rows.Err(); err != nil {
540 return nil, err
541 }
542
543 // collect references for each comments
544 commentAts := slices.Collect(maps.Keys(commentMap))
545 allReferencs, err := GetReferencesAll(e, FilterIn("from_at", commentAts))
546 if err != nil {
547 return nil, fmt.Errorf("failed to query reference_links: %w", err)
548 }
549 for commentAt, references := range allReferencs {
550 if comment, ok := commentMap[commentAt.String()]; ok {
551 comment.References = references
552 }
553 }
554
555 var comments []models.PullComment
556 for _, c := range commentMap {
557 comments = append(comments, *c)
558 }
559
560 sort.Slice(comments, func(i, j int) bool {
561 return comments[i].Created.Before(comments[j].Created)
562 })
563
564 return comments, nil
565}
566
567// timeframe here is directly passed into the sql query filter, and any
568// timeframe in the past should be negative; e.g.: "-3 months"
569func GetPullsByOwnerDid(e Execer, did, timeframe string) ([]models.Pull, error) {
570 var pulls []models.Pull
571
572 rows, err := e.Query(`
573 select
574 p.owner_did,
575 p.repo_at,
576 p.pull_id,
577 p.created,
578 p.title,
579 p.state,
580 r.did,
581 r.name,
582 r.knot,
583 r.rkey,
584 r.created
585 from
586 pulls p
587 join
588 repos r on p.repo_at = r.at_uri
589 where
590 p.owner_did = ? and p.created >= date ('now', ?)
591 order by
592 p.created desc`, did, timeframe)
593 if err != nil {
594 return nil, err
595 }
596 defer rows.Close()
597
598 for rows.Next() {
599 var pull models.Pull
600 var repo models.Repo
601 var pullCreatedAt, repoCreatedAt string
602 err := rows.Scan(
603 &pull.OwnerDid,
604 &pull.RepoAt,
605 &pull.PullId,
606 &pullCreatedAt,
607 &pull.Title,
608 &pull.State,
609 &repo.Did,
610 &repo.Name,
611 &repo.Knot,
612 &repo.Rkey,
613 &repoCreatedAt,
614 )
615 if err != nil {
616 return nil, err
617 }
618
619 pullCreatedTime, err := time.Parse(time.RFC3339, pullCreatedAt)
620 if err != nil {
621 return nil, err
622 }
623 pull.Created = pullCreatedTime
624
625 repoCreatedTime, err := time.Parse(time.RFC3339, repoCreatedAt)
626 if err != nil {
627 return nil, err
628 }
629 repo.Created = repoCreatedTime
630
631 pull.Repo = &repo
632
633 pulls = append(pulls, pull)
634 }
635
636 if err := rows.Err(); err != nil {
637 return nil, err
638 }
639
640 return pulls, nil
641}
642
643func NewPullComment(tx *sql.Tx, comment *models.PullComment) (int64, error) {
644 query := `insert into pull_comments (owner_did, repo_at, submission_id, comment_at, pull_id, body) values (?, ?, ?, ?, ?, ?)`
645 res, err := tx.Exec(
646 query,
647 comment.OwnerDid,
648 comment.RepoAt,
649 comment.SubmissionId,
650 comment.CommentAt,
651 comment.PullId,
652 comment.Body,
653 )
654 if err != nil {
655 return 0, err
656 }
657
658 i, err := res.LastInsertId()
659 if err != nil {
660 return 0, err
661 }
662
663 if err := putReferences(tx, comment.AtUri(), comment.References); err != nil {
664 return 0, fmt.Errorf("put reference_links: %w", err)
665 }
666
667 return i, nil
668}
669
670func SetPullState(e Execer, repoAt syntax.ATURI, pullId int, pullState models.PullState) error {
671 _, err := e.Exec(
672 `update pulls set state = ? where repo_at = ? and pull_id = ? and (state <> ? or state <> ?)`,
673 pullState,
674 repoAt,
675 pullId,
676 models.PullDeleted, // only update state of non-deleted pulls
677 models.PullMerged, // only update state of non-merged pulls
678 )
679 return err
680}
681
682func ClosePull(e Execer, repoAt syntax.ATURI, pullId int) error {
683 err := SetPullState(e, repoAt, pullId, models.PullClosed)
684 return err
685}
686
687func ReopenPull(e Execer, repoAt syntax.ATURI, pullId int) error {
688 err := SetPullState(e, repoAt, pullId, models.PullOpen)
689 return err
690}
691
692func MergePull(e Execer, repoAt syntax.ATURI, pullId int) error {
693 err := SetPullState(e, repoAt, pullId, models.PullMerged)
694 return err
695}
696
697func DeletePull(e Execer, repoAt syntax.ATURI, pullId int) error {
698 err := SetPullState(e, repoAt, pullId, models.PullDeleted)
699 return err
700}
701
702func ResubmitPull(e Execer, pullAt syntax.ATURI, newRoundNumber int, newPatch string, combinedPatch string, newSourceRev string) error {
703 _, err := e.Exec(`
704 insert into pull_submissions (pull_at, round_number, patch, combined, source_rev)
705 values (?, ?, ?, ?, ?)
706 `, pullAt, newRoundNumber, newPatch, combinedPatch, newSourceRev)
707
708 return err
709}
710
711func SetPullParentChangeId(e Execer, parentChangeId string, filters ...filter) error {
712 var conditions []string
713 var args []any
714
715 args = append(args, parentChangeId)
716
717 for _, filter := range filters {
718 conditions = append(conditions, filter.Condition())
719 args = append(args, filter.Arg()...)
720 }
721
722 whereClause := ""
723 if conditions != nil {
724 whereClause = " where " + strings.Join(conditions, " and ")
725 }
726
727 query := fmt.Sprintf("update pulls set parent_change_id = ? %s", whereClause)
728 _, err := e.Exec(query, args...)
729
730 return err
731}
732
733// Only used when stacking to update contents in the event of a rebase (the interdiff should be empty).
734// otherwise submissions are immutable
735func UpdatePull(e Execer, newPatch, sourceRev string, filters ...filter) error {
736 var conditions []string
737 var args []any
738
739 args = append(args, sourceRev)
740 args = append(args, newPatch)
741
742 for _, filter := range filters {
743 conditions = append(conditions, filter.Condition())
744 args = append(args, filter.Arg()...)
745 }
746
747 whereClause := ""
748 if conditions != nil {
749 whereClause = " where " + strings.Join(conditions, " and ")
750 }
751
752 query := fmt.Sprintf("update pull_submissions set source_rev = ?, patch = ? %s", whereClause)
753 _, err := e.Exec(query, args...)
754
755 return err
756}
757
758func GetPullCount(e Execer, repoAt syntax.ATURI) (models.PullCount, error) {
759 row := e.QueryRow(`
760 select
761 count(case when state = ? then 1 end) as open_count,
762 count(case when state = ? then 1 end) as merged_count,
763 count(case when state = ? then 1 end) as closed_count,
764 count(case when state = ? then 1 end) as deleted_count
765 from pulls
766 where repo_at = ?`,
767 models.PullOpen,
768 models.PullMerged,
769 models.PullClosed,
770 models.PullDeleted,
771 repoAt,
772 )
773
774 var count models.PullCount
775 if err := row.Scan(&count.Open, &count.Merged, &count.Closed, &count.Deleted); err != nil {
776 return models.PullCount{Open: 0, Merged: 0, Closed: 0, Deleted: 0}, err
777 }
778
779 return count, nil
780}
781
782// change-id parent-change-id
783//
784// 4 w ,-------- z (TOP)
785// 3 z <----',------- y
786// 2 y <-----',------ x
787// 1 x <------' nil (BOT)
788//
789// `w` is parent of none, so it is the top of the stack
790func GetStack(e Execer, stackId string) (models.Stack, error) {
791 unorderedPulls, err := GetPulls(
792 e,
793 FilterEq("stack_id", stackId),
794 FilterNotEq("state", models.PullDeleted),
795 )
796 if err != nil {
797 return nil, err
798 }
799 // map of parent-change-id to pull
800 changeIdMap := make(map[string]*models.Pull, len(unorderedPulls))
801 parentMap := make(map[string]*models.Pull, len(unorderedPulls))
802 for _, p := range unorderedPulls {
803 changeIdMap[p.ChangeId] = p
804 if p.ParentChangeId != "" {
805 parentMap[p.ParentChangeId] = p
806 }
807 }
808
809 // the top of the stack is the pull that is not a parent of any pull
810 var topPull *models.Pull
811 for _, maybeTop := range unorderedPulls {
812 if _, ok := parentMap[maybeTop.ChangeId]; !ok {
813 topPull = maybeTop
814 break
815 }
816 }
817
818 pulls := []*models.Pull{}
819 for {
820 pulls = append(pulls, topPull)
821 if topPull.ParentChangeId != "" {
822 if next, ok := changeIdMap[topPull.ParentChangeId]; ok {
823 topPull = next
824 } else {
825 return nil, fmt.Errorf("failed to find parent pull request, stack is malformed")
826 }
827 } else {
828 break
829 }
830 }
831
832 return pulls, nil
833}
834
835func GetAbandonedPulls(e Execer, stackId string) ([]*models.Pull, error) {
836 pulls, err := GetPulls(
837 e,
838 FilterEq("stack_id", stackId),
839 FilterEq("state", models.PullDeleted),
840 )
841 if err != nil {
842 return nil, err
843 }
844
845 return pulls, nil
846}