forked from tangled.org/core
this repo has no description
1package db 2 3import ( 4 "database/sql" 5 "fmt" 6 "log" 7 "sort" 8 "strings" 9 "time" 10 11 "github.com/bluekeyes/go-gitdiff/gitdiff" 12 "github.com/bluesky-social/indigo/atproto/syntax" 13 "tangled.sh/tangled.sh/core/patchutil" 14 "tangled.sh/tangled.sh/core/types" 15) 16 17type PullState int 18 19const ( 20 PullClosed PullState = iota 21 PullOpen 22 PullMerged 23) 24 25func (p PullState) String() string { 26 switch p { 27 case PullOpen: 28 return "open" 29 case PullMerged: 30 return "merged" 31 case PullClosed: 32 return "closed" 33 default: 34 return "closed" 35 } 36} 37 38func (p PullState) IsOpen() bool { 39 return p == PullOpen 40} 41func (p PullState) IsMerged() bool { 42 return p == PullMerged 43} 44func (p PullState) IsClosed() bool { 45 return p == PullClosed 46} 47 48type Pull struct { 49 // ids 50 ID int 51 PullId int 52 53 // at ids 54 RepoAt syntax.ATURI 55 OwnerDid string 56 Rkey string 57 PullAt syntax.ATURI 58 59 // content 60 Title string 61 Body string 62 TargetBranch string 63 State PullState 64 Submissions []*PullSubmission 65 66 // meta 67 Created time.Time 68 PullSource *PullSource 69 70 // optionally, populate this when querying for reverse mappings 71 Repo *Repo 72} 73 74type PullSource struct { 75 Branch string 76 RepoAt *syntax.ATURI 77 78 // optionally populate this for reverse mappings 79 Repo *Repo 80} 81 82type PullSubmission struct { 83 // ids 84 ID int 85 PullId int 86 87 // at ids 88 RepoAt syntax.ATURI 89 90 // content 91 RoundNumber int 92 Patch string 93 Comments []PullComment 94 SourceRev string // include the rev that was used to create this submission: only for branch PRs 95 96 // meta 97 Created time.Time 98} 99 100type PullComment struct { 101 // ids 102 ID int 103 PullId int 104 SubmissionId int 105 106 // at ids 107 RepoAt string 108 OwnerDid string 109 CommentAt string 110 111 // content 112 Body string 113 114 // meta 115 Created time.Time 116} 117 118func (p *Pull) LatestPatch() string { 119 latestSubmission := p.Submissions[p.LastRoundNumber()] 120 return latestSubmission.Patch 121} 122 123func (p *Pull) LastRoundNumber() int { 124 return len(p.Submissions) - 1 125} 126 127func (p *Pull) IsPatchBased() bool { 128 return p.PullSource == nil 129} 130 131func (p *Pull) IsBranchBased() bool { 132 if p.PullSource != nil { 133 if p.PullSource.RepoAt != nil { 134 return p.PullSource.RepoAt == &p.RepoAt 135 } else { 136 // no repo specified 137 return true 138 } 139 } 140 return false 141} 142 143func (p *Pull) IsForkBased() bool { 144 if p.PullSource != nil { 145 if p.PullSource.RepoAt != nil { 146 // make sure repos are different 147 return p.PullSource.RepoAt != &p.RepoAt 148 } 149 } 150 return false 151} 152 153func (s PullSubmission) AsNiceDiff(targetBranch string) types.NiceDiff { 154 patch := s.Patch 155 156 diffs, _, err := gitdiff.Parse(strings.NewReader(patch)) 157 if err != nil { 158 log.Println(err) 159 } 160 161 nd := types.NiceDiff{} 162 nd.Commit.Parent = targetBranch 163 164 for _, d := range diffs { 165 ndiff := types.Diff{} 166 ndiff.Name.New = d.NewName 167 ndiff.Name.Old = d.OldName 168 ndiff.IsBinary = d.IsBinary 169 ndiff.IsNew = d.IsNew 170 ndiff.IsDelete = d.IsDelete 171 ndiff.IsCopy = d.IsCopy 172 ndiff.IsRename = d.IsRename 173 174 for _, tf := range d.TextFragments { 175 ndiff.TextFragments = append(ndiff.TextFragments, *tf) 176 for _, l := range tf.Lines { 177 switch l.Op { 178 case gitdiff.OpAdd: 179 nd.Stat.Insertions += 1 180 case gitdiff.OpDelete: 181 nd.Stat.Deletions += 1 182 } 183 } 184 } 185 186 nd.Diff = append(nd.Diff, ndiff) 187 } 188 189 nd.Stat.FilesChanged = len(diffs) 190 191 return nd 192} 193 194func (s PullSubmission) IsFormatPatch() bool { 195 return patchutil.IsFormatPatch(s.Patch) 196} 197 198func (s PullSubmission) AsFormatPatch() []patchutil.FormatPatch { 199 patches, err := patchutil.ExtractPatches(s.Patch) 200 if err != nil { 201 log.Println("error extracting patches from submission:", err) 202 return []patchutil.FormatPatch{} 203 } 204 205 return patches 206} 207 208func NewPull(tx *sql.Tx, pull *Pull) error { 209 defer tx.Rollback() 210 211 _, err := tx.Exec(` 212 insert or ignore into repo_pull_seqs (repo_at, next_pull_id) 213 values (?, 1) 214 `, pull.RepoAt) 215 if err != nil { 216 return err 217 } 218 219 var nextId int 220 err = tx.QueryRow(` 221 update repo_pull_seqs 222 set next_pull_id = next_pull_id + 1 223 where repo_at = ? 224 returning next_pull_id - 1 225 `, pull.RepoAt).Scan(&nextId) 226 if err != nil { 227 return err 228 } 229 230 pull.PullId = nextId 231 pull.State = PullOpen 232 233 var sourceBranch, sourceRepoAt *string 234 if pull.PullSource != nil { 235 sourceBranch = &pull.PullSource.Branch 236 if pull.PullSource.RepoAt != nil { 237 x := pull.PullSource.RepoAt.String() 238 sourceRepoAt = &x 239 } 240 } 241 242 _, err = tx.Exec( 243 ` 244 insert into pulls (repo_at, owner_did, pull_id, title, target_branch, body, rkey, state, source_branch, source_repo_at) 245 values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`, 246 pull.RepoAt, 247 pull.OwnerDid, 248 pull.PullId, 249 pull.Title, 250 pull.TargetBranch, 251 pull.Body, 252 pull.Rkey, 253 pull.State, 254 sourceBranch, 255 sourceRepoAt, 256 ) 257 if err != nil { 258 return err 259 } 260 261 _, err = tx.Exec(` 262 insert into pull_submissions (pull_id, repo_at, round_number, patch, source_rev) 263 values (?, ?, ?, ?, ?) 264 `, pull.PullId, pull.RepoAt, 0, pull.Submissions[0].Patch, pull.Submissions[0].SourceRev) 265 if err != nil { 266 return err 267 } 268 269 if err := tx.Commit(); err != nil { 270 return err 271 } 272 273 return nil 274} 275 276func SetPullAt(e Execer, repoAt syntax.ATURI, pullId int, pullAt string) error { 277 _, err := e.Exec(`update pulls set pull_at = ? where repo_at = ? and pull_id = ?`, pullAt, repoAt, pullId) 278 return err 279} 280 281func GetPullAt(e Execer, repoAt syntax.ATURI, pullId int) (string, error) { 282 var pullAt string 283 err := e.QueryRow(`select pull_at from pulls where repo_at = ? and pull_id = ?`, repoAt, pullId).Scan(&pullAt) 284 return pullAt, err 285} 286 287func NextPullId(e Execer, repoAt syntax.ATURI) (int, error) { 288 var pullId int 289 err := e.QueryRow(`select next_pull_id from repo_pull_seqs where repo_at = ?`, repoAt).Scan(&pullId) 290 return pullId - 1, err 291} 292 293func GetPulls(e Execer, repoAt syntax.ATURI, state PullState) ([]*Pull, error) { 294 pulls := make(map[int]*Pull) 295 296 rows, err := e.Query(` 297 select 298 owner_did, 299 pull_id, 300 created, 301 title, 302 state, 303 target_branch, 304 pull_at, 305 body, 306 rkey, 307 source_branch, 308 source_repo_at 309 from 310 pulls 311 where 312 repo_at = ? and state = ?`, repoAt, state) 313 if err != nil { 314 return nil, err 315 } 316 defer rows.Close() 317 318 for rows.Next() { 319 var pull Pull 320 var createdAt string 321 var sourceBranch, sourceRepoAt sql.NullString 322 err := rows.Scan( 323 &pull.OwnerDid, 324 &pull.PullId, 325 &createdAt, 326 &pull.Title, 327 &pull.State, 328 &pull.TargetBranch, 329 &pull.PullAt, 330 &pull.Body, 331 &pull.Rkey, 332 &sourceBranch, 333 &sourceRepoAt, 334 ) 335 if err != nil { 336 return nil, err 337 } 338 339 createdTime, err := time.Parse(time.RFC3339, createdAt) 340 if err != nil { 341 return nil, err 342 } 343 pull.Created = createdTime 344 345 if sourceBranch.Valid { 346 pull.PullSource = &PullSource{ 347 Branch: sourceBranch.String, 348 } 349 if sourceRepoAt.Valid { 350 sourceRepoAtParsed, err := syntax.ParseATURI(sourceRepoAt.String) 351 if err != nil { 352 return nil, err 353 } 354 pull.PullSource.RepoAt = &sourceRepoAtParsed 355 } 356 } 357 358 pulls[pull.PullId] = &pull 359 } 360 361 // get latest round no. for each pull 362 inClause := strings.TrimSuffix(strings.Repeat("?, ", len(pulls)), ", ") 363 submissionsQuery := fmt.Sprintf(` 364 select 365 id, pull_id, round_number 366 from 367 pull_submissions 368 where 369 repo_at = ? and pull_id in (%s) 370 `, inClause) 371 372 args := make([]any, len(pulls)+1) 373 args[0] = repoAt.String() 374 idx := 1 375 for _, p := range pulls { 376 args[idx] = p.PullId 377 idx += 1 378 } 379 submissionsRows, err := e.Query(submissionsQuery, args...) 380 if err != nil { 381 return nil, err 382 } 383 defer submissionsRows.Close() 384 385 for submissionsRows.Next() { 386 var s PullSubmission 387 err := submissionsRows.Scan( 388 &s.ID, 389 &s.PullId, 390 &s.RoundNumber, 391 ) 392 if err != nil { 393 return nil, err 394 } 395 396 if p, ok := pulls[s.PullId]; ok { 397 p.Submissions = make([]*PullSubmission, s.RoundNumber+1) 398 p.Submissions[s.RoundNumber] = &s 399 } 400 } 401 if err := rows.Err(); err != nil { 402 return nil, err 403 } 404 405 // get comment count on latest submission on each pull 406 inClause = strings.TrimSuffix(strings.Repeat("?, ", len(pulls)), ", ") 407 commentsQuery := fmt.Sprintf(` 408 select 409 count(id), pull_id 410 from 411 pull_comments 412 where 413 submission_id in (%s) 414 group by 415 submission_id 416 `, inClause) 417 418 args = []any{} 419 for _, p := range pulls { 420 args = append(args, p.Submissions[p.LastRoundNumber()].ID) 421 } 422 commentsRows, err := e.Query(commentsQuery, args...) 423 if err != nil { 424 return nil, err 425 } 426 defer commentsRows.Close() 427 428 for commentsRows.Next() { 429 var commentCount, pullId int 430 err := commentsRows.Scan( 431 &commentCount, 432 &pullId, 433 ) 434 if err != nil { 435 return nil, err 436 } 437 if p, ok := pulls[pullId]; ok { 438 p.Submissions[p.LastRoundNumber()].Comments = make([]PullComment, commentCount) 439 } 440 } 441 if err := rows.Err(); err != nil { 442 return nil, err 443 } 444 445 orderedByDate := []*Pull{} 446 for _, p := range pulls { 447 orderedByDate = append(orderedByDate, p) 448 } 449 sort.Slice(orderedByDate, func(i, j int) bool { 450 return orderedByDate[i].Created.After(orderedByDate[j].Created) 451 }) 452 453 return orderedByDate, nil 454} 455 456func GetPull(e Execer, repoAt syntax.ATURI, pullId int) (*Pull, error) { 457 query := ` 458 select 459 owner_did, 460 pull_id, 461 created, 462 title, 463 state, 464 target_branch, 465 pull_at, 466 repo_at, 467 body, 468 rkey, 469 source_branch, 470 source_repo_at 471 from 472 pulls 473 where 474 repo_at = ? and pull_id = ? 475 ` 476 row := e.QueryRow(query, repoAt, pullId) 477 478 var pull Pull 479 var createdAt string 480 var sourceBranch, sourceRepoAt sql.NullString 481 err := row.Scan( 482 &pull.OwnerDid, 483 &pull.PullId, 484 &createdAt, 485 &pull.Title, 486 &pull.State, 487 &pull.TargetBranch, 488 &pull.PullAt, 489 &pull.RepoAt, 490 &pull.Body, 491 &pull.Rkey, 492 &sourceBranch, 493 &sourceRepoAt, 494 ) 495 if err != nil { 496 return nil, err 497 } 498 499 createdTime, err := time.Parse(time.RFC3339, createdAt) 500 if err != nil { 501 return nil, err 502 } 503 pull.Created = createdTime 504 505 // populate source 506 if sourceBranch.Valid { 507 pull.PullSource = &PullSource{ 508 Branch: sourceBranch.String, 509 } 510 if sourceRepoAt.Valid { 511 sourceRepoAtParsed, err := syntax.ParseATURI(sourceRepoAt.String) 512 if err != nil { 513 return nil, err 514 } 515 pull.PullSource.RepoAt = &sourceRepoAtParsed 516 } 517 } 518 519 submissionsQuery := ` 520 select 521 id, pull_id, repo_at, round_number, patch, created, source_rev 522 from 523 pull_submissions 524 where 525 repo_at = ? and pull_id = ? 526 ` 527 submissionsRows, err := e.Query(submissionsQuery, repoAt, pullId) 528 if err != nil { 529 return nil, err 530 } 531 defer submissionsRows.Close() 532 533 submissionsMap := make(map[int]*PullSubmission) 534 535 for submissionsRows.Next() { 536 var submission PullSubmission 537 var submissionCreatedStr string 538 var submissionSourceRev sql.NullString 539 err := submissionsRows.Scan( 540 &submission.ID, 541 &submission.PullId, 542 &submission.RepoAt, 543 &submission.RoundNumber, 544 &submission.Patch, 545 &submissionCreatedStr, 546 &submissionSourceRev, 547 ) 548 if err != nil { 549 return nil, err 550 } 551 552 submissionCreatedTime, err := time.Parse(time.RFC3339, submissionCreatedStr) 553 if err != nil { 554 return nil, err 555 } 556 submission.Created = submissionCreatedTime 557 558 if submissionSourceRev.Valid { 559 submission.SourceRev = submissionSourceRev.String 560 } 561 562 submissionsMap[submission.ID] = &submission 563 } 564 if err = submissionsRows.Close(); err != nil { 565 return nil, err 566 } 567 if len(submissionsMap) == 0 { 568 return &pull, nil 569 } 570 571 var args []any 572 for k := range submissionsMap { 573 args = append(args, k) 574 } 575 inClause := strings.TrimSuffix(strings.Repeat("?, ", len(submissionsMap)), ", ") 576 commentsQuery := fmt.Sprintf(` 577 select 578 id, 579 pull_id, 580 submission_id, 581 repo_at, 582 owner_did, 583 comment_at, 584 body, 585 created 586 from 587 pull_comments 588 where 589 submission_id IN (%s) 590 order by 591 created asc 592 `, inClause) 593 commentsRows, err := e.Query(commentsQuery, args...) 594 if err != nil { 595 return nil, err 596 } 597 defer commentsRows.Close() 598 599 for commentsRows.Next() { 600 var comment PullComment 601 var commentCreatedStr string 602 err := commentsRows.Scan( 603 &comment.ID, 604 &comment.PullId, 605 &comment.SubmissionId, 606 &comment.RepoAt, 607 &comment.OwnerDid, 608 &comment.CommentAt, 609 &comment.Body, 610 &commentCreatedStr, 611 ) 612 if err != nil { 613 return nil, err 614 } 615 616 commentCreatedTime, err := time.Parse(time.RFC3339, commentCreatedStr) 617 if err != nil { 618 return nil, err 619 } 620 comment.Created = commentCreatedTime 621 622 // Add the comment to its submission 623 if submission, ok := submissionsMap[comment.SubmissionId]; ok { 624 submission.Comments = append(submission.Comments, comment) 625 } 626 627 } 628 if err = commentsRows.Err(); err != nil { 629 return nil, err 630 } 631 632 pull.Submissions = make([]*PullSubmission, len(submissionsMap)) 633 for _, submission := range submissionsMap { 634 pull.Submissions[submission.RoundNumber] = submission 635 } 636 637 return &pull, nil 638} 639 640// timeframe here is directly passed into the sql query filter, and any 641// timeframe in the past should be negative; e.g.: "-3 months" 642func GetPullsByOwnerDid(e Execer, did, timeframe string) ([]Pull, error) { 643 var pulls []Pull 644 645 rows, err := e.Query(` 646 select 647 p.owner_did, 648 p.repo_at, 649 p.pull_id, 650 p.created, 651 p.title, 652 p.state, 653 r.did, 654 r.name, 655 r.knot, 656 r.rkey, 657 r.created 658 from 659 pulls p 660 join 661 repos r on p.repo_at = r.at_uri 662 where 663 p.owner_did = ? and p.created >= date ('now', ?) 664 order by 665 p.created desc`, did, timeframe) 666 if err != nil { 667 return nil, err 668 } 669 defer rows.Close() 670 671 for rows.Next() { 672 var pull Pull 673 var repo Repo 674 var pullCreatedAt, repoCreatedAt string 675 err := rows.Scan( 676 &pull.OwnerDid, 677 &pull.RepoAt, 678 &pull.PullId, 679 &pullCreatedAt, 680 &pull.Title, 681 &pull.State, 682 &repo.Did, 683 &repo.Name, 684 &repo.Knot, 685 &repo.Rkey, 686 &repoCreatedAt, 687 ) 688 if err != nil { 689 return nil, err 690 } 691 692 pullCreatedTime, err := time.Parse(time.RFC3339, pullCreatedAt) 693 if err != nil { 694 return nil, err 695 } 696 pull.Created = pullCreatedTime 697 698 repoCreatedTime, err := time.Parse(time.RFC3339, repoCreatedAt) 699 if err != nil { 700 return nil, err 701 } 702 repo.Created = repoCreatedTime 703 704 pull.Repo = &repo 705 706 pulls = append(pulls, pull) 707 } 708 709 if err := rows.Err(); err != nil { 710 return nil, err 711 } 712 713 return pulls, nil 714} 715 716func NewPullComment(e Execer, comment *PullComment) (int64, error) { 717 query := `insert into pull_comments (owner_did, repo_at, submission_id, comment_at, pull_id, body) values (?, ?, ?, ?, ?, ?)` 718 res, err := e.Exec( 719 query, 720 comment.OwnerDid, 721 comment.RepoAt, 722 comment.SubmissionId, 723 comment.CommentAt, 724 comment.PullId, 725 comment.Body, 726 ) 727 if err != nil { 728 return 0, err 729 } 730 731 i, err := res.LastInsertId() 732 if err != nil { 733 return 0, err 734 } 735 736 return i, nil 737} 738 739func SetPullState(e Execer, repoAt syntax.ATURI, pullId int, pullState PullState) error { 740 _, err := e.Exec(`update pulls set state = ? where repo_at = ? and pull_id = ?`, pullState, repoAt, pullId) 741 return err 742} 743 744func ClosePull(e Execer, repoAt syntax.ATURI, pullId int) error { 745 err := SetPullState(e, repoAt, pullId, PullClosed) 746 return err 747} 748 749func ReopenPull(e Execer, repoAt syntax.ATURI, pullId int) error { 750 err := SetPullState(e, repoAt, pullId, PullOpen) 751 return err 752} 753 754func MergePull(e Execer, repoAt syntax.ATURI, pullId int) error { 755 err := SetPullState(e, repoAt, pullId, PullMerged) 756 return err 757} 758 759func ResubmitPull(e Execer, pull *Pull, newPatch, sourceRev string) error { 760 newRoundNumber := len(pull.Submissions) 761 _, err := e.Exec(` 762 insert into pull_submissions (pull_id, repo_at, round_number, patch, source_rev) 763 values (?, ?, ?, ?, ?) 764 `, pull.PullId, pull.RepoAt, newRoundNumber, newPatch, sourceRev) 765 766 return err 767} 768 769type PullCount struct { 770 Open int 771 Merged int 772 Closed int 773} 774 775func GetPullCount(e Execer, repoAt syntax.ATURI) (PullCount, error) { 776 row := e.QueryRow(` 777 select 778 count(case when state = ? then 1 end) as open_count, 779 count(case when state = ? then 1 end) as merged_count, 780 count(case when state = ? then 1 end) as closed_count 781 from pulls 782 where repo_at = ?`, 783 PullOpen, 784 PullMerged, 785 PullClosed, 786 repoAt, 787 ) 788 789 var count PullCount 790 if err := row.Scan(&count.Open, &count.Merged, &count.Closed); err != nil { 791 return PullCount{0, 0, 0}, err 792 } 793 794 return count, nil 795}