forked from tangled.org/core
Monorepo for Tangled — https://tangled.org
1package db 2 3import ( 4 "database/sql" 5 "fmt" 6 "log" 7 "sort" 8 "strings" 9 "time" 10 11 "github.com/bluesky-social/indigo/atproto/syntax" 12 "tangled.org/core/appview/models" 13) 14 15func NewPull(tx *sql.Tx, pull *models.Pull) error { 16 _, err := tx.Exec(` 17 insert or ignore into repo_pull_seqs (repo_at, next_pull_id) 18 values (?, 1) 19 `, pull.RepoAt) 20 if err != nil { 21 return err 22 } 23 24 var nextId int 25 err = tx.QueryRow(` 26 update repo_pull_seqs 27 set next_pull_id = next_pull_id + 1 28 where repo_at = ? 29 returning next_pull_id - 1 30 `, pull.RepoAt).Scan(&nextId) 31 if err != nil { 32 return err 33 } 34 35 pull.PullId = nextId 36 pull.State = models.PullOpen 37 38 var sourceBranch, sourceRepoAt *string 39 if pull.PullSource != nil { 40 sourceBranch = &pull.PullSource.Branch 41 if pull.PullSource.RepoAt != nil { 42 x := pull.PullSource.RepoAt.String() 43 sourceRepoAt = &x 44 } 45 } 46 47 var stackId, changeId, parentChangeId *string 48 if pull.StackId != "" { 49 stackId = &pull.StackId 50 } 51 if pull.ChangeId != "" { 52 changeId = &pull.ChangeId 53 } 54 if pull.ParentChangeId != "" { 55 parentChangeId = &pull.ParentChangeId 56 } 57 58 result, err := tx.Exec( 59 ` 60 insert into pulls ( 61 repo_at, owner_did, pull_id, title, target_branch, body, rkey, state, source_branch, source_repo_at, stack_id, change_id, parent_change_id 62 ) 63 values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`, 64 pull.RepoAt, 65 pull.OwnerDid, 66 pull.PullId, 67 pull.Title, 68 pull.TargetBranch, 69 pull.Body, 70 pull.Rkey, 71 pull.State, 72 sourceBranch, 73 sourceRepoAt, 74 stackId, 75 changeId, 76 parentChangeId, 77 ) 78 if err != nil { 79 return err 80 } 81 82 // Set the database primary key ID 83 id, err := result.LastInsertId() 84 if err != nil { 85 return err 86 } 87 pull.ID = int(id) 88 89 _, err = tx.Exec(` 90 insert into pull_submissions (pull_id, repo_at, round_number, patch, source_rev) 91 values (?, ?, ?, ?, ?) 92 `, pull.PullId, pull.RepoAt, 0, pull.Submissions[0].Patch, pull.Submissions[0].SourceRev) 93 return err 94} 95 96func GetPullAt(e Execer, repoAt syntax.ATURI, pullId int) (syntax.ATURI, error) { 97 pull, err := GetPull(e, repoAt, pullId) 98 if err != nil { 99 return "", err 100 } 101 return pull.PullAt(), err 102} 103 104func NextPullId(e Execer, repoAt syntax.ATURI) (int, error) { 105 var pullId int 106 err := e.QueryRow(`select next_pull_id from repo_pull_seqs where repo_at = ?`, repoAt).Scan(&pullId) 107 return pullId - 1, err 108} 109 110func GetPullsWithLimit(e Execer, limit int, filters ...filter) ([]*models.Pull, error) { 111 pulls := make(map[int]*models.Pull) 112 113 var conditions []string 114 var args []any 115 for _, filter := range filters { 116 conditions = append(conditions, filter.Condition()) 117 args = append(args, filter.Arg()...) 118 } 119 120 whereClause := "" 121 if conditions != nil { 122 whereClause = " where " + strings.Join(conditions, " and ") 123 } 124 limitClause := "" 125 if limit != 0 { 126 limitClause = fmt.Sprintf(" limit %d ", limit) 127 } 128 129 query := fmt.Sprintf(` 130 select 131 id, 132 owner_did, 133 repo_at, 134 pull_id, 135 created, 136 title, 137 state, 138 target_branch, 139 body, 140 rkey, 141 source_branch, 142 source_repo_at, 143 stack_id, 144 change_id, 145 parent_change_id 146 from 147 pulls 148 %s 149 order by 150 created desc 151 %s 152 `, whereClause, limitClause) 153 154 rows, err := e.Query(query, args...) 155 if err != nil { 156 return nil, err 157 } 158 defer rows.Close() 159 160 for rows.Next() { 161 var pull models.Pull 162 var createdAt string 163 var sourceBranch, sourceRepoAt, stackId, changeId, parentChangeId sql.NullString 164 err := rows.Scan( 165 &pull.ID, 166 &pull.OwnerDid, 167 &pull.RepoAt, 168 &pull.PullId, 169 &createdAt, 170 &pull.Title, 171 &pull.State, 172 &pull.TargetBranch, 173 &pull.Body, 174 &pull.Rkey, 175 &sourceBranch, 176 &sourceRepoAt, 177 &stackId, 178 &changeId, 179 &parentChangeId, 180 ) 181 if err != nil { 182 return nil, err 183 } 184 185 createdTime, err := time.Parse(time.RFC3339, createdAt) 186 if err != nil { 187 return nil, err 188 } 189 pull.Created = createdTime 190 191 if sourceBranch.Valid { 192 pull.PullSource = &models.PullSource{ 193 Branch: sourceBranch.String, 194 } 195 if sourceRepoAt.Valid { 196 sourceRepoAtParsed, err := syntax.ParseATURI(sourceRepoAt.String) 197 if err != nil { 198 return nil, err 199 } 200 pull.PullSource.RepoAt = &sourceRepoAtParsed 201 } 202 } 203 204 if stackId.Valid { 205 pull.StackId = stackId.String 206 } 207 if changeId.Valid { 208 pull.ChangeId = changeId.String 209 } 210 if parentChangeId.Valid { 211 pull.ParentChangeId = parentChangeId.String 212 } 213 214 pulls[pull.PullId] = &pull 215 } 216 217 // get latest round no. for each pull 218 inClause := strings.TrimSuffix(strings.Repeat("?, ", len(pulls)), ", ") 219 submissionsQuery := fmt.Sprintf(` 220 select 221 id, pull_id, round_number, patch, created, source_rev 222 from 223 pull_submissions 224 where 225 repo_at in (%s) and pull_id in (%s) 226 `, inClause, inClause) 227 228 args = make([]any, len(pulls)*2) 229 idx := 0 230 for _, p := range pulls { 231 args[idx] = p.RepoAt 232 idx += 1 233 } 234 for _, p := range pulls { 235 args[idx] = p.PullId 236 idx += 1 237 } 238 submissionsRows, err := e.Query(submissionsQuery, args...) 239 if err != nil { 240 return nil, err 241 } 242 defer submissionsRows.Close() 243 244 for submissionsRows.Next() { 245 var s models.PullSubmission 246 var sourceRev sql.NullString 247 var createdAt string 248 err := submissionsRows.Scan( 249 &s.ID, 250 &s.PullId, 251 &s.RoundNumber, 252 &s.Patch, 253 &createdAt, 254 &sourceRev, 255 ) 256 if err != nil { 257 return nil, err 258 } 259 260 createdTime, err := time.Parse(time.RFC3339, createdAt) 261 if err != nil { 262 return nil, err 263 } 264 s.Created = createdTime 265 266 if sourceRev.Valid { 267 s.SourceRev = sourceRev.String 268 } 269 270 if p, ok := pulls[s.PullId]; ok { 271 p.Submissions = make([]*models.PullSubmission, s.RoundNumber+1) 272 p.Submissions[s.RoundNumber] = &s 273 } 274 } 275 if err := rows.Err(); err != nil { 276 return nil, err 277 } 278 279 // get comment count on latest submission on each pull 280 inClause = strings.TrimSuffix(strings.Repeat("?, ", len(pulls)), ", ") 281 commentsQuery := fmt.Sprintf(` 282 select 283 count(id), pull_id 284 from 285 pull_comments 286 where 287 submission_id in (%s) 288 group by 289 submission_id 290 `, inClause) 291 292 args = []any{} 293 for _, p := range pulls { 294 args = append(args, p.Submissions[p.LastRoundNumber()].ID) 295 } 296 commentsRows, err := e.Query(commentsQuery, args...) 297 if err != nil { 298 return nil, err 299 } 300 defer commentsRows.Close() 301 302 for commentsRows.Next() { 303 var commentCount, pullId int 304 err := commentsRows.Scan( 305 &commentCount, 306 &pullId, 307 ) 308 if err != nil { 309 return nil, err 310 } 311 if p, ok := pulls[pullId]; ok { 312 p.Submissions[p.LastRoundNumber()].Comments = make([]models.PullComment, commentCount) 313 } 314 } 315 if err := rows.Err(); err != nil { 316 return nil, err 317 } 318 319 orderedByPullId := []*models.Pull{} 320 for _, p := range pulls { 321 orderedByPullId = append(orderedByPullId, p) 322 } 323 sort.Slice(orderedByPullId, func(i, j int) bool { 324 return orderedByPullId[i].PullId > orderedByPullId[j].PullId 325 }) 326 327 return orderedByPullId, nil 328} 329 330func GetPulls(e Execer, filters ...filter) ([]*models.Pull, error) { 331 return GetPullsWithLimit(e, 0, filters...) 332} 333 334func GetPull(e Execer, repoAt syntax.ATURI, pullId int) (*models.Pull, error) { 335 query := ` 336 select 337 id, 338 owner_did, 339 pull_id, 340 created, 341 title, 342 state, 343 target_branch, 344 repo_at, 345 body, 346 rkey, 347 source_branch, 348 source_repo_at, 349 stack_id, 350 change_id, 351 parent_change_id 352 from 353 pulls 354 where 355 repo_at = ? and pull_id = ? 356 ` 357 row := e.QueryRow(query, repoAt, pullId) 358 359 var pull models.Pull 360 var createdAt string 361 var sourceBranch, sourceRepoAt, stackId, changeId, parentChangeId sql.NullString 362 err := row.Scan( 363 &pull.ID, 364 &pull.OwnerDid, 365 &pull.PullId, 366 &createdAt, 367 &pull.Title, 368 &pull.State, 369 &pull.TargetBranch, 370 &pull.RepoAt, 371 &pull.Body, 372 &pull.Rkey, 373 &sourceBranch, 374 &sourceRepoAt, 375 &stackId, 376 &changeId, 377 &parentChangeId, 378 ) 379 if err != nil { 380 return nil, err 381 } 382 383 createdTime, err := time.Parse(time.RFC3339, createdAt) 384 if err != nil { 385 return nil, err 386 } 387 pull.Created = createdTime 388 389 // populate source 390 if sourceBranch.Valid { 391 pull.PullSource = &models.PullSource{ 392 Branch: sourceBranch.String, 393 } 394 if sourceRepoAt.Valid { 395 sourceRepoAtParsed, err := syntax.ParseATURI(sourceRepoAt.String) 396 if err != nil { 397 return nil, err 398 } 399 pull.PullSource.RepoAt = &sourceRepoAtParsed 400 } 401 } 402 403 if stackId.Valid { 404 pull.StackId = stackId.String 405 } 406 if changeId.Valid { 407 pull.ChangeId = changeId.String 408 } 409 if parentChangeId.Valid { 410 pull.ParentChangeId = parentChangeId.String 411 } 412 413 submissionsQuery := ` 414 select 415 id, pull_id, repo_at, round_number, patch, created, source_rev 416 from 417 pull_submissions 418 where 419 repo_at = ? and pull_id = ? 420 ` 421 submissionsRows, err := e.Query(submissionsQuery, repoAt, pullId) 422 if err != nil { 423 return nil, err 424 } 425 defer submissionsRows.Close() 426 427 submissionsMap := make(map[int]*models.PullSubmission) 428 429 for submissionsRows.Next() { 430 var submission models.PullSubmission 431 var submissionCreatedStr string 432 var submissionSourceRev sql.NullString 433 err := submissionsRows.Scan( 434 &submission.ID, 435 &submission.PullId, 436 &submission.RepoAt, 437 &submission.RoundNumber, 438 &submission.Patch, 439 &submissionCreatedStr, 440 &submissionSourceRev, 441 ) 442 if err != nil { 443 return nil, err 444 } 445 446 submissionCreatedTime, err := time.Parse(time.RFC3339, submissionCreatedStr) 447 if err != nil { 448 return nil, err 449 } 450 submission.Created = submissionCreatedTime 451 452 if submissionSourceRev.Valid { 453 submission.SourceRev = submissionSourceRev.String 454 } 455 456 submissionsMap[submission.ID] = &submission 457 } 458 if err = submissionsRows.Close(); err != nil { 459 return nil, err 460 } 461 if len(submissionsMap) == 0 { 462 return &pull, nil 463 } 464 465 var args []any 466 for k := range submissionsMap { 467 args = append(args, k) 468 } 469 inClause := strings.TrimSuffix(strings.Repeat("?, ", len(submissionsMap)), ", ") 470 commentsQuery := fmt.Sprintf(` 471 select 472 id, 473 pull_id, 474 submission_id, 475 repo_at, 476 owner_did, 477 comment_at, 478 body, 479 created 480 from 481 pull_comments 482 where 483 submission_id IN (%s) 484 order by 485 created asc 486 `, inClause) 487 commentsRows, err := e.Query(commentsQuery, args...) 488 if err != nil { 489 return nil, err 490 } 491 defer commentsRows.Close() 492 493 for commentsRows.Next() { 494 var comment models.PullComment 495 var commentCreatedStr string 496 err := commentsRows.Scan( 497 &comment.ID, 498 &comment.PullId, 499 &comment.SubmissionId, 500 &comment.RepoAt, 501 &comment.OwnerDid, 502 &comment.CommentAt, 503 &comment.Body, 504 &commentCreatedStr, 505 ) 506 if err != nil { 507 return nil, err 508 } 509 510 commentCreatedTime, err := time.Parse(time.RFC3339, commentCreatedStr) 511 if err != nil { 512 return nil, err 513 } 514 comment.Created = commentCreatedTime 515 516 // Add the comment to its submission 517 if submission, ok := submissionsMap[comment.SubmissionId]; ok { 518 submission.Comments = append(submission.Comments, comment) 519 } 520 521 } 522 if err = commentsRows.Err(); err != nil { 523 return nil, err 524 } 525 526 var pullSourceRepo *models.Repo 527 if pull.PullSource != nil { 528 if pull.PullSource.RepoAt != nil { 529 pullSourceRepo, err = GetRepoByAtUri(e, pull.PullSource.RepoAt.String()) 530 if err != nil { 531 log.Printf("failed to get repo by at uri: %v", err) 532 } else { 533 pull.PullSource.Repo = pullSourceRepo 534 } 535 } 536 } 537 538 pull.Submissions = make([]*models.PullSubmission, len(submissionsMap)) 539 for _, submission := range submissionsMap { 540 pull.Submissions[submission.RoundNumber] = submission 541 } 542 543 return &pull, nil 544} 545 546// timeframe here is directly passed into the sql query filter, and any 547// timeframe in the past should be negative; e.g.: "-3 months" 548func GetPullsByOwnerDid(e Execer, did, timeframe string) ([]models.Pull, error) { 549 var pulls []models.Pull 550 551 rows, err := e.Query(` 552 select 553 p.owner_did, 554 p.repo_at, 555 p.pull_id, 556 p.created, 557 p.title, 558 p.state, 559 r.did, 560 r.name, 561 r.knot, 562 r.rkey, 563 r.created 564 from 565 pulls p 566 join 567 repos r on p.repo_at = r.at_uri 568 where 569 p.owner_did = ? and p.created >= date ('now', ?) 570 order by 571 p.created desc`, did, timeframe) 572 if err != nil { 573 return nil, err 574 } 575 defer rows.Close() 576 577 for rows.Next() { 578 var pull models.Pull 579 var repo models.Repo 580 var pullCreatedAt, repoCreatedAt string 581 err := rows.Scan( 582 &pull.OwnerDid, 583 &pull.RepoAt, 584 &pull.PullId, 585 &pullCreatedAt, 586 &pull.Title, 587 &pull.State, 588 &repo.Did, 589 &repo.Name, 590 &repo.Knot, 591 &repo.Rkey, 592 &repoCreatedAt, 593 ) 594 if err != nil { 595 return nil, err 596 } 597 598 pullCreatedTime, err := time.Parse(time.RFC3339, pullCreatedAt) 599 if err != nil { 600 return nil, err 601 } 602 pull.Created = pullCreatedTime 603 604 repoCreatedTime, err := time.Parse(time.RFC3339, repoCreatedAt) 605 if err != nil { 606 return nil, err 607 } 608 repo.Created = repoCreatedTime 609 610 pull.Repo = &repo 611 612 pulls = append(pulls, pull) 613 } 614 615 if err := rows.Err(); err != nil { 616 return nil, err 617 } 618 619 return pulls, nil 620} 621 622func NewPullComment(e Execer, comment *models.PullComment) (int64, error) { 623 query := `insert into pull_comments (owner_did, repo_at, submission_id, comment_at, pull_id, body) values (?, ?, ?, ?, ?, ?)` 624 res, err := e.Exec( 625 query, 626 comment.OwnerDid, 627 comment.RepoAt, 628 comment.SubmissionId, 629 comment.CommentAt, 630 comment.PullId, 631 comment.Body, 632 ) 633 if err != nil { 634 return 0, err 635 } 636 637 i, err := res.LastInsertId() 638 if err != nil { 639 return 0, err 640 } 641 642 return i, nil 643} 644 645func SetPullState(e Execer, repoAt syntax.ATURI, pullId int, pullState models.PullState) error { 646 _, err := e.Exec( 647 `update pulls set state = ? where repo_at = ? and pull_id = ? and (state <> ? or state <> ?)`, 648 pullState, 649 repoAt, 650 pullId, 651 models.PullDeleted, // only update state of non-deleted pulls 652 models.PullMerged, // only update state of non-merged pulls 653 ) 654 return err 655} 656 657func ClosePull(e Execer, repoAt syntax.ATURI, pullId int) error { 658 err := SetPullState(e, repoAt, pullId, models.PullClosed) 659 return err 660} 661 662func ReopenPull(e Execer, repoAt syntax.ATURI, pullId int) error { 663 err := SetPullState(e, repoAt, pullId, models.PullOpen) 664 return err 665} 666 667func MergePull(e Execer, repoAt syntax.ATURI, pullId int) error { 668 err := SetPullState(e, repoAt, pullId, models.PullMerged) 669 return err 670} 671 672func DeletePull(e Execer, repoAt syntax.ATURI, pullId int) error { 673 err := SetPullState(e, repoAt, pullId, models.PullDeleted) 674 return err 675} 676 677func ResubmitPull(e Execer, pull *models.Pull, newPatch, sourceRev string) error { 678 newRoundNumber := len(pull.Submissions) 679 _, err := e.Exec(` 680 insert into pull_submissions (pull_id, repo_at, round_number, patch, source_rev) 681 values (?, ?, ?, ?, ?) 682 `, pull.PullId, pull.RepoAt, newRoundNumber, newPatch, sourceRev) 683 684 return err 685} 686 687func SetPullParentChangeId(e Execer, parentChangeId string, filters ...filter) error { 688 var conditions []string 689 var args []any 690 691 args = append(args, parentChangeId) 692 693 for _, filter := range filters { 694 conditions = append(conditions, filter.Condition()) 695 args = append(args, filter.Arg()...) 696 } 697 698 whereClause := "" 699 if conditions != nil { 700 whereClause = " where " + strings.Join(conditions, " and ") 701 } 702 703 query := fmt.Sprintf("update pulls set parent_change_id = ? %s", whereClause) 704 _, err := e.Exec(query, args...) 705 706 return err 707} 708 709// Only used when stacking to update contents in the event of a rebase (the interdiff should be empty). 710// otherwise submissions are immutable 711func UpdatePull(e Execer, newPatch, sourceRev string, filters ...filter) error { 712 var conditions []string 713 var args []any 714 715 args = append(args, sourceRev) 716 args = append(args, newPatch) 717 718 for _, filter := range filters { 719 conditions = append(conditions, filter.Condition()) 720 args = append(args, filter.Arg()...) 721 } 722 723 whereClause := "" 724 if conditions != nil { 725 whereClause = " where " + strings.Join(conditions, " and ") 726 } 727 728 query := fmt.Sprintf("update pull_submissions set source_rev = ?, patch = ? %s", whereClause) 729 _, err := e.Exec(query, args...) 730 731 return err 732} 733 734func GetPullCount(e Execer, repoAt syntax.ATURI) (models.PullCount, error) { 735 row := e.QueryRow(` 736 select 737 count(case when state = ? then 1 end) as open_count, 738 count(case when state = ? then 1 end) as merged_count, 739 count(case when state = ? then 1 end) as closed_count, 740 count(case when state = ? then 1 end) as deleted_count 741 from pulls 742 where repo_at = ?`, 743 models.PullOpen, 744 models.PullMerged, 745 models.PullClosed, 746 models.PullDeleted, 747 repoAt, 748 ) 749 750 var count models.PullCount 751 if err := row.Scan(&count.Open, &count.Merged, &count.Closed, &count.Deleted); err != nil { 752 return models.PullCount{Open: 0, Merged: 0, Closed: 0, Deleted: 0}, err 753 } 754 755 return count, nil 756} 757 758// change-id parent-change-id 759// 760// 4 w ,-------- z (TOP) 761// 3 z <----',------- y 762// 2 y <-----',------ x 763// 1 x <------' nil (BOT) 764// 765// `w` is parent of none, so it is the top of the stack 766func GetStack(e Execer, stackId string) (models.Stack, error) { 767 unorderedPulls, err := GetPulls( 768 e, 769 FilterEq("stack_id", stackId), 770 FilterNotEq("state", models.PullDeleted), 771 ) 772 if err != nil { 773 return nil, err 774 } 775 // map of parent-change-id to pull 776 changeIdMap := make(map[string]*models.Pull, len(unorderedPulls)) 777 parentMap := make(map[string]*models.Pull, len(unorderedPulls)) 778 for _, p := range unorderedPulls { 779 changeIdMap[p.ChangeId] = p 780 if p.ParentChangeId != "" { 781 parentMap[p.ParentChangeId] = p 782 } 783 } 784 785 // the top of the stack is the pull that is not a parent of any pull 786 var topPull *models.Pull 787 for _, maybeTop := range unorderedPulls { 788 if _, ok := parentMap[maybeTop.ChangeId]; !ok { 789 topPull = maybeTop 790 break 791 } 792 } 793 794 pulls := []*models.Pull{} 795 for { 796 pulls = append(pulls, topPull) 797 if topPull.ParentChangeId != "" { 798 if next, ok := changeIdMap[topPull.ParentChangeId]; ok { 799 topPull = next 800 } else { 801 return nil, fmt.Errorf("failed to find parent pull request, stack is malformed") 802 } 803 } else { 804 break 805 } 806 } 807 808 return pulls, nil 809} 810 811func GetAbandonedPulls(e Execer, stackId string) ([]*models.Pull, error) { 812 pulls, err := GetPulls( 813 e, 814 FilterEq("stack_id", stackId), 815 FilterEq("state", models.PullDeleted), 816 ) 817 if err != nil { 818 return nil, err 819 } 820 821 return pulls, nil 822}