social media crossposting tool. 3rd time's the charm
mastodon misskey crossposting bluesky
1import asyncio 2import json 3import re 4import uuid 5from dataclasses import dataclass, field 6from typing import Any, cast, override 7 8import websockets 9 10from cross.attachments import ( 11 LabelsAttachment, 12 MediaAttachment, 13 QuoteAttachment, 14 RemoteUrlAttachment, 15 SensitiveAttachment, 16) 17from cross.media import Blob, download_blob 18from cross.post import Post 19from cross.service import InputService 20from database.connection import DatabasePool 21from misskey.info import MisskeyService 22from util.markdown import MarkdownParser 23from util.util import normalize_service_url 24 25ALLOWED_VISIBILITY = ["public", "home"] 26 27 28@dataclass 29class MisskeyInputOptions: 30 token: str 31 instance: str 32 allowed_visibility: list[str] = field( 33 default_factory=lambda: ALLOWED_VISIBILITY.copy() 34 ) 35 filters: list[re.Pattern[str]] = field(default_factory=lambda: []) 36 37 @classmethod 38 def from_dict(cls, data: dict[str, Any]) -> "MisskeyInputOptions": 39 data["instance"] = normalize_service_url(data["instance"]) 40 41 if "allowed_visibility" in data: 42 for vis in data.get("allowed_visibility", []): 43 if vis not in ALLOWED_VISIBILITY: 44 raise ValueError(f"Invalid visibility option {vis}!") 45 46 if "filters" in data: 47 data["filters"] = [re.compile(r) for r in data["filters"]] 48 49 return MisskeyInputOptions(**data) 50 51 52class MisskeyInputService(MisskeyService, InputService): 53 def __init__(self, db: DatabasePool, options: MisskeyInputOptions) -> None: 54 super().__init__(options.instance, db) 55 self.options: MisskeyInputOptions = options 56 57 self.log.info("Verifying %s credentails...", self.url) 58 responce = self.verify_credentials() 59 self.user_id: str = responce["id"] 60 61 @override 62 def _get_token(self) -> str: 63 return self.options.token 64 65 def _on_note(self, note: dict[str, Any]): 66 if note["userId"] != self.user_id: 67 return 68 69 if note["visibility"] not in self.options.allowed_visibility: 70 return 71 72 if note.get("poll"): 73 self.log.info("Skipping '%s'! Contains a poll..", note["id"]) 74 return 75 76 renote: dict[str, Any] | None = note.get("renote") 77 if renote: 78 if note.get("text") is None: 79 self._on_renote(note, renote) 80 return 81 82 if renote["userId"] != self.user_id: 83 return 84 85 reply: dict[str, Any] | None = note.get("reply") 86 if reply: 87 if reply.get("userId") != self.user_id: 88 self.log.info("Skipping '%s'! Reply to other user..", note["id"]) 89 return 90 91 parent = None 92 if reply: 93 parent = self._get_post(self.url, self.user_id, reply["id"]) 94 if not parent: 95 self.log.info( 96 "Skipping %s, parent %s not found in db", note["id"], reply["id"] 97 ) 98 return 99 100 parser = MarkdownParser() # TODO MFM parser 101 text, fragments = parser.parse(note.get("text", "")) 102 post = Post(id=note["id"], parent_id=reply["id"] if reply else None, text=text) 103 post.fragments.extend(fragments) 104 105 post.attachments.put(RemoteUrlAttachment(url=self.url + "/notes/" + note["id"])) 106 if renote: 107 post.attachments.put(QuoteAttachment(quoted_id=renote['id'], quoted_user=self.user_id)) 108 if any([a.get("isSensitive", False) for a in note.get("files", [])]): 109 post.attachments.put(SensitiveAttachment(sensitive=True)) 110 if note.get("cw"): 111 post.attachments.put(LabelsAttachment(labels=[note["cw"]])) 112 113 blobs: list[Blob] = [] 114 for media in note.get("files", []): 115 self.log.info("Downloading %s...", media["url"]) 116 blob: Blob | None = download_blob(media["url"], media.get("comment", "")) 117 if not blob: 118 self.log.error( 119 "Skipping %s! Failed to download media %s.", 120 note["id"], 121 media["url"], 122 ) 123 return 124 blobs.append(blob) 125 126 if blobs: 127 post.attachments.put(MediaAttachment(blobs=blobs)) 128 129 if parent: 130 self._insert_post( 131 { 132 "user": self.user_id, 133 "service": self.url, 134 "identifier": note["id"], 135 "parent": parent["id"], 136 "root": parent["id"] if not parent["root"] else parent["root"], 137 } 138 ) 139 else: 140 self._insert_post( 141 { 142 "user": self.user_id, 143 "service": self.url, 144 "identifier": note["id"], 145 } 146 ) 147 148 for out in self.outputs: 149 self.submitter(lambda: out.accept_post(post)) 150 151 def _on_renote(self, note: dict[str, Any], renote: dict[str, Any]): 152 reposted = self._get_post(self.url, self.user_id, renote["id"]) 153 if not reposted: 154 self.log.info( 155 "Skipping repost '%s' as reposted post '%s' was not found in the db.", 156 note["id"], 157 renote["id"], 158 ) 159 return 160 161 self._insert_post( 162 { 163 "user": self.user_id, 164 "service": self.url, 165 "identifier": note["id"], 166 "reposted": reposted["id"], 167 } 168 ) 169 170 for out in self.outputs: 171 self.submitter(lambda: out.accept_repost(note["id"], renote["id"])) 172 173 def _accept_msg(self, msg: websockets.Data) -> None: 174 data: dict[str, Any] = cast(dict[str, Any], json.loads(msg)) 175 176 if data["type"] == "channel": 177 type: str = cast(str, data["body"]["type"]) 178 if type == "note" or type == "reply": 179 note_body = data["body"]["body"] 180 self._on_note(note_body) 181 return 182 183 async def _subscribe_to_home(self, ws: websockets.ClientConnection) -> None: 184 await ws.send( 185 json.dumps( 186 { 187 "type": "connect", 188 "body": {"channel": "homeTimeline", "id": str(uuid.uuid4())}, 189 } 190 ) 191 ) 192 self.log.info("Subscribed to 'homeTimeline' channel...") 193 194 @override 195 async def listen(self): 196 streaming: str = f"{'wss' if self.url.startswith('https') else 'ws'}://{self.url.split('://', 1)[1]}" 197 url: str = f"{streaming}/streaming?i={self.options.token}" 198 199 async for ws in websockets.connect(url): 200 try: 201 self.log.info("Listening to %s...", streaming) 202 await self._subscribe_to_home(ws) 203 204 async def listen_for_messages(): 205 async for msg in ws: 206 self.submitter(lambda: self._accept_msg(msg)) 207 208 listen = asyncio.create_task(listen_for_messages()) 209 210 _ = await asyncio.gather(listen) 211 except websockets.ConnectionClosedError as e: 212 self.log.error(e, stack_info=True, exc_info=True) 213 self.log.info("Reconnecting to %s...", streaming) 214 continue