1package jetstream
2
3import (
4 "context"
5 "fmt"
6 "log/slog"
7 "sync"
8 "time"
9
10 "github.com/bluesky-social/jetstream/pkg/client"
11 "github.com/bluesky-social/jetstream/pkg/client/schedulers/sequential"
12 "github.com/bluesky-social/jetstream/pkg/models"
13 "github.com/sotangled/tangled/log"
14)
15
16type DB interface {
17 GetLastTimeUs() (int64, error)
18 SaveLastTimeUs(int64) error
19 UpdateLastTimeUs(int64) error
20}
21
22type JetstreamSubscriber struct {
23 client *client.Client
24 cancel context.CancelFunc
25 dids []string
26 ident string
27 running bool
28}
29
30type JetstreamClient struct {
31 cfg *client.ClientConfig
32 baseIdent string
33 l *slog.Logger
34 db DB
35 waitForDid bool
36 maxDidsPerSubscriber int
37
38 mu sync.RWMutex
39 subscribers []*JetstreamSubscriber
40 processFunc func(context.Context, *models.Event) error
41 subscriberWg sync.WaitGroup
42}
43
44func (j *JetstreamClient) AddDid(did string) {
45 if did == "" {
46 return
47 }
48 j.mu.Lock()
49 defer j.mu.Unlock()
50
51 // Just add to the config for now, actual subscriber management happens in UpdateDids
52 j.cfg.WantedDids = append(j.cfg.WantedDids, did)
53}
54
55func (j *JetstreamClient) UpdateDids(dids []string) {
56 j.mu.Lock()
57 for _, did := range dids {
58 if did != "" {
59 j.cfg.WantedDids = append(j.cfg.WantedDids, did)
60 }
61 }
62
63 needRebalance := j.processFunc != nil
64 j.mu.Unlock()
65
66 if needRebalance {
67 j.rebalanceSubscribers()
68 }
69}
70
71func NewJetstreamClient(endpoint, ident string, collections []string, cfg *client.ClientConfig, logger *slog.Logger, db DB, waitForDid bool) (*JetstreamClient, error) {
72 if cfg == nil {
73 cfg = client.DefaultClientConfig()
74 cfg.WebsocketURL = endpoint
75 cfg.WantedCollections = collections
76 }
77
78 return &JetstreamClient{
79 cfg: cfg,
80 baseIdent: ident,
81 db: db,
82 l: logger,
83 waitForDid: waitForDid,
84 subscribers: make([]*JetstreamSubscriber, 0),
85 maxDidsPerSubscriber: 100,
86 }, nil
87}
88
89// StartJetstream starts the jetstream client and processes events using the provided processFunc.
90// The caller is responsible for saving the last time_us to the database (just use your db.SaveLastTimeUs).
91func (j *JetstreamClient) StartJetstream(ctx context.Context, processFunc func(context.Context, *models.Event) error) error {
92 j.mu.Lock()
93 j.processFunc = processFunc
94 j.mu.Unlock()
95
96 if j.waitForDid {
97 // Start a goroutine to wait for DIDs and then start subscribers
98 go func() {
99 for {
100 j.mu.RLock()
101 hasDids := len(j.cfg.WantedDids) > 0
102 j.mu.RUnlock()
103
104 if hasDids {
105 j.l.Info("done waiting for did, starting subscribers")
106 j.rebalanceSubscribers()
107 return
108 }
109 time.Sleep(time.Second)
110 }
111 }()
112 } else {
113 // Start subscribers immediately
114 j.rebalanceSubscribers()
115 }
116
117 return nil
118}
119
120// rebalanceSubscribers creates, updates, or removes subscribers based on the current list of DIDs
121func (j *JetstreamClient) rebalanceSubscribers() {
122 j.mu.Lock()
123 defer j.mu.Unlock()
124
125 if j.processFunc == nil {
126 j.l.Warn("cannot rebalance subscribers without a process function")
127 return
128 }
129
130 // calculate how many subscribers we need
131 totalDids := len(j.cfg.WantedDids)
132 subscribersNeeded := (totalDids + j.maxDidsPerSubscriber - 1) / j.maxDidsPerSubscriber // ceiling division
133
134 // first case: no subscribers yet; create all needed subscribers
135 if len(j.subscribers) == 0 {
136 for i := range subscribersNeeded {
137 startIdx := i * j.maxDidsPerSubscriber
138 endIdx := min((i+1)*j.maxDidsPerSubscriber, totalDids)
139
140 subscriberDids := j.cfg.WantedDids[startIdx:endIdx]
141
142 subCfg := *j.cfg
143 subCfg.WantedDids = subscriberDids
144
145 ident := fmt.Sprintf("%s-%d", j.baseIdent, i)
146 subscriber := &JetstreamSubscriber{
147 dids: subscriberDids,
148 ident: ident,
149 }
150 j.subscribers = append(j.subscribers, subscriber)
151
152 j.subscriberWg.Add(1)
153 go j.startSubscriber(subscriber, &subCfg)
154 }
155 return
156 }
157
158 // second case: we have more subscribers than needed, stop extra subscribers
159 if len(j.subscribers) > subscribersNeeded {
160 for i := subscribersNeeded; i < len(j.subscribers); i++ {
161 sub := j.subscribers[i]
162 if sub.running && sub.cancel != nil {
163 sub.cancel()
164 sub.running = false
165 }
166 }
167 j.subscribers = j.subscribers[:subscribersNeeded]
168 }
169
170 // third case: we need more subscribers
171 if len(j.subscribers) < subscribersNeeded {
172 existingCount := len(j.subscribers)
173 // Create additional subscribers
174 for i := existingCount; i < subscribersNeeded; i++ {
175 startIdx := i * j.maxDidsPerSubscriber
176 endIdx := min((i+1)*j.maxDidsPerSubscriber, totalDids)
177
178 subscriberDids := j.cfg.WantedDids[startIdx:endIdx]
179
180 subCfg := *j.cfg
181 subCfg.WantedDids = subscriberDids
182
183 ident := fmt.Sprintf("%s-%d", j.baseIdent, i)
184 subscriber := &JetstreamSubscriber{
185 dids: subscriberDids,
186 ident: ident,
187 }
188 j.subscribers = append(j.subscribers, subscriber)
189
190 j.subscriberWg.Add(1)
191 go j.startSubscriber(subscriber, &subCfg)
192 }
193 }
194
195 // fourth case: update existing subscribers with new wantedDids
196 for i := 0; i < subscribersNeeded && i < len(j.subscribers); i++ {
197 startIdx := i * j.maxDidsPerSubscriber
198 endIdx := min((i+1)*j.maxDidsPerSubscriber, totalDids)
199 newDids := j.cfg.WantedDids[startIdx:endIdx]
200
201 // if the dids for this subscriber have changed, restart it
202 sub := j.subscribers[i]
203 if !didSlicesEqual(sub.dids, newDids) {
204 j.l.Info("subscriber DIDs changed, updating",
205 "subscriber", sub.ident,
206 "old_count", len(sub.dids),
207 "new_count", len(newDids))
208
209 if sub.running && sub.cancel != nil {
210 sub.cancel()
211 sub.running = false
212 }
213
214 subCfg := *j.cfg
215 subCfg.WantedDids = newDids
216
217 sub.dids = newDids
218
219 j.subscriberWg.Add(1)
220 go j.startSubscriber(sub, &subCfg)
221 }
222 }
223}
224
225func didSlicesEqual(a, b []string) bool {
226 if len(a) != len(b) {
227 return false
228 }
229
230 aMap := make(map[string]struct{}, len(a))
231 for _, did := range a {
232 aMap[did] = struct{}{}
233 }
234
235 for _, did := range b {
236 if _, exists := aMap[did]; !exists {
237 return false
238 }
239 }
240
241 return true
242}
243
244// startSubscriber initializes and starts a single subscriber
245func (j *JetstreamClient) startSubscriber(sub *JetstreamSubscriber, cfg *client.ClientConfig) {
246 defer j.subscriberWg.Done()
247
248 logger := j.l.With("subscriber", sub.ident)
249 logger.Info("starting subscriber", "dids_count", len(sub.dids))
250
251 sched := sequential.NewScheduler(sub.ident, logger, j.processFunc)
252
253 client, err := client.NewClient(cfg, log.New("jetstream-"+sub.ident), sched)
254 if err != nil {
255 logger.Error("failed to create jetstream client", "error", err)
256 return
257 }
258
259 sub.client = client
260
261 j.mu.Lock()
262 sub.running = true
263 j.mu.Unlock()
264
265 j.connectAndReadForSubscriber(sub)
266}
267
268func (j *JetstreamClient) connectAndReadForSubscriber(sub *JetstreamSubscriber) {
269 ctx := context.Background()
270 l := j.l.With("subscriber", sub.ident)
271
272 for {
273 // Check if this subscriber should still be running
274 j.mu.RLock()
275 running := sub.running
276 j.mu.RUnlock()
277
278 if !running {
279 l.Info("subscriber marked for shutdown")
280 return
281 }
282
283 cursor := j.getLastTimeUs(ctx)
284
285 connCtx, cancel := context.WithCancel(ctx)
286
287 j.mu.Lock()
288 sub.cancel = cancel
289 j.mu.Unlock()
290
291 l.Info("connecting subscriber to jetstream")
292 if err := sub.client.ConnectAndRead(connCtx, cursor); err != nil {
293 l.Error("error reading jetstream", "error", err)
294 cancel()
295 time.Sleep(time.Second) // Small backoff before retry
296 continue
297 }
298
299 select {
300 case <-ctx.Done():
301 l.Info("context done, stopping subscriber")
302 return
303 case <-connCtx.Done():
304 l.Info("connection context done, reconnecting")
305 continue
306 }
307 }
308}
309
310// GetRunningSubscribersCount returns the total number of currently running subscribers
311func (j *JetstreamClient) GetRunningSubscribersCount() int {
312 j.mu.RLock()
313 defer j.mu.RUnlock()
314
315 runningCount := 0
316 for _, sub := range j.subscribers {
317 if sub.running {
318 runningCount++
319 }
320 }
321
322 return runningCount
323}
324
325// Shutdown gracefully stops all subscribers
326func (j *JetstreamClient) Shutdown() {
327 j.mu.Lock()
328
329 // Cancel all subscribers
330 for _, sub := range j.subscribers {
331 if sub.running && sub.cancel != nil {
332 sub.cancel()
333 sub.running = false
334 }
335 }
336
337 j.mu.Unlock()
338
339 // Wait for all subscribers to complete
340 j.subscriberWg.Wait()
341 j.l.Info("all subscribers shut down", "total_subscribers", len(j.subscribers), "running_subscribers", j.GetRunningSubscribersCount())
342}
343
344func (j *JetstreamClient) getLastTimeUs(ctx context.Context) *int64 {
345 l := log.FromContext(ctx)
346 lastTimeUs, err := j.db.GetLastTimeUs()
347 if err != nil {
348 l.Warn("couldn't get last time us, starting from now", "error", err)
349 lastTimeUs = time.Now().UnixMicro()
350 err = j.db.SaveLastTimeUs(lastTimeUs)
351 if err != nil {
352 l.Error("failed to save last time us", "error", err)
353 }
354 }
355
356 // If last time is older than 2 days, start from now
357 if time.Now().UnixMicro()-lastTimeUs > 2*24*60*60*1000*1000 {
358 lastTimeUs = time.Now().UnixMicro()
359 l.Warn("last time us is older than 2 days; discarding that and starting from now")
360 err = j.db.UpdateLastTimeUs(lastTimeUs)
361 if err != nil {
362 l.Error("failed to save last time us", "error", err)
363 }
364 }
365
366 l.Info("found last time_us", "time_us", lastTimeUs, "running_subscribers", j.GetRunningSubscribersCount())
367 return &lastTimeUs
368}