1package db
2
3import (
4 "database/sql"
5 "fmt"
6 "log"
7 "strings"
8 "time"
9
10 "github.com/bluekeyes/go-gitdiff/gitdiff"
11 "github.com/bluesky-social/indigo/atproto/syntax"
12 "tangled.sh/tangled.sh/core/types"
13)
14
15type PullState int
16
17const (
18 PullClosed PullState = iota
19 PullOpen
20 PullMerged
21)
22
23func (p PullState) String() string {
24 switch p {
25 case PullOpen:
26 return "open"
27 case PullMerged:
28 return "merged"
29 case PullClosed:
30 return "closed"
31 default:
32 return "closed"
33 }
34}
35
36func (p PullState) IsOpen() bool {
37 return p == PullOpen
38}
39func (p PullState) IsMerged() bool {
40 return p == PullMerged
41}
42func (p PullState) IsClosed() bool {
43 return p == PullClosed
44}
45
46type Pull struct {
47 // ids
48 ID int
49 PullId int
50
51 // at ids
52 RepoAt syntax.ATURI
53 OwnerDid string
54 Rkey string
55 PullAt syntax.ATURI
56
57 // content
58 Title string
59 Body string
60 TargetBranch string
61 State PullState
62 Submissions []*PullSubmission
63
64 // meta
65 Created time.Time
66}
67
68type PullSubmission struct {
69 // ids
70 ID int
71 PullId int
72
73 // at ids
74 RepoAt syntax.ATURI
75
76 // content
77 RoundNumber int
78 Patch string
79 Comments []PullComment
80
81 // meta
82 Created time.Time
83}
84
85type PullComment struct {
86 // ids
87 ID int
88 PullId int
89 SubmissionId int
90
91 // at ids
92 RepoAt string
93 OwnerDid string
94 CommentAt string
95
96 // content
97 Body string
98
99 // meta
100 Created time.Time
101}
102
103func (p *Pull) LatestPatch() string {
104 latestSubmission := p.Submissions[p.LastRoundNumber()]
105 return latestSubmission.Patch
106}
107
108func (p *Pull) LastRoundNumber() int {
109 return len(p.Submissions) - 1
110}
111
112func (s PullSubmission) AsNiceDiff(targetBranch string) types.NiceDiff {
113 patch := s.Patch
114
115 diffs, _, err := gitdiff.Parse(strings.NewReader(patch))
116 if err != nil {
117 log.Println(err)
118 }
119
120 nd := types.NiceDiff{}
121 nd.Commit.Parent = targetBranch
122
123 for _, d := range diffs {
124 ndiff := types.Diff{}
125 ndiff.Name.New = d.NewName
126 ndiff.Name.Old = d.OldName
127 ndiff.IsBinary = d.IsBinary
128 ndiff.IsNew = d.IsNew
129 ndiff.IsDelete = d.IsDelete
130 ndiff.IsCopy = d.IsCopy
131 ndiff.IsRename = d.IsRename
132
133 for _, tf := range d.TextFragments {
134 ndiff.TextFragments = append(ndiff.TextFragments, *tf)
135 for _, l := range tf.Lines {
136 switch l.Op {
137 case gitdiff.OpAdd:
138 nd.Stat.Insertions += 1
139 case gitdiff.OpDelete:
140 nd.Stat.Deletions += 1
141 }
142 }
143 }
144
145 nd.Diff = append(nd.Diff, ndiff)
146 }
147
148 nd.Stat.FilesChanged = len(diffs)
149
150 return nd
151}
152
153func NewPull(tx *sql.Tx, pull *Pull) error {
154 defer tx.Rollback()
155
156 _, err := tx.Exec(`
157 insert or ignore into repo_pull_seqs (repo_at, next_pull_id)
158 values (?, 1)
159 `, pull.RepoAt)
160 if err != nil {
161 return err
162 }
163
164 var nextId int
165 err = tx.QueryRow(`
166 update repo_pull_seqs
167 set next_pull_id = next_pull_id + 1
168 where repo_at = ?
169 returning next_pull_id - 1
170 `, pull.RepoAt).Scan(&nextId)
171 if err != nil {
172 return err
173 }
174
175 pull.PullId = nextId
176 pull.State = PullOpen
177
178 _, err = tx.Exec(`
179 insert into pulls (repo_at, owner_did, pull_id, title, target_branch, body, rkey, state)
180 values (?, ?, ?, ?, ?, ?, ?, ?)
181 `, pull.RepoAt, pull.OwnerDid, pull.PullId, pull.Title, pull.TargetBranch, pull.Body, pull.Rkey, pull.State)
182 if err != nil {
183 return err
184 }
185
186 _, err = tx.Exec(`
187 insert into pull_submissions (pull_id, repo_at, round_number, patch)
188 values (?, ?, ?, ?)
189 `, pull.PullId, pull.RepoAt, 0, pull.Submissions[0].Patch)
190 if err != nil {
191 return err
192 }
193
194 if err := tx.Commit(); err != nil {
195 return err
196 }
197
198 return nil
199}
200
201func SetPullAt(e Execer, repoAt syntax.ATURI, pullId int, pullAt string) error {
202 _, err := e.Exec(`update pulls set pull_at = ? where repo_at = ? and pull_id = ?`, pullAt, repoAt, pullId)
203 return err
204}
205
206func GetPullAt(e Execer, repoAt syntax.ATURI, pullId int) (string, error) {
207 var pullAt string
208 err := e.QueryRow(`select pull_at from pulls where repo_at = ? and pull_id = ?`, repoAt, pullId).Scan(&pullAt)
209 return pullAt, err
210}
211
212func NextPullId(e Execer, repoAt syntax.ATURI) (int, error) {
213 var pullId int
214 err := e.QueryRow(`select next_pull_id from repo_pull_seqs where repo_at = ?`, repoAt).Scan(&pullId)
215 return pullId - 1, err
216}
217
218func GetPulls(e Execer, repoAt syntax.ATURI, state PullState) ([]Pull, error) {
219 var pulls []Pull
220
221 rows, err := e.Query(`
222 select
223 owner_did,
224 pull_id,
225 created,
226 title,
227 state,
228 target_branch,
229 pull_at,
230 body,
231 rkey
232 from
233 pulls
234 where
235 repo_at = ? and state = ?
236 order by
237 created desc`, repoAt, state)
238 if err != nil {
239 return nil, err
240 }
241 defer rows.Close()
242
243 for rows.Next() {
244 var pull Pull
245 var createdAt string
246 err := rows.Scan(
247 &pull.OwnerDid,
248 &pull.PullId,
249 &createdAt,
250 &pull.Title,
251 &pull.State,
252 &pull.TargetBranch,
253 &pull.PullAt,
254 &pull.Body,
255 &pull.Rkey,
256 )
257 if err != nil {
258 return nil, err
259 }
260
261 createdTime, err := time.Parse(time.RFC3339, createdAt)
262 if err != nil {
263 return nil, err
264 }
265 pull.Created = createdTime
266
267 pulls = append(pulls, pull)
268 }
269
270 if err := rows.Err(); err != nil {
271 return nil, err
272 }
273
274 return pulls, nil
275}
276
277func GetPull(e Execer, repoAt syntax.ATURI, pullId int) (*Pull, error) {
278 query := `
279 select
280 owner_did,
281 pull_id,
282 created,
283 title,
284 state,
285 target_branch,
286 pull_at,
287 repo_at,
288 body,
289 rkey
290 from
291 pulls
292 where
293 repo_at = ? and pull_id = ?
294 `
295 row := e.QueryRow(query, repoAt, pullId)
296
297 var pull Pull
298 var createdAt string
299 err := row.Scan(
300 &pull.OwnerDid,
301 &pull.PullId,
302 &createdAt,
303 &pull.Title,
304 &pull.State,
305 &pull.TargetBranch,
306 &pull.PullAt,
307 &pull.RepoAt,
308 &pull.Body,
309 &pull.Rkey,
310 )
311 if err != nil {
312 return nil, err
313 }
314
315 createdTime, err := time.Parse(time.RFC3339, createdAt)
316 if err != nil {
317 return nil, err
318 }
319 pull.Created = createdTime
320
321 submissionsQuery := `
322 select
323 id, pull_id, repo_at, round_number, patch, created
324 from
325 pull_submissions
326 where
327 repo_at = ? and pull_id = ?
328 `
329 submissionsRows, err := e.Query(submissionsQuery, repoAt, pullId)
330 if err != nil {
331 return nil, err
332 }
333 defer submissionsRows.Close()
334
335 submissionsMap := make(map[int]*PullSubmission)
336
337 for submissionsRows.Next() {
338 var submission PullSubmission
339 var submissionCreatedStr string
340 err := submissionsRows.Scan(
341 &submission.ID,
342 &submission.PullId,
343 &submission.RepoAt,
344 &submission.RoundNumber,
345 &submission.Patch,
346 &submissionCreatedStr,
347 )
348 if err != nil {
349 return nil, err
350 }
351
352 submissionCreatedTime, err := time.Parse(time.RFC3339, submissionCreatedStr)
353 if err != nil {
354 return nil, err
355 }
356 submission.Created = submissionCreatedTime
357
358 submissionsMap[submission.ID] = &submission
359 }
360 if err = submissionsRows.Close(); err != nil {
361 return nil, err
362 }
363 if len(submissionsMap) == 0 {
364 return &pull, nil
365 }
366
367 var args []any
368 for k := range submissionsMap {
369 args = append(args, k)
370 }
371 inClause := strings.TrimSuffix(strings.Repeat("?, ", len(submissionsMap)), ", ")
372 commentsQuery := fmt.Sprintf(`
373 select
374 id,
375 pull_id,
376 submission_id,
377 repo_at,
378 owner_did,
379 comment_at,
380 body,
381 created
382 from
383 pull_comments
384 where
385 submission_id IN (%s)
386 order by
387 created asc
388 `, inClause)
389 commentsRows, err := e.Query(commentsQuery, args...)
390 if err != nil {
391 return nil, err
392 }
393 defer commentsRows.Close()
394
395 for commentsRows.Next() {
396 var comment PullComment
397 var commentCreatedStr string
398 err := commentsRows.Scan(
399 &comment.ID,
400 &comment.PullId,
401 &comment.SubmissionId,
402 &comment.RepoAt,
403 &comment.OwnerDid,
404 &comment.CommentAt,
405 &comment.Body,
406 &commentCreatedStr,
407 )
408 if err != nil {
409 return nil, err
410 }
411
412 commentCreatedTime, err := time.Parse(time.RFC3339, commentCreatedStr)
413 if err != nil {
414 return nil, err
415 }
416 comment.Created = commentCreatedTime
417
418 // Add the comment to its submission
419 if submission, ok := submissionsMap[comment.SubmissionId]; ok {
420 submission.Comments = append(submission.Comments, comment)
421 }
422
423 }
424 if err = commentsRows.Err(); err != nil {
425 return nil, err
426 }
427
428 pull.Submissions = make([]*PullSubmission, len(submissionsMap))
429 for _, submission := range submissionsMap {
430 pull.Submissions[submission.RoundNumber] = submission
431 }
432
433 return &pull, nil
434}
435
436func NewPullComment(e Execer, comment *PullComment) (int64, error) {
437 query := `insert into pull_comments (owner_did, repo_at, submission_id, comment_at, pull_id, body) values (?, ?, ?, ?, ?, ?)`
438 res, err := e.Exec(
439 query,
440 comment.OwnerDid,
441 comment.RepoAt,
442 comment.SubmissionId,
443 comment.CommentAt,
444 comment.PullId,
445 comment.Body,
446 )
447 if err != nil {
448 return 0, err
449 }
450
451 i, err := res.LastInsertId()
452 if err != nil {
453 return 0, err
454 }
455
456 return i, nil
457}
458
459func SetPullState(e Execer, repoAt syntax.ATURI, pullId int, pullState PullState) error {
460 _, err := e.Exec(`update pulls set state = ? where repo_at = ? and pull_id = ?`, pullState, repoAt, pullId)
461 return err
462}
463
464func ClosePull(e Execer, repoAt syntax.ATURI, pullId int) error {
465 err := SetPullState(e, repoAt, pullId, PullClosed)
466 return err
467}
468
469func ReopenPull(e Execer, repoAt syntax.ATURI, pullId int) error {
470 err := SetPullState(e, repoAt, pullId, PullOpen)
471 return err
472}
473
474func MergePull(e Execer, repoAt syntax.ATURI, pullId int) error {
475 err := SetPullState(e, repoAt, pullId, PullMerged)
476 return err
477}
478
479func ResubmitPull(e Execer, pull *Pull, newPatch string) error {
480 newRoundNumber := len(pull.Submissions)
481 _, err := e.Exec(`
482 insert into pull_submissions (pull_id, repo_at, round_number, patch)
483 values (?, ?, ?, ?)
484 `, pull.PullId, pull.RepoAt, newRoundNumber, newPatch)
485
486 return err
487}
488
489type PullCount struct {
490 Open int
491 Merged int
492 Closed int
493}
494
495func GetPullCount(e Execer, repoAt syntax.ATURI) (PullCount, error) {
496 row := e.QueryRow(`
497 select
498 count(case when state = ? then 1 end) as open_count,
499 count(case when state = ? then 1 end) as merged_count,
500 count(case when state = ? then 1 end) as closed_count
501 from pulls
502 where repo_at = ?`,
503 PullOpen,
504 PullMerged,
505 PullClosed,
506 repoAt,
507 )
508
509 var count PullCount
510 if err := row.Scan(&count.Open, &count.Merged, &count.Closed); err != nil {
511 return PullCount{0, 0, 0}, err
512 }
513
514 return count, nil
515}