import asyncio import json import re import uuid from typing import Any, Callable import requests import websockets import cross import util.database as database import util.md_util as md_util from misskey.common import MisskeyPost from util.media import MediaInfo, download_media from util.util import LOGGER, as_envvar ALLOWED_VISIBILITY = ["public", "home"] class MisskeyInputOptions: def __init__(self, o: dict) -> None: self.allowed_visibility = ALLOWED_VISIBILITY self.filters = [re.compile(f) for f in o.get("regex_filters", [])] allowed_visibility = o.get("allowed_visibility") if allowed_visibility is not None: if any([v not in ALLOWED_VISIBILITY for v in allowed_visibility]): raise ValueError( f"'allowed_visibility' only accepts {', '.join(ALLOWED_VISIBILITY)}, got: {allowed_visibility}" ) self.allowed_visibility = allowed_visibility class MisskeyInput(cross.Input): def __init__(self, settings: dict, db: cross.DataBaseWorker) -> None: self.options = MisskeyInputOptions(settings.get("options", {})) self.token = as_envvar(settings.get("token")) or (_ for _ in ()).throw( ValueError("'token' is required") ) instance: str = as_envvar(settings.get("instance")) or (_ for _ in ()).throw( ValueError("'instance' is required") ) service = instance[:-1] if instance.endswith("/") else instance LOGGER.info("Verifying %s credentails...", service) responce = requests.post( f"{instance}/api/i", json={"i": self.token}, headers={"Content-Type": "application/json"}, ) if responce.status_code != 200: LOGGER.error("Failed to validate user credentials!") responce.raise_for_status() return super().__init__(service, responce.json()["id"], settings, db) def _on_note(self, outputs: list[cross.Output], note: dict): if note["userId"] != self.user_id: return if note.get("visibility") not in self.options.allowed_visibility: LOGGER.info( "Skipping '%s'! '%s' visibility..", note["id"], note.get("visibility") ) return # TODO polls not supported on bsky. maybe 3rd party? skip for now # we don't handle reblogs. possible with bridgy(?) and self if note.get("poll"): LOGGER.info("Skipping '%s'! Contains a poll..", note["id"]) return renote: dict | None = note.get("renote") if renote: if note.get("text") is not None: LOGGER.info("Skipping '%s'! Quote..", note["id"]) return if renote.get("userId") != self.user_id: LOGGER.info("Skipping '%s'! Reblog of other user..", note["id"]) return success = database.try_insert_repost( self.db, note["id"], renote["id"], self.user_id, self.service ) if not success: LOGGER.info( "Skipping '%s' as renoted note was not found in db!", note["id"] ) return for output in outputs: output.accept_repost(note["id"], renote["id"]) return reply_id: str | None = note.get("replyId") if reply_id: if note.get("reply", {}).get("userId") != self.user_id: LOGGER.info("Skipping '%s'! Reply to other user..", note["id"]) return success = database.try_insert_post( self.db, note["id"], reply_id, self.user_id, self.service ) if not success: LOGGER.info("Skipping '%s' as parent note was not found in db!", note["id"]) return mention_handles: dict = note.get("mentionHandles") or {} tags: list[str] = note.get("tags") or [] handles: list[tuple[str, str]] = [] for key, value in mention_handles.items(): handles.append((value, value)) tokens = md_util.tokenize_markdown(note.get("text", ""), tags, handles) if not cross.test_filters(tokens, self.options.filters): LOGGER.info("Skipping '%s'. Matched a filter!", note["id"]) return LOGGER.info("Crossposting '%s'...", note["id"]) media_attachments: list[MediaInfo] = [] for attachment in note.get("files", []): LOGGER.info("Downloading %s...", attachment["url"]) info = download_media(attachment["url"], attachment.get("comment") or "") if not info: LOGGER.error("Skipping '%s'. Failed to download media!", note["id"]) return media_attachments.append(info) cross_post = MisskeyPost(self.service, note, tokens, media_attachments) for output in outputs: output.accept_post(cross_post) def _on_delete(self, outputs: list[cross.Output], note: dict): # TODO handle deletes pass def _on_message(self, outputs: list[cross.Output], data: dict): if data["type"] == "channel": type: str = data["body"]["type"] if type == "note" or type == "reply": note_body = data["body"]["body"] self._on_note(outputs, note_body) return pass async def _send_keepalive(self, ws: websockets.WebSocketClientProtocol): while ws.open: try: await asyncio.sleep(120) if ws.open: await ws.send("h") LOGGER.debug("Sent keepalive h..") else: LOGGER.info("WebSocket is closed, stopping keepalive task.") break except Exception as e: LOGGER.error(f"Error sending keepalive: {e}") break async def _subscribe_to_home(self, ws: websockets.WebSocketClientProtocol): await ws.send( json.dumps( { "type": "connect", "body": {"channel": "homeTimeline", "id": str(uuid.uuid4())}, } ) ) LOGGER.info("Subscribed to 'homeTimeline' channel...") async def listen( self, outputs: list[cross.Output], submit: Callable[[Callable[[], Any]], Any] ): streaming: str = f"wss://{self.service.split('://', 1)[1]}" url: str = f"{streaming}/streaming?i={self.token}" async for ws in websockets.connect( url, extra_headers={"User-Agent": "XPost/0.0.3"} ): try: LOGGER.info("Listening to %s...", streaming) await self._subscribe_to_home(ws) async def listen_for_messages(): async for msg in ws: # TODO listen to deletes somehow submit(lambda: self._on_message(outputs, json.loads(msg))) keepalive = asyncio.create_task(self._send_keepalive(ws)) listen = asyncio.create_task(listen_for_messages()) await asyncio.gather(keepalive, listen) except websockets.ConnectionClosedError as e: LOGGER.error(e, stack_info=True, exc_info=True) LOGGER.info("Reconnecting to %s...", streaming) continue