import asyncio import json import re from dataclasses import dataclass, field from typing import Any, cast, override import websockets from cross.attachments import ( LabelsAttachment, LanguagesAttachment, MediaAttachment, QuoteAttachment, RemoteUrlAttachment, SensitiveAttachment, ) from cross.media import Blob, download_blob from cross.post import Post from cross.service import InputService from database.connection import DatabasePool from mastodon.info import MastodonService, validate_and_transform from mastodon.parser import StatusParser ALLOWED_VISIBILITY: list[str] = ["public", "unlisted"] @dataclass(kw_only=True) class MastodonInputOptions: token: str instance: str allowed_visibility: list[str] = field( default_factory=lambda: ALLOWED_VISIBILITY.copy() ) filters: list[re.Pattern[str]] = field(default_factory=lambda: []) @classmethod def from_dict(cls, data: dict[str, Any]) -> "MastodonInputOptions": validate_and_transform(data) if "allowed_visibility" in data: for vis in data.get("allowed_visibility", []): if vis not in ALLOWED_VISIBILITY: raise ValueError(f"Invalid visibility option {vis}!") if "filters" in data: data["filters"] = [re.compile(r) for r in data["filters"]] return MastodonInputOptions(**data) class MastodonInputService(MastodonService, InputService): def __init__(self, db: DatabasePool, options: MastodonInputOptions) -> None: super().__init__(options.instance, db) self.options: MastodonInputOptions = options self.log.info("Verifying %s credentails...", self.url) response = self.verify_credentials() self.user_id: str = response["id"] self.log.info("Getting %s configuration...", self.url) response = self.fetch_instance_info() self.streaming_url: str = response["urls"]["streaming_api"] @override def _get_token(self) -> str: return self.options.token def _on_create_post(self, status: dict[str, Any]): if status["account"]["id"] != self.user_id: return if status["visibility"] not in self.options.allowed_visibility: return reblog: dict[str, Any] | None = status.get("reblog") if reblog: if reblog["account"]["id"] != self.user_id: return self._on_reblog(status, reblog) return if status.get("poll"): self.log.info("Skipping '%s'! Contains a poll..", status["id"]) return quote: dict[str, Any] | None = status.get("quote") if quote: quote = quote['quoted_status'] if quote.get('quoted_status') else quote if not quote or quote["account"]["id"] != self.user_id: return rquote = self._get_post(self.url, self.user_id, quote['id']) if not rquote: self.log.info( "Skipping %s, parent %s not found in db", status["id"], quote['id'] ) return in_reply: str | None = status.get("in_reply_to_id") in_reply_to: str | None = status.get("in_reply_to_account_id") if in_reply_to and in_reply_to != self.user_id: return parent = None if in_reply: parent = self._get_post(self.url, self.user_id, in_reply) if not parent: self.log.info( "Skipping %s, parent %s not found in db", status["id"], in_reply ) return parser = StatusParser(status) parser.feed(status["content"]) tokens = parser.get_result() post = Post(id=status["id"], parent_id=in_reply, tokens=tokens) if quote: post.attachments.put(QuoteAttachment(quoted_id=quote['id'], quoted_user=self.user_id)) if status.get("url"): post.attachments.put(RemoteUrlAttachment(url=status["url"])) if status.get("sensitive"): post.attachments.put(SensitiveAttachment(sensitive=True)) if status.get("language"): post.attachments.put(LanguagesAttachment(langs=[status["language"]])) if status.get("spoiler"): post.attachments.put(LabelsAttachment(labels=[status["spoiler"]])) blobs: list[Blob] = [] for media in status.get("media_attachments", []): self.log.info("Downloading %s...", media["url"]) blob: Blob | None = download_blob(media["url"], media.get("alt")) if not blob: self.log.error( "Skipping %s! Failed to download media %s.", status["id"], media["url"], ) return blobs.append(blob) if blobs: post.attachments.put(MediaAttachment(blobs=blobs)) if parent: self._insert_post( { "user": self.user_id, "service": self.url, "identifier": status["id"], "parent": parent["id"], "root": parent["id"] if not parent["root"] else parent["root"], } ) else: self._insert_post( { "user": self.user_id, "service": self.url, "identifier": status["id"], } ) for out in self.outputs: self.submitter(lambda: out.accept_post(post)) def _on_reblog(self, status: dict[str, Any], reblog: dict[str, Any]): reposted = self._get_post(self.url, self.user_id, reblog["id"]) if not reposted: self.log.info( "Skipping repost '%s' as reposted post '%s' was not found in the db.", status["id"], reblog["id"], ) return self._insert_post( { "user": self.user_id, "service": self.url, "identifier": status["id"], "reposted": reposted["id"], } ) for out in self.outputs: self.submitter(lambda: out.accept_repost(status["id"], reblog["id"])) def _on_delete_post(self, status_id: str): post = self._get_post(self.url, self.user_id, status_id) if not post: return if post["reposted_id"]: for output in self.outputs: self.submitter(lambda: output.delete_repost(status_id)) else: for output in self.outputs: self.submitter(lambda: output.delete_post(status_id)) self._delete_post_by_id(post["id"]) def _accept_msg(self, msg: websockets.Data) -> None: data: dict[str, Any] = cast(dict[str, Any], json.loads(msg)) event: str = cast(str, data["event"]) payload: str = cast(str, data["payload"]) if event == "update": self._on_create_post(json.loads(payload)) elif event == "delete": self._on_delete_post(payload) @override async def listen(self): url = f"{self.streaming_url}/api/v1/streaming?stream=user" async for ws in websockets.connect( url, additional_headers={"Authorization": f"Bearer {self.options.token}"} ): try: self.log.info("Listening to %s...", self.streaming_url) async def listen_for_messages(): async for msg in ws: self.submitter(lambda: self._accept_msg(msg)) listen = asyncio.create_task(listen_for_messages()) _ = await asyncio.gather(listen) except websockets.ConnectionClosedError as e: self.log.error(e, stack_info=True, exc_info=True) self.log.info("Reconnecting to %s...", self.streaming_url) continue