1package knotclient
2
3import (
4 "context"
5 "encoding/json"
6 "fmt"
7 "log/slog"
8 "math/rand"
9 "net/url"
10 "sync"
11 "time"
12
13 "tangled.sh/tangled.sh/core/knotclient/cursor"
14 "tangled.sh/tangled.sh/core/log"
15
16 "github.com/gorilla/websocket"
17)
18
19type ProcessFunc func(ctx context.Context, source EventSource, message Message) error
20
21type Message struct {
22 Rkey string
23 Nsid string
24 // do not full deserialize this portion of the message, processFunc can do that
25 EventJson json.RawMessage `json:"event"`
26}
27
28type ConsumerConfig struct {
29 Sources map[EventSource]struct{}
30 ProcessFunc ProcessFunc
31 RetryInterval time.Duration
32 MaxRetryInterval time.Duration
33 ConnectionTimeout time.Duration
34 WorkerCount int
35 QueueSize int
36 Logger *slog.Logger
37 Dev bool
38 CursorStore cursor.Store
39}
40
41func NewConsumerConfig() *ConsumerConfig {
42 return &ConsumerConfig{
43 Sources: make(map[EventSource]struct{}),
44 }
45}
46
47type EventSource struct {
48 Knot string
49}
50
51func NewEventSource(knot string) EventSource {
52 return EventSource{
53 Knot: knot,
54 }
55}
56
57type EventConsumer struct {
58 wg sync.WaitGroup
59 dialer *websocket.Dialer
60 connMap sync.Map
61 jobQueue chan job
62 logger *slog.Logger
63 randSource *rand.Rand
64
65 // rw lock over edits to ConsumerConfig
66 cfgMu sync.RWMutex
67 cfg ConsumerConfig
68}
69
70func (e *EventConsumer) buildUrl(s EventSource, cursor int64) (*url.URL, error) {
71 scheme := "wss"
72 if e.cfg.Dev {
73 scheme = "ws"
74 }
75
76 u, err := url.Parse(scheme + "://" + s.Knot + "/events")
77 if err != nil {
78 return nil, err
79 }
80
81 if cursor != 0 {
82 query := url.Values{}
83 query.Add("cursor", fmt.Sprintf("%d", cursor))
84 u.RawQuery = query.Encode()
85 }
86 return u, nil
87}
88
89type job struct {
90 source EventSource
91 message []byte
92}
93
94func NewEventConsumer(cfg ConsumerConfig) *EventConsumer {
95 if cfg.RetryInterval == 0 {
96 cfg.RetryInterval = 15 * time.Minute
97 }
98 if cfg.ConnectionTimeout == 0 {
99 cfg.ConnectionTimeout = 10 * time.Second
100 }
101 if cfg.WorkerCount <= 0 {
102 cfg.WorkerCount = 5
103 }
104 if cfg.MaxRetryInterval == 0 {
105 cfg.MaxRetryInterval = 1 * time.Hour
106 }
107 if cfg.Logger == nil {
108 cfg.Logger = log.New("eventconsumer")
109 }
110 if cfg.QueueSize == 0 {
111 cfg.QueueSize = 100
112 }
113 if cfg.CursorStore == nil {
114 cfg.CursorStore = &cursor.MemoryStore{}
115 }
116 return &EventConsumer{
117 cfg: cfg,
118 dialer: websocket.DefaultDialer,
119 jobQueue: make(chan job, cfg.QueueSize), // buffered job queue
120 logger: cfg.Logger,
121 randSource: rand.New(rand.NewSource(time.Now().UnixNano())),
122 }
123}
124
125func (c *EventConsumer) Start(ctx context.Context) {
126 c.cfg.Logger.Info("starting consumer", "config", c.cfg)
127
128 // start workers
129 for range c.cfg.WorkerCount {
130 c.wg.Add(1)
131 go c.worker(ctx)
132 }
133
134 // start streaming
135 for source := range c.cfg.Sources {
136 c.wg.Add(1)
137 go c.startConnectionLoop(ctx, source)
138 }
139}
140
141func (c *EventConsumer) Stop() {
142 c.connMap.Range(func(_, val any) bool {
143 if conn, ok := val.(*websocket.Conn); ok {
144 conn.Close()
145 }
146 return true
147 })
148 c.wg.Wait()
149 close(c.jobQueue)
150}
151
152func (c *EventConsumer) AddSource(ctx context.Context, s EventSource) {
153 c.cfgMu.Lock()
154 c.cfg.Sources[s] = struct{}{}
155 c.wg.Add(1)
156 go c.startConnectionLoop(ctx, s)
157 c.cfgMu.Unlock()
158}
159
160func (c *EventConsumer) worker(ctx context.Context) {
161 defer c.wg.Done()
162 for {
163 select {
164 case <-ctx.Done():
165 return
166 case j, ok := <-c.jobQueue:
167 if !ok {
168 return
169 }
170
171 var msg Message
172 err := json.Unmarshal(j.message, &msg)
173 if err != nil {
174 c.logger.Error("error deserializing message", "source", j.source.Knot, "err", err)
175 return
176 }
177
178 // update cursor
179 c.cfg.CursorStore.Set(j.source.Knot, time.Now().UnixNano())
180
181 if err := c.cfg.ProcessFunc(ctx, j.source, msg); err != nil {
182 c.logger.Error("error processing message", "source", j.source, "err", err)
183 }
184 }
185 }
186}
187
188func (c *EventConsumer) startConnectionLoop(ctx context.Context, source EventSource) {
189 defer c.wg.Done()
190 retryInterval := c.cfg.RetryInterval
191 for {
192 select {
193 case <-ctx.Done():
194 return
195 default:
196 err := c.runConnection(ctx, source)
197 if err != nil {
198 c.logger.Error("connection failed", "source", source, "err", err)
199 }
200
201 // apply jitter
202 jitter := time.Duration(c.randSource.Int63n(int64(retryInterval) / 5))
203 delay := retryInterval + jitter
204
205 if retryInterval < c.cfg.MaxRetryInterval {
206 retryInterval *= 2
207 if retryInterval > c.cfg.MaxRetryInterval {
208 retryInterval = c.cfg.MaxRetryInterval
209 }
210 }
211 c.logger.Info("retrying connection", "source", source, "delay", delay)
212 select {
213 case <-time.After(delay):
214 case <-ctx.Done():
215 return
216 }
217 }
218 }
219}
220
221func (c *EventConsumer) runConnection(ctx context.Context, source EventSource) error {
222 connCtx, cancel := context.WithTimeout(ctx, c.cfg.ConnectionTimeout)
223 defer cancel()
224
225 cursor := c.cfg.CursorStore.Get(source.Knot)
226
227 u, err := c.buildUrl(source, cursor)
228 if err != nil {
229 return err
230 }
231
232 c.logger.Info("connecting", "url", u.String())
233 conn, _, err := c.dialer.DialContext(connCtx, u.String(), nil)
234 if err != nil {
235 return err
236 }
237 defer conn.Close()
238 c.connMap.Store(source, conn)
239 defer c.connMap.Delete(source)
240
241 c.logger.Info("connected", "source", source)
242
243 for {
244 select {
245 case <-ctx.Done():
246 return nil
247 default:
248 msgType, msg, err := conn.ReadMessage()
249 if err != nil {
250 return err
251 }
252 if msgType != websocket.TextMessage {
253 continue
254 }
255 select {
256 case c.jobQueue <- job{source: source, message: msg}:
257 case <-ctx.Done():
258 return nil
259 }
260 }
261 }
262}