forked from
tangled.org/core
Monorepo for Tangled — https://tangled.org
1package db
2
3import (
4 "cmp"
5 "database/sql"
6 "errors"
7 "fmt"
8 "maps"
9 "slices"
10 "sort"
11 "strings"
12 "time"
13
14 "github.com/bluesky-social/indigo/atproto/syntax"
15 "tangled.org/core/appview/models"
16)
17
18func NewPull(tx *sql.Tx, pull *models.Pull) error {
19 _, err := tx.Exec(`
20 insert or ignore into repo_pull_seqs (repo_at, next_pull_id)
21 values (?, 1)
22 `, pull.RepoAt)
23 if err != nil {
24 return err
25 }
26
27 var nextId int
28 err = tx.QueryRow(`
29 update repo_pull_seqs
30 set next_pull_id = next_pull_id + 1
31 where repo_at = ?
32 returning next_pull_id - 1
33 `, pull.RepoAt).Scan(&nextId)
34 if err != nil {
35 return err
36 }
37
38 pull.PullId = nextId
39 pull.State = models.PullOpen
40
41 var sourceBranch, sourceRepoAt *string
42 if pull.PullSource != nil {
43 sourceBranch = &pull.PullSource.Branch
44 if pull.PullSource.RepoAt != nil {
45 x := pull.PullSource.RepoAt.String()
46 sourceRepoAt = &x
47 }
48 }
49
50 var stackId, changeId, parentChangeId *string
51 if pull.StackId != "" {
52 stackId = &pull.StackId
53 }
54 if pull.ChangeId != "" {
55 changeId = &pull.ChangeId
56 }
57 if pull.ParentChangeId != "" {
58 parentChangeId = &pull.ParentChangeId
59 }
60
61 result, err := tx.Exec(
62 `
63 insert into pulls (
64 repo_at, owner_did, pull_id, title, target_branch, body, rkey, state, source_branch, source_repo_at, stack_id, change_id, parent_change_id
65 )
66 values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`,
67 pull.RepoAt,
68 pull.OwnerDid,
69 pull.PullId,
70 pull.Title,
71 pull.TargetBranch,
72 pull.Body,
73 pull.Rkey,
74 pull.State,
75 sourceBranch,
76 sourceRepoAt,
77 stackId,
78 changeId,
79 parentChangeId,
80 )
81 if err != nil {
82 return err
83 }
84
85 // Set the database primary key ID
86 id, err := result.LastInsertId()
87 if err != nil {
88 return err
89 }
90 pull.ID = int(id)
91
92 _, err = tx.Exec(`
93 insert into pull_submissions (pull_at, round_number, patch, source_rev)
94 values (?, ?, ?, ?)
95 `, pull.PullAt(), 0, pull.Submissions[0].Patch, pull.Submissions[0].SourceRev)
96 return err
97}
98
99func GetPullAt(e Execer, repoAt syntax.ATURI, pullId int) (syntax.ATURI, error) {
100 pull, err := GetPull(e, repoAt, pullId)
101 if err != nil {
102 return "", err
103 }
104 return pull.PullAt(), err
105}
106
107func NextPullId(e Execer, repoAt syntax.ATURI) (int, error) {
108 var pullId int
109 err := e.QueryRow(`select next_pull_id from repo_pull_seqs where repo_at = ?`, repoAt).Scan(&pullId)
110 return pullId - 1, err
111}
112
113func GetPullsWithLimit(e Execer, limit int, filters ...filter) ([]*models.Pull, error) {
114 pulls := make(map[syntax.ATURI]*models.Pull)
115
116 var conditions []string
117 var args []any
118 for _, filter := range filters {
119 conditions = append(conditions, filter.Condition())
120 args = append(args, filter.Arg()...)
121 }
122
123 whereClause := ""
124 if conditions != nil {
125 whereClause = " where " + strings.Join(conditions, " and ")
126 }
127 limitClause := ""
128 if limit != 0 {
129 limitClause = fmt.Sprintf(" limit %d ", limit)
130 }
131
132 query := fmt.Sprintf(`
133 select
134 id,
135 owner_did,
136 repo_at,
137 pull_id,
138 created,
139 title,
140 state,
141 target_branch,
142 body,
143 rkey,
144 source_branch,
145 source_repo_at,
146 stack_id,
147 change_id,
148 parent_change_id
149 from
150 pulls
151 %s
152 order by
153 created desc
154 %s
155 `, whereClause, limitClause)
156
157 rows, err := e.Query(query, args...)
158 if err != nil {
159 return nil, err
160 }
161 defer rows.Close()
162
163 for rows.Next() {
164 var pull models.Pull
165 var createdAt string
166 var sourceBranch, sourceRepoAt, stackId, changeId, parentChangeId sql.NullString
167 err := rows.Scan(
168 &pull.ID,
169 &pull.OwnerDid,
170 &pull.RepoAt,
171 &pull.PullId,
172 &createdAt,
173 &pull.Title,
174 &pull.State,
175 &pull.TargetBranch,
176 &pull.Body,
177 &pull.Rkey,
178 &sourceBranch,
179 &sourceRepoAt,
180 &stackId,
181 &changeId,
182 &parentChangeId,
183 )
184 if err != nil {
185 return nil, err
186 }
187
188 createdTime, err := time.Parse(time.RFC3339, createdAt)
189 if err != nil {
190 return nil, err
191 }
192 pull.Created = createdTime
193
194 if sourceBranch.Valid {
195 pull.PullSource = &models.PullSource{
196 Branch: sourceBranch.String,
197 }
198 if sourceRepoAt.Valid {
199 sourceRepoAtParsed, err := syntax.ParseATURI(sourceRepoAt.String)
200 if err != nil {
201 return nil, err
202 }
203 pull.PullSource.RepoAt = &sourceRepoAtParsed
204 }
205 }
206
207 if stackId.Valid {
208 pull.StackId = stackId.String
209 }
210 if changeId.Valid {
211 pull.ChangeId = changeId.String
212 }
213 if parentChangeId.Valid {
214 pull.ParentChangeId = parentChangeId.String
215 }
216
217 pulls[pull.PullAt()] = &pull
218 }
219
220 var pullAts []syntax.ATURI
221 for _, p := range pulls {
222 pullAts = append(pullAts, p.PullAt())
223 }
224 submissionsMap, err := GetPullSubmissions(e, FilterIn("pull_at", pullAts))
225 if err != nil {
226 return nil, fmt.Errorf("failed to get submissions: %w", err)
227 }
228
229 for pullAt, submissions := range submissionsMap {
230 if p, ok := pulls[pullAt]; ok {
231 p.Submissions = submissions
232 }
233 }
234
235 // collect allLabels for each issue
236 allLabels, err := GetLabels(e, FilterIn("subject", pullAts))
237 if err != nil {
238 return nil, fmt.Errorf("failed to query labels: %w", err)
239 }
240 for pullAt, labels := range allLabels {
241 if p, ok := pulls[pullAt]; ok {
242 p.Labels = labels
243 }
244 }
245
246 // collect pull source for all pulls that need it
247 var sourceAts []syntax.ATURI
248 for _, p := range pulls {
249 if p.PullSource != nil && p.PullSource.RepoAt != nil {
250 sourceAts = append(sourceAts, *p.PullSource.RepoAt)
251 }
252 }
253 sourceRepos, err := GetRepos(e, 0, FilterIn("at_uri", sourceAts))
254 if err != nil && !errors.Is(err, sql.ErrNoRows) {
255 return nil, fmt.Errorf("failed to get source repos: %w", err)
256 }
257 sourceRepoMap := make(map[syntax.ATURI]*models.Repo)
258 for _, r := range sourceRepos {
259 sourceRepoMap[r.RepoAt()] = &r
260 }
261 for _, p := range pulls {
262 if p.PullSource != nil && p.PullSource.RepoAt != nil {
263 if sourceRepo, ok := sourceRepoMap[*p.PullSource.RepoAt]; ok {
264 p.PullSource.Repo = sourceRepo
265 }
266 }
267 }
268
269 orderedByPullId := []*models.Pull{}
270 for _, p := range pulls {
271 orderedByPullId = append(orderedByPullId, p)
272 }
273 sort.Slice(orderedByPullId, func(i, j int) bool {
274 return orderedByPullId[i].PullId > orderedByPullId[j].PullId
275 })
276
277 return orderedByPullId, nil
278}
279
280func GetPulls(e Execer, filters ...filter) ([]*models.Pull, error) {
281 return GetPullsWithLimit(e, 0, filters...)
282}
283
284func GetPull(e Execer, repoAt syntax.ATURI, pullId int) (*models.Pull, error) {
285 pulls, err := GetPullsWithLimit(e, 1, FilterEq("repo_at", repoAt), FilterEq("pull_id", pullId))
286 if err != nil {
287 return nil, err
288 }
289 if pulls == nil {
290 return nil, sql.ErrNoRows
291 }
292
293 return pulls[0], nil
294}
295
296// mapping from pull -> pull submissions
297func GetPullSubmissions(e Execer, filters ...filter) (map[syntax.ATURI][]*models.PullSubmission, error) {
298 var conditions []string
299 var args []any
300 for _, filter := range filters {
301 conditions = append(conditions, filter.Condition())
302 args = append(args, filter.Arg()...)
303 }
304
305 whereClause := ""
306 if conditions != nil {
307 whereClause = " where " + strings.Join(conditions, " and ")
308 }
309
310 query := fmt.Sprintf(`
311 select
312 id,
313 pull_at,
314 round_number,
315 patch,
316 created,
317 source_rev
318 from
319 pull_submissions
320 %s
321 order by
322 round_number asc
323 `, whereClause)
324
325 rows, err := e.Query(query, args...)
326 if err != nil {
327 return nil, err
328 }
329 defer rows.Close()
330
331 submissionMap := make(map[int]*models.PullSubmission)
332
333 for rows.Next() {
334 var submission models.PullSubmission
335 var createdAt string
336 var sourceRev sql.NullString
337 err := rows.Scan(
338 &submission.ID,
339 &submission.PullAt,
340 &submission.RoundNumber,
341 &submission.Patch,
342 &createdAt,
343 &sourceRev,
344 )
345 if err != nil {
346 return nil, err
347 }
348
349 createdTime, err := time.Parse(time.RFC3339, createdAt)
350 if err != nil {
351 return nil, err
352 }
353 submission.Created = createdTime
354
355 if sourceRev.Valid {
356 submission.SourceRev = sourceRev.String
357 }
358
359 submissionMap[submission.ID] = &submission
360 }
361
362 if err := rows.Err(); err != nil {
363 return nil, err
364 }
365
366 // Get comments for all submissions using GetPullComments
367 submissionIds := slices.Collect(maps.Keys(submissionMap))
368 comments, err := GetPullComments(e, FilterIn("submission_id", submissionIds))
369 if err != nil {
370 return nil, err
371 }
372 for _, comment := range comments {
373 if submission, ok := submissionMap[comment.SubmissionId]; ok {
374 submission.Comments = append(submission.Comments, comment)
375 }
376 }
377
378 // group the submissions by pull_at
379 m := make(map[syntax.ATURI][]*models.PullSubmission)
380 for _, s := range submissionMap {
381 m[s.PullAt] = append(m[s.PullAt], s)
382 }
383
384 // sort each one by round number
385 for _, s := range m {
386 slices.SortFunc(s, func(a, b *models.PullSubmission) int {
387 return cmp.Compare(a.RoundNumber, b.RoundNumber)
388 })
389 }
390
391 return m, nil
392}
393
394func GetPullComments(e Execer, filters ...filter) ([]models.PullComment, error) {
395 var conditions []string
396 var args []any
397 for _, filter := range filters {
398 conditions = append(conditions, filter.Condition())
399 args = append(args, filter.Arg()...)
400 }
401
402 whereClause := ""
403 if conditions != nil {
404 whereClause = " where " + strings.Join(conditions, " and ")
405 }
406
407 query := fmt.Sprintf(`
408 select
409 id,
410 pull_id,
411 submission_id,
412 repo_at,
413 owner_did,
414 comment_at,
415 body,
416 created
417 from
418 pull_comments
419 %s
420 order by
421 created asc
422 `, whereClause)
423
424 rows, err := e.Query(query, args...)
425 if err != nil {
426 return nil, err
427 }
428 defer rows.Close()
429
430 var comments []models.PullComment
431 for rows.Next() {
432 var comment models.PullComment
433 var createdAt string
434 err := rows.Scan(
435 &comment.ID,
436 &comment.PullId,
437 &comment.SubmissionId,
438 &comment.RepoAt,
439 &comment.OwnerDid,
440 &comment.CommentAt,
441 &comment.Body,
442 &createdAt,
443 )
444 if err != nil {
445 return nil, err
446 }
447
448 if t, err := time.Parse(time.RFC3339, createdAt); err == nil {
449 comment.Created = t
450 }
451
452 comments = append(comments, comment)
453 }
454
455 if err := rows.Err(); err != nil {
456 return nil, err
457 }
458
459 return comments, nil
460}
461
462// timeframe here is directly passed into the sql query filter, and any
463// timeframe in the past should be negative; e.g.: "-3 months"
464func GetPullsByOwnerDid(e Execer, did, timeframe string) ([]models.Pull, error) {
465 var pulls []models.Pull
466
467 rows, err := e.Query(`
468 select
469 p.owner_did,
470 p.repo_at,
471 p.pull_id,
472 p.created,
473 p.title,
474 p.state,
475 r.did,
476 r.name,
477 r.knot,
478 r.rkey,
479 r.created
480 from
481 pulls p
482 join
483 repos r on p.repo_at = r.at_uri
484 where
485 p.owner_did = ? and p.created >= date ('now', ?)
486 order by
487 p.created desc`, did, timeframe)
488 if err != nil {
489 return nil, err
490 }
491 defer rows.Close()
492
493 for rows.Next() {
494 var pull models.Pull
495 var repo models.Repo
496 var pullCreatedAt, repoCreatedAt string
497 err := rows.Scan(
498 &pull.OwnerDid,
499 &pull.RepoAt,
500 &pull.PullId,
501 &pullCreatedAt,
502 &pull.Title,
503 &pull.State,
504 &repo.Did,
505 &repo.Name,
506 &repo.Knot,
507 &repo.Rkey,
508 &repoCreatedAt,
509 )
510 if err != nil {
511 return nil, err
512 }
513
514 pullCreatedTime, err := time.Parse(time.RFC3339, pullCreatedAt)
515 if err != nil {
516 return nil, err
517 }
518 pull.Created = pullCreatedTime
519
520 repoCreatedTime, err := time.Parse(time.RFC3339, repoCreatedAt)
521 if err != nil {
522 return nil, err
523 }
524 repo.Created = repoCreatedTime
525
526 pull.Repo = &repo
527
528 pulls = append(pulls, pull)
529 }
530
531 if err := rows.Err(); err != nil {
532 return nil, err
533 }
534
535 return pulls, nil
536}
537
538func NewPullComment(e Execer, comment *models.PullComment) (int64, error) {
539 query := `insert into pull_comments (owner_did, repo_at, submission_id, comment_at, pull_id, body) values (?, ?, ?, ?, ?, ?)`
540 res, err := e.Exec(
541 query,
542 comment.OwnerDid,
543 comment.RepoAt,
544 comment.SubmissionId,
545 comment.CommentAt,
546 comment.PullId,
547 comment.Body,
548 )
549 if err != nil {
550 return 0, err
551 }
552
553 i, err := res.LastInsertId()
554 if err != nil {
555 return 0, err
556 }
557
558 return i, nil
559}
560
561func SetPullState(e Execer, repoAt syntax.ATURI, pullId int, pullState models.PullState) error {
562 _, err := e.Exec(
563 `update pulls set state = ? where repo_at = ? and pull_id = ? and (state <> ? or state <> ?)`,
564 pullState,
565 repoAt,
566 pullId,
567 models.PullDeleted, // only update state of non-deleted pulls
568 models.PullMerged, // only update state of non-merged pulls
569 )
570 return err
571}
572
573func ClosePull(e Execer, repoAt syntax.ATURI, pullId int) error {
574 err := SetPullState(e, repoAt, pullId, models.PullClosed)
575 return err
576}
577
578func ReopenPull(e Execer, repoAt syntax.ATURI, pullId int) error {
579 err := SetPullState(e, repoAt, pullId, models.PullOpen)
580 return err
581}
582
583func MergePull(e Execer, repoAt syntax.ATURI, pullId int) error {
584 err := SetPullState(e, repoAt, pullId, models.PullMerged)
585 return err
586}
587
588func DeletePull(e Execer, repoAt syntax.ATURI, pullId int) error {
589 err := SetPullState(e, repoAt, pullId, models.PullDeleted)
590 return err
591}
592
593func ResubmitPull(e Execer, pull *models.Pull) error {
594 newPatch := pull.LatestPatch()
595 newSourceRev := pull.LatestSha()
596 newRoundNumber := len(pull.Submissions)
597 _, err := e.Exec(`
598 insert into pull_submissions (pull_at, round_number, patch, source_rev)
599 values (?, ?, ?, ?)
600 `, pull.PullAt(), newRoundNumber, newPatch, newSourceRev)
601
602 return err
603}
604
605func SetPullParentChangeId(e Execer, parentChangeId string, filters ...filter) error {
606 var conditions []string
607 var args []any
608
609 args = append(args, parentChangeId)
610
611 for _, filter := range filters {
612 conditions = append(conditions, filter.Condition())
613 args = append(args, filter.Arg()...)
614 }
615
616 whereClause := ""
617 if conditions != nil {
618 whereClause = " where " + strings.Join(conditions, " and ")
619 }
620
621 query := fmt.Sprintf("update pulls set parent_change_id = ? %s", whereClause)
622 _, err := e.Exec(query, args...)
623
624 return err
625}
626
627// Only used when stacking to update contents in the event of a rebase (the interdiff should be empty).
628// otherwise submissions are immutable
629func UpdatePull(e Execer, newPatch, sourceRev string, filters ...filter) error {
630 var conditions []string
631 var args []any
632
633 args = append(args, sourceRev)
634 args = append(args, newPatch)
635
636 for _, filter := range filters {
637 conditions = append(conditions, filter.Condition())
638 args = append(args, filter.Arg()...)
639 }
640
641 whereClause := ""
642 if conditions != nil {
643 whereClause = " where " + strings.Join(conditions, " and ")
644 }
645
646 query := fmt.Sprintf("update pull_submissions set source_rev = ?, patch = ? %s", whereClause)
647 _, err := e.Exec(query, args...)
648
649 return err
650}
651
652func GetPullCount(e Execer, repoAt syntax.ATURI) (models.PullCount, error) {
653 row := e.QueryRow(`
654 select
655 count(case when state = ? then 1 end) as open_count,
656 count(case when state = ? then 1 end) as merged_count,
657 count(case when state = ? then 1 end) as closed_count,
658 count(case when state = ? then 1 end) as deleted_count
659 from pulls
660 where repo_at = ?`,
661 models.PullOpen,
662 models.PullMerged,
663 models.PullClosed,
664 models.PullDeleted,
665 repoAt,
666 )
667
668 var count models.PullCount
669 if err := row.Scan(&count.Open, &count.Merged, &count.Closed, &count.Deleted); err != nil {
670 return models.PullCount{Open: 0, Merged: 0, Closed: 0, Deleted: 0}, err
671 }
672
673 return count, nil
674}
675
676// change-id parent-change-id
677//
678// 4 w ,-------- z (TOP)
679// 3 z <----',------- y
680// 2 y <-----',------ x
681// 1 x <------' nil (BOT)
682//
683// `w` is parent of none, so it is the top of the stack
684func GetStack(e Execer, stackId string) (models.Stack, error) {
685 unorderedPulls, err := GetPulls(
686 e,
687 FilterEq("stack_id", stackId),
688 FilterNotEq("state", models.PullDeleted),
689 )
690 if err != nil {
691 return nil, err
692 }
693 // map of parent-change-id to pull
694 changeIdMap := make(map[string]*models.Pull, len(unorderedPulls))
695 parentMap := make(map[string]*models.Pull, len(unorderedPulls))
696 for _, p := range unorderedPulls {
697 changeIdMap[p.ChangeId] = p
698 if p.ParentChangeId != "" {
699 parentMap[p.ParentChangeId] = p
700 }
701 }
702
703 // the top of the stack is the pull that is not a parent of any pull
704 var topPull *models.Pull
705 for _, maybeTop := range unorderedPulls {
706 if _, ok := parentMap[maybeTop.ChangeId]; !ok {
707 topPull = maybeTop
708 break
709 }
710 }
711
712 pulls := []*models.Pull{}
713 for {
714 pulls = append(pulls, topPull)
715 if topPull.ParentChangeId != "" {
716 if next, ok := changeIdMap[topPull.ParentChangeId]; ok {
717 topPull = next
718 } else {
719 return nil, fmt.Errorf("failed to find parent pull request, stack is malformed")
720 }
721 } else {
722 break
723 }
724 }
725
726 return pulls, nil
727}
728
729func GetAbandonedPulls(e Execer, stackId string) ([]*models.Pull, error) {
730 pulls, err := GetPulls(
731 e,
732 FilterEq("stack_id", stackId),
733 FilterEq("state", models.PullDeleted),
734 )
735 if err != nil {
736 return nil, err
737 }
738
739 return pulls, nil
740}