A community based topic aggregation platform built on atproto
1package integration
2
3import (
4 "Coves/internal/atproto/auth"
5 "Coves/internal/core/users"
6 "bytes"
7 "context"
8 "database/sql"
9 "encoding/base64"
10 "encoding/json"
11 "fmt"
12 "io"
13 "net/http"
14 "strings"
15 "testing"
16 "time"
17
18 "github.com/golang-jwt/jwt/v5"
19)
20
21// createTestUser creates a test user in the database for use in integration tests
22// Returns the created user or fails the test
23func createTestUser(t *testing.T, db *sql.DB, handle, did string) *users.User {
24 t.Helper()
25
26 ctx := context.Background()
27
28 // Create user directly in DB for speed
29 query := `
30 INSERT INTO users (did, handle, pds_url, created_at, updated_at)
31 VALUES ($1, $2, $3, NOW(), NOW())
32 RETURNING did, handle, pds_url, created_at, updated_at
33 `
34
35 user := &users.User{}
36 err := db.QueryRowContext(ctx, query, did, handle, "http://localhost:3001").Scan(
37 &user.DID,
38 &user.Handle,
39 &user.PDSURL,
40 &user.CreatedAt,
41 &user.UpdatedAt,
42 )
43 if err != nil {
44 t.Fatalf("Failed to create test user: %v", err)
45 }
46
47 return user
48}
49
50// contains checks if string s contains substring substr
51// Helper for error message assertions
52func contains(s, substr string) bool {
53 return strings.Contains(s, substr)
54}
55
56// authenticateWithPDS authenticates with PDS to get access token and DID
57// Used for setting up test environments that need PDS credentials
58func authenticateWithPDS(pdsURL, handle, password string) (string, string, error) {
59 // Call com.atproto.server.createSession
60 sessionReq := map[string]string{
61 "identifier": handle,
62 "password": password,
63 }
64
65 reqBody, marshalErr := json.Marshal(sessionReq)
66 if marshalErr != nil {
67 return "", "", fmt.Errorf("failed to marshal session request: %w", marshalErr)
68 }
69 resp, err := http.Post(
70 pdsURL+"/xrpc/com.atproto.server.createSession",
71 "application/json",
72 bytes.NewBuffer(reqBody),
73 )
74 if err != nil {
75 return "", "", fmt.Errorf("failed to create session: %w", err)
76 }
77 defer func() { _ = resp.Body.Close() }()
78
79 if resp.StatusCode != http.StatusOK {
80 body, readErr := io.ReadAll(resp.Body)
81 if readErr != nil {
82 return "", "", fmt.Errorf("PDS auth failed (status %d, failed to read body: %w)", resp.StatusCode, readErr)
83 }
84 return "", "", fmt.Errorf("PDS auth failed (status %d): %s", resp.StatusCode, string(body))
85 }
86
87 var sessionResp struct {
88 AccessJwt string `json:"accessJwt"`
89 DID string `json:"did"`
90 }
91
92 if err := json.NewDecoder(resp.Body).Decode(&sessionResp); err != nil {
93 return "", "", fmt.Errorf("failed to decode session response: %w", err)
94 }
95
96 return sessionResp.AccessJwt, sessionResp.DID, nil
97}
98
99// createSimpleTestJWT creates a minimal JWT for testing (Phase 1 - no signature)
100// In production, this would be a real OAuth token from PDS with proper signatures
101func createSimpleTestJWT(userDID string) string {
102 // Create minimal JWT claims using RegisteredClaims
103 // Use userDID as issuer since we don't have a proper PDS DID for testing
104 claims := auth.Claims{
105 RegisteredClaims: jwt.RegisteredClaims{
106 Subject: userDID,
107 Issuer: userDID, // Use DID as issuer for testing (valid per atProto)
108 Audience: jwt.ClaimStrings{"did:web:test.coves.social"},
109 IssuedAt: jwt.NewNumericDate(time.Now()),
110 ExpiresAt: jwt.NewNumericDate(time.Now().Add(1 * time.Hour)),
111 },
112 Scope: "com.atproto.access",
113 }
114
115 // For Phase 1 testing, we create an unsigned JWT
116 // The middleware is configured with skipVerify=true for testing
117 header := map[string]interface{}{
118 "alg": "none",
119 "typ": "JWT",
120 }
121
122 headerJSON, _ := json.Marshal(header)
123 claimsJSON, _ := json.Marshal(claims)
124
125 // Base64url encode (without padding)
126 headerB64 := base64.RawURLEncoding.EncodeToString(headerJSON)
127 claimsB64 := base64.RawURLEncoding.EncodeToString(claimsJSON)
128
129 // For "alg: none", signature is empty
130 return headerB64 + "." + claimsB64 + "."
131}
132
133// generateTID generates a simple timestamp-based identifier for testing
134// In production, PDS generates proper TIDs
135func generateTID() string {
136 return fmt.Sprintf("3k%d", time.Now().UnixNano()/1000)
137}
138
139// createPDSAccount creates a new account on PDS and returns access token + DID
140// This is used for E2E tests that need real PDS accounts
141func createPDSAccount(pdsURL, handle, email, password string) (accessToken, did string, err error) {
142 // Call com.atproto.server.createAccount
143 reqBody := map[string]string{
144 "handle": handle,
145 "email": email,
146 "password": password,
147 }
148
149 reqJSON, marshalErr := json.Marshal(reqBody)
150 if marshalErr != nil {
151 return "", "", fmt.Errorf("failed to marshal account request: %w", marshalErr)
152 }
153
154 resp, httpErr := http.Post(
155 pdsURL+"/xrpc/com.atproto.server.createAccount",
156 "application/json",
157 bytes.NewBuffer(reqJSON),
158 )
159 if httpErr != nil {
160 return "", "", fmt.Errorf("failed to create account: %w", httpErr)
161 }
162 defer func() { _ = resp.Body.Close() }()
163
164 if resp.StatusCode != http.StatusOK {
165 body, readErr := io.ReadAll(resp.Body)
166 if readErr != nil {
167 return "", "", fmt.Errorf("account creation failed (status %d, failed to read body: %w)", resp.StatusCode, readErr)
168 }
169 return "", "", fmt.Errorf("account creation failed (status %d): %s", resp.StatusCode, string(body))
170 }
171
172 var accountResp struct {
173 AccessJwt string `json:"accessJwt"`
174 DID string `json:"did"`
175 }
176
177 if decodeErr := json.NewDecoder(resp.Body).Decode(&accountResp); decodeErr != nil {
178 return "", "", fmt.Errorf("failed to decode account response: %w", decodeErr)
179 }
180
181 return accountResp.AccessJwt, accountResp.DID, nil
182}
183
184// writePDSRecord writes a record to PDS via com.atproto.repo.createRecord
185// Returns the AT-URI and CID of the created record
186func writePDSRecord(pdsURL, accessToken, repo, collection, rkey string, record interface{}) (uri, cid string, err error) {
187 reqBody := map[string]interface{}{
188 "repo": repo,
189 "collection": collection,
190 "record": record,
191 }
192
193 // If rkey is provided, include it
194 if rkey != "" {
195 reqBody["rkey"] = rkey
196 }
197
198 reqJSON, marshalErr := json.Marshal(reqBody)
199 if marshalErr != nil {
200 return "", "", fmt.Errorf("failed to marshal record request: %w", marshalErr)
201 }
202
203 req, reqErr := http.NewRequest("POST", pdsURL+"/xrpc/com.atproto.repo.createRecord", bytes.NewBuffer(reqJSON))
204 if reqErr != nil {
205 return "", "", fmt.Errorf("failed to create request: %w", reqErr)
206 }
207
208 req.Header.Set("Content-Type", "application/json")
209 req.Header.Set("Authorization", "Bearer "+accessToken)
210
211 resp, httpErr := http.DefaultClient.Do(req)
212 if httpErr != nil {
213 return "", "", fmt.Errorf("failed to write record: %w", httpErr)
214 }
215 defer func() { _ = resp.Body.Close() }()
216
217 if resp.StatusCode != http.StatusOK {
218 body, readErr := io.ReadAll(resp.Body)
219 if readErr != nil {
220 return "", "", fmt.Errorf("record creation failed (status %d, failed to read body: %w)", resp.StatusCode, readErr)
221 }
222 return "", "", fmt.Errorf("record creation failed (status %d): %s", resp.StatusCode, string(body))
223 }
224
225 var recordResp struct {
226 URI string `json:"uri"`
227 CID string `json:"cid"`
228 }
229
230 if decodeErr := json.NewDecoder(resp.Body).Decode(&recordResp); decodeErr != nil {
231 return "", "", fmt.Errorf("failed to decode record response: %w", decodeErr)
232 }
233
234 return recordResp.URI, recordResp.CID, nil
235}