1package db
2
3import (
4 "database/sql"
5 "fmt"
6 "log"
7 "sort"
8 "strings"
9 "time"
10
11 "github.com/bluesky-social/indigo/atproto/syntax"
12 "tangled.org/core/appview/models"
13)
14
15func NewPull(tx *sql.Tx, pull *models.Pull) error {
16 _, err := tx.Exec(`
17 insert or ignore into repo_pull_seqs (repo_at, next_pull_id)
18 values (?, 1)
19 `, pull.RepoAt)
20 if err != nil {
21 return err
22 }
23
24 var nextId int
25 err = tx.QueryRow(`
26 update repo_pull_seqs
27 set next_pull_id = next_pull_id + 1
28 where repo_at = ?
29 returning next_pull_id - 1
30 `, pull.RepoAt).Scan(&nextId)
31 if err != nil {
32 return err
33 }
34
35 pull.PullId = nextId
36 pull.State = models.PullOpen
37
38 var sourceBranch, sourceRepoAt *string
39 if pull.PullSource != nil {
40 sourceBranch = &pull.PullSource.Branch
41 if pull.PullSource.RepoAt != nil {
42 x := pull.PullSource.RepoAt.String()
43 sourceRepoAt = &x
44 }
45 }
46
47 var stackId, changeId, parentChangeId *string
48 if pull.StackId != "" {
49 stackId = &pull.StackId
50 }
51 if pull.ChangeId != "" {
52 changeId = &pull.ChangeId
53 }
54 if pull.ParentChangeId != "" {
55 parentChangeId = &pull.ParentChangeId
56 }
57
58 _, err = tx.Exec(
59 `
60 insert into pulls (
61 repo_at, owner_did, pull_id, title, target_branch, body, rkey, state, source_branch, source_repo_at, stack_id, change_id, parent_change_id
62 )
63 values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`,
64 pull.RepoAt,
65 pull.OwnerDid,
66 pull.PullId,
67 pull.Title,
68 pull.TargetBranch,
69 pull.Body,
70 pull.Rkey,
71 pull.State,
72 sourceBranch,
73 sourceRepoAt,
74 stackId,
75 changeId,
76 parentChangeId,
77 )
78 if err != nil {
79 return err
80 }
81
82 _, err = tx.Exec(`
83 insert into pull_submissions (pull_id, repo_at, round_number, patch, source_rev)
84 values (?, ?, ?, ?, ?)
85 `, pull.PullId, pull.RepoAt, 0, pull.Submissions[0].Patch, pull.Submissions[0].SourceRev)
86 return err
87}
88
89func GetPullAt(e Execer, repoAt syntax.ATURI, pullId int) (syntax.ATURI, error) {
90 pull, err := GetPull(e, repoAt, pullId)
91 if err != nil {
92 return "", err
93 }
94 return pull.PullAt(), err
95}
96
97func NextPullId(e Execer, repoAt syntax.ATURI) (int, error) {
98 var pullId int
99 err := e.QueryRow(`select next_pull_id from repo_pull_seqs where repo_at = ?`, repoAt).Scan(&pullId)
100 return pullId - 1, err
101}
102
103func GetPullsWithLimit(e Execer, limit int, filters ...filter) ([]*models.Pull, error) {
104 pulls := make(map[int]*models.Pull)
105
106 var conditions []string
107 var args []any
108 for _, filter := range filters {
109 conditions = append(conditions, filter.Condition())
110 args = append(args, filter.Arg()...)
111 }
112
113 whereClause := ""
114 if conditions != nil {
115 whereClause = " where " + strings.Join(conditions, " and ")
116 }
117 limitClause := ""
118 if limit != 0 {
119 limitClause = fmt.Sprintf(" limit %d ", limit)
120 }
121
122 query := fmt.Sprintf(`
123 select
124 owner_did,
125 repo_at,
126 pull_id,
127 created,
128 title,
129 state,
130 target_branch,
131 body,
132 rkey,
133 source_branch,
134 source_repo_at,
135 stack_id,
136 change_id,
137 parent_change_id
138 from
139 pulls
140 %s
141 order by
142 created desc
143 %s
144 `, whereClause, limitClause)
145
146 rows, err := e.Query(query, args...)
147 if err != nil {
148 return nil, err
149 }
150 defer rows.Close()
151
152 for rows.Next() {
153 var pull models.Pull
154 var createdAt string
155 var sourceBranch, sourceRepoAt, stackId, changeId, parentChangeId sql.NullString
156 err := rows.Scan(
157 &pull.OwnerDid,
158 &pull.RepoAt,
159 &pull.PullId,
160 &createdAt,
161 &pull.Title,
162 &pull.State,
163 &pull.TargetBranch,
164 &pull.Body,
165 &pull.Rkey,
166 &sourceBranch,
167 &sourceRepoAt,
168 &stackId,
169 &changeId,
170 &parentChangeId,
171 )
172 if err != nil {
173 return nil, err
174 }
175
176 createdTime, err := time.Parse(time.RFC3339, createdAt)
177 if err != nil {
178 return nil, err
179 }
180 pull.Created = createdTime
181
182 if sourceBranch.Valid {
183 pull.PullSource = &models.PullSource{
184 Branch: sourceBranch.String,
185 }
186 if sourceRepoAt.Valid {
187 sourceRepoAtParsed, err := syntax.ParseATURI(sourceRepoAt.String)
188 if err != nil {
189 return nil, err
190 }
191 pull.PullSource.RepoAt = &sourceRepoAtParsed
192 }
193 }
194
195 if stackId.Valid {
196 pull.StackId = stackId.String
197 }
198 if changeId.Valid {
199 pull.ChangeId = changeId.String
200 }
201 if parentChangeId.Valid {
202 pull.ParentChangeId = parentChangeId.String
203 }
204
205 pulls[pull.PullId] = &pull
206 }
207
208 // get latest round no. for each pull
209 inClause := strings.TrimSuffix(strings.Repeat("?, ", len(pulls)), ", ")
210 submissionsQuery := fmt.Sprintf(`
211 select
212 id, pull_id, round_number, patch, created, source_rev
213 from
214 pull_submissions
215 where
216 repo_at in (%s) and pull_id in (%s)
217 `, inClause, inClause)
218
219 args = make([]any, len(pulls)*2)
220 idx := 0
221 for _, p := range pulls {
222 args[idx] = p.RepoAt
223 idx += 1
224 }
225 for _, p := range pulls {
226 args[idx] = p.PullId
227 idx += 1
228 }
229 submissionsRows, err := e.Query(submissionsQuery, args...)
230 if err != nil {
231 return nil, err
232 }
233 defer submissionsRows.Close()
234
235 for submissionsRows.Next() {
236 var s models.PullSubmission
237 var sourceRev sql.NullString
238 var createdAt string
239 err := submissionsRows.Scan(
240 &s.ID,
241 &s.PullId,
242 &s.RoundNumber,
243 &s.Patch,
244 &createdAt,
245 &sourceRev,
246 )
247 if err != nil {
248 return nil, err
249 }
250
251 createdTime, err := time.Parse(time.RFC3339, createdAt)
252 if err != nil {
253 return nil, err
254 }
255 s.Created = createdTime
256
257 if sourceRev.Valid {
258 s.SourceRev = sourceRev.String
259 }
260
261 if p, ok := pulls[s.PullId]; ok {
262 p.Submissions = make([]*models.PullSubmission, s.RoundNumber+1)
263 p.Submissions[s.RoundNumber] = &s
264 }
265 }
266 if err := rows.Err(); err != nil {
267 return nil, err
268 }
269
270 // get comment count on latest submission on each pull
271 inClause = strings.TrimSuffix(strings.Repeat("?, ", len(pulls)), ", ")
272 commentsQuery := fmt.Sprintf(`
273 select
274 count(id), pull_id
275 from
276 pull_comments
277 where
278 submission_id in (%s)
279 group by
280 submission_id
281 `, inClause)
282
283 args = []any{}
284 for _, p := range pulls {
285 args = append(args, p.Submissions[p.LastRoundNumber()].ID)
286 }
287 commentsRows, err := e.Query(commentsQuery, args...)
288 if err != nil {
289 return nil, err
290 }
291 defer commentsRows.Close()
292
293 for commentsRows.Next() {
294 var commentCount, pullId int
295 err := commentsRows.Scan(
296 &commentCount,
297 &pullId,
298 )
299 if err != nil {
300 return nil, err
301 }
302 if p, ok := pulls[pullId]; ok {
303 p.Submissions[p.LastRoundNumber()].Comments = make([]models.PullComment, commentCount)
304 }
305 }
306 if err := rows.Err(); err != nil {
307 return nil, err
308 }
309
310 orderedByPullId := []*models.Pull{}
311 for _, p := range pulls {
312 orderedByPullId = append(orderedByPullId, p)
313 }
314 sort.Slice(orderedByPullId, func(i, j int) bool {
315 return orderedByPullId[i].PullId > orderedByPullId[j].PullId
316 })
317
318 return orderedByPullId, nil
319}
320
321func GetPulls(e Execer, filters ...filter) ([]*models.Pull, error) {
322 return GetPullsWithLimit(e, 0, filters...)
323}
324
325func GetPull(e Execer, repoAt syntax.ATURI, pullId int) (*models.Pull, error) {
326 query := `
327 select
328 owner_did,
329 pull_id,
330 created,
331 title,
332 state,
333 target_branch,
334 repo_at,
335 body,
336 rkey,
337 source_branch,
338 source_repo_at,
339 stack_id,
340 change_id,
341 parent_change_id
342 from
343 pulls
344 where
345 repo_at = ? and pull_id = ?
346 `
347 row := e.QueryRow(query, repoAt, pullId)
348
349 var pull models.Pull
350 var createdAt string
351 var sourceBranch, sourceRepoAt, stackId, changeId, parentChangeId sql.NullString
352 err := row.Scan(
353 &pull.OwnerDid,
354 &pull.PullId,
355 &createdAt,
356 &pull.Title,
357 &pull.State,
358 &pull.TargetBranch,
359 &pull.RepoAt,
360 &pull.Body,
361 &pull.Rkey,
362 &sourceBranch,
363 &sourceRepoAt,
364 &stackId,
365 &changeId,
366 &parentChangeId,
367 )
368 if err != nil {
369 return nil, err
370 }
371
372 createdTime, err := time.Parse(time.RFC3339, createdAt)
373 if err != nil {
374 return nil, err
375 }
376 pull.Created = createdTime
377
378 // populate source
379 if sourceBranch.Valid {
380 pull.PullSource = &models.PullSource{
381 Branch: sourceBranch.String,
382 }
383 if sourceRepoAt.Valid {
384 sourceRepoAtParsed, err := syntax.ParseATURI(sourceRepoAt.String)
385 if err != nil {
386 return nil, err
387 }
388 pull.PullSource.RepoAt = &sourceRepoAtParsed
389 }
390 }
391
392 if stackId.Valid {
393 pull.StackId = stackId.String
394 }
395 if changeId.Valid {
396 pull.ChangeId = changeId.String
397 }
398 if parentChangeId.Valid {
399 pull.ParentChangeId = parentChangeId.String
400 }
401
402 submissionsQuery := `
403 select
404 id, pull_id, repo_at, round_number, patch, created, source_rev
405 from
406 pull_submissions
407 where
408 repo_at = ? and pull_id = ?
409 `
410 submissionsRows, err := e.Query(submissionsQuery, repoAt, pullId)
411 if err != nil {
412 return nil, err
413 }
414 defer submissionsRows.Close()
415
416 submissionsMap := make(map[int]*models.PullSubmission)
417
418 for submissionsRows.Next() {
419 var submission models.PullSubmission
420 var submissionCreatedStr string
421 var submissionSourceRev sql.NullString
422 err := submissionsRows.Scan(
423 &submission.ID,
424 &submission.PullId,
425 &submission.RepoAt,
426 &submission.RoundNumber,
427 &submission.Patch,
428 &submissionCreatedStr,
429 &submissionSourceRev,
430 )
431 if err != nil {
432 return nil, err
433 }
434
435 submissionCreatedTime, err := time.Parse(time.RFC3339, submissionCreatedStr)
436 if err != nil {
437 return nil, err
438 }
439 submission.Created = submissionCreatedTime
440
441 if submissionSourceRev.Valid {
442 submission.SourceRev = submissionSourceRev.String
443 }
444
445 submissionsMap[submission.ID] = &submission
446 }
447 if err = submissionsRows.Close(); err != nil {
448 return nil, err
449 }
450 if len(submissionsMap) == 0 {
451 return &pull, nil
452 }
453
454 var args []any
455 for k := range submissionsMap {
456 args = append(args, k)
457 }
458 inClause := strings.TrimSuffix(strings.Repeat("?, ", len(submissionsMap)), ", ")
459 commentsQuery := fmt.Sprintf(`
460 select
461 id,
462 pull_id,
463 submission_id,
464 repo_at,
465 owner_did,
466 comment_at,
467 body,
468 created
469 from
470 pull_comments
471 where
472 submission_id IN (%s)
473 order by
474 created asc
475 `, inClause)
476 commentsRows, err := e.Query(commentsQuery, args...)
477 if err != nil {
478 return nil, err
479 }
480 defer commentsRows.Close()
481
482 for commentsRows.Next() {
483 var comment models.PullComment
484 var commentCreatedStr string
485 err := commentsRows.Scan(
486 &comment.ID,
487 &comment.PullId,
488 &comment.SubmissionId,
489 &comment.RepoAt,
490 &comment.OwnerDid,
491 &comment.CommentAt,
492 &comment.Body,
493 &commentCreatedStr,
494 )
495 if err != nil {
496 return nil, err
497 }
498
499 commentCreatedTime, err := time.Parse(time.RFC3339, commentCreatedStr)
500 if err != nil {
501 return nil, err
502 }
503 comment.Created = commentCreatedTime
504
505 // Add the comment to its submission
506 if submission, ok := submissionsMap[comment.SubmissionId]; ok {
507 submission.Comments = append(submission.Comments, comment)
508 }
509
510 }
511 if err = commentsRows.Err(); err != nil {
512 return nil, err
513 }
514
515 var pullSourceRepo *models.Repo
516 if pull.PullSource != nil {
517 if pull.PullSource.RepoAt != nil {
518 pullSourceRepo, err = GetRepoByAtUri(e, pull.PullSource.RepoAt.String())
519 if err != nil {
520 log.Printf("failed to get repo by at uri: %v", err)
521 } else {
522 pull.PullSource.Repo = pullSourceRepo
523 }
524 }
525 }
526
527 pull.Submissions = make([]*models.PullSubmission, len(submissionsMap))
528 for _, submission := range submissionsMap {
529 pull.Submissions[submission.RoundNumber] = submission
530 }
531
532 return &pull, nil
533}
534
535// timeframe here is directly passed into the sql query filter, and any
536// timeframe in the past should be negative; e.g.: "-3 months"
537func GetPullsByOwnerDid(e Execer, did, timeframe string) ([]models.Pull, error) {
538 var pulls []models.Pull
539
540 rows, err := e.Query(`
541 select
542 p.owner_did,
543 p.repo_at,
544 p.pull_id,
545 p.created,
546 p.title,
547 p.state,
548 r.did,
549 r.name,
550 r.knot,
551 r.rkey,
552 r.created
553 from
554 pulls p
555 join
556 repos r on p.repo_at = r.at_uri
557 where
558 p.owner_did = ? and p.created >= date ('now', ?)
559 order by
560 p.created desc`, did, timeframe)
561 if err != nil {
562 return nil, err
563 }
564 defer rows.Close()
565
566 for rows.Next() {
567 var pull models.Pull
568 var repo models.Repo
569 var pullCreatedAt, repoCreatedAt string
570 err := rows.Scan(
571 &pull.OwnerDid,
572 &pull.RepoAt,
573 &pull.PullId,
574 &pullCreatedAt,
575 &pull.Title,
576 &pull.State,
577 &repo.Did,
578 &repo.Name,
579 &repo.Knot,
580 &repo.Rkey,
581 &repoCreatedAt,
582 )
583 if err != nil {
584 return nil, err
585 }
586
587 pullCreatedTime, err := time.Parse(time.RFC3339, pullCreatedAt)
588 if err != nil {
589 return nil, err
590 }
591 pull.Created = pullCreatedTime
592
593 repoCreatedTime, err := time.Parse(time.RFC3339, repoCreatedAt)
594 if err != nil {
595 return nil, err
596 }
597 repo.Created = repoCreatedTime
598
599 pull.Repo = &repo
600
601 pulls = append(pulls, pull)
602 }
603
604 if err := rows.Err(); err != nil {
605 return nil, err
606 }
607
608 return pulls, nil
609}
610
611func NewPullComment(e Execer, comment *models.PullComment) (int64, error) {
612 query := `insert into pull_comments (owner_did, repo_at, submission_id, comment_at, pull_id, body) values (?, ?, ?, ?, ?, ?)`
613 res, err := e.Exec(
614 query,
615 comment.OwnerDid,
616 comment.RepoAt,
617 comment.SubmissionId,
618 comment.CommentAt,
619 comment.PullId,
620 comment.Body,
621 )
622 if err != nil {
623 return 0, err
624 }
625
626 i, err := res.LastInsertId()
627 if err != nil {
628 return 0, err
629 }
630
631 return i, nil
632}
633
634func SetPullState(e Execer, repoAt syntax.ATURI, pullId int, pullState models.PullState) error {
635 _, err := e.Exec(
636 `update pulls set state = ? where repo_at = ? and pull_id = ? and (state <> ? or state <> ?)`,
637 pullState,
638 repoAt,
639 pullId,
640 models.PullDeleted, // only update state of non-deleted pulls
641 models.PullMerged, // only update state of non-merged pulls
642 )
643 return err
644}
645
646func ClosePull(e Execer, repoAt syntax.ATURI, pullId int) error {
647 err := SetPullState(e, repoAt, pullId, models.PullClosed)
648 return err
649}
650
651func ReopenPull(e Execer, repoAt syntax.ATURI, pullId int) error {
652 err := SetPullState(e, repoAt, pullId, models.PullOpen)
653 return err
654}
655
656func MergePull(e Execer, repoAt syntax.ATURI, pullId int) error {
657 err := SetPullState(e, repoAt, pullId, models.PullMerged)
658 return err
659}
660
661func DeletePull(e Execer, repoAt syntax.ATURI, pullId int) error {
662 err := SetPullState(e, repoAt, pullId, models.PullDeleted)
663 return err
664}
665
666func ResubmitPull(e Execer, pull *models.Pull, newPatch, sourceRev string) error {
667 newRoundNumber := len(pull.Submissions)
668 _, err := e.Exec(`
669 insert into pull_submissions (pull_id, repo_at, round_number, patch, source_rev)
670 values (?, ?, ?, ?, ?)
671 `, pull.PullId, pull.RepoAt, newRoundNumber, newPatch, sourceRev)
672
673 return err
674}
675
676func SetPullParentChangeId(e Execer, parentChangeId string, filters ...filter) error {
677 var conditions []string
678 var args []any
679
680 args = append(args, parentChangeId)
681
682 for _, filter := range filters {
683 conditions = append(conditions, filter.Condition())
684 args = append(args, filter.Arg()...)
685 }
686
687 whereClause := ""
688 if conditions != nil {
689 whereClause = " where " + strings.Join(conditions, " and ")
690 }
691
692 query := fmt.Sprintf("update pulls set parent_change_id = ? %s", whereClause)
693 _, err := e.Exec(query, args...)
694
695 return err
696}
697
698// Only used when stacking to update contents in the event of a rebase (the interdiff should be empty).
699// otherwise submissions are immutable
700func UpdatePull(e Execer, newPatch, sourceRev string, filters ...filter) error {
701 var conditions []string
702 var args []any
703
704 args = append(args, sourceRev)
705 args = append(args, newPatch)
706
707 for _, filter := range filters {
708 conditions = append(conditions, filter.Condition())
709 args = append(args, filter.Arg()...)
710 }
711
712 whereClause := ""
713 if conditions != nil {
714 whereClause = " where " + strings.Join(conditions, " and ")
715 }
716
717 query := fmt.Sprintf("update pull_submissions set source_rev = ?, patch = ? %s", whereClause)
718 _, err := e.Exec(query, args...)
719
720 return err
721}
722
723func GetPullCount(e Execer, repoAt syntax.ATURI) (models.PullCount, error) {
724 row := e.QueryRow(`
725 select
726 count(case when state = ? then 1 end) as open_count,
727 count(case when state = ? then 1 end) as merged_count,
728 count(case when state = ? then 1 end) as closed_count,
729 count(case when state = ? then 1 end) as deleted_count
730 from pulls
731 where repo_at = ?`,
732 models.PullOpen,
733 models.PullMerged,
734 models.PullClosed,
735 models.PullDeleted,
736 repoAt,
737 )
738
739 var count models.PullCount
740 if err := row.Scan(&count.Open, &count.Merged, &count.Closed, &count.Deleted); err != nil {
741 return models.PullCount{Open: 0, Merged: 0, Closed: 0, Deleted: 0}, err
742 }
743
744 return count, nil
745}
746
747// change-id parent-change-id
748//
749// 4 w ,-------- z (TOP)
750// 3 z <----',------- y
751// 2 y <-----',------ x
752// 1 x <------' nil (BOT)
753//
754// `w` is parent of none, so it is the top of the stack
755func GetStack(e Execer, stackId string) (models.Stack, error) {
756 unorderedPulls, err := GetPulls(
757 e,
758 FilterEq("stack_id", stackId),
759 FilterNotEq("state", models.PullDeleted),
760 )
761 if err != nil {
762 return nil, err
763 }
764 // map of parent-change-id to pull
765 changeIdMap := make(map[string]*models.Pull, len(unorderedPulls))
766 parentMap := make(map[string]*models.Pull, len(unorderedPulls))
767 for _, p := range unorderedPulls {
768 changeIdMap[p.ChangeId] = p
769 if p.ParentChangeId != "" {
770 parentMap[p.ParentChangeId] = p
771 }
772 }
773
774 // the top of the stack is the pull that is not a parent of any pull
775 var topPull *models.Pull
776 for _, maybeTop := range unorderedPulls {
777 if _, ok := parentMap[maybeTop.ChangeId]; !ok {
778 topPull = maybeTop
779 break
780 }
781 }
782
783 pulls := []*models.Pull{}
784 for {
785 pulls = append(pulls, topPull)
786 if topPull.ParentChangeId != "" {
787 if next, ok := changeIdMap[topPull.ParentChangeId]; ok {
788 topPull = next
789 } else {
790 return nil, fmt.Errorf("failed to find parent pull request, stack is malformed")
791 }
792 } else {
793 break
794 }
795 }
796
797 return pulls, nil
798}
799
800func GetAbandonedPulls(e Execer, stackId string) ([]*models.Pull, error) {
801 pulls, err := GetPulls(
802 e,
803 FilterEq("stack_id", stackId),
804 FilterEq("state", models.PullDeleted),
805 )
806 if err != nil {
807 return nil, err
808 }
809
810 return pulls, nil
811}