social media crossposting tool. 3rd time's the charm
mastodon misskey crossposting bluesky

run a formatter

zenfyr.dev dd2bea73 c0211a8f

verified
+82 -72
bluesky/atproto2.py
···
from typing import Any
-
from atproto import client_utils, Client, AtUri, IdResolver
+
+
from atproto import AtUri, Client, IdResolver, client_utils
from atproto_client import models
+
from util.util import LOGGER
+
def resolve_identity(
-
handle: str | None = None,
-
did: str | None = None,
-
pds: str | None = None):
+
handle: str | None = None, did: str | None = None, pds: str | None = None
+
):
"""helper to try and resolve identity from provided parameters, a valid handle is enough"""
-
+
if did and pds:
-
return did, pds[:-1] if pds.endswith('/') else pds
-
+
return did, pds[:-1] if pds.endswith("/") else pds
+
resolver = IdResolver()
if not did:
if not handle:
···
did = resolver.handle.resolve(handle)
if not did:
raise Exception("Failed to resolve DID!")
-
+
if not pds:
LOGGER.info("Resolving PDS from DID document...")
did_doc = resolver.did.resolve(did)
···
pds = did_doc.get_pds_endpoint()
if not pds:
raise Exception("Failed to resolve PDS!")
-
-
return did, pds[:-1] if pds.endswith('/') else pds
+
+
return did, pds[:-1] if pds.endswith("/") else pds
+
class Client2(Client):
def __init__(self, base_url: str | None = None, *args: Any, **kwargs: Any) -> None:
super().__init__(base_url, *args, **kwargs)
-
+
def send_video(
-
self,
-
text: str | client_utils.TextBuilder,
+
self,
+
text: str | client_utils.TextBuilder,
video: bytes,
video_alt: str | None = None,
video_aspect_ratio: models.AppBskyEmbedDefs.AspectRatio | None = None,
···
langs: list[str] | None = None,
facets: list[models.AppBskyRichtextFacet.Main] | None = None,
labels: models.ComAtprotoLabelDefs.SelfLabels | None = None,
-
time_iso: str | None = None
-
) -> models.AppBskyFeedPost.CreateRecordResponse:
+
time_iso: str | None = None,
+
) -> models.AppBskyFeedPost.CreateRecordResponse:
"""same as send_video, but with labels"""
-
+
if video_alt is None:
-
video_alt = ''
+
video_alt = ""
upload = self.upload_blob(video)
-
+
return self.send_post(
text,
reply_to=reply_to,
-
embed=models.AppBskyEmbedVideo.Main(video=upload.blob, alt=video_alt, aspect_ratio=video_aspect_ratio),
+
embed=models.AppBskyEmbedVideo.Main(
+
video=upload.blob, alt=video_alt, aspect_ratio=video_aspect_ratio
+
),
langs=langs,
facets=facets,
labels=labels,
-
time_iso=time_iso
+
time_iso=time_iso,
)
-
+
def send_images(
-
self,
-
text: str | client_utils.TextBuilder,
+
self,
+
text: str | client_utils.TextBuilder,
images: list[bytes],
image_alts: list[str] | None = None,
image_aspect_ratios: list[models.AppBskyEmbedDefs.AspectRatio] | None = None,
···
langs: list[str] | None = None,
facets: list[models.AppBskyRichtextFacet.Main] | None = None,
labels: models.ComAtprotoLabelDefs.SelfLabels | None = None,
-
time_iso: str | None = None
-
) -> models.AppBskyFeedPost.CreateRecordResponse:
+
time_iso: str | None = None,
+
) -> models.AppBskyFeedPost.CreateRecordResponse:
"""same as send_images, but with labels"""
-
+
if image_alts is None:
-
image_alts = [''] * len(images)
+
image_alts = [""] * len(images)
else:
diff = len(images) - len(image_alts)
-
image_alts = image_alts + [''] * diff
-
+
image_alts = image_alts + [""] * diff
+
if image_aspect_ratios is None:
aligned_image_aspect_ratios = [None] * len(images)
else:
diff = len(images) - len(image_aspect_ratios)
aligned_image_aspect_ratios = image_aspect_ratios + [None] * diff
-
+
uploads = [self.upload_blob(image) for image in images]
-
+
embed_images = [
-
models.AppBskyEmbedImages.Image(alt=alt, image=upload.blob, aspect_ratio=aspect_ratio)
-
for alt, upload, aspect_ratio in zip(image_alts, uploads, aligned_image_aspect_ratios)
+
models.AppBskyEmbedImages.Image(
+
alt=alt, image=upload.blob, aspect_ratio=aspect_ratio
+
)
+
for alt, upload, aspect_ratio in zip(
+
image_alts, uploads, aligned_image_aspect_ratios
+
)
]
-
+
return self.send_post(
text,
reply_to=reply_to,
···
langs=langs,
facets=facets,
labels=labels,
-
time_iso=time_iso
+
time_iso=time_iso,
)
-
+
def send_post(
-
self,
-
text: str | client_utils.TextBuilder,
+
self,
+
text: str | client_utils.TextBuilder,
reply_to: models.AppBskyFeedPost.ReplyRef | None = None,
-
embed:
-
None |
-
models.AppBskyEmbedImages.Main |
-
models.AppBskyEmbedExternal.Main |
-
models.AppBskyEmbedRecord.Main |
-
models.AppBskyEmbedRecordWithMedia.Main |
-
models.AppBskyEmbedVideo.Main = None,
+
embed: None
+
| models.AppBskyEmbedImages.Main
+
| models.AppBskyEmbedExternal.Main
+
| models.AppBskyEmbedRecord.Main
+
| models.AppBskyEmbedRecordWithMedia.Main
+
| models.AppBskyEmbedVideo.Main = None,
langs: list[str] | None = None,
facets: list[models.AppBskyRichtextFacet.Main] | None = None,
labels: models.ComAtprotoLabelDefs.SelfLabels | None = None,
-
time_iso: str | None = None
-
) -> models.AppBskyFeedPost.CreateRecordResponse:
+
time_iso: str | None = None,
+
) -> models.AppBskyFeedPost.CreateRecordResponse:
"""same as send_post, but with labels"""
-
+
if isinstance(text, client_utils.TextBuilder):
facets = text.build_facets()
text = text.build_text()
-
+
repo = self.me and self.me.did
if not repo:
raise Exception("Client not logged in!")
-
+
if not langs:
-
langs = ['en']
-
+
langs = ["en"]
+
record = models.AppBskyFeedPost.Record(
created_at=time_iso or self.get_current_time_iso(),
text=text,
···
embed=embed or None,
langs=langs,
facets=facets or None,
-
labels=labels or None
+
labels=labels or None,
)
return self.app.bsky.feed.post.create(repo, record)
-
-
def create_gates(self, thread_gate_opts: list[str], quote_gate: bool, post_uri: str, time_iso: str | None = None):
+
+
def create_gates(
+
self,
+
thread_gate_opts: list[str],
+
quote_gate: bool,
+
post_uri: str,
+
time_iso: str | None = None,
+
):
account = self.me
if not account:
raise Exception("Client not logged in!")
-
+
rkey = AtUri.from_str(post_uri).rkey
time_iso = time_iso or self.get_current_time_iso()
-
-
if 'everybody' not in thread_gate_opts:
+
+
if "everybody" not in thread_gate_opts:
allow = []
if thread_gate_opts:
-
if 'following' in thread_gate_opts:
+
if "following" in thread_gate_opts:
allow.append(models.AppBskyFeedThreadgate.FollowingRule())
-
if 'followers' in thread_gate_opts:
+
if "followers" in thread_gate_opts:
allow.append(models.AppBskyFeedThreadgate.FollowerRule())
-
if 'mentioned' in thread_gate_opts:
+
if "mentioned" in thread_gate_opts:
allow.append(models.AppBskyFeedThreadgate.MentionRule())
-
+
thread_gate = models.AppBskyFeedThreadgate.Record(
-
post=post_uri,
-
created_at=time_iso,
-
allow=allow
+
post=post_uri, created_at=time_iso, allow=allow
)
-
+
self.app.bsky.feed.threadgate.create(account.did, thread_gate, rkey)
-
+
if quote_gate:
post_gate = models.AppBskyFeedPostgate.Record(
post=post_uri,
created_at=time_iso,
-
embedding_rules=[
-
models.AppBskyFeedPostgate.DisableRule()
-
]
+
embedding_rules=[models.AppBskyFeedPostgate.DisableRule()],
)
-
-
self.app.bsky.feed.postgate.create(account.did, post_gate, rkey)
+
+
self.app.bsky.feed.postgate.create(account.did, post_gate, rkey)
+90 -71
bluesky/common.py
···
-
import re, json
+
import re
from atproto import client_utils
···
from util.util import canonical_label
# only for lexicon reference
-
SERVICE = 'https://bsky.app'
+
SERVICE = "https://bsky.app"
# TODO this is terrible and stupid
-
ADULT_PATTERN = re.compile(r"\b(sexual content|nsfw|erotic|adult only|18\+)\b", re.IGNORECASE)
-
PORN_PATTERN = re.compile(r"\b(porn|yiff|hentai|pornographic|fetish)\b", re.IGNORECASE)
+
ADULT_PATTERN = re.compile(
+
r"\b(sexual content|nsfw|erotic|adult only|18\+)\b", re.IGNORECASE
+
)
+
PORN_PATTERN = re.compile(r"\b(porn|yiff|hentai|pornographic|fetish)\b", re.IGNORECASE)
+
class BlueskyPost(cross.Post):
-
def __init__(self, record: dict, tokens: list[cross.Token], attachments: list[MediaInfo]) -> None:
+
def __init__(
+
self, record: dict, tokens: list[cross.Token], attachments: list[MediaInfo]
+
) -> None:
super().__init__()
-
self.uri = record['$xpost.strongRef']['uri']
+
self.uri = record["$xpost.strongRef"]["uri"]
self.parent_uri = None
-
if record.get('reply'):
-
self.parent_uri = record['reply']['parent']['uri']
-
+
if record.get("reply"):
+
self.parent_uri = record["reply"]["parent"]["uri"]
+
self.tokens = tokens
-
self.timestamp = record['createdAt']
-
labels = record.get('labels', {}).get('values')
+
self.timestamp = record["createdAt"]
+
labels = record.get("labels", {}).get("values")
self.spoiler = None
if labels:
-
self.spoiler = ', '.join([str(label['val']).replace('-', ' ') for label in labels])
-
+
self.spoiler = ", ".join(
+
[str(label["val"]).replace("-", " ") for label in labels]
+
)
+
self.attachments = attachments
-
self.languages = record.get('langs', [])
-
+
self.languages = record.get("langs", [])
+
# at:// of the post record
def get_id(self) -> str:
return self.uri
-
+
def get_parent_id(self) -> str | None:
return self.parent_uri
-
+
def get_tokens(self) -> list[cross.Token]:
return self.tokens
-
+
def get_text_type(self) -> str:
return "text/plain"
-
+
def get_timestamp(self) -> str:
return self.timestamp
def get_attachments(self) -> list[MediaInfo]:
return self.attachments
-
+
def get_spoiler(self) -> str | None:
return self.spoiler
def get_languages(self) -> list[str]:
return self.languages
-
+
def is_sensitive(self) -> bool:
return self.spoiler is not None
def get_post_url(self) -> str | None:
-
did, _, post_id = str(self.uri[len("at://"):]).split("/")
-
+
did, _, post_id = str(self.uri[len("at://") :]).split("/")
+
return f"https://bsky.app/profile/{did}/post/{post_id}"
+
def tokenize_post(post: dict) -> list[cross.Token]:
-
text: str = post.get('text', '')
+
text: str = post.get("text", "")
if not text:
return []
-
ut8_text = text.encode(encoding='utf-8')
-
+
ut8_text = text.encode(encoding="utf-8")
+
def decode(ut8: bytes) -> str:
-
return ut8.decode(encoding='utf-8')
-
-
facets: list[dict] = post.get('facets', [])
+
return ut8.decode(encoding="utf-8")
+
+
facets: list[dict] = post.get("facets", [])
if not facets:
return [cross.TextToken(decode(ut8_text))]
-
+
slices: list[tuple[int, int, str, str]] = []
-
+
for facet in facets:
-
features: list[dict] = facet.get('features', [])
+
features: list[dict] = facet.get("features", [])
if not features:
continue
-
+
# we don't support overlapping facets/features
feature = features[0]
-
feature_type = feature['$type']
-
index = facet['index']
+
feature_type = feature["$type"]
+
index = facet["index"]
match feature_type:
-
case 'app.bsky.richtext.facet#tag':
-
slices.append((index['byteStart'], index['byteEnd'], 'tag', feature['tag']))
-
case 'app.bsky.richtext.facet#link':
-
slices.append((index['byteStart'], index['byteEnd'], 'link', feature['uri']))
-
case 'app.bsky.richtext.facet#mention':
-
slices.append((index['byteStart'], index['byteEnd'], 'mention', feature['did']))
-
+
case "app.bsky.richtext.facet#tag":
+
slices.append(
+
(index["byteStart"], index["byteEnd"], "tag", feature["tag"])
+
)
+
case "app.bsky.richtext.facet#link":
+
slices.append(
+
(index["byteStart"], index["byteEnd"], "link", feature["uri"])
+
)
+
case "app.bsky.richtext.facet#mention":
+
slices.append(
+
(index["byteStart"], index["byteEnd"], "mention", feature["did"])
+
)
+
if not slices:
return [cross.TextToken(decode(ut8_text))]
-
+
slices.sort(key=lambda s: s[0])
unique: list[tuple[int, int, str, str]] = []
current_end = 0
···
if start >= current_end:
unique.append((start, end, ttype, val))
current_end = end
-
+
if not unique:
return [cross.TextToken(decode(ut8_text))]
-
+
tokens: list[cross.Token] = []
prev = 0
-
+
for start, end, ttype, val in unique:
if start > prev:
# text between facets
tokens.append(cross.TextToken(decode(ut8_text[prev:start])))
# facet token
match ttype:
-
case 'link':
+
case "link":
label = decode(ut8_text[start:end])
-
+
# try to unflatten links
-
split = val.split('://', 1)
+
split = val.split("://", 1)
if len(split) > 1:
if split[1].startswith(label):
-
tokens.append(cross.LinkToken(val, ''))
+
tokens.append(cross.LinkToken(val, ""))
prev = end
continue
-
-
if label.endswith('...') and split[1].startswith(label[:-3]):
-
tokens.append(cross.LinkToken(val, ''))
+
+
if label.endswith("...") and split[1].startswith(label[:-3]):
+
tokens.append(cross.LinkToken(val, ""))
prev = end
-
continue
-
+
continue
+
tokens.append(cross.LinkToken(val, label))
-
case 'tag':
+
case "tag":
tag = decode(ut8_text[start:end])
-
tokens.append(cross.TagToken(tag[1:] if tag.startswith('#') else tag))
-
case 'mention':
+
tokens.append(cross.TagToken(tag[1:] if tag.startswith("#") else tag))
+
case "mention":
mention = decode(ut8_text[start:end])
-
tokens.append(cross.MentionToken(mention[1:] if mention.startswith('@') else mention, val))
+
tokens.append(
+
cross.MentionToken(
+
mention[1:] if mention.startswith("@") else mention, val
+
)
+
)
prev = end
if prev < len(ut8_text):
tokens.append(cross.TextToken(decode(ut8_text[prev:])))
-
-
return tokens
+
+
return tokens
+
def tokens_to_richtext(tokens: list[cross.Token]) -> client_utils.TextBuilder | None:
builder = client_utils.TextBuilder()
-
+
def flatten_link(href: str):
-
split = href.split('://', 1)
+
split = href.split("://", 1)
if len(split) > 1:
href = split[1]
-
+
if len(href) > 32:
-
href = href[:32] + '...'
-
+
href = href[:32] + "..."
+
return href
-
+
for token in tokens:
if isinstance(token, cross.TextToken):
builder.text(token.text)
···
if canonical_label(token.label, token.href):
builder.link(flatten_link(token.href), token.href)
continue
-
+
builder.link(token.label, token.href)
elif isinstance(token, cross.TagToken):
-
builder.tag('#' + token.tag, token.tag.lower())
+
builder.tag("#" + token.tag, token.tag.lower())
else:
# fail on unsupported tokens
return None
-
-
return builder
+
+
return builder
+105 -79
bluesky/input.py
···
-
import re, json, websockets, asyncio
+
import asyncio
+
import json
+
import re
+
from typing import Any, Callable
+
import websockets
from atproto_client import models
from atproto_client.models.utils import get_or_create as get_model_or_create
+
+
import cross
+
import util.database as database
from bluesky.atproto2 import resolve_identity
-
-
from bluesky.common import BlueskyPost, SERVICE, tokenize_post
-
-
import cross, util.database as database
+
from bluesky.common import SERVICE, BlueskyPost, tokenize_post
+
from util.database import DataBaseWorker
+
from util.media import MediaInfo, download_media
from util.util import LOGGER, as_envvar
-
from util.media import MediaInfo, download_media
-
from util.database import DataBaseWorker
-
from typing import Callable, Any
-
class BlueskyInputOptions():
+
class BlueskyInputOptions:
def __init__(self, o: dict) -> None:
-
self.filters = [re.compile(f) for f in o.get('regex_filters', [])]
+
self.filters = [re.compile(f) for f in o.get("regex_filters", [])]
+
class BlueskyInput(cross.Input):
def __init__(self, settings: dict, db: DataBaseWorker) -> None:
-
self.options = BlueskyInputOptions(settings.get('options', {}))
+
self.options = BlueskyInputOptions(settings.get("options", {}))
did, pds = resolve_identity(
-
handle=as_envvar(settings.get('handle')),
-
did=as_envvar(settings.get('did')),
-
pds=as_envvar(settings.get('pds'))
+
handle=as_envvar(settings.get("handle")),
+
did=as_envvar(settings.get("did")),
+
pds=as_envvar(settings.get("pds")),
)
self.pds = pds
-
+
# PDS is Not a service, the lexicon and rids are the same across pds
super().__init__(SERVICE, did, settings, db)
-
+
def _on_post(self, outputs: list[cross.Output], post: dict[str, Any]):
-
post_uri = post['$xpost.strongRef']['uri']
-
post_cid = post['$xpost.strongRef']['cid']
-
+
post_uri = post["$xpost.strongRef"]["uri"]
+
post_cid = post["$xpost.strongRef"]["cid"]
+
parent_uri = None
-
if post.get('reply'):
-
parent_uri = post['reply']['parent']['uri']
-
-
embed = post.get('embed', {})
-
if embed.get('$type') in ('app.bsky.embed.record', 'app.bsky.embed.recordWithMedia'):
-
did, collection, rid = str(embed['record']['uri'][len('at://'):]).split('/')
-
if collection == 'app.bsky.feed.post':
+
if post.get("reply"):
+
parent_uri = post["reply"]["parent"]["uri"]
+
+
embed = post.get("embed", {})
+
if embed.get("$type") in (
+
"app.bsky.embed.record",
+
"app.bsky.embed.recordWithMedia",
+
):
+
did, collection, rid = str(embed["record"]["uri"][len("at://") :]).split(
+
"/"
+
)
+
if collection == "app.bsky.feed.post":
LOGGER.info("Skipping '%s'! Quote..", post_uri)
return
-
-
success = database.try_insert_post(self.db, post_uri, parent_uri, self.user_id, self.service)
+
+
success = database.try_insert_post(
+
self.db, post_uri, parent_uri, self.user_id, self.service
+
)
if not success:
LOGGER.info("Skipping '%s' as parent post was not found in db!", post_uri)
return
-
database.store_data(self.db, post_uri, self.user_id, self.service, {'cid': post_cid})
-
+
database.store_data(
+
self.db, post_uri, self.user_id, self.service, {"cid": post_cid}
+
)
+
tokens = tokenize_post(post)
if not cross.test_filters(tokens, self.options.filters):
LOGGER.info("Skipping '%s'. Matched a filter!", post_uri)
return
-
+
LOGGER.info("Crossposting '%s'...", post_uri)
-
+
def get_blob_url(blob: str):
-
return f'{self.pds}/xrpc/com.atproto.sync.getBlob?did={self.user_id}&cid={blob}'
-
+
return f"{self.pds}/xrpc/com.atproto.sync.getBlob?did={self.user_id}&cid={blob}"
+
attachments: list[MediaInfo] = []
-
if embed.get('$type') == 'app.bsky.embed.images':
+
if embed.get("$type") == "app.bsky.embed.images":
model = get_model_or_create(embed, model=models.AppBskyEmbedImages.Main)
assert isinstance(model, models.AppBskyEmbedImages.Main)
-
+
for image in model.images:
url = get_blob_url(image.image.cid.encode())
LOGGER.info("Downloading %s...", url)
···
LOGGER.error("Skipping '%s'. Failed to download media!", post_uri)
return
attachments.append(io)
-
elif embed.get('$type') == 'app.bsky.embed.video':
+
elif embed.get("$type") == "app.bsky.embed.video":
model = get_model_or_create(embed, model=models.AppBskyEmbedVideo.Main)
assert isinstance(model, models.AppBskyEmbedVideo.Main)
url = get_blob_url(model.video.cid.encode())
LOGGER.info("Downloading %s...", url)
-
io = download_media(url, model.alt if model.alt else '')
+
io = download_media(url, model.alt if model.alt else "")
if not io:
LOGGER.error("Skipping '%s'. Failed to download media!", post_uri)
return
attachments.append(io)
-
+
cross_post = BlueskyPost(post, tokens, attachments)
for output in outputs:
output.accept_post(cross_post)
···
post = database.find_post(self.db, post_id, self.user_id, self.service)
if not post:
return
-
+
LOGGER.info("Deleting '%s'...", post_id)
if repost:
for output in outputs:
···
for output in outputs:
output.delete_post(post_id)
database.delete_post(self.db, post_id, self.user_id, self.service)
-
+
def _on_repost(self, outputs: list[cross.Output], post: dict[str, Any]):
-
post_uri = post['$xpost.strongRef']['uri']
-
post_cid = post['$xpost.strongRef']['cid']
-
-
reposted_uri = post['subject']['uri']
-
-
success = database.try_insert_repost(self.db, post_uri, reposted_uri, self.user_id, self.service)
+
post_uri = post["$xpost.strongRef"]["uri"]
+
post_cid = post["$xpost.strongRef"]["cid"]
+
+
reposted_uri = post["subject"]["uri"]
+
+
success = database.try_insert_repost(
+
self.db, post_uri, reposted_uri, self.user_id, self.service
+
)
if not success:
LOGGER.info("Skipping '%s' as reposted post was not found in db!", post_uri)
return
-
database.store_data(self.db, post_uri, self.user_id, self.service, {'cid': post_cid})
-
+
database.store_data(
+
self.db, post_uri, self.user_id, self.service, {"cid": post_cid}
+
)
+
LOGGER.info("Crossposting '%s'...", post_uri)
for output in outputs:
output.accept_repost(post_uri, reposted_uri)
+
class BlueskyJetstreamInput(BlueskyInput):
def __init__(self, settings: dict, db: DataBaseWorker) -> None:
super().__init__(settings, db)
-
self.jetstream = settings.get("jetstream", "wss://jetstream2.us-east.bsky.network/subscribe")
-
+
self.jetstream = settings.get(
+
"jetstream", "wss://jetstream2.us-east.bsky.network/subscribe"
+
)
+
def __on_commit(self, outputs: list[cross.Output], msg: dict):
-
if msg.get('did') != self.user_id:
+
if msg.get("did") != self.user_id:
return
-
-
commit: dict = msg.get('commit', {})
+
+
commit: dict = msg.get("commit", {})
if not commit:
return
-
-
commit_type = commit['operation']
+
+
commit_type = commit["operation"]
match commit_type:
-
case 'create':
-
record = dict(commit.get('record', {}))
-
record['$xpost.strongRef'] = {
-
'cid': commit['cid'],
-
'uri': f'at://{self.user_id}/{commit['collection']}/{commit['rkey']}'
+
case "create":
+
record = dict(commit.get("record", {}))
+
record["$xpost.strongRef"] = {
+
"cid": commit["cid"],
+
"uri": f"at://{self.user_id}/{commit['collection']}/{commit['rkey']}",
}
-
-
match commit['collection']:
-
case 'app.bsky.feed.post':
+
+
match commit["collection"]:
+
case "app.bsky.feed.post":
self._on_post(outputs, record)
-
case 'app.bsky.feed.repost':
+
case "app.bsky.feed.repost":
self._on_repost(outputs, record)
-
case 'delete':
-
post_id: str = f'at://{self.user_id}/{commit['collection']}/{commit['rkey']}'
-
match commit['collection']:
-
case 'app.bsky.feed.post':
+
case "delete":
+
post_id: str = (
+
f"at://{self.user_id}/{commit['collection']}/{commit['rkey']}"
+
)
+
match commit["collection"]:
+
case "app.bsky.feed.post":
self._on_delete_post(outputs, post_id, False)
-
case 'app.bsky.feed.repost':
+
case "app.bsky.feed.repost":
self._on_delete_post(outputs, post_id, True)
-
-
async def listen(self, outputs: list[cross.Output], submit: Callable[[Callable[[], Any]], Any]):
-
uri = self.jetstream + '?'
+
+
async def listen(
+
self, outputs: list[cross.Output], submit: Callable[[Callable[[], Any]], Any]
+
):
+
uri = self.jetstream + "?"
uri += "wantedCollections=app.bsky.feed.post"
uri += "&wantedCollections=app.bsky.feed.repost"
uri += f"&wantedDids={self.user_id}"
-
-
async for ws in websockets.connect(uri, extra_headers={"User-Agent": "XPost/0.0.3"}):
+
+
async for ws in websockets.connect(
+
uri, extra_headers={"User-Agent": "XPost/0.0.3"}
+
):
try:
LOGGER.info("Listening to %s...", self.jetstream)
-
+
async def listen_for_messages():
async for msg in ws:
submit(lambda: self.__on_commit(outputs, json.loads(msg)))
-
+
listen = asyncio.create_task(listen_for_messages())
-
+
await asyncio.gather(listen)
except websockets.ConnectionClosedError as e:
LOGGER.error(e, stack_info=True, exc_info=True)
LOGGER.info("Reconnecting to %s...", self.jetstream)
-
continue
+
continue
+238 -209
bluesky/output.py
···
-
import json
-
from httpx import Timeout
-
-
from atproto import client_utils, Request
+
from atproto import Request, client_utils
from atproto_client import models
-
from bluesky.atproto2 import Client2, resolve_identity
-
-
from bluesky.common import SERVICE, ADULT_PATTERN, PORN_PATTERN, tokens_to_richtext
+
from httpx import Timeout
-
import cross, util.database as database
+
import cross
import misskey.mfm_util as mfm_util
-
from util.util import LOGGER, as_envvar
-
from util.media import MediaInfo, get_filename_from_url, get_media_meta, compress_image, convert_to_mp4
+
import util.database as database
+
from bluesky.atproto2 import Client2, resolve_identity
+
from bluesky.common import ADULT_PATTERN, PORN_PATTERN, SERVICE, tokens_to_richtext
from util.database import DataBaseWorker
+
from util.media import (
+
MediaInfo,
+
compress_image,
+
convert_to_mp4,
+
get_filename_from_url,
+
get_media_meta,
+
)
+
from util.util import LOGGER, as_envvar
-
ALLOWED_GATES = ['mentioned', 'following', 'followers', 'everybody']
+
ALLOWED_GATES = ["mentioned", "following", "followers", "everybody"]
+
class BlueskyOutputOptions:
def __init__(self, o: dict) -> None:
self.quote_gate: bool = False
-
self.thread_gate: list[str] = ['everybody']
+
self.thread_gate: list[str] = ["everybody"]
self.encode_videos: bool = True
-
-
quote_gate = o.get('quote_gate')
+
+
quote_gate = o.get("quote_gate")
if quote_gate is not None:
self.quote_gate = bool(quote_gate)
-
-
thread_gate = o.get('thread_gate')
+
+
thread_gate = o.get("thread_gate")
if thread_gate is not None:
if any([v not in ALLOWED_GATES for v in thread_gate]):
-
raise ValueError(f"'thread_gate' only accepts {', '.join(ALLOWED_GATES)} or [], got: {thread_gate}")
+
raise ValueError(
+
f"'thread_gate' only accepts {', '.join(ALLOWED_GATES)} or [], got: {thread_gate}"
+
)
self.thread_gate = thread_gate
-
-
encode_videos = o.get('encode_videos')
+
+
encode_videos = o.get("encode_videos")
if encode_videos is not None:
self.encode_videos = bool(encode_videos)
+
class BlueskyOutput(cross.Output):
def __init__(self, input: cross.Input, settings: dict, db: DataBaseWorker) -> None:
super().__init__(input, settings, db)
-
self.options = BlueskyOutputOptions(settings.get('options') or {})
-
-
if not as_envvar(settings.get('app-password')):
+
self.options = BlueskyOutputOptions(settings.get("options") or {})
+
+
if not as_envvar(settings.get("app-password")):
raise Exception("Account app password not provided!")
-
+
did, pds = resolve_identity(
-
handle=as_envvar(settings.get('handle')),
-
did=as_envvar(settings.get('did')),
-
pds=as_envvar(settings.get('pds'))
+
handle=as_envvar(settings.get("handle")),
+
did=as_envvar(settings.get("did")),
+
pds=as_envvar(settings.get("pds")),
)
-
+
reqs = Request(timeout=Timeout(None, connect=30.0))
-
+
self.bsky = Client2(pds, request=reqs)
self.bsky.configure_proxy_header(
-
service_type='bsky_appview',
-
did=as_envvar(settings.get('bsky_appview')) or 'did:web:api.bsky.app'
+
service_type="bsky_appview",
+
did=as_envvar(settings.get("bsky_appview")) or "did:web:api.bsky.app",
)
-
self.bsky.login(did, as_envvar(settings.get('app-password')))
-
+
self.bsky.login(did, as_envvar(settings.get("app-password")))
+
def __check_login(self):
login = self.bsky.me
if not login:
raise Exception("Client not logged in!")
return login
-
+
def _find_parent(self, parent_id: str):
login = self.__check_login()
-
+
thread_tuple = database.find_mapped_thread(
self.db,
parent_id,
self.input.user_id,
self.input.service,
login.did,
-
SERVICE
+
SERVICE,
)
-
+
if not thread_tuple:
LOGGER.error("Failed to find thread tuple in the database!")
return None
-
+
root_uri: str = thread_tuple[0]
reply_uri: str = thread_tuple[1]
-
-
root_cid = database.fetch_data(self.db, root_uri, login.did, SERVICE)['cid']
-
reply_cid = database.fetch_data(self.db, root_uri, login.did, SERVICE)['cid']
-
-
root_record = models.AppBskyFeedPost.CreateRecordResponse(uri=root_uri, cid=root_cid)
-
reply_record = models.AppBskyFeedPost.CreateRecordResponse(uri=reply_uri, cid=reply_cid)
-
+
+
root_cid = database.fetch_data(self.db, root_uri, login.did, SERVICE)["cid"]
+
reply_cid = database.fetch_data(self.db, root_uri, login.did, SERVICE)["cid"]
+
+
root_record = models.AppBskyFeedPost.CreateRecordResponse(
+
uri=root_uri, cid=root_cid
+
)
+
reply_record = models.AppBskyFeedPost.CreateRecordResponse(
+
uri=reply_uri, cid=reply_cid
+
)
+
return (
models.create_strong_ref(root_record),
models.create_strong_ref(reply_record),
thread_tuple[2],
-
thread_tuple[3]
+
thread_tuple[3],
)
-
+
def _split_attachments(self, attachments: list[MediaInfo]):
sup_media: list[MediaInfo] = []
unsup_media: list[MediaInfo] = []
-
+
for a in attachments:
-
if a.mime.startswith('image/') or a.mime.startswith('video/'): # TODO convert gifs to videos
+
if a.mime.startswith("image/") or a.mime.startswith(
+
"video/"
+
): # TODO convert gifs to videos
sup_media.append(a)
else:
unsup_media.append(a)
-
+
return (sup_media, unsup_media)
def _split_media_per_post(
-
self,
-
tokens: list[client_utils.TextBuilder],
-
media: list[MediaInfo]):
-
+
self, tokens: list[client_utils.TextBuilder], media: list[MediaInfo]
+
):
posts: list[dict] = [{"tokens": tokens, "attachments": []} for tokens in tokens]
available_indices: list[int] = list(range(len(posts)))
-
+
current_image_post_idx: int | None = None
def make_blank_post() -> dict:
-
return {
-
"tokens": [client_utils.TextBuilder().text('')],
-
"attachments": []
-
}
-
+
return {"tokens": [client_utils.TextBuilder().text("")], "attachments": []}
+
def pop_next_empty_index() -> int:
if available_indices:
return available_indices.pop(0)
···
new_idx = len(posts)
posts.append(make_blank_post())
return new_idx
-
+
for att in media:
-
if att.mime.startswith('video/'):
+
if att.mime.startswith("video/"):
current_image_post_idx = None
idx = pop_next_empty_index()
posts[idx]["attachments"].append(att)
-
elif att.mime.startswith('image/'):
+
elif att.mime.startswith("image/"):
if (
current_image_post_idx is not None
and len(posts[current_image_post_idx]["attachments"]) < 4
···
idx = pop_next_empty_index()
posts[idx]["attachments"].append(att)
current_image_post_idx = idx
-
+
result: list[tuple[client_utils.TextBuilder, list[MediaInfo]]] = []
for p in posts:
result.append((p["tokens"], p["attachments"]))
return result
-
+
def accept_post(self, post: cross.Post):
login = self.__check_login()
-
+
parent_id = post.get_parent_id()
-
+
# used for db insertion
new_root_id = None
new_parent_id = None
-
+
root_ref = None
reply_ref = None
if parent_id:
···
if not parents:
return
root_ref, reply_ref, new_root_id, new_parent_id = parents
-
+
tokens = post.get_tokens().copy()
-
+
unique_labels: set[str] = set()
cw = post.get_spoiler()
if cw:
tokens.insert(0, cross.TextToken("CW: " + cw + "\n\n"))
-
unique_labels.add('graphic-media')
-
+
unique_labels.add("graphic-media")
+
# from bsky.app, a post can only have one of those labels
if PORN_PATTERN.search(cw):
-
unique_labels.add('porn')
+
unique_labels.add("porn")
elif ADULT_PATTERN.search(cw):
-
unique_labels.add('sexual')
-
+
unique_labels.add("sexual")
+
if post.is_sensitive():
-
unique_labels.add('graphic-media')
-
-
labels = models.ComAtprotoLabelDefs.SelfLabels(values=[models.ComAtprotoLabelDefs.SelfLabel(val=label) for label in unique_labels])
+
unique_labels.add("graphic-media")
+
+
labels = models.ComAtprotoLabelDefs.SelfLabels(
+
values=[
+
models.ComAtprotoLabelDefs.SelfLabel(val=label)
+
for label in unique_labels
+
]
+
)
sup_media, unsup_media = self._split_attachments(post.get_attachments())
if unsup_media:
if tokens:
-
tokens.append(cross.TextToken('\n'))
+
tokens.append(cross.TextToken("\n"))
for i, attachment in enumerate(unsup_media):
-
tokens.append(cross.LinkToken(
-
attachment.url,
-
f"[{get_filename_from_url(attachment.url)}]"
-
))
-
tokens.append(cross.TextToken(' '))
-
+
tokens.append(
+
cross.LinkToken(
+
attachment.url, f"[{get_filename_from_url(attachment.url)}]"
+
)
+
)
+
tokens.append(cross.TextToken(" "))
+
if post.get_text_type() == "text/x.misskeymarkdown":
tokens, status = mfm_util.strip_mfm(tokens)
post_url = post.get_post_url()
if status and post_url:
-
tokens.append(cross.TextToken('\n'))
-
tokens.append(cross.LinkToken(post_url, "[Post contains MFM, see original]"))
-
+
tokens.append(cross.TextToken("\n"))
+
tokens.append(
+
cross.LinkToken(post_url, "[Post contains MFM, see original]")
+
)
+
split_tokens: list[list[cross.Token]] = cross.split_tokens(tokens, 300)
post_text: list[client_utils.TextBuilder] = []
-
+
# convert tokens into rich text. skip post if contains unsupported tokens
for block in split_tokens:
rich_text = tokens_to_richtext(block)
-
+
if not rich_text:
-
LOGGER.error("Skipping '%s' as it contains invalid rich text types!", post.get_id())
+
LOGGER.error(
+
"Skipping '%s' as it contains invalid rich text types!",
+
post.get_id(),
+
)
return
post_text.append(rich_text)
-
+
if not post_text:
-
post_text = [client_utils.TextBuilder().text('')]
-
+
post_text = [client_utils.TextBuilder().text("")]
+
for m in sup_media:
-
if m.mime.startswith('image/'):
+
if m.mime.startswith("image/"):
if len(m.io) > 2_000_000:
-
LOGGER.error("Skipping post_id '%s', failed to download attachment! File too large.", post.get_id())
+
LOGGER.error(
+
"Skipping post_id '%s', failed to download attachment! File too large.",
+
post.get_id(),
+
)
return
-
-
if m.mime.startswith('video/'):
-
if m.mime != 'video/mp4' and not self.options.encode_videos:
-
LOGGER.info("Video is not mp4, but encoding is disabled. Skipping '%s'...", post.get_id())
+
+
if m.mime.startswith("video/"):
+
if m.mime != "video/mp4" and not self.options.encode_videos:
+
LOGGER.info(
+
"Video is not mp4, but encoding is disabled. Skipping '%s'...",
+
post.get_id(),
+
)
return
-
+
if len(m.io) > 100_000_000:
-
LOGGER.error("Skipping post_id '%s', failed to download attachment! File too large?", post.get_id())
+
LOGGER.error(
+
"Skipping post_id '%s', failed to download attachment! File too large?",
+
post.get_id(),
+
)
return
-
+
created_records: list[models.AppBskyFeedPost.CreateRecordResponse] = []
baked_media = self._split_media_per_post(post_text, sup_media)
-
+
for text, attachments in baked_media:
if not attachments:
if reply_ref and root_ref:
-
new_post = self.bsky.send_post(text, reply_to=models.AppBskyFeedPost.ReplyRef(
-
parent=reply_ref,
-
root=root_ref
-
), labels=labels, time_iso=post.get_timestamp())
+
new_post = self.bsky.send_post(
+
text,
+
reply_to=models.AppBskyFeedPost.ReplyRef(
+
parent=reply_ref, root=root_ref
+
),
+
labels=labels,
+
time_iso=post.get_timestamp(),
+
)
else:
-
new_post = self.bsky.send_post(text, labels=labels, time_iso=post.get_timestamp())
+
new_post = self.bsky.send_post(
+
text, labels=labels, time_iso=post.get_timestamp()
+
)
root_ref = models.create_strong_ref(new_post)
-
+
self.bsky.create_gates(
-
self.options.thread_gate,
-
self.options.quote_gate,
-
new_post.uri,
-
time_iso=post.get_timestamp()
+
self.options.thread_gate,
+
self.options.quote_gate,
+
new_post.uri,
+
time_iso=post.get_timestamp(),
)
reply_ref = models.create_strong_ref(new_post)
created_records.append(new_post)
else:
# if a single post is an image - everything else is an image
-
if attachments[0].mime.startswith('image/'):
+
if attachments[0].mime.startswith("image/"):
images: list[bytes] = []
image_alts: list[str] = []
image_aspect_ratios: list[models.AppBskyEmbedDefs.AspectRatio] = []
-
+
for attachment in attachments:
image_io = compress_image(attachment.io, quality=100)
metadata = get_media_meta(image_io)
-
+
if len(image_io) > 1_000_000:
LOGGER.info("Compressing %s...", attachment.name)
image_io = compress_image(image_io)
-
+
images.append(image_io)
image_alts.append(attachment.alt)
-
image_aspect_ratios.append(models.AppBskyEmbedDefs.AspectRatio(
-
width=metadata['width'],
-
height=metadata['height']
-
))
-
+
image_aspect_ratios.append(
+
models.AppBskyEmbedDefs.AspectRatio(
+
width=metadata["width"], height=metadata["height"]
+
)
+
)
+
new_post = self.bsky.send_images(
text=post_text[0],
images=images,
image_alts=image_alts,
image_aspect_ratios=image_aspect_ratios,
-
reply_to= models.AppBskyFeedPost.ReplyRef(
-
parent=reply_ref,
-
root=root_ref
-
) if root_ref and reply_ref else None,
-
labels=labels,
-
time_iso=post.get_timestamp()
+
reply_to=models.AppBskyFeedPost.ReplyRef(
+
parent=reply_ref, root=root_ref
+
)
+
if root_ref and reply_ref
+
else None,
+
labels=labels,
+
time_iso=post.get_timestamp(),
)
if not root_ref:
root_ref = models.create_strong_ref(new_post)
-
+
self.bsky.create_gates(
-
self.options.thread_gate,
+
self.options.thread_gate,
self.options.quote_gate,
-
new_post.uri,
-
time_iso=post.get_timestamp()
+
new_post.uri,
+
time_iso=post.get_timestamp(),
)
reply_ref = models.create_strong_ref(new_post)
created_records.append(new_post)
-
else: # video is guarantedd to be one
+
else: # video is guarantedd to be one
metadata = get_media_meta(attachments[0].io)
-
if metadata['duration'] > 180:
-
LOGGER.info("Skipping post_id '%s', video attachment too long!", post.get_id())
+
if metadata["duration"] > 180:
+
LOGGER.info(
+
"Skipping post_id '%s', video attachment too long!",
+
post.get_id(),
+
)
return
-
+
video_io = attachments[0].io
-
if attachments[0].mime != 'video/mp4':
+
if attachments[0].mime != "video/mp4":
LOGGER.info("Converting %s to mp4...", attachments[0].name)
video_io = convert_to_mp4(video_io)
-
+
aspect_ratio = models.AppBskyEmbedDefs.AspectRatio(
-
width=metadata['width'],
-
height=metadata['height']
+
width=metadata["width"], height=metadata["height"]
)
-
+
new_post = self.bsky.send_video(
text=post_text[0],
video=video_io,
video_aspect_ratio=aspect_ratio,
video_alt=attachments[0].alt,
-
reply_to= models.AppBskyFeedPost.ReplyRef(
-
parent=reply_ref,
-
root=root_ref
-
) if root_ref and reply_ref else None,
+
reply_to=models.AppBskyFeedPost.ReplyRef(
+
parent=reply_ref, root=root_ref
+
)
+
if root_ref and reply_ref
+
else None,
labels=labels,
-
time_iso=post.get_timestamp()
+
time_iso=post.get_timestamp(),
)
if not root_ref:
root_ref = models.create_strong_ref(new_post)
-
+
self.bsky.create_gates(
self.options.thread_gate,
-
self.options.quote_gate,
-
new_post.uri,
-
time_iso=post.get_timestamp()
+
self.options.quote_gate,
+
new_post.uri,
+
time_iso=post.get_timestamp(),
)
reply_ref = models.create_strong_ref(new_post)
created_records.append(new_post)
-
-
db_post = database.find_post(self.db, post.get_id(), self.input.user_id, self.input.service)
+
+
db_post = database.find_post(
+
self.db, post.get_id(), self.input.user_id, self.input.service
+
)
assert db_post, "ghghghhhhh"
-
-
if new_root_id is None or new_parent_id is None:
+
+
if new_root_id is None or new_parent_id is None:
new_root_id = database.insert_post(
+
self.db, created_records[0].uri, login.did, SERVICE
+
)
+
database.store_data(
self.db,
created_records[0].uri,
login.did,
-
SERVICE
-
)
-
database.store_data(
-
self.db,
-
created_records[0].uri,
-
login.did,
SERVICE,
-
{'cid': created_records[0].cid}
+
{"cid": created_records[0].cid},
)
-
+
new_parent_id = new_root_id
-
database.insert_mapping(self.db, db_post['id'], new_parent_id)
+
database.insert_mapping(self.db, db_post["id"], new_parent_id)
created_records = created_records[1:]
-
+
for record in created_records:
new_parent_id = database.insert_reply(
-
self.db,
-
record.uri,
-
login.did,
-
SERVICE,
-
new_parent_id,
-
new_root_id
+
self.db, record.uri, login.did, SERVICE, new_parent_id, new_root_id
)
database.store_data(
-
self.db,
-
record.uri,
-
login.did,
-
SERVICE,
-
{'cid': record.cid}
+
self.db, record.uri, login.did, SERVICE, {"cid": record.cid}
)
-
database.insert_mapping(self.db, db_post['id'], new_parent_id)
-
+
database.insert_mapping(self.db, db_post["id"], new_parent_id)
+
def delete_post(self, identifier: str):
login = self.__check_login()
-
-
post = database.find_post(self.db, identifier, self.input.user_id, self.input.service)
+
+
post = database.find_post(
+
self.db, identifier, self.input.user_id, self.input.service
+
)
if not post:
return
-
-
mappings = database.find_mappings(self.db, post['id'], SERVICE, login.did)
+
+
mappings = database.find_mappings(self.db, post["id"], SERVICE, login.did)
for mapping in mappings[::-1]:
LOGGER.info("Deleting '%s'...", mapping[0])
self.bsky.delete_post(mapping[0])
database.delete_post(self.db, mapping[0], SERVICE, login.did)
-
+
def accept_repost(self, repost_id: str, reposted_id: str):
login, repost = self.__delete_repost(repost_id)
if not (login and repost):
return
-
-
reposted = database.find_post(self.db, reposted_id, self.input.user_id, self.input.service)
+
+
reposted = database.find_post(
+
self.db, reposted_id, self.input.user_id, self.input.service
+
)
if not reposted:
return
-
+
# mappings of the reposted post
-
mappings = database.find_mappings(self.db, reposted['id'], SERVICE, login.did)
+
mappings = database.find_mappings(self.db, reposted["id"], SERVICE, login.did)
if mappings:
-
cid = database.fetch_data(self.db, mappings[0][0], login.did, SERVICE)['cid']
+
cid = database.fetch_data(self.db, mappings[0][0], login.did, SERVICE)[
+
"cid"
+
]
rsp = self.bsky.repost(mappings[0][0], cid)
-
+
internal_id = database.insert_repost(
-
self.db,
-
rsp.uri,
-
reposted['id'],
-
login.did,
-
SERVICE)
-
database.store_data(
-
self.db,
-
rsp.uri,
-
login.did,
-
SERVICE,
-
{'cid': rsp.cid}
+
self.db, rsp.uri, reposted["id"], login.did, SERVICE
)
-
database.insert_mapping(self.db, repost['id'], internal_id)
-
-
def __delete_repost(self, repost_id: str) -> tuple[models.AppBskyActorDefs.ProfileViewDetailed | None, dict | None]:
+
database.store_data(self.db, rsp.uri, login.did, SERVICE, {"cid": rsp.cid})
+
database.insert_mapping(self.db, repost["id"], internal_id)
+
+
def __delete_repost(
+
self, repost_id: str
+
) -> tuple[models.AppBskyActorDefs.ProfileViewDetailed | None, dict | None]:
login = self.__check_login()
-
-
repost = database.find_post(self.db, repost_id, self.input.user_id, self.input.service)
+
+
repost = database.find_post(
+
self.db, repost_id, self.input.user_id, self.input.service
+
)
if not repost:
return None, None
-
-
mappings = database.find_mappings(self.db, repost['id'], SERVICE, login.did)
+
+
mappings = database.find_mappings(self.db, repost["id"], SERVICE, login.did)
if mappings:
LOGGER.info("Deleting '%s'...", mappings[0][0])
self.bsky.unrepost(mappings[0][0])
database.delete_post(self.db, mappings[0][0], login.did, SERVICE)
return login, repost
-
+
def delete_repost(self, repost_id: str):
self.__delete_repost(repost_id)
-
-
+79 -61
cross.py
···
+
import re
from abc import ABC, abstractmethod
-
from typing import Callable, Any
+
from datetime import datetime, timezone
+
from typing import Any, Callable
+
from util.database import DataBaseWorker
-
from datetime import datetime, timezone
from util.media import MediaInfo
from util.util import LOGGER, canonical_label
-
import re
-
ALTERNATE = re.compile(r'\S+|\s+')
+
ALTERNATE = re.compile(r"\S+|\s+")
+
# generic token
-
class Token():
+
class Token:
def __init__(self, type: str) -> None:
self.type = type
+
class TextToken(Token):
def __init__(self, text: str) -> None:
-
super().__init__('text')
+
super().__init__("text")
self.text = text
+
# token that represents a link to a website. e.g. [link](https://google.com/)
class LinkToken(Token):
def __init__(self, href: str, label: str) -> None:
-
super().__init__('link')
+
super().__init__("link")
self.href = href
self.label = label
-
-
# token that represents a hashtag. e.g. #SocialMedia
+
+
+
# token that represents a hashtag. e.g. #SocialMedia
class TagToken(Token):
def __init__(self, tag: str) -> None:
-
super().__init__('tag')
+
super().__init__("tag")
self.tag = tag
+
# token that represents a mention of a user.
class MentionToken(Token):
def __init__(self, username: str, uri: str) -> None:
-
super().__init__('mention')
+
super().__init__("mention")
self.username = username
self.uri = uri
-
-
class MediaMeta():
+
+
+
class MediaMeta:
def __init__(self, width: int, height: int, duration: float) -> None:
self.width = width
self.height = height
self.duration = duration
-
+
def get_width(self) -> int:
return self.width
-
+
def get_height(self) -> int:
return self.height
-
+
def get_duration(self) -> float:
return self.duration
-
+
+
class Post(ABC):
@abstractmethod
def get_id(self) -> str:
-
return ''
-
+
return ""
+
@abstractmethod
def get_parent_id(self) -> str | None:
pass
-
+
@abstractmethod
def get_tokens(self) -> list[Token]:
pass
···
@abstractmethod
def get_text_type(self) -> str:
pass
-
+
# post iso timestamp
@abstractmethod
def get_timestamp(self) -> str:
pass
-
+
def get_attachments(self) -> list[MediaInfo]:
return []
-
+
def get_spoiler(self) -> str | None:
return None
-
+
def get_languages(self) -> list[str]:
return []
-
+
def is_sensitive(self) -> bool:
return False
-
+
def get_post_url(self) -> str | None:
return None
+
# generic input service.
# user and service for db queries
-
class Input():
-
def __init__(self, service: str, user_id: str, settings: dict, db: DataBaseWorker) -> None:
+
class Input:
+
def __init__(
+
self, service: str, user_id: str, settings: dict, db: DataBaseWorker
+
) -> None:
self.service = service
self.user_id = user_id
self.settings = settings
self.db = db
-
+
async def listen(self, outputs: list, handler: Callable[[Post], Any]):
pass
-
class Output():
+
+
class Output:
def __init__(self, input: Input, settings: dict, db: DataBaseWorker) -> None:
self.input = input
self.settings = settings
self.db = db
-
+
def accept_post(self, post: Post):
LOGGER.warning('Not Implemented.. "posted" %s', post.get_id())
-
+
def delete_post(self, identifier: str):
LOGGER.warning('Not Implemented.. "deleted" %s', identifier)
-
+
def accept_repost(self, repost_id: str, reposted_id: str):
LOGGER.warning('Not Implemented.. "reblogged" %s, %s', repost_id, reposted_id)
-
+
def delete_repost(self, repost_id: str):
LOGGER.warning('Not Implemented.. "removed reblog" %s', repost_id)
+
def test_filters(tokens: list[Token], filters: list[re.Pattern[str]]):
if not tokens or not filters:
return True
-
-
markdown = ''
-
+
+
markdown = ""
+
for token in tokens:
if isinstance(token, TextToken):
markdown += token.text
elif isinstance(token, LinkToken):
-
markdown += f'[{token.label}]({token.href})'
+
markdown += f"[{token.label}]({token.href})"
elif isinstance(token, TagToken):
-
markdown += '#' + token.tag
+
markdown += "#" + token.tag
elif isinstance(token, MentionToken):
markdown += token.username
-
+
for filter in filters:
if filter.search(markdown):
return False
-
+
return True
-
def split_tokens(tokens: list[Token], max_chars: int, max_link_len: int = 35) -> list[list[Token]]:
+
+
def split_tokens(
+
tokens: list[Token], max_chars: int, max_link_len: int = 35
+
) -> list[list[Token]]:
def new_block():
nonlocal blocks, block, length
if block:
blocks.append(block)
block = []
length = 0
-
+
def append_text(text_segment):
nonlocal block
# if the last element in the current block is also text, just append to it
···
block[-1].text += text_segment
else:
block.append(TextToken(text_segment))
-
+
blocks: list[list[Token]] = []
block: list[Token] = []
length = 0
-
+
for tk in tokens:
if isinstance(tk, TagToken):
-
tag_len = 1 + len(tk.tag) # (#) + tag
+
tag_len = 1 + len(tk.tag) # (#) + tag
if length + tag_len > max_chars:
-
new_block() # create new block if the current one is too large
-
+
new_block() # create new block if the current one is too large
+
block.append(tk)
length += tag_len
-
elif isinstance(tk, LinkToken): # TODO labels should proably be split too
+
elif isinstance(tk, LinkToken): # TODO labels should proably be split too
link_len = len(tk.label)
-
if canonical_label(tk.label, tk.href): # cut down the link if the label is canonical
+
if canonical_label(
+
tk.label, tk.href
+
): # cut down the link if the label is canonical
link_len = min(link_len, max_link_len)
-
+
if length + link_len > max_chars:
new_block()
block.append(tk)
length += link_len
elif isinstance(tk, TextToken):
segments: list[str] = ALTERNATE.findall(tk.text)
-
+
for seg in segments:
seg_len: int = len(seg)
if length + seg_len <= max_chars - (0 if seg.isspace() else 1):
append_text(seg)
length += seg_len
continue
-
+
if length > 0:
new_block()
-
+
if not seg.isspace():
while len(seg) > max_chars - 1:
chunk = seg[: max_chars - 1] + "-"
···
seg = seg[max_chars - 1 :]
else:
while len(seg) > max_chars:
-
chunk = seg[: max_chars]
+
chunk = seg[:max_chars]
append_text(chunk)
new_block()
-
seg = seg[max_chars :]
-
+
seg = seg[max_chars:]
+
if seg:
append_text(seg)
length = len(seg)
-
else: #TODO fix mentions
+
else: # TODO fix mentions
block.append(tk)
-
+
if block:
blocks.append(block)
-
-
return blocks
+
+
return blocks
+59 -54
main.py
···
-
import os
+
import asyncio
import json
-
import asyncio, threading, queue, traceback
+
import os
+
import queue
+
import threading
+
import traceback
-
from util.util import LOGGER, as_json
-
import cross, util.database as database
-
+
import cross
+
import util.database as database
from bluesky.input import BlueskyJetstreamInput
-
from bluesky.output import BlueskyOutputOptions, BlueskyOutput
-
-
from mastodon.input import MastodonInputOptions, MastodonInput
+
from bluesky.output import BlueskyOutput, BlueskyOutputOptions
+
from mastodon.input import MastodonInput, MastodonInputOptions
from mastodon.output import MastodonOutput
-
from misskey.input import MisskeyInput
+
from util.util import LOGGER, as_json
DEFAULT_SETTINGS: dict = {
-
'input': {
-
'type': 'mastodon-wss',
-
'instance': 'env:MASTODON_INSTANCE',
-
'token': 'env:MASTODON_TOKEN',
-
"options": MastodonInputOptions({})
+
"input": {
+
"type": "mastodon-wss",
+
"instance": "env:MASTODON_INSTANCE",
+
"token": "env:MASTODON_TOKEN",
+
"options": MastodonInputOptions({}),
},
-
'outputs': [
+
"outputs": [
{
-
'type': 'bluesky',
-
'handle': 'env:BLUESKY_HANDLE',
-
'app-password': 'env:BLUESKY_APP_PASSWORD',
-
'options': BlueskyOutputOptions({})
+
"type": "bluesky",
+
"handle": "env:BLUESKY_HANDLE",
+
"app-password": "env:BLUESKY_APP_PASSWORD",
+
"options": BlueskyOutputOptions({}),
}
-
]
+
],
}
INPUTS = {
"mastodon-wss": lambda settings, db: MastodonInput(settings, db),
"misskey-wss": lambda settigs, db: MisskeyInput(settigs, db),
-
"bluesky-jetstream-wss": lambda settings, db: BlueskyJetstreamInput(settings, db)
+
"bluesky-jetstream-wss": lambda settings, db: BlueskyJetstreamInput(settings, db),
}
OUTPUTS = {
"bluesky": lambda input, settings, db: BlueskyOutput(input, settings, db),
-
"mastodon": lambda input, settings, db: MastodonOutput(input, settings, db)
+
"mastodon": lambda input, settings, db: MastodonOutput(input, settings, db),
}
+
def execute(data_dir):
if not os.path.exists(data_dir):
os.makedirs(data_dir)
-
-
settings_path = os.path.join(data_dir, 'settings.json')
-
database_path = os.path.join(data_dir, 'data.db')
-
+
+
settings_path = os.path.join(data_dir, "settings.json")
+
database_path = os.path.join(data_dir, "data.db")
+
if not os.path.exists(settings_path):
LOGGER.info("First launch detected! Creating %s and exiting!", settings_path)
-
-
with open(settings_path, 'w') as f:
+
+
with open(settings_path, "w") as f:
f.write(as_json(DEFAULT_SETTINGS, indent=2))
return 0
-
LOGGER.info('Loading settings...')
-
with open(settings_path, 'rb') as f:
+
LOGGER.info("Loading settings...")
+
with open(settings_path, "rb") as f:
settings = json.load(f)
-
-
LOGGER.info('Starting database worker...')
+
+
LOGGER.info("Starting database worker...")
db_worker = database.DataBaseWorker(os.path.abspath(database_path))
-
-
db_worker.execute('PRAGMA foreign_keys = ON;')
-
+
+
db_worker.execute("PRAGMA foreign_keys = ON;")
+
# create the posts table
# id - internal id of the post
# user_id - user id on the service (e.g. a724sknj5y9ydk0w)
···
);
"""
)
-
+
columns = db_worker.execute("PRAGMA table_info(posts)")
column_names = [col[1] for col in columns]
if "reposted_id" not in column_names:
···
ALTER TABLE posts
ADD COLUMN extra_data TEXT NULL
""")
-
+
# create the mappings table
# original_post_id - the post this was mapped from
# mapped_post_id - the post this was mapped to
···
);
"""
)
-
-
input_settings = settings.get('input')
+
+
input_settings = settings.get("input")
if not input_settings:
raise Exception("No input specified!")
-
outputs_settings = settings.get('outputs', [])
-
-
input = INPUTS[input_settings['type']](input_settings, db_worker)
-
+
outputs_settings = settings.get("outputs", [])
+
+
input = INPUTS[input_settings["type"]](input_settings, db_worker)
+
if not outputs_settings:
LOGGER.warning("No outputs specified! Check the config!")
-
+
outputs: list[cross.Output] = []
for output_settings in outputs_settings:
-
outputs.append(OUTPUTS[output_settings['type']](input, output_settings, db_worker))
-
-
LOGGER.info('Starting task worker...')
+
outputs.append(
+
OUTPUTS[output_settings["type"]](input, output_settings, db_worker)
+
)
+
+
LOGGER.info("Starting task worker...")
+
def worker(queue: queue.Queue):
while True:
task = queue.get()
if task is None:
break
-
+
try:
task()
except Exception as e:
···
traceback.print_exc()
finally:
queue.task_done()
-
+
task_queue = queue.Queue()
thread = threading.Thread(target=worker, args=(task_queue,), daemon=True)
thread.start()
-
-
LOGGER.info('Connecting to %s...', input.service)
+
+
LOGGER.info("Connecting to %s...", input.service)
try:
asyncio.run(input.listen(outputs, lambda x: task_queue.put(x)))
except KeyboardInterrupt:
LOGGER.info("Stopping...")
-
+
task_queue.join()
task_queue.put(None)
thread.join()
-
+
if __name__ == "__main__":
-
execute('./data')
+
execute("./data")
+26 -20
mastodon/common.py
···
import cross
from util.media import MediaInfo
+
class MastodonPost(cross.Post):
-
def __init__(self, status: dict, tokens: list[cross.Token], media_attachments: list[MediaInfo]) -> None:
+
def __init__(
+
self,
+
status: dict,
+
tokens: list[cross.Token],
+
media_attachments: list[MediaInfo],
+
) -> None:
super().__init__()
-
self.id = status['id']
-
self.parent_id = status.get('in_reply_to_id')
+
self.id = status["id"]
+
self.parent_id = status.get("in_reply_to_id")
self.tokens = tokens
-
self.content_type = status.get('content_type', 'text/plain')
-
self.timestamp = status['created_at']
+
self.content_type = status.get("content_type", "text/plain")
+
self.timestamp = status["created_at"]
self.media_attachments = media_attachments
-
self.spoiler = status.get('spoiler_text')
-
self.language = [status['language']] if status.get('language') else []
-
self.sensitive = status.get('sensitive', False)
-
self.url = status.get('url')
-
+
self.spoiler = status.get("spoiler_text")
+
self.language = [status["language"]] if status.get("language") else []
+
self.sensitive = status.get("sensitive", False)
+
self.url = status.get("url")
+
def get_id(self) -> str:
return self.id
-
+
def get_parent_id(self) -> str | None:
return self.parent_id
-
+
def get_tokens(self) -> list[cross.Token]:
return self.tokens
-
+
def get_text_type(self) -> str:
return self.content_type
-
+
def get_timestamp(self) -> str:
return self.timestamp
-
+
def get_attachments(self) -> list[MediaInfo]:
return self.media_attachments
-
+
def get_spoiler(self) -> str | None:
return self.spoiler
-
+
def get_languages(self) -> list[str]:
return self.language
-
+
def is_sensitive(self) -> bool:
return self.sensitive or (self.spoiler is not None)
-
+
def get_post_url(self) -> str | None:
-
return self.url
+
return self.url
+124 -96
mastodon/input.py
···
-
import requests, websockets
+
import asyncio
import json
import re
-
import asyncio
+
from typing import Any, Callable
-
from mastodon.common import MastodonPost
+
import requests
+
import websockets
+
+
import cross
+
import util.database as database
import util.html_util as html_util
import util.md_util as md_util
-
-
import cross, util.database as database
-
from util.util import LOGGER, as_envvar
-
from util.media import MediaInfo, download_media
+
from mastodon.common import MastodonPost
from util.database import DataBaseWorker
+
from util.media import MediaInfo, download_media
+
from util.util import LOGGER, as_envvar
-
from typing import Callable, Any
+
ALLOWED_VISIBILITY = ["public", "unlisted"]
+
MARKDOWNY = ["text/x.misskeymarkdown", "text/markdown", "text/plain"]
-
ALLOWED_VISIBILITY = ['public', 'unlisted']
-
MARKDOWNY = ['text/x.misskeymarkdown', 'text/markdown', 'text/plain']
-
class MastodonInputOptions():
+
class MastodonInputOptions:
def __init__(self, o: dict) -> None:
self.allowed_visibility = ALLOWED_VISIBILITY
-
self.filters = [re.compile(f) for f in o.get('regex_filters', [])]
-
-
allowed_visibility = o.get('allowed_visibility')
+
self.filters = [re.compile(f) for f in o.get("regex_filters", [])]
+
+
allowed_visibility = o.get("allowed_visibility")
if allowed_visibility is not None:
if any([v not in ALLOWED_VISIBILITY for v in allowed_visibility]):
-
raise ValueError(f"'allowed_visibility' only accepts {', '.join(ALLOWED_VISIBILITY)}, got: {allowed_visibility}")
+
raise ValueError(
+
f"'allowed_visibility' only accepts {', '.join(ALLOWED_VISIBILITY)}, got: {allowed_visibility}"
+
)
self.allowed_visibility = allowed_visibility
+
class MastodonInput(cross.Input):
def __init__(self, settings: dict, db: DataBaseWorker) -> None:
-
self.options = MastodonInputOptions(settings.get('options', {}))
-
self.token = as_envvar(settings.get('token')) or (_ for _ in ()).throw(ValueError("'token' is required"))
-
instance: str = as_envvar(settings.get('instance')) or (_ for _ in ()).throw(ValueError("'instance' is required"))
-
-
service = instance[:-1] if instance.endswith('/') else instance
-
+
self.options = MastodonInputOptions(settings.get("options", {}))
+
self.token = as_envvar(settings.get("token")) or (_ for _ in ()).throw(
+
ValueError("'token' is required")
+
)
+
instance: str = as_envvar(settings.get("instance")) or (_ for _ in ()).throw(
+
ValueError("'instance' is required")
+
)
+
+
service = instance[:-1] if instance.endswith("/") else instance
+
LOGGER.info("Verifying %s credentails...", service)
-
responce = requests.get(f"{service}/api/v1/accounts/verify_credentials", headers={
-
'Authorization': f'Bearer {self.token}'
-
})
+
responce = requests.get(
+
f"{service}/api/v1/accounts/verify_credentials",
+
headers={"Authorization": f"Bearer {self.token}"},
+
)
if responce.status_code != 200:
LOGGER.error("Failed to validate user credentials!")
responce.raise_for_status()
return
-
+
super().__init__(service, responce.json()["id"], settings, db)
self.streaming = self._get_streaming_url()
-
+
if not self.streaming:
raise Exception("Instance %s does not support streaming!", service)
···
response = requests.get(f"{self.service}/api/v1/instance")
response.raise_for_status()
data: dict = response.json()
-
return (data.get('urls') or {}).get('streaming_api')
+
return (data.get("urls") or {}).get("streaming_api")
def __to_tokens(self, status: dict):
-
content_type = status.get('content_type', 'text/plain')
-
raw_text = status.get('text')
-
+
content_type = status.get("content_type", "text/plain")
+
raw_text = status.get("text")
+
tags: list[str] = []
-
for tag in status.get('tags', []):
-
tags.append(tag['name'])
-
+
for tag in status.get("tags", []):
+
tags.append(tag["name"])
+
mentions: list[tuple[str, str]] = []
-
for mention in status.get('mentions', []):
-
mentions.append(('@' + mention['username'], '@' + mention['acct']))
-
+
for mention in status.get("mentions", []):
+
mentions.append(("@" + mention["username"], "@" + mention["acct"]))
+
if raw_text and content_type in MARKDOWNY:
return md_util.tokenize_markdown(raw_text, tags, mentions)
-
-
akkoma_ext: dict | None = status.get('akkoma', {}).get('source')
+
+
akkoma_ext: dict | None = status.get("akkoma", {}).get("source")
if akkoma_ext:
-
if akkoma_ext.get('mediaType') in MARKDOWNY:
+
if akkoma_ext.get("mediaType") in MARKDOWNY:
return md_util.tokenize_markdown(akkoma_ext["content"], tags, mentions)
-
+
tokenizer = html_util.HTMLPostTokenizer()
tokenizer.mentions = mentions
tokenizer.tags = tags
-
tokenizer.feed(status.get('content', ""))
+
tokenizer.feed(status.get("content", ""))
return tokenizer.get_tokens()
-
+
def _on_create_post(self, outputs: list[cross.Output], status: dict):
# skip events from other users
-
if (status.get('account') or {})['id'] != self.user_id:
+
if (status.get("account") or {})["id"] != self.user_id:
return
-
-
if status.get('visibility') not in self.options.allowed_visibility:
+
+
if status.get("visibility") not in self.options.allowed_visibility:
# Skip f/o and direct posts
-
LOGGER.info("Skipping '%s'! '%s' visibility..", status['id'], status.get('visibility'))
+
LOGGER.info(
+
"Skipping '%s'! '%s' visibility..",
+
status["id"],
+
status.get("visibility"),
+
)
return
-
+
# TODO polls not supported on bsky. maybe 3rd party? skip for now
# we don't handle reblogs. possible with bridgy(?) and self
# we don't handle quotes.
-
if status.get('poll'):
-
LOGGER.info("Skipping '%s'! Contains a poll..", status['id'])
+
if status.get("poll"):
+
LOGGER.info("Skipping '%s'! Contains a poll..", status["id"])
return
-
-
if status.get('quote_id') or status.get('quote'):
-
LOGGER.info("Skipping '%s'! Quote..", status['id'])
+
+
if status.get("quote_id") or status.get("quote"):
+
LOGGER.info("Skipping '%s'! Quote..", status["id"])
return
-
-
reblog: dict | None = status.get('reblog')
+
+
reblog: dict | None = status.get("reblog")
if reblog:
-
if (reblog.get('account') or {})['id'] != self.user_id:
-
LOGGER.info("Skipping '%s'! Reblog of other user..", status['id'])
+
if (reblog.get("account") or {})["id"] != self.user_id:
+
LOGGER.info("Skipping '%s'! Reblog of other user..", status["id"])
return
-
-
success = database.try_insert_repost(self.db, status['id'], reblog['id'], self.user_id, self.service)
+
+
success = database.try_insert_repost(
+
self.db, status["id"], reblog["id"], self.user_id, self.service
+
)
if not success:
-
LOGGER.info("Skipping '%s' as reblogged post was not found in db!", status['id'])
+
LOGGER.info(
+
"Skipping '%s' as reblogged post was not found in db!", status["id"]
+
)
return
-
+
for output in outputs:
-
output.accept_repost(status['id'], reblog['id'])
+
output.accept_repost(status["id"], reblog["id"])
return
-
-
in_reply: str | None = status.get('in_reply_to_id')
-
in_reply_to: str | None = status.get('in_reply_to_account_id')
+
+
in_reply: str | None = status.get("in_reply_to_id")
+
in_reply_to: str | None = status.get("in_reply_to_account_id")
if in_reply_to and in_reply_to != self.user_id:
# We don't support replies.
-
LOGGER.info("Skipping '%s'! Reply to other user..", status['id'])
+
LOGGER.info("Skipping '%s'! Reply to other user..", status["id"])
return
-
-
success = database.try_insert_post(self.db, status['id'], in_reply, self.user_id, self.service)
+
+
success = database.try_insert_post(
+
self.db, status["id"], in_reply, self.user_id, self.service
+
)
if not success:
-
LOGGER.info("Skipping '%s' as parent post was not found in db!", status['id'])
+
LOGGER.info(
+
"Skipping '%s' as parent post was not found in db!", status["id"]
+
)
return
-
+
tokens = self.__to_tokens(status)
if not cross.test_filters(tokens, self.options.filters):
-
LOGGER.info("Skipping '%s'. Matched a filter!", status['id'])
+
LOGGER.info("Skipping '%s'. Matched a filter!", status["id"])
return
-
-
LOGGER.info("Crossposting '%s'...", status['id'])
-
+
+
LOGGER.info("Crossposting '%s'...", status["id"])
+
media_attachments: list[MediaInfo] = []
-
for attachment in status.get('media_attachments', []):
-
LOGGER.info("Downloading %s...", attachment['url'])
-
info = download_media(attachment['url'], attachment.get('description') or '')
+
for attachment in status.get("media_attachments", []):
+
LOGGER.info("Downloading %s...", attachment["url"])
+
info = download_media(
+
attachment["url"], attachment.get("description") or ""
+
)
if not info:
-
LOGGER.error("Skipping '%s'. Failed to download media!", status['id'])
+
LOGGER.error("Skipping '%s'. Failed to download media!", status["id"])
return
media_attachments.append(info)
-
+
cross_post = MastodonPost(status, tokens, media_attachments)
for output in outputs:
output.accept_post(cross_post)
-
+
def _on_delete_post(self, outputs: list[cross.Output], identifier: str):
post = database.find_post(self.db, identifier, self.user_id, self.service)
if not post:
return
-
+
LOGGER.info("Deleting '%s'...", identifier)
-
if post['reposted_id']:
+
if post["reposted_id"]:
for output in outputs:
output.delete_repost(identifier)
else:
for output in outputs:
output.delete_post(identifier)
-
+
database.delete_post(self.db, identifier, self.user_id, self.service)
-
+
def _on_post(self, outputs: list[cross.Output], event: str, payload: str):
match event:
-
case 'update':
+
case "update":
self._on_create_post(outputs, json.loads(payload))
-
case 'delete':
+
case "delete":
self._on_delete_post(outputs, payload)
-
-
async def listen(self, outputs: list[cross.Output], submit: Callable[[Callable[[], Any]], Any]):
+
+
async def listen(
+
self, outputs: list[cross.Output], submit: Callable[[Callable[[], Any]], Any]
+
):
uri = f"{self.streaming}/api/v1/streaming?stream=user&access_token={self.token}"
-
-
async for ws in websockets.connect(uri, extra_headers={"User-Agent": "XPost/0.0.3"}):
+
+
async for ws in websockets.connect(
+
uri, extra_headers={"User-Agent": "XPost/0.0.3"}
+
):
try:
LOGGER.info("Listening to %s...", self.streaming)
-
+
async def listen_for_messages():
async for msg in ws:
data = json.loads(msg)
-
event: str = data.get('event')
-
payload: str = data.get('payload')
-
+
event: str = data.get("event")
+
payload: str = data.get("payload")
+
submit(lambda: self._on_post(outputs, str(event), str(payload)))
-
+
listen = asyncio.create_task(listen_for_messages())
-
+
await asyncio.gather(listen)
except websockets.ConnectionClosedError as e:
LOGGER.error(e, stack_info=True, exc_info=True)
LOGGER.info("Reconnecting to %s...", self.streaming)
-
continue
+
continue
+252 -214
mastodon/output.py
···
-
import requests, time
+
import time
-
import cross, util.database as database
+
import requests
+
+
import cross
import misskey.mfm_util as mfm_util
-
from util.util import LOGGER, as_envvar, canonical_label
+
import util.database as database
+
from util.database import DataBaseWorker
from util.media import MediaInfo
-
from util.database import DataBaseWorker
+
from util.util import LOGGER, as_envvar, canonical_label
POSSIBLE_MIMES = [
-
'audio/ogg',
-
'audio/mp3',
-
'image/webp',
-
'image/jpeg',
-
'image/png',
-
'video/mp4',
-
'video/quicktime',
-
'video/webm'
+
"audio/ogg",
+
"audio/mp3",
+
"image/webp",
+
"image/jpeg",
+
"image/png",
+
"video/mp4",
+
"video/quicktime",
+
"video/webm",
]
-
TEXT_MIMES = [
-
'text/x.misskeymarkdown',
-
'text/markdown',
-
'text/plain'
-
]
+
TEXT_MIMES = ["text/x.misskeymarkdown", "text/markdown", "text/plain"]
+
+
ALLOWED_POSTING_VISIBILITY = ["public", "unlisted", "private"]
-
ALLOWED_POSTING_VISIBILITY = ['public', 'unlisted', 'private']
-
class MastodonOutputOptions():
+
class MastodonOutputOptions:
def __init__(self, o: dict) -> None:
-
self.visibility = 'public'
-
-
visibility = o.get('visibility')
+
self.visibility = "public"
+
+
visibility = o.get("visibility")
if visibility is not None:
if visibility not in ALLOWED_POSTING_VISIBILITY:
-
raise ValueError(f"'visibility' only accepts {', '.join(ALLOWED_POSTING_VISIBILITY)}, got: {visibility}")
+
raise ValueError(
+
f"'visibility' only accepts {', '.join(ALLOWED_POSTING_VISIBILITY)}, got: {visibility}"
+
)
self.visibility = visibility
+
class MastodonOutput(cross.Output):
def __init__(self, input: cross.Input, settings: dict, db: DataBaseWorker) -> None:
super().__init__(input, settings, db)
-
self.options = settings.get('options') or {}
-
self.token = as_envvar(settings.get('token')) or (_ for _ in ()).throw(ValueError("'token' is required"))
-
instance: str = as_envvar(settings.get('instance')) or (_ for _ in ()).throw(ValueError("'instance' is required"))
-
-
self.service = instance[:-1] if instance.endswith('/') else instance
-
+
self.options = settings.get("options") or {}
+
self.token = as_envvar(settings.get("token")) or (_ for _ in ()).throw(
+
ValueError("'token' is required")
+
)
+
instance: str = as_envvar(settings.get("instance")) or (_ for _ in ()).throw(
+
ValueError("'instance' is required")
+
)
+
+
self.service = instance[:-1] if instance.endswith("/") else instance
+
LOGGER.info("Verifying %s credentails...", self.service)
-
responce = requests.get(f"{self.service}/api/v1/accounts/verify_credentials", headers={
-
'Authorization': f'Bearer {self.token}'
-
})
+
responce = requests.get(
+
f"{self.service}/api/v1/accounts/verify_credentials",
+
headers={"Authorization": f"Bearer {self.token}"},
+
)
if responce.status_code != 200:
LOGGER.error("Failed to validate user credentials!")
responce.raise_for_status()
···
self.user_id: str = responce.json()["id"]
LOGGER.info("Getting %s configuration...", self.service)
-
responce = requests.get(f"{self.service}/api/v1/instance", headers={
-
'Authorization': f'Bearer {self.token}'
-
})
+
responce = requests.get(
+
f"{self.service}/api/v1/instance",
+
headers={"Authorization": f"Bearer {self.token}"},
+
)
if responce.status_code != 200:
LOGGER.error("Failed to get instance info!")
responce.raise_for_status()
return
-
+
instance_info: dict = responce.json()
-
configuration: dict = instance_info['configuration']
-
-
statuses_config: dict = configuration.get('statuses', {})
-
self.max_characters: int = statuses_config.get('max_characters', 500)
-
self.max_media_attachments: int = statuses_config.get('max_media_attachments', 4)
-
self.characters_reserved_per_url: int = statuses_config.get('characters_reserved_per_url', 23)
-
-
media_config: dict = configuration.get('media_attachments', {})
-
self.image_size_limit: int = media_config.get('image_size_limit', 16777216)
-
self.video_size_limit: int = media_config.get('video_size_limit', 103809024)
-
self.supported_mime_types: list[str] = media_config.get('supported_mime_types', POSSIBLE_MIMES)
-
+
configuration: dict = instance_info["configuration"]
+
+
statuses_config: dict = configuration.get("statuses", {})
+
self.max_characters: int = statuses_config.get("max_characters", 500)
+
self.max_media_attachments: int = statuses_config.get(
+
"max_media_attachments", 4
+
)
+
self.characters_reserved_per_url: int = statuses_config.get(
+
"characters_reserved_per_url", 23
+
)
+
+
media_config: dict = configuration.get("media_attachments", {})
+
self.image_size_limit: int = media_config.get("image_size_limit", 16777216)
+
self.video_size_limit: int = media_config.get("video_size_limit", 103809024)
+
self.supported_mime_types: list[str] = media_config.get(
+
"supported_mime_types", POSSIBLE_MIMES
+
)
+
# *oma: max post chars
-
max_toot_chars = instance_info.get('max_toot_chars')
+
max_toot_chars = instance_info.get("max_toot_chars")
if max_toot_chars:
self.max_characters: int = max_toot_chars
-
+
# *oma: max upload limit
-
upload_limit = instance_info.get('upload_limit')
+
upload_limit = instance_info.get("upload_limit")
if upload_limit:
self.image_size_limit: int = upload_limit
self.video_size_limit: int = upload_limit
-
+
# chuckya: supported text types
-
chuckya_text_mimes: list[str] = statuses_config.get('supported_mime_types', [])
+
chuckya_text_mimes: list[str] = statuses_config.get("supported_mime_types", [])
self.text_format = next(
-
(mime for mime in TEXT_MIMES if mime in (chuckya_text_mimes)),
-
'text/plain'
+
(mime for mime in TEXT_MIMES if mime in (chuckya_text_mimes)), "text/plain"
)
-
+
# *oma ext: supported text types
-
pleroma = instance_info.get('pleroma')
+
pleroma = instance_info.get("pleroma")
if pleroma:
-
post_formats: list[str] = pleroma.get('metadata', {}).get('post_formats', [])
+
post_formats: list[str] = pleroma.get("metadata", {}).get(
+
"post_formats", []
+
)
self.text_format = next(
-
(mime for mime in TEXT_MIMES if mime in post_formats),
-
self.text_format
+
(mime for mime in TEXT_MIMES if mime in post_formats), self.text_format
)
-
+
def upload_media(self, attachments: list[MediaInfo]) -> list[str] | None:
for a in attachments:
-
if a.mime.startswith('image/') and len(a.io) > self.image_size_limit:
+
if a.mime.startswith("image/") and len(a.io) > self.image_size_limit:
return None
-
-
if a.mime.startswith('video/') and len(a.io) > self.video_size_limit:
+
+
if a.mime.startswith("video/") and len(a.io) > self.video_size_limit:
return None
-
-
if not a.mime.startswith('image/') and not a.mime.startswith('video/'):
+
+
if not a.mime.startswith("image/") and not a.mime.startswith("video/"):
if len(a.io) > 7_000_000:
return None
-
+
uploads: list[dict] = []
for a in attachments:
data = {}
if a.alt:
-
data['description'] = a.alt
-
-
req = requests.post(f"{self.service}/api/v2/media", headers= {
-
'Authorization': f'Bearer {self.token}'
-
}, files={'file': (a.name, a.io, a.mime)}, data=data)
-
+
data["description"] = a.alt
+
+
req = requests.post(
+
f"{self.service}/api/v2/media",
+
headers={"Authorization": f"Bearer {self.token}"},
+
files={"file": (a.name, a.io, a.mime)},
+
data=data,
+
)
+
if req.status_code == 200:
-
LOGGER.info("Uploaded %s! (%s)", a.name, req.json()['id'])
-
uploads.append({
-
'done': True,
-
'id': req.json()['id']
-
})
+
LOGGER.info("Uploaded %s! (%s)", a.name, req.json()["id"])
+
uploads.append({"done": True, "id": req.json()["id"]})
elif req.status_code == 202:
LOGGER.info("Waiting for %s to process!", a.name)
-
uploads.append({
-
'done': False,
-
'id': req.json()['id']
-
})
+
uploads.append({"done": False, "id": req.json()["id"]})
else:
LOGGER.error("Failed to upload %s! %s", a.name, req.text)
req.raise_for_status()
-
-
while any([not val['done'] for val in uploads]):
+
+
while any([not val["done"] for val in uploads]):
LOGGER.info("Waiting for media to process...")
time.sleep(3)
for media in uploads:
-
if media['done']:
+
if media["done"]:
continue
-
-
reqs = requests.get(f'{self.service}/api/v1/media/{media['id']}', headers={
-
'Authorization': f'Bearer {self.token}'
-
})
-
+
+
reqs = requests.get(
+
f"{self.service}/api/v1/media/{media['id']}",
+
headers={"Authorization": f"Bearer {self.token}"},
+
)
+
if reqs.status_code == 206:
continue
-
+
if reqs.status_code == 200:
-
media['done'] = True
+
media["done"] = True
continue
reqs.raise_for_status()
-
-
return [val['id'] for val in uploads]
+
+
return [val["id"] for val in uploads]
def token_to_string(self, tokens: list[cross.Token]) -> str | None:
-
p_text: str = ''
-
+
p_text: str = ""
+
for token in tokens:
if isinstance(token, cross.TextToken):
p_text += token.text
elif isinstance(token, cross.TagToken):
-
p_text += '#' + token.tag
+
p_text += "#" + token.tag
elif isinstance(token, cross.LinkToken):
if canonical_label(token.label, token.href):
p_text += token.href
else:
-
if self.text_format == 'text/plain':
-
p_text += f'{token.label} ({token.href})'
-
elif self.text_format in {'text/x.misskeymarkdown', 'text/markdown'}:
-
p_text += f'[{token.label}]({token.href})'
+
if self.text_format == "text/plain":
+
p_text += f"{token.label} ({token.href})"
+
elif self.text_format in {
+
"text/x.misskeymarkdown",
+
"text/markdown",
+
}:
+
p_text += f"[{token.label}]({token.href})"
else:
return None
-
+
return p_text
def split_tokens_media(self, tokens: list[cross.Token], media: list[MediaInfo]):
-
split_tokens = cross.split_tokens(tokens, self.max_characters, self.characters_reserved_per_url)
+
split_tokens = cross.split_tokens(
+
tokens, self.max_characters, self.characters_reserved_per_url
+
)
post_text: list[str] = []
-
+
for block in split_tokens:
baked_text = self.token_to_string(block)
-
+
if baked_text is None:
return None
post_text.append(baked_text)
-
+
if not post_text:
-
post_text = ['']
-
-
posts: list[dict] = [{"text": post_text, "attachments": []} for post_text in post_text]
+
post_text = [""]
+
+
posts: list[dict] = [
+
{"text": post_text, "attachments": []} for post_text in post_text
+
]
available_indices: list[int] = list(range(len(posts)))
-
+
current_image_post_idx: int | None = None
-
+
def make_blank_post() -> dict:
-
return {
-
"text": '',
-
"attachments": []
-
}
-
+
return {"text": "", "attachments": []}
+
def pop_next_empty_index() -> int:
if available_indices:
return available_indices.pop(0)
···
new_idx = len(posts)
posts.append(make_blank_post())
return new_idx
-
+
for att in media:
if (
current_image_post_idx is not None
-
and len(posts[current_image_post_idx]["attachments"]) < self.max_media_attachments
+
and len(posts[current_image_post_idx]["attachments"])
+
< self.max_media_attachments
):
posts[current_image_post_idx]["attachments"].append(att)
else:
idx = pop_next_empty_index()
posts[idx]["attachments"].append(att)
current_image_post_idx = idx
-
+
result: list[tuple[str, list[MediaInfo]]] = []
-
+
for p in posts:
-
result.append((p['text'], p["attachments"]))
-
+
result.append((p["text"], p["attachments"]))
+
return result
-
+
def accept_post(self, post: cross.Post):
parent_id = post.get_parent_id()
-
+
new_root_id: int | None = None
new_parent_id: int | None = None
-
+
reply_ref: str | None = None
if parent_id:
thread_tuple = database.find_mapped_thread(
···
self.input.user_id,
self.input.service,
self.user_id,
-
self.service
+
self.service,
)
-
+
if not thread_tuple:
LOGGER.error("Failed to find thread tuple in the database!")
return None
-
+
_, reply_ref, new_root_id, new_parent_id = thread_tuple
-
+
lang: str
if post.get_languages():
lang = post.get_languages()[0]
else:
-
lang = 'en'
-
+
lang = "en"
+
post_tokens = post.get_tokens()
if post.get_text_type() == "text/x.misskeymarkdown":
post_tokens, status = mfm_util.strip_mfm(post_tokens)
post_url = post.get_post_url()
if status and post_url:
-
post_tokens.append(cross.TextToken('\n'))
-
post_tokens.append(cross.LinkToken(post_url, "[Post contains MFM, see original]"))
-
+
post_tokens.append(cross.TextToken("\n"))
+
post_tokens.append(
+
cross.LinkToken(post_url, "[Post contains MFM, see original]")
+
)
+
raw_statuses = self.split_tokens_media(post_tokens, post.get_attachments())
if not raw_statuses:
LOGGER.error("Failed to split post into statuses?")
return None
baked_statuses = []
-
+
for status, raw_media in raw_statuses:
media: list[str] | None = None
if raw_media:
···
return None
baked_statuses.append((status, media))
continue
-
baked_statuses.append((status,[]))
-
+
baked_statuses.append((status, []))
+
created_statuses: list[str] = []
-
+
for status, media in baked_statuses:
payload = {
-
'status': status,
-
'media_ids': media or [],
-
'spoiler_text': post.get_spoiler() or '',
-
'visibility': self.options.get('visibility', 'public'),
-
'content_type': self.text_format,
-
'language': lang
+
"status": status,
+
"media_ids": media or [],
+
"spoiler_text": post.get_spoiler() or "",
+
"visibility": self.options.get("visibility", "public"),
+
"content_type": self.text_format,
+
"language": lang,
}
-
+
if media:
-
payload['sensitive'] = post.is_sensitive()
-
+
payload["sensitive"] = post.is_sensitive()
+
if post.get_spoiler():
-
payload['sensitive'] = True
-
+
payload["sensitive"] = True
+
if not status:
-
payload['status'] = '🖼️'
-
+
payload["status"] = "🖼️"
+
if reply_ref:
-
payload['in_reply_to_id'] = reply_ref
-
-
reqs = requests.post(f'{self.service}/api/v1/statuses', headers={
-
'Authorization': f'Bearer {self.token}',
-
'Content-Type': 'application/json'
-
}, json=payload)
-
+
payload["in_reply_to_id"] = reply_ref
+
+
reqs = requests.post(
+
f"{self.service}/api/v1/statuses",
+
headers={
+
"Authorization": f"Bearer {self.token}",
+
"Content-Type": "application/json",
+
},
+
json=payload,
+
)
+
if reqs.status_code != 200:
-
LOGGER.info("Failed to post status! %s - %s", reqs.status_code, reqs.text)
+
LOGGER.info(
+
"Failed to post status! %s - %s", reqs.status_code, reqs.text
+
)
reqs.raise_for_status()
-
-
reply_ref = reqs.json()['id']
+
+
reply_ref = reqs.json()["id"]
LOGGER.info("Created new status %s!", reply_ref)
-
-
created_statuses.append(reqs.json()['id'])
-
-
db_post = database.find_post(self.db, post.get_id(), self.input.user_id, self.input.service)
+
+
created_statuses.append(reqs.json()["id"])
+
+
db_post = database.find_post(
+
self.db, post.get_id(), self.input.user_id, self.input.service
+
)
assert db_post, "ghghghhhhh"
-
-
if new_root_id is None or new_parent_id is None:
+
+
if new_root_id is None or new_parent_id is None:
new_root_id = database.insert_post(
-
self.db,
-
created_statuses[0],
-
self.user_id,
-
self.service
+
self.db, created_statuses[0], self.user_id, self.service
)
new_parent_id = new_root_id
-
database.insert_mapping(self.db, db_post['id'], new_parent_id)
+
database.insert_mapping(self.db, db_post["id"], new_parent_id)
created_statuses = created_statuses[1:]
-
+
for db_id in created_statuses:
new_parent_id = database.insert_reply(
-
self.db,
-
db_id,
-
self.user_id,
-
self.service,
-
new_parent_id,
-
new_root_id
+
self.db, db_id, self.user_id, self.service, new_parent_id, new_root_id
)
-
database.insert_mapping(self.db, db_post['id'], new_parent_id)
-
+
database.insert_mapping(self.db, db_post["id"], new_parent_id)
+
def delete_post(self, identifier: str):
-
post = database.find_post(self.db, identifier, self.input.user_id, self.input.service)
+
post = database.find_post(
+
self.db, identifier, self.input.user_id, self.input.service
+
)
if not post:
return
-
-
mappings = database.find_mappings(self.db, post['id'], self.service, self.user_id)
+
+
mappings = database.find_mappings(
+
self.db, post["id"], self.service, self.user_id
+
)
for mapping in mappings[::-1]:
LOGGER.info("Deleting '%s'...", mapping[0])
-
requests.delete(f'{self.service}/api/v1/statuses/{mapping[0]}', headers={
-
'Authorization': f'Bearer {self.token}'
-
})
+
requests.delete(
+
f"{self.service}/api/v1/statuses/{mapping[0]}",
+
headers={"Authorization": f"Bearer {self.token}"},
+
)
database.delete_post(self.db, mapping[0], self.service, self.user_id)
-
+
def accept_repost(self, repost_id: str, reposted_id: str):
repost = self.__delete_repost(repost_id)
if not repost:
return None
-
-
reposted = database.find_post(self.db, reposted_id, self.input.user_id, self.input.service)
+
+
reposted = database.find_post(
+
self.db, reposted_id, self.input.user_id, self.input.service
+
)
if not reposted:
return
-
-
mappings = database.find_mappings(self.db, reposted['id'], self.service, self.user_id)
+
+
mappings = database.find_mappings(
+
self.db, reposted["id"], self.service, self.user_id
+
)
if mappings:
-
rsp = requests.post(f'{self.service}/api/v1/statuses/{mappings[0][0]}/reblog', headers={
-
'Authorization': f'Bearer {self.token}'
-
})
-
+
rsp = requests.post(
+
f"{self.service}/api/v1/statuses/{mappings[0][0]}/reblog",
+
headers={"Authorization": f"Bearer {self.token}"},
+
)
+
if rsp.status_code != 200:
-
LOGGER.error("Failed to boost status! status_code: %s, msg: %s", rsp.status_code, rsp.content)
+
LOGGER.error(
+
"Failed to boost status! status_code: %s, msg: %s",
+
rsp.status_code,
+
rsp.content,
+
)
return
-
+
internal_id = database.insert_repost(
-
self.db,
-
rsp.json()['id'],
-
reposted['id'],
-
self.user_id,
-
self.service)
-
database.insert_mapping(self.db, repost['id'], internal_id)
-
+
self.db, rsp.json()["id"], reposted["id"], self.user_id, self.service
+
)
+
database.insert_mapping(self.db, repost["id"], internal_id)
+
def __delete_repost(self, repost_id: str) -> dict | None:
-
repost = database.find_post(self.db, repost_id, self.input.user_id, self.input.service)
+
repost = database.find_post(
+
self.db, repost_id, self.input.user_id, self.input.service
+
)
if not repost:
return None
-
-
mappings = database.find_mappings(self.db, repost['id'], self.service, self.user_id)
-
reposted_mappings = database.find_mappings(self.db, repost['reposted_id'], self.service, self.user_id)
+
+
mappings = database.find_mappings(
+
self.db, repost["id"], self.service, self.user_id
+
)
+
reposted_mappings = database.find_mappings(
+
self.db, repost["reposted_id"], self.service, self.user_id
+
)
if mappings and reposted_mappings:
LOGGER.info("Deleting '%s'...", mappings[0][0])
-
requests.post(f'{self.service}/api/v1/statuses/{reposted_mappings[0][0]}/unreblog', headers={
-
'Authorization': f'Bearer {self.token}'
-
})
+
requests.post(
+
f"{self.service}/api/v1/statuses/{reposted_mappings[0][0]}/unreblog",
+
headers={"Authorization": f"Bearer {self.token}"},
+
)
database.delete_post(self.db, mappings[0][0], self.user_id, self.service)
return repost
-
+
def delete_repost(self, repost_id: str):
-
self.__delete_repost(repost_id)
+
self.__delete_repost(repost_id)
+26 -17
misskey/common.py
···
import cross
from util.media import MediaInfo
+
class MisskeyPost(cross.Post):
-
def __init__(self, instance_url: str, note: dict, tokens: list[cross.Token], files: list[MediaInfo]) -> None:
+
def __init__(
+
self,
+
instance_url: str,
+
note: dict,
+
tokens: list[cross.Token],
+
files: list[MediaInfo],
+
) -> None:
super().__init__()
self.note = note
-
self.id = note['id']
-
self.parent_id = note.get('replyId')
+
self.id = note["id"]
+
self.parent_id = note.get("replyId")
self.tokens = tokens
-
self.timestamp = note['createdAt']
+
self.timestamp = note["createdAt"]
self.media_attachments = files
-
self.spoiler = note.get('cw')
-
self.sensitive = any([a.get('isSensitive', False) for a in note.get('files', [])])
-
self.url = instance_url + '/notes/' + note['id']
-
+
self.spoiler = note.get("cw")
+
self.sensitive = any(
+
[a.get("isSensitive", False) for a in note.get("files", [])]
+
)
+
self.url = instance_url + "/notes/" + note["id"]
+
def get_id(self) -> str:
return self.id
-
+
def get_parent_id(self) -> str | None:
return self.parent_id
-
+
def get_tokens(self) -> list[cross.Token]:
return self.tokens
def get_text_type(self) -> str:
return "text/x.misskeymarkdown"
-
+
def get_timestamp(self) -> str:
return self.timestamp
-
+
def get_attachments(self) -> list[MediaInfo]:
return self.media_attachments
-
+
def get_spoiler(self) -> str | None:
return self.spoiler
-
+
def get_languages(self) -> list[str]:
return []
-
+
def is_sensitive(self) -> bool:
return self.sensitive or (self.spoiler is not None)
-
+
def get_post_url(self) -> str | None:
-
return self.url
+
return self.url
+115 -92
misskey/input.py
···
-
import requests, websockets
import asyncio
-
import json, uuid
+
import json
import re
+
import uuid
+
from typing import Any, Callable
-
from misskey.common import MisskeyPost
+
import requests
+
import websockets
-
import cross, util.database as database
+
import cross
+
import util.database as database
import util.md_util as md_util
+
from misskey.common import MisskeyPost
from util.media import MediaInfo, download_media
from util.util import LOGGER, as_envvar
-
from typing import Callable, Any
-
-
ALLOWED_VISIBILITY = ['public', 'home']
-
-
class MisskeyInputOptions():
+
ALLOWED_VISIBILITY = ["public", "home"]
+
+
+
class MisskeyInputOptions:
def __init__(self, o: dict) -> None:
self.allowed_visibility = ALLOWED_VISIBILITY
-
self.filters = [re.compile(f) for f in o.get('regex_filters', [])]
-
-
allowed_visibility = o.get('allowed_visibility')
+
self.filters = [re.compile(f) for f in o.get("regex_filters", [])]
+
+
allowed_visibility = o.get("allowed_visibility")
if allowed_visibility is not None:
if any([v not in ALLOWED_VISIBILITY for v in allowed_visibility]):
-
raise ValueError(f"'allowed_visibility' only accepts {', '.join(ALLOWED_VISIBILITY)}, got: {allowed_visibility}")
+
raise ValueError(
+
f"'allowed_visibility' only accepts {', '.join(ALLOWED_VISIBILITY)}, got: {allowed_visibility}"
+
)
self.allowed_visibility = allowed_visibility
+
class MisskeyInput(cross.Input):
def __init__(self, settings: dict, db: cross.DataBaseWorker) -> None:
-
self.options = MisskeyInputOptions(settings.get('options', {}))
-
self.token = as_envvar(settings.get('token')) or (_ for _ in ()).throw(ValueError("'token' is required"))
-
instance: str = as_envvar(settings.get('instance')) or (_ for _ in ()).throw(ValueError("'instance' is required"))
-
-
service = instance[:-1] if instance.endswith('/') else instance
-
+
self.options = MisskeyInputOptions(settings.get("options", {}))
+
self.token = as_envvar(settings.get("token")) or (_ for _ in ()).throw(
+
ValueError("'token' is required")
+
)
+
instance: str = as_envvar(settings.get("instance")) or (_ for _ in ()).throw(
+
ValueError("'instance' is required")
+
)
+
+
service = instance[:-1] if instance.endswith("/") else instance
+
LOGGER.info("Verifying %s credentails...", service)
-
responce = requests.post(f"{instance}/api/i", json={ 'i': self.token }, headers={
-
"Content-Type": "application/json"
-
})
+
responce = requests.post(
+
f"{instance}/api/i",
+
json={"i": self.token},
+
headers={"Content-Type": "application/json"},
+
)
if responce.status_code != 200:
LOGGER.error("Failed to validate user credentials!")
responce.raise_for_status()
return
-
+
super().__init__(service, responce.json()["id"], settings, db)
-
+
def _on_note(self, outputs: list[cross.Output], note: dict):
-
if note['userId'] != self.user_id:
+
if note["userId"] != self.user_id:
return
-
-
if note.get('visibility') not in self.options.allowed_visibility:
-
LOGGER.info("Skipping '%s'! '%s' visibility..", note['id'], note.get('visibility'))
+
+
if note.get("visibility") not in self.options.allowed_visibility:
+
LOGGER.info(
+
"Skipping '%s'! '%s' visibility..", note["id"], note.get("visibility")
+
)
return
-
+
# TODO polls not supported on bsky. maybe 3rd party? skip for now
# we don't handle reblogs. possible with bridgy(?) and self
-
if note.get('poll'):
-
LOGGER.info("Skipping '%s'! Contains a poll..", note['id'])
+
if note.get("poll"):
+
LOGGER.info("Skipping '%s'! Contains a poll..", note["id"])
return
-
-
renote: dict | None = note.get('renote')
+
+
renote: dict | None = note.get("renote")
if renote:
-
if note.get('text') is not None:
-
LOGGER.info("Skipping '%s'! Quote..", note['id'])
+
if note.get("text") is not None:
+
LOGGER.info("Skipping '%s'! Quote..", note["id"])
return
-
-
if renote.get('userId') != self.user_id:
-
LOGGER.info("Skipping '%s'! Reblog of other user..", note['id'])
+
+
if renote.get("userId") != self.user_id:
+
LOGGER.info("Skipping '%s'! Reblog of other user..", note["id"])
return
-
-
success = database.try_insert_repost(self.db, note['id'], renote['id'], self.user_id, self.service)
+
+
success = database.try_insert_repost(
+
self.db, note["id"], renote["id"], self.user_id, self.service
+
)
if not success:
-
LOGGER.info("Skipping '%s' as renoted note was not found in db!", note['id'])
+
LOGGER.info(
+
"Skipping '%s' as renoted note was not found in db!", note["id"]
+
)
return
-
+
for output in outputs:
-
output.accept_repost(note['id'], renote['id'])
+
output.accept_repost(note["id"], renote["id"])
return
-
-
reply_id: str | None = note.get('replyId')
+
+
reply_id: str | None = note.get("replyId")
if reply_id:
-
if note.get('reply', {}).get('userId') != self.user_id:
-
LOGGER.info("Skipping '%s'! Reply to other user..", note['id'])
+
if note.get("reply", {}).get("userId") != self.user_id:
+
LOGGER.info("Skipping '%s'! Reply to other user..", note["id"])
return
-
-
success = database.try_insert_post(self.db, note['id'], reply_id, self.user_id, self.service)
+
+
success = database.try_insert_post(
+
self.db, note["id"], reply_id, self.user_id, self.service
+
)
if not success:
-
LOGGER.info("Skipping '%s' as parent note was not found in db!", note['id'])
+
LOGGER.info("Skipping '%s' as parent note was not found in db!", note["id"])
return
-
-
mention_handles: dict = note.get('mentionHandles') or {}
-
tags: list[str] = note.get('tags') or []
-
+
+
mention_handles: dict = note.get("mentionHandles") or {}
+
tags: list[str] = note.get("tags") or []
+
handles: list[tuple[str, str]] = []
for key, value in mention_handles.items():
handles.append((value, value))
-
-
tokens = md_util.tokenize_markdown(note.get('text', ''), tags, handles)
+
+
tokens = md_util.tokenize_markdown(note.get("text", ""), tags, handles)
if not cross.test_filters(tokens, self.options.filters):
-
LOGGER.info("Skipping '%s'. Matched a filter!", note['id'])
+
LOGGER.info("Skipping '%s'. Matched a filter!", note["id"])
return
-
-
LOGGER.info("Crossposting '%s'...", note['id'])
-
+
+
LOGGER.info("Crossposting '%s'...", note["id"])
+
media_attachments: list[MediaInfo] = []
-
for attachment in note.get('files', []):
-
LOGGER.info("Downloading %s...", attachment['url'])
-
info = download_media(attachment['url'], attachment.get('comment') or '')
+
for attachment in note.get("files", []):
+
LOGGER.info("Downloading %s...", attachment["url"])
+
info = download_media(attachment["url"], attachment.get("comment") or "")
if not info:
-
LOGGER.error("Skipping '%s'. Failed to download media!", note['id'])
+
LOGGER.error("Skipping '%s'. Failed to download media!", note["id"])
return
media_attachments.append(info)
-
+
cross_post = MisskeyPost(self.service, note, tokens, media_attachments)
for output in outputs:
output.accept_post(cross_post)
-
+
def _on_delete(self, outputs: list[cross.Output], note: dict):
# TODO handle deletes
pass
-
+
def _on_message(self, outputs: list[cross.Output], data: dict):
-
-
if data['type'] == 'channel':
-
type: str = data['body']['type']
-
if type == 'note' or type == 'reply':
-
note_body = data['body']['body']
+
if data["type"] == "channel":
+
type: str = data["body"]["type"]
+
if type == "note" or type == "reply":
+
note_body = data["body"]["body"]
self._on_note(outputs, note_body)
return
-
+
pass
-
+
async def _send_keepalive(self, ws: websockets.WebSocketClientProtocol):
while ws.open:
try:
···
except Exception as e:
LOGGER.error(f"Error sending keepalive: {e}")
break
-
+
async def _subscribe_to_home(self, ws: websockets.WebSocketClientProtocol):
-
await ws.send(json.dumps({
-
"type": "connect",
-
"body": {
-
"channel": "homeTimeline",
-
"id": str(uuid.uuid4())
-
}
-
}))
+
await ws.send(
+
json.dumps(
+
{
+
"type": "connect",
+
"body": {"channel": "homeTimeline", "id": str(uuid.uuid4())},
+
}
+
)
+
)
LOGGER.info("Subscribed to 'homeTimeline' channel...")
-
-
-
async def listen(self, outputs: list[cross.Output], submit: Callable[[Callable[[], Any]], Any]):
-
streaming: str = f"wss://{self.service.split("://", 1)[1]}"
+
+
async def listen(
+
self, outputs: list[cross.Output], submit: Callable[[Callable[[], Any]], Any]
+
):
+
streaming: str = f"wss://{self.service.split('://', 1)[1]}"
url: str = f"{streaming}/streaming?i={self.token}"
-
-
async for ws in websockets.connect(url, extra_headers={"User-Agent": "XPost/0.0.3"}):
+
+
async for ws in websockets.connect(
+
url, extra_headers={"User-Agent": "XPost/0.0.3"}
+
):
try:
LOGGER.info("Listening to %s...", streaming)
await self._subscribe_to_home(ws)
-
+
async def listen_for_messages():
async for msg in ws:
# TODO listen to deletes somehow
submit(lambda: self._on_message(outputs, json.loads(msg)))
-
+
keepalive = asyncio.create_task(self._send_keepalive(ws))
listen = asyncio.create_task(listen_for_messages())
-
+
await asyncio.gather(keepalive, listen)
except websockets.ConnectionClosedError as e:
LOGGER.error(e, stack_info=True, exc_info=True)
LOGGER.info("Reconnecting to %s...", streaming)
-
continue
+
continue
+8 -5
misskey/mfm_util.py
···
-
import re, cross
+
import re
+
+
import cross
+
+
MFM_PATTERN = re.compile(r"\$\[([^\[\]]+)\]")
-
MFM_PATTERN = re.compile(r'\$\[([^\[\]]+)\]')
def strip_mfm(tokens: list[cross.Token]) -> tuple[list[cross.Token], bool]:
modified = False
···
return tokens, modified
+
def __strip_mfm(text: str) -> str:
def match_contents(match: re.Match[str]):
content = match.group(1).strip()
-
parts = content.split(' ', 1)
-
return parts[1] if len(parts) > 1 else ''
+
parts = content.split(" ", 1)
+
return parts[1] if len(parts) > 1 else ""
while MFM_PATTERN.search(text):
text = MFM_PATTERN.sub(match_contents, text)
return text
-
+118 -67
util/database.py
···
+
import json
+
import queue
import sqlite3
-
from concurrent.futures import Future
import threading
-
import queue
-
import json
+
from concurrent.futures import Future
-
class DataBaseWorker():
+
+
class DataBaseWorker:
def __init__(self, database: str) -> None:
super(DataBaseWorker, self).__init__()
self.database = database
···
self.conn = sqlite3.connect(self.database, check_same_thread=False)
self.lock = threading.Lock()
self.thread.start()
-
+
def _run(self):
while not self.shutdown_event.is_set():
try:
···
self.queue.task_done()
except queue.Empty:
continue
-
-
def execute(self, sql: str, params = ()):
+
+
def execute(self, sql: str, params=()):
def task(conn: sqlite3.Connection):
cursor = conn.execute(sql, params)
conn.commit()
return cursor.fetchall()
-
+
future = Future()
self.queue.put((task, future))
return future.result()
-
+
def close(self):
self.shutdown_event.set()
self.thread.join()
with self.lock:
self.conn.close()
+
def try_insert_repost(
db: DataBaseWorker,
post_id: str,
reposted_id: str,
input_user: str,
-
input_service: str) -> bool:
-
+
input_service: str,
+
) -> bool:
reposted = find_post(db, reposted_id, input_user, input_service)
if not reposted:
return False
-
-
insert_repost(db, post_id, reposted['id'], input_user, input_service)
+
+
insert_repost(db, post_id, reposted["id"], input_user, input_service)
return True
-
+
def try_insert_post(
-
db: DataBaseWorker,
+
db: DataBaseWorker,
post_id: str,
in_reply: str | None,
input_user: str,
-
input_service: str) -> bool:
+
input_service: str,
+
) -> bool:
root_id = None
parent_id = None
-
+
if in_reply:
parent_post = find_post(db, in_reply, input_user, input_service)
if not parent_post:
return False
-
-
root_id = parent_post['id']
+
+
root_id = parent_post["id"]
parent_id = root_id
-
if parent_post['root_id']:
-
root_id = parent_post['root_id']
-
+
if parent_post["root_id"]:
+
root_id = parent_post["root_id"]
+
if root_id and parent_id:
-
insert_reply(db,post_id, input_user, input_service, parent_id, root_id)
+
insert_reply(db, post_id, input_user, input_service, parent_id, root_id)
else:
insert_post(db, post_id, input_user, input_service)
-
+
return True
-
def insert_repost(db: DataBaseWorker, identifier: str, reposted_id: int, user_id: str, serivce: str) -> int:
+
+
def insert_repost(
+
db: DataBaseWorker, identifier: str, reposted_id: int, user_id: str, serivce: str
+
) -> int:
db.execute(
"""
INSERT INTO posts (user_id, service, identifier, reposted_id)
VALUES (?, ?, ?, ?);
-
""", (user_id, serivce, identifier, reposted_id))
+
""",
+
(user_id, serivce, identifier, reposted_id),
+
)
return db.execute("SELECT last_insert_rowid();", ())[0][0]
+
def insert_post(db: DataBaseWorker, identifier: str, user_id: str, serivce: str) -> int:
db.execute(
"""
INSERT INTO posts (user_id, service, identifier)
VALUES (?, ?, ?);
-
""", (user_id, serivce, identifier))
+
""",
+
(user_id, serivce, identifier),
+
)
return db.execute("SELECT last_insert_rowid();", ())[0][0]
-
def insert_reply(db: DataBaseWorker, identifier: str, user_id: str, serivce: str, parent: int, root: int) -> int:
+
+
def insert_reply(
+
db: DataBaseWorker,
+
identifier: str,
+
user_id: str,
+
serivce: str,
+
parent: int,
+
root: int,
+
) -> int:
db.execute(
"""
INSERT INTO posts (user_id, service, identifier, parent_id, root_id)
VALUES (?, ?, ?, ?, ?);
-
""", (user_id, serivce, identifier, parent, root))
+
""",
+
(user_id, serivce, identifier, parent, root),
+
)
return db.execute("SELECT last_insert_rowid();", ())[0][0]
+
def insert_mapping(db: DataBaseWorker, original: int, mapped: int):
-
db.execute("""
+
db.execute(
+
"""
INSERT INTO mappings (original_post_id, mapped_post_id)
VALUES (?, ?);
-
""", (original, mapped))
+
""",
+
(original, mapped),
+
)
+
def delete_post(db: DataBaseWorker, identifier: str, user_id: str, serivce: str):
db.execute(
···
WHERE identifier = ?
AND service = ?
AND user_id = ?
-
""", (identifier, serivce, user_id))
-
+
""",
+
(identifier, serivce, user_id),
+
)
+
+
def fetch_data(db: DataBaseWorker, identifier: str, user_id: str, service: str) -> dict:
result = db.execute(
"""
···
WHERE identifier = ?
AND user_id = ?
AND service = ?
-
""", (identifier, user_id, service))
+
""",
+
(identifier, user_id, service),
+
)
if not result or not result[0]:
return {}
return json.loads(result[0][0])
-
def store_data(db: DataBaseWorker, identifier: str, user_id: str, service: str, extra_data: dict) -> None:
+
+
def store_data(
+
db: DataBaseWorker, identifier: str, user_id: str, service: str, extra_data: dict
+
) -> None:
db.execute(
"""
UPDATE posts
···
AND user_id = ?
AND service = ?
""",
-
(json.dumps(extra_data), identifier, user_id, service)
+
(json.dumps(extra_data), identifier, user_id, service),
)
-
def find_mappings(db: DataBaseWorker, original_post: int, service: str, user_id: str) -> list[str]:
+
+
def find_mappings(
+
db: DataBaseWorker, original_post: int, service: str, user_id: str
+
) -> list[str]:
return db.execute(
"""
SELECT p.identifier
···
AND p.user_id = ?
ORDER BY p.id;
""",
-
(original_post, service, user_id))
-
+
(original_post, service, user_id),
+
)
+
+
def find_post_by_id(db: DataBaseWorker, id: int) -> dict | None:
result = db.execute(
"""
SELECT user_id, service, identifier, parent_id, root_id, reposted_id
FROM posts
WHERE id = ?
-
""", (id,))
+
""",
+
(id,),
+
)
if not result:
return None
user_id, service, identifier, parent_id, root_id, reposted_id = result[0]
return {
-
'user_id': user_id,
-
'service': service,
-
'identifier': identifier,
-
'parent_id': parent_id,
-
'root_id': root_id,
-
'reposted_id': reposted_id
+
"user_id": user_id,
+
"service": service,
+
"identifier": identifier,
+
"parent_id": parent_id,
+
"root_id": root_id,
+
"reposted_id": reposted_id,
}
-
def find_post(db: DataBaseWorker, identifier: str, user_id: str, service: str) -> dict | None:
+
+
def find_post(
+
db: DataBaseWorker, identifier: str, user_id: str, service: str
+
) -> dict | None:
result = db.execute(
"""
SELECT id, parent_id, root_id, reposted_id
···
WHERE identifier = ?
AND user_id = ?
AND service = ?
-
""", (identifier, user_id, service))
+
""",
+
(identifier, user_id, service),
+
)
if not result:
return None
id, parent_id, root_id, reposted_id = result[0]
return {
-
'id': id,
-
'parent_id': parent_id,
-
'root_id': root_id,
-
'reposted_id': reposted_id
+
"id": id,
+
"parent_id": parent_id,
+
"root_id": root_id,
+
"reposted_id": reposted_id,
}
+
def find_mapped_thread(
-
db: DataBaseWorker,
+
db: DataBaseWorker,
parent_id: str,
input_user: str,
input_service: str,
output_user: str,
-
output_service: str):
-
+
output_service: str,
+
):
reply_data: dict | None = find_post(db, parent_id, input_user, input_service)
if not reply_data:
return None
-
-
reply_mappings: list[str] | None = find_mappings(db, reply_data['id'], output_service, output_user)
+
+
reply_mappings: list[str] | None = find_mappings(
+
db, reply_data["id"], output_service, output_user
+
)
if not reply_mappings:
return None
-
+
reply_identifier: str = reply_mappings[-1]
root_identifier: str = reply_mappings[0]
-
if reply_data['root_id']:
-
root_data = find_post_by_id(db, reply_data['root_id'])
+
if reply_data["root_id"]:
+
root_data = find_post_by_id(db, reply_data["root_id"])
if not root_data:
return None
-
-
root_mappings = find_mappings(db, reply_data['root_id'], output_service, output_user)
+
+
root_mappings = find_mappings(
+
db, reply_data["root_id"], output_service, output_user
+
)
if not root_mappings:
return None
root_identifier = root_mappings[0]
-
+
return (
-
root_identifier[0], # real ids
+
root_identifier[0], # real ids
reply_identifier[0],
-
reply_data['root_id'], # db ids
-
reply_data['id']
-
)
+
reply_data["root_id"], # db ids
+
reply_data["id"],
+
)
+82 -79
util/html_util.py
···
from html.parser import HTMLParser
+
import cross
+
class HTMLPostTokenizer(HTMLParser):
def __init__(self) -> None:
super().__init__()
self.tokens: list[cross.Token] = []
-
+
self.mentions: list[tuple[str, str]]
self.tags: list[str]
-
+
self.in_pre = False
self.in_code = False
-
+
self.current_tag_stack = []
self.list_stack = []
-
+
self.anchor_stack = []
self.anchor_data = []
-
+
def handle_starttag(self, tag: str, attrs: list[tuple[str, str | None]]) -> None:
attrs_dict = dict(attrs)
-
+
def append_newline():
if self.tokens:
last_token = self.tokens[-1]
-
if isinstance(last_token, cross.TextToken) and not last_token.text.endswith('\n'):
-
self.tokens.append(cross.TextToken('\n'))
-
+
if isinstance(
+
last_token, cross.TextToken
+
) and not last_token.text.endswith("\n"):
+
self.tokens.append(cross.TextToken("\n"))
+
match tag:
-
case 'br':
-
self.tokens.append(cross.TextToken(' \n'))
-
case 'a':
-
href = attrs_dict.get('href', '')
+
case "br":
+
self.tokens.append(cross.TextToken(" \n"))
+
case "a":
+
href = attrs_dict.get("href", "")
self.anchor_stack.append(href)
-
case 'strong', 'b':
-
self.tokens.append(cross.TextToken('**'))
-
case 'em', 'i':
-
self.tokens.append(cross.TextToken('*'))
-
case 'del', 's':
-
self.tokens.append(cross.TextToken('~~'))
-
case 'code':
+
case "strong", "b":
+
self.tokens.append(cross.TextToken("**"))
+
case "em", "i":
+
self.tokens.append(cross.TextToken("*"))
+
case "del", "s":
+
self.tokens.append(cross.TextToken("~~"))
+
case "code":
if not self.in_pre:
-
self.tokens.append(cross.TextToken('`'))
+
self.tokens.append(cross.TextToken("`"))
self.in_code = True
-
case 'pre':
+
case "pre":
append_newline()
-
self.tokens.append(cross.TextToken('```\n'))
+
self.tokens.append(cross.TextToken("```\n"))
self.in_pre = True
-
case 'blockquote':
+
case "blockquote":
append_newline()
-
self.tokens.append(cross.TextToken('> '))
-
case 'ul', 'ol':
+
self.tokens.append(cross.TextToken("> "))
+
case "ul", "ol":
self.list_stack.append(tag)
append_newline()
-
case 'li':
-
indent = ' ' * (len(self.list_stack) - 1)
-
if self.list_stack and self.list_stack[-1] == 'ul':
-
self.tokens.append(cross.TextToken(f'{indent}- '))
-
elif self.list_stack and self.list_stack[-1] == 'ol':
-
self.tokens.append(cross.TextToken(f'{indent}1. '))
+
case "li":
+
indent = " " * (len(self.list_stack) - 1)
+
if self.list_stack and self.list_stack[-1] == "ul":
+
self.tokens.append(cross.TextToken(f"{indent}- "))
+
elif self.list_stack and self.list_stack[-1] == "ol":
+
self.tokens.append(cross.TextToken(f"{indent}1. "))
case _:
-
if tag in {'h1', 'h2', 'h3', 'h4', 'h5', 'h6'}:
+
if tag in {"h1", "h2", "h3", "h4", "h5", "h6"}:
level = int(tag[1])
self.tokens.append(cross.TextToken("\n" + "#" * level + " "))
-
+
self.current_tag_stack.append(tag)
-
+
def handle_data(self, data: str) -> None:
if self.anchor_stack:
self.anchor_data.append(data)
else:
self.tokens.append(cross.TextToken(data))
-
+
def handle_endtag(self, tag: str) -> None:
if not self.current_tag_stack:
return
-
+
if tag in self.current_tag_stack:
self.current_tag_stack.remove(tag)
-
+
match tag:
-
case 'p':
-
self.tokens.append(cross.TextToken('\n\n'))
-
case 'a':
+
case "p":
+
self.tokens.append(cross.TextToken("\n\n"))
+
case "a":
href = self.anchor_stack.pop()
-
anchor_data = ''.join(self.anchor_data)
+
anchor_data = "".join(self.anchor_data)
self.anchor_data = []
-
-
if anchor_data.startswith('#'):
+
+
if anchor_data.startswith("#"):
as_tag = anchor_data[1:].lower()
if any(as_tag == block for block in self.tags):
self.tokens.append(cross.TagToken(anchor_data[1:]))
-
elif anchor_data.startswith('@'):
+
elif anchor_data.startswith("@"):
match = next(
-
(pair for pair in self.mentions if anchor_data in pair),
-
None
+
(pair for pair in self.mentions if anchor_data in pair), None
)
-
+
if match:
-
self.tokens.append(cross.MentionToken(match[1], ''))
+
self.tokens.append(cross.MentionToken(match[1], ""))
else:
self.tokens.append(cross.LinkToken(href, anchor_data))
-
case 'strong', 'b':
-
self.tokens.append(cross.TextToken('**'))
-
case 'em', 'i':
-
self.tokens.append(cross.TextToken('*'))
-
case 'del', 's':
-
self.tokens.append(cross.TextToken('~~'))
-
case 'code':
+
case "strong", "b":
+
self.tokens.append(cross.TextToken("**"))
+
case "em", "i":
+
self.tokens.append(cross.TextToken("*"))
+
case "del", "s":
+
self.tokens.append(cross.TextToken("~~"))
+
case "code":
if not self.in_pre and self.in_code:
-
self.tokens.append(cross.TextToken('`'))
+
self.tokens.append(cross.TextToken("`"))
self.in_code = False
-
case 'pre':
-
self.tokens.append(cross.TextToken('\n```\n'))
+
case "pre":
+
self.tokens.append(cross.TextToken("\n```\n"))
self.in_pre = False
-
case 'blockquote':
-
self.tokens.append(cross.TextToken('\n'))
-
case 'ul', 'ol':
+
case "blockquote":
+
self.tokens.append(cross.TextToken("\n"))
+
case "ul", "ol":
if self.list_stack:
self.list_stack.pop()
-
self.tokens.append(cross.TextToken('\n'))
-
case 'li':
-
self.tokens.append(cross.TextToken('\n'))
+
self.tokens.append(cross.TextToken("\n"))
+
case "li":
+
self.tokens.append(cross.TextToken("\n"))
case _:
-
if tag in ['h1', 'h2', 'h3', 'h4', 'h5', 'h6']:
-
self.tokens.append(cross.TextToken('\n'))
-
+
if tag in ["h1", "h2", "h3", "h4", "h5", "h6"]:
+
self.tokens.append(cross.TextToken("\n"))
+
def get_tokens(self) -> list[cross.Token]:
if not self.tokens:
return []
-
+
combined: list[cross.Token] = []
buffer: list[str] = []
-
+
def flush_buffer():
if buffer:
-
merged = ''.join(buffer)
+
merged = "".join(buffer)
combined.append(cross.TextToken(text=merged))
buffer.clear()
···
else:
flush_buffer()
combined.append(token)
-
+
flush_buffer()
-
+
if combined and isinstance(combined[-1], cross.TextToken):
-
if combined[-1].text.endswith('\n\n'):
+
if combined[-1].text.endswith("\n\n"):
combined[-1] = cross.TextToken(combined[-1].text[:-2])
return combined
-
+
def reset(self):
"""Reset the parser state for reuse."""
super().reset()
self.tokens = []
-
+
self.mentions = []
self.tags = []
-
+
self.in_pre = False
self.in_code = False
-
+
self.current_tag_stack = []
self.anchor_stack = []
-
self.list_stack = []
+
self.list_stack = []
+44 -33
util/md_util.py
···
import util.html_util as html_util
import util.util as util
-
URL = re.compile(r'(?:(?:[A-Za-z][A-Za-z0-9+.-]*://)|mailto:)[^\s]+', re.IGNORECASE)
-
MD_INLINE_LINK = re.compile(r"\[([^\]]+)\]\(\s*((?:(?:[A-Za-z][A-Za-z0-9+.\-]*://)|mailto:)[^\s\)]+)\s*\)", re.IGNORECASE)
-
MD_AUTOLINK = re.compile(r"<((?:(?:[A-Za-z][A-Za-z0-9+.\-]*://)|mailto:)[^\s>]+)>", re.IGNORECASE)
-
HASHTAG = re.compile(r'(?<!\w)\#([\w]+)')
-
FEDIVERSE_HANDLE = re.compile(r'(?<![\w@])@([\w\.-]+)(?:@([\w\.-]+\.[\w\.-]+))?')
+
URL = re.compile(r"(?:(?:[A-Za-z][A-Za-z0-9+.-]*://)|mailto:)[^\s]+", re.IGNORECASE)
+
MD_INLINE_LINK = re.compile(
+
r"\[([^\]]+)\]\(\s*((?:(?:[A-Za-z][A-Za-z0-9+.\-]*://)|mailto:)[^\s\)]+)\s*\)",
+
re.IGNORECASE,
+
)
+
MD_AUTOLINK = re.compile(
+
r"<((?:(?:[A-Za-z][A-Za-z0-9+.\-]*://)|mailto:)[^\s>]+)>", re.IGNORECASE
+
)
+
HASHTAG = re.compile(r"(?<!\w)\#([\w]+)")
+
FEDIVERSE_HANDLE = re.compile(r"(?<![\w@])@([\w\.-]+)(?:@([\w\.-]+\.[\w\.-]+))?")
-
def tokenize_markdown(text: str, tags: list[str], handles: list[tuple[str, str]]) -> list[cross.Token]:
+
+
def tokenize_markdown(
+
text: str, tags: list[str], handles: list[tuple[str, str]]
+
) -> list[cross.Token]:
if not text:
return []
-
+
tokenizer = html_util.HTMLPostTokenizer()
tokenizer.mentions = handles
tokenizer.tags = tags
tokenizer.feed(text)
html_tokens = tokenizer.get_tokens()
-
+
tokens: list[cross.Token] = []
-
+
for tk in html_tokens:
if isinstance(tk, cross.TextToken):
tokens.extend(__tokenize_md(tk.text, tags, handles))
···
if not tk.label or util.canonical_label(tk.label, tk.href):
tokens.append(tk)
continue
-
+
tokens.extend(__tokenize_md(f"[{tk.label}]({tk.href})", tags, handles))
else:
tokens.append(tk)
-
+
return tokens
-
+
-
def __tokenize_md(text: str, tags: list[str], handles: list[tuple[str, str]]) -> list[cross.Token]:
+
def __tokenize_md(
+
text: str, tags: list[str], handles: list[tuple[str, str]]
+
) -> list[cross.Token]:
index: int = 0
total: int = len(text)
buffer: list[str] = []
-
+
tokens: list[cross.Token] = []
-
+
def flush():
nonlocal buffer
if buffer:
-
tokens.append(cross.TextToken(''.join(buffer)))
+
tokens.append(cross.TextToken("".join(buffer)))
buffer = []
-
+
while index < total:
-
if text[index] == '[':
+
if text[index] == "[":
md_inline = MD_INLINE_LINK.match(text, index)
if md_inline:
flush()
···
tokens.append(cross.LinkToken(href, label))
index = md_inline.end()
continue
-
-
if text[index] == '<':
+
+
if text[index] == "<":
md_auto = MD_AUTOLINK.match(text, index)
if md_auto:
flush()
···
tokens.append(cross.LinkToken(href, href))
index = md_auto.end()
continue
-
-
if text[index] == '#':
+
+
if text[index] == "#":
tag = HASHTAG.match(text, index)
if tag:
tag_text = tag.group(1)
···
tokens.append(cross.TagToken(tag_text))
index = tag.end()
continue
-
-
if text[index] == '@':
+
+
if text[index] == "@":
handle = FEDIVERSE_HANDLE.match(text, index)
if handle:
handle_text = handle.group(0)
stripped_handle = handle_text.strip()
-
+
match = next(
-
(pair for pair in handles if stripped_handle in pair),
-
None
+
(pair for pair in handles if stripped_handle in pair), None
)
-
+
if match:
flush()
-
tokens.append(cross.MentionToken(match[1], '')) # TODO: misskey doesn’t provide a uri
+
tokens.append(
+
cross.MentionToken(match[1], "")
+
) # TODO: misskey doesn’t provide a uri
index = handle.end()
continue
-
+
url = URL.match(text, index)
if url:
flush()
···
tokens.append(cross.LinkToken(href, href))
index = url.end()
continue
-
+
buffer.append(text[index])
index += 1
-
+
flush()
-
return tokens
+
return tokens
+73 -56
util/media.py
···
+
import json
+
import os
+
import re
+
import subprocess
+
import urllib.parse
+
+
import magic
import requests
-
import subprocess
-
import json
-
import re, urllib.parse, os
+
from util.util import LOGGER
-
import magic
FILENAME = re.compile(r'filename="?([^\";]*)"?')
MAGIC = magic.Magic(mime=True)
-
class MediaInfo():
+
+
class MediaInfo:
def __init__(self, url: str, name: str, mime: str, alt: str, io: bytes) -> None:
self.url = url
self.name = name
···
self.alt = alt
self.io = io
+
def download_media(url: str, alt: str) -> MediaInfo | None:
name = get_filename_from_url(url)
io = download_blob(url, max_bytes=100_000_000)
···
return None
mime = MAGIC.from_buffer(io)
if not mime:
-
mime = 'application/octet-stream'
+
mime = "application/octet-stream"
return MediaInfo(url, name, mime, alt, io)
+
def get_filename_from_url(url):
try:
response = requests.head(url, allow_redirects=True)
-
disposition = response.headers.get('Content-Disposition')
+
disposition = response.headers.get("Content-Disposition")
if disposition:
filename = FILENAME.findall(disposition)
if filename:
···
parsed_url = urllib.parse.urlparse(url)
base_name = os.path.basename(parsed_url.path)
-
+
# hardcoded fix to return the cid for pds
-
if base_name == 'com.atproto.sync.getBlob':
+
if base_name == "com.atproto.sync.getBlob":
qs = urllib.parse.parse_qs(parsed_url.query)
-
if qs and qs.get('cid'):
-
return qs['cid'][0]
+
if qs and qs.get("cid"):
+
return qs["cid"][0]
return base_name
+
def probe_bytes(bytes: bytes) -> dict:
cmd = [
-
'ffprobe',
-
'-v', 'error',
-
'-show_format',
-
'-show_streams',
-
'-print_format', 'json',
-
'pipe:0'
+
"ffprobe",
+
"-v", "error",
+
"-show_format",
+
"-show_streams",
+
"-print_format", "json",
+
"pipe:0",
]
-
proc = subprocess.run(cmd, input=bytes, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
+
proc = subprocess.run(
+
cmd, input=bytes, stdout=subprocess.PIPE, stderr=subprocess.PIPE
+
)
if proc.returncode != 0:
raise RuntimeError(f"ffprobe failed: {proc.stderr.decode()}")
return json.loads(proc.stdout)
+
def convert_to_mp4(video_bytes: bytes) -> bytes:
cmd = [
-
'ffmpeg',
-
'-i', 'pipe:0',
-
'-c:v', 'libx264',
-
'-crf', '30',
-
'-preset', 'slow',
-
'-c:a', 'aac',
-
'-b:a', '128k',
-
'-movflags', 'frag_keyframe+empty_moov+default_base_moof',
-
'-f', 'mp4',
-
'pipe:1'
+
"ffmpeg",
+
"-i", "pipe:0",
+
"-c:v", "libx264",
+
"-crf", "30",
+
"-preset", "slow",
+
"-c:a", "aac",
+
"-b:a", "128k",
+
"-movflags", "frag_keyframe+empty_moov+default_base_moof",
+
"-f", "mp4",
+
"pipe:1",
]
-
-
proc = subprocess.Popen(cmd, stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
+
+
proc = subprocess.Popen(
+
cmd, stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.PIPE
+
)
out_bytes, err = proc.communicate(input=video_bytes)
-
+
if proc.returncode != 0:
raise RuntimeError(f"ffmpeg compress failed: {err.decode()}")
-
+
return out_bytes
+
def compress_image(image_bytes: bytes, quality: int = 90):
cmd = [
-
'ffmpeg',
-
'-f', 'image2pipe',
-
'-i', 'pipe:0',
-
'-c:v', 'webp',
-
'-q:v', str(quality),
-
'-f', 'image2pipe',
-
'pipe:1'
-
]
+
"ffmpeg",
+
"-f", "image2pipe",
+
"-i", "pipe:0",
+
"-c:v", "webp",
+
"-q:v", str(quality),
+
"-f", "image2pipe",
+
"pipe:1",
+
]
-
proc = subprocess.Popen(cmd, stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
+
proc = subprocess.Popen(
+
cmd, stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.PIPE
+
)
out_bytes, err = proc.communicate(input=image_bytes)
-
+
if proc.returncode != 0:
raise RuntimeError(f"ffmpeg compress failed: {err.decode()}")
-
+
return out_bytes
+
def download_blob(url: str, max_bytes: int = 5_000_000) -> bytes | None:
response = requests.get(url, stream=True, timeout=20)
if response.status_code != 200:
LOGGER.info("Failed to download %s! %s", url, response.text)
return None
-
+
downloaded_bytes = b""
current_size = 0
-
+
for chunk in response.iter_content(chunk_size=8192):
-
if not chunk:
+
if not chunk:
continue
-
+
current_size += len(chunk)
if current_size > max_bytes:
response.close()
return None
-
+
downloaded_bytes += chunk
-
+
return downloaded_bytes
-
+
def get_media_meta(bytes: bytes):
probe = probe_bytes(bytes)
-
streams = [s for s in probe['streams'] if s['codec_type'] == 'video']
+
streams = [s for s in probe["streams"] if s["codec_type"] == "video"]
if not streams:
raise ValueError("No video stream found")
-
+
media = streams[0]
return {
-
'width': int(media['width']),
-
'height': int(media['height']),
-
'duration': float(media.get('duration', probe['format'].get('duration', -1)))
-
}
+
"width": int(media["width"]),
+
"height": int(media["height"]),
+
"duration": float(media.get("duration", probe["format"].get("duration", -1))),
+
}
+20 -13
util/util.py
···
-
import logging, sys, os
import json
+
import logging
+
import os
+
import sys
logging.basicConfig(stream=sys.stdout, level=logging.INFO)
LOGGER = logging.getLogger("XPost")
-
def as_json(obj, indent=None,sort_keys=False) -> str:
+
+
def as_json(obj, indent=None, sort_keys=False) -> str:
return json.dumps(
-
obj.__dict__ if not isinstance(obj, dict) else obj,
-
default=lambda o: o.__json__() if hasattr(o, '__json__') else o.__dict__,
+
obj.__dict__ if not isinstance(obj, dict) else obj,
+
default=lambda o: o.__json__() if hasattr(o, "__json__") else o.__dict__,
indent=indent,
-
sort_keys=sort_keys)
+
sort_keys=sort_keys,
+
)
+
def canonical_label(label: str | None, href: str):
if not label or label == href:
return True
-
-
split = href.split('://', 1)
+
+
split = href.split("://", 1)
if len(split) > 1:
if split[1] == label:
return True
-
+
return False
+
def safe_get(obj: dict, key: str, default):
val = obj.get(key, default)
return val if val else default
+
def as_envvar(text: str | None) -> str | None:
if not text:
return None
-
-
if text.startswith('env:'):
-
return os.environ.get(text[4:], '')
-
-
return text
+
+
if text.startswith("env:"):
+
return os.environ.get(text[4:], "")
+
+
return text