1package db
2
3import (
4 "database/sql"
5 "fmt"
6 "log"
7 "sort"
8 "strings"
9 "time"
10
11 "github.com/bluekeyes/go-gitdiff/gitdiff"
12 "github.com/bluesky-social/indigo/atproto/syntax"
13 "tangled.sh/tangled.sh/core/api/tangled"
14 "tangled.sh/tangled.sh/core/patchutil"
15 "tangled.sh/tangled.sh/core/types"
16)
17
18type PullState int
19
20const (
21 PullClosed PullState = iota
22 PullOpen
23 PullMerged
24)
25
26func (p PullState) String() string {
27 switch p {
28 case PullOpen:
29 return "open"
30 case PullMerged:
31 return "merged"
32 case PullClosed:
33 return "closed"
34 default:
35 return "closed"
36 }
37}
38
39func (p PullState) IsOpen() bool {
40 return p == PullOpen
41}
42func (p PullState) IsMerged() bool {
43 return p == PullMerged
44}
45func (p PullState) IsClosed() bool {
46 return p == PullClosed
47}
48
49type Pull struct {
50 // ids
51 ID int
52 PullId int
53
54 // at ids
55 RepoAt syntax.ATURI
56 OwnerDid string
57 Rkey string
58
59 // content
60 Title string
61 Body string
62 TargetBranch string
63 State PullState
64 Submissions []*PullSubmission
65
66 // meta
67 Created time.Time
68 PullSource *PullSource
69
70 // optionally, populate this when querying for reverse mappings
71 Repo *Repo
72}
73
74type PullSource struct {
75 Branch string
76 RepoAt *syntax.ATURI
77
78 // optionally populate this for reverse mappings
79 Repo *Repo
80}
81
82type PullSubmission struct {
83 // ids
84 ID int
85 PullId int
86
87 // at ids
88 RepoAt syntax.ATURI
89
90 // content
91 RoundNumber int
92 Patch string
93 Comments []PullComment
94 SourceRev string // include the rev that was used to create this submission: only for branch PRs
95
96 // meta
97 Created time.Time
98}
99
100type PullComment struct {
101 // ids
102 ID int
103 PullId int
104 SubmissionId int
105
106 // at ids
107 RepoAt string
108 OwnerDid string
109 CommentAt string
110
111 // content
112 Body string
113
114 // meta
115 Created time.Time
116}
117
118func (p *Pull) LatestPatch() string {
119 latestSubmission := p.Submissions[p.LastRoundNumber()]
120 return latestSubmission.Patch
121}
122
123func (p *Pull) PullAt() syntax.ATURI {
124 return syntax.ATURI(fmt.Sprintf("at://%s/%s/%s", p.OwnerDid, tangled.RepoPullNSID, p.Rkey))
125}
126
127func (p *Pull) LastRoundNumber() int {
128 return len(p.Submissions) - 1
129}
130
131func (p *Pull) IsPatchBased() bool {
132 return p.PullSource == nil
133}
134
135func (p *Pull) IsBranchBased() bool {
136 if p.PullSource != nil {
137 if p.PullSource.RepoAt != nil {
138 return p.PullSource.RepoAt == &p.RepoAt
139 } else {
140 // no repo specified
141 return true
142 }
143 }
144 return false
145}
146
147func (p *Pull) IsForkBased() bool {
148 if p.PullSource != nil {
149 if p.PullSource.RepoAt != nil {
150 // make sure repos are different
151 return p.PullSource.RepoAt != &p.RepoAt
152 }
153 }
154 return false
155}
156
157func (s PullSubmission) AsDiff(targetBranch string) ([]*gitdiff.File, error) {
158 patch := s.Patch
159
160 // if format-patch; then extract each patch
161 var diffs []*gitdiff.File
162 if patchutil.IsFormatPatch(patch) {
163 patches, err := patchutil.ExtractPatches(patch)
164 if err != nil {
165 return nil, err
166 }
167 var ps [][]*gitdiff.File
168 for _, p := range patches {
169 ps = append(ps, p.Files)
170 }
171
172 diffs = patchutil.CombineDiff(ps...)
173 } else {
174 d, _, err := gitdiff.Parse(strings.NewReader(patch))
175 if err != nil {
176 return nil, err
177 }
178 diffs = d
179 }
180
181 return diffs, nil
182}
183
184func (s PullSubmission) AsNiceDiff(targetBranch string) types.NiceDiff {
185 diffs, err := s.AsDiff(targetBranch)
186 if err != nil {
187 log.Println(err)
188 }
189
190 nd := types.NiceDiff{}
191 nd.Commit.Parent = targetBranch
192
193 for _, d := range diffs {
194 ndiff := types.Diff{}
195 ndiff.Name.New = d.NewName
196 ndiff.Name.Old = d.OldName
197 ndiff.IsBinary = d.IsBinary
198 ndiff.IsNew = d.IsNew
199 ndiff.IsDelete = d.IsDelete
200 ndiff.IsCopy = d.IsCopy
201 ndiff.IsRename = d.IsRename
202
203 for _, tf := range d.TextFragments {
204 ndiff.TextFragments = append(ndiff.TextFragments, *tf)
205 for _, l := range tf.Lines {
206 switch l.Op {
207 case gitdiff.OpAdd:
208 nd.Stat.Insertions += 1
209 case gitdiff.OpDelete:
210 nd.Stat.Deletions += 1
211 }
212 }
213 }
214
215 nd.Diff = append(nd.Diff, ndiff)
216 }
217
218 nd.Stat.FilesChanged = len(diffs)
219
220 return nd
221}
222
223func (s PullSubmission) IsFormatPatch() bool {
224 return patchutil.IsFormatPatch(s.Patch)
225}
226
227func (s PullSubmission) AsFormatPatch() []patchutil.FormatPatch {
228 patches, err := patchutil.ExtractPatches(s.Patch)
229 if err != nil {
230 log.Println("error extracting patches from submission:", err)
231 return []patchutil.FormatPatch{}
232 }
233
234 return patches
235}
236
237func NewPull(tx *sql.Tx, pull *Pull) error {
238 _, err := tx.Exec(`
239 insert or ignore into repo_pull_seqs (repo_at, next_pull_id)
240 values (?, 1)
241 `, pull.RepoAt)
242 if err != nil {
243 return err
244 }
245
246 var nextId int
247 err = tx.QueryRow(`
248 update repo_pull_seqs
249 set next_pull_id = next_pull_id + 1
250 where repo_at = ?
251 returning next_pull_id - 1
252 `, pull.RepoAt).Scan(&nextId)
253 if err != nil {
254 return err
255 }
256
257 pull.PullId = nextId
258 pull.State = PullOpen
259
260 var sourceBranch, sourceRepoAt *string
261 if pull.PullSource != nil {
262 sourceBranch = &pull.PullSource.Branch
263 if pull.PullSource.RepoAt != nil {
264 x := pull.PullSource.RepoAt.String()
265 sourceRepoAt = &x
266 }
267 }
268
269 _, err = tx.Exec(
270 `
271 insert into pulls (repo_at, owner_did, pull_id, title, target_branch, body, rkey, state, source_branch, source_repo_at)
272 values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`,
273 pull.RepoAt,
274 pull.OwnerDid,
275 pull.PullId,
276 pull.Title,
277 pull.TargetBranch,
278 pull.Body,
279 pull.Rkey,
280 pull.State,
281 sourceBranch,
282 sourceRepoAt,
283 )
284 if err != nil {
285 return err
286 }
287
288 _, err = tx.Exec(`
289 insert into pull_submissions (pull_id, repo_at, round_number, patch, source_rev)
290 values (?, ?, ?, ?, ?)
291 `, pull.PullId, pull.RepoAt, 0, pull.Submissions[0].Patch, pull.Submissions[0].SourceRev)
292 return err
293}
294
295func GetPullAt(e Execer, repoAt syntax.ATURI, pullId int) (syntax.ATURI, error) {
296 pull, err := GetPull(e, repoAt, pullId)
297 if err != nil {
298 return "", err
299 }
300 return pull.PullAt(), err
301}
302
303func NextPullId(e Execer, repoAt syntax.ATURI) (int, error) {
304 var pullId int
305 err := e.QueryRow(`select next_pull_id from repo_pull_seqs where repo_at = ?`, repoAt).Scan(&pullId)
306 return pullId - 1, err
307}
308
309func GetPulls(e Execer, repoAt syntax.ATURI, state PullState) ([]*Pull, error) {
310 pulls := make(map[int]*Pull)
311
312 rows, err := e.Query(`
313 select
314 owner_did,
315 pull_id,
316 created,
317 title,
318 state,
319 target_branch,
320 body,
321 rkey,
322 source_branch,
323 source_repo_at
324 from
325 pulls
326 where
327 repo_at = ? and state = ?`, repoAt, state)
328 if err != nil {
329 return nil, err
330 }
331 defer rows.Close()
332
333 for rows.Next() {
334 var pull Pull
335 var createdAt string
336 var sourceBranch, sourceRepoAt sql.NullString
337 err := rows.Scan(
338 &pull.OwnerDid,
339 &pull.PullId,
340 &createdAt,
341 &pull.Title,
342 &pull.State,
343 &pull.TargetBranch,
344 &pull.Body,
345 &pull.Rkey,
346 &sourceBranch,
347 &sourceRepoAt,
348 )
349 if err != nil {
350 return nil, err
351 }
352
353 createdTime, err := time.Parse(time.RFC3339, createdAt)
354 if err != nil {
355 return nil, err
356 }
357 pull.Created = createdTime
358
359 if sourceBranch.Valid {
360 pull.PullSource = &PullSource{
361 Branch: sourceBranch.String,
362 }
363 if sourceRepoAt.Valid {
364 sourceRepoAtParsed, err := syntax.ParseATURI(sourceRepoAt.String)
365 if err != nil {
366 return nil, err
367 }
368 pull.PullSource.RepoAt = &sourceRepoAtParsed
369 }
370 }
371
372 pulls[pull.PullId] = &pull
373 }
374
375 // get latest round no. for each pull
376 inClause := strings.TrimSuffix(strings.Repeat("?, ", len(pulls)), ", ")
377 submissionsQuery := fmt.Sprintf(`
378 select
379 id, pull_id, round_number
380 from
381 pull_submissions
382 where
383 repo_at = ? and pull_id in (%s)
384 `, inClause)
385
386 args := make([]any, len(pulls)+1)
387 args[0] = repoAt.String()
388 idx := 1
389 for _, p := range pulls {
390 args[idx] = p.PullId
391 idx += 1
392 }
393 submissionsRows, err := e.Query(submissionsQuery, args...)
394 if err != nil {
395 return nil, err
396 }
397 defer submissionsRows.Close()
398
399 for submissionsRows.Next() {
400 var s PullSubmission
401 err := submissionsRows.Scan(
402 &s.ID,
403 &s.PullId,
404 &s.RoundNumber,
405 )
406 if err != nil {
407 return nil, err
408 }
409
410 if p, ok := pulls[s.PullId]; ok {
411 p.Submissions = make([]*PullSubmission, s.RoundNumber+1)
412 p.Submissions[s.RoundNumber] = &s
413 }
414 }
415 if err := rows.Err(); err != nil {
416 return nil, err
417 }
418
419 // get comment count on latest submission on each pull
420 inClause = strings.TrimSuffix(strings.Repeat("?, ", len(pulls)), ", ")
421 commentsQuery := fmt.Sprintf(`
422 select
423 count(id), pull_id
424 from
425 pull_comments
426 where
427 submission_id in (%s)
428 group by
429 submission_id
430 `, inClause)
431
432 args = []any{}
433 for _, p := range pulls {
434 args = append(args, p.Submissions[p.LastRoundNumber()].ID)
435 }
436 commentsRows, err := e.Query(commentsQuery, args...)
437 if err != nil {
438 return nil, err
439 }
440 defer commentsRows.Close()
441
442 for commentsRows.Next() {
443 var commentCount, pullId int
444 err := commentsRows.Scan(
445 &commentCount,
446 &pullId,
447 )
448 if err != nil {
449 return nil, err
450 }
451 if p, ok := pulls[pullId]; ok {
452 p.Submissions[p.LastRoundNumber()].Comments = make([]PullComment, commentCount)
453 }
454 }
455 if err := rows.Err(); err != nil {
456 return nil, err
457 }
458
459 orderedByDate := []*Pull{}
460 for _, p := range pulls {
461 orderedByDate = append(orderedByDate, p)
462 }
463 sort.Slice(orderedByDate, func(i, j int) bool {
464 return orderedByDate[i].Created.After(orderedByDate[j].Created)
465 })
466
467 return orderedByDate, nil
468}
469
470func GetPull(e Execer, repoAt syntax.ATURI, pullId int) (*Pull, error) {
471 query := `
472 select
473 owner_did,
474 pull_id,
475 created,
476 title,
477 state,
478 target_branch,
479 repo_at,
480 body,
481 rkey,
482 source_branch,
483 source_repo_at
484 from
485 pulls
486 where
487 repo_at = ? and pull_id = ?
488 `
489 row := e.QueryRow(query, repoAt, pullId)
490
491 var pull Pull
492 var createdAt string
493 var sourceBranch, sourceRepoAt sql.NullString
494 err := row.Scan(
495 &pull.OwnerDid,
496 &pull.PullId,
497 &createdAt,
498 &pull.Title,
499 &pull.State,
500 &pull.TargetBranch,
501 &pull.RepoAt,
502 &pull.Body,
503 &pull.Rkey,
504 &sourceBranch,
505 &sourceRepoAt,
506 )
507 if err != nil {
508 return nil, err
509 }
510
511 createdTime, err := time.Parse(time.RFC3339, createdAt)
512 if err != nil {
513 return nil, err
514 }
515 pull.Created = createdTime
516
517 // populate source
518 if sourceBranch.Valid {
519 pull.PullSource = &PullSource{
520 Branch: sourceBranch.String,
521 }
522 if sourceRepoAt.Valid {
523 sourceRepoAtParsed, err := syntax.ParseATURI(sourceRepoAt.String)
524 if err != nil {
525 return nil, err
526 }
527 pull.PullSource.RepoAt = &sourceRepoAtParsed
528 }
529 }
530
531 submissionsQuery := `
532 select
533 id, pull_id, repo_at, round_number, patch, created, source_rev
534 from
535 pull_submissions
536 where
537 repo_at = ? and pull_id = ?
538 `
539 submissionsRows, err := e.Query(submissionsQuery, repoAt, pullId)
540 if err != nil {
541 return nil, err
542 }
543 defer submissionsRows.Close()
544
545 submissionsMap := make(map[int]*PullSubmission)
546
547 for submissionsRows.Next() {
548 var submission PullSubmission
549 var submissionCreatedStr string
550 var submissionSourceRev sql.NullString
551 err := submissionsRows.Scan(
552 &submission.ID,
553 &submission.PullId,
554 &submission.RepoAt,
555 &submission.RoundNumber,
556 &submission.Patch,
557 &submissionCreatedStr,
558 &submissionSourceRev,
559 )
560 if err != nil {
561 return nil, err
562 }
563
564 submissionCreatedTime, err := time.Parse(time.RFC3339, submissionCreatedStr)
565 if err != nil {
566 return nil, err
567 }
568 submission.Created = submissionCreatedTime
569
570 if submissionSourceRev.Valid {
571 submission.SourceRev = submissionSourceRev.String
572 }
573
574 submissionsMap[submission.ID] = &submission
575 }
576 if err = submissionsRows.Close(); err != nil {
577 return nil, err
578 }
579 if len(submissionsMap) == 0 {
580 return &pull, nil
581 }
582
583 var args []any
584 for k := range submissionsMap {
585 args = append(args, k)
586 }
587 inClause := strings.TrimSuffix(strings.Repeat("?, ", len(submissionsMap)), ", ")
588 commentsQuery := fmt.Sprintf(`
589 select
590 id,
591 pull_id,
592 submission_id,
593 repo_at,
594 owner_did,
595 comment_at,
596 body,
597 created
598 from
599 pull_comments
600 where
601 submission_id IN (%s)
602 order by
603 created asc
604 `, inClause)
605 commentsRows, err := e.Query(commentsQuery, args...)
606 if err != nil {
607 return nil, err
608 }
609 defer commentsRows.Close()
610
611 for commentsRows.Next() {
612 var comment PullComment
613 var commentCreatedStr string
614 err := commentsRows.Scan(
615 &comment.ID,
616 &comment.PullId,
617 &comment.SubmissionId,
618 &comment.RepoAt,
619 &comment.OwnerDid,
620 &comment.CommentAt,
621 &comment.Body,
622 &commentCreatedStr,
623 )
624 if err != nil {
625 return nil, err
626 }
627
628 commentCreatedTime, err := time.Parse(time.RFC3339, commentCreatedStr)
629 if err != nil {
630 return nil, err
631 }
632 comment.Created = commentCreatedTime
633
634 // Add the comment to its submission
635 if submission, ok := submissionsMap[comment.SubmissionId]; ok {
636 submission.Comments = append(submission.Comments, comment)
637 }
638
639 }
640 if err = commentsRows.Err(); err != nil {
641 return nil, err
642 }
643
644 var pullSourceRepo *Repo
645 if pull.PullSource != nil {
646 if pull.PullSource.RepoAt != nil {
647 pullSourceRepo, err = GetRepoByAtUri(e, pull.PullSource.RepoAt.String())
648 if err != nil {
649 log.Printf("failed to get repo by at uri: %v", err)
650 } else {
651 pull.PullSource.Repo = pullSourceRepo
652 }
653 }
654 }
655
656 pull.Submissions = make([]*PullSubmission, len(submissionsMap))
657 for _, submission := range submissionsMap {
658 pull.Submissions[submission.RoundNumber] = submission
659 }
660
661 return &pull, nil
662}
663
664// timeframe here is directly passed into the sql query filter, and any
665// timeframe in the past should be negative; e.g.: "-3 months"
666func GetPullsByOwnerDid(e Execer, did, timeframe string) ([]Pull, error) {
667 var pulls []Pull
668
669 rows, err := e.Query(`
670 select
671 p.owner_did,
672 p.repo_at,
673 p.pull_id,
674 p.created,
675 p.title,
676 p.state,
677 r.did,
678 r.name,
679 r.knot,
680 r.rkey,
681 r.created
682 from
683 pulls p
684 join
685 repos r on p.repo_at = r.at_uri
686 where
687 p.owner_did = ? and p.created >= date ('now', ?)
688 order by
689 p.created desc`, did, timeframe)
690 if err != nil {
691 return nil, err
692 }
693 defer rows.Close()
694
695 for rows.Next() {
696 var pull Pull
697 var repo Repo
698 var pullCreatedAt, repoCreatedAt string
699 err := rows.Scan(
700 &pull.OwnerDid,
701 &pull.RepoAt,
702 &pull.PullId,
703 &pullCreatedAt,
704 &pull.Title,
705 &pull.State,
706 &repo.Did,
707 &repo.Name,
708 &repo.Knot,
709 &repo.Rkey,
710 &repoCreatedAt,
711 )
712 if err != nil {
713 return nil, err
714 }
715
716 pullCreatedTime, err := time.Parse(time.RFC3339, pullCreatedAt)
717 if err != nil {
718 return nil, err
719 }
720 pull.Created = pullCreatedTime
721
722 repoCreatedTime, err := time.Parse(time.RFC3339, repoCreatedAt)
723 if err != nil {
724 return nil, err
725 }
726 repo.Created = repoCreatedTime
727
728 pull.Repo = &repo
729
730 pulls = append(pulls, pull)
731 }
732
733 if err := rows.Err(); err != nil {
734 return nil, err
735 }
736
737 return pulls, nil
738}
739
740func NewPullComment(e Execer, comment *PullComment) (int64, error) {
741 query := `insert into pull_comments (owner_did, repo_at, submission_id, comment_at, pull_id, body) values (?, ?, ?, ?, ?, ?)`
742 res, err := e.Exec(
743 query,
744 comment.OwnerDid,
745 comment.RepoAt,
746 comment.SubmissionId,
747 comment.CommentAt,
748 comment.PullId,
749 comment.Body,
750 )
751 if err != nil {
752 return 0, err
753 }
754
755 i, err := res.LastInsertId()
756 if err != nil {
757 return 0, err
758 }
759
760 return i, nil
761}
762
763func SetPullState(e Execer, repoAt syntax.ATURI, pullId int, pullState PullState) error {
764 _, err := e.Exec(`update pulls set state = ? where repo_at = ? and pull_id = ?`, pullState, repoAt, pullId)
765 return err
766}
767
768func ClosePull(e Execer, repoAt syntax.ATURI, pullId int) error {
769 err := SetPullState(e, repoAt, pullId, PullClosed)
770 return err
771}
772
773func ReopenPull(e Execer, repoAt syntax.ATURI, pullId int) error {
774 err := SetPullState(e, repoAt, pullId, PullOpen)
775 return err
776}
777
778func MergePull(e Execer, repoAt syntax.ATURI, pullId int) error {
779 err := SetPullState(e, repoAt, pullId, PullMerged)
780 return err
781}
782
783func ResubmitPull(e Execer, pull *Pull, newPatch, sourceRev string) error {
784 newRoundNumber := len(pull.Submissions)
785 _, err := e.Exec(`
786 insert into pull_submissions (pull_id, repo_at, round_number, patch, source_rev)
787 values (?, ?, ?, ?, ?)
788 `, pull.PullId, pull.RepoAt, newRoundNumber, newPatch, sourceRev)
789
790 return err
791}
792
793type PullCount struct {
794 Open int
795 Merged int
796 Closed int
797}
798
799func GetPullCount(e Execer, repoAt syntax.ATURI) (PullCount, error) {
800 row := e.QueryRow(`
801 select
802 count(case when state = ? then 1 end) as open_count,
803 count(case when state = ? then 1 end) as merged_count,
804 count(case when state = ? then 1 end) as closed_count
805 from pulls
806 where repo_at = ?`,
807 PullOpen,
808 PullMerged,
809 PullClosed,
810 repoAt,
811 )
812
813 var count PullCount
814 if err := row.Scan(&count.Open, &count.Merged, &count.Closed); err != nil {
815 return PullCount{0, 0, 0}, err
816 }
817
818 return count, nil
819}