social media crossposting tool. 3rd time's the charm
mastodon misskey crossposting bluesky
1import asyncio 2import json 3import re 4from abc import ABC 5from dataclasses import dataclass, field 6from typing import Any, cast, override 7 8import websockets 9 10from atproto.util import AtUri 11from bluesky.info import SERVICE, BlueskyService, validate_and_transform 12from cross.attachments import ( 13 LabelsAttachment, 14 LanguagesAttachment, 15 MediaAttachment, 16 QuoteAttachment, 17 RemoteUrlAttachment, 18) 19from cross.media import Blob, download_blob 20from cross.post import Post 21from cross.service import InputService 22from database.connection import DatabasePool 23from util.util import normalize_service_url 24 25 26@dataclass(kw_only=True) 27class BlueskyInputOptions: 28 handle: str | None = None 29 did: str | None = None 30 pds: str | None = None 31 filters: list[re.Pattern[str]] = field(default_factory=lambda: []) 32 33 @classmethod 34 def from_dict(cls, data: dict[str, Any]) -> "BlueskyInputOptions": 35 validate_and_transform(data) 36 37 if "filters" in data: 38 data["filters"] = [re.compile(r) for r in data["filters"]] 39 40 return BlueskyInputOptions(**data) 41 42 43@dataclass(kw_only=True) 44class BlueskyJetstreamInputOptions(BlueskyInputOptions): 45 jetstream: str = "wss://jetstream2.us-west.bsky.network/subscribe" 46 47 @classmethod 48 def from_dict(cls, data: dict[str, Any]) -> "BlueskyJetstreamInputOptions": 49 jetstream = data.pop("jetstream", None) 50 51 base = BlueskyInputOptions.from_dict(data).__dict__.copy() 52 if jetstream: 53 base["jetstream"] = normalize_service_url(jetstream) 54 55 return BlueskyJetstreamInputOptions(**base) 56 57 58class BlueskyBaseInputService(BlueskyService, InputService, ABC): 59 def __init__(self, db: DatabasePool) -> None: 60 super().__init__(SERVICE, db) 61 62 def _on_post(self, record: dict[str, Any]): 63 post_uri = cast(str, record["$xpost.strongRef"]["uri"]) 64 post_cid = cast(str, record["$xpost.strongRef"]["cid"]) 65 66 parent_uri = cast( 67 str, None if not record.get("reply") else record["reply"]["parent"]["uri"] 68 ) 69 parent = None 70 if parent_uri: 71 parent = self._get_post(self.url, self.did, parent_uri) 72 if not parent: 73 self.log.info( 74 "Skipping %s, parent %s not found in db", post_uri, parent_uri 75 ) 76 return 77 78 # TODO FRAGMENTS 79 post = Post(id=post_uri, parent_id=parent_uri, text=record["text"]) 80 did, _, rid = AtUri.record_uri(post_uri) 81 post.attachments.put( 82 RemoteUrlAttachment(url=f"https://bsky.app/profile/{did}/post/{rid}") 83 ) 84 85 embed: dict[str, Any] = record.get("embed", {}) 86 blob_urls: list[tuple[str, str, str | None]] = [] 87 def handle_embeds(embed: dict[str, Any]): 88 nonlocal blob_urls, post 89 match cast(str, embed["$type"]): 90 case "app.bsky.embed.record" | "app.bsky.embed.recordWithMedia": 91 rcrd = embed['record']['record'] if embed['record'].get('record') else embed['record'] 92 did, collection, _ = AtUri.record_uri(rcrd["uri"]) 93 if collection != "app.bsky.feed.post": 94 return 95 if did != self.did: 96 return 97 post.attachments.put(QuoteAttachment(quoted_id=rcrd["uri"], quoted_user=did)) 98 99 if embed.get('media'): 100 handle_embeds(embed["media"]) 101 case "app.bsky.embed.images": 102 for image in embed["images"]: 103 blob_cid = image["image"]["ref"]["$link"] 104 url = f"{self.pds}/xrpc/com.atproto.sync.getBlob?did={self.did}&cid={blob_cid}" 105 blob_urls.append((url, blob_cid, image.get("alt"))) 106 case "app.bsky.embed.video": 107 blob_cid = embed["video"]["ref"]["$link"] 108 url = f"{self.pds}/xrpc/com.atproto.sync.getBlob?did={self.did}&cid={blob_cid}" 109 blob_urls.append((url, blob_cid, embed.get("alt"))) 110 case _: 111 self.log.warning(f"Unhandled embedd type {embed['$type']}") 112 pass 113 if embed: 114 handle_embeds(embed) 115 116 if blob_urls: 117 blobs: list[Blob] = [] 118 for url, cid, alt in blob_urls: 119 self.log.info("Downloading %s...", cid) 120 blob: Blob | None = download_blob(url, alt) 121 if not blob: 122 self.log.error( 123 "Skipping %s! Failed to download blob %s.", post_uri, cid 124 ) 125 return 126 blobs.append(blob) 127 post.attachments.put(MediaAttachment(blobs=blobs)) 128 129 if "langs" in record: 130 post.attachments.put(LanguagesAttachment(langs=record["langs"])) 131 if "labels" in record: 132 post.attachments.put( 133 LabelsAttachment( 134 labels=[ 135 label["val"].replace("-", " ") for label in record["values"] 136 ] 137 ), 138 ) 139 140 if parent: 141 self._insert_post( 142 { 143 "user": self.did, 144 "service": self.url, 145 "identifier": post_uri, 146 "parent": parent["id"], 147 "root": parent["id"] if not parent["root"] else parent["root"], 148 "extra_data": json.dumps({"cid": post_cid}), 149 } 150 ) 151 else: 152 self._insert_post( 153 { 154 "user": self.did, 155 "service": self.url, 156 "identifier": post_uri, 157 "extra_data": json.dumps({"cid": post_cid}), 158 } 159 ) 160 161 for out in self.outputs: 162 self.submitter(lambda: out.accept_post(post)) 163 164 def _on_repost(self, record: dict[str, Any]): 165 post_uri = cast(str, record["$xpost.strongRef"]["uri"]) 166 post_cid = cast(str, record["$xpost.strongRef"]["cid"]) 167 168 reposted_uri = cast(str, record["subject"]["uri"]) 169 reposted = self._get_post(self.url, self.did, reposted_uri) 170 if not reposted: 171 self.log.info( 172 "Skipping repost '%s' as reposted post '%s' was not found in the db.", 173 post_uri, 174 reposted_uri, 175 ) 176 return 177 178 self._insert_post( 179 { 180 "user": self.did, 181 "service": self.url, 182 "identifier": post_uri, 183 "reposted": reposted["id"], 184 "extra_data": json.dumps({"cid": post_cid}), 185 } 186 ) 187 188 for out in self.outputs: 189 self.submitter(lambda: out.accept_repost(post_uri, reposted_uri)) 190 191 def _on_delete_post(self, post_id: str, repost: bool): 192 post = self._get_post(self.url, self.did, post_id) 193 if not post: 194 return 195 196 if repost: 197 for output in self.outputs: 198 self.submitter(lambda: output.delete_repost(post_id)) 199 else: 200 for output in self.outputs: 201 self.submitter(lambda: output.delete_post(post_id)) 202 self._delete_post_by_id(post["id"]) 203 204 205class BlueskyJetstreamInputService(BlueskyBaseInputService): 206 def __init__(self, db: DatabasePool, options: BlueskyJetstreamInputOptions) -> None: 207 super().__init__(db) 208 self.options: BlueskyJetstreamInputOptions = options 209 self._init_identity() 210 211 @override 212 def get_identity_options(self) -> tuple[str | None, str | None, str | None]: 213 return (self.options.handle, self.options.did, self.options.pds) 214 215 def _accept_msg(self, msg: websockets.Data) -> None: 216 data: dict[str, Any] = cast(dict[str, Any], json.loads(msg)) 217 if data.get("did") != self.did: 218 return 219 commit: dict[str, Any] | None = data.get("commit") 220 if not commit: 221 return 222 223 commit_type: str = cast(str, commit["operation"]) 224 match commit_type: 225 case "create": 226 record: dict[str, Any] = cast(dict[str, Any], commit["record"]) 227 record["$xpost.strongRef"] = { 228 "cid": commit["cid"], 229 "uri": f"at://{self.did}/{commit['collection']}/{commit['rkey']}", 230 } 231 232 match cast(str, commit["collection"]): 233 case "app.bsky.feed.post": 234 self._on_post(record) 235 case "app.bsky.feed.repost": 236 self._on_repost(record) 237 case _: 238 pass 239 case "delete": 240 post_id: str = ( 241 f"at://{self.did}/{commit['collection']}/{commit['rkey']}" 242 ) 243 match cast(str, commit["collection"]): 244 case "app.bsky.feed.post": 245 self._on_delete_post(post_id, False) 246 case "app.bsky.feed.repost": 247 self._on_delete_post(post_id, True) 248 case _: 249 pass 250 case _: 251 pass 252 253 @override 254 async def listen(self): 255 url = self.options.jetstream + "?" 256 url += "wantedCollections=app.bsky.feed.post" 257 url += "&wantedCollections=app.bsky.feed.repost" 258 url += f"&wantedDids={self.did}" 259 260 async for ws in websockets.connect(url): 261 try: 262 self.log.info("Listening to %s...", self.options.jetstream) 263 264 async def listen_for_messages(): 265 async for msg in ws: 266 self.submitter(lambda: self._accept_msg(msg)) 267 268 listen = asyncio.create_task(listen_for_messages()) 269 270 _ = await asyncio.gather(listen) 271 except websockets.ConnectionClosedError as e: 272 self.log.error(e, stack_info=True, exc_info=True) 273 self.log.info("Reconnecting to %s...", self.options.jetstream) 274 continue