forked from tangled.org/core
Monorepo for Tangled — https://tangled.org
at master 19 kB view raw
1package db 2 3import ( 4 "cmp" 5 "database/sql" 6 "errors" 7 "fmt" 8 "maps" 9 "slices" 10 "sort" 11 "strings" 12 "time" 13 14 "github.com/bluesky-social/indigo/atproto/syntax" 15 "tangled.org/core/appview/models" 16 "tangled.org/core/orm" 17) 18 19func NewPull(tx *sql.Tx, pull *models.Pull) error { 20 _, err := tx.Exec(` 21 insert or ignore into repo_pull_seqs (repo_at, next_pull_id) 22 values (?, 1) 23 `, pull.RepoAt) 24 if err != nil { 25 return err 26 } 27 28 var nextId int 29 err = tx.QueryRow(` 30 update repo_pull_seqs 31 set next_pull_id = next_pull_id + 1 32 where repo_at = ? 33 returning next_pull_id - 1 34 `, pull.RepoAt).Scan(&nextId) 35 if err != nil { 36 return err 37 } 38 39 pull.PullId = nextId 40 pull.State = models.PullOpen 41 42 var sourceBranch, sourceRepoAt *string 43 if pull.PullSource != nil { 44 sourceBranch = &pull.PullSource.Branch 45 if pull.PullSource.RepoAt != nil { 46 x := pull.PullSource.RepoAt.String() 47 sourceRepoAt = &x 48 } 49 } 50 51 var stackId, changeId, parentChangeId *string 52 if pull.StackId != "" { 53 stackId = &pull.StackId 54 } 55 if pull.ChangeId != "" { 56 changeId = &pull.ChangeId 57 } 58 if pull.ParentChangeId != "" { 59 parentChangeId = &pull.ParentChangeId 60 } 61 62 result, err := tx.Exec( 63 ` 64 insert into pulls ( 65 repo_at, owner_did, pull_id, title, target_branch, body, rkey, state, source_branch, source_repo_at, stack_id, change_id, parent_change_id 66 ) 67 values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`, 68 pull.RepoAt, 69 pull.OwnerDid, 70 pull.PullId, 71 pull.Title, 72 pull.TargetBranch, 73 pull.Body, 74 pull.Rkey, 75 pull.State, 76 sourceBranch, 77 sourceRepoAt, 78 stackId, 79 changeId, 80 parentChangeId, 81 ) 82 if err != nil { 83 return err 84 } 85 86 // Set the database primary key ID 87 id, err := result.LastInsertId() 88 if err != nil { 89 return err 90 } 91 pull.ID = int(id) 92 93 _, err = tx.Exec(` 94 insert into pull_submissions (pull_at, round_number, patch, combined, source_rev) 95 values (?, ?, ?, ?, ?) 96 `, pull.AtUri(), 0, pull.Submissions[0].Patch, pull.Submissions[0].Combined, pull.Submissions[0].SourceRev) 97 if err != nil { 98 return err 99 } 100 101 if err := putReferences(tx, pull.AtUri(), pull.References); err != nil { 102 return fmt.Errorf("put reference_links: %w", err) 103 } 104 105 return nil 106} 107 108func GetPullAt(e Execer, repoAt syntax.ATURI, pullId int) (syntax.ATURI, error) { 109 pull, err := GetPull(e, repoAt, pullId) 110 if err != nil { 111 return "", err 112 } 113 return pull.AtUri(), err 114} 115 116func NextPullId(e Execer, repoAt syntax.ATURI) (int, error) { 117 var pullId int 118 err := e.QueryRow(`select next_pull_id from repo_pull_seqs where repo_at = ?`, repoAt).Scan(&pullId) 119 return pullId - 1, err 120} 121 122func GetPullsWithLimit(e Execer, limit int, filters ...orm.Filter) ([]*models.Pull, error) { 123 pulls := make(map[syntax.ATURI]*models.Pull) 124 125 var conditions []string 126 var args []any 127 for _, filter := range filters { 128 conditions = append(conditions, filter.Condition()) 129 args = append(args, filter.Arg()...) 130 } 131 132 whereClause := "" 133 if conditions != nil { 134 whereClause = " where " + strings.Join(conditions, " and ") 135 } 136 limitClause := "" 137 if limit != 0 { 138 limitClause = fmt.Sprintf(" limit %d ", limit) 139 } 140 141 query := fmt.Sprintf(` 142 select 143 id, 144 owner_did, 145 repo_at, 146 pull_id, 147 created, 148 title, 149 state, 150 target_branch, 151 body, 152 rkey, 153 source_branch, 154 source_repo_at, 155 stack_id, 156 change_id, 157 parent_change_id 158 from 159 pulls 160 %s 161 order by 162 created desc 163 %s 164 `, whereClause, limitClause) 165 166 rows, err := e.Query(query, args...) 167 if err != nil { 168 return nil, err 169 } 170 defer rows.Close() 171 172 for rows.Next() { 173 var pull models.Pull 174 var createdAt string 175 var sourceBranch, sourceRepoAt, stackId, changeId, parentChangeId sql.NullString 176 err := rows.Scan( 177 &pull.ID, 178 &pull.OwnerDid, 179 &pull.RepoAt, 180 &pull.PullId, 181 &createdAt, 182 &pull.Title, 183 &pull.State, 184 &pull.TargetBranch, 185 &pull.Body, 186 &pull.Rkey, 187 &sourceBranch, 188 &sourceRepoAt, 189 &stackId, 190 &changeId, 191 &parentChangeId, 192 ) 193 if err != nil { 194 return nil, err 195 } 196 197 createdTime, err := time.Parse(time.RFC3339, createdAt) 198 if err != nil { 199 return nil, err 200 } 201 pull.Created = createdTime 202 203 if sourceBranch.Valid { 204 pull.PullSource = &models.PullSource{ 205 Branch: sourceBranch.String, 206 } 207 if sourceRepoAt.Valid { 208 sourceRepoAtParsed, err := syntax.ParseATURI(sourceRepoAt.String) 209 if err != nil { 210 return nil, err 211 } 212 pull.PullSource.RepoAt = &sourceRepoAtParsed 213 } 214 } 215 216 if stackId.Valid { 217 pull.StackId = stackId.String 218 } 219 if changeId.Valid { 220 pull.ChangeId = changeId.String 221 } 222 if parentChangeId.Valid { 223 pull.ParentChangeId = parentChangeId.String 224 } 225 226 pulls[pull.AtUri()] = &pull 227 } 228 229 var pullAts []syntax.ATURI 230 for _, p := range pulls { 231 pullAts = append(pullAts, p.AtUri()) 232 } 233 submissionsMap, err := GetPullSubmissions(e, orm.FilterIn("pull_at", pullAts)) 234 if err != nil { 235 return nil, fmt.Errorf("failed to get submissions: %w", err) 236 } 237 238 for pullAt, submissions := range submissionsMap { 239 if p, ok := pulls[pullAt]; ok { 240 p.Submissions = submissions 241 } 242 } 243 244 // collect allLabels for each issue 245 allLabels, err := GetLabels(e, orm.FilterIn("subject", pullAts)) 246 if err != nil { 247 return nil, fmt.Errorf("failed to query labels: %w", err) 248 } 249 for pullAt, labels := range allLabels { 250 if p, ok := pulls[pullAt]; ok { 251 p.Labels = labels 252 } 253 } 254 255 // collect pull source for all pulls that need it 256 var sourceAts []syntax.ATURI 257 for _, p := range pulls { 258 if p.PullSource != nil && p.PullSource.RepoAt != nil { 259 sourceAts = append(sourceAts, *p.PullSource.RepoAt) 260 } 261 } 262 sourceRepos, err := GetRepos(e, 0, orm.FilterIn("at_uri", sourceAts)) 263 if err != nil && !errors.Is(err, sql.ErrNoRows) { 264 return nil, fmt.Errorf("failed to get source repos: %w", err) 265 } 266 sourceRepoMap := make(map[syntax.ATURI]*models.Repo) 267 for _, r := range sourceRepos { 268 sourceRepoMap[r.RepoAt()] = &r 269 } 270 for _, p := range pulls { 271 if p.PullSource != nil && p.PullSource.RepoAt != nil { 272 if sourceRepo, ok := sourceRepoMap[*p.PullSource.RepoAt]; ok { 273 p.PullSource.Repo = sourceRepo 274 } 275 } 276 } 277 278 allReferences, err := GetReferencesAll(e, orm.FilterIn("from_at", pullAts)) 279 if err != nil { 280 return nil, fmt.Errorf("failed to query reference_links: %w", err) 281 } 282 for pullAt, references := range allReferences { 283 if pull, ok := pulls[pullAt]; ok { 284 pull.References = references 285 } 286 } 287 288 orderedByPullId := []*models.Pull{} 289 for _, p := range pulls { 290 orderedByPullId = append(orderedByPullId, p) 291 } 292 sort.Slice(orderedByPullId, func(i, j int) bool { 293 return orderedByPullId[i].PullId > orderedByPullId[j].PullId 294 }) 295 296 return orderedByPullId, nil 297} 298 299func GetPulls(e Execer, filters ...orm.Filter) ([]*models.Pull, error) { 300 return GetPullsWithLimit(e, 0, filters...) 301} 302 303func GetPullIDs(e Execer, opts models.PullSearchOptions) ([]int64, error) { 304 var ids []int64 305 306 var filters []orm.Filter 307 filters = append(filters, orm.FilterEq("state", opts.State)) 308 if opts.RepoAt != "" { 309 filters = append(filters, orm.FilterEq("repo_at", opts.RepoAt)) 310 } 311 312 var conditions []string 313 var args []any 314 315 for _, filter := range filters { 316 conditions = append(conditions, filter.Condition()) 317 args = append(args, filter.Arg()...) 318 } 319 320 whereClause := "" 321 if conditions != nil { 322 whereClause = " where " + strings.Join(conditions, " and ") 323 } 324 pageClause := "" 325 if opts.Page.Limit != 0 { 326 pageClause = fmt.Sprintf( 327 " limit %d offset %d ", 328 opts.Page.Limit, 329 opts.Page.Offset, 330 ) 331 } 332 333 query := fmt.Sprintf( 334 ` 335 select 336 id 337 from 338 pulls 339 %s 340 %s`, 341 whereClause, 342 pageClause, 343 ) 344 args = append(args, opts.Page.Limit, opts.Page.Offset) 345 rows, err := e.Query(query, args...) 346 if err != nil { 347 return nil, err 348 } 349 defer rows.Close() 350 351 for rows.Next() { 352 var id int64 353 err := rows.Scan(&id) 354 if err != nil { 355 return nil, err 356 } 357 358 ids = append(ids, id) 359 } 360 361 return ids, nil 362} 363 364func GetPull(e Execer, repoAt syntax.ATURI, pullId int) (*models.Pull, error) { 365 pulls, err := GetPullsWithLimit(e, 1, orm.FilterEq("repo_at", repoAt), orm.FilterEq("pull_id", pullId)) 366 if err != nil { 367 return nil, err 368 } 369 if len(pulls) == 0 { 370 return nil, sql.ErrNoRows 371 } 372 373 return pulls[0], nil 374} 375 376// mapping from pull -> pull submissions 377func GetPullSubmissions(e Execer, filters ...orm.Filter) (map[syntax.ATURI][]*models.PullSubmission, error) { 378 var conditions []string 379 var args []any 380 for _, filter := range filters { 381 conditions = append(conditions, filter.Condition()) 382 args = append(args, filter.Arg()...) 383 } 384 385 whereClause := "" 386 if conditions != nil { 387 whereClause = " where " + strings.Join(conditions, " and ") 388 } 389 390 query := fmt.Sprintf(` 391 select 392 id, 393 pull_at, 394 round_number, 395 patch, 396 combined, 397 created, 398 source_rev 399 from 400 pull_submissions 401 %s 402 order by 403 round_number asc 404 `, whereClause) 405 406 rows, err := e.Query(query, args...) 407 if err != nil { 408 return nil, err 409 } 410 defer rows.Close() 411 412 submissionMap := make(map[int]*models.PullSubmission) 413 414 for rows.Next() { 415 var submission models.PullSubmission 416 var submissionCreatedStr string 417 var submissionSourceRev, submissionCombined sql.NullString 418 err := rows.Scan( 419 &submission.ID, 420 &submission.PullAt, 421 &submission.RoundNumber, 422 &submission.Patch, 423 &submissionCombined, 424 &submissionCreatedStr, 425 &submissionSourceRev, 426 ) 427 if err != nil { 428 return nil, err 429 } 430 431 if t, err := time.Parse(time.RFC3339, submissionCreatedStr); err == nil { 432 submission.Created = t 433 } 434 435 if submissionSourceRev.Valid { 436 submission.SourceRev = submissionSourceRev.String 437 } 438 439 if submissionCombined.Valid { 440 submission.Combined = submissionCombined.String 441 } 442 443 submissionMap[submission.ID] = &submission 444 } 445 446 if err := rows.Err(); err != nil { 447 return nil, err 448 } 449 450 // Get comments for all submissions using GetPullComments 451 submissionIds := slices.Collect(maps.Keys(submissionMap)) 452 comments, err := GetPullComments(e, orm.FilterIn("submission_id", submissionIds)) 453 if err != nil { 454 return nil, fmt.Errorf("failed to get pull comments: %w", err) 455 } 456 for _, comment := range comments { 457 if submission, ok := submissionMap[comment.SubmissionId]; ok { 458 submission.Comments = append(submission.Comments, comment) 459 } 460 } 461 462 // group the submissions by pull_at 463 m := make(map[syntax.ATURI][]*models.PullSubmission) 464 for _, s := range submissionMap { 465 m[s.PullAt] = append(m[s.PullAt], s) 466 } 467 468 // sort each one by round number 469 for _, s := range m { 470 slices.SortFunc(s, func(a, b *models.PullSubmission) int { 471 return cmp.Compare(a.RoundNumber, b.RoundNumber) 472 }) 473 } 474 475 return m, nil 476} 477 478func GetPullComments(e Execer, filters ...orm.Filter) ([]models.PullComment, error) { 479 var conditions []string 480 var args []any 481 for _, filter := range filters { 482 conditions = append(conditions, filter.Condition()) 483 args = append(args, filter.Arg()...) 484 } 485 486 whereClause := "" 487 if conditions != nil { 488 whereClause = " where " + strings.Join(conditions, " and ") 489 } 490 491 query := fmt.Sprintf(` 492 select 493 id, 494 pull_id, 495 submission_id, 496 repo_at, 497 owner_did, 498 comment_at, 499 body, 500 created 501 from 502 pull_comments 503 %s 504 order by 505 created asc 506 `, whereClause) 507 508 rows, err := e.Query(query, args...) 509 if err != nil { 510 return nil, err 511 } 512 defer rows.Close() 513 514 commentMap := make(map[string]*models.PullComment) 515 for rows.Next() { 516 var comment models.PullComment 517 var createdAt string 518 err := rows.Scan( 519 &comment.ID, 520 &comment.PullId, 521 &comment.SubmissionId, 522 &comment.RepoAt, 523 &comment.OwnerDid, 524 &comment.CommentAt, 525 &comment.Body, 526 &createdAt, 527 ) 528 if err != nil { 529 return nil, err 530 } 531 532 if t, err := time.Parse(time.RFC3339, createdAt); err == nil { 533 comment.Created = t 534 } 535 536 atUri := comment.AtUri().String() 537 commentMap[atUri] = &comment 538 } 539 540 if err := rows.Err(); err != nil { 541 return nil, err 542 } 543 544 // collect references for each comments 545 commentAts := slices.Collect(maps.Keys(commentMap)) 546 allReferencs, err := GetReferencesAll(e, orm.FilterIn("from_at", commentAts)) 547 if err != nil { 548 return nil, fmt.Errorf("failed to query reference_links: %w", err) 549 } 550 for commentAt, references := range allReferencs { 551 if comment, ok := commentMap[commentAt.String()]; ok { 552 comment.References = references 553 } 554 } 555 556 var comments []models.PullComment 557 for _, c := range commentMap { 558 comments = append(comments, *c) 559 } 560 561 sort.Slice(comments, func(i, j int) bool { 562 return comments[i].Created.Before(comments[j].Created) 563 }) 564 565 return comments, nil 566} 567 568// timeframe here is directly passed into the sql query filter, and any 569// timeframe in the past should be negative; e.g.: "-3 months" 570func GetPullsByOwnerDid(e Execer, did, timeframe string) ([]models.Pull, error) { 571 var pulls []models.Pull 572 573 rows, err := e.Query(` 574 select 575 p.owner_did, 576 p.repo_at, 577 p.pull_id, 578 p.created, 579 p.title, 580 p.state, 581 r.did, 582 r.name, 583 r.knot, 584 r.rkey, 585 r.created 586 from 587 pulls p 588 join 589 repos r on p.repo_at = r.at_uri 590 where 591 p.owner_did = ? and p.created >= date ('now', ?) 592 order by 593 p.created desc`, did, timeframe) 594 if err != nil { 595 return nil, err 596 } 597 defer rows.Close() 598 599 for rows.Next() { 600 var pull models.Pull 601 var repo models.Repo 602 var pullCreatedAt, repoCreatedAt string 603 err := rows.Scan( 604 &pull.OwnerDid, 605 &pull.RepoAt, 606 &pull.PullId, 607 &pullCreatedAt, 608 &pull.Title, 609 &pull.State, 610 &repo.Did, 611 &repo.Name, 612 &repo.Knot, 613 &repo.Rkey, 614 &repoCreatedAt, 615 ) 616 if err != nil { 617 return nil, err 618 } 619 620 pullCreatedTime, err := time.Parse(time.RFC3339, pullCreatedAt) 621 if err != nil { 622 return nil, err 623 } 624 pull.Created = pullCreatedTime 625 626 repoCreatedTime, err := time.Parse(time.RFC3339, repoCreatedAt) 627 if err != nil { 628 return nil, err 629 } 630 repo.Created = repoCreatedTime 631 632 pull.Repo = &repo 633 634 pulls = append(pulls, pull) 635 } 636 637 if err := rows.Err(); err != nil { 638 return nil, err 639 } 640 641 return pulls, nil 642} 643 644func NewPullComment(tx *sql.Tx, comment *models.PullComment) (int64, error) { 645 query := `insert into pull_comments (owner_did, repo_at, submission_id, comment_at, pull_id, body) values (?, ?, ?, ?, ?, ?)` 646 res, err := tx.Exec( 647 query, 648 comment.OwnerDid, 649 comment.RepoAt, 650 comment.SubmissionId, 651 comment.CommentAt, 652 comment.PullId, 653 comment.Body, 654 ) 655 if err != nil { 656 return 0, err 657 } 658 659 i, err := res.LastInsertId() 660 if err != nil { 661 return 0, err 662 } 663 664 if err := putReferences(tx, comment.AtUri(), comment.References); err != nil { 665 return 0, fmt.Errorf("put reference_links: %w", err) 666 } 667 668 return i, nil 669} 670 671func SetPullState(e Execer, repoAt syntax.ATURI, pullId int, pullState models.PullState) error { 672 _, err := e.Exec( 673 `update pulls set state = ? where repo_at = ? and pull_id = ? and (state <> ? or state <> ?)`, 674 pullState, 675 repoAt, 676 pullId, 677 models.PullDeleted, // only update state of non-deleted pulls 678 models.PullMerged, // only update state of non-merged pulls 679 ) 680 return err 681} 682 683func ClosePull(e Execer, repoAt syntax.ATURI, pullId int) error { 684 err := SetPullState(e, repoAt, pullId, models.PullClosed) 685 return err 686} 687 688func ReopenPull(e Execer, repoAt syntax.ATURI, pullId int) error { 689 err := SetPullState(e, repoAt, pullId, models.PullOpen) 690 return err 691} 692 693func MergePull(e Execer, repoAt syntax.ATURI, pullId int) error { 694 err := SetPullState(e, repoAt, pullId, models.PullMerged) 695 return err 696} 697 698func DeletePull(e Execer, repoAt syntax.ATURI, pullId int) error { 699 err := SetPullState(e, repoAt, pullId, models.PullDeleted) 700 return err 701} 702 703func ResubmitPull(e Execer, pullAt syntax.ATURI, newRoundNumber int, newPatch string, combinedPatch string, newSourceRev string) error { 704 _, err := e.Exec(` 705 insert into pull_submissions (pull_at, round_number, patch, combined, source_rev) 706 values (?, ?, ?, ?, ?) 707 `, pullAt, newRoundNumber, newPatch, combinedPatch, newSourceRev) 708 709 return err 710} 711 712func SetPullParentChangeId(e Execer, parentChangeId string, filters ...orm.Filter) error { 713 var conditions []string 714 var args []any 715 716 args = append(args, parentChangeId) 717 718 for _, filter := range filters { 719 conditions = append(conditions, filter.Condition()) 720 args = append(args, filter.Arg()...) 721 } 722 723 whereClause := "" 724 if conditions != nil { 725 whereClause = " where " + strings.Join(conditions, " and ") 726 } 727 728 query := fmt.Sprintf("update pulls set parent_change_id = ? %s", whereClause) 729 _, err := e.Exec(query, args...) 730 731 return err 732} 733 734// Only used when stacking to update contents in the event of a rebase (the interdiff should be empty). 735// otherwise submissions are immutable 736func UpdatePull(e Execer, newPatch, sourceRev string, filters ...orm.Filter) error { 737 var conditions []string 738 var args []any 739 740 args = append(args, sourceRev) 741 args = append(args, newPatch) 742 743 for _, filter := range filters { 744 conditions = append(conditions, filter.Condition()) 745 args = append(args, filter.Arg()...) 746 } 747 748 whereClause := "" 749 if conditions != nil { 750 whereClause = " where " + strings.Join(conditions, " and ") 751 } 752 753 query := fmt.Sprintf("update pull_submissions set source_rev = ?, patch = ? %s", whereClause) 754 _, err := e.Exec(query, args...) 755 756 return err 757} 758 759func GetPullCount(e Execer, repoAt syntax.ATURI) (models.PullCount, error) { 760 row := e.QueryRow(` 761 select 762 count(case when state = ? then 1 end) as open_count, 763 count(case when state = ? then 1 end) as merged_count, 764 count(case when state = ? then 1 end) as closed_count, 765 count(case when state = ? then 1 end) as deleted_count 766 from pulls 767 where repo_at = ?`, 768 models.PullOpen, 769 models.PullMerged, 770 models.PullClosed, 771 models.PullDeleted, 772 repoAt, 773 ) 774 775 var count models.PullCount 776 if err := row.Scan(&count.Open, &count.Merged, &count.Closed, &count.Deleted); err != nil { 777 return models.PullCount{Open: 0, Merged: 0, Closed: 0, Deleted: 0}, err 778 } 779 780 return count, nil 781} 782 783// change-id parent-change-id 784// 785// 4 w ,-------- z (TOP) 786// 3 z <----',------- y 787// 2 y <-----',------ x 788// 1 x <------' nil (BOT) 789// 790// `w` is parent of none, so it is the top of the stack 791func GetStack(e Execer, stackId string) (models.Stack, error) { 792 unorderedPulls, err := GetPulls( 793 e, 794 orm.FilterEq("stack_id", stackId), 795 orm.FilterNotEq("state", models.PullDeleted), 796 ) 797 if err != nil { 798 return nil, err 799 } 800 // map of parent-change-id to pull 801 changeIdMap := make(map[string]*models.Pull, len(unorderedPulls)) 802 parentMap := make(map[string]*models.Pull, len(unorderedPulls)) 803 for _, p := range unorderedPulls { 804 changeIdMap[p.ChangeId] = p 805 if p.ParentChangeId != "" { 806 parentMap[p.ParentChangeId] = p 807 } 808 } 809 810 // the top of the stack is the pull that is not a parent of any pull 811 var topPull *models.Pull 812 for _, maybeTop := range unorderedPulls { 813 if _, ok := parentMap[maybeTop.ChangeId]; !ok { 814 topPull = maybeTop 815 break 816 } 817 } 818 819 pulls := []*models.Pull{} 820 for { 821 pulls = append(pulls, topPull) 822 if topPull.ParentChangeId != "" { 823 if next, ok := changeIdMap[topPull.ParentChangeId]; ok { 824 topPull = next 825 } else { 826 return nil, fmt.Errorf("failed to find parent pull request, stack is malformed") 827 } 828 } else { 829 break 830 } 831 } 832 833 return pulls, nil 834} 835 836func GetAbandonedPulls(e Execer, stackId string) ([]*models.Pull, error) { 837 pulls, err := GetPulls( 838 e, 839 orm.FilterEq("stack_id", stackId), 840 orm.FilterEq("state", models.PullDeleted), 841 ) 842 if err != nil { 843 return nil, err 844 } 845 846 return pulls, nil 847}