forked from tangled.org/core
Monorepo for Tangled — https://tangled.org
1package db 2 3import ( 4 "database/sql" 5 "fmt" 6 "log" 7 "sort" 8 "strings" 9 "time" 10 11 "github.com/bluekeyes/go-gitdiff/gitdiff" 12 "github.com/bluesky-social/indigo/atproto/syntax" 13 "tangled.sh/tangled.sh/core/api/tangled" 14 "tangled.sh/tangled.sh/core/patchutil" 15 "tangled.sh/tangled.sh/core/types" 16) 17 18type PullState int 19 20const ( 21 PullClosed PullState = iota 22 PullOpen 23 PullMerged 24) 25 26func (p PullState) String() string { 27 switch p { 28 case PullOpen: 29 return "open" 30 case PullMerged: 31 return "merged" 32 case PullClosed: 33 return "closed" 34 default: 35 return "closed" 36 } 37} 38 39func (p PullState) IsOpen() bool { 40 return p == PullOpen 41} 42func (p PullState) IsMerged() bool { 43 return p == PullMerged 44} 45func (p PullState) IsClosed() bool { 46 return p == PullClosed 47} 48 49type Pull struct { 50 // ids 51 ID int 52 PullId int 53 54 // at ids 55 RepoAt syntax.ATURI 56 OwnerDid string 57 Rkey string 58 59 // content 60 Title string 61 Body string 62 TargetBranch string 63 State PullState 64 Submissions []*PullSubmission 65 66 // meta 67 Created time.Time 68 PullSource *PullSource 69 70 // optionally, populate this when querying for reverse mappings 71 Repo *Repo 72} 73 74type PullSource struct { 75 Branch string 76 RepoAt *syntax.ATURI 77 78 // optionally populate this for reverse mappings 79 Repo *Repo 80} 81 82type PullSubmission struct { 83 // ids 84 ID int 85 PullId int 86 87 // at ids 88 RepoAt syntax.ATURI 89 90 // content 91 RoundNumber int 92 Patch string 93 Comments []PullComment 94 SourceRev string // include the rev that was used to create this submission: only for branch PRs 95 96 // meta 97 Created time.Time 98} 99 100type PullComment struct { 101 // ids 102 ID int 103 PullId int 104 SubmissionId int 105 106 // at ids 107 RepoAt string 108 OwnerDid string 109 CommentAt string 110 111 // content 112 Body string 113 114 // meta 115 Created time.Time 116} 117 118func (p *Pull) LatestPatch() string { 119 latestSubmission := p.Submissions[p.LastRoundNumber()] 120 return latestSubmission.Patch 121} 122 123func (p *Pull) PullAt() syntax.ATURI { 124 return syntax.ATURI(fmt.Sprintf("at://%s/%s/%s", p.OwnerDid, tangled.RepoPullNSID, p.Rkey)) 125} 126 127func (p *Pull) LastRoundNumber() int { 128 return len(p.Submissions) - 1 129} 130 131func (p *Pull) IsPatchBased() bool { 132 return p.PullSource == nil 133} 134 135func (p *Pull) IsBranchBased() bool { 136 if p.PullSource != nil { 137 if p.PullSource.RepoAt != nil { 138 return p.PullSource.RepoAt == &p.RepoAt 139 } else { 140 // no repo specified 141 return true 142 } 143 } 144 return false 145} 146 147func (p *Pull) IsForkBased() bool { 148 if p.PullSource != nil { 149 if p.PullSource.RepoAt != nil { 150 // make sure repos are different 151 return p.PullSource.RepoAt != &p.RepoAt 152 } 153 } 154 return false 155} 156 157func (s PullSubmission) AsDiff(targetBranch string) ([]*gitdiff.File, error) { 158 patch := s.Patch 159 160 // if format-patch; then extract each patch 161 var diffs []*gitdiff.File 162 if patchutil.IsFormatPatch(patch) { 163 patches, err := patchutil.ExtractPatches(patch) 164 if err != nil { 165 return nil, err 166 } 167 var ps [][]*gitdiff.File 168 for _, p := range patches { 169 ps = append(ps, p.Files) 170 } 171 172 diffs = patchutil.CombineDiff(ps...) 173 } else { 174 d, _, err := gitdiff.Parse(strings.NewReader(patch)) 175 if err != nil { 176 return nil, err 177 } 178 diffs = d 179 } 180 181 return diffs, nil 182} 183 184func (s PullSubmission) AsNiceDiff(targetBranch string) types.NiceDiff { 185 diffs, err := s.AsDiff(targetBranch) 186 if err != nil { 187 log.Println(err) 188 } 189 190 nd := types.NiceDiff{} 191 nd.Commit.Parent = targetBranch 192 193 for _, d := range diffs { 194 ndiff := types.Diff{} 195 ndiff.Name.New = d.NewName 196 ndiff.Name.Old = d.OldName 197 ndiff.IsBinary = d.IsBinary 198 ndiff.IsNew = d.IsNew 199 ndiff.IsDelete = d.IsDelete 200 ndiff.IsCopy = d.IsCopy 201 ndiff.IsRename = d.IsRename 202 203 for _, tf := range d.TextFragments { 204 ndiff.TextFragments = append(ndiff.TextFragments, *tf) 205 for _, l := range tf.Lines { 206 switch l.Op { 207 case gitdiff.OpAdd: 208 nd.Stat.Insertions += 1 209 case gitdiff.OpDelete: 210 nd.Stat.Deletions += 1 211 } 212 } 213 } 214 215 nd.Diff = append(nd.Diff, ndiff) 216 } 217 218 nd.Stat.FilesChanged = len(diffs) 219 220 return nd 221} 222 223func (s PullSubmission) IsFormatPatch() bool { 224 return patchutil.IsFormatPatch(s.Patch) 225} 226 227func (s PullSubmission) AsFormatPatch() []patchutil.FormatPatch { 228 patches, err := patchutil.ExtractPatches(s.Patch) 229 if err != nil { 230 log.Println("error extracting patches from submission:", err) 231 return []patchutil.FormatPatch{} 232 } 233 234 return patches 235} 236 237func NewPull(tx *sql.Tx, pull *Pull) error { 238 defer tx.Rollback() 239 240 _, err := tx.Exec(` 241 insert or ignore into repo_pull_seqs (repo_at, next_pull_id) 242 values (?, 1) 243 `, pull.RepoAt) 244 if err != nil { 245 return err 246 } 247 248 var nextId int 249 err = tx.QueryRow(` 250 update repo_pull_seqs 251 set next_pull_id = next_pull_id + 1 252 where repo_at = ? 253 returning next_pull_id - 1 254 `, pull.RepoAt).Scan(&nextId) 255 if err != nil { 256 return err 257 } 258 259 pull.PullId = nextId 260 pull.State = PullOpen 261 262 var sourceBranch, sourceRepoAt *string 263 if pull.PullSource != nil { 264 sourceBranch = &pull.PullSource.Branch 265 if pull.PullSource.RepoAt != nil { 266 x := pull.PullSource.RepoAt.String() 267 sourceRepoAt = &x 268 } 269 } 270 271 _, err = tx.Exec( 272 ` 273 insert into pulls (repo_at, owner_did, pull_id, title, target_branch, body, rkey, state, source_branch, source_repo_at) 274 values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`, 275 pull.RepoAt, 276 pull.OwnerDid, 277 pull.PullId, 278 pull.Title, 279 pull.TargetBranch, 280 pull.Body, 281 pull.Rkey, 282 pull.State, 283 sourceBranch, 284 sourceRepoAt, 285 ) 286 if err != nil { 287 return err 288 } 289 290 _, err = tx.Exec(` 291 insert into pull_submissions (pull_id, repo_at, round_number, patch, source_rev) 292 values (?, ?, ?, ?, ?) 293 `, pull.PullId, pull.RepoAt, 0, pull.Submissions[0].Patch, pull.Submissions[0].SourceRev) 294 if err != nil { 295 return err 296 } 297 298 if err := tx.Commit(); err != nil { 299 return err 300 } 301 302 return nil 303} 304 305func GetPullAt(e Execer, repoAt syntax.ATURI, pullId int) (syntax.ATURI, error) { 306 pull, err := GetPull(e, repoAt, pullId) 307 if err != nil { 308 return "", err 309 } 310 return pull.PullAt(), err 311} 312 313func NextPullId(e Execer, repoAt syntax.ATURI) (int, error) { 314 var pullId int 315 err := e.QueryRow(`select next_pull_id from repo_pull_seqs where repo_at = ?`, repoAt).Scan(&pullId) 316 return pullId - 1, err 317} 318 319func GetPulls(e Execer, repoAt syntax.ATURI, state PullState) ([]*Pull, error) { 320 pulls := make(map[int]*Pull) 321 322 rows, err := e.Query(` 323 select 324 owner_did, 325 pull_id, 326 created, 327 title, 328 state, 329 target_branch, 330 body, 331 rkey, 332 source_branch, 333 source_repo_at 334 from 335 pulls 336 where 337 repo_at = ? and state = ?`, repoAt, state) 338 if err != nil { 339 return nil, err 340 } 341 defer rows.Close() 342 343 for rows.Next() { 344 var pull Pull 345 var createdAt string 346 var sourceBranch, sourceRepoAt sql.NullString 347 err := rows.Scan( 348 &pull.OwnerDid, 349 &pull.PullId, 350 &createdAt, 351 &pull.Title, 352 &pull.State, 353 &pull.TargetBranch, 354 &pull.Body, 355 &pull.Rkey, 356 &sourceBranch, 357 &sourceRepoAt, 358 ) 359 if err != nil { 360 return nil, err 361 } 362 363 createdTime, err := time.Parse(time.RFC3339, createdAt) 364 if err != nil { 365 return nil, err 366 } 367 pull.Created = createdTime 368 369 if sourceBranch.Valid { 370 pull.PullSource = &PullSource{ 371 Branch: sourceBranch.String, 372 } 373 if sourceRepoAt.Valid { 374 sourceRepoAtParsed, err := syntax.ParseATURI(sourceRepoAt.String) 375 if err != nil { 376 return nil, err 377 } 378 pull.PullSource.RepoAt = &sourceRepoAtParsed 379 } 380 } 381 382 pulls[pull.PullId] = &pull 383 } 384 385 // get latest round no. for each pull 386 inClause := strings.TrimSuffix(strings.Repeat("?, ", len(pulls)), ", ") 387 submissionsQuery := fmt.Sprintf(` 388 select 389 id, pull_id, round_number 390 from 391 pull_submissions 392 where 393 repo_at = ? and pull_id in (%s) 394 `, inClause) 395 396 args := make([]any, len(pulls)+1) 397 args[0] = repoAt.String() 398 idx := 1 399 for _, p := range pulls { 400 args[idx] = p.PullId 401 idx += 1 402 } 403 submissionsRows, err := e.Query(submissionsQuery, args...) 404 if err != nil { 405 return nil, err 406 } 407 defer submissionsRows.Close() 408 409 for submissionsRows.Next() { 410 var s PullSubmission 411 err := submissionsRows.Scan( 412 &s.ID, 413 &s.PullId, 414 &s.RoundNumber, 415 ) 416 if err != nil { 417 return nil, err 418 } 419 420 if p, ok := pulls[s.PullId]; ok { 421 p.Submissions = make([]*PullSubmission, s.RoundNumber+1) 422 p.Submissions[s.RoundNumber] = &s 423 } 424 } 425 if err := rows.Err(); err != nil { 426 return nil, err 427 } 428 429 // get comment count on latest submission on each pull 430 inClause = strings.TrimSuffix(strings.Repeat("?, ", len(pulls)), ", ") 431 commentsQuery := fmt.Sprintf(` 432 select 433 count(id), pull_id 434 from 435 pull_comments 436 where 437 submission_id in (%s) 438 group by 439 submission_id 440 `, inClause) 441 442 args = []any{} 443 for _, p := range pulls { 444 args = append(args, p.Submissions[p.LastRoundNumber()].ID) 445 } 446 commentsRows, err := e.Query(commentsQuery, args...) 447 if err != nil { 448 return nil, err 449 } 450 defer commentsRows.Close() 451 452 for commentsRows.Next() { 453 var commentCount, pullId int 454 err := commentsRows.Scan( 455 &commentCount, 456 &pullId, 457 ) 458 if err != nil { 459 return nil, err 460 } 461 if p, ok := pulls[pullId]; ok { 462 p.Submissions[p.LastRoundNumber()].Comments = make([]PullComment, commentCount) 463 } 464 } 465 if err := rows.Err(); err != nil { 466 return nil, err 467 } 468 469 orderedByDate := []*Pull{} 470 for _, p := range pulls { 471 orderedByDate = append(orderedByDate, p) 472 } 473 sort.Slice(orderedByDate, func(i, j int) bool { 474 return orderedByDate[i].Created.After(orderedByDate[j].Created) 475 }) 476 477 return orderedByDate, nil 478} 479 480func GetPull(e Execer, repoAt syntax.ATURI, pullId int) (*Pull, error) { 481 query := ` 482 select 483 owner_did, 484 pull_id, 485 created, 486 title, 487 state, 488 target_branch, 489 repo_at, 490 body, 491 rkey, 492 source_branch, 493 source_repo_at 494 from 495 pulls 496 where 497 repo_at = ? and pull_id = ? 498 ` 499 row := e.QueryRow(query, repoAt, pullId) 500 501 var pull Pull 502 var createdAt string 503 var sourceBranch, sourceRepoAt sql.NullString 504 err := row.Scan( 505 &pull.OwnerDid, 506 &pull.PullId, 507 &createdAt, 508 &pull.Title, 509 &pull.State, 510 &pull.TargetBranch, 511 &pull.RepoAt, 512 &pull.Body, 513 &pull.Rkey, 514 &sourceBranch, 515 &sourceRepoAt, 516 ) 517 if err != nil { 518 return nil, err 519 } 520 521 createdTime, err := time.Parse(time.RFC3339, createdAt) 522 if err != nil { 523 return nil, err 524 } 525 pull.Created = createdTime 526 527 // populate source 528 if sourceBranch.Valid { 529 pull.PullSource = &PullSource{ 530 Branch: sourceBranch.String, 531 } 532 if sourceRepoAt.Valid { 533 sourceRepoAtParsed, err := syntax.ParseATURI(sourceRepoAt.String) 534 if err != nil { 535 return nil, err 536 } 537 pull.PullSource.RepoAt = &sourceRepoAtParsed 538 } 539 } 540 541 submissionsQuery := ` 542 select 543 id, pull_id, repo_at, round_number, patch, created, source_rev 544 from 545 pull_submissions 546 where 547 repo_at = ? and pull_id = ? 548 ` 549 submissionsRows, err := e.Query(submissionsQuery, repoAt, pullId) 550 if err != nil { 551 return nil, err 552 } 553 defer submissionsRows.Close() 554 555 submissionsMap := make(map[int]*PullSubmission) 556 557 for submissionsRows.Next() { 558 var submission PullSubmission 559 var submissionCreatedStr string 560 var submissionSourceRev sql.NullString 561 err := submissionsRows.Scan( 562 &submission.ID, 563 &submission.PullId, 564 &submission.RepoAt, 565 &submission.RoundNumber, 566 &submission.Patch, 567 &submissionCreatedStr, 568 &submissionSourceRev, 569 ) 570 if err != nil { 571 return nil, err 572 } 573 574 submissionCreatedTime, err := time.Parse(time.RFC3339, submissionCreatedStr) 575 if err != nil { 576 return nil, err 577 } 578 submission.Created = submissionCreatedTime 579 580 if submissionSourceRev.Valid { 581 submission.SourceRev = submissionSourceRev.String 582 } 583 584 submissionsMap[submission.ID] = &submission 585 } 586 if err = submissionsRows.Close(); err != nil { 587 return nil, err 588 } 589 if len(submissionsMap) == 0 { 590 return &pull, nil 591 } 592 593 var args []any 594 for k := range submissionsMap { 595 args = append(args, k) 596 } 597 inClause := strings.TrimSuffix(strings.Repeat("?, ", len(submissionsMap)), ", ") 598 commentsQuery := fmt.Sprintf(` 599 select 600 id, 601 pull_id, 602 submission_id, 603 repo_at, 604 owner_did, 605 comment_at, 606 body, 607 created 608 from 609 pull_comments 610 where 611 submission_id IN (%s) 612 order by 613 created asc 614 `, inClause) 615 commentsRows, err := e.Query(commentsQuery, args...) 616 if err != nil { 617 return nil, err 618 } 619 defer commentsRows.Close() 620 621 for commentsRows.Next() { 622 var comment PullComment 623 var commentCreatedStr string 624 err := commentsRows.Scan( 625 &comment.ID, 626 &comment.PullId, 627 &comment.SubmissionId, 628 &comment.RepoAt, 629 &comment.OwnerDid, 630 &comment.CommentAt, 631 &comment.Body, 632 &commentCreatedStr, 633 ) 634 if err != nil { 635 return nil, err 636 } 637 638 commentCreatedTime, err := time.Parse(time.RFC3339, commentCreatedStr) 639 if err != nil { 640 return nil, err 641 } 642 comment.Created = commentCreatedTime 643 644 // Add the comment to its submission 645 if submission, ok := submissionsMap[comment.SubmissionId]; ok { 646 submission.Comments = append(submission.Comments, comment) 647 } 648 649 } 650 if err = commentsRows.Err(); err != nil { 651 return nil, err 652 } 653 654 var pullSourceRepo *Repo 655 if pull.PullSource != nil { 656 if pull.PullSource.RepoAt != nil { 657 pullSourceRepo, err = GetRepoByAtUri(e, pull.PullSource.RepoAt.String()) 658 if err != nil { 659 log.Printf("failed to get repo by at uri: %v", err) 660 } else { 661 pull.PullSource.Repo = pullSourceRepo 662 } 663 } 664 } 665 666 pull.Submissions = make([]*PullSubmission, len(submissionsMap)) 667 for _, submission := range submissionsMap { 668 pull.Submissions[submission.RoundNumber] = submission 669 } 670 671 return &pull, nil 672} 673 674// timeframe here is directly passed into the sql query filter, and any 675// timeframe in the past should be negative; e.g.: "-3 months" 676func GetPullsByOwnerDid(e Execer, did, timeframe string) ([]Pull, error) { 677 var pulls []Pull 678 679 rows, err := e.Query(` 680 select 681 p.owner_did, 682 p.repo_at, 683 p.pull_id, 684 p.created, 685 p.title, 686 p.state, 687 r.did, 688 r.name, 689 r.knot, 690 r.rkey, 691 r.created 692 from 693 pulls p 694 join 695 repos r on p.repo_at = r.at_uri 696 where 697 p.owner_did = ? and p.created >= date ('now', ?) 698 order by 699 p.created desc`, did, timeframe) 700 if err != nil { 701 return nil, err 702 } 703 defer rows.Close() 704 705 for rows.Next() { 706 var pull Pull 707 var repo Repo 708 var pullCreatedAt, repoCreatedAt string 709 err := rows.Scan( 710 &pull.OwnerDid, 711 &pull.RepoAt, 712 &pull.PullId, 713 &pullCreatedAt, 714 &pull.Title, 715 &pull.State, 716 &repo.Did, 717 &repo.Name, 718 &repo.Knot, 719 &repo.Rkey, 720 &repoCreatedAt, 721 ) 722 if err != nil { 723 return nil, err 724 } 725 726 pullCreatedTime, err := time.Parse(time.RFC3339, pullCreatedAt) 727 if err != nil { 728 return nil, err 729 } 730 pull.Created = pullCreatedTime 731 732 repoCreatedTime, err := time.Parse(time.RFC3339, repoCreatedAt) 733 if err != nil { 734 return nil, err 735 } 736 repo.Created = repoCreatedTime 737 738 pull.Repo = &repo 739 740 pulls = append(pulls, pull) 741 } 742 743 if err := rows.Err(); err != nil { 744 return nil, err 745 } 746 747 return pulls, nil 748} 749 750func NewPullComment(e Execer, comment *PullComment) (int64, error) { 751 query := `insert into pull_comments (owner_did, repo_at, submission_id, comment_at, pull_id, body) values (?, ?, ?, ?, ?, ?)` 752 res, err := e.Exec( 753 query, 754 comment.OwnerDid, 755 comment.RepoAt, 756 comment.SubmissionId, 757 comment.CommentAt, 758 comment.PullId, 759 comment.Body, 760 ) 761 if err != nil { 762 return 0, err 763 } 764 765 i, err := res.LastInsertId() 766 if err != nil { 767 return 0, err 768 } 769 770 return i, nil 771} 772 773func SetPullState(e Execer, repoAt syntax.ATURI, pullId int, pullState PullState) error { 774 _, err := e.Exec(`update pulls set state = ? where repo_at = ? and pull_id = ?`, pullState, repoAt, pullId) 775 return err 776} 777 778func ClosePull(e Execer, repoAt syntax.ATURI, pullId int) error { 779 err := SetPullState(e, repoAt, pullId, PullClosed) 780 return err 781} 782 783func ReopenPull(e Execer, repoAt syntax.ATURI, pullId int) error { 784 err := SetPullState(e, repoAt, pullId, PullOpen) 785 return err 786} 787 788func MergePull(e Execer, repoAt syntax.ATURI, pullId int) error { 789 err := SetPullState(e, repoAt, pullId, PullMerged) 790 return err 791} 792 793func ResubmitPull(e Execer, pull *Pull, newPatch, sourceRev string) error { 794 newRoundNumber := len(pull.Submissions) 795 _, err := e.Exec(` 796 insert into pull_submissions (pull_id, repo_at, round_number, patch, source_rev) 797 values (?, ?, ?, ?, ?) 798 `, pull.PullId, pull.RepoAt, newRoundNumber, newPatch, sourceRev) 799 800 return err 801} 802 803type PullCount struct { 804 Open int 805 Merged int 806 Closed int 807} 808 809func GetPullCount(e Execer, repoAt syntax.ATURI) (PullCount, error) { 810 row := e.QueryRow(` 811 select 812 count(case when state = ? then 1 end) as open_count, 813 count(case when state = ? then 1 end) as merged_count, 814 count(case when state = ? then 1 end) as closed_count 815 from pulls 816 where repo_at = ?`, 817 PullOpen, 818 PullMerged, 819 PullClosed, 820 repoAt, 821 ) 822 823 var count PullCount 824 if err := row.Scan(&count.Open, &count.Merged, &count.Closed); err != nil { 825 return PullCount{0, 0, 0}, err 826 } 827 828 return count, nil 829}