1package db
2
3import (
4 "database/sql"
5 "fmt"
6 "log"
7 "strings"
8 "time"
9
10 "github.com/bluekeyes/go-gitdiff/gitdiff"
11 "github.com/bluesky-social/indigo/atproto/syntax"
12 "github.com/sotangled/tangled/types"
13)
14
15type PullState int
16
17const (
18 PullClosed PullState = iota
19 PullOpen
20 PullMerged
21)
22
23func (p PullState) String() string {
24 switch p {
25 case PullOpen:
26 return "open"
27 case PullMerged:
28 return "merged"
29 case PullClosed:
30 return "closed"
31 default:
32 return "closed"
33 }
34}
35
36func (p PullState) IsOpen() bool {
37 return p == PullOpen
38}
39func (p PullState) IsMerged() bool {
40 return p == PullMerged
41}
42func (p PullState) IsClosed() bool {
43 return p == PullClosed
44}
45
46type Pull struct {
47 // ids
48 ID int
49 PullId int
50
51 // at ids
52 RepoAt syntax.ATURI
53 OwnerDid string
54 Rkey string
55 PullAt syntax.ATURI
56
57 // content
58 Title string
59 Body string
60 TargetBranch string
61 State PullState
62 Submissions []*PullSubmission
63
64 // meta
65 Created time.Time
66}
67
68type PullSubmission struct {
69 // ids
70 ID int
71 PullId int
72
73 // at ids
74 RepoAt syntax.ATURI
75
76 // content
77 RoundNumber int
78 Patch string
79 Comments []PullComment
80
81 // meta
82 Created time.Time
83}
84
85type PullComment struct {
86 // ids
87 ID int
88 PullId int
89 SubmissionId int
90
91 // at ids
92 RepoAt string
93 OwnerDid string
94 CommentAt string
95
96 // content
97 Body string
98
99 // meta
100 Created time.Time
101}
102
103func (p *Pull) LatestPatch() string {
104 latestSubmission := p.Submissions[len(p.Submissions)-1]
105 return latestSubmission.Patch
106}
107
108func (s PullSubmission) AsNiceDiff(targetBranch string) types.NiceDiff {
109 patch := s.Patch
110
111 diffs, _, err := gitdiff.Parse(strings.NewReader(patch))
112 if err != nil {
113 log.Println(err)
114 }
115
116 nd := types.NiceDiff{}
117 nd.Commit.Parent = targetBranch
118
119 for _, d := range diffs {
120 ndiff := types.Diff{}
121 ndiff.Name.New = d.NewName
122 ndiff.Name.Old = d.OldName
123 ndiff.IsBinary = d.IsBinary
124 ndiff.IsNew = d.IsNew
125 ndiff.IsDelete = d.IsDelete
126 ndiff.IsCopy = d.IsCopy
127 ndiff.IsRename = d.IsRename
128
129 for _, tf := range d.TextFragments {
130 ndiff.TextFragments = append(ndiff.TextFragments, *tf)
131 for _, l := range tf.Lines {
132 switch l.Op {
133 case gitdiff.OpAdd:
134 nd.Stat.Insertions += 1
135 case gitdiff.OpDelete:
136 nd.Stat.Deletions += 1
137 }
138 }
139 }
140
141 nd.Diff = append(nd.Diff, ndiff)
142 }
143
144 nd.Stat.FilesChanged = len(diffs)
145
146 return nd
147}
148
149func NewPull(tx *sql.Tx, pull *Pull) error {
150 defer tx.Rollback()
151
152 _, err := tx.Exec(`
153 insert or ignore into repo_pull_seqs (repo_at, next_pull_id)
154 values (?, 1)
155 `, pull.RepoAt)
156 if err != nil {
157 return err
158 }
159
160 var nextId int
161 err = tx.QueryRow(`
162 update repo_pull_seqs
163 set next_pull_id = next_pull_id + 1
164 where repo_at = ?
165 returning next_pull_id - 1
166 `, pull.RepoAt).Scan(&nextId)
167 if err != nil {
168 return err
169 }
170
171 pull.PullId = nextId
172 pull.State = PullOpen
173
174 _, err = tx.Exec(`
175 insert into pulls (repo_at, owner_did, pull_id, title, target_branch, body, rkey, state)
176 values (?, ?, ?, ?, ?, ?, ?, ?)
177 `, pull.RepoAt, pull.OwnerDid, pull.PullId, pull.Title, pull.TargetBranch, pull.Body, pull.Rkey, pull.State)
178 if err != nil {
179 return err
180 }
181
182 _, err = tx.Exec(`
183 insert into pull_submissions (pull_id, repo_at, round_number, patch)
184 values (?, ?, ?, ?)
185 `, pull.PullId, pull.RepoAt, 0, pull.Submissions[0].Patch)
186 if err != nil {
187 return err
188 }
189
190 if err := tx.Commit(); err != nil {
191 return err
192 }
193
194 return nil
195}
196
197func SetPullAt(e Execer, repoAt syntax.ATURI, pullId int, pullAt string) error {
198 _, err := e.Exec(`update pulls set pull_at = ? where repo_at = ? and pull_id = ?`, pullAt, repoAt, pullId)
199 return err
200}
201
202func GetPullAt(e Execer, repoAt syntax.ATURI, pullId int) (string, error) {
203 var pullAt string
204 err := e.QueryRow(`select pull_at from pulls where repo_at = ? and pull_id = ?`, repoAt, pullId).Scan(&pullAt)
205 return pullAt, err
206}
207
208func NextPullId(e Execer, repoAt syntax.ATURI) (int, error) {
209 var pullId int
210 err := e.QueryRow(`select next_pull_id from repo_pull_seqs where repo_at = ?`, repoAt).Scan(&pullId)
211 return pullId - 1, err
212}
213
214func GetPulls(e Execer, repoAt syntax.ATURI, state PullState) ([]Pull, error) {
215 var pulls []Pull
216
217 rows, err := e.Query(`
218 select
219 owner_did,
220 pull_id,
221 created,
222 title,
223 state,
224 target_branch,
225 pull_at,
226 body,
227 rkey
228 from
229 pulls
230 where
231 repo_at = ? and state = ?
232 order by
233 created desc`, repoAt, state)
234 if err != nil {
235 return nil, err
236 }
237 defer rows.Close()
238
239 for rows.Next() {
240 var pull Pull
241 var createdAt string
242 err := rows.Scan(
243 &pull.OwnerDid,
244 &pull.PullId,
245 &createdAt,
246 &pull.Title,
247 &pull.State,
248 &pull.TargetBranch,
249 &pull.PullAt,
250 &pull.Body,
251 &pull.Rkey,
252 )
253 if err != nil {
254 return nil, err
255 }
256
257 createdTime, err := time.Parse(time.RFC3339, createdAt)
258 if err != nil {
259 return nil, err
260 }
261 pull.Created = createdTime
262
263 pulls = append(pulls, pull)
264 }
265
266 if err := rows.Err(); err != nil {
267 return nil, err
268 }
269
270 return pulls, nil
271}
272
273func GetPull(e Execer, repoAt syntax.ATURI, pullId int) (*Pull, error) {
274 query := `
275 select
276 owner_did,
277 pull_id,
278 created,
279 title,
280 state,
281 target_branch,
282 pull_at,
283 repo_at,
284 body,
285 rkey
286 from
287 pulls
288 where
289 repo_at = ? and pull_id = ?
290 `
291 row := e.QueryRow(query, repoAt, pullId)
292
293 var pull Pull
294 var createdAt string
295 err := row.Scan(
296 &pull.OwnerDid,
297 &pull.PullId,
298 &createdAt,
299 &pull.Title,
300 &pull.State,
301 &pull.TargetBranch,
302 &pull.PullAt,
303 &pull.RepoAt,
304 &pull.Body,
305 &pull.Rkey,
306 )
307 if err != nil {
308 return nil, err
309 }
310
311 createdTime, err := time.Parse(time.RFC3339, createdAt)
312 if err != nil {
313 return nil, err
314 }
315 pull.Created = createdTime
316
317 submissionsQuery := `
318 select
319 id, pull_id, repo_at, round_number, patch, created
320 from
321 pull_submissions
322 where
323 repo_at = ? and pull_id = ?
324 `
325 submissionsRows, err := e.Query(submissionsQuery, repoAt, pullId)
326 if err != nil {
327 return nil, err
328 }
329 defer submissionsRows.Close()
330
331 submissionsMap := make(map[int]*PullSubmission)
332
333 for submissionsRows.Next() {
334 var submission PullSubmission
335 var submissionCreatedStr string
336 err := submissionsRows.Scan(
337 &submission.ID,
338 &submission.PullId,
339 &submission.RepoAt,
340 &submission.RoundNumber,
341 &submission.Patch,
342 &submissionCreatedStr,
343 )
344 if err != nil {
345 return nil, err
346 }
347
348 submissionCreatedTime, err := time.Parse(time.RFC3339, submissionCreatedStr)
349 if err != nil {
350 return nil, err
351 }
352 submission.Created = submissionCreatedTime
353
354 submissionsMap[submission.ID] = &submission
355 }
356 if err = submissionsRows.Close(); err != nil {
357 return nil, err
358 }
359 if len(submissionsMap) == 0 {
360 return &pull, nil
361 }
362
363 var args []any
364 for k := range submissionsMap {
365 args = append(args, k)
366 }
367 inClause := strings.TrimSuffix(strings.Repeat("?, ", len(submissionsMap)), ", ")
368 commentsQuery := fmt.Sprintf(`
369 select
370 id,
371 pull_id,
372 submission_id,
373 repo_at,
374 owner_did,
375 comment_at,
376 body,
377 created
378 from
379 pull_comments
380 where
381 submission_id IN (%s)
382 order by
383 created asc
384 `, inClause)
385 commentsRows, err := e.Query(commentsQuery, args...)
386 if err != nil {
387 return nil, err
388 }
389 defer commentsRows.Close()
390
391 for commentsRows.Next() {
392 var comment PullComment
393 var commentCreatedStr string
394 err := commentsRows.Scan(
395 &comment.ID,
396 &comment.PullId,
397 &comment.SubmissionId,
398 &comment.RepoAt,
399 &comment.OwnerDid,
400 &comment.CommentAt,
401 &comment.Body,
402 &commentCreatedStr,
403 )
404 if err != nil {
405 return nil, err
406 }
407
408 commentCreatedTime, err := time.Parse(time.RFC3339, commentCreatedStr)
409 if err != nil {
410 return nil, err
411 }
412 comment.Created = commentCreatedTime
413
414 // Add the comment to its submission
415 if submission, ok := submissionsMap[comment.SubmissionId]; ok {
416 submission.Comments = append(submission.Comments, comment)
417 }
418
419 }
420 if err = commentsRows.Err(); err != nil {
421 return nil, err
422 }
423
424 pull.Submissions = make([]*PullSubmission, len(submissionsMap))
425 for _, submission := range submissionsMap {
426 pull.Submissions[submission.RoundNumber] = submission
427 }
428
429 return &pull, nil
430}
431
432func NewPullComment(e Execer, comment *PullComment) (int64, error) {
433 query := `insert into pull_comments (owner_did, repo_at, submission_id, comment_at, pull_id, body) values (?, ?, ?, ?, ?, ?)`
434 res, err := e.Exec(
435 query,
436 comment.OwnerDid,
437 comment.RepoAt,
438 comment.SubmissionId,
439 comment.CommentAt,
440 comment.PullId,
441 comment.Body,
442 )
443 if err != nil {
444 return 0, err
445 }
446
447 i, err := res.LastInsertId()
448 if err != nil {
449 return 0, err
450 }
451
452 return i, nil
453}
454
455func SetPullState(e Execer, repoAt syntax.ATURI, pullId int, pullState PullState) error {
456 _, err := e.Exec(`update pulls set state = ? where repo_at = ? and pull_id = ?`, pullState, repoAt, pullId)
457 return err
458}
459
460func ClosePull(e Execer, repoAt syntax.ATURI, pullId int) error {
461 err := SetPullState(e, repoAt, pullId, PullClosed)
462 return err
463}
464
465func ReopenPull(e Execer, repoAt syntax.ATURI, pullId int) error {
466 err := SetPullState(e, repoAt, pullId, PullOpen)
467 return err
468}
469
470func MergePull(e Execer, repoAt syntax.ATURI, pullId int) error {
471 err := SetPullState(e, repoAt, pullId, PullMerged)
472 return err
473}
474
475func ResubmitPull(e Execer, pull *Pull, newPatch string) error {
476 newRoundNumber := len(pull.Submissions)
477 _, err := e.Exec(`
478 insert into pull_submissions (pull_id, repo_at, round_number, patch)
479 values (?, ?, ?, ?)
480 `, pull.PullId, pull.RepoAt, newRoundNumber, newPatch)
481
482 return err
483}
484
485type PullCount struct {
486 Open int
487 Merged int
488 Closed int
489}
490
491func GetPullCount(e Execer, repoAt syntax.ATURI) (PullCount, error) {
492 row := e.QueryRow(`
493 select
494 count(case when state = ? then 1 end) as open_count,
495 count(case when state = ? then 1 end) as merged_count,
496 count(case when state = ? then 1 end) as closed_count
497 from pulls
498 where repo_at = ?`,
499 PullOpen,
500 PullMerged,
501 PullClosed,
502 repoAt,
503 )
504
505 var count PullCount
506 if err := row.Scan(&count.Open, &count.Merged, &count.Closed); err != nil {
507 return PullCount{0, 0, 0}, err
508 }
509
510 return count, nil
511}