social media crossposting tool. 3rd time's the charm
mastodon
misskey
crossposting
bluesky
1import asyncio
2import json
3import re
4import uuid
5from dataclasses import dataclass, field
6from typing import Any, cast, override
7
8import websockets
9
10from cross.attachments import (
11 LabelsAttachment,
12 MediaAttachment,
13 QuoteAttachment,
14 RemoteUrlAttachment,
15 SensitiveAttachment,
16)
17from cross.media import Blob, download_blob
18from cross.post import Post
19from cross.service import InputService
20from database.connection import DatabasePool
21from misskey.info import MisskeyService
22from util.markdown import MarkdownParser
23from util.util import normalize_service_url
24
25ALLOWED_VISIBILITY = ["public", "home"]
26
27
28@dataclass
29class MisskeyInputOptions:
30 token: str
31 instance: str
32 allowed_visibility: list[str] = field(
33 default_factory=lambda: ALLOWED_VISIBILITY.copy()
34 )
35 filters: list[re.Pattern[str]] = field(default_factory=lambda: [])
36
37 @classmethod
38 def from_dict(cls, data: dict[str, Any]) -> "MisskeyInputOptions":
39 data["instance"] = normalize_service_url(data["instance"])
40
41 if "allowed_visibility" in data:
42 for vis in data.get("allowed_visibility", []):
43 if vis not in ALLOWED_VISIBILITY:
44 raise ValueError(f"Invalid visibility option {vis}!")
45
46 if "filters" in data:
47 data["filters"] = [re.compile(r) for r in data["filters"]]
48
49 return MisskeyInputOptions(**data)
50
51
52class MisskeyInputService(MisskeyService, InputService):
53 def __init__(self, db: DatabasePool, options: MisskeyInputOptions) -> None:
54 super().__init__(options.instance, db)
55 self.options: MisskeyInputOptions = options
56
57 self.log.info("Verifying %s credentails...", self.url)
58 responce = self.verify_credentials()
59 self.user_id: str = responce["id"]
60
61 @override
62 def _get_token(self) -> str:
63 return self.options.token
64
65 def _on_note(self, note: dict[str, Any]):
66 if note["userId"] != self.user_id:
67 return
68
69 if note["visibility"] not in self.options.allowed_visibility:
70 return
71
72 if note.get("poll"):
73 self.log.info("Skipping '%s'! Contains a poll..", note["id"])
74 return
75
76 renote: dict[str, Any] | None = note.get("renote")
77 if renote:
78 if note.get("text") is None:
79 self._on_renote(note, renote)
80 return
81
82 if renote["userId"] != self.user_id:
83 return
84
85 reply: dict[str, Any] | None = note.get("reply")
86 if reply:
87 if reply.get("userId") != self.user_id:
88 self.log.info("Skipping '%s'! Reply to other user..", note["id"])
89 return
90
91 parent = None
92 if reply:
93 parent = self._get_post(self.url, self.user_id, reply["id"])
94 if not parent:
95 self.log.info(
96 "Skipping %s, parent %s not found in db", note["id"], reply["id"]
97 )
98 return
99
100 parser = MarkdownParser() # TODO MFM parser
101 text, fragments = parser.parse(note.get("text", ""))
102 post = Post(id=note["id"], parent_id=reply["id"] if reply else None, text=text)
103 post.fragments.extend(fragments)
104
105 post.attachments.put(RemoteUrlAttachment(url=self.url + "/notes/" + note["id"]))
106 if renote:
107 post.attachments.put(QuoteAttachment(quoted_id=renote['id'], quoted_user=self.user_id))
108 if any([a.get("isSensitive", False) for a in note.get("files", [])]):
109 post.attachments.put(SensitiveAttachment(sensitive=True))
110 if note.get("cw"):
111 post.attachments.put(LabelsAttachment(labels=[note["cw"]]))
112
113 blobs: list[Blob] = []
114 for media in note.get("files", []):
115 self.log.info("Downloading %s...", media["url"])
116 blob: Blob | None = download_blob(media["url"], media.get("comment", ""))
117 if not blob:
118 self.log.error(
119 "Skipping %s! Failed to download media %s.",
120 note["id"],
121 media["url"],
122 )
123 return
124 blobs.append(blob)
125
126 if blobs:
127 post.attachments.put(MediaAttachment(blobs=blobs))
128
129 if parent:
130 self._insert_post(
131 {
132 "user": self.user_id,
133 "service": self.url,
134 "identifier": note["id"],
135 "parent": parent["id"],
136 "root": parent["id"] if not parent["root"] else parent["root"],
137 }
138 )
139 else:
140 self._insert_post(
141 {
142 "user": self.user_id,
143 "service": self.url,
144 "identifier": note["id"],
145 }
146 )
147
148 for out in self.outputs:
149 self.submitter(lambda: out.accept_post(post))
150
151 def _on_renote(self, note: dict[str, Any], renote: dict[str, Any]):
152 reposted = self._get_post(self.url, self.user_id, renote["id"])
153 if not reposted:
154 self.log.info(
155 "Skipping repost '%s' as reposted post '%s' was not found in the db.",
156 note["id"],
157 renote["id"],
158 )
159 return
160
161 self._insert_post(
162 {
163 "user": self.user_id,
164 "service": self.url,
165 "identifier": note["id"],
166 "reposted": reposted["id"],
167 }
168 )
169
170 for out in self.outputs:
171 self.submitter(lambda: out.accept_repost(note["id"], renote["id"]))
172
173 def _accept_msg(self, msg: websockets.Data) -> None:
174 data: dict[str, Any] = cast(dict[str, Any], json.loads(msg))
175
176 if data["type"] == "channel":
177 type: str = cast(str, data["body"]["type"])
178 if type == "note" or type == "reply":
179 note_body = data["body"]["body"]
180 self._on_note(note_body)
181 return
182
183 async def _subscribe_to_home(self, ws: websockets.ClientConnection) -> None:
184 await ws.send(
185 json.dumps(
186 {
187 "type": "connect",
188 "body": {"channel": "homeTimeline", "id": str(uuid.uuid4())},
189 }
190 )
191 )
192 self.log.info("Subscribed to 'homeTimeline' channel...")
193
194 @override
195 async def listen(self):
196 streaming: str = f"{'wss' if self.url.startswith('https') else 'ws'}://{self.url.split('://', 1)[1]}"
197 url: str = f"{streaming}/streaming?i={self.options.token}"
198
199 async for ws in websockets.connect(url):
200 try:
201 self.log.info("Listening to %s...", streaming)
202 await self._subscribe_to_home(ws)
203
204 async def listen_for_messages():
205 async for msg in ws:
206 self.submitter(lambda: self._accept_msg(msg))
207
208 listen = asyncio.create_task(listen_for_messages())
209
210 _ = await asyncio.gather(listen)
211 except websockets.ConnectionClosedError as e:
212 self.log.error(e, stack_info=True, exc_info=True)
213 self.log.info("Reconnecting to %s...", streaming)
214 continue