social media crossposting tool. 3rd time's the charm
mastodon
misskey
crossposting
bluesky
1import asyncio
2import json
3import re
4import uuid
5from dataclasses import dataclass, field
6from typing import Any, cast, override
7
8import websockets
9
10from cross.attachments import (
11 LabelsAttachment,
12 MediaAttachment,
13 QuoteAttachment,
14 RemoteUrlAttachment,
15 SensitiveAttachment,
16)
17from cross.media import Blob, download_blob
18from cross.post import Post
19from cross.service import InputService
20from database.connection import DatabasePool
21from misskey.info import MisskeyService
22from util.markdown import MarkdownParser
23from util.util import normalize_service_url
24
25ALLOWED_VISIBILITY = ["public", "home"]
26
27
28@dataclass
29class MisskeyInputOptions:
30 token: str
31 instance: str
32 allowed_visibility: list[str] = field(
33 default_factory=lambda: ALLOWED_VISIBILITY.copy()
34 )
35 filters: list[re.Pattern[str]] = field(default_factory=lambda: [])
36
37 @classmethod
38 def from_dict(cls, data: dict[str, Any]) -> "MisskeyInputOptions":
39 data["instance"] = normalize_service_url(data["instance"])
40
41 if "allowed_visibility" in data:
42 for vis in data.get("allowed_visibility", []):
43 if vis not in ALLOWED_VISIBILITY:
44 raise ValueError(f"Invalid visibility option {vis}!")
45
46 if "filters" in data:
47 data["filters"] = [re.compile(r) for r in data["filters"]]
48
49 return MisskeyInputOptions(**data)
50
51
52class MisskeyInputService(MisskeyService, InputService):
53 def __init__(self, db: DatabasePool, options: MisskeyInputOptions) -> None:
54 super().__init__(options.instance, db)
55 self.options: MisskeyInputOptions = options
56
57 self.log.info("Verifying %s credentails...", self.url)
58 responce = self.verify_credentials()
59 self.user_id: str = responce["id"]
60
61 @override
62 def _get_token(self) -> str:
63 return self.options.token
64
65 def _on_note(self, note: dict[str, Any]):
66 if note["userId"] != self.user_id:
67 return
68
69 if note["visibility"] not in self.options.allowed_visibility:
70 return
71
72 if note.get("poll"):
73 self.log.info("Skipping '%s'! Contains a poll..", note["id"])
74 return
75
76 renote: dict[str, Any] | None = note.get("renote")
77 if renote:
78 if note.get("text") is None:
79 self._on_renote(note, renote)
80 return
81
82 if renote["userId"] != self.user_id:
83 return
84
85 rrenote = self._get_post(self.url, self.user_id, renote["id"])
86 if not rrenote:
87 self.log.info(
88 "Skipping %s, quote %s not found in db", note["id"], renote["id"]
89 )
90 return
91
92 reply: dict[str, Any] | None = note.get("reply")
93 if reply:
94 if reply.get("userId") != self.user_id:
95 self.log.info("Skipping '%s'! Reply to other user..", note["id"])
96 return
97
98 parent = None
99 if reply:
100 parent = self._get_post(self.url, self.user_id, reply["id"])
101 if not parent:
102 self.log.info(
103 "Skipping %s, parent %s not found in db", note["id"], reply["id"]
104 )
105 return
106
107 parser = MarkdownParser() # TODO MFM parser
108 text, fragments = parser.parse(note.get("text", ""))
109 post = Post(id=note["id"], parent_id=reply["id"] if reply else None, text=text)
110 post.fragments.extend(fragments)
111
112 post.attachments.put(RemoteUrlAttachment(url=self.url + "/notes/" + note["id"]))
113 if renote:
114 post.attachments.put(QuoteAttachment(quoted_id=renote['id'], quoted_user=self.user_id))
115 if any([a.get("isSensitive", False) for a in note.get("files", [])]):
116 post.attachments.put(SensitiveAttachment(sensitive=True))
117 if note.get("cw"):
118 post.attachments.put(LabelsAttachment(labels=[note["cw"]]))
119
120 blobs: list[Blob] = []
121 for media in note.get("files", []):
122 self.log.info("Downloading %s...", media["url"])
123 blob: Blob | None = download_blob(media["url"], media.get("comment", ""))
124 if not blob:
125 self.log.error(
126 "Skipping %s! Failed to download media %s.",
127 note["id"],
128 media["url"],
129 )
130 return
131 blobs.append(blob)
132
133 if blobs:
134 post.attachments.put(MediaAttachment(blobs=blobs))
135
136 if parent:
137 self._insert_post(
138 {
139 "user": self.user_id,
140 "service": self.url,
141 "identifier": note["id"],
142 "parent": parent["id"],
143 "root": parent["id"] if not parent["root"] else parent["root"],
144 }
145 )
146 else:
147 self._insert_post(
148 {
149 "user": self.user_id,
150 "service": self.url,
151 "identifier": note["id"],
152 }
153 )
154
155 for out in self.outputs:
156 self.submitter(lambda: out.accept_post(post))
157
158 def _on_renote(self, note: dict[str, Any], renote: dict[str, Any]):
159 reposted = self._get_post(self.url, self.user_id, renote["id"])
160 if not reposted:
161 self.log.info(
162 "Skipping repost '%s' as reposted post '%s' was not found in the db.",
163 note["id"],
164 renote["id"],
165 )
166 return
167
168 self._insert_post(
169 {
170 "user": self.user_id,
171 "service": self.url,
172 "identifier": note["id"],
173 "reposted": reposted["id"],
174 }
175 )
176
177 for out in self.outputs:
178 self.submitter(lambda: out.accept_repost(note["id"], renote["id"]))
179
180 def _accept_msg(self, msg: websockets.Data) -> None:
181 data: dict[str, Any] = cast(dict[str, Any], json.loads(msg))
182
183 if data["type"] == "channel":
184 type: str = cast(str, data["body"]["type"])
185 if type == "note" or type == "reply":
186 note_body = data["body"]["body"]
187 self._on_note(note_body)
188 return
189
190 async def _subscribe_to_home(self, ws: websockets.ClientConnection) -> None:
191 await ws.send(
192 json.dumps(
193 {
194 "type": "connect",
195 "body": {"channel": "homeTimeline", "id": str(uuid.uuid4())},
196 }
197 )
198 )
199 self.log.info("Subscribed to 'homeTimeline' channel...")
200
201 @override
202 async def listen(self):
203 streaming: str = f"{'wss' if self.url.startswith('https') else 'ws'}://{self.url.split('://', 1)[1]}"
204 url: str = f"{streaming}/streaming?i={self.options.token}"
205
206 async for ws in websockets.connect(url):
207 try:
208 self.log.info("Listening to %s...", streaming)
209 await self._subscribe_to_home(ws)
210
211 async def listen_for_messages():
212 async for msg in ws:
213 self.submitter(lambda: self._accept_msg(msg))
214
215 listen = asyncio.create_task(listen_for_messages())
216
217 _ = await asyncio.gather(listen)
218 except websockets.ConnectionClosedError as e:
219 self.log.error(e, stack_info=True, exc_info=True)
220 self.log.info("Reconnecting to %s...", streaming)
221 continue