its for when you want to get like notifications for your reposts

fix: actually make the ws subscribers separate oops

ptr.pet eea4e683 360d403f

verified
Changed files
+116 -82
+97 -71
main.go
···
"log"
"log/slog"
"net/http"
+
"sync/atomic"
"github.com/bluesky-social/indigo/api/bsky"
"github.com/bluesky-social/indigo/atproto/syntax"
···
"github.com/bluesky-social/jetstream/pkg/client"
"github.com/bluesky-social/jetstream/pkg/models"
"github.com/cornelk/hashmap"
+
"github.com/google/uuid"
"github.com/gorilla/mux"
"github.com/gorilla/websocket"
)
···
const ListenTypeFollows = "follows"
type SubscriberData struct {
-
DID syntax.DID
-
Conn *websocket.Conn
-
ListenType string
-
ListenTo Set[syntax.DID]
-
follows map[syntax.RecordKey]bsky.GraphFollow
+
SubscribedTo syntax.DID
+
Conn *websocket.Conn
+
ListenType string
+
ListenTo Set[syntax.DID]
}
-
type ListeneeData struct {
-
targets *hashmap.Map[syntax.DID, *SubscriberData]
-
likes map[syntax.RecordKey]bsky.FeedLike
+
type UserData struct {
+
targets *hashmap.Map[string, *SubscriberData]
+
likes map[syntax.RecordKey]bsky.FeedLike
+
follows *hashmap.Map[syntax.RecordKey, bsky.GraphFollow]
+
followsCursor atomic.Pointer[string]
}
type NotificationMessage struct {
···
var (
// storing the subscriber data in both Should Be Fine
// we dont modify subscriber data at the same time in two places
-
subscribers = hashmap.New[syntax.DID, *SubscriberData]()
-
listeningTo = hashmap.New[syntax.DID, *ListeneeData]()
+
subscribers = hashmap.New[string, *SubscriberData]()
+
userData = hashmap.New[syntax.DID, *UserData]()
-
likeStream *client.Client
-
subscriberStream *client.Client
+
likeStream *client.Client
+
followStream *client.Client
upgrader = websocket.Upgrader{
CheckOrigin: func(r *http.Request) bool {
···
func getSubscriberDids() []string {
dids := make([]string, 0, subscribers.Len())
-
subscribers.Range(func(s syntax.DID, sd *SubscriberData) bool {
-
dids = append(dids, string(s))
+
subscribers.Range(func(s string, sd *SubscriberData) bool {
+
dids = append(dids, string(sd.SubscribedTo))
return true
})
return dids
}
-
func startListeningTo(sd *SubscriberData, did syntax.DID) {
-
ld, _ := listeningTo.GetOrInsert(did, &ListeneeData{
-
targets: hashmap.New[syntax.DID, *SubscriberData](),
+
func getUserData(did syntax.DID) *UserData {
+
ud, _ := userData.GetOrInsert(did, &UserData{
+
targets: hashmap.New[string, *SubscriberData](),
likes: make(map[syntax.RecordKey]bsky.FeedLike),
+
follows: hashmap.New[syntax.RecordKey, bsky.GraphFollow](),
})
-
ld.targets.Insert(sd.DID, sd)
+
return ud
}
-
func stopListeningTo(subscriberDid, did syntax.DID) {
-
if ld, exists := listeningTo.Get(did); exists {
-
ld.targets.Del(subscriberDid)
+
func startListeningTo(sid string, sd *SubscriberData, did syntax.DID) {
+
ud := getUserData(did)
+
ud.targets.Insert(sid, sd)
+
}
+
+
func stopListeningTo(sid string, did syntax.DID) {
+
if ud, exists := userData.Get(did); exists {
+
ud.targets.Del(sid)
}
}
···
logger = slog.Default()
go startJetstreamLoop(logger, &likeStream, "like_tracker", HandleLikeEvent, getLikeStreamOpts)
-
go startJetstreamLoop(logger, &subscriberStream, "subscriber", HandleSubscriberEvent, getSubscriberStreamOpts)
+
go startJetstreamLoop(logger, &followStream, "subscriber", HandleFollowEvent, getFollowStreamOpts)
r := mux.NewRouter()
r.HandleFunc("/subscribe/{did}", handleSubscribe).Methods("GET")
···
http.Error(w, "not a valid did", http.StatusBadRequest)
return
}
+
sid := uuid.New().String()
query := r.URL.Query()
listenType := query.Get("listenTo")
···
listenType = ListenTypeFollows
}
-
logger := logger.With("did", did)
+
logger := logger.With("did", did, "subscriberId", sid)
conn, err := upgrader.Upgrade(w, r, nil)
if err != nil {
···
Host: pdsURI,
}
+
ud := getUserData(did)
sd := &SubscriberData{
-
DID: did,
-
Conn: conn,
-
ListenType: listenType,
+
SubscribedTo: did,
+
Conn: conn,
+
ListenType: listenType,
}
switch listenType {
case ListenTypeFollows:
-
follows, err := fetchFollows(r.Context(), xrpcClient, did)
+
follows, err := fetchFollows(r.Context(), xrpcClient, ud.followsCursor.Load(), did)
if err != nil {
logger.Error("error fetching follows", "error", err)
return
}
-
logger.Info("fetched follows")
-
sd.follows = follows
sd.ListenTo = make(Set[syntax.DID])
-
for _, follow := range follows {
-
sd.ListenTo[syntax.DID(follow.Subject)] = struct{}{}
+
if len(follows) > 0 {
+
// store cursor for later requests so we dont have to fetch the whole thing again
+
ud.followsCursor.Store((*string)(&follows[len(follows)-1].rkey))
+
for _, f := range follows {
+
ud.follows.Insert(f.rkey, f.follow)
+
sd.ListenTo[syntax.DID(f.follow.Subject)] = struct{}{}
+
}
}
+
logger.Info("fetched follows")
case ListenTypeNone:
sd.ListenTo = make(Set[syntax.DID])
default:
···
return
}
-
subscribers.Set(sd.DID, sd)
+
subscribers.Set(sid, sd)
for listenDid := range sd.ListenTo {
-
startListeningTo(sd, listenDid)
+
startListeningTo(sid, sd, listenDid)
}
-
updateSubscriberStreamOpts()
+
updateFollowStreamOpts()
// delete subscriber after we are done
defer func() {
for listenDid := range sd.ListenTo {
-
stopListeningTo(sd.DID, listenDid)
+
stopListeningTo(sid, listenDid)
}
-
subscribers.Del(sd.DID)
-
updateSubscriberStreamOpts()
+
subscribers.Del(sid)
+
updateFollowStreamOpts()
}()
logger.Info("serving subscriber")
···
}
// remove all current listens and add the ones the user requested
for listenDid := range sd.ListenTo {
-
stopListeningTo(sd.DID, listenDid)
+
stopListeningTo(sid, listenDid)
delete(sd.ListenTo, listenDid)
}
for _, listenDid := range innerMsg.ListenTo {
sd.ListenTo[listenDid] = struct{}{}
-
startListeningTo(sd, listenDid)
+
startListeningTo(sid, sd, listenDid)
}
}
}
···
}
}
-
func getSubscriberStreamOpts() models.SubscriberOptionsUpdatePayload {
+
func getFollowStreamOpts() models.SubscriberOptionsUpdatePayload {
return models.SubscriberOptionsUpdatePayload{
-
WantedCollections: []string{"app.bsky.feed.repost", "app.bsky.graph.follow"},
+
WantedCollections: []string{"app.bsky.graph.follow"},
WantedDIDs: getSubscriberDids(),
}
}
-
func updateSubscriberStreamOpts() {
-
opts := getSubscriberStreamOpts()
-
err := subscriberStream.SendOptionsUpdate(opts)
+
func updateFollowStreamOpts() {
+
opts := getFollowStreamOpts()
+
err := followStream.SendOptionsUpdate(opts)
if err != nil {
-
logger.Error("couldnt update subscriber stream opts", "error", err)
+
logger.Error("couldnt update follow stream opts", "error", err)
return
}
-
logger.Info("updated subscriber stream opts", "userCount", len(opts.WantedDIDs))
+
logger.Info("updated follow stream opts", "userCount", len(opts.WantedDIDs))
}
func HandleLikeEvent(ctx context.Context, event *models.Event) error {
···
byDid := syntax.DID(event.Did)
// skip handling event if its not from a source we are listening to
-
ld, exists := listeningTo.Get(byDid)
-
if !exists {
+
ud, exists := userData.Get(byDid)
+
if !exists || ud.targets.Len() == 0 {
return nil
}
···
var like bsky.FeedLike
if deleted {
-
if l, exists := ld.likes[rkey]; exists {
+
if l, exists := ud.likes[rkey]; exists {
like = l
-
defer delete(ld.likes, rkey)
+
defer delete(ud.likes, rkey)
} else {
logger.Error("like record not found", "rkey", rkey)
return nil
···
// store for later when it gets deleted so we can fetch the record
if !deleted {
-
ld.likes[rkey] = like
+
ud.likes[rkey] = like
}
repostURI := syntax.ATURI(like.Via.Uri)
···
if err != nil {
return err
}
-
if sd, exists := ld.targets.Get(reposterDID); exists {
+
ud.targets.Range(func(sid string, sd *SubscriberData) bool {
+
if sd.SubscribedTo != reposterDID {
+
return true
+
}
+
notification := NotificationMessage{
Liked: !deleted,
ByDid: byDid,
···
}
if err := sd.Conn.WriteJSON(notification); err != nil {
-
logger.Error("failed to send notification", "subscriber", sd.DID, "error", err)
+
logger.Error("failed to send notification", "subscriber", sd.SubscribedTo, "error", err)
}
-
}
+
return true
+
})
return nil
}
-
func HandleSubscriberEvent(ctx context.Context, event *models.Event) error {
+
func HandleFollowEvent(ctx context.Context, event *models.Event) error {
if event == nil || event.Commit == nil {
return nil
}
byDid := syntax.DID(event.Did)
-
sd, exists := subscribers.Get(byDid)
-
if !exists {
+
ud, exists := userData.Get(byDid)
+
if !exists || ud.targets.Len() == 0 {
return nil
}
···
switch event.Commit.Collection {
case "app.bsky.graph.follow":
-
// if we arent managing then we dont need to update anything
-
if sd.ListenType != ListenTypeFollows {
-
return nil
-
}
var r bsky.GraphFollow
if deleted {
-
if f, exists := sd.follows[rkey]; exists {
+
if f, exists := ud.follows.Get(rkey); exists {
r = f
} else {
logger.Error("follow record not found", "rkey", rkey)
return nil
}
-
subjectDid := syntax.DID(r.Subject)
-
stopListeningTo(sd.DID, subjectDid)
-
delete(sd.ListenTo, subjectDid)
-
delete(sd.follows, rkey)
+
ud.follows.Del(rkey)
} else {
if err := unmarshalEvent(event, &r); err != nil {
-
return err
+
logger.Error("could not unmarshal follow event", "error", err)
+
return nil
+
}
+
ud.follows.Insert(rkey, r)
+
}
+
ud.targets.Range(func(sid string, sd *SubscriberData) bool {
+
// if we arent managing then we dont need to update anything
+
if sd.ListenType != ListenTypeFollows {
+
return true
}
subjectDid := syntax.DID(r.Subject)
-
sd.ListenTo[subjectDid] = struct{}{}
-
sd.follows[rkey] = r
-
startListeningTo(sd, subjectDid)
-
}
+
if deleted {
+
stopListeningTo(sid, subjectDid)
+
delete(sd.ListenTo, subjectDid)
+
} else {
+
sd.ListenTo[subjectDid] = struct{}{}
+
startListeningTo(sid, sd, subjectDid)
+
}
+
return true
+
})
}
return nil
+19 -11
xrpc.go
···
return nil
}
-
func fetchRecords[v any](ctx context.Context, xrpcClient *xrpc.Client, cb func(syntax.ATURI, v), collection string, did syntax.DID) error {
+
func fetchRecords[v any](ctx context.Context, xrpcClient *xrpc.Client, cb func(syntax.ATURI, v), cursor *string, collection string, did syntax.DID) error {
if xrpcClient == nil {
pdsURI, err := findUserPDS(ctx, did)
if err != nil {
···
}
}
-
cursor := ""
+
var cur string = ""
+
if cursor != nil {
+
cur = *cursor
+
}
for {
// todo: ratelimits?? idk what this does for those
-
out, err := atproto.RepoListRecords(ctx, xrpcClient, collection, cursor, 100, string(did), false)
+
out, err := atproto.RepoListRecords(ctx, xrpcClient, collection, cur, 100, string(did), true)
if err != nil {
return err
}
···
if out.Cursor == nil || *out.Cursor == "" {
break
}
-
cursor = *out.Cursor
-
-
break
+
cur = *out.Cursor
}
return nil
}
-
func fetchFollows(ctx context.Context, xrpcClient *xrpc.Client, did syntax.DID) (map[syntax.RecordKey]bsky.GraphFollow, error) {
-
out := make(map[syntax.RecordKey]bsky.GraphFollow)
-
fetchRecords(ctx, xrpcClient, func(uri syntax.ATURI, f bsky.GraphFollow) { out[uri.RecordKey()] = f }, "app.bsky.graph.follow", did)
+
type FetchFollowItem struct {
+
rkey syntax.RecordKey
+
follow bsky.GraphFollow
+
}
+
+
func fetchFollows(ctx context.Context, xrpcClient *xrpc.Client, cursor *string, did syntax.DID) ([]FetchFollowItem, error) {
+
out := make([]FetchFollowItem, 0)
+
fetchRecords(ctx, xrpcClient, func(uri syntax.ATURI, f bsky.GraphFollow) {
+
out = append(out, FetchFollowItem{rkey: uri.RecordKey(), follow: f})
+
}, cursor, "app.bsky.graph.follow", did)
return out, nil
}
-
func fetchRepostLikes(ctx context.Context, xrpcClient *xrpc.Client, did syntax.DID) (map[syntax.RecordKey]bsky.FeedLike, error) {
+
func fetchRepostLikes(ctx context.Context, xrpcClient *xrpc.Client, cursor *string, did syntax.DID) (map[syntax.RecordKey]bsky.FeedLike, error) {
out := make(map[syntax.RecordKey]bsky.FeedLike)
fetchRecords(ctx, xrpcClient, func(uri syntax.ATURI, f bsky.FeedLike) {
if f.Via != nil && syntax.ATURI(f.Via.Uri).Collection() == "app.bsky.feed.repost" {
out[uri.RecordKey()] = f
}
-
}, "app.bsky.feed.like", did)
+
}, cursor, "app.bsky.feed.like", did)
return out, nil
}