1package db
2
3import (
4 "database/sql"
5 "fmt"
6 "log"
7 "strings"
8 "time"
9
10 "github.com/bluekeyes/go-gitdiff/gitdiff"
11 "github.com/bluesky-social/indigo/atproto/syntax"
12 "tangled.sh/tangled.sh/core/types"
13)
14
15type PullState int
16
17const (
18 PullClosed PullState = iota
19 PullOpen
20 PullMerged
21)
22
23func (p PullState) String() string {
24 switch p {
25 case PullOpen:
26 return "open"
27 case PullMerged:
28 return "merged"
29 case PullClosed:
30 return "closed"
31 default:
32 return "closed"
33 }
34}
35
36func (p PullState) IsOpen() bool {
37 return p == PullOpen
38}
39func (p PullState) IsMerged() bool {
40 return p == PullMerged
41}
42func (p PullState) IsClosed() bool {
43 return p == PullClosed
44}
45
46type Pull struct {
47 // ids
48 ID int
49 PullId int
50
51 // at ids
52 RepoAt syntax.ATURI
53 OwnerDid string
54 Rkey string
55 PullAt syntax.ATURI
56
57 // content
58 Title string
59 Body string
60 TargetBranch string
61 State PullState
62 Submissions []*PullSubmission
63
64 // meta
65 Created time.Time
66 PullSource *PullSource
67
68 // optionally, populate this when querying for reverse mappings
69 Repo *Repo
70}
71
72type PullSource struct {
73 Branch string
74 RepoAt *syntax.ATURI
75
76 // optionally populate this for reverse mappings
77 Repo *Repo
78}
79
80type PullSubmission struct {
81 // ids
82 ID int
83 PullId int
84
85 // at ids
86 RepoAt syntax.ATURI
87
88 // content
89 RoundNumber int
90 Patch string
91 Comments []PullComment
92 SourceRev string // include the rev that was used to create this submission: only for branch PRs
93
94 // meta
95 Created time.Time
96}
97
98type PullComment struct {
99 // ids
100 ID int
101 PullId int
102 SubmissionId int
103
104 // at ids
105 RepoAt string
106 OwnerDid string
107 CommentAt string
108
109 // content
110 Body string
111
112 // meta
113 Created time.Time
114}
115
116func (p *Pull) LatestPatch() string {
117 latestSubmission := p.Submissions[p.LastRoundNumber()]
118 return latestSubmission.Patch
119}
120
121func (p *Pull) LastRoundNumber() int {
122 return len(p.Submissions) - 1
123}
124
125func (p *Pull) IsPatchBased() bool {
126 return p.PullSource == nil
127}
128
129func (p *Pull) IsBranchBased() bool {
130 if p.PullSource != nil {
131 if p.PullSource.RepoAt != nil {
132 return p.PullSource.RepoAt == &p.RepoAt
133 } else {
134 // no repo specified
135 return true
136 }
137 }
138 return false
139}
140
141func (p *Pull) IsForkBased() bool {
142 if p.PullSource != nil {
143 if p.PullSource.RepoAt != nil {
144 // make sure repos are different
145 return p.PullSource.RepoAt != &p.RepoAt
146 }
147 }
148 return false
149}
150
151func (s PullSubmission) AsNiceDiff(targetBranch string) types.NiceDiff {
152 patch := s.Patch
153
154 diffs, _, err := gitdiff.Parse(strings.NewReader(patch))
155 if err != nil {
156 log.Println(err)
157 }
158
159 nd := types.NiceDiff{}
160 nd.Commit.Parent = targetBranch
161
162 for _, d := range diffs {
163 ndiff := types.Diff{}
164 ndiff.Name.New = d.NewName
165 ndiff.Name.Old = d.OldName
166 ndiff.IsBinary = d.IsBinary
167 ndiff.IsNew = d.IsNew
168 ndiff.IsDelete = d.IsDelete
169 ndiff.IsCopy = d.IsCopy
170 ndiff.IsRename = d.IsRename
171
172 for _, tf := range d.TextFragments {
173 ndiff.TextFragments = append(ndiff.TextFragments, *tf)
174 for _, l := range tf.Lines {
175 switch l.Op {
176 case gitdiff.OpAdd:
177 nd.Stat.Insertions += 1
178 case gitdiff.OpDelete:
179 nd.Stat.Deletions += 1
180 }
181 }
182 }
183
184 nd.Diff = append(nd.Diff, ndiff)
185 }
186
187 nd.Stat.FilesChanged = len(diffs)
188
189 return nd
190}
191
192func NewPull(tx *sql.Tx, pull *Pull) error {
193 defer tx.Rollback()
194
195 _, err := tx.Exec(`
196 insert or ignore into repo_pull_seqs (repo_at, next_pull_id)
197 values (?, 1)
198 `, pull.RepoAt)
199 if err != nil {
200 return err
201 }
202
203 var nextId int
204 err = tx.QueryRow(`
205 update repo_pull_seqs
206 set next_pull_id = next_pull_id + 1
207 where repo_at = ?
208 returning next_pull_id - 1
209 `, pull.RepoAt).Scan(&nextId)
210 if err != nil {
211 return err
212 }
213
214 pull.PullId = nextId
215 pull.State = PullOpen
216
217 var sourceBranch, sourceRepoAt *string
218 if pull.PullSource != nil {
219 sourceBranch = &pull.PullSource.Branch
220 if pull.PullSource.RepoAt != nil {
221 x := pull.PullSource.RepoAt.String()
222 sourceRepoAt = &x
223 }
224 }
225
226 _, err = tx.Exec(
227 `
228 insert into pulls (repo_at, owner_did, pull_id, title, target_branch, body, rkey, state, source_branch, source_repo_at)
229 values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`,
230 pull.RepoAt,
231 pull.OwnerDid,
232 pull.PullId,
233 pull.Title,
234 pull.TargetBranch,
235 pull.Body,
236 pull.Rkey,
237 pull.State,
238 sourceBranch,
239 sourceRepoAt,
240 )
241 if err != nil {
242 return err
243 }
244
245 _, err = tx.Exec(`
246 insert into pull_submissions (pull_id, repo_at, round_number, patch, source_rev)
247 values (?, ?, ?, ?, ?)
248 `, pull.PullId, pull.RepoAt, 0, pull.Submissions[0].Patch, pull.Submissions[0].SourceRev)
249 if err != nil {
250 return err
251 }
252
253 if err := tx.Commit(); err != nil {
254 return err
255 }
256
257 return nil
258}
259
260func SetPullAt(e Execer, repoAt syntax.ATURI, pullId int, pullAt string) error {
261 _, err := e.Exec(`update pulls set pull_at = ? where repo_at = ? and pull_id = ?`, pullAt, repoAt, pullId)
262 return err
263}
264
265func GetPullAt(e Execer, repoAt syntax.ATURI, pullId int) (string, error) {
266 var pullAt string
267 err := e.QueryRow(`select pull_at from pulls where repo_at = ? and pull_id = ?`, repoAt, pullId).Scan(&pullAt)
268 return pullAt, err
269}
270
271func NextPullId(e Execer, repoAt syntax.ATURI) (int, error) {
272 var pullId int
273 err := e.QueryRow(`select next_pull_id from repo_pull_seqs where repo_at = ?`, repoAt).Scan(&pullId)
274 return pullId - 1, err
275}
276
277func GetPulls(e Execer, repoAt syntax.ATURI, state PullState) ([]Pull, error) {
278 var pulls []Pull
279
280 rows, err := e.Query(`
281 select
282 owner_did,
283 pull_id,
284 created,
285 title,
286 state,
287 target_branch,
288 pull_at,
289 body,
290 rkey,
291 source_branch,
292 source_repo_at
293 from
294 pulls
295 where
296 repo_at = ? and state = ?
297 order by
298 created desc`, repoAt, state)
299 if err != nil {
300 return nil, err
301 }
302 defer rows.Close()
303
304 for rows.Next() {
305 var pull Pull
306 var createdAt string
307 var sourceBranch, sourceRepoAt sql.NullString
308 err := rows.Scan(
309 &pull.OwnerDid,
310 &pull.PullId,
311 &createdAt,
312 &pull.Title,
313 &pull.State,
314 &pull.TargetBranch,
315 &pull.PullAt,
316 &pull.Body,
317 &pull.Rkey,
318 &sourceBranch,
319 &sourceRepoAt,
320 )
321 if err != nil {
322 return nil, err
323 }
324
325 createdTime, err := time.Parse(time.RFC3339, createdAt)
326 if err != nil {
327 return nil, err
328 }
329 pull.Created = createdTime
330
331 if sourceBranch.Valid {
332 pull.PullSource = &PullSource{
333 Branch: sourceBranch.String,
334 }
335 if sourceRepoAt.Valid {
336 sourceRepoAtParsed, err := syntax.ParseATURI(sourceRepoAt.String)
337 if err != nil {
338 return nil, err
339 }
340 pull.PullSource.RepoAt = &sourceRepoAtParsed
341 }
342 }
343
344 pulls = append(pulls, pull)
345 }
346
347 if err := rows.Err(); err != nil {
348 return nil, err
349 }
350
351 return pulls, nil
352}
353
354func GetPull(e Execer, repoAt syntax.ATURI, pullId int) (*Pull, error) {
355 query := `
356 select
357 owner_did,
358 pull_id,
359 created,
360 title,
361 state,
362 target_branch,
363 pull_at,
364 repo_at,
365 body,
366 rkey,
367 source_branch,
368 source_repo_at
369 from
370 pulls
371 where
372 repo_at = ? and pull_id = ?
373 `
374 row := e.QueryRow(query, repoAt, pullId)
375
376 var pull Pull
377 var createdAt string
378 var sourceBranch, sourceRepoAt sql.NullString
379 err := row.Scan(
380 &pull.OwnerDid,
381 &pull.PullId,
382 &createdAt,
383 &pull.Title,
384 &pull.State,
385 &pull.TargetBranch,
386 &pull.PullAt,
387 &pull.RepoAt,
388 &pull.Body,
389 &pull.Rkey,
390 &sourceBranch,
391 &sourceRepoAt,
392 )
393 if err != nil {
394 return nil, err
395 }
396
397 createdTime, err := time.Parse(time.RFC3339, createdAt)
398 if err != nil {
399 return nil, err
400 }
401 pull.Created = createdTime
402
403 // populate source
404 if sourceBranch.Valid {
405 pull.PullSource = &PullSource{
406 Branch: sourceBranch.String,
407 }
408 if sourceRepoAt.Valid {
409 sourceRepoAtParsed, err := syntax.ParseATURI(sourceRepoAt.String)
410 if err != nil {
411 return nil, err
412 }
413 pull.PullSource.RepoAt = &sourceRepoAtParsed
414 }
415 }
416
417 submissionsQuery := `
418 select
419 id, pull_id, repo_at, round_number, patch, created, source_rev
420 from
421 pull_submissions
422 where
423 repo_at = ? and pull_id = ?
424 `
425 submissionsRows, err := e.Query(submissionsQuery, repoAt, pullId)
426 if err != nil {
427 return nil, err
428 }
429 defer submissionsRows.Close()
430
431 submissionsMap := make(map[int]*PullSubmission)
432
433 for submissionsRows.Next() {
434 var submission PullSubmission
435 var submissionCreatedStr string
436 var submissionSourceRev sql.NullString
437 err := submissionsRows.Scan(
438 &submission.ID,
439 &submission.PullId,
440 &submission.RepoAt,
441 &submission.RoundNumber,
442 &submission.Patch,
443 &submissionCreatedStr,
444 &submissionSourceRev,
445 )
446 if err != nil {
447 return nil, err
448 }
449
450 submissionCreatedTime, err := time.Parse(time.RFC3339, submissionCreatedStr)
451 if err != nil {
452 return nil, err
453 }
454 submission.Created = submissionCreatedTime
455
456 if submissionSourceRev.Valid {
457 submission.SourceRev = submissionSourceRev.String
458 }
459
460 submissionsMap[submission.ID] = &submission
461 }
462 if err = submissionsRows.Close(); err != nil {
463 return nil, err
464 }
465 if len(submissionsMap) == 0 {
466 return &pull, nil
467 }
468
469 var args []any
470 for k := range submissionsMap {
471 args = append(args, k)
472 }
473 inClause := strings.TrimSuffix(strings.Repeat("?, ", len(submissionsMap)), ", ")
474 commentsQuery := fmt.Sprintf(`
475 select
476 id,
477 pull_id,
478 submission_id,
479 repo_at,
480 owner_did,
481 comment_at,
482 body,
483 created
484 from
485 pull_comments
486 where
487 submission_id IN (%s)
488 order by
489 created asc
490 `, inClause)
491 commentsRows, err := e.Query(commentsQuery, args...)
492 if err != nil {
493 return nil, err
494 }
495 defer commentsRows.Close()
496
497 for commentsRows.Next() {
498 var comment PullComment
499 var commentCreatedStr string
500 err := commentsRows.Scan(
501 &comment.ID,
502 &comment.PullId,
503 &comment.SubmissionId,
504 &comment.RepoAt,
505 &comment.OwnerDid,
506 &comment.CommentAt,
507 &comment.Body,
508 &commentCreatedStr,
509 )
510 if err != nil {
511 return nil, err
512 }
513
514 commentCreatedTime, err := time.Parse(time.RFC3339, commentCreatedStr)
515 if err != nil {
516 return nil, err
517 }
518 comment.Created = commentCreatedTime
519
520 // Add the comment to its submission
521 if submission, ok := submissionsMap[comment.SubmissionId]; ok {
522 submission.Comments = append(submission.Comments, comment)
523 }
524
525 }
526 if err = commentsRows.Err(); err != nil {
527 return nil, err
528 }
529
530 pull.Submissions = make([]*PullSubmission, len(submissionsMap))
531 for _, submission := range submissionsMap {
532 pull.Submissions[submission.RoundNumber] = submission
533 }
534
535 return &pull, nil
536}
537
538// timeframe here is directly passed into the sql query filter, and any
539// timeframe in the past should be negative; e.g.: "-3 months"
540func GetPullsByOwnerDid(e Execer, did, timeframe string) ([]Pull, error) {
541 var pulls []Pull
542
543 rows, err := e.Query(`
544 select
545 p.owner_did,
546 p.repo_at,
547 p.pull_id,
548 p.created,
549 p.title,
550 p.state,
551 r.did,
552 r.name,
553 r.knot,
554 r.rkey,
555 r.created
556 from
557 pulls p
558 join
559 repos r on p.repo_at = r.at_uri
560 where
561 p.owner_did = ? and p.created >= date ('now', ?)
562 order by
563 p.created desc`, did, timeframe)
564 if err != nil {
565 return nil, err
566 }
567 defer rows.Close()
568
569 for rows.Next() {
570 var pull Pull
571 var repo Repo
572 var pullCreatedAt, repoCreatedAt string
573 err := rows.Scan(
574 &pull.OwnerDid,
575 &pull.RepoAt,
576 &pull.PullId,
577 &pullCreatedAt,
578 &pull.Title,
579 &pull.State,
580 &repo.Did,
581 &repo.Name,
582 &repo.Knot,
583 &repo.Rkey,
584 &repoCreatedAt,
585 )
586 if err != nil {
587 return nil, err
588 }
589
590 pullCreatedTime, err := time.Parse(time.RFC3339, pullCreatedAt)
591 if err != nil {
592 return nil, err
593 }
594 pull.Created = pullCreatedTime
595
596 repoCreatedTime, err := time.Parse(time.RFC3339, repoCreatedAt)
597 if err != nil {
598 return nil, err
599 }
600 repo.Created = repoCreatedTime
601
602 pull.Repo = &repo
603
604 pulls = append(pulls, pull)
605 }
606
607 if err := rows.Err(); err != nil {
608 return nil, err
609 }
610
611 return pulls, nil
612}
613
614func NewPullComment(e Execer, comment *PullComment) (int64, error) {
615 query := `insert into pull_comments (owner_did, repo_at, submission_id, comment_at, pull_id, body) values (?, ?, ?, ?, ?, ?)`
616 res, err := e.Exec(
617 query,
618 comment.OwnerDid,
619 comment.RepoAt,
620 comment.SubmissionId,
621 comment.CommentAt,
622 comment.PullId,
623 comment.Body,
624 )
625 if err != nil {
626 return 0, err
627 }
628
629 i, err := res.LastInsertId()
630 if err != nil {
631 return 0, err
632 }
633
634 return i, nil
635}
636
637func SetPullState(e Execer, repoAt syntax.ATURI, pullId int, pullState PullState) error {
638 _, err := e.Exec(`update pulls set state = ? where repo_at = ? and pull_id = ?`, pullState, repoAt, pullId)
639 return err
640}
641
642func ClosePull(e Execer, repoAt syntax.ATURI, pullId int) error {
643 err := SetPullState(e, repoAt, pullId, PullClosed)
644 return err
645}
646
647func ReopenPull(e Execer, repoAt syntax.ATURI, pullId int) error {
648 err := SetPullState(e, repoAt, pullId, PullOpen)
649 return err
650}
651
652func MergePull(e Execer, repoAt syntax.ATURI, pullId int) error {
653 err := SetPullState(e, repoAt, pullId, PullMerged)
654 return err
655}
656
657func ResubmitPull(e Execer, pull *Pull, newPatch, sourceRev string) error {
658 newRoundNumber := len(pull.Submissions)
659 _, err := e.Exec(`
660 insert into pull_submissions (pull_id, repo_at, round_number, patch, source_rev)
661 values (?, ?, ?, ?, ?)
662 `, pull.PullId, pull.RepoAt, newRoundNumber, newPatch, sourceRev)
663
664 return err
665}
666
667type PullCount struct {
668 Open int
669 Merged int
670 Closed int
671}
672
673func GetPullCount(e Execer, repoAt syntax.ATURI) (PullCount, error) {
674 row := e.QueryRow(`
675 select
676 count(case when state = ? then 1 end) as open_count,
677 count(case when state = ? then 1 end) as merged_count,
678 count(case when state = ? then 1 end) as closed_count
679 from pulls
680 where repo_at = ?`,
681 PullOpen,
682 PullMerged,
683 PullClosed,
684 repoAt,
685 )
686
687 var count PullCount
688 if err := row.Scan(&count.Open, &count.Merged, &count.Closed); err != nil {
689 return PullCount{0, 0, 0}, err
690 }
691
692 return count, nil
693}