1package db
2
3import (
4 "database/sql"
5 "fmt"
6 "log"
7 "sort"
8 "strings"
9 "time"
10
11 "github.com/bluekeyes/go-gitdiff/gitdiff"
12 "github.com/bluesky-social/indigo/atproto/syntax"
13 "tangled.sh/tangled.sh/core/patchutil"
14 "tangled.sh/tangled.sh/core/types"
15)
16
17type PullState int
18
19const (
20 PullClosed PullState = iota
21 PullOpen
22 PullMerged
23)
24
25func (p PullState) String() string {
26 switch p {
27 case PullOpen:
28 return "open"
29 case PullMerged:
30 return "merged"
31 case PullClosed:
32 return "closed"
33 default:
34 return "closed"
35 }
36}
37
38func (p PullState) IsOpen() bool {
39 return p == PullOpen
40}
41func (p PullState) IsMerged() bool {
42 return p == PullMerged
43}
44func (p PullState) IsClosed() bool {
45 return p == PullClosed
46}
47
48type Pull struct {
49 // ids
50 ID int
51 PullId int
52
53 // at ids
54 RepoAt syntax.ATURI
55 OwnerDid string
56 Rkey string
57 PullAt syntax.ATURI
58
59 // content
60 Title string
61 Body string
62 TargetBranch string
63 State PullState
64 Submissions []*PullSubmission
65
66 // meta
67 Created time.Time
68 PullSource *PullSource
69
70 // optionally, populate this when querying for reverse mappings
71 Repo *Repo
72}
73
74type PullSource struct {
75 Branch string
76 RepoAt *syntax.ATURI
77
78 // optionally populate this for reverse mappings
79 Repo *Repo
80}
81
82type PullSubmission struct {
83 // ids
84 ID int
85 PullId int
86
87 // at ids
88 RepoAt syntax.ATURI
89
90 // content
91 RoundNumber int
92 Patch string
93 Comments []PullComment
94 SourceRev string // include the rev that was used to create this submission: only for branch PRs
95
96 // meta
97 Created time.Time
98}
99
100type PullComment struct {
101 // ids
102 ID int
103 PullId int
104 SubmissionId int
105
106 // at ids
107 RepoAt string
108 OwnerDid string
109 CommentAt string
110
111 // content
112 Body string
113
114 // meta
115 Created time.Time
116}
117
118func (p *Pull) LatestPatch() string {
119 latestSubmission := p.Submissions[p.LastRoundNumber()]
120 return latestSubmission.Patch
121}
122
123func (p *Pull) LastRoundNumber() int {
124 return len(p.Submissions) - 1
125}
126
127func (p *Pull) IsPatchBased() bool {
128 return p.PullSource == nil
129}
130
131func (p *Pull) IsBranchBased() bool {
132 if p.PullSource != nil {
133 if p.PullSource.RepoAt != nil {
134 return p.PullSource.RepoAt == &p.RepoAt
135 } else {
136 // no repo specified
137 return true
138 }
139 }
140 return false
141}
142
143func (p *Pull) IsForkBased() bool {
144 if p.PullSource != nil {
145 if p.PullSource.RepoAt != nil {
146 // make sure repos are different
147 return p.PullSource.RepoAt != &p.RepoAt
148 }
149 }
150 return false
151}
152
153func (s PullSubmission) AsDiff(targetBranch string) ([]*gitdiff.File, error) {
154 patch := s.Patch
155
156 // if format-patch; then extract each patch
157 var diffs []*gitdiff.File
158 if patchutil.IsFormatPatch(patch) {
159 patches, err := patchutil.ExtractPatches(patch)
160 if err != nil {
161 return nil, err
162 }
163 var ps [][]*gitdiff.File
164 for _, p := range patches {
165 ps = append(ps, p.Files)
166 }
167
168 diffs = patchutil.CombineDiff(ps...)
169 } else {
170 d, _, err := gitdiff.Parse(strings.NewReader(patch))
171 if err != nil {
172 return nil, err
173 }
174 diffs = d
175 }
176
177 return diffs, nil
178}
179
180func (s PullSubmission) AsNiceDiff(targetBranch string) types.NiceDiff {
181 diffs, err := s.AsDiff(targetBranch)
182 if err != nil {
183 log.Println(err)
184 }
185
186 nd := types.NiceDiff{}
187 nd.Commit.Parent = targetBranch
188
189 for _, d := range diffs {
190 ndiff := types.Diff{}
191 ndiff.Name.New = d.NewName
192 ndiff.Name.Old = d.OldName
193 ndiff.IsBinary = d.IsBinary
194 ndiff.IsNew = d.IsNew
195 ndiff.IsDelete = d.IsDelete
196 ndiff.IsCopy = d.IsCopy
197 ndiff.IsRename = d.IsRename
198
199 for _, tf := range d.TextFragments {
200 ndiff.TextFragments = append(ndiff.TextFragments, *tf)
201 for _, l := range tf.Lines {
202 switch l.Op {
203 case gitdiff.OpAdd:
204 nd.Stat.Insertions += 1
205 case gitdiff.OpDelete:
206 nd.Stat.Deletions += 1
207 }
208 }
209 }
210
211 nd.Diff = append(nd.Diff, ndiff)
212 }
213
214 nd.Stat.FilesChanged = len(diffs)
215
216 return nd
217}
218
219func (s PullSubmission) IsFormatPatch() bool {
220 return patchutil.IsFormatPatch(s.Patch)
221}
222
223func (s PullSubmission) AsFormatPatch() []patchutil.FormatPatch {
224 patches, err := patchutil.ExtractPatches(s.Patch)
225 if err != nil {
226 log.Println("error extracting patches from submission:", err)
227 return []patchutil.FormatPatch{}
228 }
229
230 return patches
231}
232
233func NewPull(tx *sql.Tx, pull *Pull) error {
234 defer tx.Rollback()
235
236 _, err := tx.Exec(`
237 insert or ignore into repo_pull_seqs (repo_at, next_pull_id)
238 values (?, 1)
239 `, pull.RepoAt)
240 if err != nil {
241 return err
242 }
243
244 var nextId int
245 err = tx.QueryRow(`
246 update repo_pull_seqs
247 set next_pull_id = next_pull_id + 1
248 where repo_at = ?
249 returning next_pull_id - 1
250 `, pull.RepoAt).Scan(&nextId)
251 if err != nil {
252 return err
253 }
254
255 pull.PullId = nextId
256 pull.State = PullOpen
257
258 var sourceBranch, sourceRepoAt *string
259 if pull.PullSource != nil {
260 sourceBranch = &pull.PullSource.Branch
261 if pull.PullSource.RepoAt != nil {
262 x := pull.PullSource.RepoAt.String()
263 sourceRepoAt = &x
264 }
265 }
266
267 _, err = tx.Exec(
268 `
269 insert into pulls (repo_at, owner_did, pull_id, title, target_branch, body, rkey, state, source_branch, source_repo_at)
270 values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`,
271 pull.RepoAt,
272 pull.OwnerDid,
273 pull.PullId,
274 pull.Title,
275 pull.TargetBranch,
276 pull.Body,
277 pull.Rkey,
278 pull.State,
279 sourceBranch,
280 sourceRepoAt,
281 )
282 if err != nil {
283 return err
284 }
285
286 _, err = tx.Exec(`
287 insert into pull_submissions (pull_id, repo_at, round_number, patch, source_rev)
288 values (?, ?, ?, ?, ?)
289 `, pull.PullId, pull.RepoAt, 0, pull.Submissions[0].Patch, pull.Submissions[0].SourceRev)
290 if err != nil {
291 return err
292 }
293
294 if err := tx.Commit(); err != nil {
295 return err
296 }
297
298 return nil
299}
300
301func SetPullAt(e Execer, repoAt syntax.ATURI, pullId int, pullAt string) error {
302 _, err := e.Exec(`update pulls set pull_at = ? where repo_at = ? and pull_id = ?`, pullAt, repoAt, pullId)
303 return err
304}
305
306func GetPullAt(e Execer, repoAt syntax.ATURI, pullId int) (string, error) {
307 var pullAt string
308 err := e.QueryRow(`select pull_at from pulls where repo_at = ? and pull_id = ?`, repoAt, pullId).Scan(&pullAt)
309 return pullAt, err
310}
311
312func NextPullId(e Execer, repoAt syntax.ATURI) (int, error) {
313 var pullId int
314 err := e.QueryRow(`select next_pull_id from repo_pull_seqs where repo_at = ?`, repoAt).Scan(&pullId)
315 return pullId - 1, err
316}
317
318func GetPulls(e Execer, repoAt syntax.ATURI, state PullState) ([]*Pull, error) {
319 pulls := make(map[int]*Pull)
320
321 rows, err := e.Query(`
322 select
323 owner_did,
324 pull_id,
325 created,
326 title,
327 state,
328 target_branch,
329 pull_at,
330 body,
331 rkey,
332 source_branch,
333 source_repo_at
334 from
335 pulls
336 where
337 repo_at = ? and state = ?`, repoAt, state)
338 if err != nil {
339 return nil, err
340 }
341 defer rows.Close()
342
343 for rows.Next() {
344 var pull Pull
345 var createdAt string
346 var sourceBranch, sourceRepoAt sql.NullString
347 err := rows.Scan(
348 &pull.OwnerDid,
349 &pull.PullId,
350 &createdAt,
351 &pull.Title,
352 &pull.State,
353 &pull.TargetBranch,
354 &pull.PullAt,
355 &pull.Body,
356 &pull.Rkey,
357 &sourceBranch,
358 &sourceRepoAt,
359 )
360 if err != nil {
361 return nil, err
362 }
363
364 createdTime, err := time.Parse(time.RFC3339, createdAt)
365 if err != nil {
366 return nil, err
367 }
368 pull.Created = createdTime
369
370 if sourceBranch.Valid {
371 pull.PullSource = &PullSource{
372 Branch: sourceBranch.String,
373 }
374 if sourceRepoAt.Valid {
375 sourceRepoAtParsed, err := syntax.ParseATURI(sourceRepoAt.String)
376 if err != nil {
377 return nil, err
378 }
379 pull.PullSource.RepoAt = &sourceRepoAtParsed
380 }
381 }
382
383 pulls[pull.PullId] = &pull
384 }
385
386 // get latest round no. for each pull
387 inClause := strings.TrimSuffix(strings.Repeat("?, ", len(pulls)), ", ")
388 submissionsQuery := fmt.Sprintf(`
389 select
390 id, pull_id, round_number
391 from
392 pull_submissions
393 where
394 repo_at = ? and pull_id in (%s)
395 `, inClause)
396
397 args := make([]any, len(pulls)+1)
398 args[0] = repoAt.String()
399 idx := 1
400 for _, p := range pulls {
401 args[idx] = p.PullId
402 idx += 1
403 }
404 submissionsRows, err := e.Query(submissionsQuery, args...)
405 if err != nil {
406 return nil, err
407 }
408 defer submissionsRows.Close()
409
410 for submissionsRows.Next() {
411 var s PullSubmission
412 err := submissionsRows.Scan(
413 &s.ID,
414 &s.PullId,
415 &s.RoundNumber,
416 )
417 if err != nil {
418 return nil, err
419 }
420
421 if p, ok := pulls[s.PullId]; ok {
422 p.Submissions = make([]*PullSubmission, s.RoundNumber+1)
423 p.Submissions[s.RoundNumber] = &s
424 }
425 }
426 if err := rows.Err(); err != nil {
427 return nil, err
428 }
429
430 // get comment count on latest submission on each pull
431 inClause = strings.TrimSuffix(strings.Repeat("?, ", len(pulls)), ", ")
432 commentsQuery := fmt.Sprintf(`
433 select
434 count(id), pull_id
435 from
436 pull_comments
437 where
438 submission_id in (%s)
439 group by
440 submission_id
441 `, inClause)
442
443 args = []any{}
444 for _, p := range pulls {
445 args = append(args, p.Submissions[p.LastRoundNumber()].ID)
446 }
447 commentsRows, err := e.Query(commentsQuery, args...)
448 if err != nil {
449 return nil, err
450 }
451 defer commentsRows.Close()
452
453 for commentsRows.Next() {
454 var commentCount, pullId int
455 err := commentsRows.Scan(
456 &commentCount,
457 &pullId,
458 )
459 if err != nil {
460 return nil, err
461 }
462 if p, ok := pulls[pullId]; ok {
463 p.Submissions[p.LastRoundNumber()].Comments = make([]PullComment, commentCount)
464 }
465 }
466 if err := rows.Err(); err != nil {
467 return nil, err
468 }
469
470 orderedByDate := []*Pull{}
471 for _, p := range pulls {
472 orderedByDate = append(orderedByDate, p)
473 }
474 sort.Slice(orderedByDate, func(i, j int) bool {
475 return orderedByDate[i].Created.After(orderedByDate[j].Created)
476 })
477
478 return orderedByDate, nil
479}
480
481func GetPull(e Execer, repoAt syntax.ATURI, pullId int) (*Pull, error) {
482 query := `
483 select
484 owner_did,
485 pull_id,
486 created,
487 title,
488 state,
489 target_branch,
490 pull_at,
491 repo_at,
492 body,
493 rkey,
494 source_branch,
495 source_repo_at
496 from
497 pulls
498 where
499 repo_at = ? and pull_id = ?
500 `
501 row := e.QueryRow(query, repoAt, pullId)
502
503 var pull Pull
504 var createdAt string
505 var sourceBranch, sourceRepoAt sql.NullString
506 err := row.Scan(
507 &pull.OwnerDid,
508 &pull.PullId,
509 &createdAt,
510 &pull.Title,
511 &pull.State,
512 &pull.TargetBranch,
513 &pull.PullAt,
514 &pull.RepoAt,
515 &pull.Body,
516 &pull.Rkey,
517 &sourceBranch,
518 &sourceRepoAt,
519 )
520 if err != nil {
521 return nil, err
522 }
523
524 createdTime, err := time.Parse(time.RFC3339, createdAt)
525 if err != nil {
526 return nil, err
527 }
528 pull.Created = createdTime
529
530 // populate source
531 if sourceBranch.Valid {
532 pull.PullSource = &PullSource{
533 Branch: sourceBranch.String,
534 }
535 if sourceRepoAt.Valid {
536 sourceRepoAtParsed, err := syntax.ParseATURI(sourceRepoAt.String)
537 if err != nil {
538 return nil, err
539 }
540 pull.PullSource.RepoAt = &sourceRepoAtParsed
541 }
542 }
543
544 submissionsQuery := `
545 select
546 id, pull_id, repo_at, round_number, patch, created, source_rev
547 from
548 pull_submissions
549 where
550 repo_at = ? and pull_id = ?
551 `
552 submissionsRows, err := e.Query(submissionsQuery, repoAt, pullId)
553 if err != nil {
554 return nil, err
555 }
556 defer submissionsRows.Close()
557
558 submissionsMap := make(map[int]*PullSubmission)
559
560 for submissionsRows.Next() {
561 var submission PullSubmission
562 var submissionCreatedStr string
563 var submissionSourceRev sql.NullString
564 err := submissionsRows.Scan(
565 &submission.ID,
566 &submission.PullId,
567 &submission.RepoAt,
568 &submission.RoundNumber,
569 &submission.Patch,
570 &submissionCreatedStr,
571 &submissionSourceRev,
572 )
573 if err != nil {
574 return nil, err
575 }
576
577 submissionCreatedTime, err := time.Parse(time.RFC3339, submissionCreatedStr)
578 if err != nil {
579 return nil, err
580 }
581 submission.Created = submissionCreatedTime
582
583 if submissionSourceRev.Valid {
584 submission.SourceRev = submissionSourceRev.String
585 }
586
587 submissionsMap[submission.ID] = &submission
588 }
589 if err = submissionsRows.Close(); err != nil {
590 return nil, err
591 }
592 if len(submissionsMap) == 0 {
593 return &pull, nil
594 }
595
596 var args []any
597 for k := range submissionsMap {
598 args = append(args, k)
599 }
600 inClause := strings.TrimSuffix(strings.Repeat("?, ", len(submissionsMap)), ", ")
601 commentsQuery := fmt.Sprintf(`
602 select
603 id,
604 pull_id,
605 submission_id,
606 repo_at,
607 owner_did,
608 comment_at,
609 body,
610 created
611 from
612 pull_comments
613 where
614 submission_id IN (%s)
615 order by
616 created asc
617 `, inClause)
618 commentsRows, err := e.Query(commentsQuery, args...)
619 if err != nil {
620 return nil, err
621 }
622 defer commentsRows.Close()
623
624 for commentsRows.Next() {
625 var comment PullComment
626 var commentCreatedStr string
627 err := commentsRows.Scan(
628 &comment.ID,
629 &comment.PullId,
630 &comment.SubmissionId,
631 &comment.RepoAt,
632 &comment.OwnerDid,
633 &comment.CommentAt,
634 &comment.Body,
635 &commentCreatedStr,
636 )
637 if err != nil {
638 return nil, err
639 }
640
641 commentCreatedTime, err := time.Parse(time.RFC3339, commentCreatedStr)
642 if err != nil {
643 return nil, err
644 }
645 comment.Created = commentCreatedTime
646
647 // Add the comment to its submission
648 if submission, ok := submissionsMap[comment.SubmissionId]; ok {
649 submission.Comments = append(submission.Comments, comment)
650 }
651
652 }
653 if err = commentsRows.Err(); err != nil {
654 return nil, err
655 }
656
657 var pullSourceRepo *Repo
658 if pull.PullSource != nil {
659 if pull.PullSource.RepoAt != nil {
660 pullSourceRepo, err = GetRepoByAtUri(e, pull.PullSource.RepoAt.String())
661 if err != nil {
662 log.Printf("failed to get repo by at uri: %v", err)
663 } else {
664 pull.PullSource.Repo = pullSourceRepo
665 }
666 }
667 }
668
669 pull.Submissions = make([]*PullSubmission, len(submissionsMap))
670 for _, submission := range submissionsMap {
671 pull.Submissions[submission.RoundNumber] = submission
672 }
673
674 return &pull, nil
675}
676
677// timeframe here is directly passed into the sql query filter, and any
678// timeframe in the past should be negative; e.g.: "-3 months"
679func GetPullsByOwnerDid(e Execer, did, timeframe string) ([]Pull, error) {
680 var pulls []Pull
681
682 rows, err := e.Query(`
683 select
684 p.owner_did,
685 p.repo_at,
686 p.pull_id,
687 p.created,
688 p.title,
689 p.state,
690 r.did,
691 r.name,
692 r.knot,
693 r.rkey,
694 r.created
695 from
696 pulls p
697 join
698 repos r on p.repo_at = r.at_uri
699 where
700 p.owner_did = ? and p.created >= date ('now', ?)
701 order by
702 p.created desc`, did, timeframe)
703 if err != nil {
704 return nil, err
705 }
706 defer rows.Close()
707
708 for rows.Next() {
709 var pull Pull
710 var repo Repo
711 var pullCreatedAt, repoCreatedAt string
712 err := rows.Scan(
713 &pull.OwnerDid,
714 &pull.RepoAt,
715 &pull.PullId,
716 &pullCreatedAt,
717 &pull.Title,
718 &pull.State,
719 &repo.Did,
720 &repo.Name,
721 &repo.Knot,
722 &repo.Rkey,
723 &repoCreatedAt,
724 )
725 if err != nil {
726 return nil, err
727 }
728
729 pullCreatedTime, err := time.Parse(time.RFC3339, pullCreatedAt)
730 if err != nil {
731 return nil, err
732 }
733 pull.Created = pullCreatedTime
734
735 repoCreatedTime, err := time.Parse(time.RFC3339, repoCreatedAt)
736 if err != nil {
737 return nil, err
738 }
739 repo.Created = repoCreatedTime
740
741 pull.Repo = &repo
742
743 pulls = append(pulls, pull)
744 }
745
746 if err := rows.Err(); err != nil {
747 return nil, err
748 }
749
750 return pulls, nil
751}
752
753func NewPullComment(e Execer, comment *PullComment) (int64, error) {
754 query := `insert into pull_comments (owner_did, repo_at, submission_id, comment_at, pull_id, body) values (?, ?, ?, ?, ?, ?)`
755 res, err := e.Exec(
756 query,
757 comment.OwnerDid,
758 comment.RepoAt,
759 comment.SubmissionId,
760 comment.CommentAt,
761 comment.PullId,
762 comment.Body,
763 )
764 if err != nil {
765 return 0, err
766 }
767
768 i, err := res.LastInsertId()
769 if err != nil {
770 return 0, err
771 }
772
773 return i, nil
774}
775
776func SetPullState(e Execer, repoAt syntax.ATURI, pullId int, pullState PullState) error {
777 _, err := e.Exec(`update pulls set state = ? where repo_at = ? and pull_id = ?`, pullState, repoAt, pullId)
778 return err
779}
780
781func ClosePull(e Execer, repoAt syntax.ATURI, pullId int) error {
782 err := SetPullState(e, repoAt, pullId, PullClosed)
783 return err
784}
785
786func ReopenPull(e Execer, repoAt syntax.ATURI, pullId int) error {
787 err := SetPullState(e, repoAt, pullId, PullOpen)
788 return err
789}
790
791func MergePull(e Execer, repoAt syntax.ATURI, pullId int) error {
792 err := SetPullState(e, repoAt, pullId, PullMerged)
793 return err
794}
795
796func ResubmitPull(e Execer, pull *Pull, newPatch, sourceRev string) error {
797 newRoundNumber := len(pull.Submissions)
798 _, err := e.Exec(`
799 insert into pull_submissions (pull_id, repo_at, round_number, patch, source_rev)
800 values (?, ?, ?, ?, ?)
801 `, pull.PullId, pull.RepoAt, newRoundNumber, newPatch, sourceRev)
802
803 return err
804}
805
806type PullCount struct {
807 Open int
808 Merged int
809 Closed int
810}
811
812func GetPullCount(e Execer, repoAt syntax.ATURI) (PullCount, error) {
813 row := e.QueryRow(`
814 select
815 count(case when state = ? then 1 end) as open_count,
816 count(case when state = ? then 1 end) as merged_count,
817 count(case when state = ? then 1 end) as closed_count
818 from pulls
819 where repo_at = ?`,
820 PullOpen,
821 PullMerged,
822 PullClosed,
823 repoAt,
824 )
825
826 var count PullCount
827 if err := row.Scan(&count.Open, &count.Merged, &count.Closed); err != nil {
828 return PullCount{0, 0, 0}, err
829 }
830
831 return count, nil
832}