forked from tangled.org/core
Monorepo for Tangled — https://tangled.org
1package db 2 3import ( 4 "cmp" 5 "database/sql" 6 "errors" 7 "fmt" 8 "maps" 9 "slices" 10 "sort" 11 "strings" 12 "time" 13 14 "github.com/bluesky-social/indigo/atproto/syntax" 15 "tangled.org/core/appview/models" 16) 17 18func NewPull(tx *sql.Tx, pull *models.Pull) error { 19 _, err := tx.Exec(` 20 insert or ignore into repo_pull_seqs (repo_at, next_pull_id) 21 values (?, 1) 22 `, pull.RepoAt) 23 if err != nil { 24 return err 25 } 26 27 var nextId int 28 err = tx.QueryRow(` 29 update repo_pull_seqs 30 set next_pull_id = next_pull_id + 1 31 where repo_at = ? 32 returning next_pull_id - 1 33 `, pull.RepoAt).Scan(&nextId) 34 if err != nil { 35 return err 36 } 37 38 pull.PullId = nextId 39 pull.State = models.PullOpen 40 41 var sourceBranch, sourceRepoAt *string 42 if pull.PullSource != nil { 43 sourceBranch = &pull.PullSource.Branch 44 if pull.PullSource.RepoAt != nil { 45 x := pull.PullSource.RepoAt.String() 46 sourceRepoAt = &x 47 } 48 } 49 50 var stackId, changeId, parentChangeId *string 51 if pull.StackId != "" { 52 stackId = &pull.StackId 53 } 54 if pull.ChangeId != "" { 55 changeId = &pull.ChangeId 56 } 57 if pull.ParentChangeId != "" { 58 parentChangeId = &pull.ParentChangeId 59 } 60 61 result, err := tx.Exec( 62 ` 63 insert into pulls ( 64 repo_at, owner_did, pull_id, title, target_branch, body, rkey, state, source_branch, source_repo_at, stack_id, change_id, parent_change_id 65 ) 66 values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`, 67 pull.RepoAt, 68 pull.OwnerDid, 69 pull.PullId, 70 pull.Title, 71 pull.TargetBranch, 72 pull.Body, 73 pull.Rkey, 74 pull.State, 75 sourceBranch, 76 sourceRepoAt, 77 stackId, 78 changeId, 79 parentChangeId, 80 ) 81 if err != nil { 82 return err 83 } 84 85 // Set the database primary key ID 86 id, err := result.LastInsertId() 87 if err != nil { 88 return err 89 } 90 pull.ID = int(id) 91 92 _, err = tx.Exec(` 93 insert into pull_submissions (pull_at, round_number, patch, combined, source_rev) 94 values (?, ?, ?, ?, ?) 95 `, pull.PullAt(), 0, pull.Submissions[0].Patch, pull.Submissions[0].Combined, pull.Submissions[0].SourceRev) 96 return err 97} 98 99func GetPullAt(e Execer, repoAt syntax.ATURI, pullId int) (syntax.ATURI, error) { 100 pull, err := GetPull(e, repoAt, pullId) 101 if err != nil { 102 return "", err 103 } 104 return pull.PullAt(), err 105} 106 107func NextPullId(e Execer, repoAt syntax.ATURI) (int, error) { 108 var pullId int 109 err := e.QueryRow(`select next_pull_id from repo_pull_seqs where repo_at = ?`, repoAt).Scan(&pullId) 110 return pullId - 1, err 111} 112 113func GetPullsWithLimit(e Execer, limit int, filters ...filter) ([]*models.Pull, error) { 114 pulls := make(map[syntax.ATURI]*models.Pull) 115 116 var conditions []string 117 var args []any 118 for _, filter := range filters { 119 conditions = append(conditions, filter.Condition()) 120 args = append(args, filter.Arg()...) 121 } 122 123 whereClause := "" 124 if conditions != nil { 125 whereClause = " where " + strings.Join(conditions, " and ") 126 } 127 limitClause := "" 128 if limit != 0 { 129 limitClause = fmt.Sprintf(" limit %d ", limit) 130 } 131 132 query := fmt.Sprintf(` 133 select 134 id, 135 owner_did, 136 repo_at, 137 pull_id, 138 created, 139 title, 140 state, 141 target_branch, 142 body, 143 rkey, 144 source_branch, 145 source_repo_at, 146 stack_id, 147 change_id, 148 parent_change_id 149 from 150 pulls 151 %s 152 order by 153 created desc 154 %s 155 `, whereClause, limitClause) 156 157 rows, err := e.Query(query, args...) 158 if err != nil { 159 return nil, err 160 } 161 defer rows.Close() 162 163 for rows.Next() { 164 var pull models.Pull 165 var createdAt string 166 var sourceBranch, sourceRepoAt, stackId, changeId, parentChangeId sql.NullString 167 err := rows.Scan( 168 &pull.ID, 169 &pull.OwnerDid, 170 &pull.RepoAt, 171 &pull.PullId, 172 &createdAt, 173 &pull.Title, 174 &pull.State, 175 &pull.TargetBranch, 176 &pull.Body, 177 &pull.Rkey, 178 &sourceBranch, 179 &sourceRepoAt, 180 &stackId, 181 &changeId, 182 &parentChangeId, 183 ) 184 if err != nil { 185 return nil, err 186 } 187 188 createdTime, err := time.Parse(time.RFC3339, createdAt) 189 if err != nil { 190 return nil, err 191 } 192 pull.Created = createdTime 193 194 if sourceBranch.Valid { 195 pull.PullSource = &models.PullSource{ 196 Branch: sourceBranch.String, 197 } 198 if sourceRepoAt.Valid { 199 sourceRepoAtParsed, err := syntax.ParseATURI(sourceRepoAt.String) 200 if err != nil { 201 return nil, err 202 } 203 pull.PullSource.RepoAt = &sourceRepoAtParsed 204 } 205 } 206 207 if stackId.Valid { 208 pull.StackId = stackId.String 209 } 210 if changeId.Valid { 211 pull.ChangeId = changeId.String 212 } 213 if parentChangeId.Valid { 214 pull.ParentChangeId = parentChangeId.String 215 } 216 217 pulls[pull.PullAt()] = &pull 218 } 219 220 var pullAts []syntax.ATURI 221 for _, p := range pulls { 222 pullAts = append(pullAts, p.PullAt()) 223 } 224 submissionsMap, err := GetPullSubmissions(e, FilterIn("pull_at", pullAts)) 225 if err != nil { 226 return nil, fmt.Errorf("failed to get submissions: %w", err) 227 } 228 229 for pullAt, submissions := range submissionsMap { 230 if p, ok := pulls[pullAt]; ok { 231 p.Submissions = submissions 232 } 233 } 234 235 // collect allLabels for each issue 236 allLabels, err := GetLabels(e, FilterIn("subject", pullAts)) 237 if err != nil { 238 return nil, fmt.Errorf("failed to query labels: %w", err) 239 } 240 for pullAt, labels := range allLabels { 241 if p, ok := pulls[pullAt]; ok { 242 p.Labels = labels 243 } 244 } 245 246 // collect pull source for all pulls that need it 247 var sourceAts []syntax.ATURI 248 for _, p := range pulls { 249 if p.PullSource != nil && p.PullSource.RepoAt != nil { 250 sourceAts = append(sourceAts, *p.PullSource.RepoAt) 251 } 252 } 253 sourceRepos, err := GetRepos(e, 0, FilterIn("at_uri", sourceAts)) 254 if err != nil && !errors.Is(err, sql.ErrNoRows) { 255 return nil, fmt.Errorf("failed to get source repos: %w", err) 256 } 257 sourceRepoMap := make(map[syntax.ATURI]*models.Repo) 258 for _, r := range sourceRepos { 259 sourceRepoMap[r.RepoAt()] = &r 260 } 261 for _, p := range pulls { 262 if p.PullSource != nil && p.PullSource.RepoAt != nil { 263 if sourceRepo, ok := sourceRepoMap[*p.PullSource.RepoAt]; ok { 264 p.PullSource.Repo = sourceRepo 265 } 266 } 267 } 268 269 orderedByPullId := []*models.Pull{} 270 for _, p := range pulls { 271 orderedByPullId = append(orderedByPullId, p) 272 } 273 sort.Slice(orderedByPullId, func(i, j int) bool { 274 return orderedByPullId[i].PullId > orderedByPullId[j].PullId 275 }) 276 277 return orderedByPullId, nil 278} 279 280func GetPulls(e Execer, filters ...filter) ([]*models.Pull, error) { 281 return GetPullsWithLimit(e, 0, filters...) 282} 283 284func GetPull(e Execer, repoAt syntax.ATURI, pullId int) (*models.Pull, error) { 285 pulls, err := GetPullsWithLimit(e, 1, FilterEq("repo_at", repoAt), FilterEq("pull_id", pullId)) 286 if err != nil { 287 return nil, err 288 } 289 if pulls == nil { 290 return nil, sql.ErrNoRows 291 } 292 293 return pulls[0], nil 294} 295 296// mapping from pull -> pull submissions 297func GetPullSubmissions(e Execer, filters ...filter) (map[syntax.ATURI][]*models.PullSubmission, error) { 298 var conditions []string 299 var args []any 300 for _, filter := range filters { 301 conditions = append(conditions, filter.Condition()) 302 args = append(args, filter.Arg()...) 303 } 304 305 whereClause := "" 306 if conditions != nil { 307 whereClause = " where " + strings.Join(conditions, " and ") 308 } 309 310 query := fmt.Sprintf(` 311 select 312 id, 313 pull_at, 314 round_number, 315 patch, 316 combined, 317 created, 318 source_rev 319 from 320 pull_submissions 321 %s 322 order by 323 round_number asc 324 `, whereClause) 325 326 rows, err := e.Query(query, args...) 327 if err != nil { 328 return nil, err 329 } 330 defer rows.Close() 331 332 submissionMap := make(map[int]*models.PullSubmission) 333 334 for rows.Next() { 335 var submission models.PullSubmission 336 var submissionCreatedStr string 337 var submissionSourceRev, submissionCombined sql.NullString 338 err := rows.Scan( 339 &submission.ID, 340 &submission.PullAt, 341 &submission.RoundNumber, 342 &submission.Patch, 343 &submissionCombined, 344 &submissionCreatedStr, 345 &submissionSourceRev, 346 ) 347 if err != nil { 348 return nil, err 349 } 350 351 if t, err := time.Parse(time.RFC3339, submissionCreatedStr); err == nil { 352 submission.Created = t 353 } 354 355 if submissionSourceRev.Valid { 356 submission.SourceRev = submissionSourceRev.String 357 } 358 359 if submissionCombined.Valid { 360 submission.Combined = submissionCombined.String 361 } 362 363 submissionMap[submission.ID] = &submission 364 } 365 366 if err := rows.Err(); err != nil { 367 return nil, err 368 } 369 370 // Get comments for all submissions using GetPullComments 371 submissionIds := slices.Collect(maps.Keys(submissionMap)) 372 comments, err := GetPullComments(e, FilterIn("submission_id", submissionIds)) 373 if err != nil { 374 return nil, err 375 } 376 for _, comment := range comments { 377 if submission, ok := submissionMap[comment.SubmissionId]; ok { 378 submission.Comments = append(submission.Comments, comment) 379 } 380 } 381 382 // group the submissions by pull_at 383 m := make(map[syntax.ATURI][]*models.PullSubmission) 384 for _, s := range submissionMap { 385 m[s.PullAt] = append(m[s.PullAt], s) 386 } 387 388 // sort each one by round number 389 for _, s := range m { 390 slices.SortFunc(s, func(a, b *models.PullSubmission) int { 391 return cmp.Compare(a.RoundNumber, b.RoundNumber) 392 }) 393 } 394 395 return m, nil 396} 397 398func GetPullComments(e Execer, filters ...filter) ([]models.PullComment, error) { 399 var conditions []string 400 var args []any 401 for _, filter := range filters { 402 conditions = append(conditions, filter.Condition()) 403 args = append(args, filter.Arg()...) 404 } 405 406 whereClause := "" 407 if conditions != nil { 408 whereClause = " where " + strings.Join(conditions, " and ") 409 } 410 411 query := fmt.Sprintf(` 412 select 413 id, 414 pull_id, 415 submission_id, 416 repo_at, 417 owner_did, 418 comment_at, 419 body, 420 created 421 from 422 pull_comments 423 %s 424 order by 425 created asc 426 `, whereClause) 427 428 rows, err := e.Query(query, args...) 429 if err != nil { 430 return nil, err 431 } 432 defer rows.Close() 433 434 var comments []models.PullComment 435 for rows.Next() { 436 var comment models.PullComment 437 var createdAt string 438 err := rows.Scan( 439 &comment.ID, 440 &comment.PullId, 441 &comment.SubmissionId, 442 &comment.RepoAt, 443 &comment.OwnerDid, 444 &comment.CommentAt, 445 &comment.Body, 446 &createdAt, 447 ) 448 if err != nil { 449 return nil, err 450 } 451 452 if t, err := time.Parse(time.RFC3339, createdAt); err == nil { 453 comment.Created = t 454 } 455 456 comments = append(comments, comment) 457 } 458 459 if err := rows.Err(); err != nil { 460 return nil, err 461 } 462 463 return comments, nil 464} 465 466// timeframe here is directly passed into the sql query filter, and any 467// timeframe in the past should be negative; e.g.: "-3 months" 468func GetPullsByOwnerDid(e Execer, did, timeframe string) ([]models.Pull, error) { 469 var pulls []models.Pull 470 471 rows, err := e.Query(` 472 select 473 p.owner_did, 474 p.repo_at, 475 p.pull_id, 476 p.created, 477 p.title, 478 p.state, 479 r.did, 480 r.name, 481 r.knot, 482 r.rkey, 483 r.created 484 from 485 pulls p 486 join 487 repos r on p.repo_at = r.at_uri 488 where 489 p.owner_did = ? and p.created >= date ('now', ?) 490 order by 491 p.created desc`, did, timeframe) 492 if err != nil { 493 return nil, err 494 } 495 defer rows.Close() 496 497 for rows.Next() { 498 var pull models.Pull 499 var repo models.Repo 500 var pullCreatedAt, repoCreatedAt string 501 err := rows.Scan( 502 &pull.OwnerDid, 503 &pull.RepoAt, 504 &pull.PullId, 505 &pullCreatedAt, 506 &pull.Title, 507 &pull.State, 508 &repo.Did, 509 &repo.Name, 510 &repo.Knot, 511 &repo.Rkey, 512 &repoCreatedAt, 513 ) 514 if err != nil { 515 return nil, err 516 } 517 518 pullCreatedTime, err := time.Parse(time.RFC3339, pullCreatedAt) 519 if err != nil { 520 return nil, err 521 } 522 pull.Created = pullCreatedTime 523 524 repoCreatedTime, err := time.Parse(time.RFC3339, repoCreatedAt) 525 if err != nil { 526 return nil, err 527 } 528 repo.Created = repoCreatedTime 529 530 pull.Repo = &repo 531 532 pulls = append(pulls, pull) 533 } 534 535 if err := rows.Err(); err != nil { 536 return nil, err 537 } 538 539 return pulls, nil 540} 541 542func NewPullComment(e Execer, comment *models.PullComment) (int64, error) { 543 query := `insert into pull_comments (owner_did, repo_at, submission_id, comment_at, pull_id, body) values (?, ?, ?, ?, ?, ?)` 544 res, err := e.Exec( 545 query, 546 comment.OwnerDid, 547 comment.RepoAt, 548 comment.SubmissionId, 549 comment.CommentAt, 550 comment.PullId, 551 comment.Body, 552 ) 553 if err != nil { 554 return 0, err 555 } 556 557 i, err := res.LastInsertId() 558 if err != nil { 559 return 0, err 560 } 561 562 return i, nil 563} 564 565func SetPullState(e Execer, repoAt syntax.ATURI, pullId int, pullState models.PullState) error { 566 _, err := e.Exec( 567 `update pulls set state = ? where repo_at = ? and pull_id = ? and (state <> ? or state <> ?)`, 568 pullState, 569 repoAt, 570 pullId, 571 models.PullDeleted, // only update state of non-deleted pulls 572 models.PullMerged, // only update state of non-merged pulls 573 ) 574 return err 575} 576 577func ClosePull(e Execer, repoAt syntax.ATURI, pullId int) error { 578 err := SetPullState(e, repoAt, pullId, models.PullClosed) 579 return err 580} 581 582func ReopenPull(e Execer, repoAt syntax.ATURI, pullId int) error { 583 err := SetPullState(e, repoAt, pullId, models.PullOpen) 584 return err 585} 586 587func MergePull(e Execer, repoAt syntax.ATURI, pullId int) error { 588 err := SetPullState(e, repoAt, pullId, models.PullMerged) 589 return err 590} 591 592func DeletePull(e Execer, repoAt syntax.ATURI, pullId int) error { 593 err := SetPullState(e, repoAt, pullId, models.PullDeleted) 594 return err 595} 596 597func ResubmitPull(e Execer, pullAt syntax.ATURI, newRoundNumber int, newPatch string, combinedPatch string, newSourceRev string) error { 598 _, err := e.Exec(` 599 insert into pull_submissions (pull_at, round_number, patch, combined, source_rev) 600 values (?, ?, ?, ?, ?) 601 `, pullAt, newRoundNumber, newPatch, combinedPatch, newSourceRev) 602 603 return err 604} 605 606func SetPullParentChangeId(e Execer, parentChangeId string, filters ...filter) error { 607 var conditions []string 608 var args []any 609 610 args = append(args, parentChangeId) 611 612 for _, filter := range filters { 613 conditions = append(conditions, filter.Condition()) 614 args = append(args, filter.Arg()...) 615 } 616 617 whereClause := "" 618 if conditions != nil { 619 whereClause = " where " + strings.Join(conditions, " and ") 620 } 621 622 query := fmt.Sprintf("update pulls set parent_change_id = ? %s", whereClause) 623 _, err := e.Exec(query, args...) 624 625 return err 626} 627 628// Only used when stacking to update contents in the event of a rebase (the interdiff should be empty). 629// otherwise submissions are immutable 630func UpdatePull(e Execer, newPatch, sourceRev string, filters ...filter) error { 631 var conditions []string 632 var args []any 633 634 args = append(args, sourceRev) 635 args = append(args, newPatch) 636 637 for _, filter := range filters { 638 conditions = append(conditions, filter.Condition()) 639 args = append(args, filter.Arg()...) 640 } 641 642 whereClause := "" 643 if conditions != nil { 644 whereClause = " where " + strings.Join(conditions, " and ") 645 } 646 647 query := fmt.Sprintf("update pull_submissions set source_rev = ?, patch = ? %s", whereClause) 648 _, err := e.Exec(query, args...) 649 650 return err 651} 652 653func GetPullCount(e Execer, repoAt syntax.ATURI) (models.PullCount, error) { 654 row := e.QueryRow(` 655 select 656 count(case when state = ? then 1 end) as open_count, 657 count(case when state = ? then 1 end) as merged_count, 658 count(case when state = ? then 1 end) as closed_count, 659 count(case when state = ? then 1 end) as deleted_count 660 from pulls 661 where repo_at = ?`, 662 models.PullOpen, 663 models.PullMerged, 664 models.PullClosed, 665 models.PullDeleted, 666 repoAt, 667 ) 668 669 var count models.PullCount 670 if err := row.Scan(&count.Open, &count.Merged, &count.Closed, &count.Deleted); err != nil { 671 return models.PullCount{Open: 0, Merged: 0, Closed: 0, Deleted: 0}, err 672 } 673 674 return count, nil 675} 676 677// change-id parent-change-id 678// 679// 4 w ,-------- z (TOP) 680// 3 z <----',------- y 681// 2 y <-----',------ x 682// 1 x <------' nil (BOT) 683// 684// `w` is parent of none, so it is the top of the stack 685func GetStack(e Execer, stackId string) (models.Stack, error) { 686 unorderedPulls, err := GetPulls( 687 e, 688 FilterEq("stack_id", stackId), 689 FilterNotEq("state", models.PullDeleted), 690 ) 691 if err != nil { 692 return nil, err 693 } 694 // map of parent-change-id to pull 695 changeIdMap := make(map[string]*models.Pull, len(unorderedPulls)) 696 parentMap := make(map[string]*models.Pull, len(unorderedPulls)) 697 for _, p := range unorderedPulls { 698 changeIdMap[p.ChangeId] = p 699 if p.ParentChangeId != "" { 700 parentMap[p.ParentChangeId] = p 701 } 702 } 703 704 // the top of the stack is the pull that is not a parent of any pull 705 var topPull *models.Pull 706 for _, maybeTop := range unorderedPulls { 707 if _, ok := parentMap[maybeTop.ChangeId]; !ok { 708 topPull = maybeTop 709 break 710 } 711 } 712 713 pulls := []*models.Pull{} 714 for { 715 pulls = append(pulls, topPull) 716 if topPull.ParentChangeId != "" { 717 if next, ok := changeIdMap[topPull.ParentChangeId]; ok { 718 topPull = next 719 } else { 720 return nil, fmt.Errorf("failed to find parent pull request, stack is malformed") 721 } 722 } else { 723 break 724 } 725 } 726 727 return pulls, nil 728} 729 730func GetAbandonedPulls(e Execer, stackId string) ([]*models.Pull, error) { 731 pulls, err := GetPulls( 732 e, 733 FilterEq("stack_id", stackId), 734 FilterEq("state", models.PullDeleted), 735 ) 736 if err != nil { 737 return nil, err 738 } 739 740 return pulls, nil 741}