social media crossposting tool. 3rd time's the charm
mastodon misskey crossposting bluesky
1import asyncio 2import json 3import re 4from dataclasses import dataclass, field 5from typing import Any, cast, override 6 7import websockets 8 9from cross.attachments import ( 10 LabelsAttachment, 11 LanguagesAttachment, 12 MediaAttachment, 13 RemoteUrlAttachment, 14 SensitiveAttachment, 15) 16from cross.media import Blob, download_blob 17from cross.post import Post 18from cross.service import InputService 19from database.connection import DatabasePool 20from mastodon.info import MastodonService, validate_and_transform 21from mastodon.parser import StatusParser 22 23ALLOWED_VISIBILITY: list[str] = ["public", "unlisted"] 24 25 26@dataclass(kw_only=True) 27class MastodonInputOptions: 28 token: str 29 instance: str 30 allowed_visibility: list[str] = field( 31 default_factory=lambda: ALLOWED_VISIBILITY.copy() 32 ) 33 filters: list[re.Pattern[str]] = field(default_factory=lambda: []) 34 35 @classmethod 36 def from_dict(cls, data: dict[str, Any]) -> "MastodonInputOptions": 37 validate_and_transform(data) 38 39 if "allowed_visibility" in data: 40 for vis in data.get("allowed_visibility", []): 41 if vis not in ALLOWED_VISIBILITY: 42 raise ValueError(f"Invalid visibility option {vis}!") 43 44 if "filters" in data: 45 data["filters"] = [re.compile(r) for r in data["filters"]] 46 47 return MastodonInputOptions(**data) 48 49 50class MastodonInputService(MastodonService, InputService): 51 def __init__(self, db: DatabasePool, options: MastodonInputOptions) -> None: 52 super().__init__(options.instance, db) 53 self.options: MastodonInputOptions = options 54 55 self.log.info("Verifying %s credentails...", self.url) 56 responce = self.verify_credentials() 57 self.user_id: str = responce["id"] 58 59 self.log.info("Getting %s configuration...", self.url) 60 responce = self.fetch_instance_info() 61 self.streaming_url: str = responce["urls"]["streaming_api"] 62 63 @override 64 def _get_token(self) -> str: 65 return self.options.token 66 67 def _on_create_post(self, status: dict[str, Any]): 68 if status["account"]["id"] != self.user_id: 69 return 70 71 if status["visibility"] not in self.options.allowed_visibility: 72 return 73 74 reblog: dict[str, Any] | None = status.get("reblog") 75 if reblog: 76 if reblog["account"]["id"] != self.user_id: 77 return 78 self._on_reblog(status, reblog) 79 return 80 81 if status.get("poll"): 82 self.log.info("Skipping '%s'! Contains a poll..", status["id"]) 83 return 84 85 if status.get("quote"): 86 self.log.info("Skipping '%s'! Quote..", status["id"]) 87 return 88 89 in_reply: str | None = status.get("in_reply_to_id") 90 in_reply_to: str | None = status.get("in_reply_to_account_id") 91 if in_reply_to and in_reply_to != self.user_id: 92 return 93 94 parent = None 95 if in_reply: 96 parent = self._get_post(self.url, self.user_id, in_reply) 97 if not parent: 98 self.log.info( 99 "Skipping %s, parent %s not found in db", status["id"], in_reply 100 ) 101 return 102 parser = StatusParser() 103 parser.feed(status["content"]) 104 text, fragments = parser.get_result() 105 106 post = Post(id=status["id"], parent_id=in_reply, text=text) 107 post.fragments.extend(fragments) 108 109 if status.get("url"): 110 post.attachments.put(RemoteUrlAttachment(url=status["url"])) 111 if status.get("sensitive"): 112 post.attachments.put(SensitiveAttachment(sensitive=True)) 113 if status.get("language"): 114 post.attachments.put(LanguagesAttachment(langs=[status["language"]])) 115 if status.get("spoiler"): 116 post.attachments.put(LabelsAttachment(labels=[status["spoiler"]])) 117 118 blobs: list[Blob] = [] 119 for media in status.get("media_attachments", []): 120 self.log.info("Downloading %s...", media["url"]) 121 blob: Blob | None = download_blob(media["url"], media.get("alt")) 122 if not blob: 123 self.log.error( 124 "Skipping %s! Failed to download media %s.", 125 status["id"], 126 media["url"], 127 ) 128 return 129 blobs.append(blob) 130 131 if blobs: 132 post.attachments.put(MediaAttachment(blobs=blobs)) 133 134 if parent: 135 self._insert_post( 136 { 137 "user": self.user_id, 138 "service": self.url, 139 "identifier": status["id"], 140 "parent": parent["id"], 141 "root": parent["id"] if not parent["root"] else parent["root"], 142 } 143 ) 144 else: 145 self._insert_post( 146 { 147 "user": self.user_id, 148 "service": self.url, 149 "identifier": status["id"], 150 } 151 ) 152 153 for out in self.outputs: 154 self.submitter(lambda: out.accept_post(post)) 155 156 def _on_reblog(self, status: dict[str, Any], reblog: dict[str, Any]): 157 reposted = self._get_post(self.url, self.user_id, reblog["id"]) 158 if not reposted: 159 self.log.info( 160 "Skipping repost '%s' as reposted post '%s' was not found in the db.", 161 status["id"], 162 reblog["id"], 163 ) 164 return 165 166 self._insert_post( 167 { 168 "user": self.user_id, 169 "service": self.url, 170 "identifier": status["id"], 171 "reposted": reposted["id"], 172 } 173 ) 174 175 for out in self.outputs: 176 self.submitter(lambda: out.accept_repost(status["id"], reposted["id"])) 177 178 def _on_delete_post(self, status_id: str): 179 post = self._get_post(self.url, self.user_id, status_id) 180 if not post: 181 return 182 183 if post["reposted_id"]: 184 for output in self.outputs: 185 self.submitter(lambda: output.delete_repost(status_id)) 186 else: 187 for output in self.outputs: 188 self.submitter(lambda: output.delete_post(status_id)) 189 self._delete_post_by_id(post["id"]) 190 191 def _accept_msg(self, msg: websockets.Data) -> None: 192 data: dict[str, Any] = cast(dict[str, Any], json.loads(msg)) 193 event: str = cast(str, data["event"]) 194 payload: str = cast(str, data["payload"]) 195 196 if event == "update": 197 self._on_create_post(json.loads(payload)) 198 elif event == "delete": 199 self._on_delete_post(payload) 200 201 @override 202 async def listen(self): 203 url = f"{self.streaming_url}/api/v1/streaming?stream=user" 204 205 async for ws in websockets.connect( 206 url, additional_headers={"Authorization": f"Bearer {self.options.token}"} 207 ): 208 try: 209 self.log.info("Listening to %s...", self.streaming_url) 210 211 async def listen_for_messages(): 212 async for msg in ws: 213 self.submitter(lambda: self._accept_msg(msg)) 214 215 listen = asyncio.create_task(listen_for_messages()) 216 217 _ = await asyncio.gather(listen) 218 except websockets.ConnectionClosedError as e: 219 self.log.error(e, stack_info=True, exc_info=True) 220 self.log.info("Reconnecting to %s...", self.streaming_url) 221 continue