social media crossposting tool. 3rd time's the charm
mastodon
misskey
crossposting
bluesky
1import asyncio
2import json
3import re
4import uuid
5from typing import Any, Callable
6
7import requests
8import websockets
9
10import cross
11import util.database as database
12import util.md_util as md_util
13from misskey.common import MisskeyPost
14from util.media import MediaInfo, download_media
15from util.util import LOGGER, as_envvar
16
17ALLOWED_VISIBILITY = ["public", "home"]
18
19
20class MisskeyInputOptions:
21 def __init__(self, o: dict) -> None:
22 self.allowed_visibility = ALLOWED_VISIBILITY
23 self.filters = [re.compile(f) for f in o.get("regex_filters", [])]
24
25 allowed_visibility = o.get("allowed_visibility")
26 if allowed_visibility is not None:
27 if any([v not in ALLOWED_VISIBILITY for v in allowed_visibility]):
28 raise ValueError(
29 f"'allowed_visibility' only accepts {', '.join(ALLOWED_VISIBILITY)}, got: {allowed_visibility}"
30 )
31 self.allowed_visibility = allowed_visibility
32
33
34class MisskeyInput(cross.Input):
35 def __init__(self, settings: dict, db: cross.DataBaseWorker) -> None:
36 self.options = MisskeyInputOptions(settings.get("options", {}))
37 self.token = as_envvar(settings.get("token")) or (_ for _ in ()).throw(
38 ValueError("'token' is required")
39 )
40 instance: str = as_envvar(settings.get("instance")) or (_ for _ in ()).throw(
41 ValueError("'instance' is required")
42 )
43
44 service = instance[:-1] if instance.endswith("/") else instance
45
46 LOGGER.info("Verifying %s credentails...", service)
47 responce = requests.post(
48 f"{instance}/api/i",
49 json={"i": self.token},
50 headers={"Content-Type": "application/json"},
51 )
52 if responce.status_code != 200:
53 LOGGER.error("Failed to validate user credentials!")
54 responce.raise_for_status()
55 return
56
57 super().__init__(service, responce.json()["id"], settings, db)
58
59 def _on_note(self, outputs: list[cross.Output], note: dict):
60 if note["userId"] != self.user_id:
61 return
62
63 if note.get("visibility") not in self.options.allowed_visibility:
64 LOGGER.info(
65 "Skipping '%s'! '%s' visibility..", note["id"], note.get("visibility")
66 )
67 return
68
69 # TODO polls not supported on bsky. maybe 3rd party? skip for now
70 # we don't handle reblogs. possible with bridgy(?) and self
71 if note.get("poll"):
72 LOGGER.info("Skipping '%s'! Contains a poll..", note["id"])
73 return
74
75 renote: dict | None = note.get("renote")
76 if renote:
77 if note.get("text") is not None:
78 LOGGER.info("Skipping '%s'! Quote..", note["id"])
79 return
80
81 if renote.get("userId") != self.user_id:
82 LOGGER.info("Skipping '%s'! Reblog of other user..", note["id"])
83 return
84
85 success = database.try_insert_repost(
86 self.db, note["id"], renote["id"], self.user_id, self.service
87 )
88 if not success:
89 LOGGER.info(
90 "Skipping '%s' as renoted note was not found in db!", note["id"]
91 )
92 return
93
94 for output in outputs:
95 output.accept_repost(note["id"], renote["id"])
96 return
97
98 reply_id: str | None = note.get("replyId")
99 if reply_id:
100 if note.get("reply", {}).get("userId") != self.user_id:
101 LOGGER.info("Skipping '%s'! Reply to other user..", note["id"])
102 return
103
104 success = database.try_insert_post(
105 self.db, note["id"], reply_id, self.user_id, self.service
106 )
107 if not success:
108 LOGGER.info("Skipping '%s' as parent note was not found in db!", note["id"])
109 return
110
111 mention_handles: dict = note.get("mentionHandles") or {}
112 tags: list[str] = note.get("tags") or []
113
114 handles: list[tuple[str, str]] = []
115 for key, value in mention_handles.items():
116 handles.append((value, value))
117
118 tokens = md_util.tokenize_markdown(note.get("text", ""), tags, handles)
119 if not cross.test_filters(tokens, self.options.filters):
120 LOGGER.info("Skipping '%s'. Matched a filter!", note["id"])
121 return
122
123 LOGGER.info("Crossposting '%s'...", note["id"])
124
125 media_attachments: list[MediaInfo] = []
126 for attachment in note.get("files", []):
127 LOGGER.info("Downloading %s...", attachment["url"])
128 info = download_media(attachment["url"], attachment.get("comment") or "")
129 if not info:
130 LOGGER.error("Skipping '%s'. Failed to download media!", note["id"])
131 return
132 media_attachments.append(info)
133
134 cross_post = MisskeyPost(self.service, note, tokens, media_attachments)
135 for output in outputs:
136 output.accept_post(cross_post)
137
138 def _on_delete(self, outputs: list[cross.Output], note: dict):
139 # TODO handle deletes
140 pass
141
142 def _on_message(self, outputs: list[cross.Output], data: dict):
143 if data["type"] == "channel":
144 type: str = data["body"]["type"]
145 if type == "note" or type == "reply":
146 note_body = data["body"]["body"]
147 self._on_note(outputs, note_body)
148 return
149
150 pass
151
152 async def _send_keepalive(self, ws: websockets.WebSocketClientProtocol):
153 while ws.open:
154 try:
155 await asyncio.sleep(120)
156 if ws.open:
157 await ws.send("h")
158 LOGGER.debug("Sent keepalive h..")
159 else:
160 LOGGER.info("WebSocket is closed, stopping keepalive task.")
161 break
162 except Exception as e:
163 LOGGER.error(f"Error sending keepalive: {e}")
164 break
165
166 async def _subscribe_to_home(self, ws: websockets.WebSocketClientProtocol):
167 await ws.send(
168 json.dumps(
169 {
170 "type": "connect",
171 "body": {"channel": "homeTimeline", "id": str(uuid.uuid4())},
172 }
173 )
174 )
175 LOGGER.info("Subscribed to 'homeTimeline' channel...")
176
177 async def listen(
178 self, outputs: list[cross.Output], submit: Callable[[Callable[[], Any]], Any]
179 ):
180 streaming: str = f"wss://{self.service.split('://', 1)[1]}"
181 url: str = f"{streaming}/streaming?i={self.token}"
182
183 async for ws in websockets.connect(
184 url, extra_headers={"User-Agent": "XPost/0.0.3"}
185 ):
186 try:
187 LOGGER.info("Listening to %s...", streaming)
188 await self._subscribe_to_home(ws)
189
190 async def listen_for_messages():
191 async for msg in ws:
192 # TODO listen to deletes somehow
193 submit(lambda: self._on_message(outputs, json.loads(msg)))
194
195 keepalive = asyncio.create_task(self._send_keepalive(ws))
196 listen = asyncio.create_task(listen_for_messages())
197
198 await asyncio.gather(keepalive, listen)
199 except websockets.ConnectionClosedError as e:
200 LOGGER.error(e, stack_info=True, exc_info=True)
201 LOGGER.info("Reconnecting to %s...", streaming)
202 continue