forked from tangled.org/core
this repo has no description
1package jetstream 2 3import ( 4 "context" 5 "fmt" 6 "log/slog" 7 "sync" 8 "time" 9 10 "github.com/bluesky-social/jetstream/pkg/client" 11 "github.com/bluesky-social/jetstream/pkg/client/schedulers/sequential" 12 "github.com/bluesky-social/jetstream/pkg/models" 13 "github.com/sotangled/tangled/log" 14) 15 16type DB interface { 17 GetLastTimeUs() (int64, error) 18 SaveLastTimeUs(int64) error 19 UpdateLastTimeUs(int64) error 20} 21 22type JetstreamSubscriber struct { 23 client *client.Client 24 cancel context.CancelFunc 25 dids []string 26 ident string 27 running bool 28} 29 30type JetstreamClient struct { 31 cfg *client.ClientConfig 32 baseIdent string 33 l *slog.Logger 34 db DB 35 waitForDid bool 36 maxDidsPerSubscriber int 37 38 mu sync.RWMutex 39 subscribers []*JetstreamSubscriber 40 processFunc func(context.Context, *models.Event) error 41 subscriberWg sync.WaitGroup 42} 43 44func (j *JetstreamClient) AddDid(did string) { 45 if did == "" { 46 return 47 } 48 j.mu.Lock() 49 defer j.mu.Unlock() 50 51 // Just add to the config for now, actual subscriber management happens in UpdateDids 52 j.cfg.WantedDids = append(j.cfg.WantedDids, did) 53} 54 55func (j *JetstreamClient) UpdateDids(dids []string) { 56 j.mu.Lock() 57 for _, did := range dids { 58 if did != "" { 59 j.cfg.WantedDids = append(j.cfg.WantedDids, did) 60 } 61 } 62 63 needRebalance := j.processFunc != nil 64 j.mu.Unlock() 65 66 if needRebalance { 67 j.rebalanceSubscribers() 68 } 69} 70 71func NewJetstreamClient(endpoint, ident string, collections []string, cfg *client.ClientConfig, logger *slog.Logger, db DB, waitForDid bool) (*JetstreamClient, error) { 72 if cfg == nil { 73 cfg = client.DefaultClientConfig() 74 cfg.WebsocketURL = endpoint 75 cfg.WantedCollections = collections 76 } 77 78 return &JetstreamClient{ 79 cfg: cfg, 80 baseIdent: ident, 81 db: db, 82 l: logger, 83 waitForDid: waitForDid, 84 subscribers: make([]*JetstreamSubscriber, 0), 85 maxDidsPerSubscriber: 100, 86 }, nil 87} 88 89// StartJetstream starts the jetstream client and processes events using the provided processFunc. 90// The caller is responsible for saving the last time_us to the database (just use your db.SaveLastTimeUs). 91func (j *JetstreamClient) StartJetstream(ctx context.Context, processFunc func(context.Context, *models.Event) error) error { 92 j.mu.Lock() 93 j.processFunc = processFunc 94 j.mu.Unlock() 95 96 if j.waitForDid { 97 // Start a goroutine to wait for DIDs and then start subscribers 98 go func() { 99 for { 100 j.mu.RLock() 101 hasDids := len(j.cfg.WantedDids) > 0 102 j.mu.RUnlock() 103 104 if hasDids { 105 j.l.Info("done waiting for did, starting subscribers") 106 j.rebalanceSubscribers() 107 return 108 } 109 time.Sleep(time.Second) 110 } 111 }() 112 } else { 113 // Start subscribers immediately 114 j.rebalanceSubscribers() 115 } 116 117 return nil 118} 119 120// rebalanceSubscribers creates, updates, or removes subscribers based on the current list of DIDs 121func (j *JetstreamClient) rebalanceSubscribers() { 122 j.mu.Lock() 123 defer j.mu.Unlock() 124 125 if j.processFunc == nil { 126 j.l.Warn("cannot rebalance subscribers without a process function") 127 return 128 } 129 130 // calculate how many subscribers we need 131 totalDids := len(j.cfg.WantedDids) 132 subscribersNeeded := (totalDids + j.maxDidsPerSubscriber - 1) / j.maxDidsPerSubscriber // ceiling division 133 134 // first case: no subscribers yet; create all needed subscribers 135 if len(j.subscribers) == 0 { 136 for i := range subscribersNeeded { 137 startIdx := i * j.maxDidsPerSubscriber 138 endIdx := min((i+1)*j.maxDidsPerSubscriber, totalDids) 139 140 subscriberDids := j.cfg.WantedDids[startIdx:endIdx] 141 142 subCfg := *j.cfg 143 subCfg.WantedDids = subscriberDids 144 145 ident := fmt.Sprintf("%s-%d", j.baseIdent, i) 146 subscriber := &JetstreamSubscriber{ 147 dids: subscriberDids, 148 ident: ident, 149 } 150 j.subscribers = append(j.subscribers, subscriber) 151 152 j.subscriberWg.Add(1) 153 go j.startSubscriber(subscriber, &subCfg) 154 } 155 return 156 } 157 158 // second case: we have more subscribers than needed, stop extra subscribers 159 if len(j.subscribers) > subscribersNeeded { 160 for i := subscribersNeeded; i < len(j.subscribers); i++ { 161 sub := j.subscribers[i] 162 if sub.running && sub.cancel != nil { 163 sub.cancel() 164 sub.running = false 165 } 166 } 167 j.subscribers = j.subscribers[:subscribersNeeded] 168 } 169 170 // third case: we need more subscribers 171 if len(j.subscribers) < subscribersNeeded { 172 existingCount := len(j.subscribers) 173 // Create additional subscribers 174 for i := existingCount; i < subscribersNeeded; i++ { 175 startIdx := i * j.maxDidsPerSubscriber 176 endIdx := min((i+1)*j.maxDidsPerSubscriber, totalDids) 177 178 subscriberDids := j.cfg.WantedDids[startIdx:endIdx] 179 180 subCfg := *j.cfg 181 subCfg.WantedDids = subscriberDids 182 183 ident := fmt.Sprintf("%s-%d", j.baseIdent, i) 184 subscriber := &JetstreamSubscriber{ 185 dids: subscriberDids, 186 ident: ident, 187 } 188 j.subscribers = append(j.subscribers, subscriber) 189 190 j.subscriberWg.Add(1) 191 go j.startSubscriber(subscriber, &subCfg) 192 } 193 } 194 195 // fourth case: update existing subscribers with new wantedDids 196 for i := 0; i < subscribersNeeded && i < len(j.subscribers); i++ { 197 startIdx := i * j.maxDidsPerSubscriber 198 endIdx := min((i+1)*j.maxDidsPerSubscriber, totalDids) 199 newDids := j.cfg.WantedDids[startIdx:endIdx] 200 201 // if the dids for this subscriber have changed, restart it 202 sub := j.subscribers[i] 203 if !didSlicesEqual(sub.dids, newDids) { 204 j.l.Info("subscriber DIDs changed, updating", 205 "subscriber", sub.ident, 206 "old_count", len(sub.dids), 207 "new_count", len(newDids)) 208 209 if sub.running && sub.cancel != nil { 210 sub.cancel() 211 sub.running = false 212 } 213 214 subCfg := *j.cfg 215 subCfg.WantedDids = newDids 216 217 sub.dids = newDids 218 219 j.subscriberWg.Add(1) 220 go j.startSubscriber(sub, &subCfg) 221 } 222 } 223} 224 225func didSlicesEqual(a, b []string) bool { 226 if len(a) != len(b) { 227 return false 228 } 229 230 aMap := make(map[string]struct{}, len(a)) 231 for _, did := range a { 232 aMap[did] = struct{}{} 233 } 234 235 for _, did := range b { 236 if _, exists := aMap[did]; !exists { 237 return false 238 } 239 } 240 241 return true 242} 243 244// startSubscriber initializes and starts a single subscriber 245func (j *JetstreamClient) startSubscriber(sub *JetstreamSubscriber, cfg *client.ClientConfig) { 246 defer j.subscriberWg.Done() 247 248 logger := j.l.With("subscriber", sub.ident) 249 logger.Info("starting subscriber", "dids_count", len(sub.dids)) 250 251 sched := sequential.NewScheduler(sub.ident, logger, j.processFunc) 252 253 client, err := client.NewClient(cfg, log.New("jetstream-"+sub.ident), sched) 254 if err != nil { 255 logger.Error("failed to create jetstream client", "error", err) 256 return 257 } 258 259 sub.client = client 260 261 j.mu.Lock() 262 sub.running = true 263 j.mu.Unlock() 264 265 j.connectAndReadForSubscriber(sub) 266} 267 268func (j *JetstreamClient) connectAndReadForSubscriber(sub *JetstreamSubscriber) { 269 ctx := context.Background() 270 l := j.l.With("subscriber", sub.ident) 271 272 for { 273 // Check if this subscriber should still be running 274 j.mu.RLock() 275 running := sub.running 276 j.mu.RUnlock() 277 278 if !running { 279 l.Info("subscriber marked for shutdown") 280 return 281 } 282 283 cursor := j.getLastTimeUs(ctx) 284 285 connCtx, cancel := context.WithCancel(ctx) 286 287 j.mu.Lock() 288 sub.cancel = cancel 289 j.mu.Unlock() 290 291 l.Info("connecting subscriber to jetstream") 292 if err := sub.client.ConnectAndRead(connCtx, cursor); err != nil { 293 l.Error("error reading jetstream", "error", err) 294 cancel() 295 time.Sleep(time.Second) // Small backoff before retry 296 continue 297 } 298 299 select { 300 case <-ctx.Done(): 301 l.Info("context done, stopping subscriber") 302 return 303 case <-connCtx.Done(): 304 l.Info("connection context done, reconnecting") 305 continue 306 } 307 } 308} 309 310// GetRunningSubscribersCount returns the total number of currently running subscribers 311func (j *JetstreamClient) GetRunningSubscribersCount() int { 312 j.mu.RLock() 313 defer j.mu.RUnlock() 314 315 runningCount := 0 316 for _, sub := range j.subscribers { 317 if sub.running { 318 runningCount++ 319 } 320 } 321 322 return runningCount 323} 324 325// Shutdown gracefully stops all subscribers 326func (j *JetstreamClient) Shutdown() { 327 j.mu.Lock() 328 329 // Cancel all subscribers 330 for _, sub := range j.subscribers { 331 if sub.running && sub.cancel != nil { 332 sub.cancel() 333 sub.running = false 334 } 335 } 336 337 j.mu.Unlock() 338 339 // Wait for all subscribers to complete 340 j.subscriberWg.Wait() 341 j.l.Info("all subscribers shut down", "total_subscribers", len(j.subscribers), "running_subscribers", j.GetRunningSubscribersCount()) 342} 343 344func (j *JetstreamClient) getLastTimeUs(ctx context.Context) *int64 { 345 l := log.FromContext(ctx) 346 lastTimeUs, err := j.db.GetLastTimeUs() 347 if err != nil { 348 l.Warn("couldn't get last time us, starting from now", "error", err) 349 lastTimeUs = time.Now().UnixMicro() 350 err = j.db.SaveLastTimeUs(lastTimeUs) 351 if err != nil { 352 l.Error("failed to save last time us", "error", err) 353 } 354 } 355 356 // If last time is older than 2 days, start from now 357 if time.Now().UnixMicro()-lastTimeUs > 2*24*60*60*1000*1000 { 358 lastTimeUs = time.Now().UnixMicro() 359 l.Warn("last time us is older than 2 days; discarding that and starting from now") 360 err = j.db.UpdateLastTimeUs(lastTimeUs) 361 if err != nil { 362 l.Error("failed to save last time us", "error", err) 363 } 364 } 365 366 l.Info("found last time_us", "time_us", lastTimeUs, "running_subscribers", j.GetRunningSubscribersCount()) 367 return &lastTimeUs 368}