its for when you want to get like notifications for your reposts
at main 11 kB view raw
1package main 2 3import ( 4 "context" 5 "encoding/json" 6 "errors" 7 "io" 8 "log" 9 "log/slog" 10 "net/http" 11 "sync/atomic" 12 "time" 13 14 "github.com/bluesky-social/indigo/api/bsky" 15 "github.com/bluesky-social/indigo/atproto/syntax" 16 "github.com/bluesky-social/indigo/xrpc" 17 "github.com/bluesky-social/jetstream/pkg/models" 18 "github.com/cornelk/hashmap" 19 "github.com/google/uuid" 20 "github.com/gorilla/mux" 21 "github.com/gorilla/websocket" 22) 23 24type Set[T comparable] map[T]struct{} 25 26const ListenTypeNone = "none" 27const ListenTypeFollows = "follows" 28 29type SubscriberData struct { 30 forActor syntax.DID 31 conn *websocket.Conn 32 listenType string 33 listenTo Set[syntax.DID] 34} 35 36type ActorData struct { 37 targets *hashmap.Map[string, *SubscriberData] 38 likes *hashmap.Map[syntax.RecordKey, bsky.FeedLike] 39 follows *hashmap.Map[syntax.RecordKey, bsky.GraphFollow] 40 followsCursor atomic.Pointer[string] 41 profile *bsky.ActorDefs_ProfileViewDetailed 42 profileFetchedAt time.Time 43} 44 45type NotificationActor struct { 46 DID syntax.DID `json:"did"` // the DID of the actor that (un)liked the post 47 Profile *bsky.ActorDefs_ProfileViewDetailed `json:"profile"` // the detailed profile of the actor that (un)liked the post 48} 49 50type NotificationMessage struct { 51 Liked bool `json:"liked"` // whether the message was liked or unliked 52 Actor NotificationActor `json:"actor"` // information about the actor that (un)liked 53 Record bsky.FeedLike `json:"record"` // the raw like record 54 Time int64 `json:"time"` // when the like event came in 55} 56 57type SubscriberMessage struct { 58 Type string `json:"type"` 59 Content json.RawMessage `json:"content"` 60} 61 62type SubscriberUpdateListenTo struct { 63 ListenTo []syntax.DID `json:"listen_to"` 64} 65 66var ( 67 // storing the subscriber data in both Should Be Fine 68 // we dont modify subscriber data at the same time in two places 69 subscribers = hashmap.New[string, *SubscriberData]() 70 actorData = hashmap.New[syntax.DID, *ActorData]() 71 72 likeStreams *StreamManager 73 followStreams *StreamManager 74 75 upgrader = websocket.Upgrader{ 76 CheckOrigin: func(r *http.Request) bool { 77 return true 78 }, 79 } 80 81 logger *slog.Logger 82) 83 84func getSubscriberDids() []string { 85 _dids := make(Set[string], subscribers.Len()) 86 subscribers.Range(func(s string, sd *SubscriberData) bool { 87 _dids[string(sd.forActor)] = struct{}{} 88 return true 89 }) 90 dids := make([]string, 0, len(_dids)) 91 for k := range _dids { 92 dids = append(dids, k) 93 } 94 return dids 95} 96 97func getLikeDids() []string { 98 _dids := make(Set[string], subscribers.Len()*5000) 99 subscribers.Range(func(s string, sd *SubscriberData) bool { 100 for did := range sd.listenTo { 101 _dids[string(did)] = struct{}{} 102 } 103 return true 104 }) 105 dids := make([]string, 0, len(_dids)) 106 for k := range _dids { 107 dids = append(dids, k) 108 } 109 return dids 110} 111 112func getActorData(did syntax.DID) *ActorData { 113 ud, _ := actorData.GetOrInsert(did, &ActorData{ 114 targets: hashmap.New[string, *SubscriberData](), 115 likes: hashmap.New[syntax.RecordKey, bsky.FeedLike](), 116 follows: hashmap.New[syntax.RecordKey, bsky.GraphFollow](), 117 }) 118 return ud 119} 120 121func markActorForLikes(sid string, sd *SubscriberData, did syntax.DID) { 122 ud := getActorData(did) 123 ud.targets.Insert(sid, sd) 124} 125 126func unmarkActorForLikes(sid string, did syntax.DID) { 127 if ud, exists := actorData.Get(did); exists { 128 ud.targets.Del(sid) 129 } 130} 131 132func main() { 133 logger = slog.Default() 134 135 likeStreams = NewStreamManager(logger, "like-tracker", HandleLikeEvent, getLikeStreamOpts) 136 followStreams = NewStreamManager(logger, "subscriber", HandleFollowEvent, getFollowStreamOpts) 137 138 r := mux.NewRouter() 139 r.HandleFunc("/subscribe/{did}", handleSubscribe).Methods("GET") 140 141 log.Println("server starting on :8080") 142 if err := http.ListenAndServe(":8080", r); err != nil { 143 log.Fatalf("error while serving: %s", err) 144 } 145} 146 147func handleSubscribe(w http.ResponseWriter, r *http.Request) { 148 vars := mux.Vars(r) 149 did, err := syntax.ParseDID(vars["did"]) 150 if err != nil { 151 http.Error(w, "not a valid did", http.StatusBadRequest) 152 return 153 } 154 sid := uuid.New().String() 155 156 query := r.URL.Query() 157 listenType := query.Get("listenTo") 158 if len(listenType) == 0 { 159 listenType = ListenTypeFollows 160 } 161 162 logger := logger.With("did", did, "subscriberId", sid) 163 164 conn, err := upgrader.Upgrade(w, r, nil) 165 if err != nil { 166 logger.Error("WebSocket upgrade failed", "error", err) 167 return 168 } 169 defer conn.Close() 170 171 logger.Info("new subscriber") 172 173 pdsURI, err := findUserPDS(r.Context(), did) 174 if err != nil { 175 logger.Error("cant resolve user pds", "error", err) 176 return 177 } 178 logger = logger.With("pds", pdsURI) 179 180 xrpcClient := &xrpc.Client{ 181 Host: pdsURI, 182 } 183 184 ud := getActorData(did) 185 sd := &SubscriberData{ 186 forActor: did, 187 conn: conn, 188 listenType: listenType, 189 } 190 191 switch listenType { 192 case ListenTypeFollows: 193 follows, err := fetchFollows(r.Context(), xrpcClient, ud.followsCursor.Load(), did) 194 if err != nil { 195 logger.Error("error fetching follows", "error", err) 196 return 197 } 198 sd.listenTo = make(Set[syntax.DID]) 199 // use we have stored 200 ud.follows.Range(func(rk syntax.RecordKey, f bsky.GraphFollow) bool { 201 sd.listenTo[syntax.DID(f.Subject)] = struct{}{} 202 return true 203 }) 204 if len(follows) > 0 { 205 // store cursor for later requests so we dont have to fetch the whole thing again 206 ud.followsCursor.Store((*string)(&follows[len(follows)-1].rkey)) 207 for _, f := range follows { 208 ud.follows.Insert(f.rkey, f.follow) 209 sd.listenTo[syntax.DID(f.follow.Subject)] = struct{}{} 210 } 211 } 212 logger.Info("fetched follows") 213 case ListenTypeNone: 214 sd.listenTo = make(Set[syntax.DID]) 215 default: 216 http.Error(w, "invalid listen type", http.StatusBadRequest) 217 return 218 } 219 220 subscribers.Set(sid, sd) 221 for listenDid := range sd.listenTo { 222 markActorForLikes(sid, sd, listenDid) 223 } 224 updateStreamOpts() 225 // delete subscriber after we are done 226 defer func() { 227 for listenDid := range sd.listenTo { 228 unmarkActorForLikes(sid, listenDid) 229 } 230 subscribers.Del(sid) 231 updateStreamOpts() 232 }() 233 234 logger.Info("serving subscriber") 235 236 // send pings 237 go func() { 238 for { 239 select { 240 case <-r.Context().Done(): 241 return 242 default: 243 conn.WriteMessage(websocket.PingMessage, []byte{}) 244 time.Sleep(time.Second * 15) 245 } 246 } 247 }() 248 249 for { 250 var msg SubscriberMessage 251 err := conn.ReadJSON(&msg) 252 // dont kill websocket if its no value 253 if err != nil && !errors.Is(err, io.ErrUnexpectedEOF) { 254 logger.Error("WebSocket connection closed", "error", err) 255 break 256 } 257 switch msg.Type { 258 case "update_listen_to": 259 // only allow this if we arent managing listen to 260 if sd.listenType != ListenTypeNone { 261 continue 262 } 263 264 var innerMsg SubscriberUpdateListenTo 265 if err := json.Unmarshal(msg.Content, &innerMsg); err != nil { 266 logger.Debug("invalid message", "error", err) 267 break 268 } 269 270 // remove all current listens and add the ones the user requested 271 for listenDid := range sd.listenTo { 272 unmarkActorForLikes(sid, listenDid) 273 delete(sd.listenTo, listenDid) 274 } 275 for _, listenDid := range innerMsg.ListenTo { 276 sd.listenTo[listenDid] = struct{}{} 277 markActorForLikes(sid, sd, listenDid) 278 } 279 280 updateStreamOpts() 281 } 282 } 283} 284 285func getLikeStreamOpts() models.SubscriberOptionsUpdatePayload { 286 return models.SubscriberOptionsUpdatePayload{ 287 WantedCollections: []string{"app.bsky.feed.like"}, 288 WantedDIDs: getLikeDids(), 289 } 290} 291 292func getFollowStreamOpts() models.SubscriberOptionsUpdatePayload { 293 return models.SubscriberOptionsUpdatePayload{ 294 WantedCollections: []string{"app.bsky.graph.follow"}, 295 WantedDIDs: getSubscriberDids(), 296 } 297} 298 299func updateStreamOpts() { 300 likeStreams.updateOpts() 301 followStreams.updateOpts() 302} 303 304func HandleLikeEvent(ctx context.Context, event *models.Event) error { 305 if event == nil || event.Commit == nil { 306 return nil 307 } 308 309 byDid := syntax.DID(event.Did) 310 // skip handling event if its not from a source we are listening to 311 ud, exists := actorData.Get(byDid) 312 if !exists || ud.targets.Len() == 0 { 313 return nil 314 } 315 316 logger := logger.With("actor", byDid, "type", "like") 317 318 deleted := event.Commit.Operation == models.CommitOperationDelete 319 rkey := syntax.RecordKey(event.Commit.RKey) 320 321 var like bsky.FeedLike 322 if deleted { 323 if l, exists := ud.likes.Get(rkey); exists { 324 like = l 325 defer ud.likes.Del(rkey) 326 } else { 327 logger.Error("like record not found", "rkey", rkey) 328 return nil 329 } 330 } else if err := unmarshalEvent(event, &like); err != nil { 331 return nil 332 } 333 334 // if there is no via it means its not a repost anyway 335 if like.Via == nil { 336 return nil 337 } 338 339 // store for later when it gets deleted so we can fetch the record 340 if !deleted { 341 ud.likes.Insert(rkey, like) 342 } 343 344 repostURI := syntax.ATURI(like.Via.Uri) 345 // if not a repost we dont care 346 if repostURI.Collection() != "app.bsky.feed.repost" { 347 return nil 348 } 349 reposterDID, err := repostURI.Authority().AsDID() 350 if err != nil { 351 return err 352 } 353 ud.targets.Range(func(sid string, sd *SubscriberData) bool { 354 if sd.forActor != reposterDID { 355 return true 356 } 357 358 if ud.profile == nil || time.Since(ud.profileFetchedAt) > time.Hour*24 { 359 profile, err := fetchProfile(ctx, byDid) 360 if err != nil { 361 logger.Error("cant fetch profile", "error", err) 362 } else { 363 ud.profile = profile 364 ud.profileFetchedAt = time.Now() 365 } 366 } 367 368 notification := NotificationMessage{ 369 Liked: !deleted, 370 Actor: NotificationActor{ 371 DID: byDid, 372 Profile: ud.profile, 373 }, 374 Record: like, 375 Time: event.TimeUS, 376 } 377 378 if err := sd.conn.WriteJSON(notification); err != nil { 379 logger.Error("failed to send notification", "error", err) 380 } 381 return true 382 }) 383 384 return nil 385} 386 387func HandleFollowEvent(ctx context.Context, event *models.Event) error { 388 if event == nil || event.Commit == nil { 389 return nil 390 } 391 392 byDid := syntax.DID(event.Did) 393 ud, exists := actorData.Get(byDid) 394 if !exists || ud.targets.Len() == 0 { 395 return nil 396 } 397 398 deleted := event.Commit.Operation == models.CommitOperationDelete 399 rkey := syntax.RecordKey(event.Commit.RKey) 400 401 switch event.Commit.Collection { 402 case "app.bsky.graph.follow": 403 var r bsky.GraphFollow 404 if deleted { 405 if f, exists := ud.follows.Get(rkey); exists { 406 r = f 407 } else { 408 // most likely no ListenTypeFollows subscriber attached on actor 409 logger.Warn("follow record not found", "rkey", rkey, "actor", byDid) 410 return nil 411 } 412 ud.follows.Del(rkey) 413 } else { 414 if err := unmarshalEvent(event, &r); err != nil { 415 return nil 416 } 417 ud.follows.Insert(rkey, r) 418 } 419 420 ud.targets.Range(func(sid string, sd *SubscriberData) bool { 421 // if we arent managing then we dont need to update anything 422 if sd.listenType != ListenTypeFollows { 423 return true 424 } 425 subjectDid := syntax.DID(r.Subject) 426 if deleted { 427 unmarkActorForLikes(sid, subjectDid) 428 delete(sd.listenTo, subjectDid) 429 } else { 430 sd.listenTo[subjectDid] = struct{}{} 431 markActorForLikes(sid, sd, subjectDid) 432 } 433 return true 434 }) 435 } 436 437 return nil 438} 439 440func unmarshalEvent[v any](event *models.Event, val *v) error { 441 if err := json.Unmarshal(event.Commit.Record, val); err != nil { 442 logger.Error("cant unmarshal record", "error", err, "raw", event.Commit.Record) 443 return err 444 } 445 return nil 446}