A community based topic aggregation platform built on atproto
1package oauth
2
3import (
4 "context"
5 "database/sql"
6 "errors"
7 "fmt"
8 "log/slog"
9 "net/url"
10 "strings"
11 "time"
12
13 "github.com/bluesky-social/indigo/atproto/auth/oauth"
14 "github.com/bluesky-social/indigo/atproto/syntax"
15 "github.com/lib/pq"
16)
17
18var (
19 ErrSessionNotFound = errors.New("oauth session not found")
20 ErrAuthRequestNotFound = errors.New("oauth auth request not found")
21)
22
23// PostgresOAuthStore implements oauth.ClientAuthStore interface using PostgreSQL
24type PostgresOAuthStore struct {
25 db *sql.DB
26 sessionTTL time.Duration
27}
28
29// NewPostgresOAuthStore creates a new PostgreSQL-backed OAuth store
30func NewPostgresOAuthStore(db *sql.DB, sessionTTL time.Duration) oauth.ClientAuthStore {
31 if sessionTTL == 0 {
32 sessionTTL = 7 * 24 * time.Hour // Default to 7 days
33 }
34 return &PostgresOAuthStore{
35 db: db,
36 sessionTTL: sessionTTL,
37 }
38}
39
40// GetSession retrieves a session by DID and session ID
41func (s *PostgresOAuthStore) GetSession(ctx context.Context, did syntax.DID, sessionID string) (*oauth.ClientSessionData, error) {
42 query := `
43 SELECT
44 did, session_id, host_url, auth_server_iss,
45 auth_server_token_endpoint, auth_server_revocation_endpoint,
46 scopes, access_token, refresh_token,
47 dpop_authserver_nonce, dpop_pds_nonce, dpop_private_key_multibase
48 FROM oauth_sessions
49 WHERE did = $1 AND session_id = $2 AND expires_at > NOW()
50 `
51
52 var session oauth.ClientSessionData
53 var authServerIss, authServerTokenEndpoint, authServerRevocationEndpoint sql.NullString
54 var hostURL, dpopPrivateKeyMultibase sql.NullString
55 var scopes pq.StringArray
56 var dpopAuthServerNonce, dpopHostNonce sql.NullString
57
58 err := s.db.QueryRowContext(ctx, query, did.String(), sessionID).Scan(
59 &session.AccountDID,
60 &session.SessionID,
61 &hostURL,
62 &authServerIss,
63 &authServerTokenEndpoint,
64 &authServerRevocationEndpoint,
65 &scopes,
66 &session.AccessToken,
67 &session.RefreshToken,
68 &dpopAuthServerNonce,
69 &dpopHostNonce,
70 &dpopPrivateKeyMultibase,
71 )
72
73 if err == sql.ErrNoRows {
74 return nil, ErrSessionNotFound
75 }
76 if err != nil {
77 return nil, fmt.Errorf("failed to get session: %w", err)
78 }
79
80 // Convert nullable fields
81 if hostURL.Valid {
82 session.HostURL = hostURL.String
83 }
84 if authServerIss.Valid {
85 session.AuthServerURL = authServerIss.String
86 }
87 if authServerTokenEndpoint.Valid {
88 session.AuthServerTokenEndpoint = authServerTokenEndpoint.String
89 }
90 if authServerRevocationEndpoint.Valid {
91 session.AuthServerRevocationEndpoint = authServerRevocationEndpoint.String
92 }
93 if dpopAuthServerNonce.Valid {
94 session.DPoPAuthServerNonce = dpopAuthServerNonce.String
95 }
96 if dpopHostNonce.Valid {
97 session.DPoPHostNonce = dpopHostNonce.String
98 }
99 if dpopPrivateKeyMultibase.Valid {
100 session.DPoPPrivateKeyMultibase = dpopPrivateKeyMultibase.String
101 }
102 session.Scopes = scopes
103
104 return &session, nil
105}
106
107// SaveSession saves or updates a session (upsert operation)
108func (s *PostgresOAuthStore) SaveSession(ctx context.Context, sess oauth.ClientSessionData) error {
109 // Input validation per atProto OAuth security requirements
110
111 // Validate DID format
112 if _, err := syntax.ParseDID(sess.AccountDID.String()); err != nil {
113 return fmt.Errorf("invalid DID format: %w", err)
114 }
115
116 // Validate token lengths (max 10000 chars to prevent memory issues)
117 const maxTokenLength = 10000
118 if len(sess.AccessToken) > maxTokenLength {
119 return fmt.Errorf("access_token exceeds maximum length of %d characters", maxTokenLength)
120 }
121 if len(sess.RefreshToken) > maxTokenLength {
122 return fmt.Errorf("refresh_token exceeds maximum length of %d characters", maxTokenLength)
123 }
124
125 // Validate session ID is not empty
126 if sess.SessionID == "" {
127 return fmt.Errorf("session_id cannot be empty")
128 }
129
130 // Validate URLs if provided
131 if sess.HostURL != "" {
132 if _, err := url.Parse(sess.HostURL); err != nil {
133 return fmt.Errorf("invalid host_url: %w", err)
134 }
135 }
136 if sess.AuthServerURL != "" {
137 if _, err := url.Parse(sess.AuthServerURL); err != nil {
138 return fmt.Errorf("invalid auth_server URL: %w", err)
139 }
140 }
141 if sess.AuthServerTokenEndpoint != "" {
142 if _, err := url.Parse(sess.AuthServerTokenEndpoint); err != nil {
143 return fmt.Errorf("invalid auth_server_token_endpoint: %w", err)
144 }
145 }
146 if sess.AuthServerRevocationEndpoint != "" {
147 if _, err := url.Parse(sess.AuthServerRevocationEndpoint); err != nil {
148 return fmt.Errorf("invalid auth_server_revocation_endpoint: %w", err)
149 }
150 }
151
152 query := `
153 INSERT INTO oauth_sessions (
154 did, session_id, handle, pds_url, host_url,
155 access_token, refresh_token,
156 dpop_private_jwk, dpop_private_key_multibase,
157 dpop_authserver_nonce, dpop_pds_nonce,
158 auth_server_iss, auth_server_token_endpoint, auth_server_revocation_endpoint,
159 scopes, expires_at, created_at, updated_at
160 ) VALUES (
161 $1, $2, $3, $4, $5,
162 $6, $7,
163 NULL, $8,
164 $9, $10,
165 $11, $12, $13,
166 $14, $15, NOW(), NOW()
167 )
168 ON CONFLICT (did, session_id) DO UPDATE SET
169 handle = EXCLUDED.handle,
170 pds_url = EXCLUDED.pds_url,
171 host_url = EXCLUDED.host_url,
172 access_token = EXCLUDED.access_token,
173 refresh_token = EXCLUDED.refresh_token,
174 dpop_private_key_multibase = EXCLUDED.dpop_private_key_multibase,
175 dpop_authserver_nonce = EXCLUDED.dpop_authserver_nonce,
176 dpop_pds_nonce = EXCLUDED.dpop_pds_nonce,
177 auth_server_iss = EXCLUDED.auth_server_iss,
178 auth_server_token_endpoint = EXCLUDED.auth_server_token_endpoint,
179 auth_server_revocation_endpoint = EXCLUDED.auth_server_revocation_endpoint,
180 scopes = EXCLUDED.scopes,
181 expires_at = EXCLUDED.expires_at,
182 updated_at = NOW()
183 `
184
185 // Calculate token expiration using configured TTL
186 expiresAt := time.Now().Add(s.sessionTTL)
187
188 // Convert empty strings to NULL for optional fields
189 var authServerRevocationEndpoint sql.NullString
190 if sess.AuthServerRevocationEndpoint != "" {
191 authServerRevocationEndpoint.String = sess.AuthServerRevocationEndpoint
192 authServerRevocationEndpoint.Valid = true
193 }
194
195 // Extract handle from DID (placeholder - in real implementation, resolve from identity)
196 // For now, use DID as handle since we don't have the handle in ClientSessionData
197 handle := sess.AccountDID.String()
198
199 // Use HostURL as PDS URL
200 pdsURL := sess.HostURL
201 if pdsURL == "" {
202 pdsURL = sess.AuthServerURL // Fallback to auth server URL
203 }
204
205 _, err := s.db.ExecContext(
206 ctx, query,
207 sess.AccountDID.String(),
208 sess.SessionID,
209 handle,
210 pdsURL,
211 sess.HostURL,
212 sess.AccessToken,
213 sess.RefreshToken,
214 sess.DPoPPrivateKeyMultibase,
215 sess.DPoPAuthServerNonce,
216 sess.DPoPHostNonce,
217 sess.AuthServerURL,
218 sess.AuthServerTokenEndpoint,
219 authServerRevocationEndpoint,
220 pq.Array(sess.Scopes),
221 expiresAt,
222 )
223 if err != nil {
224 return fmt.Errorf("failed to save session: %w", err)
225 }
226
227 return nil
228}
229
230// DeleteSession deletes a session by DID and session ID
231func (s *PostgresOAuthStore) DeleteSession(ctx context.Context, did syntax.DID, sessionID string) error {
232 query := `DELETE FROM oauth_sessions WHERE did = $1 AND session_id = $2`
233
234 result, err := s.db.ExecContext(ctx, query, did.String(), sessionID)
235 if err != nil {
236 return fmt.Errorf("failed to delete session: %w", err)
237 }
238
239 rows, err := result.RowsAffected()
240 if err != nil {
241 return fmt.Errorf("failed to get rows affected: %w", err)
242 }
243
244 if rows == 0 {
245 return ErrSessionNotFound
246 }
247
248 return nil
249}
250
251// GetAuthRequestInfo retrieves auth request information by state
252func (s *PostgresOAuthStore) GetAuthRequestInfo(ctx context.Context, state string) (*oauth.AuthRequestData, error) {
253 query := `
254 SELECT
255 state, did, handle, pds_url, pkce_verifier,
256 dpop_private_key_multibase, dpop_authserver_nonce,
257 auth_server_iss, request_uri,
258 auth_server_token_endpoint, auth_server_revocation_endpoint,
259 scopes, created_at
260 FROM oauth_requests
261 WHERE state = $1
262 `
263
264 var info oauth.AuthRequestData
265 var did, handle, pdsURL sql.NullString
266 var dpopPrivateKeyMultibase, dpopAuthServerNonce sql.NullString
267 var requestURI, authServerTokenEndpoint, authServerRevocationEndpoint sql.NullString
268 var scopes pq.StringArray
269 var createdAt time.Time
270
271 err := s.db.QueryRowContext(ctx, query, state).Scan(
272 &info.State,
273 &did,
274 &handle,
275 &pdsURL,
276 &info.PKCEVerifier,
277 &dpopPrivateKeyMultibase,
278 &dpopAuthServerNonce,
279 &info.AuthServerURL,
280 &requestURI,
281 &authServerTokenEndpoint,
282 &authServerRevocationEndpoint,
283 &scopes,
284 &createdAt,
285 )
286
287 if err == sql.ErrNoRows {
288 return nil, ErrAuthRequestNotFound
289 }
290 if err != nil {
291 return nil, fmt.Errorf("failed to get auth request info: %w", err)
292 }
293
294 // Parse DID if present
295 if did.Valid && did.String != "" {
296 parsedDID, err := syntax.ParseDID(did.String)
297 if err != nil {
298 return nil, fmt.Errorf("failed to parse DID: %w", err)
299 }
300 info.AccountDID = &parsedDID
301 }
302
303 // Convert nullable fields
304 if dpopPrivateKeyMultibase.Valid {
305 info.DPoPPrivateKeyMultibase = dpopPrivateKeyMultibase.String
306 }
307 if dpopAuthServerNonce.Valid {
308 info.DPoPAuthServerNonce = dpopAuthServerNonce.String
309 }
310 if requestURI.Valid {
311 info.RequestURI = requestURI.String
312 }
313 if authServerTokenEndpoint.Valid {
314 info.AuthServerTokenEndpoint = authServerTokenEndpoint.String
315 }
316 if authServerRevocationEndpoint.Valid {
317 info.AuthServerRevocationEndpoint = authServerRevocationEndpoint.String
318 }
319 info.Scopes = scopes
320
321 return &info, nil
322}
323
324// SaveAuthRequestInfo saves auth request information (create only, not upsert)
325func (s *PostgresOAuthStore) SaveAuthRequestInfo(ctx context.Context, info oauth.AuthRequestData) error {
326 query := `
327 INSERT INTO oauth_requests (
328 state, did, handle, pds_url, pkce_verifier,
329 dpop_private_key_multibase, dpop_authserver_nonce,
330 auth_server_iss, request_uri,
331 auth_server_token_endpoint, auth_server_revocation_endpoint,
332 scopes, return_url, created_at
333 ) VALUES (
334 $1, $2, $3, $4, $5,
335 $6, $7,
336 $8, $9,
337 $10, $11,
338 $12, NULL, NOW()
339 )
340 `
341
342 // Extract DID string if present
343 var didStr sql.NullString
344 if info.AccountDID != nil {
345 didStr.String = info.AccountDID.String()
346 didStr.Valid = true
347 }
348
349 // Convert empty strings to NULL for optional fields
350 var authServerRevocationEndpoint sql.NullString
351 if info.AuthServerRevocationEndpoint != "" {
352 authServerRevocationEndpoint.String = info.AuthServerRevocationEndpoint
353 authServerRevocationEndpoint.Valid = true
354 }
355
356 // Placeholder values for handle and pds_url (not in AuthRequestData)
357 // In production, these would be resolved during the auth flow
358 handle := ""
359 pdsURL := ""
360 if info.AccountDID != nil {
361 handle = info.AccountDID.String() // Temporary placeholder
362 pdsURL = info.AuthServerURL // Temporary placeholder
363 }
364
365 _, err := s.db.ExecContext(
366 ctx, query,
367 info.State,
368 didStr,
369 handle,
370 pdsURL,
371 info.PKCEVerifier,
372 info.DPoPPrivateKeyMultibase,
373 info.DPoPAuthServerNonce,
374 info.AuthServerURL,
375 info.RequestURI,
376 info.AuthServerTokenEndpoint,
377 authServerRevocationEndpoint,
378 pq.Array(info.Scopes),
379 )
380 if err != nil {
381 // Check for duplicate state
382 if strings.Contains(err.Error(), "duplicate key") && strings.Contains(err.Error(), "oauth_requests_state_key") {
383 return fmt.Errorf("auth request with state already exists: %s", info.State)
384 }
385 return fmt.Errorf("failed to save auth request info: %w", err)
386 }
387
388 return nil
389}
390
391// DeleteAuthRequestInfo deletes auth request information by state
392func (s *PostgresOAuthStore) DeleteAuthRequestInfo(ctx context.Context, state string) error {
393 query := `DELETE FROM oauth_requests WHERE state = $1`
394
395 result, err := s.db.ExecContext(ctx, query, state)
396 if err != nil {
397 return fmt.Errorf("failed to delete auth request info: %w", err)
398 }
399
400 rows, err := result.RowsAffected()
401 if err != nil {
402 return fmt.Errorf("failed to get rows affected: %w", err)
403 }
404
405 if rows == 0 {
406 return ErrAuthRequestNotFound
407 }
408
409 return nil
410}
411
412// CleanupExpiredSessions removes sessions that have expired
413// Should be called periodically (e.g., via cron job)
414func (s *PostgresOAuthStore) CleanupExpiredSessions(ctx context.Context) (int64, error) {
415 query := `DELETE FROM oauth_sessions WHERE expires_at < NOW()`
416
417 result, err := s.db.ExecContext(ctx, query)
418 if err != nil {
419 return 0, fmt.Errorf("failed to cleanup expired sessions: %w", err)
420 }
421
422 rows, err := result.RowsAffected()
423 if err != nil {
424 return 0, fmt.Errorf("failed to get rows affected: %w", err)
425 }
426
427 return rows, nil
428}
429
430// CleanupExpiredAuthRequests removes auth requests older than 30 minutes
431// Should be called periodically (e.g., via cron job)
432func (s *PostgresOAuthStore) CleanupExpiredAuthRequests(ctx context.Context) (int64, error) {
433 query := `DELETE FROM oauth_requests WHERE created_at < NOW() - INTERVAL '30 minutes'`
434
435 result, err := s.db.ExecContext(ctx, query)
436 if err != nil {
437 return 0, fmt.Errorf("failed to cleanup expired auth requests: %w", err)
438 }
439
440 rows, err := result.RowsAffected()
441 if err != nil {
442 return 0, fmt.Errorf("failed to get rows affected: %w", err)
443 }
444
445 return rows, nil
446}
447
448// MobileOAuthData holds mobile-specific OAuth flow data
449type MobileOAuthData struct {
450 CSRFToken string
451 RedirectURI string
452}
453
454// mobileFlowContextKey is the context key for mobile flow data
455type mobileFlowContextKey struct{}
456
457// ContextWithMobileFlowData adds mobile flow data to a context.
458// This is used by HandleMobileLogin to pass mobile data to the store wrapper,
459// which will save it when SaveAuthRequestInfo is called by indigo.
460func ContextWithMobileFlowData(ctx context.Context, data MobileOAuthData) context.Context {
461 return context.WithValue(ctx, mobileFlowContextKey{}, data)
462}
463
464// getMobileFlowDataFromContext retrieves mobile flow data from context, if present
465func getMobileFlowDataFromContext(ctx context.Context) (MobileOAuthData, bool) {
466 data, ok := ctx.Value(mobileFlowContextKey{}).(MobileOAuthData)
467 return data, ok
468}
469
470// MobileAwareClientStore is a marker interface that indicates a store is properly
471// configured for mobile OAuth flows. Only stores that intercept SaveAuthRequestInfo
472// to save mobile CSRF data should implement this interface.
473// This prevents silent mobile OAuth breakage when a plain PostgresOAuthStore is used.
474type MobileAwareClientStore interface {
475 IsMobileAware() bool
476}
477
478// MobileAwareStoreWrapper wraps a ClientAuthStore to automatically save mobile
479// CSRF data when SaveAuthRequestInfo is called during a mobile OAuth flow.
480// This is necessary because indigo's StartAuthFlow doesn't expose the OAuth state,
481// so we intercept the SaveAuthRequestInfo call to capture it.
482type MobileAwareStoreWrapper struct {
483 oauth.ClientAuthStore
484 mobileStore MobileOAuthStore
485}
486
487// IsMobileAware implements MobileAwareClientStore, indicating this store
488// properly saves mobile CSRF data during OAuth flow initiation.
489func (w *MobileAwareStoreWrapper) IsMobileAware() bool {
490 return true
491}
492
493// NewMobileAwareStoreWrapper creates a wrapper that intercepts SaveAuthRequestInfo
494// to also save mobile CSRF data when present in context.
495func NewMobileAwareStoreWrapper(store oauth.ClientAuthStore) *MobileAwareStoreWrapper {
496 wrapper := &MobileAwareStoreWrapper{
497 ClientAuthStore: store,
498 }
499 // Check if the underlying store implements MobileOAuthStore
500 if ms, ok := store.(MobileOAuthStore); ok {
501 wrapper.mobileStore = ms
502 }
503 return wrapper
504}
505
506// SaveAuthRequestInfo saves the auth request and also saves mobile CSRF data
507// if mobile flow data is present in the context.
508func (w *MobileAwareStoreWrapper) SaveAuthRequestInfo(ctx context.Context, info oauth.AuthRequestData) error {
509 // First, save the auth request to the underlying store
510 if err := w.ClientAuthStore.SaveAuthRequestInfo(ctx, info); err != nil {
511 return err
512 }
513
514 // Check if this is a mobile flow (mobile data in context)
515 if mobileData, ok := getMobileFlowDataFromContext(ctx); ok && w.mobileStore != nil {
516 // Save mobile CSRF data tied to this OAuth state
517 // IMPORTANT: If this fails, we MUST propagate the error. Otherwise:
518 // 1. No server-side CSRF record is stored
519 // 2. Every mobile callback will "fail closed" to web flow
520 // 3. Mobile sign-in silently breaks with no indication
521 // Failing loudly here lets the user retry rather than being confused
522 // about why they're getting a web flow instead of mobile.
523 if err := w.mobileStore.SaveMobileOAuthData(ctx, info.State, mobileData); err != nil {
524 slog.Error("failed to save mobile CSRF data - mobile login will fail",
525 "state", info.State, "error", err)
526 return fmt.Errorf("failed to save mobile OAuth data: %w", err)
527 }
528 }
529
530 return nil
531}
532
533// GetMobileOAuthData implements MobileOAuthStore interface by delegating to underlying store
534func (w *MobileAwareStoreWrapper) GetMobileOAuthData(ctx context.Context, state string) (*MobileOAuthData, error) {
535 if w.mobileStore != nil {
536 return w.mobileStore.GetMobileOAuthData(ctx, state)
537 }
538 return nil, nil
539}
540
541// SaveMobileOAuthData implements MobileOAuthStore interface by delegating to underlying store
542func (w *MobileAwareStoreWrapper) SaveMobileOAuthData(ctx context.Context, state string, data MobileOAuthData) error {
543 if w.mobileStore != nil {
544 return w.mobileStore.SaveMobileOAuthData(ctx, state, data)
545 }
546 return nil
547}
548
549// UnwrapPostgresStore returns the underlying PostgresOAuthStore if present.
550// This is useful for accessing cleanup methods that aren't part of the interface.
551func (w *MobileAwareStoreWrapper) UnwrapPostgresStore() *PostgresOAuthStore {
552 if ps, ok := w.ClientAuthStore.(*PostgresOAuthStore); ok {
553 return ps
554 }
555 return nil
556}
557
558// SaveMobileOAuthData stores mobile CSRF data tied to an OAuth state
559// This ties the CSRF token to the OAuth flow via the state parameter,
560// which comes back through the OAuth response for server-side validation.
561func (s *PostgresOAuthStore) SaveMobileOAuthData(ctx context.Context, state string, data MobileOAuthData) error {
562 query := `
563 UPDATE oauth_requests
564 SET mobile_csrf_token = $2, mobile_redirect_uri = $3
565 WHERE state = $1
566 `
567
568 result, err := s.db.ExecContext(ctx, query, state, data.CSRFToken, data.RedirectURI)
569 if err != nil {
570 return fmt.Errorf("failed to save mobile OAuth data: %w", err)
571 }
572
573 rows, err := result.RowsAffected()
574 if err != nil {
575 return fmt.Errorf("failed to get rows affected: %w", err)
576 }
577
578 if rows == 0 {
579 return ErrAuthRequestNotFound
580 }
581
582 return nil
583}
584
585// GetMobileOAuthData retrieves mobile CSRF data by OAuth state
586// This is called during callback to compare the server-side CSRF token
587// (retrieved by state from the OAuth response) against the cookie CSRF.
588func (s *PostgresOAuthStore) GetMobileOAuthData(ctx context.Context, state string) (*MobileOAuthData, error) {
589 query := `
590 SELECT mobile_csrf_token, mobile_redirect_uri
591 FROM oauth_requests
592 WHERE state = $1
593 `
594
595 var csrfToken, redirectURI sql.NullString
596 err := s.db.QueryRowContext(ctx, query, state).Scan(&csrfToken, &redirectURI)
597
598 if err == sql.ErrNoRows {
599 return nil, ErrAuthRequestNotFound
600 }
601 if err != nil {
602 return nil, fmt.Errorf("failed to get mobile OAuth data: %w", err)
603 }
604
605 // Return nil if no mobile data was stored (this was a web flow)
606 if !csrfToken.Valid {
607 return nil, nil
608 }
609
610 return &MobileOAuthData{
611 CSRFToken: csrfToken.String,
612 RedirectURI: redirectURI.String,
613 }, nil
614}