1package db
2
3import (
4 "database/sql"
5 "fmt"
6 "log"
7 "strings"
8 "time"
9
10 "github.com/bluekeyes/go-gitdiff/gitdiff"
11 "github.com/bluesky-social/indigo/atproto/syntax"
12 "tangled.sh/tangled.sh/core/types"
13)
14
15type PullState int
16
17const (
18 PullClosed PullState = iota
19 PullOpen
20 PullMerged
21)
22
23func (p PullState) String() string {
24 switch p {
25 case PullOpen:
26 return "open"
27 case PullMerged:
28 return "merged"
29 case PullClosed:
30 return "closed"
31 default:
32 return "closed"
33 }
34}
35
36func (p PullState) IsOpen() bool {
37 return p == PullOpen
38}
39func (p PullState) IsMerged() bool {
40 return p == PullMerged
41}
42func (p PullState) IsClosed() bool {
43 return p == PullClosed
44}
45
46type Pull struct {
47 // ids
48 ID int
49 PullId int
50
51 // at ids
52 RepoAt syntax.ATURI
53 OwnerDid string
54 Rkey string
55 PullAt syntax.ATURI
56
57 // content
58 Title string
59 Body string
60 TargetBranch string
61 State PullState
62 Submissions []*PullSubmission
63
64 // meta
65 Created time.Time
66 PullSource *PullSource
67}
68
69type PullSource struct {
70 Branch string
71 RepoAt *syntax.ATURI
72
73 // optionally populate this for reverse mappings
74 Repo *Repo
75}
76
77type PullSubmission struct {
78 // ids
79 ID int
80 PullId int
81
82 // at ids
83 RepoAt syntax.ATURI
84
85 // content
86 RoundNumber int
87 Patch string
88 Comments []PullComment
89 SourceRev string // include the rev that was used to create this submission: only for branch PRs
90
91 // meta
92 Created time.Time
93}
94
95type PullComment struct {
96 // ids
97 ID int
98 PullId int
99 SubmissionId int
100
101 // at ids
102 RepoAt string
103 OwnerDid string
104 CommentAt string
105
106 // content
107 Body string
108
109 // meta
110 Created time.Time
111}
112
113func (p *Pull) LatestPatch() string {
114 latestSubmission := p.Submissions[p.LastRoundNumber()]
115 return latestSubmission.Patch
116}
117
118func (p *Pull) LastRoundNumber() int {
119 return len(p.Submissions) - 1
120}
121
122func (p *Pull) IsSameRepoBranch() bool {
123 if p.PullSource != nil {
124 if p.PullSource.RepoAt != nil {
125 return p.PullSource.RepoAt == &p.RepoAt
126 } else {
127 // no repo specified
128 return true
129 }
130 }
131 return false
132}
133
134func (p *Pull) IsPatch() bool {
135 return p.PullSource == nil
136}
137
138func (s PullSubmission) AsNiceDiff(targetBranch string) types.NiceDiff {
139 patch := s.Patch
140
141 diffs, _, err := gitdiff.Parse(strings.NewReader(patch))
142 if err != nil {
143 log.Println(err)
144 }
145
146 nd := types.NiceDiff{}
147 nd.Commit.Parent = targetBranch
148
149 for _, d := range diffs {
150 ndiff := types.Diff{}
151 ndiff.Name.New = d.NewName
152 ndiff.Name.Old = d.OldName
153 ndiff.IsBinary = d.IsBinary
154 ndiff.IsNew = d.IsNew
155 ndiff.IsDelete = d.IsDelete
156 ndiff.IsCopy = d.IsCopy
157 ndiff.IsRename = d.IsRename
158
159 for _, tf := range d.TextFragments {
160 ndiff.TextFragments = append(ndiff.TextFragments, *tf)
161 for _, l := range tf.Lines {
162 switch l.Op {
163 case gitdiff.OpAdd:
164 nd.Stat.Insertions += 1
165 case gitdiff.OpDelete:
166 nd.Stat.Deletions += 1
167 }
168 }
169 }
170
171 nd.Diff = append(nd.Diff, ndiff)
172 }
173
174 nd.Stat.FilesChanged = len(diffs)
175
176 return nd
177}
178
179func NewPull(tx *sql.Tx, pull *Pull) error {
180 defer tx.Rollback()
181
182 _, err := tx.Exec(`
183 insert or ignore into repo_pull_seqs (repo_at, next_pull_id)
184 values (?, 1)
185 `, pull.RepoAt)
186 if err != nil {
187 return err
188 }
189
190 var nextId int
191 err = tx.QueryRow(`
192 update repo_pull_seqs
193 set next_pull_id = next_pull_id + 1
194 where repo_at = ?
195 returning next_pull_id - 1
196 `, pull.RepoAt).Scan(&nextId)
197 if err != nil {
198 return err
199 }
200
201 pull.PullId = nextId
202 pull.State = PullOpen
203
204 var sourceBranch, sourceRepoAt *string
205 if pull.PullSource != nil {
206 sourceBranch = &pull.PullSource.Branch
207 if pull.PullSource.RepoAt != nil {
208 x := pull.PullSource.RepoAt.String()
209 sourceRepoAt = &x
210 }
211 }
212
213 _, err = tx.Exec(
214 `
215 insert into pulls (repo_at, owner_did, pull_id, title, target_branch, body, rkey, state, source_branch, source_repo_at)
216 values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`,
217 pull.RepoAt,
218 pull.OwnerDid,
219 pull.PullId,
220 pull.Title,
221 pull.TargetBranch,
222 pull.Body,
223 pull.Rkey,
224 pull.State,
225 sourceBranch,
226 sourceRepoAt,
227 )
228 if err != nil {
229 return err
230 }
231
232 _, err = tx.Exec(`
233 insert into pull_submissions (pull_id, repo_at, round_number, patch, source_rev)
234 values (?, ?, ?, ?, ?)
235 `, pull.PullId, pull.RepoAt, 0, pull.Submissions[0].Patch, pull.Submissions[0].SourceRev)
236 if err != nil {
237 return err
238 }
239
240 if err := tx.Commit(); err != nil {
241 return err
242 }
243
244 return nil
245}
246
247func SetPullAt(e Execer, repoAt syntax.ATURI, pullId int, pullAt string) error {
248 _, err := e.Exec(`update pulls set pull_at = ? where repo_at = ? and pull_id = ?`, pullAt, repoAt, pullId)
249 return err
250}
251
252func GetPullAt(e Execer, repoAt syntax.ATURI, pullId int) (string, error) {
253 var pullAt string
254 err := e.QueryRow(`select pull_at from pulls where repo_at = ? and pull_id = ?`, repoAt, pullId).Scan(&pullAt)
255 return pullAt, err
256}
257
258func NextPullId(e Execer, repoAt syntax.ATURI) (int, error) {
259 var pullId int
260 err := e.QueryRow(`select next_pull_id from repo_pull_seqs where repo_at = ?`, repoAt).Scan(&pullId)
261 return pullId - 1, err
262}
263
264func GetPulls(e Execer, repoAt syntax.ATURI, state PullState) ([]Pull, error) {
265 var pulls []Pull
266
267 rows, err := e.Query(`
268 select
269 owner_did,
270 pull_id,
271 created,
272 title,
273 state,
274 target_branch,
275 pull_at,
276 body,
277 rkey,
278 source_branch,
279 source_repo_at
280 from
281 pulls
282 where
283 repo_at = ? and state = ?
284 order by
285 created desc`, repoAt, state)
286 if err != nil {
287 return nil, err
288 }
289 defer rows.Close()
290
291 for rows.Next() {
292 var pull Pull
293 var createdAt string
294 var sourceBranch, sourceRepoAt sql.NullString
295 err := rows.Scan(
296 &pull.OwnerDid,
297 &pull.PullId,
298 &createdAt,
299 &pull.Title,
300 &pull.State,
301 &pull.TargetBranch,
302 &pull.PullAt,
303 &pull.Body,
304 &pull.Rkey,
305 &sourceBranch,
306 &sourceRepoAt,
307 )
308 if err != nil {
309 return nil, err
310 }
311
312 createdTime, err := time.Parse(time.RFC3339, createdAt)
313 if err != nil {
314 return nil, err
315 }
316 pull.Created = createdTime
317
318 if sourceBranch.Valid {
319 pull.PullSource = &PullSource{
320 Branch: sourceBranch.String,
321 }
322 if sourceRepoAt.Valid {
323 sourceRepoAtParsed, err := syntax.ParseATURI(sourceRepoAt.String)
324 if err != nil {
325 return nil, err
326 }
327 pull.PullSource.RepoAt = &sourceRepoAtParsed
328 }
329 }
330
331 pulls = append(pulls, pull)
332 }
333
334 if err := rows.Err(); err != nil {
335 return nil, err
336 }
337
338 return pulls, nil
339}
340
341func GetPull(e Execer, repoAt syntax.ATURI, pullId int) (*Pull, error) {
342 query := `
343 select
344 owner_did,
345 pull_id,
346 created,
347 title,
348 state,
349 target_branch,
350 pull_at,
351 repo_at,
352 body,
353 rkey,
354 source_branch,
355 source_repo_at
356 from
357 pulls
358 where
359 repo_at = ? and pull_id = ?
360 `
361 row := e.QueryRow(query, repoAt, pullId)
362
363 var pull Pull
364 var createdAt string
365 var sourceBranch, sourceRepoAt sql.NullString
366 err := row.Scan(
367 &pull.OwnerDid,
368 &pull.PullId,
369 &createdAt,
370 &pull.Title,
371 &pull.State,
372 &pull.TargetBranch,
373 &pull.PullAt,
374 &pull.RepoAt,
375 &pull.Body,
376 &pull.Rkey,
377 &sourceBranch,
378 &sourceRepoAt,
379 )
380 if err != nil {
381 return nil, err
382 }
383
384 createdTime, err := time.Parse(time.RFC3339, createdAt)
385 if err != nil {
386 return nil, err
387 }
388 pull.Created = createdTime
389
390 // populate source
391 if sourceBranch.Valid {
392 pull.PullSource = &PullSource{
393 Branch: sourceBranch.String,
394 }
395 if sourceRepoAt.Valid {
396 sourceRepoAtParsed, err := syntax.ParseATURI(sourceRepoAt.String)
397 if err != nil {
398 return nil, err
399 }
400 pull.PullSource.RepoAt = &sourceRepoAtParsed
401 }
402 }
403
404 submissionsQuery := `
405 select
406 id, pull_id, repo_at, round_number, patch, created, source_rev
407 from
408 pull_submissions
409 where
410 repo_at = ? and pull_id = ?
411 `
412 submissionsRows, err := e.Query(submissionsQuery, repoAt, pullId)
413 if err != nil {
414 return nil, err
415 }
416 defer submissionsRows.Close()
417
418 submissionsMap := make(map[int]*PullSubmission)
419
420 for submissionsRows.Next() {
421 var submission PullSubmission
422 var submissionCreatedStr string
423 var submissionSourceRev sql.NullString
424 err := submissionsRows.Scan(
425 &submission.ID,
426 &submission.PullId,
427 &submission.RepoAt,
428 &submission.RoundNumber,
429 &submission.Patch,
430 &submissionCreatedStr,
431 &submissionSourceRev,
432 )
433 if err != nil {
434 return nil, err
435 }
436
437 submissionCreatedTime, err := time.Parse(time.RFC3339, submissionCreatedStr)
438 if err != nil {
439 return nil, err
440 }
441 submission.Created = submissionCreatedTime
442
443 if submissionSourceRev.Valid {
444 submission.SourceRev = submissionSourceRev.String
445 }
446
447 submissionsMap[submission.ID] = &submission
448 }
449 if err = submissionsRows.Close(); err != nil {
450 return nil, err
451 }
452 if len(submissionsMap) == 0 {
453 return &pull, nil
454 }
455
456 var args []any
457 for k := range submissionsMap {
458 args = append(args, k)
459 }
460 inClause := strings.TrimSuffix(strings.Repeat("?, ", len(submissionsMap)), ", ")
461 commentsQuery := fmt.Sprintf(`
462 select
463 id,
464 pull_id,
465 submission_id,
466 repo_at,
467 owner_did,
468 comment_at,
469 body,
470 created
471 from
472 pull_comments
473 where
474 submission_id IN (%s)
475 order by
476 created asc
477 `, inClause)
478 commentsRows, err := e.Query(commentsQuery, args...)
479 if err != nil {
480 return nil, err
481 }
482 defer commentsRows.Close()
483
484 for commentsRows.Next() {
485 var comment PullComment
486 var commentCreatedStr string
487 err := commentsRows.Scan(
488 &comment.ID,
489 &comment.PullId,
490 &comment.SubmissionId,
491 &comment.RepoAt,
492 &comment.OwnerDid,
493 &comment.CommentAt,
494 &comment.Body,
495 &commentCreatedStr,
496 )
497 if err != nil {
498 return nil, err
499 }
500
501 commentCreatedTime, err := time.Parse(time.RFC3339, commentCreatedStr)
502 if err != nil {
503 return nil, err
504 }
505 comment.Created = commentCreatedTime
506
507 // Add the comment to its submission
508 if submission, ok := submissionsMap[comment.SubmissionId]; ok {
509 submission.Comments = append(submission.Comments, comment)
510 }
511
512 }
513 if err = commentsRows.Err(); err != nil {
514 return nil, err
515 }
516
517 pull.Submissions = make([]*PullSubmission, len(submissionsMap))
518 for _, submission := range submissionsMap {
519 pull.Submissions[submission.RoundNumber] = submission
520 }
521
522 return &pull, nil
523}
524
525func GetPullsByOwnerDid(e Execer, did string) ([]Pull, error) {
526 var pulls []Pull
527
528 rows, err := e.Query(`
529 select
530 owner_did,
531 repo_at,
532 pull_id,
533 created,
534 title,
535 state
536 from
537 pulls
538 where
539 owner_did = ?
540 order by
541 created desc`, did)
542 if err != nil {
543 return nil, err
544 }
545 defer rows.Close()
546
547 for rows.Next() {
548 var pull Pull
549 var createdAt string
550 err := rows.Scan(
551 &pull.OwnerDid,
552 &pull.RepoAt,
553 &pull.PullId,
554 &createdAt,
555 &pull.Title,
556 &pull.State,
557 )
558 if err != nil {
559 return nil, err
560 }
561
562 createdTime, err := time.Parse(time.RFC3339, createdAt)
563 if err != nil {
564 return nil, err
565 }
566 pull.Created = createdTime
567
568 pulls = append(pulls, pull)
569 }
570
571 if err := rows.Err(); err != nil {
572 return nil, err
573 }
574
575 return pulls, nil
576}
577
578func NewPullComment(e Execer, comment *PullComment) (int64, error) {
579 query := `insert into pull_comments (owner_did, repo_at, submission_id, comment_at, pull_id, body) values (?, ?, ?, ?, ?, ?)`
580 res, err := e.Exec(
581 query,
582 comment.OwnerDid,
583 comment.RepoAt,
584 comment.SubmissionId,
585 comment.CommentAt,
586 comment.PullId,
587 comment.Body,
588 )
589 if err != nil {
590 return 0, err
591 }
592
593 i, err := res.LastInsertId()
594 if err != nil {
595 return 0, err
596 }
597
598 return i, nil
599}
600
601func SetPullState(e Execer, repoAt syntax.ATURI, pullId int, pullState PullState) error {
602 _, err := e.Exec(`update pulls set state = ? where repo_at = ? and pull_id = ?`, pullState, repoAt, pullId)
603 return err
604}
605
606func ClosePull(e Execer, repoAt syntax.ATURI, pullId int) error {
607 err := SetPullState(e, repoAt, pullId, PullClosed)
608 return err
609}
610
611func ReopenPull(e Execer, repoAt syntax.ATURI, pullId int) error {
612 err := SetPullState(e, repoAt, pullId, PullOpen)
613 return err
614}
615
616func MergePull(e Execer, repoAt syntax.ATURI, pullId int) error {
617 err := SetPullState(e, repoAt, pullId, PullMerged)
618 return err
619}
620
621func ResubmitPull(e Execer, pull *Pull, newPatch, sourceRev string) error {
622 newRoundNumber := len(pull.Submissions)
623 _, err := e.Exec(`
624 insert into pull_submissions (pull_id, repo_at, round_number, patch, source_rev)
625 values (?, ?, ?, ?, ?)
626 `, pull.PullId, pull.RepoAt, newRoundNumber, newPatch, sourceRev)
627
628 return err
629}
630
631type PullCount struct {
632 Open int
633 Merged int
634 Closed int
635}
636
637func GetPullCount(e Execer, repoAt syntax.ATURI) (PullCount, error) {
638 row := e.QueryRow(`
639 select
640 count(case when state = ? then 1 end) as open_count,
641 count(case when state = ? then 1 end) as merged_count,
642 count(case when state = ? then 1 end) as closed_count
643 from pulls
644 where repo_at = ?`,
645 PullOpen,
646 PullMerged,
647 PullClosed,
648 repoAt,
649 )
650
651 var count PullCount
652 if err := row.Scan(&count.Open, &count.Merged, &count.Closed); err != nil {
653 return PullCount{0, 0, 0}, err
654 }
655
656 return count, nil
657}