social media crossposting tool. 3rd time's the charm
mastodon misskey crossposting bluesky

pool db connections, atp identity resolution, bsky jetstream input

zenfyr.dev d58e914e b74b024b

verified
Changed files
+187 -49
atproto
bluesky
cross
database
mastodon
misskey
+49 -14
atproto/identity.py
···
from typing import Any
+
import dns.resolver
import requests
+
+
import env
from util.cache import TTLCache
-
from util.util import LOGGER
+
from util.util import LOGGER, normalize_service_url
+
+
+
class DidDocument:
+
def __init__(self, raw_doc: dict[str, Any]) -> None:
+
self.raw: dict[str, Any] = raw_doc
+
self.atproto_pds: str | None = None
+
+
def get_atproto_pds(self) -> str | None:
+
if self.atproto_pds:
+
return self.atproto_pds
+
+
services = self.raw.get("service")
+
if not services:
+
return None
+
+
for service in services:
+
if (
+
service.get("id") == "#atproto_pds"
+
and service.get("type") == "AtprotoPersonalDataServer"
+
):
+
endpoint = service.get("serviceEndpoint")
+
if endpoint:
+
url = normalize_service_url(endpoint)
+
self.atproto_pds = url
+
return url
+
self.atproto_pds = ""
+
return None
+
class DidResolver:
def __init__(self, plc_host: str) -> None:
self.plc_host: str = plc_host
-
self.__cache: TTLCache[str, dict[str, Any]] = TTLCache(ttl_seconds=12*60*60)
+
self.__cache: TTLCache[str, DidDocument] = TTLCache(ttl_seconds=12 * 60 * 60)
-
def try_resolve_plc(self, did: str) -> dict[str, Any] | None:
+
def try_resolve_plc(self, did: str) -> DidDocument | None:
url = f"{self.plc_host}/{did}"
response = requests.get(url, timeout=10, allow_redirects=True)
if response.status_code == 200:
-
return response.json()
+
return DidDocument(response.json())
elif response.status_code == 404 or response.status_code == 410:
-
return None # tombstone or not registered
+
return None # tombstone or not registered
else:
response.raise_for_status()
-
def try_resolve_web(self, did: str) -> dict[str, Any] | None:
-
url = f"http://{did[len('did:web:'):]}/.well-known/did.json"
+
def try_resolve_web(self, did: str) -> DidDocument | None:
+
url = f"http://{did[len('did:web:') :]}/.well-known/did.json"
response = requests.get(url, timeout=10, allow_redirects=True)
if response.status_code == 200:
-
return response.json()
+
return DidDocument(response.json())
elif response.status_code == 404 or response.status_code == 410:
-
return None # tombstone or gone
+
return None # tombstone or gone
else:
response.raise_for_status()
-
def resolve_did(self, did: str) -> dict[str, Any]:
+
def resolve_did(self, did: str) -> DidDocument:
cached = self.__cache.get(did)
if cached:
return cached
-
if did.startswith('did:plc:'):
+
if did.startswith("did:plc:"):
from_plc = self.try_resolve_plc(did)
if from_plc:
self.__cache.set(did, from_plc)
return from_plc
-
elif did.startswith('did:web:'):
+
elif did.startswith("did:web:"):
from_web = self.try_resolve_web(did)
if from_web:
self.__cache.set(did, from_web)
return from_web
raise Exception(f"Failed to resolve {did}!")
+
class HandleResolver:
def __init__(self) -> None:
···
for rdata in answers:
for txt_data in rdata.strings:
-
did = txt_data.decode('utf-8').strip()
+
did = txt_data.decode("utf-8").strip()
if did.startswith("did="):
return did[4:]
except dns.resolver.NXDOMAIN:
···
else:
response.raise_for_status()
-
def resolve_handle(self, handle: str) -> str:
cached = self.__cache.get(handle)
if cached:
···
return from_http
raise Exception(f"Failed to resolve handle {handle}!")
+
+
+
handle_resolver = HandleResolver()
+
did_resolver = DidResolver(env.PLC_HOST)
+35 -3
bluesky/info.py
···
-
from abc import ABC
+
from abc import ABC, abstractmethod
from typing import Any
+
+
from atproto.identity import did_resolver, handle_resolver
from cross.service import Service
-
from util.util import normalize_service_url
+
from util.util import LOGGER, normalize_service_url
+
+
SERVICE = "https://bsky.app"
+
def validate_and_transform(data: dict[str, Any]):
if not data["handle"] and not data["did"]:
···
if "pds" in data:
data["pds"] = normalize_service_url(data["pds"])
+
class BlueskyService(ABC, Service):
-
pass
+
pds: str
+
did: str
+
+
def _init_identity(self) -> None:
+
handle, did, pds = self.get_identity_options()
+
+
if did and pds:
+
self.did = did
+
self.pds = pds
+
return
+
+
if not did:
+
if not handle:
+
raise KeyError("No did: or atproto handle provided!")
+
LOGGER.info("Resolving ATP identity for %s...", handle)
+
self.did = handle_resolver.resolve_handle(handle)
+
+
if not pds:
+
LOGGER.info("Resolving PDS from %s DID document...", did)
+
atp_pds = did_resolver.resolve_did(self.did).get_atproto_pds()
+
if not atp_pds:
+
raise Exception("Failed to resolve atproto pds for %s")
+
self.pds = atp_pds
+
+
@abstractmethod
+
def get_identity_options(self) -> tuple[str | None, str | None, str | None]:
+
pass
+56 -7
bluesky/input.py
···
+
import asyncio
+
import re
from abc import ABC
from dataclasses import dataclass, field
-
import re
from typing import Any, Callable, override
-
from bluesky.info import BlueskyService, validate_and_transform
+
import websockets
+
+
from bluesky.info import SERVICE, BlueskyService, validate_and_transform
from cross.service import InputService, OutputService
+
from database.connection import DatabasePool
+
from util.util import LOGGER, normalize_service_url
@dataclass(kw_only=True)
class BlueskyInputOptions:
-
handle: str | None
-
did: str | None
-
pds: str | None
+
handle: str | None = None
+
did: str | None = None
+
pds: str | None = None
filters: list[re.Pattern[str]] = field(default_factory=lambda: [])
@classmethod
···
return BlueskyInputOptions(**data)
+
@dataclass(kw_only=True)
+
class BlueskyJetstreamInputOptions(BlueskyInputOptions):
+
jetstream: str = "wss://jetstream2.us-west.bsky.network/subscribe"
+
+
@classmethod
+
def from_dict(cls, data: dict[str, Any]) -> "BlueskyJetstreamInputOptions":
+
jetstream = data.pop("jetstream", None)
+
+
base = BlueskyInputOptions.from_dict(data).__dict__.copy()
+
if jetstream:
+
base["jetstream"] = normalize_service_url(jetstream)
+
+
return BlueskyJetstreamInputOptions(**base)
+
+
class BlueskyBaseInputService(BlueskyService, InputService, ABC):
-
pass
+
def __init__(self, db: DatabasePool) -> None:
+
super().__init__(SERVICE, db)
class BlueskyJetstreamInputService(BlueskyBaseInputService):
+
def __init__(self, db: DatabasePool, options: BlueskyJetstreamInputOptions) -> None:
+
super().__init__(db)
+
self.options: BlueskyJetstreamInputOptions = options
+
self._init_identity()
+
+
@override
+
def get_identity_options(self) -> tuple[str | None, str | None, str | None]:
+
return (self.options.handle, self.options.did, self.options.pds)
+
@override
async def listen(
self,
outputs: list[OutputService],
submitter: Callable[[Callable[[], None]], None],
):
-
return await super().listen(outputs, submitter) # TODO
+
url = self.options.jetstream + "?"
+
url += "wantedCollections=app.bsky.feed.post"
+
url += "&wantedCollections=app.bsky.feed.repost"
+
url += f"&wantedDids={self.did}"
+
+
async for ws in websockets.connect(url):
+
try:
+
LOGGER.info("Listening to %s...", self.options.jetstream)
+
+
async def listen_for_messages():
+
async for msg in ws:
+
LOGGER.info(msg) # TODO
+
+
listen = asyncio.create_task(listen_for_messages())
+
+
_ = await asyncio.gather(listen)
+
except websockets.ConnectionClosedError as e:
+
LOGGER.error(e, stack_info=True, exc_info=True)
+
LOGGER.info("Reconnecting to %s...", self.options.jetstream)
+
continue
+6 -9
cross/service.py
···
import sqlite3
from abc import ABC, abstractmethod
-
from pathlib import Path
from typing import Callable, cast
from cross.post import Post
-
from database.connection import get_conn
+
from database.connection import DatabasePool
from util.util import LOGGER
class Service:
-
def __init__(self, url: str, db: Path) -> None:
+
def __init__(self, url: str, db: DatabasePool) -> None:
self.url: str = url
-
self.conn: sqlite3.Connection = get_conn(db)
+
self.db: DatabasePool = db
+
#self._lock: threading.Lock = threading.Lock()
def get_post(self, url: str, user: str, identifier: str) -> sqlite3.Row | None:
-
cursor = self.conn.cursor()
+
cursor = self.db.get_conn().cursor()
_ = cursor.execute(
"""
SELECT * FROM posts
···
return cast(sqlite3.Row, cursor.fetchone())
def get_post_by_id(self, id: int) -> sqlite3.Row | None:
-
cursor = self.conn.cursor()
+
cursor = self.db.get_conn().cursor()
_ = cursor.execute("SELECT * FROM posts WHERE id = ?", (id,))
return cast(sqlite3.Row, cursor.fetchone())
-
-
def close(self):
-
self.conn.close()
class OutputService(Service):
+17
database/connection.py
···
import sqlite3
+
import threading
from pathlib import Path
+
+
class DatabasePool:
+
def __init__(self, db: Path) -> None:
+
self.db: Path = db
+
self._local: threading.local = threading.local()
+
self._conns: list[sqlite3.Connection] = []
+
+
def get_conn(self) -> sqlite3.Connection:
+
if getattr(self._local, 'conn', None) is None:
+
self._local.conn = get_conn(self.db)
+
self._conns.append(self._local.conn)
+
return self._local.conn
+
+
def close(self):
+
for c in self._conns:
+
c.close()
def get_conn(db: Path) -> sqlite3.Connection:
conn = sqlite3.connect(db, autocommit=True, check_same_thread=False)
+1 -1
env.py
···
DEV = bool(os.environ.get("DEV")) or False
DATA_DIR = os.environ.get("DATA_DIR") or "./data"
MIGRATIONS_DIR = os.environ.get("MIGRATIONS_DIR") or "./migrations"
-
PLC_HOST = os.environ.get("PLC_HOST") or "http://plc.directory"
+
PLC_HOST = os.environ.get("PLC_HOST") or "https://plc.wtf"
+7 -3
main.py
···
from pathlib import Path
from typing import Callable
+
from database.connection import DatabasePool
import env
from database.migrations import DatabaseMigrator
from registry import create_input_service, create_output_service
···
finally:
migrator.close()
+
db_pool = DatabasePool(database_path)
+
LOGGER.info("Bootstrapping registries...")
bootstrap()
···
if "outputs" not in settings:
raise KeyError("No `outputs` spicified in settings!")
-
input = create_input_service(database_path, settings["input"])
+
input = create_input_service(db_pool, settings["input"])
outputs = [
-
create_output_service(database_path, data) for data in settings["outputs"]
+
create_output_service(db_pool, data) for data in settings["outputs"]
]
LOGGER.info("Starting task worker...")
···
thread = threading.Thread(target=worker, args=(task_queue,), daemon=True)
thread.start()
-
LOGGER.info("Connecting to %s...", "TODO") # TODO
+
LOGGER.info("Connecting to %s...", input.url)
try:
asyncio.run(input.listen(outputs, lambda c: task_queue.put(c)))
except KeyboardInterrupt:
···
task_queue.join()
task_queue.put(None)
thread.join()
+
db_pool.close()
if __name__ == "__main__":
+2 -2
mastodon/input.py
···
import asyncio
import re
from dataclasses import dataclass, field
-
from pathlib import Path
from typing import Any, Callable, override
import websockets
from cross.service import InputService, OutputService
+
from database.connection import DatabasePool
from mastodon.info import MastodonService, validate_and_transform
from util.util import LOGGER
···
class MastodonInputService(MastodonService, InputService):
-
def __init__(self, db: Path, options: MastodonInputOptions) -> None:
+
def __init__(self, db: DatabasePool, options: MastodonInputOptions) -> None:
super().__init__(options.instance, db)
self.options: MastodonInputOptions = options
+2 -2
mastodon/output.py
···
from dataclasses import dataclass
-
from pathlib import Path
from typing import Any, override
from cross.service import OutputService
+
from database.connection import DatabasePool
from mastodon.info import InstanceInfo, MastodonService, validate_and_transform
from util.util import LOGGER
···
# TODO
class MastodonOutputService(MastodonService, OutputService):
-
def __init__(self, db: Path, options: MastodonOutputOptions) -> None:
+
def __init__(self, db: DatabasePool, options: MastodonOutputOptions) -> None:
super().__init__(options.instance, db)
self.options: MastodonOutputOptions = options
+2 -2
misskey/input.py
···
import re
import uuid
from dataclasses import dataclass, field
-
from pathlib import Path
from typing import Any, Callable, override
import websockets
from cross.service import InputService, OutputService
+
from database.connection import DatabasePool
from misskey.info import MisskeyService
from util.util import LOGGER, normalize_service_url
···
class MisskeyInputService(MisskeyService, InputService):
-
def __init__(self, db: Path, options: MisskeyInputOptions) -> None:
+
def __init__(self, db: DatabasePool, options: MisskeyInputOptions) -> None:
super().__init__(options.instance, db)
self.options: MisskeyInputOptions = options
+5 -4
registry.py
···
from typing import Any, Callable
from cross.service import InputService, OutputService
+
from database.connection import DatabasePool
-
input_factories: dict[str, Callable[[Path, dict[str, Any]], InputService]] = {}
-
output_factories: dict[str, Callable[[Path, dict[str, Any]], OutputService]] = {}
+
input_factories: dict[str, Callable[[DatabasePool, dict[str, Any]], InputService]] = {}
+
output_factories: dict[str, Callable[[DatabasePool, dict[str, Any]], OutputService]] = {}
-
def create_input_service(db: Path, data: dict[str, Any]) -> InputService:
+
def create_input_service(db: DatabasePool, data: dict[str, Any]) -> InputService:
if "type" not in data:
raise ValueError("No `type` field in input data!")
type: str = str(data["type"])
···
return factory(db, data)
-
def create_output_service(db: Path, data: dict[str, Any]) -> OutputService:
+
def create_output_service(db: DatabasePool, data: dict[str, Any]) -> OutputService:
if "type" not in data:
raise ValueError("No `type` field in input data!")
type: str = str(data["type"])
+5 -2
registry_bootstrap.py
···
-
from pathlib import Path
from typing import Any
+
from database.connection import DatabasePool
from registry import input_factories, output_factories
···
self.class_name: str = class_name
self.options_class_name: str = options_class_name
-
def __call__(self, db: Path, d: dict[str, Any]):
+
def __call__(self, db: DatabasePool, d: dict[str, Any]):
module = __import__(
self.module_path, fromlist=[self.class_name, self.options_class_name]
)
···
input_factories["misskey-wss"] = LazyFactory(
"misskey.input", "MisskeyInputService", "MisskeyInputOptions"
)
+
input_factories["bluesky-jetstream"] = LazyFactory(
+
"bluesky.input", "BlueskyJetstreamInputService", "BlueskyJetstreamInputOptions"
+
)