forked from
tangled.org/core
Monorepo for Tangled — https://tangled.org
1package db
2
3import (
4 "context"
5 "database/sql"
6 "errors"
7 "fmt"
8 "strings"
9 "time"
10
11 "github.com/bluesky-social/indigo/atproto/syntax"
12 "tangled.org/core/appview/models"
13 "tangled.org/core/appview/pagination"
14)
15
16func CreateNotification(e Execer, notification *models.Notification) error {
17 query := `
18 INSERT INTO notifications (recipient_did, actor_did, type, entity_type, entity_id, read, repo_id, issue_id, pull_id)
19 VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
20 `
21
22 result, err := e.Exec(query,
23 notification.RecipientDid,
24 notification.ActorDid,
25 string(notification.Type),
26 notification.EntityType,
27 notification.EntityId,
28 notification.Read,
29 notification.RepoId,
30 notification.IssueId,
31 notification.PullId,
32 )
33 if err != nil {
34 return fmt.Errorf("failed to create notification: %w", err)
35 }
36
37 id, err := result.LastInsertId()
38 if err != nil {
39 return fmt.Errorf("failed to get notification ID: %w", err)
40 }
41
42 notification.ID = id
43 return nil
44}
45
46// GetNotificationsPaginated retrieves notifications with filters and pagination
47func GetNotificationsPaginated(e Execer, page pagination.Page, filters ...filter) ([]*models.Notification, error) {
48 var conditions []string
49 var args []any
50
51 for _, filter := range filters {
52 conditions = append(conditions, filter.Condition())
53 args = append(args, filter.Arg()...)
54 }
55
56 whereClause := ""
57 if len(conditions) > 0 {
58 whereClause = "WHERE " + conditions[0]
59 for _, condition := range conditions[1:] {
60 whereClause += " AND " + condition
61 }
62 }
63 pageClause := ""
64 if page.Limit > 0 {
65 pageClause = " limit ? offset ? "
66 args = append(args, page.Limit, page.Offset)
67 }
68
69 query := fmt.Sprintf(`
70 select id, recipient_did, actor_did, type, entity_type, entity_id, read, created, repo_id, issue_id, pull_id
71 from notifications
72 %s
73 order by created desc
74 %s
75 `, whereClause, pageClause)
76
77 rows, err := e.QueryContext(context.Background(), query, args...)
78 if err != nil {
79 return nil, fmt.Errorf("failed to query notifications: %w", err)
80 }
81 defer rows.Close()
82
83 var notifications []*models.Notification
84 for rows.Next() {
85 var n models.Notification
86 var typeStr string
87 var createdStr string
88 err := rows.Scan(
89 &n.ID,
90 &n.RecipientDid,
91 &n.ActorDid,
92 &typeStr,
93 &n.EntityType,
94 &n.EntityId,
95 &n.Read,
96 &createdStr,
97 &n.RepoId,
98 &n.IssueId,
99 &n.PullId,
100 )
101 if err != nil {
102 return nil, fmt.Errorf("failed to scan notification: %w", err)
103 }
104 n.Type = models.NotificationType(typeStr)
105 n.Created, err = time.Parse(time.RFC3339, createdStr)
106 if err != nil {
107 return nil, fmt.Errorf("failed to parse created timestamp: %w", err)
108 }
109 notifications = append(notifications, &n)
110 }
111
112 return notifications, nil
113}
114
115// GetNotificationsWithEntities retrieves notifications with their related entities
116func GetNotificationsWithEntities(e Execer, page pagination.Page, filters ...filter) ([]*models.NotificationWithEntity, error) {
117 var conditions []string
118 var args []any
119
120 for _, filter := range filters {
121 conditions = append(conditions, filter.Condition())
122 args = append(args, filter.Arg()...)
123 }
124
125 whereClause := ""
126 if len(conditions) > 0 {
127 whereClause = "WHERE " + conditions[0]
128 for _, condition := range conditions[1:] {
129 whereClause += " AND " + condition
130 }
131 }
132
133 query := fmt.Sprintf(`
134 select
135 n.id, n.recipient_did, n.actor_did, n.type, n.entity_type, n.entity_id,
136 n.read, n.created, n.repo_id, n.issue_id, n.pull_id,
137 r.id as r_id, r.did as r_did, r.name as r_name, r.description as r_description, r.website as r_website, r.topics as r_topics,
138 i.id as i_id, i.did as i_did, i.issue_id as i_issue_id, i.title as i_title, i.open as i_open,
139 p.id as p_id, p.owner_did as p_owner_did, p.pull_id as p_pull_id, p.title as p_title, p.state as p_state
140 from notifications n
141 left join repos r on n.repo_id = r.id
142 left join issues i on n.issue_id = i.id
143 left join pulls p on n.pull_id = p.id
144 %s
145 order by n.created desc
146 limit ? offset ?
147 `, whereClause)
148
149 args = append(args, page.Limit, page.Offset)
150
151 rows, err := e.QueryContext(context.Background(), query, args...)
152 if err != nil {
153 return nil, fmt.Errorf("failed to query notifications with entities: %w", err)
154 }
155 defer rows.Close()
156
157 var notifications []*models.NotificationWithEntity
158 for rows.Next() {
159 var n models.Notification
160 var typeStr string
161 var createdStr string
162 var repo models.Repo
163 var issue models.Issue
164 var pull models.Pull
165 var rId, iId, pId sql.NullInt64
166 var rDid, rName, rDescription, rWebsite, rTopicStr sql.NullString
167 var iDid sql.NullString
168 var iIssueId sql.NullInt64
169 var iTitle sql.NullString
170 var iOpen sql.NullBool
171 var pOwnerDid sql.NullString
172 var pPullId sql.NullInt64
173 var pTitle sql.NullString
174 var pState sql.NullInt64
175
176 err := rows.Scan(
177 &n.ID, &n.RecipientDid, &n.ActorDid, &typeStr, &n.EntityType, &n.EntityId,
178 &n.Read, &createdStr, &n.RepoId, &n.IssueId, &n.PullId,
179 &rId, &rDid, &rName, &rDescription, &rWebsite, &rTopicStr,
180 &iId, &iDid, &iIssueId, &iTitle, &iOpen,
181 &pId, &pOwnerDid, &pPullId, &pTitle, &pState,
182 )
183 if err != nil {
184 return nil, fmt.Errorf("failed to scan notification with entities: %w", err)
185 }
186
187 n.Type = models.NotificationType(typeStr)
188 n.Created, err = time.Parse(time.RFC3339, createdStr)
189 if err != nil {
190 return nil, fmt.Errorf("failed to parse created timestamp: %w", err)
191 }
192
193 nwe := &models.NotificationWithEntity{Notification: &n}
194
195 // populate repo if present
196 if rId.Valid {
197 repo.Id = rId.Int64
198 if rDid.Valid {
199 repo.Did = rDid.String
200 }
201 if rName.Valid {
202 repo.Name = rName.String
203 }
204 if rDescription.Valid {
205 repo.Description = rDescription.String
206 }
207 if rWebsite.Valid {
208 repo.Website = rWebsite.String
209 }
210 if rTopicStr.Valid {
211 repo.Topics = strings.Fields(rTopicStr.String)
212 }
213 nwe.Repo = &repo
214 }
215
216 // populate issue if present
217 if iId.Valid {
218 issue.Id = iId.Int64
219 if iDid.Valid {
220 issue.Did = iDid.String
221 }
222 if iIssueId.Valid {
223 issue.IssueId = int(iIssueId.Int64)
224 }
225 if iTitle.Valid {
226 issue.Title = iTitle.String
227 }
228 if iOpen.Valid {
229 issue.Open = iOpen.Bool
230 }
231 nwe.Issue = &issue
232 }
233
234 // populate pull if present
235 if pId.Valid {
236 pull.ID = int(pId.Int64)
237 if pOwnerDid.Valid {
238 pull.OwnerDid = pOwnerDid.String
239 }
240 if pPullId.Valid {
241 pull.PullId = int(pPullId.Int64)
242 }
243 if pTitle.Valid {
244 pull.Title = pTitle.String
245 }
246 if pState.Valid {
247 pull.State = models.PullState(pState.Int64)
248 }
249 nwe.Pull = &pull
250 }
251
252 notifications = append(notifications, nwe)
253 }
254
255 return notifications, nil
256}
257
258// GetNotifications retrieves notifications with filters
259func GetNotifications(e Execer, filters ...filter) ([]*models.Notification, error) {
260 return GetNotificationsPaginated(e, pagination.FirstPage(), filters...)
261}
262
263func CountNotifications(e Execer, filters ...filter) (int64, error) {
264 var conditions []string
265 var args []any
266 for _, filter := range filters {
267 conditions = append(conditions, filter.Condition())
268 args = append(args, filter.Arg()...)
269 }
270
271 whereClause := ""
272 if conditions != nil {
273 whereClause = " where " + strings.Join(conditions, " and ")
274 }
275
276 query := fmt.Sprintf(`select count(1) from notifications %s`, whereClause)
277 var count int64
278 err := e.QueryRow(query, args...).Scan(&count)
279
280 if !errors.Is(err, sql.ErrNoRows) && err != nil {
281 return 0, err
282 }
283
284 return count, nil
285}
286
287func MarkNotificationRead(e Execer, notificationID int64, userDID string) error {
288 idFilter := FilterEq("id", notificationID)
289 recipientFilter := FilterEq("recipient_did", userDID)
290
291 query := fmt.Sprintf(`
292 UPDATE notifications
293 SET read = 1
294 WHERE %s AND %s
295 `, idFilter.Condition(), recipientFilter.Condition())
296
297 args := append(idFilter.Arg(), recipientFilter.Arg()...)
298
299 result, err := e.Exec(query, args...)
300 if err != nil {
301 return fmt.Errorf("failed to mark notification as read: %w", err)
302 }
303
304 rowsAffected, err := result.RowsAffected()
305 if err != nil {
306 return fmt.Errorf("failed to get rows affected: %w", err)
307 }
308
309 if rowsAffected == 0 {
310 return fmt.Errorf("notification not found or access denied")
311 }
312
313 return nil
314}
315
316func MarkAllNotificationsRead(e Execer, userDID string) error {
317 recipientFilter := FilterEq("recipient_did", userDID)
318 readFilter := FilterEq("read", 0)
319
320 query := fmt.Sprintf(`
321 UPDATE notifications
322 SET read = 1
323 WHERE %s AND %s
324 `, recipientFilter.Condition(), readFilter.Condition())
325
326 args := append(recipientFilter.Arg(), readFilter.Arg()...)
327
328 _, err := e.Exec(query, args...)
329 if err != nil {
330 return fmt.Errorf("failed to mark all notifications as read: %w", err)
331 }
332
333 return nil
334}
335
336func DeleteNotification(e Execer, notificationID int64, userDID string) error {
337 idFilter := FilterEq("id", notificationID)
338 recipientFilter := FilterEq("recipient_did", userDID)
339
340 query := fmt.Sprintf(`
341 DELETE FROM notifications
342 WHERE %s AND %s
343 `, idFilter.Condition(), recipientFilter.Condition())
344
345 args := append(idFilter.Arg(), recipientFilter.Arg()...)
346
347 result, err := e.Exec(query, args...)
348 if err != nil {
349 return fmt.Errorf("failed to delete notification: %w", err)
350 }
351
352 rowsAffected, err := result.RowsAffected()
353 if err != nil {
354 return fmt.Errorf("failed to get rows affected: %w", err)
355 }
356
357 if rowsAffected == 0 {
358 return fmt.Errorf("notification not found or access denied")
359 }
360
361 return nil
362}
363
364func GetNotificationPreference(e Execer, userDid string) (*models.NotificationPreferences, error) {
365 prefs, err := GetNotificationPreferences(e, FilterEq("user_did", userDid))
366 if err != nil {
367 return nil, err
368 }
369
370 p, ok := prefs[syntax.DID(userDid)]
371 if !ok {
372 return models.DefaultNotificationPreferences(syntax.DID(userDid)), nil
373 }
374
375 return p, nil
376}
377
378func GetNotificationPreferences(e Execer, filters ...filter) (map[syntax.DID]*models.NotificationPreferences, error) {
379 prefsMap := make(map[syntax.DID]*models.NotificationPreferences)
380
381 var conditions []string
382 var args []any
383 for _, filter := range filters {
384 conditions = append(conditions, filter.Condition())
385 args = append(args, filter.Arg()...)
386 }
387
388 whereClause := ""
389 if conditions != nil {
390 whereClause = " where " + strings.Join(conditions, " and ")
391 }
392
393 query := fmt.Sprintf(`
394 select
395 id,
396 user_did,
397 repo_starred,
398 issue_created,
399 issue_commented,
400 pull_created,
401 pull_commented,
402 followed,
403 pull_merged,
404 issue_closed,
405 email_notifications
406 from
407 notification_preferences
408 %s
409 `, whereClause)
410
411 rows, err := e.Query(query, args...)
412 if err != nil {
413 return nil, err
414 }
415 defer rows.Close()
416
417 for rows.Next() {
418 var prefs models.NotificationPreferences
419 if err := rows.Scan(
420 &prefs.ID,
421 &prefs.UserDid,
422 &prefs.RepoStarred,
423 &prefs.IssueCreated,
424 &prefs.IssueCommented,
425 &prefs.PullCreated,
426 &prefs.PullCommented,
427 &prefs.Followed,
428 &prefs.PullMerged,
429 &prefs.IssueClosed,
430 &prefs.EmailNotifications,
431 ); err != nil {
432 return nil, err
433 }
434
435 prefsMap[prefs.UserDid] = &prefs
436 }
437
438 if err := rows.Err(); err != nil {
439 return nil, err
440 }
441
442 return prefsMap, nil
443}
444
445func (d *DB) UpdateNotificationPreferences(ctx context.Context, prefs *models.NotificationPreferences) error {
446 query := `
447 INSERT OR REPLACE INTO notification_preferences
448 (user_did, repo_starred, issue_created, issue_commented, pull_created,
449 pull_commented, followed, pull_merged, issue_closed, email_notifications)
450 VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
451 `
452
453 result, err := d.DB.ExecContext(ctx, query,
454 prefs.UserDid,
455 prefs.RepoStarred,
456 prefs.IssueCreated,
457 prefs.IssueCommented,
458 prefs.PullCreated,
459 prefs.PullCommented,
460 prefs.Followed,
461 prefs.PullMerged,
462 prefs.IssueClosed,
463 prefs.EmailNotifications,
464 )
465 if err != nil {
466 return fmt.Errorf("failed to update notification preferences: %w", err)
467 }
468
469 if prefs.ID == 0 {
470 id, err := result.LastInsertId()
471 if err != nil {
472 return fmt.Errorf("failed to get preferences ID: %w", err)
473 }
474 prefs.ID = id
475 }
476
477 return nil
478}
479
480func (d *DB) ClearOldNotifications(ctx context.Context, olderThan time.Duration) error {
481 cutoff := time.Now().Add(-olderThan)
482 createdFilter := FilterLte("created", cutoff)
483
484 query := fmt.Sprintf(`
485 DELETE FROM notifications
486 WHERE %s
487 `, createdFilter.Condition())
488
489 _, err := d.DB.ExecContext(ctx, query, createdFilter.Arg()...)
490 if err != nil {
491 return fmt.Errorf("failed to cleanup old notifications: %w", err)
492 }
493
494 return nil
495}