A community based topic aggregation platform built on atproto
at main 18 kB view raw
1package oauth 2 3import ( 4 "context" 5 "database/sql" 6 "os" 7 "testing" 8 9 "github.com/bluesky-social/indigo/atproto/auth/oauth" 10 "github.com/bluesky-social/indigo/atproto/syntax" 11 _ "github.com/lib/pq" 12 "github.com/pressly/goose/v3" 13 "github.com/stretchr/testify/assert" 14 "github.com/stretchr/testify/require" 15) 16 17// setupTestDB creates a test database connection and runs migrations 18func setupTestDB(t *testing.T) *sql.DB { 19 dsn := os.Getenv("TEST_DATABASE_URL") 20 if dsn == "" { 21 dsn = "postgres://test_user:test_password@localhost:5434/coves_test?sslmode=disable" 22 } 23 24 db, err := sql.Open("postgres", dsn) 25 require.NoError(t, err, "Failed to connect to test database") 26 27 // Run migrations 28 require.NoError(t, goose.Up(db, "../../db/migrations"), "Failed to run migrations") 29 30 return db 31} 32 33// cleanupOAuth removes all test OAuth data from the database 34func cleanupOAuth(t *testing.T, db *sql.DB) { 35 _, err := db.Exec("DELETE FROM oauth_sessions WHERE did LIKE 'did:plc:test%'") 36 require.NoError(t, err, "Failed to cleanup oauth_sessions") 37 38 _, err = db.Exec("DELETE FROM oauth_requests WHERE state LIKE 'test%'") 39 require.NoError(t, err, "Failed to cleanup oauth_requests") 40} 41 42func TestPostgresOAuthStore_SaveAndGetSession(t *testing.T) { 43 db := setupTestDB(t) 44 defer func() { _ = db.Close() }() 45 defer cleanupOAuth(t, db) 46 47 store := NewPostgresOAuthStore(db, 0) // Use default TTL 48 ctx := context.Background() 49 50 did, err := syntax.ParseDID("did:plc:test123abc") 51 require.NoError(t, err) 52 53 session := oauth.ClientSessionData{ 54 AccountDID: did, 55 SessionID: "session123", 56 HostURL: "https://pds.example.com", 57 AuthServerURL: "https://auth.example.com", 58 AuthServerTokenEndpoint: "https://auth.example.com/oauth/token", 59 AuthServerRevocationEndpoint: "https://auth.example.com/oauth/revoke", 60 Scopes: []string{"atproto"}, 61 AccessToken: "at_test_token_abc123", 62 RefreshToken: "rt_test_token_xyz789", 63 DPoPAuthServerNonce: "nonce_auth_123", 64 DPoPHostNonce: "nonce_host_456", 65 DPoPPrivateKeyMultibase: "z6MkpTHR8VNsBxYAAWHut2Geadd9jSwuBV8xRoAnwWsdvktH", 66 } 67 68 // Save session 69 err = store.SaveSession(ctx, session) 70 assert.NoError(t, err) 71 72 // Retrieve session 73 retrieved, err := store.GetSession(ctx, did, "session123") 74 assert.NoError(t, err) 75 assert.NotNil(t, retrieved) 76 assert.Equal(t, session.AccountDID.String(), retrieved.AccountDID.String()) 77 assert.Equal(t, session.SessionID, retrieved.SessionID) 78 assert.Equal(t, session.HostURL, retrieved.HostURL) 79 assert.Equal(t, session.AuthServerURL, retrieved.AuthServerURL) 80 assert.Equal(t, session.AuthServerTokenEndpoint, retrieved.AuthServerTokenEndpoint) 81 assert.Equal(t, session.AccessToken, retrieved.AccessToken) 82 assert.Equal(t, session.RefreshToken, retrieved.RefreshToken) 83 assert.Equal(t, session.DPoPAuthServerNonce, retrieved.DPoPAuthServerNonce) 84 assert.Equal(t, session.DPoPHostNonce, retrieved.DPoPHostNonce) 85 assert.Equal(t, session.DPoPPrivateKeyMultibase, retrieved.DPoPPrivateKeyMultibase) 86 assert.Equal(t, session.Scopes, retrieved.Scopes) 87} 88 89func TestPostgresOAuthStore_SaveSession_Upsert(t *testing.T) { 90 db := setupTestDB(t) 91 defer func() { _ = db.Close() }() 92 defer cleanupOAuth(t, db) 93 94 store := NewPostgresOAuthStore(db, 0) // Use default TTL 95 ctx := context.Background() 96 97 did, err := syntax.ParseDID("did:plc:testupsert") 98 require.NoError(t, err) 99 100 // Initial session 101 session1 := oauth.ClientSessionData{ 102 AccountDID: did, 103 SessionID: "session_upsert", 104 HostURL: "https://pds1.example.com", 105 AuthServerURL: "https://auth1.example.com", 106 AuthServerTokenEndpoint: "https://auth1.example.com/oauth/token", 107 Scopes: []string{"atproto"}, 108 AccessToken: "old_access_token", 109 RefreshToken: "old_refresh_token", 110 DPoPPrivateKeyMultibase: "z6MkpTHR8VNsBxYAAWHut2Geadd9jSwuBV8xRoAnwWsdvktH", 111 } 112 113 err = store.SaveSession(ctx, session1) 114 require.NoError(t, err) 115 116 // Updated session (same DID and session ID) 117 session2 := oauth.ClientSessionData{ 118 AccountDID: did, 119 SessionID: "session_upsert", 120 HostURL: "https://pds2.example.com", 121 AuthServerURL: "https://auth2.example.com", 122 AuthServerTokenEndpoint: "https://auth2.example.com/oauth/token", 123 Scopes: []string{"atproto", "transition:generic"}, 124 AccessToken: "new_access_token", 125 RefreshToken: "new_refresh_token", 126 DPoPPrivateKeyMultibase: "z6MkpTHR8VNsBxYAAWHut2Geadd9jSwuBV8xRoAnwWsdvktX", 127 } 128 129 // Save again - should update 130 err = store.SaveSession(ctx, session2) 131 assert.NoError(t, err) 132 133 // Retrieve should get updated values 134 retrieved, err := store.GetSession(ctx, did, "session_upsert") 135 assert.NoError(t, err) 136 assert.Equal(t, "new_access_token", retrieved.AccessToken) 137 assert.Equal(t, "new_refresh_token", retrieved.RefreshToken) 138 assert.Equal(t, "https://pds2.example.com", retrieved.HostURL) 139 assert.Equal(t, []string{"atproto", "transition:generic"}, retrieved.Scopes) 140} 141 142func TestPostgresOAuthStore_GetSession_NotFound(t *testing.T) { 143 db := setupTestDB(t) 144 defer func() { _ = db.Close() }() 145 146 store := NewPostgresOAuthStore(db, 0) // Use default TTL 147 ctx := context.Background() 148 149 did, err := syntax.ParseDID("did:plc:nonexistent") 150 require.NoError(t, err) 151 152 _, err = store.GetSession(ctx, did, "nonexistent_session") 153 assert.ErrorIs(t, err, ErrSessionNotFound) 154} 155 156func TestPostgresOAuthStore_DeleteSession(t *testing.T) { 157 db := setupTestDB(t) 158 defer func() { _ = db.Close() }() 159 defer cleanupOAuth(t, db) 160 161 store := NewPostgresOAuthStore(db, 0) // Use default TTL 162 ctx := context.Background() 163 164 did, err := syntax.ParseDID("did:plc:testdelete") 165 require.NoError(t, err) 166 167 session := oauth.ClientSessionData{ 168 AccountDID: did, 169 SessionID: "session_delete", 170 HostURL: "https://pds.example.com", 171 AuthServerURL: "https://auth.example.com", 172 AuthServerTokenEndpoint: "https://auth.example.com/oauth/token", 173 Scopes: []string{"atproto"}, 174 AccessToken: "test_token", 175 RefreshToken: "test_refresh", 176 DPoPPrivateKeyMultibase: "z6MkpTHR8VNsBxYAAWHut2Geadd9jSwuBV8xRoAnwWsdvktH", 177 } 178 179 // Save session 180 err = store.SaveSession(ctx, session) 181 require.NoError(t, err) 182 183 // Delete session 184 err = store.DeleteSession(ctx, did, "session_delete") 185 assert.NoError(t, err) 186 187 // Verify session is gone 188 _, err = store.GetSession(ctx, did, "session_delete") 189 assert.ErrorIs(t, err, ErrSessionNotFound) 190} 191 192func TestPostgresOAuthStore_DeleteSession_NotFound(t *testing.T) { 193 db := setupTestDB(t) 194 defer func() { _ = db.Close() }() 195 196 store := NewPostgresOAuthStore(db, 0) // Use default TTL 197 ctx := context.Background() 198 199 did, err := syntax.ParseDID("did:plc:nonexistent") 200 require.NoError(t, err) 201 202 err = store.DeleteSession(ctx, did, "nonexistent_session") 203 assert.ErrorIs(t, err, ErrSessionNotFound) 204} 205 206func TestPostgresOAuthStore_SaveAndGetAuthRequestInfo(t *testing.T) { 207 db := setupTestDB(t) 208 defer func() { _ = db.Close() }() 209 defer cleanupOAuth(t, db) 210 211 store := NewPostgresOAuthStore(db, 0) // Use default TTL 212 ctx := context.Background() 213 214 did, err := syntax.ParseDID("did:plc:testrequest") 215 require.NoError(t, err) 216 217 info := oauth.AuthRequestData{ 218 State: "test_state_abc123", 219 AuthServerURL: "https://auth.example.com", 220 AccountDID: &did, 221 Scopes: []string{"atproto"}, 222 RequestURI: "urn:ietf:params:oauth:request_uri:abc123", 223 AuthServerTokenEndpoint: "https://auth.example.com/oauth/token", 224 AuthServerRevocationEndpoint: "https://auth.example.com/oauth/revoke", 225 PKCEVerifier: "verifier_xyz789", 226 DPoPAuthServerNonce: "nonce_abc", 227 DPoPPrivateKeyMultibase: "z6MkpTHR8VNsBxYAAWHut2Geadd9jSwuBV8xRoAnwWsdvktH", 228 } 229 230 // Save auth request info 231 err = store.SaveAuthRequestInfo(ctx, info) 232 assert.NoError(t, err) 233 234 // Retrieve auth request info 235 retrieved, err := store.GetAuthRequestInfo(ctx, "test_state_abc123") 236 assert.NoError(t, err) 237 assert.NotNil(t, retrieved) 238 assert.Equal(t, info.State, retrieved.State) 239 assert.Equal(t, info.AuthServerURL, retrieved.AuthServerURL) 240 assert.NotNil(t, retrieved.AccountDID) 241 assert.Equal(t, info.AccountDID.String(), retrieved.AccountDID.String()) 242 assert.Equal(t, info.Scopes, retrieved.Scopes) 243 assert.Equal(t, info.RequestURI, retrieved.RequestURI) 244 assert.Equal(t, info.AuthServerTokenEndpoint, retrieved.AuthServerTokenEndpoint) 245 assert.Equal(t, info.PKCEVerifier, retrieved.PKCEVerifier) 246 assert.Equal(t, info.DPoPAuthServerNonce, retrieved.DPoPAuthServerNonce) 247 assert.Equal(t, info.DPoPPrivateKeyMultibase, retrieved.DPoPPrivateKeyMultibase) 248} 249 250func TestPostgresOAuthStore_SaveAuthRequestInfo_NoDID(t *testing.T) { 251 db := setupTestDB(t) 252 defer func() { _ = db.Close() }() 253 defer cleanupOAuth(t, db) 254 255 store := NewPostgresOAuthStore(db, 0) // Use default TTL 256 ctx := context.Background() 257 258 info := oauth.AuthRequestData{ 259 State: "test_state_nodid", 260 AuthServerURL: "https://auth.example.com", 261 AccountDID: nil, // No DID provided 262 Scopes: []string{"atproto"}, 263 RequestURI: "urn:ietf:params:oauth:request_uri:nodid", 264 AuthServerTokenEndpoint: "https://auth.example.com/oauth/token", 265 PKCEVerifier: "verifier_nodid", 266 DPoPAuthServerNonce: "nonce_nodid", 267 DPoPPrivateKeyMultibase: "z6MkpTHR8VNsBxYAAWHut2Geadd9jSwuBV8xRoAnwWsdvktH", 268 } 269 270 // Save auth request info without DID 271 err := store.SaveAuthRequestInfo(ctx, info) 272 assert.NoError(t, err) 273 274 // Retrieve and verify DID is nil 275 retrieved, err := store.GetAuthRequestInfo(ctx, "test_state_nodid") 276 assert.NoError(t, err) 277 assert.Nil(t, retrieved.AccountDID) 278 assert.Equal(t, info.State, retrieved.State) 279} 280 281func TestPostgresOAuthStore_GetAuthRequestInfo_NotFound(t *testing.T) { 282 db := setupTestDB(t) 283 defer func() { _ = db.Close() }() 284 285 store := NewPostgresOAuthStore(db, 0) // Use default TTL 286 ctx := context.Background() 287 288 _, err := store.GetAuthRequestInfo(ctx, "nonexistent_state") 289 assert.ErrorIs(t, err, ErrAuthRequestNotFound) 290} 291 292func TestPostgresOAuthStore_DeleteAuthRequestInfo(t *testing.T) { 293 db := setupTestDB(t) 294 defer func() { _ = db.Close() }() 295 defer cleanupOAuth(t, db) 296 297 store := NewPostgresOAuthStore(db, 0) // Use default TTL 298 ctx := context.Background() 299 300 info := oauth.AuthRequestData{ 301 State: "test_state_delete", 302 AuthServerURL: "https://auth.example.com", 303 Scopes: []string{"atproto"}, 304 RequestURI: "urn:ietf:params:oauth:request_uri:delete", 305 AuthServerTokenEndpoint: "https://auth.example.com/oauth/token", 306 PKCEVerifier: "verifier_delete", 307 DPoPAuthServerNonce: "nonce_delete", 308 DPoPPrivateKeyMultibase: "z6MkpTHR8VNsBxYAAWHut2Geadd9jSwuBV8xRoAnwWsdvktH", 309 } 310 311 // Save auth request info 312 err := store.SaveAuthRequestInfo(ctx, info) 313 require.NoError(t, err) 314 315 // Delete auth request info 316 err = store.DeleteAuthRequestInfo(ctx, "test_state_delete") 317 assert.NoError(t, err) 318 319 // Verify it's gone 320 _, err = store.GetAuthRequestInfo(ctx, "test_state_delete") 321 assert.ErrorIs(t, err, ErrAuthRequestNotFound) 322} 323 324func TestPostgresOAuthStore_DeleteAuthRequestInfo_NotFound(t *testing.T) { 325 db := setupTestDB(t) 326 defer func() { _ = db.Close() }() 327 328 store := NewPostgresOAuthStore(db, 0) // Use default TTL 329 ctx := context.Background() 330 331 err := store.DeleteAuthRequestInfo(ctx, "nonexistent_state") 332 assert.ErrorIs(t, err, ErrAuthRequestNotFound) 333} 334 335func TestPostgresOAuthStore_CleanupExpiredSessions(t *testing.T) { 336 db := setupTestDB(t) 337 defer func() { _ = db.Close() }() 338 defer cleanupOAuth(t, db) 339 340 storeInterface := NewPostgresOAuthStore(db, 0) // Use default TTL 341 store, ok := storeInterface.(*PostgresOAuthStore) 342 require.True(t, ok, "store should be *PostgresOAuthStore") 343 ctx := context.Background() 344 345 did1, err := syntax.ParseDID("did:plc:testexpired1") 346 require.NoError(t, err) 347 did2, err := syntax.ParseDID("did:plc:testexpired2") 348 require.NoError(t, err) 349 350 // Create an expired session (manually insert with past expiration) 351 _, err = db.ExecContext(ctx, ` 352 INSERT INTO oauth_sessions ( 353 did, session_id, handle, pds_url, host_url, 354 access_token, refresh_token, 355 dpop_private_key_multibase, auth_server_iss, 356 auth_server_token_endpoint, scopes, 357 expires_at, created_at 358 ) VALUES ( 359 $1, $2, $3, $4, $5, 360 $6, $7, 361 $8, $9, 362 $10, $11, 363 NOW() - INTERVAL '1 day', NOW() 364 ) 365 `, did1.String(), "expired_session", "test.handle", "https://pds.example.com", "https://pds.example.com", 366 "expired_token", "expired_refresh", 367 "z6MkpTHR8VNsBxYAAWHut2Geadd9jSwuBV8xRoAnwWsdvktH", "https://auth.example.com", 368 "https://auth.example.com/oauth/token", `{"atproto"}`) 369 require.NoError(t, err) 370 371 // Create a valid session 372 validSession := oauth.ClientSessionData{ 373 AccountDID: did2, 374 SessionID: "valid_session", 375 HostURL: "https://pds.example.com", 376 AuthServerURL: "https://auth.example.com", 377 AuthServerTokenEndpoint: "https://auth.example.com/oauth/token", 378 Scopes: []string{"atproto"}, 379 AccessToken: "valid_token", 380 RefreshToken: "valid_refresh", 381 DPoPPrivateKeyMultibase: "z6MkpTHR8VNsBxYAAWHut2Geadd9jSwuBV8xRoAnwWsdvktH", 382 } 383 err = store.SaveSession(ctx, validSession) 384 require.NoError(t, err) 385 386 // Cleanup expired sessions 387 count, err := store.CleanupExpiredSessions(ctx) 388 assert.NoError(t, err) 389 assert.Equal(t, int64(1), count, "Should delete 1 expired session") 390 391 // Verify expired session is gone 392 _, err = store.GetSession(ctx, did1, "expired_session") 393 assert.ErrorIs(t, err, ErrSessionNotFound) 394 395 // Verify valid session still exists 396 _, err = store.GetSession(ctx, did2, "valid_session") 397 assert.NoError(t, err) 398} 399 400func TestPostgresOAuthStore_CleanupExpiredAuthRequests(t *testing.T) { 401 db := setupTestDB(t) 402 defer func() { _ = db.Close() }() 403 defer cleanupOAuth(t, db) 404 405 storeInterface := NewPostgresOAuthStore(db, 0) 406 pgStore, ok := storeInterface.(*PostgresOAuthStore) 407 require.True(t, ok, "store should be *PostgresOAuthStore") 408 store := oauth.ClientAuthStore(pgStore) 409 ctx := context.Background() 410 411 // Create an old auth request (manually insert with old timestamp) 412 _, err := db.ExecContext(ctx, ` 413 INSERT INTO oauth_requests ( 414 state, did, handle, pds_url, pkce_verifier, 415 dpop_private_key_multibase, dpop_authserver_nonce, 416 auth_server_iss, request_uri, 417 auth_server_token_endpoint, scopes, 418 created_at 419 ) VALUES ( 420 $1, $2, $3, $4, $5, 421 $6, $7, 422 $8, $9, 423 $10, $11, 424 NOW() - INTERVAL '1 hour' 425 ) 426 `, "test_old_state", "did:plc:testold", "test.handle", "https://pds.example.com", 427 "old_verifier", "z6MkpTHR8VNsBxYAAWHut2Geadd9jSwuBV8xRoAnwWsdvktH", 428 "nonce_old", "https://auth.example.com", "urn:ietf:params:oauth:request_uri:old", 429 "https://auth.example.com/oauth/token", `{"atproto"}`) 430 require.NoError(t, err) 431 432 // Create a recent auth request 433 recentInfo := oauth.AuthRequestData{ 434 State: "test_recent_state", 435 AuthServerURL: "https://auth.example.com", 436 Scopes: []string{"atproto"}, 437 RequestURI: "urn:ietf:params:oauth:request_uri:recent", 438 AuthServerTokenEndpoint: "https://auth.example.com/oauth/token", 439 PKCEVerifier: "recent_verifier", 440 DPoPAuthServerNonce: "nonce_recent", 441 DPoPPrivateKeyMultibase: "z6MkpTHR8VNsBxYAAWHut2Geadd9jSwuBV8xRoAnwWsdvktH", 442 } 443 err = store.SaveAuthRequestInfo(ctx, recentInfo) 444 require.NoError(t, err) 445 446 // Cleanup expired auth requests (older than 30 minutes) 447 count, err := pgStore.CleanupExpiredAuthRequests(ctx) 448 assert.NoError(t, err) 449 assert.Equal(t, int64(1), count, "Should delete 1 expired auth request") 450 451 // Verify old request is gone 452 _, err = store.GetAuthRequestInfo(ctx, "test_old_state") 453 assert.ErrorIs(t, err, ErrAuthRequestNotFound) 454 455 // Verify recent request still exists 456 _, err = store.GetAuthRequestInfo(ctx, "test_recent_state") 457 assert.NoError(t, err) 458} 459 460func TestPostgresOAuthStore_MultipleSessions(t *testing.T) { 461 db := setupTestDB(t) 462 defer func() { _ = db.Close() }() 463 defer cleanupOAuth(t, db) 464 465 store := NewPostgresOAuthStore(db, 0) // Use default TTL 466 ctx := context.Background() 467 468 did, err := syntax.ParseDID("did:plc:testmulti") 469 require.NoError(t, err) 470 471 // Create multiple sessions for the same DID 472 session1 := oauth.ClientSessionData{ 473 AccountDID: did, 474 SessionID: "browser1", 475 HostURL: "https://pds.example.com", 476 AuthServerURL: "https://auth.example.com", 477 AuthServerTokenEndpoint: "https://auth.example.com/oauth/token", 478 Scopes: []string{"atproto"}, 479 AccessToken: "token_browser1", 480 RefreshToken: "refresh_browser1", 481 DPoPPrivateKeyMultibase: "z6MkpTHR8VNsBxYAAWHut2Geadd9jSwuBV8xRoAnwWsdvktH", 482 } 483 484 session2 := oauth.ClientSessionData{ 485 AccountDID: did, 486 SessionID: "mobile_app", 487 HostURL: "https://pds.example.com", 488 AuthServerURL: "https://auth.example.com", 489 AuthServerTokenEndpoint: "https://auth.example.com/oauth/token", 490 Scopes: []string{"atproto"}, 491 AccessToken: "token_mobile", 492 RefreshToken: "refresh_mobile", 493 DPoPPrivateKeyMultibase: "z6MkpTHR8VNsBxYAAWHut2Geadd9jSwuBV8xRoAnwWsdvktX", 494 } 495 496 // Save both sessions 497 err = store.SaveSession(ctx, session1) 498 require.NoError(t, err) 499 err = store.SaveSession(ctx, session2) 500 require.NoError(t, err) 501 502 // Retrieve both sessions 503 retrieved1, err := store.GetSession(ctx, did, "browser1") 504 assert.NoError(t, err) 505 assert.Equal(t, "token_browser1", retrieved1.AccessToken) 506 507 retrieved2, err := store.GetSession(ctx, did, "mobile_app") 508 assert.NoError(t, err) 509 assert.Equal(t, "token_mobile", retrieved2.AccessToken) 510 511 // Delete one session 512 err = store.DeleteSession(ctx, did, "browser1") 513 assert.NoError(t, err) 514 515 // Verify only browser1 is deleted 516 _, err = store.GetSession(ctx, did, "browser1") 517 assert.ErrorIs(t, err, ErrSessionNotFound) 518 519 // mobile_app should still exist 520 _, err = store.GetSession(ctx, did, "mobile_app") 521 assert.NoError(t, err) 522}