forked from tangled.org/core
Monorepo for Tangled — https://tangled.org
1package knotclient 2 3import ( 4 "context" 5 "encoding/json" 6 "fmt" 7 "log/slog" 8 "math/rand" 9 "net/url" 10 "sync" 11 "time" 12 13 "tangled.sh/tangled.sh/core/knotclient/cursor" 14 "tangled.sh/tangled.sh/core/log" 15 16 "github.com/gorilla/websocket" 17) 18 19type ProcessFunc func(ctx context.Context, source EventSource, message Message) error 20 21type Message struct { 22 Rkey string 23 Nsid string 24 // do not full deserialize this portion of the message, processFunc can do that 25 EventJson json.RawMessage `json:"event"` 26} 27 28type ConsumerConfig struct { 29 Sources map[EventSource]struct{} 30 ProcessFunc ProcessFunc 31 RetryInterval time.Duration 32 MaxRetryInterval time.Duration 33 ConnectionTimeout time.Duration 34 WorkerCount int 35 QueueSize int 36 Logger *slog.Logger 37 Dev bool 38 CursorStore cursor.Store 39} 40 41func NewConsumerConfig() *ConsumerConfig { 42 return &ConsumerConfig{ 43 Sources: make(map[EventSource]struct{}), 44 } 45} 46 47type EventSource struct { 48 Knot string 49} 50 51func NewEventSource(knot string) EventSource { 52 return EventSource{ 53 Knot: knot, 54 } 55} 56 57type EventConsumer struct { 58 wg sync.WaitGroup 59 dialer *websocket.Dialer 60 connMap sync.Map 61 jobQueue chan job 62 logger *slog.Logger 63 randSource *rand.Rand 64 65 // rw lock over edits to ConsumerConfig 66 cfgMu sync.RWMutex 67 cfg ConsumerConfig 68} 69 70func (e *EventConsumer) buildUrl(s EventSource, cursor int64) (*url.URL, error) { 71 scheme := "wss" 72 if e.cfg.Dev { 73 scheme = "ws" 74 } 75 76 u, err := url.Parse(scheme + "://" + s.Knot + "/events") 77 if err != nil { 78 return nil, err 79 } 80 81 if cursor != 0 { 82 query := url.Values{} 83 query.Add("cursor", fmt.Sprintf("%d", cursor)) 84 u.RawQuery = query.Encode() 85 } 86 return u, nil 87} 88 89type job struct { 90 source EventSource 91 message []byte 92} 93 94func NewEventConsumer(cfg ConsumerConfig) *EventConsumer { 95 if cfg.RetryInterval == 0 { 96 cfg.RetryInterval = 15 * time.Minute 97 } 98 if cfg.ConnectionTimeout == 0 { 99 cfg.ConnectionTimeout = 10 * time.Second 100 } 101 if cfg.WorkerCount <= 0 { 102 cfg.WorkerCount = 5 103 } 104 if cfg.MaxRetryInterval == 0 { 105 cfg.MaxRetryInterval = 1 * time.Hour 106 } 107 if cfg.Logger == nil { 108 cfg.Logger = log.New("eventconsumer") 109 } 110 if cfg.QueueSize == 0 { 111 cfg.QueueSize = 100 112 } 113 if cfg.CursorStore == nil { 114 cfg.CursorStore = &cursor.MemoryStore{} 115 } 116 return &EventConsumer{ 117 cfg: cfg, 118 dialer: websocket.DefaultDialer, 119 jobQueue: make(chan job, cfg.QueueSize), // buffered job queue 120 logger: cfg.Logger, 121 randSource: rand.New(rand.NewSource(time.Now().UnixNano())), 122 } 123} 124 125func (c *EventConsumer) Start(ctx context.Context) { 126 c.cfg.Logger.Info("starting consumer", "config", c.cfg) 127 128 // start workers 129 for range c.cfg.WorkerCount { 130 c.wg.Add(1) 131 go c.worker(ctx) 132 } 133 134 // start streaming 135 for source := range c.cfg.Sources { 136 c.wg.Add(1) 137 go c.startConnectionLoop(ctx, source) 138 } 139} 140 141func (c *EventConsumer) Stop() { 142 c.connMap.Range(func(_, val any) bool { 143 if conn, ok := val.(*websocket.Conn); ok { 144 conn.Close() 145 } 146 return true 147 }) 148 c.wg.Wait() 149 close(c.jobQueue) 150} 151 152func (c *EventConsumer) AddSource(ctx context.Context, s EventSource) { 153 c.cfgMu.Lock() 154 c.cfg.Sources[s] = struct{}{} 155 c.wg.Add(1) 156 go c.startConnectionLoop(ctx, s) 157 c.cfgMu.Unlock() 158} 159 160func (c *EventConsumer) worker(ctx context.Context) { 161 defer c.wg.Done() 162 for { 163 select { 164 case <-ctx.Done(): 165 return 166 case j, ok := <-c.jobQueue: 167 if !ok { 168 return 169 } 170 171 var msg Message 172 err := json.Unmarshal(j.message, &msg) 173 if err != nil { 174 c.logger.Error("error deserializing message", "source", j.source.Knot, "err", err) 175 return 176 } 177 178 // update cursor 179 c.cfg.CursorStore.Set(j.source.Knot, time.Now().UnixNano()) 180 181 if err := c.cfg.ProcessFunc(ctx, j.source, msg); err != nil { 182 c.logger.Error("error processing message", "source", j.source, "err", err) 183 } 184 } 185 } 186} 187 188func (c *EventConsumer) startConnectionLoop(ctx context.Context, source EventSource) { 189 defer c.wg.Done() 190 retryInterval := c.cfg.RetryInterval 191 for { 192 select { 193 case <-ctx.Done(): 194 return 195 default: 196 err := c.runConnection(ctx, source) 197 if err != nil { 198 c.logger.Error("connection failed", "source", source, "err", err) 199 } 200 201 // apply jitter 202 jitter := time.Duration(c.randSource.Int63n(int64(retryInterval) / 5)) 203 delay := retryInterval + jitter 204 205 if retryInterval < c.cfg.MaxRetryInterval { 206 retryInterval *= 2 207 if retryInterval > c.cfg.MaxRetryInterval { 208 retryInterval = c.cfg.MaxRetryInterval 209 } 210 } 211 c.logger.Info("retrying connection", "source", source, "delay", delay) 212 select { 213 case <-time.After(delay): 214 case <-ctx.Done(): 215 return 216 } 217 } 218 } 219} 220 221func (c *EventConsumer) runConnection(ctx context.Context, source EventSource) error { 222 connCtx, cancel := context.WithTimeout(ctx, c.cfg.ConnectionTimeout) 223 defer cancel() 224 225 cursor := c.cfg.CursorStore.Get(source.Knot) 226 227 u, err := c.buildUrl(source, cursor) 228 if err != nil { 229 return err 230 } 231 232 c.logger.Info("connecting", "url", u.String()) 233 conn, _, err := c.dialer.DialContext(connCtx, u.String(), nil) 234 if err != nil { 235 return err 236 } 237 defer conn.Close() 238 c.connMap.Store(source, conn) 239 defer c.connMap.Delete(source) 240 241 c.logger.Info("connected", "source", source) 242 243 for { 244 select { 245 case <-ctx.Done(): 246 return nil 247 default: 248 msgType, msg, err := conn.ReadMessage() 249 if err != nil { 250 return err 251 } 252 if msgType != websocket.TextMessage { 253 continue 254 } 255 select { 256 case c.jobQueue <- job{source: source, message: msg}: 257 case <-ctx.Done(): 258 return nil 259 } 260 } 261 } 262}