forked from tangled.org/core
Monorepo for Tangled — https://tangled.org
at master 7.3 kB view raw
1package oauth 2 3import ( 4 "context" 5 "encoding/json" 6 "fmt" 7 "time" 8 9 "github.com/bluesky-social/indigo/atproto/auth/oauth" 10 "github.com/bluesky-social/indigo/atproto/syntax" 11 "github.com/redis/go-redis/v9" 12) 13 14type RedisStoreConfig struct { 15 RedisURL string 16 17 // The purpose of these limits is to avoid dead sessions hanging around in the db indefinitely. 18 // The durations here should be *at least as long as* the expected duration of the oauth session itself. 19 SessionExpiryDuration time.Duration // duration since session creation (max TTL) 20 SessionInactivityDuration time.Duration // duration since last session update 21 AuthRequestExpiryDuration time.Duration // duration since auth request creation 22} 23 24// redis-backed implementation of ClientAuthStore. 25type RedisStore struct { 26 client *redis.Client 27 cfg *RedisStoreConfig 28} 29 30var _ oauth.ClientAuthStore = &RedisStore{} 31 32type sessionMetadata struct { 33 CreatedAt time.Time `json:"created_at"` 34 UpdatedAt time.Time `json:"updated_at"` 35} 36 37func NewRedisStore(cfg *RedisStoreConfig) (*RedisStore, error) { 38 if cfg == nil { 39 return nil, fmt.Errorf("missing cfg") 40 } 41 if cfg.RedisURL == "" { 42 return nil, fmt.Errorf("missing RedisURL") 43 } 44 if cfg.SessionExpiryDuration == 0 { 45 return nil, fmt.Errorf("missing SessionExpiryDuration") 46 } 47 if cfg.SessionInactivityDuration == 0 { 48 return nil, fmt.Errorf("missing SessionInactivityDuration") 49 } 50 if cfg.AuthRequestExpiryDuration == 0 { 51 return nil, fmt.Errorf("missing AuthRequestExpiryDuration") 52 } 53 54 opts, err := redis.ParseURL(cfg.RedisURL) 55 if err != nil { 56 return nil, fmt.Errorf("failed to parse redis URL: %w", err) 57 } 58 59 client := redis.NewClient(opts) 60 61 // test the connection 62 ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) 63 defer cancel() 64 65 if err := client.Ping(ctx).Err(); err != nil { 66 return nil, fmt.Errorf("failed to connect to redis: %w", err) 67 } 68 69 return &RedisStore{ 70 client: client, 71 cfg: cfg, 72 }, nil 73} 74 75func (r *RedisStore) Close() error { 76 return r.client.Close() 77} 78 79func sessionKey(did syntax.DID, sessionID string) string { 80 return fmt.Sprintf("oauth:session:%s:%s", did, sessionID) 81} 82 83func sessionMetadataKey(did syntax.DID, sessionID string) string { 84 return fmt.Sprintf("oauth:session_meta:%s:%s", did, sessionID) 85} 86 87func authRequestKey(state string) string { 88 return fmt.Sprintf("oauth:auth_request:%s", state) 89} 90 91func (r *RedisStore) GetSession(ctx context.Context, did syntax.DID, sessionID string) (*oauth.ClientSessionData, error) { 92 key := sessionKey(did, sessionID) 93 metaKey := sessionMetadataKey(did, sessionID) 94 95 // Check metadata for inactivity expiry 96 metaData, err := r.client.Get(ctx, metaKey).Bytes() 97 if err == redis.Nil { 98 return nil, fmt.Errorf("session not found: %s", did) 99 } 100 if err != nil { 101 return nil, fmt.Errorf("failed to get session metadata: %w", err) 102 } 103 104 var meta sessionMetadata 105 if err := json.Unmarshal(metaData, &meta); err != nil { 106 return nil, fmt.Errorf("failed to unmarshal session metadata: %w", err) 107 } 108 109 // Check if session has been inactive for too long 110 inactiveThreshold := time.Now().Add(-r.cfg.SessionInactivityDuration) 111 if meta.UpdatedAt.Before(inactiveThreshold) { 112 // Session is inactive, delete it 113 r.client.Del(ctx, key, metaKey) 114 return nil, fmt.Errorf("session expired due to inactivity: %s", did) 115 } 116 117 // Get the actual session data 118 data, err := r.client.Get(ctx, key).Bytes() 119 if err == redis.Nil { 120 return nil, fmt.Errorf("session not found: %s", did) 121 } 122 if err != nil { 123 return nil, fmt.Errorf("failed to get session: %w", err) 124 } 125 126 var sess oauth.ClientSessionData 127 if err := json.Unmarshal(data, &sess); err != nil { 128 return nil, fmt.Errorf("failed to unmarshal session: %w", err) 129 } 130 131 return &sess, nil 132} 133 134func (r *RedisStore) SaveSession(ctx context.Context, sess oauth.ClientSessionData) error { 135 key := sessionKey(sess.AccountDID, sess.SessionID) 136 metaKey := sessionMetadataKey(sess.AccountDID, sess.SessionID) 137 138 data, err := json.Marshal(sess) 139 if err != nil { 140 return fmt.Errorf("failed to marshal session: %w", err) 141 } 142 143 // Check if session already exists to preserve CreatedAt 144 var meta sessionMetadata 145 existingMetaData, err := r.client.Get(ctx, metaKey).Bytes() 146 if err == redis.Nil { 147 // New session 148 meta = sessionMetadata{ 149 CreatedAt: time.Now(), 150 UpdatedAt: time.Now(), 151 } 152 } else if err != nil { 153 return fmt.Errorf("failed to check existing session metadata: %w", err) 154 } else { 155 // Existing session - preserve CreatedAt, update UpdatedAt 156 if err := json.Unmarshal(existingMetaData, &meta); err != nil { 157 return fmt.Errorf("failed to unmarshal existing session metadata: %w", err) 158 } 159 meta.UpdatedAt = time.Now() 160 } 161 162 // Calculate remaining TTL based on creation time 163 remainingTTL := r.cfg.SessionExpiryDuration - time.Since(meta.CreatedAt) 164 if remainingTTL <= 0 { 165 return fmt.Errorf("session has expired") 166 } 167 168 // Use the shorter of: remaining TTL or inactivity duration 169 ttl := min(r.cfg.SessionInactivityDuration, remainingTTL) 170 171 // Save session data 172 if err := r.client.Set(ctx, key, data, ttl).Err(); err != nil { 173 return fmt.Errorf("failed to save session: %w", err) 174 } 175 176 // Save metadata 177 metaData, err := json.Marshal(meta) 178 if err != nil { 179 return fmt.Errorf("failed to marshal session metadata: %w", err) 180 } 181 if err := r.client.Set(ctx, metaKey, metaData, ttl).Err(); err != nil { 182 return fmt.Errorf("failed to save session metadata: %w", err) 183 } 184 185 return nil 186} 187 188func (r *RedisStore) DeleteSession(ctx context.Context, did syntax.DID, sessionID string) error { 189 key := sessionKey(did, sessionID) 190 metaKey := sessionMetadataKey(did, sessionID) 191 192 if err := r.client.Del(ctx, key, metaKey).Err(); err != nil { 193 return fmt.Errorf("failed to delete session: %w", err) 194 } 195 return nil 196} 197 198func (r *RedisStore) GetAuthRequestInfo(ctx context.Context, state string) (*oauth.AuthRequestData, error) { 199 key := authRequestKey(state) 200 data, err := r.client.Get(ctx, key).Bytes() 201 if err == redis.Nil { 202 return nil, fmt.Errorf("request info not found: %s", state) 203 } 204 if err != nil { 205 return nil, fmt.Errorf("failed to get auth request: %w", err) 206 } 207 208 var req oauth.AuthRequestData 209 if err := json.Unmarshal(data, &req); err != nil { 210 return nil, fmt.Errorf("failed to unmarshal auth request: %w", err) 211 } 212 213 return &req, nil 214} 215 216func (r *RedisStore) SaveAuthRequestInfo(ctx context.Context, info oauth.AuthRequestData) error { 217 key := authRequestKey(info.State) 218 219 // check if already exists (to match MemStore behavior) 220 exists, err := r.client.Exists(ctx, key).Result() 221 if err != nil { 222 return fmt.Errorf("failed to check auth request existence: %w", err) 223 } 224 if exists > 0 { 225 return fmt.Errorf("auth request already saved for state %s", info.State) 226 } 227 228 data, err := json.Marshal(info) 229 if err != nil { 230 return fmt.Errorf("failed to marshal auth request: %w", err) 231 } 232 233 if err := r.client.Set(ctx, key, data, r.cfg.AuthRequestExpiryDuration).Err(); err != nil { 234 return fmt.Errorf("failed to save auth request: %w", err) 235 } 236 237 return nil 238} 239 240func (r *RedisStore) DeleteAuthRequestInfo(ctx context.Context, state string) error { 241 key := authRequestKey(state) 242 if err := r.client.Del(ctx, key).Err(); err != nil { 243 return fmt.Errorf("failed to delete auth request: %w", err) 244 } 245 return nil 246}