its for when you want to get like notifications for your reposts
1package main
2
3import (
4 "context"
5 "encoding/json"
6 "log"
7 "log/slog"
8 "net/http"
9
10 "github.com/bluesky-social/indigo/api/bsky"
11 "github.com/bluesky-social/indigo/xrpc"
12 "github.com/bluesky-social/jetstream/pkg/client"
13 "github.com/bluesky-social/jetstream/pkg/models"
14 "github.com/cornelk/hashmap"
15 "github.com/gorilla/mux"
16 "github.com/gorilla/websocket"
17)
18
19type Set[T comparable] map[T]struct{}
20
21// Data structures
22type SubscriberData struct {
23 DID string
24 Conn *websocket.Conn
25 ListenTo Set[string]
26 Reposts Set[string]
27}
28
29type NotificationMessage struct {
30 Liked bool `json:"liked"`
31 ByDid string `json:"did"`
32 RepostURI string `json:"repost_uri"`
33}
34
35// Global state
36var (
37 subscribers = hashmap.New[string, *SubscriberData]()
38 listeningTo = hashmap.New[string, *hashmap.Map[string, *SubscriberData]]()
39
40 likeStream *client.Client
41 subscriberStream *client.Client
42
43 upgrader = websocket.Upgrader{
44 CheckOrigin: func(r *http.Request) bool {
45 return true
46 },
47 }
48
49 logger *slog.Logger
50)
51
52func getFollowsDids() []string {
53 var dids []string
54 subscribers.Range(func(s string, sd *SubscriberData) bool {
55 for follow, _ := range sd.ListenTo {
56 dids = append(dids, follow)
57 }
58 return true
59 })
60 return dids
61}
62
63func getSubscriberDids() []string {
64 dids := make([]string, 0, subscribers.Len())
65 subscribers.Range(func(s string, sd *SubscriberData) bool {
66 dids = append(dids, s)
67 return true
68 })
69 return dids
70}
71
72func main() {
73 logger = slog.Default()
74
75 go likeStreamLoop(logger)
76 go subscriberStreamLoop(logger)
77
78 r := mux.NewRouter()
79 r.HandleFunc("/subscribe/{did}", handleSubscribe).Methods("GET")
80
81 log.Println("Server starting on :8080")
82 if err := http.ListenAndServe(":8080", r); err != nil {
83 log.Fatalf("error while serving: %s", err)
84 }
85}
86
87func handleSubscribe(w http.ResponseWriter, r *http.Request) {
88 vars := mux.Vars(r)
89 did := vars["did"]
90
91 logger = logger.With("did", did)
92
93 conn, err := upgrader.Upgrade(w, r, nil)
94 if err != nil {
95 logger.Error("WebSocket upgrade failed", "error", err)
96 return
97 }
98 defer conn.Close()
99
100 logger.Info("new subscriber")
101
102 pdsURI, err := findUserPDS(r.Context(), did)
103 if err != nil {
104 logger.Error("cant resolve user pds", "error", err)
105 return
106 }
107 logger = logger.With("pds", pdsURI)
108
109 xrpcClient := &xrpc.Client{
110 Host: pdsURI,
111 }
112 // todo: implement skipping fetching follows and allow specifying users to listen to via websocket
113 follows, err := fetchFollows(r.Context(), xrpcClient, did)
114 if err != nil {
115 logger.Error("error fetching follows", "error", err)
116 return
117 }
118 logger.Info("fetched follows")
119 reposts, err := fetchReposts(r.Context(), xrpcClient, did)
120 if err != nil {
121 logger.Error("error fetching reposts", "error", err)
122 return
123 }
124 logger.Info("fetched reposts")
125
126 sd := &SubscriberData{
127 DID: did,
128 Conn: conn,
129 // use user follows as default listen to
130 ListenTo: follows,
131 Reposts: reposts,
132 }
133
134 subscribers.Set(sd.DID, sd)
135 for listenDid := range sd.ListenTo {
136 listenTo(sd, listenDid)
137 }
138
139 updateSubscriberStreamOpts()
140 updateLikeStreamOpts()
141 // delete subscriber after we are done
142 defer func() {
143 for listenDid := range sd.ListenTo {
144 stopListeningTo(sd.DID, listenDid)
145 }
146 subscribers.Del(sd.DID)
147
148 updateSubscriberStreamOpts()
149 updateLikeStreamOpts()
150 }()
151
152 logger.Info("serving subscriber")
153
154 for {
155 _, _, err := conn.ReadMessage()
156 if err != nil {
157 logger.Info("WebSocket connection closed", "error", err)
158 break
159 }
160 }
161}
162
163func listenTo(sd *SubscriberData, did string) {
164 targetDids, _ := listeningTo.GetOrInsert(did, hashmap.New[string, *SubscriberData]())
165 targetDids.Insert(sd.DID, sd)
166}
167
168func stopListeningTo(subscriberDid, did string) {
169 if targetDids, exists := listeningTo.Get(did); exists {
170 targetDids.Del(subscriberDid)
171 }
172}
173
174func getLikeStreamOpts() models.SubscriberOptionsUpdatePayload {
175 return models.SubscriberOptionsUpdatePayload{
176 WantedCollections: []string{"app.bsky.feed.like"},
177 // WantedDIDs: getFollowsDids(),
178 }
179}
180
181func getSubscriberStreamOpts() models.SubscriberOptionsUpdatePayload {
182 return models.SubscriberOptionsUpdatePayload{
183 WantedCollections: []string{"app.bsky.feed.repost", "app.bsky.graph.follow"},
184 WantedDIDs: getSubscriberDids(),
185 }
186}
187
188func updateLikeStreamOpts() {
189 opts := getLikeStreamOpts()
190 err := likeStream.SendOptionsUpdate(opts)
191 if err != nil {
192 logger.Error("couldnt update like stream opts", "error", err)
193 return
194 }
195 logger.Info("updated like stream opts", "requestedDids", len(opts.WantedDIDs))
196}
197
198func updateSubscriberStreamOpts() {
199 opts := getSubscriberStreamOpts()
200 err := subscriberStream.SendOptionsUpdate(opts)
201 if err != nil {
202 logger.Error("couldnt update subscriber stream opts", "error", err)
203 return
204 }
205 logger.Info("updated subscriber stream opts", "userCount", len(opts.WantedDIDs))
206}
207
208func likeStreamLoop(logger *slog.Logger) {
209 startJetstreamLoop(logger, &likeStream, "like_tracker", HandleLikeEvent, getLikeStreamOpts)
210}
211
212func subscriberStreamLoop(logger *slog.Logger) {
213 startJetstreamLoop(logger, &subscriberStream, "subscriber", HandleSubscriberEvent, getSubscriberStreamOpts)
214}
215
216func HandleLikeEvent(ctx context.Context, event *models.Event) error {
217 if event == nil || event.Commit == nil || len(event.Commit.Record) == 0 {
218 return nil
219 }
220
221 // skip handling event if its not from a source we are listening to
222 targets, exists := listeningTo.Get(event.Did)
223 if !exists {
224 return nil
225 }
226
227 var like bsky.FeedLike
228 if err := json.Unmarshal(event.Commit.Record, &like); err != nil {
229 logger.Error("failed to unmarshal like", "error", err)
230 return nil
231 }
232
233 targets.Range(func(s string, sd *SubscriberData) bool {
234 for repostURI, _ := range sd.Reposts {
235 // (un)liked a post the subscriber reposted
236 if like.Subject.Uri == repostURI {
237 notification := NotificationMessage{
238 Liked: event.Commit.Operation != models.CommitOperationDelete,
239 ByDid: event.Did,
240 RepostURI: repostURI,
241 }
242
243 if err := sd.Conn.WriteJSON(notification); err != nil {
244 logger.Error("failed to send notification", "subscriber", sd.DID, "error", err)
245 }
246 }
247 }
248 return true
249 })
250
251 return nil
252}
253
254func HandleSubscriberEvent(ctx context.Context, event *models.Event) error {
255 if event == nil || event.Commit == nil {
256 return nil
257 }
258
259 switch event.Commit.Collection {
260 case "app.bsky.feed.repost":
261 modifySubscribersWithEvent(
262 event,
263 func(s *SubscriberData, r bsky.FeedRepost) { delete(s.Reposts, r.Subject.Uri) },
264 func(s *SubscriberData, r bsky.FeedRepost) {
265 s.Reposts[r.Subject.Uri] = struct{}{}
266 },
267 )
268 case "app.bsky.graph.follow":
269 modifySubscribersWithEvent(
270 event,
271 func(s *SubscriberData, r bsky.GraphFollow) {
272 delete(s.ListenTo, r.Subject)
273 stopListeningTo(s.DID, r.Subject)
274 },
275 func(s *SubscriberData, r bsky.GraphFollow) {
276 s.ListenTo[r.Subject] = struct{}{}
277 listenTo(s, r.Subject)
278 },
279 )
280 }
281
282 return nil
283}
284
285type ModifyFunc[v any] func(*SubscriberData, v)
286
287func modifySubscribersWithEvent[v any](event *models.Event, onDelete ModifyFunc[v], onUpdate ModifyFunc[v]) error {
288 if len(event.Commit.Record) == 0 {
289 return nil
290 }
291
292 var data v
293 if err := json.Unmarshal(event.Commit.Record, &data); err != nil {
294 logger.Error("Failed to unmarshal repost", "error", err, "raw", event.Commit.Record)
295 return nil
296 }
297
298 if subscriber, exists := subscribers.Get(event.Did); exists {
299 if event.Commit.Operation == models.CommitOperationDelete {
300 onDelete(subscriber, data)
301 } else {
302 onUpdate(subscriber, data)
303 }
304 }
305
306 return nil
307}