A community based topic aggregation platform built on atproto
1package integration 2 3import ( 4 "Coves/internal/api/middleware" 5 "Coves/internal/atproto/oauth" 6 "Coves/internal/core/users" 7 "bytes" 8 "context" 9 "database/sql" 10 "encoding/json" 11 "fmt" 12 "io" 13 "net/http" 14 "os" 15 "strings" 16 "testing" 17 "time" 18 19 oauthlib "github.com/bluesky-social/indigo/atproto/auth/oauth" 20 "github.com/bluesky-social/indigo/atproto/syntax" 21) 22 23// getTestPDSURL returns the PDS URL for testing from env var or default 24func getTestPDSURL() string { 25 pdsURL := os.Getenv("PDS_URL") 26 if pdsURL == "" { 27 pdsURL = "http://localhost:3001" 28 } 29 return pdsURL 30} 31 32// getTestInstanceDID returns the instance DID for testing from env var or default 33func getTestInstanceDID() string { 34 instanceDID := os.Getenv("INSTANCE_DID") 35 if instanceDID == "" { 36 instanceDID = "did:web:test.coves.social" 37 } 38 return instanceDID 39} 40 41// createTestUser creates a test user in the database for use in integration tests 42// Returns the created user or fails the test 43func createTestUser(t *testing.T, db *sql.DB, handle, did string) *users.User { 44 t.Helper() 45 46 ctx := context.Background() 47 48 // Create user directly in DB for speed 49 query := ` 50 INSERT INTO users (did, handle, pds_url, created_at, updated_at) 51 VALUES ($1, $2, $3, NOW(), NOW()) 52 RETURNING did, handle, pds_url, created_at, updated_at 53 ` 54 55 user := &users.User{} 56 err := db.QueryRowContext(ctx, query, did, handle, getTestPDSURL()).Scan( 57 &user.DID, 58 &user.Handle, 59 &user.PDSURL, 60 &user.CreatedAt, 61 &user.UpdatedAt, 62 ) 63 if err != nil { 64 t.Fatalf("Failed to create test user: %v", err) 65 } 66 67 return user 68} 69 70// contains checks if string s contains substring substr 71// Helper for error message assertions 72func contains(s, substr string) bool { 73 return strings.Contains(s, substr) 74} 75 76// authenticateWithPDS authenticates with PDS to get access token and DID 77// Used for setting up test environments that need PDS credentials 78func authenticateWithPDS(pdsURL, handle, password string) (string, string, error) { 79 // Call com.atproto.server.createSession 80 sessionReq := map[string]string{ 81 "identifier": handle, 82 "password": password, 83 } 84 85 reqBody, marshalErr := json.Marshal(sessionReq) 86 if marshalErr != nil { 87 return "", "", fmt.Errorf("failed to marshal session request: %w", marshalErr) 88 } 89 resp, err := http.Post( 90 pdsURL+"/xrpc/com.atproto.server.createSession", 91 "application/json", 92 bytes.NewBuffer(reqBody), 93 ) 94 if err != nil { 95 return "", "", fmt.Errorf("failed to create session: %w", err) 96 } 97 defer func() { _ = resp.Body.Close() }() 98 99 if resp.StatusCode != http.StatusOK { 100 body, readErr := io.ReadAll(resp.Body) 101 if readErr != nil { 102 return "", "", fmt.Errorf("PDS auth failed (status %d, failed to read body: %w)", resp.StatusCode, readErr) 103 } 104 return "", "", fmt.Errorf("PDS auth failed (status %d): %s", resp.StatusCode, string(body)) 105 } 106 107 var sessionResp struct { 108 AccessJwt string `json:"accessJwt"` 109 DID string `json:"did"` 110 } 111 112 if err := json.NewDecoder(resp.Body).Decode(&sessionResp); err != nil { 113 return "", "", fmt.Errorf("failed to decode session response: %w", err) 114 } 115 116 return sessionResp.AccessJwt, sessionResp.DID, nil 117} 118 119// generateTID generates a simple timestamp-based identifier for testing 120// In production, PDS generates proper TIDs 121func generateTID() string { 122 return fmt.Sprintf("3k%d", time.Now().UnixNano()/1000) 123} 124 125// createPDSAccount creates a new account on PDS and returns access token + DID 126// This is used for E2E tests that need real PDS accounts 127func createPDSAccount(pdsURL, handle, email, password string) (accessToken, did string, err error) { 128 // Call com.atproto.server.createAccount 129 reqBody := map[string]string{ 130 "handle": handle, 131 "email": email, 132 "password": password, 133 } 134 135 reqJSON, marshalErr := json.Marshal(reqBody) 136 if marshalErr != nil { 137 return "", "", fmt.Errorf("failed to marshal account request: %w", marshalErr) 138 } 139 140 resp, httpErr := http.Post( 141 pdsURL+"/xrpc/com.atproto.server.createAccount", 142 "application/json", 143 bytes.NewBuffer(reqJSON), 144 ) 145 if httpErr != nil { 146 return "", "", fmt.Errorf("failed to create account: %w", httpErr) 147 } 148 defer func() { _ = resp.Body.Close() }() 149 150 if resp.StatusCode != http.StatusOK { 151 body, readErr := io.ReadAll(resp.Body) 152 if readErr != nil { 153 return "", "", fmt.Errorf("account creation failed (status %d, failed to read body: %w)", resp.StatusCode, readErr) 154 } 155 return "", "", fmt.Errorf("account creation failed (status %d): %s", resp.StatusCode, string(body)) 156 } 157 158 var accountResp struct { 159 AccessJwt string `json:"accessJwt"` 160 DID string `json:"did"` 161 } 162 163 if decodeErr := json.NewDecoder(resp.Body).Decode(&accountResp); decodeErr != nil { 164 return "", "", fmt.Errorf("failed to decode account response: %w", decodeErr) 165 } 166 167 return accountResp.AccessJwt, accountResp.DID, nil 168} 169 170// writePDSRecord writes a record to PDS via com.atproto.repo.createRecord 171// Returns the AT-URI and CID of the created record 172func writePDSRecord(pdsURL, accessToken, repo, collection, rkey string, record interface{}) (uri, cid string, err error) { 173 reqBody := map[string]interface{}{ 174 "repo": repo, 175 "collection": collection, 176 "record": record, 177 } 178 179 // If rkey is provided, include it 180 if rkey != "" { 181 reqBody["rkey"] = rkey 182 } 183 184 reqJSON, marshalErr := json.Marshal(reqBody) 185 if marshalErr != nil { 186 return "", "", fmt.Errorf("failed to marshal record request: %w", marshalErr) 187 } 188 189 req, reqErr := http.NewRequest("POST", pdsURL+"/xrpc/com.atproto.repo.createRecord", bytes.NewBuffer(reqJSON)) 190 if reqErr != nil { 191 return "", "", fmt.Errorf("failed to create request: %w", reqErr) 192 } 193 194 req.Header.Set("Content-Type", "application/json") 195 req.Header.Set("Authorization", "Bearer "+accessToken) 196 197 resp, httpErr := http.DefaultClient.Do(req) 198 if httpErr != nil { 199 return "", "", fmt.Errorf("failed to write record: %w", httpErr) 200 } 201 defer func() { _ = resp.Body.Close() }() 202 203 if resp.StatusCode != http.StatusOK { 204 body, readErr := io.ReadAll(resp.Body) 205 if readErr != nil { 206 return "", "", fmt.Errorf("record creation failed (status %d, failed to read body: %w)", resp.StatusCode, readErr) 207 } 208 return "", "", fmt.Errorf("record creation failed (status %d): %s", resp.StatusCode, string(body)) 209 } 210 211 var recordResp struct { 212 URI string `json:"uri"` 213 CID string `json:"cid"` 214 } 215 216 if decodeErr := json.NewDecoder(resp.Body).Decode(&recordResp); decodeErr != nil { 217 return "", "", fmt.Errorf("failed to decode record response: %w", decodeErr) 218 } 219 220 return recordResp.URI, recordResp.CID, nil 221} 222 223// createFeedTestCommunity creates a test community for feed tests 224// Returns the community DID or an error 225func createFeedTestCommunity(db *sql.DB, ctx context.Context, name, ownerHandle string) (string, error) { 226 // Get configuration from env vars 227 pdsURL := getTestPDSURL() 228 instanceDID := getTestInstanceDID() 229 230 // Create owner user first (directly insert to avoid service dependencies) 231 ownerDID := fmt.Sprintf("did:plc:%s", ownerHandle) 232 _, err := db.ExecContext(ctx, ` 233 INSERT INTO users (did, handle, pds_url, created_at) 234 VALUES ($1, $2, $3, NOW()) 235 ON CONFLICT (did) DO NOTHING 236 `, ownerDID, ownerHandle, pdsURL) 237 if err != nil { 238 return "", err 239 } 240 241 // Create community 242 communityDID := fmt.Sprintf("did:plc:community-%s", name) 243 _, err = db.ExecContext(ctx, ` 244 INSERT INTO communities (did, name, owner_did, created_by_did, hosted_by_did, handle, pds_url, created_at) 245 VALUES ($1, $2, $3, $4, $5, $6, $7, NOW()) 246 ON CONFLICT (did) DO NOTHING 247 `, communityDID, name, ownerDID, ownerDID, instanceDID, fmt.Sprintf("%s.coves.social", name), pdsURL) 248 249 return communityDID, err 250} 251 252// createTestPost creates a test post and returns its URI 253func createTestPost(t *testing.T, db *sql.DB, communityDID, authorDID, title string, score int, createdAt time.Time) string { 254 t.Helper() 255 256 ctx := context.Background() 257 258 // Create author user if not exists (directly insert to avoid service dependencies) 259 _, _ = db.ExecContext(ctx, ` 260 INSERT INTO users (did, handle, pds_url, created_at) 261 VALUES ($1, $2, $3, NOW()) 262 ON CONFLICT (did) DO NOTHING 263 `, authorDID, fmt.Sprintf("%s.bsky.social", authorDID), getTestPDSURL()) 264 265 // Generate URI 266 rkey := fmt.Sprintf("post-%d", time.Now().UnixNano()) 267 uri := fmt.Sprintf("at://%s/social.coves.community.post/%s", communityDID, rkey) 268 269 // Insert post 270 _, err := db.ExecContext(ctx, ` 271 INSERT INTO posts (uri, cid, rkey, author_did, community_did, title, created_at, score, upvote_count) 272 VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9) 273 `, uri, "bafytest", rkey, authorDID, communityDID, title, createdAt, score, score) 274 if err != nil { 275 t.Fatalf("Failed to create test post: %v", err) 276 } 277 278 return uri 279} 280 281// MockSessionUnsealer is a mock implementation of SessionUnsealer for testing 282// It returns predefined sessions based on token value 283type MockSessionUnsealer struct { 284 sessions map[string]*oauth.SealedSession 285} 286 287// NewMockSessionUnsealer creates a new mock unsealer 288func NewMockSessionUnsealer() *MockSessionUnsealer { 289 return &MockSessionUnsealer{ 290 sessions: make(map[string]*oauth.SealedSession), 291 } 292} 293 294// AddSession adds a token -> session mapping 295func (m *MockSessionUnsealer) AddSession(token, did, sessionID string) { 296 m.sessions[token] = &oauth.SealedSession{ 297 DID: did, 298 SessionID: sessionID, 299 ExpiresAt: time.Now().Add(1 * time.Hour).Unix(), 300 } 301} 302 303// UnsealSession returns the predefined session for a token 304func (m *MockSessionUnsealer) UnsealSession(token string) (*oauth.SealedSession, error) { 305 if sess, ok := m.sessions[token]; ok { 306 return sess, nil 307 } 308 return nil, fmt.Errorf("unknown token") 309} 310 311// MockOAuthStore is a mock implementation of ClientAuthStore for testing 312type MockOAuthStore struct { 313 sessions map[string]*oauthlib.ClientSessionData 314} 315 316// NewMockOAuthStore creates a new mock OAuth store 317func NewMockOAuthStore() *MockOAuthStore { 318 return &MockOAuthStore{ 319 sessions: make(map[string]*oauthlib.ClientSessionData), 320 } 321} 322 323// AddSession adds a session to the store 324func (m *MockOAuthStore) AddSession(did, sessionID, accessToken string) { 325 key := did + ":" + sessionID 326 parsedDID, _ := syntax.ParseDID(did) 327 m.sessions[key] = &oauthlib.ClientSessionData{ 328 AccountDID: parsedDID, 329 SessionID: sessionID, 330 AccessToken: accessToken, 331 } 332} 333 334// GetSession implements ClientAuthStore 335func (m *MockOAuthStore) GetSession(ctx context.Context, did syntax.DID, sessionID string) (*oauthlib.ClientSessionData, error) { 336 key := did.String() + ":" + sessionID 337 if sess, ok := m.sessions[key]; ok { 338 return sess, nil 339 } 340 return nil, fmt.Errorf("session not found") 341} 342 343// SaveSession implements ClientAuthStore 344func (m *MockOAuthStore) SaveSession(ctx context.Context, sess oauthlib.ClientSessionData) error { 345 key := sess.AccountDID.String() + ":" + sess.SessionID 346 m.sessions[key] = &sess 347 return nil 348} 349 350// DeleteSession implements ClientAuthStore 351func (m *MockOAuthStore) DeleteSession(ctx context.Context, did syntax.DID, sessionID string) error { 352 key := did.String() + ":" + sessionID 353 delete(m.sessions, key) 354 return nil 355} 356 357// GetAuthRequestInfo implements ClientAuthStore 358func (m *MockOAuthStore) GetAuthRequestInfo(ctx context.Context, state string) (*oauthlib.AuthRequestData, error) { 359 return nil, fmt.Errorf("not implemented in mock") 360} 361 362// SaveAuthRequestInfo implements ClientAuthStore 363func (m *MockOAuthStore) SaveAuthRequestInfo(ctx context.Context, info oauthlib.AuthRequestData) error { 364 return nil 365} 366 367// DeleteAuthRequestInfo implements ClientAuthStore 368func (m *MockOAuthStore) DeleteAuthRequestInfo(ctx context.Context, state string) error { 369 return nil 370} 371 372// CreateTestOAuthMiddleware creates an OAuth middleware with mock implementations for testing 373// The returned middleware accepts a test token that maps to the specified userDID 374func CreateTestOAuthMiddleware(userDID string) (*middleware.OAuthAuthMiddleware, string) { 375 unsealer := NewMockSessionUnsealer() 376 store := NewMockOAuthStore() 377 378 testToken := "test-token-" + userDID 379 sessionID := "test-session-123" 380 381 // Add the test session 382 unsealer.AddSession(testToken, userDID, sessionID) 383 store.AddSession(userDID, sessionID, "test-access-token") 384 385 authMiddleware := middleware.NewOAuthAuthMiddleware(unsealer, store) 386 return authMiddleware, testToken 387} 388 389// E2EOAuthMiddleware wraps OAuth middleware for E2E testing with multiple users 390type E2EOAuthMiddleware struct { 391 *middleware.OAuthAuthMiddleware 392 unsealer *MockSessionUnsealer 393 store *MockOAuthStore 394} 395 396// NewE2EOAuthMiddleware creates an OAuth middleware for E2E testing 397func NewE2EOAuthMiddleware() *E2EOAuthMiddleware { 398 unsealer := NewMockSessionUnsealer() 399 store := NewMockOAuthStore() 400 m := middleware.NewOAuthAuthMiddleware(unsealer, store) 401 return &E2EOAuthMiddleware{m, unsealer, store} 402} 403 404// AddUser registers a user DID and returns the token to use in Authorization header 405func (e *E2EOAuthMiddleware) AddUser(did string) string { 406 token := "test-token-" + did 407 sessionID := "session-" + did 408 e.unsealer.AddSession(token, did, sessionID) 409 e.store.AddSession(did, sessionID, "access-token-"+did) 410 return token 411}