social media crossposting tool. 3rd time's the charm
mastodon misskey crossposting bluesky
at next 8.0 kB view raw
1import asyncio 2import json 3import re 4from dataclasses import dataclass, field 5from typing import Any, cast, override 6 7import websockets 8 9from cross.attachments import ( 10 LabelsAttachment, 11 LanguagesAttachment, 12 MediaAttachment, 13 QuoteAttachment, 14 RemoteUrlAttachment, 15 SensitiveAttachment, 16) 17from cross.media import Blob, download_blob 18from cross.post import Post 19from cross.service import InputService 20from database.connection import DatabasePool 21from mastodon.info import MastodonService, validate_and_transform 22from mastodon.parser import StatusParser 23 24ALLOWED_VISIBILITY: list[str] = ["public", "unlisted"] 25 26 27@dataclass(kw_only=True) 28class MastodonInputOptions: 29 token: str 30 instance: str 31 allowed_visibility: list[str] = field( 32 default_factory=lambda: ALLOWED_VISIBILITY.copy() 33 ) 34 filters: list[re.Pattern[str]] = field(default_factory=lambda: []) 35 36 @classmethod 37 def from_dict(cls, data: dict[str, Any]) -> "MastodonInputOptions": 38 validate_and_transform(data) 39 40 if "allowed_visibility" in data: 41 for vis in data.get("allowed_visibility", []): 42 if vis not in ALLOWED_VISIBILITY: 43 raise ValueError(f"Invalid visibility option {vis}!") 44 45 if "filters" in data: 46 data["filters"] = [re.compile(r) for r in data["filters"]] 47 48 return MastodonInputOptions(**data) 49 50 51class MastodonInputService(MastodonService, InputService): 52 def __init__(self, db: DatabasePool, options: MastodonInputOptions) -> None: 53 super().__init__(options.instance, db) 54 self.options: MastodonInputOptions = options 55 56 self.log.info("Verifying %s credentails...", self.url) 57 response = self.verify_credentials() 58 self.user_id: str = response["id"] 59 60 self.log.info("Getting %s configuration...", self.url) 61 response = self.fetch_instance_info() 62 self.streaming_url: str = response["urls"]["streaming_api"] 63 64 @override 65 def _get_token(self) -> str: 66 return self.options.token 67 68 def _on_create_post(self, status: dict[str, Any]): 69 if status["account"]["id"] != self.user_id: 70 return 71 72 if status["visibility"] not in self.options.allowed_visibility: 73 return 74 75 reblog: dict[str, Any] | None = status.get("reblog") 76 if reblog: 77 if reblog["account"]["id"] != self.user_id: 78 return 79 self._on_reblog(status, reblog) 80 return 81 82 if status.get("poll"): 83 self.log.info("Skipping '%s'! Contains a poll..", status["id"]) 84 return 85 86 quote: dict[str, Any] | None = status.get("quote") 87 if quote: 88 quote = quote['quoted_status'] if quote.get('quoted_status') else quote 89 if not quote or quote["account"]["id"] != self.user_id: 90 return 91 92 rquote = self._get_post(self.url, self.user_id, quote['id']) 93 if not rquote: 94 self.log.info( 95 "Skipping %s, parent %s not found in db", status["id"], quote['id'] 96 ) 97 return 98 99 in_reply: str | None = status.get("in_reply_to_id") 100 in_reply_to: str | None = status.get("in_reply_to_account_id") 101 if in_reply_to and in_reply_to != self.user_id: 102 return 103 104 parent = None 105 if in_reply: 106 parent = self._get_post(self.url, self.user_id, in_reply) 107 if not parent: 108 self.log.info( 109 "Skipping %s, parent %s not found in db", status["id"], in_reply 110 ) 111 return 112 parser = StatusParser(status) 113 parser.feed(status["content"]) 114 tokens = parser.get_result() 115 116 post = Post(id=status["id"], parent_id=in_reply, tokens=tokens) 117 118 if quote: 119 post.attachments.put(QuoteAttachment(quoted_id=quote['id'], quoted_user=self.user_id)) 120 if status.get("url"): 121 post.attachments.put(RemoteUrlAttachment(url=status["url"])) 122 if status.get("sensitive"): 123 post.attachments.put(SensitiveAttachment(sensitive=True)) 124 if status.get("language"): 125 post.attachments.put(LanguagesAttachment(langs=[status["language"]])) 126 if status.get("spoiler"): 127 post.attachments.put(LabelsAttachment(labels=[status["spoiler"]])) 128 129 blobs: list[Blob] = [] 130 for media in status.get("media_attachments", []): 131 self.log.info("Downloading %s...", media["url"]) 132 blob: Blob | None = download_blob(media["url"], media.get("alt")) 133 if not blob: 134 self.log.error( 135 "Skipping %s! Failed to download media %s.", 136 status["id"], 137 media["url"], 138 ) 139 return 140 blobs.append(blob) 141 142 if blobs: 143 post.attachments.put(MediaAttachment(blobs=blobs)) 144 145 if parent: 146 self._insert_post( 147 { 148 "user": self.user_id, 149 "service": self.url, 150 "identifier": status["id"], 151 "parent": parent["id"], 152 "root": parent["id"] if not parent["root"] else parent["root"], 153 } 154 ) 155 else: 156 self._insert_post( 157 { 158 "user": self.user_id, 159 "service": self.url, 160 "identifier": status["id"], 161 } 162 ) 163 164 for out in self.outputs: 165 self.submitter(lambda: out.accept_post(post)) 166 167 def _on_reblog(self, status: dict[str, Any], reblog: dict[str, Any]): 168 reposted = self._get_post(self.url, self.user_id, reblog["id"]) 169 if not reposted: 170 self.log.info( 171 "Skipping repost '%s' as reposted post '%s' was not found in the db.", 172 status["id"], 173 reblog["id"], 174 ) 175 return 176 177 self._insert_post( 178 { 179 "user": self.user_id, 180 "service": self.url, 181 "identifier": status["id"], 182 "reposted": reposted["id"], 183 } 184 ) 185 186 for out in self.outputs: 187 self.submitter(lambda: out.accept_repost(status["id"], reblog["id"])) 188 189 def _on_delete_post(self, status_id: str): 190 post = self._get_post(self.url, self.user_id, status_id) 191 if not post: 192 return 193 194 if post["reposted_id"]: 195 for output in self.outputs: 196 self.submitter(lambda: output.delete_repost(status_id)) 197 else: 198 for output in self.outputs: 199 self.submitter(lambda: output.delete_post(status_id)) 200 self._delete_post_by_id(post["id"]) 201 202 def _accept_msg(self, msg: websockets.Data) -> None: 203 data: dict[str, Any] = cast(dict[str, Any], json.loads(msg)) 204 event: str = cast(str, data["event"]) 205 payload: str = cast(str, data["payload"]) 206 207 if event == "update": 208 self._on_create_post(json.loads(payload)) 209 elif event == "delete": 210 self._on_delete_post(payload) 211 212 @override 213 async def listen(self): 214 url = f"{self.streaming_url}/api/v1/streaming?stream=user" 215 216 async for ws in websockets.connect( 217 url, additional_headers={"Authorization": f"Bearer {self.options.token}"} 218 ): 219 try: 220 self.log.info("Listening to %s...", self.streaming_url) 221 222 async def listen_for_messages(): 223 async for msg in ws: 224 self.submitter(lambda: self._accept_msg(msg)) 225 226 listen = asyncio.create_task(listen_for_messages()) 227 228 _ = await asyncio.gather(listen) 229 except websockets.ConnectionClosedError as e: 230 self.log.error(e, stack_info=True, exc_info=True) 231 self.log.info("Reconnecting to %s...", self.streaming_url) 232 continue