1package db
2
3import (
4 "database/sql"
5 "fmt"
6 "log"
7 "strings"
8 "time"
9
10 "github.com/bluekeyes/go-gitdiff/gitdiff"
11 "github.com/bluesky-social/indigo/atproto/syntax"
12 "tangled.sh/tangled.sh/core/types"
13)
14
15type PullState int
16
17const (
18 PullClosed PullState = iota
19 PullOpen
20 PullMerged
21)
22
23func (p PullState) String() string {
24 switch p {
25 case PullOpen:
26 return "open"
27 case PullMerged:
28 return "merged"
29 case PullClosed:
30 return "closed"
31 default:
32 return "closed"
33 }
34}
35
36func (p PullState) IsOpen() bool {
37 return p == PullOpen
38}
39func (p PullState) IsMerged() bool {
40 return p == PullMerged
41}
42func (p PullState) IsClosed() bool {
43 return p == PullClosed
44}
45
46type Pull struct {
47 // ids
48 ID int
49 PullId int
50
51 // at ids
52 RepoAt syntax.ATURI
53 OwnerDid string
54 Rkey string
55 PullAt syntax.ATURI
56
57 // content
58 Title string
59 Body string
60 TargetBranch string
61 State PullState
62 Submissions []*PullSubmission
63
64 // meta
65 Created time.Time
66 PullSource *PullSource
67
68 // optionally, populate this when querying for reverse mappings
69 Repo *Repo
70}
71
72type PullSource struct {
73 Branch string
74 RepoAt *syntax.ATURI
75
76 // optionally populate this for reverse mappings
77 Repo *Repo
78}
79
80type PullSubmission struct {
81 // ids
82 ID int
83 PullId int
84
85 // at ids
86 RepoAt syntax.ATURI
87
88 // content
89 RoundNumber int
90 Patch string
91 Comments []PullComment
92 SourceRev string // include the rev that was used to create this submission: only for branch PRs
93
94 // meta
95 Created time.Time
96}
97
98type PullComment struct {
99 // ids
100 ID int
101 PullId int
102 SubmissionId int
103
104 // at ids
105 RepoAt string
106 OwnerDid string
107 CommentAt string
108
109 // content
110 Body string
111
112 // meta
113 Created time.Time
114}
115
116func (p *Pull) LatestPatch() string {
117 latestSubmission := p.Submissions[p.LastRoundNumber()]
118 return latestSubmission.Patch
119}
120
121func (p *Pull) LastRoundNumber() int {
122 return len(p.Submissions) - 1
123}
124
125func (p *Pull) IsSameRepoBranch() bool {
126 if p.PullSource != nil {
127 if p.PullSource.RepoAt != nil {
128 return p.PullSource.RepoAt == &p.RepoAt
129 } else {
130 // no repo specified
131 return true
132 }
133 }
134 return false
135}
136
137func (p *Pull) IsPatch() bool {
138 return p.PullSource == nil
139}
140
141func (s PullSubmission) AsNiceDiff(targetBranch string) types.NiceDiff {
142 patch := s.Patch
143
144 diffs, _, err := gitdiff.Parse(strings.NewReader(patch))
145 if err != nil {
146 log.Println(err)
147 }
148
149 nd := types.NiceDiff{}
150 nd.Commit.Parent = targetBranch
151
152 for _, d := range diffs {
153 ndiff := types.Diff{}
154 ndiff.Name.New = d.NewName
155 ndiff.Name.Old = d.OldName
156 ndiff.IsBinary = d.IsBinary
157 ndiff.IsNew = d.IsNew
158 ndiff.IsDelete = d.IsDelete
159 ndiff.IsCopy = d.IsCopy
160 ndiff.IsRename = d.IsRename
161
162 for _, tf := range d.TextFragments {
163 ndiff.TextFragments = append(ndiff.TextFragments, *tf)
164 for _, l := range tf.Lines {
165 switch l.Op {
166 case gitdiff.OpAdd:
167 nd.Stat.Insertions += 1
168 case gitdiff.OpDelete:
169 nd.Stat.Deletions += 1
170 }
171 }
172 }
173
174 nd.Diff = append(nd.Diff, ndiff)
175 }
176
177 nd.Stat.FilesChanged = len(diffs)
178
179 return nd
180}
181
182func NewPull(tx *sql.Tx, pull *Pull) error {
183 defer tx.Rollback()
184
185 _, err := tx.Exec(`
186 insert or ignore into repo_pull_seqs (repo_at, next_pull_id)
187 values (?, 1)
188 `, pull.RepoAt)
189 if err != nil {
190 return err
191 }
192
193 var nextId int
194 err = tx.QueryRow(`
195 update repo_pull_seqs
196 set next_pull_id = next_pull_id + 1
197 where repo_at = ?
198 returning next_pull_id - 1
199 `, pull.RepoAt).Scan(&nextId)
200 if err != nil {
201 return err
202 }
203
204 pull.PullId = nextId
205 pull.State = PullOpen
206
207 var sourceBranch, sourceRepoAt *string
208 if pull.PullSource != nil {
209 sourceBranch = &pull.PullSource.Branch
210 if pull.PullSource.RepoAt != nil {
211 x := pull.PullSource.RepoAt.String()
212 sourceRepoAt = &x
213 }
214 }
215
216 _, err = tx.Exec(
217 `
218 insert into pulls (repo_at, owner_did, pull_id, title, target_branch, body, rkey, state, source_branch, source_repo_at)
219 values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`,
220 pull.RepoAt,
221 pull.OwnerDid,
222 pull.PullId,
223 pull.Title,
224 pull.TargetBranch,
225 pull.Body,
226 pull.Rkey,
227 pull.State,
228 sourceBranch,
229 sourceRepoAt,
230 )
231 if err != nil {
232 return err
233 }
234
235 _, err = tx.Exec(`
236 insert into pull_submissions (pull_id, repo_at, round_number, patch, source_rev)
237 values (?, ?, ?, ?, ?)
238 `, pull.PullId, pull.RepoAt, 0, pull.Submissions[0].Patch, pull.Submissions[0].SourceRev)
239 if err != nil {
240 return err
241 }
242
243 if err := tx.Commit(); err != nil {
244 return err
245 }
246
247 return nil
248}
249
250func SetPullAt(e Execer, repoAt syntax.ATURI, pullId int, pullAt string) error {
251 _, err := e.Exec(`update pulls set pull_at = ? where repo_at = ? and pull_id = ?`, pullAt, repoAt, pullId)
252 return err
253}
254
255func GetPullAt(e Execer, repoAt syntax.ATURI, pullId int) (string, error) {
256 var pullAt string
257 err := e.QueryRow(`select pull_at from pulls where repo_at = ? and pull_id = ?`, repoAt, pullId).Scan(&pullAt)
258 return pullAt, err
259}
260
261func NextPullId(e Execer, repoAt syntax.ATURI) (int, error) {
262 var pullId int
263 err := e.QueryRow(`select next_pull_id from repo_pull_seqs where repo_at = ?`, repoAt).Scan(&pullId)
264 return pullId - 1, err
265}
266
267func GetPulls(e Execer, repoAt syntax.ATURI, state PullState) ([]Pull, error) {
268 var pulls []Pull
269
270 rows, err := e.Query(`
271 select
272 owner_did,
273 pull_id,
274 created,
275 title,
276 state,
277 target_branch,
278 pull_at,
279 body,
280 rkey,
281 source_branch,
282 source_repo_at
283 from
284 pulls
285 where
286 repo_at = ? and state = ?
287 order by
288 created desc`, repoAt, state)
289 if err != nil {
290 return nil, err
291 }
292 defer rows.Close()
293
294 for rows.Next() {
295 var pull Pull
296 var createdAt string
297 var sourceBranch, sourceRepoAt sql.NullString
298 err := rows.Scan(
299 &pull.OwnerDid,
300 &pull.PullId,
301 &createdAt,
302 &pull.Title,
303 &pull.State,
304 &pull.TargetBranch,
305 &pull.PullAt,
306 &pull.Body,
307 &pull.Rkey,
308 &sourceBranch,
309 &sourceRepoAt,
310 )
311 if err != nil {
312 return nil, err
313 }
314
315 createdTime, err := time.Parse(time.RFC3339, createdAt)
316 if err != nil {
317 return nil, err
318 }
319 pull.Created = createdTime
320
321 if sourceBranch.Valid {
322 pull.PullSource = &PullSource{
323 Branch: sourceBranch.String,
324 }
325 if sourceRepoAt.Valid {
326 sourceRepoAtParsed, err := syntax.ParseATURI(sourceRepoAt.String)
327 if err != nil {
328 return nil, err
329 }
330 pull.PullSource.RepoAt = &sourceRepoAtParsed
331 }
332 }
333
334 pulls = append(pulls, pull)
335 }
336
337 if err := rows.Err(); err != nil {
338 return nil, err
339 }
340
341 return pulls, nil
342}
343
344func GetPull(e Execer, repoAt syntax.ATURI, pullId int) (*Pull, error) {
345 query := `
346 select
347 owner_did,
348 pull_id,
349 created,
350 title,
351 state,
352 target_branch,
353 pull_at,
354 repo_at,
355 body,
356 rkey,
357 source_branch,
358 source_repo_at
359 from
360 pulls
361 where
362 repo_at = ? and pull_id = ?
363 `
364 row := e.QueryRow(query, repoAt, pullId)
365
366 var pull Pull
367 var createdAt string
368 var sourceBranch, sourceRepoAt sql.NullString
369 err := row.Scan(
370 &pull.OwnerDid,
371 &pull.PullId,
372 &createdAt,
373 &pull.Title,
374 &pull.State,
375 &pull.TargetBranch,
376 &pull.PullAt,
377 &pull.RepoAt,
378 &pull.Body,
379 &pull.Rkey,
380 &sourceBranch,
381 &sourceRepoAt,
382 )
383 if err != nil {
384 return nil, err
385 }
386
387 createdTime, err := time.Parse(time.RFC3339, createdAt)
388 if err != nil {
389 return nil, err
390 }
391 pull.Created = createdTime
392
393 // populate source
394 if sourceBranch.Valid {
395 pull.PullSource = &PullSource{
396 Branch: sourceBranch.String,
397 }
398 if sourceRepoAt.Valid {
399 sourceRepoAtParsed, err := syntax.ParseATURI(sourceRepoAt.String)
400 if err != nil {
401 return nil, err
402 }
403 pull.PullSource.RepoAt = &sourceRepoAtParsed
404 }
405 }
406
407 submissionsQuery := `
408 select
409 id, pull_id, repo_at, round_number, patch, created, source_rev
410 from
411 pull_submissions
412 where
413 repo_at = ? and pull_id = ?
414 `
415 submissionsRows, err := e.Query(submissionsQuery, repoAt, pullId)
416 if err != nil {
417 return nil, err
418 }
419 defer submissionsRows.Close()
420
421 submissionsMap := make(map[int]*PullSubmission)
422
423 for submissionsRows.Next() {
424 var submission PullSubmission
425 var submissionCreatedStr string
426 var submissionSourceRev sql.NullString
427 err := submissionsRows.Scan(
428 &submission.ID,
429 &submission.PullId,
430 &submission.RepoAt,
431 &submission.RoundNumber,
432 &submission.Patch,
433 &submissionCreatedStr,
434 &submissionSourceRev,
435 )
436 if err != nil {
437 return nil, err
438 }
439
440 submissionCreatedTime, err := time.Parse(time.RFC3339, submissionCreatedStr)
441 if err != nil {
442 return nil, err
443 }
444 submission.Created = submissionCreatedTime
445
446 if submissionSourceRev.Valid {
447 submission.SourceRev = submissionSourceRev.String
448 }
449
450 submissionsMap[submission.ID] = &submission
451 }
452 if err = submissionsRows.Close(); err != nil {
453 return nil, err
454 }
455 if len(submissionsMap) == 0 {
456 return &pull, nil
457 }
458
459 var args []any
460 for k := range submissionsMap {
461 args = append(args, k)
462 }
463 inClause := strings.TrimSuffix(strings.Repeat("?, ", len(submissionsMap)), ", ")
464 commentsQuery := fmt.Sprintf(`
465 select
466 id,
467 pull_id,
468 submission_id,
469 repo_at,
470 owner_did,
471 comment_at,
472 body,
473 created
474 from
475 pull_comments
476 where
477 submission_id IN (%s)
478 order by
479 created asc
480 `, inClause)
481 commentsRows, err := e.Query(commentsQuery, args...)
482 if err != nil {
483 return nil, err
484 }
485 defer commentsRows.Close()
486
487 for commentsRows.Next() {
488 var comment PullComment
489 var commentCreatedStr string
490 err := commentsRows.Scan(
491 &comment.ID,
492 &comment.PullId,
493 &comment.SubmissionId,
494 &comment.RepoAt,
495 &comment.OwnerDid,
496 &comment.CommentAt,
497 &comment.Body,
498 &commentCreatedStr,
499 )
500 if err != nil {
501 return nil, err
502 }
503
504 commentCreatedTime, err := time.Parse(time.RFC3339, commentCreatedStr)
505 if err != nil {
506 return nil, err
507 }
508 comment.Created = commentCreatedTime
509
510 // Add the comment to its submission
511 if submission, ok := submissionsMap[comment.SubmissionId]; ok {
512 submission.Comments = append(submission.Comments, comment)
513 }
514
515 }
516 if err = commentsRows.Err(); err != nil {
517 return nil, err
518 }
519
520 pull.Submissions = make([]*PullSubmission, len(submissionsMap))
521 for _, submission := range submissionsMap {
522 pull.Submissions[submission.RoundNumber] = submission
523 }
524
525 return &pull, nil
526}
527
528// timeframe here is directly passed into the sql query filter, and any
529// timeframe in the past should be negative; e.g.: "-3 months"
530func GetPullsByOwnerDid(e Execer, did, timeframe string) ([]Pull, error) {
531 var pulls []Pull
532
533 rows, err := e.Query(`
534 select
535 p.owner_did,
536 p.repo_at,
537 p.pull_id,
538 p.created,
539 p.title,
540 p.state,
541 r.did,
542 r.name,
543 r.knot,
544 r.rkey,
545 r.created
546 from
547 pulls p
548 join
549 repos r on p.repo_at = r.at_uri
550 where
551 p.owner_did = ? and p.created >= date ('now', ?)
552 order by
553 p.created desc`, did, timeframe)
554 if err != nil {
555 return nil, err
556 }
557 defer rows.Close()
558
559 for rows.Next() {
560 var pull Pull
561 var repo Repo
562 var pullCreatedAt, repoCreatedAt string
563 err := rows.Scan(
564 &pull.OwnerDid,
565 &pull.RepoAt,
566 &pull.PullId,
567 &pullCreatedAt,
568 &pull.Title,
569 &pull.State,
570 &repo.Did,
571 &repo.Name,
572 &repo.Knot,
573 &repo.Rkey,
574 &repoCreatedAt,
575 )
576 if err != nil {
577 return nil, err
578 }
579
580 pullCreatedTime, err := time.Parse(time.RFC3339, pullCreatedAt)
581 if err != nil {
582 return nil, err
583 }
584 pull.Created = pullCreatedTime
585
586 repoCreatedTime, err := time.Parse(time.RFC3339, repoCreatedAt)
587 if err != nil {
588 return nil, err
589 }
590 repo.Created = repoCreatedTime
591
592 pull.Repo = &repo
593
594 pulls = append(pulls, pull)
595 }
596
597 if err := rows.Err(); err != nil {
598 return nil, err
599 }
600
601 return pulls, nil
602}
603
604func NewPullComment(e Execer, comment *PullComment) (int64, error) {
605 query := `insert into pull_comments (owner_did, repo_at, submission_id, comment_at, pull_id, body) values (?, ?, ?, ?, ?, ?)`
606 res, err := e.Exec(
607 query,
608 comment.OwnerDid,
609 comment.RepoAt,
610 comment.SubmissionId,
611 comment.CommentAt,
612 comment.PullId,
613 comment.Body,
614 )
615 if err != nil {
616 return 0, err
617 }
618
619 i, err := res.LastInsertId()
620 if err != nil {
621 return 0, err
622 }
623
624 return i, nil
625}
626
627func SetPullState(e Execer, repoAt syntax.ATURI, pullId int, pullState PullState) error {
628 _, err := e.Exec(`update pulls set state = ? where repo_at = ? and pull_id = ?`, pullState, repoAt, pullId)
629 return err
630}
631
632func ClosePull(e Execer, repoAt syntax.ATURI, pullId int) error {
633 err := SetPullState(e, repoAt, pullId, PullClosed)
634 return err
635}
636
637func ReopenPull(e Execer, repoAt syntax.ATURI, pullId int) error {
638 err := SetPullState(e, repoAt, pullId, PullOpen)
639 return err
640}
641
642func MergePull(e Execer, repoAt syntax.ATURI, pullId int) error {
643 err := SetPullState(e, repoAt, pullId, PullMerged)
644 return err
645}
646
647func ResubmitPull(e Execer, pull *Pull, newPatch, sourceRev string) error {
648 newRoundNumber := len(pull.Submissions)
649 _, err := e.Exec(`
650 insert into pull_submissions (pull_id, repo_at, round_number, patch, source_rev)
651 values (?, ?, ?, ?, ?)
652 `, pull.PullId, pull.RepoAt, newRoundNumber, newPatch, sourceRev)
653
654 return err
655}
656
657type PullCount struct {
658 Open int
659 Merged int
660 Closed int
661}
662
663func GetPullCount(e Execer, repoAt syntax.ATURI) (PullCount, error) {
664 row := e.QueryRow(`
665 select
666 count(case when state = ? then 1 end) as open_count,
667 count(case when state = ? then 1 end) as merged_count,
668 count(case when state = ? then 1 end) as closed_count
669 from pulls
670 where repo_at = ?`,
671 PullOpen,
672 PullMerged,
673 PullClosed,
674 repoAt,
675 )
676
677 var count PullCount
678 if err := row.Scan(&count.Open, &count.Merged, &count.Closed); err != nil {
679 return PullCount{0, 0, 0}, err
680 }
681
682 return count, nil
683}