import sqlite3 from pathlib import Path from typing import Callable from database.connection import get_conn from util.util import LOGGER class DatabaseMigrator: def __init__(self, db_path: Path, migrations_folder: Path) -> None: self.db_path: Path = db_path self.migrations_folder: Path = migrations_folder self.conn: sqlite3.Connection = get_conn(db_path) _ = self.conn.execute("PRAGMA foreign_keys = OFF;") self.conn.autocommit = False def close(self): self.conn.close() def get_version(self) -> int: cursor = self.conn.cursor() _ = cursor.execute("PRAGMA user_version") return int(cursor.fetchone()[0]) def set_version(self, version: int): cursor = self.conn.cursor() _ = cursor.execute(f"PRAGMA user_version = {version}") self.conn.commit() def apply_migration(self, version: int, filename: str, migration: Callable[[sqlite3.Connection], None]): try: _ = migration(self.conn) self.set_version(version) self.conn.commit() LOGGER.info("Applied migration: %s..", filename) except sqlite3.Error as e: self.conn.rollback() raise Exception(f"Error applying migration {filename}: {e}") def migrate(self): current_version = self.get_version() from migrations._registry import load_migrations migrations = load_migrations(self.migrations_folder) if not migrations: LOGGER.warning("No migration files found.") return pending = [m for m in migrations if m[0] > current_version] if not pending: LOGGER.info("No pending migrations.") return for version, filename, migration in pending: self.apply_migration(version, filename, migration)