forked from
tangled.org/core
Monorepo for Tangled — https://tangled.org
1package db
2
3import (
4 "context"
5 "database/sql"
6 "errors"
7 "fmt"
8 "strings"
9 "time"
10
11 "github.com/bluesky-social/indigo/atproto/syntax"
12 "tangled.org/core/appview/models"
13 "tangled.org/core/appview/pagination"
14 "tangled.org/core/orm"
15)
16
17func CreateNotification(e Execer, notification *models.Notification) error {
18 query := `
19 INSERT INTO notifications (recipient_did, actor_did, type, entity_type, entity_id, read, repo_id, issue_id, pull_id)
20 VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
21 `
22
23 result, err := e.Exec(query,
24 notification.RecipientDid,
25 notification.ActorDid,
26 string(notification.Type),
27 notification.EntityType,
28 notification.EntityId,
29 notification.Read,
30 notification.RepoId,
31 notification.IssueId,
32 notification.PullId,
33 )
34 if err != nil {
35 return fmt.Errorf("failed to create notification: %w", err)
36 }
37
38 id, err := result.LastInsertId()
39 if err != nil {
40 return fmt.Errorf("failed to get notification ID: %w", err)
41 }
42
43 notification.ID = id
44 return nil
45}
46
47// GetNotificationsPaginated retrieves notifications with filters and pagination
48func GetNotificationsPaginated(e Execer, page pagination.Page, filters ...orm.Filter) ([]*models.Notification, error) {
49 var conditions []string
50 var args []any
51
52 for _, filter := range filters {
53 conditions = append(conditions, filter.Condition())
54 args = append(args, filter.Arg()...)
55 }
56
57 whereClause := ""
58 if len(conditions) > 0 {
59 whereClause = "WHERE " + conditions[0]
60 for _, condition := range conditions[1:] {
61 whereClause += " AND " + condition
62 }
63 }
64 pageClause := ""
65 if page.Limit > 0 {
66 pageClause = " limit ? offset ? "
67 args = append(args, page.Limit, page.Offset)
68 }
69
70 query := fmt.Sprintf(`
71 select id, recipient_did, actor_did, type, entity_type, entity_id, read, created, repo_id, issue_id, pull_id
72 from notifications
73 %s
74 order by created desc
75 %s
76 `, whereClause, pageClause)
77
78 rows, err := e.QueryContext(context.Background(), query, args...)
79 if err != nil {
80 return nil, fmt.Errorf("failed to query notifications: %w", err)
81 }
82 defer rows.Close()
83
84 var notifications []*models.Notification
85 for rows.Next() {
86 var n models.Notification
87 var typeStr string
88 var createdStr string
89 err := rows.Scan(
90 &n.ID,
91 &n.RecipientDid,
92 &n.ActorDid,
93 &typeStr,
94 &n.EntityType,
95 &n.EntityId,
96 &n.Read,
97 &createdStr,
98 &n.RepoId,
99 &n.IssueId,
100 &n.PullId,
101 )
102 if err != nil {
103 return nil, fmt.Errorf("failed to scan notification: %w", err)
104 }
105 n.Type = models.NotificationType(typeStr)
106 n.Created, err = time.Parse(time.RFC3339, createdStr)
107 if err != nil {
108 return nil, fmt.Errorf("failed to parse created timestamp: %w", err)
109 }
110 notifications = append(notifications, &n)
111 }
112
113 return notifications, nil
114}
115
116// GetNotificationsWithEntities retrieves notifications with their related entities
117func GetNotificationsWithEntities(e Execer, page pagination.Page, filters ...orm.Filter) ([]*models.NotificationWithEntity, error) {
118 var conditions []string
119 var args []any
120
121 for _, filter := range filters {
122 conditions = append(conditions, filter.Condition())
123 args = append(args, filter.Arg()...)
124 }
125
126 whereClause := ""
127 if len(conditions) > 0 {
128 whereClause = "WHERE " + conditions[0]
129 for _, condition := range conditions[1:] {
130 whereClause += " AND " + condition
131 }
132 }
133
134 query := fmt.Sprintf(`
135 select
136 n.id, n.recipient_did, n.actor_did, n.type, n.entity_type, n.entity_id,
137 n.read, n.created, n.repo_id, n.issue_id, n.pull_id,
138 r.id as r_id, r.did as r_did, r.name as r_name, r.description as r_description, r.website as r_website, r.topics as r_topics,
139 i.id as i_id, i.did as i_did, i.issue_id as i_issue_id, i.title as i_title, i.open as i_open,
140 p.id as p_id, p.owner_did as p_owner_did, p.pull_id as p_pull_id, p.title as p_title, p.state as p_state
141 from notifications n
142 left join repos r on n.repo_id = r.id
143 left join issues i on n.issue_id = i.id
144 left join pulls p on n.pull_id = p.id
145 %s
146 order by n.created desc
147 limit ? offset ?
148 `, whereClause)
149
150 args = append(args, page.Limit, page.Offset)
151
152 rows, err := e.QueryContext(context.Background(), query, args...)
153 if err != nil {
154 return nil, fmt.Errorf("failed to query notifications with entities: %w", err)
155 }
156 defer rows.Close()
157
158 var notifications []*models.NotificationWithEntity
159 for rows.Next() {
160 var n models.Notification
161 var typeStr string
162 var createdStr string
163 var repo models.Repo
164 var issue models.Issue
165 var pull models.Pull
166 var rId, iId, pId sql.NullInt64
167 var rDid, rName, rDescription, rWebsite, rTopicStr sql.NullString
168 var iDid sql.NullString
169 var iIssueId sql.NullInt64
170 var iTitle sql.NullString
171 var iOpen sql.NullBool
172 var pOwnerDid sql.NullString
173 var pPullId sql.NullInt64
174 var pTitle sql.NullString
175 var pState sql.NullInt64
176
177 err := rows.Scan(
178 &n.ID, &n.RecipientDid, &n.ActorDid, &typeStr, &n.EntityType, &n.EntityId,
179 &n.Read, &createdStr, &n.RepoId, &n.IssueId, &n.PullId,
180 &rId, &rDid, &rName, &rDescription, &rWebsite, &rTopicStr,
181 &iId, &iDid, &iIssueId, &iTitle, &iOpen,
182 &pId, &pOwnerDid, &pPullId, &pTitle, &pState,
183 )
184 if err != nil {
185 return nil, fmt.Errorf("failed to scan notification with entities: %w", err)
186 }
187
188 n.Type = models.NotificationType(typeStr)
189 n.Created, err = time.Parse(time.RFC3339, createdStr)
190 if err != nil {
191 return nil, fmt.Errorf("failed to parse created timestamp: %w", err)
192 }
193
194 nwe := &models.NotificationWithEntity{Notification: &n}
195
196 // populate repo if present
197 if rId.Valid {
198 repo.Id = rId.Int64
199 if rDid.Valid {
200 repo.Did = rDid.String
201 }
202 if rName.Valid {
203 repo.Name = rName.String
204 }
205 if rDescription.Valid {
206 repo.Description = rDescription.String
207 }
208 if rWebsite.Valid {
209 repo.Website = rWebsite.String
210 }
211 if rTopicStr.Valid {
212 repo.Topics = strings.Fields(rTopicStr.String)
213 }
214 nwe.Repo = &repo
215 }
216
217 // populate issue if present
218 if iId.Valid {
219 issue.Id = iId.Int64
220 if iDid.Valid {
221 issue.Did = iDid.String
222 }
223 if iIssueId.Valid {
224 issue.IssueId = int(iIssueId.Int64)
225 }
226 if iTitle.Valid {
227 issue.Title = iTitle.String
228 }
229 if iOpen.Valid {
230 issue.Open = iOpen.Bool
231 }
232 nwe.Issue = &issue
233 }
234
235 // populate pull if present
236 if pId.Valid {
237 pull.ID = int(pId.Int64)
238 if pOwnerDid.Valid {
239 pull.OwnerDid = pOwnerDid.String
240 }
241 if pPullId.Valid {
242 pull.PullId = int(pPullId.Int64)
243 }
244 if pTitle.Valid {
245 pull.Title = pTitle.String
246 }
247 if pState.Valid {
248 pull.State = models.PullState(pState.Int64)
249 }
250 nwe.Pull = &pull
251 }
252
253 notifications = append(notifications, nwe)
254 }
255
256 return notifications, nil
257}
258
259// GetNotifications retrieves notifications with filters
260func GetNotifications(e Execer, filters ...orm.Filter) ([]*models.Notification, error) {
261 return GetNotificationsPaginated(e, pagination.FirstPage(), filters...)
262}
263
264func CountNotifications(e Execer, filters ...orm.Filter) (int64, error) {
265 var conditions []string
266 var args []any
267 for _, filter := range filters {
268 conditions = append(conditions, filter.Condition())
269 args = append(args, filter.Arg()...)
270 }
271
272 whereClause := ""
273 if conditions != nil {
274 whereClause = " where " + strings.Join(conditions, " and ")
275 }
276
277 query := fmt.Sprintf(`select count(1) from notifications %s`, whereClause)
278 var count int64
279 err := e.QueryRow(query, args...).Scan(&count)
280
281 if !errors.Is(err, sql.ErrNoRows) && err != nil {
282 return 0, err
283 }
284
285 return count, nil
286}
287
288func MarkNotificationRead(e Execer, notificationID int64, userDID string) error {
289 idFilter := orm.FilterEq("id", notificationID)
290 recipientFilter := orm.FilterEq("recipient_did", userDID)
291
292 query := fmt.Sprintf(`
293 UPDATE notifications
294 SET read = 1
295 WHERE %s AND %s
296 `, idFilter.Condition(), recipientFilter.Condition())
297
298 args := append(idFilter.Arg(), recipientFilter.Arg()...)
299
300 result, err := e.Exec(query, args...)
301 if err != nil {
302 return fmt.Errorf("failed to mark notification as read: %w", err)
303 }
304
305 rowsAffected, err := result.RowsAffected()
306 if err != nil {
307 return fmt.Errorf("failed to get rows affected: %w", err)
308 }
309
310 if rowsAffected == 0 {
311 return fmt.Errorf("notification not found or access denied")
312 }
313
314 return nil
315}
316
317func MarkAllNotificationsRead(e Execer, userDID string) error {
318 recipientFilter := orm.FilterEq("recipient_did", userDID)
319 readFilter := orm.FilterEq("read", 0)
320
321 query := fmt.Sprintf(`
322 UPDATE notifications
323 SET read = 1
324 WHERE %s AND %s
325 `, recipientFilter.Condition(), readFilter.Condition())
326
327 args := append(recipientFilter.Arg(), readFilter.Arg()...)
328
329 _, err := e.Exec(query, args...)
330 if err != nil {
331 return fmt.Errorf("failed to mark all notifications as read: %w", err)
332 }
333
334 return nil
335}
336
337func DeleteNotification(e Execer, notificationID int64, userDID string) error {
338 idFilter := orm.FilterEq("id", notificationID)
339 recipientFilter := orm.FilterEq("recipient_did", userDID)
340
341 query := fmt.Sprintf(`
342 DELETE FROM notifications
343 WHERE %s AND %s
344 `, idFilter.Condition(), recipientFilter.Condition())
345
346 args := append(idFilter.Arg(), recipientFilter.Arg()...)
347
348 result, err := e.Exec(query, args...)
349 if err != nil {
350 return fmt.Errorf("failed to delete notification: %w", err)
351 }
352
353 rowsAffected, err := result.RowsAffected()
354 if err != nil {
355 return fmt.Errorf("failed to get rows affected: %w", err)
356 }
357
358 if rowsAffected == 0 {
359 return fmt.Errorf("notification not found or access denied")
360 }
361
362 return nil
363}
364
365func GetNotificationPreference(e Execer, userDid string) (*models.NotificationPreferences, error) {
366 prefs, err := GetNotificationPreferences(e, orm.FilterEq("user_did", userDid))
367 if err != nil {
368 return nil, err
369 }
370
371 p, ok := prefs[syntax.DID(userDid)]
372 if !ok {
373 return models.DefaultNotificationPreferences(syntax.DID(userDid)), nil
374 }
375
376 return p, nil
377}
378
379func GetNotificationPreferences(e Execer, filters ...orm.Filter) (map[syntax.DID]*models.NotificationPreferences, error) {
380 prefsMap := make(map[syntax.DID]*models.NotificationPreferences)
381
382 var conditions []string
383 var args []any
384 for _, filter := range filters {
385 conditions = append(conditions, filter.Condition())
386 args = append(args, filter.Arg()...)
387 }
388
389 whereClause := ""
390 if conditions != nil {
391 whereClause = " where " + strings.Join(conditions, " and ")
392 }
393
394 query := fmt.Sprintf(`
395 select
396 id,
397 user_did,
398 repo_starred,
399 issue_created,
400 issue_commented,
401 pull_created,
402 pull_commented,
403 followed,
404 user_mentioned,
405 pull_merged,
406 issue_closed,
407 email_notifications
408 from
409 notification_preferences
410 %s
411 `, whereClause)
412
413 rows, err := e.Query(query, args...)
414 if err != nil {
415 return nil, err
416 }
417 defer rows.Close()
418
419 for rows.Next() {
420 var prefs models.NotificationPreferences
421 if err := rows.Scan(
422 &prefs.ID,
423 &prefs.UserDid,
424 &prefs.RepoStarred,
425 &prefs.IssueCreated,
426 &prefs.IssueCommented,
427 &prefs.PullCreated,
428 &prefs.PullCommented,
429 &prefs.Followed,
430 &prefs.UserMentioned,
431 &prefs.PullMerged,
432 &prefs.IssueClosed,
433 &prefs.EmailNotifications,
434 ); err != nil {
435 return nil, err
436 }
437
438 prefsMap[prefs.UserDid] = &prefs
439 }
440
441 if err := rows.Err(); err != nil {
442 return nil, err
443 }
444
445 return prefsMap, nil
446}
447
448func (d *DB) UpdateNotificationPreferences(ctx context.Context, prefs *models.NotificationPreferences) error {
449 query := `
450 INSERT OR REPLACE INTO notification_preferences
451 (user_did, repo_starred, issue_created, issue_commented, pull_created,
452 pull_commented, followed, user_mentioned, pull_merged, issue_closed,
453 email_notifications)
454 VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
455 `
456
457 result, err := d.DB.ExecContext(ctx, query,
458 prefs.UserDid,
459 prefs.RepoStarred,
460 prefs.IssueCreated,
461 prefs.IssueCommented,
462 prefs.PullCreated,
463 prefs.PullCommented,
464 prefs.Followed,
465 prefs.UserMentioned,
466 prefs.PullMerged,
467 prefs.IssueClosed,
468 prefs.EmailNotifications,
469 )
470 if err != nil {
471 return fmt.Errorf("failed to update notification preferences: %w", err)
472 }
473
474 if prefs.ID == 0 {
475 id, err := result.LastInsertId()
476 if err != nil {
477 return fmt.Errorf("failed to get preferences ID: %w", err)
478 }
479 prefs.ID = id
480 }
481
482 return nil
483}
484
485func (d *DB) ClearOldNotifications(ctx context.Context, olderThan time.Duration) error {
486 cutoff := time.Now().Add(-olderThan)
487 createdFilter := orm.FilterLte("created", cutoff)
488
489 query := fmt.Sprintf(`
490 DELETE FROM notifications
491 WHERE %s
492 `, createdFilter.Condition())
493
494 _, err := d.DB.ExecContext(ctx, query, createdFilter.Arg()...)
495 if err != nil {
496 return fmt.Errorf("failed to cleanup old notifications: %w", err)
497 }
498
499 return nil
500}