its for when you want to get like notifications for your reposts
1package main 2 3import ( 4 "context" 5 "encoding/json" 6 "log" 7 "log/slog" 8 "net/http" 9 "sync/atomic" 10 "time" 11 12 "github.com/bluesky-social/indigo/api/bsky" 13 "github.com/bluesky-social/indigo/atproto/syntax" 14 "github.com/bluesky-social/indigo/xrpc" 15 "github.com/bluesky-social/jetstream/pkg/models" 16 "github.com/cornelk/hashmap" 17 "github.com/google/uuid" 18 "github.com/gorilla/mux" 19 "github.com/gorilla/websocket" 20) 21 22type Set[T comparable] map[T]struct{} 23 24const ListenTypeNone = "none" 25const ListenTypeFollows = "follows" 26 27type SubscriberData struct { 28 forActor syntax.DID 29 conn *websocket.Conn 30 listenType string 31 listenTo Set[syntax.DID] 32} 33 34type ActorData struct { 35 targets *hashmap.Map[string, *SubscriberData] 36 likes *hashmap.Map[syntax.RecordKey, bsky.FeedLike] 37 follows *hashmap.Map[syntax.RecordKey, bsky.GraphFollow] 38 followsCursor atomic.Pointer[string] 39 profile *bsky.ActorDefs_ProfileViewDetailed 40 profileFetchedAt time.Time 41} 42 43type NotificationActor struct { 44 DID syntax.DID `json:"did"` // the DID of the actor that (un)liked the post 45 Profile *bsky.ActorDefs_ProfileViewDetailed `json:"profile"` // the detailed profile of the actor that (un)liked the post 46} 47 48type NotificationMessage struct { 49 Liked bool `json:"liked"` // whether the message was liked or unliked 50 Actor NotificationActor `json:"actor"` // information about the actor that (un)liked 51 Record bsky.FeedLike `json:"record"` // the raw like record 52 Time int64 `json:"time"` // when the like event came in 53} 54 55type SubscriberMessage struct { 56 Type string `json:"type"` 57 Content json.RawMessage `json:"content"` 58} 59 60type SubscriberUpdateListenTo struct { 61 ListenTo []syntax.DID `json:"listen_to"` 62} 63 64var ( 65 // storing the subscriber data in both Should Be Fine 66 // we dont modify subscriber data at the same time in two places 67 subscribers = hashmap.New[string, *SubscriberData]() 68 actorData = hashmap.New[syntax.DID, *ActorData]() 69 70 likeStreams *StreamManager 71 followStreams *StreamManager 72 73 upgrader = websocket.Upgrader{ 74 CheckOrigin: func(r *http.Request) bool { 75 return true 76 }, 77 } 78 79 logger *slog.Logger 80) 81 82func getSubscriberDids() []string { 83 _dids := make(Set[string], subscribers.Len()) 84 subscribers.Range(func(s string, sd *SubscriberData) bool { 85 _dids[string(sd.forActor)] = struct{}{} 86 return true 87 }) 88 dids := make([]string, 0, len(_dids)) 89 for k := range _dids { 90 dids = append(dids, k) 91 } 92 return dids 93} 94 95func getLikeDids() []string { 96 _dids := make(Set[string], subscribers.Len()*5000) 97 subscribers.Range(func(s string, sd *SubscriberData) bool { 98 for did := range sd.listenTo { 99 _dids[string(did)] = struct{}{} 100 } 101 return true 102 }) 103 dids := make([]string, 0, len(_dids)) 104 for k := range _dids { 105 dids = append(dids, k) 106 } 107 return dids 108} 109 110func getActorData(did syntax.DID) *ActorData { 111 ud, _ := actorData.GetOrInsert(did, &ActorData{ 112 targets: hashmap.New[string, *SubscriberData](), 113 likes: hashmap.New[syntax.RecordKey, bsky.FeedLike](), 114 follows: hashmap.New[syntax.RecordKey, bsky.GraphFollow](), 115 }) 116 return ud 117} 118 119func markActorForLikes(sid string, sd *SubscriberData, did syntax.DID) { 120 ud := getActorData(did) 121 ud.targets.Insert(sid, sd) 122} 123 124func unmarkActorForLikes(sid string, did syntax.DID) { 125 if ud, exists := actorData.Get(did); exists { 126 ud.targets.Del(sid) 127 } 128} 129 130func main() { 131 logger = slog.Default() 132 133 likeStreams = NewStreamManager(logger, "like-tracker", HandleLikeEvent, getLikeStreamOpts) 134 followStreams = NewStreamManager(logger, "subscriber", HandleFollowEvent, getFollowStreamOpts) 135 136 r := mux.NewRouter() 137 r.HandleFunc("/subscribe/{did}", handleSubscribe).Methods("GET") 138 139 log.Println("server starting on :8080") 140 if err := http.ListenAndServe(":8080", r); err != nil { 141 log.Fatalf("error while serving: %s", err) 142 } 143} 144 145func handleSubscribe(w http.ResponseWriter, r *http.Request) { 146 vars := mux.Vars(r) 147 did, err := syntax.ParseDID(vars["did"]) 148 if err != nil { 149 http.Error(w, "not a valid did", http.StatusBadRequest) 150 return 151 } 152 sid := uuid.New().String() 153 154 query := r.URL.Query() 155 listenType := query.Get("listenTo") 156 if len(listenType) == 0 { 157 listenType = ListenTypeFollows 158 } 159 160 logger := logger.With("did", did, "subscriberId", sid) 161 162 conn, err := upgrader.Upgrade(w, r, nil) 163 if err != nil { 164 logger.Error("WebSocket upgrade failed", "error", err) 165 return 166 } 167 defer conn.Close() 168 169 logger.Info("new subscriber") 170 171 pdsURI, err := findUserPDS(r.Context(), did) 172 if err != nil { 173 logger.Error("cant resolve user pds", "error", err) 174 return 175 } 176 logger = logger.With("pds", pdsURI) 177 178 xrpcClient := &xrpc.Client{ 179 Host: pdsURI, 180 } 181 182 ud := getActorData(did) 183 sd := &SubscriberData{ 184 forActor: did, 185 conn: conn, 186 listenType: listenType, 187 } 188 189 switch listenType { 190 case ListenTypeFollows: 191 follows, err := fetchFollows(r.Context(), xrpcClient, ud.followsCursor.Load(), did) 192 if err != nil { 193 logger.Error("error fetching follows", "error", err) 194 return 195 } 196 sd.listenTo = make(Set[syntax.DID]) 197 // use we have stored 198 ud.follows.Range(func(rk syntax.RecordKey, f bsky.GraphFollow) bool { 199 sd.listenTo[syntax.DID(f.Subject)] = struct{}{} 200 return true 201 }) 202 if len(follows) > 0 { 203 // store cursor for later requests so we dont have to fetch the whole thing again 204 ud.followsCursor.Store((*string)(&follows[len(follows)-1].rkey)) 205 for _, f := range follows { 206 ud.follows.Insert(f.rkey, f.follow) 207 sd.listenTo[syntax.DID(f.follow.Subject)] = struct{}{} 208 } 209 } 210 logger.Info("fetched follows") 211 case ListenTypeNone: 212 sd.listenTo = make(Set[syntax.DID]) 213 default: 214 http.Error(w, "invalid listen type", http.StatusBadRequest) 215 return 216 } 217 218 subscribers.Set(sid, sd) 219 for listenDid := range sd.listenTo { 220 markActorForLikes(sid, sd, listenDid) 221 } 222 updateStreamOpts() 223 // delete subscriber after we are done 224 defer func() { 225 for listenDid := range sd.listenTo { 226 unmarkActorForLikes(sid, listenDid) 227 } 228 subscribers.Del(sid) 229 updateStreamOpts() 230 }() 231 232 logger.Info("serving subscriber") 233 234 for { 235 var msg SubscriberMessage 236 err := conn.ReadJSON(&msg) 237 if err != nil { 238 logger.Info("WebSocket connection closed", "error", err) 239 break 240 } 241 switch msg.Type { 242 case "update_listen_to": 243 // only allow this if we arent managing listen to 244 if sd.listenType != ListenTypeNone { 245 continue 246 } 247 248 var innerMsg SubscriberUpdateListenTo 249 if err := json.Unmarshal(msg.Content, &innerMsg); err != nil { 250 logger.Info("invalid message", "error", err) 251 break 252 } 253 254 // remove all current listens and add the ones the user requested 255 for listenDid := range sd.listenTo { 256 unmarkActorForLikes(sid, listenDid) 257 delete(sd.listenTo, listenDid) 258 } 259 for _, listenDid := range innerMsg.ListenTo { 260 sd.listenTo[listenDid] = struct{}{} 261 markActorForLikes(sid, sd, listenDid) 262 } 263 264 updateStreamOpts() 265 } 266 } 267} 268 269func getLikeStreamOpts() models.SubscriberOptionsUpdatePayload { 270 return models.SubscriberOptionsUpdatePayload{ 271 WantedCollections: []string{"app.bsky.feed.like"}, 272 WantedDIDs: getLikeDids(), 273 } 274} 275 276func getFollowStreamOpts() models.SubscriberOptionsUpdatePayload { 277 return models.SubscriberOptionsUpdatePayload{ 278 WantedCollections: []string{"app.bsky.graph.follow"}, 279 WantedDIDs: getSubscriberDids(), 280 } 281} 282 283func updateStreamOpts() { 284 likeStreams.updateOpts() 285 followStreams.updateOpts() 286} 287 288func HandleLikeEvent(ctx context.Context, event *models.Event) error { 289 if event == nil || event.Commit == nil { 290 return nil 291 } 292 293 byDid := syntax.DID(event.Did) 294 // skip handling event if its not from a source we are listening to 295 ud, exists := actorData.Get(byDid) 296 if !exists || ud.targets.Len() == 0 { 297 return nil 298 } 299 300 logger := logger.With("actor", byDid, "type", "like") 301 302 deleted := event.Commit.Operation == models.CommitOperationDelete 303 rkey := syntax.RecordKey(event.Commit.RKey) 304 305 var like bsky.FeedLike 306 if deleted { 307 if l, exists := ud.likes.Get(rkey); exists { 308 like = l 309 defer ud.likes.Del(rkey) 310 } else { 311 logger.Error("like record not found", "rkey", rkey) 312 return nil 313 } 314 } else if err := unmarshalEvent(event, &like); err != nil { 315 return nil 316 } 317 318 // if there is no via it means its not a repost anyway 319 if like.Via == nil { 320 return nil 321 } 322 323 // store for later when it gets deleted so we can fetch the record 324 if !deleted { 325 ud.likes.Insert(rkey, like) 326 } 327 328 repostURI := syntax.ATURI(like.Via.Uri) 329 // if not a repost we dont care 330 if repostURI.Collection() != "app.bsky.feed.repost" { 331 return nil 332 } 333 reposterDID, err := repostURI.Authority().AsDID() 334 if err != nil { 335 return err 336 } 337 ud.targets.Range(func(sid string, sd *SubscriberData) bool { 338 if sd.forActor != reposterDID { 339 return true 340 } 341 342 if ud.profile == nil || time.Since(ud.profileFetchedAt) > time.Hour*24 { 343 profile, err := fetchProfile(ctx, byDid) 344 if err != nil { 345 logger.Error("cant fetch profile", "error", err) 346 } else { 347 ud.profile = profile 348 ud.profileFetchedAt = time.Now() 349 } 350 } 351 352 notification := NotificationMessage{ 353 Liked: !deleted, 354 Actor: NotificationActor{ 355 DID: byDid, 356 Profile: ud.profile, 357 }, 358 Record: like, 359 Time: event.TimeUS, 360 } 361 362 if err := sd.conn.WriteJSON(notification); err != nil { 363 logger.Error("failed to send notification", "error", err) 364 } 365 return true 366 }) 367 368 return nil 369} 370 371func HandleFollowEvent(ctx context.Context, event *models.Event) error { 372 if event == nil || event.Commit == nil { 373 return nil 374 } 375 376 byDid := syntax.DID(event.Did) 377 ud, exists := actorData.Get(byDid) 378 if !exists || ud.targets.Len() == 0 { 379 return nil 380 } 381 382 deleted := event.Commit.Operation == models.CommitOperationDelete 383 rkey := syntax.RecordKey(event.Commit.RKey) 384 385 switch event.Commit.Collection { 386 case "app.bsky.graph.follow": 387 var r bsky.GraphFollow 388 if deleted { 389 if f, exists := ud.follows.Get(rkey); exists { 390 r = f 391 } else { 392 // most likely no ListenTypeFollows subscriber attached on actor 393 logger.Warn("follow record not found", "rkey", rkey, "actor", byDid) 394 return nil 395 } 396 ud.follows.Del(rkey) 397 } else { 398 if err := unmarshalEvent(event, &r); err != nil { 399 return nil 400 } 401 ud.follows.Insert(rkey, r) 402 } 403 404 ud.targets.Range(func(sid string, sd *SubscriberData) bool { 405 // if we arent managing then we dont need to update anything 406 if sd.listenType != ListenTypeFollows { 407 return true 408 } 409 subjectDid := syntax.DID(r.Subject) 410 if deleted { 411 unmarkActorForLikes(sid, subjectDid) 412 delete(sd.listenTo, subjectDid) 413 } else { 414 sd.listenTo[subjectDid] = struct{}{} 415 markActorForLikes(sid, sd, subjectDid) 416 } 417 return true 418 }) 419 } 420 421 return nil 422} 423 424func unmarshalEvent[v any](event *models.Event, val *v) error { 425 if err := json.Unmarshal(event.Commit.Record, val); err != nil { 426 logger.Error("cant unmarshal record", "error", err, "raw", event.Commit.Record) 427 return err 428 } 429 return nil 430}