1package knotclient
2
3import (
4 "context"
5 "encoding/json"
6 "fmt"
7 "log/slog"
8 "math/rand"
9 "net/url"
10 "sync"
11 "time"
12
13 "tangled.sh/tangled.sh/core/knotclient/cursor"
14 "tangled.sh/tangled.sh/core/log"
15
16 "github.com/gorilla/websocket"
17)
18
19type ProcessFunc func(ctx context.Context, source EventSource, message Message) error
20
21type Message struct {
22 Rkey string
23 Nsid string
24 // do not full deserialize this portion of the message, processFunc can do that
25 EventJson json.RawMessage `json:"event"`
26}
27
28type ConsumerConfig struct {
29 Sources map[EventSource]struct{}
30 ProcessFunc ProcessFunc
31 RetryInterval time.Duration
32 MaxRetryInterval time.Duration
33 ConnectionTimeout time.Duration
34 WorkerCount int
35 QueueSize int
36 Logger *slog.Logger
37 Dev bool
38 CursorStore cursor.Store
39}
40
41func NewConsumerConfig() *ConsumerConfig {
42 return &ConsumerConfig{
43 Sources: make(map[EventSource]struct{}),
44 }
45}
46
47func (cc *ConsumerConfig) AddEventSource(es EventSource) {
48 cc.Sources[es] = struct{}{}
49}
50
51type EventSource struct {
52 Knot string
53}
54
55func NewEventSource(knot string) EventSource {
56 return EventSource{
57 Knot: knot,
58 }
59}
60
61type EventConsumer struct {
62 wg sync.WaitGroup
63 dialer *websocket.Dialer
64 connMap sync.Map
65 jobQueue chan job
66 logger *slog.Logger
67 randSource *rand.Rand
68
69 // rw lock over edits to ConsumerConfig
70 cfgMu sync.RWMutex
71 cfg ConsumerConfig
72}
73
74func (e *EventConsumer) buildUrl(s EventSource, cursor int64) (*url.URL, error) {
75 scheme := "wss"
76 if e.cfg.Dev {
77 scheme = "ws"
78 }
79
80 u, err := url.Parse(scheme + "://" + s.Knot + "/events")
81 if err != nil {
82 return nil, err
83 }
84
85 if cursor != 0 {
86 query := url.Values{}
87 query.Add("cursor", fmt.Sprintf("%d", cursor))
88 u.RawQuery = query.Encode()
89 }
90 return u, nil
91}
92
93type job struct {
94 source EventSource
95 message []byte
96}
97
98func NewEventConsumer(cfg ConsumerConfig) *EventConsumer {
99 if cfg.RetryInterval == 0 {
100 cfg.RetryInterval = 15 * time.Minute
101 }
102 if cfg.ConnectionTimeout == 0 {
103 cfg.ConnectionTimeout = 10 * time.Second
104 }
105 if cfg.WorkerCount <= 0 {
106 cfg.WorkerCount = 5
107 }
108 if cfg.MaxRetryInterval == 0 {
109 cfg.MaxRetryInterval = 1 * time.Hour
110 }
111 if cfg.Logger == nil {
112 cfg.Logger = log.New("eventconsumer")
113 }
114 if cfg.QueueSize == 0 {
115 cfg.QueueSize = 100
116 }
117 if cfg.CursorStore == nil {
118 cfg.CursorStore = &cursor.MemoryStore{}
119 }
120 return &EventConsumer{
121 cfg: cfg,
122 dialer: websocket.DefaultDialer,
123 jobQueue: make(chan job, cfg.QueueSize), // buffered job queue
124 logger: cfg.Logger,
125 randSource: rand.New(rand.NewSource(time.Now().UnixNano())),
126 }
127}
128
129func (c *EventConsumer) Start(ctx context.Context) {
130 c.cfg.Logger.Info("starting consumer", "config", c.cfg)
131
132 // start workers
133 for range c.cfg.WorkerCount {
134 c.wg.Add(1)
135 go c.worker(ctx)
136 }
137
138 // start streaming
139 for source := range c.cfg.Sources {
140 c.wg.Add(1)
141 go c.startConnectionLoop(ctx, source)
142 }
143}
144
145func (c *EventConsumer) Stop() {
146 c.connMap.Range(func(_, val any) bool {
147 if conn, ok := val.(*websocket.Conn); ok {
148 conn.Close()
149 }
150 return true
151 })
152 c.wg.Wait()
153 close(c.jobQueue)
154}
155
156func (c *EventConsumer) AddSource(ctx context.Context, s EventSource) {
157 c.cfgMu.Lock()
158 c.cfg.Sources[s] = struct{}{}
159 c.wg.Add(1)
160 go c.startConnectionLoop(ctx, s)
161 c.cfgMu.Unlock()
162}
163
164func (c *EventConsumer) worker(ctx context.Context) {
165 defer c.wg.Done()
166 for {
167 select {
168 case <-ctx.Done():
169 return
170 case j, ok := <-c.jobQueue:
171 if !ok {
172 return
173 }
174
175 var msg Message
176 err := json.Unmarshal(j.message, &msg)
177 if err != nil {
178 c.logger.Error("error deserializing message", "source", j.source.Knot, "err", err)
179 return
180 }
181
182 // update cursor
183 c.cfg.CursorStore.Set(j.source.Knot, time.Now().UnixNano())
184
185 if err := c.cfg.ProcessFunc(ctx, j.source, msg); err != nil {
186 c.logger.Error("error processing message", "source", j.source, "err", err)
187 }
188 }
189 }
190}
191
192func (c *EventConsumer) startConnectionLoop(ctx context.Context, source EventSource) {
193 defer c.wg.Done()
194 retryInterval := c.cfg.RetryInterval
195 for {
196 select {
197 case <-ctx.Done():
198 return
199 default:
200 err := c.runConnection(ctx, source)
201 if err != nil {
202 c.logger.Error("connection failed", "source", source, "err", err)
203 }
204
205 // apply jitter
206 jitter := time.Duration(c.randSource.Int63n(int64(retryInterval) / 5))
207 delay := retryInterval + jitter
208
209 if retryInterval < c.cfg.MaxRetryInterval {
210 retryInterval *= 2
211 if retryInterval > c.cfg.MaxRetryInterval {
212 retryInterval = c.cfg.MaxRetryInterval
213 }
214 }
215 c.logger.Info("retrying connection", "source", source, "delay", delay)
216 select {
217 case <-time.After(delay):
218 case <-ctx.Done():
219 return
220 }
221 }
222 }
223}
224
225func (c *EventConsumer) runConnection(ctx context.Context, source EventSource) error {
226 connCtx, cancel := context.WithTimeout(ctx, c.cfg.ConnectionTimeout)
227 defer cancel()
228
229 cursor := c.cfg.CursorStore.Get(source.Knot)
230
231 u, err := c.buildUrl(source, cursor)
232 if err != nil {
233 return err
234 }
235
236 c.logger.Info("connecting", "url", u.String())
237 conn, _, err := c.dialer.DialContext(connCtx, u.String(), nil)
238 if err != nil {
239 return err
240 }
241 defer conn.Close()
242 c.connMap.Store(source, conn)
243 defer c.connMap.Delete(source)
244
245 c.logger.Info("connected", "source", source)
246
247 for {
248 select {
249 case <-ctx.Done():
250 return nil
251 default:
252 msgType, msg, err := conn.ReadMessage()
253 if err != nil {
254 return err
255 }
256 if msgType != websocket.TextMessage {
257 continue
258 }
259 select {
260 case c.jobQueue <- job{source: source, message: msg}:
261 case <-ctx.Done():
262 return nil
263 }
264 }
265 }
266}