A community based topic aggregation platform built on atproto
1package postgres 2 3import ( 4 "Coves/internal/core/communityFeeds" 5 "Coves/internal/core/posts" 6 "context" 7 "database/sql" 8 "encoding/base64" 9 "encoding/json" 10 "fmt" 11 "strconv" 12 "strings" 13 "time" 14 15 "github.com/lib/pq" 16) 17 18type postgresFeedRepo struct { 19 db *sql.DB 20} 21 22// sortClauses maps sort types to safe SQL ORDER BY clauses 23// This whitelist prevents SQL injection via dynamic ORDER BY construction 24var sortClauses = map[string]string{ 25 "hot": `(p.score / POWER(EXTRACT(EPOCH FROM (NOW() - p.created_at))/3600 + 2, 1.5)) DESC, p.created_at DESC, p.uri DESC`, 26 "top": `p.score DESC, p.created_at DESC, p.uri DESC`, 27 "new": `p.created_at DESC, p.uri DESC`, 28} 29 30// hotRankExpression is the SQL expression for computing the hot rank 31// NOTE: Uses NOW() which means hot_rank changes over time - this is expected behavior 32// for hot sorting (posts naturally age out). Slight time drift between cursor creation 33// and usage may cause minor reordering but won't drop posts entirely (unlike using raw score). 34const hotRankExpression = `(p.score / POWER(EXTRACT(EPOCH FROM (NOW() - p.created_at))/3600 + 2, 1.5))` 35 36// NewCommunityFeedRepository creates a new PostgreSQL feed repository 37func NewCommunityFeedRepository(db *sql.DB) communityFeeds.Repository { 38 return &postgresFeedRepo{db: db} 39} 40 41// GetCommunityFeed retrieves posts from a community with sorting and pagination 42// Single query with JOINs for optimal performance 43func (r *postgresFeedRepo) GetCommunityFeed(ctx context.Context, req communityFeeds.GetCommunityFeedRequest) ([]*communityFeeds.FeedViewPost, *string, error) { 44 // Build ORDER BY clause based on sort type 45 orderBy, timeFilter := r.buildSortClause(req.Sort, req.Timeframe) 46 47 // Build cursor filter for pagination 48 cursorFilter, cursorValues, err := r.parseCursor(req.Cursor, req.Sort) 49 if err != nil { 50 return nil, nil, communityFeeds.ErrInvalidCursor 51 } 52 53 // Build the main query 54 // For hot sort, we need to compute and return the hot_rank for cursor building 55 var selectClause string 56 if req.Sort == "hot" { 57 selectClause = fmt.Sprintf(` 58 SELECT 59 p.uri, p.cid, p.rkey, 60 p.author_did, u.handle as author_handle, 61 p.community_did, c.name as community_name, c.avatar_cid as community_avatar, 62 p.title, p.content, p.content_facets, p.embed, p.content_labels, 63 p.created_at, p.edited_at, p.indexed_at, 64 p.upvote_count, p.downvote_count, p.score, p.comment_count, 65 %s as hot_rank 66 FROM posts p`, hotRankExpression) 67 } else { 68 selectClause = ` 69 SELECT 70 p.uri, p.cid, p.rkey, 71 p.author_did, u.handle as author_handle, 72 p.community_did, c.name as community_name, c.avatar_cid as community_avatar, 73 p.title, p.content, p.content_facets, p.embed, p.content_labels, 74 p.created_at, p.edited_at, p.indexed_at, 75 p.upvote_count, p.downvote_count, p.score, p.comment_count, 76 NULL::numeric as hot_rank 77 FROM posts p` 78 } 79 80 query := fmt.Sprintf(` 81 %s 82 INNER JOIN users u ON p.author_did = u.did 83 INNER JOIN communities c ON p.community_did = c.did 84 WHERE p.community_did = $1 85 AND p.deleted_at IS NULL 86 %s 87 %s 88 ORDER BY %s 89 LIMIT $2 90 `, selectClause, timeFilter, cursorFilter, orderBy) 91 92 // Prepare query arguments 93 args := []interface{}{req.Community, req.Limit + 1} // +1 to check for next page 94 args = append(args, cursorValues...) 95 96 // Execute query 97 rows, err := r.db.QueryContext(ctx, query, args...) 98 if err != nil { 99 return nil, nil, fmt.Errorf("failed to query community feed: %w", err) 100 } 101 defer func() { 102 if err := rows.Close(); err != nil { 103 // Log close errors (non-fatal but worth noting) 104 fmt.Printf("Warning: failed to close rows: %v\n", err) 105 } 106 }() 107 108 // Scan results 109 var feedPosts []*communityFeeds.FeedViewPost 110 var hotRanks []float64 // Store hot ranks for cursor building 111 for rows.Next() { 112 feedPost, hotRank, err := r.scanFeedViewPost(rows) 113 if err != nil { 114 return nil, nil, fmt.Errorf("failed to scan feed post: %w", err) 115 } 116 feedPosts = append(feedPosts, feedPost) 117 hotRanks = append(hotRanks, hotRank) 118 } 119 120 if err := rows.Err(); err != nil { 121 return nil, nil, fmt.Errorf("error iterating feed results: %w", err) 122 } 123 124 // Handle pagination cursor 125 var cursor *string 126 if len(feedPosts) > req.Limit && req.Limit > 0 { 127 feedPosts = feedPosts[:req.Limit] 128 hotRanks = hotRanks[:req.Limit] 129 lastPost := feedPosts[len(feedPosts)-1].Post 130 lastHotRank := hotRanks[len(hotRanks)-1] 131 cursorStr := r.buildCursor(lastPost, req.Sort, lastHotRank) 132 cursor = &cursorStr 133 } 134 135 return feedPosts, cursor, nil 136} 137 138// buildSortClause returns the ORDER BY SQL and optional time filter 139func (r *postgresFeedRepo) buildSortClause(sort, timeframe string) (string, string) { 140 // Use whitelist map for ORDER BY clause (defense-in-depth against SQL injection) 141 orderBy := sortClauses[sort] 142 if orderBy == "" { 143 orderBy = sortClauses["hot"] // safe default 144 } 145 146 // Add time filter for "top" sort 147 var timeFilter string 148 if sort == "top" { 149 timeFilter = r.buildTimeFilter(timeframe) 150 } 151 152 return orderBy, timeFilter 153} 154 155// buildTimeFilter returns SQL filter for timeframe 156func (r *postgresFeedRepo) buildTimeFilter(timeframe string) string { 157 if timeframe == "" || timeframe == "all" { 158 return "" 159 } 160 161 var interval string 162 switch timeframe { 163 case "hour": 164 interval = "1 hour" 165 case "day": 166 interval = "1 day" 167 case "week": 168 interval = "1 week" 169 case "month": 170 interval = "1 month" 171 case "year": 172 interval = "1 year" 173 default: 174 return "" 175 } 176 177 return fmt.Sprintf("AND p.created_at > NOW() - INTERVAL '%s'", interval) 178} 179 180// parseCursor decodes pagination cursor 181func (r *postgresFeedRepo) parseCursor(cursor *string, sort string) (string, []interface{}, error) { 182 if cursor == nil || *cursor == "" { 183 return "", nil, nil 184 } 185 186 // Decode base64 cursor 187 decoded, err := base64.StdEncoding.DecodeString(*cursor) 188 if err != nil { 189 return "", nil, fmt.Errorf("invalid cursor encoding") 190 } 191 192 // Parse cursor based on sort type using :: delimiter (Bluesky convention) 193 parts := strings.Split(string(decoded), "::") 194 195 switch sort { 196 case "new": 197 // Cursor format: timestamp::uri 198 if len(parts) != 2 { 199 return "", nil, fmt.Errorf("invalid cursor format") 200 } 201 202 createdAt := parts[0] 203 uri := parts[1] 204 205 // Validate timestamp format 206 if _, err := time.Parse(time.RFC3339Nano, createdAt); err != nil { 207 return "", nil, fmt.Errorf("invalid cursor timestamp") 208 } 209 210 // Validate URI format (must be AT-URI) 211 if !strings.HasPrefix(uri, "at://") { 212 return "", nil, fmt.Errorf("invalid cursor URI") 213 } 214 215 filter := `AND (p.created_at < $3 OR (p.created_at = $3 AND p.uri < $4))` 216 return filter, []interface{}{createdAt, uri}, nil 217 218 case "top": 219 // Cursor format: score::timestamp::uri 220 if len(parts) != 3 { 221 return "", nil, fmt.Errorf("invalid cursor format for %s sort", sort) 222 } 223 224 scoreStr := parts[0] 225 createdAt := parts[1] 226 uri := parts[2] 227 228 // Validate score is numeric 229 score := 0 230 if _, err := fmt.Sscanf(scoreStr, "%d", &score); err != nil { 231 return "", nil, fmt.Errorf("invalid cursor score") 232 } 233 234 // Validate timestamp format 235 if _, err := time.Parse(time.RFC3339Nano, createdAt); err != nil { 236 return "", nil, fmt.Errorf("invalid cursor timestamp") 237 } 238 239 // Validate URI format (must be AT-URI) 240 if !strings.HasPrefix(uri, "at://") { 241 return "", nil, fmt.Errorf("invalid cursor URI") 242 } 243 244 filter := `AND (p.score < $3 OR (p.score = $3 AND p.created_at < $4) OR (p.score = $3 AND p.created_at = $4 AND p.uri < $5))` 245 return filter, []interface{}{score, createdAt, uri}, nil 246 247 case "hot": 248 // Cursor format: hot_rank::timestamp::uri 249 // CRITICAL: Must use computed hot_rank, not raw score, to prevent pagination bugs 250 if len(parts) != 3 { 251 return "", nil, fmt.Errorf("invalid cursor format for hot sort") 252 } 253 254 hotRankStr := parts[0] 255 createdAt := parts[1] 256 uri := parts[2] 257 258 // Validate hot_rank is numeric (float) 259 hotRank := 0.0 260 if _, err := fmt.Sscanf(hotRankStr, "%f", &hotRank); err != nil { 261 return "", nil, fmt.Errorf("invalid cursor hot rank") 262 } 263 264 // Validate timestamp format 265 if _, err := time.Parse(time.RFC3339Nano, createdAt); err != nil { 266 return "", nil, fmt.Errorf("invalid cursor timestamp") 267 } 268 269 // Validate URI format (must be AT-URI) 270 if !strings.HasPrefix(uri, "at://") { 271 return "", nil, fmt.Errorf("invalid cursor URI") 272 } 273 274 // CRITICAL: Compare against the computed hot_rank expression, not p.score 275 // This prevents dropping posts with higher raw scores but lower hot ranks 276 // 277 // NOTE: We exclude the exact cursor post by URI to handle time drift in hot_rank 278 // (hot_rank changes with NOW(), so the same post may have different ranks over time) 279 filter := fmt.Sprintf(`AND ((%s < $3 OR (%s = $3 AND p.created_at < $4) OR (%s = $3 AND p.created_at = $4 AND p.uri < $5)) AND p.uri != $6)`, 280 hotRankExpression, hotRankExpression, hotRankExpression) 281 return filter, []interface{}{hotRank, createdAt, uri, uri}, nil 282 283 default: 284 return "", nil, nil 285 } 286} 287 288// buildCursor creates pagination cursor from last post 289func (r *postgresFeedRepo) buildCursor(post *posts.PostView, sort string, hotRank float64) string { 290 var cursorStr string 291 // Use :: as delimiter following Bluesky convention 292 // Safe because :: doesn't appear in ISO timestamps or AT-URIs 293 const delimiter = "::" 294 295 switch sort { 296 case "new": 297 // Format: timestamp::uri (following Bluesky pattern) 298 cursorStr = fmt.Sprintf("%s%s%s", post.CreatedAt.Format(time.RFC3339Nano), delimiter, post.URI) 299 300 case "top": 301 // Format: score::timestamp::uri 302 score := 0 303 if post.Stats != nil { 304 score = post.Stats.Score 305 } 306 cursorStr = fmt.Sprintf("%d%s%s%s%s", score, delimiter, post.CreatedAt.Format(time.RFC3339Nano), delimiter, post.URI) 307 308 case "hot": 309 // Format: hot_rank::timestamp::uri 310 // CRITICAL: Use computed hot_rank with full precision to prevent pagination bugs 311 // Using 'g' format with -1 precision gives us full float64 precision without trailing zeros 312 // This prevents posts being dropped when hot ranks differ by <1e-6 313 hotRankStr := strconv.FormatFloat(hotRank, 'g', -1, 64) 314 cursorStr = fmt.Sprintf("%s%s%s%s%s", hotRankStr, delimiter, post.CreatedAt.Format(time.RFC3339Nano), delimiter, post.URI) 315 316 default: 317 cursorStr = post.URI 318 } 319 320 return base64.StdEncoding.EncodeToString([]byte(cursorStr)) 321} 322 323// scanFeedViewPost scans a row into FeedViewPost 324// Alpha: No viewer state - basic community feed only 325func (r *postgresFeedRepo) scanFeedViewPost(rows *sql.Rows) (*communityFeeds.FeedViewPost, float64, error) { 326 var ( 327 postView posts.PostView 328 authorView posts.AuthorView 329 communityRef posts.CommunityRef 330 title, content sql.NullString 331 facets, embed sql.NullString 332 labels pq.StringArray 333 editedAt sql.NullTime 334 communityAvatar sql.NullString 335 hotRank sql.NullFloat64 336 ) 337 338 err := rows.Scan( 339 &postView.URI, &postView.CID, &postView.RKey, 340 &authorView.DID, &authorView.Handle, 341 &communityRef.DID, &communityRef.Name, &communityAvatar, 342 &title, &content, &facets, &embed, &labels, 343 &postView.CreatedAt, &editedAt, &postView.IndexedAt, 344 &postView.UpvoteCount, &postView.DownvoteCount, &postView.Score, &postView.CommentCount, 345 &hotRank, 346 ) 347 if err != nil { 348 return nil, 0, err 349 } 350 351 // Build author view (no display_name or avatar in users table yet) 352 postView.Author = &authorView 353 354 // Build community ref 355 communityRef.Avatar = nullStringPtr(communityAvatar) 356 postView.Community = &communityRef 357 358 // Set optional fields 359 postView.Title = nullStringPtr(title) 360 postView.Text = nullStringPtr(content) 361 362 // Parse facets JSON 363 if facets.Valid { 364 var facetArray []interface{} 365 if err := json.Unmarshal([]byte(facets.String), &facetArray); err == nil { 366 postView.TextFacets = facetArray 367 } 368 } 369 370 // Parse embed JSON 371 if embed.Valid { 372 var embedData interface{} 373 if err := json.Unmarshal([]byte(embed.String), &embedData); err == nil { 374 postView.Embed = embedData 375 } 376 } 377 378 // Build stats 379 postView.Stats = &posts.PostStats{ 380 Upvotes: postView.UpvoteCount, 381 Downvotes: postView.DownvoteCount, 382 Score: postView.Score, 383 CommentCount: postView.CommentCount, 384 } 385 386 // Alpha: No viewer state for basic feed 387 // TODO(feed-generator): Implement viewer state (saved, voted, blocked) in feed generator skeleton 388 389 // Build the record (required by lexicon - social.coves.post.record structure) 390 record := map[string]interface{}{ 391 "$type": "social.coves.post.record", 392 "community": communityRef.DID, 393 "author": authorView.DID, 394 "createdAt": postView.CreatedAt.Format(time.RFC3339), 395 } 396 397 // Add optional fields to record if present 398 if title.Valid { 399 record["title"] = title.String 400 } 401 if content.Valid { 402 record["content"] = content.String 403 } 404 if facets.Valid { 405 var facetArray []interface{} 406 if err := json.Unmarshal([]byte(facets.String), &facetArray); err == nil { 407 record["facets"] = facetArray 408 } 409 } 410 if embed.Valid { 411 var embedData interface{} 412 if err := json.Unmarshal([]byte(embed.String), &embedData); err == nil { 413 record["embed"] = embedData 414 } 415 } 416 if len(labels) > 0 { 417 record["contentLabels"] = labels 418 } 419 420 postView.Record = record 421 422 // Wrap in FeedViewPost 423 feedPost := &communityFeeds.FeedViewPost{ 424 Post: &postView, 425 // Reason: nil, // TODO(feed-generator): Implement pinned posts 426 // Reply: nil, // TODO(feed-generator): Implement reply context 427 } 428 429 // Return the computed hot_rank (0.0 if NULL for non-hot sorts) 430 hotRankValue := 0.0 431 if hotRank.Valid { 432 hotRankValue = hotRank.Float64 433 } 434 435 return feedPost, hotRankValue, nil 436} 437 438// Helper function to convert sql.NullString to *string 439func nullStringPtr(ns sql.NullString) *string { 440 if !ns.Valid { 441 return nil 442 } 443 return &ns.String 444}