forked from
tangled.org/core
Monorepo for Tangled — https://tangled.org
1package db
2
3import (
4 "database/sql"
5 "fmt"
6 "log"
7 "sort"
8 "strings"
9 "time"
10
11 "github.com/bluekeyes/go-gitdiff/gitdiff"
12 "github.com/bluesky-social/indigo/atproto/syntax"
13 tangled "tangled.sh/tangled.sh/core/api/tangled"
14 "tangled.sh/tangled.sh/core/patchutil"
15 "tangled.sh/tangled.sh/core/types"
16)
17
18type PullState int
19
20const (
21 PullClosed PullState = iota
22 PullOpen
23 PullMerged
24)
25
26func (p PullState) String() string {
27 switch p {
28 case PullOpen:
29 return "open"
30 case PullMerged:
31 return "merged"
32 case PullClosed:
33 return "closed"
34 default:
35 return "closed"
36 }
37}
38
39func (p PullState) IsOpen() bool {
40 return p == PullOpen
41}
42func (p PullState) IsMerged() bool {
43 return p == PullMerged
44}
45func (p PullState) IsClosed() bool {
46 return p == PullClosed
47}
48
49type Pull struct {
50 // ids
51 ID int
52 PullId int
53
54 // at ids
55 RepoAt syntax.ATURI
56 OwnerDid string
57 Rkey string
58
59 // content
60 Title string
61 Body string
62 TargetBranch string
63 State PullState
64 Submissions []*PullSubmission
65
66 // meta
67 Created time.Time
68 PullSource *PullSource
69
70 // optionally, populate this when querying for reverse mappings
71 Repo *Repo
72}
73
74type PullSource struct {
75 Branch string
76 RepoAt *syntax.ATURI
77
78 // optionally populate this for reverse mappings
79 Repo *Repo
80}
81
82type PullSubmission struct {
83 // ids
84 ID int
85 PullId int
86
87 // at ids
88 RepoAt syntax.ATURI
89
90 // content
91 RoundNumber int
92 Patch string
93 Comments []PullComment
94 SourceRev string // include the rev that was used to create this submission: only for branch PRs
95
96 // meta
97 Created time.Time
98}
99
100type PullComment struct {
101 // ids
102 ID int
103 PullId int
104 SubmissionId int
105
106 // at ids
107 RepoAt string
108 OwnerDid string
109 CommentAt string
110
111 // content
112 Body string
113
114 // meta
115 Created time.Time
116}
117
118func (p *Pull) LatestPatch() string {
119 latestSubmission := p.Submissions[p.LastRoundNumber()]
120 return latestSubmission.Patch
121}
122
123func (p *Pull) PullAt() syntax.ATURI {
124 return syntax.ATURI(fmt.Sprintf("at://%s/%s/%s", p.OwnerDid, tangled.RepoPullNSID, p.Rkey))
125}
126
127func (p *Pull) LastRoundNumber() int {
128 return len(p.Submissions) - 1
129}
130
131func (p *Pull) IsPatchBased() bool {
132 return p.PullSource == nil
133}
134
135func (p *Pull) IsBranchBased() bool {
136 if p.PullSource != nil {
137 if p.PullSource.RepoAt != nil {
138 return p.PullSource.RepoAt == &p.RepoAt
139 } else {
140 // no repo specified
141 return true
142 }
143 }
144 return false
145}
146
147func (p *Pull) IsForkBased() bool {
148 if p.PullSource != nil {
149 if p.PullSource.RepoAt != nil {
150 // make sure repos are different
151 return p.PullSource.RepoAt != &p.RepoAt
152 }
153 }
154 return false
155}
156
157func (s PullSubmission) AsDiff(targetBranch string) ([]*gitdiff.File, error) {
158 patch := s.Patch
159
160 // if format-patch; then extract each patch
161 var diffs []*gitdiff.File
162 if patchutil.IsFormatPatch(patch) {
163 patches, err := patchutil.ExtractPatches(patch)
164 if err != nil {
165 return nil, err
166 }
167 var ps [][]*gitdiff.File
168 for _, p := range patches {
169 ps = append(ps, p.Files)
170 }
171
172 diffs = patchutil.CombineDiff(ps...)
173 } else {
174 d, _, err := gitdiff.Parse(strings.NewReader(patch))
175 if err != nil {
176 return nil, err
177 }
178 diffs = d
179 }
180
181 return diffs, nil
182}
183
184func (s PullSubmission) AsNiceDiff(targetBranch string) types.NiceDiff {
185 diffs, err := s.AsDiff(targetBranch)
186 if err != nil {
187 log.Println(err)
188 }
189
190 nd := types.NiceDiff{}
191 nd.Commit.Parent = targetBranch
192
193 for _, d := range diffs {
194 ndiff := types.Diff{}
195 ndiff.Name.New = d.NewName
196 ndiff.Name.Old = d.OldName
197 ndiff.IsBinary = d.IsBinary
198 ndiff.IsNew = d.IsNew
199 ndiff.IsDelete = d.IsDelete
200 ndiff.IsCopy = d.IsCopy
201 ndiff.IsRename = d.IsRename
202
203 for _, tf := range d.TextFragments {
204 ndiff.TextFragments = append(ndiff.TextFragments, *tf)
205 for _, l := range tf.Lines {
206 switch l.Op {
207 case gitdiff.OpAdd:
208 nd.Stat.Insertions += 1
209 case gitdiff.OpDelete:
210 nd.Stat.Deletions += 1
211 }
212 }
213 }
214
215 nd.Diff = append(nd.Diff, ndiff)
216 }
217
218 nd.Stat.FilesChanged = len(diffs)
219
220 return nd
221}
222
223func (s PullSubmission) IsFormatPatch() bool {
224 return patchutil.IsFormatPatch(s.Patch)
225}
226
227func (s PullSubmission) AsFormatPatch() []patchutil.FormatPatch {
228 patches, err := patchutil.ExtractPatches(s.Patch)
229 if err != nil {
230 log.Println("error extracting patches from submission:", err)
231 return []patchutil.FormatPatch{}
232 }
233
234 return patches
235}
236
237func NewPull(tx *sql.Tx, pull *Pull) error {
238 defer tx.Rollback()
239
240 _, err := tx.Exec(`
241 insert or ignore into repo_pull_seqs (repo_at, next_pull_id)
242 values (?, 1)
243 `, pull.RepoAt)
244 if err != nil {
245 return err
246 }
247
248 var nextId int
249 err = tx.QueryRow(`
250 update repo_pull_seqs
251 set next_pull_id = next_pull_id + 1
252 where repo_at = ?
253 returning next_pull_id - 1
254 `, pull.RepoAt).Scan(&nextId)
255 if err != nil {
256 return err
257 }
258
259 pull.PullId = nextId
260 pull.State = PullOpen
261
262 var sourceBranch, sourceRepoAt *string
263 if pull.PullSource != nil {
264 sourceBranch = &pull.PullSource.Branch
265 if pull.PullSource.RepoAt != nil {
266 x := pull.PullSource.RepoAt.String()
267 sourceRepoAt = &x
268 }
269 }
270
271 _, err = tx.Exec(
272 `
273 insert into pulls (repo_at, owner_did, pull_id, title, target_branch, body, rkey, state, source_branch, source_repo_at)
274 values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`,
275 pull.RepoAt,
276 pull.OwnerDid,
277 pull.PullId,
278 pull.Title,
279 pull.TargetBranch,
280 pull.Body,
281 pull.Rkey,
282 pull.State,
283 sourceBranch,
284 sourceRepoAt,
285 )
286 if err != nil {
287 return err
288 }
289
290 _, err = tx.Exec(`
291 insert into pull_submissions (pull_id, repo_at, round_number, patch, source_rev)
292 values (?, ?, ?, ?, ?)
293 `, pull.PullId, pull.RepoAt, 0, pull.Submissions[0].Patch, pull.Submissions[0].SourceRev)
294 if err != nil {
295 return err
296 }
297
298 if err := tx.Commit(); err != nil {
299 return err
300 }
301
302 return nil
303}
304
305func GetPullAt(e Execer, repoAt syntax.ATURI, pullId int) (syntax.ATURI, error) {
306 pull, err := GetPull(e, repoAt, pullId)
307 if err != nil {
308 return "", err
309 }
310 return pull.PullAt(), err
311}
312
313func NextPullId(e Execer, repoAt syntax.ATURI) (int, error) {
314 var pullId int
315 err := e.QueryRow(`select next_pull_id from repo_pull_seqs where repo_at = ?`, repoAt).Scan(&pullId)
316 return pullId - 1, err
317}
318
319func GetPulls(e Execer, repoAt syntax.ATURI, state PullState) ([]*Pull, error) {
320 pulls := make(map[int]*Pull)
321
322 rows, err := e.Query(`
323 select
324 owner_did,
325 pull_id,
326 created,
327 title,
328 state,
329 target_branch,
330 body,
331 rkey,
332 source_branch,
333 source_repo_at
334 from
335 pulls
336 where
337 repo_at = ? and state = ?`, repoAt, state)
338 if err != nil {
339 return nil, err
340 }
341 defer rows.Close()
342
343 for rows.Next() {
344 var pull Pull
345 var createdAt string
346 var sourceBranch, sourceRepoAt sql.NullString
347 err := rows.Scan(
348 &pull.OwnerDid,
349 &pull.PullId,
350 &createdAt,
351 &pull.Title,
352 &pull.State,
353 &pull.TargetBranch,
354 &pull.Body,
355 &pull.Rkey,
356 &sourceBranch,
357 &sourceRepoAt,
358 )
359 if err != nil {
360 return nil, err
361 }
362
363 createdTime, err := time.Parse(time.RFC3339, createdAt)
364 if err != nil {
365 return nil, err
366 }
367 pull.Created = createdTime
368
369 if sourceBranch.Valid {
370 pull.PullSource = &PullSource{
371 Branch: sourceBranch.String,
372 }
373 if sourceRepoAt.Valid {
374 sourceRepoAtParsed, err := syntax.ParseATURI(sourceRepoAt.String)
375 if err != nil {
376 return nil, err
377 }
378 pull.PullSource.RepoAt = &sourceRepoAtParsed
379 }
380 }
381
382 pulls[pull.PullId] = &pull
383 }
384
385 // get latest round no. for each pull
386 inClause := strings.TrimSuffix(strings.Repeat("?, ", len(pulls)), ", ")
387 submissionsQuery := fmt.Sprintf(`
388 select
389 id, pull_id, round_number
390 from
391 pull_submissions
392 where
393 repo_at = ? and pull_id in (%s)
394 `, inClause)
395
396 args := make([]any, len(pulls)+1)
397 args[0] = repoAt.String()
398 idx := 1
399 for _, p := range pulls {
400 args[idx] = p.PullId
401 idx += 1
402 }
403 submissionsRows, err := e.Query(submissionsQuery, args...)
404 if err != nil {
405 return nil, err
406 }
407 defer submissionsRows.Close()
408
409 for submissionsRows.Next() {
410 var s PullSubmission
411 err := submissionsRows.Scan(
412 &s.ID,
413 &s.PullId,
414 &s.RoundNumber,
415 )
416 if err != nil {
417 return nil, err
418 }
419
420 if p, ok := pulls[s.PullId]; ok {
421 p.Submissions = make([]*PullSubmission, s.RoundNumber+1)
422 p.Submissions[s.RoundNumber] = &s
423 }
424 }
425 if err := rows.Err(); err != nil {
426 return nil, err
427 }
428
429 // get comment count on latest submission on each pull
430 inClause = strings.TrimSuffix(strings.Repeat("?, ", len(pulls)), ", ")
431 commentsQuery := fmt.Sprintf(`
432 select
433 count(id), pull_id
434 from
435 pull_comments
436 where
437 submission_id in (%s)
438 group by
439 submission_id
440 `, inClause)
441
442 args = []any{}
443 for _, p := range pulls {
444 args = append(args, p.Submissions[p.LastRoundNumber()].ID)
445 }
446 commentsRows, err := e.Query(commentsQuery, args...)
447 if err != nil {
448 return nil, err
449 }
450 defer commentsRows.Close()
451
452 for commentsRows.Next() {
453 var commentCount, pullId int
454 err := commentsRows.Scan(
455 &commentCount,
456 &pullId,
457 )
458 if err != nil {
459 return nil, err
460 }
461 if p, ok := pulls[pullId]; ok {
462 p.Submissions[p.LastRoundNumber()].Comments = make([]PullComment, commentCount)
463 }
464 }
465 if err := rows.Err(); err != nil {
466 return nil, err
467 }
468
469 orderedByDate := []*Pull{}
470 for _, p := range pulls {
471 orderedByDate = append(orderedByDate, p)
472 }
473 sort.Slice(orderedByDate, func(i, j int) bool {
474 return orderedByDate[i].Created.After(orderedByDate[j].Created)
475 })
476
477 return orderedByDate, nil
478}
479
480func GetPull(e Execer, repoAt syntax.ATURI, pullId int) (*Pull, error) {
481 query := `
482 select
483 owner_did,
484 pull_id,
485 created,
486 title,
487 state,
488 target_branch,
489 repo_at,
490 body,
491 rkey,
492 source_branch,
493 source_repo_at
494 from
495 pulls
496 where
497 repo_at = ? and pull_id = ?
498 `
499 row := e.QueryRow(query, repoAt, pullId)
500
501 var pull Pull
502 var createdAt string
503 var sourceBranch, sourceRepoAt sql.NullString
504 err := row.Scan(
505 &pull.OwnerDid,
506 &pull.PullId,
507 &createdAt,
508 &pull.Title,
509 &pull.State,
510 &pull.TargetBranch,
511 &pull.RepoAt,
512 &pull.Body,
513 &pull.Rkey,
514 &sourceBranch,
515 &sourceRepoAt,
516 )
517 if err != nil {
518 return nil, err
519 }
520
521 createdTime, err := time.Parse(time.RFC3339, createdAt)
522 if err != nil {
523 return nil, err
524 }
525 pull.Created = createdTime
526
527 // populate source
528 if sourceBranch.Valid {
529 pull.PullSource = &PullSource{
530 Branch: sourceBranch.String,
531 }
532 if sourceRepoAt.Valid {
533 sourceRepoAtParsed, err := syntax.ParseATURI(sourceRepoAt.String)
534 if err != nil {
535 return nil, err
536 }
537 pull.PullSource.RepoAt = &sourceRepoAtParsed
538 }
539 }
540
541 submissionsQuery := `
542 select
543 id, pull_id, repo_at, round_number, patch, created, source_rev
544 from
545 pull_submissions
546 where
547 repo_at = ? and pull_id = ?
548 `
549 submissionsRows, err := e.Query(submissionsQuery, repoAt, pullId)
550 if err != nil {
551 return nil, err
552 }
553 defer submissionsRows.Close()
554
555 submissionsMap := make(map[int]*PullSubmission)
556
557 for submissionsRows.Next() {
558 var submission PullSubmission
559 var submissionCreatedStr string
560 var submissionSourceRev sql.NullString
561 err := submissionsRows.Scan(
562 &submission.ID,
563 &submission.PullId,
564 &submission.RepoAt,
565 &submission.RoundNumber,
566 &submission.Patch,
567 &submissionCreatedStr,
568 &submissionSourceRev,
569 )
570 if err != nil {
571 return nil, err
572 }
573
574 submissionCreatedTime, err := time.Parse(time.RFC3339, submissionCreatedStr)
575 if err != nil {
576 return nil, err
577 }
578 submission.Created = submissionCreatedTime
579
580 if submissionSourceRev.Valid {
581 submission.SourceRev = submissionSourceRev.String
582 }
583
584 submissionsMap[submission.ID] = &submission
585 }
586 if err = submissionsRows.Close(); err != nil {
587 return nil, err
588 }
589 if len(submissionsMap) == 0 {
590 return &pull, nil
591 }
592
593 var args []any
594 for k := range submissionsMap {
595 args = append(args, k)
596 }
597 inClause := strings.TrimSuffix(strings.Repeat("?, ", len(submissionsMap)), ", ")
598 commentsQuery := fmt.Sprintf(`
599 select
600 id,
601 pull_id,
602 submission_id,
603 repo_at,
604 owner_did,
605 comment_at,
606 body,
607 created
608 from
609 pull_comments
610 where
611 submission_id IN (%s)
612 order by
613 created asc
614 `, inClause)
615 commentsRows, err := e.Query(commentsQuery, args...)
616 if err != nil {
617 return nil, err
618 }
619 defer commentsRows.Close()
620
621 for commentsRows.Next() {
622 var comment PullComment
623 var commentCreatedStr string
624 err := commentsRows.Scan(
625 &comment.ID,
626 &comment.PullId,
627 &comment.SubmissionId,
628 &comment.RepoAt,
629 &comment.OwnerDid,
630 &comment.CommentAt,
631 &comment.Body,
632 &commentCreatedStr,
633 )
634 if err != nil {
635 return nil, err
636 }
637
638 commentCreatedTime, err := time.Parse(time.RFC3339, commentCreatedStr)
639 if err != nil {
640 return nil, err
641 }
642 comment.Created = commentCreatedTime
643
644 // Add the comment to its submission
645 if submission, ok := submissionsMap[comment.SubmissionId]; ok {
646 submission.Comments = append(submission.Comments, comment)
647 }
648
649 }
650 if err = commentsRows.Err(); err != nil {
651 return nil, err
652 }
653
654 var pullSourceRepo *Repo
655 if pull.PullSource != nil {
656 if pull.PullSource.RepoAt != nil {
657 pullSourceRepo, err = GetRepoByAtUri(e, pull.PullSource.RepoAt.String())
658 if err != nil {
659 log.Printf("failed to get repo by at uri: %v", err)
660 } else {
661 pull.PullSource.Repo = pullSourceRepo
662 }
663 }
664 }
665
666 pull.Submissions = make([]*PullSubmission, len(submissionsMap))
667 for _, submission := range submissionsMap {
668 pull.Submissions[submission.RoundNumber] = submission
669 }
670
671 return &pull, nil
672}
673
674// timeframe here is directly passed into the sql query filter, and any
675// timeframe in the past should be negative; e.g.: "-3 months"
676func GetPullsByOwnerDid(e Execer, did, timeframe string) ([]Pull, error) {
677 var pulls []Pull
678
679 rows, err := e.Query(`
680 select
681 p.owner_did,
682 p.repo_at,
683 p.pull_id,
684 p.created,
685 p.title,
686 p.state,
687 r.did,
688 r.name,
689 r.knot,
690 r.rkey,
691 r.created
692 from
693 pulls p
694 join
695 repos r on p.repo_at = r.at_uri
696 where
697 p.owner_did = ? and p.created >= date ('now', ?)
698 order by
699 p.created desc`, did, timeframe)
700 if err != nil {
701 return nil, err
702 }
703 defer rows.Close()
704
705 for rows.Next() {
706 var pull Pull
707 var repo Repo
708 var pullCreatedAt, repoCreatedAt string
709 err := rows.Scan(
710 &pull.OwnerDid,
711 &pull.RepoAt,
712 &pull.PullId,
713 &pullCreatedAt,
714 &pull.Title,
715 &pull.State,
716 &repo.Did,
717 &repo.Name,
718 &repo.Knot,
719 &repo.Rkey,
720 &repoCreatedAt,
721 )
722 if err != nil {
723 return nil, err
724 }
725
726 pullCreatedTime, err := time.Parse(time.RFC3339, pullCreatedAt)
727 if err != nil {
728 return nil, err
729 }
730 pull.Created = pullCreatedTime
731
732 repoCreatedTime, err := time.Parse(time.RFC3339, repoCreatedAt)
733 if err != nil {
734 return nil, err
735 }
736 repo.Created = repoCreatedTime
737
738 pull.Repo = &repo
739
740 pulls = append(pulls, pull)
741 }
742
743 if err := rows.Err(); err != nil {
744 return nil, err
745 }
746
747 return pulls, nil
748}
749
750func NewPullComment(e Execer, comment *PullComment) (int64, error) {
751 query := `insert into pull_comments (owner_did, repo_at, submission_id, comment_at, pull_id, body) values (?, ?, ?, ?, ?, ?)`
752 res, err := e.Exec(
753 query,
754 comment.OwnerDid,
755 comment.RepoAt,
756 comment.SubmissionId,
757 comment.CommentAt,
758 comment.PullId,
759 comment.Body,
760 )
761 if err != nil {
762 return 0, err
763 }
764
765 i, err := res.LastInsertId()
766 if err != nil {
767 return 0, err
768 }
769
770 return i, nil
771}
772
773func SetPullState(e Execer, repoAt syntax.ATURI, pullId int, pullState PullState) error {
774 _, err := e.Exec(`update pulls set state = ? where repo_at = ? and pull_id = ?`, pullState, repoAt, pullId)
775 return err
776}
777
778func ClosePull(e Execer, repoAt syntax.ATURI, pullId int) error {
779 err := SetPullState(e, repoAt, pullId, PullClosed)
780 return err
781}
782
783func ReopenPull(e Execer, repoAt syntax.ATURI, pullId int) error {
784 err := SetPullState(e, repoAt, pullId, PullOpen)
785 return err
786}
787
788func MergePull(e Execer, repoAt syntax.ATURI, pullId int) error {
789 err := SetPullState(e, repoAt, pullId, PullMerged)
790 return err
791}
792
793func ResubmitPull(e Execer, pull *Pull, newPatch, sourceRev string) error {
794 newRoundNumber := len(pull.Submissions)
795 _, err := e.Exec(`
796 insert into pull_submissions (pull_id, repo_at, round_number, patch, source_rev)
797 values (?, ?, ?, ?, ?)
798 `, pull.PullId, pull.RepoAt, newRoundNumber, newPatch, sourceRev)
799
800 return err
801}
802
803type PullCount struct {
804 Open int
805 Merged int
806 Closed int
807}
808
809func GetPullCount(e Execer, repoAt syntax.ATURI) (PullCount, error) {
810 row := e.QueryRow(`
811 select
812 count(case when state = ? then 1 end) as open_count,
813 count(case when state = ? then 1 end) as merged_count,
814 count(case when state = ? then 1 end) as closed_count
815 from pulls
816 where repo_at = ?`,
817 PullOpen,
818 PullMerged,
819 PullClosed,
820 repoAt,
821 )
822
823 var count PullCount
824 if err := row.Scan(&count.Open, &count.Merged, &count.Closed); err != nil {
825 return PullCount{0, 0, 0}, err
826 }
827
828 return count, nil
829}