1package db
2
3import (
4 "database/sql"
5 "fmt"
6 "log"
7 "sort"
8 "strings"
9 "time"
10
11 "github.com/bluekeyes/go-gitdiff/gitdiff"
12 "github.com/bluesky-social/indigo/atproto/syntax"
13 "tangled.sh/tangled.sh/core/types"
14)
15
16type PullState int
17
18const (
19 PullClosed PullState = iota
20 PullOpen
21 PullMerged
22)
23
24func (p PullState) String() string {
25 switch p {
26 case PullOpen:
27 return "open"
28 case PullMerged:
29 return "merged"
30 case PullClosed:
31 return "closed"
32 default:
33 return "closed"
34 }
35}
36
37func (p PullState) IsOpen() bool {
38 return p == PullOpen
39}
40func (p PullState) IsMerged() bool {
41 return p == PullMerged
42}
43func (p PullState) IsClosed() bool {
44 return p == PullClosed
45}
46
47type Pull struct {
48 // ids
49 ID int
50 PullId int
51
52 // at ids
53 RepoAt syntax.ATURI
54 OwnerDid string
55 Rkey string
56 PullAt syntax.ATURI
57
58 // content
59 Title string
60 Body string
61 TargetBranch string
62 State PullState
63 Submissions []*PullSubmission
64
65 // meta
66 Created time.Time
67 PullSource *PullSource
68
69 // optionally, populate this when querying for reverse mappings
70 Repo *Repo
71}
72
73type PullSource struct {
74 Branch string
75 RepoAt *syntax.ATURI
76
77 // optionally populate this for reverse mappings
78 Repo *Repo
79}
80
81type PullSubmission struct {
82 // ids
83 ID int
84 PullId int
85
86 // at ids
87 RepoAt syntax.ATURI
88
89 // content
90 RoundNumber int
91 Patch string
92 Comments []PullComment
93 SourceRev string // include the rev that was used to create this submission: only for branch PRs
94
95 // meta
96 Created time.Time
97}
98
99type PullComment struct {
100 // ids
101 ID int
102 PullId int
103 SubmissionId int
104
105 // at ids
106 RepoAt string
107 OwnerDid string
108 CommentAt string
109
110 // content
111 Body string
112
113 // meta
114 Created time.Time
115}
116
117func (p *Pull) LatestPatch() string {
118 latestSubmission := p.Submissions[p.LastRoundNumber()]
119 return latestSubmission.Patch
120}
121
122func (p *Pull) LastRoundNumber() int {
123 return len(p.Submissions) - 1
124}
125
126func (p *Pull) IsPatchBased() bool {
127 return p.PullSource == nil
128}
129
130func (p *Pull) IsBranchBased() bool {
131 if p.PullSource != nil {
132 if p.PullSource.RepoAt != nil {
133 return p.PullSource.RepoAt == &p.RepoAt
134 } else {
135 // no repo specified
136 return true
137 }
138 }
139 return false
140}
141
142func (p *Pull) IsForkBased() bool {
143 if p.PullSource != nil {
144 if p.PullSource.RepoAt != nil {
145 // make sure repos are different
146 return p.PullSource.RepoAt != &p.RepoAt
147 }
148 }
149 return false
150}
151
152func (s PullSubmission) AsNiceDiff(targetBranch string) types.NiceDiff {
153 patch := s.Patch
154
155 diffs, _, err := gitdiff.Parse(strings.NewReader(patch))
156 if err != nil {
157 log.Println(err)
158 }
159
160 nd := types.NiceDiff{}
161 nd.Commit.Parent = targetBranch
162
163 for _, d := range diffs {
164 ndiff := types.Diff{}
165 ndiff.Name.New = d.NewName
166 ndiff.Name.Old = d.OldName
167 ndiff.IsBinary = d.IsBinary
168 ndiff.IsNew = d.IsNew
169 ndiff.IsDelete = d.IsDelete
170 ndiff.IsCopy = d.IsCopy
171 ndiff.IsRename = d.IsRename
172
173 for _, tf := range d.TextFragments {
174 ndiff.TextFragments = append(ndiff.TextFragments, *tf)
175 for _, l := range tf.Lines {
176 switch l.Op {
177 case gitdiff.OpAdd:
178 nd.Stat.Insertions += 1
179 case gitdiff.OpDelete:
180 nd.Stat.Deletions += 1
181 }
182 }
183 }
184
185 nd.Diff = append(nd.Diff, ndiff)
186 }
187
188 nd.Stat.FilesChanged = len(diffs)
189
190 return nd
191}
192
193func NewPull(tx *sql.Tx, pull *Pull) error {
194 defer tx.Rollback()
195
196 _, err := tx.Exec(`
197 insert or ignore into repo_pull_seqs (repo_at, next_pull_id)
198 values (?, 1)
199 `, pull.RepoAt)
200 if err != nil {
201 return err
202 }
203
204 var nextId int
205 err = tx.QueryRow(`
206 update repo_pull_seqs
207 set next_pull_id = next_pull_id + 1
208 where repo_at = ?
209 returning next_pull_id - 1
210 `, pull.RepoAt).Scan(&nextId)
211 if err != nil {
212 return err
213 }
214
215 pull.PullId = nextId
216 pull.State = PullOpen
217
218 var sourceBranch, sourceRepoAt *string
219 if pull.PullSource != nil {
220 sourceBranch = &pull.PullSource.Branch
221 if pull.PullSource.RepoAt != nil {
222 x := pull.PullSource.RepoAt.String()
223 sourceRepoAt = &x
224 }
225 }
226
227 _, err = tx.Exec(
228 `
229 insert into pulls (repo_at, owner_did, pull_id, title, target_branch, body, rkey, state, source_branch, source_repo_at)
230 values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`,
231 pull.RepoAt,
232 pull.OwnerDid,
233 pull.PullId,
234 pull.Title,
235 pull.TargetBranch,
236 pull.Body,
237 pull.Rkey,
238 pull.State,
239 sourceBranch,
240 sourceRepoAt,
241 )
242 if err != nil {
243 return err
244 }
245
246 _, err = tx.Exec(`
247 insert into pull_submissions (pull_id, repo_at, round_number, patch, source_rev)
248 values (?, ?, ?, ?, ?)
249 `, pull.PullId, pull.RepoAt, 0, pull.Submissions[0].Patch, pull.Submissions[0].SourceRev)
250 if err != nil {
251 return err
252 }
253
254 if err := tx.Commit(); err != nil {
255 return err
256 }
257
258 return nil
259}
260
261func SetPullAt(e Execer, repoAt syntax.ATURI, pullId int, pullAt string) error {
262 _, err := e.Exec(`update pulls set pull_at = ? where repo_at = ? and pull_id = ?`, pullAt, repoAt, pullId)
263 return err
264}
265
266func GetPullAt(e Execer, repoAt syntax.ATURI, pullId int) (string, error) {
267 var pullAt string
268 err := e.QueryRow(`select pull_at from pulls where repo_at = ? and pull_id = ?`, repoAt, pullId).Scan(&pullAt)
269 return pullAt, err
270}
271
272func NextPullId(e Execer, repoAt syntax.ATURI) (int, error) {
273 var pullId int
274 err := e.QueryRow(`select next_pull_id from repo_pull_seqs where repo_at = ?`, repoAt).Scan(&pullId)
275 return pullId - 1, err
276}
277
278func GetPulls(e Execer, repoAt syntax.ATURI, state PullState) ([]*Pull, error) {
279 pulls := make(map[int]*Pull)
280
281 rows, err := e.Query(`
282 select
283 owner_did,
284 pull_id,
285 created,
286 title,
287 state,
288 target_branch,
289 pull_at,
290 body,
291 rkey,
292 source_branch,
293 source_repo_at
294 from
295 pulls
296 where
297 repo_at = ? and state = ?`, repoAt, state)
298 if err != nil {
299 return nil, err
300 }
301 defer rows.Close()
302
303 for rows.Next() {
304 var pull Pull
305 var createdAt string
306 var sourceBranch, sourceRepoAt sql.NullString
307 err := rows.Scan(
308 &pull.OwnerDid,
309 &pull.PullId,
310 &createdAt,
311 &pull.Title,
312 &pull.State,
313 &pull.TargetBranch,
314 &pull.PullAt,
315 &pull.Body,
316 &pull.Rkey,
317 &sourceBranch,
318 &sourceRepoAt,
319 )
320 if err != nil {
321 return nil, err
322 }
323
324 createdTime, err := time.Parse(time.RFC3339, createdAt)
325 if err != nil {
326 return nil, err
327 }
328 pull.Created = createdTime
329
330 if sourceBranch.Valid {
331 pull.PullSource = &PullSource{
332 Branch: sourceBranch.String,
333 }
334 if sourceRepoAt.Valid {
335 sourceRepoAtParsed, err := syntax.ParseATURI(sourceRepoAt.String)
336 if err != nil {
337 return nil, err
338 }
339 pull.PullSource.RepoAt = &sourceRepoAtParsed
340 }
341 }
342
343 pulls[pull.PullId] = &pull
344 }
345
346 // get latest round no. for each pull
347 inClause := strings.TrimSuffix(strings.Repeat("?, ", len(pulls)), ", ")
348 submissionsQuery := fmt.Sprintf(`
349 select
350 id, pull_id, round_number
351 from
352 pull_submissions
353 where
354 repo_at = ? and pull_id in (%s)
355 `, inClause)
356
357 args := make([]any, len(pulls)+1)
358 args[0] = repoAt.String()
359 idx := 1
360 for _, p := range pulls {
361 args[idx] = p.PullId
362 idx += 1
363 }
364 submissionsRows, err := e.Query(submissionsQuery, args...)
365 if err != nil {
366 return nil, err
367 }
368 defer submissionsRows.Close()
369
370 for submissionsRows.Next() {
371 var s PullSubmission
372 err := submissionsRows.Scan(
373 &s.ID,
374 &s.PullId,
375 &s.RoundNumber,
376 )
377 if err != nil {
378 return nil, err
379 }
380
381 if p, ok := pulls[s.PullId]; ok {
382 p.Submissions = make([]*PullSubmission, s.RoundNumber+1)
383 p.Submissions[s.RoundNumber] = &s
384 }
385 }
386 if err := rows.Err(); err != nil {
387 return nil, err
388 }
389
390 // get comment count on latest submission on each pull
391 inClause = strings.TrimSuffix(strings.Repeat("?, ", len(pulls)), ", ")
392 commentsQuery := fmt.Sprintf(`
393 select
394 count(id), pull_id
395 from
396 pull_comments
397 where
398 submission_id in (%s)
399 group by
400 submission_id
401 `, inClause)
402
403 args = make([]any, len(pulls))
404 idx = 0
405 for _, p := range pulls {
406 args[idx] = p.Submissions[p.LastRoundNumber()].ID
407 idx += 1
408 }
409 commentsRows, err := e.Query(commentsQuery, args...)
410 if err != nil {
411 return nil, err
412 }
413 defer commentsRows.Close()
414
415 for commentsRows.Next() {
416 var commentCount, pullId int
417 err := commentsRows.Scan(
418 &commentCount,
419 &pullId,
420 )
421 if err != nil {
422 return nil, err
423 }
424 if p, ok := pulls[pullId]; ok {
425 p.Submissions[p.LastRoundNumber()].Comments = make([]PullComment, commentCount)
426 }
427 }
428 if err := rows.Err(); err != nil {
429 return nil, err
430 }
431
432 orderedByDate := make([]*Pull, len(pulls))
433 idx = 0
434 for _, p := range pulls {
435 orderedByDate[idx] = p
436 idx += 1
437 }
438 sort.Slice(orderedByDate, func(i, j int) bool {
439 return orderedByDate[i].Created.After(orderedByDate[j].Created)
440 })
441
442 return orderedByDate, nil
443}
444
445func GetPull(e Execer, repoAt syntax.ATURI, pullId int) (*Pull, error) {
446 query := `
447 select
448 owner_did,
449 pull_id,
450 created,
451 title,
452 state,
453 target_branch,
454 pull_at,
455 repo_at,
456 body,
457 rkey,
458 source_branch,
459 source_repo_at
460 from
461 pulls
462 where
463 repo_at = ? and pull_id = ?
464 `
465 row := e.QueryRow(query, repoAt, pullId)
466
467 var pull Pull
468 var createdAt string
469 var sourceBranch, sourceRepoAt sql.NullString
470 err := row.Scan(
471 &pull.OwnerDid,
472 &pull.PullId,
473 &createdAt,
474 &pull.Title,
475 &pull.State,
476 &pull.TargetBranch,
477 &pull.PullAt,
478 &pull.RepoAt,
479 &pull.Body,
480 &pull.Rkey,
481 &sourceBranch,
482 &sourceRepoAt,
483 )
484 if err != nil {
485 return nil, err
486 }
487
488 createdTime, err := time.Parse(time.RFC3339, createdAt)
489 if err != nil {
490 return nil, err
491 }
492 pull.Created = createdTime
493
494 // populate source
495 if sourceBranch.Valid {
496 pull.PullSource = &PullSource{
497 Branch: sourceBranch.String,
498 }
499 if sourceRepoAt.Valid {
500 sourceRepoAtParsed, err := syntax.ParseATURI(sourceRepoAt.String)
501 if err != nil {
502 return nil, err
503 }
504 pull.PullSource.RepoAt = &sourceRepoAtParsed
505 }
506 }
507
508 submissionsQuery := `
509 select
510 id, pull_id, repo_at, round_number, patch, created, source_rev
511 from
512 pull_submissions
513 where
514 repo_at = ? and pull_id = ?
515 `
516 submissionsRows, err := e.Query(submissionsQuery, repoAt, pullId)
517 if err != nil {
518 return nil, err
519 }
520 defer submissionsRows.Close()
521
522 submissionsMap := make(map[int]*PullSubmission)
523
524 for submissionsRows.Next() {
525 var submission PullSubmission
526 var submissionCreatedStr string
527 var submissionSourceRev sql.NullString
528 err := submissionsRows.Scan(
529 &submission.ID,
530 &submission.PullId,
531 &submission.RepoAt,
532 &submission.RoundNumber,
533 &submission.Patch,
534 &submissionCreatedStr,
535 &submissionSourceRev,
536 )
537 if err != nil {
538 return nil, err
539 }
540
541 submissionCreatedTime, err := time.Parse(time.RFC3339, submissionCreatedStr)
542 if err != nil {
543 return nil, err
544 }
545 submission.Created = submissionCreatedTime
546
547 if submissionSourceRev.Valid {
548 submission.SourceRev = submissionSourceRev.String
549 }
550
551 submissionsMap[submission.ID] = &submission
552 }
553 if err = submissionsRows.Close(); err != nil {
554 return nil, err
555 }
556 if len(submissionsMap) == 0 {
557 return &pull, nil
558 }
559
560 var args []any
561 for k := range submissionsMap {
562 args = append(args, k)
563 }
564 inClause := strings.TrimSuffix(strings.Repeat("?, ", len(submissionsMap)), ", ")
565 commentsQuery := fmt.Sprintf(`
566 select
567 id,
568 pull_id,
569 submission_id,
570 repo_at,
571 owner_did,
572 comment_at,
573 body,
574 created
575 from
576 pull_comments
577 where
578 submission_id IN (%s)
579 order by
580 created asc
581 `, inClause)
582 commentsRows, err := e.Query(commentsQuery, args...)
583 if err != nil {
584 return nil, err
585 }
586 defer commentsRows.Close()
587
588 for commentsRows.Next() {
589 var comment PullComment
590 var commentCreatedStr string
591 err := commentsRows.Scan(
592 &comment.ID,
593 &comment.PullId,
594 &comment.SubmissionId,
595 &comment.RepoAt,
596 &comment.OwnerDid,
597 &comment.CommentAt,
598 &comment.Body,
599 &commentCreatedStr,
600 )
601 if err != nil {
602 return nil, err
603 }
604
605 commentCreatedTime, err := time.Parse(time.RFC3339, commentCreatedStr)
606 if err != nil {
607 return nil, err
608 }
609 comment.Created = commentCreatedTime
610
611 // Add the comment to its submission
612 if submission, ok := submissionsMap[comment.SubmissionId]; ok {
613 submission.Comments = append(submission.Comments, comment)
614 }
615
616 }
617 if err = commentsRows.Err(); err != nil {
618 return nil, err
619 }
620
621 pull.Submissions = make([]*PullSubmission, len(submissionsMap))
622 for _, submission := range submissionsMap {
623 pull.Submissions[submission.RoundNumber] = submission
624 }
625
626 return &pull, nil
627}
628
629// timeframe here is directly passed into the sql query filter, and any
630// timeframe in the past should be negative; e.g.: "-3 months"
631func GetPullsByOwnerDid(e Execer, did, timeframe string) ([]Pull, error) {
632 var pulls []Pull
633
634 rows, err := e.Query(`
635 select
636 p.owner_did,
637 p.repo_at,
638 p.pull_id,
639 p.created,
640 p.title,
641 p.state,
642 r.did,
643 r.name,
644 r.knot,
645 r.rkey,
646 r.created
647 from
648 pulls p
649 join
650 repos r on p.repo_at = r.at_uri
651 where
652 p.owner_did = ? and p.created >= date ('now', ?)
653 order by
654 p.created desc`, did, timeframe)
655 if err != nil {
656 return nil, err
657 }
658 defer rows.Close()
659
660 for rows.Next() {
661 var pull Pull
662 var repo Repo
663 var pullCreatedAt, repoCreatedAt string
664 err := rows.Scan(
665 &pull.OwnerDid,
666 &pull.RepoAt,
667 &pull.PullId,
668 &pullCreatedAt,
669 &pull.Title,
670 &pull.State,
671 &repo.Did,
672 &repo.Name,
673 &repo.Knot,
674 &repo.Rkey,
675 &repoCreatedAt,
676 )
677 if err != nil {
678 return nil, err
679 }
680
681 pullCreatedTime, err := time.Parse(time.RFC3339, pullCreatedAt)
682 if err != nil {
683 return nil, err
684 }
685 pull.Created = pullCreatedTime
686
687 repoCreatedTime, err := time.Parse(time.RFC3339, repoCreatedAt)
688 if err != nil {
689 return nil, err
690 }
691 repo.Created = repoCreatedTime
692
693 pull.Repo = &repo
694
695 pulls = append(pulls, pull)
696 }
697
698 if err := rows.Err(); err != nil {
699 return nil, err
700 }
701
702 return pulls, nil
703}
704
705func NewPullComment(e Execer, comment *PullComment) (int64, error) {
706 query := `insert into pull_comments (owner_did, repo_at, submission_id, comment_at, pull_id, body) values (?, ?, ?, ?, ?, ?)`
707 res, err := e.Exec(
708 query,
709 comment.OwnerDid,
710 comment.RepoAt,
711 comment.SubmissionId,
712 comment.CommentAt,
713 comment.PullId,
714 comment.Body,
715 )
716 if err != nil {
717 return 0, err
718 }
719
720 i, err := res.LastInsertId()
721 if err != nil {
722 return 0, err
723 }
724
725 return i, nil
726}
727
728func SetPullState(e Execer, repoAt syntax.ATURI, pullId int, pullState PullState) error {
729 _, err := e.Exec(`update pulls set state = ? where repo_at = ? and pull_id = ?`, pullState, repoAt, pullId)
730 return err
731}
732
733func ClosePull(e Execer, repoAt syntax.ATURI, pullId int) error {
734 err := SetPullState(e, repoAt, pullId, PullClosed)
735 return err
736}
737
738func ReopenPull(e Execer, repoAt syntax.ATURI, pullId int) error {
739 err := SetPullState(e, repoAt, pullId, PullOpen)
740 return err
741}
742
743func MergePull(e Execer, repoAt syntax.ATURI, pullId int) error {
744 err := SetPullState(e, repoAt, pullId, PullMerged)
745 return err
746}
747
748func ResubmitPull(e Execer, pull *Pull, newPatch, sourceRev string) error {
749 newRoundNumber := len(pull.Submissions)
750 _, err := e.Exec(`
751 insert into pull_submissions (pull_id, repo_at, round_number, patch, source_rev)
752 values (?, ?, ?, ?, ?)
753 `, pull.PullId, pull.RepoAt, newRoundNumber, newPatch, sourceRev)
754
755 return err
756}
757
758type PullCount struct {
759 Open int
760 Merged int
761 Closed int
762}
763
764func GetPullCount(e Execer, repoAt syntax.ATURI) (PullCount, error) {
765 row := e.QueryRow(`
766 select
767 count(case when state = ? then 1 end) as open_count,
768 count(case when state = ? then 1 end) as merged_count,
769 count(case when state = ? then 1 end) as closed_count
770 from pulls
771 where repo_at = ?`,
772 PullOpen,
773 PullMerged,
774 PullClosed,
775 repoAt,
776 )
777
778 var count PullCount
779 if err := row.Scan(&count.Open, &count.Merged, &count.Closed); err != nil {
780 return PullCount{0, 0, 0}, err
781 }
782
783 return count, nil
784}