1package jetstream
2
3import (
4 "context"
5 "fmt"
6 "log/slog"
7 "sync"
8 "time"
9
10 "github.com/bluesky-social/jetstream/pkg/client"
11 "github.com/bluesky-social/jetstream/pkg/client/schedulers/sequential"
12 "github.com/bluesky-social/jetstream/pkg/models"
13 "github.com/sotangled/tangled/log"
14)
15
16type DB interface {
17 GetLastTimeUs() (int64, error)
18 SaveLastTimeUs(int64) error
19 UpdateLastTimeUs(int64) error
20}
21
22type JetstreamSubscriber struct {
23 client *client.Client
24 cancel context.CancelFunc
25 dids []string
26 ident string
27 running bool
28}
29
30type JetstreamClient struct {
31 cfg *client.ClientConfig
32 baseIdent string
33 l *slog.Logger
34 db DB
35 waitForDid bool
36 maxDidsPerSubscriber int
37
38 mu sync.RWMutex
39 subscribers []*JetstreamSubscriber
40 processFunc func(context.Context, *models.Event) error
41 subscriberWg sync.WaitGroup
42}
43
44func (j *JetstreamClient) AddDid(did string) {
45 if did == "" {
46 return
47 }
48 j.mu.Lock()
49 defer j.mu.Unlock()
50
51 // Just add to the config for now, actual subscriber management happens in UpdateDids
52 j.cfg.WantedDids = append(j.cfg.WantedDids, did)
53}
54
55func (j *JetstreamClient) UpdateDids(dids []string) {
56 j.mu.Lock()
57 for _, did := range dids {
58 if did != "" {
59 j.cfg.WantedDids = append(j.cfg.WantedDids, did)
60 }
61 }
62
63 needRebalance := j.processFunc != nil
64 j.mu.Unlock()
65
66 if needRebalance {
67 j.rebalanceSubscribers()
68 }
69}
70
71func NewJetstreamClient(endpoint, ident string, collections []string, cfg *client.ClientConfig, logger *slog.Logger, db DB, waitForDid bool) (*JetstreamClient, error) {
72 if cfg == nil {
73 cfg = client.DefaultClientConfig()
74 cfg.WebsocketURL = endpoint
75 cfg.WantedCollections = collections
76 }
77
78 return &JetstreamClient{
79 cfg: cfg,
80 baseIdent: ident,
81 db: db,
82 l: logger,
83 waitForDid: waitForDid,
84 subscribers: make([]*JetstreamSubscriber, 0),
85 maxDidsPerSubscriber: 100,
86 }, nil
87}
88
89// StartJetstream starts the jetstream client and processes events using the provided processFunc.
90// The caller is responsible for saving the last time_us to the database (just use your db.SaveLastTimeUs).
91func (j *JetstreamClient) StartJetstream(ctx context.Context, processFunc func(context.Context, *models.Event) error) error {
92 j.mu.Lock()
93 j.processFunc = processFunc
94 j.mu.Unlock()
95
96 if j.waitForDid {
97 // Start a goroutine to wait for DIDs and then start subscribers
98 go func() {
99 for {
100 j.mu.RLock()
101 hasDids := len(j.cfg.WantedDids) > 0
102 j.mu.RUnlock()
103
104 if hasDids {
105 j.l.Info("done waiting for did, starting subscribers")
106 j.rebalanceSubscribers()
107 return
108 }
109 time.Sleep(time.Second)
110 }
111 }()
112 } else {
113 // Start subscribers immediately
114 j.rebalanceSubscribers()
115 }
116
117 return nil
118}
119
120// rebalanceSubscribers creates, updates, or removes subscribers based on the current list of DIDs
121func (j *JetstreamClient) rebalanceSubscribers() {
122 j.mu.Lock()
123 defer j.mu.Unlock()
124
125 if j.processFunc == nil {
126 j.l.Warn("cannot rebalance subscribers without a process function")
127 return
128 }
129
130 // stop all subscribers first
131 for _, sub := range j.subscribers {
132 if sub.running && sub.cancel != nil {
133 sub.cancel()
134 sub.running = false
135 }
136 }
137
138 // calculate how many subscribers we need
139 totalDids := len(j.cfg.WantedDids)
140 subscribersNeeded := (totalDids + j.maxDidsPerSubscriber - 1) / j.maxDidsPerSubscriber // ceiling division
141
142 // create or reuse subscribers as needed
143 j.subscribers = j.subscribers[:0]
144
145 for i := range subscribersNeeded {
146 startIdx := i * j.maxDidsPerSubscriber
147 endIdx := min((i+1)*j.maxDidsPerSubscriber, totalDids)
148
149 subscriberDids := j.cfg.WantedDids[startIdx:endIdx]
150
151 subCfg := *j.cfg
152 subCfg.WantedDids = subscriberDids
153
154 ident := fmt.Sprintf("%s-%d", j.baseIdent, i)
155 subscriber := &JetstreamSubscriber{
156 dids: subscriberDids,
157 ident: ident,
158 }
159 j.subscribers = append(j.subscribers, subscriber)
160
161 j.subscriberWg.Add(1)
162 go j.startSubscriber(subscriber, &subCfg)
163 }
164}
165
166// startSubscriber initializes and starts a single subscriber
167func (j *JetstreamClient) startSubscriber(sub *JetstreamSubscriber, cfg *client.ClientConfig) {
168 defer j.subscriberWg.Done()
169
170 logger := j.l.With("subscriber", sub.ident)
171 logger.Info("starting subscriber", "dids_count", len(sub.dids))
172
173 sched := sequential.NewScheduler(sub.ident, logger, j.processFunc)
174
175 client, err := client.NewClient(cfg, log.New("jetstream-"+sub.ident), sched)
176 if err != nil {
177 logger.Error("failed to create jetstream client", "error", err)
178 return
179 }
180
181 sub.client = client
182
183 j.mu.Lock()
184 sub.running = true
185 j.mu.Unlock()
186
187 j.connectAndReadForSubscriber(sub)
188}
189
190func (j *JetstreamClient) connectAndReadForSubscriber(sub *JetstreamSubscriber) {
191 ctx := context.Background()
192 l := j.l.With("subscriber", sub.ident)
193
194 for {
195 // Check if this subscriber should still be running
196 j.mu.RLock()
197 running := sub.running
198 j.mu.RUnlock()
199
200 if !running {
201 l.Info("subscriber marked for shutdown")
202 return
203 }
204
205 cursor := j.getLastTimeUs(ctx)
206
207 connCtx, cancel := context.WithCancel(ctx)
208
209 j.mu.Lock()
210 sub.cancel = cancel
211 j.mu.Unlock()
212
213 l.Info("connecting subscriber to jetstream")
214 if err := sub.client.ConnectAndRead(connCtx, cursor); err != nil {
215 l.Error("error reading jetstream", "error", err)
216 cancel()
217 time.Sleep(time.Second) // Small backoff before retry
218 continue
219 }
220
221 select {
222 case <-ctx.Done():
223 l.Info("context done, stopping subscriber")
224 return
225 case <-connCtx.Done():
226 l.Info("connection context done, reconnecting")
227 continue
228 }
229 }
230}
231
232// GetRunningSubscribersCount returns the total number of currently running subscribers
233func (j *JetstreamClient) GetRunningSubscribersCount() int {
234 j.mu.RLock()
235 defer j.mu.RUnlock()
236
237 runningCount := 0
238 for _, sub := range j.subscribers {
239 if sub.running {
240 runningCount++
241 }
242 }
243
244 return runningCount
245}
246
247// Shutdown gracefully stops all subscribers
248func (j *JetstreamClient) Shutdown() {
249 j.mu.Lock()
250
251 // Cancel all subscribers
252 for _, sub := range j.subscribers {
253 if sub.running && sub.cancel != nil {
254 sub.cancel()
255 sub.running = false
256 }
257 }
258
259 j.mu.Unlock()
260
261 // Wait for all subscribers to complete
262 j.subscriberWg.Wait()
263 j.l.Info("all subscribers shut down", "total_subscribers", len(j.subscribers), "running_subscribers", j.GetRunningSubscribersCount())
264}
265
266func (j *JetstreamClient) getLastTimeUs(ctx context.Context) *int64 {
267 l := log.FromContext(ctx)
268 lastTimeUs, err := j.db.GetLastTimeUs()
269 if err != nil {
270 l.Warn("couldn't get last time us, starting from now", "error", err)
271 lastTimeUs = time.Now().UnixMicro()
272 err = j.db.SaveLastTimeUs(lastTimeUs)
273 if err != nil {
274 l.Error("failed to save last time us", "error", err)
275 }
276 }
277
278 // If last time is older than 2 days, start from now
279 if time.Now().UnixMicro()-lastTimeUs > 2*24*60*60*1000*1000 {
280 lastTimeUs = time.Now().UnixMicro()
281 l.Warn("last time us is older than 2 days; discarding that and starting from now")
282 err = j.db.UpdateLastTimeUs(lastTimeUs)
283 if err != nil {
284 l.Error("failed to save last time us", "error", err)
285 }
286 }
287
288 l.Info("found last time_us", "time_us", lastTimeUs, "running_subscribers", j.GetRunningSubscribersCount())
289 return &lastTimeUs
290}