social media crossposting tool. 3rd time's the charm
mastodon misskey crossposting bluesky
1import asyncio 2import json 3import re 4from abc import ABC 5from dataclasses import dataclass, field 6from typing import Any, cast, override 7 8import websockets 9 10from atproto.util import AtUri 11from bluesky.info import SERVICE, BlueskyService, validate_and_transform 12from cross.attachments import ( 13 LabelsAttachment, 14 LanguagesAttachment, 15 MediaAttachment, 16 QuoteAttachment, 17 RemoteUrlAttachment, 18) 19from cross.media import Blob, download_blob 20from cross.post import Post 21from cross.service import InputService 22from database.connection import DatabasePool 23from util.util import normalize_service_url 24 25 26@dataclass(kw_only=True) 27class BlueskyInputOptions: 28 handle: str | None = None 29 did: str | None = None 30 pds: str | None = None 31 filters: list[re.Pattern[str]] = field(default_factory=lambda: []) 32 33 @classmethod 34 def from_dict(cls, data: dict[str, Any]) -> "BlueskyInputOptions": 35 validate_and_transform(data) 36 37 if "filters" in data: 38 data["filters"] = [re.compile(r) for r in data["filters"]] 39 40 return BlueskyInputOptions(**data) 41 42 43@dataclass(kw_only=True) 44class BlueskyJetstreamInputOptions(BlueskyInputOptions): 45 jetstream: str = "wss://jetstream2.us-west.bsky.network/subscribe" 46 47 @classmethod 48 def from_dict(cls, data: dict[str, Any]) -> "BlueskyJetstreamInputOptions": 49 jetstream = data.pop("jetstream", None) 50 51 base = BlueskyInputOptions.from_dict(data).__dict__.copy() 52 if jetstream: 53 base["jetstream"] = normalize_service_url(jetstream) 54 55 return BlueskyJetstreamInputOptions(**base) 56 57 58class BlueskyBaseInputService(BlueskyService, InputService, ABC): 59 def __init__(self, db: DatabasePool) -> None: 60 super().__init__(SERVICE, db) 61 62 def _on_post(self, record: dict[str, Any]): 63 post_uri = cast(str, record["$xpost.strongRef"]["uri"]) 64 post_cid = cast(str, record["$xpost.strongRef"]["cid"]) 65 66 parent_uri = cast( 67 str, None if not record.get("reply") else record["reply"]["parent"]["uri"] 68 ) 69 parent = None 70 if parent_uri: 71 parent = self._get_post(self.url, self.did, parent_uri) 72 if not parent: 73 self.log.info( 74 "Skipping %s, parent %s not found in db", post_uri, parent_uri 75 ) 76 return 77 78 # TODO FRAGMENTS 79 post = Post(id=post_uri, parent_id=parent_uri, text=record["text"]) 80 did, _, rid = AtUri.record_uri(post_uri) 81 post.attachments.put( 82 RemoteUrlAttachment(url=f"https://bsky.app/profile/{did}/post/{rid}") 83 ) 84 85 embed: dict[str, Any] = record.get("embed", {}) 86 blob_urls: list[tuple[str, str, str | None]] = [] 87 def handle_embeds(embed: dict[str, Any]) -> str | None: 88 nonlocal blob_urls, post 89 match cast(str, embed["$type"]): 90 case "app.bsky.embed.record" | "app.bsky.embed.recordWithMedia": 91 rcrd = embed['record']['record'] if embed['record'].get('record') else embed['record'] 92 did, collection, _ = AtUri.record_uri(rcrd["uri"]) 93 if collection != "app.bsky.feed.post": 94 return f"Unhandled record collection {collection}" 95 if did != self.did: 96 return "" 97 98 rquote = self._get_post(self.url, did, rcrd["uri"]) 99 if not rquote: 100 return f"Quote {rcrd["uri"]} not found in the db" 101 post.attachments.put(QuoteAttachment(quoted_id=rcrd["uri"], quoted_user=did)) 102 103 if embed.get('media'): 104 return handle_embeds(embed["media"]) 105 case "app.bsky.embed.images": 106 for image in embed["images"]: 107 blob_cid = image["image"]["ref"]["$link"] 108 url = f"{self.pds}/xrpc/com.atproto.sync.getBlob?did={self.did}&cid={blob_cid}" 109 blob_urls.append((url, blob_cid, image.get("alt"))) 110 case "app.bsky.embed.video": 111 blob_cid = embed["video"]["ref"]["$link"] 112 url = f"{self.pds}/xrpc/com.atproto.sync.getBlob?did={self.did}&cid={blob_cid}" 113 blob_urls.append((url, blob_cid, embed.get("alt"))) 114 case _: 115 self.log.warning(f"Unhandled embed type {embed['$type']}") 116 117 if embed: 118 fexit = handle_embeds(embed) 119 if fexit is not None: 120 self.log.info("Skipping %s! %s", post_uri, fexit) 121 return 122 123 if blob_urls: 124 blobs: list[Blob] = [] 125 for url, cid, alt in blob_urls: 126 self.log.info("Downloading %s...", cid) 127 blob: Blob | None = download_blob(url, alt) 128 if not blob: 129 self.log.error( 130 "Skipping %s! Failed to download blob %s.", post_uri, cid 131 ) 132 return 133 blobs.append(blob) 134 post.attachments.put(MediaAttachment(blobs=blobs)) 135 136 if "langs" in record: 137 post.attachments.put(LanguagesAttachment(langs=record["langs"])) 138 if "labels" in record: 139 post.attachments.put( 140 LabelsAttachment( 141 labels=[ 142 label["val"].replace("-", " ") for label in record["values"] 143 ] 144 ), 145 ) 146 147 if parent: 148 self._insert_post( 149 { 150 "user": self.did, 151 "service": self.url, 152 "identifier": post_uri, 153 "parent": parent["id"], 154 "root": parent["id"] if not parent["root"] else parent["root"], 155 "extra_data": json.dumps({"cid": post_cid}), 156 } 157 ) 158 else: 159 self._insert_post( 160 { 161 "user": self.did, 162 "service": self.url, 163 "identifier": post_uri, 164 "extra_data": json.dumps({"cid": post_cid}), 165 } 166 ) 167 168 for out in self.outputs: 169 self.submitter(lambda: out.accept_post(post)) 170 171 def _on_repost(self, record: dict[str, Any]): 172 post_uri = cast(str, record["$xpost.strongRef"]["uri"]) 173 post_cid = cast(str, record["$xpost.strongRef"]["cid"]) 174 175 reposted_uri = cast(str, record["subject"]["uri"]) 176 reposted = self._get_post(self.url, self.did, reposted_uri) 177 if not reposted: 178 self.log.info( 179 "Skipping repost '%s' as reposted post '%s' was not found in the db.", 180 post_uri, 181 reposted_uri, 182 ) 183 return 184 185 self._insert_post( 186 { 187 "user": self.did, 188 "service": self.url, 189 "identifier": post_uri, 190 "reposted": reposted["id"], 191 "extra_data": json.dumps({"cid": post_cid}), 192 } 193 ) 194 195 for out in self.outputs: 196 self.submitter(lambda: out.accept_repost(post_uri, reposted_uri)) 197 198 def _on_delete_post(self, post_id: str, repost: bool): 199 post = self._get_post(self.url, self.did, post_id) 200 if not post: 201 return 202 203 if repost: 204 for output in self.outputs: 205 self.submitter(lambda: output.delete_repost(post_id)) 206 else: 207 for output in self.outputs: 208 self.submitter(lambda: output.delete_post(post_id)) 209 self._delete_post_by_id(post["id"]) 210 211 212class BlueskyJetstreamInputService(BlueskyBaseInputService): 213 def __init__(self, db: DatabasePool, options: BlueskyJetstreamInputOptions) -> None: 214 super().__init__(db) 215 self.options: BlueskyJetstreamInputOptions = options 216 self._init_identity() 217 218 @override 219 def get_identity_options(self) -> tuple[str | None, str | None, str | None]: 220 return (self.options.handle, self.options.did, self.options.pds) 221 222 def _accept_msg(self, msg: websockets.Data) -> None: 223 data: dict[str, Any] = cast(dict[str, Any], json.loads(msg)) 224 if data.get("did") != self.did: 225 return 226 commit: dict[str, Any] | None = data.get("commit") 227 if not commit: 228 return 229 230 commit_type: str = cast(str, commit["operation"]) 231 match commit_type: 232 case "create": 233 record: dict[str, Any] = cast(dict[str, Any], commit["record"]) 234 record["$xpost.strongRef"] = { 235 "cid": commit["cid"], 236 "uri": f"at://{self.did}/{commit['collection']}/{commit['rkey']}", 237 } 238 239 match cast(str, commit["collection"]): 240 case "app.bsky.feed.post": 241 self._on_post(record) 242 case "app.bsky.feed.repost": 243 self._on_repost(record) 244 case _: 245 pass 246 case "delete": 247 post_id: str = ( 248 f"at://{self.did}/{commit['collection']}/{commit['rkey']}" 249 ) 250 match cast(str, commit["collection"]): 251 case "app.bsky.feed.post": 252 self._on_delete_post(post_id, False) 253 case "app.bsky.feed.repost": 254 self._on_delete_post(post_id, True) 255 case _: 256 pass 257 case _: 258 pass 259 260 @override 261 async def listen(self): 262 url = self.options.jetstream + "?" 263 url += "wantedCollections=app.bsky.feed.post" 264 url += "&wantedCollections=app.bsky.feed.repost" 265 url += f"&wantedDids={self.did}" 266 267 async for ws in websockets.connect(url): 268 try: 269 self.log.info("Listening to %s...", self.options.jetstream) 270 271 async def listen_for_messages(): 272 async for msg in ws: 273 self.submitter(lambda: self._accept_msg(msg)) 274 275 listen = asyncio.create_task(listen_for_messages()) 276 277 _ = await asyncio.gather(listen) 278 except websockets.ConnectionClosedError as e: 279 self.log.error(e, stack_info=True, exc_info=True) 280 self.log.info("Reconnecting to %s...", self.options.jetstream) 281 continue