import asyncio import json import re from typing import Any, Callable import requests import websockets import cross import util.database as database import util.html_util as html_util import util.md_util as md_util from mastodon.common import MastodonPost from util.database import DataBaseWorker from util.media import MediaInfo, download_media from util.util import LOGGER, as_envvar ALLOWED_VISIBILITY = ["public", "unlisted"] MARKDOWNY = ["text/x.misskeymarkdown", "text/markdown", "text/plain"] class MastodonInputOptions: def __init__(self, o: dict) -> None: self.allowed_visibility = ALLOWED_VISIBILITY self.filters = [re.compile(f) for f in o.get("regex_filters", [])] allowed_visibility = o.get("allowed_visibility") if allowed_visibility is not None: if any([v not in ALLOWED_VISIBILITY for v in allowed_visibility]): raise ValueError( f"'allowed_visibility' only accepts {', '.join(ALLOWED_VISIBILITY)}, got: {allowed_visibility}" ) self.allowed_visibility = allowed_visibility class MastodonInput(cross.Input): def __init__(self, settings: dict, db: DataBaseWorker) -> None: self.options = MastodonInputOptions(settings.get("options", {})) self.token = as_envvar(settings.get("token")) or (_ for _ in ()).throw( ValueError("'token' is required") ) instance: str = as_envvar(settings.get("instance")) or (_ for _ in ()).throw( ValueError("'instance' is required") ) service = instance[:-1] if instance.endswith("/") else instance LOGGER.info("Verifying %s credentails...", service) responce = requests.get( f"{service}/api/v1/accounts/verify_credentials", headers={"Authorization": f"Bearer {self.token}"}, ) if responce.status_code != 200: LOGGER.error("Failed to validate user credentials!") responce.raise_for_status() return super().__init__(service, responce.json()["id"], settings, db) self.streaming = self._get_streaming_url() if not self.streaming: raise Exception("Instance %s does not support streaming!", service) def _get_streaming_url(self): response = requests.get(f"{self.service}/api/v1/instance") response.raise_for_status() data: dict = response.json() return (data.get("urls") or {}).get("streaming_api") def __to_tokens(self, status: dict): content_type = status.get("content_type", "text/plain") raw_text = status.get("text") tags: list[str] = [] for tag in status.get("tags", []): tags.append(tag["name"]) mentions: list[tuple[str, str]] = [] for mention in status.get("mentions", []): mentions.append(("@" + mention["username"], "@" + mention["acct"])) if raw_text and content_type in MARKDOWNY: return md_util.tokenize_markdown(raw_text, tags, mentions) akkoma_ext: dict | None = status.get("akkoma", {}).get("source") if akkoma_ext: if akkoma_ext.get("mediaType") in MARKDOWNY: return md_util.tokenize_markdown(akkoma_ext["content"], tags, mentions) tokenizer = html_util.HTMLPostTokenizer() tokenizer.mentions = mentions tokenizer.tags = tags tokenizer.feed(status.get("content", "")) return tokenizer.get_tokens() def _on_create_post(self, outputs: list[cross.Output], status: dict): # skip events from other users if (status.get("account") or {})["id"] != self.user_id: return if status.get("visibility") not in self.options.allowed_visibility: # Skip f/o and direct posts LOGGER.info( "Skipping '%s'! '%s' visibility..", status["id"], status.get("visibility"), ) return # TODO polls not supported on bsky. maybe 3rd party? skip for now # we don't handle reblogs. possible with bridgy(?) and self # we don't handle quotes. if status.get("poll"): LOGGER.info("Skipping '%s'! Contains a poll..", status["id"]) return if status.get("quote_id") or status.get("quote"): LOGGER.info("Skipping '%s'! Quote..", status["id"]) return reblog: dict | None = status.get("reblog") if reblog: if (reblog.get("account") or {})["id"] != self.user_id: LOGGER.info("Skipping '%s'! Reblog of other user..", status["id"]) return success = database.try_insert_repost( self.db, status["id"], reblog["id"], self.user_id, self.service ) if not success: LOGGER.info( "Skipping '%s' as reblogged post was not found in db!", status["id"] ) return for output in outputs: output.accept_repost(status["id"], reblog["id"]) return in_reply: str | None = status.get("in_reply_to_id") in_reply_to: str | None = status.get("in_reply_to_account_id") if in_reply_to and in_reply_to != self.user_id: # We don't support replies. LOGGER.info("Skipping '%s'! Reply to other user..", status["id"]) return success = database.try_insert_post( self.db, status["id"], in_reply, self.user_id, self.service ) if not success: LOGGER.info( "Skipping '%s' as parent post was not found in db!", status["id"] ) return tokens = self.__to_tokens(status) if not cross.test_filters(tokens, self.options.filters): LOGGER.info("Skipping '%s'. Matched a filter!", status["id"]) return LOGGER.info("Crossposting '%s'...", status["id"]) media_attachments: list[MediaInfo] = [] for attachment in status.get("media_attachments", []): LOGGER.info("Downloading %s...", attachment["url"]) info = download_media( attachment["url"], attachment.get("description") or "" ) if not info: LOGGER.error("Skipping '%s'. Failed to download media!", status["id"]) return media_attachments.append(info) cross_post = MastodonPost(status, tokens, media_attachments) for output in outputs: output.accept_post(cross_post) def _on_delete_post(self, outputs: list[cross.Output], identifier: str): post = database.find_post(self.db, identifier, self.user_id, self.service) if not post: return LOGGER.info("Deleting '%s'...", identifier) if post["reposted_id"]: for output in outputs: output.delete_repost(identifier) else: for output in outputs: output.delete_post(identifier) database.delete_post(self.db, identifier, self.user_id, self.service) def _on_post(self, outputs: list[cross.Output], event: str, payload: str): match event: case "update": self._on_create_post(outputs, json.loads(payload)) case "delete": self._on_delete_post(outputs, payload) async def listen( self, outputs: list[cross.Output], submit: Callable[[Callable[[], Any]], Any] ): uri = f"{self.streaming}/api/v1/streaming?stream=user&access_token={self.token}" async for ws in websockets.connect( uri, extra_headers={"User-Agent": "XPost/0.0.3"} ): try: LOGGER.info("Listening to %s...", self.streaming) async def listen_for_messages(): async for msg in ws: data = json.loads(msg) event: str = data.get("event") payload: str = data.get("payload") submit(lambda: self._on_post(outputs, str(event), str(payload))) listen = asyncio.create_task(listen_for_messages()) await asyncio.gather(listen) except websockets.ConnectionClosedError as e: LOGGER.error(e, stack_info=True, exc_info=True) LOGGER.info("Reconnecting to %s...", self.streaming) continue