social media crossposting tool. 3rd time's the charm
mastodon
misskey
crossposting
bluesky
1import asyncio
2import json
3import re
4from dataclasses import dataclass, field
5from typing import Any, cast, override
6
7import websockets
8
9from cross.attachments import (
10 LabelsAttachment,
11 LanguagesAttachment,
12 MediaAttachment,
13 QuoteAttachment,
14 RemoteUrlAttachment,
15 SensitiveAttachment,
16)
17from cross.media import Blob, download_blob
18from cross.post import Post
19from cross.service import InputService
20from database.connection import DatabasePool
21from mastodon.info import MastodonService, validate_and_transform
22from mastodon.parser import StatusParser
23
24ALLOWED_VISIBILITY: list[str] = ["public", "unlisted"]
25
26
27@dataclass(kw_only=True)
28class MastodonInputOptions:
29 token: str
30 instance: str
31 allowed_visibility: list[str] = field(
32 default_factory=lambda: ALLOWED_VISIBILITY.copy()
33 )
34 filters: list[re.Pattern[str]] = field(default_factory=lambda: [])
35
36 @classmethod
37 def from_dict(cls, data: dict[str, Any]) -> "MastodonInputOptions":
38 validate_and_transform(data)
39
40 if "allowed_visibility" in data:
41 for vis in data.get("allowed_visibility", []):
42 if vis not in ALLOWED_VISIBILITY:
43 raise ValueError(f"Invalid visibility option {vis}!")
44
45 if "filters" in data:
46 data["filters"] = [re.compile(r) for r in data["filters"]]
47
48 return MastodonInputOptions(**data)
49
50
51class MastodonInputService(MastodonService, InputService):
52 def __init__(self, db: DatabasePool, options: MastodonInputOptions) -> None:
53 super().__init__(options.instance, db)
54 self.options: MastodonInputOptions = options
55
56 self.log.info("Verifying %s credentails...", self.url)
57 responce = self.verify_credentials()
58 self.user_id: str = responce["id"]
59
60 self.log.info("Getting %s configuration...", self.url)
61 responce = self.fetch_instance_info()
62 self.streaming_url: str = responce["urls"]["streaming_api"]
63
64 @override
65 def _get_token(self) -> str:
66 return self.options.token
67
68 def _on_create_post(self, status: dict[str, Any]):
69 if status["account"]["id"] != self.user_id:
70 return
71
72 if status["visibility"] not in self.options.allowed_visibility:
73 return
74
75 reblog: dict[str, Any] | None = status.get("reblog")
76 if reblog:
77 if reblog["account"]["id"] != self.user_id:
78 return
79 self._on_reblog(status, reblog)
80 return
81
82 if status.get("poll"):
83 self.log.info("Skipping '%s'! Contains a poll..", status["id"])
84 return
85
86 quote: dict[str, Any] | None = status.get("quote")
87 if quote:
88 quote = quote['quoted_status'] if quote.get('quoted_status') else quote
89 if not quote or quote["account"]["id"] != self.user_id:
90 return
91
92 rquote = self._get_post(self.url, self.user_id, quote['id'])
93 if not rquote:
94 self.log.info(
95 "Skipping %s, parent %s not found in db", status["id"], quote['id']
96 )
97 return
98
99 in_reply: str | None = status.get("in_reply_to_id")
100 in_reply_to: str | None = status.get("in_reply_to_account_id")
101 if in_reply_to and in_reply_to != self.user_id:
102 return
103
104 parent = None
105 if in_reply:
106 parent = self._get_post(self.url, self.user_id, in_reply)
107 if not parent:
108 self.log.info(
109 "Skipping %s, parent %s not found in db", status["id"], in_reply
110 )
111 return
112 parser = StatusParser()
113 parser.feed(status["content"])
114 text, fragments = parser.get_result()
115
116 post = Post(id=status["id"], parent_id=in_reply, text=text)
117 post.fragments.extend(fragments)
118
119 if quote:
120 post.attachments.put(QuoteAttachment(quoted_id=quote['id'], quoted_user=self.user_id))
121 if status.get("url"):
122 post.attachments.put(RemoteUrlAttachment(url=status["url"]))
123 if status.get("sensitive"):
124 post.attachments.put(SensitiveAttachment(sensitive=True))
125 if status.get("language"):
126 post.attachments.put(LanguagesAttachment(langs=[status["language"]]))
127 if status.get("spoiler"):
128 post.attachments.put(LabelsAttachment(labels=[status["spoiler"]]))
129
130 blobs: list[Blob] = []
131 for media in status.get("media_attachments", []):
132 self.log.info("Downloading %s...", media["url"])
133 blob: Blob | None = download_blob(media["url"], media.get("alt"))
134 if not blob:
135 self.log.error(
136 "Skipping %s! Failed to download media %s.",
137 status["id"],
138 media["url"],
139 )
140 return
141 blobs.append(blob)
142
143 if blobs:
144 post.attachments.put(MediaAttachment(blobs=blobs))
145
146 if parent:
147 self._insert_post(
148 {
149 "user": self.user_id,
150 "service": self.url,
151 "identifier": status["id"],
152 "parent": parent["id"],
153 "root": parent["id"] if not parent["root"] else parent["root"],
154 }
155 )
156 else:
157 self._insert_post(
158 {
159 "user": self.user_id,
160 "service": self.url,
161 "identifier": status["id"],
162 }
163 )
164
165 for out in self.outputs:
166 self.submitter(lambda: out.accept_post(post))
167
168 def _on_reblog(self, status: dict[str, Any], reblog: dict[str, Any]):
169 reposted = self._get_post(self.url, self.user_id, reblog["id"])
170 if not reposted:
171 self.log.info(
172 "Skipping repost '%s' as reposted post '%s' was not found in the db.",
173 status["id"],
174 reblog["id"],
175 )
176 return
177
178 self._insert_post(
179 {
180 "user": self.user_id,
181 "service": self.url,
182 "identifier": status["id"],
183 "reposted": reposted["id"],
184 }
185 )
186
187 for out in self.outputs:
188 self.submitter(lambda: out.accept_repost(status["id"], reblog["id"]))
189
190 def _on_delete_post(self, status_id: str):
191 post = self._get_post(self.url, self.user_id, status_id)
192 if not post:
193 return
194
195 if post["reposted_id"]:
196 for output in self.outputs:
197 self.submitter(lambda: output.delete_repost(status_id))
198 else:
199 for output in self.outputs:
200 self.submitter(lambda: output.delete_post(status_id))
201 self._delete_post_by_id(post["id"])
202
203 def _accept_msg(self, msg: websockets.Data) -> None:
204 data: dict[str, Any] = cast(dict[str, Any], json.loads(msg))
205 event: str = cast(str, data["event"])
206 payload: str = cast(str, data["payload"])
207
208 if event == "update":
209 self._on_create_post(json.loads(payload))
210 elif event == "delete":
211 self._on_delete_post(payload)
212
213 @override
214 async def listen(self):
215 url = f"{self.streaming_url}/api/v1/streaming?stream=user"
216
217 async for ws in websockets.connect(
218 url, additional_headers={"Authorization": f"Bearer {self.options.token}"}
219 ):
220 try:
221 self.log.info("Listening to %s...", self.streaming_url)
222
223 async def listen_for_messages():
224 async for msg in ws:
225 self.submitter(lambda: self._accept_msg(msg))
226
227 listen = asyncio.create_task(listen_for_messages())
228
229 _ = await asyncio.gather(listen)
230 except websockets.ConnectionClosedError as e:
231 self.log.error(e, stack_info=True, exc_info=True)
232 self.log.info("Reconnecting to %s...", self.streaming_url)
233 continue