A community based topic aggregation platform built on atproto
1package postgres
2
3import (
4 "Coves/internal/core/users"
5 "context"
6 "database/sql"
7 "fmt"
8 "log"
9 "strings"
10
11 "github.com/lib/pq"
12)
13
14type postgresUserRepo struct {
15 db *sql.DB
16}
17
18// NewUserRepository creates a new PostgreSQL user repository
19func NewUserRepository(db *sql.DB) users.UserRepository {
20 return &postgresUserRepo{db: db}
21}
22
23// Create inserts a new user into the users table
24func (r *postgresUserRepo) Create(ctx context.Context, user *users.User) (*users.User, error) {
25 query := `
26 INSERT INTO users (did, handle, pds_url)
27 VALUES ($1, $2, $3)
28 RETURNING did, handle, pds_url, created_at, updated_at`
29
30 err := r.db.QueryRowContext(ctx, query, user.DID, user.Handle, user.PDSURL).
31 Scan(&user.DID, &user.Handle, &user.PDSURL, &user.CreatedAt, &user.UpdatedAt)
32 if err != nil {
33 // Check for unique constraint violations
34 if strings.Contains(err.Error(), "duplicate key") {
35 if strings.Contains(err.Error(), "users_pkey") {
36 return nil, fmt.Errorf("user with DID already exists")
37 }
38 if strings.Contains(err.Error(), "users_handle_key") {
39 return nil, fmt.Errorf("handle already taken")
40 }
41 }
42 return nil, fmt.Errorf("failed to create user: %w", err)
43 }
44
45 return user, nil
46}
47
48// GetByDID retrieves a user by their DID
49func (r *postgresUserRepo) GetByDID(ctx context.Context, did string) (*users.User, error) {
50 user := &users.User{}
51 query := `SELECT did, handle, pds_url, created_at, updated_at FROM users WHERE did = $1`
52
53 err := r.db.QueryRowContext(ctx, query, did).
54 Scan(&user.DID, &user.Handle, &user.PDSURL, &user.CreatedAt, &user.UpdatedAt)
55
56 if err == sql.ErrNoRows {
57 return nil, fmt.Errorf("user not found")
58 }
59 if err != nil {
60 return nil, fmt.Errorf("failed to get user by DID: %w", err)
61 }
62
63 return user, nil
64}
65
66// GetByHandle retrieves a user by their handle
67func (r *postgresUserRepo) GetByHandle(ctx context.Context, handle string) (*users.User, error) {
68 user := &users.User{}
69 query := `SELECT did, handle, pds_url, created_at, updated_at FROM users WHERE handle = $1`
70
71 err := r.db.QueryRowContext(ctx, query, handle).
72 Scan(&user.DID, &user.Handle, &user.PDSURL, &user.CreatedAt, &user.UpdatedAt)
73
74 if err == sql.ErrNoRows {
75 return nil, fmt.Errorf("user not found")
76 }
77 if err != nil {
78 return nil, fmt.Errorf("failed to get user by handle: %w", err)
79 }
80
81 return user, nil
82}
83
84// UpdateHandle updates the handle for a user with the given DID
85func (r *postgresUserRepo) UpdateHandle(ctx context.Context, did, newHandle string) (*users.User, error) {
86 user := &users.User{}
87 query := `
88 UPDATE users
89 SET handle = $2, updated_at = NOW()
90 WHERE did = $1
91 RETURNING did, handle, pds_url, created_at, updated_at`
92
93 err := r.db.QueryRowContext(ctx, query, did, newHandle).
94 Scan(&user.DID, &user.Handle, &user.PDSURL, &user.CreatedAt, &user.UpdatedAt)
95
96 if err == sql.ErrNoRows {
97 return nil, fmt.Errorf("user not found")
98 }
99 if err != nil {
100 // Check for unique constraint violation on handle
101 if strings.Contains(err.Error(), "duplicate key") && strings.Contains(err.Error(), "users_handle_key") {
102 return nil, fmt.Errorf("handle already taken")
103 }
104 return nil, fmt.Errorf("failed to update handle: %w", err)
105 }
106
107 return user, nil
108}
109
110const MaxBatchSize = 1000
111
112// GetByDIDs retrieves multiple users by their DIDs in a single query
113// Returns a map of DID -> User for efficient lookups
114// Missing users are not included in the result map (no error for missing users)
115func (r *postgresUserRepo) GetByDIDs(ctx context.Context, dids []string) (map[string]*users.User, error) {
116 if len(dids) == 0 {
117 return make(map[string]*users.User), nil
118 }
119
120 // Validate batch size to prevent excessive memory usage and query timeouts
121 if len(dids) > MaxBatchSize {
122 return nil, fmt.Errorf("batch size %d exceeds maximum %d", len(dids), MaxBatchSize)
123 }
124
125 // Validate DID format to prevent SQL injection and malformed queries
126 // All atProto DIDs must start with "did:" prefix
127 for _, did := range dids {
128 if !strings.HasPrefix(did, "did:") {
129 return nil, fmt.Errorf("invalid DID format: %s", did)
130 }
131 }
132
133 // Build parameterized query with IN clause
134 // Use ANY($1) for PostgreSQL array support with pq.Array() for type conversion
135 query := `SELECT did, handle, pds_url, created_at, updated_at FROM users WHERE did = ANY($1)`
136
137 rows, err := r.db.QueryContext(ctx, query, pq.Array(dids))
138 if err != nil {
139 return nil, fmt.Errorf("failed to query users by DIDs: %w", err)
140 }
141 defer func() {
142 if closeErr := rows.Close(); closeErr != nil {
143 log.Printf("Warning: Failed to close rows: %v", closeErr)
144 }
145 }()
146
147 // Build map of results
148 result := make(map[string]*users.User, len(dids))
149 for rows.Next() {
150 user := &users.User{}
151 err := rows.Scan(&user.DID, &user.Handle, &user.PDSURL, &user.CreatedAt, &user.UpdatedAt)
152 if err != nil {
153 return nil, fmt.Errorf("failed to scan user row: %w", err)
154 }
155 result[user.DID] = user
156 }
157
158 if err = rows.Err(); err != nil {
159 return nil, fmt.Errorf("error iterating user rows: %w", err)
160 }
161
162 return result, nil
163}