import asyncio import json import re from abc import ABC from dataclasses import dataclass, field from typing import Any, cast, override import websockets from atproto.util import AtUri from bluesky.tokens import tokenize_post from bluesky.info import SERVICE, BlueskyService, validate_and_transform from cross.attachments import ( LabelsAttachment, LanguagesAttachment, MediaAttachment, QuoteAttachment, RemoteUrlAttachment, ) from cross.media import Blob, download_blob from cross.post import Post from cross.service import InputService from database.connection import DatabasePool from util.util import normalize_service_url @dataclass(kw_only=True) class BlueskyInputOptions: handle: str | None = None did: str | None = None pds: str | None = None filters: list[re.Pattern[str]] = field(default_factory=lambda: []) @classmethod def from_dict(cls, data: dict[str, Any]) -> "BlueskyInputOptions": validate_and_transform(data) if "filters" in data: data["filters"] = [re.compile(r) for r in data["filters"]] return BlueskyInputOptions(**data) @dataclass(kw_only=True) class BlueskyJetstreamInputOptions(BlueskyInputOptions): jetstream: str = "wss://jetstream2.us-west.bsky.network/subscribe" @classmethod def from_dict(cls, data: dict[str, Any]) -> "BlueskyJetstreamInputOptions": jetstream = data.pop("jetstream", None) base = BlueskyInputOptions.from_dict(data).__dict__.copy() if jetstream: base["jetstream"] = normalize_service_url(jetstream) return BlueskyJetstreamInputOptions(**base) class BlueskyBaseInputService(BlueskyService, InputService, ABC): def __init__(self, db: DatabasePool) -> None: super().__init__(SERVICE, db) def _on_post(self, record: dict[str, Any]): post_uri = cast(str, record["$xpost.strongRef"]["uri"]) post_cid = cast(str, record["$xpost.strongRef"]["cid"]) parent_uri = cast( str, None if not record.get("reply") else record["reply"]["parent"]["uri"] ) parent = None if parent_uri: parent = self._get_post(self.url, self.did, parent_uri) if not parent: self.log.info( "Skipping %s, parent %s not found in db", post_uri, parent_uri ) return tokens = tokenize_post(record["text"], record.get('facets', {})) post = Post(id=post_uri, parent_id=parent_uri, tokens=tokens) did, _, rid = AtUri.record_uri(post_uri) post.attachments.put( RemoteUrlAttachment(url=f"https://bsky.app/profile/{did}/post/{rid}") ) embed: dict[str, Any] = record.get("embed", {}) blob_urls: list[tuple[str, str, str | None]] = [] def handle_embeds(embed: dict[str, Any]) -> str | None: nonlocal blob_urls, post match cast(str, embed["$type"]): case "app.bsky.embed.record" | "app.bsky.embed.recordWithMedia": rcrd = embed['record']['record'] if embed['record'].get('record') else embed['record'] did, collection, _ = AtUri.record_uri(rcrd["uri"]) if collection != "app.bsky.feed.post": return f"Unhandled record collection {collection}" if did != self.did: return "" rquote = self._get_post(self.url, did, rcrd["uri"]) if not rquote: return f"Quote {rcrd["uri"]} not found in the db" post.attachments.put(QuoteAttachment(quoted_id=rcrd["uri"], quoted_user=did)) if embed.get('media'): return handle_embeds(embed["media"]) case "app.bsky.embed.images": for image in embed["images"]: blob_cid = image["image"]["ref"]["$link"] url = f"{self.pds}/xrpc/com.atproto.sync.getBlob?did={self.did}&cid={blob_cid}" blob_urls.append((url, blob_cid, image.get("alt"))) case "app.bsky.embed.video": blob_cid = embed["video"]["ref"]["$link"] url = f"{self.pds}/xrpc/com.atproto.sync.getBlob?did={self.did}&cid={blob_cid}" blob_urls.append((url, blob_cid, embed.get("alt"))) case _: self.log.warning(f"Unhandled embed type {embed['$type']}") if embed: fexit = handle_embeds(embed) if fexit is not None: self.log.info("Skipping %s! %s", post_uri, fexit) return if blob_urls: blobs: list[Blob] = [] for url, cid, alt in blob_urls: self.log.info("Downloading %s...", cid) blob: Blob | None = download_blob(url, alt) if not blob: self.log.error( "Skipping %s! Failed to download blob %s.", post_uri, cid ) return blobs.append(blob) post.attachments.put(MediaAttachment(blobs=blobs)) if "langs" in record: post.attachments.put(LanguagesAttachment(langs=record["langs"])) if "labels" in record: post.attachments.put( LabelsAttachment( labels=[ label["val"].replace("-", " ") for label in record["values"] ] ), ) if parent: self._insert_post( { "user": self.did, "service": self.url, "identifier": post_uri, "parent": parent["id"], "root": parent["id"] if not parent["root"] else parent["root"], "extra_data": json.dumps({"cid": post_cid}), } ) else: self._insert_post( { "user": self.did, "service": self.url, "identifier": post_uri, "extra_data": json.dumps({"cid": post_cid}), } ) for out in self.outputs: self.submitter(lambda: out.accept_post(post)) def _on_repost(self, record: dict[str, Any]): post_uri = cast(str, record["$xpost.strongRef"]["uri"]) post_cid = cast(str, record["$xpost.strongRef"]["cid"]) reposted_uri = cast(str, record["subject"]["uri"]) reposted = self._get_post(self.url, self.did, reposted_uri) if not reposted: self.log.info( "Skipping repost '%s' as reposted post '%s' was not found in the db.", post_uri, reposted_uri, ) return self._insert_post( { "user": self.did, "service": self.url, "identifier": post_uri, "reposted": reposted["id"], "extra_data": json.dumps({"cid": post_cid}), } ) for out in self.outputs: self.submitter(lambda: out.accept_repost(post_uri, reposted_uri)) def _on_delete_post(self, post_id: str, repost: bool): post = self._get_post(self.url, self.did, post_id) if not post: return if repost: for output in self.outputs: self.submitter(lambda: output.delete_repost(post_id)) else: for output in self.outputs: self.submitter(lambda: output.delete_post(post_id)) self._delete_post_by_id(post["id"]) class BlueskyJetstreamInputService(BlueskyBaseInputService): def __init__(self, db: DatabasePool, options: BlueskyJetstreamInputOptions) -> None: super().__init__(db) self.options: BlueskyJetstreamInputOptions = options self._init_identity() @override def get_identity_options(self) -> tuple[str | None, str | None, str | None]: return (self.options.handle, self.options.did, self.options.pds) def _accept_msg(self, msg: websockets.Data) -> None: data: dict[str, Any] = cast(dict[str, Any], json.loads(msg)) if data.get("did") != self.did: return commit: dict[str, Any] | None = data.get("commit") if not commit: return commit_type: str = cast(str, commit["operation"]) match commit_type: case "create": record: dict[str, Any] = cast(dict[str, Any], commit["record"]) record["$xpost.strongRef"] = { "cid": commit["cid"], "uri": f"at://{self.did}/{commit['collection']}/{commit['rkey']}", } match cast(str, commit["collection"]): case "app.bsky.feed.post": self._on_post(record) case "app.bsky.feed.repost": self._on_repost(record) case _: pass case "delete": post_id: str = ( f"at://{self.did}/{commit['collection']}/{commit['rkey']}" ) match cast(str, commit["collection"]): case "app.bsky.feed.post": self._on_delete_post(post_id, False) case "app.bsky.feed.repost": self._on_delete_post(post_id, True) case _: pass case _: pass @override async def listen(self): url = self.options.jetstream + "?" url += "wantedCollections=app.bsky.feed.post" url += "&wantedCollections=app.bsky.feed.repost" url += f"&wantedDids={self.did}" async for ws in websockets.connect(url): try: self.log.info("Listening to %s...", self.options.jetstream) async def listen_for_messages(): async for msg in ws: self.submitter(lambda: self._accept_msg(msg)) listen = asyncio.create_task(listen_for_messages()) _ = await asyncio.gather(listen) except websockets.ConnectionClosedError as e: self.log.error(e, stack_info=True, exc_info=True) self.log.info("Reconnecting to %s...", self.options.jetstream) continue