A community based topic aggregation platform built on atproto
1package postgres 2 3import ( 4 "context" 5 "database/sql" 6 "encoding/base64" 7 "encoding/json" 8 "fmt" 9 "strconv" 10 "strings" 11 "time" 12 13 "Coves/internal/core/communityFeeds" 14 "Coves/internal/core/posts" 15) 16 17type postgresFeedRepo struct { 18 db *sql.DB 19} 20 21// sortClauses maps sort types to safe SQL ORDER BY clauses 22// This whitelist prevents SQL injection via dynamic ORDER BY construction 23var sortClauses = map[string]string{ 24 "hot": `(p.score / POWER(EXTRACT(EPOCH FROM (NOW() - p.created_at))/3600 + 2, 1.5)) DESC, p.created_at DESC, p.uri DESC`, 25 "top": `p.score DESC, p.created_at DESC, p.uri DESC`, 26 "new": `p.created_at DESC, p.uri DESC`, 27} 28 29// hotRankExpression is the SQL expression for computing the hot rank 30// NOTE: Uses NOW() which means hot_rank changes over time - this is expected behavior 31// for hot sorting (posts naturally age out). Slight time drift between cursor creation 32// and usage may cause minor reordering but won't drop posts entirely (unlike using raw score). 33const hotRankExpression = `(p.score / POWER(EXTRACT(EPOCH FROM (NOW() - p.created_at))/3600 + 2, 1.5))` 34 35// NewCommunityFeedRepository creates a new PostgreSQL feed repository 36func NewCommunityFeedRepository(db *sql.DB) communityFeeds.Repository { 37 return &postgresFeedRepo{db: db} 38} 39 40// GetCommunityFeed retrieves posts from a community with sorting and pagination 41// Single query with JOINs for optimal performance 42func (r *postgresFeedRepo) GetCommunityFeed(ctx context.Context, req communityFeeds.GetCommunityFeedRequest) ([]*communityFeeds.FeedViewPost, *string, error) { 43 // Build ORDER BY clause based on sort type 44 orderBy, timeFilter := r.buildSortClause(req.Sort, req.Timeframe) 45 46 // Build cursor filter for pagination 47 cursorFilter, cursorValues, err := r.parseCursor(req.Cursor, req.Sort) 48 if err != nil { 49 return nil, nil, communityFeeds.ErrInvalidCursor 50 } 51 52 // Build the main query 53 // For hot sort, we need to compute and return the hot_rank for cursor building 54 var selectClause string 55 if req.Sort == "hot" { 56 selectClause = fmt.Sprintf(` 57 SELECT 58 p.uri, p.cid, p.rkey, 59 p.author_did, u.handle as author_handle, 60 p.community_did, c.name as community_name, c.avatar_cid as community_avatar, 61 p.title, p.content, p.content_facets, p.embed, p.content_labels, 62 p.created_at, p.edited_at, p.indexed_at, 63 p.upvote_count, p.downvote_count, p.score, p.comment_count, 64 %s as hot_rank 65 FROM posts p`, hotRankExpression) 66 } else { 67 selectClause = ` 68 SELECT 69 p.uri, p.cid, p.rkey, 70 p.author_did, u.handle as author_handle, 71 p.community_did, c.name as community_name, c.avatar_cid as community_avatar, 72 p.title, p.content, p.content_facets, p.embed, p.content_labels, 73 p.created_at, p.edited_at, p.indexed_at, 74 p.upvote_count, p.downvote_count, p.score, p.comment_count, 75 NULL::numeric as hot_rank 76 FROM posts p` 77 } 78 79 query := fmt.Sprintf(` 80 %s 81 INNER JOIN users u ON p.author_did = u.did 82 INNER JOIN communities c ON p.community_did = c.did 83 WHERE p.community_did = $1 84 AND p.deleted_at IS NULL 85 %s 86 %s 87 ORDER BY %s 88 LIMIT $2 89 `, selectClause, timeFilter, cursorFilter, orderBy) 90 91 // Prepare query arguments 92 args := []interface{}{req.Community, req.Limit + 1} // +1 to check for next page 93 args = append(args, cursorValues...) 94 95 // Execute query 96 rows, err := r.db.QueryContext(ctx, query, args...) 97 if err != nil { 98 return nil, nil, fmt.Errorf("failed to query community feed: %w", err) 99 } 100 defer func() { 101 if err := rows.Close(); err != nil { 102 // Log close errors (non-fatal but worth noting) 103 fmt.Printf("Warning: failed to close rows: %v\n", err) 104 } 105 }() 106 107 // Scan results 108 var feedPosts []*communityFeeds.FeedViewPost 109 var hotRanks []float64 // Store hot ranks for cursor building 110 for rows.Next() { 111 feedPost, hotRank, err := r.scanFeedViewPost(rows) 112 if err != nil { 113 return nil, nil, fmt.Errorf("failed to scan feed post: %w", err) 114 } 115 feedPosts = append(feedPosts, feedPost) 116 hotRanks = append(hotRanks, hotRank) 117 } 118 119 if err := rows.Err(); err != nil { 120 return nil, nil, fmt.Errorf("error iterating feed results: %w", err) 121 } 122 123 // Handle pagination cursor 124 var cursor *string 125 if len(feedPosts) > req.Limit && req.Limit > 0 { 126 feedPosts = feedPosts[:req.Limit] 127 hotRanks = hotRanks[:req.Limit] 128 lastPost := feedPosts[len(feedPosts)-1].Post 129 lastHotRank := hotRanks[len(hotRanks)-1] 130 cursorStr := r.buildCursor(lastPost, req.Sort, lastHotRank) 131 cursor = &cursorStr 132 } 133 134 return feedPosts, cursor, nil 135} 136 137// buildSortClause returns the ORDER BY SQL and optional time filter 138func (r *postgresFeedRepo) buildSortClause(sort, timeframe string) (string, string) { 139 // Use whitelist map for ORDER BY clause (defense-in-depth against SQL injection) 140 orderBy := sortClauses[sort] 141 if orderBy == "" { 142 orderBy = sortClauses["hot"] // safe default 143 } 144 145 // Add time filter for "top" sort 146 var timeFilter string 147 if sort == "top" { 148 timeFilter = r.buildTimeFilter(timeframe) 149 } 150 151 return orderBy, timeFilter 152} 153 154// buildTimeFilter returns SQL filter for timeframe 155func (r *postgresFeedRepo) buildTimeFilter(timeframe string) string { 156 if timeframe == "" || timeframe == "all" { 157 return "" 158 } 159 160 var interval string 161 switch timeframe { 162 case "hour": 163 interval = "1 hour" 164 case "day": 165 interval = "1 day" 166 case "week": 167 interval = "1 week" 168 case "month": 169 interval = "1 month" 170 case "year": 171 interval = "1 year" 172 default: 173 return "" 174 } 175 176 return fmt.Sprintf("AND p.created_at > NOW() - INTERVAL '%s'", interval) 177} 178 179// parseCursor decodes pagination cursor 180func (r *postgresFeedRepo) parseCursor(cursor *string, sort string) (string, []interface{}, error) { 181 if cursor == nil || *cursor == "" { 182 return "", nil, nil 183 } 184 185 // Decode base64 cursor 186 decoded, err := base64.StdEncoding.DecodeString(*cursor) 187 if err != nil { 188 return "", nil, fmt.Errorf("invalid cursor encoding") 189 } 190 191 // Parse cursor based on sort type using :: delimiter (Bluesky convention) 192 parts := strings.Split(string(decoded), "::") 193 194 switch sort { 195 case "new": 196 // Cursor format: timestamp::uri 197 if len(parts) != 2 { 198 return "", nil, fmt.Errorf("invalid cursor format") 199 } 200 201 createdAt := parts[0] 202 uri := parts[1] 203 204 // Validate timestamp format 205 if _, err := time.Parse(time.RFC3339Nano, createdAt); err != nil { 206 return "", nil, fmt.Errorf("invalid cursor timestamp") 207 } 208 209 // Validate URI format (must be AT-URI) 210 if !strings.HasPrefix(uri, "at://") { 211 return "", nil, fmt.Errorf("invalid cursor URI") 212 } 213 214 filter := `AND (p.created_at < $3 OR (p.created_at = $3 AND p.uri < $4))` 215 return filter, []interface{}{createdAt, uri}, nil 216 217 case "top": 218 // Cursor format: score::timestamp::uri 219 if len(parts) != 3 { 220 return "", nil, fmt.Errorf("invalid cursor format for %s sort", sort) 221 } 222 223 scoreStr := parts[0] 224 createdAt := parts[1] 225 uri := parts[2] 226 227 // Validate score is numeric 228 score := 0 229 if _, err := fmt.Sscanf(scoreStr, "%d", &score); err != nil { 230 return "", nil, fmt.Errorf("invalid cursor score") 231 } 232 233 // Validate timestamp format 234 if _, err := time.Parse(time.RFC3339Nano, createdAt); err != nil { 235 return "", nil, fmt.Errorf("invalid cursor timestamp") 236 } 237 238 // Validate URI format (must be AT-URI) 239 if !strings.HasPrefix(uri, "at://") { 240 return "", nil, fmt.Errorf("invalid cursor URI") 241 } 242 243 filter := `AND (p.score < $3 OR (p.score = $3 AND p.created_at < $4) OR (p.score = $3 AND p.created_at = $4 AND p.uri < $5))` 244 return filter, []interface{}{score, createdAt, uri}, nil 245 246 case "hot": 247 // Cursor format: hot_rank::timestamp::uri 248 // CRITICAL: Must use computed hot_rank, not raw score, to prevent pagination bugs 249 if len(parts) != 3 { 250 return "", nil, fmt.Errorf("invalid cursor format for hot sort") 251 } 252 253 hotRankStr := parts[0] 254 createdAt := parts[1] 255 uri := parts[2] 256 257 // Validate hot_rank is numeric (float) 258 hotRank := 0.0 259 if _, err := fmt.Sscanf(hotRankStr, "%f", &hotRank); err != nil { 260 return "", nil, fmt.Errorf("invalid cursor hot rank") 261 } 262 263 // Validate timestamp format 264 if _, err := time.Parse(time.RFC3339Nano, createdAt); err != nil { 265 return "", nil, fmt.Errorf("invalid cursor timestamp") 266 } 267 268 // Validate URI format (must be AT-URI) 269 if !strings.HasPrefix(uri, "at://") { 270 return "", nil, fmt.Errorf("invalid cursor URI") 271 } 272 273 // CRITICAL: Compare against the computed hot_rank expression, not p.score 274 // This prevents dropping posts with higher raw scores but lower hot ranks 275 // 276 // NOTE: We exclude the exact cursor post by URI to handle time drift in hot_rank 277 // (hot_rank changes with NOW(), so the same post may have different ranks over time) 278 filter := fmt.Sprintf(`AND ((%s < $3 OR (%s = $3 AND p.created_at < $4) OR (%s = $3 AND p.created_at = $4 AND p.uri < $5)) AND p.uri != $6)`, 279 hotRankExpression, hotRankExpression, hotRankExpression) 280 return filter, []interface{}{hotRank, createdAt, uri, uri}, nil 281 282 default: 283 return "", nil, nil 284 } 285} 286 287// buildCursor creates pagination cursor from last post 288func (r *postgresFeedRepo) buildCursor(post *posts.PostView, sort string, hotRank float64) string { 289 var cursorStr string 290 // Use :: as delimiter following Bluesky convention 291 // Safe because :: doesn't appear in ISO timestamps or AT-URIs 292 const delimiter = "::" 293 294 switch sort { 295 case "new": 296 // Format: timestamp::uri (following Bluesky pattern) 297 cursorStr = fmt.Sprintf("%s%s%s", post.CreatedAt.Format(time.RFC3339Nano), delimiter, post.URI) 298 299 case "top": 300 // Format: score::timestamp::uri 301 score := 0 302 if post.Stats != nil { 303 score = post.Stats.Score 304 } 305 cursorStr = fmt.Sprintf("%d%s%s%s%s", score, delimiter, post.CreatedAt.Format(time.RFC3339Nano), delimiter, post.URI) 306 307 case "hot": 308 // Format: hot_rank::timestamp::uri 309 // CRITICAL: Use computed hot_rank with full precision to prevent pagination bugs 310 // Using 'g' format with -1 precision gives us full float64 precision without trailing zeros 311 // This prevents posts being dropped when hot ranks differ by <1e-6 312 hotRankStr := strconv.FormatFloat(hotRank, 'g', -1, 64) 313 cursorStr = fmt.Sprintf("%s%s%s%s%s", hotRankStr, delimiter, post.CreatedAt.Format(time.RFC3339Nano), delimiter, post.URI) 314 315 default: 316 cursorStr = post.URI 317 } 318 319 return base64.StdEncoding.EncodeToString([]byte(cursorStr)) 320} 321 322// scanFeedViewPost scans a row into FeedViewPost 323// Alpha: No viewer state - basic community feed only 324func (r *postgresFeedRepo) scanFeedViewPost(rows *sql.Rows) (*communityFeeds.FeedViewPost, float64, error) { 325 var ( 326 postView posts.PostView 327 authorView posts.AuthorView 328 communityRef posts.CommunityRef 329 title, content sql.NullString 330 facets, embed sql.NullString 331 labelsJSON sql.NullString 332 editedAt sql.NullTime 333 communityAvatar sql.NullString 334 hotRank sql.NullFloat64 335 ) 336 337 err := rows.Scan( 338 &postView.URI, &postView.CID, &postView.RKey, 339 &authorView.DID, &authorView.Handle, 340 &communityRef.DID, &communityRef.Name, &communityAvatar, 341 &title, &content, &facets, &embed, &labelsJSON, 342 &postView.CreatedAt, &editedAt, &postView.IndexedAt, 343 &postView.UpvoteCount, &postView.DownvoteCount, &postView.Score, &postView.CommentCount, 344 &hotRank, 345 ) 346 if err != nil { 347 return nil, 0, err 348 } 349 350 // Build author view (no display_name or avatar in users table yet) 351 postView.Author = &authorView 352 353 // Build community ref 354 communityRef.Avatar = nullStringPtr(communityAvatar) 355 postView.Community = &communityRef 356 357 // Set optional fields 358 postView.Title = nullStringPtr(title) 359 postView.Text = nullStringPtr(content) 360 361 // Parse facets JSON 362 if facets.Valid { 363 var facetArray []interface{} 364 if err := json.Unmarshal([]byte(facets.String), &facetArray); err == nil { 365 postView.TextFacets = facetArray 366 } 367 } 368 369 // Parse embed JSON 370 if embed.Valid { 371 var embedData interface{} 372 if err := json.Unmarshal([]byte(embed.String), &embedData); err == nil { 373 postView.Embed = embedData 374 } 375 } 376 377 // Build stats 378 postView.Stats = &posts.PostStats{ 379 Upvotes: postView.UpvoteCount, 380 Downvotes: postView.DownvoteCount, 381 Score: postView.Score, 382 CommentCount: postView.CommentCount, 383 } 384 385 // Alpha: No viewer state for basic feed 386 // TODO(feed-generator): Implement viewer state (saved, voted, blocked) in feed generator skeleton 387 388 // Build the record (required by lexicon - social.coves.community.post structure) 389 record := map[string]interface{}{ 390 "$type": "social.coves.community.post", 391 "community": communityRef.DID, 392 "author": authorView.DID, 393 "createdAt": postView.CreatedAt.Format(time.RFC3339), 394 } 395 396 // Add optional fields to record if present 397 if title.Valid { 398 record["title"] = title.String 399 } 400 if content.Valid { 401 record["content"] = content.String 402 } 403 if facets.Valid { 404 var facetArray []interface{} 405 if err := json.Unmarshal([]byte(facets.String), &facetArray); err == nil { 406 record["facets"] = facetArray 407 } 408 } 409 if embed.Valid { 410 var embedData interface{} 411 if err := json.Unmarshal([]byte(embed.String), &embedData); err == nil { 412 record["embed"] = embedData 413 } 414 } 415 if labelsJSON.Valid { 416 // Labels are stored as JSONB containing full com.atproto.label.defs#selfLabels structure 417 // Deserialize and include in record 418 var selfLabels posts.SelfLabels 419 if err := json.Unmarshal([]byte(labelsJSON.String), &selfLabels); err == nil { 420 record["labels"] = selfLabels 421 } 422 } 423 424 postView.Record = record 425 426 // Wrap in FeedViewPost 427 feedPost := &communityFeeds.FeedViewPost{ 428 Post: &postView, 429 // Reason: nil, // TODO(feed-generator): Implement pinned posts 430 // Reply: nil, // TODO(feed-generator): Implement reply context 431 } 432 433 // Return the computed hot_rank (0.0 if NULL for non-hot sorts) 434 hotRankValue := 0.0 435 if hotRank.Valid { 436 hotRankValue = hotRank.Float64 437 } 438 439 return feedPost, hotRankValue, nil 440} 441 442// Helper function to convert sql.NullString to *string 443func nullStringPtr(ns sql.NullString) *string { 444 if !ns.Valid { 445 return nil 446 } 447 return &ns.String 448}