import asyncio import json import re from abc import ABC from dataclasses import dataclass, field from typing import Any, cast, override import websockets from atproto.util import AtUri from bluesky.info import SERVICE, BlueskyService, validate_and_transform from cross.attachments import ( LabelsAttachment, LanguagesAttachment, MediaAttachment, RemoteUrlAttachment, ) from cross.media import Blob, download_blob from cross.post import Post from cross.service import InputService from database.connection import DatabasePool from util.util import normalize_service_url @dataclass(kw_only=True) class BlueskyInputOptions: handle: str | None = None did: str | None = None pds: str | None = None filters: list[re.Pattern[str]] = field(default_factory=lambda: []) @classmethod def from_dict(cls, data: dict[str, Any]) -> "BlueskyInputOptions": validate_and_transform(data) if "filters" in data: data["filters"] = [re.compile(r) for r in data["filters"]] return BlueskyInputOptions(**data) @dataclass(kw_only=True) class BlueskyJetstreamInputOptions(BlueskyInputOptions): jetstream: str = "wss://jetstream2.us-west.bsky.network/subscribe" @classmethod def from_dict(cls, data: dict[str, Any]) -> "BlueskyJetstreamInputOptions": jetstream = data.pop("jetstream", None) base = BlueskyInputOptions.from_dict(data).__dict__.copy() if jetstream: base["jetstream"] = normalize_service_url(jetstream) return BlueskyJetstreamInputOptions(**base) class BlueskyBaseInputService(BlueskyService, InputService, ABC): def __init__(self, db: DatabasePool) -> None: super().__init__(SERVICE, db) def _on_post(self, record: dict[str, Any]): post_uri = cast(str, record["$xpost.strongRef"]["uri"]) post_cid = cast(str, record["$xpost.strongRef"]["cid"]) parent_uri = cast( str, None if not record.get("reply") else record["reply"]["parent"]["uri"] ) parent = None if parent_uri: parent = self._get_post(self.url, self.did, parent_uri) if not parent: self.log.info( "Skipping %s, parent %s not found in db", post_uri, parent_uri ) return # TODO FRAGMENTS post = Post(id=post_uri, parent_id=parent_uri, text=record["text"]) did, _, rid = AtUri.record_uri(post_uri) post.attachments.put( RemoteUrlAttachment(url=f"https://bsky.app/profile/{did}/post/{rid}") ) embed = record.get("embed", {}) if embed: match cast(str, embed["$type"]): case "app.bsky.embed.record" | "app.bsky.embed.recordWithMedia": _, collection, _ = AtUri.record_uri( cast(str, embed["record"]["uri"]) ) if collection == "app.bsky.feed.post": self.log.info("Skipping '%s'! Quote..", post_uri) return case "app.bsky.embed.images": blobs: list[Blob] = [] for image in embed["images"]: blob_cid = image["image"]["ref"]["$link"] url = f"{self.pds}/xrpc/com.atproto.sync.getBlob?did={self.did}&cid={blob_cid}" self.log.info("Downloading %s...", blob_cid) blob: Blob | None = download_blob(url, image.get("alt")) if not blob: self.log.error( "Skipping %s! Failed to download blob %s.", post_uri, blob_cid, ) return blobs.append(blob) post.attachments.put(MediaAttachment(blobs=blobs)) case "app.bsky.embed.video": blob_cid = embed["video"]["ref"]["$link"] url = f"{self.pds}/xrpc/com.atproto.sync.getBlob?did={self.did}&cid={blob_cid}" self.log.info("Downloading %s...", blob_cid) blob: Blob | None = download_blob(url, embed.get("alt")) if not blob: self.log.error( "Skipping %s! Failed to download blob %s.", post_uri, blob_cid, ) return post.attachments.put(MediaAttachment(blobs=[blob])) case _: self.log.warning(f"Unhandled embedd type {embed['$type']}") pass if "langs" in record: post.attachments.put(LanguagesAttachment(langs=record["langs"])) if "labels" in record: post.attachments.put( LabelsAttachment( labels=[ label["val"].replace("-", " ") for label in record["values"] ] ), ) if parent: self._insert_post( { "user": self.did, "service": self.url, "identifier": post_uri, "parent": parent["id"], "root": parent["id"] if not parent["root"] else parent["root"], "extra_data": json.dumps({"cid": post_cid}), } ) else: self._insert_post( { "user": self.did, "service": self.url, "identifier": post_uri, "extra_data": json.dumps({"cid": post_cid}), } ) for out in self.outputs: self.submitter(lambda: out.accept_post(post)) def _on_repost(self, record: dict[str, Any]): post_uri = cast(str, record["$xpost.strongRef"]["uri"]) post_cid = cast(str, record["$xpost.strongRef"]["cid"]) reposted_uri = cast(str, record["subject"]["uri"]) reposted = self._get_post(self.url, self.did, reposted_uri) if not reposted: self.log.info( "Skipping repost '%s' as reposted post '%s' was not found in the db.", post_uri, reposted_uri, ) return self._insert_post( { "user": self.did, "service": self.url, "identifier": post_uri, "reposted": reposted["id"], "extra_data": json.dumps({"cid": post_cid}), } ) for out in self.outputs: self.submitter(lambda: out.accept_repost(post_uri, reposted_uri)) def _on_delete_post(self, post_id: str, repost: bool): post = self._get_post(self.url, self.did, post_id) if not post: return if repost: for output in self.outputs: self.submitter(lambda: output.delete_repost(post_id)) else: for output in self.outputs: self.submitter(lambda: output.delete_post(post_id)) self._delete_post_by_id(post["id"]) class BlueskyJetstreamInputService(BlueskyBaseInputService): def __init__(self, db: DatabasePool, options: BlueskyJetstreamInputOptions) -> None: super().__init__(db) self.options: BlueskyJetstreamInputOptions = options self._init_identity() @override def get_identity_options(self) -> tuple[str | None, str | None, str | None]: return (self.options.handle, self.options.did, self.options.pds) def _accept_msg(self, msg: websockets.Data) -> None: data: dict[str, Any] = cast(dict[str, Any], json.loads(msg)) if data.get("did") != self.did: return commit: dict[str, Any] | None = data.get("commit") if not commit: return commit_type: str = cast(str, commit["operation"]) match commit_type: case "create": record: dict[str, Any] = cast(dict[str, Any], commit["record"]) record["$xpost.strongRef"] = { "cid": commit["cid"], "uri": f"at://{self.did}/{commit['collection']}/{commit['rkey']}", } match cast(str, commit["collection"]): case "app.bsky.feed.post": self._on_post(record) case "app.bsky.feed.repost": self._on_repost(record) case _: pass case "delete": post_id: str = ( f"at://{self.did}/{commit['collection']}/{commit['rkey']}" ) match cast(str, commit["collection"]): case "app.bsky.feed.post": self._on_delete_post(post_id, False) case "app.bsky.feed.repost": self._on_delete_post(post_id, True) case _: pass case _: pass @override async def listen(self): url = self.options.jetstream + "?" url += "wantedCollections=app.bsky.feed.post" url += "&wantedCollections=app.bsky.feed.repost" url += f"&wantedDids={self.did}" async for ws in websockets.connect(url): try: self.log.info("Listening to %s...", self.options.jetstream) async def listen_for_messages(): async for msg in ws: self.submitter(lambda: self._accept_msg(msg)) listen = asyncio.create_task(listen_for_messages()) _ = await asyncio.gather(listen) except websockets.ConnectionClosedError as e: self.log.error(e, stack_info=True, exc_info=True) self.log.info("Reconnecting to %s...", self.options.jetstream) continue