social media crossposting tool. 3rd time's the charm
mastodon
misskey
crossposting
bluesky
1import asyncio
2import json
3import re
4from typing import Any, Callable
5
6import requests
7import websockets
8
9import cross
10import util.database as database
11import util.html_util as html_util
12import util.md_util as md_util
13from mastodon.common import MastodonPost
14from util.database import DataBaseWorker
15from util.media import MediaInfo, download_media
16from util.util import LOGGER, as_envvar
17
18ALLOWED_VISIBILITY = ["public", "unlisted"]
19MARKDOWNY = ["text/x.misskeymarkdown", "text/markdown", "text/plain"]
20
21
22class MastodonInputOptions:
23 def __init__(self, o: dict) -> None:
24 self.allowed_visibility = ALLOWED_VISIBILITY
25 self.filters = [re.compile(f) for f in o.get("regex_filters", [])]
26
27 allowed_visibility = o.get("allowed_visibility")
28 if allowed_visibility is not None:
29 if any([v not in ALLOWED_VISIBILITY for v in allowed_visibility]):
30 raise ValueError(
31 f"'allowed_visibility' only accepts {', '.join(ALLOWED_VISIBILITY)}, got: {allowed_visibility}"
32 )
33 self.allowed_visibility = allowed_visibility
34
35
36class MastodonInput(cross.Input):
37 def __init__(self, settings: dict, db: DataBaseWorker) -> None:
38 self.options = MastodonInputOptions(settings.get("options", {}))
39 self.token = as_envvar(settings.get("token")) or (_ for _ in ()).throw(
40 ValueError("'token' is required")
41 )
42 instance: str = as_envvar(settings.get("instance")) or (_ for _ in ()).throw(
43 ValueError("'instance' is required")
44 )
45
46 service = instance[:-1] if instance.endswith("/") else instance
47
48 LOGGER.info("Verifying %s credentails...", service)
49 responce = requests.get(
50 f"{service}/api/v1/accounts/verify_credentials",
51 headers={"Authorization": f"Bearer {self.token}"},
52 )
53 if responce.status_code != 200:
54 LOGGER.error("Failed to validate user credentials!")
55 responce.raise_for_status()
56 return
57
58 super().__init__(service, responce.json()["id"], settings, db)
59 self.streaming = self._get_streaming_url()
60
61 if not self.streaming:
62 raise Exception("Instance %s does not support streaming!", service)
63
64 def _get_streaming_url(self):
65 response = requests.get(f"{self.service}/api/v1/instance")
66 response.raise_for_status()
67 data: dict = response.json()
68 return (data.get("urls") or {}).get("streaming_api")
69
70 def __to_tokens(self, status: dict):
71 content_type = status.get("content_type", "text/plain")
72 raw_text = status.get("text")
73
74 tags: list[str] = []
75 for tag in status.get("tags", []):
76 tags.append(tag["name"])
77
78 mentions: list[tuple[str, str]] = []
79 for mention in status.get("mentions", []):
80 mentions.append(("@" + mention["username"], "@" + mention["acct"]))
81
82 if raw_text and content_type in MARKDOWNY:
83 return md_util.tokenize_markdown(raw_text, tags, mentions)
84
85 akkoma_ext: dict | None = status.get("akkoma", {}).get("source")
86 if akkoma_ext:
87 if akkoma_ext.get("mediaType") in MARKDOWNY:
88 return md_util.tokenize_markdown(akkoma_ext["content"], tags, mentions)
89
90 tokenizer = html_util.HTMLPostTokenizer()
91 tokenizer.mentions = mentions
92 tokenizer.tags = tags
93 tokenizer.feed(status.get("content", ""))
94 return tokenizer.get_tokens()
95
96 def _on_create_post(self, outputs: list[cross.Output], status: dict):
97 # skip events from other users
98 if (status.get("account") or {})["id"] != self.user_id:
99 return
100
101 if status.get("visibility") not in self.options.allowed_visibility:
102 # Skip f/o and direct posts
103 LOGGER.info(
104 "Skipping '%s'! '%s' visibility..",
105 status["id"],
106 status.get("visibility"),
107 )
108 return
109
110 # TODO polls not supported on bsky. maybe 3rd party? skip for now
111 # we don't handle reblogs. possible with bridgy(?) and self
112 # we don't handle quotes.
113 if status.get("poll"):
114 LOGGER.info("Skipping '%s'! Contains a poll..", status["id"])
115 return
116
117 if status.get("quote_id") or status.get("quote"):
118 LOGGER.info("Skipping '%s'! Quote..", status["id"])
119 return
120
121 reblog: dict | None = status.get("reblog")
122 if reblog:
123 if (reblog.get("account") or {})["id"] != self.user_id:
124 LOGGER.info("Skipping '%s'! Reblog of other user..", status["id"])
125 return
126
127 success = database.try_insert_repost(
128 self.db, status["id"], reblog["id"], self.user_id, self.service
129 )
130 if not success:
131 LOGGER.info(
132 "Skipping '%s' as reblogged post was not found in db!", status["id"]
133 )
134 return
135
136 for output in outputs:
137 output.accept_repost(status["id"], reblog["id"])
138 return
139
140 in_reply: str | None = status.get("in_reply_to_id")
141 in_reply_to: str | None = status.get("in_reply_to_account_id")
142 if in_reply_to and in_reply_to != self.user_id:
143 # We don't support replies.
144 LOGGER.info("Skipping '%s'! Reply to other user..", status["id"])
145 return
146
147 success = database.try_insert_post(
148 self.db, status["id"], in_reply, self.user_id, self.service
149 )
150 if not success:
151 LOGGER.info(
152 "Skipping '%s' as parent post was not found in db!", status["id"]
153 )
154 return
155
156 tokens = self.__to_tokens(status)
157 if not cross.test_filters(tokens, self.options.filters):
158 LOGGER.info("Skipping '%s'. Matched a filter!", status["id"])
159 return
160
161 LOGGER.info("Crossposting '%s'...", status["id"])
162
163 media_attachments: list[MediaInfo] = []
164 for attachment in status.get("media_attachments", []):
165 LOGGER.info("Downloading %s...", attachment["url"])
166 info = download_media(
167 attachment["url"], attachment.get("description") or ""
168 )
169 if not info:
170 LOGGER.error("Skipping '%s'. Failed to download media!", status["id"])
171 return
172 media_attachments.append(info)
173
174 cross_post = MastodonPost(status, tokens, media_attachments)
175 for output in outputs:
176 output.accept_post(cross_post)
177
178 def _on_delete_post(self, outputs: list[cross.Output], identifier: str):
179 post = database.find_post(self.db, identifier, self.user_id, self.service)
180 if not post:
181 return
182
183 LOGGER.info("Deleting '%s'...", identifier)
184 if post["reposted_id"]:
185 for output in outputs:
186 output.delete_repost(identifier)
187 else:
188 for output in outputs:
189 output.delete_post(identifier)
190
191 database.delete_post(self.db, identifier, self.user_id, self.service)
192
193 def _on_post(self, outputs: list[cross.Output], event: str, payload: str):
194 match event:
195 case "update":
196 self._on_create_post(outputs, json.loads(payload))
197 case "delete":
198 self._on_delete_post(outputs, payload)
199
200 async def listen(
201 self, outputs: list[cross.Output], submit: Callable[[Callable[[], Any]], Any]
202 ):
203 uri = f"{self.streaming}/api/v1/streaming?stream=user&access_token={self.token}"
204
205 async for ws in websockets.connect(
206 uri, extra_headers={"User-Agent": "XPost/0.0.3"}
207 ):
208 try:
209 LOGGER.info("Listening to %s...", self.streaming)
210
211 async def listen_for_messages():
212 async for msg in ws:
213 data = json.loads(msg)
214 event: str = data.get("event")
215 payload: str = data.get("payload")
216
217 submit(lambda: self._on_post(outputs, str(event), str(payload)))
218
219 listen = asyncio.create_task(listen_for_messages())
220
221 await asyncio.gather(listen)
222 except websockets.ConnectionClosedError as e:
223 LOGGER.error(e, stack_info=True, exc_info=True)
224 LOGGER.info("Reconnecting to %s...", self.streaming)
225 continue