forked from tangled.org/core
Monorepo for Tangled — https://tangled.org
1package db 2 3import ( 4 "database/sql" 5 "fmt" 6 "log" 7 "sort" 8 "strings" 9 "time" 10 11 "github.com/bluesky-social/indigo/atproto/syntax" 12 "tangled.org/core/appview/models" 13) 14 15func NewPull(tx *sql.Tx, pull *models.Pull) error { 16 _, err := tx.Exec(` 17 insert or ignore into repo_pull_seqs (repo_at, next_pull_id) 18 values (?, 1) 19 `, pull.RepoAt) 20 if err != nil { 21 return err 22 } 23 24 var nextId int 25 err = tx.QueryRow(` 26 update repo_pull_seqs 27 set next_pull_id = next_pull_id + 1 28 where repo_at = ? 29 returning next_pull_id - 1 30 `, pull.RepoAt).Scan(&nextId) 31 if err != nil { 32 return err 33 } 34 35 pull.PullId = nextId 36 pull.State = models.PullOpen 37 38 var sourceBranch, sourceRepoAt *string 39 if pull.PullSource != nil { 40 sourceBranch = &pull.PullSource.Branch 41 if pull.PullSource.RepoAt != nil { 42 x := pull.PullSource.RepoAt.String() 43 sourceRepoAt = &x 44 } 45 } 46 47 var stackId, changeId, parentChangeId *string 48 if pull.StackId != "" { 49 stackId = &pull.StackId 50 } 51 if pull.ChangeId != "" { 52 changeId = &pull.ChangeId 53 } 54 if pull.ParentChangeId != "" { 55 parentChangeId = &pull.ParentChangeId 56 } 57 58 _, err = tx.Exec( 59 ` 60 insert into pulls ( 61 repo_at, owner_did, pull_id, title, target_branch, body, rkey, state, source_branch, source_repo_at, stack_id, change_id, parent_change_id 62 ) 63 values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`, 64 pull.RepoAt, 65 pull.OwnerDid, 66 pull.PullId, 67 pull.Title, 68 pull.TargetBranch, 69 pull.Body, 70 pull.Rkey, 71 pull.State, 72 sourceBranch, 73 sourceRepoAt, 74 stackId, 75 changeId, 76 parentChangeId, 77 ) 78 if err != nil { 79 return err 80 } 81 82 _, err = tx.Exec(` 83 insert into pull_submissions (pull_id, repo_at, round_number, patch, source_rev) 84 values (?, ?, ?, ?, ?) 85 `, pull.PullId, pull.RepoAt, 0, pull.Submissions[0].Patch, pull.Submissions[0].SourceRev) 86 return err 87} 88 89func GetPullAt(e Execer, repoAt syntax.ATURI, pullId int) (syntax.ATURI, error) { 90 pull, err := GetPull(e, repoAt, pullId) 91 if err != nil { 92 return "", err 93 } 94 return pull.PullAt(), err 95} 96 97func NextPullId(e Execer, repoAt syntax.ATURI) (int, error) { 98 var pullId int 99 err := e.QueryRow(`select next_pull_id from repo_pull_seqs where repo_at = ?`, repoAt).Scan(&pullId) 100 return pullId - 1, err 101} 102 103func GetPullsWithLimit(e Execer, limit int, filters ...filter) ([]*models.Pull, error) { 104 pulls := make(map[int]*models.Pull) 105 106 var conditions []string 107 var args []any 108 for _, filter := range filters { 109 conditions = append(conditions, filter.Condition()) 110 args = append(args, filter.Arg()...) 111 } 112 113 whereClause := "" 114 if conditions != nil { 115 whereClause = " where " + strings.Join(conditions, " and ") 116 } 117 limitClause := "" 118 if limit != 0 { 119 limitClause = fmt.Sprintf(" limit %d ", limit) 120 } 121 122 query := fmt.Sprintf(` 123 select 124 owner_did, 125 repo_at, 126 pull_id, 127 created, 128 title, 129 state, 130 target_branch, 131 body, 132 rkey, 133 source_branch, 134 source_repo_at, 135 stack_id, 136 change_id, 137 parent_change_id 138 from 139 pulls 140 %s 141 order by 142 created desc 143 %s 144 `, whereClause, limitClause) 145 146 rows, err := e.Query(query, args...) 147 if err != nil { 148 return nil, err 149 } 150 defer rows.Close() 151 152 for rows.Next() { 153 var pull models.Pull 154 var createdAt string 155 var sourceBranch, sourceRepoAt, stackId, changeId, parentChangeId sql.NullString 156 err := rows.Scan( 157 &pull.OwnerDid, 158 &pull.RepoAt, 159 &pull.PullId, 160 &createdAt, 161 &pull.Title, 162 &pull.State, 163 &pull.TargetBranch, 164 &pull.Body, 165 &pull.Rkey, 166 &sourceBranch, 167 &sourceRepoAt, 168 &stackId, 169 &changeId, 170 &parentChangeId, 171 ) 172 if err != nil { 173 return nil, err 174 } 175 176 createdTime, err := time.Parse(time.RFC3339, createdAt) 177 if err != nil { 178 return nil, err 179 } 180 pull.Created = createdTime 181 182 if sourceBranch.Valid { 183 pull.PullSource = &models.PullSource{ 184 Branch: sourceBranch.String, 185 } 186 if sourceRepoAt.Valid { 187 sourceRepoAtParsed, err := syntax.ParseATURI(sourceRepoAt.String) 188 if err != nil { 189 return nil, err 190 } 191 pull.PullSource.RepoAt = &sourceRepoAtParsed 192 } 193 } 194 195 if stackId.Valid { 196 pull.StackId = stackId.String 197 } 198 if changeId.Valid { 199 pull.ChangeId = changeId.String 200 } 201 if parentChangeId.Valid { 202 pull.ParentChangeId = parentChangeId.String 203 } 204 205 pulls[pull.PullId] = &pull 206 } 207 208 // get latest round no. for each pull 209 inClause := strings.TrimSuffix(strings.Repeat("?, ", len(pulls)), ", ") 210 submissionsQuery := fmt.Sprintf(` 211 select 212 id, pull_id, round_number, patch, created, source_rev 213 from 214 pull_submissions 215 where 216 repo_at in (%s) and pull_id in (%s) 217 `, inClause, inClause) 218 219 args = make([]any, len(pulls)*2) 220 idx := 0 221 for _, p := range pulls { 222 args[idx] = p.RepoAt 223 idx += 1 224 } 225 for _, p := range pulls { 226 args[idx] = p.PullId 227 idx += 1 228 } 229 submissionsRows, err := e.Query(submissionsQuery, args...) 230 if err != nil { 231 return nil, err 232 } 233 defer submissionsRows.Close() 234 235 for submissionsRows.Next() { 236 var s models.PullSubmission 237 var sourceRev sql.NullString 238 var createdAt string 239 err := submissionsRows.Scan( 240 &s.ID, 241 &s.PullId, 242 &s.RoundNumber, 243 &s.Patch, 244 &createdAt, 245 &sourceRev, 246 ) 247 if err != nil { 248 return nil, err 249 } 250 251 createdTime, err := time.Parse(time.RFC3339, createdAt) 252 if err != nil { 253 return nil, err 254 } 255 s.Created = createdTime 256 257 if sourceRev.Valid { 258 s.SourceRev = sourceRev.String 259 } 260 261 if p, ok := pulls[s.PullId]; ok { 262 p.Submissions = make([]*models.PullSubmission, s.RoundNumber+1) 263 p.Submissions[s.RoundNumber] = &s 264 } 265 } 266 if err := rows.Err(); err != nil { 267 return nil, err 268 } 269 270 // get comment count on latest submission on each pull 271 inClause = strings.TrimSuffix(strings.Repeat("?, ", len(pulls)), ", ") 272 commentsQuery := fmt.Sprintf(` 273 select 274 count(id), pull_id 275 from 276 pull_comments 277 where 278 submission_id in (%s) 279 group by 280 submission_id 281 `, inClause) 282 283 args = []any{} 284 for _, p := range pulls { 285 args = append(args, p.Submissions[p.LastRoundNumber()].ID) 286 } 287 commentsRows, err := e.Query(commentsQuery, args...) 288 if err != nil { 289 return nil, err 290 } 291 defer commentsRows.Close() 292 293 for commentsRows.Next() { 294 var commentCount, pullId int 295 err := commentsRows.Scan( 296 &commentCount, 297 &pullId, 298 ) 299 if err != nil { 300 return nil, err 301 } 302 if p, ok := pulls[pullId]; ok { 303 p.Submissions[p.LastRoundNumber()].Comments = make([]models.PullComment, commentCount) 304 } 305 } 306 if err := rows.Err(); err != nil { 307 return nil, err 308 } 309 310 orderedByPullId := []*models.Pull{} 311 for _, p := range pulls { 312 orderedByPullId = append(orderedByPullId, p) 313 } 314 sort.Slice(orderedByPullId, func(i, j int) bool { 315 return orderedByPullId[i].PullId > orderedByPullId[j].PullId 316 }) 317 318 return orderedByPullId, nil 319} 320 321func GetPulls(e Execer, filters ...filter) ([]*models.Pull, error) { 322 return GetPullsWithLimit(e, 0, filters...) 323} 324 325func GetPull(e Execer, repoAt syntax.ATURI, pullId int) (*models.Pull, error) { 326 query := ` 327 select 328 owner_did, 329 pull_id, 330 created, 331 title, 332 state, 333 target_branch, 334 repo_at, 335 body, 336 rkey, 337 source_branch, 338 source_repo_at, 339 stack_id, 340 change_id, 341 parent_change_id 342 from 343 pulls 344 where 345 repo_at = ? and pull_id = ? 346 ` 347 row := e.QueryRow(query, repoAt, pullId) 348 349 var pull models.Pull 350 var createdAt string 351 var sourceBranch, sourceRepoAt, stackId, changeId, parentChangeId sql.NullString 352 err := row.Scan( 353 &pull.OwnerDid, 354 &pull.PullId, 355 &createdAt, 356 &pull.Title, 357 &pull.State, 358 &pull.TargetBranch, 359 &pull.RepoAt, 360 &pull.Body, 361 &pull.Rkey, 362 &sourceBranch, 363 &sourceRepoAt, 364 &stackId, 365 &changeId, 366 &parentChangeId, 367 ) 368 if err != nil { 369 return nil, err 370 } 371 372 createdTime, err := time.Parse(time.RFC3339, createdAt) 373 if err != nil { 374 return nil, err 375 } 376 pull.Created = createdTime 377 378 // populate source 379 if sourceBranch.Valid { 380 pull.PullSource = &models.PullSource{ 381 Branch: sourceBranch.String, 382 } 383 if sourceRepoAt.Valid { 384 sourceRepoAtParsed, err := syntax.ParseATURI(sourceRepoAt.String) 385 if err != nil { 386 return nil, err 387 } 388 pull.PullSource.RepoAt = &sourceRepoAtParsed 389 } 390 } 391 392 if stackId.Valid { 393 pull.StackId = stackId.String 394 } 395 if changeId.Valid { 396 pull.ChangeId = changeId.String 397 } 398 if parentChangeId.Valid { 399 pull.ParentChangeId = parentChangeId.String 400 } 401 402 submissionsQuery := ` 403 select 404 id, pull_id, repo_at, round_number, patch, created, source_rev 405 from 406 pull_submissions 407 where 408 repo_at = ? and pull_id = ? 409 ` 410 submissionsRows, err := e.Query(submissionsQuery, repoAt, pullId) 411 if err != nil { 412 return nil, err 413 } 414 defer submissionsRows.Close() 415 416 submissionsMap := make(map[int]*models.PullSubmission) 417 418 for submissionsRows.Next() { 419 var submission models.PullSubmission 420 var submissionCreatedStr string 421 var submissionSourceRev sql.NullString 422 err := submissionsRows.Scan( 423 &submission.ID, 424 &submission.PullId, 425 &submission.RepoAt, 426 &submission.RoundNumber, 427 &submission.Patch, 428 &submissionCreatedStr, 429 &submissionSourceRev, 430 ) 431 if err != nil { 432 return nil, err 433 } 434 435 submissionCreatedTime, err := time.Parse(time.RFC3339, submissionCreatedStr) 436 if err != nil { 437 return nil, err 438 } 439 submission.Created = submissionCreatedTime 440 441 if submissionSourceRev.Valid { 442 submission.SourceRev = submissionSourceRev.String 443 } 444 445 submissionsMap[submission.ID] = &submission 446 } 447 if err = submissionsRows.Close(); err != nil { 448 return nil, err 449 } 450 if len(submissionsMap) == 0 { 451 return &pull, nil 452 } 453 454 var args []any 455 for k := range submissionsMap { 456 args = append(args, k) 457 } 458 inClause := strings.TrimSuffix(strings.Repeat("?, ", len(submissionsMap)), ", ") 459 commentsQuery := fmt.Sprintf(` 460 select 461 id, 462 pull_id, 463 submission_id, 464 repo_at, 465 owner_did, 466 comment_at, 467 body, 468 created 469 from 470 pull_comments 471 where 472 submission_id IN (%s) 473 order by 474 created asc 475 `, inClause) 476 commentsRows, err := e.Query(commentsQuery, args...) 477 if err != nil { 478 return nil, err 479 } 480 defer commentsRows.Close() 481 482 for commentsRows.Next() { 483 var comment models.PullComment 484 var commentCreatedStr string 485 err := commentsRows.Scan( 486 &comment.ID, 487 &comment.PullId, 488 &comment.SubmissionId, 489 &comment.RepoAt, 490 &comment.OwnerDid, 491 &comment.CommentAt, 492 &comment.Body, 493 &commentCreatedStr, 494 ) 495 if err != nil { 496 return nil, err 497 } 498 499 commentCreatedTime, err := time.Parse(time.RFC3339, commentCreatedStr) 500 if err != nil { 501 return nil, err 502 } 503 comment.Created = commentCreatedTime 504 505 // Add the comment to its submission 506 if submission, ok := submissionsMap[comment.SubmissionId]; ok { 507 submission.Comments = append(submission.Comments, comment) 508 } 509 510 } 511 if err = commentsRows.Err(); err != nil { 512 return nil, err 513 } 514 515 var pullSourceRepo *models.Repo 516 if pull.PullSource != nil { 517 if pull.PullSource.RepoAt != nil { 518 pullSourceRepo, err = GetRepoByAtUri(e, pull.PullSource.RepoAt.String()) 519 if err != nil { 520 log.Printf("failed to get repo by at uri: %v", err) 521 } else { 522 pull.PullSource.Repo = pullSourceRepo 523 } 524 } 525 } 526 527 pull.Submissions = make([]*models.PullSubmission, len(submissionsMap)) 528 for _, submission := range submissionsMap { 529 pull.Submissions[submission.RoundNumber] = submission 530 } 531 532 return &pull, nil 533} 534 535// timeframe here is directly passed into the sql query filter, and any 536// timeframe in the past should be negative; e.g.: "-3 months" 537func GetPullsByOwnerDid(e Execer, did, timeframe string) ([]models.Pull, error) { 538 var pulls []models.Pull 539 540 rows, err := e.Query(` 541 select 542 p.owner_did, 543 p.repo_at, 544 p.pull_id, 545 p.created, 546 p.title, 547 p.state, 548 r.did, 549 r.name, 550 r.knot, 551 r.rkey, 552 r.created 553 from 554 pulls p 555 join 556 repos r on p.repo_at = r.at_uri 557 where 558 p.owner_did = ? and p.created >= date ('now', ?) 559 order by 560 p.created desc`, did, timeframe) 561 if err != nil { 562 return nil, err 563 } 564 defer rows.Close() 565 566 for rows.Next() { 567 var pull models.Pull 568 var repo models.Repo 569 var pullCreatedAt, repoCreatedAt string 570 err := rows.Scan( 571 &pull.OwnerDid, 572 &pull.RepoAt, 573 &pull.PullId, 574 &pullCreatedAt, 575 &pull.Title, 576 &pull.State, 577 &repo.Did, 578 &repo.Name, 579 &repo.Knot, 580 &repo.Rkey, 581 &repoCreatedAt, 582 ) 583 if err != nil { 584 return nil, err 585 } 586 587 pullCreatedTime, err := time.Parse(time.RFC3339, pullCreatedAt) 588 if err != nil { 589 return nil, err 590 } 591 pull.Created = pullCreatedTime 592 593 repoCreatedTime, err := time.Parse(time.RFC3339, repoCreatedAt) 594 if err != nil { 595 return nil, err 596 } 597 repo.Created = repoCreatedTime 598 599 pull.Repo = &repo 600 601 pulls = append(pulls, pull) 602 } 603 604 if err := rows.Err(); err != nil { 605 return nil, err 606 } 607 608 return pulls, nil 609} 610 611func NewPullComment(e Execer, comment *models.PullComment) (int64, error) { 612 query := `insert into pull_comments (owner_did, repo_at, submission_id, comment_at, pull_id, body) values (?, ?, ?, ?, ?, ?)` 613 res, err := e.Exec( 614 query, 615 comment.OwnerDid, 616 comment.RepoAt, 617 comment.SubmissionId, 618 comment.CommentAt, 619 comment.PullId, 620 comment.Body, 621 ) 622 if err != nil { 623 return 0, err 624 } 625 626 i, err := res.LastInsertId() 627 if err != nil { 628 return 0, err 629 } 630 631 return i, nil 632} 633 634func SetPullState(e Execer, repoAt syntax.ATURI, pullId int, pullState models.PullState) error { 635 _, err := e.Exec( 636 `update pulls set state = ? where repo_at = ? and pull_id = ? and (state <> ? or state <> ?)`, 637 pullState, 638 repoAt, 639 pullId, 640 models.PullDeleted, // only update state of non-deleted pulls 641 models.PullMerged, // only update state of non-merged pulls 642 ) 643 return err 644} 645 646func ClosePull(e Execer, repoAt syntax.ATURI, pullId int) error { 647 err := SetPullState(e, repoAt, pullId, models.PullClosed) 648 return err 649} 650 651func ReopenPull(e Execer, repoAt syntax.ATURI, pullId int) error { 652 err := SetPullState(e, repoAt, pullId, models.PullOpen) 653 return err 654} 655 656func MergePull(e Execer, repoAt syntax.ATURI, pullId int) error { 657 err := SetPullState(e, repoAt, pullId, models.PullMerged) 658 return err 659} 660 661func DeletePull(e Execer, repoAt syntax.ATURI, pullId int) error { 662 err := SetPullState(e, repoAt, pullId, models.PullDeleted) 663 return err 664} 665 666func ResubmitPull(e Execer, pull *models.Pull, newPatch, sourceRev string) error { 667 newRoundNumber := len(pull.Submissions) 668 _, err := e.Exec(` 669 insert into pull_submissions (pull_id, repo_at, round_number, patch, source_rev) 670 values (?, ?, ?, ?, ?) 671 `, pull.PullId, pull.RepoAt, newRoundNumber, newPatch, sourceRev) 672 673 return err 674} 675 676func SetPullParentChangeId(e Execer, parentChangeId string, filters ...filter) error { 677 var conditions []string 678 var args []any 679 680 args = append(args, parentChangeId) 681 682 for _, filter := range filters { 683 conditions = append(conditions, filter.Condition()) 684 args = append(args, filter.Arg()...) 685 } 686 687 whereClause := "" 688 if conditions != nil { 689 whereClause = " where " + strings.Join(conditions, " and ") 690 } 691 692 query := fmt.Sprintf("update pulls set parent_change_id = ? %s", whereClause) 693 _, err := e.Exec(query, args...) 694 695 return err 696} 697 698// Only used when stacking to update contents in the event of a rebase (the interdiff should be empty). 699// otherwise submissions are immutable 700func UpdatePull(e Execer, newPatch, sourceRev string, filters ...filter) error { 701 var conditions []string 702 var args []any 703 704 args = append(args, sourceRev) 705 args = append(args, newPatch) 706 707 for _, filter := range filters { 708 conditions = append(conditions, filter.Condition()) 709 args = append(args, filter.Arg()...) 710 } 711 712 whereClause := "" 713 if conditions != nil { 714 whereClause = " where " + strings.Join(conditions, " and ") 715 } 716 717 query := fmt.Sprintf("update pull_submissions set source_rev = ?, patch = ? %s", whereClause) 718 _, err := e.Exec(query, args...) 719 720 return err 721} 722 723func GetPullCount(e Execer, repoAt syntax.ATURI) (models.PullCount, error) { 724 row := e.QueryRow(` 725 select 726 count(case when state = ? then 1 end) as open_count, 727 count(case when state = ? then 1 end) as merged_count, 728 count(case when state = ? then 1 end) as closed_count, 729 count(case when state = ? then 1 end) as deleted_count 730 from pulls 731 where repo_at = ?`, 732 models.PullOpen, 733 models.PullMerged, 734 models.PullClosed, 735 models.PullDeleted, 736 repoAt, 737 ) 738 739 var count models.PullCount 740 if err := row.Scan(&count.Open, &count.Merged, &count.Closed, &count.Deleted); err != nil { 741 return models.PullCount{Open: 0, Merged: 0, Closed: 0, Deleted: 0}, err 742 } 743 744 return count, nil 745} 746 747// change-id parent-change-id 748// 749// 4 w ,-------- z (TOP) 750// 3 z <----',------- y 751// 2 y <-----',------ x 752// 1 x <------' nil (BOT) 753// 754// `w` is parent of none, so it is the top of the stack 755func GetStack(e Execer, stackId string) (models.Stack, error) { 756 unorderedPulls, err := GetPulls( 757 e, 758 FilterEq("stack_id", stackId), 759 FilterNotEq("state", models.PullDeleted), 760 ) 761 if err != nil { 762 return nil, err 763 } 764 // map of parent-change-id to pull 765 changeIdMap := make(map[string]*models.Pull, len(unorderedPulls)) 766 parentMap := make(map[string]*models.Pull, len(unorderedPulls)) 767 for _, p := range unorderedPulls { 768 changeIdMap[p.ChangeId] = p 769 if p.ParentChangeId != "" { 770 parentMap[p.ParentChangeId] = p 771 } 772 } 773 774 // the top of the stack is the pull that is not a parent of any pull 775 var topPull *models.Pull 776 for _, maybeTop := range unorderedPulls { 777 if _, ok := parentMap[maybeTop.ChangeId]; !ok { 778 topPull = maybeTop 779 break 780 } 781 } 782 783 pulls := []*models.Pull{} 784 for { 785 pulls = append(pulls, topPull) 786 if topPull.ParentChangeId != "" { 787 if next, ok := changeIdMap[topPull.ParentChangeId]; ok { 788 topPull = next 789 } else { 790 return nil, fmt.Errorf("failed to find parent pull request, stack is malformed") 791 } 792 } else { 793 break 794 } 795 } 796 797 return pulls, nil 798} 799 800func GetAbandonedPulls(e Execer, stackId string) ([]*models.Pull, error) { 801 pulls, err := GetPulls( 802 e, 803 FilterEq("stack_id", stackId), 804 FilterEq("state", models.PullDeleted), 805 ) 806 if err != nil { 807 return nil, err 808 } 809 810 return pulls, nil 811}