From 494ace72ef4ac8810939b4706e587ea02cbbc2be Mon Sep 17 00:00:00 2001 From: oppiliappan Date: Sat, 27 Sep 2025 18:29:24 +0100 Subject: [PATCH] appview/db: refactor GetPulls Change-Id: kntqukzzonxxqotsnllrqwtsoovyknkl instead of using a massive left join, it now uses a few FilterIns. we can also get rid of the GetPull helper, it is a specialization of GetPulls that returns a single pull request. Signed-off-by: oppiliappan --- appview/db/db.go | 140 +++++++++++++++++ appview/db/pulls.go | 347 +++++++++++++---------------------------- appview/models/pull.go | 5 +- 3 files changed, 252 insertions(+), 240 deletions(-) diff --git a/appview/db/db.go b/appview/db/db.go index a4b39560..cabfaa1b 100644 --- a/appview/db/db.go +++ b/appview/db/db.go @@ -954,6 +954,146 @@ func Make(dbPath string) (*DB, error) { return err }) + // add generated at_uri column to pulls table + // + // this requires a full table recreation because stored columns + // cannot be added via alter + // + // disable foreign-keys for the next migration + conn.ExecContext(ctx, "pragma foreign_keys = off;") + runMigration(conn, "add-at-uri-to-pulls", func(tx *sql.Tx) error { + _, err := tx.Exec(` + create table if not exists pulls_new ( + -- identifiers + id integer primary key autoincrement, + pull_id integer not null, + at_uri text generated always as ('at://' || owner_did || '/' || 'sh.tangled.repo.pull' || '/' || rkey) stored, + + -- at identifiers + repo_at text not null, + owner_did text not null, + rkey text not null, + + -- content + title text not null, + body text not null, + target_branch text not null, + state integer not null default 0 check (state in (0, 1, 2, 3)), -- closed, open, merged, deleted + + -- source info + source_branch text, + source_repo_at text, + + -- stacking + stack_id text, + change_id text, + parent_change_id text, + + -- meta + created text not null default (strftime('%Y-%m-%dT%H:%M:%SZ', 'now')), + + -- constraints + unique(repo_at, pull_id), + unique(at_uri), + foreign key (repo_at) references repos(at_uri) on delete cascade + ); + `) + if err != nil { + return err + } + + // transfer data + _, err = tx.Exec(` + insert into pulls_new ( + id, pull_id, repo_at, owner_did, rkey, + title, body, target_branch, state, + source_branch, source_repo_at, + stack_id, change_id, parent_change_id, + created + ) + select + id, pull_id, repo_at, owner_did, rkey, + title, body, target_branch, state, + source_branch, source_repo_at, + stack_id, change_id, parent_change_id, + created + from pulls; + `) + if err != nil { + return err + } + + // drop old table + _, err = tx.Exec(`drop table pulls`) + if err != nil { + return err + } + + // rename new table + _, err = tx.Exec(`alter table pulls_new rename to pulls`) + return err + }) + conn.ExecContext(ctx, "pragma foreign_keys = on;") + + // remove repo_at and pull_id from pull_submissions and replace with pull_at + // + // this requires a full table recreation because stored columns + // cannot be added via alter + // + // disable foreign-keys for the next migration + conn.ExecContext(ctx, "pragma foreign_keys = off;") + runMigration(conn, "remove-repo-at-pull-id-from-pull-submissions", func(tx *sql.Tx) error { + _, err := tx.Exec(` + create table if not exists pull_submissions_new ( + -- identifiers + id integer primary key autoincrement, + pull_at text not null, + + -- content, these are immutable, and require a resubmission to update + round_number integer not null default 0, + patch text, + source_rev text, + + -- meta + created text not null default (strftime('%Y-%m-%dT%H:%M:%SZ', 'now')), + + -- constraints + unique(pull_at, round_number), + foreign key (pull_at) references pulls(at_uri) on delete cascade + ); + `) + if err != nil { + return err + } + + // transfer data, constructing pull_at from pulls table + _, err = tx.Exec(` + insert into pull_submissions_new (id, pull_at, round_number, patch, created) + select + ps.id, + 'at://' || p.owner_did || '/sh.tangled.repo.pull/' || p.rkey, + ps.round_number, + ps.patch, + ps.created + from pull_submissions ps + join pulls p on ps.repo_at = p.repo_at and ps.pull_id = p.pull_id; + `) + if err != nil { + return err + } + + // drop old table + _, err = tx.Exec(`drop table pull_submissions`) + if err != nil { + return err + } + + // rename new table + _, err = tx.Exec(`alter table pull_submissions_new rename to pull_submissions`) + return err + }) + conn.ExecContext(ctx, "pragma foreign_keys = on;") + return &DB{db}, nil } diff --git a/appview/db/pulls.go b/appview/db/pulls.go index aad64fdc..7bc6a2af 100644 --- a/appview/db/pulls.go +++ b/appview/db/pulls.go @@ -3,7 +3,8 @@ package db import ( "database/sql" "fmt" - "log" + "maps" + "slices" "sort" "strings" "time" @@ -87,9 +88,9 @@ func NewPull(tx *sql.Tx, pull *models.Pull) error { pull.ID = int(id) _, err = tx.Exec(` - insert into pull_submissions (pull_id, repo_at, round_number, patch, source_rev) - values (?, ?, ?, ?, ?) - `, pull.PullId, pull.RepoAt, 0, pull.Submissions[0].Patch, pull.Submissions[0].SourceRev) + insert into pull_submissions (pull_at, round_number, patch, source_rev) + values (?, ?, ?, ?) + `, pull.PullAt(), 0, pull.Submissions[0].Patch, pull.Submissions[0].SourceRev) return err } @@ -108,7 +109,7 @@ func NextPullId(e Execer, repoAt syntax.ATURI) (int, error) { } func GetPullsWithLimit(e Execer, limit int, filters ...filter) ([]*models.Pull, error) { - pulls := make(map[int]*models.Pull) + pulls := make(map[syntax.ATURI]*models.Pull) var conditions []string var args []any @@ -211,110 +212,23 @@ func GetPullsWithLimit(e Execer, limit int, filters ...filter) ([]*models.Pull, pull.ParentChangeId = parentChangeId.String } - pulls[pull.PullId] = &pull + pulls[pull.PullAt()] = &pull } - // get latest round no. for each pull - inClause := strings.TrimSuffix(strings.Repeat("?, ", len(pulls)), ", ") - submissionsQuery := fmt.Sprintf(` - select - id, pull_id, round_number, patch, created, source_rev - from - pull_submissions - where - repo_at in (%s) and pull_id in (%s) - `, inClause, inClause) - - args = make([]any, len(pulls)*2) - idx := 0 - for _, p := range pulls { - args[idx] = p.RepoAt - idx += 1 - } - for _, p := range pulls { - args[idx] = p.PullId - idx += 1 - } - submissionsRows, err := e.Query(submissionsQuery, args...) - if err != nil { - return nil, err - } - defer submissionsRows.Close() - - for submissionsRows.Next() { - var s models.PullSubmission - var sourceRev sql.NullString - var createdAt string - err := submissionsRows.Scan( - &s.ID, - &s.PullId, - &s.RoundNumber, - &s.Patch, - &createdAt, - &sourceRev, - ) - if err != nil { - return nil, err - } - - createdTime, err := time.Parse(time.RFC3339, createdAt) - if err != nil { - return nil, err - } - s.Created = createdTime - - if sourceRev.Valid { - s.SourceRev = sourceRev.String - } - - if p, ok := pulls[s.PullId]; ok { - p.Submissions = make([]*models.PullSubmission, s.RoundNumber+1) - p.Submissions[s.RoundNumber] = &s - } - } - if err := rows.Err(); err != nil { - return nil, err - } - - // get comment count on latest submission on each pull - inClause = strings.TrimSuffix(strings.Repeat("?, ", len(pulls)), ", ") - commentsQuery := fmt.Sprintf(` - select - count(id), pull_id - from - pull_comments - where - submission_id in (%s) - group by - submission_id - `, inClause) - - args = []any{} + var pullAts []syntax.ATURI for _, p := range pulls { - args = append(args, p.Submissions[p.LastRoundNumber()].ID) + pullAts = append(pullAts, p.PullAt()) } - commentsRows, err := e.Query(commentsQuery, args...) + submissionsMap, err := GetPullSubmissions(e, FilterIn("pull_at", pullAts)) if err != nil { - return nil, err + return nil, fmt.Errorf("failed to get submissions: %w", err) } - defer commentsRows.Close() - for commentsRows.Next() { - var commentCount, pullId int - err := commentsRows.Scan( - &commentCount, - &pullId, - ) - if err != nil { - return nil, err - } - if p, ok := pulls[pullId]; ok { - p.Submissions[p.LastRoundNumber()].Comments = make([]models.PullComment, commentCount) + for pullAt, submissions := range submissionsMap { + if p, ok := pulls[pullAt]; ok { + p.Submissions = submissions } } - if err := rows.Err(); err != nil { - return nil, err - } orderedByPullId := []*models.Pull{} for _, p := range pulls { @@ -332,142 +246,122 @@ func GetPulls(e Execer, filters ...filter) ([]*models.Pull, error) { } func GetPull(e Execer, repoAt syntax.ATURI, pullId int) (*models.Pull, error) { - query := ` - select - id, - owner_did, - pull_id, - created, - title, - state, - target_branch, - repo_at, - body, - rkey, - source_branch, - source_repo_at, - stack_id, - change_id, - parent_change_id - from - pulls - where - repo_at = ? and pull_id = ? - ` - row := e.QueryRow(query, repoAt, pullId) - - var pull models.Pull - var createdAt string - var sourceBranch, sourceRepoAt, stackId, changeId, parentChangeId sql.NullString - err := row.Scan( - &pull.ID, - &pull.OwnerDid, - &pull.PullId, - &createdAt, - &pull.Title, - &pull.State, - &pull.TargetBranch, - &pull.RepoAt, - &pull.Body, - &pull.Rkey, - &sourceBranch, - &sourceRepoAt, - &stackId, - &changeId, - &parentChangeId, - ) + pulls, err := GetPullsWithLimit(e, 1, FilterEq("repo_at", repoAt), FilterEq("pull_id", pullId)) if err != nil { return nil, err } - - createdTime, err := time.Parse(time.RFC3339, createdAt) - if err != nil { - return nil, err + if pulls == nil { + return nil, sql.ErrNoRows } - pull.Created = createdTime - // populate source - if sourceBranch.Valid { - pull.PullSource = &models.PullSource{ - Branch: sourceBranch.String, - } - if sourceRepoAt.Valid { - sourceRepoAtParsed, err := syntax.ParseATURI(sourceRepoAt.String) - if err != nil { - return nil, err - } - pull.PullSource.RepoAt = &sourceRepoAtParsed - } - } + return pulls[0], nil +} - if stackId.Valid { - pull.StackId = stackId.String - } - if changeId.Valid { - pull.ChangeId = changeId.String +// mapping from pull -> pull submissions +func GetPullSubmissions(e Execer, filters ...filter) (map[syntax.ATURI][]*models.PullSubmission, error) { + var conditions []string + var args []any + for _, filter := range filters { + conditions = append(conditions, filter.Condition()) + args = append(args, filter.Arg()...) } - if parentChangeId.Valid { - pull.ParentChangeId = parentChangeId.String + + whereClause := "" + if conditions != nil { + whereClause = " where " + strings.Join(conditions, " and ") } - submissionsQuery := ` + query := fmt.Sprintf(` select - id, pull_id, repo_at, round_number, patch, created, source_rev + id, + pull_at, + round_number, + patch, + created, + source_rev from pull_submissions - where - repo_at = ? and pull_id = ? - ` - submissionsRows, err := e.Query(submissionsQuery, repoAt, pullId) + %s + order by + round_number asc + `, whereClause) + + rows, err := e.Query(query, args...) if err != nil { return nil, err } - defer submissionsRows.Close() + defer rows.Close() - submissionsMap := make(map[int]*models.PullSubmission) + submissionMap := make(map[int]*models.PullSubmission) - for submissionsRows.Next() { + for rows.Next() { var submission models.PullSubmission - var submissionCreatedStr string - var submissionSourceRev sql.NullString - err := submissionsRows.Scan( + var createdAt string + var sourceRev sql.NullString + err := rows.Scan( &submission.ID, - &submission.PullId, - &submission.RepoAt, + &submission.PullAt, &submission.RoundNumber, &submission.Patch, - &submissionCreatedStr, - &submissionSourceRev, + &createdAt, + &sourceRev, ) if err != nil { return nil, err } - submissionCreatedTime, err := time.Parse(time.RFC3339, submissionCreatedStr) + createdTime, err := time.Parse(time.RFC3339, createdAt) if err != nil { return nil, err } - submission.Created = submissionCreatedTime + submission.Created = createdTime - if submissionSourceRev.Valid { - submission.SourceRev = submissionSourceRev.String + if sourceRev.Valid { + submission.SourceRev = sourceRev.String } - submissionsMap[submission.ID] = &submission + submissionMap[submission.ID] = &submission + } + + if err := rows.Err(); err != nil { + return nil, err } - if err = submissionsRows.Close(); err != nil { + + // Get comments for all submissions using GetPullComments + submissionIds := slices.Collect(maps.Keys(submissionMap)) + comments, err := GetPullComments(e, FilterIn("submission_id", submissionIds)) + if err != nil { return nil, err } - if len(submissionsMap) == 0 { - return &pull, nil + for _, comment := range comments { + if submission, ok := submissionMap[comment.SubmissionId]; ok { + submission.Comments = append(submission.Comments, comment) + } + } + + // order the submissions by pull_at + m := make(map[syntax.ATURI][]*models.PullSubmission) + for _, s := range submissionMap { + m[s.PullAt] = append(m[s.PullAt], s) } + return m, nil +} + +func GetPullComments(e Execer, filters ...filter) ([]models.PullComment, error) { + var conditions []string var args []any - for k := range submissionsMap { - args = append(args, k) + for _, filter := range filters { + conditions = append(conditions, filter.Condition()) + args = append(args, filter.Arg()...) } - inClause := strings.TrimSuffix(strings.Repeat("?, ", len(submissionsMap)), ", ") - commentsQuery := fmt.Sprintf(` + + whereClause := "" + if conditions != nil { + whereClause = " where " + strings.Join(conditions, " and ") + } + + query := fmt.Sprintf(` select id, pull_id, @@ -479,21 +373,22 @@ func GetPull(e Execer, repoAt syntax.ATURI, pullId int) (*models.Pull, error) { created from pull_comments - where - submission_id IN (%s) + %s order by created asc - `, inClause) - commentsRows, err := e.Query(commentsQuery, args...) + `, whereClause) + + rows, err := e.Query(query, args...) if err != nil { return nil, err } - defer commentsRows.Close() + defer rows.Close() - for commentsRows.Next() { + var comments []models.PullComment + for rows.Next() { var comment models.PullComment - var commentCreatedStr string - err := commentsRows.Scan( + var createdAt string + err := rows.Scan( &comment.ID, &comment.PullId, &comment.SubmissionId, @@ -501,46 +396,24 @@ func GetPull(e Execer, repoAt syntax.ATURI, pullId int) (*models.Pull, error) { &comment.OwnerDid, &comment.CommentAt, &comment.Body, - &commentCreatedStr, + &createdAt, ) if err != nil { return nil, err } - commentCreatedTime, err := time.Parse(time.RFC3339, commentCreatedStr) - if err != nil { - return nil, err + if t, err := time.Parse(time.RFC3339, createdAt); err == nil { + comment.Created = t } - comment.Created = commentCreatedTime - - // Add the comment to its submission - if submission, ok := submissionsMap[comment.SubmissionId]; ok { - submission.Comments = append(submission.Comments, comment) - } - - } - if err = commentsRows.Err(); err != nil { - return nil, err - } - var pullSourceRepo *models.Repo - if pull.PullSource != nil { - if pull.PullSource.RepoAt != nil { - pullSourceRepo, err = GetRepoByAtUri(e, pull.PullSource.RepoAt.String()) - if err != nil { - log.Printf("failed to get repo by at uri: %v", err) - } else { - pull.PullSource.Repo = pullSourceRepo - } - } + comments = append(comments, comment) } - pull.Submissions = make([]*models.PullSubmission, len(submissionsMap)) - for _, submission := range submissionsMap { - pull.Submissions[submission.RoundNumber] = submission + if err := rows.Err(); err != nil { + return nil, err } - return &pull, nil + return comments, nil } // timeframe here is directly passed into the sql query filter, and any @@ -677,9 +550,9 @@ func DeletePull(e Execer, repoAt syntax.ATURI, pullId int) error { func ResubmitPull(e Execer, pull *models.Pull, newPatch, sourceRev string) error { newRoundNumber := len(pull.Submissions) _, err := e.Exec(` - insert into pull_submissions (pull_id, repo_at, round_number, patch, source_rev) - values (?, ?, ?, ?, ?) - `, pull.PullId, pull.RepoAt, newRoundNumber, newPatch, sourceRev) + insert into pull_submissions (pull_at, round_number, patch, source_rev) + values (?, ?, ?, ?) + `, pull.PullAt(), newRoundNumber, newPatch, sourceRev) return err } diff --git a/appview/models/pull.go b/appview/models/pull.go index 2826b00b..e17f96ae 100644 --- a/appview/models/pull.go +++ b/appview/models/pull.go @@ -125,11 +125,10 @@ func (p PullSource) AsRecord() tangled.RepoPull_Source { type PullSubmission struct { // ids - ID int - PullId int + ID int // at ids - RepoAt syntax.ATURI + PullAt syntax.ATURI // content RoundNumber int -- 2.43.0