its for when you want to get like notifications for your reposts
1package main
2
3import (
4 "context"
5 "encoding/json"
6 "log"
7 "log/slog"
8 "net/http"
9 "sync/atomic"
10
11 "github.com/bluesky-social/indigo/api/bsky"
12 "github.com/bluesky-social/indigo/atproto/syntax"
13 "github.com/bluesky-social/indigo/xrpc"
14 "github.com/bluesky-social/jetstream/pkg/client"
15 "github.com/bluesky-social/jetstream/pkg/models"
16 "github.com/cornelk/hashmap"
17 "github.com/google/uuid"
18 "github.com/gorilla/mux"
19 "github.com/gorilla/websocket"
20)
21
22type Set[T comparable] map[T]struct{}
23
24const ListenTypeNone = "none"
25const ListenTypeFollows = "follows"
26
27type SubscriberData struct {
28 SubscribedTo syntax.DID
29 Conn *websocket.Conn
30 ListenType string
31 ListenTo Set[syntax.DID]
32}
33
34type UserData struct {
35 targets *hashmap.Map[string, *SubscriberData]
36 likes map[syntax.RecordKey]bsky.FeedLike
37 follows *hashmap.Map[syntax.RecordKey, bsky.GraphFollow]
38 followsCursor atomic.Pointer[string]
39}
40
41type NotificationMessage struct {
42 Liked bool `json:"liked"`
43 ByDid syntax.DID `json:"did"`
44 RepostURI syntax.ATURI `json:"repost_uri"`
45 PostURI syntax.ATURI `json:"post_uri"`
46}
47
48type SubscriberMessage struct {
49 Type string `json:"type"`
50 Content json.RawMessage `json:"content"`
51}
52
53type SubscriberUpdateListenTo struct {
54 ListenTo []syntax.DID `json:"listen_to"`
55}
56
57var (
58 // storing the subscriber data in both Should Be Fine
59 // we dont modify subscriber data at the same time in two places
60 subscribers = hashmap.New[string, *SubscriberData]()
61 userData = hashmap.New[syntax.DID, *UserData]()
62
63 likeStream *client.Client
64 followStream *client.Client
65
66 upgrader = websocket.Upgrader{
67 CheckOrigin: func(r *http.Request) bool {
68 return true
69 },
70 }
71
72 logger *slog.Logger
73)
74
75func getSubscriberDids() []string {
76 _dids := make(Set[string], subscribers.Len())
77 subscribers.Range(func(s string, sd *SubscriberData) bool {
78 _dids[string(sd.SubscribedTo)] = struct{}{}
79 return true
80 })
81 dids := make([]string, 0, len(_dids))
82 for k := range _dids {
83 dids = append(dids, k)
84 }
85 return dids
86}
87
88func getUserData(did syntax.DID) *UserData {
89 ud, _ := userData.GetOrInsert(did, &UserData{
90 targets: hashmap.New[string, *SubscriberData](),
91 likes: make(map[syntax.RecordKey]bsky.FeedLike),
92 follows: hashmap.New[syntax.RecordKey, bsky.GraphFollow](),
93 })
94 return ud
95}
96
97func startListeningTo(sid string, sd *SubscriberData, did syntax.DID) {
98 ud := getUserData(did)
99 ud.targets.Insert(sid, sd)
100}
101
102func stopListeningTo(sid string, did syntax.DID) {
103 if ud, exists := userData.Get(did); exists {
104 ud.targets.Del(sid)
105 }
106}
107
108func main() {
109 logger = slog.Default()
110
111 go startJetstreamLoop(logger, &likeStream, "like_tracker", HandleLikeEvent, getLikeStreamOpts)
112 go startJetstreamLoop(logger, &followStream, "subscriber", HandleFollowEvent, getFollowStreamOpts)
113
114 r := mux.NewRouter()
115 r.HandleFunc("/subscribe/{did}", handleSubscribe).Methods("GET")
116
117 log.Println("server starting on :8080")
118 if err := http.ListenAndServe(":8080", r); err != nil {
119 log.Fatalf("error while serving: %s", err)
120 }
121}
122
123func handleSubscribe(w http.ResponseWriter, r *http.Request) {
124 vars := mux.Vars(r)
125 did, err := syntax.ParseDID(vars["did"])
126 if err != nil {
127 http.Error(w, "not a valid did", http.StatusBadRequest)
128 return
129 }
130 sid := uuid.New().String()
131
132 query := r.URL.Query()
133 listenType := query.Get("listenTo")
134 if len(listenType) == 0 {
135 listenType = ListenTypeFollows
136 }
137
138 logger := logger.With("did", did, "subscriberId", sid)
139
140 conn, err := upgrader.Upgrade(w, r, nil)
141 if err != nil {
142 logger.Error("WebSocket upgrade failed", "error", err)
143 return
144 }
145 defer conn.Close()
146
147 logger.Info("new subscriber")
148
149 pdsURI, err := findUserPDS(r.Context(), did)
150 if err != nil {
151 logger.Error("cant resolve user pds", "error", err)
152 return
153 }
154 logger = logger.With("pds", pdsURI)
155
156 xrpcClient := &xrpc.Client{
157 Host: pdsURI,
158 }
159
160 ud := getUserData(did)
161 sd := &SubscriberData{
162 SubscribedTo: did,
163 Conn: conn,
164 ListenType: listenType,
165 }
166
167 switch listenType {
168 case ListenTypeFollows:
169 follows, err := fetchFollows(r.Context(), xrpcClient, ud.followsCursor.Load(), did)
170 if err != nil {
171 logger.Error("error fetching follows", "error", err)
172 return
173 }
174 sd.ListenTo = make(Set[syntax.DID])
175 // use we have stored
176 ud.follows.Range(func(rk syntax.RecordKey, f bsky.GraphFollow) bool {
177 sd.ListenTo[syntax.DID(f.Subject)] = struct{}{}
178 return true
179 })
180 if len(follows) > 0 {
181 // store cursor for later requests so we dont have to fetch the whole thing again
182 ud.followsCursor.Store((*string)(&follows[len(follows)-1].rkey))
183 for _, f := range follows {
184 ud.follows.Insert(f.rkey, f.follow)
185 sd.ListenTo[syntax.DID(f.follow.Subject)] = struct{}{}
186 }
187 }
188 logger.Info("fetched follows")
189 case ListenTypeNone:
190 sd.ListenTo = make(Set[syntax.DID])
191 default:
192 http.Error(w, "invalid listen type", http.StatusBadRequest)
193 return
194 }
195
196 subscribers.Set(sid, sd)
197 for listenDid := range sd.ListenTo {
198 startListeningTo(sid, sd, listenDid)
199 }
200 updateFollowStreamOpts()
201 // delete subscriber after we are done
202 defer func() {
203 for listenDid := range sd.ListenTo {
204 stopListeningTo(sid, listenDid)
205 }
206 subscribers.Del(sid)
207 updateFollowStreamOpts()
208 }()
209
210 logger.Info("serving subscriber")
211
212 for {
213 var msg SubscriberMessage
214 err := conn.ReadJSON(&msg)
215 if err != nil {
216 logger.Info("WebSocket connection closed", "error", err)
217 break
218 }
219 switch msg.Type {
220 case "update_listen_to":
221 // only allow this if we arent managing listen to
222 if sd.ListenType != ListenTypeNone {
223 continue
224 }
225
226 var innerMsg SubscriberUpdateListenTo
227 if err := json.Unmarshal(msg.Content, &innerMsg); err != nil {
228 logger.Info("invalid message", "error", err)
229 break
230 }
231 // remove all current listens and add the ones the user requested
232 for listenDid := range sd.ListenTo {
233 stopListeningTo(sid, listenDid)
234 delete(sd.ListenTo, listenDid)
235 }
236 for _, listenDid := range innerMsg.ListenTo {
237 sd.ListenTo[listenDid] = struct{}{}
238 startListeningTo(sid, sd, listenDid)
239 }
240 }
241 }
242}
243
244func getLikeStreamOpts() models.SubscriberOptionsUpdatePayload {
245 return models.SubscriberOptionsUpdatePayload{
246 WantedCollections: []string{"app.bsky.feed.like"},
247 }
248}
249
250func getFollowStreamOpts() models.SubscriberOptionsUpdatePayload {
251 return models.SubscriberOptionsUpdatePayload{
252 WantedCollections: []string{"app.bsky.graph.follow"},
253 WantedDIDs: getSubscriberDids(),
254 }
255}
256
257func updateFollowStreamOpts() {
258 opts := getFollowStreamOpts()
259 err := followStream.SendOptionsUpdate(opts)
260 if err != nil {
261 logger.Error("couldnt update follow stream opts", "error", err)
262 return
263 }
264 logger.Info("updated follow stream opts", "userCount", len(opts.WantedDIDs))
265}
266
267func HandleLikeEvent(ctx context.Context, event *models.Event) error {
268 if event == nil || event.Commit == nil {
269 return nil
270 }
271
272 byDid := syntax.DID(event.Did)
273 // skip handling event if its not from a source we are listening to
274 ud, exists := userData.Get(byDid)
275 if !exists || ud.targets.Len() == 0 {
276 return nil
277 }
278
279 deleted := event.Commit.Operation == models.CommitOperationDelete
280 rkey := syntax.RecordKey(event.Commit.RKey)
281
282 var like bsky.FeedLike
283 if deleted {
284 if l, exists := ud.likes[rkey]; exists {
285 like = l
286 defer delete(ud.likes, rkey)
287 } else {
288 logger.Error("like record not found", "rkey", rkey)
289 return nil
290 }
291 } else {
292 if err := json.Unmarshal(event.Commit.Record, &like); err != nil {
293 logger.Error("failed to unmarshal like", "error", err)
294 return nil
295 }
296 }
297
298 // if there is no via it means its not a repost anyway
299 if like.Via == nil {
300 return nil
301 }
302
303 // store for later when it gets deleted so we can fetch the record
304 if !deleted {
305 ud.likes[rkey] = like
306 }
307
308 repostURI := syntax.ATURI(like.Via.Uri)
309 // if not a repost we dont care
310 if repostURI.Collection() != "app.bsky.feed.repost" {
311 return nil
312 }
313 reposterDID, err := repostURI.Authority().AsDID()
314 if err != nil {
315 return err
316 }
317 ud.targets.Range(func(sid string, sd *SubscriberData) bool {
318 if sd.SubscribedTo != reposterDID {
319 return true
320 }
321
322 notification := NotificationMessage{
323 Liked: !deleted,
324 ByDid: byDid,
325 RepostURI: repostURI,
326 PostURI: syntax.ATURI(like.Subject.Uri),
327 }
328
329 if err := sd.Conn.WriteJSON(notification); err != nil {
330 logger.Error("failed to send notification", "subscriber", sd.SubscribedTo, "error", err)
331 }
332 return true
333 })
334
335 return nil
336}
337
338func HandleFollowEvent(ctx context.Context, event *models.Event) error {
339 if event == nil || event.Commit == nil {
340 return nil
341 }
342
343 byDid := syntax.DID(event.Did)
344 ud, exists := userData.Get(byDid)
345 if !exists {
346 return nil
347 }
348
349 deleted := event.Commit.Operation == models.CommitOperationDelete
350 rkey := syntax.RecordKey(event.Commit.RKey)
351
352 switch event.Commit.Collection {
353 case "app.bsky.graph.follow":
354 var r bsky.GraphFollow
355 if deleted {
356 if f, exists := ud.follows.Get(rkey); exists {
357 r = f
358 } else {
359 logger.Error("follow record not found", "rkey", rkey)
360 return nil
361 }
362 ud.follows.Del(rkey)
363 } else {
364 if err := unmarshalEvent(event, &r); err != nil {
365 logger.Error("could not unmarshal follow event", "error", err)
366 return nil
367 }
368 ud.follows.Insert(rkey, r)
369 }
370 ud.targets.Range(func(sid string, sd *SubscriberData) bool {
371 // if we arent managing then we dont need to update anything
372 if sd.ListenType != ListenTypeFollows {
373 return true
374 }
375 subjectDid := syntax.DID(r.Subject)
376 if deleted {
377 stopListeningTo(sid, subjectDid)
378 delete(sd.ListenTo, subjectDid)
379 } else {
380 sd.ListenTo[subjectDid] = struct{}{}
381 startListeningTo(sid, sd, subjectDid)
382 }
383 return true
384 })
385 }
386
387 return nil
388}
389
390func unmarshalEvent[v any](event *models.Event, val *v) error {
391 if err := json.Unmarshal(event.Commit.Record, val); err != nil {
392 logger.Error("failed to unmarshal", "error", err, "raw", event.Commit.Record)
393 return nil
394 }
395 return nil
396}