its for when you want to get like notifications for your reposts
1package main 2 3import ( 4 "context" 5 "encoding/json" 6 "log" 7 "log/slog" 8 "net/http" 9 "sync/atomic" 10 11 "github.com/bluesky-social/indigo/api/bsky" 12 "github.com/bluesky-social/indigo/atproto/syntax" 13 "github.com/bluesky-social/indigo/xrpc" 14 "github.com/bluesky-social/jetstream/pkg/client" 15 "github.com/bluesky-social/jetstream/pkg/models" 16 "github.com/cornelk/hashmap" 17 "github.com/google/uuid" 18 "github.com/gorilla/mux" 19 "github.com/gorilla/websocket" 20) 21 22type Set[T comparable] map[T]struct{} 23 24const ListenTypeNone = "none" 25const ListenTypeFollows = "follows" 26 27type SubscriberData struct { 28 SubscribedTo syntax.DID 29 Conn *websocket.Conn 30 ListenType string 31 ListenTo Set[syntax.DID] 32} 33 34type UserData struct { 35 targets *hashmap.Map[string, *SubscriberData] 36 likes map[syntax.RecordKey]bsky.FeedLike 37 follows *hashmap.Map[syntax.RecordKey, bsky.GraphFollow] 38 followsCursor atomic.Pointer[string] 39} 40 41type NotificationMessage struct { 42 Liked bool `json:"liked"` 43 ByDid syntax.DID `json:"did"` 44 RepostURI syntax.ATURI `json:"repost_uri"` 45 PostURI syntax.ATURI `json:"post_uri"` 46} 47 48type SubscriberMessage struct { 49 Type string `json:"type"` 50 Content json.RawMessage `json:"content"` 51} 52 53type SubscriberUpdateListenTo struct { 54 ListenTo []syntax.DID `json:"listen_to"` 55} 56 57var ( 58 // storing the subscriber data in both Should Be Fine 59 // we dont modify subscriber data at the same time in two places 60 subscribers = hashmap.New[string, *SubscriberData]() 61 userData = hashmap.New[syntax.DID, *UserData]() 62 63 likeStream *client.Client 64 followStream *client.Client 65 66 upgrader = websocket.Upgrader{ 67 CheckOrigin: func(r *http.Request) bool { 68 return true 69 }, 70 } 71 72 logger *slog.Logger 73) 74 75func getSubscriberDids() []string { 76 _dids := make(Set[string], subscribers.Len()) 77 subscribers.Range(func(s string, sd *SubscriberData) bool { 78 _dids[string(sd.SubscribedTo)] = struct{}{} 79 return true 80 }) 81 dids := make([]string, 0, len(_dids)) 82 for k := range _dids { 83 dids = append(dids, k) 84 } 85 return dids 86} 87 88func getUserData(did syntax.DID) *UserData { 89 ud, _ := userData.GetOrInsert(did, &UserData{ 90 targets: hashmap.New[string, *SubscriberData](), 91 likes: make(map[syntax.RecordKey]bsky.FeedLike), 92 follows: hashmap.New[syntax.RecordKey, bsky.GraphFollow](), 93 }) 94 return ud 95} 96 97func startListeningTo(sid string, sd *SubscriberData, did syntax.DID) { 98 ud := getUserData(did) 99 ud.targets.Insert(sid, sd) 100} 101 102func stopListeningTo(sid string, did syntax.DID) { 103 if ud, exists := userData.Get(did); exists { 104 ud.targets.Del(sid) 105 } 106} 107 108func main() { 109 logger = slog.Default() 110 111 go startJetstreamLoop(logger, &likeStream, "like_tracker", HandleLikeEvent, getLikeStreamOpts) 112 go startJetstreamLoop(logger, &followStream, "subscriber", HandleFollowEvent, getFollowStreamOpts) 113 114 r := mux.NewRouter() 115 r.HandleFunc("/subscribe/{did}", handleSubscribe).Methods("GET") 116 117 log.Println("server starting on :8080") 118 if err := http.ListenAndServe(":8080", r); err != nil { 119 log.Fatalf("error while serving: %s", err) 120 } 121} 122 123func handleSubscribe(w http.ResponseWriter, r *http.Request) { 124 vars := mux.Vars(r) 125 did, err := syntax.ParseDID(vars["did"]) 126 if err != nil { 127 http.Error(w, "not a valid did", http.StatusBadRequest) 128 return 129 } 130 sid := uuid.New().String() 131 132 query := r.URL.Query() 133 listenType := query.Get("listenTo") 134 if len(listenType) == 0 { 135 listenType = ListenTypeFollows 136 } 137 138 logger := logger.With("did", did, "subscriberId", sid) 139 140 conn, err := upgrader.Upgrade(w, r, nil) 141 if err != nil { 142 logger.Error("WebSocket upgrade failed", "error", err) 143 return 144 } 145 defer conn.Close() 146 147 logger.Info("new subscriber") 148 149 pdsURI, err := findUserPDS(r.Context(), did) 150 if err != nil { 151 logger.Error("cant resolve user pds", "error", err) 152 return 153 } 154 logger = logger.With("pds", pdsURI) 155 156 xrpcClient := &xrpc.Client{ 157 Host: pdsURI, 158 } 159 160 ud := getUserData(did) 161 sd := &SubscriberData{ 162 SubscribedTo: did, 163 Conn: conn, 164 ListenType: listenType, 165 } 166 167 switch listenType { 168 case ListenTypeFollows: 169 follows, err := fetchFollows(r.Context(), xrpcClient, ud.followsCursor.Load(), did) 170 if err != nil { 171 logger.Error("error fetching follows", "error", err) 172 return 173 } 174 sd.ListenTo = make(Set[syntax.DID]) 175 if len(follows) > 0 { 176 // store cursor for later requests so we dont have to fetch the whole thing again 177 ud.followsCursor.Store((*string)(&follows[len(follows)-1].rkey)) 178 for _, f := range follows { 179 ud.follows.Insert(f.rkey, f.follow) 180 sd.ListenTo[syntax.DID(f.follow.Subject)] = struct{}{} 181 } 182 } 183 logger.Info("fetched follows") 184 case ListenTypeNone: 185 sd.ListenTo = make(Set[syntax.DID]) 186 default: 187 http.Error(w, "invalid listen type", http.StatusBadRequest) 188 return 189 } 190 191 subscribers.Set(sid, sd) 192 for listenDid := range sd.ListenTo { 193 startListeningTo(sid, sd, listenDid) 194 } 195 updateFollowStreamOpts() 196 // delete subscriber after we are done 197 defer func() { 198 for listenDid := range sd.ListenTo { 199 stopListeningTo(sid, listenDid) 200 } 201 subscribers.Del(sid) 202 updateFollowStreamOpts() 203 }() 204 205 logger.Info("serving subscriber") 206 207 for { 208 var msg SubscriberMessage 209 err := conn.ReadJSON(&msg) 210 if err != nil { 211 logger.Info("WebSocket connection closed", "error", err) 212 break 213 } 214 switch msg.Type { 215 case "update_listen_to": 216 // only allow this if we arent managing listen to 217 if sd.ListenType != ListenTypeNone { 218 continue 219 } 220 221 var innerMsg SubscriberUpdateListenTo 222 if err := json.Unmarshal(msg.Content, &innerMsg); err != nil { 223 logger.Info("invalid message", "error", err) 224 break 225 } 226 // remove all current listens and add the ones the user requested 227 for listenDid := range sd.ListenTo { 228 stopListeningTo(sid, listenDid) 229 delete(sd.ListenTo, listenDid) 230 } 231 for _, listenDid := range innerMsg.ListenTo { 232 sd.ListenTo[listenDid] = struct{}{} 233 startListeningTo(sid, sd, listenDid) 234 } 235 } 236 } 237} 238 239func getLikeStreamOpts() models.SubscriberOptionsUpdatePayload { 240 return models.SubscriberOptionsUpdatePayload{ 241 WantedCollections: []string{"app.bsky.feed.like"}, 242 } 243} 244 245func getFollowStreamOpts() models.SubscriberOptionsUpdatePayload { 246 return models.SubscriberOptionsUpdatePayload{ 247 WantedCollections: []string{"app.bsky.graph.follow"}, 248 WantedDIDs: getSubscriberDids(), 249 } 250} 251 252func updateFollowStreamOpts() { 253 opts := getFollowStreamOpts() 254 err := followStream.SendOptionsUpdate(opts) 255 if err != nil { 256 logger.Error("couldnt update follow stream opts", "error", err) 257 return 258 } 259 logger.Info("updated follow stream opts", "userCount", len(opts.WantedDIDs)) 260} 261 262func HandleLikeEvent(ctx context.Context, event *models.Event) error { 263 if event == nil || event.Commit == nil { 264 return nil 265 } 266 267 byDid := syntax.DID(event.Did) 268 // skip handling event if its not from a source we are listening to 269 ud, exists := userData.Get(byDid) 270 if !exists || ud.targets.Len() == 0 { 271 return nil 272 } 273 274 deleted := event.Commit.Operation == models.CommitOperationDelete 275 rkey := syntax.RecordKey(event.Commit.RKey) 276 277 var like bsky.FeedLike 278 if deleted { 279 if l, exists := ud.likes[rkey]; exists { 280 like = l 281 defer delete(ud.likes, rkey) 282 } else { 283 logger.Error("like record not found", "rkey", rkey) 284 return nil 285 } 286 } else { 287 if err := json.Unmarshal(event.Commit.Record, &like); err != nil { 288 logger.Error("failed to unmarshal like", "error", err) 289 return nil 290 } 291 } 292 293 // if there is no via it means its not a repost anyway 294 if like.Via == nil { 295 return nil 296 } 297 298 // store for later when it gets deleted so we can fetch the record 299 if !deleted { 300 ud.likes[rkey] = like 301 } 302 303 repostURI := syntax.ATURI(like.Via.Uri) 304 // if not a repost we dont care 305 if repostURI.Collection() != "app.bsky.feed.repost" { 306 return nil 307 } 308 reposterDID, err := repostURI.Authority().AsDID() 309 if err != nil { 310 return err 311 } 312 ud.targets.Range(func(sid string, sd *SubscriberData) bool { 313 if sd.SubscribedTo != reposterDID { 314 return true 315 } 316 317 notification := NotificationMessage{ 318 Liked: !deleted, 319 ByDid: byDid, 320 RepostURI: repostURI, 321 PostURI: syntax.ATURI(like.Subject.Uri), 322 } 323 324 if err := sd.Conn.WriteJSON(notification); err != nil { 325 logger.Error("failed to send notification", "subscriber", sd.SubscribedTo, "error", err) 326 } 327 return true 328 }) 329 330 return nil 331} 332 333func HandleFollowEvent(ctx context.Context, event *models.Event) error { 334 if event == nil || event.Commit == nil { 335 return nil 336 } 337 338 byDid := syntax.DID(event.Did) 339 ud, exists := userData.Get(byDid) 340 if !exists || ud.targets.Len() == 0 { 341 return nil 342 } 343 344 deleted := event.Commit.Operation == models.CommitOperationDelete 345 rkey := syntax.RecordKey(event.Commit.RKey) 346 347 switch event.Commit.Collection { 348 case "app.bsky.graph.follow": 349 var r bsky.GraphFollow 350 if deleted { 351 if f, exists := ud.follows.Get(rkey); exists { 352 r = f 353 } else { 354 logger.Error("follow record not found", "rkey", rkey) 355 return nil 356 } 357 ud.follows.Del(rkey) 358 } else { 359 if err := unmarshalEvent(event, &r); err != nil { 360 logger.Error("could not unmarshal follow event", "error", err) 361 return nil 362 } 363 ud.follows.Insert(rkey, r) 364 } 365 ud.targets.Range(func(sid string, sd *SubscriberData) bool { 366 // if we arent managing then we dont need to update anything 367 if sd.ListenType != ListenTypeFollows { 368 return true 369 } 370 subjectDid := syntax.DID(r.Subject) 371 if deleted { 372 stopListeningTo(sid, subjectDid) 373 delete(sd.ListenTo, subjectDid) 374 } else { 375 sd.ListenTo[subjectDid] = struct{}{} 376 startListeningTo(sid, sd, subjectDid) 377 } 378 return true 379 }) 380 } 381 382 return nil 383} 384 385func unmarshalEvent[v any](event *models.Event, val *v) error { 386 if err := json.Unmarshal(event.Commit.Record, val); err != nil { 387 logger.Error("failed to unmarshal", "error", err, "raw", event.Commit.Record) 388 return nil 389 } 390 return nil 391}