1package db
2
3import (
4 "database/sql"
5 "fmt"
6 "log"
7 "sort"
8 "strings"
9 "time"
10
11 "github.com/bluekeyes/go-gitdiff/gitdiff"
12 "github.com/bluesky-social/indigo/atproto/syntax"
13 "tangled.sh/tangled.sh/core/patchutil"
14 "tangled.sh/tangled.sh/core/types"
15)
16
17type PullState int
18
19const (
20 PullClosed PullState = iota
21 PullOpen
22 PullMerged
23)
24
25func (p PullState) String() string {
26 switch p {
27 case PullOpen:
28 return "open"
29 case PullMerged:
30 return "merged"
31 case PullClosed:
32 return "closed"
33 default:
34 return "closed"
35 }
36}
37
38func (p PullState) IsOpen() bool {
39 return p == PullOpen
40}
41func (p PullState) IsMerged() bool {
42 return p == PullMerged
43}
44func (p PullState) IsClosed() bool {
45 return p == PullClosed
46}
47
48type Pull struct {
49 // ids
50 ID int
51 PullId int
52
53 // at ids
54 RepoAt syntax.ATURI
55 OwnerDid string
56 Rkey string
57 PullAt syntax.ATURI
58
59 // content
60 Title string
61 Body string
62 TargetBranch string
63 State PullState
64 Submissions []*PullSubmission
65
66 // meta
67 Created time.Time
68 PullSource *PullSource
69
70 // optionally, populate this when querying for reverse mappings
71 Repo *Repo
72}
73
74type PullSource struct {
75 Branch string
76 RepoAt *syntax.ATURI
77
78 // optionally populate this for reverse mappings
79 Repo *Repo
80}
81
82type PullSubmission struct {
83 // ids
84 ID int
85 PullId int
86
87 // at ids
88 RepoAt syntax.ATURI
89
90 // content
91 RoundNumber int
92 Patch string
93 Comments []PullComment
94 SourceRev string // include the rev that was used to create this submission: only for branch PRs
95
96 // meta
97 Created time.Time
98}
99
100type PullComment struct {
101 // ids
102 ID int
103 PullId int
104 SubmissionId int
105
106 // at ids
107 RepoAt string
108 OwnerDid string
109 CommentAt string
110
111 // content
112 Body string
113
114 // meta
115 Created time.Time
116}
117
118func (p *Pull) LatestPatch() string {
119 latestSubmission := p.Submissions[p.LastRoundNumber()]
120 return latestSubmission.Patch
121}
122
123func (p *Pull) LastRoundNumber() int {
124 return len(p.Submissions) - 1
125}
126
127func (p *Pull) IsPatchBased() bool {
128 return p.PullSource == nil
129}
130
131func (p *Pull) IsBranchBased() bool {
132 if p.PullSource != nil {
133 if p.PullSource.RepoAt != nil {
134 return p.PullSource.RepoAt == &p.RepoAt
135 } else {
136 // no repo specified
137 return true
138 }
139 }
140 return false
141}
142
143func (p *Pull) IsForkBased() bool {
144 if p.PullSource != nil {
145 if p.PullSource.RepoAt != nil {
146 // make sure repos are different
147 return p.PullSource.RepoAt != &p.RepoAt
148 }
149 }
150 return false
151}
152
153func (s PullSubmission) AsNiceDiff(targetBranch string) types.NiceDiff {
154 patch := s.Patch
155
156 diffs, _, err := gitdiff.Parse(strings.NewReader(patch))
157 if err != nil {
158 log.Println(err)
159 }
160
161 nd := types.NiceDiff{}
162 nd.Commit.Parent = targetBranch
163
164 for _, d := range diffs {
165 ndiff := types.Diff{}
166 ndiff.Name.New = d.NewName
167 ndiff.Name.Old = d.OldName
168 ndiff.IsBinary = d.IsBinary
169 ndiff.IsNew = d.IsNew
170 ndiff.IsDelete = d.IsDelete
171 ndiff.IsCopy = d.IsCopy
172 ndiff.IsRename = d.IsRename
173
174 for _, tf := range d.TextFragments {
175 ndiff.TextFragments = append(ndiff.TextFragments, *tf)
176 for _, l := range tf.Lines {
177 switch l.Op {
178 case gitdiff.OpAdd:
179 nd.Stat.Insertions += 1
180 case gitdiff.OpDelete:
181 nd.Stat.Deletions += 1
182 }
183 }
184 }
185
186 nd.Diff = append(nd.Diff, ndiff)
187 }
188
189 nd.Stat.FilesChanged = len(diffs)
190
191 return nd
192}
193
194func (s PullSubmission) IsFormatPatch() bool {
195 return patchutil.IsFormatPatch(s.Patch)
196}
197
198func (s PullSubmission) AsFormatPatch() []patchutil.FormatPatch {
199 patches, err := patchutil.ExtractPatches(s.Patch)
200 if err != nil {
201 log.Println("error extracting patches from submission:", err)
202 return []patchutil.FormatPatch{}
203 }
204
205 return patches
206}
207
208func NewPull(tx *sql.Tx, pull *Pull) error {
209 defer tx.Rollback()
210
211 _, err := tx.Exec(`
212 insert or ignore into repo_pull_seqs (repo_at, next_pull_id)
213 values (?, 1)
214 `, pull.RepoAt)
215 if err != nil {
216 return err
217 }
218
219 var nextId int
220 err = tx.QueryRow(`
221 update repo_pull_seqs
222 set next_pull_id = next_pull_id + 1
223 where repo_at = ?
224 returning next_pull_id - 1
225 `, pull.RepoAt).Scan(&nextId)
226 if err != nil {
227 return err
228 }
229
230 pull.PullId = nextId
231 pull.State = PullOpen
232
233 var sourceBranch, sourceRepoAt *string
234 if pull.PullSource != nil {
235 sourceBranch = &pull.PullSource.Branch
236 if pull.PullSource.RepoAt != nil {
237 x := pull.PullSource.RepoAt.String()
238 sourceRepoAt = &x
239 }
240 }
241
242 _, err = tx.Exec(
243 `
244 insert into pulls (repo_at, owner_did, pull_id, title, target_branch, body, rkey, state, source_branch, source_repo_at)
245 values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`,
246 pull.RepoAt,
247 pull.OwnerDid,
248 pull.PullId,
249 pull.Title,
250 pull.TargetBranch,
251 pull.Body,
252 pull.Rkey,
253 pull.State,
254 sourceBranch,
255 sourceRepoAt,
256 )
257 if err != nil {
258 return err
259 }
260
261 _, err = tx.Exec(`
262 insert into pull_submissions (pull_id, repo_at, round_number, patch, source_rev)
263 values (?, ?, ?, ?, ?)
264 `, pull.PullId, pull.RepoAt, 0, pull.Submissions[0].Patch, pull.Submissions[0].SourceRev)
265 if err != nil {
266 return err
267 }
268
269 if err := tx.Commit(); err != nil {
270 return err
271 }
272
273 return nil
274}
275
276func SetPullAt(e Execer, repoAt syntax.ATURI, pullId int, pullAt string) error {
277 _, err := e.Exec(`update pulls set pull_at = ? where repo_at = ? and pull_id = ?`, pullAt, repoAt, pullId)
278 return err
279}
280
281func GetPullAt(e Execer, repoAt syntax.ATURI, pullId int) (string, error) {
282 var pullAt string
283 err := e.QueryRow(`select pull_at from pulls where repo_at = ? and pull_id = ?`, repoAt, pullId).Scan(&pullAt)
284 return pullAt, err
285}
286
287func NextPullId(e Execer, repoAt syntax.ATURI) (int, error) {
288 var pullId int
289 err := e.QueryRow(`select next_pull_id from repo_pull_seqs where repo_at = ?`, repoAt).Scan(&pullId)
290 return pullId - 1, err
291}
292
293func GetPulls(e Execer, repoAt syntax.ATURI, state PullState) ([]*Pull, error) {
294 pulls := make(map[int]*Pull)
295
296 rows, err := e.Query(`
297 select
298 owner_did,
299 pull_id,
300 created,
301 title,
302 state,
303 target_branch,
304 pull_at,
305 body,
306 rkey,
307 source_branch,
308 source_repo_at
309 from
310 pulls
311 where
312 repo_at = ? and state = ?`, repoAt, state)
313 if err != nil {
314 return nil, err
315 }
316 defer rows.Close()
317
318 for rows.Next() {
319 var pull Pull
320 var createdAt string
321 var sourceBranch, sourceRepoAt sql.NullString
322 err := rows.Scan(
323 &pull.OwnerDid,
324 &pull.PullId,
325 &createdAt,
326 &pull.Title,
327 &pull.State,
328 &pull.TargetBranch,
329 &pull.PullAt,
330 &pull.Body,
331 &pull.Rkey,
332 &sourceBranch,
333 &sourceRepoAt,
334 )
335 if err != nil {
336 return nil, err
337 }
338
339 createdTime, err := time.Parse(time.RFC3339, createdAt)
340 if err != nil {
341 return nil, err
342 }
343 pull.Created = createdTime
344
345 if sourceBranch.Valid {
346 pull.PullSource = &PullSource{
347 Branch: sourceBranch.String,
348 }
349 if sourceRepoAt.Valid {
350 sourceRepoAtParsed, err := syntax.ParseATURI(sourceRepoAt.String)
351 if err != nil {
352 return nil, err
353 }
354 pull.PullSource.RepoAt = &sourceRepoAtParsed
355 }
356 }
357
358 pulls[pull.PullId] = &pull
359 }
360
361 // get latest round no. for each pull
362 inClause := strings.TrimSuffix(strings.Repeat("?, ", len(pulls)), ", ")
363 submissionsQuery := fmt.Sprintf(`
364 select
365 id, pull_id, round_number
366 from
367 pull_submissions
368 where
369 repo_at = ? and pull_id in (%s)
370 `, inClause)
371
372 args := make([]any, len(pulls)+1)
373 args[0] = repoAt.String()
374 idx := 1
375 for _, p := range pulls {
376 args[idx] = p.PullId
377 idx += 1
378 }
379 submissionsRows, err := e.Query(submissionsQuery, args...)
380 if err != nil {
381 return nil, err
382 }
383 defer submissionsRows.Close()
384
385 for submissionsRows.Next() {
386 var s PullSubmission
387 err := submissionsRows.Scan(
388 &s.ID,
389 &s.PullId,
390 &s.RoundNumber,
391 )
392 if err != nil {
393 return nil, err
394 }
395
396 if p, ok := pulls[s.PullId]; ok {
397 p.Submissions = make([]*PullSubmission, s.RoundNumber+1)
398 p.Submissions[s.RoundNumber] = &s
399 }
400 }
401 if err := rows.Err(); err != nil {
402 return nil, err
403 }
404
405 // get comment count on latest submission on each pull
406 inClause = strings.TrimSuffix(strings.Repeat("?, ", len(pulls)), ", ")
407 commentsQuery := fmt.Sprintf(`
408 select
409 count(id), pull_id
410 from
411 pull_comments
412 where
413 submission_id in (%s)
414 group by
415 submission_id
416 `, inClause)
417
418 args = []any{}
419 for _, p := range pulls {
420 args = append(args, p.Submissions[p.LastRoundNumber()].ID)
421 }
422 commentsRows, err := e.Query(commentsQuery, args...)
423 if err != nil {
424 return nil, err
425 }
426 defer commentsRows.Close()
427
428 for commentsRows.Next() {
429 var commentCount, pullId int
430 err := commentsRows.Scan(
431 &commentCount,
432 &pullId,
433 )
434 if err != nil {
435 return nil, err
436 }
437 if p, ok := pulls[pullId]; ok {
438 p.Submissions[p.LastRoundNumber()].Comments = make([]PullComment, commentCount)
439 }
440 }
441 if err := rows.Err(); err != nil {
442 return nil, err
443 }
444
445 orderedByDate := []*Pull{}
446 for _, p := range pulls {
447 orderedByDate = append(orderedByDate, p)
448 }
449 sort.Slice(orderedByDate, func(i, j int) bool {
450 return orderedByDate[i].Created.After(orderedByDate[j].Created)
451 })
452
453 return orderedByDate, nil
454}
455
456func GetPull(e Execer, repoAt syntax.ATURI, pullId int) (*Pull, error) {
457 query := `
458 select
459 owner_did,
460 pull_id,
461 created,
462 title,
463 state,
464 target_branch,
465 pull_at,
466 repo_at,
467 body,
468 rkey,
469 source_branch,
470 source_repo_at
471 from
472 pulls
473 where
474 repo_at = ? and pull_id = ?
475 `
476 row := e.QueryRow(query, repoAt, pullId)
477
478 var pull Pull
479 var createdAt string
480 var sourceBranch, sourceRepoAt sql.NullString
481 err := row.Scan(
482 &pull.OwnerDid,
483 &pull.PullId,
484 &createdAt,
485 &pull.Title,
486 &pull.State,
487 &pull.TargetBranch,
488 &pull.PullAt,
489 &pull.RepoAt,
490 &pull.Body,
491 &pull.Rkey,
492 &sourceBranch,
493 &sourceRepoAt,
494 )
495 if err != nil {
496 return nil, err
497 }
498
499 createdTime, err := time.Parse(time.RFC3339, createdAt)
500 if err != nil {
501 return nil, err
502 }
503 pull.Created = createdTime
504
505 // populate source
506 if sourceBranch.Valid {
507 pull.PullSource = &PullSource{
508 Branch: sourceBranch.String,
509 }
510 if sourceRepoAt.Valid {
511 sourceRepoAtParsed, err := syntax.ParseATURI(sourceRepoAt.String)
512 if err != nil {
513 return nil, err
514 }
515 pull.PullSource.RepoAt = &sourceRepoAtParsed
516 }
517 }
518
519 submissionsQuery := `
520 select
521 id, pull_id, repo_at, round_number, patch, created, source_rev
522 from
523 pull_submissions
524 where
525 repo_at = ? and pull_id = ?
526 `
527 submissionsRows, err := e.Query(submissionsQuery, repoAt, pullId)
528 if err != nil {
529 return nil, err
530 }
531 defer submissionsRows.Close()
532
533 submissionsMap := make(map[int]*PullSubmission)
534
535 for submissionsRows.Next() {
536 var submission PullSubmission
537 var submissionCreatedStr string
538 var submissionSourceRev sql.NullString
539 err := submissionsRows.Scan(
540 &submission.ID,
541 &submission.PullId,
542 &submission.RepoAt,
543 &submission.RoundNumber,
544 &submission.Patch,
545 &submissionCreatedStr,
546 &submissionSourceRev,
547 )
548 if err != nil {
549 return nil, err
550 }
551
552 submissionCreatedTime, err := time.Parse(time.RFC3339, submissionCreatedStr)
553 if err != nil {
554 return nil, err
555 }
556 submission.Created = submissionCreatedTime
557
558 if submissionSourceRev.Valid {
559 submission.SourceRev = submissionSourceRev.String
560 }
561
562 submissionsMap[submission.ID] = &submission
563 }
564 if err = submissionsRows.Close(); err != nil {
565 return nil, err
566 }
567 if len(submissionsMap) == 0 {
568 return &pull, nil
569 }
570
571 var args []any
572 for k := range submissionsMap {
573 args = append(args, k)
574 }
575 inClause := strings.TrimSuffix(strings.Repeat("?, ", len(submissionsMap)), ", ")
576 commentsQuery := fmt.Sprintf(`
577 select
578 id,
579 pull_id,
580 submission_id,
581 repo_at,
582 owner_did,
583 comment_at,
584 body,
585 created
586 from
587 pull_comments
588 where
589 submission_id IN (%s)
590 order by
591 created asc
592 `, inClause)
593 commentsRows, err := e.Query(commentsQuery, args...)
594 if err != nil {
595 return nil, err
596 }
597 defer commentsRows.Close()
598
599 for commentsRows.Next() {
600 var comment PullComment
601 var commentCreatedStr string
602 err := commentsRows.Scan(
603 &comment.ID,
604 &comment.PullId,
605 &comment.SubmissionId,
606 &comment.RepoAt,
607 &comment.OwnerDid,
608 &comment.CommentAt,
609 &comment.Body,
610 &commentCreatedStr,
611 )
612 if err != nil {
613 return nil, err
614 }
615
616 commentCreatedTime, err := time.Parse(time.RFC3339, commentCreatedStr)
617 if err != nil {
618 return nil, err
619 }
620 comment.Created = commentCreatedTime
621
622 // Add the comment to its submission
623 if submission, ok := submissionsMap[comment.SubmissionId]; ok {
624 submission.Comments = append(submission.Comments, comment)
625 }
626
627 }
628 if err = commentsRows.Err(); err != nil {
629 return nil, err
630 }
631
632 pull.Submissions = make([]*PullSubmission, len(submissionsMap))
633 for _, submission := range submissionsMap {
634 pull.Submissions[submission.RoundNumber] = submission
635 }
636
637 return &pull, nil
638}
639
640// timeframe here is directly passed into the sql query filter, and any
641// timeframe in the past should be negative; e.g.: "-3 months"
642func GetPullsByOwnerDid(e Execer, did, timeframe string) ([]Pull, error) {
643 var pulls []Pull
644
645 rows, err := e.Query(`
646 select
647 p.owner_did,
648 p.repo_at,
649 p.pull_id,
650 p.created,
651 p.title,
652 p.state,
653 r.did,
654 r.name,
655 r.knot,
656 r.rkey,
657 r.created
658 from
659 pulls p
660 join
661 repos r on p.repo_at = r.at_uri
662 where
663 p.owner_did = ? and p.created >= date ('now', ?)
664 order by
665 p.created desc`, did, timeframe)
666 if err != nil {
667 return nil, err
668 }
669 defer rows.Close()
670
671 for rows.Next() {
672 var pull Pull
673 var repo Repo
674 var pullCreatedAt, repoCreatedAt string
675 err := rows.Scan(
676 &pull.OwnerDid,
677 &pull.RepoAt,
678 &pull.PullId,
679 &pullCreatedAt,
680 &pull.Title,
681 &pull.State,
682 &repo.Did,
683 &repo.Name,
684 &repo.Knot,
685 &repo.Rkey,
686 &repoCreatedAt,
687 )
688 if err != nil {
689 return nil, err
690 }
691
692 pullCreatedTime, err := time.Parse(time.RFC3339, pullCreatedAt)
693 if err != nil {
694 return nil, err
695 }
696 pull.Created = pullCreatedTime
697
698 repoCreatedTime, err := time.Parse(time.RFC3339, repoCreatedAt)
699 if err != nil {
700 return nil, err
701 }
702 repo.Created = repoCreatedTime
703
704 pull.Repo = &repo
705
706 pulls = append(pulls, pull)
707 }
708
709 if err := rows.Err(); err != nil {
710 return nil, err
711 }
712
713 return pulls, nil
714}
715
716func NewPullComment(e Execer, comment *PullComment) (int64, error) {
717 query := `insert into pull_comments (owner_did, repo_at, submission_id, comment_at, pull_id, body) values (?, ?, ?, ?, ?, ?)`
718 res, err := e.Exec(
719 query,
720 comment.OwnerDid,
721 comment.RepoAt,
722 comment.SubmissionId,
723 comment.CommentAt,
724 comment.PullId,
725 comment.Body,
726 )
727 if err != nil {
728 return 0, err
729 }
730
731 i, err := res.LastInsertId()
732 if err != nil {
733 return 0, err
734 }
735
736 return i, nil
737}
738
739func SetPullState(e Execer, repoAt syntax.ATURI, pullId int, pullState PullState) error {
740 _, err := e.Exec(`update pulls set state = ? where repo_at = ? and pull_id = ?`, pullState, repoAt, pullId)
741 return err
742}
743
744func ClosePull(e Execer, repoAt syntax.ATURI, pullId int) error {
745 err := SetPullState(e, repoAt, pullId, PullClosed)
746 return err
747}
748
749func ReopenPull(e Execer, repoAt syntax.ATURI, pullId int) error {
750 err := SetPullState(e, repoAt, pullId, PullOpen)
751 return err
752}
753
754func MergePull(e Execer, repoAt syntax.ATURI, pullId int) error {
755 err := SetPullState(e, repoAt, pullId, PullMerged)
756 return err
757}
758
759func ResubmitPull(e Execer, pull *Pull, newPatch, sourceRev string) error {
760 newRoundNumber := len(pull.Submissions)
761 _, err := e.Exec(`
762 insert into pull_submissions (pull_id, repo_at, round_number, patch, source_rev)
763 values (?, ?, ?, ?, ?)
764 `, pull.PullId, pull.RepoAt, newRoundNumber, newPatch, sourceRev)
765
766 return err
767}
768
769type PullCount struct {
770 Open int
771 Merged int
772 Closed int
773}
774
775func GetPullCount(e Execer, repoAt syntax.ATURI) (PullCount, error) {
776 row := e.QueryRow(`
777 select
778 count(case when state = ? then 1 end) as open_count,
779 count(case when state = ? then 1 end) as merged_count,
780 count(case when state = ? then 1 end) as closed_count
781 from pulls
782 where repo_at = ?`,
783 PullOpen,
784 PullMerged,
785 PullClosed,
786 repoAt,
787 )
788
789 var count PullCount
790 if err := row.Scan(&count.Open, &count.Merged, &count.Closed); err != nil {
791 return PullCount{0, 0, 0}, err
792 }
793
794 return count, nil
795}