1package eventconsumer
2
3import (
4 "context"
5 "encoding/json"
6 "log/slog"
7 "math/rand"
8 "net/url"
9 "sync"
10 "time"
11
12 "tangled.sh/tangled.sh/core/eventconsumer/cursor"
13 "tangled.sh/tangled.sh/core/log"
14
15 "github.com/avast/retry-go/v4"
16 "github.com/gorilla/websocket"
17)
18
19type ProcessFunc func(ctx context.Context, source Source, message Message) error
20
21type Message struct {
22 Rkey string
23 Nsid string
24 // do not full deserialize this portion of the message, processFunc can do that
25 EventJson json.RawMessage `json:"event"`
26}
27
28type ConsumerConfig struct {
29 Sources map[Source]struct{}
30 ProcessFunc ProcessFunc
31 RetryInterval time.Duration
32 MaxRetryInterval time.Duration
33 ConnectionTimeout time.Duration
34 WorkerCount int
35 QueueSize int
36 Logger *slog.Logger
37 Dev bool
38 CursorStore cursor.Store
39}
40
41func NewConsumerConfig() *ConsumerConfig {
42 return &ConsumerConfig{
43 Sources: make(map[Source]struct{}),
44 }
45}
46
47type Source interface {
48 // url to start streaming events from
49 Url(cursor int64, dev bool) (*url.URL, error)
50 // cache key for cursor storage
51 Key() string
52}
53
54type Consumer struct {
55 wg sync.WaitGroup
56 dialer *websocket.Dialer
57 connMap sync.Map
58 jobQueue chan job
59 logger *slog.Logger
60 randSource *rand.Rand
61
62 // rw lock over edits to ConsumerConfig
63 cfgMu sync.RWMutex
64 cfg ConsumerConfig
65}
66
67type job struct {
68 source Source
69 message []byte
70}
71
72func NewConsumer(cfg ConsumerConfig) *Consumer {
73 if cfg.RetryInterval == 0 {
74 cfg.RetryInterval = 15 * time.Minute
75 }
76 if cfg.ConnectionTimeout == 0 {
77 cfg.ConnectionTimeout = 10 * time.Second
78 }
79 if cfg.WorkerCount <= 0 {
80 cfg.WorkerCount = 5
81 }
82 if cfg.MaxRetryInterval == 0 {
83 cfg.MaxRetryInterval = 1 * time.Hour
84 }
85 if cfg.Logger == nil {
86 cfg.Logger = log.New("consumer")
87 }
88 if cfg.QueueSize == 0 {
89 cfg.QueueSize = 100
90 }
91 if cfg.CursorStore == nil {
92 cfg.CursorStore = &cursor.MemoryStore{}
93 }
94 return &Consumer{
95 cfg: cfg,
96 dialer: websocket.DefaultDialer,
97 jobQueue: make(chan job, cfg.QueueSize), // buffered job queue
98 logger: cfg.Logger,
99 randSource: rand.New(rand.NewSource(time.Now().UnixNano())),
100 }
101}
102
103func (c *Consumer) Start(ctx context.Context) {
104 c.cfg.Logger.Info("starting consumer", "config", c.cfg)
105
106 // start workers
107 for range c.cfg.WorkerCount {
108 c.wg.Add(1)
109 go c.worker(ctx)
110 }
111
112 // start streaming
113 for source := range c.cfg.Sources {
114 c.wg.Add(1)
115 go c.startConnectionLoop(ctx, source)
116 }
117}
118
119func (c *Consumer) Stop() {
120 c.connMap.Range(func(_, val any) bool {
121 if conn, ok := val.(*websocket.Conn); ok {
122 conn.Close()
123 }
124 return true
125 })
126 c.wg.Wait()
127 close(c.jobQueue)
128}
129
130func (c *Consumer) AddSource(ctx context.Context, s Source) {
131 // we are already listening to this source
132 if _, ok := c.cfg.Sources[s]; ok {
133 c.logger.Info("source already present", "source", s)
134 return
135 }
136
137 c.cfgMu.Lock()
138 c.cfg.Sources[s] = struct{}{}
139 c.wg.Add(1)
140 go c.startConnectionLoop(ctx, s)
141 c.cfgMu.Unlock()
142}
143
144func (c *Consumer) worker(ctx context.Context) {
145 defer c.wg.Done()
146 for {
147 select {
148 case <-ctx.Done():
149 return
150 case j, ok := <-c.jobQueue:
151 if !ok {
152 return
153 }
154
155 var msg Message
156 err := json.Unmarshal(j.message, &msg)
157 if err != nil {
158 c.logger.Error("error deserializing message", "source", j.source.Key(), "err", err)
159 return
160 }
161
162 // update cursor
163 c.cfg.CursorStore.Set(j.source.Key(), time.Now().UnixNano())
164
165 if err := c.cfg.ProcessFunc(ctx, j.source, msg); err != nil {
166 c.logger.Error("error processing message", "source", j.source, "err", err)
167 }
168 }
169 }
170}
171
172func (c *Consumer) startConnectionLoop(ctx context.Context, source Source) {
173 defer c.wg.Done()
174
175 for {
176 select {
177 case <-ctx.Done():
178 return
179 default:
180 err := c.runConnection(ctx, source)
181 if err != nil {
182 c.logger.Error("failed to run connection", "err", err)
183 }
184 }
185 }
186}
187
188func (c *Consumer) runConnection(ctx context.Context, source Source) error {
189 cursor := c.cfg.CursorStore.Get(source.Key())
190
191 u, err := source.Url(cursor, c.cfg.Dev)
192 if err != nil {
193 return err
194 }
195
196 c.logger.Info("connecting", "url", u.String())
197
198 retryOpts := []retry.Option{
199 retry.Attempts(0), // infinite attempts
200 retry.DelayType(retry.BackOffDelay),
201 retry.Delay(c.cfg.RetryInterval),
202 retry.MaxDelay(c.cfg.MaxRetryInterval),
203 retry.MaxJitter(c.cfg.RetryInterval / 5),
204 retry.OnRetry(func(n uint, err error) {
205 c.logger.Info("retrying connection",
206 "source", source,
207 "url", u.String(),
208 "attempt", n+1,
209 "err", err,
210 )
211 }),
212 retry.Context(ctx),
213 }
214
215 var conn *websocket.Conn
216
217 err = retry.Do(func() error {
218 connCtx, cancel := context.WithTimeout(ctx, c.cfg.ConnectionTimeout)
219 defer cancel()
220 conn, _, err = c.dialer.DialContext(connCtx, u.String(), nil)
221 return err
222 }, retryOpts...)
223 if err != nil {
224 return err
225 }
226
227 c.connMap.Store(source, conn)
228 defer conn.Close()
229 defer c.connMap.Delete(source)
230
231 c.logger.Info("connected", "source", source)
232
233 for {
234 select {
235 case <-ctx.Done():
236 return nil
237 default:
238 msgType, msg, err := conn.ReadMessage()
239 if err != nil {
240 return err
241 }
242 if msgType != websocket.TextMessage {
243 continue
244 }
245 select {
246 case c.jobQueue <- job{source: source, message: msg}:
247 case <-ctx.Done():
248 return nil
249 }
250 }
251 }
252}