its for when you want to get like notifications for your reposts
1package main
2
3import (
4 "context"
5 "encoding/json"
6 "log"
7 "log/slog"
8 "net/http"
9 "sync/atomic"
10
11 "github.com/bluesky-social/indigo/api/bsky"
12 "github.com/bluesky-social/indigo/atproto/syntax"
13 "github.com/bluesky-social/indigo/xrpc"
14 "github.com/bluesky-social/jetstream/pkg/client"
15 "github.com/bluesky-social/jetstream/pkg/models"
16 "github.com/cornelk/hashmap"
17 "github.com/google/uuid"
18 "github.com/gorilla/mux"
19 "github.com/gorilla/websocket"
20)
21
22type Set[T comparable] map[T]struct{}
23
24const ListenTypeNone = "none"
25const ListenTypeFollows = "follows"
26
27type SubscriberData struct {
28 SubscribedTo syntax.DID
29 Conn *websocket.Conn
30 ListenType string
31 ListenTo Set[syntax.DID]
32}
33
34type UserData struct {
35 targets *hashmap.Map[string, *SubscriberData]
36 likes map[syntax.RecordKey]bsky.FeedLike
37 follows *hashmap.Map[syntax.RecordKey, bsky.GraphFollow]
38 followsCursor atomic.Pointer[string]
39}
40
41type NotificationMessage struct {
42 Liked bool `json:"liked"`
43 ByDid syntax.DID `json:"did"`
44 RepostURI syntax.ATURI `json:"repost_uri"`
45}
46
47type SubscriberMessage struct {
48 Type string `json:"type"`
49 Content json.RawMessage `json:"content"`
50}
51
52type SubscriberUpdateListenTo struct {
53 ListenTo []syntax.DID `json:"listen_to"`
54}
55
56var (
57 // storing the subscriber data in both Should Be Fine
58 // we dont modify subscriber data at the same time in two places
59 subscribers = hashmap.New[string, *SubscriberData]()
60 userData = hashmap.New[syntax.DID, *UserData]()
61
62 likeStream *client.Client
63 followStream *client.Client
64
65 upgrader = websocket.Upgrader{
66 CheckOrigin: func(r *http.Request) bool {
67 return true
68 },
69 }
70
71 logger *slog.Logger
72)
73
74func getSubscriberDids() []string {
75 _dids := make(Set[string], subscribers.Len())
76 subscribers.Range(func(s string, sd *SubscriberData) bool {
77 _dids[string(sd.SubscribedTo)] = struct{}{}
78 return true
79 })
80 dids := make([]string, 0, len(_dids))
81 for k := range _dids {
82 dids = append(dids, k)
83 }
84 return dids
85}
86
87func getUserData(did syntax.DID) *UserData {
88 ud, _ := userData.GetOrInsert(did, &UserData{
89 targets: hashmap.New[string, *SubscriberData](),
90 likes: make(map[syntax.RecordKey]bsky.FeedLike),
91 follows: hashmap.New[syntax.RecordKey, bsky.GraphFollow](),
92 })
93 return ud
94}
95
96func startListeningTo(sid string, sd *SubscriberData, did syntax.DID) {
97 ud := getUserData(did)
98 ud.targets.Insert(sid, sd)
99}
100
101func stopListeningTo(sid string, did syntax.DID) {
102 if ud, exists := userData.Get(did); exists {
103 ud.targets.Del(sid)
104 }
105}
106
107func main() {
108 logger = slog.Default()
109
110 go startJetstreamLoop(logger, &likeStream, "like_tracker", HandleLikeEvent, getLikeStreamOpts)
111 go startJetstreamLoop(logger, &followStream, "subscriber", HandleFollowEvent, getFollowStreamOpts)
112
113 r := mux.NewRouter()
114 r.HandleFunc("/subscribe/{did}", handleSubscribe).Methods("GET")
115
116 log.Println("server starting on :8080")
117 if err := http.ListenAndServe(":8080", r); err != nil {
118 log.Fatalf("error while serving: %s", err)
119 }
120}
121
122func handleSubscribe(w http.ResponseWriter, r *http.Request) {
123 vars := mux.Vars(r)
124 did, err := syntax.ParseDID(vars["did"])
125 if err != nil {
126 http.Error(w, "not a valid did", http.StatusBadRequest)
127 return
128 }
129 sid := uuid.New().String()
130
131 query := r.URL.Query()
132 listenType := query.Get("listenTo")
133 if len(listenType) == 0 {
134 listenType = ListenTypeFollows
135 }
136
137 logger := logger.With("did", did, "subscriberId", sid)
138
139 conn, err := upgrader.Upgrade(w, r, nil)
140 if err != nil {
141 logger.Error("WebSocket upgrade failed", "error", err)
142 return
143 }
144 defer conn.Close()
145
146 logger.Info("new subscriber")
147
148 pdsURI, err := findUserPDS(r.Context(), did)
149 if err != nil {
150 logger.Error("cant resolve user pds", "error", err)
151 return
152 }
153 logger = logger.With("pds", pdsURI)
154
155 xrpcClient := &xrpc.Client{
156 Host: pdsURI,
157 }
158
159 ud := getUserData(did)
160 sd := &SubscriberData{
161 SubscribedTo: did,
162 Conn: conn,
163 ListenType: listenType,
164 }
165
166 switch listenType {
167 case ListenTypeFollows:
168 follows, err := fetchFollows(r.Context(), xrpcClient, ud.followsCursor.Load(), did)
169 if err != nil {
170 logger.Error("error fetching follows", "error", err)
171 return
172 }
173 sd.ListenTo = make(Set[syntax.DID])
174 if len(follows) > 0 {
175 // store cursor for later requests so we dont have to fetch the whole thing again
176 ud.followsCursor.Store((*string)(&follows[len(follows)-1].rkey))
177 for _, f := range follows {
178 ud.follows.Insert(f.rkey, f.follow)
179 sd.ListenTo[syntax.DID(f.follow.Subject)] = struct{}{}
180 }
181 }
182 logger.Info("fetched follows")
183 case ListenTypeNone:
184 sd.ListenTo = make(Set[syntax.DID])
185 default:
186 http.Error(w, "invalid listen type", http.StatusBadRequest)
187 return
188 }
189
190 subscribers.Set(sid, sd)
191 for listenDid := range sd.ListenTo {
192 startListeningTo(sid, sd, listenDid)
193 }
194 updateFollowStreamOpts()
195 // delete subscriber after we are done
196 defer func() {
197 for listenDid := range sd.ListenTo {
198 stopListeningTo(sid, listenDid)
199 }
200 subscribers.Del(sid)
201 updateFollowStreamOpts()
202 }()
203
204 logger.Info("serving subscriber")
205
206 for {
207 var msg SubscriberMessage
208 err := conn.ReadJSON(&msg)
209 if err != nil {
210 logger.Info("WebSocket connection closed", "error", err)
211 break
212 }
213 switch msg.Type {
214 case "update_listen_to":
215 // only allow this if we arent managing listen to
216 if sd.ListenType != ListenTypeNone {
217 continue
218 }
219
220 var innerMsg SubscriberUpdateListenTo
221 if err := json.Unmarshal(msg.Content, &innerMsg); err != nil {
222 logger.Info("invalid message", "error", err)
223 break
224 }
225 // remove all current listens and add the ones the user requested
226 for listenDid := range sd.ListenTo {
227 stopListeningTo(sid, listenDid)
228 delete(sd.ListenTo, listenDid)
229 }
230 for _, listenDid := range innerMsg.ListenTo {
231 sd.ListenTo[listenDid] = struct{}{}
232 startListeningTo(sid, sd, listenDid)
233 }
234 }
235 }
236}
237
238func getLikeStreamOpts() models.SubscriberOptionsUpdatePayload {
239 return models.SubscriberOptionsUpdatePayload{
240 WantedCollections: []string{"app.bsky.feed.like"},
241 }
242}
243
244func getFollowStreamOpts() models.SubscriberOptionsUpdatePayload {
245 return models.SubscriberOptionsUpdatePayload{
246 WantedCollections: []string{"app.bsky.graph.follow"},
247 WantedDIDs: getSubscriberDids(),
248 }
249}
250
251func updateFollowStreamOpts() {
252 opts := getFollowStreamOpts()
253 err := followStream.SendOptionsUpdate(opts)
254 if err != nil {
255 logger.Error("couldnt update follow stream opts", "error", err)
256 return
257 }
258 logger.Info("updated follow stream opts", "userCount", len(opts.WantedDIDs))
259}
260
261func HandleLikeEvent(ctx context.Context, event *models.Event) error {
262 if event == nil || event.Commit == nil {
263 return nil
264 }
265
266 byDid := syntax.DID(event.Did)
267 // skip handling event if its not from a source we are listening to
268 ud, exists := userData.Get(byDid)
269 if !exists || ud.targets.Len() == 0 {
270 return nil
271 }
272
273 deleted := event.Commit.Operation == models.CommitOperationDelete
274 rkey := syntax.RecordKey(event.Commit.RKey)
275
276 var like bsky.FeedLike
277 if deleted {
278 if l, exists := ud.likes[rkey]; exists {
279 like = l
280 defer delete(ud.likes, rkey)
281 } else {
282 logger.Error("like record not found", "rkey", rkey)
283 return nil
284 }
285 } else {
286 if err := json.Unmarshal(event.Commit.Record, &like); err != nil {
287 logger.Error("failed to unmarshal like", "error", err)
288 return nil
289 }
290 }
291
292 // if there is no via it means its not a repost anyway
293 if like.Via == nil {
294 return nil
295 }
296
297 // store for later when it gets deleted so we can fetch the record
298 if !deleted {
299 ud.likes[rkey] = like
300 }
301
302 repostURI := syntax.ATURI(like.Via.Uri)
303 // if not a repost we dont care
304 if repostURI.Collection() != "app.bsky.feed.repost" {
305 return nil
306 }
307 reposterDID, err := repostURI.Authority().AsDID()
308 if err != nil {
309 return err
310 }
311 ud.targets.Range(func(sid string, sd *SubscriberData) bool {
312 if sd.SubscribedTo != reposterDID {
313 return true
314 }
315
316 notification := NotificationMessage{
317 Liked: !deleted,
318 ByDid: byDid,
319 RepostURI: repostURI,
320 }
321
322 if err := sd.Conn.WriteJSON(notification); err != nil {
323 logger.Error("failed to send notification", "subscriber", sd.SubscribedTo, "error", err)
324 }
325 return true
326 })
327
328 return nil
329}
330
331func HandleFollowEvent(ctx context.Context, event *models.Event) error {
332 if event == nil || event.Commit == nil {
333 return nil
334 }
335
336 byDid := syntax.DID(event.Did)
337 ud, exists := userData.Get(byDid)
338 if !exists || ud.targets.Len() == 0 {
339 return nil
340 }
341
342 deleted := event.Commit.Operation == models.CommitOperationDelete
343 rkey := syntax.RecordKey(event.Commit.RKey)
344
345 switch event.Commit.Collection {
346 case "app.bsky.graph.follow":
347 var r bsky.GraphFollow
348 if deleted {
349 if f, exists := ud.follows.Get(rkey); exists {
350 r = f
351 } else {
352 logger.Error("follow record not found", "rkey", rkey)
353 return nil
354 }
355 ud.follows.Del(rkey)
356 } else {
357 if err := unmarshalEvent(event, &r); err != nil {
358 logger.Error("could not unmarshal follow event", "error", err)
359 return nil
360 }
361 ud.follows.Insert(rkey, r)
362 }
363 ud.targets.Range(func(sid string, sd *SubscriberData) bool {
364 // if we arent managing then we dont need to update anything
365 if sd.ListenType != ListenTypeFollows {
366 return true
367 }
368 subjectDid := syntax.DID(r.Subject)
369 if deleted {
370 stopListeningTo(sid, subjectDid)
371 delete(sd.ListenTo, subjectDid)
372 } else {
373 sd.ListenTo[subjectDid] = struct{}{}
374 startListeningTo(sid, sd, subjectDid)
375 }
376 return true
377 })
378 }
379
380 return nil
381}
382
383func unmarshalEvent[v any](event *models.Event, val *v) error {
384 if err := json.Unmarshal(event.Commit.Record, val); err != nil {
385 logger.Error("failed to unmarshal", "error", err, "raw", event.Commit.Record)
386 return nil
387 }
388 return nil
389}