social media crossposting tool. 3rd time's the charm
mastodon
misskey
crossposting
bluesky
1import asyncio
2import json
3import re
4from abc import ABC
5from dataclasses import dataclass, field
6from typing import Any, cast, override
7
8import websockets
9
10from atproto.util import AtUri
11from bluesky.info import SERVICE, BlueskyService, validate_and_transform
12from cross.attachments import (
13 LabelsAttachment,
14 LanguagesAttachment,
15 MediaAttachment,
16 RemoteUrlAttachment,
17)
18from cross.media import Blob, download_blob
19from cross.post import Post
20from cross.service import InputService
21from database.connection import DatabasePool
22from util.util import normalize_service_url
23
24
25@dataclass(kw_only=True)
26class BlueskyInputOptions:
27 handle: str | None = None
28 did: str | None = None
29 pds: str | None = None
30 filters: list[re.Pattern[str]] = field(default_factory=lambda: [])
31
32 @classmethod
33 def from_dict(cls, data: dict[str, Any]) -> "BlueskyInputOptions":
34 validate_and_transform(data)
35
36 if "filters" in data:
37 data["filters"] = [re.compile(r) for r in data["filters"]]
38
39 return BlueskyInputOptions(**data)
40
41
42@dataclass(kw_only=True)
43class BlueskyJetstreamInputOptions(BlueskyInputOptions):
44 jetstream: str = "wss://jetstream2.us-west.bsky.network/subscribe"
45
46 @classmethod
47 def from_dict(cls, data: dict[str, Any]) -> "BlueskyJetstreamInputOptions":
48 jetstream = data.pop("jetstream", None)
49
50 base = BlueskyInputOptions.from_dict(data).__dict__.copy()
51 if jetstream:
52 base["jetstream"] = normalize_service_url(jetstream)
53
54 return BlueskyJetstreamInputOptions(**base)
55
56
57class BlueskyBaseInputService(BlueskyService, InputService, ABC):
58 def __init__(self, db: DatabasePool) -> None:
59 super().__init__(SERVICE, db)
60
61 def _on_post(self, record: dict[str, Any]):
62 post_uri = cast(str, record["$xpost.strongRef"]["uri"])
63 post_cid = cast(str, record["$xpost.strongRef"]["cid"])
64
65 parent_uri = cast(
66 str, None if not record.get("reply") else record["reply"]["parent"]["uri"]
67 )
68 parent = None
69 if parent_uri:
70 parent = self._get_post(self.url, self.did, parent_uri)
71 if not parent:
72 self.log.info(
73 "Skipping %s, parent %s not found in db", post_uri, parent_uri
74 )
75 return
76
77 # TODO FRAGMENTS
78 post = Post(id=post_uri, parent_id=parent_uri, text=record["text"])
79 did, _, rid = AtUri.record_uri(post_uri)
80 post.attachments.put(
81 RemoteUrlAttachment(url=f"https://bsky.app/profile/{did}/post/{rid}")
82 )
83
84 embed = record.get("embed", {})
85 if embed:
86 match cast(str, embed["$type"]):
87 case "app.bsky.embed.record" | "app.bsky.embed.recordWithMedia":
88 _, collection, _ = AtUri.record_uri(
89 cast(str, embed["record"]["uri"])
90 )
91 if collection == "app.bsky.feed.post":
92 self.log.info("Skipping '%s'! Quote..", post_uri)
93 return
94 case "app.bsky.embed.images":
95 blobs: list[Blob] = []
96 for image in embed["images"]:
97 blob_cid = image["image"]["ref"]["$link"]
98 url = f"{self.pds}/xrpc/com.atproto.sync.getBlob?did={self.did}&cid={blob_cid}"
99 self.log.info("Downloading %s...", blob_cid)
100 blob: Blob | None = download_blob(url, image.get("alt"))
101 if not blob:
102 self.log.error(
103 "Skipping %s! Failed to download blob %s.",
104 post_uri,
105 blob_cid,
106 )
107 return
108 blobs.append(blob)
109 post.attachments.put(MediaAttachment(blobs=blobs))
110 case "app.bsky.embed.video":
111 blob_cid = embed["video"]["ref"]["$link"]
112 url = f"{self.pds}/xrpc/com.atproto.sync.getBlob?did={self.did}&cid={blob_cid}"
113 self.log.info("Downloading %s...", blob_cid)
114 blob: Blob | None = download_blob(url, embed.get("alt"))
115 if not blob:
116 self.log.error(
117 "Skipping %s! Failed to download blob %s.",
118 post_uri,
119 blob_cid,
120 )
121 return
122 post.attachments.put(MediaAttachment(blobs=[blob]))
123 case _:
124 self.log.warning(f"Unhandled embedd type {embed['$type']}")
125 pass
126
127 if "langs" in record:
128 post.attachments.put(LanguagesAttachment(langs=record["langs"]))
129 if "labels" in record:
130 post.attachments.put(
131 LabelsAttachment(
132 labels=[
133 label["val"].replace("-", " ") for label in record["values"]
134 ]
135 ),
136 )
137
138 if parent:
139 self._insert_post(
140 {
141 "user": self.did,
142 "service": self.url,
143 "identifier": post_uri,
144 "parent": parent["id"],
145 "root": parent["id"] if not parent["root"] else parent["root"],
146 "extra_data": json.dumps({"cid": post_cid}),
147 }
148 )
149 else:
150 self._insert_post(
151 {
152 "user": self.did,
153 "service": self.url,
154 "identifier": post_uri,
155 "extra_data": json.dumps({"cid": post_cid}),
156 }
157 )
158
159 for out in self.outputs:
160 self.submitter(lambda: out.accept_post(post))
161
162 def _on_repost(self, record: dict[str, Any]):
163 post_uri = cast(str, record["$xpost.strongRef"]["uri"])
164 post_cid = cast(str, record["$xpost.strongRef"]["cid"])
165
166 reposted_uri = cast(str, record["subject"]["uri"])
167 reposted = self._get_post(self.url, self.did, reposted_uri)
168 if not reposted:
169 self.log.info(
170 "Skipping repost '%s' as reposted post '%s' was not found in the db."
171 )
172 return
173
174 self._insert_post(
175 {
176 "user": self.did,
177 "service": self.url,
178 "identifier": post_uri,
179 "reposted": reposted["id"],
180 "extra_data": json.dumps({"cid": post_cid}),
181 }
182 )
183
184 for out in self.outputs:
185 self.submitter(lambda: out.accept_repost(post_uri, reposted_uri))
186
187 def _on_delete_post(self, post_id: str, repost: bool):
188 post = self._get_post(self.url, self.did, post_id)
189 if not post:
190 return
191
192 if repost:
193 for output in self.outputs:
194 self.submitter(lambda: output.delete_repost(post_id))
195 else:
196 for output in self.outputs:
197 self.submitter(lambda: output.delete_post(post_id))
198 self._delete_post_by_id(post["id"])
199
200
201class BlueskyJetstreamInputService(BlueskyBaseInputService):
202 def __init__(self, db: DatabasePool, options: BlueskyJetstreamInputOptions) -> None:
203 super().__init__(db)
204 self.options: BlueskyJetstreamInputOptions = options
205 self._init_identity()
206
207 @override
208 def get_identity_options(self) -> tuple[str | None, str | None, str | None]:
209 return (self.options.handle, self.options.did, self.options.pds)
210
211 def _accept_msg(self, msg: websockets.Data) -> None:
212 data: dict[str, Any] = cast(dict[str, Any], json.loads(msg))
213 if data.get("did") != self.did:
214 return
215 commit: dict[str, Any] | None = data.get("commit")
216 if not commit:
217 return
218
219 commit_type: str = cast(str, commit["operation"])
220 match commit_type:
221 case "create":
222 record: dict[str, Any] = cast(dict[str, Any], commit["record"])
223 record["$xpost.strongRef"] = {
224 "cid": commit["cid"],
225 "uri": f"at://{self.did}/{commit['collection']}/{commit['rkey']}",
226 }
227
228 match cast(str, commit["collection"]):
229 case "app.bsky.feed.post":
230 self._on_post(record)
231 case "app.bsky.feed.repost":
232 self._on_repost(record)
233 case _:
234 pass
235 case "delete":
236 post_id: str = (
237 f"at://{self.did}/{commit['collection']}/{commit['rkey']}"
238 )
239 match cast(str, commit["collection"]):
240 case "app.bsky.feed.post":
241 self._on_delete_post(post_id, False)
242 case "app.bsky.feed.repost":
243 self._on_delete_post(post_id, True)
244 case _:
245 pass
246 case _:
247 pass
248
249 @override
250 async def listen(self):
251 url = self.options.jetstream + "?"
252 url += "wantedCollections=app.bsky.feed.post"
253 url += "&wantedCollections=app.bsky.feed.repost"
254 url += f"&wantedDids={self.did}"
255
256 async for ws in websockets.connect(url):
257 try:
258 self.log.info("Listening to %s...", self.options.jetstream)
259
260 async def listen_for_messages():
261 async for msg in ws:
262 self.submitter(lambda: self._accept_msg(msg))
263
264 listen = asyncio.create_task(listen_for_messages())
265
266 _ = await asyncio.gather(listen)
267 except websockets.ConnectionClosedError as e:
268 self.log.error(e, stack_info=True, exc_info=True)
269 self.log.info("Reconnecting to %s...", self.options.jetstream)
270 continue