social media crossposting tool. 3rd time's the charm
mastodon
misskey
crossposting
bluesky
1import asyncio
2import json
3import re
4from dataclasses import dataclass, field
5from typing import Any, cast, override
6
7import websockets
8
9from cross.service import InputService
10from database.connection import DatabasePool
11from mastodon.info import MastodonService, validate_and_transform
12from util.util import LOGGER
13
14ALLOWED_VISIBILITY: list[str] = ["public", "unlisted"]
15
16
17@dataclass(kw_only=True)
18class MastodonInputOptions:
19 token: str
20 instance: str
21 allowed_visibility: list[str] = field(
22 default_factory=lambda: ALLOWED_VISIBILITY.copy()
23 )
24 filters: list[re.Pattern[str]] = field(default_factory=lambda: [])
25
26 @classmethod
27 def from_dict(cls, data: dict[str, Any]) -> "MastodonInputOptions":
28 validate_and_transform(data)
29
30 if "allowed_visibility" in data:
31 for vis in data.get("allowed_visibility", []):
32 if vis not in ALLOWED_VISIBILITY:
33 raise ValueError(f"Invalid visibility option {vis}!")
34
35 if "filters" in data:
36 data["filters"] = [re.compile(r) for r in data["filters"]]
37
38 return MastodonInputOptions(**data)
39
40
41class MastodonInputService(MastodonService, InputService):
42 def __init__(self, db: DatabasePool, options: MastodonInputOptions) -> None:
43 super().__init__(options.instance, db)
44 self.options: MastodonInputOptions = options
45
46 LOGGER.info("Verifying %s credentails...", self.url)
47 responce = self.verify_credentials()
48 self.user_id: str = responce["id"]
49
50 LOGGER.info("Getting %s configuration...", self.url)
51 responce = self.fetch_instance_info()
52 self.streaming_url: str = responce["urls"]["streaming_api"]
53
54 @override
55 def _get_token(self) -> str:
56 return self.options.token
57
58 def _on_create_post(self, status: dict[str, Any]):
59 LOGGER.info(status) # TODO
60
61 def _on_delete_post(self, status_id: str):
62 LOGGER.info(status_id) # TODO
63
64 def _accept_msg(self, msg: websockets.Data) -> None:
65 data: dict[str, Any] = cast(dict[str, Any], json.loads(msg))
66 event: str = cast(str, data['event'])
67 payload: str = cast(str, data['payload'])
68
69 if event == "update":
70 self._on_create_post(json.loads(payload))
71 elif event == "delete":
72 self._on_delete_post(payload)
73
74 @override
75 async def listen(self):
76 url = f"{self.streaming_url}/api/v1/streaming?stream=user"
77
78 async for ws in websockets.connect(
79 url, additional_headers={"Authorization": f"Bearer {self.options.token}"}
80 ):
81 try:
82 LOGGER.info("Listening to %s...", self.streaming_url)
83
84 async def listen_for_messages():
85 async for msg in ws:
86 self.submitter(lambda: self._accept_msg(msg))
87
88 listen = asyncio.create_task(listen_for_messages())
89
90 _ = await asyncio.gather(listen)
91 except websockets.ConnectionClosedError as e:
92 LOGGER.error(e, stack_info=True, exc_info=True)
93 LOGGER.info("Reconnecting to %s...", self.streaming_url)
94 continue