forked from tangled.org/core
Monorepo for Tangled — https://tangled.org
1package db 2 3import ( 4 "cmp" 5 "database/sql" 6 "errors" 7 "fmt" 8 "maps" 9 "slices" 10 "sort" 11 "strings" 12 "time" 13 14 "github.com/bluesky-social/indigo/atproto/syntax" 15 "tangled.org/core/appview/models" 16) 17 18func NewPull(tx *sql.Tx, pull *models.Pull) error { 19 _, err := tx.Exec(` 20 insert or ignore into repo_pull_seqs (repo_at, next_pull_id) 21 values (?, 1) 22 `, pull.RepoAt) 23 if err != nil { 24 return err 25 } 26 27 var nextId int 28 err = tx.QueryRow(` 29 update repo_pull_seqs 30 set next_pull_id = next_pull_id + 1 31 where repo_at = ? 32 returning next_pull_id - 1 33 `, pull.RepoAt).Scan(&nextId) 34 if err != nil { 35 return err 36 } 37 38 pull.PullId = nextId 39 pull.State = models.PullOpen 40 41 var sourceBranch, sourceRepoAt *string 42 if pull.PullSource != nil { 43 sourceBranch = &pull.PullSource.Branch 44 if pull.PullSource.RepoAt != nil { 45 x := pull.PullSource.RepoAt.String() 46 sourceRepoAt = &x 47 } 48 } 49 50 var stackId, changeId, parentChangeId *string 51 if pull.StackId != "" { 52 stackId = &pull.StackId 53 } 54 if pull.ChangeId != "" { 55 changeId = &pull.ChangeId 56 } 57 if pull.ParentChangeId != "" { 58 parentChangeId = &pull.ParentChangeId 59 } 60 61 result, err := tx.Exec( 62 ` 63 insert into pulls ( 64 repo_at, owner_did, pull_id, title, target_branch, body, rkey, state, source_branch, source_repo_at, stack_id, change_id, parent_change_id 65 ) 66 values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`, 67 pull.RepoAt, 68 pull.OwnerDid, 69 pull.PullId, 70 pull.Title, 71 pull.TargetBranch, 72 pull.Body, 73 pull.Rkey, 74 pull.State, 75 sourceBranch, 76 sourceRepoAt, 77 stackId, 78 changeId, 79 parentChangeId, 80 ) 81 if err != nil { 82 return err 83 } 84 85 // Set the database primary key ID 86 id, err := result.LastInsertId() 87 if err != nil { 88 return err 89 } 90 pull.ID = int(id) 91 92 _, err = tx.Exec(` 93 insert into pull_submissions (pull_at, round_number, patch, combined, source_rev) 94 values (?, ?, ?, ?, ?) 95 `, pull.AtUri(), 0, pull.Submissions[0].Patch, pull.Submissions[0].Combined, pull.Submissions[0].SourceRev) 96 return err 97} 98 99func GetPullAt(e Execer, repoAt syntax.ATURI, pullId int) (syntax.ATURI, error) { 100 pull, err := GetPull(e, repoAt, pullId) 101 if err != nil { 102 return "", err 103 } 104 return pull.AtUri(), err 105} 106 107func NextPullId(e Execer, repoAt syntax.ATURI) (int, error) { 108 var pullId int 109 err := e.QueryRow(`select next_pull_id from repo_pull_seqs where repo_at = ?`, repoAt).Scan(&pullId) 110 return pullId - 1, err 111} 112 113func GetPullsWithLimit(e Execer, limit int, filters ...filter) ([]*models.Pull, error) { 114 pulls := make(map[syntax.ATURI]*models.Pull) 115 116 var conditions []string 117 var args []any 118 for _, filter := range filters { 119 conditions = append(conditions, filter.Condition()) 120 args = append(args, filter.Arg()...) 121 } 122 123 whereClause := "" 124 if conditions != nil { 125 whereClause = " where " + strings.Join(conditions, " and ") 126 } 127 limitClause := "" 128 if limit != 0 { 129 limitClause = fmt.Sprintf(" limit %d ", limit) 130 } 131 132 query := fmt.Sprintf(` 133 select 134 id, 135 owner_did, 136 repo_at, 137 pull_id, 138 created, 139 title, 140 state, 141 target_branch, 142 body, 143 rkey, 144 source_branch, 145 source_repo_at, 146 stack_id, 147 change_id, 148 parent_change_id 149 from 150 pulls 151 %s 152 order by 153 created desc 154 %s 155 `, whereClause, limitClause) 156 157 rows, err := e.Query(query, args...) 158 if err != nil { 159 return nil, err 160 } 161 defer rows.Close() 162 163 for rows.Next() { 164 var pull models.Pull 165 var createdAt string 166 var sourceBranch, sourceRepoAt, stackId, changeId, parentChangeId sql.NullString 167 err := rows.Scan( 168 &pull.ID, 169 &pull.OwnerDid, 170 &pull.RepoAt, 171 &pull.PullId, 172 &createdAt, 173 &pull.Title, 174 &pull.State, 175 &pull.TargetBranch, 176 &pull.Body, 177 &pull.Rkey, 178 &sourceBranch, 179 &sourceRepoAt, 180 &stackId, 181 &changeId, 182 &parentChangeId, 183 ) 184 if err != nil { 185 return nil, err 186 } 187 188 createdTime, err := time.Parse(time.RFC3339, createdAt) 189 if err != nil { 190 return nil, err 191 } 192 pull.Created = createdTime 193 194 if sourceBranch.Valid { 195 pull.PullSource = &models.PullSource{ 196 Branch: sourceBranch.String, 197 } 198 if sourceRepoAt.Valid { 199 sourceRepoAtParsed, err := syntax.ParseATURI(sourceRepoAt.String) 200 if err != nil { 201 return nil, err 202 } 203 pull.PullSource.RepoAt = &sourceRepoAtParsed 204 } 205 } 206 207 if stackId.Valid { 208 pull.StackId = stackId.String 209 } 210 if changeId.Valid { 211 pull.ChangeId = changeId.String 212 } 213 if parentChangeId.Valid { 214 pull.ParentChangeId = parentChangeId.String 215 } 216 217 pulls[pull.AtUri()] = &pull 218 } 219 220 var pullAts []syntax.ATURI 221 for _, p := range pulls { 222 pullAts = append(pullAts, p.AtUri()) 223 } 224 submissionsMap, err := GetPullSubmissions(e, FilterIn("pull_at", pullAts)) 225 if err != nil { 226 return nil, fmt.Errorf("failed to get submissions: %w", err) 227 } 228 229 for pullAt, submissions := range submissionsMap { 230 if p, ok := pulls[pullAt]; ok { 231 p.Submissions = submissions 232 } 233 } 234 235 // collect allLabels for each issue 236 allLabels, err := GetLabels(e, FilterIn("subject", pullAts)) 237 if err != nil { 238 return nil, fmt.Errorf("failed to query labels: %w", err) 239 } 240 for pullAt, labels := range allLabels { 241 if p, ok := pulls[pullAt]; ok { 242 p.Labels = labels 243 } 244 } 245 246 // collect pull source for all pulls that need it 247 var sourceAts []syntax.ATURI 248 for _, p := range pulls { 249 if p.PullSource != nil && p.PullSource.RepoAt != nil { 250 sourceAts = append(sourceAts, *p.PullSource.RepoAt) 251 } 252 } 253 sourceRepos, err := GetRepos(e, 0, FilterIn("at_uri", sourceAts)) 254 if err != nil && !errors.Is(err, sql.ErrNoRows) { 255 return nil, fmt.Errorf("failed to get source repos: %w", err) 256 } 257 sourceRepoMap := make(map[syntax.ATURI]*models.Repo) 258 for _, r := range sourceRepos { 259 sourceRepoMap[r.RepoAt()] = &r 260 } 261 for _, p := range pulls { 262 if p.PullSource != nil && p.PullSource.RepoAt != nil { 263 if sourceRepo, ok := sourceRepoMap[*p.PullSource.RepoAt]; ok { 264 p.PullSource.Repo = sourceRepo 265 } 266 } 267 } 268 269 orderedByPullId := []*models.Pull{} 270 for _, p := range pulls { 271 orderedByPullId = append(orderedByPullId, p) 272 } 273 sort.Slice(orderedByPullId, func(i, j int) bool { 274 return orderedByPullId[i].PullId > orderedByPullId[j].PullId 275 }) 276 277 return orderedByPullId, nil 278} 279 280func GetPulls(e Execer, filters ...filter) ([]*models.Pull, error) { 281 return GetPullsWithLimit(e, 0, filters...) 282} 283 284func GetPullIDs(e Execer, opts models.PullSearchOptions) ([]int64, error) { 285 var ids []int64 286 287 var filters []filter 288 filters = append(filters, FilterEq("state", opts.State)) 289 if opts.RepoAt != "" { 290 filters = append(filters, FilterEq("repo_at", opts.RepoAt)) 291 } 292 293 var conditions []string 294 var args []any 295 296 for _, filter := range filters { 297 conditions = append(conditions, filter.Condition()) 298 args = append(args, filter.Arg()...) 299 } 300 301 whereClause := "" 302 if conditions != nil { 303 whereClause = " where " + strings.Join(conditions, " and ") 304 } 305 pageClause := "" 306 if opts.Page.Limit != 0 { 307 pageClause = fmt.Sprintf( 308 " limit %d offset %d ", 309 opts.Page.Limit, 310 opts.Page.Offset, 311 ) 312 } 313 314 query := fmt.Sprintf( 315 ` 316 select 317 id 318 from 319 pulls 320 %s 321 %s`, 322 whereClause, 323 pageClause, 324 ) 325 args = append(args, opts.Page.Limit, opts.Page.Offset) 326 rows, err := e.Query(query, args...) 327 if err != nil { 328 return nil, err 329 } 330 defer rows.Close() 331 332 for rows.Next() { 333 var id int64 334 err := rows.Scan(&id) 335 if err != nil { 336 return nil, err 337 } 338 339 ids = append(ids, id) 340 } 341 342 return ids, nil 343} 344 345func GetPull(e Execer, repoAt syntax.ATURI, pullId int) (*models.Pull, error) { 346 pulls, err := GetPullsWithLimit(e, 1, FilterEq("repo_at", repoAt), FilterEq("pull_id", pullId)) 347 if err != nil { 348 return nil, err 349 } 350 if len(pulls) == 0 { 351 return nil, sql.ErrNoRows 352 } 353 354 return pulls[0], nil 355} 356 357// mapping from pull -> pull submissions 358func GetPullSubmissions(e Execer, filters ...filter) (map[syntax.ATURI][]*models.PullSubmission, error) { 359 var conditions []string 360 var args []any 361 for _, filter := range filters { 362 conditions = append(conditions, filter.Condition()) 363 args = append(args, filter.Arg()...) 364 } 365 366 whereClause := "" 367 if conditions != nil { 368 whereClause = " where " + strings.Join(conditions, " and ") 369 } 370 371 query := fmt.Sprintf(` 372 select 373 id, 374 pull_at, 375 round_number, 376 patch, 377 combined, 378 created, 379 source_rev 380 from 381 pull_submissions 382 %s 383 order by 384 round_number asc 385 `, whereClause) 386 387 rows, err := e.Query(query, args...) 388 if err != nil { 389 return nil, err 390 } 391 defer rows.Close() 392 393 submissionMap := make(map[int]*models.PullSubmission) 394 395 for rows.Next() { 396 var submission models.PullSubmission 397 var submissionCreatedStr string 398 var submissionSourceRev, submissionCombined sql.NullString 399 err := rows.Scan( 400 &submission.ID, 401 &submission.PullAt, 402 &submission.RoundNumber, 403 &submission.Patch, 404 &submissionCombined, 405 &submissionCreatedStr, 406 &submissionSourceRev, 407 ) 408 if err != nil { 409 return nil, err 410 } 411 412 if t, err := time.Parse(time.RFC3339, submissionCreatedStr); err == nil { 413 submission.Created = t 414 } 415 416 if submissionSourceRev.Valid { 417 submission.SourceRev = submissionSourceRev.String 418 } 419 420 if submissionCombined.Valid { 421 submission.Combined = submissionCombined.String 422 } 423 424 submissionMap[submission.ID] = &submission 425 } 426 427 if err := rows.Err(); err != nil { 428 return nil, err 429 } 430 431 // Get comments for all submissions using GetPullComments 432 submissionIds := slices.Collect(maps.Keys(submissionMap)) 433 comments, err := GetPullComments(e, FilterIn("submission_id", submissionIds)) 434 if err != nil { 435 return nil, err 436 } 437 for _, comment := range comments { 438 if submission, ok := submissionMap[comment.SubmissionId]; ok { 439 submission.Comments = append(submission.Comments, comment) 440 } 441 } 442 443 // group the submissions by pull_at 444 m := make(map[syntax.ATURI][]*models.PullSubmission) 445 for _, s := range submissionMap { 446 m[s.PullAt] = append(m[s.PullAt], s) 447 } 448 449 // sort each one by round number 450 for _, s := range m { 451 slices.SortFunc(s, func(a, b *models.PullSubmission) int { 452 return cmp.Compare(a.RoundNumber, b.RoundNumber) 453 }) 454 } 455 456 return m, nil 457} 458 459func GetPullComments(e Execer, filters ...filter) ([]models.PullComment, error) { 460 var conditions []string 461 var args []any 462 for _, filter := range filters { 463 conditions = append(conditions, filter.Condition()) 464 args = append(args, filter.Arg()...) 465 } 466 467 whereClause := "" 468 if conditions != nil { 469 whereClause = " where " + strings.Join(conditions, " and ") 470 } 471 472 query := fmt.Sprintf(` 473 select 474 id, 475 pull_id, 476 submission_id, 477 repo_at, 478 owner_did, 479 comment_at, 480 body, 481 created 482 from 483 pull_comments 484 %s 485 order by 486 created asc 487 `, whereClause) 488 489 rows, err := e.Query(query, args...) 490 if err != nil { 491 return nil, err 492 } 493 defer rows.Close() 494 495 var comments []models.PullComment 496 for rows.Next() { 497 var comment models.PullComment 498 var createdAt string 499 err := rows.Scan( 500 &comment.ID, 501 &comment.PullId, 502 &comment.SubmissionId, 503 &comment.RepoAt, 504 &comment.OwnerDid, 505 &comment.CommentAt, 506 &comment.Body, 507 &createdAt, 508 ) 509 if err != nil { 510 return nil, err 511 } 512 513 if t, err := time.Parse(time.RFC3339, createdAt); err == nil { 514 comment.Created = t 515 } 516 517 comments = append(comments, comment) 518 } 519 520 if err := rows.Err(); err != nil { 521 return nil, err 522 } 523 524 return comments, nil 525} 526 527// timeframe here is directly passed into the sql query filter, and any 528// timeframe in the past should be negative; e.g.: "-3 months" 529func GetPullsByOwnerDid(e Execer, did, timeframe string) ([]models.Pull, error) { 530 var pulls []models.Pull 531 532 rows, err := e.Query(` 533 select 534 p.owner_did, 535 p.repo_at, 536 p.pull_id, 537 p.created, 538 p.title, 539 p.state, 540 r.did, 541 r.name, 542 r.knot, 543 r.rkey, 544 r.created 545 from 546 pulls p 547 join 548 repos r on p.repo_at = r.at_uri 549 where 550 p.owner_did = ? and p.created >= date ('now', ?) 551 order by 552 p.created desc`, did, timeframe) 553 if err != nil { 554 return nil, err 555 } 556 defer rows.Close() 557 558 for rows.Next() { 559 var pull models.Pull 560 var repo models.Repo 561 var pullCreatedAt, repoCreatedAt string 562 err := rows.Scan( 563 &pull.OwnerDid, 564 &pull.RepoAt, 565 &pull.PullId, 566 &pullCreatedAt, 567 &pull.Title, 568 &pull.State, 569 &repo.Did, 570 &repo.Name, 571 &repo.Knot, 572 &repo.Rkey, 573 &repoCreatedAt, 574 ) 575 if err != nil { 576 return nil, err 577 } 578 579 pullCreatedTime, err := time.Parse(time.RFC3339, pullCreatedAt) 580 if err != nil { 581 return nil, err 582 } 583 pull.Created = pullCreatedTime 584 585 repoCreatedTime, err := time.Parse(time.RFC3339, repoCreatedAt) 586 if err != nil { 587 return nil, err 588 } 589 repo.Created = repoCreatedTime 590 591 pull.Repo = &repo 592 593 pulls = append(pulls, pull) 594 } 595 596 if err := rows.Err(); err != nil { 597 return nil, err 598 } 599 600 return pulls, nil 601} 602 603func NewPullComment(e Execer, comment *models.PullComment) (int64, error) { 604 query := `insert into pull_comments (owner_did, repo_at, submission_id, comment_at, pull_id, body) values (?, ?, ?, ?, ?, ?)` 605 res, err := e.Exec( 606 query, 607 comment.OwnerDid, 608 comment.RepoAt, 609 comment.SubmissionId, 610 comment.CommentAt, 611 comment.PullId, 612 comment.Body, 613 ) 614 if err != nil { 615 return 0, err 616 } 617 618 i, err := res.LastInsertId() 619 if err != nil { 620 return 0, err 621 } 622 623 return i, nil 624} 625 626func SetPullState(e Execer, repoAt syntax.ATURI, pullId int, pullState models.PullState) error { 627 _, err := e.Exec( 628 `update pulls set state = ? where repo_at = ? and pull_id = ? and (state <> ? or state <> ?)`, 629 pullState, 630 repoAt, 631 pullId, 632 models.PullDeleted, // only update state of non-deleted pulls 633 models.PullMerged, // only update state of non-merged pulls 634 ) 635 return err 636} 637 638func ClosePull(e Execer, repoAt syntax.ATURI, pullId int) error { 639 err := SetPullState(e, repoAt, pullId, models.PullClosed) 640 return err 641} 642 643func ReopenPull(e Execer, repoAt syntax.ATURI, pullId int) error { 644 err := SetPullState(e, repoAt, pullId, models.PullOpen) 645 return err 646} 647 648func MergePull(e Execer, repoAt syntax.ATURI, pullId int) error { 649 err := SetPullState(e, repoAt, pullId, models.PullMerged) 650 return err 651} 652 653func DeletePull(e Execer, repoAt syntax.ATURI, pullId int) error { 654 err := SetPullState(e, repoAt, pullId, models.PullDeleted) 655 return err 656} 657 658func ResubmitPull(e Execer, pullAt syntax.ATURI, newRoundNumber int, newPatch string, combinedPatch string, newSourceRev string) error { 659 _, err := e.Exec(` 660 insert into pull_submissions (pull_at, round_number, patch, combined, source_rev) 661 values (?, ?, ?, ?, ?) 662 `, pullAt, newRoundNumber, newPatch, combinedPatch, newSourceRev) 663 664 return err 665} 666 667func SetPullParentChangeId(e Execer, parentChangeId string, filters ...filter) error { 668 var conditions []string 669 var args []any 670 671 args = append(args, parentChangeId) 672 673 for _, filter := range filters { 674 conditions = append(conditions, filter.Condition()) 675 args = append(args, filter.Arg()...) 676 } 677 678 whereClause := "" 679 if conditions != nil { 680 whereClause = " where " + strings.Join(conditions, " and ") 681 } 682 683 query := fmt.Sprintf("update pulls set parent_change_id = ? %s", whereClause) 684 _, err := e.Exec(query, args...) 685 686 return err 687} 688 689// Only used when stacking to update contents in the event of a rebase (the interdiff should be empty). 690// otherwise submissions are immutable 691func UpdatePull(e Execer, newPatch, sourceRev string, filters ...filter) error { 692 var conditions []string 693 var args []any 694 695 args = append(args, sourceRev) 696 args = append(args, newPatch) 697 698 for _, filter := range filters { 699 conditions = append(conditions, filter.Condition()) 700 args = append(args, filter.Arg()...) 701 } 702 703 whereClause := "" 704 if conditions != nil { 705 whereClause = " where " + strings.Join(conditions, " and ") 706 } 707 708 query := fmt.Sprintf("update pull_submissions set source_rev = ?, patch = ? %s", whereClause) 709 _, err := e.Exec(query, args...) 710 711 return err 712} 713 714func GetPullCount(e Execer, repoAt syntax.ATURI) (models.PullCount, error) { 715 row := e.QueryRow(` 716 select 717 count(case when state = ? then 1 end) as open_count, 718 count(case when state = ? then 1 end) as merged_count, 719 count(case when state = ? then 1 end) as closed_count, 720 count(case when state = ? then 1 end) as deleted_count 721 from pulls 722 where repo_at = ?`, 723 models.PullOpen, 724 models.PullMerged, 725 models.PullClosed, 726 models.PullDeleted, 727 repoAt, 728 ) 729 730 var count models.PullCount 731 if err := row.Scan(&count.Open, &count.Merged, &count.Closed, &count.Deleted); err != nil { 732 return models.PullCount{Open: 0, Merged: 0, Closed: 0, Deleted: 0}, err 733 } 734 735 return count, nil 736} 737 738// change-id parent-change-id 739// 740// 4 w ,-------- z (TOP) 741// 3 z <----',------- y 742// 2 y <-----',------ x 743// 1 x <------' nil (BOT) 744// 745// `w` is parent of none, so it is the top of the stack 746func GetStack(e Execer, stackId string) (models.Stack, error) { 747 unorderedPulls, err := GetPulls( 748 e, 749 FilterEq("stack_id", stackId), 750 FilterNotEq("state", models.PullDeleted), 751 ) 752 if err != nil { 753 return nil, err 754 } 755 // map of parent-change-id to pull 756 changeIdMap := make(map[string]*models.Pull, len(unorderedPulls)) 757 parentMap := make(map[string]*models.Pull, len(unorderedPulls)) 758 for _, p := range unorderedPulls { 759 changeIdMap[p.ChangeId] = p 760 if p.ParentChangeId != "" { 761 parentMap[p.ParentChangeId] = p 762 } 763 } 764 765 // the top of the stack is the pull that is not a parent of any pull 766 var topPull *models.Pull 767 for _, maybeTop := range unorderedPulls { 768 if _, ok := parentMap[maybeTop.ChangeId]; !ok { 769 topPull = maybeTop 770 break 771 } 772 } 773 774 pulls := []*models.Pull{} 775 for { 776 pulls = append(pulls, topPull) 777 if topPull.ParentChangeId != "" { 778 if next, ok := changeIdMap[topPull.ParentChangeId]; ok { 779 topPull = next 780 } else { 781 return nil, fmt.Errorf("failed to find parent pull request, stack is malformed") 782 } 783 } else { 784 break 785 } 786 } 787 788 return pulls, nil 789} 790 791func GetAbandonedPulls(e Execer, stackId string) ([]*models.Pull, error) { 792 pulls, err := GetPulls( 793 e, 794 FilterEq("stack_id", stackId), 795 FilterEq("state", models.PullDeleted), 796 ) 797 if err != nil { 798 return nil, err 799 } 800 801 return pulls, nil 802}