social media crossposting tool. 3rd time's the charm
mastodon
misskey
crossposting
bluesky
1import asyncio
2import json
3import re
4from abc import ABC
5from dataclasses import dataclass, field
6from typing import Any, cast, override
7
8import websockets
9
10from atproto.util import AtUri
11from bluesky.tokens import tokenize_post
12from bluesky.info import SERVICE, BlueskyService, validate_and_transform
13from cross.attachments import (
14 LabelsAttachment,
15 LanguagesAttachment,
16 MediaAttachment,
17 QuoteAttachment,
18 RemoteUrlAttachment,
19)
20from cross.media import Blob, download_blob
21from cross.post import Post
22from cross.service import InputService
23from database.connection import DatabasePool
24from util.util import normalize_service_url
25
26
27@dataclass(kw_only=True)
28class BlueskyInputOptions:
29 handle: str | None = None
30 did: str | None = None
31 pds: str | None = None
32 filters: list[re.Pattern[str]] = field(default_factory=lambda: [])
33
34 @classmethod
35 def from_dict(cls, data: dict[str, Any]) -> "BlueskyInputOptions":
36 validate_and_transform(data)
37
38 if "filters" in data:
39 data["filters"] = [re.compile(r) for r in data["filters"]]
40
41 return BlueskyInputOptions(**data)
42
43
44@dataclass(kw_only=True)
45class BlueskyJetstreamInputOptions(BlueskyInputOptions):
46 jetstream: str = "wss://jetstream2.us-west.bsky.network/subscribe"
47
48 @classmethod
49 def from_dict(cls, data: dict[str, Any]) -> "BlueskyJetstreamInputOptions":
50 jetstream = data.pop("jetstream", None)
51
52 base = BlueskyInputOptions.from_dict(data).__dict__.copy()
53 if jetstream:
54 base["jetstream"] = normalize_service_url(jetstream)
55
56 return BlueskyJetstreamInputOptions(**base)
57
58
59class BlueskyBaseInputService(BlueskyService, InputService, ABC):
60 def __init__(self, db: DatabasePool) -> None:
61 super().__init__(SERVICE, db)
62
63 def _on_post(self, record: dict[str, Any]):
64 post_uri = cast(str, record["$xpost.strongRef"]["uri"])
65 post_cid = cast(str, record["$xpost.strongRef"]["cid"])
66
67 parent_uri = cast(
68 str, None if not record.get("reply") else record["reply"]["parent"]["uri"]
69 )
70 parent = None
71 if parent_uri:
72 parent = self._get_post(self.url, self.did, parent_uri)
73 if not parent:
74 self.log.info(
75 "Skipping %s, parent %s not found in db", post_uri, parent_uri
76 )
77 return
78
79 tokens = tokenize_post(record["text"], record.get('facets', {}))
80 post = Post(id=post_uri, parent_id=parent_uri, tokens=tokens)
81
82 did, _, rid = AtUri.record_uri(post_uri)
83 post.attachments.put(
84 RemoteUrlAttachment(url=f"https://bsky.app/profile/{did}/post/{rid}")
85 )
86
87 embed: dict[str, Any] = record.get("embed", {})
88 blob_urls: list[tuple[str, str, str | None]] = []
89 def handle_embeds(embed: dict[str, Any]) -> str | None:
90 nonlocal blob_urls, post
91 match cast(str, embed["$type"]):
92 case "app.bsky.embed.record" | "app.bsky.embed.recordWithMedia":
93 rcrd = embed['record']['record'] if embed['record'].get('record') else embed['record']
94 did, collection, _ = AtUri.record_uri(rcrd["uri"])
95 if collection != "app.bsky.feed.post":
96 return f"Unhandled record collection {collection}"
97 if did != self.did:
98 return ""
99
100 rquote = self._get_post(self.url, did, rcrd["uri"])
101 if not rquote:
102 return f"Quote {rcrd["uri"]} not found in the db"
103 post.attachments.put(QuoteAttachment(quoted_id=rcrd["uri"], quoted_user=did))
104
105 if embed.get('media'):
106 return handle_embeds(embed["media"])
107 case "app.bsky.embed.images":
108 for image in embed["images"]:
109 blob_cid = image["image"]["ref"]["$link"]
110 url = f"{self.pds}/xrpc/com.atproto.sync.getBlob?did={self.did}&cid={blob_cid}"
111 blob_urls.append((url, blob_cid, image.get("alt")))
112 case "app.bsky.embed.video":
113 blob_cid = embed["video"]["ref"]["$link"]
114 url = f"{self.pds}/xrpc/com.atproto.sync.getBlob?did={self.did}&cid={blob_cid}"
115 blob_urls.append((url, blob_cid, embed.get("alt")))
116 case _:
117 self.log.warning(f"Unhandled embed type {embed['$type']}")
118
119 if embed:
120 fexit = handle_embeds(embed)
121 if fexit is not None:
122 self.log.info("Skipping %s! %s", post_uri, fexit)
123 return
124
125 if blob_urls:
126 blobs: list[Blob] = []
127 for url, cid, alt in blob_urls:
128 self.log.info("Downloading %s...", cid)
129 blob: Blob | None = download_blob(url, alt)
130 if not blob:
131 self.log.error(
132 "Skipping %s! Failed to download blob %s.", post_uri, cid
133 )
134 return
135 blobs.append(blob)
136 post.attachments.put(MediaAttachment(blobs=blobs))
137
138 if "langs" in record:
139 post.attachments.put(LanguagesAttachment(langs=record["langs"]))
140 if "labels" in record:
141 post.attachments.put(
142 LabelsAttachment(
143 labels=[
144 label["val"].replace("-", " ") for label in record["values"]
145 ]
146 ),
147 )
148
149 if parent:
150 self._insert_post(
151 {
152 "user": self.did,
153 "service": self.url,
154 "identifier": post_uri,
155 "parent": parent["id"],
156 "root": parent["id"] if not parent["root"] else parent["root"],
157 "extra_data": json.dumps({"cid": post_cid}),
158 }
159 )
160 else:
161 self._insert_post(
162 {
163 "user": self.did,
164 "service": self.url,
165 "identifier": post_uri,
166 "extra_data": json.dumps({"cid": post_cid}),
167 }
168 )
169
170 for out in self.outputs:
171 self.submitter(lambda: out.accept_post(post))
172
173 def _on_repost(self, record: dict[str, Any]):
174 post_uri = cast(str, record["$xpost.strongRef"]["uri"])
175 post_cid = cast(str, record["$xpost.strongRef"]["cid"])
176
177 reposted_uri = cast(str, record["subject"]["uri"])
178 reposted = self._get_post(self.url, self.did, reposted_uri)
179 if not reposted:
180 self.log.info(
181 "Skipping repost '%s' as reposted post '%s' was not found in the db.",
182 post_uri,
183 reposted_uri,
184 )
185 return
186
187 self._insert_post(
188 {
189 "user": self.did,
190 "service": self.url,
191 "identifier": post_uri,
192 "reposted": reposted["id"],
193 "extra_data": json.dumps({"cid": post_cid}),
194 }
195 )
196
197 for out in self.outputs:
198 self.submitter(lambda: out.accept_repost(post_uri, reposted_uri))
199
200 def _on_delete_post(self, post_id: str, repost: bool):
201 post = self._get_post(self.url, self.did, post_id)
202 if not post:
203 return
204
205 if repost:
206 for output in self.outputs:
207 self.submitter(lambda: output.delete_repost(post_id))
208 else:
209 for output in self.outputs:
210 self.submitter(lambda: output.delete_post(post_id))
211 self._delete_post_by_id(post["id"])
212
213
214class BlueskyJetstreamInputService(BlueskyBaseInputService):
215 def __init__(self, db: DatabasePool, options: BlueskyJetstreamInputOptions) -> None:
216 super().__init__(db)
217 self.options: BlueskyJetstreamInputOptions = options
218 self._init_identity()
219
220 @override
221 def get_identity_options(self) -> tuple[str | None, str | None, str | None]:
222 return (self.options.handle, self.options.did, self.options.pds)
223
224 def _accept_msg(self, msg: websockets.Data) -> None:
225 data: dict[str, Any] = cast(dict[str, Any], json.loads(msg))
226 if data.get("did") != self.did:
227 return
228 commit: dict[str, Any] | None = data.get("commit")
229 if not commit:
230 return
231
232 commit_type: str = cast(str, commit["operation"])
233 match commit_type:
234 case "create":
235 record: dict[str, Any] = cast(dict[str, Any], commit["record"])
236 record["$xpost.strongRef"] = {
237 "cid": commit["cid"],
238 "uri": f"at://{self.did}/{commit['collection']}/{commit['rkey']}",
239 }
240
241 match cast(str, commit["collection"]):
242 case "app.bsky.feed.post":
243 self._on_post(record)
244 case "app.bsky.feed.repost":
245 self._on_repost(record)
246 case _:
247 pass
248 case "delete":
249 post_id: str = (
250 f"at://{self.did}/{commit['collection']}/{commit['rkey']}"
251 )
252 match cast(str, commit["collection"]):
253 case "app.bsky.feed.post":
254 self._on_delete_post(post_id, False)
255 case "app.bsky.feed.repost":
256 self._on_delete_post(post_id, True)
257 case _:
258 pass
259 case _:
260 pass
261
262 @override
263 async def listen(self):
264 url = self.options.jetstream + "?"
265 url += "wantedCollections=app.bsky.feed.post"
266 url += "&wantedCollections=app.bsky.feed.repost"
267 url += f"&wantedDids={self.did}"
268
269 async for ws in websockets.connect(url):
270 try:
271 self.log.info("Listening to %s...", self.options.jetstream)
272
273 async def listen_for_messages():
274 async for msg in ws:
275 self.submitter(lambda: self._accept_msg(msg))
276
277 listen = asyncio.create_task(listen_for_messages())
278
279 _ = await asyncio.gather(listen)
280 except websockets.ConnectionClosedError as e:
281 self.log.error(e, stack_info=True, exc_info=True)
282 self.log.info("Reconnecting to %s...", self.options.jetstream)
283 continue