1package db
2
3import (
4 "database/sql"
5 "fmt"
6 "log"
7 "sort"
8 "strings"
9 "time"
10
11 "github.com/bluesky-social/indigo/atproto/syntax"
12 "tangled.org/core/appview/models"
13)
14
15func NewPull(tx *sql.Tx, pull *models.Pull) error {
16 _, err := tx.Exec(`
17 insert or ignore into repo_pull_seqs (repo_at, next_pull_id)
18 values (?, 1)
19 `, pull.RepoAt)
20 if err != nil {
21 return err
22 }
23
24 var nextId int
25 err = tx.QueryRow(`
26 update repo_pull_seqs
27 set next_pull_id = next_pull_id + 1
28 where repo_at = ?
29 returning next_pull_id - 1
30 `, pull.RepoAt).Scan(&nextId)
31 if err != nil {
32 return err
33 }
34
35 pull.PullId = nextId
36 pull.State = models.PullOpen
37
38 var sourceBranch, sourceRepoAt *string
39 if pull.PullSource != nil {
40 sourceBranch = &pull.PullSource.Branch
41 if pull.PullSource.RepoAt != nil {
42 x := pull.PullSource.RepoAt.String()
43 sourceRepoAt = &x
44 }
45 }
46
47 var stackId, changeId, parentChangeId *string
48 if pull.StackId != "" {
49 stackId = &pull.StackId
50 }
51 if pull.ChangeId != "" {
52 changeId = &pull.ChangeId
53 }
54 if pull.ParentChangeId != "" {
55 parentChangeId = &pull.ParentChangeId
56 }
57
58 result, err := tx.Exec(
59 `
60 insert into pulls (
61 repo_at, owner_did, pull_id, title, target_branch, body, rkey, state, source_branch, source_repo_at, stack_id, change_id, parent_change_id
62 )
63 values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`,
64 pull.RepoAt,
65 pull.OwnerDid,
66 pull.PullId,
67 pull.Title,
68 pull.TargetBranch,
69 pull.Body,
70 pull.Rkey,
71 pull.State,
72 sourceBranch,
73 sourceRepoAt,
74 stackId,
75 changeId,
76 parentChangeId,
77 )
78 if err != nil {
79 return err
80 }
81
82 // Set the database primary key ID
83 id, err := result.LastInsertId()
84 if err != nil {
85 return err
86 }
87 pull.ID = int(id)
88
89 _, err = tx.Exec(`
90 insert into pull_submissions (pull_id, repo_at, round_number, patch, source_rev)
91 values (?, ?, ?, ?, ?)
92 `, pull.PullId, pull.RepoAt, 0, pull.Submissions[0].Patch, pull.Submissions[0].SourceRev)
93 return err
94}
95
96func GetPullAt(e Execer, repoAt syntax.ATURI, pullId int) (syntax.ATURI, error) {
97 pull, err := GetPull(e, repoAt, pullId)
98 if err != nil {
99 return "", err
100 }
101 return pull.PullAt(), err
102}
103
104func NextPullId(e Execer, repoAt syntax.ATURI) (int, error) {
105 var pullId int
106 err := e.QueryRow(`select next_pull_id from repo_pull_seqs where repo_at = ?`, repoAt).Scan(&pullId)
107 return pullId - 1, err
108}
109
110func GetPullsWithLimit(e Execer, limit int, filters ...filter) ([]*models.Pull, error) {
111 pulls := make(map[int]*models.Pull)
112
113 var conditions []string
114 var args []any
115 for _, filter := range filters {
116 conditions = append(conditions, filter.Condition())
117 args = append(args, filter.Arg()...)
118 }
119
120 whereClause := ""
121 if conditions != nil {
122 whereClause = " where " + strings.Join(conditions, " and ")
123 }
124 limitClause := ""
125 if limit != 0 {
126 limitClause = fmt.Sprintf(" limit %d ", limit)
127 }
128
129 query := fmt.Sprintf(`
130 select
131 id,
132 owner_did,
133 repo_at,
134 pull_id,
135 created,
136 title,
137 state,
138 target_branch,
139 body,
140 rkey,
141 source_branch,
142 source_repo_at,
143 stack_id,
144 change_id,
145 parent_change_id
146 from
147 pulls
148 %s
149 order by
150 created desc
151 %s
152 `, whereClause, limitClause)
153
154 rows, err := e.Query(query, args...)
155 if err != nil {
156 return nil, err
157 }
158 defer rows.Close()
159
160 for rows.Next() {
161 var pull models.Pull
162 var createdAt string
163 var sourceBranch, sourceRepoAt, stackId, changeId, parentChangeId sql.NullString
164 err := rows.Scan(
165 &pull.ID,
166 &pull.OwnerDid,
167 &pull.RepoAt,
168 &pull.PullId,
169 &createdAt,
170 &pull.Title,
171 &pull.State,
172 &pull.TargetBranch,
173 &pull.Body,
174 &pull.Rkey,
175 &sourceBranch,
176 &sourceRepoAt,
177 &stackId,
178 &changeId,
179 &parentChangeId,
180 )
181 if err != nil {
182 return nil, err
183 }
184
185 createdTime, err := time.Parse(time.RFC3339, createdAt)
186 if err != nil {
187 return nil, err
188 }
189 pull.Created = createdTime
190
191 if sourceBranch.Valid {
192 pull.PullSource = &models.PullSource{
193 Branch: sourceBranch.String,
194 }
195 if sourceRepoAt.Valid {
196 sourceRepoAtParsed, err := syntax.ParseATURI(sourceRepoAt.String)
197 if err != nil {
198 return nil, err
199 }
200 pull.PullSource.RepoAt = &sourceRepoAtParsed
201 }
202 }
203
204 if stackId.Valid {
205 pull.StackId = stackId.String
206 }
207 if changeId.Valid {
208 pull.ChangeId = changeId.String
209 }
210 if parentChangeId.Valid {
211 pull.ParentChangeId = parentChangeId.String
212 }
213
214 pulls[pull.PullId] = &pull
215 }
216
217 // get latest round no. for each pull
218 inClause := strings.TrimSuffix(strings.Repeat("?, ", len(pulls)), ", ")
219 submissionsQuery := fmt.Sprintf(`
220 select
221 id, pull_id, round_number, patch, created, source_rev
222 from
223 pull_submissions
224 where
225 repo_at in (%s) and pull_id in (%s)
226 `, inClause, inClause)
227
228 args = make([]any, len(pulls)*2)
229 idx := 0
230 for _, p := range pulls {
231 args[idx] = p.RepoAt
232 idx += 1
233 }
234 for _, p := range pulls {
235 args[idx] = p.PullId
236 idx += 1
237 }
238 submissionsRows, err := e.Query(submissionsQuery, args...)
239 if err != nil {
240 return nil, err
241 }
242 defer submissionsRows.Close()
243
244 for submissionsRows.Next() {
245 var s models.PullSubmission
246 var sourceRev sql.NullString
247 var createdAt string
248 err := submissionsRows.Scan(
249 &s.ID,
250 &s.PullId,
251 &s.RoundNumber,
252 &s.Patch,
253 &createdAt,
254 &sourceRev,
255 )
256 if err != nil {
257 return nil, err
258 }
259
260 createdTime, err := time.Parse(time.RFC3339, createdAt)
261 if err != nil {
262 return nil, err
263 }
264 s.Created = createdTime
265
266 if sourceRev.Valid {
267 s.SourceRev = sourceRev.String
268 }
269
270 if p, ok := pulls[s.PullId]; ok {
271 p.Submissions = make([]*models.PullSubmission, s.RoundNumber+1)
272 p.Submissions[s.RoundNumber] = &s
273 }
274 }
275 if err := rows.Err(); err != nil {
276 return nil, err
277 }
278
279 // get comment count on latest submission on each pull
280 inClause = strings.TrimSuffix(strings.Repeat("?, ", len(pulls)), ", ")
281 commentsQuery := fmt.Sprintf(`
282 select
283 count(id), pull_id
284 from
285 pull_comments
286 where
287 submission_id in (%s)
288 group by
289 submission_id
290 `, inClause)
291
292 args = []any{}
293 for _, p := range pulls {
294 args = append(args, p.Submissions[p.LastRoundNumber()].ID)
295 }
296 commentsRows, err := e.Query(commentsQuery, args...)
297 if err != nil {
298 return nil, err
299 }
300 defer commentsRows.Close()
301
302 for commentsRows.Next() {
303 var commentCount, pullId int
304 err := commentsRows.Scan(
305 &commentCount,
306 &pullId,
307 )
308 if err != nil {
309 return nil, err
310 }
311 if p, ok := pulls[pullId]; ok {
312 p.Submissions[p.LastRoundNumber()].Comments = make([]models.PullComment, commentCount)
313 }
314 }
315 if err := rows.Err(); err != nil {
316 return nil, err
317 }
318
319 orderedByPullId := []*models.Pull{}
320 for _, p := range pulls {
321 orderedByPullId = append(orderedByPullId, p)
322 }
323 sort.Slice(orderedByPullId, func(i, j int) bool {
324 return orderedByPullId[i].PullId > orderedByPullId[j].PullId
325 })
326
327 return orderedByPullId, nil
328}
329
330func GetPulls(e Execer, filters ...filter) ([]*models.Pull, error) {
331 return GetPullsWithLimit(e, 0, filters...)
332}
333
334func GetPull(e Execer, repoAt syntax.ATURI, pullId int) (*models.Pull, error) {
335 query := `
336 select
337 id,
338 owner_did,
339 pull_id,
340 created,
341 title,
342 state,
343 target_branch,
344 repo_at,
345 body,
346 rkey,
347 source_branch,
348 source_repo_at,
349 stack_id,
350 change_id,
351 parent_change_id
352 from
353 pulls
354 where
355 repo_at = ? and pull_id = ?
356 `
357 row := e.QueryRow(query, repoAt, pullId)
358
359 var pull models.Pull
360 var createdAt string
361 var sourceBranch, sourceRepoAt, stackId, changeId, parentChangeId sql.NullString
362 err := row.Scan(
363 &pull.ID,
364 &pull.OwnerDid,
365 &pull.PullId,
366 &createdAt,
367 &pull.Title,
368 &pull.State,
369 &pull.TargetBranch,
370 &pull.RepoAt,
371 &pull.Body,
372 &pull.Rkey,
373 &sourceBranch,
374 &sourceRepoAt,
375 &stackId,
376 &changeId,
377 &parentChangeId,
378 )
379 if err != nil {
380 return nil, err
381 }
382
383 createdTime, err := time.Parse(time.RFC3339, createdAt)
384 if err != nil {
385 return nil, err
386 }
387 pull.Created = createdTime
388
389 // populate source
390 if sourceBranch.Valid {
391 pull.PullSource = &models.PullSource{
392 Branch: sourceBranch.String,
393 }
394 if sourceRepoAt.Valid {
395 sourceRepoAtParsed, err := syntax.ParseATURI(sourceRepoAt.String)
396 if err != nil {
397 return nil, err
398 }
399 pull.PullSource.RepoAt = &sourceRepoAtParsed
400 }
401 }
402
403 if stackId.Valid {
404 pull.StackId = stackId.String
405 }
406 if changeId.Valid {
407 pull.ChangeId = changeId.String
408 }
409 if parentChangeId.Valid {
410 pull.ParentChangeId = parentChangeId.String
411 }
412
413 submissionsQuery := `
414 select
415 id, pull_id, repo_at, round_number, patch, created, source_rev
416 from
417 pull_submissions
418 where
419 repo_at = ? and pull_id = ?
420 `
421 submissionsRows, err := e.Query(submissionsQuery, repoAt, pullId)
422 if err != nil {
423 return nil, err
424 }
425 defer submissionsRows.Close()
426
427 submissionsMap := make(map[int]*models.PullSubmission)
428
429 for submissionsRows.Next() {
430 var submission models.PullSubmission
431 var submissionCreatedStr string
432 var submissionSourceRev sql.NullString
433 err := submissionsRows.Scan(
434 &submission.ID,
435 &submission.PullId,
436 &submission.RepoAt,
437 &submission.RoundNumber,
438 &submission.Patch,
439 &submissionCreatedStr,
440 &submissionSourceRev,
441 )
442 if err != nil {
443 return nil, err
444 }
445
446 submissionCreatedTime, err := time.Parse(time.RFC3339, submissionCreatedStr)
447 if err != nil {
448 return nil, err
449 }
450 submission.Created = submissionCreatedTime
451
452 if submissionSourceRev.Valid {
453 submission.SourceRev = submissionSourceRev.String
454 }
455
456 submissionsMap[submission.ID] = &submission
457 }
458 if err = submissionsRows.Close(); err != nil {
459 return nil, err
460 }
461 if len(submissionsMap) == 0 {
462 return &pull, nil
463 }
464
465 var args []any
466 for k := range submissionsMap {
467 args = append(args, k)
468 }
469 inClause := strings.TrimSuffix(strings.Repeat("?, ", len(submissionsMap)), ", ")
470 commentsQuery := fmt.Sprintf(`
471 select
472 id,
473 pull_id,
474 submission_id,
475 repo_at,
476 owner_did,
477 comment_at,
478 body,
479 created
480 from
481 pull_comments
482 where
483 submission_id IN (%s)
484 order by
485 created asc
486 `, inClause)
487 commentsRows, err := e.Query(commentsQuery, args...)
488 if err != nil {
489 return nil, err
490 }
491 defer commentsRows.Close()
492
493 for commentsRows.Next() {
494 var comment models.PullComment
495 var commentCreatedStr string
496 err := commentsRows.Scan(
497 &comment.ID,
498 &comment.PullId,
499 &comment.SubmissionId,
500 &comment.RepoAt,
501 &comment.OwnerDid,
502 &comment.CommentAt,
503 &comment.Body,
504 &commentCreatedStr,
505 )
506 if err != nil {
507 return nil, err
508 }
509
510 commentCreatedTime, err := time.Parse(time.RFC3339, commentCreatedStr)
511 if err != nil {
512 return nil, err
513 }
514 comment.Created = commentCreatedTime
515
516 // Add the comment to its submission
517 if submission, ok := submissionsMap[comment.SubmissionId]; ok {
518 submission.Comments = append(submission.Comments, comment)
519 }
520
521 }
522 if err = commentsRows.Err(); err != nil {
523 return nil, err
524 }
525
526 var pullSourceRepo *models.Repo
527 if pull.PullSource != nil {
528 if pull.PullSource.RepoAt != nil {
529 pullSourceRepo, err = GetRepoByAtUri(e, pull.PullSource.RepoAt.String())
530 if err != nil {
531 log.Printf("failed to get repo by at uri: %v", err)
532 } else {
533 pull.PullSource.Repo = pullSourceRepo
534 }
535 }
536 }
537
538 pull.Submissions = make([]*models.PullSubmission, len(submissionsMap))
539 for _, submission := range submissionsMap {
540 pull.Submissions[submission.RoundNumber] = submission
541 }
542
543 return &pull, nil
544}
545
546// timeframe here is directly passed into the sql query filter, and any
547// timeframe in the past should be negative; e.g.: "-3 months"
548func GetPullsByOwnerDid(e Execer, did, timeframe string) ([]models.Pull, error) {
549 var pulls []models.Pull
550
551 rows, err := e.Query(`
552 select
553 p.owner_did,
554 p.repo_at,
555 p.pull_id,
556 p.created,
557 p.title,
558 p.state,
559 r.did,
560 r.name,
561 r.knot,
562 r.rkey,
563 r.created
564 from
565 pulls p
566 join
567 repos r on p.repo_at = r.at_uri
568 where
569 p.owner_did = ? and p.created >= date ('now', ?)
570 order by
571 p.created desc`, did, timeframe)
572 if err != nil {
573 return nil, err
574 }
575 defer rows.Close()
576
577 for rows.Next() {
578 var pull models.Pull
579 var repo models.Repo
580 var pullCreatedAt, repoCreatedAt string
581 err := rows.Scan(
582 &pull.OwnerDid,
583 &pull.RepoAt,
584 &pull.PullId,
585 &pullCreatedAt,
586 &pull.Title,
587 &pull.State,
588 &repo.Did,
589 &repo.Name,
590 &repo.Knot,
591 &repo.Rkey,
592 &repoCreatedAt,
593 )
594 if err != nil {
595 return nil, err
596 }
597
598 pullCreatedTime, err := time.Parse(time.RFC3339, pullCreatedAt)
599 if err != nil {
600 return nil, err
601 }
602 pull.Created = pullCreatedTime
603
604 repoCreatedTime, err := time.Parse(time.RFC3339, repoCreatedAt)
605 if err != nil {
606 return nil, err
607 }
608 repo.Created = repoCreatedTime
609
610 pull.Repo = &repo
611
612 pulls = append(pulls, pull)
613 }
614
615 if err := rows.Err(); err != nil {
616 return nil, err
617 }
618
619 return pulls, nil
620}
621
622func NewPullComment(e Execer, comment *models.PullComment) (int64, error) {
623 query := `insert into pull_comments (owner_did, repo_at, submission_id, comment_at, pull_id, body) values (?, ?, ?, ?, ?, ?)`
624 res, err := e.Exec(
625 query,
626 comment.OwnerDid,
627 comment.RepoAt,
628 comment.SubmissionId,
629 comment.CommentAt,
630 comment.PullId,
631 comment.Body,
632 )
633 if err != nil {
634 return 0, err
635 }
636
637 i, err := res.LastInsertId()
638 if err != nil {
639 return 0, err
640 }
641
642 return i, nil
643}
644
645func SetPullState(e Execer, repoAt syntax.ATURI, pullId int, pullState models.PullState) error {
646 _, err := e.Exec(
647 `update pulls set state = ? where repo_at = ? and pull_id = ? and (state <> ? or state <> ?)`,
648 pullState,
649 repoAt,
650 pullId,
651 models.PullDeleted, // only update state of non-deleted pulls
652 models.PullMerged, // only update state of non-merged pulls
653 )
654 return err
655}
656
657func ClosePull(e Execer, repoAt syntax.ATURI, pullId int) error {
658 err := SetPullState(e, repoAt, pullId, models.PullClosed)
659 return err
660}
661
662func ReopenPull(e Execer, repoAt syntax.ATURI, pullId int) error {
663 err := SetPullState(e, repoAt, pullId, models.PullOpen)
664 return err
665}
666
667func MergePull(e Execer, repoAt syntax.ATURI, pullId int) error {
668 err := SetPullState(e, repoAt, pullId, models.PullMerged)
669 return err
670}
671
672func DeletePull(e Execer, repoAt syntax.ATURI, pullId int) error {
673 err := SetPullState(e, repoAt, pullId, models.PullDeleted)
674 return err
675}
676
677func ResubmitPull(e Execer, pull *models.Pull, newPatch, sourceRev string) error {
678 newRoundNumber := len(pull.Submissions)
679 _, err := e.Exec(`
680 insert into pull_submissions (pull_id, repo_at, round_number, patch, source_rev)
681 values (?, ?, ?, ?, ?)
682 `, pull.PullId, pull.RepoAt, newRoundNumber, newPatch, sourceRev)
683
684 return err
685}
686
687func SetPullParentChangeId(e Execer, parentChangeId string, filters ...filter) error {
688 var conditions []string
689 var args []any
690
691 args = append(args, parentChangeId)
692
693 for _, filter := range filters {
694 conditions = append(conditions, filter.Condition())
695 args = append(args, filter.Arg()...)
696 }
697
698 whereClause := ""
699 if conditions != nil {
700 whereClause = " where " + strings.Join(conditions, " and ")
701 }
702
703 query := fmt.Sprintf("update pulls set parent_change_id = ? %s", whereClause)
704 _, err := e.Exec(query, args...)
705
706 return err
707}
708
709// Only used when stacking to update contents in the event of a rebase (the interdiff should be empty).
710// otherwise submissions are immutable
711func UpdatePull(e Execer, newPatch, sourceRev string, filters ...filter) error {
712 var conditions []string
713 var args []any
714
715 args = append(args, sourceRev)
716 args = append(args, newPatch)
717
718 for _, filter := range filters {
719 conditions = append(conditions, filter.Condition())
720 args = append(args, filter.Arg()...)
721 }
722
723 whereClause := ""
724 if conditions != nil {
725 whereClause = " where " + strings.Join(conditions, " and ")
726 }
727
728 query := fmt.Sprintf("update pull_submissions set source_rev = ?, patch = ? %s", whereClause)
729 _, err := e.Exec(query, args...)
730
731 return err
732}
733
734func GetPullCount(e Execer, repoAt syntax.ATURI) (models.PullCount, error) {
735 row := e.QueryRow(`
736 select
737 count(case when state = ? then 1 end) as open_count,
738 count(case when state = ? then 1 end) as merged_count,
739 count(case when state = ? then 1 end) as closed_count,
740 count(case when state = ? then 1 end) as deleted_count
741 from pulls
742 where repo_at = ?`,
743 models.PullOpen,
744 models.PullMerged,
745 models.PullClosed,
746 models.PullDeleted,
747 repoAt,
748 )
749
750 var count models.PullCount
751 if err := row.Scan(&count.Open, &count.Merged, &count.Closed, &count.Deleted); err != nil {
752 return models.PullCount{Open: 0, Merged: 0, Closed: 0, Deleted: 0}, err
753 }
754
755 return count, nil
756}
757
758// change-id parent-change-id
759//
760// 4 w ,-------- z (TOP)
761// 3 z <----',------- y
762// 2 y <-----',------ x
763// 1 x <------' nil (BOT)
764//
765// `w` is parent of none, so it is the top of the stack
766func GetStack(e Execer, stackId string) (models.Stack, error) {
767 unorderedPulls, err := GetPulls(
768 e,
769 FilterEq("stack_id", stackId),
770 FilterNotEq("state", models.PullDeleted),
771 )
772 if err != nil {
773 return nil, err
774 }
775 // map of parent-change-id to pull
776 changeIdMap := make(map[string]*models.Pull, len(unorderedPulls))
777 parentMap := make(map[string]*models.Pull, len(unorderedPulls))
778 for _, p := range unorderedPulls {
779 changeIdMap[p.ChangeId] = p
780 if p.ParentChangeId != "" {
781 parentMap[p.ParentChangeId] = p
782 }
783 }
784
785 // the top of the stack is the pull that is not a parent of any pull
786 var topPull *models.Pull
787 for _, maybeTop := range unorderedPulls {
788 if _, ok := parentMap[maybeTop.ChangeId]; !ok {
789 topPull = maybeTop
790 break
791 }
792 }
793
794 pulls := []*models.Pull{}
795 for {
796 pulls = append(pulls, topPull)
797 if topPull.ParentChangeId != "" {
798 if next, ok := changeIdMap[topPull.ParentChangeId]; ok {
799 topPull = next
800 } else {
801 return nil, fmt.Errorf("failed to find parent pull request, stack is malformed")
802 }
803 } else {
804 break
805 }
806 }
807
808 return pulls, nil
809}
810
811func GetAbandonedPulls(e Execer, stackId string) ([]*models.Pull, error) {
812 pulls, err := GetPulls(
813 e,
814 FilterEq("stack_id", stackId),
815 FilterEq("state", models.PullDeleted),
816 )
817 if err != nil {
818 return nil, err
819 }
820
821 return pulls, nil
822}