1package db
2
3import (
4 "database/sql"
5 "fmt"
6 "maps"
7 "slices"
8 "sort"
9 "strings"
10 "time"
11
12 "github.com/bluesky-social/indigo/atproto/syntax"
13 "tangled.org/core/appview/models"
14)
15
16func NewPull(tx *sql.Tx, pull *models.Pull) error {
17 _, err := tx.Exec(`
18 insert or ignore into repo_pull_seqs (repo_at, next_pull_id)
19 values (?, 1)
20 `, pull.RepoAt)
21 if err != nil {
22 return err
23 }
24
25 var nextId int
26 err = tx.QueryRow(`
27 update repo_pull_seqs
28 set next_pull_id = next_pull_id + 1
29 where repo_at = ?
30 returning next_pull_id - 1
31 `, pull.RepoAt).Scan(&nextId)
32 if err != nil {
33 return err
34 }
35
36 pull.PullId = nextId
37 pull.State = models.PullOpen
38
39 var sourceBranch, sourceRepoAt *string
40 if pull.PullSource != nil {
41 sourceBranch = &pull.PullSource.Branch
42 if pull.PullSource.RepoAt != nil {
43 x := pull.PullSource.RepoAt.String()
44 sourceRepoAt = &x
45 }
46 }
47
48 var stackId, changeId, parentChangeId *string
49 if pull.StackId != "" {
50 stackId = &pull.StackId
51 }
52 if pull.ChangeId != "" {
53 changeId = &pull.ChangeId
54 }
55 if pull.ParentChangeId != "" {
56 parentChangeId = &pull.ParentChangeId
57 }
58
59 result, err := tx.Exec(
60 `
61 insert into pulls (
62 repo_at, owner_did, pull_id, title, target_branch, body, rkey, state, source_branch, source_repo_at, stack_id, change_id, parent_change_id
63 )
64 values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`,
65 pull.RepoAt,
66 pull.OwnerDid,
67 pull.PullId,
68 pull.Title,
69 pull.TargetBranch,
70 pull.Body,
71 pull.Rkey,
72 pull.State,
73 sourceBranch,
74 sourceRepoAt,
75 stackId,
76 changeId,
77 parentChangeId,
78 )
79 if err != nil {
80 return err
81 }
82
83 // Set the database primary key ID
84 id, err := result.LastInsertId()
85 if err != nil {
86 return err
87 }
88 pull.ID = int(id)
89
90 _, err = tx.Exec(`
91 insert into pull_submissions (pull_at, round_number, patch, source_rev)
92 values (?, ?, ?, ?)
93 `, pull.PullAt(), 0, pull.Submissions[0].Patch, pull.Submissions[0].SourceRev)
94 return err
95}
96
97func GetPullAt(e Execer, repoAt syntax.ATURI, pullId int) (syntax.ATURI, error) {
98 pull, err := GetPull(e, repoAt, pullId)
99 if err != nil {
100 return "", err
101 }
102 return pull.PullAt(), err
103}
104
105func NextPullId(e Execer, repoAt syntax.ATURI) (int, error) {
106 var pullId int
107 err := e.QueryRow(`select next_pull_id from repo_pull_seqs where repo_at = ?`, repoAt).Scan(&pullId)
108 return pullId - 1, err
109}
110
111func GetPullsWithLimit(e Execer, limit int, filters ...filter) ([]*models.Pull, error) {
112 pulls := make(map[syntax.ATURI]*models.Pull)
113
114 var conditions []string
115 var args []any
116 for _, filter := range filters {
117 conditions = append(conditions, filter.Condition())
118 args = append(args, filter.Arg()...)
119 }
120
121 whereClause := ""
122 if conditions != nil {
123 whereClause = " where " + strings.Join(conditions, " and ")
124 }
125 limitClause := ""
126 if limit != 0 {
127 limitClause = fmt.Sprintf(" limit %d ", limit)
128 }
129
130 query := fmt.Sprintf(`
131 select
132 id,
133 owner_did,
134 repo_at,
135 pull_id,
136 created,
137 title,
138 state,
139 target_branch,
140 body,
141 rkey,
142 source_branch,
143 source_repo_at,
144 stack_id,
145 change_id,
146 parent_change_id
147 from
148 pulls
149 %s
150 order by
151 created desc
152 %s
153 `, whereClause, limitClause)
154
155 rows, err := e.Query(query, args...)
156 if err != nil {
157 return nil, err
158 }
159 defer rows.Close()
160
161 for rows.Next() {
162 var pull models.Pull
163 var createdAt string
164 var sourceBranch, sourceRepoAt, stackId, changeId, parentChangeId sql.NullString
165 err := rows.Scan(
166 &pull.ID,
167 &pull.OwnerDid,
168 &pull.RepoAt,
169 &pull.PullId,
170 &createdAt,
171 &pull.Title,
172 &pull.State,
173 &pull.TargetBranch,
174 &pull.Body,
175 &pull.Rkey,
176 &sourceBranch,
177 &sourceRepoAt,
178 &stackId,
179 &changeId,
180 &parentChangeId,
181 )
182 if err != nil {
183 return nil, err
184 }
185
186 createdTime, err := time.Parse(time.RFC3339, createdAt)
187 if err != nil {
188 return nil, err
189 }
190 pull.Created = createdTime
191
192 if sourceBranch.Valid {
193 pull.PullSource = &models.PullSource{
194 Branch: sourceBranch.String,
195 }
196 if sourceRepoAt.Valid {
197 sourceRepoAtParsed, err := syntax.ParseATURI(sourceRepoAt.String)
198 if err != nil {
199 return nil, err
200 }
201 pull.PullSource.RepoAt = &sourceRepoAtParsed
202 }
203 }
204
205 if stackId.Valid {
206 pull.StackId = stackId.String
207 }
208 if changeId.Valid {
209 pull.ChangeId = changeId.String
210 }
211 if parentChangeId.Valid {
212 pull.ParentChangeId = parentChangeId.String
213 }
214
215 pulls[pull.PullAt()] = &pull
216 }
217
218 var pullAts []syntax.ATURI
219 for _, p := range pulls {
220 pullAts = append(pullAts, p.PullAt())
221 }
222 submissionsMap, err := GetPullSubmissions(e, FilterIn("pull_at", pullAts))
223 if err != nil {
224 return nil, fmt.Errorf("failed to get submissions: %w", err)
225 }
226
227 for pullAt, submissions := range submissionsMap {
228 if p, ok := pulls[pullAt]; ok {
229 p.Submissions = submissions
230 }
231 }
232
233 orderedByPullId := []*models.Pull{}
234 for _, p := range pulls {
235 orderedByPullId = append(orderedByPullId, p)
236 }
237 sort.Slice(orderedByPullId, func(i, j int) bool {
238 return orderedByPullId[i].PullId > orderedByPullId[j].PullId
239 })
240
241 return orderedByPullId, nil
242}
243
244func GetPulls(e Execer, filters ...filter) ([]*models.Pull, error) {
245 return GetPullsWithLimit(e, 0, filters...)
246}
247
248func GetPull(e Execer, repoAt syntax.ATURI, pullId int) (*models.Pull, error) {
249 pulls, err := GetPullsWithLimit(e, 1, FilterEq("repo_at", repoAt), FilterEq("pull_id", pullId))
250 if err != nil {
251 return nil, err
252 }
253 if pulls == nil {
254 return nil, sql.ErrNoRows
255 }
256
257 return pulls[0], nil
258}
259
260// mapping from pull -> pull submissions
261func GetPullSubmissions(e Execer, filters ...filter) (map[syntax.ATURI][]*models.PullSubmission, error) {
262 var conditions []string
263 var args []any
264 for _, filter := range filters {
265 conditions = append(conditions, filter.Condition())
266 args = append(args, filter.Arg()...)
267 }
268
269 whereClause := ""
270 if conditions != nil {
271 whereClause = " where " + strings.Join(conditions, " and ")
272 }
273
274 query := fmt.Sprintf(`
275 select
276 id,
277 pull_at,
278 round_number,
279 patch,
280 created,
281 source_rev
282 from
283 pull_submissions
284 %s
285 order by
286 round_number asc
287 `, whereClause)
288
289 rows, err := e.Query(query, args...)
290 if err != nil {
291 return nil, err
292 }
293 defer rows.Close()
294
295 submissionMap := make(map[int]*models.PullSubmission)
296
297 for rows.Next() {
298 var submission models.PullSubmission
299 var createdAt string
300 var sourceRev sql.NullString
301 err := rows.Scan(
302 &submission.ID,
303 &submission.PullAt,
304 &submission.RoundNumber,
305 &submission.Patch,
306 &createdAt,
307 &sourceRev,
308 )
309 if err != nil {
310 return nil, err
311 }
312
313 createdTime, err := time.Parse(time.RFC3339, createdAt)
314 if err != nil {
315 return nil, err
316 }
317 submission.Created = createdTime
318
319 if sourceRev.Valid {
320 submission.SourceRev = sourceRev.String
321 }
322
323 submissionMap[submission.ID] = &submission
324 }
325
326 if err := rows.Err(); err != nil {
327 return nil, err
328 }
329
330 // Get comments for all submissions using GetPullComments
331 submissionIds := slices.Collect(maps.Keys(submissionMap))
332 comments, err := GetPullComments(e, FilterIn("submission_id", submissionIds))
333 if err != nil {
334 return nil, err
335 }
336 for _, comment := range comments {
337 if submission, ok := submissionMap[comment.SubmissionId]; ok {
338 submission.Comments = append(submission.Comments, comment)
339 }
340 }
341
342 // order the submissions by pull_at
343 m := make(map[syntax.ATURI][]*models.PullSubmission)
344 for _, s := range submissionMap {
345 m[s.PullAt] = append(m[s.PullAt], s)
346 }
347
348 return m, nil
349}
350
351func GetPullComments(e Execer, filters ...filter) ([]models.PullComment, error) {
352 var conditions []string
353 var args []any
354 for _, filter := range filters {
355 conditions = append(conditions, filter.Condition())
356 args = append(args, filter.Arg()...)
357 }
358
359 whereClause := ""
360 if conditions != nil {
361 whereClause = " where " + strings.Join(conditions, " and ")
362 }
363
364 query := fmt.Sprintf(`
365 select
366 id,
367 pull_id,
368 submission_id,
369 repo_at,
370 owner_did,
371 comment_at,
372 body,
373 created
374 from
375 pull_comments
376 %s
377 order by
378 created asc
379 `, whereClause)
380
381 rows, err := e.Query(query, args...)
382 if err != nil {
383 return nil, err
384 }
385 defer rows.Close()
386
387 var comments []models.PullComment
388 for rows.Next() {
389 var comment models.PullComment
390 var createdAt string
391 err := rows.Scan(
392 &comment.ID,
393 &comment.PullId,
394 &comment.SubmissionId,
395 &comment.RepoAt,
396 &comment.OwnerDid,
397 &comment.CommentAt,
398 &comment.Body,
399 &createdAt,
400 )
401 if err != nil {
402 return nil, err
403 }
404
405 if t, err := time.Parse(time.RFC3339, createdAt); err == nil {
406 comment.Created = t
407 }
408
409 comments = append(comments, comment)
410 }
411
412 if err := rows.Err(); err != nil {
413 return nil, err
414 }
415
416 return comments, nil
417}
418
419// timeframe here is directly passed into the sql query filter, and any
420// timeframe in the past should be negative; e.g.: "-3 months"
421func GetPullsByOwnerDid(e Execer, did, timeframe string) ([]models.Pull, error) {
422 var pulls []models.Pull
423
424 rows, err := e.Query(`
425 select
426 p.owner_did,
427 p.repo_at,
428 p.pull_id,
429 p.created,
430 p.title,
431 p.state,
432 r.did,
433 r.name,
434 r.knot,
435 r.rkey,
436 r.created
437 from
438 pulls p
439 join
440 repos r on p.repo_at = r.at_uri
441 where
442 p.owner_did = ? and p.created >= date ('now', ?)
443 order by
444 p.created desc`, did, timeframe)
445 if err != nil {
446 return nil, err
447 }
448 defer rows.Close()
449
450 for rows.Next() {
451 var pull models.Pull
452 var repo models.Repo
453 var pullCreatedAt, repoCreatedAt string
454 err := rows.Scan(
455 &pull.OwnerDid,
456 &pull.RepoAt,
457 &pull.PullId,
458 &pullCreatedAt,
459 &pull.Title,
460 &pull.State,
461 &repo.Did,
462 &repo.Name,
463 &repo.Knot,
464 &repo.Rkey,
465 &repoCreatedAt,
466 )
467 if err != nil {
468 return nil, err
469 }
470
471 pullCreatedTime, err := time.Parse(time.RFC3339, pullCreatedAt)
472 if err != nil {
473 return nil, err
474 }
475 pull.Created = pullCreatedTime
476
477 repoCreatedTime, err := time.Parse(time.RFC3339, repoCreatedAt)
478 if err != nil {
479 return nil, err
480 }
481 repo.Created = repoCreatedTime
482
483 pull.Repo = &repo
484
485 pulls = append(pulls, pull)
486 }
487
488 if err := rows.Err(); err != nil {
489 return nil, err
490 }
491
492 return pulls, nil
493}
494
495func NewPullComment(e Execer, comment *models.PullComment) (int64, error) {
496 query := `insert into pull_comments (owner_did, repo_at, submission_id, comment_at, pull_id, body) values (?, ?, ?, ?, ?, ?)`
497 res, err := e.Exec(
498 query,
499 comment.OwnerDid,
500 comment.RepoAt,
501 comment.SubmissionId,
502 comment.CommentAt,
503 comment.PullId,
504 comment.Body,
505 )
506 if err != nil {
507 return 0, err
508 }
509
510 i, err := res.LastInsertId()
511 if err != nil {
512 return 0, err
513 }
514
515 return i, nil
516}
517
518func SetPullState(e Execer, repoAt syntax.ATURI, pullId int, pullState models.PullState) error {
519 _, err := e.Exec(
520 `update pulls set state = ? where repo_at = ? and pull_id = ? and (state <> ? or state <> ?)`,
521 pullState,
522 repoAt,
523 pullId,
524 models.PullDeleted, // only update state of non-deleted pulls
525 models.PullMerged, // only update state of non-merged pulls
526 )
527 return err
528}
529
530func ClosePull(e Execer, repoAt syntax.ATURI, pullId int) error {
531 err := SetPullState(e, repoAt, pullId, models.PullClosed)
532 return err
533}
534
535func ReopenPull(e Execer, repoAt syntax.ATURI, pullId int) error {
536 err := SetPullState(e, repoAt, pullId, models.PullOpen)
537 return err
538}
539
540func MergePull(e Execer, repoAt syntax.ATURI, pullId int) error {
541 err := SetPullState(e, repoAt, pullId, models.PullMerged)
542 return err
543}
544
545func DeletePull(e Execer, repoAt syntax.ATURI, pullId int) error {
546 err := SetPullState(e, repoAt, pullId, models.PullDeleted)
547 return err
548}
549
550func ResubmitPull(e Execer, pull *models.Pull, newPatch, sourceRev string) error {
551 newRoundNumber := len(pull.Submissions)
552 _, err := e.Exec(`
553 insert into pull_submissions (pull_at, round_number, patch, source_rev)
554 values (?, ?, ?, ?)
555 `, pull.PullAt(), newRoundNumber, newPatch, sourceRev)
556
557 return err
558}
559
560func SetPullParentChangeId(e Execer, parentChangeId string, filters ...filter) error {
561 var conditions []string
562 var args []any
563
564 args = append(args, parentChangeId)
565
566 for _, filter := range filters {
567 conditions = append(conditions, filter.Condition())
568 args = append(args, filter.Arg()...)
569 }
570
571 whereClause := ""
572 if conditions != nil {
573 whereClause = " where " + strings.Join(conditions, " and ")
574 }
575
576 query := fmt.Sprintf("update pulls set parent_change_id = ? %s", whereClause)
577 _, err := e.Exec(query, args...)
578
579 return err
580}
581
582// Only used when stacking to update contents in the event of a rebase (the interdiff should be empty).
583// otherwise submissions are immutable
584func UpdatePull(e Execer, newPatch, sourceRev string, filters ...filter) error {
585 var conditions []string
586 var args []any
587
588 args = append(args, sourceRev)
589 args = append(args, newPatch)
590
591 for _, filter := range filters {
592 conditions = append(conditions, filter.Condition())
593 args = append(args, filter.Arg()...)
594 }
595
596 whereClause := ""
597 if conditions != nil {
598 whereClause = " where " + strings.Join(conditions, " and ")
599 }
600
601 query := fmt.Sprintf("update pull_submissions set source_rev = ?, patch = ? %s", whereClause)
602 _, err := e.Exec(query, args...)
603
604 return err
605}
606
607func GetPullCount(e Execer, repoAt syntax.ATURI) (models.PullCount, error) {
608 row := e.QueryRow(`
609 select
610 count(case when state = ? then 1 end) as open_count,
611 count(case when state = ? then 1 end) as merged_count,
612 count(case when state = ? then 1 end) as closed_count,
613 count(case when state = ? then 1 end) as deleted_count
614 from pulls
615 where repo_at = ?`,
616 models.PullOpen,
617 models.PullMerged,
618 models.PullClosed,
619 models.PullDeleted,
620 repoAt,
621 )
622
623 var count models.PullCount
624 if err := row.Scan(&count.Open, &count.Merged, &count.Closed, &count.Deleted); err != nil {
625 return models.PullCount{Open: 0, Merged: 0, Closed: 0, Deleted: 0}, err
626 }
627
628 return count, nil
629}
630
631// change-id parent-change-id
632//
633// 4 w ,-------- z (TOP)
634// 3 z <----',------- y
635// 2 y <-----',------ x
636// 1 x <------' nil (BOT)
637//
638// `w` is parent of none, so it is the top of the stack
639func GetStack(e Execer, stackId string) (models.Stack, error) {
640 unorderedPulls, err := GetPulls(
641 e,
642 FilterEq("stack_id", stackId),
643 FilterNotEq("state", models.PullDeleted),
644 )
645 if err != nil {
646 return nil, err
647 }
648 // map of parent-change-id to pull
649 changeIdMap := make(map[string]*models.Pull, len(unorderedPulls))
650 parentMap := make(map[string]*models.Pull, len(unorderedPulls))
651 for _, p := range unorderedPulls {
652 changeIdMap[p.ChangeId] = p
653 if p.ParentChangeId != "" {
654 parentMap[p.ParentChangeId] = p
655 }
656 }
657
658 // the top of the stack is the pull that is not a parent of any pull
659 var topPull *models.Pull
660 for _, maybeTop := range unorderedPulls {
661 if _, ok := parentMap[maybeTop.ChangeId]; !ok {
662 topPull = maybeTop
663 break
664 }
665 }
666
667 pulls := []*models.Pull{}
668 for {
669 pulls = append(pulls, topPull)
670 if topPull.ParentChangeId != "" {
671 if next, ok := changeIdMap[topPull.ParentChangeId]; ok {
672 topPull = next
673 } else {
674 return nil, fmt.Errorf("failed to find parent pull request, stack is malformed")
675 }
676 } else {
677 break
678 }
679 }
680
681 return pulls, nil
682}
683
684func GetAbandonedPulls(e Execer, stackId string) ([]*models.Pull, error) {
685 pulls, err := GetPulls(
686 e,
687 FilterEq("stack_id", stackId),
688 FilterEq("state", models.PullDeleted),
689 )
690 if err != nil {
691 return nil, err
692 }
693
694 return pulls, nil
695}