forked from
tangled.org/core
Monorepo for Tangled — https://tangled.org
1package oauth
2
3import (
4 "context"
5 "encoding/json"
6 "fmt"
7 "time"
8
9 "github.com/bluesky-social/indigo/atproto/auth/oauth"
10 "github.com/bluesky-social/indigo/atproto/syntax"
11 "github.com/redis/go-redis/v9"
12)
13
14type RedisStoreConfig struct {
15 RedisURL string
16
17 // The purpose of these limits is to avoid dead sessions hanging around in the db indefinitely.
18 // The durations here should be *at least as long as* the expected duration of the oauth session itself.
19 SessionExpiryDuration time.Duration // duration since session creation (max TTL)
20 SessionInactivityDuration time.Duration // duration since last session update
21 AuthRequestExpiryDuration time.Duration // duration since auth request creation
22}
23
24// redis-backed implementation of ClientAuthStore.
25type RedisStore struct {
26 client *redis.Client
27 cfg *RedisStoreConfig
28}
29
30var _ oauth.ClientAuthStore = &RedisStore{}
31
32type sessionMetadata struct {
33 CreatedAt time.Time `json:"created_at"`
34 UpdatedAt time.Time `json:"updated_at"`
35}
36
37func NewRedisStore(cfg *RedisStoreConfig) (*RedisStore, error) {
38 if cfg == nil {
39 return nil, fmt.Errorf("missing cfg")
40 }
41 if cfg.RedisURL == "" {
42 return nil, fmt.Errorf("missing RedisURL")
43 }
44 if cfg.SessionExpiryDuration == 0 {
45 return nil, fmt.Errorf("missing SessionExpiryDuration")
46 }
47 if cfg.SessionInactivityDuration == 0 {
48 return nil, fmt.Errorf("missing SessionInactivityDuration")
49 }
50 if cfg.AuthRequestExpiryDuration == 0 {
51 return nil, fmt.Errorf("missing AuthRequestExpiryDuration")
52 }
53
54 opts, err := redis.ParseURL(cfg.RedisURL)
55 if err != nil {
56 return nil, fmt.Errorf("failed to parse redis URL: %w", err)
57 }
58
59 client := redis.NewClient(opts)
60
61 // test the connection
62 ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
63 defer cancel()
64
65 if err := client.Ping(ctx).Err(); err != nil {
66 return nil, fmt.Errorf("failed to connect to redis: %w", err)
67 }
68
69 return &RedisStore{
70 client: client,
71 cfg: cfg,
72 }, nil
73}
74
75func (r *RedisStore) Close() error {
76 return r.client.Close()
77}
78
79func sessionKey(did syntax.DID, sessionID string) string {
80 return fmt.Sprintf("oauth:session:%s:%s", did, sessionID)
81}
82
83func sessionMetadataKey(did syntax.DID, sessionID string) string {
84 return fmt.Sprintf("oauth:session_meta:%s:%s", did, sessionID)
85}
86
87func authRequestKey(state string) string {
88 return fmt.Sprintf("oauth:auth_request:%s", state)
89}
90
91func (r *RedisStore) GetSession(ctx context.Context, did syntax.DID, sessionID string) (*oauth.ClientSessionData, error) {
92 key := sessionKey(did, sessionID)
93 metaKey := sessionMetadataKey(did, sessionID)
94
95 // Check metadata for inactivity expiry
96 metaData, err := r.client.Get(ctx, metaKey).Bytes()
97 if err == redis.Nil {
98 return nil, fmt.Errorf("session not found: %s", did)
99 }
100 if err != nil {
101 return nil, fmt.Errorf("failed to get session metadata: %w", err)
102 }
103
104 var meta sessionMetadata
105 if err := json.Unmarshal(metaData, &meta); err != nil {
106 return nil, fmt.Errorf("failed to unmarshal session metadata: %w", err)
107 }
108
109 // Check if session has been inactive for too long
110 inactiveThreshold := time.Now().Add(-r.cfg.SessionInactivityDuration)
111 if meta.UpdatedAt.Before(inactiveThreshold) {
112 // Session is inactive, delete it
113 r.client.Del(ctx, key, metaKey)
114 return nil, fmt.Errorf("session expired due to inactivity: %s", did)
115 }
116
117 // Get the actual session data
118 data, err := r.client.Get(ctx, key).Bytes()
119 if err == redis.Nil {
120 return nil, fmt.Errorf("session not found: %s", did)
121 }
122 if err != nil {
123 return nil, fmt.Errorf("failed to get session: %w", err)
124 }
125
126 var sess oauth.ClientSessionData
127 if err := json.Unmarshal(data, &sess); err != nil {
128 return nil, fmt.Errorf("failed to unmarshal session: %w", err)
129 }
130
131 return &sess, nil
132}
133
134func (r *RedisStore) SaveSession(ctx context.Context, sess oauth.ClientSessionData) error {
135 key := sessionKey(sess.AccountDID, sess.SessionID)
136 metaKey := sessionMetadataKey(sess.AccountDID, sess.SessionID)
137
138 data, err := json.Marshal(sess)
139 if err != nil {
140 return fmt.Errorf("failed to marshal session: %w", err)
141 }
142
143 // Check if session already exists to preserve CreatedAt
144 var meta sessionMetadata
145 existingMetaData, err := r.client.Get(ctx, metaKey).Bytes()
146 if err == redis.Nil {
147 // New session
148 meta = sessionMetadata{
149 CreatedAt: time.Now(),
150 UpdatedAt: time.Now(),
151 }
152 } else if err != nil {
153 return fmt.Errorf("failed to check existing session metadata: %w", err)
154 } else {
155 // Existing session - preserve CreatedAt, update UpdatedAt
156 if err := json.Unmarshal(existingMetaData, &meta); err != nil {
157 return fmt.Errorf("failed to unmarshal existing session metadata: %w", err)
158 }
159 meta.UpdatedAt = time.Now()
160 }
161
162 // Calculate remaining TTL based on creation time
163 remainingTTL := r.cfg.SessionExpiryDuration - time.Since(meta.CreatedAt)
164 if remainingTTL <= 0 {
165 return fmt.Errorf("session has expired")
166 }
167
168 // Use the shorter of: remaining TTL or inactivity duration
169 ttl := min(r.cfg.SessionInactivityDuration, remainingTTL)
170
171 // Save session data
172 if err := r.client.Set(ctx, key, data, ttl).Err(); err != nil {
173 return fmt.Errorf("failed to save session: %w", err)
174 }
175
176 // Save metadata
177 metaData, err := json.Marshal(meta)
178 if err != nil {
179 return fmt.Errorf("failed to marshal session metadata: %w", err)
180 }
181 if err := r.client.Set(ctx, metaKey, metaData, ttl).Err(); err != nil {
182 return fmt.Errorf("failed to save session metadata: %w", err)
183 }
184
185 return nil
186}
187
188func (r *RedisStore) DeleteSession(ctx context.Context, did syntax.DID, sessionID string) error {
189 key := sessionKey(did, sessionID)
190 metaKey := sessionMetadataKey(did, sessionID)
191
192 if err := r.client.Del(ctx, key, metaKey).Err(); err != nil {
193 return fmt.Errorf("failed to delete session: %w", err)
194 }
195 return nil
196}
197
198func (r *RedisStore) GetAuthRequestInfo(ctx context.Context, state string) (*oauth.AuthRequestData, error) {
199 key := authRequestKey(state)
200 data, err := r.client.Get(ctx, key).Bytes()
201 if err == redis.Nil {
202 return nil, fmt.Errorf("request info not found: %s", state)
203 }
204 if err != nil {
205 return nil, fmt.Errorf("failed to get auth request: %w", err)
206 }
207
208 var req oauth.AuthRequestData
209 if err := json.Unmarshal(data, &req); err != nil {
210 return nil, fmt.Errorf("failed to unmarshal auth request: %w", err)
211 }
212
213 return &req, nil
214}
215
216func (r *RedisStore) SaveAuthRequestInfo(ctx context.Context, info oauth.AuthRequestData) error {
217 key := authRequestKey(info.State)
218
219 // check if already exists (to match MemStore behavior)
220 exists, err := r.client.Exists(ctx, key).Result()
221 if err != nil {
222 return fmt.Errorf("failed to check auth request existence: %w", err)
223 }
224 if exists > 0 {
225 return fmt.Errorf("auth request already saved for state %s", info.State)
226 }
227
228 data, err := json.Marshal(info)
229 if err != nil {
230 return fmt.Errorf("failed to marshal auth request: %w", err)
231 }
232
233 if err := r.client.Set(ctx, key, data, r.cfg.AuthRequestExpiryDuration).Err(); err != nil {
234 return fmt.Errorf("failed to save auth request: %w", err)
235 }
236
237 return nil
238}
239
240func (r *RedisStore) DeleteAuthRequestInfo(ctx context.Context, state string) error {
241 key := authRequestKey(state)
242 if err := r.client.Del(ctx, key).Err(); err != nil {
243 return fmt.Errorf("failed to delete auth request: %w", err)
244 }
245 return nil
246}