forked from tangled.org/core
this repo has no description
1package db 2 3import ( 4 "database/sql" 5 "fmt" 6 "log" 7 "sort" 8 "strings" 9 "time" 10 11 "github.com/bluekeyes/go-gitdiff/gitdiff" 12 "github.com/bluesky-social/indigo/atproto/syntax" 13 "tangled.sh/tangled.sh/core/patchutil" 14 "tangled.sh/tangled.sh/core/types" 15) 16 17type PullState int 18 19const ( 20 PullClosed PullState = iota 21 PullOpen 22 PullMerged 23) 24 25func (p PullState) String() string { 26 switch p { 27 case PullOpen: 28 return "open" 29 case PullMerged: 30 return "merged" 31 case PullClosed: 32 return "closed" 33 default: 34 return "closed" 35 } 36} 37 38func (p PullState) IsOpen() bool { 39 return p == PullOpen 40} 41func (p PullState) IsMerged() bool { 42 return p == PullMerged 43} 44func (p PullState) IsClosed() bool { 45 return p == PullClosed 46} 47 48type Pull struct { 49 // ids 50 ID int 51 PullId int 52 53 // at ids 54 RepoAt syntax.ATURI 55 OwnerDid string 56 Rkey string 57 PullAt syntax.ATURI 58 59 // content 60 Title string 61 Body string 62 TargetBranch string 63 State PullState 64 Submissions []*PullSubmission 65 66 // meta 67 Created time.Time 68 PullSource *PullSource 69 70 // optionally, populate this when querying for reverse mappings 71 Repo *Repo 72} 73 74type PullSource struct { 75 Branch string 76 RepoAt *syntax.ATURI 77 78 // optionally populate this for reverse mappings 79 Repo *Repo 80} 81 82type PullSubmission struct { 83 // ids 84 ID int 85 PullId int 86 87 // at ids 88 RepoAt syntax.ATURI 89 90 // content 91 RoundNumber int 92 Patch string 93 Comments []PullComment 94 SourceRev string // include the rev that was used to create this submission: only for branch PRs 95 96 // meta 97 Created time.Time 98} 99 100type PullComment struct { 101 // ids 102 ID int 103 PullId int 104 SubmissionId int 105 106 // at ids 107 RepoAt string 108 OwnerDid string 109 CommentAt string 110 111 // content 112 Body string 113 114 // meta 115 Created time.Time 116} 117 118func (p *Pull) LatestPatch() string { 119 latestSubmission := p.Submissions[p.LastRoundNumber()] 120 return latestSubmission.Patch 121} 122 123func (p *Pull) LastRoundNumber() int { 124 return len(p.Submissions) - 1 125} 126 127func (p *Pull) IsPatchBased() bool { 128 return p.PullSource == nil 129} 130 131func (p *Pull) IsBranchBased() bool { 132 if p.PullSource != nil { 133 if p.PullSource.RepoAt != nil { 134 return p.PullSource.RepoAt == &p.RepoAt 135 } else { 136 // no repo specified 137 return true 138 } 139 } 140 return false 141} 142 143func (p *Pull) IsForkBased() bool { 144 if p.PullSource != nil { 145 if p.PullSource.RepoAt != nil { 146 // make sure repos are different 147 return p.PullSource.RepoAt != &p.RepoAt 148 } 149 } 150 return false 151} 152 153func (s PullSubmission) AsDiff(targetBranch string) ([]*gitdiff.File, error) { 154 patch := s.Patch 155 156 // if format-patch; then extract each patch 157 var diffs []*gitdiff.File 158 if patchutil.IsFormatPatch(patch) { 159 patches, err := patchutil.ExtractPatches(patch) 160 if err != nil { 161 return nil, err 162 } 163 var ps [][]*gitdiff.File 164 for _, p := range patches { 165 ps = append(ps, p.Files) 166 } 167 168 diffs = patchutil.CombineDiff(ps...) 169 } else { 170 d, _, err := gitdiff.Parse(strings.NewReader(patch)) 171 if err != nil { 172 return nil, err 173 } 174 diffs = d 175 } 176 177 return diffs, nil 178} 179 180func (s PullSubmission) AsNiceDiff(targetBranch string) types.NiceDiff { 181 diffs, err := s.AsDiff(targetBranch) 182 if err != nil { 183 log.Println(err) 184 } 185 186 nd := types.NiceDiff{} 187 nd.Commit.Parent = targetBranch 188 189 for _, d := range diffs { 190 ndiff := types.Diff{} 191 ndiff.Name.New = d.NewName 192 ndiff.Name.Old = d.OldName 193 ndiff.IsBinary = d.IsBinary 194 ndiff.IsNew = d.IsNew 195 ndiff.IsDelete = d.IsDelete 196 ndiff.IsCopy = d.IsCopy 197 ndiff.IsRename = d.IsRename 198 199 for _, tf := range d.TextFragments { 200 ndiff.TextFragments = append(ndiff.TextFragments, *tf) 201 for _, l := range tf.Lines { 202 switch l.Op { 203 case gitdiff.OpAdd: 204 nd.Stat.Insertions += 1 205 case gitdiff.OpDelete: 206 nd.Stat.Deletions += 1 207 } 208 } 209 } 210 211 nd.Diff = append(nd.Diff, ndiff) 212 } 213 214 nd.Stat.FilesChanged = len(diffs) 215 216 return nd 217} 218 219func (s PullSubmission) IsFormatPatch() bool { 220 return patchutil.IsFormatPatch(s.Patch) 221} 222 223func (s PullSubmission) AsFormatPatch() []patchutil.FormatPatch { 224 patches, err := patchutil.ExtractPatches(s.Patch) 225 if err != nil { 226 log.Println("error extracting patches from submission:", err) 227 return []patchutil.FormatPatch{} 228 } 229 230 return patches 231} 232 233func NewPull(tx *sql.Tx, pull *Pull) error { 234 defer tx.Rollback() 235 236 _, err := tx.Exec(` 237 insert or ignore into repo_pull_seqs (repo_at, next_pull_id) 238 values (?, 1) 239 `, pull.RepoAt) 240 if err != nil { 241 return err 242 } 243 244 var nextId int 245 err = tx.QueryRow(` 246 update repo_pull_seqs 247 set next_pull_id = next_pull_id + 1 248 where repo_at = ? 249 returning next_pull_id - 1 250 `, pull.RepoAt).Scan(&nextId) 251 if err != nil { 252 return err 253 } 254 255 pull.PullId = nextId 256 pull.State = PullOpen 257 258 var sourceBranch, sourceRepoAt *string 259 if pull.PullSource != nil { 260 sourceBranch = &pull.PullSource.Branch 261 if pull.PullSource.RepoAt != nil { 262 x := pull.PullSource.RepoAt.String() 263 sourceRepoAt = &x 264 } 265 } 266 267 _, err = tx.Exec( 268 ` 269 insert into pulls (repo_at, owner_did, pull_id, title, target_branch, body, rkey, state, source_branch, source_repo_at) 270 values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`, 271 pull.RepoAt, 272 pull.OwnerDid, 273 pull.PullId, 274 pull.Title, 275 pull.TargetBranch, 276 pull.Body, 277 pull.Rkey, 278 pull.State, 279 sourceBranch, 280 sourceRepoAt, 281 ) 282 if err != nil { 283 return err 284 } 285 286 _, err = tx.Exec(` 287 insert into pull_submissions (pull_id, repo_at, round_number, patch, source_rev) 288 values (?, ?, ?, ?, ?) 289 `, pull.PullId, pull.RepoAt, 0, pull.Submissions[0].Patch, pull.Submissions[0].SourceRev) 290 if err != nil { 291 return err 292 } 293 294 if err := tx.Commit(); err != nil { 295 return err 296 } 297 298 return nil 299} 300 301func SetPullAt(e Execer, repoAt syntax.ATURI, pullId int, pullAt string) error { 302 _, err := e.Exec(`update pulls set pull_at = ? where repo_at = ? and pull_id = ?`, pullAt, repoAt, pullId) 303 return err 304} 305 306func GetPullAt(e Execer, repoAt syntax.ATURI, pullId int) (string, error) { 307 var pullAt string 308 err := e.QueryRow(`select pull_at from pulls where repo_at = ? and pull_id = ?`, repoAt, pullId).Scan(&pullAt) 309 return pullAt, err 310} 311 312func NextPullId(e Execer, repoAt syntax.ATURI) (int, error) { 313 var pullId int 314 err := e.QueryRow(`select next_pull_id from repo_pull_seqs where repo_at = ?`, repoAt).Scan(&pullId) 315 return pullId - 1, err 316} 317 318func GetPulls(e Execer, repoAt syntax.ATURI, state PullState) ([]*Pull, error) { 319 pulls := make(map[int]*Pull) 320 321 rows, err := e.Query(` 322 select 323 owner_did, 324 pull_id, 325 created, 326 title, 327 state, 328 target_branch, 329 pull_at, 330 body, 331 rkey, 332 source_branch, 333 source_repo_at 334 from 335 pulls 336 where 337 repo_at = ? and state = ?`, repoAt, state) 338 if err != nil { 339 return nil, err 340 } 341 defer rows.Close() 342 343 for rows.Next() { 344 var pull Pull 345 var createdAt string 346 var sourceBranch, sourceRepoAt sql.NullString 347 err := rows.Scan( 348 &pull.OwnerDid, 349 &pull.PullId, 350 &createdAt, 351 &pull.Title, 352 &pull.State, 353 &pull.TargetBranch, 354 &pull.PullAt, 355 &pull.Body, 356 &pull.Rkey, 357 &sourceBranch, 358 &sourceRepoAt, 359 ) 360 if err != nil { 361 return nil, err 362 } 363 364 createdTime, err := time.Parse(time.RFC3339, createdAt) 365 if err != nil { 366 return nil, err 367 } 368 pull.Created = createdTime 369 370 if sourceBranch.Valid { 371 pull.PullSource = &PullSource{ 372 Branch: sourceBranch.String, 373 } 374 if sourceRepoAt.Valid { 375 sourceRepoAtParsed, err := syntax.ParseATURI(sourceRepoAt.String) 376 if err != nil { 377 return nil, err 378 } 379 pull.PullSource.RepoAt = &sourceRepoAtParsed 380 } 381 } 382 383 pulls[pull.PullId] = &pull 384 } 385 386 // get latest round no. for each pull 387 inClause := strings.TrimSuffix(strings.Repeat("?, ", len(pulls)), ", ") 388 submissionsQuery := fmt.Sprintf(` 389 select 390 id, pull_id, round_number 391 from 392 pull_submissions 393 where 394 repo_at = ? and pull_id in (%s) 395 `, inClause) 396 397 args := make([]any, len(pulls)+1) 398 args[0] = repoAt.String() 399 idx := 1 400 for _, p := range pulls { 401 args[idx] = p.PullId 402 idx += 1 403 } 404 submissionsRows, err := e.Query(submissionsQuery, args...) 405 if err != nil { 406 return nil, err 407 } 408 defer submissionsRows.Close() 409 410 for submissionsRows.Next() { 411 var s PullSubmission 412 err := submissionsRows.Scan( 413 &s.ID, 414 &s.PullId, 415 &s.RoundNumber, 416 ) 417 if err != nil { 418 return nil, err 419 } 420 421 if p, ok := pulls[s.PullId]; ok { 422 p.Submissions = make([]*PullSubmission, s.RoundNumber+1) 423 p.Submissions[s.RoundNumber] = &s 424 } 425 } 426 if err := rows.Err(); err != nil { 427 return nil, err 428 } 429 430 // get comment count on latest submission on each pull 431 inClause = strings.TrimSuffix(strings.Repeat("?, ", len(pulls)), ", ") 432 commentsQuery := fmt.Sprintf(` 433 select 434 count(id), pull_id 435 from 436 pull_comments 437 where 438 submission_id in (%s) 439 group by 440 submission_id 441 `, inClause) 442 443 args = []any{} 444 for _, p := range pulls { 445 args = append(args, p.Submissions[p.LastRoundNumber()].ID) 446 } 447 commentsRows, err := e.Query(commentsQuery, args...) 448 if err != nil { 449 return nil, err 450 } 451 defer commentsRows.Close() 452 453 for commentsRows.Next() { 454 var commentCount, pullId int 455 err := commentsRows.Scan( 456 &commentCount, 457 &pullId, 458 ) 459 if err != nil { 460 return nil, err 461 } 462 if p, ok := pulls[pullId]; ok { 463 p.Submissions[p.LastRoundNumber()].Comments = make([]PullComment, commentCount) 464 } 465 } 466 if err := rows.Err(); err != nil { 467 return nil, err 468 } 469 470 orderedByDate := []*Pull{} 471 for _, p := range pulls { 472 orderedByDate = append(orderedByDate, p) 473 } 474 sort.Slice(orderedByDate, func(i, j int) bool { 475 return orderedByDate[i].Created.After(orderedByDate[j].Created) 476 }) 477 478 return orderedByDate, nil 479} 480 481func GetPull(e Execer, repoAt syntax.ATURI, pullId int) (*Pull, error) { 482 query := ` 483 select 484 owner_did, 485 pull_id, 486 created, 487 title, 488 state, 489 target_branch, 490 pull_at, 491 repo_at, 492 body, 493 rkey, 494 source_branch, 495 source_repo_at 496 from 497 pulls 498 where 499 repo_at = ? and pull_id = ? 500 ` 501 row := e.QueryRow(query, repoAt, pullId) 502 503 var pull Pull 504 var createdAt string 505 var sourceBranch, sourceRepoAt sql.NullString 506 err := row.Scan( 507 &pull.OwnerDid, 508 &pull.PullId, 509 &createdAt, 510 &pull.Title, 511 &pull.State, 512 &pull.TargetBranch, 513 &pull.PullAt, 514 &pull.RepoAt, 515 &pull.Body, 516 &pull.Rkey, 517 &sourceBranch, 518 &sourceRepoAt, 519 ) 520 if err != nil { 521 return nil, err 522 } 523 524 createdTime, err := time.Parse(time.RFC3339, createdAt) 525 if err != nil { 526 return nil, err 527 } 528 pull.Created = createdTime 529 530 // populate source 531 if sourceBranch.Valid { 532 pull.PullSource = &PullSource{ 533 Branch: sourceBranch.String, 534 } 535 if sourceRepoAt.Valid { 536 sourceRepoAtParsed, err := syntax.ParseATURI(sourceRepoAt.String) 537 if err != nil { 538 return nil, err 539 } 540 pull.PullSource.RepoAt = &sourceRepoAtParsed 541 } 542 } 543 544 submissionsQuery := ` 545 select 546 id, pull_id, repo_at, round_number, patch, created, source_rev 547 from 548 pull_submissions 549 where 550 repo_at = ? and pull_id = ? 551 ` 552 submissionsRows, err := e.Query(submissionsQuery, repoAt, pullId) 553 if err != nil { 554 return nil, err 555 } 556 defer submissionsRows.Close() 557 558 submissionsMap := make(map[int]*PullSubmission) 559 560 for submissionsRows.Next() { 561 var submission PullSubmission 562 var submissionCreatedStr string 563 var submissionSourceRev sql.NullString 564 err := submissionsRows.Scan( 565 &submission.ID, 566 &submission.PullId, 567 &submission.RepoAt, 568 &submission.RoundNumber, 569 &submission.Patch, 570 &submissionCreatedStr, 571 &submissionSourceRev, 572 ) 573 if err != nil { 574 return nil, err 575 } 576 577 submissionCreatedTime, err := time.Parse(time.RFC3339, submissionCreatedStr) 578 if err != nil { 579 return nil, err 580 } 581 submission.Created = submissionCreatedTime 582 583 if submissionSourceRev.Valid { 584 submission.SourceRev = submissionSourceRev.String 585 } 586 587 submissionsMap[submission.ID] = &submission 588 } 589 if err = submissionsRows.Close(); err != nil { 590 return nil, err 591 } 592 if len(submissionsMap) == 0 { 593 return &pull, nil 594 } 595 596 var args []any 597 for k := range submissionsMap { 598 args = append(args, k) 599 } 600 inClause := strings.TrimSuffix(strings.Repeat("?, ", len(submissionsMap)), ", ") 601 commentsQuery := fmt.Sprintf(` 602 select 603 id, 604 pull_id, 605 submission_id, 606 repo_at, 607 owner_did, 608 comment_at, 609 body, 610 created 611 from 612 pull_comments 613 where 614 submission_id IN (%s) 615 order by 616 created asc 617 `, inClause) 618 commentsRows, err := e.Query(commentsQuery, args...) 619 if err != nil { 620 return nil, err 621 } 622 defer commentsRows.Close() 623 624 for commentsRows.Next() { 625 var comment PullComment 626 var commentCreatedStr string 627 err := commentsRows.Scan( 628 &comment.ID, 629 &comment.PullId, 630 &comment.SubmissionId, 631 &comment.RepoAt, 632 &comment.OwnerDid, 633 &comment.CommentAt, 634 &comment.Body, 635 &commentCreatedStr, 636 ) 637 if err != nil { 638 return nil, err 639 } 640 641 commentCreatedTime, err := time.Parse(time.RFC3339, commentCreatedStr) 642 if err != nil { 643 return nil, err 644 } 645 comment.Created = commentCreatedTime 646 647 // Add the comment to its submission 648 if submission, ok := submissionsMap[comment.SubmissionId]; ok { 649 submission.Comments = append(submission.Comments, comment) 650 } 651 652 } 653 if err = commentsRows.Err(); err != nil { 654 return nil, err 655 } 656 657 var pullSourceRepo *Repo 658 if pull.PullSource != nil { 659 if pull.PullSource.RepoAt != nil { 660 pullSourceRepo, err = GetRepoByAtUri(e, pull.PullSource.RepoAt.String()) 661 if err != nil { 662 log.Printf("failed to get repo by at uri: %v", err) 663 } else { 664 pull.PullSource.Repo = pullSourceRepo 665 } 666 } 667 } 668 669 pull.Submissions = make([]*PullSubmission, len(submissionsMap)) 670 for _, submission := range submissionsMap { 671 pull.Submissions[submission.RoundNumber] = submission 672 } 673 674 return &pull, nil 675} 676 677// timeframe here is directly passed into the sql query filter, and any 678// timeframe in the past should be negative; e.g.: "-3 months" 679func GetPullsByOwnerDid(e Execer, did, timeframe string) ([]Pull, error) { 680 var pulls []Pull 681 682 rows, err := e.Query(` 683 select 684 p.owner_did, 685 p.repo_at, 686 p.pull_id, 687 p.created, 688 p.title, 689 p.state, 690 r.did, 691 r.name, 692 r.knot, 693 r.rkey, 694 r.created 695 from 696 pulls p 697 join 698 repos r on p.repo_at = r.at_uri 699 where 700 p.owner_did = ? and p.created >= date ('now', ?) 701 order by 702 p.created desc`, did, timeframe) 703 if err != nil { 704 return nil, err 705 } 706 defer rows.Close() 707 708 for rows.Next() { 709 var pull Pull 710 var repo Repo 711 var pullCreatedAt, repoCreatedAt string 712 err := rows.Scan( 713 &pull.OwnerDid, 714 &pull.RepoAt, 715 &pull.PullId, 716 &pullCreatedAt, 717 &pull.Title, 718 &pull.State, 719 &repo.Did, 720 &repo.Name, 721 &repo.Knot, 722 &repo.Rkey, 723 &repoCreatedAt, 724 ) 725 if err != nil { 726 return nil, err 727 } 728 729 pullCreatedTime, err := time.Parse(time.RFC3339, pullCreatedAt) 730 if err != nil { 731 return nil, err 732 } 733 pull.Created = pullCreatedTime 734 735 repoCreatedTime, err := time.Parse(time.RFC3339, repoCreatedAt) 736 if err != nil { 737 return nil, err 738 } 739 repo.Created = repoCreatedTime 740 741 pull.Repo = &repo 742 743 pulls = append(pulls, pull) 744 } 745 746 if err := rows.Err(); err != nil { 747 return nil, err 748 } 749 750 return pulls, nil 751} 752 753func NewPullComment(e Execer, comment *PullComment) (int64, error) { 754 query := `insert into pull_comments (owner_did, repo_at, submission_id, comment_at, pull_id, body) values (?, ?, ?, ?, ?, ?)` 755 res, err := e.Exec( 756 query, 757 comment.OwnerDid, 758 comment.RepoAt, 759 comment.SubmissionId, 760 comment.CommentAt, 761 comment.PullId, 762 comment.Body, 763 ) 764 if err != nil { 765 return 0, err 766 } 767 768 i, err := res.LastInsertId() 769 if err != nil { 770 return 0, err 771 } 772 773 return i, nil 774} 775 776func SetPullState(e Execer, repoAt syntax.ATURI, pullId int, pullState PullState) error { 777 _, err := e.Exec(`update pulls set state = ? where repo_at = ? and pull_id = ?`, pullState, repoAt, pullId) 778 return err 779} 780 781func ClosePull(e Execer, repoAt syntax.ATURI, pullId int) error { 782 err := SetPullState(e, repoAt, pullId, PullClosed) 783 return err 784} 785 786func ReopenPull(e Execer, repoAt syntax.ATURI, pullId int) error { 787 err := SetPullState(e, repoAt, pullId, PullOpen) 788 return err 789} 790 791func MergePull(e Execer, repoAt syntax.ATURI, pullId int) error { 792 err := SetPullState(e, repoAt, pullId, PullMerged) 793 return err 794} 795 796func ResubmitPull(e Execer, pull *Pull, newPatch, sourceRev string) error { 797 newRoundNumber := len(pull.Submissions) 798 _, err := e.Exec(` 799 insert into pull_submissions (pull_id, repo_at, round_number, patch, source_rev) 800 values (?, ?, ?, ?, ?) 801 `, pull.PullId, pull.RepoAt, newRoundNumber, newPatch, sourceRev) 802 803 return err 804} 805 806type PullCount struct { 807 Open int 808 Merged int 809 Closed int 810} 811 812func GetPullCount(e Execer, repoAt syntax.ATURI) (PullCount, error) { 813 row := e.QueryRow(` 814 select 815 count(case when state = ? then 1 end) as open_count, 816 count(case when state = ? then 1 end) as merged_count, 817 count(case when state = ? then 1 end) as closed_count 818 from pulls 819 where repo_at = ?`, 820 PullOpen, 821 PullMerged, 822 PullClosed, 823 repoAt, 824 ) 825 826 var count PullCount 827 if err := row.Scan(&count.Open, &count.Merged, &count.Closed); err != nil { 828 return PullCount{0, 0, 0}, err 829 } 830 831 return count, nil 832}