social media crossposting tool. 3rd time's the charm
mastodon
misskey
crossposting
bluesky
1import asyncio
2import json
3import re
4from abc import ABC
5from dataclasses import dataclass, field
6from typing import Any, cast, override
7
8import websockets
9
10from atproto.util import AtUri
11from bluesky.info import SERVICE, BlueskyService, validate_and_transform
12from cross.attachments import (
13 LabelsAttachment,
14 LanguagesAttachment,
15 MediaAttachment,
16 QuoteAttachment,
17 RemoteUrlAttachment,
18)
19from cross.media import Blob, download_blob
20from cross.post import Post
21from cross.service import InputService
22from database.connection import DatabasePool
23from util.util import normalize_service_url
24
25
26@dataclass(kw_only=True)
27class BlueskyInputOptions:
28 handle: str | None = None
29 did: str | None = None
30 pds: str | None = None
31 filters: list[re.Pattern[str]] = field(default_factory=lambda: [])
32
33 @classmethod
34 def from_dict(cls, data: dict[str, Any]) -> "BlueskyInputOptions":
35 validate_and_transform(data)
36
37 if "filters" in data:
38 data["filters"] = [re.compile(r) for r in data["filters"]]
39
40 return BlueskyInputOptions(**data)
41
42
43@dataclass(kw_only=True)
44class BlueskyJetstreamInputOptions(BlueskyInputOptions):
45 jetstream: str = "wss://jetstream2.us-west.bsky.network/subscribe"
46
47 @classmethod
48 def from_dict(cls, data: dict[str, Any]) -> "BlueskyJetstreamInputOptions":
49 jetstream = data.pop("jetstream", None)
50
51 base = BlueskyInputOptions.from_dict(data).__dict__.copy()
52 if jetstream:
53 base["jetstream"] = normalize_service_url(jetstream)
54
55 return BlueskyJetstreamInputOptions(**base)
56
57
58class BlueskyBaseInputService(BlueskyService, InputService, ABC):
59 def __init__(self, db: DatabasePool) -> None:
60 super().__init__(SERVICE, db)
61
62 def _on_post(self, record: dict[str, Any]):
63 post_uri = cast(str, record["$xpost.strongRef"]["uri"])
64 post_cid = cast(str, record["$xpost.strongRef"]["cid"])
65
66 parent_uri = cast(
67 str, None if not record.get("reply") else record["reply"]["parent"]["uri"]
68 )
69 parent = None
70 if parent_uri:
71 parent = self._get_post(self.url, self.did, parent_uri)
72 if not parent:
73 self.log.info(
74 "Skipping %s, parent %s not found in db", post_uri, parent_uri
75 )
76 return
77
78 # TODO FRAGMENTS
79 post = Post(id=post_uri, parent_id=parent_uri, text=record["text"])
80 did, _, rid = AtUri.record_uri(post_uri)
81 post.attachments.put(
82 RemoteUrlAttachment(url=f"https://bsky.app/profile/{did}/post/{rid}")
83 )
84
85 embed: dict[str, Any] = record.get("embed", {})
86 blob_urls: list[tuple[str, str, str | None]] = []
87 def handle_embeds(embed: dict[str, Any]) -> str | None:
88 nonlocal blob_urls, post
89 match cast(str, embed["$type"]):
90 case "app.bsky.embed.record" | "app.bsky.embed.recordWithMedia":
91 rcrd = embed['record']['record'] if embed['record'].get('record') else embed['record']
92 did, collection, _ = AtUri.record_uri(rcrd["uri"])
93 if collection != "app.bsky.feed.post":
94 return f"Unhandled record collection {collection}"
95 if did != self.did:
96 return ""
97
98 rquote = self._get_post(self.url, did, rcrd["uri"])
99 if not rquote:
100 return f"Quote {rcrd["uri"]} not found in the db"
101 post.attachments.put(QuoteAttachment(quoted_id=rcrd["uri"], quoted_user=did))
102
103 if embed.get('media'):
104 return handle_embeds(embed["media"])
105 case "app.bsky.embed.images":
106 for image in embed["images"]:
107 blob_cid = image["image"]["ref"]["$link"]
108 url = f"{self.pds}/xrpc/com.atproto.sync.getBlob?did={self.did}&cid={blob_cid}"
109 blob_urls.append((url, blob_cid, image.get("alt")))
110 case "app.bsky.embed.video":
111 blob_cid = embed["video"]["ref"]["$link"]
112 url = f"{self.pds}/xrpc/com.atproto.sync.getBlob?did={self.did}&cid={blob_cid}"
113 blob_urls.append((url, blob_cid, embed.get("alt")))
114 case _:
115 self.log.warning(f"Unhandled embed type {embed['$type']}")
116
117 if embed:
118 fexit = handle_embeds(embed)
119 if fexit is not None:
120 self.log.info("Skipping %s! %s", post_uri, fexit)
121 return
122
123 if blob_urls:
124 blobs: list[Blob] = []
125 for url, cid, alt in blob_urls:
126 self.log.info("Downloading %s...", cid)
127 blob: Blob | None = download_blob(url, alt)
128 if not blob:
129 self.log.error(
130 "Skipping %s! Failed to download blob %s.", post_uri, cid
131 )
132 return
133 blobs.append(blob)
134 post.attachments.put(MediaAttachment(blobs=blobs))
135
136 if "langs" in record:
137 post.attachments.put(LanguagesAttachment(langs=record["langs"]))
138 if "labels" in record:
139 post.attachments.put(
140 LabelsAttachment(
141 labels=[
142 label["val"].replace("-", " ") for label in record["values"]
143 ]
144 ),
145 )
146
147 if parent:
148 self._insert_post(
149 {
150 "user": self.did,
151 "service": self.url,
152 "identifier": post_uri,
153 "parent": parent["id"],
154 "root": parent["id"] if not parent["root"] else parent["root"],
155 "extra_data": json.dumps({"cid": post_cid}),
156 }
157 )
158 else:
159 self._insert_post(
160 {
161 "user": self.did,
162 "service": self.url,
163 "identifier": post_uri,
164 "extra_data": json.dumps({"cid": post_cid}),
165 }
166 )
167
168 for out in self.outputs:
169 self.submitter(lambda: out.accept_post(post))
170
171 def _on_repost(self, record: dict[str, Any]):
172 post_uri = cast(str, record["$xpost.strongRef"]["uri"])
173 post_cid = cast(str, record["$xpost.strongRef"]["cid"])
174
175 reposted_uri = cast(str, record["subject"]["uri"])
176 reposted = self._get_post(self.url, self.did, reposted_uri)
177 if not reposted:
178 self.log.info(
179 "Skipping repost '%s' as reposted post '%s' was not found in the db.",
180 post_uri,
181 reposted_uri,
182 )
183 return
184
185 self._insert_post(
186 {
187 "user": self.did,
188 "service": self.url,
189 "identifier": post_uri,
190 "reposted": reposted["id"],
191 "extra_data": json.dumps({"cid": post_cid}),
192 }
193 )
194
195 for out in self.outputs:
196 self.submitter(lambda: out.accept_repost(post_uri, reposted_uri))
197
198 def _on_delete_post(self, post_id: str, repost: bool):
199 post = self._get_post(self.url, self.did, post_id)
200 if not post:
201 return
202
203 if repost:
204 for output in self.outputs:
205 self.submitter(lambda: output.delete_repost(post_id))
206 else:
207 for output in self.outputs:
208 self.submitter(lambda: output.delete_post(post_id))
209 self._delete_post_by_id(post["id"])
210
211
212class BlueskyJetstreamInputService(BlueskyBaseInputService):
213 def __init__(self, db: DatabasePool, options: BlueskyJetstreamInputOptions) -> None:
214 super().__init__(db)
215 self.options: BlueskyJetstreamInputOptions = options
216 self._init_identity()
217
218 @override
219 def get_identity_options(self) -> tuple[str | None, str | None, str | None]:
220 return (self.options.handle, self.options.did, self.options.pds)
221
222 def _accept_msg(self, msg: websockets.Data) -> None:
223 data: dict[str, Any] = cast(dict[str, Any], json.loads(msg))
224 if data.get("did") != self.did:
225 return
226 commit: dict[str, Any] | None = data.get("commit")
227 if not commit:
228 return
229
230 commit_type: str = cast(str, commit["operation"])
231 match commit_type:
232 case "create":
233 record: dict[str, Any] = cast(dict[str, Any], commit["record"])
234 record["$xpost.strongRef"] = {
235 "cid": commit["cid"],
236 "uri": f"at://{self.did}/{commit['collection']}/{commit['rkey']}",
237 }
238
239 match cast(str, commit["collection"]):
240 case "app.bsky.feed.post":
241 self._on_post(record)
242 case "app.bsky.feed.repost":
243 self._on_repost(record)
244 case _:
245 pass
246 case "delete":
247 post_id: str = (
248 f"at://{self.did}/{commit['collection']}/{commit['rkey']}"
249 )
250 match cast(str, commit["collection"]):
251 case "app.bsky.feed.post":
252 self._on_delete_post(post_id, False)
253 case "app.bsky.feed.repost":
254 self._on_delete_post(post_id, True)
255 case _:
256 pass
257 case _:
258 pass
259
260 @override
261 async def listen(self):
262 url = self.options.jetstream + "?"
263 url += "wantedCollections=app.bsky.feed.post"
264 url += "&wantedCollections=app.bsky.feed.repost"
265 url += f"&wantedDids={self.did}"
266
267 async for ws in websockets.connect(url):
268 try:
269 self.log.info("Listening to %s...", self.options.jetstream)
270
271 async def listen_for_messages():
272 async for msg in ws:
273 self.submitter(lambda: self._accept_msg(msg))
274
275 listen = asyncio.create_task(listen_for_messages())
276
277 _ = await asyncio.gather(listen)
278 except websockets.ConnectionClosedError as e:
279 self.log.error(e, stack_info=True, exc_info=True)
280 self.log.info("Reconnecting to %s...", self.options.jetstream)
281 continue