A community based topic aggregation platform built on atproto
1package postgres
2
3import (
4 "context"
5 "database/sql"
6 "fmt"
7 "log"
8 "strings"
9
10 "Coves/internal/core/users"
11
12 "github.com/lib/pq"
13)
14
15type postgresUserRepo struct {
16 db *sql.DB
17}
18
19// NewUserRepository creates a new PostgreSQL user repository
20func NewUserRepository(db *sql.DB) users.UserRepository {
21 return &postgresUserRepo{db: db}
22}
23
24// Create inserts a new user into the users table
25func (r *postgresUserRepo) Create(ctx context.Context, user *users.User) (*users.User, error) {
26 query := `
27 INSERT INTO users (did, handle, pds_url)
28 VALUES ($1, $2, $3)
29 RETURNING did, handle, pds_url, created_at, updated_at`
30
31 err := r.db.QueryRowContext(ctx, query, user.DID, user.Handle, user.PDSURL).
32 Scan(&user.DID, &user.Handle, &user.PDSURL, &user.CreatedAt, &user.UpdatedAt)
33 if err != nil {
34 // Check for unique constraint violations
35 if strings.Contains(err.Error(), "duplicate key") {
36 if strings.Contains(err.Error(), "users_pkey") {
37 return nil, fmt.Errorf("user with DID already exists")
38 }
39 if strings.Contains(err.Error(), "users_handle_key") {
40 return nil, fmt.Errorf("handle already taken")
41 }
42 }
43 return nil, fmt.Errorf("failed to create user: %w", err)
44 }
45
46 return user, nil
47}
48
49// GetByDID retrieves a user by their DID
50func (r *postgresUserRepo) GetByDID(ctx context.Context, did string) (*users.User, error) {
51 user := &users.User{}
52 query := `SELECT did, handle, pds_url, created_at, updated_at FROM users WHERE did = $1`
53
54 err := r.db.QueryRowContext(ctx, query, did).
55 Scan(&user.DID, &user.Handle, &user.PDSURL, &user.CreatedAt, &user.UpdatedAt)
56
57 if err == sql.ErrNoRows {
58 return nil, fmt.Errorf("user not found")
59 }
60 if err != nil {
61 return nil, fmt.Errorf("failed to get user by DID: %w", err)
62 }
63
64 return user, nil
65}
66
67// GetByHandle retrieves a user by their handle
68func (r *postgresUserRepo) GetByHandle(ctx context.Context, handle string) (*users.User, error) {
69 user := &users.User{}
70 query := `SELECT did, handle, pds_url, created_at, updated_at FROM users WHERE handle = $1`
71
72 err := r.db.QueryRowContext(ctx, query, handle).
73 Scan(&user.DID, &user.Handle, &user.PDSURL, &user.CreatedAt, &user.UpdatedAt)
74
75 if err == sql.ErrNoRows {
76 return nil, fmt.Errorf("user not found")
77 }
78 if err != nil {
79 return nil, fmt.Errorf("failed to get user by handle: %w", err)
80 }
81
82 return user, nil
83}
84
85// UpdateHandle updates the handle for a user with the given DID
86func (r *postgresUserRepo) UpdateHandle(ctx context.Context, did, newHandle string) (*users.User, error) {
87 user := &users.User{}
88 query := `
89 UPDATE users
90 SET handle = $2, updated_at = NOW()
91 WHERE did = $1
92 RETURNING did, handle, pds_url, created_at, updated_at`
93
94 err := r.db.QueryRowContext(ctx, query, did, newHandle).
95 Scan(&user.DID, &user.Handle, &user.PDSURL, &user.CreatedAt, &user.UpdatedAt)
96
97 if err == sql.ErrNoRows {
98 return nil, fmt.Errorf("user not found")
99 }
100 if err != nil {
101 // Check for unique constraint violation on handle
102 if strings.Contains(err.Error(), "duplicate key") && strings.Contains(err.Error(), "users_handle_key") {
103 return nil, fmt.Errorf("handle already taken")
104 }
105 return nil, fmt.Errorf("failed to update handle: %w", err)
106 }
107
108 return user, nil
109}
110
111const MaxBatchSize = 1000
112
113// GetByDIDs retrieves multiple users by their DIDs in a single query
114// Returns a map of DID -> User for efficient lookups
115// Missing users are not included in the result map (no error for missing users)
116func (r *postgresUserRepo) GetByDIDs(ctx context.Context, dids []string) (map[string]*users.User, error) {
117 if len(dids) == 0 {
118 return make(map[string]*users.User), nil
119 }
120
121 // Validate batch size to prevent excessive memory usage and query timeouts
122 if len(dids) > MaxBatchSize {
123 return nil, fmt.Errorf("batch size %d exceeds maximum %d", len(dids), MaxBatchSize)
124 }
125
126 // Validate DID format to prevent SQL injection and malformed queries
127 // All atProto DIDs must start with "did:" prefix
128 for _, did := range dids {
129 if !strings.HasPrefix(did, "did:") {
130 return nil, fmt.Errorf("invalid DID format: %s", did)
131 }
132 }
133
134 // Build parameterized query with IN clause
135 // Use ANY($1) for PostgreSQL array support with pq.Array() for type conversion
136 query := `SELECT did, handle, pds_url, created_at, updated_at FROM users WHERE did = ANY($1)`
137
138 rows, err := r.db.QueryContext(ctx, query, pq.Array(dids))
139 if err != nil {
140 return nil, fmt.Errorf("failed to query users by DIDs: %w", err)
141 }
142 defer func() {
143 if closeErr := rows.Close(); closeErr != nil {
144 log.Printf("Warning: Failed to close rows: %v", closeErr)
145 }
146 }()
147
148 // Build map of results
149 result := make(map[string]*users.User, len(dids))
150 for rows.Next() {
151 user := &users.User{}
152 err := rows.Scan(&user.DID, &user.Handle, &user.PDSURL, &user.CreatedAt, &user.UpdatedAt)
153 if err != nil {
154 return nil, fmt.Errorf("failed to scan user row: %w", err)
155 }
156 result[user.DID] = user
157 }
158
159 if err = rows.Err(); err != nil {
160 return nil, fmt.Errorf("error iterating user rows: %w", err)
161 }
162
163 return result, nil
164}