A community based topic aggregation platform built on atproto
1package postgres 2 3import ( 4 "context" 5 "database/sql" 6 "fmt" 7 "log" 8 "strings" 9 10 "Coves/internal/core/users" 11 12 "github.com/lib/pq" 13) 14 15type postgresUserRepo struct { 16 db *sql.DB 17} 18 19// NewUserRepository creates a new PostgreSQL user repository 20func NewUserRepository(db *sql.DB) users.UserRepository { 21 return &postgresUserRepo{db: db} 22} 23 24// Create inserts a new user into the users table 25func (r *postgresUserRepo) Create(ctx context.Context, user *users.User) (*users.User, error) { 26 query := ` 27 INSERT INTO users (did, handle, pds_url) 28 VALUES ($1, $2, $3) 29 RETURNING did, handle, pds_url, created_at, updated_at` 30 31 err := r.db.QueryRowContext(ctx, query, user.DID, user.Handle, user.PDSURL). 32 Scan(&user.DID, &user.Handle, &user.PDSURL, &user.CreatedAt, &user.UpdatedAt) 33 if err != nil { 34 // Check for unique constraint violations 35 if strings.Contains(err.Error(), "duplicate key") { 36 if strings.Contains(err.Error(), "users_pkey") { 37 return nil, fmt.Errorf("user with DID already exists") 38 } 39 if strings.Contains(err.Error(), "users_handle_key") { 40 return nil, fmt.Errorf("handle already taken") 41 } 42 } 43 return nil, fmt.Errorf("failed to create user: %w", err) 44 } 45 46 return user, nil 47} 48 49// GetByDID retrieves a user by their DID 50func (r *postgresUserRepo) GetByDID(ctx context.Context, did string) (*users.User, error) { 51 user := &users.User{} 52 query := `SELECT did, handle, pds_url, created_at, updated_at FROM users WHERE did = $1` 53 54 err := r.db.QueryRowContext(ctx, query, did). 55 Scan(&user.DID, &user.Handle, &user.PDSURL, &user.CreatedAt, &user.UpdatedAt) 56 57 if err == sql.ErrNoRows { 58 return nil, fmt.Errorf("user not found") 59 } 60 if err != nil { 61 return nil, fmt.Errorf("failed to get user by DID: %w", err) 62 } 63 64 return user, nil 65} 66 67// GetByHandle retrieves a user by their handle 68func (r *postgresUserRepo) GetByHandle(ctx context.Context, handle string) (*users.User, error) { 69 user := &users.User{} 70 query := `SELECT did, handle, pds_url, created_at, updated_at FROM users WHERE handle = $1` 71 72 err := r.db.QueryRowContext(ctx, query, handle). 73 Scan(&user.DID, &user.Handle, &user.PDSURL, &user.CreatedAt, &user.UpdatedAt) 74 75 if err == sql.ErrNoRows { 76 return nil, fmt.Errorf("user not found") 77 } 78 if err != nil { 79 return nil, fmt.Errorf("failed to get user by handle: %w", err) 80 } 81 82 return user, nil 83} 84 85// UpdateHandle updates the handle for a user with the given DID 86func (r *postgresUserRepo) UpdateHandle(ctx context.Context, did, newHandle string) (*users.User, error) { 87 user := &users.User{} 88 query := ` 89 UPDATE users 90 SET handle = $2, updated_at = NOW() 91 WHERE did = $1 92 RETURNING did, handle, pds_url, created_at, updated_at` 93 94 err := r.db.QueryRowContext(ctx, query, did, newHandle). 95 Scan(&user.DID, &user.Handle, &user.PDSURL, &user.CreatedAt, &user.UpdatedAt) 96 97 if err == sql.ErrNoRows { 98 return nil, fmt.Errorf("user not found") 99 } 100 if err != nil { 101 // Check for unique constraint violation on handle 102 if strings.Contains(err.Error(), "duplicate key") && strings.Contains(err.Error(), "users_handle_key") { 103 return nil, fmt.Errorf("handle already taken") 104 } 105 return nil, fmt.Errorf("failed to update handle: %w", err) 106 } 107 108 return user, nil 109} 110 111const MaxBatchSize = 1000 112 113// GetByDIDs retrieves multiple users by their DIDs in a single query 114// Returns a map of DID -> User for efficient lookups 115// Missing users are not included in the result map (no error for missing users) 116func (r *postgresUserRepo) GetByDIDs(ctx context.Context, dids []string) (map[string]*users.User, error) { 117 if len(dids) == 0 { 118 return make(map[string]*users.User), nil 119 } 120 121 // Validate batch size to prevent excessive memory usage and query timeouts 122 if len(dids) > MaxBatchSize { 123 return nil, fmt.Errorf("batch size %d exceeds maximum %d", len(dids), MaxBatchSize) 124 } 125 126 // Validate DID format to prevent SQL injection and malformed queries 127 // All atProto DIDs must start with "did:" prefix 128 for _, did := range dids { 129 if !strings.HasPrefix(did, "did:") { 130 return nil, fmt.Errorf("invalid DID format: %s", did) 131 } 132 } 133 134 // Build parameterized query with IN clause 135 // Use ANY($1) for PostgreSQL array support with pq.Array() for type conversion 136 query := `SELECT did, handle, pds_url, created_at, updated_at FROM users WHERE did = ANY($1)` 137 138 rows, err := r.db.QueryContext(ctx, query, pq.Array(dids)) 139 if err != nil { 140 return nil, fmt.Errorf("failed to query users by DIDs: %w", err) 141 } 142 defer func() { 143 if closeErr := rows.Close(); closeErr != nil { 144 log.Printf("Warning: Failed to close rows: %v", closeErr) 145 } 146 }() 147 148 // Build map of results 149 result := make(map[string]*users.User, len(dids)) 150 for rows.Next() { 151 user := &users.User{} 152 err := rows.Scan(&user.DID, &user.Handle, &user.PDSURL, &user.CreatedAt, &user.UpdatedAt) 153 if err != nil { 154 return nil, fmt.Errorf("failed to scan user row: %w", err) 155 } 156 result[user.DID] = user 157 } 158 159 if err = rows.Err(); err != nil { 160 return nil, fmt.Errorf("error iterating user rows: %w", err) 161 } 162 163 return result, nil 164}