appview/oauth: invalidate sessions if inactive for too long #722

merged
opened by oppi.li targeting master from push-rnnvqlrqspsv

if sessions are inactive for too long, tokens will not be refreshed, and calling authorized xrpc methods will error out with invalid_grant. this changeset does two things:

  • tracks the last time a session was active using a new redis pair: oauth:session_meta:<did>:<session>, this is updated every time SaveSession is called
  • checks for session inactivity every time GetSession is called, and deletes the session if so

this way, GetSession will never return a session with expired tokens.

Signed-off-by: oppiliappan me@oppi.li

Changed files
+116 -12
appview
+6 -1
appview/oauth/oauth.go
···
jwksUri := clientUri + "/oauth/jwks.json"
-
authStore, err := NewRedisStore(config.Redis.ToURL())
if err != nil {
return nil, err
}
···
jwksUri := clientUri + "/oauth/jwks.json"
+
authStore, err := NewRedisStore(&RedisStoreConfig{
+
RedisURL: config.Redis.ToURL(),
+
SessionExpiryDuration: time.Hour * 24 * 90,
+
SessionInactivityDuration: time.Hour * 24 * 14,
+
AuthRequestExpiryDuration: time.Minute * 30,
+
})
if err != nil {
return nil, err
}
+110 -11
appview/oauth/store.go
···
"github.com/redis/go-redis/v9"
)
// redis-backed implementation of ClientAuthStore.
type RedisStore struct {
-
client *redis.Client
-
SessionTTL time.Duration
-
AuthRequestTTL time.Duration
}
var _ oauth.ClientAuthStore = &RedisStore{}
-
func NewRedisStore(redisURL string) (*RedisStore, error) {
-
opts, err := redis.ParseURL(redisURL)
if err != nil {
return nil, fmt.Errorf("failed to parse redis URL: %w", err)
}
···
}
return &RedisStore{
-
client: client,
-
SessionTTL: 30 * 24 * time.Hour, // 30 days
-
AuthRequestTTL: 10 * time.Minute, // 10 minutes
}, nil
}
···
return fmt.Sprintf("oauth:session:%s:%s", did, sessionID)
}
func authRequestKey(state string) string {
return fmt.Sprintf("oauth:auth_request:%s", state)
}
func (r *RedisStore) GetSession(ctx context.Context, did syntax.DID, sessionID string) (*oauth.ClientSessionData, error) {
key := sessionKey(did, sessionID)
data, err := r.client.Get(ctx, key).Bytes()
if err == redis.Nil {
return nil, fmt.Errorf("session not found: %s", did)
···
func (r *RedisStore) SaveSession(ctx context.Context, sess oauth.ClientSessionData) error {
key := sessionKey(sess.AccountDID, sess.SessionID)
data, err := json.Marshal(sess)
if err != nil {
return fmt.Errorf("failed to marshal session: %w", err)
}
-
if err := r.client.Set(ctx, key, data, r.SessionTTL).Err(); err != nil {
return fmt.Errorf("failed to save session: %w", err)
}
return nil
}
func (r *RedisStore) DeleteSession(ctx context.Context, did syntax.DID, sessionID string) error {
key := sessionKey(did, sessionID)
-
if err := r.client.Del(ctx, key).Err(); err != nil {
return fmt.Errorf("failed to delete session: %w", err)
}
return nil
···
return fmt.Errorf("failed to marshal auth request: %w", err)
}
-
if err := r.client.Set(ctx, key, data, r.AuthRequestTTL).Err(); err != nil {
return fmt.Errorf("failed to save auth request: %w", err)
}
···
"github.com/redis/go-redis/v9"
)
+
type RedisStoreConfig struct {
+
RedisURL string
+
+
// The purpose of these limits is to avoid dead sessions hanging around in the db indefinitely.
+
// The durations here should be *at least as long as* the expected duration of the oauth session itself.
+
SessionExpiryDuration time.Duration // duration since session creation (max TTL)
+
SessionInactivityDuration time.Duration // duration since last session update
+
AuthRequestExpiryDuration time.Duration // duration since auth request creation
+
}
+
// redis-backed implementation of ClientAuthStore.
type RedisStore struct {
+
client *redis.Client
+
cfg *RedisStoreConfig
}
var _ oauth.ClientAuthStore = &RedisStore{}
+
type sessionMetadata struct {
+
CreatedAt time.Time `json:"created_at"`
+
UpdatedAt time.Time `json:"updated_at"`
+
}
+
+
func NewRedisStore(cfg *RedisStoreConfig) (*RedisStore, error) {
+
if cfg == nil {
+
return nil, fmt.Errorf("missing cfg")
+
}
+
if cfg.RedisURL == "" {
+
return nil, fmt.Errorf("missing RedisURL")
+
}
+
if cfg.SessionExpiryDuration == 0 {
+
return nil, fmt.Errorf("missing SessionExpiryDuration")
+
}
+
if cfg.SessionInactivityDuration == 0 {
+
return nil, fmt.Errorf("missing SessionInactivityDuration")
+
}
+
if cfg.AuthRequestExpiryDuration == 0 {
+
return nil, fmt.Errorf("missing AuthRequestExpiryDuration")
+
}
+
+
opts, err := redis.ParseURL(cfg.RedisURL)
if err != nil {
return nil, fmt.Errorf("failed to parse redis URL: %w", err)
}
···
}
return &RedisStore{
+
client: client,
+
cfg: cfg,
}, nil
}
···
return fmt.Sprintf("oauth:session:%s:%s", did, sessionID)
}
+
func sessionMetadataKey(did syntax.DID, sessionID string) string {
+
return fmt.Sprintf("oauth:session_meta:%s:%s", did, sessionID)
+
}
+
func authRequestKey(state string) string {
return fmt.Sprintf("oauth:auth_request:%s", state)
}
func (r *RedisStore) GetSession(ctx context.Context, did syntax.DID, sessionID string) (*oauth.ClientSessionData, error) {
key := sessionKey(did, sessionID)
+
metaKey := sessionMetadataKey(did, sessionID)
+
+
// Check metadata for inactivity expiry
+
metaData, err := r.client.Get(ctx, metaKey).Bytes()
+
if err == redis.Nil {
+
return nil, fmt.Errorf("session not found: %s", did)
+
}
+
if err != nil {
+
return nil, fmt.Errorf("failed to get session metadata: %w", err)
+
}
+
+
var meta sessionMetadata
+
if err := json.Unmarshal(metaData, &meta); err != nil {
+
return nil, fmt.Errorf("failed to unmarshal session metadata: %w", err)
+
}
+
+
// Check if session has been inactive for too long
+
inactiveThreshold := time.Now().Add(-r.cfg.SessionInactivityDuration)
+
if meta.UpdatedAt.Before(inactiveThreshold) {
+
// Session is inactive, delete it
+
r.client.Del(ctx, key, metaKey)
+
return nil, fmt.Errorf("session expired due to inactivity: %s", did)
+
}
+
+
// Get the actual session data
data, err := r.client.Get(ctx, key).Bytes()
if err == redis.Nil {
return nil, fmt.Errorf("session not found: %s", did)
···
func (r *RedisStore) SaveSession(ctx context.Context, sess oauth.ClientSessionData) error {
key := sessionKey(sess.AccountDID, sess.SessionID)
+
metaKey := sessionMetadataKey(sess.AccountDID, sess.SessionID)
data, err := json.Marshal(sess)
if err != nil {
return fmt.Errorf("failed to marshal session: %w", err)
}
+
// Check if session already exists to preserve CreatedAt
+
var meta sessionMetadata
+
existingMetaData, err := r.client.Get(ctx, metaKey).Bytes()
+
if err == redis.Nil {
+
// New session
+
meta = sessionMetadata{
+
CreatedAt: time.Now(),
+
UpdatedAt: time.Now(),
+
}
+
} else if err != nil {
+
return fmt.Errorf("failed to check existing session metadata: %w", err)
+
} else {
+
// Existing session - preserve CreatedAt, update UpdatedAt
+
if err := json.Unmarshal(existingMetaData, &meta); err != nil {
+
return fmt.Errorf("failed to unmarshal existing session metadata: %w", err)
+
}
+
meta.UpdatedAt = time.Now()
+
}
+
+
// Calculate remaining TTL based on creation time
+
remainingTTL := r.cfg.SessionExpiryDuration - time.Since(meta.CreatedAt)
+
if remainingTTL <= 0 {
+
return fmt.Errorf("session has expired")
+
}
+
+
// Use the shorter of: remaining TTL or inactivity duration
+
ttl := min(r.cfg.SessionInactivityDuration, remainingTTL)
+
+
// Save session data
+
if err := r.client.Set(ctx, key, data, ttl).Err(); err != nil {
return fmt.Errorf("failed to save session: %w", err)
}
+
// Save metadata
+
metaData, err := json.Marshal(meta)
+
if err != nil {
+
return fmt.Errorf("failed to marshal session metadata: %w", err)
+
}
+
if err := r.client.Set(ctx, metaKey, metaData, ttl).Err(); err != nil {
+
return fmt.Errorf("failed to save session metadata: %w", err)
+
}
+
return nil
}
func (r *RedisStore) DeleteSession(ctx context.Context, did syntax.DID, sessionID string) error {
key := sessionKey(did, sessionID)
+
metaKey := sessionMetadataKey(did, sessionID)
+
+
if err := r.client.Del(ctx, key, metaKey).Err(); err != nil {
return fmt.Errorf("failed to delete session: %w", err)
}
return nil
···
return fmt.Errorf("failed to marshal auth request: %w", err)
}
+
if err := r.client.Set(ctx, key, data, r.cfg.AuthRequestExpiryDuration).Err(); err != nil {
return fmt.Errorf("failed to save auth request: %w", err)
}