1package db
2
3import (
4 "database/sql"
5 "fmt"
6 "log"
7 "sort"
8 "strings"
9 "time"
10
11 "github.com/bluekeyes/go-gitdiff/gitdiff"
12 "github.com/bluesky-social/indigo/atproto/syntax"
13 "tangled.sh/tangled.sh/core/types"
14)
15
16type PullState int
17
18const (
19 PullClosed PullState = iota
20 PullOpen
21 PullMerged
22)
23
24func (p PullState) String() string {
25 switch p {
26 case PullOpen:
27 return "open"
28 case PullMerged:
29 return "merged"
30 case PullClosed:
31 return "closed"
32 default:
33 return "closed"
34 }
35}
36
37func (p PullState) IsOpen() bool {
38 return p == PullOpen
39}
40func (p PullState) IsMerged() bool {
41 return p == PullMerged
42}
43func (p PullState) IsClosed() bool {
44 return p == PullClosed
45}
46
47type Pull struct {
48 // ids
49 ID int
50 PullId int
51
52 // at ids
53 RepoAt syntax.ATURI
54 OwnerDid string
55 Rkey string
56 PullAt syntax.ATURI
57
58 // content
59 Title string
60 Body string
61 TargetBranch string
62 State PullState
63 Submissions []*PullSubmission
64
65 // meta
66 Created time.Time
67 PullSource *PullSource
68
69 // optionally, populate this when querying for reverse mappings
70 Repo *Repo
71}
72
73type PullSource struct {
74 Branch string
75 RepoAt *syntax.ATURI
76
77 // optionally populate this for reverse mappings
78 Repo *Repo
79}
80
81type PullSubmission struct {
82 // ids
83 ID int
84 PullId int
85
86 // at ids
87 RepoAt syntax.ATURI
88
89 // content
90 RoundNumber int
91 Patch string
92 Comments []PullComment
93 SourceRev string // include the rev that was used to create this submission: only for branch PRs
94
95 // meta
96 Created time.Time
97}
98
99type PullComment struct {
100 // ids
101 ID int
102 PullId int
103 SubmissionId int
104
105 // at ids
106 RepoAt string
107 OwnerDid string
108 CommentAt string
109
110 // content
111 Body string
112
113 // meta
114 Created time.Time
115}
116
117func (p *Pull) LatestPatch() string {
118 latestSubmission := p.Submissions[p.LastRoundNumber()]
119 return latestSubmission.Patch
120}
121
122func (p *Pull) LastRoundNumber() int {
123 return len(p.Submissions) - 1
124}
125
126func (p *Pull) IsPatchBased() bool {
127 return p.PullSource == nil
128}
129
130func (p *Pull) IsBranchBased() bool {
131 if p.PullSource != nil {
132 if p.PullSource.RepoAt != nil {
133 return p.PullSource.RepoAt == &p.RepoAt
134 } else {
135 // no repo specified
136 return true
137 }
138 }
139 return false
140}
141
142func (p *Pull) IsForkBased() bool {
143 if p.PullSource != nil {
144 if p.PullSource.RepoAt != nil {
145 // make sure repos are different
146 return p.PullSource.RepoAt != &p.RepoAt
147 }
148 }
149 return false
150}
151
152func (s PullSubmission) AsNiceDiff(targetBranch string) types.NiceDiff {
153 patch := s.Patch
154
155 diffs, _, err := gitdiff.Parse(strings.NewReader(patch))
156 if err != nil {
157 log.Println(err)
158 }
159
160 nd := types.NiceDiff{}
161 nd.Commit.Parent = targetBranch
162
163 for _, d := range diffs {
164 ndiff := types.Diff{}
165 ndiff.Name.New = d.NewName
166 ndiff.Name.Old = d.OldName
167 ndiff.IsBinary = d.IsBinary
168 ndiff.IsNew = d.IsNew
169 ndiff.IsDelete = d.IsDelete
170 ndiff.IsCopy = d.IsCopy
171 ndiff.IsRename = d.IsRename
172
173 for _, tf := range d.TextFragments {
174 ndiff.TextFragments = append(ndiff.TextFragments, *tf)
175 for _, l := range tf.Lines {
176 switch l.Op {
177 case gitdiff.OpAdd:
178 nd.Stat.Insertions += 1
179 case gitdiff.OpDelete:
180 nd.Stat.Deletions += 1
181 }
182 }
183 }
184
185 nd.Diff = append(nd.Diff, ndiff)
186 }
187
188 nd.Stat.FilesChanged = len(diffs)
189
190 return nd
191}
192
193func NewPull(tx *sql.Tx, pull *Pull) error {
194 defer tx.Rollback()
195
196 _, err := tx.Exec(`
197 insert or ignore into repo_pull_seqs (repo_at, next_pull_id)
198 values (?, 1)
199 `, pull.RepoAt)
200 if err != nil {
201 return err
202 }
203
204 var nextId int
205 err = tx.QueryRow(`
206 update repo_pull_seqs
207 set next_pull_id = next_pull_id + 1
208 where repo_at = ?
209 returning next_pull_id - 1
210 `, pull.RepoAt).Scan(&nextId)
211 if err != nil {
212 return err
213 }
214
215 pull.PullId = nextId
216 pull.State = PullOpen
217
218 var sourceBranch, sourceRepoAt *string
219 if pull.PullSource != nil {
220 sourceBranch = &pull.PullSource.Branch
221 if pull.PullSource.RepoAt != nil {
222 x := pull.PullSource.RepoAt.String()
223 sourceRepoAt = &x
224 }
225 }
226
227 _, err = tx.Exec(
228 `
229 insert into pulls (repo_at, owner_did, pull_id, title, target_branch, body, rkey, state, source_branch, source_repo_at)
230 values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`,
231 pull.RepoAt,
232 pull.OwnerDid,
233 pull.PullId,
234 pull.Title,
235 pull.TargetBranch,
236 pull.Body,
237 pull.Rkey,
238 pull.State,
239 sourceBranch,
240 sourceRepoAt,
241 )
242 if err != nil {
243 return err
244 }
245
246 _, err = tx.Exec(`
247 insert into pull_submissions (pull_id, repo_at, round_number, patch, source_rev)
248 values (?, ?, ?, ?, ?)
249 `, pull.PullId, pull.RepoAt, 0, pull.Submissions[0].Patch, pull.Submissions[0].SourceRev)
250 if err != nil {
251 return err
252 }
253
254 if err := tx.Commit(); err != nil {
255 return err
256 }
257
258 return nil
259}
260
261func SetPullAt(e Execer, repoAt syntax.ATURI, pullId int, pullAt string) error {
262 _, err := e.Exec(`update pulls set pull_at = ? where repo_at = ? and pull_id = ?`, pullAt, repoAt, pullId)
263 return err
264}
265
266func GetPullAt(e Execer, repoAt syntax.ATURI, pullId int) (string, error) {
267 var pullAt string
268 err := e.QueryRow(`select pull_at from pulls where repo_at = ? and pull_id = ?`, repoAt, pullId).Scan(&pullAt)
269 return pullAt, err
270}
271
272func NextPullId(e Execer, repoAt syntax.ATURI) (int, error) {
273 var pullId int
274 err := e.QueryRow(`select next_pull_id from repo_pull_seqs where repo_at = ?`, repoAt).Scan(&pullId)
275 return pullId - 1, err
276}
277
278func GetPulls(e Execer, repoAt syntax.ATURI, state PullState) ([]*Pull, error) {
279 pulls := make(map[int]*Pull)
280
281 rows, err := e.Query(`
282 select
283 owner_did,
284 pull_id,
285 created,
286 title,
287 state,
288 target_branch,
289 pull_at,
290 body,
291 rkey,
292 source_branch,
293 source_repo_at
294 from
295 pulls
296 where
297 repo_at = ? and state = ?`, repoAt, state)
298 if err != nil {
299 return nil, err
300 }
301 defer rows.Close()
302
303 for rows.Next() {
304 var pull Pull
305 var createdAt string
306 var sourceBranch, sourceRepoAt sql.NullString
307 err := rows.Scan(
308 &pull.OwnerDid,
309 &pull.PullId,
310 &createdAt,
311 &pull.Title,
312 &pull.State,
313 &pull.TargetBranch,
314 &pull.PullAt,
315 &pull.Body,
316 &pull.Rkey,
317 &sourceBranch,
318 &sourceRepoAt,
319 )
320 if err != nil {
321 return nil, err
322 }
323
324 createdTime, err := time.Parse(time.RFC3339, createdAt)
325 if err != nil {
326 return nil, err
327 }
328 pull.Created = createdTime
329
330 if sourceBranch.Valid {
331 pull.PullSource = &PullSource{
332 Branch: sourceBranch.String,
333 }
334 if sourceRepoAt.Valid {
335 sourceRepoAtParsed, err := syntax.ParseATURI(sourceRepoAt.String)
336 if err != nil {
337 return nil, err
338 }
339 pull.PullSource.RepoAt = &sourceRepoAtParsed
340 }
341 }
342
343 pulls[pull.PullId] = &pull
344 }
345
346 // get latest round no. for each pull
347 inClause := strings.TrimSuffix(strings.Repeat("?, ", len(pulls)), ", ")
348 submissionsQuery := fmt.Sprintf(`
349 select
350 id, pull_id, round_number
351 from
352 pull_submissions
353 where
354 repo_at = ? and pull_id in (%s)
355 `, inClause)
356
357 args := make([]any, len(pulls)+1)
358 args[0] = repoAt.String()
359 idx := 1
360 for _, p := range pulls {
361 args[idx] = p.PullId
362 idx += 1
363 }
364 submissionsRows, err := e.Query(submissionsQuery, args...)
365 if err != nil {
366 return nil, err
367 }
368 defer submissionsRows.Close()
369
370 for submissionsRows.Next() {
371 var s PullSubmission
372 err := submissionsRows.Scan(
373 &s.ID,
374 &s.PullId,
375 &s.RoundNumber,
376 )
377 if err != nil {
378 return nil, err
379 }
380
381 if p, ok := pulls[s.PullId]; ok {
382 p.Submissions = make([]*PullSubmission, s.RoundNumber+1)
383 p.Submissions[s.RoundNumber] = &s
384 }
385 }
386 if err := rows.Err(); err != nil {
387 return nil, err
388 }
389
390 // get comment count on latest submission on each pull
391 inClause = strings.TrimSuffix(strings.Repeat("?, ", len(pulls)), ", ")
392 commentsQuery := fmt.Sprintf(`
393 select
394 count(id), pull_id
395 from
396 pull_comments
397 where
398 submission_id in (%s)
399 group by
400 submission_id
401 `, inClause)
402
403 args = []any{}
404 for _, p := range pulls {
405 args = append(args, p.Submissions[p.LastRoundNumber()].ID)
406 }
407 commentsRows, err := e.Query(commentsQuery, args...)
408 if err != nil {
409 return nil, err
410 }
411 defer commentsRows.Close()
412
413 for commentsRows.Next() {
414 var commentCount, pullId int
415 err := commentsRows.Scan(
416 &commentCount,
417 &pullId,
418 )
419 if err != nil {
420 return nil, err
421 }
422 if p, ok := pulls[pullId]; ok {
423 p.Submissions[p.LastRoundNumber()].Comments = make([]PullComment, commentCount)
424 }
425 }
426 if err := rows.Err(); err != nil {
427 return nil, err
428 }
429
430 orderedByDate := []*Pull{}
431 for _, p := range pulls {
432 orderedByDate = append(orderedByDate, p)
433 }
434 sort.Slice(orderedByDate, func(i, j int) bool {
435 return orderedByDate[i].Created.After(orderedByDate[j].Created)
436 })
437
438 return orderedByDate, nil
439}
440
441func GetPull(e Execer, repoAt syntax.ATURI, pullId int) (*Pull, error) {
442 query := `
443 select
444 owner_did,
445 pull_id,
446 created,
447 title,
448 state,
449 target_branch,
450 pull_at,
451 repo_at,
452 body,
453 rkey,
454 source_branch,
455 source_repo_at
456 from
457 pulls
458 where
459 repo_at = ? and pull_id = ?
460 `
461 row := e.QueryRow(query, repoAt, pullId)
462
463 var pull Pull
464 var createdAt string
465 var sourceBranch, sourceRepoAt sql.NullString
466 err := row.Scan(
467 &pull.OwnerDid,
468 &pull.PullId,
469 &createdAt,
470 &pull.Title,
471 &pull.State,
472 &pull.TargetBranch,
473 &pull.PullAt,
474 &pull.RepoAt,
475 &pull.Body,
476 &pull.Rkey,
477 &sourceBranch,
478 &sourceRepoAt,
479 )
480 if err != nil {
481 return nil, err
482 }
483
484 createdTime, err := time.Parse(time.RFC3339, createdAt)
485 if err != nil {
486 return nil, err
487 }
488 pull.Created = createdTime
489
490 // populate source
491 if sourceBranch.Valid {
492 pull.PullSource = &PullSource{
493 Branch: sourceBranch.String,
494 }
495 if sourceRepoAt.Valid {
496 sourceRepoAtParsed, err := syntax.ParseATURI(sourceRepoAt.String)
497 if err != nil {
498 return nil, err
499 }
500 pull.PullSource.RepoAt = &sourceRepoAtParsed
501 }
502 }
503
504 submissionsQuery := `
505 select
506 id, pull_id, repo_at, round_number, patch, created, source_rev
507 from
508 pull_submissions
509 where
510 repo_at = ? and pull_id = ?
511 `
512 submissionsRows, err := e.Query(submissionsQuery, repoAt, pullId)
513 if err != nil {
514 return nil, err
515 }
516 defer submissionsRows.Close()
517
518 submissionsMap := make(map[int]*PullSubmission)
519
520 for submissionsRows.Next() {
521 var submission PullSubmission
522 var submissionCreatedStr string
523 var submissionSourceRev sql.NullString
524 err := submissionsRows.Scan(
525 &submission.ID,
526 &submission.PullId,
527 &submission.RepoAt,
528 &submission.RoundNumber,
529 &submission.Patch,
530 &submissionCreatedStr,
531 &submissionSourceRev,
532 )
533 if err != nil {
534 return nil, err
535 }
536
537 submissionCreatedTime, err := time.Parse(time.RFC3339, submissionCreatedStr)
538 if err != nil {
539 return nil, err
540 }
541 submission.Created = submissionCreatedTime
542
543 if submissionSourceRev.Valid {
544 submission.SourceRev = submissionSourceRev.String
545 }
546
547 submissionsMap[submission.ID] = &submission
548 }
549 if err = submissionsRows.Close(); err != nil {
550 return nil, err
551 }
552 if len(submissionsMap) == 0 {
553 return &pull, nil
554 }
555
556 var args []any
557 for k := range submissionsMap {
558 args = append(args, k)
559 }
560 inClause := strings.TrimSuffix(strings.Repeat("?, ", len(submissionsMap)), ", ")
561 commentsQuery := fmt.Sprintf(`
562 select
563 id,
564 pull_id,
565 submission_id,
566 repo_at,
567 owner_did,
568 comment_at,
569 body,
570 created
571 from
572 pull_comments
573 where
574 submission_id IN (%s)
575 order by
576 created asc
577 `, inClause)
578 commentsRows, err := e.Query(commentsQuery, args...)
579 if err != nil {
580 return nil, err
581 }
582 defer commentsRows.Close()
583
584 for commentsRows.Next() {
585 var comment PullComment
586 var commentCreatedStr string
587 err := commentsRows.Scan(
588 &comment.ID,
589 &comment.PullId,
590 &comment.SubmissionId,
591 &comment.RepoAt,
592 &comment.OwnerDid,
593 &comment.CommentAt,
594 &comment.Body,
595 &commentCreatedStr,
596 )
597 if err != nil {
598 return nil, err
599 }
600
601 commentCreatedTime, err := time.Parse(time.RFC3339, commentCreatedStr)
602 if err != nil {
603 return nil, err
604 }
605 comment.Created = commentCreatedTime
606
607 // Add the comment to its submission
608 if submission, ok := submissionsMap[comment.SubmissionId]; ok {
609 submission.Comments = append(submission.Comments, comment)
610 }
611
612 }
613 if err = commentsRows.Err(); err != nil {
614 return nil, err
615 }
616
617 pull.Submissions = make([]*PullSubmission, len(submissionsMap))
618 for _, submission := range submissionsMap {
619 pull.Submissions[submission.RoundNumber] = submission
620 }
621
622 return &pull, nil
623}
624
625// timeframe here is directly passed into the sql query filter, and any
626// timeframe in the past should be negative; e.g.: "-3 months"
627func GetPullsByOwnerDid(e Execer, did, timeframe string) ([]Pull, error) {
628 var pulls []Pull
629
630 rows, err := e.Query(`
631 select
632 p.owner_did,
633 p.repo_at,
634 p.pull_id,
635 p.created,
636 p.title,
637 p.state,
638 r.did,
639 r.name,
640 r.knot,
641 r.rkey,
642 r.created
643 from
644 pulls p
645 join
646 repos r on p.repo_at = r.at_uri
647 where
648 p.owner_did = ? and p.created >= date ('now', ?)
649 order by
650 p.created desc`, did, timeframe)
651 if err != nil {
652 return nil, err
653 }
654 defer rows.Close()
655
656 for rows.Next() {
657 var pull Pull
658 var repo Repo
659 var pullCreatedAt, repoCreatedAt string
660 err := rows.Scan(
661 &pull.OwnerDid,
662 &pull.RepoAt,
663 &pull.PullId,
664 &pullCreatedAt,
665 &pull.Title,
666 &pull.State,
667 &repo.Did,
668 &repo.Name,
669 &repo.Knot,
670 &repo.Rkey,
671 &repoCreatedAt,
672 )
673 if err != nil {
674 return nil, err
675 }
676
677 pullCreatedTime, err := time.Parse(time.RFC3339, pullCreatedAt)
678 if err != nil {
679 return nil, err
680 }
681 pull.Created = pullCreatedTime
682
683 repoCreatedTime, err := time.Parse(time.RFC3339, repoCreatedAt)
684 if err != nil {
685 return nil, err
686 }
687 repo.Created = repoCreatedTime
688
689 pull.Repo = &repo
690
691 pulls = append(pulls, pull)
692 }
693
694 if err := rows.Err(); err != nil {
695 return nil, err
696 }
697
698 return pulls, nil
699}
700
701func NewPullComment(e Execer, comment *PullComment) (int64, error) {
702 query := `insert into pull_comments (owner_did, repo_at, submission_id, comment_at, pull_id, body) values (?, ?, ?, ?, ?, ?)`
703 res, err := e.Exec(
704 query,
705 comment.OwnerDid,
706 comment.RepoAt,
707 comment.SubmissionId,
708 comment.CommentAt,
709 comment.PullId,
710 comment.Body,
711 )
712 if err != nil {
713 return 0, err
714 }
715
716 i, err := res.LastInsertId()
717 if err != nil {
718 return 0, err
719 }
720
721 return i, nil
722}
723
724func SetPullState(e Execer, repoAt syntax.ATURI, pullId int, pullState PullState) error {
725 _, err := e.Exec(`update pulls set state = ? where repo_at = ? and pull_id = ?`, pullState, repoAt, pullId)
726 return err
727}
728
729func ClosePull(e Execer, repoAt syntax.ATURI, pullId int) error {
730 err := SetPullState(e, repoAt, pullId, PullClosed)
731 return err
732}
733
734func ReopenPull(e Execer, repoAt syntax.ATURI, pullId int) error {
735 err := SetPullState(e, repoAt, pullId, PullOpen)
736 return err
737}
738
739func MergePull(e Execer, repoAt syntax.ATURI, pullId int) error {
740 err := SetPullState(e, repoAt, pullId, PullMerged)
741 return err
742}
743
744func ResubmitPull(e Execer, pull *Pull, newPatch, sourceRev string) error {
745 newRoundNumber := len(pull.Submissions)
746 _, err := e.Exec(`
747 insert into pull_submissions (pull_id, repo_at, round_number, patch, source_rev)
748 values (?, ?, ?, ?, ?)
749 `, pull.PullId, pull.RepoAt, newRoundNumber, newPatch, sourceRev)
750
751 return err
752}
753
754type PullCount struct {
755 Open int
756 Merged int
757 Closed int
758}
759
760func GetPullCount(e Execer, repoAt syntax.ATURI) (PullCount, error) {
761 row := e.QueryRow(`
762 select
763 count(case when state = ? then 1 end) as open_count,
764 count(case when state = ? then 1 end) as merged_count,
765 count(case when state = ? then 1 end) as closed_count
766 from pulls
767 where repo_at = ?`,
768 PullOpen,
769 PullMerged,
770 PullClosed,
771 repoAt,
772 )
773
774 var count PullCount
775 if err := row.Scan(&count.Open, &count.Merged, &count.Closed); err != nil {
776 return PullCount{0, 0, 0}, err
777 }
778
779 return count, nil
780}