its for when you want to get like notifications for your reposts
1package main 2 3import ( 4 "context" 5 "encoding/json" 6 "log" 7 "log/slog" 8 "net/http" 9 "sync/atomic" 10 11 "github.com/bluesky-social/indigo/api/bsky" 12 "github.com/bluesky-social/indigo/atproto/syntax" 13 "github.com/bluesky-social/indigo/xrpc" 14 "github.com/bluesky-social/jetstream/pkg/client" 15 "github.com/bluesky-social/jetstream/pkg/models" 16 "github.com/cornelk/hashmap" 17 "github.com/google/uuid" 18 "github.com/gorilla/mux" 19 "github.com/gorilla/websocket" 20) 21 22type Set[T comparable] map[T]struct{} 23 24const ListenTypeNone = "none" 25const ListenTypeFollows = "follows" 26 27type SubscriberData struct { 28 SubscribedTo syntax.DID 29 Conn *websocket.Conn 30 ListenType string 31 ListenTo Set[syntax.DID] 32} 33 34type UserData struct { 35 targets *hashmap.Map[string, *SubscriberData] 36 likes map[syntax.RecordKey]bsky.FeedLike 37 follows *hashmap.Map[syntax.RecordKey, bsky.GraphFollow] 38 followsCursor atomic.Pointer[string] 39} 40 41type NotificationMessage struct { 42 Liked bool `json:"liked"` 43 ByDid syntax.DID `json:"did"` 44 RepostURI syntax.ATURI `json:"repost_uri"` 45 PostURI syntax.ATURI `json:"post_uri"` 46} 47 48type SubscriberMessage struct { 49 Type string `json:"type"` 50 Content json.RawMessage `json:"content"` 51} 52 53type SubscriberUpdateListenTo struct { 54 ListenTo []syntax.DID `json:"listen_to"` 55} 56 57var ( 58 // storing the subscriber data in both Should Be Fine 59 // we dont modify subscriber data at the same time in two places 60 subscribers = hashmap.New[string, *SubscriberData]() 61 userData = hashmap.New[syntax.DID, *UserData]() 62 63 likeStream *client.Client 64 followStream *client.Client 65 66 upgrader = websocket.Upgrader{ 67 CheckOrigin: func(r *http.Request) bool { 68 return true 69 }, 70 } 71 72 logger *slog.Logger 73) 74 75func getSubscriberDids() []string { 76 _dids := make(Set[string], subscribers.Len()) 77 subscribers.Range(func(s string, sd *SubscriberData) bool { 78 _dids[string(sd.SubscribedTo)] = struct{}{} 79 return true 80 }) 81 dids := make([]string, 0, len(_dids)) 82 for k := range _dids { 83 dids = append(dids, k) 84 } 85 return dids 86} 87 88func getUserData(did syntax.DID) *UserData { 89 ud, _ := userData.GetOrInsert(did, &UserData{ 90 targets: hashmap.New[string, *SubscriberData](), 91 likes: make(map[syntax.RecordKey]bsky.FeedLike), 92 follows: hashmap.New[syntax.RecordKey, bsky.GraphFollow](), 93 }) 94 return ud 95} 96 97func startListeningTo(sid string, sd *SubscriberData, did syntax.DID) { 98 ud := getUserData(did) 99 ud.targets.Insert(sid, sd) 100} 101 102func stopListeningTo(sid string, did syntax.DID) { 103 if ud, exists := userData.Get(did); exists { 104 ud.targets.Del(sid) 105 } 106} 107 108func main() { 109 logger = slog.Default() 110 111 go startJetstreamLoop(logger, &likeStream, "like_tracker", HandleLikeEvent, getLikeStreamOpts) 112 go startJetstreamLoop(logger, &followStream, "subscriber", HandleFollowEvent, getFollowStreamOpts) 113 114 r := mux.NewRouter() 115 r.HandleFunc("/subscribe/{did}", handleSubscribe).Methods("GET") 116 117 log.Println("server starting on :8080") 118 if err := http.ListenAndServe(":8080", r); err != nil { 119 log.Fatalf("error while serving: %s", err) 120 } 121} 122 123func handleSubscribe(w http.ResponseWriter, r *http.Request) { 124 vars := mux.Vars(r) 125 did, err := syntax.ParseDID(vars["did"]) 126 if err != nil { 127 http.Error(w, "not a valid did", http.StatusBadRequest) 128 return 129 } 130 sid := uuid.New().String() 131 132 query := r.URL.Query() 133 listenType := query.Get("listenTo") 134 if len(listenType) == 0 { 135 listenType = ListenTypeFollows 136 } 137 138 logger := logger.With("did", did, "subscriberId", sid) 139 140 conn, err := upgrader.Upgrade(w, r, nil) 141 if err != nil { 142 logger.Error("WebSocket upgrade failed", "error", err) 143 return 144 } 145 defer conn.Close() 146 147 logger.Info("new subscriber") 148 149 pdsURI, err := findUserPDS(r.Context(), did) 150 if err != nil { 151 logger.Error("cant resolve user pds", "error", err) 152 return 153 } 154 logger = logger.With("pds", pdsURI) 155 156 xrpcClient := &xrpc.Client{ 157 Host: pdsURI, 158 } 159 160 ud := getUserData(did) 161 sd := &SubscriberData{ 162 SubscribedTo: did, 163 Conn: conn, 164 ListenType: listenType, 165 } 166 167 switch listenType { 168 case ListenTypeFollows: 169 follows, err := fetchFollows(r.Context(), xrpcClient, ud.followsCursor.Load(), did) 170 if err != nil { 171 logger.Error("error fetching follows", "error", err) 172 return 173 } 174 sd.ListenTo = make(Set[syntax.DID]) 175 // use we have stored 176 ud.follows.Range(func(rk syntax.RecordKey, f bsky.GraphFollow) bool { 177 sd.ListenTo[syntax.DID(f.Subject)] = struct{}{} 178 return true 179 }) 180 if len(follows) > 0 { 181 // store cursor for later requests so we dont have to fetch the whole thing again 182 ud.followsCursor.Store((*string)(&follows[len(follows)-1].rkey)) 183 for _, f := range follows { 184 ud.follows.Insert(f.rkey, f.follow) 185 sd.ListenTo[syntax.DID(f.follow.Subject)] = struct{}{} 186 } 187 } 188 logger.Info("fetched follows") 189 case ListenTypeNone: 190 sd.ListenTo = make(Set[syntax.DID]) 191 default: 192 http.Error(w, "invalid listen type", http.StatusBadRequest) 193 return 194 } 195 196 subscribers.Set(sid, sd) 197 for listenDid := range sd.ListenTo { 198 startListeningTo(sid, sd, listenDid) 199 } 200 updateFollowStreamOpts() 201 // delete subscriber after we are done 202 defer func() { 203 for listenDid := range sd.ListenTo { 204 stopListeningTo(sid, listenDid) 205 } 206 subscribers.Del(sid) 207 updateFollowStreamOpts() 208 }() 209 210 logger.Info("serving subscriber") 211 212 for { 213 var msg SubscriberMessage 214 err := conn.ReadJSON(&msg) 215 if err != nil { 216 logger.Info("WebSocket connection closed", "error", err) 217 break 218 } 219 switch msg.Type { 220 case "update_listen_to": 221 // only allow this if we arent managing listen to 222 if sd.ListenType != ListenTypeNone { 223 continue 224 } 225 226 var innerMsg SubscriberUpdateListenTo 227 if err := json.Unmarshal(msg.Content, &innerMsg); err != nil { 228 logger.Info("invalid message", "error", err) 229 break 230 } 231 // remove all current listens and add the ones the user requested 232 for listenDid := range sd.ListenTo { 233 stopListeningTo(sid, listenDid) 234 delete(sd.ListenTo, listenDid) 235 } 236 for _, listenDid := range innerMsg.ListenTo { 237 sd.ListenTo[listenDid] = struct{}{} 238 startListeningTo(sid, sd, listenDid) 239 } 240 } 241 } 242} 243 244func getLikeStreamOpts() models.SubscriberOptionsUpdatePayload { 245 return models.SubscriberOptionsUpdatePayload{ 246 WantedCollections: []string{"app.bsky.feed.like"}, 247 } 248} 249 250func getFollowStreamOpts() models.SubscriberOptionsUpdatePayload { 251 return models.SubscriberOptionsUpdatePayload{ 252 WantedCollections: []string{"app.bsky.graph.follow"}, 253 WantedDIDs: getSubscriberDids(), 254 } 255} 256 257func updateFollowStreamOpts() { 258 opts := getFollowStreamOpts() 259 err := followStream.SendOptionsUpdate(opts) 260 if err != nil { 261 logger.Error("couldnt update follow stream opts", "error", err) 262 return 263 } 264 logger.Info("updated follow stream opts", "userCount", len(opts.WantedDIDs)) 265} 266 267func HandleLikeEvent(ctx context.Context, event *models.Event) error { 268 if event == nil || event.Commit == nil { 269 return nil 270 } 271 272 byDid := syntax.DID(event.Did) 273 // skip handling event if its not from a source we are listening to 274 ud, exists := userData.Get(byDid) 275 if !exists || ud.targets.Len() == 0 { 276 return nil 277 } 278 279 deleted := event.Commit.Operation == models.CommitOperationDelete 280 rkey := syntax.RecordKey(event.Commit.RKey) 281 282 var like bsky.FeedLike 283 if deleted { 284 if l, exists := ud.likes[rkey]; exists { 285 like = l 286 defer delete(ud.likes, rkey) 287 } else { 288 logger.Error("like record not found", "rkey", rkey) 289 return nil 290 } 291 } else { 292 if err := json.Unmarshal(event.Commit.Record, &like); err != nil { 293 logger.Error("failed to unmarshal like", "error", err) 294 return nil 295 } 296 } 297 298 // if there is no via it means its not a repost anyway 299 if like.Via == nil { 300 return nil 301 } 302 303 // store for later when it gets deleted so we can fetch the record 304 if !deleted { 305 ud.likes[rkey] = like 306 } 307 308 repostURI := syntax.ATURI(like.Via.Uri) 309 // if not a repost we dont care 310 if repostURI.Collection() != "app.bsky.feed.repost" { 311 return nil 312 } 313 reposterDID, err := repostURI.Authority().AsDID() 314 if err != nil { 315 return err 316 } 317 ud.targets.Range(func(sid string, sd *SubscriberData) bool { 318 if sd.SubscribedTo != reposterDID { 319 return true 320 } 321 322 notification := NotificationMessage{ 323 Liked: !deleted, 324 ByDid: byDid, 325 RepostURI: repostURI, 326 PostURI: syntax.ATURI(like.Subject.Uri), 327 } 328 329 if err := sd.Conn.WriteJSON(notification); err != nil { 330 logger.Error("failed to send notification", "subscriber", sd.SubscribedTo, "error", err) 331 } 332 return true 333 }) 334 335 return nil 336} 337 338func HandleFollowEvent(ctx context.Context, event *models.Event) error { 339 if event == nil || event.Commit == nil { 340 return nil 341 } 342 343 byDid := syntax.DID(event.Did) 344 ud, exists := userData.Get(byDid) 345 if !exists || ud.targets.Len() == 0 { 346 return nil 347 } 348 349 deleted := event.Commit.Operation == models.CommitOperationDelete 350 rkey := syntax.RecordKey(event.Commit.RKey) 351 352 switch event.Commit.Collection { 353 case "app.bsky.graph.follow": 354 var r bsky.GraphFollow 355 if deleted { 356 if f, exists := ud.follows.Get(rkey); exists { 357 r = f 358 } else { 359 logger.Error("follow record not found", "rkey", rkey) 360 return nil 361 } 362 ud.follows.Del(rkey) 363 } else { 364 if err := unmarshalEvent(event, &r); err != nil { 365 logger.Error("could not unmarshal follow event", "error", err) 366 return nil 367 } 368 ud.follows.Insert(rkey, r) 369 } 370 ud.targets.Range(func(sid string, sd *SubscriberData) bool { 371 // if we arent managing then we dont need to update anything 372 if sd.ListenType != ListenTypeFollows { 373 return true 374 } 375 subjectDid := syntax.DID(r.Subject) 376 if deleted { 377 stopListeningTo(sid, subjectDid) 378 delete(sd.ListenTo, subjectDid) 379 } else { 380 sd.ListenTo[subjectDid] = struct{}{} 381 startListeningTo(sid, sd, subjectDid) 382 } 383 return true 384 }) 385 } 386 387 return nil 388} 389 390func unmarshalEvent[v any](event *models.Event, val *v) error { 391 if err := json.Unmarshal(event.Commit.Record, val); err != nil { 392 logger.Error("failed to unmarshal", "error", err, "raw", event.Commit.Record) 393 return nil 394 } 395 return nil 396}