forked from
tangled.org/core
Monorepo for Tangled — https://tangled.org
1package knotclient
2
3import (
4 "context"
5 "encoding/json"
6 "fmt"
7 "log/slog"
8 "math/rand"
9 "net/url"
10 "strconv"
11 "sync"
12 "time"
13
14 "tangled.sh/tangled.sh/core/appview/cache"
15 "tangled.sh/tangled.sh/core/log"
16
17 "github.com/gorilla/websocket"
18)
19
20type ProcessFunc func(ctx context.Context, source EventSource, message Message) error
21
22type Message struct {
23 Rkey string
24 Nsid string
25 // do not full deserialize this portion of the message, processFunc can do that
26 EventJson json.RawMessage `json:"event"`
27}
28
29type ConsumerConfig struct {
30 Sources map[EventSource]struct{}
31 ProcessFunc ProcessFunc
32 RetryInterval time.Duration
33 MaxRetryInterval time.Duration
34 ConnectionTimeout time.Duration
35 WorkerCount int
36 QueueSize int
37 Logger *slog.Logger
38 Dev bool
39 CursorStore CursorStore
40}
41
42func NewConsumerConfig() *ConsumerConfig {
43 return &ConsumerConfig{
44 Sources: make(map[EventSource]struct{}),
45 }
46}
47
48func (cc *ConsumerConfig) AddEventSource(es EventSource) {
49 cc.Sources[es] = struct{}{}
50}
51
52type EventSource struct {
53 Knot string
54}
55
56func NewEventSource(knot string) EventSource {
57 return EventSource{
58 Knot: knot,
59 }
60}
61
62type EventConsumer struct {
63 wg sync.WaitGroup
64 dialer *websocket.Dialer
65 connMap sync.Map
66 jobQueue chan job
67 logger *slog.Logger
68 randSource *rand.Rand
69
70 // rw lock over edits to ConsumerConfig
71 cfgMu sync.RWMutex
72 cfg ConsumerConfig
73}
74
75type CursorStore interface {
76 Set(knot string, cursor int64)
77 Get(knot string) (cursor int64)
78}
79
80type RedisCursorStore struct {
81 rdb *cache.Cache
82}
83
84func NewRedisCursorStore(cache *cache.Cache) RedisCursorStore {
85 return RedisCursorStore{
86 rdb: cache,
87 }
88}
89
90const (
91 cursorKey = "cursor:%s"
92)
93
94func (r *RedisCursorStore) Set(knot string, cursor int64) {
95 key := fmt.Sprintf(cursorKey, knot)
96 r.rdb.Set(context.Background(), key, cursor, 0)
97}
98
99func (r *RedisCursorStore) Get(knot string) (cursor int64) {
100 key := fmt.Sprintf(cursorKey, knot)
101 val, err := r.rdb.Get(context.Background(), key).Result()
102 if err != nil {
103 return 0
104 }
105
106 cursor, err = strconv.ParseInt(val, 10, 64)
107 if err != nil {
108 return 0 // optionally log parsing error
109 }
110
111 return cursor
112}
113
114type MemoryCursorStore struct {
115 store sync.Map
116}
117
118func (m *MemoryCursorStore) Set(knot string, cursor int64) {
119 m.store.Store(knot, cursor)
120}
121
122func (m *MemoryCursorStore) Get(knot string) (cursor int64) {
123 if result, ok := m.store.Load(knot); ok {
124 if val, ok := result.(int64); ok {
125 return val
126 }
127 }
128
129 return 0
130}
131
132func (e *EventConsumer) buildUrl(s EventSource, cursor int64) (*url.URL, error) {
133 scheme := "wss"
134 if e.cfg.Dev {
135 scheme = "ws"
136 }
137
138 u, err := url.Parse(scheme + "://" + s.Knot + "/events")
139 if err != nil {
140 return nil, err
141 }
142
143 if cursor != 0 {
144 query := url.Values{}
145 query.Add("cursor", fmt.Sprintf("%d", cursor))
146 u.RawQuery = query.Encode()
147 }
148 return u, nil
149}
150
151type job struct {
152 source EventSource
153 message []byte
154}
155
156func NewEventConsumer(cfg ConsumerConfig) *EventConsumer {
157 if cfg.RetryInterval == 0 {
158 cfg.RetryInterval = 15 * time.Minute
159 }
160 if cfg.ConnectionTimeout == 0 {
161 cfg.ConnectionTimeout = 10 * time.Second
162 }
163 if cfg.WorkerCount <= 0 {
164 cfg.WorkerCount = 5
165 }
166 if cfg.MaxRetryInterval == 0 {
167 cfg.MaxRetryInterval = 1 * time.Hour
168 }
169 if cfg.Logger == nil {
170 cfg.Logger = log.New("eventconsumer")
171 }
172 if cfg.QueueSize == 0 {
173 cfg.QueueSize = 100
174 }
175 if cfg.CursorStore == nil {
176 cfg.CursorStore = &MemoryCursorStore{}
177 }
178 return &EventConsumer{
179 cfg: cfg,
180 dialer: websocket.DefaultDialer,
181 jobQueue: make(chan job, cfg.QueueSize), // buffered job queue
182 logger: cfg.Logger,
183 randSource: rand.New(rand.NewSource(time.Now().UnixNano())),
184 }
185}
186
187func (c *EventConsumer) Start(ctx context.Context) {
188 c.cfg.Logger.Info("starting consumer", "config", c.cfg)
189
190 // start workers
191 for range c.cfg.WorkerCount {
192 c.wg.Add(1)
193 go c.worker(ctx)
194 }
195
196 // start streaming
197 for source := range c.cfg.Sources {
198 c.wg.Add(1)
199 go c.startConnectionLoop(ctx, source)
200 }
201}
202
203func (c *EventConsumer) Stop() {
204 c.connMap.Range(func(_, val any) bool {
205 if conn, ok := val.(*websocket.Conn); ok {
206 conn.Close()
207 }
208 return true
209 })
210 c.wg.Wait()
211 close(c.jobQueue)
212}
213
214func (c *EventConsumer) AddSource(ctx context.Context, s EventSource) {
215 c.cfgMu.Lock()
216 c.cfg.Sources[s] = struct{}{}
217 c.wg.Add(1)
218 go c.startConnectionLoop(ctx, s)
219 c.cfgMu.Unlock()
220}
221
222func (c *EventConsumer) worker(ctx context.Context) {
223 defer c.wg.Done()
224 for {
225 select {
226 case <-ctx.Done():
227 return
228 case j, ok := <-c.jobQueue:
229 if !ok {
230 return
231 }
232
233 var msg Message
234 err := json.Unmarshal(j.message, &msg)
235 if err != nil {
236 c.logger.Error("error deserializing message", "source", j.source.Knot, "err", err)
237 return
238 }
239
240 // update cursor
241 c.cfg.CursorStore.Set(j.source.Knot, time.Now().Unix())
242
243 if err := c.cfg.ProcessFunc(ctx, j.source, msg); err != nil {
244 c.logger.Error("error processing message", "source", j.source, "err", err)
245 }
246 }
247 }
248}
249
250func (c *EventConsumer) startConnectionLoop(ctx context.Context, source EventSource) {
251 defer c.wg.Done()
252 retryInterval := c.cfg.RetryInterval
253 for {
254 select {
255 case <-ctx.Done():
256 return
257 default:
258 err := c.runConnection(ctx, source)
259 if err != nil {
260 c.logger.Error("connection failed", "source", source, "err", err)
261 }
262
263 // apply jitter
264 jitter := time.Duration(c.randSource.Int63n(int64(retryInterval) / 5))
265 delay := retryInterval + jitter
266
267 if retryInterval < c.cfg.MaxRetryInterval {
268 retryInterval *= 2
269 if retryInterval > c.cfg.MaxRetryInterval {
270 retryInterval = c.cfg.MaxRetryInterval
271 }
272 }
273 c.logger.Info("retrying connection", "source", source, "delay", delay)
274 select {
275 case <-time.After(delay):
276 case <-ctx.Done():
277 return
278 }
279 }
280 }
281}
282
283func (c *EventConsumer) runConnection(ctx context.Context, source EventSource) error {
284 connCtx, cancel := context.WithTimeout(ctx, c.cfg.ConnectionTimeout)
285 defer cancel()
286
287 cursor := c.cfg.CursorStore.Get(source.Knot)
288
289 u, err := c.buildUrl(source, cursor)
290 if err != nil {
291 return err
292 }
293
294 c.logger.Info("connecting", "url", u.String())
295 conn, _, err := c.dialer.DialContext(connCtx, u.String(), nil)
296 if err != nil {
297 return err
298 }
299 defer conn.Close()
300 c.connMap.Store(source, conn)
301 defer c.connMap.Delete(source)
302
303 c.logger.Info("connected", "source", source)
304
305 for {
306 select {
307 case <-ctx.Done():
308 return nil
309 default:
310 msgType, msg, err := conn.ReadMessage()
311 if err != nil {
312 return err
313 }
314 if msgType != websocket.TextMessage {
315 continue
316 }
317 select {
318 case c.jobQueue <- job{source: source, message: msg}:
319 case <-ctx.Done():
320 return nil
321 }
322 }
323 }
324}