1package eventconsumer
2
3import (
4 "context"
5 "encoding/json"
6 "log/slog"
7 "math/rand"
8 "net/url"
9 "sync"
10 "time"
11
12 "tangled.sh/tangled.sh/core/eventconsumer/cursor"
13 "tangled.sh/tangled.sh/core/log"
14
15 "github.com/avast/retry-go/v4"
16 "github.com/gorilla/websocket"
17)
18
19type ProcessFunc func(ctx context.Context, source Source, message Message) error
20
21type Message struct {
22 Rkey string
23 Nsid string
24 // do not full deserialize this portion of the message, processFunc can do that
25 EventJson json.RawMessage `json:"event"`
26}
27
28type ConsumerConfig struct {
29 Sources map[Source]struct{}
30 ProcessFunc ProcessFunc
31 RetryInterval time.Duration
32 MaxRetryInterval time.Duration
33 ConnectionTimeout time.Duration
34 WorkerCount int
35 QueueSize int
36 Logger *slog.Logger
37 Dev bool
38 CursorStore cursor.Store
39}
40
41func NewConsumerConfig() *ConsumerConfig {
42 return &ConsumerConfig{
43 Sources: make(map[Source]struct{}),
44 }
45}
46
47type Source interface {
48 // url to start streaming events from
49 Url(cursor int64, dev bool) (*url.URL, error)
50 // cache key for cursor storage
51 Key() string
52}
53
54type Consumer struct {
55 wg sync.WaitGroup
56 dialer *websocket.Dialer
57 connMap sync.Map
58 jobQueue chan job
59 logger *slog.Logger
60 randSource *rand.Rand
61
62 // rw lock over edits to ConsumerConfig
63 cfgMu sync.RWMutex
64 cfg ConsumerConfig
65}
66
67type job struct {
68 source Source
69 message []byte
70}
71
72func NewConsumer(cfg ConsumerConfig) *Consumer {
73 if cfg.RetryInterval == 0 {
74 cfg.RetryInterval = 15 * time.Minute
75 }
76 if cfg.ConnectionTimeout == 0 {
77 cfg.ConnectionTimeout = 10 * time.Second
78 }
79 if cfg.WorkerCount <= 0 {
80 cfg.WorkerCount = 5
81 }
82 if cfg.MaxRetryInterval == 0 {
83 cfg.MaxRetryInterval = 1 * time.Hour
84 }
85 if cfg.Logger == nil {
86 cfg.Logger = log.New("consumer")
87 }
88 if cfg.QueueSize == 0 {
89 cfg.QueueSize = 100
90 }
91 if cfg.CursorStore == nil {
92 cfg.CursorStore = &cursor.MemoryStore{}
93 }
94 return &Consumer{
95 cfg: cfg,
96 dialer: websocket.DefaultDialer,
97 jobQueue: make(chan job, cfg.QueueSize), // buffered job queue
98 logger: cfg.Logger,
99 randSource: rand.New(rand.NewSource(time.Now().UnixNano())),
100 }
101}
102
103func (c *Consumer) Start(ctx context.Context) {
104 c.cfg.Logger.Info("starting consumer", "config", c.cfg)
105
106 // start workers
107 for range c.cfg.WorkerCount {
108 c.wg.Add(1)
109 go c.worker(ctx)
110 }
111
112 // start streaming
113 for source := range c.cfg.Sources {
114 c.wg.Add(1)
115 go c.startConnectionLoop(ctx, source)
116 }
117}
118
119func (c *Consumer) Stop() {
120 c.connMap.Range(func(_, val any) bool {
121 if conn, ok := val.(*websocket.Conn); ok {
122 conn.Close()
123 }
124 return true
125 })
126 c.wg.Wait()
127 close(c.jobQueue)
128}
129
130func (c *Consumer) AddSource(ctx context.Context, s Source) {
131 // we are already listening to this source
132 if _, ok := c.cfg.Sources[s]; ok {
133 c.logger.Info("source already present", "source", s)
134 return
135 }
136
137 c.cfgMu.Lock()
138 c.cfg.Sources[s] = struct{}{}
139 c.wg.Add(1)
140 go c.startConnectionLoop(ctx, s)
141 c.cfgMu.Unlock()
142}
143
144func (c *Consumer) worker(ctx context.Context) {
145 defer c.wg.Done()
146 for {
147 select {
148 case <-ctx.Done():
149 return
150 case j, ok := <-c.jobQueue:
151 if !ok {
152 return
153 }
154
155 var msg Message
156 err := json.Unmarshal(j.message, &msg)
157 if err != nil {
158 c.logger.Error("error deserializing message", "source", j.source.Key(), "err", err)
159 return
160 }
161
162 // update cursor
163 c.cfg.CursorStore.Set(j.source.Key(), time.Now().UnixNano())
164
165 if err := c.cfg.ProcessFunc(ctx, j.source, msg); err != nil {
166 c.logger.Error("error processing message", "source", j.source, "err", err)
167 }
168 }
169 }
170}
171
172func (c *Consumer) startConnectionLoop(ctx context.Context, source Source) {
173 defer c.wg.Done()
174
175 // attempt connection initially
176 err := c.runConnection(ctx, source)
177 if err != nil {
178 c.logger.Error("failed to run connection", "err", err)
179 }
180
181 timer := time.NewTimer(1 * time.Minute)
182 defer timer.Stop()
183
184 // every subsequent attempt is delayed by 1 minute
185 for {
186 select {
187 case <-ctx.Done():
188 return
189 case <-timer.C:
190 err := c.runConnection(ctx, source)
191 if err != nil {
192 c.logger.Error("failed to run connection", "err", err)
193 }
194 timer.Reset(1 * time.Minute)
195 }
196 }
197}
198
199func (c *Consumer) runConnection(ctx context.Context, source Source) error {
200 cursor := c.cfg.CursorStore.Get(source.Key())
201
202 u, err := source.Url(cursor, c.cfg.Dev)
203 if err != nil {
204 return err
205 }
206
207 c.logger.Info("connecting", "url", u.String())
208
209 retryOpts := []retry.Option{
210 retry.Attempts(0), // infinite attempts
211 retry.DelayType(retry.BackOffDelay),
212 retry.Delay(c.cfg.RetryInterval),
213 retry.MaxDelay(c.cfg.MaxRetryInterval),
214 retry.MaxJitter(c.cfg.RetryInterval / 5),
215 retry.OnRetry(func(n uint, err error) {
216 c.logger.Info("retrying connection",
217 "source", source,
218 "url", u.String(),
219 "attempt", n+1,
220 "err", err,
221 )
222 }),
223 retry.Context(ctx),
224 }
225
226 var conn *websocket.Conn
227
228 err = retry.Do(func() error {
229 connCtx, cancel := context.WithTimeout(ctx, c.cfg.ConnectionTimeout)
230 defer cancel()
231 conn, _, err = c.dialer.DialContext(connCtx, u.String(), nil)
232 return err
233 }, retryOpts...)
234 if err != nil {
235 return err
236 }
237
238 c.connMap.Store(source, conn)
239 defer conn.Close()
240 defer c.connMap.Delete(source)
241
242 c.logger.Info("connected", "source", source)
243
244 for {
245 select {
246 case <-ctx.Done():
247 return nil
248 default:
249 msgType, msg, err := conn.ReadMessage()
250 if err != nil {
251 return err
252 }
253 if msgType != websocket.TextMessage {
254 continue
255 }
256 select {
257 case c.jobQueue <- job{source: source, message: msg}:
258 case <-ctx.Done():
259 return nil
260 }
261 }
262 }
263}