social media crossposting tool. 3rd time's the charm
mastodon misskey crossposting bluesky
at master 7.5 kB view raw
1import asyncio 2import json 3import re 4import uuid 5from typing import Any, Callable 6 7import requests 8import websockets 9 10import cross 11import util.database as database 12import util.md_util as md_util 13from misskey.common import MisskeyPost 14from util.media import MediaInfo, download_media 15from util.util import LOGGER, as_envvar 16 17ALLOWED_VISIBILITY = ["public", "home"] 18 19 20class MisskeyInputOptions: 21 def __init__(self, o: dict) -> None: 22 self.allowed_visibility = ALLOWED_VISIBILITY 23 self.filters = [re.compile(f) for f in o.get("regex_filters", [])] 24 25 allowed_visibility = o.get("allowed_visibility") 26 if allowed_visibility is not None: 27 if any([v not in ALLOWED_VISIBILITY for v in allowed_visibility]): 28 raise ValueError( 29 f"'allowed_visibility' only accepts {', '.join(ALLOWED_VISIBILITY)}, got: {allowed_visibility}" 30 ) 31 self.allowed_visibility = allowed_visibility 32 33 34class MisskeyInput(cross.Input): 35 def __init__(self, settings: dict, db: cross.DataBaseWorker) -> None: 36 self.options = MisskeyInputOptions(settings.get("options", {})) 37 self.token = as_envvar(settings.get("token")) or (_ for _ in ()).throw( 38 ValueError("'token' is required") 39 ) 40 instance: str = as_envvar(settings.get("instance")) or (_ for _ in ()).throw( 41 ValueError("'instance' is required") 42 ) 43 44 service = instance[:-1] if instance.endswith("/") else instance 45 46 LOGGER.info("Verifying %s credentails...", service) 47 responce = requests.post( 48 f"{instance}/api/i", 49 json={"i": self.token}, 50 headers={"Content-Type": "application/json"}, 51 ) 52 if responce.status_code != 200: 53 LOGGER.error("Failed to validate user credentials!") 54 responce.raise_for_status() 55 return 56 57 super().__init__(service, responce.json()["id"], settings, db) 58 59 def _on_note(self, outputs: list[cross.Output], note: dict): 60 if note["userId"] != self.user_id: 61 return 62 63 if note.get("visibility") not in self.options.allowed_visibility: 64 LOGGER.info( 65 "Skipping '%s'! '%s' visibility..", note["id"], note.get("visibility") 66 ) 67 return 68 69 # TODO polls not supported on bsky. maybe 3rd party? skip for now 70 # we don't handle reblogs. possible with bridgy(?) and self 71 if note.get("poll"): 72 LOGGER.info("Skipping '%s'! Contains a poll..", note["id"]) 73 return 74 75 renote: dict | None = note.get("renote") 76 if renote: 77 if note.get("text") is not None: 78 LOGGER.info("Skipping '%s'! Quote..", note["id"]) 79 return 80 81 if renote.get("userId") != self.user_id: 82 LOGGER.info("Skipping '%s'! Reblog of other user..", note["id"]) 83 return 84 85 success = database.try_insert_repost( 86 self.db, note["id"], renote["id"], self.user_id, self.service 87 ) 88 if not success: 89 LOGGER.info( 90 "Skipping '%s' as renoted note was not found in db!", note["id"] 91 ) 92 return 93 94 for output in outputs: 95 output.accept_repost(note["id"], renote["id"]) 96 return 97 98 reply_id: str | None = note.get("replyId") 99 if reply_id: 100 if note.get("reply", {}).get("userId") != self.user_id: 101 LOGGER.info("Skipping '%s'! Reply to other user..", note["id"]) 102 return 103 104 success = database.try_insert_post( 105 self.db, note["id"], reply_id, self.user_id, self.service 106 ) 107 if not success: 108 LOGGER.info("Skipping '%s' as parent note was not found in db!", note["id"]) 109 return 110 111 mention_handles: dict = note.get("mentionHandles") or {} 112 tags: list[str] = note.get("tags") or [] 113 114 handles: list[tuple[str, str]] = [] 115 for key, value in mention_handles.items(): 116 handles.append((value, value)) 117 118 tokens = md_util.tokenize_markdown(note.get("text", ""), tags, handles) 119 if not cross.test_filters(tokens, self.options.filters): 120 LOGGER.info("Skipping '%s'. Matched a filter!", note["id"]) 121 return 122 123 LOGGER.info("Crossposting '%s'...", note["id"]) 124 125 media_attachments: list[MediaInfo] = [] 126 for attachment in note.get("files", []): 127 LOGGER.info("Downloading %s...", attachment["url"]) 128 info = download_media(attachment["url"], attachment.get("comment") or "") 129 if not info: 130 LOGGER.error("Skipping '%s'. Failed to download media!", note["id"]) 131 return 132 media_attachments.append(info) 133 134 cross_post = MisskeyPost(self.service, note, tokens, media_attachments) 135 for output in outputs: 136 output.accept_post(cross_post) 137 138 def _on_delete(self, outputs: list[cross.Output], note: dict): 139 # TODO handle deletes 140 pass 141 142 def _on_message(self, outputs: list[cross.Output], data: dict): 143 if data["type"] == "channel": 144 type: str = data["body"]["type"] 145 if type == "note" or type == "reply": 146 note_body = data["body"]["body"] 147 self._on_note(outputs, note_body) 148 return 149 150 pass 151 152 async def _send_keepalive(self, ws: websockets.WebSocketClientProtocol): 153 while ws.open: 154 try: 155 await asyncio.sleep(120) 156 if ws.open: 157 await ws.send("h") 158 LOGGER.debug("Sent keepalive h..") 159 else: 160 LOGGER.info("WebSocket is closed, stopping keepalive task.") 161 break 162 except Exception as e: 163 LOGGER.error(f"Error sending keepalive: {e}") 164 break 165 166 async def _subscribe_to_home(self, ws: websockets.WebSocketClientProtocol): 167 await ws.send( 168 json.dumps( 169 { 170 "type": "connect", 171 "body": {"channel": "homeTimeline", "id": str(uuid.uuid4())}, 172 } 173 ) 174 ) 175 LOGGER.info("Subscribed to 'homeTimeline' channel...") 176 177 async def listen( 178 self, outputs: list[cross.Output], submit: Callable[[Callable[[], Any]], Any] 179 ): 180 streaming: str = f"wss://{self.service.split('://', 1)[1]}" 181 url: str = f"{streaming}/streaming?i={self.token}" 182 183 async for ws in websockets.connect( 184 url, extra_headers={"User-Agent": "XPost/0.0.3"} 185 ): 186 try: 187 LOGGER.info("Listening to %s...", streaming) 188 await self._subscribe_to_home(ws) 189 190 async def listen_for_messages(): 191 async for msg in ws: 192 # TODO listen to deletes somehow 193 submit(lambda: self._on_message(outputs, json.loads(msg))) 194 195 keepalive = asyncio.create_task(self._send_keepalive(ws)) 196 listen = asyncio.create_task(listen_for_messages()) 197 198 await asyncio.gather(keepalive, listen) 199 except websockets.ConnectionClosedError as e: 200 LOGGER.error(e, stack_info=True, exc_info=True) 201 LOGGER.info("Reconnecting to %s...", streaming) 202 continue