From 771a3c7f7e104add66896248ad6c63e8c54958b7 Mon Sep 17 00:00:00 2001 From: brookjeynes Date: Wed, 29 Oct 2025 15:34:05 +1000 Subject: [PATCH] fix(oauth): invalidate sessions if inactive for too long Change-Id: zkmnpkqwrwwwlwsqpomtqtsozqxuxnqq Signed-off-by: brookjeynes --- internal/server/handlers/login.go | 2 +- internal/server/oauth/oauth.go | 15 ++-- internal/server/oauth/store.go | 128 ++++++++++++++++++++++++++---- 3 files changed, 125 insertions(+), 20 deletions(-) diff --git a/internal/server/handlers/login.go b/internal/server/handlers/login.go index a1fa0cd..9764824 100644 --- a/internal/server/handlers/login.go +++ b/internal/server/handlers/login.go @@ -83,7 +83,7 @@ func (h *Handler) Logout(w http.ResponseWriter, r *http.Request) { if err != nil { l.Error("failed to logout", "err", err) } else { - l.Error("logged out successfully") + l.Debug("logged out successfully") } if !h.Config.Core.Dev && did != "" { diff --git a/internal/server/oauth/oauth.go b/internal/server/oauth/oauth.go index eb3b483..c18cabc 100644 --- a/internal/server/oauth/oauth.go +++ b/internal/server/oauth/oauth.go @@ -47,7 +47,12 @@ func New(config *config.Config, ph posthog.Client, idResolver *atproto.Resolver, jwksUri := clientUri + "/oauth/jwks.json" - authStore, err := NewRedisStore(config.Redis.ToURL()) + authStore, err := NewRedisStore(&RedisStoreConfig{ + RedisURL: config.Redis.ToURL(), + SessionExpiryDuration: time.Hour * 24 * 90, + SessionInactivityDuration: time.Hour * 24 * 14, + AuthRequestExpiryDuration: time.Minute * 30, + }) if err != nil { return nil, err } @@ -138,14 +143,14 @@ func (o *OAuth) DeleteSession(w http.ResponseWriter, r *http.Request) error { } func (o *OAuth) GetUser(r *http.Request) *types.OauthUser { - clientSession, err := o.SessionStore.Get(r, SessionName) - if err != nil || clientSession.IsNew { + sess, err := o.ResumeSession(r) + if err != nil { return nil } return &types.OauthUser{ - Did: clientSession.Values[SessionDid].(string), - Pds: clientSession.Values[SessionPds].(string), + Did: sess.Data.AccountDID.String(), + Pds: sess.Data.HostURL, } } diff --git a/internal/server/oauth/store.go b/internal/server/oauth/store.go index 0323fe9..4a48571 100644 --- a/internal/server/oauth/store.go +++ b/internal/server/oauth/store.go @@ -11,24 +11,55 @@ import ( "github.com/redis/go-redis/v9" ) -// redis-backed implementation of ClientAuthStore. +type RedisStoreConfig struct { + RedisURL string + + // The purpose of these limits is to avoid dead sessions hanging around in + // the db indefinitely. The durations here should be *at least as long as* + // the expected duration of the oauth session itself. + SessionExpiryDuration time.Duration // duration since session creation (max TTL) + SessionInactivityDuration time.Duration // duration since last session update + AuthRequestExpiryDuration time.Duration // duration since auth request creation +} + +// Redis-backed implementation of ClientAuthStore type RedisStore struct { - client *redis.Client - SessionTTL time.Duration - AuthRequestTTL time.Duration + client *redis.Client + cfg *RedisStoreConfig +} + +type sessionMetadata struct { + CreatedAt time.Time `json:"created_at"` + UpdatedAt time.Time `json:"updated_at"` } var _ oauth.ClientAuthStore = &RedisStore{} -func NewRedisStore(redisURL string) (*RedisStore, error) { - opts, err := redis.ParseURL(redisURL) +func NewRedisStore(cfg *RedisStoreConfig) (*RedisStore, error) { + if cfg == nil { + return nil, fmt.Errorf("missing cfg") + } + if cfg.RedisURL == "" { + return nil, fmt.Errorf("missing RedisURL") + } + if cfg.SessionExpiryDuration == 0 { + return nil, fmt.Errorf("missing SessionExpiryDuration") + } + if cfg.SessionInactivityDuration == 0 { + return nil, fmt.Errorf("missing SessionInactivityDuration") + } + if cfg.AuthRequestExpiryDuration == 0 { + return nil, fmt.Errorf("missing AuthRequestExpiryDuration") + } + + opts, err := redis.ParseURL(cfg.RedisURL) if err != nil { return nil, fmt.Errorf("failed to parse redis URL: %w", err) } client := redis.NewClient(opts) - // Test the connection. + // Test the connection ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() @@ -37,9 +68,8 @@ func NewRedisStore(redisURL string) (*RedisStore, error) { } return &RedisStore{ - client: client, - SessionTTL: 30 * 24 * time.Hour, // 30 days - AuthRequestTTL: 10 * time.Minute, // 10 minutes + client: client, + cfg: cfg, }, nil } @@ -51,12 +81,41 @@ func sessionKey(did syntax.DID, sessionID string) string { return fmt.Sprintf("oauth:session:%s:%s", did, sessionID) } +func sessionMetadataKey(did syntax.DID, sessionID string) string { + return fmt.Sprintf("oauth:session_meta:%s:%s", did, sessionID) +} + func authRequestKey(state string) string { return fmt.Sprintf("oauth:auth_request:%s", state) } func (r *RedisStore) GetSession(ctx context.Context, did syntax.DID, sessionID string) (*oauth.ClientSessionData, error) { key := sessionKey(did, sessionID) + metaKey := sessionMetadataKey(did, sessionID) + + // Check metadata for inactivity expiry + metaData, err := r.client.Get(ctx, metaKey).Bytes() + if err == redis.Nil { + return nil, fmt.Errorf("session not found: %s", did) + } + if err != nil { + return nil, fmt.Errorf("failed to get session metadata: %w", err) + } + + var meta sessionMetadata + if err := json.Unmarshal(metaData, &meta); err != nil { + return nil, fmt.Errorf("failed to unmarshal session metadata: %w", err) + } + + // Check if session has been inactive for too long + inactiveThreshold := time.Now().Add(-r.cfg.SessionInactivityDuration) + if meta.UpdatedAt.Before(inactiveThreshold) { + // Session is inactive, delete it + r.client.Del(ctx, key, metaKey) + return nil, fmt.Errorf("session expired due to inactivity: %s", did) + } + + // Get the actual session data data, err := r.client.Get(ctx, key).Bytes() if err == redis.Nil { return nil, fmt.Errorf("session not found: %s", did) @@ -75,22 +134,63 @@ func (r *RedisStore) GetSession(ctx context.Context, did syntax.DID, sessionID s func (r *RedisStore) SaveSession(ctx context.Context, sess oauth.ClientSessionData) error { key := sessionKey(sess.AccountDID, sess.SessionID) + metaKey := sessionMetadataKey(sess.AccountDID, sess.SessionID) data, err := json.Marshal(sess) if err != nil { return fmt.Errorf("failed to marshal session: %w", err) } - if err := r.client.Set(ctx, key, data, r.SessionTTL).Err(); err != nil { + // Check if session already exists to preserve CreatedAt + var meta sessionMetadata + existingMetaData, err := r.client.Get(ctx, metaKey).Bytes() + if err == redis.Nil { + // New session + meta = sessionMetadata{ + CreatedAt: time.Now(), + UpdatedAt: time.Now(), + } + } else if err != nil { + return fmt.Errorf("failed to check existing session metadata: %w", err) + } else { + // Existing session - preserve CreatedAt, update UpdatedAt + if err := json.Unmarshal(existingMetaData, &meta); err != nil { + return fmt.Errorf("failed to unmarshal existing session metadata: %w", err) + } + meta.UpdatedAt = time.Now() + } + + // Calculate remaining TTL based on creation time + remainingTTL := r.cfg.SessionExpiryDuration - time.Since(meta.CreatedAt) + if remainingTTL <= 0 { + return fmt.Errorf("session has expired") + } + + // Use the shorter of: remaining TTL or inactivity duration + ttl := min(r.cfg.SessionInactivityDuration, remainingTTL) + + // Save session data + if err := r.client.Set(ctx, key, data, ttl).Err(); err != nil { return fmt.Errorf("failed to save session: %w", err) } + // Save metadata + metaData, err := json.Marshal(meta) + if err != nil { + return fmt.Errorf("failed to marshal session metadata: %w", err) + } + if err := r.client.Set(ctx, metaKey, metaData, ttl).Err(); err != nil { + return fmt.Errorf("failed to save session metadata: %w", err) + } + return nil } func (r *RedisStore) DeleteSession(ctx context.Context, did syntax.DID, sessionID string) error { key := sessionKey(did, sessionID) - if err := r.client.Del(ctx, key).Err(); err != nil { + metaKey := sessionMetadataKey(did, sessionID) + + if err := r.client.Del(ctx, key, metaKey).Err(); err != nil { return fmt.Errorf("failed to delete session: %w", err) } return nil @@ -117,7 +217,7 @@ func (r *RedisStore) GetAuthRequestInfo(ctx context.Context, state string) (*oau func (r *RedisStore) SaveAuthRequestInfo(ctx context.Context, info oauth.AuthRequestData) error { key := authRequestKey(info.State) - // check if already exists (to match MemStore behavior) + // Check if already exists (to match MemStore behavior) exists, err := r.client.Exists(ctx, key).Result() if err != nil { return fmt.Errorf("failed to check auth request existence: %w", err) @@ -131,7 +231,7 @@ func (r *RedisStore) SaveAuthRequestInfo(ctx context.Context, info oauth.AuthReq return fmt.Errorf("failed to marshal auth request: %w", err) } - if err := r.client.Set(ctx, key, data, r.AuthRequestTTL).Err(); err != nil { + if err := r.client.Set(ctx, key, data, r.cfg.AuthRequestExpiryDuration).Err(); err != nil { return fmt.Errorf("failed to save auth request: %w", err) } -- 2.43.0