"""Utilities to bootstrap the CalMiner PostgreSQL database. This script is designed to be idempotent. Each step checks the existing state before attempting to modify it so repeated executions are safe. Environment variables (with defaults) used when establishing connections: * ``DATABASE_DRIVER`` (``postgresql``) * ``DATABASE_HOST`` (required) * ``DATABASE_PORT`` (``5432``) * ``DATABASE_NAME`` (required) * ``DATABASE_USER`` (required) * ``DATABASE_PASSWORD`` (optional, required for password auth) * ``DATABASE_SCHEMA`` (``public``) * ``DATABASE_ADMIN_URL`` (overrides individual admin settings) * ``DATABASE_SUPERUSER`` (falls back to ``DATABASE_USER`` or ``postgres``) * ``DATABASE_SUPERUSER_PASSWORD`` (falls back to ``DATABASE_PASSWORD``) * ``DATABASE_SUPERUSER_DB`` (``postgres``) Set ``DATABASE_URL`` if other parts of the application rely on a single connection string; this script will still honor the granular inputs above. """ from __future__ import annotations import argparse import importlib import logging import os import pkgutil import sys from dataclasses import dataclass from pathlib import Path from typing import Callable, Optional, cast from urllib.parse import quote_plus, urlencode import psycopg2 from psycopg2 import errors from psycopg2 import sql from psycopg2 import extensions from psycopg2.extensions import connection as PGConnection, parse_dsn from dotenv import load_dotenv from sqlalchemy import create_engine, inspect ROOT_DIR = Path(__file__).resolve().parents[1] if str(ROOT_DIR) not in sys.path: sys.path.insert(0, str(ROOT_DIR)) from config.database import Base logger = logging.getLogger(__name__) SCRIPTS_DIR = Path(__file__).resolve().parent DEFAULT_MIGRATIONS_DIR = SCRIPTS_DIR / "migrations" MIGRATIONS_TABLE = "schema_migrations" @dataclass(slots=True) class DatabaseConfig: """Configuration required to manage the application database.""" driver: str host: str port: int database: str user: str password: Optional[str] schema: Optional[str] admin_user: str admin_password: Optional[str] admin_database: str = "postgres" @classmethod def from_env( cls, overrides: Optional[dict[str, Optional[str]]] = None, ) -> "DatabaseConfig": load_dotenv() override_map: dict[str, Optional[str]] = dict(overrides or {}) def _get(name: str, default: Optional[str] = None) -> Optional[str]: if name in override_map and override_map[name] is not None: return override_map[name] env_value = os.getenv(name) if env_value is not None: return env_value return default driver = _get("DATABASE_DRIVER", "postgresql") host = _get("DATABASE_HOST") port_value = _get("DATABASE_PORT", "5432") database = _get("DATABASE_NAME") user = _get("DATABASE_USER") password = _get("DATABASE_PASSWORD") schema = _get("DATABASE_SCHEMA", "public") try: port = int(port_value) if port_value is not None else 5432 except ValueError as exc: raise RuntimeError( "Invalid DATABASE_PORT value: expected integer, got" f" '{port_value}'" ) from exc admin_url = _get("DATABASE_ADMIN_URL") if admin_url: admin_conninfo = parse_dsn(admin_url) admin_user = admin_conninfo.get("user") or user or "postgres" admin_password = admin_conninfo.get("password") admin_database = admin_conninfo.get("dbname") or "postgres" host = admin_conninfo.get("host") or host port = int(admin_conninfo.get("port") or port) else: admin_user = _get("DATABASE_SUPERUSER", user or "postgres") admin_password = _get("DATABASE_SUPERUSER_PASSWORD", password) admin_database = _get("DATABASE_SUPERUSER_DB", "postgres") missing = [ name for name, value in ( ("DATABASE_HOST", host), ("DATABASE_NAME", database), ("DATABASE_USER", user), ) if not value ] if missing: raise RuntimeError( "Missing required database configuration: " + ", ".join(missing) ) host = cast(str, host) database = cast(str, database) user = cast(str, user) driver = cast(str, driver) admin_user = cast(str, admin_user) admin_database = cast(str, admin_database) return cls( driver=driver, host=host, port=port, database=database, user=user, password=password, schema=schema, admin_user=admin_user, admin_password=admin_password, admin_database=admin_database, ) def admin_dsn(self, database: Optional[str] = None) -> str: target_db = database or self.admin_database return self._compose_url( user=self.admin_user, password=self.admin_password, database=target_db, schema=None, ) def application_dsn(self) -> str: """Return a SQLAlchemy URL for connecting as the application role.""" return self._compose_url( user=self.user, password=self.password, database=self.database, schema=self.schema, ) def _compose_url( self, *, user: Optional[str], password: Optional[str], database: str, schema: Optional[str], ) -> str: auth = "" if user: encoded_user = quote_plus(user) if password: encoded_pass = quote_plus(password) auth = f"{encoded_user}:{encoded_pass}@" else: auth = f"{encoded_user}@" host = self.host if ":" in host and not host.startswith("["): host = f"[{host}]" host_port = host if self.port: host_port = f"{host}:{self.port}" url = f"{self.driver}://{auth}{host_port}/{database}" params = {} if schema and schema.strip() and schema != "public": params["options"] = f"-csearch_path={schema}" if params: url = f"{url}?{urlencode(params, quote_via=quote_plus)}" return url class DatabaseSetup: """Encapsulates the full setup workflow.""" def __init__( self, config: DatabaseConfig, *, dry_run: bool = False ) -> None: self.config = config self.dry_run = dry_run self._models_loaded = False self._rollback_actions: list[tuple[str, Callable[[], None]]] = [] def _register_rollback( self, label: str, action: Callable[[], None] ) -> None: if self.dry_run: return self._rollback_actions.append((label, action)) def execute_rollbacks(self) -> None: if not self._rollback_actions: logger.info("No rollback actions registered; nothing to undo.") return logger.warning( "Attempting rollback of %d action(s)", len(self._rollback_actions) ) for label, action in reversed(self._rollback_actions): try: logger.warning("Rollback step: %s", label) action() except Exception: logger.exception("Rollback action '%s' failed", label) self._rollback_actions.clear() def clear_rollbacks(self) -> None: self._rollback_actions.clear() def _describe_connection(self, user: str, database: str) -> str: return f"{user}@{self.config.host}:{self.config.port}/{database}" def validate_admin_connection(self) -> None: descriptor = self._describe_connection( self.config.admin_user, self.config.admin_database ) logger.info("Validating admin connection (%s)", descriptor) try: with self._admin_connection(self.config.admin_database) as conn: with conn.cursor() as cursor: cursor.execute("SELECT 1") except psycopg2.Error as exc: raise RuntimeError( "Unable to connect with admin credentials. " "Check DATABASE_ADMIN_URL or DATABASE_SUPERUSER settings." f" Target: {descriptor}" ) from exc logger.info("Admin connection verified (%s)", descriptor) def validate_application_connection(self) -> None: descriptor = self._describe_connection( self.config.user, self.config.database ) logger.info("Validating application connection (%s)", descriptor) try: with self._application_connection() as conn: with conn.cursor() as cursor: cursor.execute("SELECT 1") except psycopg2.Error as exc: raise RuntimeError( "Unable to connect using application credentials. " "Ensure the role exists and credentials are correct. " f"Target: {descriptor}" ) from exc logger.info("Application connection verified (%s)", descriptor) def ensure_database(self) -> None: """Create the target database when it does not already exist.""" logger.info("Ensuring database '%s' exists", self.config.database) try: conn = self._admin_connection(self.config.admin_database) except RuntimeError: logger.error( "Could not connect to admin database '%s' while creating '%s'.", self.config.admin_database, self.config.database, ) raise try: conn.autocommit = True conn.set_isolation_level(extensions.ISOLATION_LEVEL_AUTOCOMMIT) cursor = conn.cursor() try: try: cursor.execute( "SELECT 1 FROM pg_database WHERE datname = %s", (self.config.database,), ) except psycopg2.Error as exc: message = ( "Unable to inspect existing databases while ensuring '%s'." " Verify admin permissions." ) % self.config.database logger.error(message) raise RuntimeError(message) from exc exists = cursor.fetchone() is not None if exists: logger.info( "Database '%s' already present", self.config.database ) return if self.dry_run: logger.info( "Dry run: would create database '%s'. Run without --dry-run to proceed.", self.config.database, ) return try: cursor.execute( sql.SQL("CREATE DATABASE {} ENCODING 'UTF8'").format( sql.Identifier(self.config.database) ) ) except psycopg2.Error as exc: message = ( "Failed to create database '%s'. Rerun with --dry-run for diagnostics" ) % self.config.database logger.error(message) raise RuntimeError(message) from exc else: rollback_label = f"drop database {self.config.database}" self._register_rollback( rollback_label, lambda db=self.config.database: self._drop_database(db), ) logger.info("Created database '%s'", self.config.database) finally: cursor.close() finally: conn.close() def ensure_role(self) -> None: """Create the application role and assign privileges when missing.""" logger.info("Ensuring role '%s' exists", self.config.user) try: admin_conn = self._admin_connection(self.config.admin_database) except RuntimeError: logger.error( "Unable to connect with admin credentials while ensuring role '%s'", self.config.user, ) raise with admin_conn as conn: conn.autocommit = True with conn.cursor() as cursor: try: cursor.execute( "SELECT 1 FROM pg_roles WHERE rolname = %s", (self.config.user,), ) except psycopg2.Error as exc: message = ( "Unable to inspect existing roles while ensuring role '%s'." " Verify admin permissions." ) % self.config.user logger.error(message) raise RuntimeError(message) from exc role_exists = cursor.fetchone() is not None if not role_exists: logger.info("Creating role '%s'", self.config.user) if self.dry_run: logger.info( "Dry run: would create role '%s'. Run without --dry-run to apply.", self.config.user, ) return try: if self.config.password: cursor.execute( sql.SQL( "CREATE ROLE {} WITH LOGIN PASSWORD %s" ).format(sql.Identifier(self.config.user)), (self.config.password,), ) else: cursor.execute( sql.SQL("CREATE ROLE {} WITH LOGIN").format( sql.Identifier(self.config.user) ) ) except psycopg2.Error as exc: message = ( "Failed to create role '%s'. Review admin privileges and rerun." ) % self.config.user logger.error(message) raise RuntimeError(message) from exc else: rollback_label = f"drop role {self.config.user}" self._register_rollback( rollback_label, lambda role=self.config.user: self._drop_role(role), ) else: logger.info("Role '%s' already present", self.config.user) try: role_conn = self._admin_connection(self.config.database) except RuntimeError: logger.error( "Unable to connect to application database '%s' while granting privileges to role '%s'", self.config.database, self.config.user, ) raise if self.dry_run: logger.info( "Dry run: would grant privileges on schema/database to role '%s'.", self.config.user, ) return with role_conn as conn: conn.autocommit = True with conn.cursor() as cursor: schema_name = self.config.schema or "public" schema_identifier = sql.Identifier(schema_name) role_identifier = sql.Identifier(self.config.user) try: cursor.execute( sql.SQL("GRANT CONNECT ON DATABASE {} TO {}").format( sql.Identifier(self.config.database), role_identifier, ) ) cursor.execute( sql.SQL("GRANT USAGE ON SCHEMA {} TO {}").format( schema_identifier, role_identifier, ) ) cursor.execute( sql.SQL("GRANT CREATE ON SCHEMA {} TO {}").format( schema_identifier, role_identifier, ) ) cursor.execute( sql.SQL( "GRANT ALL PRIVILEGES ON ALL TABLES IN SCHEMA {} TO {}" ).format( schema_identifier, role_identifier, ) ) cursor.execute( sql.SQL( "GRANT ALL PRIVILEGES ON ALL SEQUENCES IN SCHEMA {} TO {}" ).format( schema_identifier, role_identifier, ) ) cursor.execute( sql.SQL( "ALTER DEFAULT PRIVILEGES IN SCHEMA {} GRANT SELECT, INSERT, UPDATE, DELETE ON TABLES TO {}" ).format( schema_identifier, role_identifier, ) ) cursor.execute( sql.SQL( "ALTER DEFAULT PRIVILEGES IN SCHEMA {} GRANT USAGE, SELECT ON SEQUENCES TO {}" ).format( schema_identifier, role_identifier, ) ) except psycopg2.Error as exc: message = ( "Failed to grant privileges to role '%s' in schema '%s'." " Rerun with --dry-run for more context." ) % (self.config.user, schema_name) logger.error(message) raise RuntimeError(message) from exc logger.info( "Granted privileges on schema '%s' to role '%s'", schema_name, self.config.user, ) rollback_label = f"revoke privileges for {self.config.user}" self._register_rollback( rollback_label, lambda schema=schema_name: self._revoke_role_privileges( schema_name=schema ), ) def ensure_schema(self) -> None: """Create the configured schema when it does not exist.""" schema_name = self.config.schema if not schema_name or schema_name == "public": logger.info("Using default schema 'public'; nothing to ensure") return logger.info("Ensuring schema '%s' exists", schema_name) with self._admin_connection(self.config.database) as conn: conn.autocommit = True with conn.cursor() as cursor: cursor.execute( sql.SQL( "SELECT 1 FROM information_schema.schemata WHERE schema_name = %s" ), (schema_name,), ) exists = cursor.fetchone() is not None if not exists: if self.dry_run: logger.info( "Dry run: would create schema '%s'", schema_name, ) else: cursor.execute( sql.SQL("CREATE SCHEMA {}").format( sql.Identifier(schema_name) ) ) logger.info("Created schema '%s'", schema_name) try: if self.dry_run: logger.info( "Dry run: would set schema '%s' owner to '%s'", schema_name, self.config.user, ) else: cursor.execute( sql.SQL("ALTER SCHEMA {} OWNER TO {}").format( sql.Identifier(schema_name), sql.Identifier(self.config.user), ) ) except errors.UndefinedObject: logger.warning( "Role '%s' not found when assigning ownership to schema '%s'." " Run --ensure-role after creating the schema.", self.config.user, schema_name, ) def application_role_exists(self) -> bool: try: with self._admin_connection(self.config.admin_database) as conn: with conn.cursor() as cursor: try: cursor.execute( "SELECT 1 FROM pg_roles WHERE rolname = %s", (self.config.user,), ) except psycopg2.Error as exc: message = ( "Unable to inspect existing roles while checking for role '%s'." " Verify admin permissions." ) % self.config.user logger.error(message) raise RuntimeError(message) from exc return cursor.fetchone() is not None except RuntimeError: raise def _admin_connection(self, database: Optional[str] = None) -> PGConnection: target_db = database or self.config.admin_database dsn = self.config.admin_dsn(database) descriptor = self._describe_connection( self.config.admin_user, target_db ) try: return psycopg2.connect(dsn) except psycopg2.Error as exc: raise RuntimeError( "Unable to establish admin connection. " f"Target: {descriptor}" ) from exc def _application_connection(self) -> PGConnection: dsn = self.config.application_dsn() descriptor = self._describe_connection( self.config.user, self.config.database ) try: return psycopg2.connect(dsn) except psycopg2.Error as exc: raise RuntimeError( "Unable to establish application connection. " f"Target: {descriptor}" ) from exc def initialize_schema(self) -> None: """Create database objects from SQLAlchemy metadata if missing.""" self._ensure_models_loaded() logger.info("Ensuring SQLAlchemy metadata is reflected in database") engine = create_engine(self.config.application_dsn(), future=True) try: inspector = inspect(engine) existing_tables = set( inspector.get_table_names(schema=self.config.schema) ) metadata_tables = set(Base.metadata.tables.keys()) missing_tables = sorted(metadata_tables - existing_tables) if missing_tables: logger.info("Pending tables: %s", ", ".join(missing_tables)) else: logger.info("All tables already exist") if self.dry_run: if missing_tables: logger.info("Dry run: skipping creation of pending tables") return Base.metadata.create_all(bind=engine, checkfirst=True) finally: engine.dispose() logger.info("Schema initialization complete") def _ensure_models_loaded(self) -> None: if self._models_loaded: return package = importlib.import_module("models") for module_info in pkgutil.iter_modules(package.__path__): importlib.import_module(f"{package.__name__}.{module_info.name}") self._models_loaded = True def run_migrations( self, migrations_dir: Optional[Path | str] = None ) -> None: """Execute pending SQL migrations in chronological order.""" directory = ( Path(migrations_dir) if migrations_dir is not None else DEFAULT_MIGRATIONS_DIR ) directory = directory.resolve() if not directory.exists(): logger.warning("Migrations directory '%s' not found", directory) return migration_files = sorted(directory.glob("*.sql")) if not migration_files: logger.info("No migration scripts found in '%s'", directory) return baseline_name = "000_base.sql" baseline_path = directory / baseline_name schema_name = self.config.schema or "public" with self._application_connection() as conn: conn.autocommit = True with conn.cursor() as cursor: table_exists = self._migrations_table_exists( cursor, schema_name ) if not table_exists: if self.dry_run: logger.info( "Dry run: would create migration history table %s.%s", schema_name, MIGRATIONS_TABLE, ) applied: set[str] = set() else: self._create_migrations_table(cursor, schema_name) logger.info( "Created migration history table %s.%s", schema_name, MIGRATIONS_TABLE, ) applied = set() else: applied = self._fetch_applied_migrations( cursor, schema_name ) if baseline_path.exists() and baseline_name not in applied: if self.dry_run: logger.info( "Dry run: baseline migration '%s' pending; would apply and mark legacy files", baseline_name, ) else: logger.info( "Baseline migration '%s' pending; applying and marking older migrations", baseline_name, ) try: baseline_applied = self._apply_migration_file( cursor, schema_name, baseline_path ) except Exception: logger.error( "Failed while applying baseline migration '%s'." " Review the migration contents and rerun with --dry-run for diagnostics.", baseline_name, exc_info=True, ) raise applied.add(baseline_applied) legacy_files = [ path for path in migration_files if path.name != baseline_name ] for legacy in legacy_files: if legacy.name not in applied: try: cursor.execute( sql.SQL( "INSERT INTO {} (filename, applied_at) VALUES (%s, NOW())" ).format( sql.Identifier( schema_name, MIGRATIONS_TABLE, ) ), (legacy.name,), ) except Exception: logger.error( "Unable to record legacy migration '%s' after baseline application." " Check schema_migrations table in schema '%s' for partial state.", legacy.name, schema_name, exc_info=True, ) raise applied.add(legacy.name) logger.info( "Marked legacy migration '%s' as applied via baseline", legacy.name, ) pending = [ path for path in migration_files if path.name not in applied ] if not pending: logger.info("No pending migrations") return logger.info( "Pending migrations: %s", ", ".join(path.name for path in pending), ) if self.dry_run: logger.info("Dry run: skipping migration execution") return for path in pending: self._apply_migration_file(cursor, schema_name, path) logger.info("Applied %d migrations", len(pending)) def _apply_migration_file( self, cursor, schema_name: str, path: Path, ) -> str: logger.info("Applying migration '%s'", path.name) sql_text = path.read_text(encoding="utf-8") try: cursor.execute(sql_text) cursor.execute( sql.SQL( "INSERT INTO {} (filename, applied_at) VALUES (%s, NOW())" ).format(sql.Identifier(schema_name, MIGRATIONS_TABLE)), (path.name,), ) return path.name except Exception: logger.exception("Failed to apply migration '%s'", path.name) raise def _migrations_table_exists(self, cursor, schema_name: str) -> bool: cursor.execute( """ SELECT 1 FROM information_schema.tables WHERE table_schema = %s AND table_name = %s """, (schema_name, MIGRATIONS_TABLE), ) return cursor.fetchone() is not None def _create_migrations_table(self, cursor, schema_name: str) -> None: cursor.execute( sql.SQL( "CREATE TABLE IF NOT EXISTS {} (" "filename TEXT PRIMARY KEY," "applied_at TIMESTAMPTZ NOT NULL DEFAULT NOW()" ")" ).format(sql.Identifier(schema_name, MIGRATIONS_TABLE)) ) def _fetch_applied_migrations(self, cursor, schema_name: str) -> set[str]: cursor.execute( sql.SQL("SELECT filename FROM {} ORDER BY filename").format( sql.Identifier(schema_name, MIGRATIONS_TABLE) ) ) return {row[0] for row in cursor.fetchall()} def seed_baseline_data(self, *, dry_run: bool) -> None: """Seed reference data such as currencies.""" from scripts import seed_data seed_args = argparse.Namespace( currencies=True, units=True, defaults=False, dry_run=dry_run, verbose=0, ) seed_data.run_with_namespace(seed_args, config=self.config) if dry_run: logger.info("Dry run: skipped seed verification") return expected_currencies = { code for code, *_ in getattr(seed_data, "CURRENCY_SEEDS", ()) } expected_units = { code for code, *_ in getattr(seed_data, "MEASUREMENT_UNIT_SEEDS", ()) } self._verify_seeded_data( expected_currency_codes=expected_currencies, expected_unit_codes=expected_units, ) def _verify_seeded_data( self, *, expected_currency_codes: set[str], expected_unit_codes: set[str], ) -> None: if not expected_currency_codes and not expected_unit_codes: logger.info("No seed datasets configured for verification") return with self._application_connection() as conn: with conn.cursor() as cursor: if expected_currency_codes: cursor.execute( "SELECT code, is_active FROM currency WHERE code = ANY(%s)", (list(expected_currency_codes),), ) rows = cursor.fetchall() found_codes = {row[0] for row in rows} missing_codes = sorted( expected_currency_codes - found_codes ) if missing_codes: message = ( "Missing expected currencies after seeding: %s. " "Run scripts/seed_data.py --currencies to restore them." ) % ", ".join(missing_codes) logger.error(message) raise RuntimeError(message) logger.info( "Verified %d seeded currencies present", len(found_codes), ) default_status = next( (row[1] for row in rows if row[0] == "USD"), None ) if default_status is False: message = ( "Default currency 'USD' is inactive after seeding. " "Reactivate it or rerun the seeding command." ) logger.error(message) raise RuntimeError(message) elif default_status is None: message = ( "Default currency 'USD' not found after seeding. " "Ensure baseline migration 000_base.sql ran successfully." ) logger.error(message) raise RuntimeError(message) else: logger.info("Verified default currency 'USD' active") if expected_unit_codes: try: cursor.execute( "SELECT code, is_active FROM measurement_unit WHERE code = ANY(%s)", (list(expected_unit_codes),), ) except errors.UndefinedTable: conn.rollback() message = ( "measurement_unit table not found during seed verification. " "Ensure baseline migration 000_base.sql has been applied." ) logger.error(message) raise RuntimeError(message) else: rows = cursor.fetchall() found_units = {row[0] for row in rows} missing_units = sorted( expected_unit_codes - found_units ) if missing_units: message = ( "Missing expected measurement units after seeding: %s. " "Run scripts/seed_data.py --units to restore them." ) % ", ".join(missing_units) logger.error(message) raise RuntimeError(message) inactive_units = sorted( row[0] for row in rows if not bool(row[1]) ) if inactive_units: message = ( "Measurement units inactive after seeding: %s. " "Reactivate them or rerun unit seeding." ) % ", ".join(inactive_units) logger.error(message) raise RuntimeError(message) logger.info( "Verified %d measurement units present", len(found_units), ) logger.info("Seed verification complete") def _drop_database(self, database: str) -> None: logger.warning("Rollback: dropping database '%s'", database) with self._admin_connection(self.config.admin_database) as conn: conn.autocommit = True with conn.cursor() as cursor: cursor.execute( "SELECT pg_terminate_backend(pid) FROM pg_stat_activity WHERE datname = %s", (database,), ) cursor.execute( sql.SQL("DROP DATABASE IF EXISTS {}").format( sql.Identifier(database) ) ) def _drop_role(self, role: str) -> None: logger.warning("Rollback: dropping role '%s'", role) with self._admin_connection(self.config.admin_database) as conn: conn.autocommit = True with conn.cursor() as cursor: cursor.execute( sql.SQL("DROP ROLE IF EXISTS {}").format( sql.Identifier(role) ) ) def _revoke_role_privileges(self, *, schema_name: str) -> None: logger.warning( "Rollback: revoking privileges on schema '%s' for role '%s'", schema_name, self.config.user, ) with self._admin_connection(self.config.database) as conn: conn.autocommit = True with conn.cursor() as cursor: cursor.execute( sql.SQL( "REVOKE ALL PRIVILEGES ON ALL TABLES IN SCHEMA {} FROM {}" ).format( sql.Identifier(schema_name), sql.Identifier(self.config.user), ) ) cursor.execute( sql.SQL( "REVOKE ALL PRIVILEGES ON ALL SEQUENCES IN SCHEMA {} FROM {}" ).format( sql.Identifier(schema_name), sql.Identifier(self.config.user), ) ) cursor.execute( sql.SQL( "ALTER DEFAULT PRIVILEGES IN SCHEMA {} REVOKE SELECT, INSERT, UPDATE, DELETE ON TABLES FROM {}" ).format( sql.Identifier(schema_name), sql.Identifier(self.config.user), ) ) cursor.execute( sql.SQL( "ALTER DEFAULT PRIVILEGES IN SCHEMA {} REVOKE USAGE, SELECT ON SEQUENCES FROM {}" ).format( sql.Identifier(schema_name), sql.Identifier(self.config.user), ) ) def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser(description="Bootstrap CalMiner database") parser.add_argument( "--ensure-database", action="store_true", help="Create the application database when it does not already exist.", ) parser.add_argument( "--ensure-role", action="store_true", help="Create the application role and grant necessary privileges.", ) parser.add_argument( "--ensure-schema", action="store_true", help="Create the configured schema if it does not exist.", ) parser.add_argument( "--initialize-schema", action="store_true", help="Create missing tables based on SQLAlchemy models.", ) parser.add_argument( "--run-migrations", action="store_true", help="Execute pending SQL migrations.", ) parser.add_argument( "--seed-data", action="store_true", help="Seed baseline reference data (currencies, etc.).", ) parser.add_argument( "--migrations-dir", default=None, help="Override the default migrations directory.", ) parser.add_argument("--db-driver", help="Override DATABASE_DRIVER") parser.add_argument("--db-host", help="Override DATABASE_HOST") parser.add_argument("--db-port", type=int, help="Override DATABASE_PORT") parser.add_argument("--db-name", help="Override DATABASE_NAME") parser.add_argument("--db-user", help="Override DATABASE_USER") parser.add_argument("--db-password", help="Override DATABASE_PASSWORD") parser.add_argument("--db-schema", help="Override DATABASE_SCHEMA") parser.add_argument( "--admin-url", help="Override DATABASE_ADMIN_URL for administrative operations", ) parser.add_argument( "--admin-user", help="Override DATABASE_SUPERUSER for admin ops" ) parser.add_argument( "--admin-password", help="Override DATABASE_SUPERUSER_PASSWORD for admin ops", ) parser.add_argument( "--admin-db", help="Override DATABASE_SUPERUSER_DB for admin ops", ) parser.add_argument( "--dry-run", action="store_true", help="Log actions without applying changes.", ) parser.add_argument( "--verbose", "-v", action="count", default=0, help="Increase logging verbosity", ) return parser.parse_args() def main() -> None: args = parse_args() level = logging.WARNING - (10 * min(args.verbose, 2)) logging.basicConfig( level=max(level, logging.INFO), format="%(levelname)s %(message)s" ) override_args: dict[str, Optional[str]] = { "DATABASE_DRIVER": args.db_driver, "DATABASE_HOST": args.db_host, "DATABASE_NAME": args.db_name, "DATABASE_USER": args.db_user, "DATABASE_PASSWORD": args.db_password, "DATABASE_SCHEMA": args.db_schema, "DATABASE_ADMIN_URL": args.admin_url, "DATABASE_SUPERUSER": args.admin_user, "DATABASE_SUPERUSER_PASSWORD": args.admin_password, "DATABASE_SUPERUSER_DB": args.admin_db, } if args.db_port is not None: override_args["DATABASE_PORT"] = str(args.db_port) config = DatabaseConfig.from_env(overrides=override_args) setup = DatabaseSetup(config, dry_run=args.dry_run) admin_tasks_requested = ( args.ensure_database or args.ensure_role or args.ensure_schema ) if admin_tasks_requested: setup.validate_admin_connection() app_validated = False def ensure_application_connection_for(operation: str) -> bool: nonlocal app_validated if app_validated: return True if setup.dry_run and not setup.application_role_exists(): logger.info( "Dry run: skipping %s because application role '%s' does not exist yet.", operation, setup.config.user, ) return False setup.validate_application_connection() app_validated = True return True should_run_migrations = args.run_migrations auto_run_migrations_reason: Optional[str] = None if args.seed_data and not should_run_migrations: should_run_migrations = True auto_run_migrations_reason = "Seed data requested without explicit --run-migrations; applying migrations first." try: if args.ensure_database: setup.ensure_database() if args.ensure_role: setup.ensure_role() if args.ensure_schema: setup.ensure_schema() if args.initialize_schema: if ensure_application_connection_for( "SQLAlchemy schema initialization" ): setup.initialize_schema() if should_run_migrations: if ensure_application_connection_for("migration execution"): if auto_run_migrations_reason: logger.info(auto_run_migrations_reason) migrations_path = ( Path(args.migrations_dir) if args.migrations_dir else None ) setup.run_migrations(migrations_path) if args.seed_data: if ensure_application_connection_for("baseline data seeding"): setup.seed_baseline_data(dry_run=args.dry_run) except Exception: if not setup.dry_run: setup.execute_rollbacks() raise finally: if not setup.dry_run: setup.clear_rollbacks() if __name__ == "__main__": main()