import asyncio import json import re from typing import Any, Callable import websockets from atproto_client import models from atproto_client.models.utils import get_or_create as get_model_or_create import cross import util.database as database from bluesky.atproto2 import resolve_identity from bluesky.common import SERVICE, BlueskyPost, tokenize_post from util.database import DataBaseWorker from util.media import MediaInfo, download_media from util.util import LOGGER, as_envvar class BlueskyInputOptions: def __init__(self, o: dict) -> None: self.filters = [re.compile(f) for f in o.get("regex_filters", [])] class BlueskyInput(cross.Input): def __init__(self, settings: dict, db: DataBaseWorker) -> None: self.options = BlueskyInputOptions(settings.get("options", {})) did, pds = resolve_identity( handle=as_envvar(settings.get("handle")), did=as_envvar(settings.get("did")), pds=as_envvar(settings.get("pds")), ) self.pds = pds # PDS is Not a service, the lexicon and rids are the same across pds super().__init__(SERVICE, did, settings, db) def _on_post(self, outputs: list[cross.Output], post: dict[str, Any]): post_uri = post["$xpost.strongRef"]["uri"] post_cid = post["$xpost.strongRef"]["cid"] parent_uri = None if post.get("reply"): parent_uri = post["reply"]["parent"]["uri"] embed = post.get("embed", {}) if embed.get("$type") in ( "app.bsky.embed.record", "app.bsky.embed.recordWithMedia", ): did, collection, rid = str(embed["record"]["uri"][len("at://") :]).split( "/" ) if collection == "app.bsky.feed.post": LOGGER.info("Skipping '%s'! Quote..", post_uri) return success = database.try_insert_post( self.db, post_uri, parent_uri, self.user_id, self.service ) if not success: LOGGER.info("Skipping '%s' as parent post was not found in db!", post_uri) return database.store_data( self.db, post_uri, self.user_id, self.service, {"cid": post_cid} ) tokens = tokenize_post(post) if not cross.test_filters(tokens, self.options.filters): LOGGER.info("Skipping '%s'. Matched a filter!", post_uri) return LOGGER.info("Crossposting '%s'...", post_uri) def get_blob_url(blob: str): return f"{self.pds}/xrpc/com.atproto.sync.getBlob?did={self.user_id}&cid={blob}" attachments: list[MediaInfo] = [] if embed.get("$type") == "app.bsky.embed.images": model = get_model_or_create(embed, model=models.AppBskyEmbedImages.Main) assert isinstance(model, models.AppBskyEmbedImages.Main) for image in model.images: url = get_blob_url(image.image.cid.encode()) LOGGER.info("Downloading %s...", url) io = download_media(url, image.alt) if not io: LOGGER.error("Skipping '%s'. Failed to download media!", post_uri) return attachments.append(io) elif embed.get("$type") == "app.bsky.embed.video": model = get_model_or_create(embed, model=models.AppBskyEmbedVideo.Main) assert isinstance(model, models.AppBskyEmbedVideo.Main) url = get_blob_url(model.video.cid.encode()) LOGGER.info("Downloading %s...", url) io = download_media(url, model.alt if model.alt else "") if not io: LOGGER.error("Skipping '%s'. Failed to download media!", post_uri) return attachments.append(io) cross_post = BlueskyPost(post, tokens, attachments) for output in outputs: output.accept_post(cross_post) def _on_delete_post(self, outputs: list[cross.Output], post_id: str, repost: bool): post = database.find_post(self.db, post_id, self.user_id, self.service) if not post: return LOGGER.info("Deleting '%s'...", post_id) if repost: for output in outputs: output.delete_repost(post_id) else: for output in outputs: output.delete_post(post_id) database.delete_post(self.db, post_id, self.user_id, self.service) def _on_repost(self, outputs: list[cross.Output], post: dict[str, Any]): post_uri = post["$xpost.strongRef"]["uri"] post_cid = post["$xpost.strongRef"]["cid"] reposted_uri = post["subject"]["uri"] success = database.try_insert_repost( self.db, post_uri, reposted_uri, self.user_id, self.service ) if not success: LOGGER.info("Skipping '%s' as reposted post was not found in db!", post_uri) return database.store_data( self.db, post_uri, self.user_id, self.service, {"cid": post_cid} ) LOGGER.info("Crossposting '%s'...", post_uri) for output in outputs: output.accept_repost(post_uri, reposted_uri) class BlueskyJetstreamInput(BlueskyInput): def __init__(self, settings: dict, db: DataBaseWorker) -> None: super().__init__(settings, db) self.jetstream = settings.get( "jetstream", "wss://jetstream2.us-east.bsky.network/subscribe" ) def __on_commit(self, outputs: list[cross.Output], msg: dict): if msg.get("did") != self.user_id: return commit: dict = msg.get("commit", {}) if not commit: return commit_type = commit["operation"] match commit_type: case "create": record = dict(commit.get("record", {})) record["$xpost.strongRef"] = { "cid": commit["cid"], "uri": f"at://{self.user_id}/{commit['collection']}/{commit['rkey']}", } match commit["collection"]: case "app.bsky.feed.post": self._on_post(outputs, record) case "app.bsky.feed.repost": self._on_repost(outputs, record) case "delete": post_id: str = ( f"at://{self.user_id}/{commit['collection']}/{commit['rkey']}" ) match commit["collection"]: case "app.bsky.feed.post": self._on_delete_post(outputs, post_id, False) case "app.bsky.feed.repost": self._on_delete_post(outputs, post_id, True) async def listen( self, outputs: list[cross.Output], submit: Callable[[Callable[[], Any]], Any] ): uri = self.jetstream + "?" uri += "wantedCollections=app.bsky.feed.post" uri += "&wantedCollections=app.bsky.feed.repost" uri += f"&wantedDids={self.user_id}" async for ws in websockets.connect( uri, extra_headers={"User-Agent": "XPost/0.0.3"} ): try: LOGGER.info("Listening to %s...", self.jetstream) async def listen_for_messages(): async for msg in ws: submit(lambda: self.__on_commit(outputs, json.loads(msg))) listen = asyncio.create_task(listen_for_messages()) await asyncio.gather(listen) except websockets.ConnectionClosedError as e: LOGGER.error(e, stack_info=True, exc_info=True) LOGGER.info("Reconnecting to %s...", self.jetstream) continue