1package db
2
3import (
4 "database/sql"
5 "fmt"
6 "log"
7 "strings"
8 "time"
9
10 "github.com/bluekeyes/go-gitdiff/gitdiff"
11 "github.com/bluesky-social/indigo/atproto/syntax"
12 "tangled.sh/tangled.sh/core/types"
13)
14
15type PullState int
16
17const (
18 PullClosed PullState = iota
19 PullOpen
20 PullMerged
21)
22
23func (p PullState) String() string {
24 switch p {
25 case PullOpen:
26 return "open"
27 case PullMerged:
28 return "merged"
29 case PullClosed:
30 return "closed"
31 default:
32 return "closed"
33 }
34}
35
36func (p PullState) IsOpen() bool {
37 return p == PullOpen
38}
39func (p PullState) IsMerged() bool {
40 return p == PullMerged
41}
42func (p PullState) IsClosed() bool {
43 return p == PullClosed
44}
45
46type Pull struct {
47 // ids
48 ID int
49 PullId int
50
51 // at ids
52 RepoAt syntax.ATURI
53 OwnerDid string
54 Rkey string
55 PullAt syntax.ATURI
56
57 // content
58 Title string
59 Body string
60 TargetBranch string
61 State PullState
62 Submissions []*PullSubmission
63
64 // meta
65 Created time.Time
66 PullSource *PullSource
67}
68
69type PullSource struct {
70 Branch string
71 Repo *syntax.ATURI
72}
73
74type PullSubmission struct {
75 // ids
76 ID int
77 PullId int
78
79 // at ids
80 RepoAt syntax.ATURI
81
82 // content
83 RoundNumber int
84 Patch string
85 Comments []PullComment
86 SourceRev string // include the rev that was used to create this submission: only for branch PRs
87
88 // meta
89 Created time.Time
90}
91
92type PullComment struct {
93 // ids
94 ID int
95 PullId int
96 SubmissionId int
97
98 // at ids
99 RepoAt string
100 OwnerDid string
101 CommentAt string
102
103 // content
104 Body string
105
106 // meta
107 Created time.Time
108}
109
110func (p *Pull) LatestPatch() string {
111 latestSubmission := p.Submissions[p.LastRoundNumber()]
112 return latestSubmission.Patch
113}
114
115func (p *Pull) LastRoundNumber() int {
116 return len(p.Submissions) - 1
117}
118
119func (p *Pull) IsSameRepoBranch() bool {
120 if p.PullSource != nil {
121 if p.PullSource.Repo != nil {
122 return p.PullSource.Repo == &p.RepoAt
123 } else {
124 // no repo specified
125 return true
126 }
127 }
128 return false
129}
130
131func (s PullSubmission) AsNiceDiff(targetBranch string) types.NiceDiff {
132 patch := s.Patch
133
134 diffs, _, err := gitdiff.Parse(strings.NewReader(patch))
135 if err != nil {
136 log.Println(err)
137 }
138
139 nd := types.NiceDiff{}
140 nd.Commit.Parent = targetBranch
141
142 for _, d := range diffs {
143 ndiff := types.Diff{}
144 ndiff.Name.New = d.NewName
145 ndiff.Name.Old = d.OldName
146 ndiff.IsBinary = d.IsBinary
147 ndiff.IsNew = d.IsNew
148 ndiff.IsDelete = d.IsDelete
149 ndiff.IsCopy = d.IsCopy
150 ndiff.IsRename = d.IsRename
151
152 for _, tf := range d.TextFragments {
153 ndiff.TextFragments = append(ndiff.TextFragments, *tf)
154 for _, l := range tf.Lines {
155 switch l.Op {
156 case gitdiff.OpAdd:
157 nd.Stat.Insertions += 1
158 case gitdiff.OpDelete:
159 nd.Stat.Deletions += 1
160 }
161 }
162 }
163
164 nd.Diff = append(nd.Diff, ndiff)
165 }
166
167 nd.Stat.FilesChanged = len(diffs)
168
169 return nd
170}
171
172func NewPull(tx *sql.Tx, pull *Pull) error {
173 defer tx.Rollback()
174
175 _, err := tx.Exec(`
176 insert or ignore into repo_pull_seqs (repo_at, next_pull_id)
177 values (?, 1)
178 `, pull.RepoAt)
179 if err != nil {
180 return err
181 }
182
183 var nextId int
184 err = tx.QueryRow(`
185 update repo_pull_seqs
186 set next_pull_id = next_pull_id + 1
187 where repo_at = ?
188 returning next_pull_id - 1
189 `, pull.RepoAt).Scan(&nextId)
190 if err != nil {
191 return err
192 }
193
194 pull.PullId = nextId
195 pull.State = PullOpen
196
197 var sourceBranch, sourceRepoAt *string
198 if pull.PullSource != nil {
199 sourceBranch = &pull.PullSource.Branch
200 if pull.PullSource.Repo != nil {
201 x := pull.PullSource.Repo.String()
202 sourceRepoAt = &x
203 }
204 }
205
206 _, err = tx.Exec(
207 `
208 insert into pulls (repo_at, owner_did, pull_id, title, target_branch, body, rkey, state, source_branch, source_repo_at)
209 values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`,
210 pull.RepoAt,
211 pull.OwnerDid,
212 pull.PullId,
213 pull.Title,
214 pull.TargetBranch,
215 pull.Body,
216 pull.Rkey,
217 pull.State,
218 sourceBranch,
219 sourceRepoAt,
220 )
221 if err != nil {
222 return err
223 }
224
225 _, err = tx.Exec(`
226 insert into pull_submissions (pull_id, repo_at, round_number, patch, source_rev)
227 values (?, ?, ?, ?, ?)
228 `, pull.PullId, pull.RepoAt, 0, pull.Submissions[0].Patch, pull.Submissions[0].SourceRev)
229 if err != nil {
230 return err
231 }
232
233 if err := tx.Commit(); err != nil {
234 return err
235 }
236
237 return nil
238}
239
240func SetPullAt(e Execer, repoAt syntax.ATURI, pullId int, pullAt string) error {
241 _, err := e.Exec(`update pulls set pull_at = ? where repo_at = ? and pull_id = ?`, pullAt, repoAt, pullId)
242 return err
243}
244
245func GetPullAt(e Execer, repoAt syntax.ATURI, pullId int) (string, error) {
246 var pullAt string
247 err := e.QueryRow(`select pull_at from pulls where repo_at = ? and pull_id = ?`, repoAt, pullId).Scan(&pullAt)
248 return pullAt, err
249}
250
251func NextPullId(e Execer, repoAt syntax.ATURI) (int, error) {
252 var pullId int
253 err := e.QueryRow(`select next_pull_id from repo_pull_seqs where repo_at = ?`, repoAt).Scan(&pullId)
254 return pullId - 1, err
255}
256
257func GetPulls(e Execer, repoAt syntax.ATURI, state PullState) ([]Pull, error) {
258 var pulls []Pull
259
260 rows, err := e.Query(`
261 select
262 owner_did,
263 pull_id,
264 created,
265 title,
266 state,
267 target_branch,
268 pull_at,
269 body,
270 rkey,
271 source_branch,
272 source_repo_at
273 from
274 pulls
275 where
276 repo_at = ? and state = ?
277 order by
278 created desc`, repoAt, state)
279 if err != nil {
280 return nil, err
281 }
282 defer rows.Close()
283
284 for rows.Next() {
285 var pull Pull
286 var createdAt string
287 var sourceBranch, sourceRepoAt sql.NullString
288 err := rows.Scan(
289 &pull.OwnerDid,
290 &pull.PullId,
291 &createdAt,
292 &pull.Title,
293 &pull.State,
294 &pull.TargetBranch,
295 &pull.PullAt,
296 &pull.Body,
297 &pull.Rkey,
298 &sourceBranch,
299 &sourceRepoAt,
300 )
301 if err != nil {
302 return nil, err
303 }
304
305 createdTime, err := time.Parse(time.RFC3339, createdAt)
306 if err != nil {
307 return nil, err
308 }
309 pull.Created = createdTime
310
311 if sourceBranch.Valid {
312 pull.PullSource = &PullSource{
313 Branch: sourceBranch.String,
314 }
315 if sourceRepoAt.Valid {
316 sourceRepoAtParsed, err := syntax.ParseATURI(sourceRepoAt.String)
317 if err != nil {
318 return nil, err
319 }
320 pull.PullSource.Repo = &sourceRepoAtParsed
321 }
322 }
323
324 pulls = append(pulls, pull)
325 }
326
327 if err := rows.Err(); err != nil {
328 return nil, err
329 }
330
331 return pulls, nil
332}
333
334func GetPull(e Execer, repoAt syntax.ATURI, pullId int) (*Pull, error) {
335 query := `
336 select
337 owner_did,
338 pull_id,
339 created,
340 title,
341 state,
342 target_branch,
343 pull_at,
344 repo_at,
345 body,
346 rkey,
347 source_branch,
348 source_repo_at
349 from
350 pulls
351 where
352 repo_at = ? and pull_id = ?
353 `
354 row := e.QueryRow(query, repoAt, pullId)
355
356 var pull Pull
357 var createdAt string
358 var sourceBranch, sourceRepoAt sql.NullString
359 err := row.Scan(
360 &pull.OwnerDid,
361 &pull.PullId,
362 &createdAt,
363 &pull.Title,
364 &pull.State,
365 &pull.TargetBranch,
366 &pull.PullAt,
367 &pull.RepoAt,
368 &pull.Body,
369 &pull.Rkey,
370 &sourceBranch,
371 &sourceRepoAt,
372 )
373 if err != nil {
374 return nil, err
375 }
376
377 createdTime, err := time.Parse(time.RFC3339, createdAt)
378 if err != nil {
379 return nil, err
380 }
381 pull.Created = createdTime
382
383 // populate source
384 if sourceBranch.Valid {
385 pull.PullSource = &PullSource{
386 Branch: sourceBranch.String,
387 }
388 if sourceRepoAt.Valid {
389 sourceRepoAtParsed, err := syntax.ParseATURI(sourceRepoAt.String)
390 if err != nil {
391 return nil, err
392 }
393 pull.PullSource.Repo = &sourceRepoAtParsed
394 }
395 }
396
397 submissionsQuery := `
398 select
399 id, pull_id, repo_at, round_number, patch, created, source_rev
400 from
401 pull_submissions
402 where
403 repo_at = ? and pull_id = ?
404 `
405 submissionsRows, err := e.Query(submissionsQuery, repoAt, pullId)
406 if err != nil {
407 return nil, err
408 }
409 defer submissionsRows.Close()
410
411 submissionsMap := make(map[int]*PullSubmission)
412
413 for submissionsRows.Next() {
414 var submission PullSubmission
415 var submissionCreatedStr string
416 var submissionSourceRev sql.NullString
417 err := submissionsRows.Scan(
418 &submission.ID,
419 &submission.PullId,
420 &submission.RepoAt,
421 &submission.RoundNumber,
422 &submission.Patch,
423 &submissionCreatedStr,
424 &submissionSourceRev,
425 )
426 if err != nil {
427 return nil, err
428 }
429
430 submissionCreatedTime, err := time.Parse(time.RFC3339, submissionCreatedStr)
431 if err != nil {
432 return nil, err
433 }
434 submission.Created = submissionCreatedTime
435
436 if submissionSourceRev.Valid {
437 submission.SourceRev = submissionSourceRev.String
438 }
439
440 submissionsMap[submission.ID] = &submission
441 }
442 if err = submissionsRows.Close(); err != nil {
443 return nil, err
444 }
445 if len(submissionsMap) == 0 {
446 return &pull, nil
447 }
448
449 var args []any
450 for k := range submissionsMap {
451 args = append(args, k)
452 }
453 inClause := strings.TrimSuffix(strings.Repeat("?, ", len(submissionsMap)), ", ")
454 commentsQuery := fmt.Sprintf(`
455 select
456 id,
457 pull_id,
458 submission_id,
459 repo_at,
460 owner_did,
461 comment_at,
462 body,
463 created
464 from
465 pull_comments
466 where
467 submission_id IN (%s)
468 order by
469 created asc
470 `, inClause)
471 commentsRows, err := e.Query(commentsQuery, args...)
472 if err != nil {
473 return nil, err
474 }
475 defer commentsRows.Close()
476
477 for commentsRows.Next() {
478 var comment PullComment
479 var commentCreatedStr string
480 err := commentsRows.Scan(
481 &comment.ID,
482 &comment.PullId,
483 &comment.SubmissionId,
484 &comment.RepoAt,
485 &comment.OwnerDid,
486 &comment.CommentAt,
487 &comment.Body,
488 &commentCreatedStr,
489 )
490 if err != nil {
491 return nil, err
492 }
493
494 commentCreatedTime, err := time.Parse(time.RFC3339, commentCreatedStr)
495 if err != nil {
496 return nil, err
497 }
498 comment.Created = commentCreatedTime
499
500 // Add the comment to its submission
501 if submission, ok := submissionsMap[comment.SubmissionId]; ok {
502 submission.Comments = append(submission.Comments, comment)
503 }
504
505 }
506 if err = commentsRows.Err(); err != nil {
507 return nil, err
508 }
509
510 pull.Submissions = make([]*PullSubmission, len(submissionsMap))
511 for _, submission := range submissionsMap {
512 pull.Submissions[submission.RoundNumber] = submission
513 }
514
515 return &pull, nil
516}
517
518func GetPullsByOwnerDid(e Execer, did string) ([]Pull, error) {
519 var pulls []Pull
520
521 rows, err := e.Query(`
522 select
523 owner_did,
524 repo_at,
525 pull_id,
526 created,
527 title,
528 state
529 from
530 pulls
531 where
532 owner_did = ?
533 order by
534 created desc`, did)
535 if err != nil {
536 return nil, err
537 }
538 defer rows.Close()
539
540 for rows.Next() {
541 var pull Pull
542 var createdAt string
543 err := rows.Scan(
544 &pull.OwnerDid,
545 &pull.RepoAt,
546 &pull.PullId,
547 &createdAt,
548 &pull.Title,
549 &pull.State,
550 )
551 if err != nil {
552 return nil, err
553 }
554
555 createdTime, err := time.Parse(time.RFC3339, createdAt)
556 if err != nil {
557 return nil, err
558 }
559 pull.Created = createdTime
560
561 pulls = append(pulls, pull)
562 }
563
564 if err := rows.Err(); err != nil {
565 return nil, err
566 }
567
568 return pulls, nil
569}
570
571func NewPullComment(e Execer, comment *PullComment) (int64, error) {
572 query := `insert into pull_comments (owner_did, repo_at, submission_id, comment_at, pull_id, body) values (?, ?, ?, ?, ?, ?)`
573 res, err := e.Exec(
574 query,
575 comment.OwnerDid,
576 comment.RepoAt,
577 comment.SubmissionId,
578 comment.CommentAt,
579 comment.PullId,
580 comment.Body,
581 )
582 if err != nil {
583 return 0, err
584 }
585
586 i, err := res.LastInsertId()
587 if err != nil {
588 return 0, err
589 }
590
591 return i, nil
592}
593
594func SetPullState(e Execer, repoAt syntax.ATURI, pullId int, pullState PullState) error {
595 _, err := e.Exec(`update pulls set state = ? where repo_at = ? and pull_id = ?`, pullState, repoAt, pullId)
596 return err
597}
598
599func ClosePull(e Execer, repoAt syntax.ATURI, pullId int) error {
600 err := SetPullState(e, repoAt, pullId, PullClosed)
601 return err
602}
603
604func ReopenPull(e Execer, repoAt syntax.ATURI, pullId int) error {
605 err := SetPullState(e, repoAt, pullId, PullOpen)
606 return err
607}
608
609func MergePull(e Execer, repoAt syntax.ATURI, pullId int) error {
610 err := SetPullState(e, repoAt, pullId, PullMerged)
611 return err
612}
613
614func ResubmitPull(e Execer, pull *Pull, newPatch, sourceRev string) error {
615 newRoundNumber := len(pull.Submissions)
616 _, err := e.Exec(`
617 insert into pull_submissions (pull_id, repo_at, round_number, patch, source_rev)
618 values (?, ?, ?, ?, ?)
619 `, pull.PullId, pull.RepoAt, newRoundNumber, newPatch, sourceRev)
620
621 return err
622}
623
624type PullCount struct {
625 Open int
626 Merged int
627 Closed int
628}
629
630func GetPullCount(e Execer, repoAt syntax.ATURI) (PullCount, error) {
631 row := e.QueryRow(`
632 select
633 count(case when state = ? then 1 end) as open_count,
634 count(case when state = ? then 1 end) as merged_count,
635 count(case when state = ? then 1 end) as closed_count
636 from pulls
637 where repo_at = ?`,
638 PullOpen,
639 PullMerged,
640 PullClosed,
641 repoAt,
642 )
643
644 var count PullCount
645 if err := row.Scan(&count.Open, &count.Merged, &count.Closed); err != nil {
646 return PullCount{0, 0, 0}, err
647 }
648
649 return count, nil
650}