social media crossposting tool. 3rd time's the charm
mastodon
misskey
crossposting
bluesky
1import asyncio
2import json
3import re
4from abc import ABC
5from dataclasses import dataclass, field
6from typing import Any, cast, override
7
8import websockets
9
10from atproto.util import AtUri
11from bluesky.info import SERVICE, BlueskyService, validate_and_transform
12from cross.attachments import (
13 LabelsAttachment,
14 LanguagesAttachment,
15 MediaAttachment,
16 QuoteAttachment,
17 RemoteUrlAttachment,
18)
19from cross.media import Blob, download_blob
20from cross.post import Post
21from cross.service import InputService
22from database.connection import DatabasePool
23from util.util import normalize_service_url
24
25
26@dataclass(kw_only=True)
27class BlueskyInputOptions:
28 handle: str | None = None
29 did: str | None = None
30 pds: str | None = None
31 filters: list[re.Pattern[str]] = field(default_factory=lambda: [])
32
33 @classmethod
34 def from_dict(cls, data: dict[str, Any]) -> "BlueskyInputOptions":
35 validate_and_transform(data)
36
37 if "filters" in data:
38 data["filters"] = [re.compile(r) for r in data["filters"]]
39
40 return BlueskyInputOptions(**data)
41
42
43@dataclass(kw_only=True)
44class BlueskyJetstreamInputOptions(BlueskyInputOptions):
45 jetstream: str = "wss://jetstream2.us-west.bsky.network/subscribe"
46
47 @classmethod
48 def from_dict(cls, data: dict[str, Any]) -> "BlueskyJetstreamInputOptions":
49 jetstream = data.pop("jetstream", None)
50
51 base = BlueskyInputOptions.from_dict(data).__dict__.copy()
52 if jetstream:
53 base["jetstream"] = normalize_service_url(jetstream)
54
55 return BlueskyJetstreamInputOptions(**base)
56
57
58class BlueskyBaseInputService(BlueskyService, InputService, ABC):
59 def __init__(self, db: DatabasePool) -> None:
60 super().__init__(SERVICE, db)
61
62 def _on_post(self, record: dict[str, Any]):
63 post_uri = cast(str, record["$xpost.strongRef"]["uri"])
64 post_cid = cast(str, record["$xpost.strongRef"]["cid"])
65
66 parent_uri = cast(
67 str, None if not record.get("reply") else record["reply"]["parent"]["uri"]
68 )
69 parent = None
70 if parent_uri:
71 parent = self._get_post(self.url, self.did, parent_uri)
72 if not parent:
73 self.log.info(
74 "Skipping %s, parent %s not found in db", post_uri, parent_uri
75 )
76 return
77
78 # TODO FRAGMENTS
79 post = Post(id=post_uri, parent_id=parent_uri, text=record["text"])
80 did, _, rid = AtUri.record_uri(post_uri)
81 post.attachments.put(
82 RemoteUrlAttachment(url=f"https://bsky.app/profile/{did}/post/{rid}")
83 )
84
85 embed: dict[str, Any] = record.get("embed", {})
86 blob_urls: list[tuple[str, str, str | None]] = []
87 def handle_embeds(embed: dict[str, Any]):
88 nonlocal blob_urls, post
89 match cast(str, embed["$type"]):
90 case "app.bsky.embed.record" | "app.bsky.embed.recordWithMedia":
91 rcrd = embed['record']['record'] if embed['record'].get('record') else embed['record']
92 did, collection, _ = AtUri.record_uri(rcrd["uri"])
93 if collection != "app.bsky.feed.post":
94 return
95 if did != self.did:
96 return
97 post.attachments.put(QuoteAttachment(quoted_id=rcrd["uri"], quoted_user=did))
98
99 if embed.get('media'):
100 handle_embeds(embed["media"])
101 case "app.bsky.embed.images":
102 for image in embed["images"]:
103 blob_cid = image["image"]["ref"]["$link"]
104 url = f"{self.pds}/xrpc/com.atproto.sync.getBlob?did={self.did}&cid={blob_cid}"
105 blob_urls.append((url, blob_cid, image.get("alt")))
106 case "app.bsky.embed.video":
107 blob_cid = embed["video"]["ref"]["$link"]
108 url = f"{self.pds}/xrpc/com.atproto.sync.getBlob?did={self.did}&cid={blob_cid}"
109 blob_urls.append((url, blob_cid, embed.get("alt")))
110 case _:
111 self.log.warning(f"Unhandled embedd type {embed['$type']}")
112 pass
113 if embed:
114 handle_embeds(embed)
115
116 if blob_urls:
117 blobs: list[Blob] = []
118 for url, cid, alt in blob_urls:
119 self.log.info("Downloading %s...", cid)
120 blob: Blob | None = download_blob(url, alt)
121 if not blob:
122 self.log.error(
123 "Skipping %s! Failed to download blob %s.", post_uri, cid
124 )
125 return
126 blobs.append(blob)
127 post.attachments.put(MediaAttachment(blobs=blobs))
128
129 if "langs" in record:
130 post.attachments.put(LanguagesAttachment(langs=record["langs"]))
131 if "labels" in record:
132 post.attachments.put(
133 LabelsAttachment(
134 labels=[
135 label["val"].replace("-", " ") for label in record["values"]
136 ]
137 ),
138 )
139
140 if parent:
141 self._insert_post(
142 {
143 "user": self.did,
144 "service": self.url,
145 "identifier": post_uri,
146 "parent": parent["id"],
147 "root": parent["id"] if not parent["root"] else parent["root"],
148 "extra_data": json.dumps({"cid": post_cid}),
149 }
150 )
151 else:
152 self._insert_post(
153 {
154 "user": self.did,
155 "service": self.url,
156 "identifier": post_uri,
157 "extra_data": json.dumps({"cid": post_cid}),
158 }
159 )
160
161 for out in self.outputs:
162 self.submitter(lambda: out.accept_post(post))
163
164 def _on_repost(self, record: dict[str, Any]):
165 post_uri = cast(str, record["$xpost.strongRef"]["uri"])
166 post_cid = cast(str, record["$xpost.strongRef"]["cid"])
167
168 reposted_uri = cast(str, record["subject"]["uri"])
169 reposted = self._get_post(self.url, self.did, reposted_uri)
170 if not reposted:
171 self.log.info(
172 "Skipping repost '%s' as reposted post '%s' was not found in the db.",
173 post_uri,
174 reposted_uri,
175 )
176 return
177
178 self._insert_post(
179 {
180 "user": self.did,
181 "service": self.url,
182 "identifier": post_uri,
183 "reposted": reposted["id"],
184 "extra_data": json.dumps({"cid": post_cid}),
185 }
186 )
187
188 for out in self.outputs:
189 self.submitter(lambda: out.accept_repost(post_uri, reposted_uri))
190
191 def _on_delete_post(self, post_id: str, repost: bool):
192 post = self._get_post(self.url, self.did, post_id)
193 if not post:
194 return
195
196 if repost:
197 for output in self.outputs:
198 self.submitter(lambda: output.delete_repost(post_id))
199 else:
200 for output in self.outputs:
201 self.submitter(lambda: output.delete_post(post_id))
202 self._delete_post_by_id(post["id"])
203
204
205class BlueskyJetstreamInputService(BlueskyBaseInputService):
206 def __init__(self, db: DatabasePool, options: BlueskyJetstreamInputOptions) -> None:
207 super().__init__(db)
208 self.options: BlueskyJetstreamInputOptions = options
209 self._init_identity()
210
211 @override
212 def get_identity_options(self) -> tuple[str | None, str | None, str | None]:
213 return (self.options.handle, self.options.did, self.options.pds)
214
215 def _accept_msg(self, msg: websockets.Data) -> None:
216 data: dict[str, Any] = cast(dict[str, Any], json.loads(msg))
217 if data.get("did") != self.did:
218 return
219 commit: dict[str, Any] | None = data.get("commit")
220 if not commit:
221 return
222
223 commit_type: str = cast(str, commit["operation"])
224 match commit_type:
225 case "create":
226 record: dict[str, Any] = cast(dict[str, Any], commit["record"])
227 record["$xpost.strongRef"] = {
228 "cid": commit["cid"],
229 "uri": f"at://{self.did}/{commit['collection']}/{commit['rkey']}",
230 }
231
232 match cast(str, commit["collection"]):
233 case "app.bsky.feed.post":
234 self._on_post(record)
235 case "app.bsky.feed.repost":
236 self._on_repost(record)
237 case _:
238 pass
239 case "delete":
240 post_id: str = (
241 f"at://{self.did}/{commit['collection']}/{commit['rkey']}"
242 )
243 match cast(str, commit["collection"]):
244 case "app.bsky.feed.post":
245 self._on_delete_post(post_id, False)
246 case "app.bsky.feed.repost":
247 self._on_delete_post(post_id, True)
248 case _:
249 pass
250 case _:
251 pass
252
253 @override
254 async def listen(self):
255 url = self.options.jetstream + "?"
256 url += "wantedCollections=app.bsky.feed.post"
257 url += "&wantedCollections=app.bsky.feed.repost"
258 url += f"&wantedDids={self.did}"
259
260 async for ws in websockets.connect(url):
261 try:
262 self.log.info("Listening to %s...", self.options.jetstream)
263
264 async def listen_for_messages():
265 async for msg in ws:
266 self.submitter(lambda: self._accept_msg(msg))
267
268 listen = asyncio.create_task(listen_for_messages())
269
270 _ = await asyncio.gather(listen)
271 except websockets.ConnectionClosedError as e:
272 self.log.error(e, stack_info=True, exc_info=True)
273 self.log.info("Reconnecting to %s...", self.options.jetstream)
274 continue