forked from tangled.org/core
this repo has no description
1package db 2 3import ( 4 "database/sql" 5 "fmt" 6 "maps" 7 "slices" 8 "sort" 9 "strings" 10 "time" 11 12 "github.com/bluesky-social/indigo/atproto/syntax" 13 "tangled.org/core/appview/models" 14) 15 16func NewPull(tx *sql.Tx, pull *models.Pull) error { 17 _, err := tx.Exec(` 18 insert or ignore into repo_pull_seqs (repo_at, next_pull_id) 19 values (?, 1) 20 `, pull.RepoAt) 21 if err != nil { 22 return err 23 } 24 25 var nextId int 26 err = tx.QueryRow(` 27 update repo_pull_seqs 28 set next_pull_id = next_pull_id + 1 29 where repo_at = ? 30 returning next_pull_id - 1 31 `, pull.RepoAt).Scan(&nextId) 32 if err != nil { 33 return err 34 } 35 36 pull.PullId = nextId 37 pull.State = models.PullOpen 38 39 var sourceBranch, sourceRepoAt *string 40 if pull.PullSource != nil { 41 sourceBranch = &pull.PullSource.Branch 42 if pull.PullSource.RepoAt != nil { 43 x := pull.PullSource.RepoAt.String() 44 sourceRepoAt = &x 45 } 46 } 47 48 var stackId, changeId, parentChangeId *string 49 if pull.StackId != "" { 50 stackId = &pull.StackId 51 } 52 if pull.ChangeId != "" { 53 changeId = &pull.ChangeId 54 } 55 if pull.ParentChangeId != "" { 56 parentChangeId = &pull.ParentChangeId 57 } 58 59 result, err := tx.Exec( 60 ` 61 insert into pulls ( 62 repo_at, owner_did, pull_id, title, target_branch, body, rkey, state, source_branch, source_repo_at, stack_id, change_id, parent_change_id 63 ) 64 values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`, 65 pull.RepoAt, 66 pull.OwnerDid, 67 pull.PullId, 68 pull.Title, 69 pull.TargetBranch, 70 pull.Body, 71 pull.Rkey, 72 pull.State, 73 sourceBranch, 74 sourceRepoAt, 75 stackId, 76 changeId, 77 parentChangeId, 78 ) 79 if err != nil { 80 return err 81 } 82 83 // Set the database primary key ID 84 id, err := result.LastInsertId() 85 if err != nil { 86 return err 87 } 88 pull.ID = int(id) 89 90 _, err = tx.Exec(` 91 insert into pull_submissions (pull_at, round_number, patch, source_rev) 92 values (?, ?, ?, ?) 93 `, pull.PullAt(), 0, pull.Submissions[0].Patch, pull.Submissions[0].SourceRev) 94 return err 95} 96 97func GetPullAt(e Execer, repoAt syntax.ATURI, pullId int) (syntax.ATURI, error) { 98 pull, err := GetPull(e, repoAt, pullId) 99 if err != nil { 100 return "", err 101 } 102 return pull.PullAt(), err 103} 104 105func NextPullId(e Execer, repoAt syntax.ATURI) (int, error) { 106 var pullId int 107 err := e.QueryRow(`select next_pull_id from repo_pull_seqs where repo_at = ?`, repoAt).Scan(&pullId) 108 return pullId - 1, err 109} 110 111func GetPullsWithLimit(e Execer, limit int, filters ...filter) ([]*models.Pull, error) { 112 pulls := make(map[syntax.ATURI]*models.Pull) 113 114 var conditions []string 115 var args []any 116 for _, filter := range filters { 117 conditions = append(conditions, filter.Condition()) 118 args = append(args, filter.Arg()...) 119 } 120 121 whereClause := "" 122 if conditions != nil { 123 whereClause = " where " + strings.Join(conditions, " and ") 124 } 125 limitClause := "" 126 if limit != 0 { 127 limitClause = fmt.Sprintf(" limit %d ", limit) 128 } 129 130 query := fmt.Sprintf(` 131 select 132 id, 133 owner_did, 134 repo_at, 135 pull_id, 136 created, 137 title, 138 state, 139 target_branch, 140 body, 141 rkey, 142 source_branch, 143 source_repo_at, 144 stack_id, 145 change_id, 146 parent_change_id 147 from 148 pulls 149 %s 150 order by 151 created desc 152 %s 153 `, whereClause, limitClause) 154 155 rows, err := e.Query(query, args...) 156 if err != nil { 157 return nil, err 158 } 159 defer rows.Close() 160 161 for rows.Next() { 162 var pull models.Pull 163 var createdAt string 164 var sourceBranch, sourceRepoAt, stackId, changeId, parentChangeId sql.NullString 165 err := rows.Scan( 166 &pull.ID, 167 &pull.OwnerDid, 168 &pull.RepoAt, 169 &pull.PullId, 170 &createdAt, 171 &pull.Title, 172 &pull.State, 173 &pull.TargetBranch, 174 &pull.Body, 175 &pull.Rkey, 176 &sourceBranch, 177 &sourceRepoAt, 178 &stackId, 179 &changeId, 180 &parentChangeId, 181 ) 182 if err != nil { 183 return nil, err 184 } 185 186 createdTime, err := time.Parse(time.RFC3339, createdAt) 187 if err != nil { 188 return nil, err 189 } 190 pull.Created = createdTime 191 192 if sourceBranch.Valid { 193 pull.PullSource = &models.PullSource{ 194 Branch: sourceBranch.String, 195 } 196 if sourceRepoAt.Valid { 197 sourceRepoAtParsed, err := syntax.ParseATURI(sourceRepoAt.String) 198 if err != nil { 199 return nil, err 200 } 201 pull.PullSource.RepoAt = &sourceRepoAtParsed 202 } 203 } 204 205 if stackId.Valid { 206 pull.StackId = stackId.String 207 } 208 if changeId.Valid { 209 pull.ChangeId = changeId.String 210 } 211 if parentChangeId.Valid { 212 pull.ParentChangeId = parentChangeId.String 213 } 214 215 pulls[pull.PullAt()] = &pull 216 } 217 218 var pullAts []syntax.ATURI 219 for _, p := range pulls { 220 pullAts = append(pullAts, p.PullAt()) 221 } 222 submissionsMap, err := GetPullSubmissions(e, FilterIn("pull_at", pullAts)) 223 if err != nil { 224 return nil, fmt.Errorf("failed to get submissions: %w", err) 225 } 226 227 for pullAt, submissions := range submissionsMap { 228 if p, ok := pulls[pullAt]; ok { 229 p.Submissions = submissions 230 } 231 } 232 233 orderedByPullId := []*models.Pull{} 234 for _, p := range pulls { 235 orderedByPullId = append(orderedByPullId, p) 236 } 237 sort.Slice(orderedByPullId, func(i, j int) bool { 238 return orderedByPullId[i].PullId > orderedByPullId[j].PullId 239 }) 240 241 return orderedByPullId, nil 242} 243 244func GetPulls(e Execer, filters ...filter) ([]*models.Pull, error) { 245 return GetPullsWithLimit(e, 0, filters...) 246} 247 248func GetPull(e Execer, repoAt syntax.ATURI, pullId int) (*models.Pull, error) { 249 pulls, err := GetPullsWithLimit(e, 1, FilterEq("repo_at", repoAt), FilterEq("pull_id", pullId)) 250 if err != nil { 251 return nil, err 252 } 253 if pulls == nil { 254 return nil, sql.ErrNoRows 255 } 256 257 return pulls[0], nil 258} 259 260// mapping from pull -> pull submissions 261func GetPullSubmissions(e Execer, filters ...filter) (map[syntax.ATURI][]*models.PullSubmission, error) { 262 var conditions []string 263 var args []any 264 for _, filter := range filters { 265 conditions = append(conditions, filter.Condition()) 266 args = append(args, filter.Arg()...) 267 } 268 269 whereClause := "" 270 if conditions != nil { 271 whereClause = " where " + strings.Join(conditions, " and ") 272 } 273 274 query := fmt.Sprintf(` 275 select 276 id, 277 pull_at, 278 round_number, 279 patch, 280 created, 281 source_rev 282 from 283 pull_submissions 284 %s 285 order by 286 round_number asc 287 `, whereClause) 288 289 rows, err := e.Query(query, args...) 290 if err != nil { 291 return nil, err 292 } 293 defer rows.Close() 294 295 submissionMap := make(map[int]*models.PullSubmission) 296 297 for rows.Next() { 298 var submission models.PullSubmission 299 var createdAt string 300 var sourceRev sql.NullString 301 err := rows.Scan( 302 &submission.ID, 303 &submission.PullAt, 304 &submission.RoundNumber, 305 &submission.Patch, 306 &createdAt, 307 &sourceRev, 308 ) 309 if err != nil { 310 return nil, err 311 } 312 313 createdTime, err := time.Parse(time.RFC3339, createdAt) 314 if err != nil { 315 return nil, err 316 } 317 submission.Created = createdTime 318 319 if sourceRev.Valid { 320 submission.SourceRev = sourceRev.String 321 } 322 323 submissionMap[submission.ID] = &submission 324 } 325 326 if err := rows.Err(); err != nil { 327 return nil, err 328 } 329 330 // Get comments for all submissions using GetPullComments 331 submissionIds := slices.Collect(maps.Keys(submissionMap)) 332 comments, err := GetPullComments(e, FilterIn("submission_id", submissionIds)) 333 if err != nil { 334 return nil, err 335 } 336 for _, comment := range comments { 337 if submission, ok := submissionMap[comment.SubmissionId]; ok { 338 submission.Comments = append(submission.Comments, comment) 339 } 340 } 341 342 // order the submissions by pull_at 343 m := make(map[syntax.ATURI][]*models.PullSubmission) 344 for _, s := range submissionMap { 345 m[s.PullAt] = append(m[s.PullAt], s) 346 } 347 348 return m, nil 349} 350 351func GetPullComments(e Execer, filters ...filter) ([]models.PullComment, error) { 352 var conditions []string 353 var args []any 354 for _, filter := range filters { 355 conditions = append(conditions, filter.Condition()) 356 args = append(args, filter.Arg()...) 357 } 358 359 whereClause := "" 360 if conditions != nil { 361 whereClause = " where " + strings.Join(conditions, " and ") 362 } 363 364 query := fmt.Sprintf(` 365 select 366 id, 367 pull_id, 368 submission_id, 369 repo_at, 370 owner_did, 371 comment_at, 372 body, 373 created 374 from 375 pull_comments 376 %s 377 order by 378 created asc 379 `, whereClause) 380 381 rows, err := e.Query(query, args...) 382 if err != nil { 383 return nil, err 384 } 385 defer rows.Close() 386 387 var comments []models.PullComment 388 for rows.Next() { 389 var comment models.PullComment 390 var createdAt string 391 err := rows.Scan( 392 &comment.ID, 393 &comment.PullId, 394 &comment.SubmissionId, 395 &comment.RepoAt, 396 &comment.OwnerDid, 397 &comment.CommentAt, 398 &comment.Body, 399 &createdAt, 400 ) 401 if err != nil { 402 return nil, err 403 } 404 405 if t, err := time.Parse(time.RFC3339, createdAt); err == nil { 406 comment.Created = t 407 } 408 409 comments = append(comments, comment) 410 } 411 412 if err := rows.Err(); err != nil { 413 return nil, err 414 } 415 416 return comments, nil 417} 418 419// timeframe here is directly passed into the sql query filter, and any 420// timeframe in the past should be negative; e.g.: "-3 months" 421func GetPullsByOwnerDid(e Execer, did, timeframe string) ([]models.Pull, error) { 422 var pulls []models.Pull 423 424 rows, err := e.Query(` 425 select 426 p.owner_did, 427 p.repo_at, 428 p.pull_id, 429 p.created, 430 p.title, 431 p.state, 432 r.did, 433 r.name, 434 r.knot, 435 r.rkey, 436 r.created 437 from 438 pulls p 439 join 440 repos r on p.repo_at = r.at_uri 441 where 442 p.owner_did = ? and p.created >= date ('now', ?) 443 order by 444 p.created desc`, did, timeframe) 445 if err != nil { 446 return nil, err 447 } 448 defer rows.Close() 449 450 for rows.Next() { 451 var pull models.Pull 452 var repo models.Repo 453 var pullCreatedAt, repoCreatedAt string 454 err := rows.Scan( 455 &pull.OwnerDid, 456 &pull.RepoAt, 457 &pull.PullId, 458 &pullCreatedAt, 459 &pull.Title, 460 &pull.State, 461 &repo.Did, 462 &repo.Name, 463 &repo.Knot, 464 &repo.Rkey, 465 &repoCreatedAt, 466 ) 467 if err != nil { 468 return nil, err 469 } 470 471 pullCreatedTime, err := time.Parse(time.RFC3339, pullCreatedAt) 472 if err != nil { 473 return nil, err 474 } 475 pull.Created = pullCreatedTime 476 477 repoCreatedTime, err := time.Parse(time.RFC3339, repoCreatedAt) 478 if err != nil { 479 return nil, err 480 } 481 repo.Created = repoCreatedTime 482 483 pull.Repo = &repo 484 485 pulls = append(pulls, pull) 486 } 487 488 if err := rows.Err(); err != nil { 489 return nil, err 490 } 491 492 return pulls, nil 493} 494 495func NewPullComment(e Execer, comment *models.PullComment) (int64, error) { 496 query := `insert into pull_comments (owner_did, repo_at, submission_id, comment_at, pull_id, body) values (?, ?, ?, ?, ?, ?)` 497 res, err := e.Exec( 498 query, 499 comment.OwnerDid, 500 comment.RepoAt, 501 comment.SubmissionId, 502 comment.CommentAt, 503 comment.PullId, 504 comment.Body, 505 ) 506 if err != nil { 507 return 0, err 508 } 509 510 i, err := res.LastInsertId() 511 if err != nil { 512 return 0, err 513 } 514 515 return i, nil 516} 517 518func SetPullState(e Execer, repoAt syntax.ATURI, pullId int, pullState models.PullState) error { 519 _, err := e.Exec( 520 `update pulls set state = ? where repo_at = ? and pull_id = ? and (state <> ? or state <> ?)`, 521 pullState, 522 repoAt, 523 pullId, 524 models.PullDeleted, // only update state of non-deleted pulls 525 models.PullMerged, // only update state of non-merged pulls 526 ) 527 return err 528} 529 530func ClosePull(e Execer, repoAt syntax.ATURI, pullId int) error { 531 err := SetPullState(e, repoAt, pullId, models.PullClosed) 532 return err 533} 534 535func ReopenPull(e Execer, repoAt syntax.ATURI, pullId int) error { 536 err := SetPullState(e, repoAt, pullId, models.PullOpen) 537 return err 538} 539 540func MergePull(e Execer, repoAt syntax.ATURI, pullId int) error { 541 err := SetPullState(e, repoAt, pullId, models.PullMerged) 542 return err 543} 544 545func DeletePull(e Execer, repoAt syntax.ATURI, pullId int) error { 546 err := SetPullState(e, repoAt, pullId, models.PullDeleted) 547 return err 548} 549 550func ResubmitPull(e Execer, pull *models.Pull, newPatch, sourceRev string) error { 551 newRoundNumber := len(pull.Submissions) 552 _, err := e.Exec(` 553 insert into pull_submissions (pull_at, round_number, patch, source_rev) 554 values (?, ?, ?, ?) 555 `, pull.PullAt(), newRoundNumber, newPatch, sourceRev) 556 557 return err 558} 559 560func SetPullParentChangeId(e Execer, parentChangeId string, filters ...filter) error { 561 var conditions []string 562 var args []any 563 564 args = append(args, parentChangeId) 565 566 for _, filter := range filters { 567 conditions = append(conditions, filter.Condition()) 568 args = append(args, filter.Arg()...) 569 } 570 571 whereClause := "" 572 if conditions != nil { 573 whereClause = " where " + strings.Join(conditions, " and ") 574 } 575 576 query := fmt.Sprintf("update pulls set parent_change_id = ? %s", whereClause) 577 _, err := e.Exec(query, args...) 578 579 return err 580} 581 582// Only used when stacking to update contents in the event of a rebase (the interdiff should be empty). 583// otherwise submissions are immutable 584func UpdatePull(e Execer, newPatch, sourceRev string, filters ...filter) error { 585 var conditions []string 586 var args []any 587 588 args = append(args, sourceRev) 589 args = append(args, newPatch) 590 591 for _, filter := range filters { 592 conditions = append(conditions, filter.Condition()) 593 args = append(args, filter.Arg()...) 594 } 595 596 whereClause := "" 597 if conditions != nil { 598 whereClause = " where " + strings.Join(conditions, " and ") 599 } 600 601 query := fmt.Sprintf("update pull_submissions set source_rev = ?, patch = ? %s", whereClause) 602 _, err := e.Exec(query, args...) 603 604 return err 605} 606 607func GetPullCount(e Execer, repoAt syntax.ATURI) (models.PullCount, error) { 608 row := e.QueryRow(` 609 select 610 count(case when state = ? then 1 end) as open_count, 611 count(case when state = ? then 1 end) as merged_count, 612 count(case when state = ? then 1 end) as closed_count, 613 count(case when state = ? then 1 end) as deleted_count 614 from pulls 615 where repo_at = ?`, 616 models.PullOpen, 617 models.PullMerged, 618 models.PullClosed, 619 models.PullDeleted, 620 repoAt, 621 ) 622 623 var count models.PullCount 624 if err := row.Scan(&count.Open, &count.Merged, &count.Closed, &count.Deleted); err != nil { 625 return models.PullCount{Open: 0, Merged: 0, Closed: 0, Deleted: 0}, err 626 } 627 628 return count, nil 629} 630 631// change-id parent-change-id 632// 633// 4 w ,-------- z (TOP) 634// 3 z <----',------- y 635// 2 y <-----',------ x 636// 1 x <------' nil (BOT) 637// 638// `w` is parent of none, so it is the top of the stack 639func GetStack(e Execer, stackId string) (models.Stack, error) { 640 unorderedPulls, err := GetPulls( 641 e, 642 FilterEq("stack_id", stackId), 643 FilterNotEq("state", models.PullDeleted), 644 ) 645 if err != nil { 646 return nil, err 647 } 648 // map of parent-change-id to pull 649 changeIdMap := make(map[string]*models.Pull, len(unorderedPulls)) 650 parentMap := make(map[string]*models.Pull, len(unorderedPulls)) 651 for _, p := range unorderedPulls { 652 changeIdMap[p.ChangeId] = p 653 if p.ParentChangeId != "" { 654 parentMap[p.ParentChangeId] = p 655 } 656 } 657 658 // the top of the stack is the pull that is not a parent of any pull 659 var topPull *models.Pull 660 for _, maybeTop := range unorderedPulls { 661 if _, ok := parentMap[maybeTop.ChangeId]; !ok { 662 topPull = maybeTop 663 break 664 } 665 } 666 667 pulls := []*models.Pull{} 668 for { 669 pulls = append(pulls, topPull) 670 if topPull.ParentChangeId != "" { 671 if next, ok := changeIdMap[topPull.ParentChangeId]; ok { 672 topPull = next 673 } else { 674 return nil, fmt.Errorf("failed to find parent pull request, stack is malformed") 675 } 676 } else { 677 break 678 } 679 } 680 681 return pulls, nil 682} 683 684func GetAbandonedPulls(e Execer, stackId string) ([]*models.Pull, error) { 685 pulls, err := GetPulls( 686 e, 687 FilterEq("stack_id", stackId), 688 FilterEq("state", models.PullDeleted), 689 ) 690 if err != nil { 691 return nil, err 692 } 693 694 return pulls, nil 695}