A community based topic aggregation platform built on atproto
1package postgres
2
3import (
4 "Coves/internal/core/communityFeeds"
5 "Coves/internal/core/posts"
6 "context"
7 "database/sql"
8 "encoding/base64"
9 "encoding/json"
10 "fmt"
11 "strconv"
12 "strings"
13 "time"
14)
15
16type postgresFeedRepo struct {
17 db *sql.DB
18}
19
20// sortClauses maps sort types to safe SQL ORDER BY clauses
21// This whitelist prevents SQL injection via dynamic ORDER BY construction
22var sortClauses = map[string]string{
23 "hot": `(p.score / POWER(EXTRACT(EPOCH FROM (NOW() - p.created_at))/3600 + 2, 1.5)) DESC, p.created_at DESC, p.uri DESC`,
24 "top": `p.score DESC, p.created_at DESC, p.uri DESC`,
25 "new": `p.created_at DESC, p.uri DESC`,
26}
27
28// hotRankExpression is the SQL expression for computing the hot rank
29// NOTE: Uses NOW() which means hot_rank changes over time - this is expected behavior
30// for hot sorting (posts naturally age out). Slight time drift between cursor creation
31// and usage may cause minor reordering but won't drop posts entirely (unlike using raw score).
32const hotRankExpression = `(p.score / POWER(EXTRACT(EPOCH FROM (NOW() - p.created_at))/3600 + 2, 1.5))`
33
34// NewCommunityFeedRepository creates a new PostgreSQL feed repository
35func NewCommunityFeedRepository(db *sql.DB) communityFeeds.Repository {
36 return &postgresFeedRepo{db: db}
37}
38
39// GetCommunityFeed retrieves posts from a community with sorting and pagination
40// Single query with JOINs for optimal performance
41func (r *postgresFeedRepo) GetCommunityFeed(ctx context.Context, req communityFeeds.GetCommunityFeedRequest) ([]*communityFeeds.FeedViewPost, *string, error) {
42 // Build ORDER BY clause based on sort type
43 orderBy, timeFilter := r.buildSortClause(req.Sort, req.Timeframe)
44
45 // Build cursor filter for pagination
46 cursorFilter, cursorValues, err := r.parseCursor(req.Cursor, req.Sort)
47 if err != nil {
48 return nil, nil, communityFeeds.ErrInvalidCursor
49 }
50
51 // Build the main query
52 // For hot sort, we need to compute and return the hot_rank for cursor building
53 var selectClause string
54 if req.Sort == "hot" {
55 selectClause = fmt.Sprintf(`
56 SELECT
57 p.uri, p.cid, p.rkey,
58 p.author_did, u.handle as author_handle,
59 p.community_did, c.name as community_name, c.avatar_cid as community_avatar,
60 p.title, p.content, p.content_facets, p.embed, p.content_labels,
61 p.created_at, p.edited_at, p.indexed_at,
62 p.upvote_count, p.downvote_count, p.score, p.comment_count,
63 %s as hot_rank
64 FROM posts p`, hotRankExpression)
65 } else {
66 selectClause = `
67 SELECT
68 p.uri, p.cid, p.rkey,
69 p.author_did, u.handle as author_handle,
70 p.community_did, c.name as community_name, c.avatar_cid as community_avatar,
71 p.title, p.content, p.content_facets, p.embed, p.content_labels,
72 p.created_at, p.edited_at, p.indexed_at,
73 p.upvote_count, p.downvote_count, p.score, p.comment_count,
74 NULL::numeric as hot_rank
75 FROM posts p`
76 }
77
78 query := fmt.Sprintf(`
79 %s
80 INNER JOIN users u ON p.author_did = u.did
81 INNER JOIN communities c ON p.community_did = c.did
82 WHERE p.community_did = $1
83 AND p.deleted_at IS NULL
84 %s
85 %s
86 ORDER BY %s
87 LIMIT $2
88 `, selectClause, timeFilter, cursorFilter, orderBy)
89
90 // Prepare query arguments
91 args := []interface{}{req.Community, req.Limit + 1} // +1 to check for next page
92 args = append(args, cursorValues...)
93
94 // Execute query
95 rows, err := r.db.QueryContext(ctx, query, args...)
96 if err != nil {
97 return nil, nil, fmt.Errorf("failed to query community feed: %w", err)
98 }
99 defer func() {
100 if err := rows.Close(); err != nil {
101 // Log close errors (non-fatal but worth noting)
102 fmt.Printf("Warning: failed to close rows: %v\n", err)
103 }
104 }()
105
106 // Scan results
107 var feedPosts []*communityFeeds.FeedViewPost
108 var hotRanks []float64 // Store hot ranks for cursor building
109 for rows.Next() {
110 feedPost, hotRank, err := r.scanFeedViewPost(rows)
111 if err != nil {
112 return nil, nil, fmt.Errorf("failed to scan feed post: %w", err)
113 }
114 feedPosts = append(feedPosts, feedPost)
115 hotRanks = append(hotRanks, hotRank)
116 }
117
118 if err := rows.Err(); err != nil {
119 return nil, nil, fmt.Errorf("error iterating feed results: %w", err)
120 }
121
122 // Handle pagination cursor
123 var cursor *string
124 if len(feedPosts) > req.Limit && req.Limit > 0 {
125 feedPosts = feedPosts[:req.Limit]
126 hotRanks = hotRanks[:req.Limit]
127 lastPost := feedPosts[len(feedPosts)-1].Post
128 lastHotRank := hotRanks[len(hotRanks)-1]
129 cursorStr := r.buildCursor(lastPost, req.Sort, lastHotRank)
130 cursor = &cursorStr
131 }
132
133 return feedPosts, cursor, nil
134}
135
136// buildSortClause returns the ORDER BY SQL and optional time filter
137func (r *postgresFeedRepo) buildSortClause(sort, timeframe string) (string, string) {
138 // Use whitelist map for ORDER BY clause (defense-in-depth against SQL injection)
139 orderBy := sortClauses[sort]
140 if orderBy == "" {
141 orderBy = sortClauses["hot"] // safe default
142 }
143
144 // Add time filter for "top" sort
145 var timeFilter string
146 if sort == "top" {
147 timeFilter = r.buildTimeFilter(timeframe)
148 }
149
150 return orderBy, timeFilter
151}
152
153// buildTimeFilter returns SQL filter for timeframe
154func (r *postgresFeedRepo) buildTimeFilter(timeframe string) string {
155 if timeframe == "" || timeframe == "all" {
156 return ""
157 }
158
159 var interval string
160 switch timeframe {
161 case "hour":
162 interval = "1 hour"
163 case "day":
164 interval = "1 day"
165 case "week":
166 interval = "1 week"
167 case "month":
168 interval = "1 month"
169 case "year":
170 interval = "1 year"
171 default:
172 return ""
173 }
174
175 return fmt.Sprintf("AND p.created_at > NOW() - INTERVAL '%s'", interval)
176}
177
178// parseCursor decodes pagination cursor
179func (r *postgresFeedRepo) parseCursor(cursor *string, sort string) (string, []interface{}, error) {
180 if cursor == nil || *cursor == "" {
181 return "", nil, nil
182 }
183
184 // Decode base64 cursor
185 decoded, err := base64.StdEncoding.DecodeString(*cursor)
186 if err != nil {
187 return "", nil, fmt.Errorf("invalid cursor encoding")
188 }
189
190 // Parse cursor based on sort type using :: delimiter (Bluesky convention)
191 parts := strings.Split(string(decoded), "::")
192
193 switch sort {
194 case "new":
195 // Cursor format: timestamp::uri
196 if len(parts) != 2 {
197 return "", nil, fmt.Errorf("invalid cursor format")
198 }
199
200 createdAt := parts[0]
201 uri := parts[1]
202
203 // Validate timestamp format
204 if _, err := time.Parse(time.RFC3339Nano, createdAt); err != nil {
205 return "", nil, fmt.Errorf("invalid cursor timestamp")
206 }
207
208 // Validate URI format (must be AT-URI)
209 if !strings.HasPrefix(uri, "at://") {
210 return "", nil, fmt.Errorf("invalid cursor URI")
211 }
212
213 filter := `AND (p.created_at < $3 OR (p.created_at = $3 AND p.uri < $4))`
214 return filter, []interface{}{createdAt, uri}, nil
215
216 case "top":
217 // Cursor format: score::timestamp::uri
218 if len(parts) != 3 {
219 return "", nil, fmt.Errorf("invalid cursor format for %s sort", sort)
220 }
221
222 scoreStr := parts[0]
223 createdAt := parts[1]
224 uri := parts[2]
225
226 // Validate score is numeric
227 score := 0
228 if _, err := fmt.Sscanf(scoreStr, "%d", &score); err != nil {
229 return "", nil, fmt.Errorf("invalid cursor score")
230 }
231
232 // Validate timestamp format
233 if _, err := time.Parse(time.RFC3339Nano, createdAt); err != nil {
234 return "", nil, fmt.Errorf("invalid cursor timestamp")
235 }
236
237 // Validate URI format (must be AT-URI)
238 if !strings.HasPrefix(uri, "at://") {
239 return "", nil, fmt.Errorf("invalid cursor URI")
240 }
241
242 filter := `AND (p.score < $3 OR (p.score = $3 AND p.created_at < $4) OR (p.score = $3 AND p.created_at = $4 AND p.uri < $5))`
243 return filter, []interface{}{score, createdAt, uri}, nil
244
245 case "hot":
246 // Cursor format: hot_rank::timestamp::uri
247 // CRITICAL: Must use computed hot_rank, not raw score, to prevent pagination bugs
248 if len(parts) != 3 {
249 return "", nil, fmt.Errorf("invalid cursor format for hot sort")
250 }
251
252 hotRankStr := parts[0]
253 createdAt := parts[1]
254 uri := parts[2]
255
256 // Validate hot_rank is numeric (float)
257 hotRank := 0.0
258 if _, err := fmt.Sscanf(hotRankStr, "%f", &hotRank); err != nil {
259 return "", nil, fmt.Errorf("invalid cursor hot rank")
260 }
261
262 // Validate timestamp format
263 if _, err := time.Parse(time.RFC3339Nano, createdAt); err != nil {
264 return "", nil, fmt.Errorf("invalid cursor timestamp")
265 }
266
267 // Validate URI format (must be AT-URI)
268 if !strings.HasPrefix(uri, "at://") {
269 return "", nil, fmt.Errorf("invalid cursor URI")
270 }
271
272 // CRITICAL: Compare against the computed hot_rank expression, not p.score
273 // This prevents dropping posts with higher raw scores but lower hot ranks
274 //
275 // NOTE: We exclude the exact cursor post by URI to handle time drift in hot_rank
276 // (hot_rank changes with NOW(), so the same post may have different ranks over time)
277 filter := fmt.Sprintf(`AND ((%s < $3 OR (%s = $3 AND p.created_at < $4) OR (%s = $3 AND p.created_at = $4 AND p.uri < $5)) AND p.uri != $6)`,
278 hotRankExpression, hotRankExpression, hotRankExpression)
279 return filter, []interface{}{hotRank, createdAt, uri, uri}, nil
280
281 default:
282 return "", nil, nil
283 }
284}
285
286// buildCursor creates pagination cursor from last post
287func (r *postgresFeedRepo) buildCursor(post *posts.PostView, sort string, hotRank float64) string {
288 var cursorStr string
289 // Use :: as delimiter following Bluesky convention
290 // Safe because :: doesn't appear in ISO timestamps or AT-URIs
291 const delimiter = "::"
292
293 switch sort {
294 case "new":
295 // Format: timestamp::uri (following Bluesky pattern)
296 cursorStr = fmt.Sprintf("%s%s%s", post.CreatedAt.Format(time.RFC3339Nano), delimiter, post.URI)
297
298 case "top":
299 // Format: score::timestamp::uri
300 score := 0
301 if post.Stats != nil {
302 score = post.Stats.Score
303 }
304 cursorStr = fmt.Sprintf("%d%s%s%s%s", score, delimiter, post.CreatedAt.Format(time.RFC3339Nano), delimiter, post.URI)
305
306 case "hot":
307 // Format: hot_rank::timestamp::uri
308 // CRITICAL: Use computed hot_rank with full precision to prevent pagination bugs
309 // Using 'g' format with -1 precision gives us full float64 precision without trailing zeros
310 // This prevents posts being dropped when hot ranks differ by <1e-6
311 hotRankStr := strconv.FormatFloat(hotRank, 'g', -1, 64)
312 cursorStr = fmt.Sprintf("%s%s%s%s%s", hotRankStr, delimiter, post.CreatedAt.Format(time.RFC3339Nano), delimiter, post.URI)
313
314 default:
315 cursorStr = post.URI
316 }
317
318 return base64.StdEncoding.EncodeToString([]byte(cursorStr))
319}
320
321// scanFeedViewPost scans a row into FeedViewPost
322// Alpha: No viewer state - basic community feed only
323func (r *postgresFeedRepo) scanFeedViewPost(rows *sql.Rows) (*communityFeeds.FeedViewPost, float64, error) {
324 var (
325 postView posts.PostView
326 authorView posts.AuthorView
327 communityRef posts.CommunityRef
328 title, content sql.NullString
329 facets, embed sql.NullString
330 labelsJSON sql.NullString
331 editedAt sql.NullTime
332 communityAvatar sql.NullString
333 hotRank sql.NullFloat64
334 )
335
336 err := rows.Scan(
337 &postView.URI, &postView.CID, &postView.RKey,
338 &authorView.DID, &authorView.Handle,
339 &communityRef.DID, &communityRef.Name, &communityAvatar,
340 &title, &content, &facets, &embed, &labelsJSON,
341 &postView.CreatedAt, &editedAt, &postView.IndexedAt,
342 &postView.UpvoteCount, &postView.DownvoteCount, &postView.Score, &postView.CommentCount,
343 &hotRank,
344 )
345 if err != nil {
346 return nil, 0, err
347 }
348
349 // Build author view (no display_name or avatar in users table yet)
350 postView.Author = &authorView
351
352 // Build community ref
353 communityRef.Avatar = nullStringPtr(communityAvatar)
354 postView.Community = &communityRef
355
356 // Set optional fields
357 postView.Title = nullStringPtr(title)
358 postView.Text = nullStringPtr(content)
359
360 // Parse facets JSON
361 if facets.Valid {
362 var facetArray []interface{}
363 if err := json.Unmarshal([]byte(facets.String), &facetArray); err == nil {
364 postView.TextFacets = facetArray
365 }
366 }
367
368 // Parse embed JSON
369 if embed.Valid {
370 var embedData interface{}
371 if err := json.Unmarshal([]byte(embed.String), &embedData); err == nil {
372 postView.Embed = embedData
373 }
374 }
375
376 // Build stats
377 postView.Stats = &posts.PostStats{
378 Upvotes: postView.UpvoteCount,
379 Downvotes: postView.DownvoteCount,
380 Score: postView.Score,
381 CommentCount: postView.CommentCount,
382 }
383
384 // Alpha: No viewer state for basic feed
385 // TODO(feed-generator): Implement viewer state (saved, voted, blocked) in feed generator skeleton
386
387 // Build the record (required by lexicon - social.coves.community.post structure)
388 record := map[string]interface{}{
389 "$type": "social.coves.community.post",
390 "community": communityRef.DID,
391 "author": authorView.DID,
392 "createdAt": postView.CreatedAt.Format(time.RFC3339),
393 }
394
395 // Add optional fields to record if present
396 if title.Valid {
397 record["title"] = title.String
398 }
399 if content.Valid {
400 record["content"] = content.String
401 }
402 if facets.Valid {
403 var facetArray []interface{}
404 if err := json.Unmarshal([]byte(facets.String), &facetArray); err == nil {
405 record["facets"] = facetArray
406 }
407 }
408 if embed.Valid {
409 var embedData interface{}
410 if err := json.Unmarshal([]byte(embed.String), &embedData); err == nil {
411 record["embed"] = embedData
412 }
413 }
414 if labelsJSON.Valid {
415 // Labels are stored as JSONB containing full com.atproto.label.defs#selfLabels structure
416 // Deserialize and include in record
417 var selfLabels posts.SelfLabels
418 if err := json.Unmarshal([]byte(labelsJSON.String), &selfLabels); err == nil {
419 record["labels"] = selfLabels
420 }
421 }
422
423 postView.Record = record
424
425 // Wrap in FeedViewPost
426 feedPost := &communityFeeds.FeedViewPost{
427 Post: &postView,
428 // Reason: nil, // TODO(feed-generator): Implement pinned posts
429 // Reply: nil, // TODO(feed-generator): Implement reply context
430 }
431
432 // Return the computed hot_rank (0.0 if NULL for non-hot sorts)
433 hotRankValue := 0.0
434 if hotRank.Valid {
435 hotRankValue = hotRank.Float64
436 }
437
438 return feedPost, hotRankValue, nil
439}
440
441// Helper function to convert sql.NullString to *string
442func nullStringPtr(ns sql.NullString) *string {
443 if !ns.Valid {
444 return nil
445 }
446 return &ns.String
447}