social media crossposting tool. 3rd time's the charm
mastodon
misskey
crossposting
bluesky
1import sqlite3
2from pathlib import Path
3from typing import Callable
4
5from database.connection import get_conn
6from util.util import LOGGER
7
8class DatabaseMigrator:
9 def __init__(self, db_path: Path, migrations_folder: Path) -> None:
10 self.db_path: Path = db_path
11 self.migrations_folder: Path = migrations_folder
12 self.conn: sqlite3.Connection = get_conn(db_path)
13 _ = self.conn.execute("PRAGMA foreign_keys = OFF;")
14 self.conn.autocommit = False
15
16 def close(self):
17 self.conn.close()
18
19 def get_version(self) -> int:
20 cursor = self.conn.cursor()
21 _ = cursor.execute("PRAGMA user_version")
22 return int(cursor.fetchone()[0])
23
24 def set_version(self, version: int):
25 cursor = self.conn.cursor()
26 _ = cursor.execute(f"PRAGMA user_version = {version}")
27 self.conn.commit()
28
29 def apply_migration(self, version: int, filename: str, migration: Callable[[sqlite3.Connection], None]):
30 try:
31 _ = migration(self.conn)
32 self.set_version(version)
33 self.conn.commit()
34 LOGGER.info("Applied migration: %s..", filename)
35 except sqlite3.Error as e:
36 self.conn.rollback()
37 raise Exception(f"Error applying migration {filename}: {e}")
38
39 def migrate(self):
40 current_version = self.get_version()
41 from migrations._registry import load_migrations
42 migrations = load_migrations(self.migrations_folder)
43
44 if not migrations:
45 LOGGER.warning("No migration files found.")
46 return
47
48 pending = [m for m in migrations if m[0] > current_version]
49 if not pending:
50 LOGGER.info("No pending migrations.")
51 return
52
53 for version, filename, migration in pending:
54 self.apply_migration(version, filename, migration)