its for when you want to get like notifications for your reposts
1package main
2
3import (
4 "context"
5 "encoding/json"
6 "log"
7 "log/slog"
8 "net/http"
9 "sync/atomic"
10 "time"
11
12 "github.com/bluesky-social/indigo/api/bsky"
13 "github.com/bluesky-social/indigo/atproto/syntax"
14 "github.com/bluesky-social/indigo/xrpc"
15 "github.com/bluesky-social/jetstream/pkg/models"
16 "github.com/cornelk/hashmap"
17 "github.com/google/uuid"
18 "github.com/gorilla/mux"
19 "github.com/gorilla/websocket"
20)
21
22type Set[T comparable] map[T]struct{}
23
24const ListenTypeNone = "none"
25const ListenTypeFollows = "follows"
26
27type SubscriberData struct {
28 forActor syntax.DID
29 conn *websocket.Conn
30 listenType string
31 listenTo Set[syntax.DID]
32}
33
34type ActorData struct {
35 targets *hashmap.Map[string, *SubscriberData]
36 likes *hashmap.Map[syntax.RecordKey, bsky.FeedLike]
37 follows *hashmap.Map[syntax.RecordKey, bsky.GraphFollow]
38 followsCursor atomic.Pointer[string]
39 profile *bsky.ActorDefs_ProfileViewDetailed
40 profileFetchedAt time.Time
41}
42
43type NotificationActor struct {
44 DID syntax.DID `json:"did"` // the DID of the actor that (un)liked the post
45 Profile *bsky.ActorDefs_ProfileViewDetailed `json:"profile"` // the detailed profile of the actor that (un)liked the post
46}
47
48type NotificationMessage struct {
49 Liked bool `json:"liked"` // whether the message was liked or unliked
50 Actor NotificationActor `json:"actor"` // information about the actor that (un)liked
51 Record bsky.FeedLike `json:"record"` // the raw like record
52 Time int64 `json:"time"` // when the like event came in
53}
54
55type SubscriberMessage struct {
56 Type string `json:"type"`
57 Content json.RawMessage `json:"content"`
58}
59
60type SubscriberUpdateListenTo struct {
61 ListenTo []syntax.DID `json:"listen_to"`
62}
63
64var (
65 // storing the subscriber data in both Should Be Fine
66 // we dont modify subscriber data at the same time in two places
67 subscribers = hashmap.New[string, *SubscriberData]()
68 actorData = hashmap.New[syntax.DID, *ActorData]()
69
70 likeStreams StreamManager
71 followStreams StreamManager
72
73 upgrader = websocket.Upgrader{
74 CheckOrigin: func(r *http.Request) bool {
75 return true
76 },
77 }
78
79 logger *slog.Logger
80)
81
82func getSubscriberDids() []string {
83 _dids := make(Set[string], subscribers.Len())
84 subscribers.Range(func(s string, sd *SubscriberData) bool {
85 _dids[string(sd.forActor)] = struct{}{}
86 return true
87 })
88 dids := make([]string, 0, len(_dids))
89 for k := range _dids {
90 dids = append(dids, k)
91 }
92 return dids
93}
94
95func getLikeDids() []string {
96 _dids := make(Set[string], subscribers.Len()*5000)
97 subscribers.Range(func(s string, sd *SubscriberData) bool {
98 for did := range sd.listenTo {
99 _dids[string(did)] = struct{}{}
100 }
101 return true
102 })
103 dids := make([]string, 0, len(_dids))
104 for k := range _dids {
105 dids = append(dids, k)
106 }
107 return dids
108}
109
110func getActorData(did syntax.DID) *ActorData {
111 ud, _ := actorData.GetOrInsert(did, &ActorData{
112 targets: hashmap.New[string, *SubscriberData](),
113 likes: hashmap.New[syntax.RecordKey, bsky.FeedLike](),
114 follows: hashmap.New[syntax.RecordKey, bsky.GraphFollow](),
115 })
116 return ud
117}
118
119func markActorForLikes(sid string, sd *SubscriberData, did syntax.DID) {
120 ud := getActorData(did)
121 ud.targets.Insert(sid, sd)
122}
123
124func unmarkActorForLikes(sid string, did syntax.DID) {
125 if ud, exists := actorData.Get(did); exists {
126 ud.targets.Del(sid)
127 // remove actor data if no subscribers are attached to it
128 if ud.targets.Len() == 0 {
129 actorData.Del(did)
130 }
131 }
132}
133
134func main() {
135 logger = slog.Default()
136
137 likeStreams = NewStreamManager(logger, "like-tracker", HandleLikeEvent, getLikeStreamOpts)
138 followStreams = NewStreamManager(logger, "subscriber", HandleFollowEvent, getFollowStreamOpts)
139
140 r := mux.NewRouter()
141 r.HandleFunc("/subscribe/{did}", handleSubscribe).Methods("GET")
142
143 log.Println("server starting on :8080")
144 if err := http.ListenAndServe(":8080", r); err != nil {
145 log.Fatalf("error while serving: %s", err)
146 }
147}
148
149func handleSubscribe(w http.ResponseWriter, r *http.Request) {
150 vars := mux.Vars(r)
151 did, err := syntax.ParseDID(vars["did"])
152 if err != nil {
153 http.Error(w, "not a valid did", http.StatusBadRequest)
154 return
155 }
156 sid := uuid.New().String()
157
158 query := r.URL.Query()
159 listenType := query.Get("listenTo")
160 if len(listenType) == 0 {
161 listenType = ListenTypeFollows
162 }
163
164 logger := logger.With("did", did, "subscriberId", sid)
165
166 conn, err := upgrader.Upgrade(w, r, nil)
167 if err != nil {
168 logger.Error("WebSocket upgrade failed", "error", err)
169 return
170 }
171 defer conn.Close()
172
173 logger.Info("new subscriber")
174
175 pdsURI, err := findUserPDS(r.Context(), did)
176 if err != nil {
177 logger.Error("cant resolve user pds", "error", err)
178 return
179 }
180 logger = logger.With("pds", pdsURI)
181
182 xrpcClient := &xrpc.Client{
183 Host: pdsURI,
184 }
185
186 ud := getActorData(did)
187 sd := &SubscriberData{
188 forActor: did,
189 conn: conn,
190 listenType: listenType,
191 }
192
193 switch listenType {
194 case ListenTypeFollows:
195 follows, err := fetchFollows(r.Context(), xrpcClient, ud.followsCursor.Load(), did)
196 if err != nil {
197 logger.Error("error fetching follows", "error", err)
198 return
199 }
200 sd.listenTo = make(Set[syntax.DID])
201 // use we have stored
202 ud.follows.Range(func(rk syntax.RecordKey, f bsky.GraphFollow) bool {
203 sd.listenTo[syntax.DID(f.Subject)] = struct{}{}
204 return true
205 })
206 if len(follows) > 0 {
207 // store cursor for later requests so we dont have to fetch the whole thing again
208 ud.followsCursor.Store((*string)(&follows[len(follows)-1].rkey))
209 for _, f := range follows {
210 ud.follows.Insert(f.rkey, f.follow)
211 sd.listenTo[syntax.DID(f.follow.Subject)] = struct{}{}
212 }
213 }
214 logger.Info("fetched follows")
215 case ListenTypeNone:
216 sd.listenTo = make(Set[syntax.DID])
217 default:
218 http.Error(w, "invalid listen type", http.StatusBadRequest)
219 return
220 }
221
222 subscribers.Set(sid, sd)
223 for listenDid := range sd.listenTo {
224 markActorForLikes(sid, sd, listenDid)
225 }
226 updateStreamOpts()
227 // delete subscriber after we are done
228 defer func() {
229 for listenDid := range sd.listenTo {
230 unmarkActorForLikes(sid, listenDid)
231 }
232 subscribers.Del(sid)
233 updateStreamOpts()
234 }()
235
236 logger.Info("serving subscriber")
237
238 for {
239 var msg SubscriberMessage
240 err := conn.ReadJSON(&msg)
241 if err != nil {
242 logger.Info("WebSocket connection closed", "error", err)
243 break
244 }
245 switch msg.Type {
246 case "update_listen_to":
247 // only allow this if we arent managing listen to
248 if sd.listenType != ListenTypeNone {
249 continue
250 }
251
252 var innerMsg SubscriberUpdateListenTo
253 if err := json.Unmarshal(msg.Content, &innerMsg); err != nil {
254 logger.Info("invalid message", "error", err)
255 break
256 }
257
258 // remove all current listens and add the ones the user requested
259 for listenDid := range sd.listenTo {
260 unmarkActorForLikes(sid, listenDid)
261 delete(sd.listenTo, listenDid)
262 }
263 for _, listenDid := range innerMsg.ListenTo {
264 sd.listenTo[listenDid] = struct{}{}
265 markActorForLikes(sid, sd, listenDid)
266 }
267
268 updateStreamOpts()
269 }
270 }
271}
272
273func getLikeStreamOpts() models.SubscriberOptionsUpdatePayload {
274 return models.SubscriberOptionsUpdatePayload{
275 WantedCollections: []string{"app.bsky.feed.like"},
276 WantedDIDs: getLikeDids(),
277 }
278}
279
280func getFollowStreamOpts() models.SubscriberOptionsUpdatePayload {
281 return models.SubscriberOptionsUpdatePayload{
282 WantedCollections: []string{"app.bsky.graph.follow"},
283 WantedDIDs: getSubscriberDids(),
284 }
285}
286
287func updateStreamOpts() {
288 likeStreams.updateOpts()
289 followStreams.updateOpts()
290}
291
292func HandleLikeEvent(ctx context.Context, event *models.Event) error {
293 if event == nil || event.Commit == nil {
294 return nil
295 }
296
297 byDid := syntax.DID(event.Did)
298 // skip handling event if its not from a source we are listening to
299 ud, exists := actorData.Get(byDid)
300 if !exists || ud.targets.Len() == 0 {
301 return nil
302 }
303
304 logger := logger.With("actor", byDid, "type", "like")
305
306 deleted := event.Commit.Operation == models.CommitOperationDelete
307 rkey := syntax.RecordKey(event.Commit.RKey)
308
309 var like bsky.FeedLike
310 if deleted {
311 if l, exists := ud.likes.Get(rkey); exists {
312 like = l
313 defer ud.likes.Del(rkey)
314 } else {
315 logger.Error("like record not found", "rkey", rkey)
316 return nil
317 }
318 } else if err := unmarshalEvent(event, &like); err != nil {
319 return nil
320 }
321
322 // if there is no via it means its not a repost anyway
323 if like.Via == nil {
324 return nil
325 }
326
327 // store for later when it gets deleted so we can fetch the record
328 if !deleted {
329 ud.likes.Insert(rkey, like)
330 }
331
332 repostURI := syntax.ATURI(like.Via.Uri)
333 // if not a repost we dont care
334 if repostURI.Collection() != "app.bsky.feed.repost" {
335 return nil
336 }
337 reposterDID, err := repostURI.Authority().AsDID()
338 if err != nil {
339 return err
340 }
341 ud.targets.Range(func(sid string, sd *SubscriberData) bool {
342 if sd.forActor != reposterDID {
343 return true
344 }
345
346 if ud.profile == nil || time.Since(ud.profileFetchedAt) > time.Hour*24 {
347 profile, err := fetchProfile(ctx, byDid)
348 if err != nil {
349 logger.Error("cant fetch profile", "error", err)
350 } else {
351 ud.profile = profile
352 ud.profileFetchedAt = time.Now()
353 }
354 }
355
356 notification := NotificationMessage{
357 Liked: !deleted,
358 Actor: NotificationActor{
359 DID: byDid,
360 Profile: ud.profile,
361 },
362 Record: like,
363 Time: event.TimeUS,
364 }
365
366 if err := sd.conn.WriteJSON(notification); err != nil {
367 logger.Error("failed to send notification", "error", err)
368 }
369 return true
370 })
371
372 return nil
373}
374
375func HandleFollowEvent(ctx context.Context, event *models.Event) error {
376 if event == nil || event.Commit == nil {
377 return nil
378 }
379
380 byDid := syntax.DID(event.Did)
381 ud, exists := actorData.Get(byDid)
382 if !exists || ud.targets.Len() == 0 {
383 return nil
384 }
385
386 deleted := event.Commit.Operation == models.CommitOperationDelete
387 rkey := syntax.RecordKey(event.Commit.RKey)
388
389 switch event.Commit.Collection {
390 case "app.bsky.graph.follow":
391 var r bsky.GraphFollow
392 if deleted {
393 if f, exists := ud.follows.Get(rkey); exists {
394 r = f
395 } else {
396 // most likely no ListenTypeFollows subscriber attached on actor
397 logger.Warn("follow record not found", "rkey", rkey, "actor", byDid)
398 return nil
399 }
400 ud.follows.Del(rkey)
401 } else {
402 if err := unmarshalEvent(event, &r); err != nil {
403 return nil
404 }
405 ud.follows.Insert(rkey, r)
406 }
407
408 ud.targets.Range(func(sid string, sd *SubscriberData) bool {
409 // if we arent managing then we dont need to update anything
410 if sd.listenType != ListenTypeFollows {
411 return true
412 }
413 subjectDid := syntax.DID(r.Subject)
414 if deleted {
415 unmarkActorForLikes(sid, subjectDid)
416 delete(sd.listenTo, subjectDid)
417 } else {
418 sd.listenTo[subjectDid] = struct{}{}
419 markActorForLikes(sid, sd, subjectDid)
420 }
421 return true
422 })
423 }
424
425 return nil
426}
427
428func unmarshalEvent[v any](event *models.Event, val *v) error {
429 if err := json.Unmarshal(event.Commit.Record, val); err != nil {
430 logger.Error("cant unmarshal record", "error", err, "raw", event.Commit.Record)
431 return err
432 }
433 return nil
434}