forked from tangled.org/core
this repo has no description
1package jetstream 2 3import ( 4 "context" 5 "fmt" 6 "log/slog" 7 "sync" 8 "time" 9 10 "github.com/bluesky-social/jetstream/pkg/client" 11 "github.com/bluesky-social/jetstream/pkg/client/schedulers/sequential" 12 "github.com/bluesky-social/jetstream/pkg/models" 13 "github.com/sotangled/tangled/log" 14) 15 16type DB interface { 17 GetLastTimeUs() (int64, error) 18 SaveLastTimeUs(int64) error 19 UpdateLastTimeUs(int64) error 20} 21 22type JetstreamSubscriber struct { 23 client *client.Client 24 cancel context.CancelFunc 25 dids []string 26 ident string 27 running bool 28} 29 30type JetstreamClient struct { 31 cfg *client.ClientConfig 32 baseIdent string 33 l *slog.Logger 34 db DB 35 waitForDid bool 36 maxDidsPerSubscriber int 37 38 mu sync.RWMutex 39 subscribers []*JetstreamSubscriber 40 processFunc func(context.Context, *models.Event) error 41 subscriberWg sync.WaitGroup 42} 43 44func (j *JetstreamClient) AddDid(did string) { 45 if did == "" { 46 return 47 } 48 j.mu.Lock() 49 defer j.mu.Unlock() 50 51 // Just add to the config for now, actual subscriber management happens in UpdateDids 52 j.cfg.WantedDids = append(j.cfg.WantedDids, did) 53} 54 55func (j *JetstreamClient) UpdateDids(dids []string) { 56 j.mu.Lock() 57 for _, did := range dids { 58 if did != "" { 59 j.cfg.WantedDids = append(j.cfg.WantedDids, did) 60 } 61 } 62 63 needRebalance := j.processFunc != nil 64 j.mu.Unlock() 65 66 if needRebalance { 67 j.rebalanceSubscribers() 68 } 69} 70 71func NewJetstreamClient(endpoint, ident string, collections []string, cfg *client.ClientConfig, logger *slog.Logger, db DB, waitForDid bool) (*JetstreamClient, error) { 72 if cfg == nil { 73 cfg = client.DefaultClientConfig() 74 cfg.WebsocketURL = endpoint 75 cfg.WantedCollections = collections 76 } 77 78 return &JetstreamClient{ 79 cfg: cfg, 80 baseIdent: ident, 81 db: db, 82 l: logger, 83 waitForDid: waitForDid, 84 subscribers: make([]*JetstreamSubscriber, 0), 85 maxDidsPerSubscriber: 100, 86 }, nil 87} 88 89// StartJetstream starts the jetstream client and processes events using the provided processFunc. 90// The caller is responsible for saving the last time_us to the database (just use your db.SaveLastTimeUs). 91func (j *JetstreamClient) StartJetstream(ctx context.Context, processFunc func(context.Context, *models.Event) error) error { 92 j.mu.Lock() 93 j.processFunc = processFunc 94 j.mu.Unlock() 95 96 if j.waitForDid { 97 // Start a goroutine to wait for DIDs and then start subscribers 98 go func() { 99 for { 100 j.mu.RLock() 101 hasDids := len(j.cfg.WantedDids) > 0 102 j.mu.RUnlock() 103 104 if hasDids { 105 j.l.Info("done waiting for did, starting subscribers") 106 j.rebalanceSubscribers() 107 return 108 } 109 time.Sleep(time.Second) 110 } 111 }() 112 } else { 113 // Start subscribers immediately 114 j.rebalanceSubscribers() 115 } 116 117 return nil 118} 119 120// rebalanceSubscribers creates, updates, or removes subscribers based on the current list of DIDs 121func (j *JetstreamClient) rebalanceSubscribers() { 122 j.mu.Lock() 123 defer j.mu.Unlock() 124 125 if j.processFunc == nil { 126 j.l.Warn("cannot rebalance subscribers without a process function") 127 return 128 } 129 130 // stop all subscribers first 131 for _, sub := range j.subscribers { 132 if sub.running && sub.cancel != nil { 133 sub.cancel() 134 sub.running = false 135 } 136 } 137 138 // calculate how many subscribers we need 139 totalDids := len(j.cfg.WantedDids) 140 subscribersNeeded := (totalDids + j.maxDidsPerSubscriber - 1) / j.maxDidsPerSubscriber // ceiling division 141 142 // create or reuse subscribers as needed 143 j.subscribers = j.subscribers[:0] 144 145 for i := range subscribersNeeded { 146 startIdx := i * j.maxDidsPerSubscriber 147 endIdx := min((i+1)*j.maxDidsPerSubscriber, totalDids) 148 149 subscriberDids := j.cfg.WantedDids[startIdx:endIdx] 150 151 subCfg := *j.cfg 152 subCfg.WantedDids = subscriberDids 153 154 ident := fmt.Sprintf("%s-%d", j.baseIdent, i) 155 subscriber := &JetstreamSubscriber{ 156 dids: subscriberDids, 157 ident: ident, 158 } 159 j.subscribers = append(j.subscribers, subscriber) 160 161 j.subscriberWg.Add(1) 162 go j.startSubscriber(subscriber, &subCfg) 163 } 164} 165 166// startSubscriber initializes and starts a single subscriber 167func (j *JetstreamClient) startSubscriber(sub *JetstreamSubscriber, cfg *client.ClientConfig) { 168 defer j.subscriberWg.Done() 169 170 logger := j.l.With("subscriber", sub.ident) 171 logger.Info("starting subscriber", "dids_count", len(sub.dids)) 172 173 sched := sequential.NewScheduler(sub.ident, logger, j.processFunc) 174 175 client, err := client.NewClient(cfg, log.New("jetstream-"+sub.ident), sched) 176 if err != nil { 177 logger.Error("failed to create jetstream client", "error", err) 178 return 179 } 180 181 sub.client = client 182 183 j.mu.Lock() 184 sub.running = true 185 j.mu.Unlock() 186 187 j.connectAndReadForSubscriber(sub) 188} 189 190func (j *JetstreamClient) connectAndReadForSubscriber(sub *JetstreamSubscriber) { 191 ctx := context.Background() 192 l := j.l.With("subscriber", sub.ident) 193 194 for { 195 // Check if this subscriber should still be running 196 j.mu.RLock() 197 running := sub.running 198 j.mu.RUnlock() 199 200 if !running { 201 l.Info("subscriber marked for shutdown") 202 return 203 } 204 205 cursor := j.getLastTimeUs(ctx) 206 207 connCtx, cancel := context.WithCancel(ctx) 208 209 j.mu.Lock() 210 sub.cancel = cancel 211 j.mu.Unlock() 212 213 l.Info("connecting subscriber to jetstream") 214 if err := sub.client.ConnectAndRead(connCtx, cursor); err != nil { 215 l.Error("error reading jetstream", "error", err) 216 cancel() 217 time.Sleep(time.Second) // Small backoff before retry 218 continue 219 } 220 221 select { 222 case <-ctx.Done(): 223 l.Info("context done, stopping subscriber") 224 return 225 case <-connCtx.Done(): 226 l.Info("connection context done, reconnecting") 227 continue 228 } 229 } 230} 231 232// GetRunningSubscribersCount returns the total number of currently running subscribers 233func (j *JetstreamClient) GetRunningSubscribersCount() int { 234 j.mu.RLock() 235 defer j.mu.RUnlock() 236 237 runningCount := 0 238 for _, sub := range j.subscribers { 239 if sub.running { 240 runningCount++ 241 } 242 } 243 244 return runningCount 245} 246 247// Shutdown gracefully stops all subscribers 248func (j *JetstreamClient) Shutdown() { 249 j.mu.Lock() 250 251 // Cancel all subscribers 252 for _, sub := range j.subscribers { 253 if sub.running && sub.cancel != nil { 254 sub.cancel() 255 sub.running = false 256 } 257 } 258 259 j.mu.Unlock() 260 261 // Wait for all subscribers to complete 262 j.subscriberWg.Wait() 263 j.l.Info("all subscribers shut down", "total_subscribers", len(j.subscribers), "running_subscribers", j.GetRunningSubscribersCount()) 264} 265 266func (j *JetstreamClient) getLastTimeUs(ctx context.Context) *int64 { 267 l := log.FromContext(ctx) 268 lastTimeUs, err := j.db.GetLastTimeUs() 269 if err != nil { 270 l.Warn("couldn't get last time us, starting from now", "error", err) 271 lastTimeUs = time.Now().UnixMicro() 272 err = j.db.SaveLastTimeUs(lastTimeUs) 273 if err != nil { 274 l.Error("failed to save last time us", "error", err) 275 } 276 } 277 278 // If last time is older than 2 days, start from now 279 if time.Now().UnixMicro()-lastTimeUs > 2*24*60*60*1000*1000 { 280 lastTimeUs = time.Now().UnixMicro() 281 l.Warn("last time us is older than 2 days; discarding that and starting from now") 282 err = j.db.UpdateLastTimeUs(lastTimeUs) 283 if err != nil { 284 l.Error("failed to save last time us", "error", err) 285 } 286 } 287 288 l.Info("found last time_us", "time_us", lastTimeUs, "running_subscribers", j.GetRunningSubscribersCount()) 289 return &lastTimeUs 290}