A community based topic aggregation platform built on atproto
1package integration 2 3import ( 4 "Coves/internal/api/middleware" 5 "Coves/internal/atproto/oauth" 6 "Coves/internal/atproto/pds" 7 "Coves/internal/core/users" 8 "Coves/internal/core/votes" 9 "bytes" 10 "context" 11 "database/sql" 12 "encoding/json" 13 "fmt" 14 "io" 15 "net/http" 16 "os" 17 "strings" 18 "testing" 19 "time" 20 21 oauthlib "github.com/bluesky-social/indigo/atproto/auth/oauth" 22 "github.com/bluesky-social/indigo/atproto/syntax" 23) 24 25// getTestPDSURL returns the PDS URL for testing from env var or default 26func getTestPDSURL() string { 27 pdsURL := os.Getenv("PDS_URL") 28 if pdsURL == "" { 29 pdsURL = "http://localhost:3001" 30 } 31 return pdsURL 32} 33 34// getTestInstanceDID returns the instance DID for testing from env var or default 35func getTestInstanceDID() string { 36 instanceDID := os.Getenv("INSTANCE_DID") 37 if instanceDID == "" { 38 instanceDID = "did:web:test.coves.social" 39 } 40 return instanceDID 41} 42 43// createTestUser creates a test user in the database for use in integration tests 44// Returns the created user or fails the test 45func createTestUser(t *testing.T, db *sql.DB, handle, did string) *users.User { 46 t.Helper() 47 48 ctx := context.Background() 49 50 // Create user directly in DB for speed 51 query := ` 52 INSERT INTO users (did, handle, pds_url, created_at, updated_at) 53 VALUES ($1, $2, $3, NOW(), NOW()) 54 RETURNING did, handle, pds_url, created_at, updated_at 55 ` 56 57 user := &users.User{} 58 err := db.QueryRowContext(ctx, query, did, handle, getTestPDSURL()).Scan( 59 &user.DID, 60 &user.Handle, 61 &user.PDSURL, 62 &user.CreatedAt, 63 &user.UpdatedAt, 64 ) 65 if err != nil { 66 t.Fatalf("Failed to create test user: %v", err) 67 } 68 69 return user 70} 71 72// contains checks if string s contains substring substr 73// Helper for error message assertions 74func contains(s, substr string) bool { 75 return strings.Contains(s, substr) 76} 77 78// authenticateWithPDS authenticates with PDS to get access token and DID 79// Used for setting up test environments that need PDS credentials 80func authenticateWithPDS(pdsURL, handle, password string) (string, string, error) { 81 // Call com.atproto.server.createSession 82 sessionReq := map[string]string{ 83 "identifier": handle, 84 "password": password, 85 } 86 87 reqBody, marshalErr := json.Marshal(sessionReq) 88 if marshalErr != nil { 89 return "", "", fmt.Errorf("failed to marshal session request: %w", marshalErr) 90 } 91 resp, err := http.Post( 92 pdsURL+"/xrpc/com.atproto.server.createSession", 93 "application/json", 94 bytes.NewBuffer(reqBody), 95 ) 96 if err != nil { 97 return "", "", fmt.Errorf("failed to create session: %w", err) 98 } 99 defer func() { _ = resp.Body.Close() }() 100 101 if resp.StatusCode != http.StatusOK { 102 body, readErr := io.ReadAll(resp.Body) 103 if readErr != nil { 104 return "", "", fmt.Errorf("PDS auth failed (status %d, failed to read body: %w)", resp.StatusCode, readErr) 105 } 106 return "", "", fmt.Errorf("PDS auth failed (status %d): %s", resp.StatusCode, string(body)) 107 } 108 109 var sessionResp struct { 110 AccessJwt string `json:"accessJwt"` 111 DID string `json:"did"` 112 } 113 114 if err := json.NewDecoder(resp.Body).Decode(&sessionResp); err != nil { 115 return "", "", fmt.Errorf("failed to decode session response: %w", err) 116 } 117 118 return sessionResp.AccessJwt, sessionResp.DID, nil 119} 120 121// generateTID generates a simple timestamp-based identifier for testing 122// In production, PDS generates proper TIDs 123func generateTID() string { 124 return fmt.Sprintf("3k%d", time.Now().UnixNano()/1000) 125} 126 127// createPDSAccount creates a new account on PDS and returns access token + DID 128// This is used for E2E tests that need real PDS accounts 129func createPDSAccount(pdsURL, handle, email, password string) (accessToken, did string, err error) { 130 // Call com.atproto.server.createAccount 131 reqBody := map[string]string{ 132 "handle": handle, 133 "email": email, 134 "password": password, 135 } 136 137 reqJSON, marshalErr := json.Marshal(reqBody) 138 if marshalErr != nil { 139 return "", "", fmt.Errorf("failed to marshal account request: %w", marshalErr) 140 } 141 142 resp, httpErr := http.Post( 143 pdsURL+"/xrpc/com.atproto.server.createAccount", 144 "application/json", 145 bytes.NewBuffer(reqJSON), 146 ) 147 if httpErr != nil { 148 return "", "", fmt.Errorf("failed to create account: %w", httpErr) 149 } 150 defer func() { _ = resp.Body.Close() }() 151 152 if resp.StatusCode != http.StatusOK { 153 body, readErr := io.ReadAll(resp.Body) 154 if readErr != nil { 155 return "", "", fmt.Errorf("account creation failed (status %d, failed to read body: %w)", resp.StatusCode, readErr) 156 } 157 return "", "", fmt.Errorf("account creation failed (status %d): %s", resp.StatusCode, string(body)) 158 } 159 160 var accountResp struct { 161 AccessJwt string `json:"accessJwt"` 162 DID string `json:"did"` 163 } 164 165 if decodeErr := json.NewDecoder(resp.Body).Decode(&accountResp); decodeErr != nil { 166 return "", "", fmt.Errorf("failed to decode account response: %w", decodeErr) 167 } 168 169 return accountResp.AccessJwt, accountResp.DID, nil 170} 171 172// writePDSRecord writes a record to PDS via com.atproto.repo.createRecord 173// Returns the AT-URI and CID of the created record 174func writePDSRecord(pdsURL, accessToken, repo, collection, rkey string, record interface{}) (uri, cid string, err error) { 175 reqBody := map[string]interface{}{ 176 "repo": repo, 177 "collection": collection, 178 "record": record, 179 } 180 181 // If rkey is provided, include it 182 if rkey != "" { 183 reqBody["rkey"] = rkey 184 } 185 186 reqJSON, marshalErr := json.Marshal(reqBody) 187 if marshalErr != nil { 188 return "", "", fmt.Errorf("failed to marshal record request: %w", marshalErr) 189 } 190 191 req, reqErr := http.NewRequest("POST", pdsURL+"/xrpc/com.atproto.repo.createRecord", bytes.NewBuffer(reqJSON)) 192 if reqErr != nil { 193 return "", "", fmt.Errorf("failed to create request: %w", reqErr) 194 } 195 196 req.Header.Set("Content-Type", "application/json") 197 req.Header.Set("Authorization", "Bearer "+accessToken) 198 199 resp, httpErr := http.DefaultClient.Do(req) 200 if httpErr != nil { 201 return "", "", fmt.Errorf("failed to write record: %w", httpErr) 202 } 203 defer func() { _ = resp.Body.Close() }() 204 205 if resp.StatusCode != http.StatusOK { 206 body, readErr := io.ReadAll(resp.Body) 207 if readErr != nil { 208 return "", "", fmt.Errorf("record creation failed (status %d, failed to read body: %w)", resp.StatusCode, readErr) 209 } 210 return "", "", fmt.Errorf("record creation failed (status %d): %s", resp.StatusCode, string(body)) 211 } 212 213 var recordResp struct { 214 URI string `json:"uri"` 215 CID string `json:"cid"` 216 } 217 218 if decodeErr := json.NewDecoder(resp.Body).Decode(&recordResp); decodeErr != nil { 219 return "", "", fmt.Errorf("failed to decode record response: %w", decodeErr) 220 } 221 222 return recordResp.URI, recordResp.CID, nil 223} 224 225// createFeedTestCommunity creates a test community for feed tests 226// Returns the community DID or an error 227func createFeedTestCommunity(db *sql.DB, ctx context.Context, name, ownerHandle string) (string, error) { 228 // Get configuration from env vars 229 pdsURL := getTestPDSURL() 230 instanceDID := getTestInstanceDID() 231 232 // Create owner user first (directly insert to avoid service dependencies) 233 ownerDID := fmt.Sprintf("did:plc:%s", ownerHandle) 234 _, err := db.ExecContext(ctx, ` 235 INSERT INTO users (did, handle, pds_url, created_at) 236 VALUES ($1, $2, $3, NOW()) 237 ON CONFLICT (did) DO NOTHING 238 `, ownerDID, ownerHandle, pdsURL) 239 if err != nil { 240 return "", err 241 } 242 243 // Create community 244 communityDID := fmt.Sprintf("did:plc:community-%s", name) 245 _, err = db.ExecContext(ctx, ` 246 INSERT INTO communities (did, name, owner_did, created_by_did, hosted_by_did, handle, pds_url, created_at) 247 VALUES ($1, $2, $3, $4, $5, $6, $7, NOW()) 248 ON CONFLICT (did) DO NOTHING 249 `, communityDID, name, ownerDID, ownerDID, instanceDID, fmt.Sprintf("%s.coves.social", name), pdsURL) 250 251 return communityDID, err 252} 253 254// createTestPost creates a test post and returns its URI 255func createTestPost(t *testing.T, db *sql.DB, communityDID, authorDID, title string, score int, createdAt time.Time) string { 256 t.Helper() 257 258 ctx := context.Background() 259 260 // Create author user if not exists (directly insert to avoid service dependencies) 261 _, _ = db.ExecContext(ctx, ` 262 INSERT INTO users (did, handle, pds_url, created_at) 263 VALUES ($1, $2, $3, NOW()) 264 ON CONFLICT (did) DO NOTHING 265 `, authorDID, fmt.Sprintf("%s.bsky.social", authorDID), getTestPDSURL()) 266 267 // Generate URI 268 rkey := fmt.Sprintf("post-%d", time.Now().UnixNano()) 269 uri := fmt.Sprintf("at://%s/social.coves.community.post/%s", communityDID, rkey) 270 271 // Insert post 272 _, err := db.ExecContext(ctx, ` 273 INSERT INTO posts (uri, cid, rkey, author_did, community_did, title, created_at, score, upvote_count) 274 VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9) 275 `, uri, "bafytest", rkey, authorDID, communityDID, title, createdAt, score, score) 276 if err != nil { 277 t.Fatalf("Failed to create test post: %v", err) 278 } 279 280 return uri 281} 282 283// MockSessionUnsealer is a mock implementation of SessionUnsealer for testing 284// It returns predefined sessions based on token value 285type MockSessionUnsealer struct { 286 sessions map[string]*oauth.SealedSession 287} 288 289// NewMockSessionUnsealer creates a new mock unsealer 290func NewMockSessionUnsealer() *MockSessionUnsealer { 291 return &MockSessionUnsealer{ 292 sessions: make(map[string]*oauth.SealedSession), 293 } 294} 295 296// AddSession adds a token -> session mapping 297func (m *MockSessionUnsealer) AddSession(token, did, sessionID string) { 298 m.sessions[token] = &oauth.SealedSession{ 299 DID: did, 300 SessionID: sessionID, 301 ExpiresAt: time.Now().Add(1 * time.Hour).Unix(), 302 } 303} 304 305// UnsealSession returns the predefined session for a token 306func (m *MockSessionUnsealer) UnsealSession(token string) (*oauth.SealedSession, error) { 307 if sess, ok := m.sessions[token]; ok { 308 return sess, nil 309 } 310 return nil, fmt.Errorf("unknown token") 311} 312 313// MockOAuthStore is a mock implementation of ClientAuthStore for testing 314type MockOAuthStore struct { 315 sessions map[string]*oauthlib.ClientSessionData 316} 317 318// NewMockOAuthStore creates a new mock OAuth store 319func NewMockOAuthStore() *MockOAuthStore { 320 return &MockOAuthStore{ 321 sessions: make(map[string]*oauthlib.ClientSessionData), 322 } 323} 324 325// AddSession adds a session to the store 326func (m *MockOAuthStore) AddSession(did, sessionID, accessToken string) { 327 m.AddSessionWithPDS(did, sessionID, accessToken, getTestPDSURL()) 328} 329 330// AddSessionWithPDS adds a session to the store with a specific PDS URL 331func (m *MockOAuthStore) AddSessionWithPDS(did, sessionID, accessToken, pdsURL string) { 332 key := did + ":" + sessionID 333 parsedDID, _ := syntax.ParseDID(did) 334 m.sessions[key] = &oauthlib.ClientSessionData{ 335 AccountDID: parsedDID, 336 SessionID: sessionID, 337 AccessToken: accessToken, 338 HostURL: pdsURL, 339 } 340} 341 342// GetSession implements ClientAuthStore 343func (m *MockOAuthStore) GetSession(ctx context.Context, did syntax.DID, sessionID string) (*oauthlib.ClientSessionData, error) { 344 key := did.String() + ":" + sessionID 345 if sess, ok := m.sessions[key]; ok { 346 return sess, nil 347 } 348 return nil, fmt.Errorf("session not found") 349} 350 351// SaveSession implements ClientAuthStore 352func (m *MockOAuthStore) SaveSession(ctx context.Context, sess oauthlib.ClientSessionData) error { 353 key := sess.AccountDID.String() + ":" + sess.SessionID 354 m.sessions[key] = &sess 355 return nil 356} 357 358// DeleteSession implements ClientAuthStore 359func (m *MockOAuthStore) DeleteSession(ctx context.Context, did syntax.DID, sessionID string) error { 360 key := did.String() + ":" + sessionID 361 delete(m.sessions, key) 362 return nil 363} 364 365// GetAuthRequestInfo implements ClientAuthStore 366func (m *MockOAuthStore) GetAuthRequestInfo(ctx context.Context, state string) (*oauthlib.AuthRequestData, error) { 367 return nil, fmt.Errorf("not implemented in mock") 368} 369 370// SaveAuthRequestInfo implements ClientAuthStore 371func (m *MockOAuthStore) SaveAuthRequestInfo(ctx context.Context, info oauthlib.AuthRequestData) error { 372 return nil 373} 374 375// DeleteAuthRequestInfo implements ClientAuthStore 376func (m *MockOAuthStore) DeleteAuthRequestInfo(ctx context.Context, state string) error { 377 return nil 378} 379 380// CreateTestOAuthMiddleware creates an OAuth middleware with mock implementations for testing 381// The returned middleware accepts a test token that maps to the specified userDID 382func CreateTestOAuthMiddleware(userDID string) (*middleware.OAuthAuthMiddleware, string) { 383 unsealer := NewMockSessionUnsealer() 384 store := NewMockOAuthStore() 385 386 testToken := "test-token-" + userDID 387 sessionID := "test-session-123" 388 389 // Add the test session 390 unsealer.AddSession(testToken, userDID, sessionID) 391 store.AddSession(userDID, sessionID, "test-access-token") 392 393 authMiddleware := middleware.NewOAuthAuthMiddleware(unsealer, store) 394 return authMiddleware, testToken 395} 396 397// E2EOAuthMiddleware wraps OAuth middleware for E2E testing with multiple users 398type E2EOAuthMiddleware struct { 399 *middleware.OAuthAuthMiddleware 400 unsealer *MockSessionUnsealer 401 store *MockOAuthStore 402} 403 404// NewE2EOAuthMiddleware creates an OAuth middleware for E2E testing 405func NewE2EOAuthMiddleware() *E2EOAuthMiddleware { 406 unsealer := NewMockSessionUnsealer() 407 store := NewMockOAuthStore() 408 m := middleware.NewOAuthAuthMiddleware(unsealer, store) 409 return &E2EOAuthMiddleware{m, unsealer, store} 410} 411 412// AddUser registers a user DID and returns the token to use in Authorization header 413func (e *E2EOAuthMiddleware) AddUser(did string) string { 414 token := "test-token-" + did 415 sessionID := "session-" + did 416 e.unsealer.AddSession(token, did, sessionID) 417 e.store.AddSession(did, sessionID, "access-token-"+did) 418 return token 419} 420 421// AddUserWithPDSToken registers a user with their real PDS access token 422// Use this for E2E tests that need to write to the real PDS 423func (e *E2EOAuthMiddleware) AddUserWithPDSToken(did, pdsAccessToken, pdsURL string) string { 424 token := "test-token-" + did 425 sessionID := "session-" + did 426 e.unsealer.AddSession(token, did, sessionID) 427 e.store.AddSessionWithPDS(did, sessionID, pdsAccessToken, pdsURL) 428 return token 429} 430 431// PasswordAuthPDSClientFactory creates a PDSClientFactory that uses password-based Bearer auth. 432// This is for E2E tests that use createSession instead of OAuth. 433// The factory extracts the access token and host URL from the session data. 434func PasswordAuthPDSClientFactory() votes.PDSClientFactory { 435 return func(ctx context.Context, session *oauthlib.ClientSessionData) (pds.Client, error) { 436 if session.AccessToken == "" { 437 return nil, fmt.Errorf("session has no access token") 438 } 439 if session.HostURL == "" { 440 return nil, fmt.Errorf("session has no host URL") 441 } 442 443 return pds.NewFromAccessToken(session.HostURL, session.AccountDID.String(), session.AccessToken) 444 } 445}