import sqlite3 from pathlib import Path from util.util import LOGGER from database.connection import get_conn class DatabaseMigrator: def __init__(self, db_path: Path, migrations_folder: Path) -> None: self.db_path: Path = db_path self.migrations_folder: Path = migrations_folder self.conn: sqlite3.Connection = get_conn(db_path) def close(self): self.conn.close() def get_version(self) -> int: cursor = self.conn.cursor() _ = cursor.execute("PRAGMA user_version") return int(cursor.fetchone()[0]) def set_version(self, version: int): cursor = self.conn.cursor() _ = cursor.execute(f"PRAGMA user_version = {version}") self.conn.commit() def get_migrations(self) -> list[tuple[int, Path]]: if not self.migrations_folder.exists(): return [] files: list[tuple[int, Path]] = [] for f in self.migrations_folder.glob("*.sql"): try: version = int(f.stem.split("_")[0]) files.append((version, f)) except (ValueError, IndexError): LOGGER.warning("Warning: Skipping invalid migration file: %", f.name) return sorted(files, key=lambda x: x[0]) def apply_migration(self, version: int, path: Path): with open(path, "r") as f: sql = f.read() cursor = self.conn.cursor() try: _ = cursor.executescript(sql) self.set_version(version) LOGGER.info("Applied migration: %s", path.name) except sqlite3.Error as e: self.conn.rollback() raise Exception(f"Error applying migration {path.name}: {e}") def migrate(self): current_version = self.get_version() migrations = self.get_migrations() if not migrations: LOGGER.warning("No migration files found.") return pending = [m for m in migrations if m[0] > current_version] if not pending: LOGGER.info("No pending migrations.") return for version, filepath in pending: self.apply_migration(version, filepath)