forked from
tangled.org/core
Monorepo for Tangled — https://tangled.org
1package db
2
3import (
4 "cmp"
5 "database/sql"
6 "errors"
7 "fmt"
8 "maps"
9 "slices"
10 "sort"
11 "strings"
12 "time"
13
14 "github.com/bluesky-social/indigo/atproto/syntax"
15 "tangled.org/core/appview/models"
16)
17
18func NewPull(tx *sql.Tx, pull *models.Pull) error {
19 _, err := tx.Exec(`
20 insert or ignore into repo_pull_seqs (repo_at, next_pull_id)
21 values (?, 1)
22 `, pull.RepoAt)
23 if err != nil {
24 return err
25 }
26
27 var nextId int
28 err = tx.QueryRow(`
29 update repo_pull_seqs
30 set next_pull_id = next_pull_id + 1
31 where repo_at = ?
32 returning next_pull_id - 1
33 `, pull.RepoAt).Scan(&nextId)
34 if err != nil {
35 return err
36 }
37
38 pull.PullId = nextId
39 pull.State = models.PullOpen
40
41 var sourceBranch, sourceRepoAt *string
42 if pull.PullSource != nil {
43 sourceBranch = &pull.PullSource.Branch
44 if pull.PullSource.RepoAt != nil {
45 x := pull.PullSource.RepoAt.String()
46 sourceRepoAt = &x
47 }
48 }
49
50 var stackId, changeId, parentChangeId *string
51 if pull.StackId != "" {
52 stackId = &pull.StackId
53 }
54 if pull.ChangeId != "" {
55 changeId = &pull.ChangeId
56 }
57 if pull.ParentChangeId != "" {
58 parentChangeId = &pull.ParentChangeId
59 }
60
61 result, err := tx.Exec(
62 `
63 insert into pulls (
64 repo_at, owner_did, pull_id, title, target_branch, body, rkey, state, source_branch, source_repo_at, stack_id, change_id, parent_change_id
65 )
66 values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`,
67 pull.RepoAt,
68 pull.OwnerDid,
69 pull.PullId,
70 pull.Title,
71 pull.TargetBranch,
72 pull.Body,
73 pull.Rkey,
74 pull.State,
75 sourceBranch,
76 sourceRepoAt,
77 stackId,
78 changeId,
79 parentChangeId,
80 )
81 if err != nil {
82 return err
83 }
84
85 // Set the database primary key ID
86 id, err := result.LastInsertId()
87 if err != nil {
88 return err
89 }
90 pull.ID = int(id)
91
92 _, err = tx.Exec(`
93 insert into pull_submissions (pull_at, round_number, patch, combined, source_rev)
94 values (?, ?, ?, ?, ?)
95 `, pull.PullAt(), 0, pull.Submissions[0].Patch, pull.Submissions[0].Combined, pull.Submissions[0].SourceRev)
96 return err
97}
98
99func GetPullAt(e Execer, repoAt syntax.ATURI, pullId int) (syntax.ATURI, error) {
100 pull, err := GetPull(e, repoAt, pullId)
101 if err != nil {
102 return "", err
103 }
104 return pull.PullAt(), err
105}
106
107func NextPullId(e Execer, repoAt syntax.ATURI) (int, error) {
108 var pullId int
109 err := e.QueryRow(`select next_pull_id from repo_pull_seqs where repo_at = ?`, repoAt).Scan(&pullId)
110 return pullId - 1, err
111}
112
113func GetPullsWithLimit(e Execer, limit int, filters ...filter) ([]*models.Pull, error) {
114 pulls := make(map[syntax.ATURI]*models.Pull)
115
116 var conditions []string
117 var args []any
118 for _, filter := range filters {
119 conditions = append(conditions, filter.Condition())
120 args = append(args, filter.Arg()...)
121 }
122
123 whereClause := ""
124 if conditions != nil {
125 whereClause = " where " + strings.Join(conditions, " and ")
126 }
127 limitClause := ""
128 if limit != 0 {
129 limitClause = fmt.Sprintf(" limit %d ", limit)
130 }
131
132 query := fmt.Sprintf(`
133 select
134 id,
135 owner_did,
136 repo_at,
137 pull_id,
138 created,
139 title,
140 state,
141 target_branch,
142 body,
143 rkey,
144 source_branch,
145 source_repo_at,
146 stack_id,
147 change_id,
148 parent_change_id
149 from
150 pulls
151 %s
152 order by
153 created desc
154 %s
155 `, whereClause, limitClause)
156
157 rows, err := e.Query(query, args...)
158 if err != nil {
159 return nil, err
160 }
161 defer rows.Close()
162
163 for rows.Next() {
164 var pull models.Pull
165 var createdAt string
166 var sourceBranch, sourceRepoAt, stackId, changeId, parentChangeId sql.NullString
167 err := rows.Scan(
168 &pull.ID,
169 &pull.OwnerDid,
170 &pull.RepoAt,
171 &pull.PullId,
172 &createdAt,
173 &pull.Title,
174 &pull.State,
175 &pull.TargetBranch,
176 &pull.Body,
177 &pull.Rkey,
178 &sourceBranch,
179 &sourceRepoAt,
180 &stackId,
181 &changeId,
182 &parentChangeId,
183 )
184 if err != nil {
185 return nil, err
186 }
187
188 createdTime, err := time.Parse(time.RFC3339, createdAt)
189 if err != nil {
190 return nil, err
191 }
192 pull.Created = createdTime
193
194 if sourceBranch.Valid {
195 pull.PullSource = &models.PullSource{
196 Branch: sourceBranch.String,
197 }
198 if sourceRepoAt.Valid {
199 sourceRepoAtParsed, err := syntax.ParseATURI(sourceRepoAt.String)
200 if err != nil {
201 return nil, err
202 }
203 pull.PullSource.RepoAt = &sourceRepoAtParsed
204 }
205 }
206
207 if stackId.Valid {
208 pull.StackId = stackId.String
209 }
210 if changeId.Valid {
211 pull.ChangeId = changeId.String
212 }
213 if parentChangeId.Valid {
214 pull.ParentChangeId = parentChangeId.String
215 }
216
217 pulls[pull.PullAt()] = &pull
218 }
219
220 var pullAts []syntax.ATURI
221 for _, p := range pulls {
222 pullAts = append(pullAts, p.PullAt())
223 }
224 submissionsMap, err := GetPullSubmissions(e, FilterIn("pull_at", pullAts))
225 if err != nil {
226 return nil, fmt.Errorf("failed to get submissions: %w", err)
227 }
228
229 for pullAt, submissions := range submissionsMap {
230 if p, ok := pulls[pullAt]; ok {
231 p.Submissions = submissions
232 }
233 }
234
235 // collect allLabels for each issue
236 allLabels, err := GetLabels(e, FilterIn("subject", pullAts))
237 if err != nil {
238 return nil, fmt.Errorf("failed to query labels: %w", err)
239 }
240 for pullAt, labels := range allLabels {
241 if p, ok := pulls[pullAt]; ok {
242 p.Labels = labels
243 }
244 }
245
246 // collect pull source for all pulls that need it
247 var sourceAts []syntax.ATURI
248 for _, p := range pulls {
249 if p.PullSource != nil && p.PullSource.RepoAt != nil {
250 sourceAts = append(sourceAts, *p.PullSource.RepoAt)
251 }
252 }
253 sourceRepos, err := GetRepos(e, 0, FilterIn("at_uri", sourceAts))
254 if err != nil && !errors.Is(err, sql.ErrNoRows) {
255 return nil, fmt.Errorf("failed to get source repos: %w", err)
256 }
257 sourceRepoMap := make(map[syntax.ATURI]*models.Repo)
258 for _, r := range sourceRepos {
259 sourceRepoMap[r.RepoAt()] = &r
260 }
261 for _, p := range pulls {
262 if p.PullSource != nil && p.PullSource.RepoAt != nil {
263 if sourceRepo, ok := sourceRepoMap[*p.PullSource.RepoAt]; ok {
264 p.PullSource.Repo = sourceRepo
265 }
266 }
267 }
268
269 orderedByPullId := []*models.Pull{}
270 for _, p := range pulls {
271 orderedByPullId = append(orderedByPullId, p)
272 }
273 sort.Slice(orderedByPullId, func(i, j int) bool {
274 return orderedByPullId[i].PullId > orderedByPullId[j].PullId
275 })
276
277 return orderedByPullId, nil
278}
279
280func GetPulls(e Execer, filters ...filter) ([]*models.Pull, error) {
281 return GetPullsWithLimit(e, 0, filters...)
282}
283
284func GetPull(e Execer, repoAt syntax.ATURI, pullId int) (*models.Pull, error) {
285 pulls, err := GetPullsWithLimit(e, 1, FilterEq("repo_at", repoAt), FilterEq("pull_id", pullId))
286 if err != nil {
287 return nil, err
288 }
289 if pulls == nil {
290 return nil, sql.ErrNoRows
291 }
292
293 return pulls[0], nil
294}
295
296// mapping from pull -> pull submissions
297func GetPullSubmissions(e Execer, filters ...filter) (map[syntax.ATURI][]*models.PullSubmission, error) {
298 var conditions []string
299 var args []any
300 for _, filter := range filters {
301 conditions = append(conditions, filter.Condition())
302 args = append(args, filter.Arg()...)
303 }
304
305 whereClause := ""
306 if conditions != nil {
307 whereClause = " where " + strings.Join(conditions, " and ")
308 }
309
310 query := fmt.Sprintf(`
311 select
312 id,
313 pull_at,
314 round_number,
315 patch,
316 combined,
317 created,
318 source_rev
319 from
320 pull_submissions
321 %s
322 order by
323 round_number asc
324 `, whereClause)
325
326 rows, err := e.Query(query, args...)
327 if err != nil {
328 return nil, err
329 }
330 defer rows.Close()
331
332 submissionMap := make(map[int]*models.PullSubmission)
333
334 for rows.Next() {
335 var submission models.PullSubmission
336 var submissionCreatedStr string
337 var submissionSourceRev, submissionCombined sql.NullString
338 err := rows.Scan(
339 &submission.ID,
340 &submission.PullAt,
341 &submission.RoundNumber,
342 &submission.Patch,
343 &submissionCombined,
344 &submissionCreatedStr,
345 &submissionSourceRev,
346 )
347 if err != nil {
348 return nil, err
349 }
350
351 if t, err := time.Parse(time.RFC3339, submissionCreatedStr); err == nil {
352 submission.Created = t
353 }
354
355 if submissionSourceRev.Valid {
356 submission.SourceRev = submissionSourceRev.String
357 }
358
359 if submissionCombined.Valid {
360 submission.Combined = submissionCombined.String
361 }
362
363 submissionMap[submission.ID] = &submission
364 }
365
366 if err := rows.Err(); err != nil {
367 return nil, err
368 }
369
370 // Get comments for all submissions using GetPullComments
371 submissionIds := slices.Collect(maps.Keys(submissionMap))
372 comments, err := GetPullComments(e, FilterIn("submission_id", submissionIds))
373 if err != nil {
374 return nil, err
375 }
376 for _, comment := range comments {
377 if submission, ok := submissionMap[comment.SubmissionId]; ok {
378 submission.Comments = append(submission.Comments, comment)
379 }
380 }
381
382 // group the submissions by pull_at
383 m := make(map[syntax.ATURI][]*models.PullSubmission)
384 for _, s := range submissionMap {
385 m[s.PullAt] = append(m[s.PullAt], s)
386 }
387
388 // sort each one by round number
389 for _, s := range m {
390 slices.SortFunc(s, func(a, b *models.PullSubmission) int {
391 return cmp.Compare(a.RoundNumber, b.RoundNumber)
392 })
393 }
394
395 return m, nil
396}
397
398func GetPullComments(e Execer, filters ...filter) ([]models.PullComment, error) {
399 var conditions []string
400 var args []any
401 for _, filter := range filters {
402 conditions = append(conditions, filter.Condition())
403 args = append(args, filter.Arg()...)
404 }
405
406 whereClause := ""
407 if conditions != nil {
408 whereClause = " where " + strings.Join(conditions, " and ")
409 }
410
411 query := fmt.Sprintf(`
412 select
413 id,
414 pull_id,
415 submission_id,
416 repo_at,
417 owner_did,
418 comment_at,
419 body,
420 created
421 from
422 pull_comments
423 %s
424 order by
425 created asc
426 `, whereClause)
427
428 rows, err := e.Query(query, args...)
429 if err != nil {
430 return nil, err
431 }
432 defer rows.Close()
433
434 var comments []models.PullComment
435 for rows.Next() {
436 var comment models.PullComment
437 var createdAt string
438 err := rows.Scan(
439 &comment.ID,
440 &comment.PullId,
441 &comment.SubmissionId,
442 &comment.RepoAt,
443 &comment.OwnerDid,
444 &comment.CommentAt,
445 &comment.Body,
446 &createdAt,
447 )
448 if err != nil {
449 return nil, err
450 }
451
452 if t, err := time.Parse(time.RFC3339, createdAt); err == nil {
453 comment.Created = t
454 }
455
456 comments = append(comments, comment)
457 }
458
459 if err := rows.Err(); err != nil {
460 return nil, err
461 }
462
463 return comments, nil
464}
465
466// timeframe here is directly passed into the sql query filter, and any
467// timeframe in the past should be negative; e.g.: "-3 months"
468func GetPullsByOwnerDid(e Execer, did, timeframe string) ([]models.Pull, error) {
469 var pulls []models.Pull
470
471 rows, err := e.Query(`
472 select
473 p.owner_did,
474 p.repo_at,
475 p.pull_id,
476 p.created,
477 p.title,
478 p.state,
479 r.did,
480 r.name,
481 r.knot,
482 r.rkey,
483 r.created
484 from
485 pulls p
486 join
487 repos r on p.repo_at = r.at_uri
488 where
489 p.owner_did = ? and p.created >= date ('now', ?)
490 order by
491 p.created desc`, did, timeframe)
492 if err != nil {
493 return nil, err
494 }
495 defer rows.Close()
496
497 for rows.Next() {
498 var pull models.Pull
499 var repo models.Repo
500 var pullCreatedAt, repoCreatedAt string
501 err := rows.Scan(
502 &pull.OwnerDid,
503 &pull.RepoAt,
504 &pull.PullId,
505 &pullCreatedAt,
506 &pull.Title,
507 &pull.State,
508 &repo.Did,
509 &repo.Name,
510 &repo.Knot,
511 &repo.Rkey,
512 &repoCreatedAt,
513 )
514 if err != nil {
515 return nil, err
516 }
517
518 pullCreatedTime, err := time.Parse(time.RFC3339, pullCreatedAt)
519 if err != nil {
520 return nil, err
521 }
522 pull.Created = pullCreatedTime
523
524 repoCreatedTime, err := time.Parse(time.RFC3339, repoCreatedAt)
525 if err != nil {
526 return nil, err
527 }
528 repo.Created = repoCreatedTime
529
530 pull.Repo = &repo
531
532 pulls = append(pulls, pull)
533 }
534
535 if err := rows.Err(); err != nil {
536 return nil, err
537 }
538
539 return pulls, nil
540}
541
542func NewPullComment(e Execer, comment *models.PullComment) (int64, error) {
543 query := `insert into pull_comments (owner_did, repo_at, submission_id, comment_at, pull_id, body) values (?, ?, ?, ?, ?, ?)`
544 res, err := e.Exec(
545 query,
546 comment.OwnerDid,
547 comment.RepoAt,
548 comment.SubmissionId,
549 comment.CommentAt,
550 comment.PullId,
551 comment.Body,
552 )
553 if err != nil {
554 return 0, err
555 }
556
557 i, err := res.LastInsertId()
558 if err != nil {
559 return 0, err
560 }
561
562 return i, nil
563}
564
565func SetPullState(e Execer, repoAt syntax.ATURI, pullId int, pullState models.PullState) error {
566 _, err := e.Exec(
567 `update pulls set state = ? where repo_at = ? and pull_id = ? and (state <> ? or state <> ?)`,
568 pullState,
569 repoAt,
570 pullId,
571 models.PullDeleted, // only update state of non-deleted pulls
572 models.PullMerged, // only update state of non-merged pulls
573 )
574 return err
575}
576
577func ClosePull(e Execer, repoAt syntax.ATURI, pullId int) error {
578 err := SetPullState(e, repoAt, pullId, models.PullClosed)
579 return err
580}
581
582func ReopenPull(e Execer, repoAt syntax.ATURI, pullId int) error {
583 err := SetPullState(e, repoAt, pullId, models.PullOpen)
584 return err
585}
586
587func MergePull(e Execer, repoAt syntax.ATURI, pullId int) error {
588 err := SetPullState(e, repoAt, pullId, models.PullMerged)
589 return err
590}
591
592func DeletePull(e Execer, repoAt syntax.ATURI, pullId int) error {
593 err := SetPullState(e, repoAt, pullId, models.PullDeleted)
594 return err
595}
596
597func ResubmitPull(e Execer, pullAt syntax.ATURI, newRoundNumber int, newPatch string, combinedPatch string, newSourceRev string) error {
598 _, err := e.Exec(`
599 insert into pull_submissions (pull_at, round_number, patch, combined, source_rev)
600 values (?, ?, ?, ?, ?)
601 `, pullAt, newRoundNumber, newPatch, combinedPatch, newSourceRev)
602
603 return err
604}
605
606func SetPullParentChangeId(e Execer, parentChangeId string, filters ...filter) error {
607 var conditions []string
608 var args []any
609
610 args = append(args, parentChangeId)
611
612 for _, filter := range filters {
613 conditions = append(conditions, filter.Condition())
614 args = append(args, filter.Arg()...)
615 }
616
617 whereClause := ""
618 if conditions != nil {
619 whereClause = " where " + strings.Join(conditions, " and ")
620 }
621
622 query := fmt.Sprintf("update pulls set parent_change_id = ? %s", whereClause)
623 _, err := e.Exec(query, args...)
624
625 return err
626}
627
628// Only used when stacking to update contents in the event of a rebase (the interdiff should be empty).
629// otherwise submissions are immutable
630func UpdatePull(e Execer, newPatch, sourceRev string, filters ...filter) error {
631 var conditions []string
632 var args []any
633
634 args = append(args, sourceRev)
635 args = append(args, newPatch)
636
637 for _, filter := range filters {
638 conditions = append(conditions, filter.Condition())
639 args = append(args, filter.Arg()...)
640 }
641
642 whereClause := ""
643 if conditions != nil {
644 whereClause = " where " + strings.Join(conditions, " and ")
645 }
646
647 query := fmt.Sprintf("update pull_submissions set source_rev = ?, patch = ? %s", whereClause)
648 _, err := e.Exec(query, args...)
649
650 return err
651}
652
653func GetPullCount(e Execer, repoAt syntax.ATURI) (models.PullCount, error) {
654 row := e.QueryRow(`
655 select
656 count(case when state = ? then 1 end) as open_count,
657 count(case when state = ? then 1 end) as merged_count,
658 count(case when state = ? then 1 end) as closed_count,
659 count(case when state = ? then 1 end) as deleted_count
660 from pulls
661 where repo_at = ?`,
662 models.PullOpen,
663 models.PullMerged,
664 models.PullClosed,
665 models.PullDeleted,
666 repoAt,
667 )
668
669 var count models.PullCount
670 if err := row.Scan(&count.Open, &count.Merged, &count.Closed, &count.Deleted); err != nil {
671 return models.PullCount{Open: 0, Merged: 0, Closed: 0, Deleted: 0}, err
672 }
673
674 return count, nil
675}
676
677// change-id parent-change-id
678//
679// 4 w ,-------- z (TOP)
680// 3 z <----',------- y
681// 2 y <-----',------ x
682// 1 x <------' nil (BOT)
683//
684// `w` is parent of none, so it is the top of the stack
685func GetStack(e Execer, stackId string) (models.Stack, error) {
686 unorderedPulls, err := GetPulls(
687 e,
688 FilterEq("stack_id", stackId),
689 FilterNotEq("state", models.PullDeleted),
690 )
691 if err != nil {
692 return nil, err
693 }
694 // map of parent-change-id to pull
695 changeIdMap := make(map[string]*models.Pull, len(unorderedPulls))
696 parentMap := make(map[string]*models.Pull, len(unorderedPulls))
697 for _, p := range unorderedPulls {
698 changeIdMap[p.ChangeId] = p
699 if p.ParentChangeId != "" {
700 parentMap[p.ParentChangeId] = p
701 }
702 }
703
704 // the top of the stack is the pull that is not a parent of any pull
705 var topPull *models.Pull
706 for _, maybeTop := range unorderedPulls {
707 if _, ok := parentMap[maybeTop.ChangeId]; !ok {
708 topPull = maybeTop
709 break
710 }
711 }
712
713 pulls := []*models.Pull{}
714 for {
715 pulls = append(pulls, topPull)
716 if topPull.ParentChangeId != "" {
717 if next, ok := changeIdMap[topPull.ParentChangeId]; ok {
718 topPull = next
719 } else {
720 return nil, fmt.Errorf("failed to find parent pull request, stack is malformed")
721 }
722 } else {
723 break
724 }
725 }
726
727 return pulls, nil
728}
729
730func GetAbandonedPulls(e Execer, stackId string) ([]*models.Pull, error) {
731 pulls, err := GetPulls(
732 e,
733 FilterEq("stack_id", stackId),
734 FilterEq("state", models.PullDeleted),
735 )
736 if err != nil {
737 return nil, err
738 }
739
740 return pulls, nil
741}