A community based topic aggregation platform built on atproto
at main 19 kB view raw
1package oauth 2 3import ( 4 "context" 5 "database/sql" 6 "errors" 7 "fmt" 8 "log/slog" 9 "net/url" 10 "strings" 11 "time" 12 13 "github.com/bluesky-social/indigo/atproto/auth/oauth" 14 "github.com/bluesky-social/indigo/atproto/syntax" 15 "github.com/lib/pq" 16) 17 18var ( 19 ErrSessionNotFound = errors.New("oauth session not found") 20 ErrAuthRequestNotFound = errors.New("oauth auth request not found") 21) 22 23// PostgresOAuthStore implements oauth.ClientAuthStore interface using PostgreSQL 24type PostgresOAuthStore struct { 25 db *sql.DB 26 sessionTTL time.Duration 27} 28 29// NewPostgresOAuthStore creates a new PostgreSQL-backed OAuth store 30func NewPostgresOAuthStore(db *sql.DB, sessionTTL time.Duration) oauth.ClientAuthStore { 31 if sessionTTL == 0 { 32 sessionTTL = 7 * 24 * time.Hour // Default to 7 days 33 } 34 return &PostgresOAuthStore{ 35 db: db, 36 sessionTTL: sessionTTL, 37 } 38} 39 40// GetSession retrieves a session by DID and session ID 41func (s *PostgresOAuthStore) GetSession(ctx context.Context, did syntax.DID, sessionID string) (*oauth.ClientSessionData, error) { 42 query := ` 43 SELECT 44 did, session_id, host_url, auth_server_iss, 45 auth_server_token_endpoint, auth_server_revocation_endpoint, 46 scopes, access_token, refresh_token, 47 dpop_authserver_nonce, dpop_pds_nonce, dpop_private_key_multibase 48 FROM oauth_sessions 49 WHERE did = $1 AND session_id = $2 AND expires_at > NOW() 50 ` 51 52 var session oauth.ClientSessionData 53 var authServerIss, authServerTokenEndpoint, authServerRevocationEndpoint sql.NullString 54 var hostURL, dpopPrivateKeyMultibase sql.NullString 55 var scopes pq.StringArray 56 var dpopAuthServerNonce, dpopHostNonce sql.NullString 57 58 err := s.db.QueryRowContext(ctx, query, did.String(), sessionID).Scan( 59 &session.AccountDID, 60 &session.SessionID, 61 &hostURL, 62 &authServerIss, 63 &authServerTokenEndpoint, 64 &authServerRevocationEndpoint, 65 &scopes, 66 &session.AccessToken, 67 &session.RefreshToken, 68 &dpopAuthServerNonce, 69 &dpopHostNonce, 70 &dpopPrivateKeyMultibase, 71 ) 72 73 if err == sql.ErrNoRows { 74 return nil, ErrSessionNotFound 75 } 76 if err != nil { 77 return nil, fmt.Errorf("failed to get session: %w", err) 78 } 79 80 // Convert nullable fields 81 if hostURL.Valid { 82 session.HostURL = hostURL.String 83 } 84 if authServerIss.Valid { 85 session.AuthServerURL = authServerIss.String 86 } 87 if authServerTokenEndpoint.Valid { 88 session.AuthServerTokenEndpoint = authServerTokenEndpoint.String 89 } 90 if authServerRevocationEndpoint.Valid { 91 session.AuthServerRevocationEndpoint = authServerRevocationEndpoint.String 92 } 93 if dpopAuthServerNonce.Valid { 94 session.DPoPAuthServerNonce = dpopAuthServerNonce.String 95 } 96 if dpopHostNonce.Valid { 97 session.DPoPHostNonce = dpopHostNonce.String 98 } 99 if dpopPrivateKeyMultibase.Valid { 100 session.DPoPPrivateKeyMultibase = dpopPrivateKeyMultibase.String 101 } 102 session.Scopes = scopes 103 104 return &session, nil 105} 106 107// SaveSession saves or updates a session (upsert operation) 108func (s *PostgresOAuthStore) SaveSession(ctx context.Context, sess oauth.ClientSessionData) error { 109 // Input validation per atProto OAuth security requirements 110 111 // Validate DID format 112 if _, err := syntax.ParseDID(sess.AccountDID.String()); err != nil { 113 return fmt.Errorf("invalid DID format: %w", err) 114 } 115 116 // Validate token lengths (max 10000 chars to prevent memory issues) 117 const maxTokenLength = 10000 118 if len(sess.AccessToken) > maxTokenLength { 119 return fmt.Errorf("access_token exceeds maximum length of %d characters", maxTokenLength) 120 } 121 if len(sess.RefreshToken) > maxTokenLength { 122 return fmt.Errorf("refresh_token exceeds maximum length of %d characters", maxTokenLength) 123 } 124 125 // Validate session ID is not empty 126 if sess.SessionID == "" { 127 return fmt.Errorf("session_id cannot be empty") 128 } 129 130 // Validate URLs if provided 131 if sess.HostURL != "" { 132 if _, err := url.Parse(sess.HostURL); err != nil { 133 return fmt.Errorf("invalid host_url: %w", err) 134 } 135 } 136 if sess.AuthServerURL != "" { 137 if _, err := url.Parse(sess.AuthServerURL); err != nil { 138 return fmt.Errorf("invalid auth_server URL: %w", err) 139 } 140 } 141 if sess.AuthServerTokenEndpoint != "" { 142 if _, err := url.Parse(sess.AuthServerTokenEndpoint); err != nil { 143 return fmt.Errorf("invalid auth_server_token_endpoint: %w", err) 144 } 145 } 146 if sess.AuthServerRevocationEndpoint != "" { 147 if _, err := url.Parse(sess.AuthServerRevocationEndpoint); err != nil { 148 return fmt.Errorf("invalid auth_server_revocation_endpoint: %w", err) 149 } 150 } 151 152 query := ` 153 INSERT INTO oauth_sessions ( 154 did, session_id, handle, pds_url, host_url, 155 access_token, refresh_token, 156 dpop_private_jwk, dpop_private_key_multibase, 157 dpop_authserver_nonce, dpop_pds_nonce, 158 auth_server_iss, auth_server_token_endpoint, auth_server_revocation_endpoint, 159 scopes, expires_at, created_at, updated_at 160 ) VALUES ( 161 $1, $2, $3, $4, $5, 162 $6, $7, 163 NULL, $8, 164 $9, $10, 165 $11, $12, $13, 166 $14, $15, NOW(), NOW() 167 ) 168 ON CONFLICT (did, session_id) DO UPDATE SET 169 handle = EXCLUDED.handle, 170 pds_url = EXCLUDED.pds_url, 171 host_url = EXCLUDED.host_url, 172 access_token = EXCLUDED.access_token, 173 refresh_token = EXCLUDED.refresh_token, 174 dpop_private_key_multibase = EXCLUDED.dpop_private_key_multibase, 175 dpop_authserver_nonce = EXCLUDED.dpop_authserver_nonce, 176 dpop_pds_nonce = EXCLUDED.dpop_pds_nonce, 177 auth_server_iss = EXCLUDED.auth_server_iss, 178 auth_server_token_endpoint = EXCLUDED.auth_server_token_endpoint, 179 auth_server_revocation_endpoint = EXCLUDED.auth_server_revocation_endpoint, 180 scopes = EXCLUDED.scopes, 181 expires_at = EXCLUDED.expires_at, 182 updated_at = NOW() 183 ` 184 185 // Calculate token expiration using configured TTL 186 expiresAt := time.Now().Add(s.sessionTTL) 187 188 // Convert empty strings to NULL for optional fields 189 var authServerRevocationEndpoint sql.NullString 190 if sess.AuthServerRevocationEndpoint != "" { 191 authServerRevocationEndpoint.String = sess.AuthServerRevocationEndpoint 192 authServerRevocationEndpoint.Valid = true 193 } 194 195 // Extract handle from DID (placeholder - in real implementation, resolve from identity) 196 // For now, use DID as handle since we don't have the handle in ClientSessionData 197 handle := sess.AccountDID.String() 198 199 // Use HostURL as PDS URL 200 pdsURL := sess.HostURL 201 if pdsURL == "" { 202 pdsURL = sess.AuthServerURL // Fallback to auth server URL 203 } 204 205 _, err := s.db.ExecContext( 206 ctx, query, 207 sess.AccountDID.String(), 208 sess.SessionID, 209 handle, 210 pdsURL, 211 sess.HostURL, 212 sess.AccessToken, 213 sess.RefreshToken, 214 sess.DPoPPrivateKeyMultibase, 215 sess.DPoPAuthServerNonce, 216 sess.DPoPHostNonce, 217 sess.AuthServerURL, 218 sess.AuthServerTokenEndpoint, 219 authServerRevocationEndpoint, 220 pq.Array(sess.Scopes), 221 expiresAt, 222 ) 223 if err != nil { 224 return fmt.Errorf("failed to save session: %w", err) 225 } 226 227 return nil 228} 229 230// DeleteSession deletes a session by DID and session ID 231func (s *PostgresOAuthStore) DeleteSession(ctx context.Context, did syntax.DID, sessionID string) error { 232 query := `DELETE FROM oauth_sessions WHERE did = $1 AND session_id = $2` 233 234 result, err := s.db.ExecContext(ctx, query, did.String(), sessionID) 235 if err != nil { 236 return fmt.Errorf("failed to delete session: %w", err) 237 } 238 239 rows, err := result.RowsAffected() 240 if err != nil { 241 return fmt.Errorf("failed to get rows affected: %w", err) 242 } 243 244 if rows == 0 { 245 return ErrSessionNotFound 246 } 247 248 return nil 249} 250 251// GetAuthRequestInfo retrieves auth request information by state 252func (s *PostgresOAuthStore) GetAuthRequestInfo(ctx context.Context, state string) (*oauth.AuthRequestData, error) { 253 query := ` 254 SELECT 255 state, did, handle, pds_url, pkce_verifier, 256 dpop_private_key_multibase, dpop_authserver_nonce, 257 auth_server_iss, request_uri, 258 auth_server_token_endpoint, auth_server_revocation_endpoint, 259 scopes, created_at 260 FROM oauth_requests 261 WHERE state = $1 262 ` 263 264 var info oauth.AuthRequestData 265 var did, handle, pdsURL sql.NullString 266 var dpopPrivateKeyMultibase, dpopAuthServerNonce sql.NullString 267 var requestURI, authServerTokenEndpoint, authServerRevocationEndpoint sql.NullString 268 var scopes pq.StringArray 269 var createdAt time.Time 270 271 err := s.db.QueryRowContext(ctx, query, state).Scan( 272 &info.State, 273 &did, 274 &handle, 275 &pdsURL, 276 &info.PKCEVerifier, 277 &dpopPrivateKeyMultibase, 278 &dpopAuthServerNonce, 279 &info.AuthServerURL, 280 &requestURI, 281 &authServerTokenEndpoint, 282 &authServerRevocationEndpoint, 283 &scopes, 284 &createdAt, 285 ) 286 287 if err == sql.ErrNoRows { 288 return nil, ErrAuthRequestNotFound 289 } 290 if err != nil { 291 return nil, fmt.Errorf("failed to get auth request info: %w", err) 292 } 293 294 // Parse DID if present 295 if did.Valid && did.String != "" { 296 parsedDID, err := syntax.ParseDID(did.String) 297 if err != nil { 298 return nil, fmt.Errorf("failed to parse DID: %w", err) 299 } 300 info.AccountDID = &parsedDID 301 } 302 303 // Convert nullable fields 304 if dpopPrivateKeyMultibase.Valid { 305 info.DPoPPrivateKeyMultibase = dpopPrivateKeyMultibase.String 306 } 307 if dpopAuthServerNonce.Valid { 308 info.DPoPAuthServerNonce = dpopAuthServerNonce.String 309 } 310 if requestURI.Valid { 311 info.RequestURI = requestURI.String 312 } 313 if authServerTokenEndpoint.Valid { 314 info.AuthServerTokenEndpoint = authServerTokenEndpoint.String 315 } 316 if authServerRevocationEndpoint.Valid { 317 info.AuthServerRevocationEndpoint = authServerRevocationEndpoint.String 318 } 319 info.Scopes = scopes 320 321 return &info, nil 322} 323 324// SaveAuthRequestInfo saves auth request information (create only, not upsert) 325func (s *PostgresOAuthStore) SaveAuthRequestInfo(ctx context.Context, info oauth.AuthRequestData) error { 326 query := ` 327 INSERT INTO oauth_requests ( 328 state, did, handle, pds_url, pkce_verifier, 329 dpop_private_key_multibase, dpop_authserver_nonce, 330 auth_server_iss, request_uri, 331 auth_server_token_endpoint, auth_server_revocation_endpoint, 332 scopes, return_url, created_at 333 ) VALUES ( 334 $1, $2, $3, $4, $5, 335 $6, $7, 336 $8, $9, 337 $10, $11, 338 $12, NULL, NOW() 339 ) 340 ` 341 342 // Extract DID string if present 343 var didStr sql.NullString 344 if info.AccountDID != nil { 345 didStr.String = info.AccountDID.String() 346 didStr.Valid = true 347 } 348 349 // Convert empty strings to NULL for optional fields 350 var authServerRevocationEndpoint sql.NullString 351 if info.AuthServerRevocationEndpoint != "" { 352 authServerRevocationEndpoint.String = info.AuthServerRevocationEndpoint 353 authServerRevocationEndpoint.Valid = true 354 } 355 356 // Placeholder values for handle and pds_url (not in AuthRequestData) 357 // In production, these would be resolved during the auth flow 358 handle := "" 359 pdsURL := "" 360 if info.AccountDID != nil { 361 handle = info.AccountDID.String() // Temporary placeholder 362 pdsURL = info.AuthServerURL // Temporary placeholder 363 } 364 365 _, err := s.db.ExecContext( 366 ctx, query, 367 info.State, 368 didStr, 369 handle, 370 pdsURL, 371 info.PKCEVerifier, 372 info.DPoPPrivateKeyMultibase, 373 info.DPoPAuthServerNonce, 374 info.AuthServerURL, 375 info.RequestURI, 376 info.AuthServerTokenEndpoint, 377 authServerRevocationEndpoint, 378 pq.Array(info.Scopes), 379 ) 380 if err != nil { 381 // Check for duplicate state 382 if strings.Contains(err.Error(), "duplicate key") && strings.Contains(err.Error(), "oauth_requests_state_key") { 383 return fmt.Errorf("auth request with state already exists: %s", info.State) 384 } 385 return fmt.Errorf("failed to save auth request info: %w", err) 386 } 387 388 return nil 389} 390 391// DeleteAuthRequestInfo deletes auth request information by state 392func (s *PostgresOAuthStore) DeleteAuthRequestInfo(ctx context.Context, state string) error { 393 query := `DELETE FROM oauth_requests WHERE state = $1` 394 395 result, err := s.db.ExecContext(ctx, query, state) 396 if err != nil { 397 return fmt.Errorf("failed to delete auth request info: %w", err) 398 } 399 400 rows, err := result.RowsAffected() 401 if err != nil { 402 return fmt.Errorf("failed to get rows affected: %w", err) 403 } 404 405 if rows == 0 { 406 return ErrAuthRequestNotFound 407 } 408 409 return nil 410} 411 412// CleanupExpiredSessions removes sessions that have expired 413// Should be called periodically (e.g., via cron job) 414func (s *PostgresOAuthStore) CleanupExpiredSessions(ctx context.Context) (int64, error) { 415 query := `DELETE FROM oauth_sessions WHERE expires_at < NOW()` 416 417 result, err := s.db.ExecContext(ctx, query) 418 if err != nil { 419 return 0, fmt.Errorf("failed to cleanup expired sessions: %w", err) 420 } 421 422 rows, err := result.RowsAffected() 423 if err != nil { 424 return 0, fmt.Errorf("failed to get rows affected: %w", err) 425 } 426 427 return rows, nil 428} 429 430// CleanupExpiredAuthRequests removes auth requests older than 30 minutes 431// Should be called periodically (e.g., via cron job) 432func (s *PostgresOAuthStore) CleanupExpiredAuthRequests(ctx context.Context) (int64, error) { 433 query := `DELETE FROM oauth_requests WHERE created_at < NOW() - INTERVAL '30 minutes'` 434 435 result, err := s.db.ExecContext(ctx, query) 436 if err != nil { 437 return 0, fmt.Errorf("failed to cleanup expired auth requests: %w", err) 438 } 439 440 rows, err := result.RowsAffected() 441 if err != nil { 442 return 0, fmt.Errorf("failed to get rows affected: %w", err) 443 } 444 445 return rows, nil 446} 447 448// MobileOAuthData holds mobile-specific OAuth flow data 449type MobileOAuthData struct { 450 CSRFToken string 451 RedirectURI string 452} 453 454// mobileFlowContextKey is the context key for mobile flow data 455type mobileFlowContextKey struct{} 456 457// ContextWithMobileFlowData adds mobile flow data to a context. 458// This is used by HandleMobileLogin to pass mobile data to the store wrapper, 459// which will save it when SaveAuthRequestInfo is called by indigo. 460func ContextWithMobileFlowData(ctx context.Context, data MobileOAuthData) context.Context { 461 return context.WithValue(ctx, mobileFlowContextKey{}, data) 462} 463 464// getMobileFlowDataFromContext retrieves mobile flow data from context, if present 465func getMobileFlowDataFromContext(ctx context.Context) (MobileOAuthData, bool) { 466 data, ok := ctx.Value(mobileFlowContextKey{}).(MobileOAuthData) 467 return data, ok 468} 469 470// MobileAwareClientStore is a marker interface that indicates a store is properly 471// configured for mobile OAuth flows. Only stores that intercept SaveAuthRequestInfo 472// to save mobile CSRF data should implement this interface. 473// This prevents silent mobile OAuth breakage when a plain PostgresOAuthStore is used. 474type MobileAwareClientStore interface { 475 IsMobileAware() bool 476} 477 478// MobileAwareStoreWrapper wraps a ClientAuthStore to automatically save mobile 479// CSRF data when SaveAuthRequestInfo is called during a mobile OAuth flow. 480// This is necessary because indigo's StartAuthFlow doesn't expose the OAuth state, 481// so we intercept the SaveAuthRequestInfo call to capture it. 482type MobileAwareStoreWrapper struct { 483 oauth.ClientAuthStore 484 mobileStore MobileOAuthStore 485} 486 487// IsMobileAware implements MobileAwareClientStore, indicating this store 488// properly saves mobile CSRF data during OAuth flow initiation. 489func (w *MobileAwareStoreWrapper) IsMobileAware() bool { 490 return true 491} 492 493// NewMobileAwareStoreWrapper creates a wrapper that intercepts SaveAuthRequestInfo 494// to also save mobile CSRF data when present in context. 495func NewMobileAwareStoreWrapper(store oauth.ClientAuthStore) *MobileAwareStoreWrapper { 496 wrapper := &MobileAwareStoreWrapper{ 497 ClientAuthStore: store, 498 } 499 // Check if the underlying store implements MobileOAuthStore 500 if ms, ok := store.(MobileOAuthStore); ok { 501 wrapper.mobileStore = ms 502 } 503 return wrapper 504} 505 506// SaveAuthRequestInfo saves the auth request and also saves mobile CSRF data 507// if mobile flow data is present in the context. 508func (w *MobileAwareStoreWrapper) SaveAuthRequestInfo(ctx context.Context, info oauth.AuthRequestData) error { 509 // First, save the auth request to the underlying store 510 if err := w.ClientAuthStore.SaveAuthRequestInfo(ctx, info); err != nil { 511 return err 512 } 513 514 // Check if this is a mobile flow (mobile data in context) 515 if mobileData, ok := getMobileFlowDataFromContext(ctx); ok && w.mobileStore != nil { 516 // Save mobile CSRF data tied to this OAuth state 517 // IMPORTANT: If this fails, we MUST propagate the error. Otherwise: 518 // 1. No server-side CSRF record is stored 519 // 2. Every mobile callback will "fail closed" to web flow 520 // 3. Mobile sign-in silently breaks with no indication 521 // Failing loudly here lets the user retry rather than being confused 522 // about why they're getting a web flow instead of mobile. 523 if err := w.mobileStore.SaveMobileOAuthData(ctx, info.State, mobileData); err != nil { 524 slog.Error("failed to save mobile CSRF data - mobile login will fail", 525 "state", info.State, "error", err) 526 return fmt.Errorf("failed to save mobile OAuth data: %w", err) 527 } 528 } 529 530 return nil 531} 532 533// GetMobileOAuthData implements MobileOAuthStore interface by delegating to underlying store 534func (w *MobileAwareStoreWrapper) GetMobileOAuthData(ctx context.Context, state string) (*MobileOAuthData, error) { 535 if w.mobileStore != nil { 536 return w.mobileStore.GetMobileOAuthData(ctx, state) 537 } 538 return nil, nil 539} 540 541// SaveMobileOAuthData implements MobileOAuthStore interface by delegating to underlying store 542func (w *MobileAwareStoreWrapper) SaveMobileOAuthData(ctx context.Context, state string, data MobileOAuthData) error { 543 if w.mobileStore != nil { 544 return w.mobileStore.SaveMobileOAuthData(ctx, state, data) 545 } 546 return nil 547} 548 549// UnwrapPostgresStore returns the underlying PostgresOAuthStore if present. 550// This is useful for accessing cleanup methods that aren't part of the interface. 551func (w *MobileAwareStoreWrapper) UnwrapPostgresStore() *PostgresOAuthStore { 552 if ps, ok := w.ClientAuthStore.(*PostgresOAuthStore); ok { 553 return ps 554 } 555 return nil 556} 557 558// SaveMobileOAuthData stores mobile CSRF data tied to an OAuth state 559// This ties the CSRF token to the OAuth flow via the state parameter, 560// which comes back through the OAuth response for server-side validation. 561func (s *PostgresOAuthStore) SaveMobileOAuthData(ctx context.Context, state string, data MobileOAuthData) error { 562 query := ` 563 UPDATE oauth_requests 564 SET mobile_csrf_token = $2, mobile_redirect_uri = $3 565 WHERE state = $1 566 ` 567 568 result, err := s.db.ExecContext(ctx, query, state, data.CSRFToken, data.RedirectURI) 569 if err != nil { 570 return fmt.Errorf("failed to save mobile OAuth data: %w", err) 571 } 572 573 rows, err := result.RowsAffected() 574 if err != nil { 575 return fmt.Errorf("failed to get rows affected: %w", err) 576 } 577 578 if rows == 0 { 579 return ErrAuthRequestNotFound 580 } 581 582 return nil 583} 584 585// GetMobileOAuthData retrieves mobile CSRF data by OAuth state 586// This is called during callback to compare the server-side CSRF token 587// (retrieved by state from the OAuth response) against the cookie CSRF. 588func (s *PostgresOAuthStore) GetMobileOAuthData(ctx context.Context, state string) (*MobileOAuthData, error) { 589 query := ` 590 SELECT mobile_csrf_token, mobile_redirect_uri 591 FROM oauth_requests 592 WHERE state = $1 593 ` 594 595 var csrfToken, redirectURI sql.NullString 596 err := s.db.QueryRowContext(ctx, query, state).Scan(&csrfToken, &redirectURI) 597 598 if err == sql.ErrNoRows { 599 return nil, ErrAuthRequestNotFound 600 } 601 if err != nil { 602 return nil, fmt.Errorf("failed to get mobile OAuth data: %w", err) 603 } 604 605 // Return nil if no mobile data was stored (this was a web flow) 606 if !csrfToken.Valid { 607 return nil, nil 608 } 609 610 return &MobileOAuthData{ 611 CSRFToken: csrfToken.String, 612 RedirectURI: redirectURI.String, 613 }, nil 614}