A community based topic aggregation platform built on atproto
1package integration 2 3import ( 4 "Coves/internal/api/middleware" 5 "Coves/internal/atproto/oauth" 6 "Coves/internal/core/users" 7 "bytes" 8 "context" 9 "database/sql" 10 "encoding/json" 11 "fmt" 12 "io" 13 "net/http" 14 "os" 15 "strings" 16 "testing" 17 "time" 18 19 oauthlib "github.com/bluesky-social/indigo/atproto/auth/oauth" 20 "github.com/bluesky-social/indigo/atproto/syntax" 21) 22 23// getTestPDSURL returns the PDS URL for testing from env var or default 24func getTestPDSURL() string { 25 pdsURL := os.Getenv("PDS_URL") 26 if pdsURL == "" { 27 pdsURL = "http://localhost:3001" 28 } 29 return pdsURL 30} 31 32// getTestInstanceDID returns the instance DID for testing from env var or default 33func getTestInstanceDID() string { 34 instanceDID := os.Getenv("INSTANCE_DID") 35 if instanceDID == "" { 36 instanceDID = "did:web:test.coves.social" 37 } 38 return instanceDID 39} 40 41// createTestUser creates a test user in the database for use in integration tests 42// Returns the created user or fails the test 43func createTestUser(t *testing.T, db *sql.DB, handle, did string) *users.User { 44 t.Helper() 45 46 ctx := context.Background() 47 48 // Create user directly in DB for speed 49 query := ` 50 INSERT INTO users (did, handle, pds_url, created_at, updated_at) 51 VALUES ($1, $2, $3, NOW(), NOW()) 52 RETURNING did, handle, pds_url, created_at, updated_at 53 ` 54 55 user := &users.User{} 56 err := db.QueryRowContext(ctx, query, did, handle, getTestPDSURL()).Scan( 57 &user.DID, 58 &user.Handle, 59 &user.PDSURL, 60 &user.CreatedAt, 61 &user.UpdatedAt, 62 ) 63 if err != nil { 64 t.Fatalf("Failed to create test user: %v", err) 65 } 66 67 return user 68} 69 70// contains checks if string s contains substring substr 71// Helper for error message assertions 72func contains(s, substr string) bool { 73 return strings.Contains(s, substr) 74} 75 76// authenticateWithPDS authenticates with PDS to get access token and DID 77// Used for setting up test environments that need PDS credentials 78func authenticateWithPDS(pdsURL, handle, password string) (string, string, error) { 79 // Call com.atproto.server.createSession 80 sessionReq := map[string]string{ 81 "identifier": handle, 82 "password": password, 83 } 84 85 reqBody, marshalErr := json.Marshal(sessionReq) 86 if marshalErr != nil { 87 return "", "", fmt.Errorf("failed to marshal session request: %w", marshalErr) 88 } 89 resp, err := http.Post( 90 pdsURL+"/xrpc/com.atproto.server.createSession", 91 "application/json", 92 bytes.NewBuffer(reqBody), 93 ) 94 if err != nil { 95 return "", "", fmt.Errorf("failed to create session: %w", err) 96 } 97 defer func() { _ = resp.Body.Close() }() 98 99 if resp.StatusCode != http.StatusOK { 100 body, readErr := io.ReadAll(resp.Body) 101 if readErr != nil { 102 return "", "", fmt.Errorf("PDS auth failed (status %d, failed to read body: %w)", resp.StatusCode, readErr) 103 } 104 return "", "", fmt.Errorf("PDS auth failed (status %d): %s", resp.StatusCode, string(body)) 105 } 106 107 var sessionResp struct { 108 AccessJwt string `json:"accessJwt"` 109 DID string `json:"did"` 110 } 111 112 if err := json.NewDecoder(resp.Body).Decode(&sessionResp); err != nil { 113 return "", "", fmt.Errorf("failed to decode session response: %w", err) 114 } 115 116 return sessionResp.AccessJwt, sessionResp.DID, nil 117} 118 119// generateTID generates a simple timestamp-based identifier for testing 120// In production, PDS generates proper TIDs 121func generateTID() string { 122 return fmt.Sprintf("3k%d", time.Now().UnixNano()/1000) 123} 124 125// createPDSAccount creates a new account on PDS and returns access token + DID 126// This is used for E2E tests that need real PDS accounts 127func createPDSAccount(pdsURL, handle, email, password string) (accessToken, did string, err error) { 128 // Call com.atproto.server.createAccount 129 reqBody := map[string]string{ 130 "handle": handle, 131 "email": email, 132 "password": password, 133 } 134 135 reqJSON, marshalErr := json.Marshal(reqBody) 136 if marshalErr != nil { 137 return "", "", fmt.Errorf("failed to marshal account request: %w", marshalErr) 138 } 139 140 resp, httpErr := http.Post( 141 pdsURL+"/xrpc/com.atproto.server.createAccount", 142 "application/json", 143 bytes.NewBuffer(reqJSON), 144 ) 145 if httpErr != nil { 146 return "", "", fmt.Errorf("failed to create account: %w", httpErr) 147 } 148 defer func() { _ = resp.Body.Close() }() 149 150 if resp.StatusCode != http.StatusOK { 151 body, readErr := io.ReadAll(resp.Body) 152 if readErr != nil { 153 return "", "", fmt.Errorf("account creation failed (status %d, failed to read body: %w)", resp.StatusCode, readErr) 154 } 155 return "", "", fmt.Errorf("account creation failed (status %d): %s", resp.StatusCode, string(body)) 156 } 157 158 var accountResp struct { 159 AccessJwt string `json:"accessJwt"` 160 DID string `json:"did"` 161 } 162 163 if decodeErr := json.NewDecoder(resp.Body).Decode(&accountResp); decodeErr != nil { 164 return "", "", fmt.Errorf("failed to decode account response: %w", decodeErr) 165 } 166 167 return accountResp.AccessJwt, accountResp.DID, nil 168} 169 170// writePDSRecord writes a record to PDS via com.atproto.repo.createRecord 171// Returns the AT-URI and CID of the created record 172func writePDSRecord(pdsURL, accessToken, repo, collection, rkey string, record interface{}) (uri, cid string, err error) { 173 reqBody := map[string]interface{}{ 174 "repo": repo, 175 "collection": collection, 176 "record": record, 177 } 178 179 // If rkey is provided, include it 180 if rkey != "" { 181 reqBody["rkey"] = rkey 182 } 183 184 reqJSON, marshalErr := json.Marshal(reqBody) 185 if marshalErr != nil { 186 return "", "", fmt.Errorf("failed to marshal record request: %w", marshalErr) 187 } 188 189 req, reqErr := http.NewRequest("POST", pdsURL+"/xrpc/com.atproto.repo.createRecord", bytes.NewBuffer(reqJSON)) 190 if reqErr != nil { 191 return "", "", fmt.Errorf("failed to create request: %w", reqErr) 192 } 193 194 req.Header.Set("Content-Type", "application/json") 195 req.Header.Set("Authorization", "Bearer "+accessToken) 196 197 resp, httpErr := http.DefaultClient.Do(req) 198 if httpErr != nil { 199 return "", "", fmt.Errorf("failed to write record: %w", httpErr) 200 } 201 defer func() { _ = resp.Body.Close() }() 202 203 if resp.StatusCode != http.StatusOK { 204 body, readErr := io.ReadAll(resp.Body) 205 if readErr != nil { 206 return "", "", fmt.Errorf("record creation failed (status %d, failed to read body: %w)", resp.StatusCode, readErr) 207 } 208 return "", "", fmt.Errorf("record creation failed (status %d): %s", resp.StatusCode, string(body)) 209 } 210 211 var recordResp struct { 212 URI string `json:"uri"` 213 CID string `json:"cid"` 214 } 215 216 if decodeErr := json.NewDecoder(resp.Body).Decode(&recordResp); decodeErr != nil { 217 return "", "", fmt.Errorf("failed to decode record response: %w", decodeErr) 218 } 219 220 return recordResp.URI, recordResp.CID, nil 221} 222 223// createFeedTestCommunity creates a test community for feed tests 224// Returns the community DID or an error 225func createFeedTestCommunity(db *sql.DB, ctx context.Context, name, ownerHandle string) (string, error) { 226 // Get configuration from env vars 227 pdsURL := getTestPDSURL() 228 instanceDID := getTestInstanceDID() 229 230 // Create owner user first (directly insert to avoid service dependencies) 231 ownerDID := fmt.Sprintf("did:plc:%s", ownerHandle) 232 _, err := db.ExecContext(ctx, ` 233 INSERT INTO users (did, handle, pds_url, created_at) 234 VALUES ($1, $2, $3, NOW()) 235 ON CONFLICT (did) DO NOTHING 236 `, ownerDID, ownerHandle, pdsURL) 237 if err != nil { 238 return "", err 239 } 240 241 // Create community 242 communityDID := fmt.Sprintf("did:plc:community-%s", name) 243 _, err = db.ExecContext(ctx, ` 244 INSERT INTO communities (did, name, owner_did, created_by_did, hosted_by_did, handle, pds_url, created_at) 245 VALUES ($1, $2, $3, $4, $5, $6, $7, NOW()) 246 ON CONFLICT (did) DO NOTHING 247 `, communityDID, name, ownerDID, ownerDID, instanceDID, fmt.Sprintf("%s.coves.social", name), pdsURL) 248 249 return communityDID, err 250} 251 252// createTestPost creates a test post and returns its URI 253func createTestPost(t *testing.T, db *sql.DB, communityDID, authorDID, title string, score int, createdAt time.Time) string { 254 t.Helper() 255 256 ctx := context.Background() 257 258 // Create author user if not exists (directly insert to avoid service dependencies) 259 _, _ = db.ExecContext(ctx, ` 260 INSERT INTO users (did, handle, pds_url, created_at) 261 VALUES ($1, $2, $3, NOW()) 262 ON CONFLICT (did) DO NOTHING 263 `, authorDID, fmt.Sprintf("%s.bsky.social", authorDID), getTestPDSURL()) 264 265 // Generate URI 266 rkey := fmt.Sprintf("post-%d", time.Now().UnixNano()) 267 uri := fmt.Sprintf("at://%s/social.coves.community.post/%s", communityDID, rkey) 268 269 // Insert post 270 _, err := db.ExecContext(ctx, ` 271 INSERT INTO posts (uri, cid, rkey, author_did, community_did, title, created_at, score, upvote_count) 272 VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9) 273 `, uri, "bafytest", rkey, authorDID, communityDID, title, createdAt, score, score) 274 if err != nil { 275 t.Fatalf("Failed to create test post: %v", err) 276 } 277 278 return uri 279} 280 281// MockSessionUnsealer is a mock implementation of SessionUnsealer for testing 282// It returns predefined sessions based on token value 283type MockSessionUnsealer struct { 284 sessions map[string]*oauth.SealedSession 285} 286 287// NewMockSessionUnsealer creates a new mock unsealer 288func NewMockSessionUnsealer() *MockSessionUnsealer { 289 return &MockSessionUnsealer{ 290 sessions: make(map[string]*oauth.SealedSession), 291 } 292} 293 294// AddSession adds a token -> session mapping 295func (m *MockSessionUnsealer) AddSession(token, did, sessionID string) { 296 m.sessions[token] = &oauth.SealedSession{ 297 DID: did, 298 SessionID: sessionID, 299 ExpiresAt: time.Now().Add(1 * time.Hour).Unix(), 300 } 301} 302 303// UnsealSession returns the predefined session for a token 304func (m *MockSessionUnsealer) UnsealSession(token string) (*oauth.SealedSession, error) { 305 if sess, ok := m.sessions[token]; ok { 306 return sess, nil 307 } 308 return nil, fmt.Errorf("unknown token") 309} 310 311// MockOAuthStore is a mock implementation of ClientAuthStore for testing 312type MockOAuthStore struct { 313 sessions map[string]*oauthlib.ClientSessionData 314} 315 316// NewMockOAuthStore creates a new mock OAuth store 317func NewMockOAuthStore() *MockOAuthStore { 318 return &MockOAuthStore{ 319 sessions: make(map[string]*oauthlib.ClientSessionData), 320 } 321} 322 323// AddSession adds a session to the store 324func (m *MockOAuthStore) AddSession(did, sessionID, accessToken string) { 325 m.AddSessionWithPDS(did, sessionID, accessToken, getTestPDSURL()) 326} 327 328// AddSessionWithPDS adds a session to the store with a specific PDS URL 329func (m *MockOAuthStore) AddSessionWithPDS(did, sessionID, accessToken, pdsURL string) { 330 key := did + ":" + sessionID 331 parsedDID, _ := syntax.ParseDID(did) 332 m.sessions[key] = &oauthlib.ClientSessionData{ 333 AccountDID: parsedDID, 334 SessionID: sessionID, 335 AccessToken: accessToken, 336 HostURL: pdsURL, 337 } 338} 339 340// GetSession implements ClientAuthStore 341func (m *MockOAuthStore) GetSession(ctx context.Context, did syntax.DID, sessionID string) (*oauthlib.ClientSessionData, error) { 342 key := did.String() + ":" + sessionID 343 if sess, ok := m.sessions[key]; ok { 344 return sess, nil 345 } 346 return nil, fmt.Errorf("session not found") 347} 348 349// SaveSession implements ClientAuthStore 350func (m *MockOAuthStore) SaveSession(ctx context.Context, sess oauthlib.ClientSessionData) error { 351 key := sess.AccountDID.String() + ":" + sess.SessionID 352 m.sessions[key] = &sess 353 return nil 354} 355 356// DeleteSession implements ClientAuthStore 357func (m *MockOAuthStore) DeleteSession(ctx context.Context, did syntax.DID, sessionID string) error { 358 key := did.String() + ":" + sessionID 359 delete(m.sessions, key) 360 return nil 361} 362 363// GetAuthRequestInfo implements ClientAuthStore 364func (m *MockOAuthStore) GetAuthRequestInfo(ctx context.Context, state string) (*oauthlib.AuthRequestData, error) { 365 return nil, fmt.Errorf("not implemented in mock") 366} 367 368// SaveAuthRequestInfo implements ClientAuthStore 369func (m *MockOAuthStore) SaveAuthRequestInfo(ctx context.Context, info oauthlib.AuthRequestData) error { 370 return nil 371} 372 373// DeleteAuthRequestInfo implements ClientAuthStore 374func (m *MockOAuthStore) DeleteAuthRequestInfo(ctx context.Context, state string) error { 375 return nil 376} 377 378// CreateTestOAuthMiddleware creates an OAuth middleware with mock implementations for testing 379// The returned middleware accepts a test token that maps to the specified userDID 380func CreateTestOAuthMiddleware(userDID string) (*middleware.OAuthAuthMiddleware, string) { 381 unsealer := NewMockSessionUnsealer() 382 store := NewMockOAuthStore() 383 384 testToken := "test-token-" + userDID 385 sessionID := "test-session-123" 386 387 // Add the test session 388 unsealer.AddSession(testToken, userDID, sessionID) 389 store.AddSession(userDID, sessionID, "test-access-token") 390 391 authMiddleware := middleware.NewOAuthAuthMiddleware(unsealer, store) 392 return authMiddleware, testToken 393} 394 395// E2EOAuthMiddleware wraps OAuth middleware for E2E testing with multiple users 396type E2EOAuthMiddleware struct { 397 *middleware.OAuthAuthMiddleware 398 unsealer *MockSessionUnsealer 399 store *MockOAuthStore 400} 401 402// NewE2EOAuthMiddleware creates an OAuth middleware for E2E testing 403func NewE2EOAuthMiddleware() *E2EOAuthMiddleware { 404 unsealer := NewMockSessionUnsealer() 405 store := NewMockOAuthStore() 406 m := middleware.NewOAuthAuthMiddleware(unsealer, store) 407 return &E2EOAuthMiddleware{m, unsealer, store} 408} 409 410// AddUser registers a user DID and returns the token to use in Authorization header 411func (e *E2EOAuthMiddleware) AddUser(did string) string { 412 token := "test-token-" + did 413 sessionID := "session-" + did 414 e.unsealer.AddSession(token, did, sessionID) 415 e.store.AddSession(did, sessionID, "access-token-"+did) 416 return token 417} 418 419// AddUserWithPDSToken registers a user with their real PDS access token 420// Use this for E2E tests that need to write to the real PDS 421func (e *E2EOAuthMiddleware) AddUserWithPDSToken(did, pdsAccessToken, pdsURL string) string { 422 token := "test-token-" + did 423 sessionID := "session-" + did 424 e.unsealer.AddSession(token, did, sessionID) 425 e.store.AddSessionWithPDS(did, sessionID, pdsAccessToken, pdsURL) 426 return token 427}