social media crossposting tool. 3rd time's the charm
mastodon
misskey
crossposting
bluesky
1import sqlite3
2from pathlib import Path
3
4from util.util import LOGGER
5from database.connection import get_conn
6
7
8class DatabaseMigrator:
9 def __init__(self, db_path: Path, migrations_folder: Path) -> None:
10 self.db_path: Path = db_path
11 self.migrations_folder: Path = migrations_folder
12 self.conn: sqlite3.Connection = get_conn(db_path)
13
14 def close(self):
15 self.conn.close()
16
17 def get_version(self) -> int:
18 cursor = self.conn.cursor()
19 _ = cursor.execute("PRAGMA user_version")
20 return int(cursor.fetchone()[0])
21
22 def set_version(self, version: int):
23 cursor = self.conn.cursor()
24 _ = cursor.execute(f"PRAGMA user_version = {version}")
25 self.conn.commit()
26
27 def get_migrations(self) -> list[tuple[int, Path]]:
28 if not self.migrations_folder.exists():
29 return []
30
31 files: list[tuple[int, Path]] = []
32 for f in self.migrations_folder.glob("*.sql"):
33 try:
34 version = int(f.stem.split("_")[0])
35 files.append((version, f))
36 except (ValueError, IndexError):
37 LOGGER.warning("Warning: Skipping invalid migration file: %", f.name)
38
39 return sorted(files, key=lambda x: x[0])
40
41 def apply_migration(self, version: int, path: Path):
42 with open(path, "r") as f:
43 sql = f.read()
44
45 cursor = self.conn.cursor()
46 try:
47 _ = cursor.executescript(sql)
48 self.set_version(version)
49 LOGGER.info("Applied migration: %s", path.name)
50 except sqlite3.Error as e:
51 self.conn.rollback()
52 raise Exception(f"Error applying migration {path.name}: {e}")
53
54 def migrate(self):
55 current_version = self.get_version()
56 migrations = self.get_migrations()
57
58 if not migrations:
59 LOGGER.warning("No migration files found.")
60 return
61
62 pending = [m for m in migrations if m[0] > current_version]
63 if not pending:
64 LOGGER.info("No pending migrations.")
65 return
66
67 for version, filepath in pending:
68 self.apply_migration(version, filepath)