1package jetstream
2
3import (
4 "context"
5 "fmt"
6 "log/slog"
7 "os"
8 "os/signal"
9 "sync"
10 "syscall"
11 "time"
12
13 "github.com/bluesky-social/jetstream/pkg/client"
14 "github.com/bluesky-social/jetstream/pkg/client/schedulers/sequential"
15 "github.com/bluesky-social/jetstream/pkg/models"
16 "tangled.sh/tangled.sh/core/log"
17)
18
19type DB interface {
20 GetLastTimeUs() (int64, error)
21 SaveLastTimeUs(int64) error
22}
23
24type Set[T comparable] map[T]struct{}
25
26type JetstreamClient struct {
27 cfg *client.ClientConfig
28 client *client.Client
29 ident string
30 l *slog.Logger
31
32 wantedDids Set[string]
33 db DB
34 waitForDid bool
35 mu sync.RWMutex
36
37 cancel context.CancelFunc
38 cancelMu sync.Mutex
39}
40
41func (j *JetstreamClient) AddDid(did string) {
42 if did == "" {
43 return
44 }
45
46 j.l.Info("adding did to in-memory filter", "did", did)
47 j.mu.Lock()
48 j.wantedDids[did] = struct{}{}
49 j.mu.Unlock()
50}
51
52type processor func(context.Context, *models.Event) error
53
54func (j *JetstreamClient) withDidFilter(processFunc processor) processor {
55 // empty filter => all dids allowed
56 if len(j.wantedDids) == 0 {
57 return processFunc
58 }
59 // since this closure references j.WantedDids; it should auto-update
60 // existing instances of the closure when j.WantedDids is mutated
61 return func(ctx context.Context, evt *models.Event) error {
62 if _, ok := j.wantedDids[evt.Did]; ok {
63 return processFunc(ctx, evt)
64 } else {
65 return nil
66 }
67 }
68}
69
70func NewJetstreamClient(endpoint, ident string, collections []string, cfg *client.ClientConfig, logger *slog.Logger, db DB, waitForDid bool) (*JetstreamClient, error) {
71 if cfg == nil {
72 cfg = client.DefaultClientConfig()
73 cfg.WebsocketURL = endpoint
74 cfg.WantedCollections = collections
75 }
76
77 return &JetstreamClient{
78 cfg: cfg,
79 ident: ident,
80 db: db,
81 l: logger,
82 wantedDids: make(map[string]struct{}),
83
84 // This will make the goroutine in StartJetstream wait until
85 // j.wantedDids has been populated, typically using addDids.
86 waitForDid: waitForDid,
87 }, nil
88}
89
90// StartJetstream starts the jetstream client and processes events using the provided processFunc.
91// The caller is responsible for saving the last time_us to the database (just use your db.UpdateLastTimeUs).
92func (j *JetstreamClient) StartJetstream(ctx context.Context, processFunc func(context.Context, *models.Event) error) error {
93 logger := j.l
94
95 sched := sequential.NewScheduler(j.ident, logger, j.withDidFilter(processFunc))
96
97 client, err := client.NewClient(j.cfg, log.New("jetstream"), sched)
98 if err != nil {
99 return fmt.Errorf("failed to create jetstream client: %w", err)
100 }
101 j.client = client
102
103 go func() {
104 if j.waitForDid {
105 for len(j.wantedDids) == 0 {
106 time.Sleep(time.Second)
107 }
108 }
109 logger.Info("done waiting for did")
110
111 go j.periodicLastTimeSave(ctx)
112 j.saveIfKilled(ctx)
113
114 j.connectAndRead(ctx)
115 }()
116
117 return nil
118}
119
120func (j *JetstreamClient) connectAndRead(ctx context.Context) {
121 l := log.FromContext(ctx)
122 for {
123 cursor := j.getLastTimeUs(ctx)
124
125 connCtx, cancel := context.WithCancel(ctx)
126 j.cancelMu.Lock()
127 j.cancel = cancel
128 j.cancelMu.Unlock()
129
130 if err := j.client.ConnectAndRead(connCtx, cursor); err != nil {
131 l.Error("error reading jetstream", "error", err)
132 cancel()
133 continue
134 }
135
136 select {
137 case <-ctx.Done():
138 l.Info("context done, stopping jetstream")
139 return
140 case <-connCtx.Done():
141 l.Info("connection context done, reconnecting")
142 continue
143 }
144 }
145}
146
147// save cursor periodically
148func (j *JetstreamClient) periodicLastTimeSave(ctx context.Context) {
149 ticker := time.NewTicker(time.Minute)
150 defer ticker.Stop()
151
152 for {
153 select {
154 case <-ctx.Done():
155 return
156 case <-ticker.C:
157 j.db.SaveLastTimeUs(time.Now().UnixMicro())
158 }
159 }
160}
161
162func (j *JetstreamClient) getLastTimeUs(ctx context.Context) *int64 {
163 l := log.FromContext(ctx)
164 lastTimeUs, err := j.db.GetLastTimeUs()
165 if err != nil {
166 l.Warn("couldn't get last time us, starting from now", "error", err)
167 lastTimeUs = time.Now().UnixMicro()
168 err = j.db.SaveLastTimeUs(lastTimeUs)
169 if err != nil {
170 l.Error("failed to save last time us", "error", err)
171 }
172 }
173
174 // If last time is older than 2 days, start from now
175 if time.Now().UnixMicro()-lastTimeUs > 2*24*60*60*1000*1000 {
176 lastTimeUs = time.Now().UnixMicro()
177 l.Warn("last time us is older than 2 days; discarding that and starting from now")
178 err = j.db.SaveLastTimeUs(lastTimeUs)
179 if err != nil {
180 l.Error("failed to save last time us", "error", err)
181 }
182 }
183
184 l.Info("found last time_us", "time_us", lastTimeUs)
185 return &lastTimeUs
186}
187
188func (j *JetstreamClient) saveIfKilled(ctx context.Context) context.Context {
189 ctxWithCancel, cancel := context.WithCancel(ctx)
190
191 sigChan := make(chan os.Signal, 1)
192
193 signal.Notify(sigChan,
194 syscall.SIGINT,
195 syscall.SIGTERM,
196 syscall.SIGQUIT,
197 syscall.SIGHUP,
198 syscall.SIGKILL,
199 syscall.SIGSTOP,
200 )
201
202 go func() {
203 sig := <-sigChan
204 j.l.Info("Received signal, initiating graceful shutdown", "signal", sig)
205
206 lastTimeUs := time.Now().UnixMicro()
207 if err := j.db.SaveLastTimeUs(lastTimeUs); err != nil {
208 j.l.Error("Failed to save last time during shutdown", "error", err)
209 }
210 j.l.Info("Saved lastTimeUs before shutdown", "lastTimeUs", lastTimeUs)
211
212 j.cancelMu.Lock()
213 if j.cancel != nil {
214 j.cancel()
215 }
216 j.cancelMu.Unlock()
217
218 cancel()
219
220 os.Exit(0)
221 }()
222
223 return ctxWithCancel
224}