A community based topic aggregation platform built on atproto
1package postgres
2
3import (
4 "Coves/internal/core/users"
5 "context"
6 "database/sql"
7 "fmt"
8 "strings"
9
10 "github.com/lib/pq"
11)
12
13type postgresUserRepo struct {
14 db *sql.DB
15}
16
17// NewUserRepository creates a new PostgreSQL user repository
18func NewUserRepository(db *sql.DB) users.UserRepository {
19 return &postgresUserRepo{db: db}
20}
21
22// Create inserts a new user into the users table
23func (r *postgresUserRepo) Create(ctx context.Context, user *users.User) (*users.User, error) {
24 query := `
25 INSERT INTO users (did, handle, pds_url)
26 VALUES ($1, $2, $3)
27 RETURNING did, handle, pds_url, created_at, updated_at`
28
29 err := r.db.QueryRowContext(ctx, query, user.DID, user.Handle, user.PDSURL).
30 Scan(&user.DID, &user.Handle, &user.PDSURL, &user.CreatedAt, &user.UpdatedAt)
31 if err != nil {
32 // Check for unique constraint violations
33 if strings.Contains(err.Error(), "duplicate key") {
34 if strings.Contains(err.Error(), "users_pkey") {
35 return nil, fmt.Errorf("user with DID already exists")
36 }
37 if strings.Contains(err.Error(), "users_handle_key") {
38 return nil, fmt.Errorf("handle already taken")
39 }
40 }
41 return nil, fmt.Errorf("failed to create user: %w", err)
42 }
43
44 return user, nil
45}
46
47// GetByDID retrieves a user by their DID
48func (r *postgresUserRepo) GetByDID(ctx context.Context, did string) (*users.User, error) {
49 user := &users.User{}
50 query := `SELECT did, handle, pds_url, created_at, updated_at FROM users WHERE did = $1`
51
52 err := r.db.QueryRowContext(ctx, query, did).
53 Scan(&user.DID, &user.Handle, &user.PDSURL, &user.CreatedAt, &user.UpdatedAt)
54
55 if err == sql.ErrNoRows {
56 return nil, fmt.Errorf("user not found")
57 }
58 if err != nil {
59 return nil, fmt.Errorf("failed to get user by DID: %w", err)
60 }
61
62 return user, nil
63}
64
65// GetByHandle retrieves a user by their handle
66func (r *postgresUserRepo) GetByHandle(ctx context.Context, handle string) (*users.User, error) {
67 user := &users.User{}
68 query := `SELECT did, handle, pds_url, created_at, updated_at FROM users WHERE handle = $1`
69
70 err := r.db.QueryRowContext(ctx, query, handle).
71 Scan(&user.DID, &user.Handle, &user.PDSURL, &user.CreatedAt, &user.UpdatedAt)
72
73 if err == sql.ErrNoRows {
74 return nil, fmt.Errorf("user not found")
75 }
76 if err != nil {
77 return nil, fmt.Errorf("failed to get user by handle: %w", err)
78 }
79
80 return user, nil
81}
82
83// UpdateHandle updates the handle for a user with the given DID
84func (r *postgresUserRepo) UpdateHandle(ctx context.Context, did, newHandle string) (*users.User, error) {
85 user := &users.User{}
86 query := `
87 UPDATE users
88 SET handle = $2, updated_at = NOW()
89 WHERE did = $1
90 RETURNING did, handle, pds_url, created_at, updated_at`
91
92 err := r.db.QueryRowContext(ctx, query, did, newHandle).
93 Scan(&user.DID, &user.Handle, &user.PDSURL, &user.CreatedAt, &user.UpdatedAt)
94
95 if err == sql.ErrNoRows {
96 return nil, fmt.Errorf("user not found")
97 }
98 if err != nil {
99 // Check for unique constraint violation on handle
100 if strings.Contains(err.Error(), "duplicate key") && strings.Contains(err.Error(), "users_handle_key") {
101 return nil, fmt.Errorf("handle already taken")
102 }
103 return nil, fmt.Errorf("failed to update handle: %w", err)
104 }
105
106 return user, nil
107}
108
109const MaxBatchSize = 1000
110
111// GetByDIDs retrieves multiple users by their DIDs in a single query
112// Returns a map of DID -> User for efficient lookups
113// Missing users are not included in the result map (no error for missing users)
114func (r *postgresUserRepo) GetByDIDs(ctx context.Context, dids []string) (map[string]*users.User, error) {
115 if len(dids) == 0 {
116 return make(map[string]*users.User), nil
117 }
118
119 // Validate batch size to prevent excessive memory usage and query timeouts
120 if len(dids) > MaxBatchSize {
121 return nil, fmt.Errorf("batch size %d exceeds maximum %d", len(dids), MaxBatchSize)
122 }
123
124 // Validate DID format to prevent SQL injection and malformed queries
125 // All atProto DIDs must start with "did:" prefix
126 for _, did := range dids {
127 if !strings.HasPrefix(did, "did:") {
128 return nil, fmt.Errorf("invalid DID format: %s", did)
129 }
130 }
131
132 // Build parameterized query with IN clause
133 // Use ANY($1) for PostgreSQL array support with pq.Array() for type conversion
134 query := `SELECT did, handle, pds_url, created_at, updated_at FROM users WHERE did = ANY($1)`
135
136 rows, err := r.db.QueryContext(ctx, query, pq.Array(dids))
137 if err != nil {
138 return nil, fmt.Errorf("failed to query users by DIDs: %w", err)
139 }
140 defer rows.Close()
141
142 // Build map of results
143 result := make(map[string]*users.User, len(dids))
144 for rows.Next() {
145 user := &users.User{}
146 err := rows.Scan(&user.DID, &user.Handle, &user.PDSURL, &user.CreatedAt, &user.UpdatedAt)
147 if err != nil {
148 return nil, fmt.Errorf("failed to scan user row: %w", err)
149 }
150 result[user.DID] = user
151 }
152
153 if err = rows.Err(); err != nil {
154 return nil, fmt.Errorf("error iterating user rows: %w", err)
155 }
156
157 return result, nil
158}