A community based topic aggregation platform built on atproto
1package postgres
2
3import (
4 "Coves/internal/core/communityFeeds"
5 "Coves/internal/core/posts"
6 "context"
7 "database/sql"
8 "encoding/base64"
9 "encoding/json"
10 "fmt"
11 "strconv"
12 "strings"
13 "time"
14
15 "github.com/lib/pq"
16)
17
18type postgresFeedRepo struct {
19 db *sql.DB
20}
21
22// sortClauses maps sort types to safe SQL ORDER BY clauses
23// This whitelist prevents SQL injection via dynamic ORDER BY construction
24var sortClauses = map[string]string{
25 "hot": `(p.score / POWER(EXTRACT(EPOCH FROM (NOW() - p.created_at))/3600 + 2, 1.5)) DESC, p.created_at DESC, p.uri DESC`,
26 "top": `p.score DESC, p.created_at DESC, p.uri DESC`,
27 "new": `p.created_at DESC, p.uri DESC`,
28}
29
30// hotRankExpression is the SQL expression for computing the hot rank
31// NOTE: Uses NOW() which means hot_rank changes over time - this is expected behavior
32// for hot sorting (posts naturally age out). Slight time drift between cursor creation
33// and usage may cause minor reordering but won't drop posts entirely (unlike using raw score).
34const hotRankExpression = `(p.score / POWER(EXTRACT(EPOCH FROM (NOW() - p.created_at))/3600 + 2, 1.5))`
35
36// NewCommunityFeedRepository creates a new PostgreSQL feed repository
37func NewCommunityFeedRepository(db *sql.DB) communityFeeds.Repository {
38 return &postgresFeedRepo{db: db}
39}
40
41// GetCommunityFeed retrieves posts from a community with sorting and pagination
42// Single query with JOINs for optimal performance
43func (r *postgresFeedRepo) GetCommunityFeed(ctx context.Context, req communityFeeds.GetCommunityFeedRequest) ([]*communityFeeds.FeedViewPost, *string, error) {
44 // Build ORDER BY clause based on sort type
45 orderBy, timeFilter := r.buildSortClause(req.Sort, req.Timeframe)
46
47 // Build cursor filter for pagination
48 cursorFilter, cursorValues, err := r.parseCursor(req.Cursor, req.Sort)
49 if err != nil {
50 return nil, nil, communityFeeds.ErrInvalidCursor
51 }
52
53 // Build the main query
54 // For hot sort, we need to compute and return the hot_rank for cursor building
55 var selectClause string
56 if req.Sort == "hot" {
57 selectClause = fmt.Sprintf(`
58 SELECT
59 p.uri, p.cid, p.rkey,
60 p.author_did, u.handle as author_handle,
61 p.community_did, c.name as community_name, c.avatar_cid as community_avatar,
62 p.title, p.content, p.content_facets, p.embed, p.content_labels,
63 p.created_at, p.edited_at, p.indexed_at,
64 p.upvote_count, p.downvote_count, p.score, p.comment_count,
65 %s as hot_rank
66 FROM posts p`, hotRankExpression)
67 } else {
68 selectClause = `
69 SELECT
70 p.uri, p.cid, p.rkey,
71 p.author_did, u.handle as author_handle,
72 p.community_did, c.name as community_name, c.avatar_cid as community_avatar,
73 p.title, p.content, p.content_facets, p.embed, p.content_labels,
74 p.created_at, p.edited_at, p.indexed_at,
75 p.upvote_count, p.downvote_count, p.score, p.comment_count,
76 NULL::numeric as hot_rank
77 FROM posts p`
78 }
79
80 query := fmt.Sprintf(`
81 %s
82 INNER JOIN users u ON p.author_did = u.did
83 INNER JOIN communities c ON p.community_did = c.did
84 WHERE p.community_did = $1
85 AND p.deleted_at IS NULL
86 %s
87 %s
88 ORDER BY %s
89 LIMIT $2
90 `, selectClause, timeFilter, cursorFilter, orderBy)
91
92 // Prepare query arguments
93 args := []interface{}{req.Community, req.Limit + 1} // +1 to check for next page
94 args = append(args, cursorValues...)
95
96 // Execute query
97 rows, err := r.db.QueryContext(ctx, query, args...)
98 if err != nil {
99 return nil, nil, fmt.Errorf("failed to query community feed: %w", err)
100 }
101 defer func() {
102 if err := rows.Close(); err != nil {
103 // Log close errors (non-fatal but worth noting)
104 fmt.Printf("Warning: failed to close rows: %v\n", err)
105 }
106 }()
107
108 // Scan results
109 var feedPosts []*communityFeeds.FeedViewPost
110 var hotRanks []float64 // Store hot ranks for cursor building
111 for rows.Next() {
112 feedPost, hotRank, err := r.scanFeedViewPost(rows)
113 if err != nil {
114 return nil, nil, fmt.Errorf("failed to scan feed post: %w", err)
115 }
116 feedPosts = append(feedPosts, feedPost)
117 hotRanks = append(hotRanks, hotRank)
118 }
119
120 if err := rows.Err(); err != nil {
121 return nil, nil, fmt.Errorf("error iterating feed results: %w", err)
122 }
123
124 // Handle pagination cursor
125 var cursor *string
126 if len(feedPosts) > req.Limit && req.Limit > 0 {
127 feedPosts = feedPosts[:req.Limit]
128 hotRanks = hotRanks[:req.Limit]
129 lastPost := feedPosts[len(feedPosts)-1].Post
130 lastHotRank := hotRanks[len(hotRanks)-1]
131 cursorStr := r.buildCursor(lastPost, req.Sort, lastHotRank)
132 cursor = &cursorStr
133 }
134
135 return feedPosts, cursor, nil
136}
137
138// buildSortClause returns the ORDER BY SQL and optional time filter
139func (r *postgresFeedRepo) buildSortClause(sort, timeframe string) (string, string) {
140 // Use whitelist map for ORDER BY clause (defense-in-depth against SQL injection)
141 orderBy := sortClauses[sort]
142 if orderBy == "" {
143 orderBy = sortClauses["hot"] // safe default
144 }
145
146 // Add time filter for "top" sort
147 var timeFilter string
148 if sort == "top" {
149 timeFilter = r.buildTimeFilter(timeframe)
150 }
151
152 return orderBy, timeFilter
153}
154
155// buildTimeFilter returns SQL filter for timeframe
156func (r *postgresFeedRepo) buildTimeFilter(timeframe string) string {
157 if timeframe == "" || timeframe == "all" {
158 return ""
159 }
160
161 var interval string
162 switch timeframe {
163 case "hour":
164 interval = "1 hour"
165 case "day":
166 interval = "1 day"
167 case "week":
168 interval = "1 week"
169 case "month":
170 interval = "1 month"
171 case "year":
172 interval = "1 year"
173 default:
174 return ""
175 }
176
177 return fmt.Sprintf("AND p.created_at > NOW() - INTERVAL '%s'", interval)
178}
179
180// parseCursor decodes pagination cursor
181func (r *postgresFeedRepo) parseCursor(cursor *string, sort string) (string, []interface{}, error) {
182 if cursor == nil || *cursor == "" {
183 return "", nil, nil
184 }
185
186 // Decode base64 cursor
187 decoded, err := base64.StdEncoding.DecodeString(*cursor)
188 if err != nil {
189 return "", nil, fmt.Errorf("invalid cursor encoding")
190 }
191
192 // Parse cursor based on sort type using :: delimiter (Bluesky convention)
193 parts := strings.Split(string(decoded), "::")
194
195 switch sort {
196 case "new":
197 // Cursor format: timestamp::uri
198 if len(parts) != 2 {
199 return "", nil, fmt.Errorf("invalid cursor format")
200 }
201
202 createdAt := parts[0]
203 uri := parts[1]
204
205 // Validate timestamp format
206 if _, err := time.Parse(time.RFC3339Nano, createdAt); err != nil {
207 return "", nil, fmt.Errorf("invalid cursor timestamp")
208 }
209
210 // Validate URI format (must be AT-URI)
211 if !strings.HasPrefix(uri, "at://") {
212 return "", nil, fmt.Errorf("invalid cursor URI")
213 }
214
215 filter := `AND (p.created_at < $3 OR (p.created_at = $3 AND p.uri < $4))`
216 return filter, []interface{}{createdAt, uri}, nil
217
218 case "top":
219 // Cursor format: score::timestamp::uri
220 if len(parts) != 3 {
221 return "", nil, fmt.Errorf("invalid cursor format for %s sort", sort)
222 }
223
224 scoreStr := parts[0]
225 createdAt := parts[1]
226 uri := parts[2]
227
228 // Validate score is numeric
229 score := 0
230 if _, err := fmt.Sscanf(scoreStr, "%d", &score); err != nil {
231 return "", nil, fmt.Errorf("invalid cursor score")
232 }
233
234 // Validate timestamp format
235 if _, err := time.Parse(time.RFC3339Nano, createdAt); err != nil {
236 return "", nil, fmt.Errorf("invalid cursor timestamp")
237 }
238
239 // Validate URI format (must be AT-URI)
240 if !strings.HasPrefix(uri, "at://") {
241 return "", nil, fmt.Errorf("invalid cursor URI")
242 }
243
244 filter := `AND (p.score < $3 OR (p.score = $3 AND p.created_at < $4) OR (p.score = $3 AND p.created_at = $4 AND p.uri < $5))`
245 return filter, []interface{}{score, createdAt, uri}, nil
246
247 case "hot":
248 // Cursor format: hot_rank::timestamp::uri
249 // CRITICAL: Must use computed hot_rank, not raw score, to prevent pagination bugs
250 if len(parts) != 3 {
251 return "", nil, fmt.Errorf("invalid cursor format for hot sort")
252 }
253
254 hotRankStr := parts[0]
255 createdAt := parts[1]
256 uri := parts[2]
257
258 // Validate hot_rank is numeric (float)
259 hotRank := 0.0
260 if _, err := fmt.Sscanf(hotRankStr, "%f", &hotRank); err != nil {
261 return "", nil, fmt.Errorf("invalid cursor hot rank")
262 }
263
264 // Validate timestamp format
265 if _, err := time.Parse(time.RFC3339Nano, createdAt); err != nil {
266 return "", nil, fmt.Errorf("invalid cursor timestamp")
267 }
268
269 // Validate URI format (must be AT-URI)
270 if !strings.HasPrefix(uri, "at://") {
271 return "", nil, fmt.Errorf("invalid cursor URI")
272 }
273
274 // CRITICAL: Compare against the computed hot_rank expression, not p.score
275 // This prevents dropping posts with higher raw scores but lower hot ranks
276 //
277 // NOTE: We exclude the exact cursor post by URI to handle time drift in hot_rank
278 // (hot_rank changes with NOW(), so the same post may have different ranks over time)
279 filter := fmt.Sprintf(`AND ((%s < $3 OR (%s = $3 AND p.created_at < $4) OR (%s = $3 AND p.created_at = $4 AND p.uri < $5)) AND p.uri != $6)`,
280 hotRankExpression, hotRankExpression, hotRankExpression)
281 return filter, []interface{}{hotRank, createdAt, uri, uri}, nil
282
283 default:
284 return "", nil, nil
285 }
286}
287
288// buildCursor creates pagination cursor from last post
289func (r *postgresFeedRepo) buildCursor(post *posts.PostView, sort string, hotRank float64) string {
290 var cursorStr string
291 // Use :: as delimiter following Bluesky convention
292 // Safe because :: doesn't appear in ISO timestamps or AT-URIs
293 const delimiter = "::"
294
295 switch sort {
296 case "new":
297 // Format: timestamp::uri (following Bluesky pattern)
298 cursorStr = fmt.Sprintf("%s%s%s", post.CreatedAt.Format(time.RFC3339Nano), delimiter, post.URI)
299
300 case "top":
301 // Format: score::timestamp::uri
302 score := 0
303 if post.Stats != nil {
304 score = post.Stats.Score
305 }
306 cursorStr = fmt.Sprintf("%d%s%s%s%s", score, delimiter, post.CreatedAt.Format(time.RFC3339Nano), delimiter, post.URI)
307
308 case "hot":
309 // Format: hot_rank::timestamp::uri
310 // CRITICAL: Use computed hot_rank with full precision to prevent pagination bugs
311 // Using 'g' format with -1 precision gives us full float64 precision without trailing zeros
312 // This prevents posts being dropped when hot ranks differ by <1e-6
313 hotRankStr := strconv.FormatFloat(hotRank, 'g', -1, 64)
314 cursorStr = fmt.Sprintf("%s%s%s%s%s", hotRankStr, delimiter, post.CreatedAt.Format(time.RFC3339Nano), delimiter, post.URI)
315
316 default:
317 cursorStr = post.URI
318 }
319
320 return base64.StdEncoding.EncodeToString([]byte(cursorStr))
321}
322
323// scanFeedViewPost scans a row into FeedViewPost
324// Alpha: No viewer state - basic community feed only
325func (r *postgresFeedRepo) scanFeedViewPost(rows *sql.Rows) (*communityFeeds.FeedViewPost, float64, error) {
326 var (
327 postView posts.PostView
328 authorView posts.AuthorView
329 communityRef posts.CommunityRef
330 title, content sql.NullString
331 facets, embed sql.NullString
332 labels pq.StringArray
333 editedAt sql.NullTime
334 communityAvatar sql.NullString
335 hotRank sql.NullFloat64
336 )
337
338 err := rows.Scan(
339 &postView.URI, &postView.CID, &postView.RKey,
340 &authorView.DID, &authorView.Handle,
341 &communityRef.DID, &communityRef.Name, &communityAvatar,
342 &title, &content, &facets, &embed, &labels,
343 &postView.CreatedAt, &editedAt, &postView.IndexedAt,
344 &postView.UpvoteCount, &postView.DownvoteCount, &postView.Score, &postView.CommentCount,
345 &hotRank,
346 )
347 if err != nil {
348 return nil, 0, err
349 }
350
351 // Build author view (no display_name or avatar in users table yet)
352 postView.Author = &authorView
353
354 // Build community ref
355 communityRef.Avatar = nullStringPtr(communityAvatar)
356 postView.Community = &communityRef
357
358 // Set optional fields
359 postView.Title = nullStringPtr(title)
360 postView.Text = nullStringPtr(content)
361
362 // Parse facets JSON
363 if facets.Valid {
364 var facetArray []interface{}
365 if err := json.Unmarshal([]byte(facets.String), &facetArray); err == nil {
366 postView.TextFacets = facetArray
367 }
368 }
369
370 // Parse embed JSON
371 if embed.Valid {
372 var embedData interface{}
373 if err := json.Unmarshal([]byte(embed.String), &embedData); err == nil {
374 postView.Embed = embedData
375 }
376 }
377
378 // Build stats
379 postView.Stats = &posts.PostStats{
380 Upvotes: postView.UpvoteCount,
381 Downvotes: postView.DownvoteCount,
382 Score: postView.Score,
383 CommentCount: postView.CommentCount,
384 }
385
386 // Alpha: No viewer state for basic feed
387 // TODO(feed-generator): Implement viewer state (saved, voted, blocked) in feed generator skeleton
388
389 // Build the record (required by lexicon - social.coves.post.record structure)
390 record := map[string]interface{}{
391 "$type": "social.coves.post.record",
392 "community": communityRef.DID,
393 "author": authorView.DID,
394 "createdAt": postView.CreatedAt.Format(time.RFC3339),
395 }
396
397 // Add optional fields to record if present
398 if title.Valid {
399 record["title"] = title.String
400 }
401 if content.Valid {
402 record["content"] = content.String
403 }
404 if facets.Valid {
405 var facetArray []interface{}
406 if err := json.Unmarshal([]byte(facets.String), &facetArray); err == nil {
407 record["facets"] = facetArray
408 }
409 }
410 if embed.Valid {
411 var embedData interface{}
412 if err := json.Unmarshal([]byte(embed.String), &embedData); err == nil {
413 record["embed"] = embedData
414 }
415 }
416 if len(labels) > 0 {
417 record["contentLabels"] = labels
418 }
419
420 postView.Record = record
421
422 // Wrap in FeedViewPost
423 feedPost := &communityFeeds.FeedViewPost{
424 Post: &postView,
425 // Reason: nil, // TODO(feed-generator): Implement pinned posts
426 // Reply: nil, // TODO(feed-generator): Implement reply context
427 }
428
429 // Return the computed hot_rank (0.0 if NULL for non-hot sorts)
430 hotRankValue := 0.0
431 if hotRank.Valid {
432 hotRankValue = hotRank.Float64
433 }
434
435 return feedPost, hotRankValue, nil
436}
437
438// Helper function to convert sql.NullString to *string
439func nullStringPtr(ns sql.NullString) *string {
440 if !ns.Valid {
441 return nil
442 }
443 return &ns.String
444}