forked from tangled.org/core
Monorepo for Tangled — https://tangled.org
1package db 2 3import ( 4 "cmp" 5 "database/sql" 6 "errors" 7 "fmt" 8 "maps" 9 "slices" 10 "sort" 11 "strings" 12 "time" 13 14 "github.com/bluesky-social/indigo/atproto/syntax" 15 "tangled.org/core/appview/models" 16) 17 18func NewPull(tx *sql.Tx, pull *models.Pull) error { 19 _, err := tx.Exec(` 20 insert or ignore into repo_pull_seqs (repo_at, next_pull_id) 21 values (?, 1) 22 `, pull.RepoAt) 23 if err != nil { 24 return err 25 } 26 27 var nextId int 28 err = tx.QueryRow(` 29 update repo_pull_seqs 30 set next_pull_id = next_pull_id + 1 31 where repo_at = ? 32 returning next_pull_id - 1 33 `, pull.RepoAt).Scan(&nextId) 34 if err != nil { 35 return err 36 } 37 38 pull.PullId = nextId 39 pull.State = models.PullOpen 40 41 var sourceBranch, sourceRepoAt *string 42 if pull.PullSource != nil { 43 sourceBranch = &pull.PullSource.Branch 44 if pull.PullSource.RepoAt != nil { 45 x := pull.PullSource.RepoAt.String() 46 sourceRepoAt = &x 47 } 48 } 49 50 var stackId, changeId, parentChangeId *string 51 if pull.StackId != "" { 52 stackId = &pull.StackId 53 } 54 if pull.ChangeId != "" { 55 changeId = &pull.ChangeId 56 } 57 if pull.ParentChangeId != "" { 58 parentChangeId = &pull.ParentChangeId 59 } 60 61 result, err := tx.Exec( 62 ` 63 insert into pulls ( 64 repo_at, owner_did, pull_id, title, target_branch, body, rkey, state, source_branch, source_repo_at, stack_id, change_id, parent_change_id 65 ) 66 values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`, 67 pull.RepoAt, 68 pull.OwnerDid, 69 pull.PullId, 70 pull.Title, 71 pull.TargetBranch, 72 pull.Body, 73 pull.Rkey, 74 pull.State, 75 sourceBranch, 76 sourceRepoAt, 77 stackId, 78 changeId, 79 parentChangeId, 80 ) 81 if err != nil { 82 return err 83 } 84 85 // Set the database primary key ID 86 id, err := result.LastInsertId() 87 if err != nil { 88 return err 89 } 90 pull.ID = int(id) 91 92 _, err = tx.Exec(` 93 insert into pull_submissions (pull_at, round_number, patch, combined, source_rev) 94 values (?, ?, ?, ?, ?) 95 `, pull.AtUri(), 0, pull.Submissions[0].Patch, pull.Submissions[0].Combined, pull.Submissions[0].SourceRev) 96 if err != nil { 97 return err 98 } 99 100 if err := putReferences(tx, pull.AtUri(), pull.References); err != nil { 101 return fmt.Errorf("put reference_links: %w", err) 102 } 103 104 return nil 105} 106 107func GetPullAt(e Execer, repoAt syntax.ATURI, pullId int) (syntax.ATURI, error) { 108 pull, err := GetPull(e, repoAt, pullId) 109 if err != nil { 110 return "", err 111 } 112 return pull.AtUri(), err 113} 114 115func NextPullId(e Execer, repoAt syntax.ATURI) (int, error) { 116 var pullId int 117 err := e.QueryRow(`select next_pull_id from repo_pull_seqs where repo_at = ?`, repoAt).Scan(&pullId) 118 return pullId - 1, err 119} 120 121func GetPullsWithLimit(e Execer, limit int, filters ...filter) ([]*models.Pull, error) { 122 pulls := make(map[syntax.ATURI]*models.Pull) 123 124 var conditions []string 125 var args []any 126 for _, filter := range filters { 127 conditions = append(conditions, filter.Condition()) 128 args = append(args, filter.Arg()...) 129 } 130 131 whereClause := "" 132 if conditions != nil { 133 whereClause = " where " + strings.Join(conditions, " and ") 134 } 135 limitClause := "" 136 if limit != 0 { 137 limitClause = fmt.Sprintf(" limit %d ", limit) 138 } 139 140 query := fmt.Sprintf(` 141 select 142 id, 143 owner_did, 144 repo_at, 145 pull_id, 146 created, 147 title, 148 state, 149 target_branch, 150 body, 151 rkey, 152 source_branch, 153 source_repo_at, 154 stack_id, 155 change_id, 156 parent_change_id 157 from 158 pulls 159 %s 160 order by 161 created desc 162 %s 163 `, whereClause, limitClause) 164 165 rows, err := e.Query(query, args...) 166 if err != nil { 167 return nil, err 168 } 169 defer rows.Close() 170 171 for rows.Next() { 172 var pull models.Pull 173 var createdAt string 174 var sourceBranch, sourceRepoAt, stackId, changeId, parentChangeId sql.NullString 175 err := rows.Scan( 176 &pull.ID, 177 &pull.OwnerDid, 178 &pull.RepoAt, 179 &pull.PullId, 180 &createdAt, 181 &pull.Title, 182 &pull.State, 183 &pull.TargetBranch, 184 &pull.Body, 185 &pull.Rkey, 186 &sourceBranch, 187 &sourceRepoAt, 188 &stackId, 189 &changeId, 190 &parentChangeId, 191 ) 192 if err != nil { 193 return nil, err 194 } 195 196 createdTime, err := time.Parse(time.RFC3339, createdAt) 197 if err != nil { 198 return nil, err 199 } 200 pull.Created = createdTime 201 202 if sourceBranch.Valid { 203 pull.PullSource = &models.PullSource{ 204 Branch: sourceBranch.String, 205 } 206 if sourceRepoAt.Valid { 207 sourceRepoAtParsed, err := syntax.ParseATURI(sourceRepoAt.String) 208 if err != nil { 209 return nil, err 210 } 211 pull.PullSource.RepoAt = &sourceRepoAtParsed 212 } 213 } 214 215 if stackId.Valid { 216 pull.StackId = stackId.String 217 } 218 if changeId.Valid { 219 pull.ChangeId = changeId.String 220 } 221 if parentChangeId.Valid { 222 pull.ParentChangeId = parentChangeId.String 223 } 224 225 pulls[pull.AtUri()] = &pull 226 } 227 228 var pullAts []syntax.ATURI 229 for _, p := range pulls { 230 pullAts = append(pullAts, p.AtUri()) 231 } 232 submissionsMap, err := GetPullSubmissions(e, FilterIn("pull_at", pullAts)) 233 if err != nil { 234 return nil, fmt.Errorf("failed to get submissions: %w", err) 235 } 236 237 for pullAt, submissions := range submissionsMap { 238 if p, ok := pulls[pullAt]; ok { 239 p.Submissions = submissions 240 } 241 } 242 243 // collect allLabels for each issue 244 allLabels, err := GetLabels(e, FilterIn("subject", pullAts)) 245 if err != nil { 246 return nil, fmt.Errorf("failed to query labels: %w", err) 247 } 248 for pullAt, labels := range allLabels { 249 if p, ok := pulls[pullAt]; ok { 250 p.Labels = labels 251 } 252 } 253 254 // collect pull source for all pulls that need it 255 var sourceAts []syntax.ATURI 256 for _, p := range pulls { 257 if p.PullSource != nil && p.PullSource.RepoAt != nil { 258 sourceAts = append(sourceAts, *p.PullSource.RepoAt) 259 } 260 } 261 sourceRepos, err := GetRepos(e, 0, FilterIn("at_uri", sourceAts)) 262 if err != nil && !errors.Is(err, sql.ErrNoRows) { 263 return nil, fmt.Errorf("failed to get source repos: %w", err) 264 } 265 sourceRepoMap := make(map[syntax.ATURI]*models.Repo) 266 for _, r := range sourceRepos { 267 sourceRepoMap[r.RepoAt()] = &r 268 } 269 for _, p := range pulls { 270 if p.PullSource != nil && p.PullSource.RepoAt != nil { 271 if sourceRepo, ok := sourceRepoMap[*p.PullSource.RepoAt]; ok { 272 p.PullSource.Repo = sourceRepo 273 } 274 } 275 } 276 277 allReferences, err := GetReferencesAll(e, FilterIn("from_at", pullAts)) 278 if err != nil { 279 return nil, fmt.Errorf("failed to query reference_links: %w", err) 280 } 281 for pullAt, references := range allReferences { 282 if pull, ok := pulls[pullAt]; ok { 283 pull.References = references 284 } 285 } 286 287 orderedByPullId := []*models.Pull{} 288 for _, p := range pulls { 289 orderedByPullId = append(orderedByPullId, p) 290 } 291 sort.Slice(orderedByPullId, func(i, j int) bool { 292 return orderedByPullId[i].PullId > orderedByPullId[j].PullId 293 }) 294 295 return orderedByPullId, nil 296} 297 298func GetPulls(e Execer, filters ...filter) ([]*models.Pull, error) { 299 return GetPullsWithLimit(e, 0, filters...) 300} 301 302func GetPullIDs(e Execer, opts models.PullSearchOptions) ([]int64, error) { 303 var ids []int64 304 305 var filters []filter 306 filters = append(filters, FilterEq("state", opts.State)) 307 if opts.RepoAt != "" { 308 filters = append(filters, FilterEq("repo_at", opts.RepoAt)) 309 } 310 311 var conditions []string 312 var args []any 313 314 for _, filter := range filters { 315 conditions = append(conditions, filter.Condition()) 316 args = append(args, filter.Arg()...) 317 } 318 319 whereClause := "" 320 if conditions != nil { 321 whereClause = " where " + strings.Join(conditions, " and ") 322 } 323 pageClause := "" 324 if opts.Page.Limit != 0 { 325 pageClause = fmt.Sprintf( 326 " limit %d offset %d ", 327 opts.Page.Limit, 328 opts.Page.Offset, 329 ) 330 } 331 332 query := fmt.Sprintf( 333 ` 334 select 335 id 336 from 337 pulls 338 %s 339 %s`, 340 whereClause, 341 pageClause, 342 ) 343 args = append(args, opts.Page.Limit, opts.Page.Offset) 344 rows, err := e.Query(query, args...) 345 if err != nil { 346 return nil, err 347 } 348 defer rows.Close() 349 350 for rows.Next() { 351 var id int64 352 err := rows.Scan(&id) 353 if err != nil { 354 return nil, err 355 } 356 357 ids = append(ids, id) 358 } 359 360 return ids, nil 361} 362 363func GetPull(e Execer, repoAt syntax.ATURI, pullId int) (*models.Pull, error) { 364 pulls, err := GetPullsWithLimit(e, 1, FilterEq("repo_at", repoAt), FilterEq("pull_id", pullId)) 365 if err != nil { 366 return nil, err 367 } 368 if len(pulls) == 0 { 369 return nil, sql.ErrNoRows 370 } 371 372 return pulls[0], nil 373} 374 375// mapping from pull -> pull submissions 376func GetPullSubmissions(e Execer, filters ...filter) (map[syntax.ATURI][]*models.PullSubmission, error) { 377 var conditions []string 378 var args []any 379 for _, filter := range filters { 380 conditions = append(conditions, filter.Condition()) 381 args = append(args, filter.Arg()...) 382 } 383 384 whereClause := "" 385 if conditions != nil { 386 whereClause = " where " + strings.Join(conditions, " and ") 387 } 388 389 query := fmt.Sprintf(` 390 select 391 id, 392 pull_at, 393 round_number, 394 patch, 395 combined, 396 created, 397 source_rev 398 from 399 pull_submissions 400 %s 401 order by 402 round_number asc 403 `, whereClause) 404 405 rows, err := e.Query(query, args...) 406 if err != nil { 407 return nil, err 408 } 409 defer rows.Close() 410 411 submissionMap := make(map[int]*models.PullSubmission) 412 413 for rows.Next() { 414 var submission models.PullSubmission 415 var submissionCreatedStr string 416 var submissionSourceRev, submissionCombined sql.NullString 417 err := rows.Scan( 418 &submission.ID, 419 &submission.PullAt, 420 &submission.RoundNumber, 421 &submission.Patch, 422 &submissionCombined, 423 &submissionCreatedStr, 424 &submissionSourceRev, 425 ) 426 if err != nil { 427 return nil, err 428 } 429 430 if t, err := time.Parse(time.RFC3339, submissionCreatedStr); err == nil { 431 submission.Created = t 432 } 433 434 if submissionSourceRev.Valid { 435 submission.SourceRev = submissionSourceRev.String 436 } 437 438 if submissionCombined.Valid { 439 submission.Combined = submissionCombined.String 440 } 441 442 submissionMap[submission.ID] = &submission 443 } 444 445 if err := rows.Err(); err != nil { 446 return nil, err 447 } 448 449 // Get comments for all submissions using GetPullComments 450 submissionIds := slices.Collect(maps.Keys(submissionMap)) 451 comments, err := GetPullComments(e, FilterIn("submission_id", submissionIds)) 452 if err != nil { 453 return nil, fmt.Errorf("failed to get pull comments: %w", err) 454 } 455 for _, comment := range comments { 456 if submission, ok := submissionMap[comment.SubmissionId]; ok { 457 submission.Comments = append(submission.Comments, comment) 458 } 459 } 460 461 // group the submissions by pull_at 462 m := make(map[syntax.ATURI][]*models.PullSubmission) 463 for _, s := range submissionMap { 464 m[s.PullAt] = append(m[s.PullAt], s) 465 } 466 467 // sort each one by round number 468 for _, s := range m { 469 slices.SortFunc(s, func(a, b *models.PullSubmission) int { 470 return cmp.Compare(a.RoundNumber, b.RoundNumber) 471 }) 472 } 473 474 return m, nil 475} 476 477func GetPullComments(e Execer, filters ...filter) ([]models.PullComment, error) { 478 var conditions []string 479 var args []any 480 for _, filter := range filters { 481 conditions = append(conditions, filter.Condition()) 482 args = append(args, filter.Arg()...) 483 } 484 485 whereClause := "" 486 if conditions != nil { 487 whereClause = " where " + strings.Join(conditions, " and ") 488 } 489 490 query := fmt.Sprintf(` 491 select 492 id, 493 pull_id, 494 submission_id, 495 repo_at, 496 owner_did, 497 comment_at, 498 body, 499 created 500 from 501 pull_comments 502 %s 503 order by 504 created asc 505 `, whereClause) 506 507 rows, err := e.Query(query, args...) 508 if err != nil { 509 return nil, err 510 } 511 defer rows.Close() 512 513 commentMap := make(map[string]*models.PullComment) 514 for rows.Next() { 515 var comment models.PullComment 516 var createdAt string 517 err := rows.Scan( 518 &comment.ID, 519 &comment.PullId, 520 &comment.SubmissionId, 521 &comment.RepoAt, 522 &comment.OwnerDid, 523 &comment.CommentAt, 524 &comment.Body, 525 &createdAt, 526 ) 527 if err != nil { 528 return nil, err 529 } 530 531 if t, err := time.Parse(time.RFC3339, createdAt); err == nil { 532 comment.Created = t 533 } 534 535 atUri := comment.AtUri().String() 536 commentMap[atUri] = &comment 537 } 538 539 if err := rows.Err(); err != nil { 540 return nil, err 541 } 542 543 // collect references for each comments 544 commentAts := slices.Collect(maps.Keys(commentMap)) 545 allReferencs, err := GetReferencesAll(e, FilterIn("from_at", commentAts)) 546 if err != nil { 547 return nil, fmt.Errorf("failed to query reference_links: %w", err) 548 } 549 for commentAt, references := range allReferencs { 550 if comment, ok := commentMap[commentAt.String()]; ok { 551 comment.References = references 552 } 553 } 554 555 var comments []models.PullComment 556 for _, c := range commentMap { 557 comments = append(comments, *c) 558 } 559 560 sort.Slice(comments, func(i, j int) bool { 561 return comments[i].Created.Before(comments[j].Created) 562 }) 563 564 return comments, nil 565} 566 567// timeframe here is directly passed into the sql query filter, and any 568// timeframe in the past should be negative; e.g.: "-3 months" 569func GetPullsByOwnerDid(e Execer, did, timeframe string) ([]models.Pull, error) { 570 var pulls []models.Pull 571 572 rows, err := e.Query(` 573 select 574 p.owner_did, 575 p.repo_at, 576 p.pull_id, 577 p.created, 578 p.title, 579 p.state, 580 r.did, 581 r.name, 582 r.knot, 583 r.rkey, 584 r.created 585 from 586 pulls p 587 join 588 repos r on p.repo_at = r.at_uri 589 where 590 p.owner_did = ? and p.created >= date ('now', ?) 591 order by 592 p.created desc`, did, timeframe) 593 if err != nil { 594 return nil, err 595 } 596 defer rows.Close() 597 598 for rows.Next() { 599 var pull models.Pull 600 var repo models.Repo 601 var pullCreatedAt, repoCreatedAt string 602 err := rows.Scan( 603 &pull.OwnerDid, 604 &pull.RepoAt, 605 &pull.PullId, 606 &pullCreatedAt, 607 &pull.Title, 608 &pull.State, 609 &repo.Did, 610 &repo.Name, 611 &repo.Knot, 612 &repo.Rkey, 613 &repoCreatedAt, 614 ) 615 if err != nil { 616 return nil, err 617 } 618 619 pullCreatedTime, err := time.Parse(time.RFC3339, pullCreatedAt) 620 if err != nil { 621 return nil, err 622 } 623 pull.Created = pullCreatedTime 624 625 repoCreatedTime, err := time.Parse(time.RFC3339, repoCreatedAt) 626 if err != nil { 627 return nil, err 628 } 629 repo.Created = repoCreatedTime 630 631 pull.Repo = &repo 632 633 pulls = append(pulls, pull) 634 } 635 636 if err := rows.Err(); err != nil { 637 return nil, err 638 } 639 640 return pulls, nil 641} 642 643func NewPullComment(tx *sql.Tx, comment *models.PullComment) (int64, error) { 644 query := `insert into pull_comments (owner_did, repo_at, submission_id, comment_at, pull_id, body) values (?, ?, ?, ?, ?, ?)` 645 res, err := tx.Exec( 646 query, 647 comment.OwnerDid, 648 comment.RepoAt, 649 comment.SubmissionId, 650 comment.CommentAt, 651 comment.PullId, 652 comment.Body, 653 ) 654 if err != nil { 655 return 0, err 656 } 657 658 i, err := res.LastInsertId() 659 if err != nil { 660 return 0, err 661 } 662 663 if err := putReferences(tx, comment.AtUri(), comment.References); err != nil { 664 return 0, fmt.Errorf("put reference_links: %w", err) 665 } 666 667 return i, nil 668} 669 670func SetPullState(e Execer, repoAt syntax.ATURI, pullId int, pullState models.PullState) error { 671 _, err := e.Exec( 672 `update pulls set state = ? where repo_at = ? and pull_id = ? and (state <> ? or state <> ?)`, 673 pullState, 674 repoAt, 675 pullId, 676 models.PullDeleted, // only update state of non-deleted pulls 677 models.PullMerged, // only update state of non-merged pulls 678 ) 679 return err 680} 681 682func ClosePull(e Execer, repoAt syntax.ATURI, pullId int) error { 683 err := SetPullState(e, repoAt, pullId, models.PullClosed) 684 return err 685} 686 687func ReopenPull(e Execer, repoAt syntax.ATURI, pullId int) error { 688 err := SetPullState(e, repoAt, pullId, models.PullOpen) 689 return err 690} 691 692func MergePull(e Execer, repoAt syntax.ATURI, pullId int) error { 693 err := SetPullState(e, repoAt, pullId, models.PullMerged) 694 return err 695} 696 697func DeletePull(e Execer, repoAt syntax.ATURI, pullId int) error { 698 err := SetPullState(e, repoAt, pullId, models.PullDeleted) 699 return err 700} 701 702func ResubmitPull(e Execer, pullAt syntax.ATURI, newRoundNumber int, newPatch string, combinedPatch string, newSourceRev string) error { 703 _, err := e.Exec(` 704 insert into pull_submissions (pull_at, round_number, patch, combined, source_rev) 705 values (?, ?, ?, ?, ?) 706 `, pullAt, newRoundNumber, newPatch, combinedPatch, newSourceRev) 707 708 return err 709} 710 711func SetPullParentChangeId(e Execer, parentChangeId string, filters ...filter) error { 712 var conditions []string 713 var args []any 714 715 args = append(args, parentChangeId) 716 717 for _, filter := range filters { 718 conditions = append(conditions, filter.Condition()) 719 args = append(args, filter.Arg()...) 720 } 721 722 whereClause := "" 723 if conditions != nil { 724 whereClause = " where " + strings.Join(conditions, " and ") 725 } 726 727 query := fmt.Sprintf("update pulls set parent_change_id = ? %s", whereClause) 728 _, err := e.Exec(query, args...) 729 730 return err 731} 732 733// Only used when stacking to update contents in the event of a rebase (the interdiff should be empty). 734// otherwise submissions are immutable 735func UpdatePull(e Execer, newPatch, sourceRev string, filters ...filter) error { 736 var conditions []string 737 var args []any 738 739 args = append(args, sourceRev) 740 args = append(args, newPatch) 741 742 for _, filter := range filters { 743 conditions = append(conditions, filter.Condition()) 744 args = append(args, filter.Arg()...) 745 } 746 747 whereClause := "" 748 if conditions != nil { 749 whereClause = " where " + strings.Join(conditions, " and ") 750 } 751 752 query := fmt.Sprintf("update pull_submissions set source_rev = ?, patch = ? %s", whereClause) 753 _, err := e.Exec(query, args...) 754 755 return err 756} 757 758func GetPullCount(e Execer, repoAt syntax.ATURI) (models.PullCount, error) { 759 row := e.QueryRow(` 760 select 761 count(case when state = ? then 1 end) as open_count, 762 count(case when state = ? then 1 end) as merged_count, 763 count(case when state = ? then 1 end) as closed_count, 764 count(case when state = ? then 1 end) as deleted_count 765 from pulls 766 where repo_at = ?`, 767 models.PullOpen, 768 models.PullMerged, 769 models.PullClosed, 770 models.PullDeleted, 771 repoAt, 772 ) 773 774 var count models.PullCount 775 if err := row.Scan(&count.Open, &count.Merged, &count.Closed, &count.Deleted); err != nil { 776 return models.PullCount{Open: 0, Merged: 0, Closed: 0, Deleted: 0}, err 777 } 778 779 return count, nil 780} 781 782// change-id parent-change-id 783// 784// 4 w ,-------- z (TOP) 785// 3 z <----',------- y 786// 2 y <-----',------ x 787// 1 x <------' nil (BOT) 788// 789// `w` is parent of none, so it is the top of the stack 790func GetStack(e Execer, stackId string) (models.Stack, error) { 791 unorderedPulls, err := GetPulls( 792 e, 793 FilterEq("stack_id", stackId), 794 FilterNotEq("state", models.PullDeleted), 795 ) 796 if err != nil { 797 return nil, err 798 } 799 // map of parent-change-id to pull 800 changeIdMap := make(map[string]*models.Pull, len(unorderedPulls)) 801 parentMap := make(map[string]*models.Pull, len(unorderedPulls)) 802 for _, p := range unorderedPulls { 803 changeIdMap[p.ChangeId] = p 804 if p.ParentChangeId != "" { 805 parentMap[p.ParentChangeId] = p 806 } 807 } 808 809 // the top of the stack is the pull that is not a parent of any pull 810 var topPull *models.Pull 811 for _, maybeTop := range unorderedPulls { 812 if _, ok := parentMap[maybeTop.ChangeId]; !ok { 813 topPull = maybeTop 814 break 815 } 816 } 817 818 pulls := []*models.Pull{} 819 for { 820 pulls = append(pulls, topPull) 821 if topPull.ParentChangeId != "" { 822 if next, ok := changeIdMap[topPull.ParentChangeId]; ok { 823 topPull = next 824 } else { 825 return nil, fmt.Errorf("failed to find parent pull request, stack is malformed") 826 } 827 } else { 828 break 829 } 830 } 831 832 return pulls, nil 833} 834 835func GetAbandonedPulls(e Execer, stackId string) ([]*models.Pull, error) { 836 pulls, err := GetPulls( 837 e, 838 FilterEq("stack_id", stackId), 839 FilterEq("state", models.PullDeleted), 840 ) 841 if err != nil { 842 return nil, err 843 } 844 845 return pulls, nil 846}