1package db
2
3import (
4 "database/sql"
5 "fmt"
6 "log"
7 "strings"
8 "time"
9
10 "github.com/bluekeyes/go-gitdiff/gitdiff"
11 "github.com/bluesky-social/indigo/atproto/syntax"
12 "tangled.sh/tangled.sh/core/types"
13)
14
15type PullState int
16
17const (
18 PullClosed PullState = iota
19 PullOpen
20 PullMerged
21)
22
23func (p PullState) String() string {
24 switch p {
25 case PullOpen:
26 return "open"
27 case PullMerged:
28 return "merged"
29 case PullClosed:
30 return "closed"
31 default:
32 return "closed"
33 }
34}
35
36func (p PullState) IsOpen() bool {
37 return p == PullOpen
38}
39func (p PullState) IsMerged() bool {
40 return p == PullMerged
41}
42func (p PullState) IsClosed() bool {
43 return p == PullClosed
44}
45
46type Pull struct {
47 // ids
48 ID int
49 PullId int
50
51 // at ids
52 RepoAt syntax.ATURI
53 OwnerDid string
54 Rkey string
55 PullAt syntax.ATURI
56
57 // content
58 Title string
59 Body string
60 TargetBranch string
61 State PullState
62 Submissions []*PullSubmission
63
64 // meta
65 Created time.Time
66 PullSource *PullSource
67}
68
69type PullSource struct {
70 Branch string
71 Repo *syntax.ATURI
72}
73
74type PullSubmission struct {
75 // ids
76 ID int
77 PullId int
78
79 // at ids
80 RepoAt syntax.ATURI
81
82 // content
83 RoundNumber int
84 Patch string
85 Comments []PullComment
86 SourceRev string // include the rev that was used to create this submission: only for branch PRs
87
88 // meta
89 Created time.Time
90}
91
92type PullComment struct {
93 // ids
94 ID int
95 PullId int
96 SubmissionId int
97
98 // at ids
99 RepoAt string
100 OwnerDid string
101 CommentAt string
102
103 // content
104 Body string
105
106 // meta
107 Created time.Time
108}
109
110func (p *Pull) LatestPatch() string {
111 latestSubmission := p.Submissions[p.LastRoundNumber()]
112 return latestSubmission.Patch
113}
114
115func (p *Pull) LastRoundNumber() int {
116 return len(p.Submissions) - 1
117}
118
119func (p *Pull) IsSameRepoBranch() bool {
120 if p.PullSource != nil {
121 if p.PullSource.Repo != nil {
122 return p.PullSource.Repo == &p.RepoAt
123 } else {
124 // no repo specified
125 return true
126 }
127 }
128 return false
129}
130
131func (p *Pull) IsPatch() bool {
132 return p.PullSource == nil
133}
134
135func (s PullSubmission) AsNiceDiff(targetBranch string) types.NiceDiff {
136 patch := s.Patch
137
138 diffs, _, err := gitdiff.Parse(strings.NewReader(patch))
139 if err != nil {
140 log.Println(err)
141 }
142
143 nd := types.NiceDiff{}
144 nd.Commit.Parent = targetBranch
145
146 for _, d := range diffs {
147 ndiff := types.Diff{}
148 ndiff.Name.New = d.NewName
149 ndiff.Name.Old = d.OldName
150 ndiff.IsBinary = d.IsBinary
151 ndiff.IsNew = d.IsNew
152 ndiff.IsDelete = d.IsDelete
153 ndiff.IsCopy = d.IsCopy
154 ndiff.IsRename = d.IsRename
155
156 for _, tf := range d.TextFragments {
157 ndiff.TextFragments = append(ndiff.TextFragments, *tf)
158 for _, l := range tf.Lines {
159 switch l.Op {
160 case gitdiff.OpAdd:
161 nd.Stat.Insertions += 1
162 case gitdiff.OpDelete:
163 nd.Stat.Deletions += 1
164 }
165 }
166 }
167
168 nd.Diff = append(nd.Diff, ndiff)
169 }
170
171 nd.Stat.FilesChanged = len(diffs)
172
173 return nd
174}
175
176func NewPull(tx *sql.Tx, pull *Pull) error {
177 defer tx.Rollback()
178
179 _, err := tx.Exec(`
180 insert or ignore into repo_pull_seqs (repo_at, next_pull_id)
181 values (?, 1)
182 `, pull.RepoAt)
183 if err != nil {
184 return err
185 }
186
187 var nextId int
188 err = tx.QueryRow(`
189 update repo_pull_seqs
190 set next_pull_id = next_pull_id + 1
191 where repo_at = ?
192 returning next_pull_id - 1
193 `, pull.RepoAt).Scan(&nextId)
194 if err != nil {
195 return err
196 }
197
198 pull.PullId = nextId
199 pull.State = PullOpen
200
201 var sourceBranch, sourceRepoAt *string
202 if pull.PullSource != nil {
203 sourceBranch = &pull.PullSource.Branch
204 if pull.PullSource.Repo != nil {
205 x := pull.PullSource.Repo.String()
206 sourceRepoAt = &x
207 }
208 }
209
210 _, err = tx.Exec(
211 `
212 insert into pulls (repo_at, owner_did, pull_id, title, target_branch, body, rkey, state, source_branch, source_repo_at)
213 values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`,
214 pull.RepoAt,
215 pull.OwnerDid,
216 pull.PullId,
217 pull.Title,
218 pull.TargetBranch,
219 pull.Body,
220 pull.Rkey,
221 pull.State,
222 sourceBranch,
223 sourceRepoAt,
224 )
225 if err != nil {
226 return err
227 }
228
229 _, err = tx.Exec(`
230 insert into pull_submissions (pull_id, repo_at, round_number, patch, source_rev)
231 values (?, ?, ?, ?, ?)
232 `, pull.PullId, pull.RepoAt, 0, pull.Submissions[0].Patch, pull.Submissions[0].SourceRev)
233 if err != nil {
234 return err
235 }
236
237 if err := tx.Commit(); err != nil {
238 return err
239 }
240
241 return nil
242}
243
244func SetPullAt(e Execer, repoAt syntax.ATURI, pullId int, pullAt string) error {
245 _, err := e.Exec(`update pulls set pull_at = ? where repo_at = ? and pull_id = ?`, pullAt, repoAt, pullId)
246 return err
247}
248
249func GetPullAt(e Execer, repoAt syntax.ATURI, pullId int) (string, error) {
250 var pullAt string
251 err := e.QueryRow(`select pull_at from pulls where repo_at = ? and pull_id = ?`, repoAt, pullId).Scan(&pullAt)
252 return pullAt, err
253}
254
255func NextPullId(e Execer, repoAt syntax.ATURI) (int, error) {
256 var pullId int
257 err := e.QueryRow(`select next_pull_id from repo_pull_seqs where repo_at = ?`, repoAt).Scan(&pullId)
258 return pullId - 1, err
259}
260
261func GetPulls(e Execer, repoAt syntax.ATURI, state PullState) ([]Pull, error) {
262 var pulls []Pull
263
264 rows, err := e.Query(`
265 select
266 owner_did,
267 pull_id,
268 created,
269 title,
270 state,
271 target_branch,
272 pull_at,
273 body,
274 rkey,
275 source_branch,
276 source_repo_at
277 from
278 pulls
279 where
280 repo_at = ? and state = ?
281 order by
282 created desc`, repoAt, state)
283 if err != nil {
284 return nil, err
285 }
286 defer rows.Close()
287
288 for rows.Next() {
289 var pull Pull
290 var createdAt string
291 var sourceBranch, sourceRepoAt sql.NullString
292 err := rows.Scan(
293 &pull.OwnerDid,
294 &pull.PullId,
295 &createdAt,
296 &pull.Title,
297 &pull.State,
298 &pull.TargetBranch,
299 &pull.PullAt,
300 &pull.Body,
301 &pull.Rkey,
302 &sourceBranch,
303 &sourceRepoAt,
304 )
305 if err != nil {
306 return nil, err
307 }
308
309 createdTime, err := time.Parse(time.RFC3339, createdAt)
310 if err != nil {
311 return nil, err
312 }
313 pull.Created = createdTime
314
315 if sourceBranch.Valid {
316 pull.PullSource = &PullSource{
317 Branch: sourceBranch.String,
318 }
319 if sourceRepoAt.Valid {
320 sourceRepoAtParsed, err := syntax.ParseATURI(sourceRepoAt.String)
321 if err != nil {
322 return nil, err
323 }
324 pull.PullSource.Repo = &sourceRepoAtParsed
325 }
326 }
327
328 pulls = append(pulls, pull)
329 }
330
331 if err := rows.Err(); err != nil {
332 return nil, err
333 }
334
335 return pulls, nil
336}
337
338func GetPull(e Execer, repoAt syntax.ATURI, pullId int) (*Pull, error) {
339 query := `
340 select
341 owner_did,
342 pull_id,
343 created,
344 title,
345 state,
346 target_branch,
347 pull_at,
348 repo_at,
349 body,
350 rkey,
351 source_branch,
352 source_repo_at
353 from
354 pulls
355 where
356 repo_at = ? and pull_id = ?
357 `
358 row := e.QueryRow(query, repoAt, pullId)
359
360 var pull Pull
361 var createdAt string
362 var sourceBranch, sourceRepoAt sql.NullString
363 err := row.Scan(
364 &pull.OwnerDid,
365 &pull.PullId,
366 &createdAt,
367 &pull.Title,
368 &pull.State,
369 &pull.TargetBranch,
370 &pull.PullAt,
371 &pull.RepoAt,
372 &pull.Body,
373 &pull.Rkey,
374 &sourceBranch,
375 &sourceRepoAt,
376 )
377 if err != nil {
378 return nil, err
379 }
380
381 createdTime, err := time.Parse(time.RFC3339, createdAt)
382 if err != nil {
383 return nil, err
384 }
385 pull.Created = createdTime
386
387 // populate source
388 if sourceBranch.Valid {
389 pull.PullSource = &PullSource{
390 Branch: sourceBranch.String,
391 }
392 if sourceRepoAt.Valid {
393 sourceRepoAtParsed, err := syntax.ParseATURI(sourceRepoAt.String)
394 if err != nil {
395 return nil, err
396 }
397 pull.PullSource.Repo = &sourceRepoAtParsed
398 }
399 }
400
401 submissionsQuery := `
402 select
403 id, pull_id, repo_at, round_number, patch, created, source_rev
404 from
405 pull_submissions
406 where
407 repo_at = ? and pull_id = ?
408 `
409 submissionsRows, err := e.Query(submissionsQuery, repoAt, pullId)
410 if err != nil {
411 return nil, err
412 }
413 defer submissionsRows.Close()
414
415 submissionsMap := make(map[int]*PullSubmission)
416
417 for submissionsRows.Next() {
418 var submission PullSubmission
419 var submissionCreatedStr string
420 var submissionSourceRev sql.NullString
421 err := submissionsRows.Scan(
422 &submission.ID,
423 &submission.PullId,
424 &submission.RepoAt,
425 &submission.RoundNumber,
426 &submission.Patch,
427 &submissionCreatedStr,
428 &submissionSourceRev,
429 )
430 if err != nil {
431 return nil, err
432 }
433
434 submissionCreatedTime, err := time.Parse(time.RFC3339, submissionCreatedStr)
435 if err != nil {
436 return nil, err
437 }
438 submission.Created = submissionCreatedTime
439
440 if submissionSourceRev.Valid {
441 submission.SourceRev = submissionSourceRev.String
442 }
443
444 submissionsMap[submission.ID] = &submission
445 }
446 if err = submissionsRows.Close(); err != nil {
447 return nil, err
448 }
449 if len(submissionsMap) == 0 {
450 return &pull, nil
451 }
452
453 var args []any
454 for k := range submissionsMap {
455 args = append(args, k)
456 }
457 inClause := strings.TrimSuffix(strings.Repeat("?, ", len(submissionsMap)), ", ")
458 commentsQuery := fmt.Sprintf(`
459 select
460 id,
461 pull_id,
462 submission_id,
463 repo_at,
464 owner_did,
465 comment_at,
466 body,
467 created
468 from
469 pull_comments
470 where
471 submission_id IN (%s)
472 order by
473 created asc
474 `, inClause)
475 commentsRows, err := e.Query(commentsQuery, args...)
476 if err != nil {
477 return nil, err
478 }
479 defer commentsRows.Close()
480
481 for commentsRows.Next() {
482 var comment PullComment
483 var commentCreatedStr string
484 err := commentsRows.Scan(
485 &comment.ID,
486 &comment.PullId,
487 &comment.SubmissionId,
488 &comment.RepoAt,
489 &comment.OwnerDid,
490 &comment.CommentAt,
491 &comment.Body,
492 &commentCreatedStr,
493 )
494 if err != nil {
495 return nil, err
496 }
497
498 commentCreatedTime, err := time.Parse(time.RFC3339, commentCreatedStr)
499 if err != nil {
500 return nil, err
501 }
502 comment.Created = commentCreatedTime
503
504 // Add the comment to its submission
505 if submission, ok := submissionsMap[comment.SubmissionId]; ok {
506 submission.Comments = append(submission.Comments, comment)
507 }
508
509 }
510 if err = commentsRows.Err(); err != nil {
511 return nil, err
512 }
513
514 pull.Submissions = make([]*PullSubmission, len(submissionsMap))
515 for _, submission := range submissionsMap {
516 pull.Submissions[submission.RoundNumber] = submission
517 }
518
519 return &pull, nil
520}
521
522func GetPullsByOwnerDid(e Execer, did string) ([]Pull, error) {
523 var pulls []Pull
524
525 rows, err := e.Query(`
526 select
527 owner_did,
528 repo_at,
529 pull_id,
530 created,
531 title,
532 state
533 from
534 pulls
535 where
536 owner_did = ?
537 order by
538 created desc`, did)
539 if err != nil {
540 return nil, err
541 }
542 defer rows.Close()
543
544 for rows.Next() {
545 var pull Pull
546 var createdAt string
547 err := rows.Scan(
548 &pull.OwnerDid,
549 &pull.RepoAt,
550 &pull.PullId,
551 &createdAt,
552 &pull.Title,
553 &pull.State,
554 )
555 if err != nil {
556 return nil, err
557 }
558
559 createdTime, err := time.Parse(time.RFC3339, createdAt)
560 if err != nil {
561 return nil, err
562 }
563 pull.Created = createdTime
564
565 pulls = append(pulls, pull)
566 }
567
568 if err := rows.Err(); err != nil {
569 return nil, err
570 }
571
572 return pulls, nil
573}
574
575func NewPullComment(e Execer, comment *PullComment) (int64, error) {
576 query := `insert into pull_comments (owner_did, repo_at, submission_id, comment_at, pull_id, body) values (?, ?, ?, ?, ?, ?)`
577 res, err := e.Exec(
578 query,
579 comment.OwnerDid,
580 comment.RepoAt,
581 comment.SubmissionId,
582 comment.CommentAt,
583 comment.PullId,
584 comment.Body,
585 )
586 if err != nil {
587 return 0, err
588 }
589
590 i, err := res.LastInsertId()
591 if err != nil {
592 return 0, err
593 }
594
595 return i, nil
596}
597
598func SetPullState(e Execer, repoAt syntax.ATURI, pullId int, pullState PullState) error {
599 _, err := e.Exec(`update pulls set state = ? where repo_at = ? and pull_id = ?`, pullState, repoAt, pullId)
600 return err
601}
602
603func ClosePull(e Execer, repoAt syntax.ATURI, pullId int) error {
604 err := SetPullState(e, repoAt, pullId, PullClosed)
605 return err
606}
607
608func ReopenPull(e Execer, repoAt syntax.ATURI, pullId int) error {
609 err := SetPullState(e, repoAt, pullId, PullOpen)
610 return err
611}
612
613func MergePull(e Execer, repoAt syntax.ATURI, pullId int) error {
614 err := SetPullState(e, repoAt, pullId, PullMerged)
615 return err
616}
617
618func ResubmitPull(e Execer, pull *Pull, newPatch, sourceRev string) error {
619 newRoundNumber := len(pull.Submissions)
620 _, err := e.Exec(`
621 insert into pull_submissions (pull_id, repo_at, round_number, patch, source_rev)
622 values (?, ?, ?, ?, ?)
623 `, pull.PullId, pull.RepoAt, newRoundNumber, newPatch, sourceRev)
624
625 return err
626}
627
628type PullCount struct {
629 Open int
630 Merged int
631 Closed int
632}
633
634func GetPullCount(e Execer, repoAt syntax.ATURI) (PullCount, error) {
635 row := e.QueryRow(`
636 select
637 count(case when state = ? then 1 end) as open_count,
638 count(case when state = ? then 1 end) as merged_count,
639 count(case when state = ? then 1 end) as closed_count
640 from pulls
641 where repo_at = ?`,
642 PullOpen,
643 PullMerged,
644 PullClosed,
645 repoAt,
646 )
647
648 var count PullCount
649 if err := row.Scan(&count.Open, &count.Merged, &count.Closed); err != nil {
650 return PullCount{0, 0, 0}, err
651 }
652
653 return count, nil
654}