its for when you want to get like notifications for your reposts
1package main
2
3import (
4 "context"
5 "encoding/json"
6 "log"
7 "log/slog"
8 "net/http"
9
10 "github.com/bluesky-social/indigo/api/bsky"
11 "github.com/bluesky-social/indigo/atproto/syntax"
12 "github.com/bluesky-social/indigo/xrpc"
13 "github.com/bluesky-social/jetstream/pkg/client"
14 "github.com/bluesky-social/jetstream/pkg/models"
15 "github.com/cornelk/hashmap"
16 "github.com/gorilla/mux"
17 "github.com/gorilla/websocket"
18)
19
20type Set[T comparable] map[T]struct{}
21
22const ListenTypeNone = "none"
23const ListenTypeFollows = "follows"
24
25type SubscriberData struct {
26 DID syntax.DID
27 Conn *websocket.Conn
28 ListenType string
29 ListenTo Set[syntax.DID]
30 follows map[syntax.RecordKey]bsky.GraphFollow
31}
32
33type ListeneeData struct {
34 targets *hashmap.Map[syntax.DID, *SubscriberData]
35 likes map[syntax.RecordKey]bsky.FeedLike
36}
37
38type NotificationMessage struct {
39 Liked bool `json:"liked"`
40 ByDid syntax.DID `json:"did"`
41 RepostURI syntax.ATURI `json:"repost_uri"`
42}
43
44type SubscriberMessage struct {
45 Type string `json:"type"`
46 Content json.RawMessage `json:"content"`
47}
48
49type SubscriberUpdateListenTo struct {
50 ListenTo []syntax.DID `json:"listen_to"`
51}
52
53var (
54 // storing the subscriber data in both Should Be Fine
55 // we dont modify subscriber data at the same time in two places
56 subscribers = hashmap.New[syntax.DID, *SubscriberData]()
57 listeningTo = hashmap.New[syntax.DID, *ListeneeData]()
58
59 likeStream *client.Client
60 subscriberStream *client.Client
61
62 upgrader = websocket.Upgrader{
63 CheckOrigin: func(r *http.Request) bool {
64 return true
65 },
66 }
67
68 logger *slog.Logger
69)
70
71func getSubscriberDids() []string {
72 dids := make([]string, 0, subscribers.Len())
73 subscribers.Range(func(s syntax.DID, sd *SubscriberData) bool {
74 dids = append(dids, string(s))
75 return true
76 })
77 return dids
78}
79
80func startListeningTo(sd *SubscriberData, did syntax.DID) {
81 ld, _ := listeningTo.GetOrInsert(did, &ListeneeData{
82 targets: hashmap.New[syntax.DID, *SubscriberData](),
83 likes: make(map[syntax.RecordKey]bsky.FeedLike),
84 })
85 ld.targets.Insert(sd.DID, sd)
86}
87
88func stopListeningTo(subscriberDid, did syntax.DID) {
89 if ld, exists := listeningTo.Get(did); exists {
90 ld.targets.Del(subscriberDid)
91 }
92}
93
94func main() {
95 logger = slog.Default()
96
97 go startJetstreamLoop(logger, &likeStream, "like_tracker", HandleLikeEvent, getLikeStreamOpts)
98 go startJetstreamLoop(logger, &subscriberStream, "subscriber", HandleSubscriberEvent, getSubscriberStreamOpts)
99
100 r := mux.NewRouter()
101 r.HandleFunc("/subscribe/{did}", handleSubscribe).Methods("GET")
102
103 log.Println("server starting on :8080")
104 if err := http.ListenAndServe(":8080", r); err != nil {
105 log.Fatalf("error while serving: %s", err)
106 }
107}
108
109func handleSubscribe(w http.ResponseWriter, r *http.Request) {
110 vars := mux.Vars(r)
111 did, err := syntax.ParseDID(vars["did"])
112 if err != nil {
113 http.Error(w, "not a valid did", http.StatusBadRequest)
114 return
115 }
116
117 query := r.URL.Query()
118 listenType := query.Get("listenTo")
119 if len(listenType) == 0 {
120 listenType = ListenTypeFollows
121 }
122
123 logger := logger.With("did", did)
124
125 conn, err := upgrader.Upgrade(w, r, nil)
126 if err != nil {
127 logger.Error("WebSocket upgrade failed", "error", err)
128 return
129 }
130 defer conn.Close()
131
132 logger.Info("new subscriber")
133
134 pdsURI, err := findUserPDS(r.Context(), did)
135 if err != nil {
136 logger.Error("cant resolve user pds", "error", err)
137 return
138 }
139 logger = logger.With("pds", pdsURI)
140
141 xrpcClient := &xrpc.Client{
142 Host: pdsURI,
143 }
144
145 sd := &SubscriberData{
146 DID: did,
147 Conn: conn,
148 ListenType: listenType,
149 }
150
151 switch listenType {
152 case ListenTypeFollows:
153 follows, err := fetchFollows(r.Context(), xrpcClient, did)
154 if err != nil {
155 logger.Error("error fetching follows", "error", err)
156 return
157 }
158 logger.Info("fetched follows")
159 sd.follows = follows
160 sd.ListenTo = make(Set[syntax.DID])
161 for _, follow := range follows {
162 sd.ListenTo[syntax.DID(follow.Subject)] = struct{}{}
163 }
164 case ListenTypeNone:
165 sd.ListenTo = make(Set[syntax.DID])
166 default:
167 http.Error(w, "invalid listen type", http.StatusBadRequest)
168 return
169 }
170
171 subscribers.Set(sd.DID, sd)
172 for listenDid := range sd.ListenTo {
173 startListeningTo(sd, listenDid)
174 }
175 updateSubscriberStreamOpts()
176 // delete subscriber after we are done
177 defer func() {
178 for listenDid := range sd.ListenTo {
179 stopListeningTo(sd.DID, listenDid)
180 }
181 subscribers.Del(sd.DID)
182 updateSubscriberStreamOpts()
183 }()
184
185 logger.Info("serving subscriber")
186
187 for {
188 var msg SubscriberMessage
189 err := conn.ReadJSON(&msg)
190 if err != nil {
191 logger.Info("WebSocket connection closed", "error", err)
192 break
193 }
194 switch msg.Type {
195 case "update_listen_to":
196 // only allow this if we arent managing listen to
197 if sd.ListenType != ListenTypeNone {
198 continue
199 }
200
201 var innerMsg SubscriberUpdateListenTo
202 if err := json.Unmarshal(msg.Content, &innerMsg); err != nil {
203 logger.Info("invalid message", "error", err)
204 break
205 }
206 // remove all current listens and add the ones the user requested
207 for listenDid := range sd.ListenTo {
208 stopListeningTo(sd.DID, listenDid)
209 delete(sd.ListenTo, listenDid)
210 }
211 for _, listenDid := range innerMsg.ListenTo {
212 sd.ListenTo[listenDid] = struct{}{}
213 startListeningTo(sd, listenDid)
214 }
215 }
216 }
217}
218
219func getLikeStreamOpts() models.SubscriberOptionsUpdatePayload {
220 return models.SubscriberOptionsUpdatePayload{
221 WantedCollections: []string{"app.bsky.feed.like"},
222 }
223}
224
225func getSubscriberStreamOpts() models.SubscriberOptionsUpdatePayload {
226 return models.SubscriberOptionsUpdatePayload{
227 WantedCollections: []string{"app.bsky.feed.repost", "app.bsky.graph.follow"},
228 WantedDIDs: getSubscriberDids(),
229 }
230}
231
232func updateSubscriberStreamOpts() {
233 opts := getSubscriberStreamOpts()
234 err := subscriberStream.SendOptionsUpdate(opts)
235 if err != nil {
236 logger.Error("couldnt update subscriber stream opts", "error", err)
237 return
238 }
239 logger.Info("updated subscriber stream opts", "userCount", len(opts.WantedDIDs))
240}
241
242func HandleLikeEvent(ctx context.Context, event *models.Event) error {
243 if event == nil || event.Commit == nil {
244 return nil
245 }
246
247 byDid := syntax.DID(event.Did)
248 // skip handling event if its not from a source we are listening to
249 ld, exists := listeningTo.Get(byDid)
250 if !exists {
251 return nil
252 }
253
254 deleted := event.Commit.Operation == models.CommitOperationDelete
255 rkey := syntax.RecordKey(event.Commit.RKey)
256
257 var like bsky.FeedLike
258 if deleted {
259 if l, exists := ld.likes[rkey]; exists {
260 like = l
261 defer delete(ld.likes, rkey)
262 } else {
263 logger.Error("like record not found", "rkey", rkey)
264 return nil
265 }
266 } else {
267 if err := json.Unmarshal(event.Commit.Record, &like); err != nil {
268 logger.Error("failed to unmarshal like", "error", err)
269 return nil
270 }
271 }
272
273 // if there is no via it means its not a repost anyway
274 if like.Via == nil {
275 return nil
276 }
277
278 // store for later when it gets deleted so we can fetch the record
279 if !deleted {
280 ld.likes[rkey] = like
281 }
282
283 repostURI := syntax.ATURI(like.Via.Uri)
284 // if not a repost we dont care
285 if repostURI.Collection() != "app.bsky.feed.repost" {
286 return nil
287 }
288 reposterDID, err := repostURI.Authority().AsDID()
289 if err != nil {
290 return err
291 }
292 if sd, exists := ld.targets.Get(reposterDID); exists {
293 notification := NotificationMessage{
294 Liked: !deleted,
295 ByDid: byDid,
296 RepostURI: repostURI,
297 }
298
299 if err := sd.Conn.WriteJSON(notification); err != nil {
300 logger.Error("failed to send notification", "subscriber", sd.DID, "error", err)
301 }
302 }
303
304 return nil
305}
306
307func HandleSubscriberEvent(ctx context.Context, event *models.Event) error {
308 if event == nil || event.Commit == nil {
309 return nil
310 }
311
312 byDid := syntax.DID(event.Did)
313 sd, exists := subscribers.Get(byDid)
314 if !exists {
315 return nil
316 }
317
318 deleted := event.Commit.Operation == models.CommitOperationDelete
319 rkey := syntax.RecordKey(event.Commit.RKey)
320
321 switch event.Commit.Collection {
322 case "app.bsky.graph.follow":
323 // if we arent managing then we dont need to update anything
324 if sd.ListenType != ListenTypeFollows {
325 return nil
326 }
327 var r bsky.GraphFollow
328 if deleted {
329 if f, exists := sd.follows[rkey]; exists {
330 r = f
331 } else {
332 logger.Error("follow record not found", "rkey", rkey)
333 return nil
334 }
335 subjectDid := syntax.DID(r.Subject)
336 stopListeningTo(sd.DID, subjectDid)
337 delete(sd.ListenTo, subjectDid)
338 delete(sd.follows, rkey)
339 } else {
340 if err := unmarshalEvent(event, &r); err != nil {
341 return err
342 }
343 subjectDid := syntax.DID(r.Subject)
344 sd.ListenTo[subjectDid] = struct{}{}
345 sd.follows[rkey] = r
346 startListeningTo(sd, subjectDid)
347 }
348 }
349
350 return nil
351}
352
353func unmarshalEvent[v any](event *models.Event, val *v) error {
354 if err := json.Unmarshal(event.Commit.Record, val); err != nil {
355 logger.Error("failed to unmarshal", "error", err, "raw", event.Commit.Record)
356 return nil
357 }
358 return nil
359}