1package db
2
3import (
4 "context"
5 "database/sql"
6 "errors"
7 "fmt"
8 "strings"
9 "time"
10
11 "github.com/bluesky-social/indigo/atproto/syntax"
12 "tangled.org/core/appview/models"
13 "tangled.org/core/appview/pagination"
14)
15
16func CreateNotification(e Execer, notification *models.Notification) error {
17 query := `
18 INSERT INTO notifications (recipient_did, actor_did, type, entity_type, entity_id, read, repo_id, issue_id, pull_id)
19 VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
20 `
21
22 result, err := e.Exec(query,
23 notification.RecipientDid,
24 notification.ActorDid,
25 string(notification.Type),
26 notification.EntityType,
27 notification.EntityId,
28 notification.Read,
29 notification.RepoId,
30 notification.IssueId,
31 notification.PullId,
32 )
33 if err != nil {
34 return fmt.Errorf("failed to create notification: %w", err)
35 }
36
37 id, err := result.LastInsertId()
38 if err != nil {
39 return fmt.Errorf("failed to get notification ID: %w", err)
40 }
41
42 notification.ID = id
43 return nil
44}
45
46// GetNotificationsPaginated retrieves notifications with filters and pagination
47func GetNotificationsPaginated(e Execer, page pagination.Page, filters ...filter) ([]*models.Notification, error) {
48 var conditions []string
49 var args []any
50
51 for _, filter := range filters {
52 conditions = append(conditions, filter.Condition())
53 args = append(args, filter.Arg()...)
54 }
55
56 whereClause := ""
57 if len(conditions) > 0 {
58 whereClause = "WHERE " + conditions[0]
59 for _, condition := range conditions[1:] {
60 whereClause += " AND " + condition
61 }
62 }
63 pageClause := ""
64 if page.Limit > 0 {
65 pageClause = " limit ? offset ? "
66 args = append(args, page.Limit, page.Offset)
67 }
68
69 query := fmt.Sprintf(`
70 select id, recipient_did, actor_did, type, entity_type, entity_id, read, created, repo_id, issue_id, pull_id
71 from notifications
72 %s
73 order by created desc
74 %s
75 `, whereClause, pageClause)
76
77 rows, err := e.QueryContext(context.Background(), query, args...)
78 if err != nil {
79 return nil, fmt.Errorf("failed to query notifications: %w", err)
80 }
81 defer rows.Close()
82
83 var notifications []*models.Notification
84 for rows.Next() {
85 var n models.Notification
86 var typeStr string
87 var createdStr string
88 err := rows.Scan(
89 &n.ID,
90 &n.RecipientDid,
91 &n.ActorDid,
92 &typeStr,
93 &n.EntityType,
94 &n.EntityId,
95 &n.Read,
96 &createdStr,
97 &n.RepoId,
98 &n.IssueId,
99 &n.PullId,
100 )
101 if err != nil {
102 return nil, fmt.Errorf("failed to scan notification: %w", err)
103 }
104 n.Type = models.NotificationType(typeStr)
105 n.Created, err = time.Parse(time.RFC3339, createdStr)
106 if err != nil {
107 return nil, fmt.Errorf("failed to parse created timestamp: %w", err)
108 }
109 notifications = append(notifications, &n)
110 }
111
112 return notifications, nil
113}
114
115// GetNotificationsWithEntities retrieves notifications with their related entities
116func GetNotificationsWithEntities(e Execer, page pagination.Page, filters ...filter) ([]*models.NotificationWithEntity, error) {
117 var conditions []string
118 var args []any
119
120 for _, filter := range filters {
121 conditions = append(conditions, filter.Condition())
122 args = append(args, filter.Arg()...)
123 }
124
125 whereClause := ""
126 if len(conditions) > 0 {
127 whereClause = "WHERE " + conditions[0]
128 for _, condition := range conditions[1:] {
129 whereClause += " AND " + condition
130 }
131 }
132
133 query := fmt.Sprintf(`
134 select
135 n.id, n.recipient_did, n.actor_did, n.type, n.entity_type, n.entity_id,
136 n.read, n.created, n.repo_id, n.issue_id, n.pull_id,
137 r.id as r_id, r.did as r_did, r.name as r_name, r.description as r_description, r.website as r_website, r.topics as r_topics,
138 i.id as i_id, i.did as i_did, i.issue_id as i_issue_id, i.title as i_title, i.open as i_open,
139 p.id as p_id, p.owner_did as p_owner_did, p.pull_id as p_pull_id, p.title as p_title, p.state as p_state
140 from notifications n
141 left join repos r on n.repo_id = r.id
142 left join issues i on n.issue_id = i.id
143 left join pulls p on n.pull_id = p.id
144 %s
145 order by n.created desc
146 limit ? offset ?
147 `, whereClause)
148
149 args = append(args, page.Limit, page.Offset)
150
151 rows, err := e.QueryContext(context.Background(), query, args...)
152 if err != nil {
153 return nil, fmt.Errorf("failed to query notifications with entities: %w", err)
154 }
155 defer rows.Close()
156
157 var notifications []*models.NotificationWithEntity
158 for rows.Next() {
159 var n models.Notification
160 var typeStr string
161 var createdStr string
162 var repo models.Repo
163 var issue models.Issue
164 var pull models.Pull
165 var rId, iId, pId sql.NullInt64
166 var rDid, rName, rDescription, rWebsite, rTopicStr sql.NullString
167 var iDid sql.NullString
168 var iIssueId sql.NullInt64
169 var iTitle sql.NullString
170 var iOpen sql.NullBool
171 var pOwnerDid sql.NullString
172 var pPullId sql.NullInt64
173 var pTitle sql.NullString
174 var pState sql.NullInt64
175
176 err := rows.Scan(
177 &n.ID, &n.RecipientDid, &n.ActorDid, &typeStr, &n.EntityType, &n.EntityId,
178 &n.Read, &createdStr, &n.RepoId, &n.IssueId, &n.PullId,
179 &rId, &rDid, &rName, &rDescription, &rWebsite, &rTopicStr,
180 &iId, &iDid, &iIssueId, &iTitle, &iOpen,
181 &pId, &pOwnerDid, &pPullId, &pTitle, &pState,
182 )
183 if err != nil {
184 return nil, fmt.Errorf("failed to scan notification with entities: %w", err)
185 }
186
187 n.Type = models.NotificationType(typeStr)
188 n.Created, err = time.Parse(time.RFC3339, createdStr)
189 if err != nil {
190 return nil, fmt.Errorf("failed to parse created timestamp: %w", err)
191 }
192
193 nwe := &models.NotificationWithEntity{Notification: &n}
194
195 // populate repo if present
196 if rId.Valid {
197 repo.Id = rId.Int64
198 if rDid.Valid {
199 repo.Did = rDid.String
200 }
201 if rName.Valid {
202 repo.Name = rName.String
203 }
204 if rDescription.Valid {
205 repo.Description = rDescription.String
206 }
207 if rWebsite.Valid {
208 repo.Website = rWebsite.String
209 }
210 if rTopicStr.Valid {
211 repo.Topics = strings.Fields(rTopicStr.String)
212 }
213 nwe.Repo = &repo
214 }
215
216 // populate issue if present
217 if iId.Valid {
218 issue.Id = iId.Int64
219 if iDid.Valid {
220 issue.Did = iDid.String
221 }
222 if iIssueId.Valid {
223 issue.IssueId = int(iIssueId.Int64)
224 }
225 if iTitle.Valid {
226 issue.Title = iTitle.String
227 }
228 if iOpen.Valid {
229 issue.Open = iOpen.Bool
230 }
231 nwe.Issue = &issue
232 }
233
234 // populate pull if present
235 if pId.Valid {
236 pull.ID = int(pId.Int64)
237 if pOwnerDid.Valid {
238 pull.OwnerDid = pOwnerDid.String
239 }
240 if pPullId.Valid {
241 pull.PullId = int(pPullId.Int64)
242 }
243 if pTitle.Valid {
244 pull.Title = pTitle.String
245 }
246 if pState.Valid {
247 pull.State = models.PullState(pState.Int64)
248 }
249 nwe.Pull = &pull
250 }
251
252 notifications = append(notifications, nwe)
253 }
254
255 return notifications, nil
256}
257
258// GetNotifications retrieves notifications with filters
259func GetNotifications(e Execer, filters ...filter) ([]*models.Notification, error) {
260 return GetNotificationsPaginated(e, pagination.FirstPage(), filters...)
261}
262
263func CountNotifications(e Execer, filters ...filter) (int64, error) {
264 var conditions []string
265 var args []any
266 for _, filter := range filters {
267 conditions = append(conditions, filter.Condition())
268 args = append(args, filter.Arg()...)
269 }
270
271 whereClause := ""
272 if conditions != nil {
273 whereClause = " where " + strings.Join(conditions, " and ")
274 }
275
276 query := fmt.Sprintf(`select count(1) from notifications %s`, whereClause)
277 var count int64
278 err := e.QueryRow(query, args...).Scan(&count)
279
280 if !errors.Is(err, sql.ErrNoRows) && err != nil {
281 return 0, err
282 }
283
284 return count, nil
285}
286
287func MarkNotificationRead(e Execer, notificationID int64, userDID string) error {
288 idFilter := FilterEq("id", notificationID)
289 recipientFilter := FilterEq("recipient_did", userDID)
290
291 query := fmt.Sprintf(`
292 UPDATE notifications
293 SET read = 1
294 WHERE %s AND %s
295 `, idFilter.Condition(), recipientFilter.Condition())
296
297 args := append(idFilter.Arg(), recipientFilter.Arg()...)
298
299 result, err := e.Exec(query, args...)
300 if err != nil {
301 return fmt.Errorf("failed to mark notification as read: %w", err)
302 }
303
304 rowsAffected, err := result.RowsAffected()
305 if err != nil {
306 return fmt.Errorf("failed to get rows affected: %w", err)
307 }
308
309 if rowsAffected == 0 {
310 return fmt.Errorf("notification not found or access denied")
311 }
312
313 return nil
314}
315
316func MarkAllNotificationsRead(e Execer, userDID string) error {
317 recipientFilter := FilterEq("recipient_did", userDID)
318 readFilter := FilterEq("read", 0)
319
320 query := fmt.Sprintf(`
321 UPDATE notifications
322 SET read = 1
323 WHERE %s AND %s
324 `, recipientFilter.Condition(), readFilter.Condition())
325
326 args := append(recipientFilter.Arg(), readFilter.Arg()...)
327
328 _, err := e.Exec(query, args...)
329 if err != nil {
330 return fmt.Errorf("failed to mark all notifications as read: %w", err)
331 }
332
333 return nil
334}
335
336func DeleteNotification(e Execer, notificationID int64, userDID string) error {
337 idFilter := FilterEq("id", notificationID)
338 recipientFilter := FilterEq("recipient_did", userDID)
339
340 query := fmt.Sprintf(`
341 DELETE FROM notifications
342 WHERE %s AND %s
343 `, idFilter.Condition(), recipientFilter.Condition())
344
345 args := append(idFilter.Arg(), recipientFilter.Arg()...)
346
347 result, err := e.Exec(query, args...)
348 if err != nil {
349 return fmt.Errorf("failed to delete notification: %w", err)
350 }
351
352 rowsAffected, err := result.RowsAffected()
353 if err != nil {
354 return fmt.Errorf("failed to get rows affected: %w", err)
355 }
356
357 if rowsAffected == 0 {
358 return fmt.Errorf("notification not found or access denied")
359 }
360
361 return nil
362}
363
364func GetNotificationPreference(e Execer, userDid string) (*models.NotificationPreferences, error) {
365 prefs, err := GetNotificationPreferences(e, FilterEq("user_did", userDid))
366 if err != nil {
367 return nil, err
368 }
369
370 p, ok := prefs[syntax.DID(userDid)]
371 if !ok {
372 return models.DefaultNotificationPreferences(syntax.DID(userDid)), nil
373 }
374
375 return p, nil
376}
377
378func GetNotificationPreferences(e Execer, filters ...filter) (map[syntax.DID]*models.NotificationPreferences, error) {
379 prefsMap := make(map[syntax.DID]*models.NotificationPreferences)
380
381 var conditions []string
382 var args []any
383 for _, filter := range filters {
384 conditions = append(conditions, filter.Condition())
385 args = append(args, filter.Arg()...)
386 }
387
388 whereClause := ""
389 if conditions != nil {
390 whereClause = " where " + strings.Join(conditions, " and ")
391 }
392
393 query := fmt.Sprintf(`
394 select
395 id,
396 user_did,
397 repo_starred,
398 issue_created,
399 issue_commented,
400 pull_created,
401 pull_commented,
402 followed,
403 user_mentioned,
404 pull_merged,
405 issue_closed,
406 email_notifications
407 from
408 notification_preferences
409 %s
410 `, whereClause)
411
412 rows, err := e.Query(query, args...)
413 if err != nil {
414 return nil, err
415 }
416 defer rows.Close()
417
418 for rows.Next() {
419 var prefs models.NotificationPreferences
420 if err := rows.Scan(
421 &prefs.ID,
422 &prefs.UserDid,
423 &prefs.RepoStarred,
424 &prefs.IssueCreated,
425 &prefs.IssueCommented,
426 &prefs.PullCreated,
427 &prefs.PullCommented,
428 &prefs.Followed,
429 &prefs.UserMentioned,
430 &prefs.PullMerged,
431 &prefs.IssueClosed,
432 &prefs.EmailNotifications,
433 ); err != nil {
434 return nil, err
435 }
436
437 prefsMap[prefs.UserDid] = &prefs
438 }
439
440 if err := rows.Err(); err != nil {
441 return nil, err
442 }
443
444 return prefsMap, nil
445}
446
447func (d *DB) UpdateNotificationPreferences(ctx context.Context, prefs *models.NotificationPreferences) error {
448 query := `
449 INSERT OR REPLACE INTO notification_preferences
450 (user_did, repo_starred, issue_created, issue_commented, pull_created,
451 pull_commented, followed, user_mentioned, pull_merged, issue_closed,
452 email_notifications)
453 VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
454 `
455
456 result, err := d.DB.ExecContext(ctx, query,
457 prefs.UserDid,
458 prefs.RepoStarred,
459 prefs.IssueCreated,
460 prefs.IssueCommented,
461 prefs.PullCreated,
462 prefs.PullCommented,
463 prefs.Followed,
464 prefs.UserMentioned,
465 prefs.PullMerged,
466 prefs.IssueClosed,
467 prefs.EmailNotifications,
468 )
469 if err != nil {
470 return fmt.Errorf("failed to update notification preferences: %w", err)
471 }
472
473 if prefs.ID == 0 {
474 id, err := result.LastInsertId()
475 if err != nil {
476 return fmt.Errorf("failed to get preferences ID: %w", err)
477 }
478 prefs.ID = id
479 }
480
481 return nil
482}
483
484func (d *DB) ClearOldNotifications(ctx context.Context, olderThan time.Duration) error {
485 cutoff := time.Now().Add(-olderThan)
486 createdFilter := FilterLte("created", cutoff)
487
488 query := fmt.Sprintf(`
489 DELETE FROM notifications
490 WHERE %s
491 `, createdFilter.Condition())
492
493 _, err := d.DB.ExecContext(ctx, query, createdFilter.Arg()...)
494 if err != nil {
495 return fmt.Errorf("failed to cleanup old notifications: %w", err)
496 }
497
498 return nil
499}