A community based topic aggregation platform built on atproto
1package integration
2
3import (
4 "Coves/internal/api/middleware"
5 "Coves/internal/atproto/oauth"
6 "Coves/internal/core/users"
7 "bytes"
8 "context"
9 "database/sql"
10 "encoding/json"
11 "fmt"
12 "io"
13 "net/http"
14 "os"
15 "strings"
16 "testing"
17 "time"
18
19 oauthlib "github.com/bluesky-social/indigo/atproto/auth/oauth"
20 "github.com/bluesky-social/indigo/atproto/syntax"
21)
22
23// getTestPDSURL returns the PDS URL for testing from env var or default
24func getTestPDSURL() string {
25 pdsURL := os.Getenv("PDS_URL")
26 if pdsURL == "" {
27 pdsURL = "http://localhost:3001"
28 }
29 return pdsURL
30}
31
32// getTestInstanceDID returns the instance DID for testing from env var or default
33func getTestInstanceDID() string {
34 instanceDID := os.Getenv("INSTANCE_DID")
35 if instanceDID == "" {
36 instanceDID = "did:web:test.coves.social"
37 }
38 return instanceDID
39}
40
41// createTestUser creates a test user in the database for use in integration tests
42// Returns the created user or fails the test
43func createTestUser(t *testing.T, db *sql.DB, handle, did string) *users.User {
44 t.Helper()
45
46 ctx := context.Background()
47
48 // Create user directly in DB for speed
49 query := `
50 INSERT INTO users (did, handle, pds_url, created_at, updated_at)
51 VALUES ($1, $2, $3, NOW(), NOW())
52 RETURNING did, handle, pds_url, created_at, updated_at
53 `
54
55 user := &users.User{}
56 err := db.QueryRowContext(ctx, query, did, handle, getTestPDSURL()).Scan(
57 &user.DID,
58 &user.Handle,
59 &user.PDSURL,
60 &user.CreatedAt,
61 &user.UpdatedAt,
62 )
63 if err != nil {
64 t.Fatalf("Failed to create test user: %v", err)
65 }
66
67 return user
68}
69
70// contains checks if string s contains substring substr
71// Helper for error message assertions
72func contains(s, substr string) bool {
73 return strings.Contains(s, substr)
74}
75
76// authenticateWithPDS authenticates with PDS to get access token and DID
77// Used for setting up test environments that need PDS credentials
78func authenticateWithPDS(pdsURL, handle, password string) (string, string, error) {
79 // Call com.atproto.server.createSession
80 sessionReq := map[string]string{
81 "identifier": handle,
82 "password": password,
83 }
84
85 reqBody, marshalErr := json.Marshal(sessionReq)
86 if marshalErr != nil {
87 return "", "", fmt.Errorf("failed to marshal session request: %w", marshalErr)
88 }
89 resp, err := http.Post(
90 pdsURL+"/xrpc/com.atproto.server.createSession",
91 "application/json",
92 bytes.NewBuffer(reqBody),
93 )
94 if err != nil {
95 return "", "", fmt.Errorf("failed to create session: %w", err)
96 }
97 defer func() { _ = resp.Body.Close() }()
98
99 if resp.StatusCode != http.StatusOK {
100 body, readErr := io.ReadAll(resp.Body)
101 if readErr != nil {
102 return "", "", fmt.Errorf("PDS auth failed (status %d, failed to read body: %w)", resp.StatusCode, readErr)
103 }
104 return "", "", fmt.Errorf("PDS auth failed (status %d): %s", resp.StatusCode, string(body))
105 }
106
107 var sessionResp struct {
108 AccessJwt string `json:"accessJwt"`
109 DID string `json:"did"`
110 }
111
112 if err := json.NewDecoder(resp.Body).Decode(&sessionResp); err != nil {
113 return "", "", fmt.Errorf("failed to decode session response: %w", err)
114 }
115
116 return sessionResp.AccessJwt, sessionResp.DID, nil
117}
118
119// generateTID generates a simple timestamp-based identifier for testing
120// In production, PDS generates proper TIDs
121func generateTID() string {
122 return fmt.Sprintf("3k%d", time.Now().UnixNano()/1000)
123}
124
125// createPDSAccount creates a new account on PDS and returns access token + DID
126// This is used for E2E tests that need real PDS accounts
127func createPDSAccount(pdsURL, handle, email, password string) (accessToken, did string, err error) {
128 // Call com.atproto.server.createAccount
129 reqBody := map[string]string{
130 "handle": handle,
131 "email": email,
132 "password": password,
133 }
134
135 reqJSON, marshalErr := json.Marshal(reqBody)
136 if marshalErr != nil {
137 return "", "", fmt.Errorf("failed to marshal account request: %w", marshalErr)
138 }
139
140 resp, httpErr := http.Post(
141 pdsURL+"/xrpc/com.atproto.server.createAccount",
142 "application/json",
143 bytes.NewBuffer(reqJSON),
144 )
145 if httpErr != nil {
146 return "", "", fmt.Errorf("failed to create account: %w", httpErr)
147 }
148 defer func() { _ = resp.Body.Close() }()
149
150 if resp.StatusCode != http.StatusOK {
151 body, readErr := io.ReadAll(resp.Body)
152 if readErr != nil {
153 return "", "", fmt.Errorf("account creation failed (status %d, failed to read body: %w)", resp.StatusCode, readErr)
154 }
155 return "", "", fmt.Errorf("account creation failed (status %d): %s", resp.StatusCode, string(body))
156 }
157
158 var accountResp struct {
159 AccessJwt string `json:"accessJwt"`
160 DID string `json:"did"`
161 }
162
163 if decodeErr := json.NewDecoder(resp.Body).Decode(&accountResp); decodeErr != nil {
164 return "", "", fmt.Errorf("failed to decode account response: %w", decodeErr)
165 }
166
167 return accountResp.AccessJwt, accountResp.DID, nil
168}
169
170// writePDSRecord writes a record to PDS via com.atproto.repo.createRecord
171// Returns the AT-URI and CID of the created record
172func writePDSRecord(pdsURL, accessToken, repo, collection, rkey string, record interface{}) (uri, cid string, err error) {
173 reqBody := map[string]interface{}{
174 "repo": repo,
175 "collection": collection,
176 "record": record,
177 }
178
179 // If rkey is provided, include it
180 if rkey != "" {
181 reqBody["rkey"] = rkey
182 }
183
184 reqJSON, marshalErr := json.Marshal(reqBody)
185 if marshalErr != nil {
186 return "", "", fmt.Errorf("failed to marshal record request: %w", marshalErr)
187 }
188
189 req, reqErr := http.NewRequest("POST", pdsURL+"/xrpc/com.atproto.repo.createRecord", bytes.NewBuffer(reqJSON))
190 if reqErr != nil {
191 return "", "", fmt.Errorf("failed to create request: %w", reqErr)
192 }
193
194 req.Header.Set("Content-Type", "application/json")
195 req.Header.Set("Authorization", "Bearer "+accessToken)
196
197 resp, httpErr := http.DefaultClient.Do(req)
198 if httpErr != nil {
199 return "", "", fmt.Errorf("failed to write record: %w", httpErr)
200 }
201 defer func() { _ = resp.Body.Close() }()
202
203 if resp.StatusCode != http.StatusOK {
204 body, readErr := io.ReadAll(resp.Body)
205 if readErr != nil {
206 return "", "", fmt.Errorf("record creation failed (status %d, failed to read body: %w)", resp.StatusCode, readErr)
207 }
208 return "", "", fmt.Errorf("record creation failed (status %d): %s", resp.StatusCode, string(body))
209 }
210
211 var recordResp struct {
212 URI string `json:"uri"`
213 CID string `json:"cid"`
214 }
215
216 if decodeErr := json.NewDecoder(resp.Body).Decode(&recordResp); decodeErr != nil {
217 return "", "", fmt.Errorf("failed to decode record response: %w", decodeErr)
218 }
219
220 return recordResp.URI, recordResp.CID, nil
221}
222
223// createFeedTestCommunity creates a test community for feed tests
224// Returns the community DID or an error
225func createFeedTestCommunity(db *sql.DB, ctx context.Context, name, ownerHandle string) (string, error) {
226 // Get configuration from env vars
227 pdsURL := getTestPDSURL()
228 instanceDID := getTestInstanceDID()
229
230 // Create owner user first (directly insert to avoid service dependencies)
231 ownerDID := fmt.Sprintf("did:plc:%s", ownerHandle)
232 _, err := db.ExecContext(ctx, `
233 INSERT INTO users (did, handle, pds_url, created_at)
234 VALUES ($1, $2, $3, NOW())
235 ON CONFLICT (did) DO NOTHING
236 `, ownerDID, ownerHandle, pdsURL)
237 if err != nil {
238 return "", err
239 }
240
241 // Create community
242 communityDID := fmt.Sprintf("did:plc:community-%s", name)
243 _, err = db.ExecContext(ctx, `
244 INSERT INTO communities (did, name, owner_did, created_by_did, hosted_by_did, handle, pds_url, created_at)
245 VALUES ($1, $2, $3, $4, $5, $6, $7, NOW())
246 ON CONFLICT (did) DO NOTHING
247 `, communityDID, name, ownerDID, ownerDID, instanceDID, fmt.Sprintf("%s.coves.social", name), pdsURL)
248
249 return communityDID, err
250}
251
252// createTestPost creates a test post and returns its URI
253func createTestPost(t *testing.T, db *sql.DB, communityDID, authorDID, title string, score int, createdAt time.Time) string {
254 t.Helper()
255
256 ctx := context.Background()
257
258 // Create author user if not exists (directly insert to avoid service dependencies)
259 _, _ = db.ExecContext(ctx, `
260 INSERT INTO users (did, handle, pds_url, created_at)
261 VALUES ($1, $2, $3, NOW())
262 ON CONFLICT (did) DO NOTHING
263 `, authorDID, fmt.Sprintf("%s.bsky.social", authorDID), getTestPDSURL())
264
265 // Generate URI
266 rkey := fmt.Sprintf("post-%d", time.Now().UnixNano())
267 uri := fmt.Sprintf("at://%s/social.coves.community.post/%s", communityDID, rkey)
268
269 // Insert post
270 _, err := db.ExecContext(ctx, `
271 INSERT INTO posts (uri, cid, rkey, author_did, community_did, title, created_at, score, upvote_count)
272 VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)
273 `, uri, "bafytest", rkey, authorDID, communityDID, title, createdAt, score, score)
274 if err != nil {
275 t.Fatalf("Failed to create test post: %v", err)
276 }
277
278 return uri
279}
280
281// MockSessionUnsealer is a mock implementation of SessionUnsealer for testing
282// It returns predefined sessions based on token value
283type MockSessionUnsealer struct {
284 sessions map[string]*oauth.SealedSession
285}
286
287// NewMockSessionUnsealer creates a new mock unsealer
288func NewMockSessionUnsealer() *MockSessionUnsealer {
289 return &MockSessionUnsealer{
290 sessions: make(map[string]*oauth.SealedSession),
291 }
292}
293
294// AddSession adds a token -> session mapping
295func (m *MockSessionUnsealer) AddSession(token, did, sessionID string) {
296 m.sessions[token] = &oauth.SealedSession{
297 DID: did,
298 SessionID: sessionID,
299 ExpiresAt: time.Now().Add(1 * time.Hour).Unix(),
300 }
301}
302
303// UnsealSession returns the predefined session for a token
304func (m *MockSessionUnsealer) UnsealSession(token string) (*oauth.SealedSession, error) {
305 if sess, ok := m.sessions[token]; ok {
306 return sess, nil
307 }
308 return nil, fmt.Errorf("unknown token")
309}
310
311// MockOAuthStore is a mock implementation of ClientAuthStore for testing
312type MockOAuthStore struct {
313 sessions map[string]*oauthlib.ClientSessionData
314}
315
316// NewMockOAuthStore creates a new mock OAuth store
317func NewMockOAuthStore() *MockOAuthStore {
318 return &MockOAuthStore{
319 sessions: make(map[string]*oauthlib.ClientSessionData),
320 }
321}
322
323// AddSession adds a session to the store
324func (m *MockOAuthStore) AddSession(did, sessionID, accessToken string) {
325 key := did + ":" + sessionID
326 parsedDID, _ := syntax.ParseDID(did)
327 m.sessions[key] = &oauthlib.ClientSessionData{
328 AccountDID: parsedDID,
329 SessionID: sessionID,
330 AccessToken: accessToken,
331 }
332}
333
334// GetSession implements ClientAuthStore
335func (m *MockOAuthStore) GetSession(ctx context.Context, did syntax.DID, sessionID string) (*oauthlib.ClientSessionData, error) {
336 key := did.String() + ":" + sessionID
337 if sess, ok := m.sessions[key]; ok {
338 return sess, nil
339 }
340 return nil, fmt.Errorf("session not found")
341}
342
343// SaveSession implements ClientAuthStore
344func (m *MockOAuthStore) SaveSession(ctx context.Context, sess oauthlib.ClientSessionData) error {
345 key := sess.AccountDID.String() + ":" + sess.SessionID
346 m.sessions[key] = &sess
347 return nil
348}
349
350// DeleteSession implements ClientAuthStore
351func (m *MockOAuthStore) DeleteSession(ctx context.Context, did syntax.DID, sessionID string) error {
352 key := did.String() + ":" + sessionID
353 delete(m.sessions, key)
354 return nil
355}
356
357// GetAuthRequestInfo implements ClientAuthStore
358func (m *MockOAuthStore) GetAuthRequestInfo(ctx context.Context, state string) (*oauthlib.AuthRequestData, error) {
359 return nil, fmt.Errorf("not implemented in mock")
360}
361
362// SaveAuthRequestInfo implements ClientAuthStore
363func (m *MockOAuthStore) SaveAuthRequestInfo(ctx context.Context, info oauthlib.AuthRequestData) error {
364 return nil
365}
366
367// DeleteAuthRequestInfo implements ClientAuthStore
368func (m *MockOAuthStore) DeleteAuthRequestInfo(ctx context.Context, state string) error {
369 return nil
370}
371
372// CreateTestOAuthMiddleware creates an OAuth middleware with mock implementations for testing
373// The returned middleware accepts a test token that maps to the specified userDID
374func CreateTestOAuthMiddleware(userDID string) (*middleware.OAuthAuthMiddleware, string) {
375 unsealer := NewMockSessionUnsealer()
376 store := NewMockOAuthStore()
377
378 testToken := "test-token-" + userDID
379 sessionID := "test-session-123"
380
381 // Add the test session
382 unsealer.AddSession(testToken, userDID, sessionID)
383 store.AddSession(userDID, sessionID, "test-access-token")
384
385 authMiddleware := middleware.NewOAuthAuthMiddleware(unsealer, store)
386 return authMiddleware, testToken
387}
388
389// E2EOAuthMiddleware wraps OAuth middleware for E2E testing with multiple users
390type E2EOAuthMiddleware struct {
391 *middleware.OAuthAuthMiddleware
392 unsealer *MockSessionUnsealer
393 store *MockOAuthStore
394}
395
396// NewE2EOAuthMiddleware creates an OAuth middleware for E2E testing
397func NewE2EOAuthMiddleware() *E2EOAuthMiddleware {
398 unsealer := NewMockSessionUnsealer()
399 store := NewMockOAuthStore()
400 m := middleware.NewOAuthAuthMiddleware(unsealer, store)
401 return &E2EOAuthMiddleware{m, unsealer, store}
402}
403
404// AddUser registers a user DID and returns the token to use in Authorization header
405func (e *E2EOAuthMiddleware) AddUser(did string) string {
406 token := "test-token-" + did
407 sessionID := "session-" + did
408 e.unsealer.AddSession(token, did, sessionID)
409 e.store.AddSession(did, sessionID, "access-token-"+did)
410 return token
411}