its for when you want to get like notifications for your reposts
1package main 2 3import ( 4 "context" 5 "encoding/json" 6 "log" 7 "log/slog" 8 "net/http" 9 "sync/atomic" 10 11 "github.com/bluesky-social/indigo/api/bsky" 12 "github.com/bluesky-social/indigo/atproto/syntax" 13 "github.com/bluesky-social/indigo/xrpc" 14 "github.com/bluesky-social/jetstream/pkg/client" 15 "github.com/bluesky-social/jetstream/pkg/models" 16 "github.com/cornelk/hashmap" 17 "github.com/google/uuid" 18 "github.com/gorilla/mux" 19 "github.com/gorilla/websocket" 20) 21 22type Set[T comparable] map[T]struct{} 23 24const ListenTypeNone = "none" 25const ListenTypeFollows = "follows" 26 27type SubscriberData struct { 28 SubscribedTo syntax.DID 29 Conn *websocket.Conn 30 ListenType string 31 ListenTo Set[syntax.DID] 32} 33 34type UserData struct { 35 targets *hashmap.Map[string, *SubscriberData] 36 likes map[syntax.RecordKey]bsky.FeedLike 37 follows *hashmap.Map[syntax.RecordKey, bsky.GraphFollow] 38 followsCursor atomic.Pointer[string] 39} 40 41type NotificationMessage struct { 42 Liked bool `json:"liked"` 43 ByDid syntax.DID `json:"did"` 44 RepostURI syntax.ATURI `json:"repost_uri"` 45} 46 47type SubscriberMessage struct { 48 Type string `json:"type"` 49 Content json.RawMessage `json:"content"` 50} 51 52type SubscriberUpdateListenTo struct { 53 ListenTo []syntax.DID `json:"listen_to"` 54} 55 56var ( 57 // storing the subscriber data in both Should Be Fine 58 // we dont modify subscriber data at the same time in two places 59 subscribers = hashmap.New[string, *SubscriberData]() 60 userData = hashmap.New[syntax.DID, *UserData]() 61 62 likeStream *client.Client 63 followStream *client.Client 64 65 upgrader = websocket.Upgrader{ 66 CheckOrigin: func(r *http.Request) bool { 67 return true 68 }, 69 } 70 71 logger *slog.Logger 72) 73 74func getSubscriberDids() []string { 75 _dids := make(Set[string], subscribers.Len()) 76 subscribers.Range(func(s string, sd *SubscriberData) bool { 77 _dids[string(sd.SubscribedTo)] = struct{}{} 78 return true 79 }) 80 dids := make([]string, 0, len(_dids)) 81 for k := range _dids { 82 dids = append(dids, k) 83 } 84 return dids 85} 86 87func getUserData(did syntax.DID) *UserData { 88 ud, _ := userData.GetOrInsert(did, &UserData{ 89 targets: hashmap.New[string, *SubscriberData](), 90 likes: make(map[syntax.RecordKey]bsky.FeedLike), 91 follows: hashmap.New[syntax.RecordKey, bsky.GraphFollow](), 92 }) 93 return ud 94} 95 96func startListeningTo(sid string, sd *SubscriberData, did syntax.DID) { 97 ud := getUserData(did) 98 ud.targets.Insert(sid, sd) 99} 100 101func stopListeningTo(sid string, did syntax.DID) { 102 if ud, exists := userData.Get(did); exists { 103 ud.targets.Del(sid) 104 } 105} 106 107func main() { 108 logger = slog.Default() 109 110 go startJetstreamLoop(logger, &likeStream, "like_tracker", HandleLikeEvent, getLikeStreamOpts) 111 go startJetstreamLoop(logger, &followStream, "subscriber", HandleFollowEvent, getFollowStreamOpts) 112 113 r := mux.NewRouter() 114 r.HandleFunc("/subscribe/{did}", handleSubscribe).Methods("GET") 115 116 log.Println("server starting on :8080") 117 if err := http.ListenAndServe(":8080", r); err != nil { 118 log.Fatalf("error while serving: %s", err) 119 } 120} 121 122func handleSubscribe(w http.ResponseWriter, r *http.Request) { 123 vars := mux.Vars(r) 124 did, err := syntax.ParseDID(vars["did"]) 125 if err != nil { 126 http.Error(w, "not a valid did", http.StatusBadRequest) 127 return 128 } 129 sid := uuid.New().String() 130 131 query := r.URL.Query() 132 listenType := query.Get("listenTo") 133 if len(listenType) == 0 { 134 listenType = ListenTypeFollows 135 } 136 137 logger := logger.With("did", did, "subscriberId", sid) 138 139 conn, err := upgrader.Upgrade(w, r, nil) 140 if err != nil { 141 logger.Error("WebSocket upgrade failed", "error", err) 142 return 143 } 144 defer conn.Close() 145 146 logger.Info("new subscriber") 147 148 pdsURI, err := findUserPDS(r.Context(), did) 149 if err != nil { 150 logger.Error("cant resolve user pds", "error", err) 151 return 152 } 153 logger = logger.With("pds", pdsURI) 154 155 xrpcClient := &xrpc.Client{ 156 Host: pdsURI, 157 } 158 159 ud := getUserData(did) 160 sd := &SubscriberData{ 161 SubscribedTo: did, 162 Conn: conn, 163 ListenType: listenType, 164 } 165 166 switch listenType { 167 case ListenTypeFollows: 168 follows, err := fetchFollows(r.Context(), xrpcClient, ud.followsCursor.Load(), did) 169 if err != nil { 170 logger.Error("error fetching follows", "error", err) 171 return 172 } 173 sd.ListenTo = make(Set[syntax.DID]) 174 if len(follows) > 0 { 175 // store cursor for later requests so we dont have to fetch the whole thing again 176 ud.followsCursor.Store((*string)(&follows[len(follows)-1].rkey)) 177 for _, f := range follows { 178 ud.follows.Insert(f.rkey, f.follow) 179 sd.ListenTo[syntax.DID(f.follow.Subject)] = struct{}{} 180 } 181 } 182 logger.Info("fetched follows") 183 case ListenTypeNone: 184 sd.ListenTo = make(Set[syntax.DID]) 185 default: 186 http.Error(w, "invalid listen type", http.StatusBadRequest) 187 return 188 } 189 190 subscribers.Set(sid, sd) 191 for listenDid := range sd.ListenTo { 192 startListeningTo(sid, sd, listenDid) 193 } 194 updateFollowStreamOpts() 195 // delete subscriber after we are done 196 defer func() { 197 for listenDid := range sd.ListenTo { 198 stopListeningTo(sid, listenDid) 199 } 200 subscribers.Del(sid) 201 updateFollowStreamOpts() 202 }() 203 204 logger.Info("serving subscriber") 205 206 for { 207 var msg SubscriberMessage 208 err := conn.ReadJSON(&msg) 209 if err != nil { 210 logger.Info("WebSocket connection closed", "error", err) 211 break 212 } 213 switch msg.Type { 214 case "update_listen_to": 215 // only allow this if we arent managing listen to 216 if sd.ListenType != ListenTypeNone { 217 continue 218 } 219 220 var innerMsg SubscriberUpdateListenTo 221 if err := json.Unmarshal(msg.Content, &innerMsg); err != nil { 222 logger.Info("invalid message", "error", err) 223 break 224 } 225 // remove all current listens and add the ones the user requested 226 for listenDid := range sd.ListenTo { 227 stopListeningTo(sid, listenDid) 228 delete(sd.ListenTo, listenDid) 229 } 230 for _, listenDid := range innerMsg.ListenTo { 231 sd.ListenTo[listenDid] = struct{}{} 232 startListeningTo(sid, sd, listenDid) 233 } 234 } 235 } 236} 237 238func getLikeStreamOpts() models.SubscriberOptionsUpdatePayload { 239 return models.SubscriberOptionsUpdatePayload{ 240 WantedCollections: []string{"app.bsky.feed.like"}, 241 } 242} 243 244func getFollowStreamOpts() models.SubscriberOptionsUpdatePayload { 245 return models.SubscriberOptionsUpdatePayload{ 246 WantedCollections: []string{"app.bsky.graph.follow"}, 247 WantedDIDs: getSubscriberDids(), 248 } 249} 250 251func updateFollowStreamOpts() { 252 opts := getFollowStreamOpts() 253 err := followStream.SendOptionsUpdate(opts) 254 if err != nil { 255 logger.Error("couldnt update follow stream opts", "error", err) 256 return 257 } 258 logger.Info("updated follow stream opts", "userCount", len(opts.WantedDIDs)) 259} 260 261func HandleLikeEvent(ctx context.Context, event *models.Event) error { 262 if event == nil || event.Commit == nil { 263 return nil 264 } 265 266 byDid := syntax.DID(event.Did) 267 // skip handling event if its not from a source we are listening to 268 ud, exists := userData.Get(byDid) 269 if !exists || ud.targets.Len() == 0 { 270 return nil 271 } 272 273 deleted := event.Commit.Operation == models.CommitOperationDelete 274 rkey := syntax.RecordKey(event.Commit.RKey) 275 276 var like bsky.FeedLike 277 if deleted { 278 if l, exists := ud.likes[rkey]; exists { 279 like = l 280 defer delete(ud.likes, rkey) 281 } else { 282 logger.Error("like record not found", "rkey", rkey) 283 return nil 284 } 285 } else { 286 if err := json.Unmarshal(event.Commit.Record, &like); err != nil { 287 logger.Error("failed to unmarshal like", "error", err) 288 return nil 289 } 290 } 291 292 // if there is no via it means its not a repost anyway 293 if like.Via == nil { 294 return nil 295 } 296 297 // store for later when it gets deleted so we can fetch the record 298 if !deleted { 299 ud.likes[rkey] = like 300 } 301 302 repostURI := syntax.ATURI(like.Via.Uri) 303 // if not a repost we dont care 304 if repostURI.Collection() != "app.bsky.feed.repost" { 305 return nil 306 } 307 reposterDID, err := repostURI.Authority().AsDID() 308 if err != nil { 309 return err 310 } 311 ud.targets.Range(func(sid string, sd *SubscriberData) bool { 312 if sd.SubscribedTo != reposterDID { 313 return true 314 } 315 316 notification := NotificationMessage{ 317 Liked: !deleted, 318 ByDid: byDid, 319 RepostURI: repostURI, 320 } 321 322 if err := sd.Conn.WriteJSON(notification); err != nil { 323 logger.Error("failed to send notification", "subscriber", sd.SubscribedTo, "error", err) 324 } 325 return true 326 }) 327 328 return nil 329} 330 331func HandleFollowEvent(ctx context.Context, event *models.Event) error { 332 if event == nil || event.Commit == nil { 333 return nil 334 } 335 336 byDid := syntax.DID(event.Did) 337 ud, exists := userData.Get(byDid) 338 if !exists || ud.targets.Len() == 0 { 339 return nil 340 } 341 342 deleted := event.Commit.Operation == models.CommitOperationDelete 343 rkey := syntax.RecordKey(event.Commit.RKey) 344 345 switch event.Commit.Collection { 346 case "app.bsky.graph.follow": 347 var r bsky.GraphFollow 348 if deleted { 349 if f, exists := ud.follows.Get(rkey); exists { 350 r = f 351 } else { 352 logger.Error("follow record not found", "rkey", rkey) 353 return nil 354 } 355 ud.follows.Del(rkey) 356 } else { 357 if err := unmarshalEvent(event, &r); err != nil { 358 logger.Error("could not unmarshal follow event", "error", err) 359 return nil 360 } 361 ud.follows.Insert(rkey, r) 362 } 363 ud.targets.Range(func(sid string, sd *SubscriberData) bool { 364 // if we arent managing then we dont need to update anything 365 if sd.ListenType != ListenTypeFollows { 366 return true 367 } 368 subjectDid := syntax.DID(r.Subject) 369 if deleted { 370 stopListeningTo(sid, subjectDid) 371 delete(sd.ListenTo, subjectDid) 372 } else { 373 sd.ListenTo[subjectDid] = struct{}{} 374 startListeningTo(sid, sd, subjectDid) 375 } 376 return true 377 }) 378 } 379 380 return nil 381} 382 383func unmarshalEvent[v any](event *models.Event, val *v) error { 384 if err := json.Unmarshal(event.Commit.Record, val); err != nil { 385 logger.Error("failed to unmarshal", "error", err, "raw", event.Commit.Record) 386 return nil 387 } 388 return nil 389}