its for when you want to get like notifications for your reposts
1package main
2
3import (
4 "context"
5 "encoding/json"
6 "log"
7 "log/slog"
8 "net/http"
9 "sync/atomic"
10
11 "github.com/bluesky-social/indigo/api/bsky"
12 "github.com/bluesky-social/indigo/atproto/syntax"
13 "github.com/bluesky-social/indigo/xrpc"
14 "github.com/bluesky-social/jetstream/pkg/client"
15 "github.com/bluesky-social/jetstream/pkg/models"
16 "github.com/cornelk/hashmap"
17 "github.com/google/uuid"
18 "github.com/gorilla/mux"
19 "github.com/gorilla/websocket"
20)
21
22type Set[T comparable] map[T]struct{}
23
24const ListenTypeNone = "none"
25const ListenTypeFollows = "follows"
26
27type SubscriberData struct {
28 SubscribedTo syntax.DID
29 Conn *websocket.Conn
30 ListenType string
31 ListenTo Set[syntax.DID]
32}
33
34type UserData struct {
35 targets *hashmap.Map[string, *SubscriberData]
36 likes map[syntax.RecordKey]bsky.FeedLike
37 follows *hashmap.Map[syntax.RecordKey, bsky.GraphFollow]
38 followsCursor atomic.Pointer[string]
39}
40
41type NotificationMessage struct {
42 Liked bool `json:"liked"`
43 ByDid syntax.DID `json:"did"`
44 RepostURI syntax.ATURI `json:"repost_uri"`
45}
46
47type SubscriberMessage struct {
48 Type string `json:"type"`
49 Content json.RawMessage `json:"content"`
50}
51
52type SubscriberUpdateListenTo struct {
53 ListenTo []syntax.DID `json:"listen_to"`
54}
55
56var (
57 // storing the subscriber data in both Should Be Fine
58 // we dont modify subscriber data at the same time in two places
59 subscribers = hashmap.New[string, *SubscriberData]()
60 userData = hashmap.New[syntax.DID, *UserData]()
61
62 likeStream *client.Client
63 followStream *client.Client
64
65 upgrader = websocket.Upgrader{
66 CheckOrigin: func(r *http.Request) bool {
67 return true
68 },
69 }
70
71 logger *slog.Logger
72)
73
74func getSubscriberDids() []string {
75 dids := make([]string, 0, subscribers.Len())
76 subscribers.Range(func(s string, sd *SubscriberData) bool {
77 dids = append(dids, string(sd.SubscribedTo))
78 return true
79 })
80 return dids
81}
82
83func getUserData(did syntax.DID) *UserData {
84 ud, _ := userData.GetOrInsert(did, &UserData{
85 targets: hashmap.New[string, *SubscriberData](),
86 likes: make(map[syntax.RecordKey]bsky.FeedLike),
87 follows: hashmap.New[syntax.RecordKey, bsky.GraphFollow](),
88 })
89 return ud
90}
91
92func startListeningTo(sid string, sd *SubscriberData, did syntax.DID) {
93 ud := getUserData(did)
94 ud.targets.Insert(sid, sd)
95}
96
97func stopListeningTo(sid string, did syntax.DID) {
98 if ud, exists := userData.Get(did); exists {
99 ud.targets.Del(sid)
100 }
101}
102
103func main() {
104 logger = slog.Default()
105
106 go startJetstreamLoop(logger, &likeStream, "like_tracker", HandleLikeEvent, getLikeStreamOpts)
107 go startJetstreamLoop(logger, &followStream, "subscriber", HandleFollowEvent, getFollowStreamOpts)
108
109 r := mux.NewRouter()
110 r.HandleFunc("/subscribe/{did}", handleSubscribe).Methods("GET")
111
112 log.Println("server starting on :8080")
113 if err := http.ListenAndServe(":8080", r); err != nil {
114 log.Fatalf("error while serving: %s", err)
115 }
116}
117
118func handleSubscribe(w http.ResponseWriter, r *http.Request) {
119 vars := mux.Vars(r)
120 did, err := syntax.ParseDID(vars["did"])
121 if err != nil {
122 http.Error(w, "not a valid did", http.StatusBadRequest)
123 return
124 }
125 sid := uuid.New().String()
126
127 query := r.URL.Query()
128 listenType := query.Get("listenTo")
129 if len(listenType) == 0 {
130 listenType = ListenTypeFollows
131 }
132
133 logger := logger.With("did", did, "subscriberId", sid)
134
135 conn, err := upgrader.Upgrade(w, r, nil)
136 if err != nil {
137 logger.Error("WebSocket upgrade failed", "error", err)
138 return
139 }
140 defer conn.Close()
141
142 logger.Info("new subscriber")
143
144 pdsURI, err := findUserPDS(r.Context(), did)
145 if err != nil {
146 logger.Error("cant resolve user pds", "error", err)
147 return
148 }
149 logger = logger.With("pds", pdsURI)
150
151 xrpcClient := &xrpc.Client{
152 Host: pdsURI,
153 }
154
155 ud := getUserData(did)
156 sd := &SubscriberData{
157 SubscribedTo: did,
158 Conn: conn,
159 ListenType: listenType,
160 }
161
162 switch listenType {
163 case ListenTypeFollows:
164 follows, err := fetchFollows(r.Context(), xrpcClient, ud.followsCursor.Load(), did)
165 if err != nil {
166 logger.Error("error fetching follows", "error", err)
167 return
168 }
169 sd.ListenTo = make(Set[syntax.DID])
170 if len(follows) > 0 {
171 // store cursor for later requests so we dont have to fetch the whole thing again
172 ud.followsCursor.Store((*string)(&follows[len(follows)-1].rkey))
173 for _, f := range follows {
174 ud.follows.Insert(f.rkey, f.follow)
175 sd.ListenTo[syntax.DID(f.follow.Subject)] = struct{}{}
176 }
177 }
178 logger.Info("fetched follows")
179 case ListenTypeNone:
180 sd.ListenTo = make(Set[syntax.DID])
181 default:
182 http.Error(w, "invalid listen type", http.StatusBadRequest)
183 return
184 }
185
186 subscribers.Set(sid, sd)
187 for listenDid := range sd.ListenTo {
188 startListeningTo(sid, sd, listenDid)
189 }
190 updateFollowStreamOpts()
191 // delete subscriber after we are done
192 defer func() {
193 for listenDid := range sd.ListenTo {
194 stopListeningTo(sid, listenDid)
195 }
196 subscribers.Del(sid)
197 updateFollowStreamOpts()
198 }()
199
200 logger.Info("serving subscriber")
201
202 for {
203 var msg SubscriberMessage
204 err := conn.ReadJSON(&msg)
205 if err != nil {
206 logger.Info("WebSocket connection closed", "error", err)
207 break
208 }
209 switch msg.Type {
210 case "update_listen_to":
211 // only allow this if we arent managing listen to
212 if sd.ListenType != ListenTypeNone {
213 continue
214 }
215
216 var innerMsg SubscriberUpdateListenTo
217 if err := json.Unmarshal(msg.Content, &innerMsg); err != nil {
218 logger.Info("invalid message", "error", err)
219 break
220 }
221 // remove all current listens and add the ones the user requested
222 for listenDid := range sd.ListenTo {
223 stopListeningTo(sid, listenDid)
224 delete(sd.ListenTo, listenDid)
225 }
226 for _, listenDid := range innerMsg.ListenTo {
227 sd.ListenTo[listenDid] = struct{}{}
228 startListeningTo(sid, sd, listenDid)
229 }
230 }
231 }
232}
233
234func getLikeStreamOpts() models.SubscriberOptionsUpdatePayload {
235 return models.SubscriberOptionsUpdatePayload{
236 WantedCollections: []string{"app.bsky.feed.like"},
237 }
238}
239
240func getFollowStreamOpts() models.SubscriberOptionsUpdatePayload {
241 return models.SubscriberOptionsUpdatePayload{
242 WantedCollections: []string{"app.bsky.graph.follow"},
243 WantedDIDs: getSubscriberDids(),
244 }
245}
246
247func updateFollowStreamOpts() {
248 opts := getFollowStreamOpts()
249 err := followStream.SendOptionsUpdate(opts)
250 if err != nil {
251 logger.Error("couldnt update follow stream opts", "error", err)
252 return
253 }
254 logger.Info("updated follow stream opts", "userCount", len(opts.WantedDIDs))
255}
256
257func HandleLikeEvent(ctx context.Context, event *models.Event) error {
258 if event == nil || event.Commit == nil {
259 return nil
260 }
261
262 byDid := syntax.DID(event.Did)
263 // skip handling event if its not from a source we are listening to
264 ud, exists := userData.Get(byDid)
265 if !exists || ud.targets.Len() == 0 {
266 return nil
267 }
268
269 deleted := event.Commit.Operation == models.CommitOperationDelete
270 rkey := syntax.RecordKey(event.Commit.RKey)
271
272 var like bsky.FeedLike
273 if deleted {
274 if l, exists := ud.likes[rkey]; exists {
275 like = l
276 defer delete(ud.likes, rkey)
277 } else {
278 logger.Error("like record not found", "rkey", rkey)
279 return nil
280 }
281 } else {
282 if err := json.Unmarshal(event.Commit.Record, &like); err != nil {
283 logger.Error("failed to unmarshal like", "error", err)
284 return nil
285 }
286 }
287
288 // if there is no via it means its not a repost anyway
289 if like.Via == nil {
290 return nil
291 }
292
293 // store for later when it gets deleted so we can fetch the record
294 if !deleted {
295 ud.likes[rkey] = like
296 }
297
298 repostURI := syntax.ATURI(like.Via.Uri)
299 // if not a repost we dont care
300 if repostURI.Collection() != "app.bsky.feed.repost" {
301 return nil
302 }
303 reposterDID, err := repostURI.Authority().AsDID()
304 if err != nil {
305 return err
306 }
307 ud.targets.Range(func(sid string, sd *SubscriberData) bool {
308 if sd.SubscribedTo != reposterDID {
309 return true
310 }
311
312 notification := NotificationMessage{
313 Liked: !deleted,
314 ByDid: byDid,
315 RepostURI: repostURI,
316 }
317
318 if err := sd.Conn.WriteJSON(notification); err != nil {
319 logger.Error("failed to send notification", "subscriber", sd.SubscribedTo, "error", err)
320 }
321 return true
322 })
323
324 return nil
325}
326
327func HandleFollowEvent(ctx context.Context, event *models.Event) error {
328 if event == nil || event.Commit == nil {
329 return nil
330 }
331
332 byDid := syntax.DID(event.Did)
333 ud, exists := userData.Get(byDid)
334 if !exists || ud.targets.Len() == 0 {
335 return nil
336 }
337
338 deleted := event.Commit.Operation == models.CommitOperationDelete
339 rkey := syntax.RecordKey(event.Commit.RKey)
340
341 switch event.Commit.Collection {
342 case "app.bsky.graph.follow":
343 var r bsky.GraphFollow
344 if deleted {
345 if f, exists := ud.follows.Get(rkey); exists {
346 r = f
347 } else {
348 logger.Error("follow record not found", "rkey", rkey)
349 return nil
350 }
351 ud.follows.Del(rkey)
352 } else {
353 if err := unmarshalEvent(event, &r); err != nil {
354 logger.Error("could not unmarshal follow event", "error", err)
355 return nil
356 }
357 ud.follows.Insert(rkey, r)
358 }
359 ud.targets.Range(func(sid string, sd *SubscriberData) bool {
360 // if we arent managing then we dont need to update anything
361 if sd.ListenType != ListenTypeFollows {
362 return true
363 }
364 subjectDid := syntax.DID(r.Subject)
365 if deleted {
366 stopListeningTo(sid, subjectDid)
367 delete(sd.ListenTo, subjectDid)
368 } else {
369 sd.ListenTo[subjectDid] = struct{}{}
370 startListeningTo(sid, sd, subjectDid)
371 }
372 return true
373 })
374 }
375
376 return nil
377}
378
379func unmarshalEvent[v any](event *models.Event, val *v) error {
380 if err := json.Unmarshal(event.Commit.Record, val); err != nil {
381 logger.Error("failed to unmarshal", "error", err, "raw", event.Commit.Record)
382 return nil
383 }
384 return nil
385}