1package db
2
3import (
4 "cmp"
5 "database/sql"
6 "fmt"
7 "maps"
8 "slices"
9 "sort"
10 "strings"
11 "time"
12
13 "github.com/bluesky-social/indigo/atproto/syntax"
14 "tangled.org/core/appview/models"
15)
16
17func NewPull(tx *sql.Tx, pull *models.Pull) error {
18 _, err := tx.Exec(`
19 insert or ignore into repo_pull_seqs (repo_at, next_pull_id)
20 values (?, 1)
21 `, pull.RepoAt)
22 if err != nil {
23 return err
24 }
25
26 var nextId int
27 err = tx.QueryRow(`
28 update repo_pull_seqs
29 set next_pull_id = next_pull_id + 1
30 where repo_at = ?
31 returning next_pull_id - 1
32 `, pull.RepoAt).Scan(&nextId)
33 if err != nil {
34 return err
35 }
36
37 pull.PullId = nextId
38 pull.State = models.PullOpen
39
40 var sourceBranch, sourceRepoAt *string
41 if pull.PullSource != nil {
42 sourceBranch = &pull.PullSource.Branch
43 if pull.PullSource.RepoAt != nil {
44 x := pull.PullSource.RepoAt.String()
45 sourceRepoAt = &x
46 }
47 }
48
49 var stackId, changeId, parentChangeId *string
50 if pull.StackId != "" {
51 stackId = &pull.StackId
52 }
53 if pull.ChangeId != "" {
54 changeId = &pull.ChangeId
55 }
56 if pull.ParentChangeId != "" {
57 parentChangeId = &pull.ParentChangeId
58 }
59
60 result, err := tx.Exec(
61 `
62 insert into pulls (
63 repo_at, owner_did, pull_id, title, target_branch, body, rkey, state, source_branch, source_repo_at, stack_id, change_id, parent_change_id
64 )
65 values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`,
66 pull.RepoAt,
67 pull.OwnerDid,
68 pull.PullId,
69 pull.Title,
70 pull.TargetBranch,
71 pull.Body,
72 pull.Rkey,
73 pull.State,
74 sourceBranch,
75 sourceRepoAt,
76 stackId,
77 changeId,
78 parentChangeId,
79 )
80 if err != nil {
81 return err
82 }
83
84 // Set the database primary key ID
85 id, err := result.LastInsertId()
86 if err != nil {
87 return err
88 }
89 pull.ID = int(id)
90
91 _, err = tx.Exec(`
92 insert into pull_submissions (pull_at, round_number, patch, source_rev)
93 values (?, ?, ?, ?)
94 `, pull.PullAt(), 0, pull.Submissions[0].Patch, pull.Submissions[0].SourceRev)
95 return err
96}
97
98func GetPullAt(e Execer, repoAt syntax.ATURI, pullId int) (syntax.ATURI, error) {
99 pull, err := GetPull(e, repoAt, pullId)
100 if err != nil {
101 return "", err
102 }
103 return pull.PullAt(), err
104}
105
106func NextPullId(e Execer, repoAt syntax.ATURI) (int, error) {
107 var pullId int
108 err := e.QueryRow(`select next_pull_id from repo_pull_seqs where repo_at = ?`, repoAt).Scan(&pullId)
109 return pullId - 1, err
110}
111
112func GetPullsWithLimit(e Execer, limit int, filters ...filter) ([]*models.Pull, error) {
113 pulls := make(map[syntax.ATURI]*models.Pull)
114
115 var conditions []string
116 var args []any
117 for _, filter := range filters {
118 conditions = append(conditions, filter.Condition())
119 args = append(args, filter.Arg()...)
120 }
121
122 whereClause := ""
123 if conditions != nil {
124 whereClause = " where " + strings.Join(conditions, " and ")
125 }
126 limitClause := ""
127 if limit != 0 {
128 limitClause = fmt.Sprintf(" limit %d ", limit)
129 }
130
131 query := fmt.Sprintf(`
132 select
133 id,
134 owner_did,
135 repo_at,
136 pull_id,
137 created,
138 title,
139 state,
140 target_branch,
141 body,
142 rkey,
143 source_branch,
144 source_repo_at,
145 stack_id,
146 change_id,
147 parent_change_id
148 from
149 pulls
150 %s
151 order by
152 created desc
153 %s
154 `, whereClause, limitClause)
155
156 rows, err := e.Query(query, args...)
157 if err != nil {
158 return nil, err
159 }
160 defer rows.Close()
161
162 for rows.Next() {
163 var pull models.Pull
164 var createdAt string
165 var sourceBranch, sourceRepoAt, stackId, changeId, parentChangeId sql.NullString
166 err := rows.Scan(
167 &pull.ID,
168 &pull.OwnerDid,
169 &pull.RepoAt,
170 &pull.PullId,
171 &createdAt,
172 &pull.Title,
173 &pull.State,
174 &pull.TargetBranch,
175 &pull.Body,
176 &pull.Rkey,
177 &sourceBranch,
178 &sourceRepoAt,
179 &stackId,
180 &changeId,
181 &parentChangeId,
182 )
183 if err != nil {
184 return nil, err
185 }
186
187 createdTime, err := time.Parse(time.RFC3339, createdAt)
188 if err != nil {
189 return nil, err
190 }
191 pull.Created = createdTime
192
193 if sourceBranch.Valid {
194 pull.PullSource = &models.PullSource{
195 Branch: sourceBranch.String,
196 }
197 if sourceRepoAt.Valid {
198 sourceRepoAtParsed, err := syntax.ParseATURI(sourceRepoAt.String)
199 if err != nil {
200 return nil, err
201 }
202 pull.PullSource.RepoAt = &sourceRepoAtParsed
203 }
204 }
205
206 if stackId.Valid {
207 pull.StackId = stackId.String
208 }
209 if changeId.Valid {
210 pull.ChangeId = changeId.String
211 }
212 if parentChangeId.Valid {
213 pull.ParentChangeId = parentChangeId.String
214 }
215
216 pulls[pull.PullAt()] = &pull
217 }
218
219 var pullAts []syntax.ATURI
220 for _, p := range pulls {
221 pullAts = append(pullAts, p.PullAt())
222 }
223 submissionsMap, err := GetPullSubmissions(e, FilterIn("pull_at", pullAts))
224 if err != nil {
225 return nil, fmt.Errorf("failed to get submissions: %w", err)
226 }
227
228 for pullAt, submissions := range submissionsMap {
229 if p, ok := pulls[pullAt]; ok {
230 p.Submissions = submissions
231 }
232 }
233 // collect allLabels for each issue
234 allLabels, err := GetLabels(e, FilterIn("subject", pullAts))
235 if err != nil {
236 return nil, fmt.Errorf("failed to query labels: %w", err)
237 }
238 for pullAt, labels := range allLabels {
239 if p, ok := pulls[pullAt]; ok {
240 p.Labels = labels
241 }
242 }
243
244 orderedByPullId := []*models.Pull{}
245 for _, p := range pulls {
246 orderedByPullId = append(orderedByPullId, p)
247 }
248 sort.Slice(orderedByPullId, func(i, j int) bool {
249 return orderedByPullId[i].PullId > orderedByPullId[j].PullId
250 })
251
252 return orderedByPullId, nil
253}
254
255func GetPulls(e Execer, filters ...filter) ([]*models.Pull, error) {
256 return GetPullsWithLimit(e, 0, filters...)
257}
258
259func GetPull(e Execer, repoAt syntax.ATURI, pullId int) (*models.Pull, error) {
260 pulls, err := GetPullsWithLimit(e, 1, FilterEq("repo_at", repoAt), FilterEq("pull_id", pullId))
261 if err != nil {
262 return nil, err
263 }
264 if pulls == nil {
265 return nil, sql.ErrNoRows
266 }
267
268 return pulls[0], nil
269}
270
271// mapping from pull -> pull submissions
272func GetPullSubmissions(e Execer, filters ...filter) (map[syntax.ATURI][]*models.PullSubmission, error) {
273 var conditions []string
274 var args []any
275 for _, filter := range filters {
276 conditions = append(conditions, filter.Condition())
277 args = append(args, filter.Arg()...)
278 }
279
280 whereClause := ""
281 if conditions != nil {
282 whereClause = " where " + strings.Join(conditions, " and ")
283 }
284
285 query := fmt.Sprintf(`
286 select
287 id,
288 pull_at,
289 round_number,
290 patch,
291 created,
292 source_rev
293 from
294 pull_submissions
295 %s
296 order by
297 round_number asc
298 `, whereClause)
299
300 rows, err := e.Query(query, args...)
301 if err != nil {
302 return nil, err
303 }
304 defer rows.Close()
305
306 submissionMap := make(map[int]*models.PullSubmission)
307
308 for rows.Next() {
309 var submission models.PullSubmission
310 var createdAt string
311 var sourceRev sql.NullString
312 err := rows.Scan(
313 &submission.ID,
314 &submission.PullAt,
315 &submission.RoundNumber,
316 &submission.Patch,
317 &createdAt,
318 &sourceRev,
319 )
320 if err != nil {
321 return nil, err
322 }
323
324 createdTime, err := time.Parse(time.RFC3339, createdAt)
325 if err != nil {
326 return nil, err
327 }
328 submission.Created = createdTime
329
330 if sourceRev.Valid {
331 submission.SourceRev = sourceRev.String
332 }
333
334 submissionMap[submission.ID] = &submission
335 }
336
337 if err := rows.Err(); err != nil {
338 return nil, err
339 }
340
341 // Get comments for all submissions using GetPullComments
342 submissionIds := slices.Collect(maps.Keys(submissionMap))
343 comments, err := GetPullComments(e, FilterIn("submission_id", submissionIds))
344 if err != nil {
345 return nil, err
346 }
347 for _, comment := range comments {
348 if submission, ok := submissionMap[comment.SubmissionId]; ok {
349 submission.Comments = append(submission.Comments, comment)
350 }
351 }
352
353 // group the submissions by pull_at
354 m := make(map[syntax.ATURI][]*models.PullSubmission)
355 for _, s := range submissionMap {
356 m[s.PullAt] = append(m[s.PullAt], s)
357 }
358
359 // sort each one by round number
360 for _, s := range m {
361 slices.SortFunc(s, func(a, b *models.PullSubmission) int {
362 return cmp.Compare(a.RoundNumber, b.RoundNumber)
363 })
364 }
365
366 return m, nil
367}
368
369func GetPullComments(e Execer, filters ...filter) ([]models.PullComment, error) {
370 var conditions []string
371 var args []any
372 for _, filter := range filters {
373 conditions = append(conditions, filter.Condition())
374 args = append(args, filter.Arg()...)
375 }
376
377 whereClause := ""
378 if conditions != nil {
379 whereClause = " where " + strings.Join(conditions, " and ")
380 }
381
382 query := fmt.Sprintf(`
383 select
384 id,
385 pull_id,
386 submission_id,
387 repo_at,
388 owner_did,
389 comment_at,
390 body,
391 created
392 from
393 pull_comments
394 %s
395 order by
396 created asc
397 `, whereClause)
398
399 rows, err := e.Query(query, args...)
400 if err != nil {
401 return nil, err
402 }
403 defer rows.Close()
404
405 var comments []models.PullComment
406 for rows.Next() {
407 var comment models.PullComment
408 var createdAt string
409 err := rows.Scan(
410 &comment.ID,
411 &comment.PullId,
412 &comment.SubmissionId,
413 &comment.RepoAt,
414 &comment.OwnerDid,
415 &comment.CommentAt,
416 &comment.Body,
417 &createdAt,
418 )
419 if err != nil {
420 return nil, err
421 }
422
423 if t, err := time.Parse(time.RFC3339, createdAt); err == nil {
424 comment.Created = t
425 }
426
427 comments = append(comments, comment)
428 }
429
430 if err := rows.Err(); err != nil {
431 return nil, err
432 }
433
434 return comments, nil
435}
436
437// timeframe here is directly passed into the sql query filter, and any
438// timeframe in the past should be negative; e.g.: "-3 months"
439func GetPullsByOwnerDid(e Execer, did, timeframe string) ([]models.Pull, error) {
440 var pulls []models.Pull
441
442 rows, err := e.Query(`
443 select
444 p.owner_did,
445 p.repo_at,
446 p.pull_id,
447 p.created,
448 p.title,
449 p.state,
450 r.did,
451 r.name,
452 r.knot,
453 r.rkey,
454 r.created
455 from
456 pulls p
457 join
458 repos r on p.repo_at = r.at_uri
459 where
460 p.owner_did = ? and p.created >= date ('now', ?)
461 order by
462 p.created desc`, did, timeframe)
463 if err != nil {
464 return nil, err
465 }
466 defer rows.Close()
467
468 for rows.Next() {
469 var pull models.Pull
470 var repo models.Repo
471 var pullCreatedAt, repoCreatedAt string
472 err := rows.Scan(
473 &pull.OwnerDid,
474 &pull.RepoAt,
475 &pull.PullId,
476 &pullCreatedAt,
477 &pull.Title,
478 &pull.State,
479 &repo.Did,
480 &repo.Name,
481 &repo.Knot,
482 &repo.Rkey,
483 &repoCreatedAt,
484 )
485 if err != nil {
486 return nil, err
487 }
488
489 pullCreatedTime, err := time.Parse(time.RFC3339, pullCreatedAt)
490 if err != nil {
491 return nil, err
492 }
493 pull.Created = pullCreatedTime
494
495 repoCreatedTime, err := time.Parse(time.RFC3339, repoCreatedAt)
496 if err != nil {
497 return nil, err
498 }
499 repo.Created = repoCreatedTime
500
501 pull.Repo = &repo
502
503 pulls = append(pulls, pull)
504 }
505
506 if err := rows.Err(); err != nil {
507 return nil, err
508 }
509
510 return pulls, nil
511}
512
513func NewPullComment(e Execer, comment *models.PullComment) (int64, error) {
514 query := `insert into pull_comments (owner_did, repo_at, submission_id, comment_at, pull_id, body) values (?, ?, ?, ?, ?, ?)`
515 res, err := e.Exec(
516 query,
517 comment.OwnerDid,
518 comment.RepoAt,
519 comment.SubmissionId,
520 comment.CommentAt,
521 comment.PullId,
522 comment.Body,
523 )
524 if err != nil {
525 return 0, err
526 }
527
528 i, err := res.LastInsertId()
529 if err != nil {
530 return 0, err
531 }
532
533 return i, nil
534}
535
536func SetPullState(e Execer, repoAt syntax.ATURI, pullId int, pullState models.PullState) error {
537 _, err := e.Exec(
538 `update pulls set state = ? where repo_at = ? and pull_id = ? and (state <> ? or state <> ?)`,
539 pullState,
540 repoAt,
541 pullId,
542 models.PullDeleted, // only update state of non-deleted pulls
543 models.PullMerged, // only update state of non-merged pulls
544 )
545 return err
546}
547
548func ClosePull(e Execer, repoAt syntax.ATURI, pullId int) error {
549 err := SetPullState(e, repoAt, pullId, models.PullClosed)
550 return err
551}
552
553func ReopenPull(e Execer, repoAt syntax.ATURI, pullId int) error {
554 err := SetPullState(e, repoAt, pullId, models.PullOpen)
555 return err
556}
557
558func MergePull(e Execer, repoAt syntax.ATURI, pullId int) error {
559 err := SetPullState(e, repoAt, pullId, models.PullMerged)
560 return err
561}
562
563func DeletePull(e Execer, repoAt syntax.ATURI, pullId int) error {
564 err := SetPullState(e, repoAt, pullId, models.PullDeleted)
565 return err
566}
567
568func ResubmitPull(e Execer, pull *models.Pull, newPatch, sourceRev string) error {
569 newRoundNumber := len(pull.Submissions)
570 _, err := e.Exec(`
571 insert into pull_submissions (pull_at, round_number, patch, source_rev)
572 values (?, ?, ?, ?)
573 `, pull.PullAt(), newRoundNumber, newPatch, sourceRev)
574
575 return err
576}
577
578func SetPullParentChangeId(e Execer, parentChangeId string, filters ...filter) error {
579 var conditions []string
580 var args []any
581
582 args = append(args, parentChangeId)
583
584 for _, filter := range filters {
585 conditions = append(conditions, filter.Condition())
586 args = append(args, filter.Arg()...)
587 }
588
589 whereClause := ""
590 if conditions != nil {
591 whereClause = " where " + strings.Join(conditions, " and ")
592 }
593
594 query := fmt.Sprintf("update pulls set parent_change_id = ? %s", whereClause)
595 _, err := e.Exec(query, args...)
596
597 return err
598}
599
600// Only used when stacking to update contents in the event of a rebase (the interdiff should be empty).
601// otherwise submissions are immutable
602func UpdatePull(e Execer, newPatch, sourceRev string, filters ...filter) error {
603 var conditions []string
604 var args []any
605
606 args = append(args, sourceRev)
607 args = append(args, newPatch)
608
609 for _, filter := range filters {
610 conditions = append(conditions, filter.Condition())
611 args = append(args, filter.Arg()...)
612 }
613
614 whereClause := ""
615 if conditions != nil {
616 whereClause = " where " + strings.Join(conditions, " and ")
617 }
618
619 query := fmt.Sprintf("update pull_submissions set source_rev = ?, patch = ? %s", whereClause)
620 _, err := e.Exec(query, args...)
621
622 return err
623}
624
625func GetPullCount(e Execer, repoAt syntax.ATURI) (models.PullCount, error) {
626 row := e.QueryRow(`
627 select
628 count(case when state = ? then 1 end) as open_count,
629 count(case when state = ? then 1 end) as merged_count,
630 count(case when state = ? then 1 end) as closed_count,
631 count(case when state = ? then 1 end) as deleted_count
632 from pulls
633 where repo_at = ?`,
634 models.PullOpen,
635 models.PullMerged,
636 models.PullClosed,
637 models.PullDeleted,
638 repoAt,
639 )
640
641 var count models.PullCount
642 if err := row.Scan(&count.Open, &count.Merged, &count.Closed, &count.Deleted); err != nil {
643 return models.PullCount{Open: 0, Merged: 0, Closed: 0, Deleted: 0}, err
644 }
645
646 return count, nil
647}
648
649// change-id parent-change-id
650//
651// 4 w ,-------- z (TOP)
652// 3 z <----',------- y
653// 2 y <-----',------ x
654// 1 x <------' nil (BOT)
655//
656// `w` is parent of none, so it is the top of the stack
657func GetStack(e Execer, stackId string) (models.Stack, error) {
658 unorderedPulls, err := GetPulls(
659 e,
660 FilterEq("stack_id", stackId),
661 FilterNotEq("state", models.PullDeleted),
662 )
663 if err != nil {
664 return nil, err
665 }
666 // map of parent-change-id to pull
667 changeIdMap := make(map[string]*models.Pull, len(unorderedPulls))
668 parentMap := make(map[string]*models.Pull, len(unorderedPulls))
669 for _, p := range unorderedPulls {
670 changeIdMap[p.ChangeId] = p
671 if p.ParentChangeId != "" {
672 parentMap[p.ParentChangeId] = p
673 }
674 }
675
676 // the top of the stack is the pull that is not a parent of any pull
677 var topPull *models.Pull
678 for _, maybeTop := range unorderedPulls {
679 if _, ok := parentMap[maybeTop.ChangeId]; !ok {
680 topPull = maybeTop
681 break
682 }
683 }
684
685 pulls := []*models.Pull{}
686 for {
687 pulls = append(pulls, topPull)
688 if topPull.ParentChangeId != "" {
689 if next, ok := changeIdMap[topPull.ParentChangeId]; ok {
690 topPull = next
691 } else {
692 return nil, fmt.Errorf("failed to find parent pull request, stack is malformed")
693 }
694 } else {
695 break
696 }
697 }
698
699 return pulls, nil
700}
701
702func GetAbandonedPulls(e Execer, stackId string) ([]*models.Pull, error) {
703 pulls, err := GetPulls(
704 e,
705 FilterEq("stack_id", stackId),
706 FilterEq("state", models.PullDeleted),
707 )
708 if err != nil {
709 return nil, err
710 }
711
712 return pulls, nil
713}