A community based topic aggregation platform built on atproto
1package oauth
2
3import (
4 "context"
5 "database/sql"
6 "os"
7 "testing"
8
9 "github.com/bluesky-social/indigo/atproto/auth/oauth"
10 "github.com/bluesky-social/indigo/atproto/syntax"
11 _ "github.com/lib/pq"
12 "github.com/pressly/goose/v3"
13 "github.com/stretchr/testify/assert"
14 "github.com/stretchr/testify/require"
15)
16
17// setupTestDB creates a test database connection and runs migrations
18func setupTestDB(t *testing.T) *sql.DB {
19 dsn := os.Getenv("TEST_DATABASE_URL")
20 if dsn == "" {
21 dsn = "postgres://test_user:test_password@localhost:5434/coves_test?sslmode=disable"
22 }
23
24 db, err := sql.Open("postgres", dsn)
25 require.NoError(t, err, "Failed to connect to test database")
26
27 // Run migrations
28 require.NoError(t, goose.Up(db, "../../db/migrations"), "Failed to run migrations")
29
30 return db
31}
32
33// cleanupOAuth removes all test OAuth data from the database
34func cleanupOAuth(t *testing.T, db *sql.DB) {
35 _, err := db.Exec("DELETE FROM oauth_sessions WHERE did LIKE 'did:plc:test%'")
36 require.NoError(t, err, "Failed to cleanup oauth_sessions")
37
38 _, err = db.Exec("DELETE FROM oauth_requests WHERE state LIKE 'test%'")
39 require.NoError(t, err, "Failed to cleanup oauth_requests")
40}
41
42func TestPostgresOAuthStore_SaveAndGetSession(t *testing.T) {
43 db := setupTestDB(t)
44 defer func() { _ = db.Close() }()
45 defer cleanupOAuth(t, db)
46
47 store := NewPostgresOAuthStore(db, 0) // Use default TTL
48 ctx := context.Background()
49
50 did, err := syntax.ParseDID("did:plc:test123abc")
51 require.NoError(t, err)
52
53 session := oauth.ClientSessionData{
54 AccountDID: did,
55 SessionID: "session123",
56 HostURL: "https://pds.example.com",
57 AuthServerURL: "https://auth.example.com",
58 AuthServerTokenEndpoint: "https://auth.example.com/oauth/token",
59 AuthServerRevocationEndpoint: "https://auth.example.com/oauth/revoke",
60 Scopes: []string{"atproto"},
61 AccessToken: "at_test_token_abc123",
62 RefreshToken: "rt_test_token_xyz789",
63 DPoPAuthServerNonce: "nonce_auth_123",
64 DPoPHostNonce: "nonce_host_456",
65 DPoPPrivateKeyMultibase: "z6MkpTHR8VNsBxYAAWHut2Geadd9jSwuBV8xRoAnwWsdvktH",
66 }
67
68 // Save session
69 err = store.SaveSession(ctx, session)
70 assert.NoError(t, err)
71
72 // Retrieve session
73 retrieved, err := store.GetSession(ctx, did, "session123")
74 assert.NoError(t, err)
75 assert.NotNil(t, retrieved)
76 assert.Equal(t, session.AccountDID.String(), retrieved.AccountDID.String())
77 assert.Equal(t, session.SessionID, retrieved.SessionID)
78 assert.Equal(t, session.HostURL, retrieved.HostURL)
79 assert.Equal(t, session.AuthServerURL, retrieved.AuthServerURL)
80 assert.Equal(t, session.AuthServerTokenEndpoint, retrieved.AuthServerTokenEndpoint)
81 assert.Equal(t, session.AccessToken, retrieved.AccessToken)
82 assert.Equal(t, session.RefreshToken, retrieved.RefreshToken)
83 assert.Equal(t, session.DPoPAuthServerNonce, retrieved.DPoPAuthServerNonce)
84 assert.Equal(t, session.DPoPHostNonce, retrieved.DPoPHostNonce)
85 assert.Equal(t, session.DPoPPrivateKeyMultibase, retrieved.DPoPPrivateKeyMultibase)
86 assert.Equal(t, session.Scopes, retrieved.Scopes)
87}
88
89func TestPostgresOAuthStore_SaveSession_Upsert(t *testing.T) {
90 db := setupTestDB(t)
91 defer func() { _ = db.Close() }()
92 defer cleanupOAuth(t, db)
93
94 store := NewPostgresOAuthStore(db, 0) // Use default TTL
95 ctx := context.Background()
96
97 did, err := syntax.ParseDID("did:plc:testupsert")
98 require.NoError(t, err)
99
100 // Initial session
101 session1 := oauth.ClientSessionData{
102 AccountDID: did,
103 SessionID: "session_upsert",
104 HostURL: "https://pds1.example.com",
105 AuthServerURL: "https://auth1.example.com",
106 AuthServerTokenEndpoint: "https://auth1.example.com/oauth/token",
107 Scopes: []string{"atproto"},
108 AccessToken: "old_access_token",
109 RefreshToken: "old_refresh_token",
110 DPoPPrivateKeyMultibase: "z6MkpTHR8VNsBxYAAWHut2Geadd9jSwuBV8xRoAnwWsdvktH",
111 }
112
113 err = store.SaveSession(ctx, session1)
114 require.NoError(t, err)
115
116 // Updated session (same DID and session ID)
117 session2 := oauth.ClientSessionData{
118 AccountDID: did,
119 SessionID: "session_upsert",
120 HostURL: "https://pds2.example.com",
121 AuthServerURL: "https://auth2.example.com",
122 AuthServerTokenEndpoint: "https://auth2.example.com/oauth/token",
123 Scopes: []string{"atproto", "transition:generic"},
124 AccessToken: "new_access_token",
125 RefreshToken: "new_refresh_token",
126 DPoPPrivateKeyMultibase: "z6MkpTHR8VNsBxYAAWHut2Geadd9jSwuBV8xRoAnwWsdvktX",
127 }
128
129 // Save again - should update
130 err = store.SaveSession(ctx, session2)
131 assert.NoError(t, err)
132
133 // Retrieve should get updated values
134 retrieved, err := store.GetSession(ctx, did, "session_upsert")
135 assert.NoError(t, err)
136 assert.Equal(t, "new_access_token", retrieved.AccessToken)
137 assert.Equal(t, "new_refresh_token", retrieved.RefreshToken)
138 assert.Equal(t, "https://pds2.example.com", retrieved.HostURL)
139 assert.Equal(t, []string{"atproto", "transition:generic"}, retrieved.Scopes)
140}
141
142func TestPostgresOAuthStore_GetSession_NotFound(t *testing.T) {
143 db := setupTestDB(t)
144 defer func() { _ = db.Close() }()
145
146 store := NewPostgresOAuthStore(db, 0) // Use default TTL
147 ctx := context.Background()
148
149 did, err := syntax.ParseDID("did:plc:nonexistent")
150 require.NoError(t, err)
151
152 _, err = store.GetSession(ctx, did, "nonexistent_session")
153 assert.ErrorIs(t, err, ErrSessionNotFound)
154}
155
156func TestPostgresOAuthStore_DeleteSession(t *testing.T) {
157 db := setupTestDB(t)
158 defer func() { _ = db.Close() }()
159 defer cleanupOAuth(t, db)
160
161 store := NewPostgresOAuthStore(db, 0) // Use default TTL
162 ctx := context.Background()
163
164 did, err := syntax.ParseDID("did:plc:testdelete")
165 require.NoError(t, err)
166
167 session := oauth.ClientSessionData{
168 AccountDID: did,
169 SessionID: "session_delete",
170 HostURL: "https://pds.example.com",
171 AuthServerURL: "https://auth.example.com",
172 AuthServerTokenEndpoint: "https://auth.example.com/oauth/token",
173 Scopes: []string{"atproto"},
174 AccessToken: "test_token",
175 RefreshToken: "test_refresh",
176 DPoPPrivateKeyMultibase: "z6MkpTHR8VNsBxYAAWHut2Geadd9jSwuBV8xRoAnwWsdvktH",
177 }
178
179 // Save session
180 err = store.SaveSession(ctx, session)
181 require.NoError(t, err)
182
183 // Delete session
184 err = store.DeleteSession(ctx, did, "session_delete")
185 assert.NoError(t, err)
186
187 // Verify session is gone
188 _, err = store.GetSession(ctx, did, "session_delete")
189 assert.ErrorIs(t, err, ErrSessionNotFound)
190}
191
192func TestPostgresOAuthStore_DeleteSession_NotFound(t *testing.T) {
193 db := setupTestDB(t)
194 defer func() { _ = db.Close() }()
195
196 store := NewPostgresOAuthStore(db, 0) // Use default TTL
197 ctx := context.Background()
198
199 did, err := syntax.ParseDID("did:plc:nonexistent")
200 require.NoError(t, err)
201
202 err = store.DeleteSession(ctx, did, "nonexistent_session")
203 assert.ErrorIs(t, err, ErrSessionNotFound)
204}
205
206func TestPostgresOAuthStore_SaveAndGetAuthRequestInfo(t *testing.T) {
207 db := setupTestDB(t)
208 defer func() { _ = db.Close() }()
209 defer cleanupOAuth(t, db)
210
211 store := NewPostgresOAuthStore(db, 0) // Use default TTL
212 ctx := context.Background()
213
214 did, err := syntax.ParseDID("did:plc:testrequest")
215 require.NoError(t, err)
216
217 info := oauth.AuthRequestData{
218 State: "test_state_abc123",
219 AuthServerURL: "https://auth.example.com",
220 AccountDID: &did,
221 Scopes: []string{"atproto"},
222 RequestURI: "urn:ietf:params:oauth:request_uri:abc123",
223 AuthServerTokenEndpoint: "https://auth.example.com/oauth/token",
224 AuthServerRevocationEndpoint: "https://auth.example.com/oauth/revoke",
225 PKCEVerifier: "verifier_xyz789",
226 DPoPAuthServerNonce: "nonce_abc",
227 DPoPPrivateKeyMultibase: "z6MkpTHR8VNsBxYAAWHut2Geadd9jSwuBV8xRoAnwWsdvktH",
228 }
229
230 // Save auth request info
231 err = store.SaveAuthRequestInfo(ctx, info)
232 assert.NoError(t, err)
233
234 // Retrieve auth request info
235 retrieved, err := store.GetAuthRequestInfo(ctx, "test_state_abc123")
236 assert.NoError(t, err)
237 assert.NotNil(t, retrieved)
238 assert.Equal(t, info.State, retrieved.State)
239 assert.Equal(t, info.AuthServerURL, retrieved.AuthServerURL)
240 assert.NotNil(t, retrieved.AccountDID)
241 assert.Equal(t, info.AccountDID.String(), retrieved.AccountDID.String())
242 assert.Equal(t, info.Scopes, retrieved.Scopes)
243 assert.Equal(t, info.RequestURI, retrieved.RequestURI)
244 assert.Equal(t, info.AuthServerTokenEndpoint, retrieved.AuthServerTokenEndpoint)
245 assert.Equal(t, info.PKCEVerifier, retrieved.PKCEVerifier)
246 assert.Equal(t, info.DPoPAuthServerNonce, retrieved.DPoPAuthServerNonce)
247 assert.Equal(t, info.DPoPPrivateKeyMultibase, retrieved.DPoPPrivateKeyMultibase)
248}
249
250func TestPostgresOAuthStore_SaveAuthRequestInfo_NoDID(t *testing.T) {
251 db := setupTestDB(t)
252 defer func() { _ = db.Close() }()
253 defer cleanupOAuth(t, db)
254
255 store := NewPostgresOAuthStore(db, 0) // Use default TTL
256 ctx := context.Background()
257
258 info := oauth.AuthRequestData{
259 State: "test_state_nodid",
260 AuthServerURL: "https://auth.example.com",
261 AccountDID: nil, // No DID provided
262 Scopes: []string{"atproto"},
263 RequestURI: "urn:ietf:params:oauth:request_uri:nodid",
264 AuthServerTokenEndpoint: "https://auth.example.com/oauth/token",
265 PKCEVerifier: "verifier_nodid",
266 DPoPAuthServerNonce: "nonce_nodid",
267 DPoPPrivateKeyMultibase: "z6MkpTHR8VNsBxYAAWHut2Geadd9jSwuBV8xRoAnwWsdvktH",
268 }
269
270 // Save auth request info without DID
271 err := store.SaveAuthRequestInfo(ctx, info)
272 assert.NoError(t, err)
273
274 // Retrieve and verify DID is nil
275 retrieved, err := store.GetAuthRequestInfo(ctx, "test_state_nodid")
276 assert.NoError(t, err)
277 assert.Nil(t, retrieved.AccountDID)
278 assert.Equal(t, info.State, retrieved.State)
279}
280
281func TestPostgresOAuthStore_GetAuthRequestInfo_NotFound(t *testing.T) {
282 db := setupTestDB(t)
283 defer func() { _ = db.Close() }()
284
285 store := NewPostgresOAuthStore(db, 0) // Use default TTL
286 ctx := context.Background()
287
288 _, err := store.GetAuthRequestInfo(ctx, "nonexistent_state")
289 assert.ErrorIs(t, err, ErrAuthRequestNotFound)
290}
291
292func TestPostgresOAuthStore_DeleteAuthRequestInfo(t *testing.T) {
293 db := setupTestDB(t)
294 defer func() { _ = db.Close() }()
295 defer cleanupOAuth(t, db)
296
297 store := NewPostgresOAuthStore(db, 0) // Use default TTL
298 ctx := context.Background()
299
300 info := oauth.AuthRequestData{
301 State: "test_state_delete",
302 AuthServerURL: "https://auth.example.com",
303 Scopes: []string{"atproto"},
304 RequestURI: "urn:ietf:params:oauth:request_uri:delete",
305 AuthServerTokenEndpoint: "https://auth.example.com/oauth/token",
306 PKCEVerifier: "verifier_delete",
307 DPoPAuthServerNonce: "nonce_delete",
308 DPoPPrivateKeyMultibase: "z6MkpTHR8VNsBxYAAWHut2Geadd9jSwuBV8xRoAnwWsdvktH",
309 }
310
311 // Save auth request info
312 err := store.SaveAuthRequestInfo(ctx, info)
313 require.NoError(t, err)
314
315 // Delete auth request info
316 err = store.DeleteAuthRequestInfo(ctx, "test_state_delete")
317 assert.NoError(t, err)
318
319 // Verify it's gone
320 _, err = store.GetAuthRequestInfo(ctx, "test_state_delete")
321 assert.ErrorIs(t, err, ErrAuthRequestNotFound)
322}
323
324func TestPostgresOAuthStore_DeleteAuthRequestInfo_NotFound(t *testing.T) {
325 db := setupTestDB(t)
326 defer func() { _ = db.Close() }()
327
328 store := NewPostgresOAuthStore(db, 0) // Use default TTL
329 ctx := context.Background()
330
331 err := store.DeleteAuthRequestInfo(ctx, "nonexistent_state")
332 assert.ErrorIs(t, err, ErrAuthRequestNotFound)
333}
334
335func TestPostgresOAuthStore_CleanupExpiredSessions(t *testing.T) {
336 db := setupTestDB(t)
337 defer func() { _ = db.Close() }()
338 defer cleanupOAuth(t, db)
339
340 storeInterface := NewPostgresOAuthStore(db, 0) // Use default TTL
341 store, ok := storeInterface.(*PostgresOAuthStore)
342 require.True(t, ok, "store should be *PostgresOAuthStore")
343 ctx := context.Background()
344
345 did1, err := syntax.ParseDID("did:plc:testexpired1")
346 require.NoError(t, err)
347 did2, err := syntax.ParseDID("did:plc:testexpired2")
348 require.NoError(t, err)
349
350 // Create an expired session (manually insert with past expiration)
351 _, err = db.ExecContext(ctx, `
352 INSERT INTO oauth_sessions (
353 did, session_id, handle, pds_url, host_url,
354 access_token, refresh_token,
355 dpop_private_key_multibase, auth_server_iss,
356 auth_server_token_endpoint, scopes,
357 expires_at, created_at
358 ) VALUES (
359 $1, $2, $3, $4, $5,
360 $6, $7,
361 $8, $9,
362 $10, $11,
363 NOW() - INTERVAL '1 day', NOW()
364 )
365 `, did1.String(), "expired_session", "test.handle", "https://pds.example.com", "https://pds.example.com",
366 "expired_token", "expired_refresh",
367 "z6MkpTHR8VNsBxYAAWHut2Geadd9jSwuBV8xRoAnwWsdvktH", "https://auth.example.com",
368 "https://auth.example.com/oauth/token", `{"atproto"}`)
369 require.NoError(t, err)
370
371 // Create a valid session
372 validSession := oauth.ClientSessionData{
373 AccountDID: did2,
374 SessionID: "valid_session",
375 HostURL: "https://pds.example.com",
376 AuthServerURL: "https://auth.example.com",
377 AuthServerTokenEndpoint: "https://auth.example.com/oauth/token",
378 Scopes: []string{"atproto"},
379 AccessToken: "valid_token",
380 RefreshToken: "valid_refresh",
381 DPoPPrivateKeyMultibase: "z6MkpTHR8VNsBxYAAWHut2Geadd9jSwuBV8xRoAnwWsdvktH",
382 }
383 err = store.SaveSession(ctx, validSession)
384 require.NoError(t, err)
385
386 // Cleanup expired sessions
387 count, err := store.CleanupExpiredSessions(ctx)
388 assert.NoError(t, err)
389 assert.Equal(t, int64(1), count, "Should delete 1 expired session")
390
391 // Verify expired session is gone
392 _, err = store.GetSession(ctx, did1, "expired_session")
393 assert.ErrorIs(t, err, ErrSessionNotFound)
394
395 // Verify valid session still exists
396 _, err = store.GetSession(ctx, did2, "valid_session")
397 assert.NoError(t, err)
398}
399
400func TestPostgresOAuthStore_CleanupExpiredAuthRequests(t *testing.T) {
401 db := setupTestDB(t)
402 defer func() { _ = db.Close() }()
403 defer cleanupOAuth(t, db)
404
405 storeInterface := NewPostgresOAuthStore(db, 0)
406 pgStore, ok := storeInterface.(*PostgresOAuthStore)
407 require.True(t, ok, "store should be *PostgresOAuthStore")
408 store := oauth.ClientAuthStore(pgStore)
409 ctx := context.Background()
410
411 // Create an old auth request (manually insert with old timestamp)
412 _, err := db.ExecContext(ctx, `
413 INSERT INTO oauth_requests (
414 state, did, handle, pds_url, pkce_verifier,
415 dpop_private_key_multibase, dpop_authserver_nonce,
416 auth_server_iss, request_uri,
417 auth_server_token_endpoint, scopes,
418 created_at
419 ) VALUES (
420 $1, $2, $3, $4, $5,
421 $6, $7,
422 $8, $9,
423 $10, $11,
424 NOW() - INTERVAL '1 hour'
425 )
426 `, "test_old_state", "did:plc:testold", "test.handle", "https://pds.example.com",
427 "old_verifier", "z6MkpTHR8VNsBxYAAWHut2Geadd9jSwuBV8xRoAnwWsdvktH",
428 "nonce_old", "https://auth.example.com", "urn:ietf:params:oauth:request_uri:old",
429 "https://auth.example.com/oauth/token", `{"atproto"}`)
430 require.NoError(t, err)
431
432 // Create a recent auth request
433 recentInfo := oauth.AuthRequestData{
434 State: "test_recent_state",
435 AuthServerURL: "https://auth.example.com",
436 Scopes: []string{"atproto"},
437 RequestURI: "urn:ietf:params:oauth:request_uri:recent",
438 AuthServerTokenEndpoint: "https://auth.example.com/oauth/token",
439 PKCEVerifier: "recent_verifier",
440 DPoPAuthServerNonce: "nonce_recent",
441 DPoPPrivateKeyMultibase: "z6MkpTHR8VNsBxYAAWHut2Geadd9jSwuBV8xRoAnwWsdvktH",
442 }
443 err = store.SaveAuthRequestInfo(ctx, recentInfo)
444 require.NoError(t, err)
445
446 // Cleanup expired auth requests (older than 30 minutes)
447 count, err := pgStore.CleanupExpiredAuthRequests(ctx)
448 assert.NoError(t, err)
449 assert.Equal(t, int64(1), count, "Should delete 1 expired auth request")
450
451 // Verify old request is gone
452 _, err = store.GetAuthRequestInfo(ctx, "test_old_state")
453 assert.ErrorIs(t, err, ErrAuthRequestNotFound)
454
455 // Verify recent request still exists
456 _, err = store.GetAuthRequestInfo(ctx, "test_recent_state")
457 assert.NoError(t, err)
458}
459
460func TestPostgresOAuthStore_MultipleSessions(t *testing.T) {
461 db := setupTestDB(t)
462 defer func() { _ = db.Close() }()
463 defer cleanupOAuth(t, db)
464
465 store := NewPostgresOAuthStore(db, 0) // Use default TTL
466 ctx := context.Background()
467
468 did, err := syntax.ParseDID("did:plc:testmulti")
469 require.NoError(t, err)
470
471 // Create multiple sessions for the same DID
472 session1 := oauth.ClientSessionData{
473 AccountDID: did,
474 SessionID: "browser1",
475 HostURL: "https://pds.example.com",
476 AuthServerURL: "https://auth.example.com",
477 AuthServerTokenEndpoint: "https://auth.example.com/oauth/token",
478 Scopes: []string{"atproto"},
479 AccessToken: "token_browser1",
480 RefreshToken: "refresh_browser1",
481 DPoPPrivateKeyMultibase: "z6MkpTHR8VNsBxYAAWHut2Geadd9jSwuBV8xRoAnwWsdvktH",
482 }
483
484 session2 := oauth.ClientSessionData{
485 AccountDID: did,
486 SessionID: "mobile_app",
487 HostURL: "https://pds.example.com",
488 AuthServerURL: "https://auth.example.com",
489 AuthServerTokenEndpoint: "https://auth.example.com/oauth/token",
490 Scopes: []string{"atproto"},
491 AccessToken: "token_mobile",
492 RefreshToken: "refresh_mobile",
493 DPoPPrivateKeyMultibase: "z6MkpTHR8VNsBxYAAWHut2Geadd9jSwuBV8xRoAnwWsdvktX",
494 }
495
496 // Save both sessions
497 err = store.SaveSession(ctx, session1)
498 require.NoError(t, err)
499 err = store.SaveSession(ctx, session2)
500 require.NoError(t, err)
501
502 // Retrieve both sessions
503 retrieved1, err := store.GetSession(ctx, did, "browser1")
504 assert.NoError(t, err)
505 assert.Equal(t, "token_browser1", retrieved1.AccessToken)
506
507 retrieved2, err := store.GetSession(ctx, did, "mobile_app")
508 assert.NoError(t, err)
509 assert.Equal(t, "token_mobile", retrieved2.AccessToken)
510
511 // Delete one session
512 err = store.DeleteSession(ctx, did, "browser1")
513 assert.NoError(t, err)
514
515 // Verify only browser1 is deleted
516 _, err = store.GetSession(ctx, did, "browser1")
517 assert.ErrorIs(t, err, ErrSessionNotFound)
518
519 // mobile_app should still exist
520 _, err = store.GetSession(ctx, did, "mobile_app")
521 assert.NoError(t, err)
522}