social media crossposting tool. 3rd time's the charm
mastodon
misskey
crossposting
bluesky
1import requests, websockets
2import asyncio
3import json, uuid
4import re
5
6from misskey.common import MisskeyPost
7
8import cross, util.database as database
9from util.media import MediaInfo, download_media
10from util.util import LOGGER, as_envvar
11
12from typing import Callable, Any
13
14ALLOWED_VISIBILITY = ['public', 'home']
15
16class MisskeyInputOptions():
17 def __init__(self, o: dict) -> None:
18 self.allowed_visibility = ALLOWED_VISIBILITY
19 self.filters = [re.compile(f) for f in o.get('regex_filters', [])]
20
21 allowed_visibility = o.get('allowed_visibility')
22 if allowed_visibility is not None:
23 if any([v not in ALLOWED_VISIBILITY for v in allowed_visibility]):
24 raise ValueError(f"'allowed_visibility' only accepts {', '.join(ALLOWED_VISIBILITY)}, got: {allowed_visibility}")
25 self.allowed_visibility = allowed_visibility
26
27class MisskeyInput(cross.Input):
28 def __init__(self, settings: dict, db: cross.DataBaseWorker) -> None:
29 self.options = MisskeyInputOptions(settings.get('options', {}))
30 self.token = as_envvar(settings.get('token')) or (_ for _ in ()).throw(ValueError("'token' is required"))
31 instance: str = as_envvar(settings.get('instance')) or (_ for _ in ()).throw(ValueError("'instance' is required"))
32
33 service = instance[:-1] if instance.endswith('/') else instance
34
35 LOGGER.info("Verifying %s credentails...", service)
36 responce = requests.post(f"{instance}/api/i", json={ 'i': self.token }, headers={
37 "Content-Type": "application/json"
38 })
39 if responce.status_code != 200:
40 LOGGER.error("Failed to validate user credentials!")
41 responce.raise_for_status()
42 return
43
44 super().__init__(service, responce.json()["id"], settings, db)
45
46 def _on_note(self, outputs: list[cross.Output], note: dict):
47 if note['userId'] != self.user_id:
48 return
49
50 if note.get('renoteId') or note.get('poll'):
51 # TODO polls not supported on bsky. maybe 3rd party? skip for now
52 # we don't handle reblogs. possible with bridgy(?) and self
53 LOGGER.info("Skipping '%s'! Renote or poll..", note['id'])
54 return
55
56 reply_id: str | None = note.get('replyId')
57 if reply_id:
58 if note.get('reply', {}).get('userId') != self.user_id:
59 LOGGER.info("Skipping '%s'! Reply to other user..", note['id'])
60 return
61
62 if note.get('visibility') not in self.options.allowed_visibility:
63 LOGGER.info("Skipping '%s'! '%s' visibility..", note['id'], note.get('visibility'))
64 return
65
66 success = database.try_insert_post(self.db, note['id'], reply_id, self.user_id, self.service)
67 if not success:
68 LOGGER.info("Skipping '%s' as parent note was not found in db!", note['id'])
69 return
70
71 mention_handles: dict = note.get('mentionHandles') or {}
72 tags: list[str] = note.get('tags') or []
73
74 handles: list[tuple[str, str]] = []
75 for key, value in mention_handles.items():
76 handles.append((value, value))
77
78 tokens = cross.tokenize_markdown(note.get('text', ''), tags, handles)
79 if not cross.test_filters(tokens, self.options.filters):
80 LOGGER.info("Skipping '%s'. Matched a filter!", note['id'])
81 return
82
83 LOGGER.info("Crossposting '%s'...", note['id'])
84
85 media_attachments: list[MediaInfo] = []
86 for attachment in note.get('files', []):
87 LOGGER.info("Downloading %s...", attachment['url'])
88 info = download_media(attachment['url'], attachment.get('comment') or '')
89 if not info:
90 LOGGER.error("Skipping '%s'. Failed to download media!", note['id'])
91 return
92 media_attachments.append(info)
93
94 cross_post = MisskeyPost(note, tokens, media_attachments)
95 for output in outputs:
96 output.accept_post(cross_post)
97
98 def _on_delete(self, outputs: list[cross.Output], note: dict):
99 # TODO handle deletes
100 pass
101
102 def _on_message(self, outputs: list[cross.Output], data: dict):
103
104 if data['type'] == 'channel':
105 type: str = data['body']['type']
106 if type == 'note' or type == 'reply':
107 note_body = data['body']['body']
108 self._on_note(outputs, note_body)
109 return
110
111 pass
112
113 async def _send_keepalive(self, ws: websockets.WebSocketClientProtocol):
114 while ws.open:
115 try:
116 await asyncio.sleep(120)
117 if ws.open:
118 await ws.send("h")
119 LOGGER.debug("Sent keepalive h..")
120 else:
121 LOGGER.info("WebSocket is closed, stopping keepalive task.")
122 break
123 except Exception as e:
124 LOGGER.error(f"Error sending keepalive: {e}")
125 break
126
127 async def _subscribe_to_home(self, ws: websockets.WebSocketClientProtocol):
128 await ws.send(json.dumps({
129 "type": "connect",
130 "body": {
131 "channel": "homeTimeline",
132 "id": str(uuid.uuid4())
133 }
134 }))
135 LOGGER.info("Subscribed to 'homeTimeline' channel...")
136
137
138 async def listen(self, outputs: list[cross.Output], submit: Callable[[Callable[[], Any]], Any]):
139 streaming: str = f"wss://{self.service.split("://", 1)[1]}"
140 url: str = f"{streaming}/streaming?i={self.token}"
141
142 async for ws in websockets.connect(url, extra_headers={"User-Agent": "XPost/0.0.3"}):
143 try:
144 LOGGER.info("Listening to %s...", streaming)
145 await self._subscribe_to_home(ws)
146
147 async def listen_for_messages():
148 async for msg in ws:
149 # TODO listen to deletes somehow
150 submit(lambda: self._on_message(outputs, json.loads(msg)))
151
152 keepalive = asyncio.create_task(self._send_keepalive(ws))
153 listen = asyncio.create_task(listen_for_messages())
154
155 await asyncio.gather(keepalive, listen)
156 except websockets.ConnectionClosedError as e:
157 LOGGER.error(e, stack_info=True, exc_info=True)
158 LOGGER.info("Reconnecting to %s...", streaming)
159 continue