- Updated test functions in various test files to enhance code clarity by formatting long lines and improving indentation. - Adjusted assertions to use multi-line formatting for better readability. - Added new test cases for theme settings API to ensure proper functionality. - Ensured consistent use of line breaks and spacing across test files for uniformity.
1197 lines
45 KiB
Python
1197 lines
45 KiB
Python
"""Utilities to bootstrap the CalMiner PostgreSQL database.
|
|
|
|
This script is designed to be idempotent. Each step checks the existing
|
|
state before attempting to modify it so repeated executions are safe.
|
|
|
|
Environment variables (with defaults) used when establishing connections:
|
|
|
|
* ``DATABASE_DRIVER`` (``postgresql``)
|
|
* ``DATABASE_HOST`` (required)
|
|
* ``DATABASE_PORT`` (``5432``)
|
|
* ``DATABASE_NAME`` (required)
|
|
* ``DATABASE_USER`` (required)
|
|
* ``DATABASE_PASSWORD`` (optional, required for password auth)
|
|
* ``DATABASE_SCHEMA`` (``public``)
|
|
* ``DATABASE_ADMIN_URL`` (overrides individual admin settings)
|
|
* ``DATABASE_SUPERUSER`` (falls back to ``DATABASE_USER`` or ``postgres``)
|
|
* ``DATABASE_SUPERUSER_PASSWORD`` (falls back to ``DATABASE_PASSWORD``)
|
|
* ``DATABASE_SUPERUSER_DB`` (``postgres``)
|
|
|
|
Set ``DATABASE_URL`` if other parts of the application rely on a single
|
|
connection string; this script will still honor the granular inputs above.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
import argparse
|
|
import importlib
|
|
import logging
|
|
import os
|
|
import pkgutil
|
|
import sys
|
|
from dataclasses import dataclass
|
|
from pathlib import Path
|
|
from typing import Callable, Optional, cast
|
|
from urllib.parse import quote_plus, urlencode
|
|
import psycopg2
|
|
from psycopg2 import errors
|
|
from psycopg2 import sql
|
|
from psycopg2 import extensions
|
|
from psycopg2.extensions import connection as PGConnection, parse_dsn
|
|
from dotenv import load_dotenv
|
|
from sqlalchemy import create_engine, inspect
|
|
|
|
ROOT_DIR = Path(__file__).resolve().parents[1]
|
|
if str(ROOT_DIR) not in sys.path:
|
|
sys.path.insert(0, str(ROOT_DIR))
|
|
from config.database import Base
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
SCRIPTS_DIR = Path(__file__).resolve().parent
|
|
DEFAULT_MIGRATIONS_DIR = SCRIPTS_DIR / "migrations"
|
|
MIGRATIONS_TABLE = "schema_migrations"
|
|
|
|
|
|
@dataclass(slots=True)
|
|
class DatabaseConfig:
|
|
"""Configuration required to manage the application database."""
|
|
|
|
driver: str
|
|
host: str
|
|
port: int
|
|
database: str
|
|
user: str
|
|
password: Optional[str]
|
|
schema: Optional[str]
|
|
|
|
admin_user: str
|
|
admin_password: Optional[str]
|
|
admin_database: str = "postgres"
|
|
|
|
@classmethod
|
|
def from_env(
|
|
cls,
|
|
overrides: Optional[dict[str, Optional[str]]] = None,
|
|
) -> "DatabaseConfig":
|
|
load_dotenv()
|
|
|
|
override_map: dict[str, Optional[str]] = dict(overrides or {})
|
|
|
|
def _get(name: str, default: Optional[str] = None) -> Optional[str]:
|
|
if name in override_map and override_map[name] is not None:
|
|
return override_map[name]
|
|
env_value = os.getenv(name)
|
|
if env_value is not None:
|
|
return env_value
|
|
return default
|
|
|
|
driver = _get("DATABASE_DRIVER", "postgresql")
|
|
host = _get("DATABASE_HOST")
|
|
port_value = _get("DATABASE_PORT", "5432")
|
|
database = _get("DATABASE_NAME")
|
|
user = _get("DATABASE_USER")
|
|
password = _get("DATABASE_PASSWORD")
|
|
schema = _get("DATABASE_SCHEMA", "public")
|
|
|
|
try:
|
|
port = int(port_value) if port_value is not None else 5432
|
|
except ValueError as exc:
|
|
raise RuntimeError(
|
|
"Invalid DATABASE_PORT value: expected integer, got"
|
|
f" '{port_value}'"
|
|
) from exc
|
|
|
|
admin_url = _get("DATABASE_ADMIN_URL")
|
|
if admin_url:
|
|
admin_conninfo = parse_dsn(admin_url)
|
|
admin_user = admin_conninfo.get("user") or user or "postgres"
|
|
admin_password = admin_conninfo.get("password")
|
|
admin_database = admin_conninfo.get("dbname") or "postgres"
|
|
host = admin_conninfo.get("host") or host
|
|
port = int(admin_conninfo.get("port") or port)
|
|
else:
|
|
admin_user = _get("DATABASE_SUPERUSER", user or "postgres")
|
|
admin_password = _get("DATABASE_SUPERUSER_PASSWORD", password)
|
|
admin_database = _get("DATABASE_SUPERUSER_DB", "postgres")
|
|
|
|
missing = [
|
|
name
|
|
for name, value in (
|
|
("DATABASE_HOST", host),
|
|
("DATABASE_NAME", database),
|
|
("DATABASE_USER", user),
|
|
)
|
|
if not value
|
|
]
|
|
if missing:
|
|
raise RuntimeError(
|
|
"Missing required database configuration: " + ", ".join(missing)
|
|
)
|
|
|
|
host = cast(str, host)
|
|
database = cast(str, database)
|
|
user = cast(str, user)
|
|
driver = cast(str, driver)
|
|
admin_user = cast(str, admin_user)
|
|
admin_database = cast(str, admin_database)
|
|
|
|
return cls(
|
|
driver=driver,
|
|
host=host,
|
|
port=port,
|
|
database=database,
|
|
user=user,
|
|
password=password,
|
|
schema=schema,
|
|
admin_user=admin_user,
|
|
admin_password=admin_password,
|
|
admin_database=admin_database,
|
|
)
|
|
|
|
def admin_dsn(self, database: Optional[str] = None) -> str:
|
|
target_db = database or self.admin_database
|
|
return self._compose_url(
|
|
user=self.admin_user,
|
|
password=self.admin_password,
|
|
database=target_db,
|
|
schema=None,
|
|
)
|
|
|
|
def application_dsn(self) -> str:
|
|
"""Return a SQLAlchemy URL for connecting as the application role."""
|
|
|
|
return self._compose_url(
|
|
user=self.user,
|
|
password=self.password,
|
|
database=self.database,
|
|
schema=self.schema,
|
|
)
|
|
|
|
def _compose_url(
|
|
self,
|
|
*,
|
|
user: Optional[str],
|
|
password: Optional[str],
|
|
database: str,
|
|
schema: Optional[str],
|
|
) -> str:
|
|
auth = ""
|
|
if user:
|
|
encoded_user = quote_plus(user)
|
|
if password:
|
|
encoded_pass = quote_plus(password)
|
|
auth = f"{encoded_user}:{encoded_pass}@"
|
|
else:
|
|
auth = f"{encoded_user}@"
|
|
|
|
host = self.host
|
|
if ":" in host and not host.startswith("["):
|
|
host = f"[{host}]"
|
|
|
|
host_port = host
|
|
if self.port:
|
|
host_port = f"{host}:{self.port}"
|
|
|
|
url = f"{self.driver}://{auth}{host_port}/{database}"
|
|
|
|
params = {}
|
|
if schema and schema.strip() and schema != "public":
|
|
params["options"] = f"-csearch_path={schema}"
|
|
|
|
if params:
|
|
url = f"{url}?{urlencode(params, quote_via=quote_plus)}"
|
|
|
|
return url
|
|
|
|
|
|
class DatabaseSetup:
|
|
"""Encapsulates the full setup workflow."""
|
|
|
|
def __init__(
|
|
self, config: DatabaseConfig, *, dry_run: bool = False
|
|
) -> None:
|
|
self.config = config
|
|
self.dry_run = dry_run
|
|
self._models_loaded = False
|
|
self._rollback_actions: list[tuple[str, Callable[[], None]]] = []
|
|
|
|
def _register_rollback(
|
|
self, label: str, action: Callable[[], None]
|
|
) -> None:
|
|
if self.dry_run:
|
|
return
|
|
self._rollback_actions.append((label, action))
|
|
|
|
def execute_rollbacks(self) -> None:
|
|
if not self._rollback_actions:
|
|
logger.info("No rollback actions registered; nothing to undo.")
|
|
return
|
|
|
|
logger.warning(
|
|
"Attempting rollback of %d action(s)", len(self._rollback_actions)
|
|
)
|
|
for label, action in reversed(self._rollback_actions):
|
|
try:
|
|
logger.warning("Rollback step: %s", label)
|
|
action()
|
|
except Exception:
|
|
logger.exception("Rollback action '%s' failed", label)
|
|
self._rollback_actions.clear()
|
|
|
|
def clear_rollbacks(self) -> None:
|
|
self._rollback_actions.clear()
|
|
|
|
def _describe_connection(self, user: str, database: str) -> str:
|
|
return f"{user}@{self.config.host}:{self.config.port}/{database}"
|
|
|
|
def validate_admin_connection(self) -> None:
|
|
descriptor = self._describe_connection(
|
|
self.config.admin_user, self.config.admin_database
|
|
)
|
|
logger.info("Validating admin connection (%s)", descriptor)
|
|
try:
|
|
with self._admin_connection(self.config.admin_database) as conn:
|
|
with conn.cursor() as cursor:
|
|
cursor.execute("SELECT 1")
|
|
except psycopg2.Error as exc:
|
|
raise RuntimeError(
|
|
"Unable to connect with admin credentials. "
|
|
"Check DATABASE_ADMIN_URL or DATABASE_SUPERUSER settings."
|
|
f" Target: {descriptor}"
|
|
) from exc
|
|
logger.info("Admin connection verified (%s)", descriptor)
|
|
|
|
def validate_application_connection(self) -> None:
|
|
descriptor = self._describe_connection(
|
|
self.config.user, self.config.database
|
|
)
|
|
logger.info("Validating application connection (%s)", descriptor)
|
|
try:
|
|
with self._application_connection() as conn:
|
|
with conn.cursor() as cursor:
|
|
cursor.execute("SELECT 1")
|
|
except psycopg2.Error as exc:
|
|
raise RuntimeError(
|
|
"Unable to connect using application credentials. "
|
|
"Ensure the role exists and credentials are correct. "
|
|
f"Target: {descriptor}"
|
|
) from exc
|
|
logger.info("Application connection verified (%s)", descriptor)
|
|
|
|
def ensure_database(self) -> None:
|
|
"""Create the target database when it does not already exist."""
|
|
|
|
logger.info("Ensuring database '%s' exists", self.config.database)
|
|
try:
|
|
conn = self._admin_connection(self.config.admin_database)
|
|
except RuntimeError:
|
|
logger.error(
|
|
"Could not connect to admin database '%s' while creating '%s'.",
|
|
self.config.admin_database,
|
|
self.config.database,
|
|
)
|
|
raise
|
|
try:
|
|
conn.autocommit = True
|
|
conn.set_isolation_level(extensions.ISOLATION_LEVEL_AUTOCOMMIT)
|
|
cursor = conn.cursor()
|
|
try:
|
|
try:
|
|
cursor.execute(
|
|
"SELECT 1 FROM pg_database WHERE datname = %s",
|
|
(self.config.database,),
|
|
)
|
|
except psycopg2.Error as exc:
|
|
message = (
|
|
"Unable to inspect existing databases while ensuring '%s'."
|
|
" Verify admin permissions."
|
|
) % self.config.database
|
|
logger.error(message)
|
|
raise RuntimeError(message) from exc
|
|
|
|
exists = cursor.fetchone() is not None
|
|
if exists:
|
|
logger.info(
|
|
"Database '%s' already present", self.config.database
|
|
)
|
|
return
|
|
|
|
if self.dry_run:
|
|
logger.info(
|
|
"Dry run: would create database '%s'. Run without --dry-run to proceed.",
|
|
self.config.database,
|
|
)
|
|
return
|
|
|
|
try:
|
|
cursor.execute(
|
|
sql.SQL("CREATE DATABASE {} ENCODING 'UTF8'").format(
|
|
sql.Identifier(self.config.database)
|
|
)
|
|
)
|
|
except psycopg2.Error as exc:
|
|
message = (
|
|
"Failed to create database '%s'. Rerun with --dry-run for diagnostics"
|
|
) % self.config.database
|
|
logger.error(message)
|
|
raise RuntimeError(message) from exc
|
|
else:
|
|
rollback_label = f"drop database {self.config.database}"
|
|
self._register_rollback(
|
|
rollback_label,
|
|
lambda db=self.config.database: self._drop_database(db),
|
|
)
|
|
logger.info("Created database '%s'", self.config.database)
|
|
finally:
|
|
cursor.close()
|
|
finally:
|
|
conn.close()
|
|
|
|
def ensure_role(self) -> None:
|
|
"""Create the application role and assign privileges when missing."""
|
|
|
|
logger.info("Ensuring role '%s' exists", self.config.user)
|
|
try:
|
|
admin_conn = self._admin_connection(self.config.admin_database)
|
|
except RuntimeError:
|
|
logger.error(
|
|
"Unable to connect with admin credentials while ensuring role '%s'",
|
|
self.config.user,
|
|
)
|
|
raise
|
|
|
|
with admin_conn as conn:
|
|
conn.autocommit = True
|
|
with conn.cursor() as cursor:
|
|
try:
|
|
cursor.execute(
|
|
"SELECT 1 FROM pg_roles WHERE rolname = %s",
|
|
(self.config.user,),
|
|
)
|
|
except psycopg2.Error as exc:
|
|
message = (
|
|
"Unable to inspect existing roles while ensuring role '%s'."
|
|
" Verify admin permissions."
|
|
) % self.config.user
|
|
logger.error(message)
|
|
raise RuntimeError(message) from exc
|
|
role_exists = cursor.fetchone() is not None
|
|
if not role_exists:
|
|
logger.info("Creating role '%s'", self.config.user)
|
|
if self.dry_run:
|
|
logger.info(
|
|
"Dry run: would create role '%s'. Run without --dry-run to apply.",
|
|
self.config.user,
|
|
)
|
|
return
|
|
try:
|
|
if self.config.password:
|
|
cursor.execute(
|
|
sql.SQL(
|
|
"CREATE ROLE {} WITH LOGIN PASSWORD %s"
|
|
).format(sql.Identifier(self.config.user)),
|
|
(self.config.password,),
|
|
)
|
|
else:
|
|
cursor.execute(
|
|
sql.SQL("CREATE ROLE {} WITH LOGIN").format(
|
|
sql.Identifier(self.config.user)
|
|
)
|
|
)
|
|
except psycopg2.Error as exc:
|
|
message = (
|
|
"Failed to create role '%s'. Review admin privileges and rerun."
|
|
) % self.config.user
|
|
logger.error(message)
|
|
raise RuntimeError(message) from exc
|
|
else:
|
|
rollback_label = f"drop role {self.config.user}"
|
|
self._register_rollback(
|
|
rollback_label,
|
|
lambda role=self.config.user: self._drop_role(role),
|
|
)
|
|
else:
|
|
logger.info("Role '%s' already present", self.config.user)
|
|
|
|
try:
|
|
role_conn = self._admin_connection(self.config.database)
|
|
except RuntimeError:
|
|
logger.error(
|
|
"Unable to connect to application database '%s' while granting privileges to role '%s'",
|
|
self.config.database,
|
|
self.config.user,
|
|
)
|
|
raise
|
|
|
|
if self.dry_run:
|
|
logger.info(
|
|
"Dry run: would grant privileges on schema/database to role '%s'.",
|
|
self.config.user,
|
|
)
|
|
return
|
|
|
|
with role_conn as conn:
|
|
conn.autocommit = True
|
|
with conn.cursor() as cursor:
|
|
schema_name = self.config.schema or "public"
|
|
schema_identifier = sql.Identifier(schema_name)
|
|
role_identifier = sql.Identifier(self.config.user)
|
|
|
|
try:
|
|
cursor.execute(
|
|
sql.SQL("GRANT CONNECT ON DATABASE {} TO {}").format(
|
|
sql.Identifier(self.config.database),
|
|
role_identifier,
|
|
)
|
|
)
|
|
cursor.execute(
|
|
sql.SQL("GRANT USAGE ON SCHEMA {} TO {}").format(
|
|
schema_identifier,
|
|
role_identifier,
|
|
)
|
|
)
|
|
cursor.execute(
|
|
sql.SQL("GRANT CREATE ON SCHEMA {} TO {}").format(
|
|
schema_identifier,
|
|
role_identifier,
|
|
)
|
|
)
|
|
cursor.execute(
|
|
sql.SQL(
|
|
"GRANT ALL PRIVILEGES ON ALL TABLES IN SCHEMA {} TO {}"
|
|
).format(
|
|
schema_identifier,
|
|
role_identifier,
|
|
)
|
|
)
|
|
cursor.execute(
|
|
sql.SQL(
|
|
"GRANT ALL PRIVILEGES ON ALL SEQUENCES IN SCHEMA {} TO {}"
|
|
).format(
|
|
schema_identifier,
|
|
role_identifier,
|
|
)
|
|
)
|
|
cursor.execute(
|
|
sql.SQL(
|
|
"ALTER DEFAULT PRIVILEGES IN SCHEMA {} GRANT SELECT, INSERT, UPDATE, DELETE ON TABLES TO {}"
|
|
).format(
|
|
schema_identifier,
|
|
role_identifier,
|
|
)
|
|
)
|
|
cursor.execute(
|
|
sql.SQL(
|
|
"ALTER DEFAULT PRIVILEGES IN SCHEMA {} GRANT USAGE, SELECT ON SEQUENCES TO {}"
|
|
).format(
|
|
schema_identifier,
|
|
role_identifier,
|
|
)
|
|
)
|
|
except psycopg2.Error as exc:
|
|
message = (
|
|
"Failed to grant privileges to role '%s' in schema '%s'."
|
|
" Rerun with --dry-run for more context."
|
|
) % (self.config.user, schema_name)
|
|
logger.error(message)
|
|
raise RuntimeError(message) from exc
|
|
logger.info(
|
|
"Granted privileges on schema '%s' to role '%s'",
|
|
schema_name,
|
|
self.config.user,
|
|
)
|
|
rollback_label = f"revoke privileges for {self.config.user}"
|
|
self._register_rollback(
|
|
rollback_label,
|
|
lambda schema=schema_name: self._revoke_role_privileges(
|
|
schema_name=schema
|
|
),
|
|
)
|
|
|
|
def ensure_schema(self) -> None:
|
|
"""Create the configured schema when it does not exist."""
|
|
|
|
schema_name = self.config.schema
|
|
if not schema_name or schema_name == "public":
|
|
logger.info("Using default schema 'public'; nothing to ensure")
|
|
return
|
|
|
|
logger.info("Ensuring schema '%s' exists", schema_name)
|
|
with self._admin_connection(self.config.database) as conn:
|
|
conn.autocommit = True
|
|
with conn.cursor() as cursor:
|
|
cursor.execute(
|
|
sql.SQL(
|
|
"SELECT 1 FROM information_schema.schemata WHERE schema_name = %s"
|
|
),
|
|
(schema_name,),
|
|
)
|
|
exists = cursor.fetchone() is not None
|
|
if not exists:
|
|
if self.dry_run:
|
|
logger.info(
|
|
"Dry run: would create schema '%s'",
|
|
schema_name,
|
|
)
|
|
else:
|
|
cursor.execute(
|
|
sql.SQL("CREATE SCHEMA {}").format(
|
|
sql.Identifier(schema_name)
|
|
)
|
|
)
|
|
logger.info("Created schema '%s'", schema_name)
|
|
try:
|
|
if self.dry_run:
|
|
logger.info(
|
|
"Dry run: would set schema '%s' owner to '%s'",
|
|
schema_name,
|
|
self.config.user,
|
|
)
|
|
else:
|
|
cursor.execute(
|
|
sql.SQL("ALTER SCHEMA {} OWNER TO {}").format(
|
|
sql.Identifier(schema_name),
|
|
sql.Identifier(self.config.user),
|
|
)
|
|
)
|
|
except errors.UndefinedObject:
|
|
logger.warning(
|
|
"Role '%s' not found when assigning ownership to schema '%s'."
|
|
" Run --ensure-role after creating the schema.",
|
|
self.config.user,
|
|
schema_name,
|
|
)
|
|
|
|
def application_role_exists(self) -> bool:
|
|
try:
|
|
with self._admin_connection(self.config.admin_database) as conn:
|
|
with conn.cursor() as cursor:
|
|
try:
|
|
cursor.execute(
|
|
"SELECT 1 FROM pg_roles WHERE rolname = %s",
|
|
(self.config.user,),
|
|
)
|
|
except psycopg2.Error as exc:
|
|
message = (
|
|
"Unable to inspect existing roles while checking for role '%s'."
|
|
" Verify admin permissions."
|
|
) % self.config.user
|
|
logger.error(message)
|
|
raise RuntimeError(message) from exc
|
|
return cursor.fetchone() is not None
|
|
except RuntimeError:
|
|
raise
|
|
|
|
def _admin_connection(self, database: Optional[str] = None) -> PGConnection:
|
|
target_db = database or self.config.admin_database
|
|
dsn = self.config.admin_dsn(database)
|
|
descriptor = self._describe_connection(
|
|
self.config.admin_user, target_db
|
|
)
|
|
try:
|
|
return psycopg2.connect(dsn)
|
|
except psycopg2.Error as exc:
|
|
raise RuntimeError(
|
|
"Unable to establish admin connection. " f"Target: {descriptor}"
|
|
) from exc
|
|
|
|
def _application_connection(self) -> PGConnection:
|
|
dsn = self.config.application_dsn()
|
|
descriptor = self._describe_connection(
|
|
self.config.user, self.config.database
|
|
)
|
|
try:
|
|
return psycopg2.connect(dsn)
|
|
except psycopg2.Error as exc:
|
|
raise RuntimeError(
|
|
"Unable to establish application connection. "
|
|
f"Target: {descriptor}"
|
|
) from exc
|
|
|
|
def initialize_schema(self) -> None:
|
|
"""Create database objects from SQLAlchemy metadata if missing."""
|
|
|
|
self._ensure_models_loaded()
|
|
logger.info("Ensuring SQLAlchemy metadata is reflected in database")
|
|
engine = create_engine(self.config.application_dsn(), future=True)
|
|
try:
|
|
inspector = inspect(engine)
|
|
existing_tables = set(
|
|
inspector.get_table_names(schema=self.config.schema)
|
|
)
|
|
metadata_tables = set(Base.metadata.tables.keys())
|
|
missing_tables = sorted(metadata_tables - existing_tables)
|
|
|
|
if missing_tables:
|
|
logger.info("Pending tables: %s", ", ".join(missing_tables))
|
|
else:
|
|
logger.info("All tables already exist")
|
|
|
|
if self.dry_run:
|
|
if missing_tables:
|
|
logger.info("Dry run: skipping creation of pending tables")
|
|
return
|
|
|
|
Base.metadata.create_all(bind=engine, checkfirst=True)
|
|
finally:
|
|
engine.dispose()
|
|
|
|
logger.info("Schema initialization complete")
|
|
|
|
def _ensure_models_loaded(self) -> None:
|
|
if self._models_loaded:
|
|
return
|
|
|
|
package = importlib.import_module("models")
|
|
for module_info in pkgutil.iter_modules(package.__path__):
|
|
importlib.import_module(f"{package.__name__}.{module_info.name}")
|
|
self._models_loaded = True
|
|
|
|
def run_migrations(
|
|
self, migrations_dir: Optional[Path | str] = None
|
|
) -> None:
|
|
"""Execute pending SQL migrations in chronological order."""
|
|
|
|
directory = (
|
|
Path(migrations_dir)
|
|
if migrations_dir is not None
|
|
else DEFAULT_MIGRATIONS_DIR
|
|
)
|
|
directory = directory.resolve()
|
|
|
|
if not directory.exists():
|
|
logger.warning("Migrations directory '%s' not found", directory)
|
|
return
|
|
|
|
migration_files = sorted(directory.glob("*.sql"))
|
|
if not migration_files:
|
|
logger.info("No migration scripts found in '%s'", directory)
|
|
return
|
|
|
|
baseline_name = "000_base.sql"
|
|
baseline_path = directory / baseline_name
|
|
|
|
schema_name = self.config.schema or "public"
|
|
|
|
with self._application_connection() as conn:
|
|
conn.autocommit = True
|
|
with conn.cursor() as cursor:
|
|
table_exists = self._migrations_table_exists(
|
|
cursor, schema_name
|
|
)
|
|
if not table_exists:
|
|
if self.dry_run:
|
|
logger.info(
|
|
"Dry run: would create migration history table %s.%s",
|
|
schema_name,
|
|
MIGRATIONS_TABLE,
|
|
)
|
|
applied: set[str] = set()
|
|
else:
|
|
self._create_migrations_table(cursor, schema_name)
|
|
logger.info(
|
|
"Created migration history table %s.%s",
|
|
schema_name,
|
|
MIGRATIONS_TABLE,
|
|
)
|
|
applied = set()
|
|
else:
|
|
applied = self._fetch_applied_migrations(
|
|
cursor, schema_name
|
|
)
|
|
|
|
if baseline_path.exists() and baseline_name not in applied:
|
|
if self.dry_run:
|
|
logger.info(
|
|
"Dry run: baseline migration '%s' pending; would apply and mark legacy files",
|
|
baseline_name,
|
|
)
|
|
else:
|
|
logger.info(
|
|
"Baseline migration '%s' pending; applying and marking older migrations",
|
|
baseline_name,
|
|
)
|
|
try:
|
|
baseline_applied = self._apply_migration_file(
|
|
cursor, schema_name, baseline_path
|
|
)
|
|
except Exception:
|
|
logger.error(
|
|
"Failed while applying baseline migration '%s'."
|
|
" Review the migration contents and rerun with --dry-run for diagnostics.",
|
|
baseline_name,
|
|
exc_info=True,
|
|
)
|
|
raise
|
|
applied.add(baseline_applied)
|
|
legacy_files = [
|
|
path
|
|
for path in migration_files
|
|
if path.name != baseline_name
|
|
]
|
|
for legacy in legacy_files:
|
|
if legacy.name not in applied:
|
|
try:
|
|
cursor.execute(
|
|
sql.SQL(
|
|
"INSERT INTO {} (filename, applied_at) VALUES (%s, NOW())"
|
|
).format(
|
|
sql.Identifier(
|
|
schema_name,
|
|
MIGRATIONS_TABLE,
|
|
)
|
|
),
|
|
(legacy.name,),
|
|
)
|
|
except Exception:
|
|
logger.error(
|
|
"Unable to record legacy migration '%s' after baseline application."
|
|
" Check schema_migrations table in schema '%s' for partial state.",
|
|
legacy.name,
|
|
schema_name,
|
|
exc_info=True,
|
|
)
|
|
raise
|
|
applied.add(legacy.name)
|
|
logger.info(
|
|
"Marked legacy migration '%s' as applied via baseline",
|
|
legacy.name,
|
|
)
|
|
|
|
pending = [
|
|
path for path in migration_files if path.name not in applied
|
|
]
|
|
|
|
if not pending:
|
|
logger.info("No pending migrations")
|
|
return
|
|
|
|
logger.info(
|
|
"Pending migrations: %s",
|
|
", ".join(path.name for path in pending),
|
|
)
|
|
|
|
if self.dry_run:
|
|
logger.info("Dry run: skipping migration execution")
|
|
return
|
|
|
|
for path in pending:
|
|
self._apply_migration_file(cursor, schema_name, path)
|
|
|
|
logger.info("Applied %d migrations", len(pending))
|
|
|
|
def _apply_migration_file(
|
|
self,
|
|
cursor,
|
|
schema_name: str,
|
|
path: Path,
|
|
) -> str:
|
|
logger.info("Applying migration '%s'", path.name)
|
|
sql_text = path.read_text(encoding="utf-8")
|
|
try:
|
|
cursor.execute(sql_text)
|
|
cursor.execute(
|
|
sql.SQL(
|
|
"INSERT INTO {} (filename, applied_at) VALUES (%s, NOW())"
|
|
).format(sql.Identifier(schema_name, MIGRATIONS_TABLE)),
|
|
(path.name,),
|
|
)
|
|
return path.name
|
|
except Exception:
|
|
logger.exception("Failed to apply migration '%s'", path.name)
|
|
raise
|
|
|
|
def _migrations_table_exists(self, cursor, schema_name: str) -> bool:
|
|
cursor.execute(
|
|
"""
|
|
SELECT 1
|
|
FROM information_schema.tables
|
|
WHERE table_schema = %s AND table_name = %s
|
|
""",
|
|
(schema_name, MIGRATIONS_TABLE),
|
|
)
|
|
return cursor.fetchone() is not None
|
|
|
|
def _create_migrations_table(self, cursor, schema_name: str) -> None:
|
|
cursor.execute(
|
|
sql.SQL(
|
|
"CREATE TABLE IF NOT EXISTS {} ("
|
|
"filename TEXT PRIMARY KEY,"
|
|
"applied_at TIMESTAMPTZ NOT NULL DEFAULT NOW()"
|
|
")"
|
|
).format(sql.Identifier(schema_name, MIGRATIONS_TABLE))
|
|
)
|
|
|
|
def _fetch_applied_migrations(self, cursor, schema_name: str) -> set[str]:
|
|
cursor.execute(
|
|
sql.SQL("SELECT filename FROM {} ORDER BY filename").format(
|
|
sql.Identifier(schema_name, MIGRATIONS_TABLE)
|
|
)
|
|
)
|
|
return {row[0] for row in cursor.fetchall()}
|
|
|
|
def seed_baseline_data(self, *, dry_run: bool) -> None:
|
|
"""Seed reference data such as currencies."""
|
|
|
|
from scripts import seed_data
|
|
|
|
seed_args = argparse.Namespace(
|
|
currencies=True,
|
|
units=True,
|
|
defaults=False,
|
|
dry_run=dry_run,
|
|
verbose=0,
|
|
)
|
|
seed_data.run_with_namespace(seed_args, config=self.config)
|
|
|
|
if dry_run:
|
|
logger.info("Dry run: skipped seed verification")
|
|
return
|
|
|
|
expected_currencies = {
|
|
code for code, *_ in getattr(seed_data, "CURRENCY_SEEDS", ())
|
|
}
|
|
expected_units = {
|
|
code
|
|
for code, *_ in getattr(seed_data, "MEASUREMENT_UNIT_SEEDS", ())
|
|
}
|
|
self._verify_seeded_data(
|
|
expected_currency_codes=expected_currencies,
|
|
expected_unit_codes=expected_units,
|
|
)
|
|
|
|
def _verify_seeded_data(
|
|
self,
|
|
*,
|
|
expected_currency_codes: set[str],
|
|
expected_unit_codes: set[str],
|
|
) -> None:
|
|
if not expected_currency_codes and not expected_unit_codes:
|
|
logger.info("No seed datasets configured for verification")
|
|
return
|
|
|
|
with self._application_connection() as conn:
|
|
with conn.cursor() as cursor:
|
|
if expected_currency_codes:
|
|
cursor.execute(
|
|
"SELECT code, is_active FROM currency WHERE code = ANY(%s)",
|
|
(list(expected_currency_codes),),
|
|
)
|
|
rows = cursor.fetchall()
|
|
found_codes = {row[0] for row in rows}
|
|
missing_codes = sorted(
|
|
expected_currency_codes - found_codes
|
|
)
|
|
if missing_codes:
|
|
message = (
|
|
"Missing expected currencies after seeding: %s. "
|
|
"Run scripts/seed_data.py --currencies to restore them."
|
|
) % ", ".join(missing_codes)
|
|
logger.error(message)
|
|
raise RuntimeError(message)
|
|
|
|
logger.info(
|
|
"Verified %d seeded currencies present",
|
|
len(found_codes),
|
|
)
|
|
|
|
default_status = next(
|
|
(row[1] for row in rows if row[0] == "USD"), None
|
|
)
|
|
if default_status is False:
|
|
message = (
|
|
"Default currency 'USD' is inactive after seeding. "
|
|
"Reactivate it or rerun the seeding command."
|
|
)
|
|
logger.error(message)
|
|
raise RuntimeError(message)
|
|
elif default_status is None:
|
|
message = (
|
|
"Default currency 'USD' not found after seeding. "
|
|
"Ensure baseline migration 000_base.sql ran successfully."
|
|
)
|
|
logger.error(message)
|
|
raise RuntimeError(message)
|
|
else:
|
|
logger.info("Verified default currency 'USD' active")
|
|
|
|
if expected_unit_codes:
|
|
try:
|
|
cursor.execute(
|
|
"SELECT code, is_active FROM measurement_unit WHERE code = ANY(%s)",
|
|
(list(expected_unit_codes),),
|
|
)
|
|
except errors.UndefinedTable:
|
|
conn.rollback()
|
|
message = (
|
|
"measurement_unit table not found during seed verification. "
|
|
"Ensure baseline migration 000_base.sql has been applied."
|
|
)
|
|
logger.error(message)
|
|
raise RuntimeError(message)
|
|
else:
|
|
rows = cursor.fetchall()
|
|
found_units = {row[0] for row in rows}
|
|
missing_units = sorted(
|
|
expected_unit_codes - found_units
|
|
)
|
|
if missing_units:
|
|
message = (
|
|
"Missing expected measurement units after seeding: %s. "
|
|
"Run scripts/seed_data.py --units to restore them."
|
|
) % ", ".join(missing_units)
|
|
logger.error(message)
|
|
raise RuntimeError(message)
|
|
|
|
inactive_units = sorted(
|
|
row[0] for row in rows if not bool(row[1])
|
|
)
|
|
if inactive_units:
|
|
message = (
|
|
"Measurement units inactive after seeding: %s. "
|
|
"Reactivate them or rerun unit seeding."
|
|
) % ", ".join(inactive_units)
|
|
logger.error(message)
|
|
raise RuntimeError(message)
|
|
|
|
logger.info(
|
|
"Verified %d measurement units present",
|
|
len(found_units),
|
|
)
|
|
|
|
logger.info("Seed verification complete")
|
|
|
|
def _drop_database(self, database: str) -> None:
|
|
logger.warning("Rollback: dropping database '%s'", database)
|
|
with self._admin_connection(self.config.admin_database) as conn:
|
|
conn.autocommit = True
|
|
with conn.cursor() as cursor:
|
|
cursor.execute(
|
|
"SELECT pg_terminate_backend(pid) FROM pg_stat_activity WHERE datname = %s",
|
|
(database,),
|
|
)
|
|
cursor.execute(
|
|
sql.SQL("DROP DATABASE IF EXISTS {}").format(
|
|
sql.Identifier(database)
|
|
)
|
|
)
|
|
|
|
def _drop_role(self, role: str) -> None:
|
|
logger.warning("Rollback: dropping role '%s'", role)
|
|
with self._admin_connection(self.config.admin_database) as conn:
|
|
conn.autocommit = True
|
|
with conn.cursor() as cursor:
|
|
cursor.execute(
|
|
sql.SQL("DROP ROLE IF EXISTS {}").format(
|
|
sql.Identifier(role)
|
|
)
|
|
)
|
|
|
|
def _revoke_role_privileges(self, *, schema_name: str) -> None:
|
|
logger.warning(
|
|
"Rollback: revoking privileges on schema '%s' for role '%s'",
|
|
schema_name,
|
|
self.config.user,
|
|
)
|
|
with self._admin_connection(self.config.database) as conn:
|
|
conn.autocommit = True
|
|
with conn.cursor() as cursor:
|
|
cursor.execute(
|
|
sql.SQL(
|
|
"REVOKE ALL PRIVILEGES ON ALL TABLES IN SCHEMA {} FROM {}"
|
|
).format(
|
|
sql.Identifier(schema_name),
|
|
sql.Identifier(self.config.user),
|
|
)
|
|
)
|
|
cursor.execute(
|
|
sql.SQL(
|
|
"REVOKE ALL PRIVILEGES ON ALL SEQUENCES IN SCHEMA {} FROM {}"
|
|
).format(
|
|
sql.Identifier(schema_name),
|
|
sql.Identifier(self.config.user),
|
|
)
|
|
)
|
|
cursor.execute(
|
|
sql.SQL(
|
|
"ALTER DEFAULT PRIVILEGES IN SCHEMA {} REVOKE SELECT, INSERT, UPDATE, DELETE ON TABLES FROM {}"
|
|
).format(
|
|
sql.Identifier(schema_name),
|
|
sql.Identifier(self.config.user),
|
|
)
|
|
)
|
|
cursor.execute(
|
|
sql.SQL(
|
|
"ALTER DEFAULT PRIVILEGES IN SCHEMA {} REVOKE USAGE, SELECT ON SEQUENCES FROM {}"
|
|
).format(
|
|
sql.Identifier(schema_name),
|
|
sql.Identifier(self.config.user),
|
|
)
|
|
)
|
|
|
|
|
|
def parse_args() -> argparse.Namespace:
|
|
parser = argparse.ArgumentParser(description="Bootstrap CalMiner database")
|
|
parser.add_argument(
|
|
"--ensure-database",
|
|
action="store_true",
|
|
help="Create the application database when it does not already exist.",
|
|
)
|
|
parser.add_argument(
|
|
"--ensure-role",
|
|
action="store_true",
|
|
help="Create the application role and grant necessary privileges.",
|
|
)
|
|
parser.add_argument(
|
|
"--ensure-schema",
|
|
action="store_true",
|
|
help="Create the configured schema if it does not exist.",
|
|
)
|
|
parser.add_argument(
|
|
"--initialize-schema",
|
|
action="store_true",
|
|
help="Create missing tables based on SQLAlchemy models.",
|
|
)
|
|
parser.add_argument(
|
|
"--run-migrations",
|
|
action="store_true",
|
|
help="Execute pending SQL migrations.",
|
|
)
|
|
parser.add_argument(
|
|
"--seed-data",
|
|
action="store_true",
|
|
help="Seed baseline reference data (currencies, etc.).",
|
|
)
|
|
parser.add_argument(
|
|
"--migrations-dir",
|
|
default=None,
|
|
help="Override the default migrations directory.",
|
|
)
|
|
parser.add_argument("--db-driver", help="Override DATABASE_DRIVER")
|
|
parser.add_argument("--db-host", help="Override DATABASE_HOST")
|
|
parser.add_argument("--db-port", type=int, help="Override DATABASE_PORT")
|
|
parser.add_argument("--db-name", help="Override DATABASE_NAME")
|
|
parser.add_argument("--db-user", help="Override DATABASE_USER")
|
|
parser.add_argument("--db-password", help="Override DATABASE_PASSWORD")
|
|
parser.add_argument("--db-schema", help="Override DATABASE_SCHEMA")
|
|
parser.add_argument(
|
|
"--admin-url",
|
|
help="Override DATABASE_ADMIN_URL for administrative operations",
|
|
)
|
|
parser.add_argument(
|
|
"--admin-user", help="Override DATABASE_SUPERUSER for admin ops"
|
|
)
|
|
parser.add_argument(
|
|
"--admin-password",
|
|
help="Override DATABASE_SUPERUSER_PASSWORD for admin ops",
|
|
)
|
|
parser.add_argument(
|
|
"--admin-db",
|
|
help="Override DATABASE_SUPERUSER_DB for admin ops",
|
|
)
|
|
parser.add_argument(
|
|
"--dry-run",
|
|
action="store_true",
|
|
help="Log actions without applying changes.",
|
|
)
|
|
parser.add_argument(
|
|
"--verbose",
|
|
"-v",
|
|
action="count",
|
|
default=0,
|
|
help="Increase logging verbosity",
|
|
)
|
|
return parser.parse_args()
|
|
|
|
|
|
def main() -> None:
|
|
args = parse_args()
|
|
level = logging.WARNING - (10 * min(args.verbose, 2))
|
|
logging.basicConfig(
|
|
level=max(level, logging.INFO), format="%(levelname)s %(message)s"
|
|
)
|
|
|
|
override_args: dict[str, Optional[str]] = {
|
|
"DATABASE_DRIVER": args.db_driver,
|
|
"DATABASE_HOST": args.db_host,
|
|
"DATABASE_NAME": args.db_name,
|
|
"DATABASE_USER": args.db_user,
|
|
"DATABASE_PASSWORD": args.db_password,
|
|
"DATABASE_SCHEMA": args.db_schema,
|
|
"DATABASE_ADMIN_URL": args.admin_url,
|
|
"DATABASE_SUPERUSER": args.admin_user,
|
|
"DATABASE_SUPERUSER_PASSWORD": args.admin_password,
|
|
"DATABASE_SUPERUSER_DB": args.admin_db,
|
|
}
|
|
if args.db_port is not None:
|
|
override_args["DATABASE_PORT"] = str(args.db_port)
|
|
|
|
config = DatabaseConfig.from_env(overrides=override_args)
|
|
setup = DatabaseSetup(config, dry_run=args.dry_run)
|
|
|
|
admin_tasks_requested = (
|
|
args.ensure_database or args.ensure_role or args.ensure_schema
|
|
)
|
|
if admin_tasks_requested:
|
|
setup.validate_admin_connection()
|
|
|
|
app_validated = False
|
|
|
|
def ensure_application_connection_for(operation: str) -> bool:
|
|
nonlocal app_validated
|
|
if app_validated:
|
|
return True
|
|
if setup.dry_run and not setup.application_role_exists():
|
|
logger.info(
|
|
"Dry run: skipping %s because application role '%s' does not exist yet.",
|
|
operation,
|
|
setup.config.user,
|
|
)
|
|
return False
|
|
setup.validate_application_connection()
|
|
app_validated = True
|
|
return True
|
|
|
|
should_run_migrations = args.run_migrations
|
|
auto_run_migrations_reason: Optional[str] = None
|
|
if args.seed_data and not should_run_migrations:
|
|
should_run_migrations = True
|
|
auto_run_migrations_reason = "Seed data requested without explicit --run-migrations; applying migrations first."
|
|
|
|
try:
|
|
if args.ensure_database:
|
|
setup.ensure_database()
|
|
if args.ensure_role:
|
|
setup.ensure_role()
|
|
if args.ensure_schema:
|
|
setup.ensure_schema()
|
|
|
|
if args.initialize_schema:
|
|
if ensure_application_connection_for(
|
|
"SQLAlchemy schema initialization"
|
|
):
|
|
setup.initialize_schema()
|
|
if should_run_migrations:
|
|
if ensure_application_connection_for("migration execution"):
|
|
if auto_run_migrations_reason:
|
|
logger.info(auto_run_migrations_reason)
|
|
migrations_path = (
|
|
Path(args.migrations_dir) if args.migrations_dir else None
|
|
)
|
|
setup.run_migrations(migrations_path)
|
|
if args.seed_data:
|
|
if ensure_application_connection_for("baseline data seeding"):
|
|
setup.seed_baseline_data(dry_run=args.dry_run)
|
|
except Exception:
|
|
if not setup.dry_run:
|
|
setup.execute_rollbacks()
|
|
raise
|
|
finally:
|
|
if not setup.dry_run:
|
|
setup.clear_rollbacks()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|