A community based topic aggregation platform built on atproto
1package postgres
2
3import (
4 "context"
5 "database/sql"
6 "encoding/base64"
7 "encoding/json"
8 "fmt"
9 "strconv"
10 "strings"
11 "time"
12
13 "Coves/internal/core/communityFeeds"
14 "Coves/internal/core/posts"
15)
16
17type postgresFeedRepo struct {
18 db *sql.DB
19}
20
21// sortClauses maps sort types to safe SQL ORDER BY clauses
22// This whitelist prevents SQL injection via dynamic ORDER BY construction
23var sortClauses = map[string]string{
24 "hot": `(p.score / POWER(EXTRACT(EPOCH FROM (NOW() - p.created_at))/3600 + 2, 1.5)) DESC, p.created_at DESC, p.uri DESC`,
25 "top": `p.score DESC, p.created_at DESC, p.uri DESC`,
26 "new": `p.created_at DESC, p.uri DESC`,
27}
28
29// hotRankExpression is the SQL expression for computing the hot rank
30// NOTE: Uses NOW() which means hot_rank changes over time - this is expected behavior
31// for hot sorting (posts naturally age out). Slight time drift between cursor creation
32// and usage may cause minor reordering but won't drop posts entirely (unlike using raw score).
33const hotRankExpression = `(p.score / POWER(EXTRACT(EPOCH FROM (NOW() - p.created_at))/3600 + 2, 1.5))`
34
35// NewCommunityFeedRepository creates a new PostgreSQL feed repository
36func NewCommunityFeedRepository(db *sql.DB) communityFeeds.Repository {
37 return &postgresFeedRepo{db: db}
38}
39
40// GetCommunityFeed retrieves posts from a community with sorting and pagination
41// Single query with JOINs for optimal performance
42func (r *postgresFeedRepo) GetCommunityFeed(ctx context.Context, req communityFeeds.GetCommunityFeedRequest) ([]*communityFeeds.FeedViewPost, *string, error) {
43 // Build ORDER BY clause based on sort type
44 orderBy, timeFilter := r.buildSortClause(req.Sort, req.Timeframe)
45
46 // Build cursor filter for pagination
47 cursorFilter, cursorValues, err := r.parseCursor(req.Cursor, req.Sort)
48 if err != nil {
49 return nil, nil, communityFeeds.ErrInvalidCursor
50 }
51
52 // Build the main query
53 // For hot sort, we need to compute and return the hot_rank for cursor building
54 var selectClause string
55 if req.Sort == "hot" {
56 selectClause = fmt.Sprintf(`
57 SELECT
58 p.uri, p.cid, p.rkey,
59 p.author_did, u.handle as author_handle,
60 p.community_did, c.name as community_name, c.avatar_cid as community_avatar,
61 p.title, p.content, p.content_facets, p.embed, p.content_labels,
62 p.created_at, p.edited_at, p.indexed_at,
63 p.upvote_count, p.downvote_count, p.score, p.comment_count,
64 %s as hot_rank
65 FROM posts p`, hotRankExpression)
66 } else {
67 selectClause = `
68 SELECT
69 p.uri, p.cid, p.rkey,
70 p.author_did, u.handle as author_handle,
71 p.community_did, c.name as community_name, c.avatar_cid as community_avatar,
72 p.title, p.content, p.content_facets, p.embed, p.content_labels,
73 p.created_at, p.edited_at, p.indexed_at,
74 p.upvote_count, p.downvote_count, p.score, p.comment_count,
75 NULL::numeric as hot_rank
76 FROM posts p`
77 }
78
79 query := fmt.Sprintf(`
80 %s
81 INNER JOIN users u ON p.author_did = u.did
82 INNER JOIN communities c ON p.community_did = c.did
83 WHERE p.community_did = $1
84 AND p.deleted_at IS NULL
85 %s
86 %s
87 ORDER BY %s
88 LIMIT $2
89 `, selectClause, timeFilter, cursorFilter, orderBy)
90
91 // Prepare query arguments
92 args := []interface{}{req.Community, req.Limit + 1} // +1 to check for next page
93 args = append(args, cursorValues...)
94
95 // Execute query
96 rows, err := r.db.QueryContext(ctx, query, args...)
97 if err != nil {
98 return nil, nil, fmt.Errorf("failed to query community feed: %w", err)
99 }
100 defer func() {
101 if err := rows.Close(); err != nil {
102 // Log close errors (non-fatal but worth noting)
103 fmt.Printf("Warning: failed to close rows: %v\n", err)
104 }
105 }()
106
107 // Scan results
108 var feedPosts []*communityFeeds.FeedViewPost
109 var hotRanks []float64 // Store hot ranks for cursor building
110 for rows.Next() {
111 feedPost, hotRank, err := r.scanFeedViewPost(rows)
112 if err != nil {
113 return nil, nil, fmt.Errorf("failed to scan feed post: %w", err)
114 }
115 feedPosts = append(feedPosts, feedPost)
116 hotRanks = append(hotRanks, hotRank)
117 }
118
119 if err := rows.Err(); err != nil {
120 return nil, nil, fmt.Errorf("error iterating feed results: %w", err)
121 }
122
123 // Handle pagination cursor
124 var cursor *string
125 if len(feedPosts) > req.Limit && req.Limit > 0 {
126 feedPosts = feedPosts[:req.Limit]
127 hotRanks = hotRanks[:req.Limit]
128 lastPost := feedPosts[len(feedPosts)-1].Post
129 lastHotRank := hotRanks[len(hotRanks)-1]
130 cursorStr := r.buildCursor(lastPost, req.Sort, lastHotRank)
131 cursor = &cursorStr
132 }
133
134 return feedPosts, cursor, nil
135}
136
137// buildSortClause returns the ORDER BY SQL and optional time filter
138func (r *postgresFeedRepo) buildSortClause(sort, timeframe string) (string, string) {
139 // Use whitelist map for ORDER BY clause (defense-in-depth against SQL injection)
140 orderBy := sortClauses[sort]
141 if orderBy == "" {
142 orderBy = sortClauses["hot"] // safe default
143 }
144
145 // Add time filter for "top" sort
146 var timeFilter string
147 if sort == "top" {
148 timeFilter = r.buildTimeFilter(timeframe)
149 }
150
151 return orderBy, timeFilter
152}
153
154// buildTimeFilter returns SQL filter for timeframe
155func (r *postgresFeedRepo) buildTimeFilter(timeframe string) string {
156 if timeframe == "" || timeframe == "all" {
157 return ""
158 }
159
160 var interval string
161 switch timeframe {
162 case "hour":
163 interval = "1 hour"
164 case "day":
165 interval = "1 day"
166 case "week":
167 interval = "1 week"
168 case "month":
169 interval = "1 month"
170 case "year":
171 interval = "1 year"
172 default:
173 return ""
174 }
175
176 return fmt.Sprintf("AND p.created_at > NOW() - INTERVAL '%s'", interval)
177}
178
179// parseCursor decodes pagination cursor
180func (r *postgresFeedRepo) parseCursor(cursor *string, sort string) (string, []interface{}, error) {
181 if cursor == nil || *cursor == "" {
182 return "", nil, nil
183 }
184
185 // Decode base64 cursor
186 decoded, err := base64.StdEncoding.DecodeString(*cursor)
187 if err != nil {
188 return "", nil, fmt.Errorf("invalid cursor encoding")
189 }
190
191 // Parse cursor based on sort type using :: delimiter (Bluesky convention)
192 parts := strings.Split(string(decoded), "::")
193
194 switch sort {
195 case "new":
196 // Cursor format: timestamp::uri
197 if len(parts) != 2 {
198 return "", nil, fmt.Errorf("invalid cursor format")
199 }
200
201 createdAt := parts[0]
202 uri := parts[1]
203
204 // Validate timestamp format
205 if _, err := time.Parse(time.RFC3339Nano, createdAt); err != nil {
206 return "", nil, fmt.Errorf("invalid cursor timestamp")
207 }
208
209 // Validate URI format (must be AT-URI)
210 if !strings.HasPrefix(uri, "at://") {
211 return "", nil, fmt.Errorf("invalid cursor URI")
212 }
213
214 filter := `AND (p.created_at < $3 OR (p.created_at = $3 AND p.uri < $4))`
215 return filter, []interface{}{createdAt, uri}, nil
216
217 case "top":
218 // Cursor format: score::timestamp::uri
219 if len(parts) != 3 {
220 return "", nil, fmt.Errorf("invalid cursor format for %s sort", sort)
221 }
222
223 scoreStr := parts[0]
224 createdAt := parts[1]
225 uri := parts[2]
226
227 // Validate score is numeric
228 score := 0
229 if _, err := fmt.Sscanf(scoreStr, "%d", &score); err != nil {
230 return "", nil, fmt.Errorf("invalid cursor score")
231 }
232
233 // Validate timestamp format
234 if _, err := time.Parse(time.RFC3339Nano, createdAt); err != nil {
235 return "", nil, fmt.Errorf("invalid cursor timestamp")
236 }
237
238 // Validate URI format (must be AT-URI)
239 if !strings.HasPrefix(uri, "at://") {
240 return "", nil, fmt.Errorf("invalid cursor URI")
241 }
242
243 filter := `AND (p.score < $3 OR (p.score = $3 AND p.created_at < $4) OR (p.score = $3 AND p.created_at = $4 AND p.uri < $5))`
244 return filter, []interface{}{score, createdAt, uri}, nil
245
246 case "hot":
247 // Cursor format: hot_rank::timestamp::uri
248 // CRITICAL: Must use computed hot_rank, not raw score, to prevent pagination bugs
249 if len(parts) != 3 {
250 return "", nil, fmt.Errorf("invalid cursor format for hot sort")
251 }
252
253 hotRankStr := parts[0]
254 createdAt := parts[1]
255 uri := parts[2]
256
257 // Validate hot_rank is numeric (float)
258 hotRank := 0.0
259 if _, err := fmt.Sscanf(hotRankStr, "%f", &hotRank); err != nil {
260 return "", nil, fmt.Errorf("invalid cursor hot rank")
261 }
262
263 // Validate timestamp format
264 if _, err := time.Parse(time.RFC3339Nano, createdAt); err != nil {
265 return "", nil, fmt.Errorf("invalid cursor timestamp")
266 }
267
268 // Validate URI format (must be AT-URI)
269 if !strings.HasPrefix(uri, "at://") {
270 return "", nil, fmt.Errorf("invalid cursor URI")
271 }
272
273 // CRITICAL: Compare against the computed hot_rank expression, not p.score
274 // This prevents dropping posts with higher raw scores but lower hot ranks
275 //
276 // NOTE: We exclude the exact cursor post by URI to handle time drift in hot_rank
277 // (hot_rank changes with NOW(), so the same post may have different ranks over time)
278 filter := fmt.Sprintf(`AND ((%s < $3 OR (%s = $3 AND p.created_at < $4) OR (%s = $3 AND p.created_at = $4 AND p.uri < $5)) AND p.uri != $6)`,
279 hotRankExpression, hotRankExpression, hotRankExpression)
280 return filter, []interface{}{hotRank, createdAt, uri, uri}, nil
281
282 default:
283 return "", nil, nil
284 }
285}
286
287// buildCursor creates pagination cursor from last post
288func (r *postgresFeedRepo) buildCursor(post *posts.PostView, sort string, hotRank float64) string {
289 var cursorStr string
290 // Use :: as delimiter following Bluesky convention
291 // Safe because :: doesn't appear in ISO timestamps or AT-URIs
292 const delimiter = "::"
293
294 switch sort {
295 case "new":
296 // Format: timestamp::uri (following Bluesky pattern)
297 cursorStr = fmt.Sprintf("%s%s%s", post.CreatedAt.Format(time.RFC3339Nano), delimiter, post.URI)
298
299 case "top":
300 // Format: score::timestamp::uri
301 score := 0
302 if post.Stats != nil {
303 score = post.Stats.Score
304 }
305 cursorStr = fmt.Sprintf("%d%s%s%s%s", score, delimiter, post.CreatedAt.Format(time.RFC3339Nano), delimiter, post.URI)
306
307 case "hot":
308 // Format: hot_rank::timestamp::uri
309 // CRITICAL: Use computed hot_rank with full precision to prevent pagination bugs
310 // Using 'g' format with -1 precision gives us full float64 precision without trailing zeros
311 // This prevents posts being dropped when hot ranks differ by <1e-6
312 hotRankStr := strconv.FormatFloat(hotRank, 'g', -1, 64)
313 cursorStr = fmt.Sprintf("%s%s%s%s%s", hotRankStr, delimiter, post.CreatedAt.Format(time.RFC3339Nano), delimiter, post.URI)
314
315 default:
316 cursorStr = post.URI
317 }
318
319 return base64.StdEncoding.EncodeToString([]byte(cursorStr))
320}
321
322// scanFeedViewPost scans a row into FeedViewPost
323// Alpha: No viewer state - basic community feed only
324func (r *postgresFeedRepo) scanFeedViewPost(rows *sql.Rows) (*communityFeeds.FeedViewPost, float64, error) {
325 var (
326 postView posts.PostView
327 authorView posts.AuthorView
328 communityRef posts.CommunityRef
329 title, content sql.NullString
330 facets, embed sql.NullString
331 labelsJSON sql.NullString
332 editedAt sql.NullTime
333 communityAvatar sql.NullString
334 hotRank sql.NullFloat64
335 )
336
337 err := rows.Scan(
338 &postView.URI, &postView.CID, &postView.RKey,
339 &authorView.DID, &authorView.Handle,
340 &communityRef.DID, &communityRef.Name, &communityAvatar,
341 &title, &content, &facets, &embed, &labelsJSON,
342 &postView.CreatedAt, &editedAt, &postView.IndexedAt,
343 &postView.UpvoteCount, &postView.DownvoteCount, &postView.Score, &postView.CommentCount,
344 &hotRank,
345 )
346 if err != nil {
347 return nil, 0, err
348 }
349
350 // Build author view (no display_name or avatar in users table yet)
351 postView.Author = &authorView
352
353 // Build community ref
354 communityRef.Avatar = nullStringPtr(communityAvatar)
355 postView.Community = &communityRef
356
357 // Set optional fields
358 postView.Title = nullStringPtr(title)
359 postView.Text = nullStringPtr(content)
360
361 // Parse facets JSON
362 if facets.Valid {
363 var facetArray []interface{}
364 if err := json.Unmarshal([]byte(facets.String), &facetArray); err == nil {
365 postView.TextFacets = facetArray
366 }
367 }
368
369 // Parse embed JSON
370 if embed.Valid {
371 var embedData interface{}
372 if err := json.Unmarshal([]byte(embed.String), &embedData); err == nil {
373 postView.Embed = embedData
374 }
375 }
376
377 // Build stats
378 postView.Stats = &posts.PostStats{
379 Upvotes: postView.UpvoteCount,
380 Downvotes: postView.DownvoteCount,
381 Score: postView.Score,
382 CommentCount: postView.CommentCount,
383 }
384
385 // Alpha: No viewer state for basic feed
386 // TODO(feed-generator): Implement viewer state (saved, voted, blocked) in feed generator skeleton
387
388 // Build the record (required by lexicon - social.coves.community.post structure)
389 record := map[string]interface{}{
390 "$type": "social.coves.community.post",
391 "community": communityRef.DID,
392 "author": authorView.DID,
393 "createdAt": postView.CreatedAt.Format(time.RFC3339),
394 }
395
396 // Add optional fields to record if present
397 if title.Valid {
398 record["title"] = title.String
399 }
400 if content.Valid {
401 record["content"] = content.String
402 }
403 if facets.Valid {
404 var facetArray []interface{}
405 if err := json.Unmarshal([]byte(facets.String), &facetArray); err == nil {
406 record["facets"] = facetArray
407 }
408 }
409 if embed.Valid {
410 var embedData interface{}
411 if err := json.Unmarshal([]byte(embed.String), &embedData); err == nil {
412 record["embed"] = embedData
413 }
414 }
415 if labelsJSON.Valid {
416 // Labels are stored as JSONB containing full com.atproto.label.defs#selfLabels structure
417 // Deserialize and include in record
418 var selfLabels posts.SelfLabels
419 if err := json.Unmarshal([]byte(labelsJSON.String), &selfLabels); err == nil {
420 record["labels"] = selfLabels
421 }
422 }
423
424 postView.Record = record
425
426 // Wrap in FeedViewPost
427 feedPost := &communityFeeds.FeedViewPost{
428 Post: &postView,
429 // Reason: nil, // TODO(feed-generator): Implement pinned posts
430 // Reply: nil, // TODO(feed-generator): Implement reply context
431 }
432
433 // Return the computed hot_rank (0.0 if NULL for non-hot sorts)
434 hotRankValue := 0.0
435 if hotRank.Valid {
436 hotRankValue = hotRank.Float64
437 }
438
439 return feedPost, hotRankValue, nil
440}
441
442// Helper function to convert sql.NullString to *string
443func nullStringPtr(ns sql.NullString) *string {
444 if !ns.Valid {
445 return nil
446 }
447 return &ns.String
448}