1package knotclient
2
3import (
4 "context"
5 "log/slog"
6 "math/rand"
7 "net/url"
8 "sync"
9 "time"
10
11 "tangled.sh/tangled.sh/core/log"
12
13 "github.com/gorilla/websocket"
14)
15
16type ProcessFunc func(source string, message []byte) error
17
18type ConsumerConfig struct {
19 Sources []string
20 ProcessFunc ProcessFunc
21 RetryInterval time.Duration
22 MaxRetryInterval time.Duration
23 ConnectionTimeout time.Duration
24 WorkerCount int
25 QueueSize int
26 Logger *slog.Logger
27}
28
29type EventConsumer struct {
30 cfg ConsumerConfig
31 wg sync.WaitGroup
32 dialer *websocket.Dialer
33 connMap sync.Map
34 jobQueue chan job
35 logger *slog.Logger
36 randSource *rand.Rand
37}
38
39type job struct {
40 source string
41 message []byte
42}
43
44func NewEventConsumer(cfg ConsumerConfig) *EventConsumer {
45 if cfg.RetryInterval == 0 {
46 cfg.RetryInterval = 15 * time.Minute
47 }
48 if cfg.ConnectionTimeout == 0 {
49 cfg.ConnectionTimeout = 10 * time.Second
50 }
51 if cfg.WorkerCount <= 0 {
52 cfg.WorkerCount = 5
53 }
54 if cfg.MaxRetryInterval == 0 {
55 cfg.MaxRetryInterval = 1 * time.Hour
56 }
57 if cfg.Logger == nil {
58 cfg.Logger = log.New("eventconsumer")
59 }
60 if cfg.QueueSize == 0 {
61 cfg.QueueSize = 100
62 }
63 return &EventConsumer{
64 cfg: cfg,
65 dialer: websocket.DefaultDialer,
66 jobQueue: make(chan job, cfg.QueueSize), // buffered job queue
67 logger: cfg.Logger,
68 randSource: rand.New(rand.NewSource(time.Now().UnixNano())),
69 }
70}
71
72func (c *EventConsumer) Start(ctx context.Context) {
73 // start workers
74 for range c.cfg.WorkerCount {
75 c.wg.Add(1)
76 go c.worker(ctx)
77 }
78
79 // start streaming
80 for _, source := range c.cfg.Sources {
81 c.wg.Add(1)
82 go c.startConnectionLoop(ctx, source)
83 }
84}
85
86func (c *EventConsumer) Stop() {
87 c.connMap.Range(func(_, val any) bool {
88 if conn, ok := val.(*websocket.Conn); ok {
89 conn.Close()
90 }
91 return true
92 })
93 c.wg.Wait()
94 close(c.jobQueue)
95}
96
97func (c *EventConsumer) worker(ctx context.Context) {
98 defer c.wg.Done()
99 for {
100 select {
101 case <-ctx.Done():
102 return
103 case j, ok := <-c.jobQueue:
104 if !ok {
105 return
106 }
107 if err := c.cfg.ProcessFunc(j.source, j.message); err != nil {
108 c.logger.Error("error processing message", "source", j.source, "err", err)
109 }
110 }
111 }
112}
113
114func (c *EventConsumer) startConnectionLoop(ctx context.Context, source string) {
115 defer c.wg.Done()
116 retryInterval := c.cfg.RetryInterval
117 for {
118 select {
119 case <-ctx.Done():
120 return
121 default:
122 err := c.runConnection(ctx, source)
123 if err != nil {
124 c.logger.Error("connection failed", "source", source, "err", err)
125 }
126
127 // apply jitter
128 jitter := time.Duration(c.randSource.Int63n(int64(retryInterval) / 5))
129 delay := retryInterval + jitter
130
131 if retryInterval < c.cfg.MaxRetryInterval {
132 retryInterval *= 2
133 if retryInterval > c.cfg.MaxRetryInterval {
134 retryInterval = c.cfg.MaxRetryInterval
135 }
136 }
137 c.logger.Info("retrying connection", "source", source, "delay", delay)
138 select {
139 case <-time.After(delay):
140 case <-ctx.Done():
141 return
142 }
143 }
144 }
145}
146
147func (c *EventConsumer) runConnection(ctx context.Context, source string) error {
148 connCtx, cancel := context.WithTimeout(ctx, c.cfg.ConnectionTimeout)
149 defer cancel()
150
151 u, err := url.Parse(source)
152 if err != nil {
153 return err
154 }
155
156 conn, _, err := c.dialer.DialContext(connCtx, u.String(), nil)
157 if err != nil {
158 return err
159 }
160 defer conn.Close()
161 c.connMap.Store(source, conn)
162 defer c.connMap.Delete(source)
163
164 c.logger.Info("connected", "source", source)
165
166 for {
167 select {
168 case <-ctx.Done():
169 return nil
170 default:
171 msgType, msg, err := conn.ReadMessage()
172 if err != nil {
173 return err
174 }
175 if msgType != websocket.TextMessage {
176 continue
177 }
178 select {
179 case c.jobQueue <- job{source: source, message: msg}:
180 case <-ctx.Done():
181 return nil
182 }
183 }
184 }
185}