1package db
2
3import (
4 "cmp"
5 "database/sql"
6 "errors"
7 "fmt"
8 "maps"
9 "slices"
10 "sort"
11 "strings"
12 "time"
13
14 "github.com/bluesky-social/indigo/atproto/syntax"
15 "tangled.org/core/appview/models"
16)
17
18func NewPull(tx *sql.Tx, pull *models.Pull) error {
19 _, err := tx.Exec(`
20 insert or ignore into repo_pull_seqs (repo_at, next_pull_id)
21 values (?, 1)
22 `, pull.RepoAt)
23 if err != nil {
24 return err
25 }
26
27 var nextId int
28 err = tx.QueryRow(`
29 update repo_pull_seqs
30 set next_pull_id = next_pull_id + 1
31 where repo_at = ?
32 returning next_pull_id - 1
33 `, pull.RepoAt).Scan(&nextId)
34 if err != nil {
35 return err
36 }
37
38 pull.PullId = nextId
39 pull.State = models.PullOpen
40
41 var sourceBranch, sourceRepoAt *string
42 if pull.PullSource != nil {
43 sourceBranch = &pull.PullSource.Branch
44 if pull.PullSource.RepoAt != nil {
45 x := pull.PullSource.RepoAt.String()
46 sourceRepoAt = &x
47 }
48 }
49
50 var stackId, changeId, parentChangeId *string
51 if pull.StackId != "" {
52 stackId = &pull.StackId
53 }
54 if pull.ChangeId != "" {
55 changeId = &pull.ChangeId
56 }
57 if pull.ParentChangeId != "" {
58 parentChangeId = &pull.ParentChangeId
59 }
60
61 result, err := tx.Exec(
62 `
63 insert into pulls (
64 repo_at, owner_did, pull_id, title, target_branch, body, rkey, state, source_branch, source_repo_at, stack_id, change_id, parent_change_id
65 )
66 values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`,
67 pull.RepoAt,
68 pull.OwnerDid,
69 pull.PullId,
70 pull.Title,
71 pull.TargetBranch,
72 pull.Body,
73 pull.Rkey,
74 pull.State,
75 sourceBranch,
76 sourceRepoAt,
77 stackId,
78 changeId,
79 parentChangeId,
80 )
81 if err != nil {
82 return err
83 }
84
85 // Set the database primary key ID
86 id, err := result.LastInsertId()
87 if err != nil {
88 return err
89 }
90 pull.ID = int(id)
91
92 _, err = tx.Exec(`
93 insert into pull_submissions (pull_at, round_number, patch, combined, source_rev)
94 values (?, ?, ?, ?, ?)
95 `, pull.AtUri(), 0, pull.Submissions[0].Patch, pull.Submissions[0].Combined, pull.Submissions[0].SourceRev)
96 return err
97}
98
99func GetPullAt(e Execer, repoAt syntax.ATURI, pullId int) (syntax.ATURI, error) {
100 pull, err := GetPull(e, repoAt, pullId)
101 if err != nil {
102 return "", err
103 }
104 return pull.AtUri(), err
105}
106
107func NextPullId(e Execer, repoAt syntax.ATURI) (int, error) {
108 var pullId int
109 err := e.QueryRow(`select next_pull_id from repo_pull_seqs where repo_at = ?`, repoAt).Scan(&pullId)
110 return pullId - 1, err
111}
112
113func GetPullsWithLimit(e Execer, limit int, filters ...filter) ([]*models.Pull, error) {
114 pulls := make(map[syntax.ATURI]*models.Pull)
115
116 var conditions []string
117 var args []any
118 for _, filter := range filters {
119 conditions = append(conditions, filter.Condition())
120 args = append(args, filter.Arg()...)
121 }
122
123 whereClause := ""
124 if conditions != nil {
125 whereClause = " where " + strings.Join(conditions, " and ")
126 }
127 limitClause := ""
128 if limit != 0 {
129 limitClause = fmt.Sprintf(" limit %d ", limit)
130 }
131
132 query := fmt.Sprintf(`
133 select
134 id,
135 owner_did,
136 repo_at,
137 pull_id,
138 created,
139 title,
140 state,
141 target_branch,
142 body,
143 rkey,
144 source_branch,
145 source_repo_at,
146 stack_id,
147 change_id,
148 parent_change_id
149 from
150 pulls
151 %s
152 order by
153 created desc
154 %s
155 `, whereClause, limitClause)
156
157 rows, err := e.Query(query, args...)
158 if err != nil {
159 return nil, err
160 }
161 defer rows.Close()
162
163 for rows.Next() {
164 var pull models.Pull
165 var createdAt string
166 var sourceBranch, sourceRepoAt, stackId, changeId, parentChangeId sql.NullString
167 err := rows.Scan(
168 &pull.ID,
169 &pull.OwnerDid,
170 &pull.RepoAt,
171 &pull.PullId,
172 &createdAt,
173 &pull.Title,
174 &pull.State,
175 &pull.TargetBranch,
176 &pull.Body,
177 &pull.Rkey,
178 &sourceBranch,
179 &sourceRepoAt,
180 &stackId,
181 &changeId,
182 &parentChangeId,
183 )
184 if err != nil {
185 return nil, err
186 }
187
188 createdTime, err := time.Parse(time.RFC3339, createdAt)
189 if err != nil {
190 return nil, err
191 }
192 pull.Created = createdTime
193
194 if sourceBranch.Valid {
195 pull.PullSource = &models.PullSource{
196 Branch: sourceBranch.String,
197 }
198 if sourceRepoAt.Valid {
199 sourceRepoAtParsed, err := syntax.ParseATURI(sourceRepoAt.String)
200 if err != nil {
201 return nil, err
202 }
203 pull.PullSource.RepoAt = &sourceRepoAtParsed
204 }
205 }
206
207 if stackId.Valid {
208 pull.StackId = stackId.String
209 }
210 if changeId.Valid {
211 pull.ChangeId = changeId.String
212 }
213 if parentChangeId.Valid {
214 pull.ParentChangeId = parentChangeId.String
215 }
216
217 pulls[pull.AtUri()] = &pull
218 }
219
220 var pullAts []syntax.ATURI
221 for _, p := range pulls {
222 pullAts = append(pullAts, p.AtUri())
223 }
224 submissionsMap, err := GetPullSubmissions(e, FilterIn("pull_at", pullAts))
225 if err != nil {
226 return nil, fmt.Errorf("failed to get submissions: %w", err)
227 }
228
229 for pullAt, submissions := range submissionsMap {
230 if p, ok := pulls[pullAt]; ok {
231 p.Submissions = submissions
232 }
233 }
234
235 // collect allLabels for each issue
236 allLabels, err := GetLabels(e, FilterIn("subject", pullAts))
237 if err != nil {
238 return nil, fmt.Errorf("failed to query labels: %w", err)
239 }
240 for pullAt, labels := range allLabels {
241 if p, ok := pulls[pullAt]; ok {
242 p.Labels = labels
243 }
244 }
245
246 // collect pull source for all pulls that need it
247 var sourceAts []syntax.ATURI
248 for _, p := range pulls {
249 if p.PullSource != nil && p.PullSource.RepoAt != nil {
250 sourceAts = append(sourceAts, *p.PullSource.RepoAt)
251 }
252 }
253 sourceRepos, err := GetRepos(e, 0, FilterIn("at_uri", sourceAts))
254 if err != nil && !errors.Is(err, sql.ErrNoRows) {
255 return nil, fmt.Errorf("failed to get source repos: %w", err)
256 }
257 sourceRepoMap := make(map[syntax.ATURI]*models.Repo)
258 for _, r := range sourceRepos {
259 sourceRepoMap[r.RepoAt()] = &r
260 }
261 for _, p := range pulls {
262 if p.PullSource != nil && p.PullSource.RepoAt != nil {
263 if sourceRepo, ok := sourceRepoMap[*p.PullSource.RepoAt]; ok {
264 p.PullSource.Repo = sourceRepo
265 }
266 }
267 }
268
269 orderedByPullId := []*models.Pull{}
270 for _, p := range pulls {
271 orderedByPullId = append(orderedByPullId, p)
272 }
273 sort.Slice(orderedByPullId, func(i, j int) bool {
274 return orderedByPullId[i].PullId > orderedByPullId[j].PullId
275 })
276
277 return orderedByPullId, nil
278}
279
280func GetPulls(e Execer, filters ...filter) ([]*models.Pull, error) {
281 return GetPullsWithLimit(e, 0, filters...)
282}
283
284func GetPullIDs(e Execer, opts models.PullSearchOptions) ([]int64, error) {
285 var ids []int64
286
287 var filters []filter
288 filters = append(filters, FilterEq("state", opts.State))
289 if opts.RepoAt != "" {
290 filters = append(filters, FilterEq("repo_at", opts.RepoAt))
291 }
292
293 var conditions []string
294 var args []any
295
296 for _, filter := range filters {
297 conditions = append(conditions, filter.Condition())
298 args = append(args, filter.Arg()...)
299 }
300
301 whereClause := ""
302 if conditions != nil {
303 whereClause = " where " + strings.Join(conditions, " and ")
304 }
305 pageClause := ""
306 if opts.Page.Limit != 0 {
307 pageClause = fmt.Sprintf(
308 " limit %d offset %d ",
309 opts.Page.Limit,
310 opts.Page.Offset,
311 )
312 }
313
314 query := fmt.Sprintf(
315 `
316 select
317 id
318 from
319 pulls
320 %s
321 %s`,
322 whereClause,
323 pageClause,
324 )
325 args = append(args, opts.Page.Limit, opts.Page.Offset)
326 rows, err := e.Query(query, args...)
327 if err != nil {
328 return nil, err
329 }
330 defer rows.Close()
331
332 for rows.Next() {
333 var id int64
334 err := rows.Scan(&id)
335 if err != nil {
336 return nil, err
337 }
338
339 ids = append(ids, id)
340 }
341
342 return ids, nil
343}
344
345func GetPull(e Execer, repoAt syntax.ATURI, pullId int) (*models.Pull, error) {
346 pulls, err := GetPullsWithLimit(e, 1, FilterEq("repo_at", repoAt), FilterEq("pull_id", pullId))
347 if err != nil {
348 return nil, err
349 }
350 if len(pulls) == 0 {
351 return nil, sql.ErrNoRows
352 }
353
354 return pulls[0], nil
355}
356
357// mapping from pull -> pull submissions
358func GetPullSubmissions(e Execer, filters ...filter) (map[syntax.ATURI][]*models.PullSubmission, error) {
359 var conditions []string
360 var args []any
361 for _, filter := range filters {
362 conditions = append(conditions, filter.Condition())
363 args = append(args, filter.Arg()...)
364 }
365
366 whereClause := ""
367 if conditions != nil {
368 whereClause = " where " + strings.Join(conditions, " and ")
369 }
370
371 query := fmt.Sprintf(`
372 select
373 id,
374 pull_at,
375 round_number,
376 patch,
377 combined,
378 created,
379 source_rev
380 from
381 pull_submissions
382 %s
383 order by
384 round_number asc
385 `, whereClause)
386
387 rows, err := e.Query(query, args...)
388 if err != nil {
389 return nil, err
390 }
391 defer rows.Close()
392
393 submissionMap := make(map[int]*models.PullSubmission)
394
395 for rows.Next() {
396 var submission models.PullSubmission
397 var submissionCreatedStr string
398 var submissionSourceRev, submissionCombined sql.NullString
399 err := rows.Scan(
400 &submission.ID,
401 &submission.PullAt,
402 &submission.RoundNumber,
403 &submission.Patch,
404 &submissionCombined,
405 &submissionCreatedStr,
406 &submissionSourceRev,
407 )
408 if err != nil {
409 return nil, err
410 }
411
412 if t, err := time.Parse(time.RFC3339, submissionCreatedStr); err == nil {
413 submission.Created = t
414 }
415
416 if submissionSourceRev.Valid {
417 submission.SourceRev = submissionSourceRev.String
418 }
419
420 if submissionCombined.Valid {
421 submission.Combined = submissionCombined.String
422 }
423
424 submissionMap[submission.ID] = &submission
425 }
426
427 if err := rows.Err(); err != nil {
428 return nil, err
429 }
430
431 // Get comments for all submissions using GetPullComments
432 submissionIds := slices.Collect(maps.Keys(submissionMap))
433 comments, err := GetPullComments(e, FilterIn("submission_id", submissionIds))
434 if err != nil {
435 return nil, err
436 }
437 for _, comment := range comments {
438 if submission, ok := submissionMap[comment.SubmissionId]; ok {
439 submission.Comments = append(submission.Comments, comment)
440 }
441 }
442
443 // group the submissions by pull_at
444 m := make(map[syntax.ATURI][]*models.PullSubmission)
445 for _, s := range submissionMap {
446 m[s.PullAt] = append(m[s.PullAt], s)
447 }
448
449 // sort each one by round number
450 for _, s := range m {
451 slices.SortFunc(s, func(a, b *models.PullSubmission) int {
452 return cmp.Compare(a.RoundNumber, b.RoundNumber)
453 })
454 }
455
456 return m, nil
457}
458
459func GetPullComments(e Execer, filters ...filter) ([]models.PullComment, error) {
460 var conditions []string
461 var args []any
462 for _, filter := range filters {
463 conditions = append(conditions, filter.Condition())
464 args = append(args, filter.Arg()...)
465 }
466
467 whereClause := ""
468 if conditions != nil {
469 whereClause = " where " + strings.Join(conditions, " and ")
470 }
471
472 query := fmt.Sprintf(`
473 select
474 id,
475 pull_id,
476 submission_id,
477 repo_at,
478 owner_did,
479 comment_at,
480 body,
481 created
482 from
483 pull_comments
484 %s
485 order by
486 created asc
487 `, whereClause)
488
489 rows, err := e.Query(query, args...)
490 if err != nil {
491 return nil, err
492 }
493 defer rows.Close()
494
495 var comments []models.PullComment
496 for rows.Next() {
497 var comment models.PullComment
498 var createdAt string
499 err := rows.Scan(
500 &comment.ID,
501 &comment.PullId,
502 &comment.SubmissionId,
503 &comment.RepoAt,
504 &comment.OwnerDid,
505 &comment.CommentAt,
506 &comment.Body,
507 &createdAt,
508 )
509 if err != nil {
510 return nil, err
511 }
512
513 if t, err := time.Parse(time.RFC3339, createdAt); err == nil {
514 comment.Created = t
515 }
516
517 comments = append(comments, comment)
518 }
519
520 if err := rows.Err(); err != nil {
521 return nil, err
522 }
523
524 return comments, nil
525}
526
527// timeframe here is directly passed into the sql query filter, and any
528// timeframe in the past should be negative; e.g.: "-3 months"
529func GetPullsByOwnerDid(e Execer, did, timeframe string) ([]models.Pull, error) {
530 var pulls []models.Pull
531
532 rows, err := e.Query(`
533 select
534 p.owner_did,
535 p.repo_at,
536 p.pull_id,
537 p.created,
538 p.title,
539 p.state,
540 r.did,
541 r.name,
542 r.knot,
543 r.rkey,
544 r.created
545 from
546 pulls p
547 join
548 repos r on p.repo_at = r.at_uri
549 where
550 p.owner_did = ? and p.created >= date ('now', ?)
551 order by
552 p.created desc`, did, timeframe)
553 if err != nil {
554 return nil, err
555 }
556 defer rows.Close()
557
558 for rows.Next() {
559 var pull models.Pull
560 var repo models.Repo
561 var pullCreatedAt, repoCreatedAt string
562 err := rows.Scan(
563 &pull.OwnerDid,
564 &pull.RepoAt,
565 &pull.PullId,
566 &pullCreatedAt,
567 &pull.Title,
568 &pull.State,
569 &repo.Did,
570 &repo.Name,
571 &repo.Knot,
572 &repo.Rkey,
573 &repoCreatedAt,
574 )
575 if err != nil {
576 return nil, err
577 }
578
579 pullCreatedTime, err := time.Parse(time.RFC3339, pullCreatedAt)
580 if err != nil {
581 return nil, err
582 }
583 pull.Created = pullCreatedTime
584
585 repoCreatedTime, err := time.Parse(time.RFC3339, repoCreatedAt)
586 if err != nil {
587 return nil, err
588 }
589 repo.Created = repoCreatedTime
590
591 pull.Repo = &repo
592
593 pulls = append(pulls, pull)
594 }
595
596 if err := rows.Err(); err != nil {
597 return nil, err
598 }
599
600 return pulls, nil
601}
602
603func NewPullComment(e Execer, comment *models.PullComment) (int64, error) {
604 query := `insert into pull_comments (owner_did, repo_at, submission_id, comment_at, pull_id, body) values (?, ?, ?, ?, ?, ?)`
605 res, err := e.Exec(
606 query,
607 comment.OwnerDid,
608 comment.RepoAt,
609 comment.SubmissionId,
610 comment.CommentAt,
611 comment.PullId,
612 comment.Body,
613 )
614 if err != nil {
615 return 0, err
616 }
617
618 i, err := res.LastInsertId()
619 if err != nil {
620 return 0, err
621 }
622
623 return i, nil
624}
625
626func SetPullState(e Execer, repoAt syntax.ATURI, pullId int, pullState models.PullState) error {
627 _, err := e.Exec(
628 `update pulls set state = ? where repo_at = ? and pull_id = ? and (state <> ? or state <> ?)`,
629 pullState,
630 repoAt,
631 pullId,
632 models.PullDeleted, // only update state of non-deleted pulls
633 models.PullMerged, // only update state of non-merged pulls
634 )
635 return err
636}
637
638func ClosePull(e Execer, repoAt syntax.ATURI, pullId int) error {
639 err := SetPullState(e, repoAt, pullId, models.PullClosed)
640 return err
641}
642
643func ReopenPull(e Execer, repoAt syntax.ATURI, pullId int) error {
644 err := SetPullState(e, repoAt, pullId, models.PullOpen)
645 return err
646}
647
648func MergePull(e Execer, repoAt syntax.ATURI, pullId int) error {
649 err := SetPullState(e, repoAt, pullId, models.PullMerged)
650 return err
651}
652
653func DeletePull(e Execer, repoAt syntax.ATURI, pullId int) error {
654 err := SetPullState(e, repoAt, pullId, models.PullDeleted)
655 return err
656}
657
658func ResubmitPull(e Execer, pullAt syntax.ATURI, newRoundNumber int, newPatch string, combinedPatch string, newSourceRev string) error {
659 _, err := e.Exec(`
660 insert into pull_submissions (pull_at, round_number, patch, combined, source_rev)
661 values (?, ?, ?, ?, ?)
662 `, pullAt, newRoundNumber, newPatch, combinedPatch, newSourceRev)
663
664 return err
665}
666
667func SetPullParentChangeId(e Execer, parentChangeId string, filters ...filter) error {
668 var conditions []string
669 var args []any
670
671 args = append(args, parentChangeId)
672
673 for _, filter := range filters {
674 conditions = append(conditions, filter.Condition())
675 args = append(args, filter.Arg()...)
676 }
677
678 whereClause := ""
679 if conditions != nil {
680 whereClause = " where " + strings.Join(conditions, " and ")
681 }
682
683 query := fmt.Sprintf("update pulls set parent_change_id = ? %s", whereClause)
684 _, err := e.Exec(query, args...)
685
686 return err
687}
688
689// Only used when stacking to update contents in the event of a rebase (the interdiff should be empty).
690// otherwise submissions are immutable
691func UpdatePull(e Execer, newPatch, sourceRev string, filters ...filter) error {
692 var conditions []string
693 var args []any
694
695 args = append(args, sourceRev)
696 args = append(args, newPatch)
697
698 for _, filter := range filters {
699 conditions = append(conditions, filter.Condition())
700 args = append(args, filter.Arg()...)
701 }
702
703 whereClause := ""
704 if conditions != nil {
705 whereClause = " where " + strings.Join(conditions, " and ")
706 }
707
708 query := fmt.Sprintf("update pull_submissions set source_rev = ?, patch = ? %s", whereClause)
709 _, err := e.Exec(query, args...)
710
711 return err
712}
713
714func GetPullCount(e Execer, repoAt syntax.ATURI) (models.PullCount, error) {
715 row := e.QueryRow(`
716 select
717 count(case when state = ? then 1 end) as open_count,
718 count(case when state = ? then 1 end) as merged_count,
719 count(case when state = ? then 1 end) as closed_count,
720 count(case when state = ? then 1 end) as deleted_count
721 from pulls
722 where repo_at = ?`,
723 models.PullOpen,
724 models.PullMerged,
725 models.PullClosed,
726 models.PullDeleted,
727 repoAt,
728 )
729
730 var count models.PullCount
731 if err := row.Scan(&count.Open, &count.Merged, &count.Closed, &count.Deleted); err != nil {
732 return models.PullCount{Open: 0, Merged: 0, Closed: 0, Deleted: 0}, err
733 }
734
735 return count, nil
736}
737
738// change-id parent-change-id
739//
740// 4 w ,-------- z (TOP)
741// 3 z <----',------- y
742// 2 y <-----',------ x
743// 1 x <------' nil (BOT)
744//
745// `w` is parent of none, so it is the top of the stack
746func GetStack(e Execer, stackId string) (models.Stack, error) {
747 unorderedPulls, err := GetPulls(
748 e,
749 FilterEq("stack_id", stackId),
750 FilterNotEq("state", models.PullDeleted),
751 )
752 if err != nil {
753 return nil, err
754 }
755 // map of parent-change-id to pull
756 changeIdMap := make(map[string]*models.Pull, len(unorderedPulls))
757 parentMap := make(map[string]*models.Pull, len(unorderedPulls))
758 for _, p := range unorderedPulls {
759 changeIdMap[p.ChangeId] = p
760 if p.ParentChangeId != "" {
761 parentMap[p.ParentChangeId] = p
762 }
763 }
764
765 // the top of the stack is the pull that is not a parent of any pull
766 var topPull *models.Pull
767 for _, maybeTop := range unorderedPulls {
768 if _, ok := parentMap[maybeTop.ChangeId]; !ok {
769 topPull = maybeTop
770 break
771 }
772 }
773
774 pulls := []*models.Pull{}
775 for {
776 pulls = append(pulls, topPull)
777 if topPull.ParentChangeId != "" {
778 if next, ok := changeIdMap[topPull.ParentChangeId]; ok {
779 topPull = next
780 } else {
781 return nil, fmt.Errorf("failed to find parent pull request, stack is malformed")
782 }
783 } else {
784 break
785 }
786 }
787
788 return pulls, nil
789}
790
791func GetAbandonedPulls(e Execer, stackId string) ([]*models.Pull, error) {
792 pulls, err := GetPulls(
793 e,
794 FilterEq("stack_id", stackId),
795 FilterEq("state", models.PullDeleted),
796 )
797 if err != nil {
798 return nil, err
799 }
800
801 return pulls, nil
802}