forked from
tangled.org/core
Monorepo for Tangled — https://tangled.org
1package db
2
3import (
4 "cmp"
5 "database/sql"
6 "errors"
7 "fmt"
8 "maps"
9 "slices"
10 "sort"
11 "strings"
12 "time"
13
14 "github.com/bluesky-social/indigo/atproto/syntax"
15 "tangled.org/core/appview/models"
16 "tangled.org/core/orm"
17)
18
19func NewPull(tx *sql.Tx, pull *models.Pull) error {
20 _, err := tx.Exec(`
21 insert or ignore into repo_pull_seqs (repo_at, next_pull_id)
22 values (?, 1)
23 `, pull.RepoAt)
24 if err != nil {
25 return err
26 }
27
28 var nextId int
29 err = tx.QueryRow(`
30 update repo_pull_seqs
31 set next_pull_id = next_pull_id + 1
32 where repo_at = ?
33 returning next_pull_id - 1
34 `, pull.RepoAt).Scan(&nextId)
35 if err != nil {
36 return err
37 }
38
39 pull.PullId = nextId
40 pull.State = models.PullOpen
41
42 var sourceBranch, sourceRepoAt *string
43 if pull.PullSource != nil {
44 sourceBranch = &pull.PullSource.Branch
45 if pull.PullSource.RepoAt != nil {
46 x := pull.PullSource.RepoAt.String()
47 sourceRepoAt = &x
48 }
49 }
50
51 var stackId, changeId, parentChangeId *string
52 if pull.StackId != "" {
53 stackId = &pull.StackId
54 }
55 if pull.ChangeId != "" {
56 changeId = &pull.ChangeId
57 }
58 if pull.ParentChangeId != "" {
59 parentChangeId = &pull.ParentChangeId
60 }
61
62 result, err := tx.Exec(
63 `
64 insert into pulls (
65 repo_at, owner_did, pull_id, title, target_branch, body, rkey, state, source_branch, source_repo_at, stack_id, change_id, parent_change_id
66 )
67 values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`,
68 pull.RepoAt,
69 pull.OwnerDid,
70 pull.PullId,
71 pull.Title,
72 pull.TargetBranch,
73 pull.Body,
74 pull.Rkey,
75 pull.State,
76 sourceBranch,
77 sourceRepoAt,
78 stackId,
79 changeId,
80 parentChangeId,
81 )
82 if err != nil {
83 return err
84 }
85
86 // Set the database primary key ID
87 id, err := result.LastInsertId()
88 if err != nil {
89 return err
90 }
91 pull.ID = int(id)
92
93 _, err = tx.Exec(`
94 insert into pull_submissions (pull_at, round_number, patch, combined, source_rev)
95 values (?, ?, ?, ?, ?)
96 `, pull.AtUri(), 0, pull.Submissions[0].Patch, pull.Submissions[0].Combined, pull.Submissions[0].SourceRev)
97 if err != nil {
98 return err
99 }
100
101 if err := putReferences(tx, pull.AtUri(), pull.References); err != nil {
102 return fmt.Errorf("put reference_links: %w", err)
103 }
104
105 return nil
106}
107
108func GetPullAt(e Execer, repoAt syntax.ATURI, pullId int) (syntax.ATURI, error) {
109 pull, err := GetPull(e, repoAt, pullId)
110 if err != nil {
111 return "", err
112 }
113 return pull.AtUri(), err
114}
115
116func NextPullId(e Execer, repoAt syntax.ATURI) (int, error) {
117 var pullId int
118 err := e.QueryRow(`select next_pull_id from repo_pull_seqs where repo_at = ?`, repoAt).Scan(&pullId)
119 return pullId - 1, err
120}
121
122func GetPullsWithLimit(e Execer, limit int, filters ...orm.Filter) ([]*models.Pull, error) {
123 pulls := make(map[syntax.ATURI]*models.Pull)
124
125 var conditions []string
126 var args []any
127 for _, filter := range filters {
128 conditions = append(conditions, filter.Condition())
129 args = append(args, filter.Arg()...)
130 }
131
132 whereClause := ""
133 if conditions != nil {
134 whereClause = " where " + strings.Join(conditions, " and ")
135 }
136 limitClause := ""
137 if limit != 0 {
138 limitClause = fmt.Sprintf(" limit %d ", limit)
139 }
140
141 query := fmt.Sprintf(`
142 select
143 id,
144 owner_did,
145 repo_at,
146 pull_id,
147 created,
148 title,
149 state,
150 target_branch,
151 body,
152 rkey,
153 source_branch,
154 source_repo_at,
155 stack_id,
156 change_id,
157 parent_change_id
158 from
159 pulls
160 %s
161 order by
162 created desc
163 %s
164 `, whereClause, limitClause)
165
166 rows, err := e.Query(query, args...)
167 if err != nil {
168 return nil, err
169 }
170 defer rows.Close()
171
172 for rows.Next() {
173 var pull models.Pull
174 var createdAt string
175 var sourceBranch, sourceRepoAt, stackId, changeId, parentChangeId sql.NullString
176 err := rows.Scan(
177 &pull.ID,
178 &pull.OwnerDid,
179 &pull.RepoAt,
180 &pull.PullId,
181 &createdAt,
182 &pull.Title,
183 &pull.State,
184 &pull.TargetBranch,
185 &pull.Body,
186 &pull.Rkey,
187 &sourceBranch,
188 &sourceRepoAt,
189 &stackId,
190 &changeId,
191 &parentChangeId,
192 )
193 if err != nil {
194 return nil, err
195 }
196
197 createdTime, err := time.Parse(time.RFC3339, createdAt)
198 if err != nil {
199 return nil, err
200 }
201 pull.Created = createdTime
202
203 if sourceBranch.Valid {
204 pull.PullSource = &models.PullSource{
205 Branch: sourceBranch.String,
206 }
207 if sourceRepoAt.Valid {
208 sourceRepoAtParsed, err := syntax.ParseATURI(sourceRepoAt.String)
209 if err != nil {
210 return nil, err
211 }
212 pull.PullSource.RepoAt = &sourceRepoAtParsed
213 }
214 }
215
216 if stackId.Valid {
217 pull.StackId = stackId.String
218 }
219 if changeId.Valid {
220 pull.ChangeId = changeId.String
221 }
222 if parentChangeId.Valid {
223 pull.ParentChangeId = parentChangeId.String
224 }
225
226 pulls[pull.AtUri()] = &pull
227 }
228
229 var pullAts []syntax.ATURI
230 for _, p := range pulls {
231 pullAts = append(pullAts, p.AtUri())
232 }
233 submissionsMap, err := GetPullSubmissions(e, orm.FilterIn("pull_at", pullAts))
234 if err != nil {
235 return nil, fmt.Errorf("failed to get submissions: %w", err)
236 }
237
238 for pullAt, submissions := range submissionsMap {
239 if p, ok := pulls[pullAt]; ok {
240 p.Submissions = submissions
241 }
242 }
243
244 // collect allLabels for each issue
245 allLabels, err := GetLabels(e, orm.FilterIn("subject", pullAts))
246 if err != nil {
247 return nil, fmt.Errorf("failed to query labels: %w", err)
248 }
249 for pullAt, labels := range allLabels {
250 if p, ok := pulls[pullAt]; ok {
251 p.Labels = labels
252 }
253 }
254
255 // collect pull source for all pulls that need it
256 var sourceAts []syntax.ATURI
257 for _, p := range pulls {
258 if p.PullSource != nil && p.PullSource.RepoAt != nil {
259 sourceAts = append(sourceAts, *p.PullSource.RepoAt)
260 }
261 }
262 sourceRepos, err := GetRepos(e, 0, orm.FilterIn("at_uri", sourceAts))
263 if err != nil && !errors.Is(err, sql.ErrNoRows) {
264 return nil, fmt.Errorf("failed to get source repos: %w", err)
265 }
266 sourceRepoMap := make(map[syntax.ATURI]*models.Repo)
267 for _, r := range sourceRepos {
268 sourceRepoMap[r.RepoAt()] = &r
269 }
270 for _, p := range pulls {
271 if p.PullSource != nil && p.PullSource.RepoAt != nil {
272 if sourceRepo, ok := sourceRepoMap[*p.PullSource.RepoAt]; ok {
273 p.PullSource.Repo = sourceRepo
274 }
275 }
276 }
277
278 allReferences, err := GetReferencesAll(e, orm.FilterIn("from_at", pullAts))
279 if err != nil {
280 return nil, fmt.Errorf("failed to query reference_links: %w", err)
281 }
282 for pullAt, references := range allReferences {
283 if pull, ok := pulls[pullAt]; ok {
284 pull.References = references
285 }
286 }
287
288 orderedByPullId := []*models.Pull{}
289 for _, p := range pulls {
290 orderedByPullId = append(orderedByPullId, p)
291 }
292 sort.Slice(orderedByPullId, func(i, j int) bool {
293 return orderedByPullId[i].PullId > orderedByPullId[j].PullId
294 })
295
296 return orderedByPullId, nil
297}
298
299func GetPulls(e Execer, filters ...orm.Filter) ([]*models.Pull, error) {
300 return GetPullsWithLimit(e, 0, filters...)
301}
302
303func GetPullIDs(e Execer, opts models.PullSearchOptions) ([]int64, error) {
304 var ids []int64
305
306 var filters []orm.Filter
307 filters = append(filters, orm.FilterEq("state", opts.State))
308 if opts.RepoAt != "" {
309 filters = append(filters, orm.FilterEq("repo_at", opts.RepoAt))
310 }
311
312 var conditions []string
313 var args []any
314
315 for _, filter := range filters {
316 conditions = append(conditions, filter.Condition())
317 args = append(args, filter.Arg()...)
318 }
319
320 whereClause := ""
321 if conditions != nil {
322 whereClause = " where " + strings.Join(conditions, " and ")
323 }
324 pageClause := ""
325 if opts.Page.Limit != 0 {
326 pageClause = fmt.Sprintf(
327 " limit %d offset %d ",
328 opts.Page.Limit,
329 opts.Page.Offset,
330 )
331 }
332
333 query := fmt.Sprintf(
334 `
335 select
336 id
337 from
338 pulls
339 %s
340 %s`,
341 whereClause,
342 pageClause,
343 )
344 args = append(args, opts.Page.Limit, opts.Page.Offset)
345 rows, err := e.Query(query, args...)
346 if err != nil {
347 return nil, err
348 }
349 defer rows.Close()
350
351 for rows.Next() {
352 var id int64
353 err := rows.Scan(&id)
354 if err != nil {
355 return nil, err
356 }
357
358 ids = append(ids, id)
359 }
360
361 return ids, nil
362}
363
364func GetPull(e Execer, repoAt syntax.ATURI, pullId int) (*models.Pull, error) {
365 pulls, err := GetPullsWithLimit(e, 1, orm.FilterEq("repo_at", repoAt), orm.FilterEq("pull_id", pullId))
366 if err != nil {
367 return nil, err
368 }
369 if len(pulls) == 0 {
370 return nil, sql.ErrNoRows
371 }
372
373 return pulls[0], nil
374}
375
376// mapping from pull -> pull submissions
377func GetPullSubmissions(e Execer, filters ...orm.Filter) (map[syntax.ATURI][]*models.PullSubmission, error) {
378 var conditions []string
379 var args []any
380 for _, filter := range filters {
381 conditions = append(conditions, filter.Condition())
382 args = append(args, filter.Arg()...)
383 }
384
385 whereClause := ""
386 if conditions != nil {
387 whereClause = " where " + strings.Join(conditions, " and ")
388 }
389
390 query := fmt.Sprintf(`
391 select
392 id,
393 pull_at,
394 round_number,
395 patch,
396 combined,
397 created,
398 source_rev
399 from
400 pull_submissions
401 %s
402 order by
403 round_number asc
404 `, whereClause)
405
406 rows, err := e.Query(query, args...)
407 if err != nil {
408 return nil, err
409 }
410 defer rows.Close()
411
412 submissionMap := make(map[int]*models.PullSubmission)
413
414 for rows.Next() {
415 var submission models.PullSubmission
416 var submissionCreatedStr string
417 var submissionSourceRev, submissionCombined sql.NullString
418 err := rows.Scan(
419 &submission.ID,
420 &submission.PullAt,
421 &submission.RoundNumber,
422 &submission.Patch,
423 &submissionCombined,
424 &submissionCreatedStr,
425 &submissionSourceRev,
426 )
427 if err != nil {
428 return nil, err
429 }
430
431 if t, err := time.Parse(time.RFC3339, submissionCreatedStr); err == nil {
432 submission.Created = t
433 }
434
435 if submissionSourceRev.Valid {
436 submission.SourceRev = submissionSourceRev.String
437 }
438
439 if submissionCombined.Valid {
440 submission.Combined = submissionCombined.String
441 }
442
443 submissionMap[submission.ID] = &submission
444 }
445
446 if err := rows.Err(); err != nil {
447 return nil, err
448 }
449
450 // Get comments for all submissions using GetPullComments
451 submissionIds := slices.Collect(maps.Keys(submissionMap))
452 comments, err := GetPullComments(e, orm.FilterIn("submission_id", submissionIds))
453 if err != nil {
454 return nil, fmt.Errorf("failed to get pull comments: %w", err)
455 }
456 for _, comment := range comments {
457 if submission, ok := submissionMap[comment.SubmissionId]; ok {
458 submission.Comments = append(submission.Comments, comment)
459 }
460 }
461
462 // group the submissions by pull_at
463 m := make(map[syntax.ATURI][]*models.PullSubmission)
464 for _, s := range submissionMap {
465 m[s.PullAt] = append(m[s.PullAt], s)
466 }
467
468 // sort each one by round number
469 for _, s := range m {
470 slices.SortFunc(s, func(a, b *models.PullSubmission) int {
471 return cmp.Compare(a.RoundNumber, b.RoundNumber)
472 })
473 }
474
475 return m, nil
476}
477
478func GetPullComments(e Execer, filters ...orm.Filter) ([]models.PullComment, error) {
479 var conditions []string
480 var args []any
481 for _, filter := range filters {
482 conditions = append(conditions, filter.Condition())
483 args = append(args, filter.Arg()...)
484 }
485
486 whereClause := ""
487 if conditions != nil {
488 whereClause = " where " + strings.Join(conditions, " and ")
489 }
490
491 query := fmt.Sprintf(`
492 select
493 id,
494 pull_id,
495 submission_id,
496 repo_at,
497 owner_did,
498 comment_at,
499 body,
500 created
501 from
502 pull_comments
503 %s
504 order by
505 created asc
506 `, whereClause)
507
508 rows, err := e.Query(query, args...)
509 if err != nil {
510 return nil, err
511 }
512 defer rows.Close()
513
514 commentMap := make(map[string]*models.PullComment)
515 for rows.Next() {
516 var comment models.PullComment
517 var createdAt string
518 err := rows.Scan(
519 &comment.ID,
520 &comment.PullId,
521 &comment.SubmissionId,
522 &comment.RepoAt,
523 &comment.OwnerDid,
524 &comment.CommentAt,
525 &comment.Body,
526 &createdAt,
527 )
528 if err != nil {
529 return nil, err
530 }
531
532 if t, err := time.Parse(time.RFC3339, createdAt); err == nil {
533 comment.Created = t
534 }
535
536 atUri := comment.AtUri().String()
537 commentMap[atUri] = &comment
538 }
539
540 if err := rows.Err(); err != nil {
541 return nil, err
542 }
543
544 // collect references for each comments
545 commentAts := slices.Collect(maps.Keys(commentMap))
546 allReferencs, err := GetReferencesAll(e, orm.FilterIn("from_at", commentAts))
547 if err != nil {
548 return nil, fmt.Errorf("failed to query reference_links: %w", err)
549 }
550 for commentAt, references := range allReferencs {
551 if comment, ok := commentMap[commentAt.String()]; ok {
552 comment.References = references
553 }
554 }
555
556 var comments []models.PullComment
557 for _, c := range commentMap {
558 comments = append(comments, *c)
559 }
560
561 sort.Slice(comments, func(i, j int) bool {
562 return comments[i].Created.Before(comments[j].Created)
563 })
564
565 return comments, nil
566}
567
568// timeframe here is directly passed into the sql query filter, and any
569// timeframe in the past should be negative; e.g.: "-3 months"
570func GetPullsByOwnerDid(e Execer, did, timeframe string) ([]models.Pull, error) {
571 var pulls []models.Pull
572
573 rows, err := e.Query(`
574 select
575 p.owner_did,
576 p.repo_at,
577 p.pull_id,
578 p.created,
579 p.title,
580 p.state,
581 r.did,
582 r.name,
583 r.knot,
584 r.rkey,
585 r.created
586 from
587 pulls p
588 join
589 repos r on p.repo_at = r.at_uri
590 where
591 p.owner_did = ? and p.created >= date ('now', ?)
592 order by
593 p.created desc`, did, timeframe)
594 if err != nil {
595 return nil, err
596 }
597 defer rows.Close()
598
599 for rows.Next() {
600 var pull models.Pull
601 var repo models.Repo
602 var pullCreatedAt, repoCreatedAt string
603 err := rows.Scan(
604 &pull.OwnerDid,
605 &pull.RepoAt,
606 &pull.PullId,
607 &pullCreatedAt,
608 &pull.Title,
609 &pull.State,
610 &repo.Did,
611 &repo.Name,
612 &repo.Knot,
613 &repo.Rkey,
614 &repoCreatedAt,
615 )
616 if err != nil {
617 return nil, err
618 }
619
620 pullCreatedTime, err := time.Parse(time.RFC3339, pullCreatedAt)
621 if err != nil {
622 return nil, err
623 }
624 pull.Created = pullCreatedTime
625
626 repoCreatedTime, err := time.Parse(time.RFC3339, repoCreatedAt)
627 if err != nil {
628 return nil, err
629 }
630 repo.Created = repoCreatedTime
631
632 pull.Repo = &repo
633
634 pulls = append(pulls, pull)
635 }
636
637 if err := rows.Err(); err != nil {
638 return nil, err
639 }
640
641 return pulls, nil
642}
643
644func NewPullComment(tx *sql.Tx, comment *models.PullComment) (int64, error) {
645 query := `insert into pull_comments (owner_did, repo_at, submission_id, comment_at, pull_id, body) values (?, ?, ?, ?, ?, ?)`
646 res, err := tx.Exec(
647 query,
648 comment.OwnerDid,
649 comment.RepoAt,
650 comment.SubmissionId,
651 comment.CommentAt,
652 comment.PullId,
653 comment.Body,
654 )
655 if err != nil {
656 return 0, err
657 }
658
659 i, err := res.LastInsertId()
660 if err != nil {
661 return 0, err
662 }
663
664 if err := putReferences(tx, comment.AtUri(), comment.References); err != nil {
665 return 0, fmt.Errorf("put reference_links: %w", err)
666 }
667
668 return i, nil
669}
670
671func SetPullState(e Execer, repoAt syntax.ATURI, pullId int, pullState models.PullState) error {
672 _, err := e.Exec(
673 `update pulls set state = ? where repo_at = ? and pull_id = ? and (state <> ? or state <> ?)`,
674 pullState,
675 repoAt,
676 pullId,
677 models.PullDeleted, // only update state of non-deleted pulls
678 models.PullMerged, // only update state of non-merged pulls
679 )
680 return err
681}
682
683func ClosePull(e Execer, repoAt syntax.ATURI, pullId int) error {
684 err := SetPullState(e, repoAt, pullId, models.PullClosed)
685 return err
686}
687
688func ReopenPull(e Execer, repoAt syntax.ATURI, pullId int) error {
689 err := SetPullState(e, repoAt, pullId, models.PullOpen)
690 return err
691}
692
693func MergePull(e Execer, repoAt syntax.ATURI, pullId int) error {
694 err := SetPullState(e, repoAt, pullId, models.PullMerged)
695 return err
696}
697
698func DeletePull(e Execer, repoAt syntax.ATURI, pullId int) error {
699 err := SetPullState(e, repoAt, pullId, models.PullDeleted)
700 return err
701}
702
703func ResubmitPull(e Execer, pullAt syntax.ATURI, newRoundNumber int, newPatch string, combinedPatch string, newSourceRev string) error {
704 _, err := e.Exec(`
705 insert into pull_submissions (pull_at, round_number, patch, combined, source_rev)
706 values (?, ?, ?, ?, ?)
707 `, pullAt, newRoundNumber, newPatch, combinedPatch, newSourceRev)
708
709 return err
710}
711
712func SetPullParentChangeId(e Execer, parentChangeId string, filters ...orm.Filter) error {
713 var conditions []string
714 var args []any
715
716 args = append(args, parentChangeId)
717
718 for _, filter := range filters {
719 conditions = append(conditions, filter.Condition())
720 args = append(args, filter.Arg()...)
721 }
722
723 whereClause := ""
724 if conditions != nil {
725 whereClause = " where " + strings.Join(conditions, " and ")
726 }
727
728 query := fmt.Sprintf("update pulls set parent_change_id = ? %s", whereClause)
729 _, err := e.Exec(query, args...)
730
731 return err
732}
733
734// Only used when stacking to update contents in the event of a rebase (the interdiff should be empty).
735// otherwise submissions are immutable
736func UpdatePull(e Execer, newPatch, sourceRev string, filters ...orm.Filter) error {
737 var conditions []string
738 var args []any
739
740 args = append(args, sourceRev)
741 args = append(args, newPatch)
742
743 for _, filter := range filters {
744 conditions = append(conditions, filter.Condition())
745 args = append(args, filter.Arg()...)
746 }
747
748 whereClause := ""
749 if conditions != nil {
750 whereClause = " where " + strings.Join(conditions, " and ")
751 }
752
753 query := fmt.Sprintf("update pull_submissions set source_rev = ?, patch = ? %s", whereClause)
754 _, err := e.Exec(query, args...)
755
756 return err
757}
758
759func GetPullCount(e Execer, repoAt syntax.ATURI) (models.PullCount, error) {
760 row := e.QueryRow(`
761 select
762 count(case when state = ? then 1 end) as open_count,
763 count(case when state = ? then 1 end) as merged_count,
764 count(case when state = ? then 1 end) as closed_count,
765 count(case when state = ? then 1 end) as deleted_count
766 from pulls
767 where repo_at = ?`,
768 models.PullOpen,
769 models.PullMerged,
770 models.PullClosed,
771 models.PullDeleted,
772 repoAt,
773 )
774
775 var count models.PullCount
776 if err := row.Scan(&count.Open, &count.Merged, &count.Closed, &count.Deleted); err != nil {
777 return models.PullCount{Open: 0, Merged: 0, Closed: 0, Deleted: 0}, err
778 }
779
780 return count, nil
781}
782
783// change-id parent-change-id
784//
785// 4 w ,-------- z (TOP)
786// 3 z <----',------- y
787// 2 y <-----',------ x
788// 1 x <------' nil (BOT)
789//
790// `w` is parent of none, so it is the top of the stack
791func GetStack(e Execer, stackId string) (models.Stack, error) {
792 unorderedPulls, err := GetPulls(
793 e,
794 orm.FilterEq("stack_id", stackId),
795 orm.FilterNotEq("state", models.PullDeleted),
796 )
797 if err != nil {
798 return nil, err
799 }
800 // map of parent-change-id to pull
801 changeIdMap := make(map[string]*models.Pull, len(unorderedPulls))
802 parentMap := make(map[string]*models.Pull, len(unorderedPulls))
803 for _, p := range unorderedPulls {
804 changeIdMap[p.ChangeId] = p
805 if p.ParentChangeId != "" {
806 parentMap[p.ParentChangeId] = p
807 }
808 }
809
810 // the top of the stack is the pull that is not a parent of any pull
811 var topPull *models.Pull
812 for _, maybeTop := range unorderedPulls {
813 if _, ok := parentMap[maybeTop.ChangeId]; !ok {
814 topPull = maybeTop
815 break
816 }
817 }
818
819 pulls := []*models.Pull{}
820 for {
821 pulls = append(pulls, topPull)
822 if topPull.ParentChangeId != "" {
823 if next, ok := changeIdMap[topPull.ParentChangeId]; ok {
824 topPull = next
825 } else {
826 return nil, fmt.Errorf("failed to find parent pull request, stack is malformed")
827 }
828 } else {
829 break
830 }
831 }
832
833 return pulls, nil
834}
835
836func GetAbandonedPulls(e Execer, stackId string) ([]*models.Pull, error) {
837 pulls, err := GetPulls(
838 e,
839 orm.FilterEq("stack_id", stackId),
840 orm.FilterEq("state", models.PullDeleted),
841 )
842 if err != nil {
843 return nil, err
844 }
845
846 return pulls, nil
847}