its for when you want to get like notifications for your reposts
1package main
2
3import (
4 "context"
5 "encoding/json"
6 "errors"
7 "io"
8 "log"
9 "log/slog"
10 "net/http"
11 "sync/atomic"
12 "time"
13
14 "github.com/bluesky-social/indigo/api/bsky"
15 "github.com/bluesky-social/indigo/atproto/syntax"
16 "github.com/bluesky-social/indigo/xrpc"
17 "github.com/bluesky-social/jetstream/pkg/models"
18 "github.com/cornelk/hashmap"
19 "github.com/google/uuid"
20 "github.com/gorilla/mux"
21 "github.com/gorilla/websocket"
22)
23
24type Set[T comparable] map[T]struct{}
25
26const ListenTypeNone = "none"
27const ListenTypeFollows = "follows"
28
29type SubscriberData struct {
30 forActor syntax.DID
31 conn *websocket.Conn
32 listenType string
33 listenTo Set[syntax.DID]
34}
35
36type ActorData struct {
37 targets *hashmap.Map[string, *SubscriberData]
38 likes *hashmap.Map[syntax.RecordKey, bsky.FeedLike]
39 follows *hashmap.Map[syntax.RecordKey, bsky.GraphFollow]
40 followsCursor atomic.Pointer[string]
41 profile *bsky.ActorDefs_ProfileViewDetailed
42 profileFetchedAt time.Time
43}
44
45type NotificationActor struct {
46 DID syntax.DID `json:"did"` // the DID of the actor that (un)liked the post
47 Profile *bsky.ActorDefs_ProfileViewDetailed `json:"profile"` // the detailed profile of the actor that (un)liked the post
48}
49
50type NotificationMessage struct {
51 Liked bool `json:"liked"` // whether the message was liked or unliked
52 Actor NotificationActor `json:"actor"` // information about the actor that (un)liked
53 Record bsky.FeedLike `json:"record"` // the raw like record
54 Time int64 `json:"time"` // when the like event came in
55}
56
57type SubscriberMessage struct {
58 Type string `json:"type"`
59 Content json.RawMessage `json:"content"`
60}
61
62type SubscriberUpdateListenTo struct {
63 ListenTo []syntax.DID `json:"listen_to"`
64}
65
66var (
67 // storing the subscriber data in both Should Be Fine
68 // we dont modify subscriber data at the same time in two places
69 subscribers = hashmap.New[string, *SubscriberData]()
70 actorData = hashmap.New[syntax.DID, *ActorData]()
71
72 likeStreams *StreamManager
73 followStreams *StreamManager
74
75 upgrader = websocket.Upgrader{
76 CheckOrigin: func(r *http.Request) bool {
77 return true
78 },
79 }
80
81 logger *slog.Logger
82)
83
84func getSubscriberDids() []string {
85 _dids := make(Set[string], subscribers.Len())
86 subscribers.Range(func(s string, sd *SubscriberData) bool {
87 _dids[string(sd.forActor)] = struct{}{}
88 return true
89 })
90 dids := make([]string, 0, len(_dids))
91 for k := range _dids {
92 dids = append(dids, k)
93 }
94 return dids
95}
96
97func getLikeDids() []string {
98 _dids := make(Set[string], subscribers.Len()*5000)
99 subscribers.Range(func(s string, sd *SubscriberData) bool {
100 for did := range sd.listenTo {
101 _dids[string(did)] = struct{}{}
102 }
103 return true
104 })
105 dids := make([]string, 0, len(_dids))
106 for k := range _dids {
107 dids = append(dids, k)
108 }
109 return dids
110}
111
112func getActorData(did syntax.DID) *ActorData {
113 ud, _ := actorData.GetOrInsert(did, &ActorData{
114 targets: hashmap.New[string, *SubscriberData](),
115 likes: hashmap.New[syntax.RecordKey, bsky.FeedLike](),
116 follows: hashmap.New[syntax.RecordKey, bsky.GraphFollow](),
117 })
118 return ud
119}
120
121func markActorForLikes(sid string, sd *SubscriberData, did syntax.DID) {
122 ud := getActorData(did)
123 ud.targets.Insert(sid, sd)
124}
125
126func unmarkActorForLikes(sid string, did syntax.DID) {
127 if ud, exists := actorData.Get(did); exists {
128 ud.targets.Del(sid)
129 }
130}
131
132func main() {
133 logger = slog.Default()
134
135 likeStreams = NewStreamManager(logger, "like-tracker", HandleLikeEvent, getLikeStreamOpts)
136 followStreams = NewStreamManager(logger, "subscriber", HandleFollowEvent, getFollowStreamOpts)
137
138 r := mux.NewRouter()
139 r.HandleFunc("/subscribe/{did}", handleSubscribe).Methods("GET")
140
141 log.Println("server starting on :8080")
142 if err := http.ListenAndServe(":8080", r); err != nil {
143 log.Fatalf("error while serving: %s", err)
144 }
145}
146
147func handleSubscribe(w http.ResponseWriter, r *http.Request) {
148 vars := mux.Vars(r)
149 did, err := syntax.ParseDID(vars["did"])
150 if err != nil {
151 http.Error(w, "not a valid did", http.StatusBadRequest)
152 return
153 }
154 sid := uuid.New().String()
155
156 query := r.URL.Query()
157 listenType := query.Get("listenTo")
158 if len(listenType) == 0 {
159 listenType = ListenTypeFollows
160 }
161
162 logger := logger.With("did", did, "subscriberId", sid)
163
164 conn, err := upgrader.Upgrade(w, r, nil)
165 if err != nil {
166 logger.Error("WebSocket upgrade failed", "error", err)
167 return
168 }
169 defer conn.Close()
170
171 logger.Info("new subscriber")
172
173 pdsURI, err := findUserPDS(r.Context(), did)
174 if err != nil {
175 logger.Error("cant resolve user pds", "error", err)
176 return
177 }
178 logger = logger.With("pds", pdsURI)
179
180 xrpcClient := &xrpc.Client{
181 Host: pdsURI,
182 }
183
184 ud := getActorData(did)
185 sd := &SubscriberData{
186 forActor: did,
187 conn: conn,
188 listenType: listenType,
189 }
190
191 switch listenType {
192 case ListenTypeFollows:
193 follows, err := fetchFollows(r.Context(), xrpcClient, ud.followsCursor.Load(), did)
194 if err != nil {
195 logger.Error("error fetching follows", "error", err)
196 return
197 }
198 sd.listenTo = make(Set[syntax.DID])
199 // use we have stored
200 ud.follows.Range(func(rk syntax.RecordKey, f bsky.GraphFollow) bool {
201 sd.listenTo[syntax.DID(f.Subject)] = struct{}{}
202 return true
203 })
204 if len(follows) > 0 {
205 // store cursor for later requests so we dont have to fetch the whole thing again
206 ud.followsCursor.Store((*string)(&follows[len(follows)-1].rkey))
207 for _, f := range follows {
208 ud.follows.Insert(f.rkey, f.follow)
209 sd.listenTo[syntax.DID(f.follow.Subject)] = struct{}{}
210 }
211 }
212 logger.Info("fetched follows")
213 case ListenTypeNone:
214 sd.listenTo = make(Set[syntax.DID])
215 default:
216 http.Error(w, "invalid listen type", http.StatusBadRequest)
217 return
218 }
219
220 subscribers.Set(sid, sd)
221 for listenDid := range sd.listenTo {
222 markActorForLikes(sid, sd, listenDid)
223 }
224 updateStreamOpts()
225 // delete subscriber after we are done
226 defer func() {
227 for listenDid := range sd.listenTo {
228 unmarkActorForLikes(sid, listenDid)
229 }
230 subscribers.Del(sid)
231 updateStreamOpts()
232 }()
233
234 logger.Info("serving subscriber")
235
236 // send pings
237 go func() {
238 for {
239 select {
240 case <-r.Context().Done():
241 return
242 default:
243 conn.WriteMessage(websocket.PingMessage, []byte{})
244 time.Sleep(time.Second * 15)
245 }
246 }
247 }()
248
249 for {
250 var msg SubscriberMessage
251 err := conn.ReadJSON(&msg)
252 // dont kill websocket if its no value
253 if err != nil && !errors.Is(err, io.ErrUnexpectedEOF) {
254 logger.Error("WebSocket connection closed", "error", err)
255 break
256 }
257 switch msg.Type {
258 case "update_listen_to":
259 // only allow this if we arent managing listen to
260 if sd.listenType != ListenTypeNone {
261 continue
262 }
263
264 var innerMsg SubscriberUpdateListenTo
265 if err := json.Unmarshal(msg.Content, &innerMsg); err != nil {
266 logger.Debug("invalid message", "error", err)
267 break
268 }
269
270 // remove all current listens and add the ones the user requested
271 for listenDid := range sd.listenTo {
272 unmarkActorForLikes(sid, listenDid)
273 delete(sd.listenTo, listenDid)
274 }
275 for _, listenDid := range innerMsg.ListenTo {
276 sd.listenTo[listenDid] = struct{}{}
277 markActorForLikes(sid, sd, listenDid)
278 }
279
280 updateStreamOpts()
281 }
282 }
283}
284
285func getLikeStreamOpts() models.SubscriberOptionsUpdatePayload {
286 return models.SubscriberOptionsUpdatePayload{
287 WantedCollections: []string{"app.bsky.feed.like"},
288 WantedDIDs: getLikeDids(),
289 }
290}
291
292func getFollowStreamOpts() models.SubscriberOptionsUpdatePayload {
293 return models.SubscriberOptionsUpdatePayload{
294 WantedCollections: []string{"app.bsky.graph.follow"},
295 WantedDIDs: getSubscriberDids(),
296 }
297}
298
299func updateStreamOpts() {
300 likeStreams.updateOpts()
301 followStreams.updateOpts()
302}
303
304func HandleLikeEvent(ctx context.Context, event *models.Event) error {
305 if event == nil || event.Commit == nil {
306 return nil
307 }
308
309 byDid := syntax.DID(event.Did)
310 // skip handling event if its not from a source we are listening to
311 ud, exists := actorData.Get(byDid)
312 if !exists || ud.targets.Len() == 0 {
313 return nil
314 }
315
316 logger := logger.With("actor", byDid, "type", "like")
317
318 deleted := event.Commit.Operation == models.CommitOperationDelete
319 rkey := syntax.RecordKey(event.Commit.RKey)
320
321 var like bsky.FeedLike
322 if deleted {
323 if l, exists := ud.likes.Get(rkey); exists {
324 like = l
325 defer ud.likes.Del(rkey)
326 } else {
327 logger.Error("like record not found", "rkey", rkey)
328 return nil
329 }
330 } else if err := unmarshalEvent(event, &like); err != nil {
331 return nil
332 }
333
334 // if there is no via it means its not a repost anyway
335 if like.Via == nil {
336 return nil
337 }
338
339 // store for later when it gets deleted so we can fetch the record
340 if !deleted {
341 ud.likes.Insert(rkey, like)
342 }
343
344 repostURI := syntax.ATURI(like.Via.Uri)
345 // if not a repost we dont care
346 if repostURI.Collection() != "app.bsky.feed.repost" {
347 return nil
348 }
349 reposterDID, err := repostURI.Authority().AsDID()
350 if err != nil {
351 return err
352 }
353 ud.targets.Range(func(sid string, sd *SubscriberData) bool {
354 if sd.forActor != reposterDID {
355 return true
356 }
357
358 if ud.profile == nil || time.Since(ud.profileFetchedAt) > time.Hour*24 {
359 profile, err := fetchProfile(ctx, byDid)
360 if err != nil {
361 logger.Error("cant fetch profile", "error", err)
362 } else {
363 ud.profile = profile
364 ud.profileFetchedAt = time.Now()
365 }
366 }
367
368 notification := NotificationMessage{
369 Liked: !deleted,
370 Actor: NotificationActor{
371 DID: byDid,
372 Profile: ud.profile,
373 },
374 Record: like,
375 Time: event.TimeUS,
376 }
377
378 if err := sd.conn.WriteJSON(notification); err != nil {
379 logger.Error("failed to send notification", "error", err)
380 }
381 return true
382 })
383
384 return nil
385}
386
387func HandleFollowEvent(ctx context.Context, event *models.Event) error {
388 if event == nil || event.Commit == nil {
389 return nil
390 }
391
392 byDid := syntax.DID(event.Did)
393 ud, exists := actorData.Get(byDid)
394 if !exists || ud.targets.Len() == 0 {
395 return nil
396 }
397
398 deleted := event.Commit.Operation == models.CommitOperationDelete
399 rkey := syntax.RecordKey(event.Commit.RKey)
400
401 switch event.Commit.Collection {
402 case "app.bsky.graph.follow":
403 var r bsky.GraphFollow
404 if deleted {
405 if f, exists := ud.follows.Get(rkey); exists {
406 r = f
407 } else {
408 // most likely no ListenTypeFollows subscriber attached on actor
409 logger.Warn("follow record not found", "rkey", rkey, "actor", byDid)
410 return nil
411 }
412 ud.follows.Del(rkey)
413 } else {
414 if err := unmarshalEvent(event, &r); err != nil {
415 return nil
416 }
417 ud.follows.Insert(rkey, r)
418 }
419
420 ud.targets.Range(func(sid string, sd *SubscriberData) bool {
421 // if we arent managing then we dont need to update anything
422 if sd.listenType != ListenTypeFollows {
423 return true
424 }
425 subjectDid := syntax.DID(r.Subject)
426 if deleted {
427 unmarkActorForLikes(sid, subjectDid)
428 delete(sd.listenTo, subjectDid)
429 } else {
430 sd.listenTo[subjectDid] = struct{}{}
431 markActorForLikes(sid, sd, subjectDid)
432 }
433 return true
434 })
435 }
436
437 return nil
438}
439
440func unmarshalEvent[v any](event *models.Event, val *v) error {
441 if err := json.Unmarshal(event.Commit.Record, val); err != nil {
442 logger.Error("cant unmarshal record", "error", err, "raw", event.Commit.Record)
443 return err
444 }
445 return nil
446}