import json import queue import sqlite3 import threading from concurrent.futures import Future class DataBaseWorker: def __init__(self, database: str) -> None: super(DataBaseWorker, self).__init__() self.database = database self.queue = queue.Queue() self.thread = threading.Thread(target=self._run, daemon=True) self.shutdown_event = threading.Event() self.conn = sqlite3.connect(self.database, check_same_thread=False) self.lock = threading.Lock() self.thread.start() def _run(self): while not self.shutdown_event.is_set(): try: task, future = self.queue.get(timeout=1) try: with self.lock: result = task(self.conn) future.set_result(result) except Exception as e: future.set_exception(e) finally: self.queue.task_done() except queue.Empty: continue def execute(self, sql: str, params=()): def task(conn: sqlite3.Connection): cursor = conn.execute(sql, params) conn.commit() return cursor.fetchall() future = Future() self.queue.put((task, future)) return future.result() def close(self): self.shutdown_event.set() self.thread.join() with self.lock: self.conn.close() def try_insert_repost( db: DataBaseWorker, post_id: str, reposted_id: str, input_user: str, input_service: str, ) -> bool: reposted = find_post(db, reposted_id, input_user, input_service) if not reposted: return False insert_repost(db, post_id, reposted["id"], input_user, input_service) return True def try_insert_post( db: DataBaseWorker, post_id: str, in_reply: str | None, input_user: str, input_service: str, ) -> bool: root_id = None parent_id = None if in_reply: parent_post = find_post(db, in_reply, input_user, input_service) if not parent_post: return False root_id = parent_post["id"] parent_id = root_id if parent_post["root_id"]: root_id = parent_post["root_id"] if root_id and parent_id: insert_reply(db, post_id, input_user, input_service, parent_id, root_id) else: insert_post(db, post_id, input_user, input_service) return True def insert_repost( db: DataBaseWorker, identifier: str, reposted_id: int, user_id: str, serivce: str ) -> int: db.execute( """ INSERT INTO posts (user_id, service, identifier, reposted_id) VALUES (?, ?, ?, ?); """, (user_id, serivce, identifier, reposted_id), ) return db.execute("SELECT last_insert_rowid();", ())[0][0] def insert_post(db: DataBaseWorker, identifier: str, user_id: str, serivce: str) -> int: db.execute( """ INSERT INTO posts (user_id, service, identifier) VALUES (?, ?, ?); """, (user_id, serivce, identifier), ) return db.execute("SELECT last_insert_rowid();", ())[0][0] def insert_reply( db: DataBaseWorker, identifier: str, user_id: str, serivce: str, parent: int, root: int, ) -> int: db.execute( """ INSERT INTO posts (user_id, service, identifier, parent_id, root_id) VALUES (?, ?, ?, ?, ?); """, (user_id, serivce, identifier, parent, root), ) return db.execute("SELECT last_insert_rowid();", ())[0][0] def insert_mapping(db: DataBaseWorker, original: int, mapped: int): db.execute( """ INSERT INTO mappings (original_post_id, mapped_post_id) VALUES (?, ?); """, (original, mapped), ) def delete_post(db: DataBaseWorker, identifier: str, user_id: str, serivce: str): db.execute( """ DELETE FROM posts WHERE identifier = ? AND service = ? AND user_id = ? """, (identifier, serivce, user_id), ) def fetch_data(db: DataBaseWorker, identifier: str, user_id: str, service: str) -> dict: result = db.execute( """ SELECT extra_data FROM posts WHERE identifier = ? AND user_id = ? AND service = ? """, (identifier, user_id, service), ) if not result or not result[0]: return {} return json.loads(result[0][0]) def store_data( db: DataBaseWorker, identifier: str, user_id: str, service: str, extra_data: dict ) -> None: db.execute( """ UPDATE posts SET extra_data = ? WHERE identifier = ? AND user_id = ? AND service = ? """, (json.dumps(extra_data), identifier, user_id, service), ) def find_mappings( db: DataBaseWorker, original_post: int, service: str, user_id: str ) -> list[str]: return db.execute( """ SELECT p.identifier FROM posts AS p JOIN mappings AS m ON p.id = m.mapped_post_id WHERE m.original_post_id = ? AND p.service = ? AND p.user_id = ? ORDER BY p.id; """, (original_post, service, user_id), ) def find_post_by_id(db: DataBaseWorker, id: int) -> dict | None: result = db.execute( """ SELECT user_id, service, identifier, parent_id, root_id, reposted_id FROM posts WHERE id = ? """, (id,), ) if not result: return None user_id, service, identifier, parent_id, root_id, reposted_id = result[0] return { "user_id": user_id, "service": service, "identifier": identifier, "parent_id": parent_id, "root_id": root_id, "reposted_id": reposted_id, } def find_post( db: DataBaseWorker, identifier: str, user_id: str, service: str ) -> dict | None: result = db.execute( """ SELECT id, parent_id, root_id, reposted_id FROM posts WHERE identifier = ? AND user_id = ? AND service = ? """, (identifier, user_id, service), ) if not result: return None id, parent_id, root_id, reposted_id = result[0] return { "id": id, "parent_id": parent_id, "root_id": root_id, "reposted_id": reposted_id, } def find_mapped_thread( db: DataBaseWorker, parent_id: str, input_user: str, input_service: str, output_user: str, output_service: str, ): reply_data: dict | None = find_post(db, parent_id, input_user, input_service) if not reply_data: return None reply_mappings: list[str] | None = find_mappings( db, reply_data["id"], output_service, output_user ) if not reply_mappings: return None reply_identifier: str = reply_mappings[-1] root_identifier: str = reply_mappings[0] if reply_data["root_id"]: root_data = find_post_by_id(db, reply_data["root_id"]) if not root_data: return None root_mappings = find_mappings( db, reply_data["root_id"], output_service, output_user ) if not root_mappings: return None root_identifier = root_mappings[0] return ( root_identifier[0], # real ids reply_identifier[0], reply_data["root_id"], # db ids reply_data["id"], )