1package db
2
3import (
4 "database/sql"
5 "fmt"
6 "log"
7 "strings"
8 "time"
9
10 "github.com/bluekeyes/go-gitdiff/gitdiff"
11 "github.com/bluesky-social/indigo/atproto/syntax"
12 "tangled.sh/tangled.sh/core/types"
13)
14
15type PullState int
16
17const (
18 PullClosed PullState = iota
19 PullOpen
20 PullMerged
21)
22
23func (p PullState) String() string {
24 switch p {
25 case PullOpen:
26 return "open"
27 case PullMerged:
28 return "merged"
29 case PullClosed:
30 return "closed"
31 default:
32 return "closed"
33 }
34}
35
36func (p PullState) IsOpen() bool {
37 return p == PullOpen
38}
39func (p PullState) IsMerged() bool {
40 return p == PullMerged
41}
42func (p PullState) IsClosed() bool {
43 return p == PullClosed
44}
45
46type Pull struct {
47 // ids
48 ID int
49 PullId int
50
51 // at ids
52 RepoAt syntax.ATURI
53 OwnerDid string
54 Rkey string
55 PullAt syntax.ATURI
56
57 // content
58 Title string
59 Body string
60 TargetBranch string
61 State PullState
62 Submissions []*PullSubmission
63
64 // meta
65 Created time.Time
66 PullSource *PullSource
67}
68
69type PullSource struct {
70 Branch string
71 Repo *syntax.ATURI
72}
73
74type PullSubmission struct {
75 // ids
76 ID int
77 PullId int
78
79 // at ids
80 RepoAt syntax.ATURI
81
82 // content
83 RoundNumber int
84 Patch string
85 Comments []PullComment
86
87 // meta
88 Created time.Time
89}
90
91type PullComment struct {
92 // ids
93 ID int
94 PullId int
95 SubmissionId int
96
97 // at ids
98 RepoAt string
99 OwnerDid string
100 CommentAt string
101
102 // content
103 Body string
104
105 // meta
106 Created time.Time
107}
108
109func (p *Pull) LatestPatch() string {
110 latestSubmission := p.Submissions[p.LastRoundNumber()]
111 return latestSubmission.Patch
112}
113
114func (p *Pull) LastRoundNumber() int {
115 return len(p.Submissions) - 1
116}
117
118func (p *Pull) IsSameRepoBranch() bool {
119 if p.PullSource != nil {
120 if p.PullSource.Repo != nil {
121 return p.PullSource.Repo == &p.RepoAt
122 } else {
123 // no repo specified
124 return true
125 }
126 }
127 return false
128}
129
130func (s PullSubmission) AsNiceDiff(targetBranch string) types.NiceDiff {
131 patch := s.Patch
132
133 diffs, _, err := gitdiff.Parse(strings.NewReader(patch))
134 if err != nil {
135 log.Println(err)
136 }
137
138 nd := types.NiceDiff{}
139 nd.Commit.Parent = targetBranch
140
141 for _, d := range diffs {
142 ndiff := types.Diff{}
143 ndiff.Name.New = d.NewName
144 ndiff.Name.Old = d.OldName
145 ndiff.IsBinary = d.IsBinary
146 ndiff.IsNew = d.IsNew
147 ndiff.IsDelete = d.IsDelete
148 ndiff.IsCopy = d.IsCopy
149 ndiff.IsRename = d.IsRename
150
151 for _, tf := range d.TextFragments {
152 ndiff.TextFragments = append(ndiff.TextFragments, *tf)
153 for _, l := range tf.Lines {
154 switch l.Op {
155 case gitdiff.OpAdd:
156 nd.Stat.Insertions += 1
157 case gitdiff.OpDelete:
158 nd.Stat.Deletions += 1
159 }
160 }
161 }
162
163 nd.Diff = append(nd.Diff, ndiff)
164 }
165
166 nd.Stat.FilesChanged = len(diffs)
167
168 return nd
169}
170
171func NewPull(tx *sql.Tx, pull *Pull) error {
172 defer tx.Rollback()
173
174 _, err := tx.Exec(`
175 insert or ignore into repo_pull_seqs (repo_at, next_pull_id)
176 values (?, 1)
177 `, pull.RepoAt)
178 if err != nil {
179 return err
180 }
181
182 var nextId int
183 err = tx.QueryRow(`
184 update repo_pull_seqs
185 set next_pull_id = next_pull_id + 1
186 where repo_at = ?
187 returning next_pull_id - 1
188 `, pull.RepoAt).Scan(&nextId)
189 if err != nil {
190 return err
191 }
192
193 pull.PullId = nextId
194 pull.State = PullOpen
195
196 var sourceBranch, sourceRepoAt *string
197 if pull.PullSource != nil {
198 sourceBranch = &pull.PullSource.Branch
199 if pull.PullSource.Repo != nil {
200 x := pull.PullSource.Repo.String()
201 sourceRepoAt = &x
202 }
203 }
204
205 _, err = tx.Exec(
206 `
207 insert into pulls (repo_at, owner_did, pull_id, title, target_branch, body, rkey, state, source_branch, source_repo_at)
208 values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`,
209 pull.RepoAt,
210 pull.OwnerDid,
211 pull.PullId,
212 pull.Title,
213 pull.TargetBranch,
214 pull.Body,
215 pull.Rkey,
216 pull.State,
217 sourceBranch,
218 sourceRepoAt,
219 )
220 if err != nil {
221 return err
222 }
223
224 _, err = tx.Exec(`
225 insert into pull_submissions (pull_id, repo_at, round_number, patch)
226 values (?, ?, ?, ?)
227 `, pull.PullId, pull.RepoAt, 0, pull.Submissions[0].Patch)
228 if err != nil {
229 return err
230 }
231
232 if err := tx.Commit(); err != nil {
233 return err
234 }
235
236 return nil
237}
238
239func SetPullAt(e Execer, repoAt syntax.ATURI, pullId int, pullAt string) error {
240 _, err := e.Exec(`update pulls set pull_at = ? where repo_at = ? and pull_id = ?`, pullAt, repoAt, pullId)
241 return err
242}
243
244func GetPullAt(e Execer, repoAt syntax.ATURI, pullId int) (string, error) {
245 var pullAt string
246 err := e.QueryRow(`select pull_at from pulls where repo_at = ? and pull_id = ?`, repoAt, pullId).Scan(&pullAt)
247 return pullAt, err
248}
249
250func NextPullId(e Execer, repoAt syntax.ATURI) (int, error) {
251 var pullId int
252 err := e.QueryRow(`select next_pull_id from repo_pull_seqs where repo_at = ?`, repoAt).Scan(&pullId)
253 return pullId - 1, err
254}
255
256func GetPulls(e Execer, repoAt syntax.ATURI, state PullState) ([]Pull, error) {
257 var pulls []Pull
258
259 rows, err := e.Query(`
260 select
261 owner_did,
262 pull_id,
263 created,
264 title,
265 state,
266 target_branch,
267 pull_at,
268 body,
269 rkey,
270 source_branch,
271 source_repo_at
272 from
273 pulls
274 where
275 repo_at = ? and state = ?
276 order by
277 created desc`, repoAt, state)
278 if err != nil {
279 return nil, err
280 }
281 defer rows.Close()
282
283 for rows.Next() {
284 var pull Pull
285 var createdAt string
286 var sourceBranch, sourceRepoAt sql.NullString
287 err := rows.Scan(
288 &pull.OwnerDid,
289 &pull.PullId,
290 &createdAt,
291 &pull.Title,
292 &pull.State,
293 &pull.TargetBranch,
294 &pull.PullAt,
295 &pull.Body,
296 &pull.Rkey,
297 &sourceBranch,
298 &sourceRepoAt,
299 )
300 if err != nil {
301 return nil, err
302 }
303
304 createdTime, err := time.Parse(time.RFC3339, createdAt)
305 if err != nil {
306 return nil, err
307 }
308 pull.Created = createdTime
309
310 if sourceBranch.Valid {
311 pull.PullSource = &PullSource{
312 Branch: sourceBranch.String,
313 }
314 if sourceRepoAt.Valid {
315 sourceRepoAtParsed, err := syntax.ParseATURI(sourceRepoAt.String)
316 if err != nil {
317 return nil, err
318 }
319 pull.PullSource.Repo = &sourceRepoAtParsed
320 }
321 }
322
323 pulls = append(pulls, pull)
324 }
325
326 if err := rows.Err(); err != nil {
327 return nil, err
328 }
329
330 return pulls, nil
331}
332
333func GetPull(e Execer, repoAt syntax.ATURI, pullId int) (*Pull, error) {
334 query := `
335 select
336 owner_did,
337 pull_id,
338 created,
339 title,
340 state,
341 target_branch,
342 pull_at,
343 repo_at,
344 body,
345 rkey,
346 source_branch,
347 source_repo_at
348 from
349 pulls
350 where
351 repo_at = ? and pull_id = ?
352 `
353 row := e.QueryRow(query, repoAt, pullId)
354
355 var pull Pull
356 var createdAt string
357 var sourceBranch, sourceRepoAt sql.NullString
358 err := row.Scan(
359 &pull.OwnerDid,
360 &pull.PullId,
361 &createdAt,
362 &pull.Title,
363 &pull.State,
364 &pull.TargetBranch,
365 &pull.PullAt,
366 &pull.RepoAt,
367 &pull.Body,
368 &pull.Rkey,
369 &sourceBranch,
370 &sourceRepoAt,
371 )
372 if err != nil {
373 return nil, err
374 }
375
376 createdTime, err := time.Parse(time.RFC3339, createdAt)
377 if err != nil {
378 return nil, err
379 }
380 pull.Created = createdTime
381
382 // populate source
383 if sourceBranch.Valid {
384 pull.PullSource = &PullSource{
385 Branch: sourceBranch.String,
386 }
387 if sourceRepoAt.Valid {
388 sourceRepoAtParsed, err := syntax.ParseATURI(sourceRepoAt.String)
389 if err != nil {
390 return nil, err
391 }
392 pull.PullSource.Repo = &sourceRepoAtParsed
393 }
394 }
395
396 submissionsQuery := `
397 select
398 id, pull_id, repo_at, round_number, patch, created
399 from
400 pull_submissions
401 where
402 repo_at = ? and pull_id = ?
403 `
404 submissionsRows, err := e.Query(submissionsQuery, repoAt, pullId)
405 if err != nil {
406 return nil, err
407 }
408 defer submissionsRows.Close()
409
410 submissionsMap := make(map[int]*PullSubmission)
411
412 for submissionsRows.Next() {
413 var submission PullSubmission
414 var submissionCreatedStr string
415 err := submissionsRows.Scan(
416 &submission.ID,
417 &submission.PullId,
418 &submission.RepoAt,
419 &submission.RoundNumber,
420 &submission.Patch,
421 &submissionCreatedStr,
422 )
423 if err != nil {
424 return nil, err
425 }
426
427 submissionCreatedTime, err := time.Parse(time.RFC3339, submissionCreatedStr)
428 if err != nil {
429 return nil, err
430 }
431 submission.Created = submissionCreatedTime
432
433 submissionsMap[submission.ID] = &submission
434 }
435 if err = submissionsRows.Close(); err != nil {
436 return nil, err
437 }
438 if len(submissionsMap) == 0 {
439 return &pull, nil
440 }
441
442 var args []any
443 for k := range submissionsMap {
444 args = append(args, k)
445 }
446 inClause := strings.TrimSuffix(strings.Repeat("?, ", len(submissionsMap)), ", ")
447 commentsQuery := fmt.Sprintf(`
448 select
449 id,
450 pull_id,
451 submission_id,
452 repo_at,
453 owner_did,
454 comment_at,
455 body,
456 created
457 from
458 pull_comments
459 where
460 submission_id IN (%s)
461 order by
462 created asc
463 `, inClause)
464 commentsRows, err := e.Query(commentsQuery, args...)
465 if err != nil {
466 return nil, err
467 }
468 defer commentsRows.Close()
469
470 for commentsRows.Next() {
471 var comment PullComment
472 var commentCreatedStr string
473 err := commentsRows.Scan(
474 &comment.ID,
475 &comment.PullId,
476 &comment.SubmissionId,
477 &comment.RepoAt,
478 &comment.OwnerDid,
479 &comment.CommentAt,
480 &comment.Body,
481 &commentCreatedStr,
482 )
483 if err != nil {
484 return nil, err
485 }
486
487 commentCreatedTime, err := time.Parse(time.RFC3339, commentCreatedStr)
488 if err != nil {
489 return nil, err
490 }
491 comment.Created = commentCreatedTime
492
493 // Add the comment to its submission
494 if submission, ok := submissionsMap[comment.SubmissionId]; ok {
495 submission.Comments = append(submission.Comments, comment)
496 }
497
498 }
499 if err = commentsRows.Err(); err != nil {
500 return nil, err
501 }
502
503 pull.Submissions = make([]*PullSubmission, len(submissionsMap))
504 for _, submission := range submissionsMap {
505 pull.Submissions[submission.RoundNumber] = submission
506 }
507
508 return &pull, nil
509}
510
511func GetPullsByOwnerDid(e Execer, did string) ([]Pull, error) {
512 var pulls []Pull
513
514 rows, err := e.Query(`
515 select
516 owner_did,
517 repo_at,
518 pull_id,
519 created,
520 title,
521 state
522 from
523 pulls
524 where
525 owner_did = ?
526 order by
527 created desc`, did)
528 if err != nil {
529 return nil, err
530 }
531 defer rows.Close()
532
533 for rows.Next() {
534 var pull Pull
535 var createdAt string
536 err := rows.Scan(
537 &pull.OwnerDid,
538 &pull.RepoAt,
539 &pull.PullId,
540 &createdAt,
541 &pull.Title,
542 &pull.State,
543 )
544 if err != nil {
545 return nil, err
546 }
547
548 createdTime, err := time.Parse(time.RFC3339, createdAt)
549 if err != nil {
550 return nil, err
551 }
552 pull.Created = createdTime
553
554 pulls = append(pulls, pull)
555 }
556
557 if err := rows.Err(); err != nil {
558 return nil, err
559 }
560
561 return pulls, nil
562}
563
564func NewPullComment(e Execer, comment *PullComment) (int64, error) {
565 query := `insert into pull_comments (owner_did, repo_at, submission_id, comment_at, pull_id, body) values (?, ?, ?, ?, ?, ?)`
566 res, err := e.Exec(
567 query,
568 comment.OwnerDid,
569 comment.RepoAt,
570 comment.SubmissionId,
571 comment.CommentAt,
572 comment.PullId,
573 comment.Body,
574 )
575 if err != nil {
576 return 0, err
577 }
578
579 i, err := res.LastInsertId()
580 if err != nil {
581 return 0, err
582 }
583
584 return i, nil
585}
586
587func SetPullState(e Execer, repoAt syntax.ATURI, pullId int, pullState PullState) error {
588 _, err := e.Exec(`update pulls set state = ? where repo_at = ? and pull_id = ?`, pullState, repoAt, pullId)
589 return err
590}
591
592func ClosePull(e Execer, repoAt syntax.ATURI, pullId int) error {
593 err := SetPullState(e, repoAt, pullId, PullClosed)
594 return err
595}
596
597func ReopenPull(e Execer, repoAt syntax.ATURI, pullId int) error {
598 err := SetPullState(e, repoAt, pullId, PullOpen)
599 return err
600}
601
602func MergePull(e Execer, repoAt syntax.ATURI, pullId int) error {
603 err := SetPullState(e, repoAt, pullId, PullMerged)
604 return err
605}
606
607func ResubmitPull(e Execer, pull *Pull, newPatch string) error {
608 newRoundNumber := len(pull.Submissions)
609 _, err := e.Exec(`
610 insert into pull_submissions (pull_id, repo_at, round_number, patch)
611 values (?, ?, ?, ?)
612 `, pull.PullId, pull.RepoAt, newRoundNumber, newPatch)
613
614 return err
615}
616
617type PullCount struct {
618 Open int
619 Merged int
620 Closed int
621}
622
623func GetPullCount(e Execer, repoAt syntax.ATURI) (PullCount, error) {
624 row := e.QueryRow(`
625 select
626 count(case when state = ? then 1 end) as open_count,
627 count(case when state = ? then 1 end) as merged_count,
628 count(case when state = ? then 1 end) as closed_count
629 from pulls
630 where repo_at = ?`,
631 PullOpen,
632 PullMerged,
633 PullClosed,
634 repoAt,
635 )
636
637 var count PullCount
638 if err := row.Scan(&count.Open, &count.Merged, &count.Closed); err != nil {
639 return PullCount{0, 0, 0}, err
640 }
641
642 return count, nil
643}