social media crossposting tool. 3rd time's the charm
mastodon
misskey
crossposting
bluesky
1import asyncio
2import json
3import re
4from dataclasses import dataclass, field
5from typing import Any, cast, override
6
7import websockets
8
9from cross.attachments import (
10 LabelsAttachment,
11 LanguagesAttachment,
12 MediaAttachment,
13 RemoteUrlAttachment,
14 SensitiveAttachment,
15)
16from cross.media import Blob, download_blob
17from cross.post import Post
18from cross.service import InputService
19from database.connection import DatabasePool
20from mastodon.info import MastodonService, validate_and_transform
21from mastodon.parser import StatusParser
22
23ALLOWED_VISIBILITY: list[str] = ["public", "unlisted"]
24
25
26@dataclass(kw_only=True)
27class MastodonInputOptions:
28 token: str
29 instance: str
30 allowed_visibility: list[str] = field(
31 default_factory=lambda: ALLOWED_VISIBILITY.copy()
32 )
33 filters: list[re.Pattern[str]] = field(default_factory=lambda: [])
34
35 @classmethod
36 def from_dict(cls, data: dict[str, Any]) -> "MastodonInputOptions":
37 validate_and_transform(data)
38
39 if "allowed_visibility" in data:
40 for vis in data.get("allowed_visibility", []):
41 if vis not in ALLOWED_VISIBILITY:
42 raise ValueError(f"Invalid visibility option {vis}!")
43
44 if "filters" in data:
45 data["filters"] = [re.compile(r) for r in data["filters"]]
46
47 return MastodonInputOptions(**data)
48
49
50class MastodonInputService(MastodonService, InputService):
51 def __init__(self, db: DatabasePool, options: MastodonInputOptions) -> None:
52 super().__init__(options.instance, db)
53 self.options: MastodonInputOptions = options
54
55 self.log.info("Verifying %s credentails...", self.url)
56 responce = self.verify_credentials()
57 self.user_id: str = responce["id"]
58
59 self.log.info("Getting %s configuration...", self.url)
60 responce = self.fetch_instance_info()
61 self.streaming_url: str = responce["urls"]["streaming_api"]
62
63 @override
64 def _get_token(self) -> str:
65 return self.options.token
66
67 def _on_create_post(self, status: dict[str, Any]):
68 if status["account"]["id"] != self.user_id:
69 return
70
71 if status["visibility"] not in self.options.allowed_visibility:
72 return
73
74 reblog: dict[str, Any] | None = status.get("reblog")
75 if reblog:
76 if reblog["account"]["id"] != self.user_id:
77 return
78 self._on_reblog(status, reblog)
79 return
80
81 if status.get("poll"):
82 self.log.info("Skipping '%s'! Contains a poll..", status["id"])
83 return
84
85 if status.get("quote"):
86 self.log.info("Skipping '%s'! Quote..", status["id"])
87 return
88
89 in_reply: str | None = status.get("in_reply_to_id")
90 in_reply_to: str | None = status.get("in_reply_to_account_id")
91 if in_reply_to and in_reply_to != self.user_id:
92 return
93
94 parent = None
95 if in_reply:
96 parent = self._get_post(self.url, self.user_id, in_reply)
97 if not parent:
98 self.log.info(
99 "Skipping %s, parent %s not found in db", status["id"], in_reply
100 )
101 return
102 parser = StatusParser()
103 parser.feed(status["content"])
104 text, fragments = parser.get_result()
105
106 post = Post(id=status["id"], parent_id=in_reply, text=text)
107 post.fragments.extend(fragments)
108
109 if status.get("url"):
110 post.attachments.put(RemoteUrlAttachment(url=status["url"]))
111 if status.get("sensitive"):
112 post.attachments.put(SensitiveAttachment(sensitive=True))
113 if status.get("language"):
114 post.attachments.put(LanguagesAttachment(langs=[status["language"]]))
115 if status.get("spoiler"):
116 post.attachments.put(LabelsAttachment(labels=[status["spoiler"]]))
117
118 blobs: list[Blob] = []
119 for media in status.get("media_attachments", []):
120 self.log.info("Downloading %s...", media["url"])
121 blob: Blob | None = download_blob(media["url"], media.get("alt"))
122 if not blob:
123 self.log.error(
124 "Skipping %s! Failed to download media %s.",
125 status["id"],
126 media["url"],
127 )
128 return
129 blobs.append(blob)
130
131 if blobs:
132 post.attachments.put(MediaAttachment(blobs=blobs))
133
134 if parent:
135 self._insert_post(
136 {
137 "user": self.user_id,
138 "service": self.url,
139 "identifier": status["id"],
140 "parent": parent["id"],
141 "root": parent["id"] if not parent["root"] else parent["root"],
142 }
143 )
144 else:
145 self._insert_post(
146 {
147 "user": self.user_id,
148 "service": self.url,
149 "identifier": status["id"],
150 }
151 )
152
153 for out in self.outputs:
154 self.submitter(lambda: out.accept_post(post))
155
156 def _on_reblog(self, status: dict[str, Any], reblog: dict[str, Any]):
157 reposted = self._get_post(self.url, self.user_id, reblog["id"])
158 if not reposted:
159 self.log.info(
160 "Skipping repost '%s' as reposted post '%s' was not found in the db.",
161 status["id"],
162 reblog["id"],
163 )
164 return
165
166 self._insert_post(
167 {
168 "user": self.user_id,
169 "service": self.url,
170 "identifier": status["id"],
171 "reposted": reposted["id"],
172 }
173 )
174
175 for out in self.outputs:
176 self.submitter(lambda: out.accept_repost(status["id"], reblog["id"]))
177
178 def _on_delete_post(self, status_id: str):
179 post = self._get_post(self.url, self.user_id, status_id)
180 if not post:
181 return
182
183 if post["reposted_id"]:
184 for output in self.outputs:
185 self.submitter(lambda: output.delete_repost(status_id))
186 else:
187 for output in self.outputs:
188 self.submitter(lambda: output.delete_post(status_id))
189 self._delete_post_by_id(post["id"])
190
191 def _accept_msg(self, msg: websockets.Data) -> None:
192 data: dict[str, Any] = cast(dict[str, Any], json.loads(msg))
193 event: str = cast(str, data["event"])
194 payload: str = cast(str, data["payload"])
195
196 if event == "update":
197 self._on_create_post(json.loads(payload))
198 elif event == "delete":
199 self._on_delete_post(payload)
200
201 @override
202 async def listen(self):
203 url = f"{self.streaming_url}/api/v1/streaming?stream=user"
204
205 async for ws in websockets.connect(
206 url, additional_headers={"Authorization": f"Bearer {self.options.token}"}
207 ):
208 try:
209 self.log.info("Listening to %s...", self.streaming_url)
210
211 async def listen_for_messages():
212 async for msg in ws:
213 self.submitter(lambda: self._accept_msg(msg))
214
215 listen = asyncio.create_task(listen_for_messages())
216
217 _ = await asyncio.gather(listen)
218 except websockets.ConnectionClosedError as e:
219 self.log.error(e, stack_info=True, exc_info=True)
220 self.log.info("Reconnecting to %s...", self.streaming_url)
221 continue