forked from tangled.org/core
Monorepo for Tangled — https://tangled.org
1package db 2 3import ( 4 "cmp" 5 "database/sql" 6 "fmt" 7 "maps" 8 "slices" 9 "sort" 10 "strings" 11 "time" 12 13 "github.com/bluesky-social/indigo/atproto/syntax" 14 "tangled.org/core/appview/models" 15) 16 17func NewPull(tx *sql.Tx, pull *models.Pull) error { 18 _, err := tx.Exec(` 19 insert or ignore into repo_pull_seqs (repo_at, next_pull_id) 20 values (?, 1) 21 `, pull.RepoAt) 22 if err != nil { 23 return err 24 } 25 26 var nextId int 27 err = tx.QueryRow(` 28 update repo_pull_seqs 29 set next_pull_id = next_pull_id + 1 30 where repo_at = ? 31 returning next_pull_id - 1 32 `, pull.RepoAt).Scan(&nextId) 33 if err != nil { 34 return err 35 } 36 37 pull.PullId = nextId 38 pull.State = models.PullOpen 39 40 var sourceBranch, sourceRepoAt *string 41 if pull.PullSource != nil { 42 sourceBranch = &pull.PullSource.Branch 43 if pull.PullSource.RepoAt != nil { 44 x := pull.PullSource.RepoAt.String() 45 sourceRepoAt = &x 46 } 47 } 48 49 var stackId, changeId, parentChangeId *string 50 if pull.StackId != "" { 51 stackId = &pull.StackId 52 } 53 if pull.ChangeId != "" { 54 changeId = &pull.ChangeId 55 } 56 if pull.ParentChangeId != "" { 57 parentChangeId = &pull.ParentChangeId 58 } 59 60 result, err := tx.Exec( 61 ` 62 insert into pulls ( 63 repo_at, owner_did, pull_id, title, target_branch, body, rkey, state, source_branch, source_repo_at, stack_id, change_id, parent_change_id 64 ) 65 values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`, 66 pull.RepoAt, 67 pull.OwnerDid, 68 pull.PullId, 69 pull.Title, 70 pull.TargetBranch, 71 pull.Body, 72 pull.Rkey, 73 pull.State, 74 sourceBranch, 75 sourceRepoAt, 76 stackId, 77 changeId, 78 parentChangeId, 79 ) 80 if err != nil { 81 return err 82 } 83 84 // Set the database primary key ID 85 id, err := result.LastInsertId() 86 if err != nil { 87 return err 88 } 89 pull.ID = int(id) 90 91 _, err = tx.Exec(` 92 insert into pull_submissions (pull_at, round_number, patch, source_rev) 93 values (?, ?, ?, ?) 94 `, pull.PullAt(), 0, pull.Submissions[0].Patch, pull.Submissions[0].SourceRev) 95 return err 96} 97 98func GetPullAt(e Execer, repoAt syntax.ATURI, pullId int) (syntax.ATURI, error) { 99 pull, err := GetPull(e, repoAt, pullId) 100 if err != nil { 101 return "", err 102 } 103 return pull.PullAt(), err 104} 105 106func NextPullId(e Execer, repoAt syntax.ATURI) (int, error) { 107 var pullId int 108 err := e.QueryRow(`select next_pull_id from repo_pull_seqs where repo_at = ?`, repoAt).Scan(&pullId) 109 return pullId - 1, err 110} 111 112func GetPullsWithLimit(e Execer, limit int, filters ...filter) ([]*models.Pull, error) { 113 pulls := make(map[syntax.ATURI]*models.Pull) 114 115 var conditions []string 116 var args []any 117 for _, filter := range filters { 118 conditions = append(conditions, filter.Condition()) 119 args = append(args, filter.Arg()...) 120 } 121 122 whereClause := "" 123 if conditions != nil { 124 whereClause = " where " + strings.Join(conditions, " and ") 125 } 126 limitClause := "" 127 if limit != 0 { 128 limitClause = fmt.Sprintf(" limit %d ", limit) 129 } 130 131 query := fmt.Sprintf(` 132 select 133 id, 134 owner_did, 135 repo_at, 136 pull_id, 137 created, 138 title, 139 state, 140 target_branch, 141 body, 142 rkey, 143 source_branch, 144 source_repo_at, 145 stack_id, 146 change_id, 147 parent_change_id 148 from 149 pulls 150 %s 151 order by 152 created desc 153 %s 154 `, whereClause, limitClause) 155 156 rows, err := e.Query(query, args...) 157 if err != nil { 158 return nil, err 159 } 160 defer rows.Close() 161 162 for rows.Next() { 163 var pull models.Pull 164 var createdAt string 165 var sourceBranch, sourceRepoAt, stackId, changeId, parentChangeId sql.NullString 166 err := rows.Scan( 167 &pull.ID, 168 &pull.OwnerDid, 169 &pull.RepoAt, 170 &pull.PullId, 171 &createdAt, 172 &pull.Title, 173 &pull.State, 174 &pull.TargetBranch, 175 &pull.Body, 176 &pull.Rkey, 177 &sourceBranch, 178 &sourceRepoAt, 179 &stackId, 180 &changeId, 181 &parentChangeId, 182 ) 183 if err != nil { 184 return nil, err 185 } 186 187 createdTime, err := time.Parse(time.RFC3339, createdAt) 188 if err != nil { 189 return nil, err 190 } 191 pull.Created = createdTime 192 193 if sourceBranch.Valid { 194 pull.PullSource = &models.PullSource{ 195 Branch: sourceBranch.String, 196 } 197 if sourceRepoAt.Valid { 198 sourceRepoAtParsed, err := syntax.ParseATURI(sourceRepoAt.String) 199 if err != nil { 200 return nil, err 201 } 202 pull.PullSource.RepoAt = &sourceRepoAtParsed 203 } 204 } 205 206 if stackId.Valid { 207 pull.StackId = stackId.String 208 } 209 if changeId.Valid { 210 pull.ChangeId = changeId.String 211 } 212 if parentChangeId.Valid { 213 pull.ParentChangeId = parentChangeId.String 214 } 215 216 pulls[pull.PullAt()] = &pull 217 } 218 219 var pullAts []syntax.ATURI 220 for _, p := range pulls { 221 pullAts = append(pullAts, p.PullAt()) 222 } 223 submissionsMap, err := GetPullSubmissions(e, FilterIn("pull_at", pullAts)) 224 if err != nil { 225 return nil, fmt.Errorf("failed to get submissions: %w", err) 226 } 227 228 for pullAt, submissions := range submissionsMap { 229 if p, ok := pulls[pullAt]; ok { 230 p.Submissions = submissions 231 } 232 } 233 // collect allLabels for each issue 234 allLabels, err := GetLabels(e, FilterIn("subject", pullAts)) 235 if err != nil { 236 return nil, fmt.Errorf("failed to query labels: %w", err) 237 } 238 for pullAt, labels := range allLabels { 239 if p, ok := pulls[pullAt]; ok { 240 p.Labels = labels 241 } 242 } 243 244 orderedByPullId := []*models.Pull{} 245 for _, p := range pulls { 246 orderedByPullId = append(orderedByPullId, p) 247 } 248 sort.Slice(orderedByPullId, func(i, j int) bool { 249 return orderedByPullId[i].PullId > orderedByPullId[j].PullId 250 }) 251 252 return orderedByPullId, nil 253} 254 255func GetPulls(e Execer, filters ...filter) ([]*models.Pull, error) { 256 return GetPullsWithLimit(e, 0, filters...) 257} 258 259func GetPull(e Execer, repoAt syntax.ATURI, pullId int) (*models.Pull, error) { 260 pulls, err := GetPullsWithLimit(e, 1, FilterEq("repo_at", repoAt), FilterEq("pull_id", pullId)) 261 if err != nil { 262 return nil, err 263 } 264 if pulls == nil { 265 return nil, sql.ErrNoRows 266 } 267 268 return pulls[0], nil 269} 270 271// mapping from pull -> pull submissions 272func GetPullSubmissions(e Execer, filters ...filter) (map[syntax.ATURI][]*models.PullSubmission, error) { 273 var conditions []string 274 var args []any 275 for _, filter := range filters { 276 conditions = append(conditions, filter.Condition()) 277 args = append(args, filter.Arg()...) 278 } 279 280 whereClause := "" 281 if conditions != nil { 282 whereClause = " where " + strings.Join(conditions, " and ") 283 } 284 285 query := fmt.Sprintf(` 286 select 287 id, 288 pull_at, 289 round_number, 290 patch, 291 created, 292 source_rev 293 from 294 pull_submissions 295 %s 296 order by 297 round_number asc 298 `, whereClause) 299 300 rows, err := e.Query(query, args...) 301 if err != nil { 302 return nil, err 303 } 304 defer rows.Close() 305 306 submissionMap := make(map[int]*models.PullSubmission) 307 308 for rows.Next() { 309 var submission models.PullSubmission 310 var createdAt string 311 var sourceRev sql.NullString 312 err := rows.Scan( 313 &submission.ID, 314 &submission.PullAt, 315 &submission.RoundNumber, 316 &submission.Patch, 317 &createdAt, 318 &sourceRev, 319 ) 320 if err != nil { 321 return nil, err 322 } 323 324 createdTime, err := time.Parse(time.RFC3339, createdAt) 325 if err != nil { 326 return nil, err 327 } 328 submission.Created = createdTime 329 330 if sourceRev.Valid { 331 submission.SourceRev = sourceRev.String 332 } 333 334 submissionMap[submission.ID] = &submission 335 } 336 337 if err := rows.Err(); err != nil { 338 return nil, err 339 } 340 341 // Get comments for all submissions using GetPullComments 342 submissionIds := slices.Collect(maps.Keys(submissionMap)) 343 comments, err := GetPullComments(e, FilterIn("submission_id", submissionIds)) 344 if err != nil { 345 return nil, err 346 } 347 for _, comment := range comments { 348 if submission, ok := submissionMap[comment.SubmissionId]; ok { 349 submission.Comments = append(submission.Comments, comment) 350 } 351 } 352 353 // group the submissions by pull_at 354 m := make(map[syntax.ATURI][]*models.PullSubmission) 355 for _, s := range submissionMap { 356 m[s.PullAt] = append(m[s.PullAt], s) 357 } 358 359 // sort each one by round number 360 for _, s := range m { 361 slices.SortFunc(s, func(a, b *models.PullSubmission) int { 362 return cmp.Compare(a.RoundNumber, b.RoundNumber) 363 }) 364 } 365 366 return m, nil 367} 368 369func GetPullComments(e Execer, filters ...filter) ([]models.PullComment, error) { 370 var conditions []string 371 var args []any 372 for _, filter := range filters { 373 conditions = append(conditions, filter.Condition()) 374 args = append(args, filter.Arg()...) 375 } 376 377 whereClause := "" 378 if conditions != nil { 379 whereClause = " where " + strings.Join(conditions, " and ") 380 } 381 382 query := fmt.Sprintf(` 383 select 384 id, 385 pull_id, 386 submission_id, 387 repo_at, 388 owner_did, 389 comment_at, 390 body, 391 created 392 from 393 pull_comments 394 %s 395 order by 396 created asc 397 `, whereClause) 398 399 rows, err := e.Query(query, args...) 400 if err != nil { 401 return nil, err 402 } 403 defer rows.Close() 404 405 var comments []models.PullComment 406 for rows.Next() { 407 var comment models.PullComment 408 var createdAt string 409 err := rows.Scan( 410 &comment.ID, 411 &comment.PullId, 412 &comment.SubmissionId, 413 &comment.RepoAt, 414 &comment.OwnerDid, 415 &comment.CommentAt, 416 &comment.Body, 417 &createdAt, 418 ) 419 if err != nil { 420 return nil, err 421 } 422 423 if t, err := time.Parse(time.RFC3339, createdAt); err == nil { 424 comment.Created = t 425 } 426 427 comments = append(comments, comment) 428 } 429 430 if err := rows.Err(); err != nil { 431 return nil, err 432 } 433 434 return comments, nil 435} 436 437// timeframe here is directly passed into the sql query filter, and any 438// timeframe in the past should be negative; e.g.: "-3 months" 439func GetPullsByOwnerDid(e Execer, did, timeframe string) ([]models.Pull, error) { 440 var pulls []models.Pull 441 442 rows, err := e.Query(` 443 select 444 p.owner_did, 445 p.repo_at, 446 p.pull_id, 447 p.created, 448 p.title, 449 p.state, 450 r.did, 451 r.name, 452 r.knot, 453 r.rkey, 454 r.created 455 from 456 pulls p 457 join 458 repos r on p.repo_at = r.at_uri 459 where 460 p.owner_did = ? and p.created >= date ('now', ?) 461 order by 462 p.created desc`, did, timeframe) 463 if err != nil { 464 return nil, err 465 } 466 defer rows.Close() 467 468 for rows.Next() { 469 var pull models.Pull 470 var repo models.Repo 471 var pullCreatedAt, repoCreatedAt string 472 err := rows.Scan( 473 &pull.OwnerDid, 474 &pull.RepoAt, 475 &pull.PullId, 476 &pullCreatedAt, 477 &pull.Title, 478 &pull.State, 479 &repo.Did, 480 &repo.Name, 481 &repo.Knot, 482 &repo.Rkey, 483 &repoCreatedAt, 484 ) 485 if err != nil { 486 return nil, err 487 } 488 489 pullCreatedTime, err := time.Parse(time.RFC3339, pullCreatedAt) 490 if err != nil { 491 return nil, err 492 } 493 pull.Created = pullCreatedTime 494 495 repoCreatedTime, err := time.Parse(time.RFC3339, repoCreatedAt) 496 if err != nil { 497 return nil, err 498 } 499 repo.Created = repoCreatedTime 500 501 pull.Repo = &repo 502 503 pulls = append(pulls, pull) 504 } 505 506 if err := rows.Err(); err != nil { 507 return nil, err 508 } 509 510 return pulls, nil 511} 512 513func NewPullComment(e Execer, comment *models.PullComment) (int64, error) { 514 query := `insert into pull_comments (owner_did, repo_at, submission_id, comment_at, pull_id, body) values (?, ?, ?, ?, ?, ?)` 515 res, err := e.Exec( 516 query, 517 comment.OwnerDid, 518 comment.RepoAt, 519 comment.SubmissionId, 520 comment.CommentAt, 521 comment.PullId, 522 comment.Body, 523 ) 524 if err != nil { 525 return 0, err 526 } 527 528 i, err := res.LastInsertId() 529 if err != nil { 530 return 0, err 531 } 532 533 return i, nil 534} 535 536func SetPullState(e Execer, repoAt syntax.ATURI, pullId int, pullState models.PullState) error { 537 _, err := e.Exec( 538 `update pulls set state = ? where repo_at = ? and pull_id = ? and (state <> ? or state <> ?)`, 539 pullState, 540 repoAt, 541 pullId, 542 models.PullDeleted, // only update state of non-deleted pulls 543 models.PullMerged, // only update state of non-merged pulls 544 ) 545 return err 546} 547 548func ClosePull(e Execer, repoAt syntax.ATURI, pullId int) error { 549 err := SetPullState(e, repoAt, pullId, models.PullClosed) 550 return err 551} 552 553func ReopenPull(e Execer, repoAt syntax.ATURI, pullId int) error { 554 err := SetPullState(e, repoAt, pullId, models.PullOpen) 555 return err 556} 557 558func MergePull(e Execer, repoAt syntax.ATURI, pullId int) error { 559 err := SetPullState(e, repoAt, pullId, models.PullMerged) 560 return err 561} 562 563func DeletePull(e Execer, repoAt syntax.ATURI, pullId int) error { 564 err := SetPullState(e, repoAt, pullId, models.PullDeleted) 565 return err 566} 567 568func ResubmitPull(e Execer, pull *models.Pull, newPatch, sourceRev string) error { 569 newRoundNumber := len(pull.Submissions) 570 _, err := e.Exec(` 571 insert into pull_submissions (pull_at, round_number, patch, source_rev) 572 values (?, ?, ?, ?) 573 `, pull.PullAt(), newRoundNumber, newPatch, sourceRev) 574 575 return err 576} 577 578func SetPullParentChangeId(e Execer, parentChangeId string, filters ...filter) error { 579 var conditions []string 580 var args []any 581 582 args = append(args, parentChangeId) 583 584 for _, filter := range filters { 585 conditions = append(conditions, filter.Condition()) 586 args = append(args, filter.Arg()...) 587 } 588 589 whereClause := "" 590 if conditions != nil { 591 whereClause = " where " + strings.Join(conditions, " and ") 592 } 593 594 query := fmt.Sprintf("update pulls set parent_change_id = ? %s", whereClause) 595 _, err := e.Exec(query, args...) 596 597 return err 598} 599 600// Only used when stacking to update contents in the event of a rebase (the interdiff should be empty). 601// otherwise submissions are immutable 602func UpdatePull(e Execer, newPatch, sourceRev string, filters ...filter) error { 603 var conditions []string 604 var args []any 605 606 args = append(args, sourceRev) 607 args = append(args, newPatch) 608 609 for _, filter := range filters { 610 conditions = append(conditions, filter.Condition()) 611 args = append(args, filter.Arg()...) 612 } 613 614 whereClause := "" 615 if conditions != nil { 616 whereClause = " where " + strings.Join(conditions, " and ") 617 } 618 619 query := fmt.Sprintf("update pull_submissions set source_rev = ?, patch = ? %s", whereClause) 620 _, err := e.Exec(query, args...) 621 622 return err 623} 624 625func GetPullCount(e Execer, repoAt syntax.ATURI) (models.PullCount, error) { 626 row := e.QueryRow(` 627 select 628 count(case when state = ? then 1 end) as open_count, 629 count(case when state = ? then 1 end) as merged_count, 630 count(case when state = ? then 1 end) as closed_count, 631 count(case when state = ? then 1 end) as deleted_count 632 from pulls 633 where repo_at = ?`, 634 models.PullOpen, 635 models.PullMerged, 636 models.PullClosed, 637 models.PullDeleted, 638 repoAt, 639 ) 640 641 var count models.PullCount 642 if err := row.Scan(&count.Open, &count.Merged, &count.Closed, &count.Deleted); err != nil { 643 return models.PullCount{Open: 0, Merged: 0, Closed: 0, Deleted: 0}, err 644 } 645 646 return count, nil 647} 648 649// change-id parent-change-id 650// 651// 4 w ,-------- z (TOP) 652// 3 z <----',------- y 653// 2 y <-----',------ x 654// 1 x <------' nil (BOT) 655// 656// `w` is parent of none, so it is the top of the stack 657func GetStack(e Execer, stackId string) (models.Stack, error) { 658 unorderedPulls, err := GetPulls( 659 e, 660 FilterEq("stack_id", stackId), 661 FilterNotEq("state", models.PullDeleted), 662 ) 663 if err != nil { 664 return nil, err 665 } 666 // map of parent-change-id to pull 667 changeIdMap := make(map[string]*models.Pull, len(unorderedPulls)) 668 parentMap := make(map[string]*models.Pull, len(unorderedPulls)) 669 for _, p := range unorderedPulls { 670 changeIdMap[p.ChangeId] = p 671 if p.ParentChangeId != "" { 672 parentMap[p.ParentChangeId] = p 673 } 674 } 675 676 // the top of the stack is the pull that is not a parent of any pull 677 var topPull *models.Pull 678 for _, maybeTop := range unorderedPulls { 679 if _, ok := parentMap[maybeTop.ChangeId]; !ok { 680 topPull = maybeTop 681 break 682 } 683 } 684 685 pulls := []*models.Pull{} 686 for { 687 pulls = append(pulls, topPull) 688 if topPull.ParentChangeId != "" { 689 if next, ok := changeIdMap[topPull.ParentChangeId]; ok { 690 topPull = next 691 } else { 692 return nil, fmt.Errorf("failed to find parent pull request, stack is malformed") 693 } 694 } else { 695 break 696 } 697 } 698 699 return pulls, nil 700} 701 702func GetAbandonedPulls(e Execer, stackId string) ([]*models.Pull, error) { 703 pulls, err := GetPulls( 704 e, 705 FilterEq("stack_id", stackId), 706 FilterEq("state", models.PullDeleted), 707 ) 708 if err != nil { 709 return nil, err 710 } 711 712 return pulls, nil 713}