1package eventconsumer
2
3import (
4 "context"
5 "encoding/json"
6 "log/slog"
7 "math/rand"
8 "net/url"
9 "sync"
10 "time"
11
12 "tangled.sh/tangled.sh/core/eventconsumer/cursor"
13 "tangled.sh/tangled.sh/core/log"
14
15 "github.com/gorilla/websocket"
16)
17
18type ProcessFunc func(ctx context.Context, source Source, message Message) error
19
20type Message struct {
21 Rkey string
22 Nsid string
23 // do not full deserialize this portion of the message, processFunc can do that
24 EventJson json.RawMessage `json:"event"`
25}
26
27type ConsumerConfig struct {
28 Sources map[Source]struct{}
29 ProcessFunc ProcessFunc
30 RetryInterval time.Duration
31 MaxRetryInterval time.Duration
32 ConnectionTimeout time.Duration
33 WorkerCount int
34 QueueSize int
35 Logger *slog.Logger
36 Dev bool
37 CursorStore cursor.Store
38}
39
40func NewConsumerConfig() *ConsumerConfig {
41 return &ConsumerConfig{
42 Sources: make(map[Source]struct{}),
43 }
44}
45
46type Source interface {
47 // url to start streaming events from
48 Url(cursor int64, dev bool) (*url.URL, error)
49 // cache key for cursor storage
50 Key() string
51}
52
53type Consumer struct {
54 wg sync.WaitGroup
55 dialer *websocket.Dialer
56 connMap sync.Map
57 jobQueue chan job
58 logger *slog.Logger
59 randSource *rand.Rand
60
61 // rw lock over edits to ConsumerConfig
62 cfgMu sync.RWMutex
63 cfg ConsumerConfig
64}
65
66type job struct {
67 source Source
68 message []byte
69}
70
71func NewConsumer(cfg ConsumerConfig) *Consumer {
72 if cfg.RetryInterval == 0 {
73 cfg.RetryInterval = 15 * time.Minute
74 }
75 if cfg.ConnectionTimeout == 0 {
76 cfg.ConnectionTimeout = 10 * time.Second
77 }
78 if cfg.WorkerCount <= 0 {
79 cfg.WorkerCount = 5
80 }
81 if cfg.MaxRetryInterval == 0 {
82 cfg.MaxRetryInterval = 1 * time.Hour
83 }
84 if cfg.Logger == nil {
85 cfg.Logger = log.New("consumer")
86 }
87 if cfg.QueueSize == 0 {
88 cfg.QueueSize = 100
89 }
90 if cfg.CursorStore == nil {
91 cfg.CursorStore = &cursor.MemoryStore{}
92 }
93 return &Consumer{
94 cfg: cfg,
95 dialer: websocket.DefaultDialer,
96 jobQueue: make(chan job, cfg.QueueSize), // buffered job queue
97 logger: cfg.Logger,
98 randSource: rand.New(rand.NewSource(time.Now().UnixNano())),
99 }
100}
101
102func (c *Consumer) Start(ctx context.Context) {
103 c.cfg.Logger.Info("starting consumer", "config", c.cfg)
104
105 // start workers
106 for range c.cfg.WorkerCount {
107 c.wg.Add(1)
108 go c.worker(ctx)
109 }
110
111 // start streaming
112 for source := range c.cfg.Sources {
113 c.wg.Add(1)
114 go c.startConnectionLoop(ctx, source)
115 }
116}
117
118func (c *Consumer) Stop() {
119 c.connMap.Range(func(_, val any) bool {
120 if conn, ok := val.(*websocket.Conn); ok {
121 conn.Close()
122 }
123 return true
124 })
125 c.wg.Wait()
126 close(c.jobQueue)
127}
128
129func (c *Consumer) AddSource(ctx context.Context, s Source) {
130 // we are already listening to this source
131 if _, ok := c.cfg.Sources[s]; ok {
132 c.logger.Info("source already present", "source", s)
133 return
134 }
135
136 c.cfgMu.Lock()
137 c.cfg.Sources[s] = struct{}{}
138 c.wg.Add(1)
139 go c.startConnectionLoop(ctx, s)
140 c.cfgMu.Unlock()
141}
142
143func (c *Consumer) worker(ctx context.Context) {
144 defer c.wg.Done()
145 for {
146 select {
147 case <-ctx.Done():
148 return
149 case j, ok := <-c.jobQueue:
150 if !ok {
151 return
152 }
153
154 var msg Message
155 err := json.Unmarshal(j.message, &msg)
156 if err != nil {
157 c.logger.Error("error deserializing message", "source", j.source.Key(), "err", err)
158 return
159 }
160
161 // update cursor
162 c.cfg.CursorStore.Set(j.source.Key(), time.Now().UnixNano())
163
164 if err := c.cfg.ProcessFunc(ctx, j.source, msg); err != nil {
165 c.logger.Error("error processing message", "source", j.source, "err", err)
166 }
167 }
168 }
169}
170
171func (c *Consumer) startConnectionLoop(ctx context.Context, source Source) {
172 defer c.wg.Done()
173 retryInterval := c.cfg.RetryInterval
174 for {
175 select {
176 case <-ctx.Done():
177 return
178 default:
179 err := c.runConnection(ctx, source)
180 if err != nil {
181 c.logger.Error("connection failed", "source", source, "err", err)
182 }
183
184 // apply jitter
185 jitter := time.Duration(c.randSource.Int63n(int64(retryInterval) / 5))
186 delay := retryInterval + jitter
187
188 if retryInterval < c.cfg.MaxRetryInterval {
189 retryInterval *= 2
190 if retryInterval > c.cfg.MaxRetryInterval {
191 retryInterval = c.cfg.MaxRetryInterval
192 }
193 }
194 c.logger.Info("retrying connection", "source", source, "delay", delay)
195 select {
196 case <-time.After(delay):
197 case <-ctx.Done():
198 return
199 }
200 }
201 }
202}
203
204func (c *Consumer) runConnection(ctx context.Context, source Source) error {
205 connCtx, cancel := context.WithTimeout(ctx, c.cfg.ConnectionTimeout)
206 defer cancel()
207
208 cursor := c.cfg.CursorStore.Get(source.Key())
209
210 u, err := source.Url(cursor, c.cfg.Dev)
211 if err != nil {
212 return err
213 }
214
215 c.logger.Info("connecting", "url", u.String())
216 conn, _, err := c.dialer.DialContext(connCtx, u.String(), nil)
217 if err != nil {
218 return err
219 }
220 defer conn.Close()
221 c.connMap.Store(source, conn)
222 defer c.connMap.Delete(source)
223
224 c.logger.Info("connected", "source", source)
225
226 for {
227 select {
228 case <-ctx.Done():
229 return nil
230 default:
231 msgType, msg, err := conn.ReadMessage()
232 if err != nil {
233 return err
234 }
235 if msgType != websocket.TextMessage {
236 continue
237 }
238 select {
239 case c.jobQueue <- job{source: source, message: msg}:
240 case <-ctx.Done():
241 return nil
242 }
243 }
244 }
245}