its for when you want to get like notifications for your reposts
1package main 2 3import ( 4 "context" 5 "encoding/json" 6 "log" 7 "log/slog" 8 "net/http" 9 "sync/atomic" 10 11 "github.com/bluesky-social/indigo/api/bsky" 12 "github.com/bluesky-social/indigo/atproto/syntax" 13 "github.com/bluesky-social/indigo/xrpc" 14 "github.com/bluesky-social/jetstream/pkg/client" 15 "github.com/bluesky-social/jetstream/pkg/models" 16 "github.com/cornelk/hashmap" 17 "github.com/google/uuid" 18 "github.com/gorilla/mux" 19 "github.com/gorilla/websocket" 20) 21 22type Set[T comparable] map[T]struct{} 23 24const ListenTypeNone = "none" 25const ListenTypeFollows = "follows" 26 27type SubscriberData struct { 28 SubscribedTo syntax.DID 29 Conn *websocket.Conn 30 ListenType string 31 ListenTo Set[syntax.DID] 32} 33 34type UserData struct { 35 targets *hashmap.Map[string, *SubscriberData] 36 likes map[syntax.RecordKey]bsky.FeedLike 37 follows *hashmap.Map[syntax.RecordKey, bsky.GraphFollow] 38 followsCursor atomic.Pointer[string] 39} 40 41type NotificationMessage struct { 42 Liked bool `json:"liked"` 43 ByDid syntax.DID `json:"did"` 44 RepostURI syntax.ATURI `json:"repost_uri"` 45} 46 47type SubscriberMessage struct { 48 Type string `json:"type"` 49 Content json.RawMessage `json:"content"` 50} 51 52type SubscriberUpdateListenTo struct { 53 ListenTo []syntax.DID `json:"listen_to"` 54} 55 56var ( 57 // storing the subscriber data in both Should Be Fine 58 // we dont modify subscriber data at the same time in two places 59 subscribers = hashmap.New[string, *SubscriberData]() 60 userData = hashmap.New[syntax.DID, *UserData]() 61 62 likeStream *client.Client 63 followStream *client.Client 64 65 upgrader = websocket.Upgrader{ 66 CheckOrigin: func(r *http.Request) bool { 67 return true 68 }, 69 } 70 71 logger *slog.Logger 72) 73 74func getSubscriberDids() []string { 75 dids := make([]string, 0, subscribers.Len()) 76 subscribers.Range(func(s string, sd *SubscriberData) bool { 77 dids = append(dids, string(sd.SubscribedTo)) 78 return true 79 }) 80 return dids 81} 82 83func getUserData(did syntax.DID) *UserData { 84 ud, _ := userData.GetOrInsert(did, &UserData{ 85 targets: hashmap.New[string, *SubscriberData](), 86 likes: make(map[syntax.RecordKey]bsky.FeedLike), 87 follows: hashmap.New[syntax.RecordKey, bsky.GraphFollow](), 88 }) 89 return ud 90} 91 92func startListeningTo(sid string, sd *SubscriberData, did syntax.DID) { 93 ud := getUserData(did) 94 ud.targets.Insert(sid, sd) 95} 96 97func stopListeningTo(sid string, did syntax.DID) { 98 if ud, exists := userData.Get(did); exists { 99 ud.targets.Del(sid) 100 } 101} 102 103func main() { 104 logger = slog.Default() 105 106 go startJetstreamLoop(logger, &likeStream, "like_tracker", HandleLikeEvent, getLikeStreamOpts) 107 go startJetstreamLoop(logger, &followStream, "subscriber", HandleFollowEvent, getFollowStreamOpts) 108 109 r := mux.NewRouter() 110 r.HandleFunc("/subscribe/{did}", handleSubscribe).Methods("GET") 111 112 log.Println("server starting on :8080") 113 if err := http.ListenAndServe(":8080", r); err != nil { 114 log.Fatalf("error while serving: %s", err) 115 } 116} 117 118func handleSubscribe(w http.ResponseWriter, r *http.Request) { 119 vars := mux.Vars(r) 120 did, err := syntax.ParseDID(vars["did"]) 121 if err != nil { 122 http.Error(w, "not a valid did", http.StatusBadRequest) 123 return 124 } 125 sid := uuid.New().String() 126 127 query := r.URL.Query() 128 listenType := query.Get("listenTo") 129 if len(listenType) == 0 { 130 listenType = ListenTypeFollows 131 } 132 133 logger := logger.With("did", did, "subscriberId", sid) 134 135 conn, err := upgrader.Upgrade(w, r, nil) 136 if err != nil { 137 logger.Error("WebSocket upgrade failed", "error", err) 138 return 139 } 140 defer conn.Close() 141 142 logger.Info("new subscriber") 143 144 pdsURI, err := findUserPDS(r.Context(), did) 145 if err != nil { 146 logger.Error("cant resolve user pds", "error", err) 147 return 148 } 149 logger = logger.With("pds", pdsURI) 150 151 xrpcClient := &xrpc.Client{ 152 Host: pdsURI, 153 } 154 155 ud := getUserData(did) 156 sd := &SubscriberData{ 157 SubscribedTo: did, 158 Conn: conn, 159 ListenType: listenType, 160 } 161 162 switch listenType { 163 case ListenTypeFollows: 164 follows, err := fetchFollows(r.Context(), xrpcClient, ud.followsCursor.Load(), did) 165 if err != nil { 166 logger.Error("error fetching follows", "error", err) 167 return 168 } 169 sd.ListenTo = make(Set[syntax.DID]) 170 if len(follows) > 0 { 171 // store cursor for later requests so we dont have to fetch the whole thing again 172 ud.followsCursor.Store((*string)(&follows[len(follows)-1].rkey)) 173 for _, f := range follows { 174 ud.follows.Insert(f.rkey, f.follow) 175 sd.ListenTo[syntax.DID(f.follow.Subject)] = struct{}{} 176 } 177 } 178 logger.Info("fetched follows") 179 case ListenTypeNone: 180 sd.ListenTo = make(Set[syntax.DID]) 181 default: 182 http.Error(w, "invalid listen type", http.StatusBadRequest) 183 return 184 } 185 186 subscribers.Set(sid, sd) 187 for listenDid := range sd.ListenTo { 188 startListeningTo(sid, sd, listenDid) 189 } 190 updateFollowStreamOpts() 191 // delete subscriber after we are done 192 defer func() { 193 for listenDid := range sd.ListenTo { 194 stopListeningTo(sid, listenDid) 195 } 196 subscribers.Del(sid) 197 updateFollowStreamOpts() 198 }() 199 200 logger.Info("serving subscriber") 201 202 for { 203 var msg SubscriberMessage 204 err := conn.ReadJSON(&msg) 205 if err != nil { 206 logger.Info("WebSocket connection closed", "error", err) 207 break 208 } 209 switch msg.Type { 210 case "update_listen_to": 211 // only allow this if we arent managing listen to 212 if sd.ListenType != ListenTypeNone { 213 continue 214 } 215 216 var innerMsg SubscriberUpdateListenTo 217 if err := json.Unmarshal(msg.Content, &innerMsg); err != nil { 218 logger.Info("invalid message", "error", err) 219 break 220 } 221 // remove all current listens and add the ones the user requested 222 for listenDid := range sd.ListenTo { 223 stopListeningTo(sid, listenDid) 224 delete(sd.ListenTo, listenDid) 225 } 226 for _, listenDid := range innerMsg.ListenTo { 227 sd.ListenTo[listenDid] = struct{}{} 228 startListeningTo(sid, sd, listenDid) 229 } 230 } 231 } 232} 233 234func getLikeStreamOpts() models.SubscriberOptionsUpdatePayload { 235 return models.SubscriberOptionsUpdatePayload{ 236 WantedCollections: []string{"app.bsky.feed.like"}, 237 } 238} 239 240func getFollowStreamOpts() models.SubscriberOptionsUpdatePayload { 241 return models.SubscriberOptionsUpdatePayload{ 242 WantedCollections: []string{"app.bsky.graph.follow"}, 243 WantedDIDs: getSubscriberDids(), 244 } 245} 246 247func updateFollowStreamOpts() { 248 opts := getFollowStreamOpts() 249 err := followStream.SendOptionsUpdate(opts) 250 if err != nil { 251 logger.Error("couldnt update follow stream opts", "error", err) 252 return 253 } 254 logger.Info("updated follow stream opts", "userCount", len(opts.WantedDIDs)) 255} 256 257func HandleLikeEvent(ctx context.Context, event *models.Event) error { 258 if event == nil || event.Commit == nil { 259 return nil 260 } 261 262 byDid := syntax.DID(event.Did) 263 // skip handling event if its not from a source we are listening to 264 ud, exists := userData.Get(byDid) 265 if !exists || ud.targets.Len() == 0 { 266 return nil 267 } 268 269 deleted := event.Commit.Operation == models.CommitOperationDelete 270 rkey := syntax.RecordKey(event.Commit.RKey) 271 272 var like bsky.FeedLike 273 if deleted { 274 if l, exists := ud.likes[rkey]; exists { 275 like = l 276 defer delete(ud.likes, rkey) 277 } else { 278 logger.Error("like record not found", "rkey", rkey) 279 return nil 280 } 281 } else { 282 if err := json.Unmarshal(event.Commit.Record, &like); err != nil { 283 logger.Error("failed to unmarshal like", "error", err) 284 return nil 285 } 286 } 287 288 // if there is no via it means its not a repost anyway 289 if like.Via == nil { 290 return nil 291 } 292 293 // store for later when it gets deleted so we can fetch the record 294 if !deleted { 295 ud.likes[rkey] = like 296 } 297 298 repostURI := syntax.ATURI(like.Via.Uri) 299 // if not a repost we dont care 300 if repostURI.Collection() != "app.bsky.feed.repost" { 301 return nil 302 } 303 reposterDID, err := repostURI.Authority().AsDID() 304 if err != nil { 305 return err 306 } 307 ud.targets.Range(func(sid string, sd *SubscriberData) bool { 308 if sd.SubscribedTo != reposterDID { 309 return true 310 } 311 312 notification := NotificationMessage{ 313 Liked: !deleted, 314 ByDid: byDid, 315 RepostURI: repostURI, 316 } 317 318 if err := sd.Conn.WriteJSON(notification); err != nil { 319 logger.Error("failed to send notification", "subscriber", sd.SubscribedTo, "error", err) 320 } 321 return true 322 }) 323 324 return nil 325} 326 327func HandleFollowEvent(ctx context.Context, event *models.Event) error { 328 if event == nil || event.Commit == nil { 329 return nil 330 } 331 332 byDid := syntax.DID(event.Did) 333 ud, exists := userData.Get(byDid) 334 if !exists || ud.targets.Len() == 0 { 335 return nil 336 } 337 338 deleted := event.Commit.Operation == models.CommitOperationDelete 339 rkey := syntax.RecordKey(event.Commit.RKey) 340 341 switch event.Commit.Collection { 342 case "app.bsky.graph.follow": 343 var r bsky.GraphFollow 344 if deleted { 345 if f, exists := ud.follows.Get(rkey); exists { 346 r = f 347 } else { 348 logger.Error("follow record not found", "rkey", rkey) 349 return nil 350 } 351 ud.follows.Del(rkey) 352 } else { 353 if err := unmarshalEvent(event, &r); err != nil { 354 logger.Error("could not unmarshal follow event", "error", err) 355 return nil 356 } 357 ud.follows.Insert(rkey, r) 358 } 359 ud.targets.Range(func(sid string, sd *SubscriberData) bool { 360 // if we arent managing then we dont need to update anything 361 if sd.ListenType != ListenTypeFollows { 362 return true 363 } 364 subjectDid := syntax.DID(r.Subject) 365 if deleted { 366 stopListeningTo(sid, subjectDid) 367 delete(sd.ListenTo, subjectDid) 368 } else { 369 sd.ListenTo[subjectDid] = struct{}{} 370 startListeningTo(sid, sd, subjectDid) 371 } 372 return true 373 }) 374 } 375 376 return nil 377} 378 379func unmarshalEvent[v any](event *models.Event, val *v) error { 380 if err := json.Unmarshal(event.Commit.Record, val); err != nil { 381 logger.Error("failed to unmarshal", "error", err, "raw", event.Commit.Record) 382 return nil 383 } 384 return nil 385}