social media crossposting tool. 3rd time's the charm
mastodon
misskey
crossposting
bluesky
1import asyncio
2import json
3import re
4import uuid
5from dataclasses import dataclass, field
6from typing import Any, cast, override
7
8import websockets
9
10from cross.attachments import (
11 LabelsAttachment,
12 MediaAttachment,
13 QuoteAttachment,
14 RemoteUrlAttachment,
15 SensitiveAttachment,
16)
17from cross.media import Blob, download_blob
18from cross.post import Post
19from cross.service import InputService
20from database.connection import DatabasePool
21from misskey.info import MisskeyService
22from util.markdown import MarkdownParser
23from util.util import normalize_service_url
24
25ALLOWED_VISIBILITY = ["public", "home"]
26
27
28@dataclass
29class MisskeyInputOptions:
30 token: str
31 instance: str
32 allowed_visibility: list[str] = field(
33 default_factory=lambda: ALLOWED_VISIBILITY.copy()
34 )
35 filters: list[re.Pattern[str]] = field(default_factory=lambda: [])
36
37 @classmethod
38 def from_dict(cls, data: dict[str, Any]) -> "MisskeyInputOptions":
39 data["instance"] = normalize_service_url(data["instance"])
40
41 if "allowed_visibility" in data:
42 for vis in data.get("allowed_visibility", []):
43 if vis not in ALLOWED_VISIBILITY:
44 raise ValueError(f"Invalid visibility option {vis}!")
45
46 if "filters" in data:
47 data["filters"] = [re.compile(r) for r in data["filters"]]
48
49 return MisskeyInputOptions(**data)
50
51
52class MisskeyInputService(MisskeyService, InputService):
53 def __init__(self, db: DatabasePool, options: MisskeyInputOptions) -> None:
54 super().__init__(options.instance, db)
55 self.options: MisskeyInputOptions = options
56
57 self.log.info("Verifying %s credentails...", self.url)
58 response = self.verify_credentials()
59 self.user_id: str = response["id"]
60
61 @override
62 def _get_token(self) -> str:
63 return self.options.token
64
65 def _on_note(self, note: dict[str, Any]):
66 if note["userId"] != self.user_id:
67 return
68
69 if note["visibility"] not in self.options.allowed_visibility:
70 return
71
72 if note.get("poll"):
73 self.log.info("Skipping '%s'! Contains a poll..", note["id"])
74 return
75
76 renote: dict[str, Any] | None = note.get("renote")
77 if renote:
78 if note.get("text") is None:
79 self._on_renote(note, renote)
80 return
81
82 if renote["userId"] != self.user_id:
83 return
84
85 rrenote = self._get_post(self.url, self.user_id, renote["id"])
86 if not rrenote:
87 self.log.info(
88 "Skipping %s, quote %s not found in db", note["id"], renote["id"]
89 )
90 return
91
92 reply: dict[str, Any] | None = note.get("reply")
93 if reply:
94 if reply.get("userId") != self.user_id:
95 self.log.info("Skipping '%s'! Reply to other user..", note["id"])
96 return
97
98 parent = None
99 if reply:
100 parent = self._get_post(self.url, self.user_id, reply["id"])
101 if not parent:
102 self.log.info(
103 "Skipping %s, parent %s not found in db", note["id"], reply["id"]
104 )
105 return
106
107 mention_handles: dict = note.get("mentionHandles") or {}
108 tags: list[str] = note.get("tags") or []
109
110 handles: list[tuple[str, str]] = []
111 for key, value in mention_handles.items():
112 handles.append((value, value))
113
114 parser = MarkdownParser() # TODO MFM parser
115 tokens = parser.parse(note.get("text", ""), tags, handles)
116 post = Post(id=note["id"], parent_id=reply["id"] if reply else None, tokens=tokens)
117
118 post.attachments.put(RemoteUrlAttachment(url=self.url + "/notes/" + note["id"]))
119 if renote:
120 post.attachments.put(QuoteAttachment(quoted_id=renote['id'], quoted_user=self.user_id))
121 if any([a.get("isSensitive", False) for a in note.get("files", [])]):
122 post.attachments.put(SensitiveAttachment(sensitive=True))
123 if note.get("cw"):
124 post.attachments.put(LabelsAttachment(labels=[note["cw"]]))
125
126 blobs: list[Blob] = []
127 for media in note.get("files", []):
128 self.log.info("Downloading %s...", media["url"])
129 blob: Blob | None = download_blob(media["url"], media.get("comment", ""))
130 if not blob:
131 self.log.error(
132 "Skipping %s! Failed to download media %s.",
133 note["id"],
134 media["url"],
135 )
136 return
137 blobs.append(blob)
138
139 if blobs:
140 post.attachments.put(MediaAttachment(blobs=blobs))
141
142 if parent:
143 self._insert_post(
144 {
145 "user": self.user_id,
146 "service": self.url,
147 "identifier": note["id"],
148 "parent": parent["id"],
149 "root": parent["id"] if not parent["root"] else parent["root"],
150 }
151 )
152 else:
153 self._insert_post(
154 {
155 "user": self.user_id,
156 "service": self.url,
157 "identifier": note["id"],
158 }
159 )
160
161 for out in self.outputs:
162 self.submitter(lambda: out.accept_post(post))
163
164 def _on_renote(self, note: dict[str, Any], renote: dict[str, Any]):
165 reposted = self._get_post(self.url, self.user_id, renote["id"])
166 if not reposted:
167 self.log.info(
168 "Skipping repost '%s' as reposted post '%s' was not found in the db.",
169 note["id"],
170 renote["id"],
171 )
172 return
173
174 self._insert_post(
175 {
176 "user": self.user_id,
177 "service": self.url,
178 "identifier": note["id"],
179 "reposted": reposted["id"],
180 }
181 )
182
183 for out in self.outputs:
184 self.submitter(lambda: out.accept_repost(note["id"], renote["id"]))
185
186 def _accept_msg(self, msg: websockets.Data) -> None:
187 data: dict[str, Any] = cast(dict[str, Any], json.loads(msg))
188
189 if data["type"] == "channel":
190 type: str = cast(str, data["body"]["type"])
191 if type == "note" or type == "reply":
192 note_body = data["body"]["body"]
193 self._on_note(note_body)
194 return
195
196 async def _subscribe_to_home(self, ws: websockets.ClientConnection) -> None:
197 await ws.send(
198 json.dumps(
199 {
200 "type": "connect",
201 "body": {"channel": "homeTimeline", "id": str(uuid.uuid4())},
202 }
203 )
204 )
205 self.log.info("Subscribed to 'homeTimeline' channel...")
206
207 @override
208 async def listen(self):
209 streaming: str = f"{'wss' if self.url.startswith('https') else 'ws'}://{self.url.split('://', 1)[1]}"
210 url: str = f"{streaming}/streaming?i={self.options.token}"
211
212 async for ws in websockets.connect(url):
213 try:
214 self.log.info("Listening to %s...", streaming)
215 await self._subscribe_to_home(ws)
216
217 async def listen_for_messages():
218 async for msg in ws:
219 self.submitter(lambda: self._accept_msg(msg))
220
221 listen = asyncio.create_task(listen_for_messages())
222
223 _ = await asyncio.gather(listen)
224 except websockets.ConnectionClosedError as e:
225 self.log.error(e, stack_info=True, exc_info=True)
226 self.log.info("Reconnecting to %s...", streaming)
227 continue