social media crossposting tool. 3rd time's the charm
mastodon misskey crossposting bluesky
1import asyncio 2import json 3import re 4from dataclasses import dataclass, field 5from typing import Any, cast, override 6 7import websockets 8 9from cross.attachments import ( 10 LabelsAttachment, 11 LanguagesAttachment, 12 MediaAttachment, 13 QuoteAttachment, 14 RemoteUrlAttachment, 15 SensitiveAttachment, 16) 17from cross.media import Blob, download_blob 18from cross.post import Post 19from cross.service import InputService 20from database.connection import DatabasePool 21from mastodon.info import MastodonService, validate_and_transform 22from mastodon.parser import StatusParser 23 24ALLOWED_VISIBILITY: list[str] = ["public", "unlisted"] 25 26 27@dataclass(kw_only=True) 28class MastodonInputOptions: 29 token: str 30 instance: str 31 allowed_visibility: list[str] = field( 32 default_factory=lambda: ALLOWED_VISIBILITY.copy() 33 ) 34 filters: list[re.Pattern[str]] = field(default_factory=lambda: []) 35 36 @classmethod 37 def from_dict(cls, data: dict[str, Any]) -> "MastodonInputOptions": 38 validate_and_transform(data) 39 40 if "allowed_visibility" in data: 41 for vis in data.get("allowed_visibility", []): 42 if vis not in ALLOWED_VISIBILITY: 43 raise ValueError(f"Invalid visibility option {vis}!") 44 45 if "filters" in data: 46 data["filters"] = [re.compile(r) for r in data["filters"]] 47 48 return MastodonInputOptions(**data) 49 50 51class MastodonInputService(MastodonService, InputService): 52 def __init__(self, db: DatabasePool, options: MastodonInputOptions) -> None: 53 super().__init__(options.instance, db) 54 self.options: MastodonInputOptions = options 55 56 self.log.info("Verifying %s credentails...", self.url) 57 responce = self.verify_credentials() 58 self.user_id: str = responce["id"] 59 60 self.log.info("Getting %s configuration...", self.url) 61 responce = self.fetch_instance_info() 62 self.streaming_url: str = responce["urls"]["streaming_api"] 63 64 @override 65 def _get_token(self) -> str: 66 return self.options.token 67 68 def _on_create_post(self, status: dict[str, Any]): 69 if status["account"]["id"] != self.user_id: 70 return 71 72 if status["visibility"] not in self.options.allowed_visibility: 73 return 74 75 reblog: dict[str, Any] | None = status.get("reblog") 76 if reblog: 77 if reblog["account"]["id"] != self.user_id: 78 return 79 self._on_reblog(status, reblog) 80 return 81 82 if status.get("poll"): 83 self.log.info("Skipping '%s'! Contains a poll..", status["id"]) 84 return 85 86 quote: dict[str, Any] | None = status.get("quote") 87 if quote: 88 quote = quote['quoted_status'] if quote.get('quoted_status') else quote 89 if not quote or quote["account"]["id"] != self.user_id: 90 return 91 92 rquote = self._get_post(self.url, self.user_id, quote['id']) 93 if not rquote: 94 self.log.info( 95 "Skipping %s, parent %s not found in db", status["id"], quote['id'] 96 ) 97 return 98 99 in_reply: str | None = status.get("in_reply_to_id") 100 in_reply_to: str | None = status.get("in_reply_to_account_id") 101 if in_reply_to and in_reply_to != self.user_id: 102 return 103 104 parent = None 105 if in_reply: 106 parent = self._get_post(self.url, self.user_id, in_reply) 107 if not parent: 108 self.log.info( 109 "Skipping %s, parent %s not found in db", status["id"], in_reply 110 ) 111 return 112 parser = StatusParser() 113 parser.feed(status["content"]) 114 text, fragments = parser.get_result() 115 116 post = Post(id=status["id"], parent_id=in_reply, text=text) 117 post.fragments.extend(fragments) 118 119 if quote: 120 post.attachments.put(QuoteAttachment(quoted_id=quote['id'], quoted_user=self.user_id)) 121 if status.get("url"): 122 post.attachments.put(RemoteUrlAttachment(url=status["url"])) 123 if status.get("sensitive"): 124 post.attachments.put(SensitiveAttachment(sensitive=True)) 125 if status.get("language"): 126 post.attachments.put(LanguagesAttachment(langs=[status["language"]])) 127 if status.get("spoiler"): 128 post.attachments.put(LabelsAttachment(labels=[status["spoiler"]])) 129 130 blobs: list[Blob] = [] 131 for media in status.get("media_attachments", []): 132 self.log.info("Downloading %s...", media["url"]) 133 blob: Blob | None = download_blob(media["url"], media.get("alt")) 134 if not blob: 135 self.log.error( 136 "Skipping %s! Failed to download media %s.", 137 status["id"], 138 media["url"], 139 ) 140 return 141 blobs.append(blob) 142 143 if blobs: 144 post.attachments.put(MediaAttachment(blobs=blobs)) 145 146 if parent: 147 self._insert_post( 148 { 149 "user": self.user_id, 150 "service": self.url, 151 "identifier": status["id"], 152 "parent": parent["id"], 153 "root": parent["id"] if not parent["root"] else parent["root"], 154 } 155 ) 156 else: 157 self._insert_post( 158 { 159 "user": self.user_id, 160 "service": self.url, 161 "identifier": status["id"], 162 } 163 ) 164 165 for out in self.outputs: 166 self.submitter(lambda: out.accept_post(post)) 167 168 def _on_reblog(self, status: dict[str, Any], reblog: dict[str, Any]): 169 reposted = self._get_post(self.url, self.user_id, reblog["id"]) 170 if not reposted: 171 self.log.info( 172 "Skipping repost '%s' as reposted post '%s' was not found in the db.", 173 status["id"], 174 reblog["id"], 175 ) 176 return 177 178 self._insert_post( 179 { 180 "user": self.user_id, 181 "service": self.url, 182 "identifier": status["id"], 183 "reposted": reposted["id"], 184 } 185 ) 186 187 for out in self.outputs: 188 self.submitter(lambda: out.accept_repost(status["id"], reblog["id"])) 189 190 def _on_delete_post(self, status_id: str): 191 post = self._get_post(self.url, self.user_id, status_id) 192 if not post: 193 return 194 195 if post["reposted_id"]: 196 for output in self.outputs: 197 self.submitter(lambda: output.delete_repost(status_id)) 198 else: 199 for output in self.outputs: 200 self.submitter(lambda: output.delete_post(status_id)) 201 self._delete_post_by_id(post["id"]) 202 203 def _accept_msg(self, msg: websockets.Data) -> None: 204 data: dict[str, Any] = cast(dict[str, Any], json.loads(msg)) 205 event: str = cast(str, data["event"]) 206 payload: str = cast(str, data["payload"]) 207 208 if event == "update": 209 self._on_create_post(json.loads(payload)) 210 elif event == "delete": 211 self._on_delete_post(payload) 212 213 @override 214 async def listen(self): 215 url = f"{self.streaming_url}/api/v1/streaming?stream=user" 216 217 async for ws in websockets.connect( 218 url, additional_headers={"Authorization": f"Bearer {self.options.token}"} 219 ): 220 try: 221 self.log.info("Listening to %s...", self.streaming_url) 222 223 async def listen_for_messages(): 224 async for msg in ws: 225 self.submitter(lambda: self._accept_msg(msg)) 226 227 listen = asyncio.create_task(listen_for_messages()) 228 229 _ = await asyncio.gather(listen) 230 except websockets.ConnectionClosedError as e: 231 self.log.error(e, stack_info=True, exc_info=True) 232 self.log.info("Reconnecting to %s...", self.streaming_url) 233 continue