1package knotclient
2
3import (
4 "context"
5 "encoding/json"
6 "fmt"
7 "log/slog"
8 "math/rand"
9 "net/url"
10 "sync"
11 "time"
12
13 "tangled.sh/tangled.sh/core/appview/cache"
14 "tangled.sh/tangled.sh/core/log"
15
16 "github.com/gorilla/websocket"
17)
18
19type ProcessFunc func(source EventSource, message Message) error
20
21type Message struct {
22 Rkey string
23 Nsid string
24 // do not full deserialize this portion of the message, processFunc can do that
25 EventJson json.RawMessage `json:"event"`
26}
27
28type ConsumerConfig struct {
29 Sources map[EventSource]struct{}
30 ProcessFunc ProcessFunc
31 RetryInterval time.Duration
32 MaxRetryInterval time.Duration
33 ConnectionTimeout time.Duration
34 WorkerCount int
35 QueueSize int
36 Logger *slog.Logger
37 Dev bool
38 CursorStore CursorStore
39}
40
41type EventSource struct {
42 Knot string
43}
44
45func NewEventSource(knot string) EventSource {
46 return EventSource{
47 Knot: knot,
48 }
49}
50
51type EventConsumer struct {
52 cfg ConsumerConfig
53 wg sync.WaitGroup
54 dialer *websocket.Dialer
55 connMap sync.Map
56 jobQueue chan job
57 logger *slog.Logger
58 randSource *rand.Rand
59
60 // rw lock over edits to consumer config
61 mu sync.RWMutex
62}
63
64type CursorStore interface {
65 Set(knot, cursor string)
66 Get(knot string) (cursor string)
67}
68
69type RedisCursorStore struct {
70 rdb *cache.Cache
71}
72
73func NewRedisCursorStore(cache *cache.Cache) RedisCursorStore {
74 return RedisCursorStore{
75 rdb: cache,
76 }
77}
78
79const (
80 cursorKey = "cursor:%s"
81)
82
83func (r *RedisCursorStore) Set(knot, cursor string) {
84 key := fmt.Sprintf(cursorKey, knot)
85 r.rdb.Set(context.Background(), key, cursor, 0)
86}
87
88func (r *RedisCursorStore) Get(knot string) (cursor string) {
89 key := fmt.Sprintf(cursorKey, knot)
90 val, err := r.rdb.Get(context.Background(), key).Result()
91 if err != nil {
92 return ""
93 }
94
95 return val
96}
97
98type MemoryCursorStore struct {
99 store sync.Map
100}
101
102func (m *MemoryCursorStore) Set(knot, cursor string) {
103 m.store.Store(knot, cursor)
104}
105
106func (m *MemoryCursorStore) Get(knot string) (cursor string) {
107 if result, ok := m.store.Load(knot); ok {
108 if val, ok := result.(string); ok {
109 return val
110 }
111 }
112
113 return ""
114}
115
116func (e *EventConsumer) buildUrl(s EventSource, cursor string) (*url.URL, error) {
117 scheme := "wss"
118 if e.cfg.Dev {
119 scheme = "ws"
120 }
121
122 u, err := url.Parse(scheme + "://" + s.Knot + "/events")
123 if err != nil {
124 return nil, err
125 }
126
127 if cursor != "" {
128 query := url.Values{}
129 query.Add("cursor", cursor)
130 u.RawQuery = query.Encode()
131 }
132 return u, nil
133}
134
135type job struct {
136 source EventSource
137 message []byte
138}
139
140func NewEventConsumer(cfg ConsumerConfig) *EventConsumer {
141 if cfg.RetryInterval == 0 {
142 cfg.RetryInterval = 15 * time.Minute
143 }
144 if cfg.ConnectionTimeout == 0 {
145 cfg.ConnectionTimeout = 10 * time.Second
146 }
147 if cfg.WorkerCount <= 0 {
148 cfg.WorkerCount = 5
149 }
150 if cfg.MaxRetryInterval == 0 {
151 cfg.MaxRetryInterval = 1 * time.Hour
152 }
153 if cfg.Logger == nil {
154 cfg.Logger = log.New("eventconsumer")
155 }
156 if cfg.QueueSize == 0 {
157 cfg.QueueSize = 100
158 }
159 if cfg.CursorStore == nil {
160 cfg.CursorStore = &MemoryCursorStore{}
161 }
162 return &EventConsumer{
163 cfg: cfg,
164 dialer: websocket.DefaultDialer,
165 jobQueue: make(chan job, cfg.QueueSize), // buffered job queue
166 logger: cfg.Logger,
167 randSource: rand.New(rand.NewSource(time.Now().UnixNano())),
168 }
169}
170
171func (c *EventConsumer) Start(ctx context.Context) {
172 c.cfg.Logger.Info("starting consumer", "config", c.cfg)
173
174 // start workers
175 for range c.cfg.WorkerCount {
176 c.wg.Add(1)
177 go c.worker(ctx)
178 }
179
180 // start streaming
181 for source := range c.cfg.Sources {
182 c.wg.Add(1)
183 go c.startConnectionLoop(ctx, source)
184 }
185}
186
187func (c *EventConsumer) Stop() {
188 c.connMap.Range(func(_, val any) bool {
189 if conn, ok := val.(*websocket.Conn); ok {
190 conn.Close()
191 }
192 return true
193 })
194 c.wg.Wait()
195 close(c.jobQueue)
196}
197
198func (c *EventConsumer) AddSource(ctx context.Context, s EventSource) {
199 c.mu.Lock()
200 c.cfg.Sources[s] = struct{}{}
201 c.wg.Add(1)
202 go c.startConnectionLoop(ctx, s)
203 c.mu.Unlock()
204}
205
206func (c *EventConsumer) worker(ctx context.Context) {
207 defer c.wg.Done()
208 for {
209 select {
210 case <-ctx.Done():
211 return
212 case j, ok := <-c.jobQueue:
213 if !ok {
214 return
215 }
216
217 var msg Message
218 err := json.Unmarshal(j.message, &msg)
219 if err != nil {
220 c.logger.Error("error deserializing message", "source", j.source.Knot, "err", err)
221 return
222 }
223
224 // update cursor
225 c.cfg.CursorStore.Set(j.source.Knot, msg.Rkey)
226
227 if err := c.cfg.ProcessFunc(j.source, msg); err != nil {
228 c.logger.Error("error processing message", "source", j.source, "err", err)
229 }
230 }
231 }
232}
233
234func (c *EventConsumer) startConnectionLoop(ctx context.Context, source EventSource) {
235 defer c.wg.Done()
236 retryInterval := c.cfg.RetryInterval
237 for {
238 select {
239 case <-ctx.Done():
240 return
241 default:
242 err := c.runConnection(ctx, source)
243 if err != nil {
244 c.logger.Error("connection failed", "source", source, "err", err)
245 }
246
247 // apply jitter
248 jitter := time.Duration(c.randSource.Int63n(int64(retryInterval) / 5))
249 delay := retryInterval + jitter
250
251 if retryInterval < c.cfg.MaxRetryInterval {
252 retryInterval *= 2
253 if retryInterval > c.cfg.MaxRetryInterval {
254 retryInterval = c.cfg.MaxRetryInterval
255 }
256 }
257 c.logger.Info("retrying connection", "source", source, "delay", delay)
258 select {
259 case <-time.After(delay):
260 case <-ctx.Done():
261 return
262 }
263 }
264 }
265}
266
267func (c *EventConsumer) runConnection(ctx context.Context, source EventSource) error {
268 connCtx, cancel := context.WithTimeout(ctx, c.cfg.ConnectionTimeout)
269 defer cancel()
270
271 cursor := c.cfg.CursorStore.Get(source.Knot)
272
273 u, err := c.buildUrl(source, cursor)
274 if err != nil {
275 return err
276 }
277
278 c.logger.Info("connecting", "url", u.String())
279 conn, _, err := c.dialer.DialContext(connCtx, u.String(), nil)
280 if err != nil {
281 return err
282 }
283 defer conn.Close()
284 c.connMap.Store(source, conn)
285 defer c.connMap.Delete(source)
286
287 c.logger.Info("connected", "source", source)
288
289 for {
290 select {
291 case <-ctx.Done():
292 return nil
293 default:
294 msgType, msg, err := conn.ReadMessage()
295 if err != nil {
296 return err
297 }
298 if msgType != websocket.TextMessage {
299 continue
300 }
301 select {
302 case c.jobQueue <- job{source: source, message: msg}:
303 case <-ctx.Done():
304 return nil
305 }
306 }
307 }
308}