Files
calminer/scripts/setup_database.py
zwitschi 5b1322ddbc
Some checks failed
Run Tests / test (push) Failing after 1m51s
feat: Add application-level settings for CSS color management
- Introduced a new table `application_setting` to store configurable application options.
- Implemented functions to manage CSS color settings, including loading, updating, and reading environment overrides.
- Added a new settings view to render and manage theme colors.
- Updated UI to include a settings page with theme color management and environment overrides display.
- Enhanced CSS styles for the settings page and sidebar navigation.
- Created unit and end-to-end tests for the new settings functionality and CSS management.
2025-10-25 19:20:52 +02:00

1189 lines
45 KiB
Python

"""Utilities to bootstrap the CalMiner PostgreSQL database.
This script is designed to be idempotent. Each step checks the existing
state before attempting to modify it so repeated executions are safe.
Environment variables (with defaults) used when establishing connections:
* ``DATABASE_DRIVER`` (``postgresql``)
* ``DATABASE_HOST`` (required)
* ``DATABASE_PORT`` (``5432``)
* ``DATABASE_NAME`` (required)
* ``DATABASE_USER`` (required)
* ``DATABASE_PASSWORD`` (optional, required for password auth)
* ``DATABASE_SCHEMA`` (``public``)
* ``DATABASE_ADMIN_URL`` (overrides individual admin settings)
* ``DATABASE_SUPERUSER`` (falls back to ``DATABASE_USER`` or ``postgres``)
* ``DATABASE_SUPERUSER_PASSWORD`` (falls back to ``DATABASE_PASSWORD``)
* ``DATABASE_SUPERUSER_DB`` (``postgres``)
Set ``DATABASE_URL`` if other parts of the application rely on a single
connection string; this script will still honor the granular inputs above.
"""
from __future__ import annotations
import argparse
import importlib
import logging
import os
import pkgutil
import sys
from dataclasses import dataclass
from pathlib import Path
from typing import Callable, Optional, cast
from urllib.parse import quote_plus, urlencode
import psycopg2
from psycopg2 import errors
from psycopg2 import sql
from psycopg2 import extensions
from psycopg2.extensions import connection as PGConnection, parse_dsn
from dotenv import load_dotenv
from sqlalchemy import create_engine, inspect
ROOT_DIR = Path(__file__).resolve().parents[1]
if str(ROOT_DIR) not in sys.path:
sys.path.insert(0, str(ROOT_DIR))
from config.database import Base
logger = logging.getLogger(__name__)
SCRIPTS_DIR = Path(__file__).resolve().parent
DEFAULT_MIGRATIONS_DIR = SCRIPTS_DIR / "migrations"
MIGRATIONS_TABLE = "schema_migrations"
@dataclass(slots=True)
class DatabaseConfig:
"""Configuration required to manage the application database."""
driver: str
host: str
port: int
database: str
user: str
password: Optional[str]
schema: Optional[str]
admin_user: str
admin_password: Optional[str]
admin_database: str = "postgres"
@classmethod
def from_env(
cls,
overrides: Optional[dict[str, Optional[str]]] = None,
) -> "DatabaseConfig":
load_dotenv()
override_map: dict[str, Optional[str]] = dict(overrides or {})
def _get(name: str, default: Optional[str] = None) -> Optional[str]:
if name in override_map and override_map[name] is not None:
return override_map[name]
env_value = os.getenv(name)
if env_value is not None:
return env_value
return default
driver = _get("DATABASE_DRIVER", "postgresql")
host = _get("DATABASE_HOST")
port_value = _get("DATABASE_PORT", "5432")
database = _get("DATABASE_NAME")
user = _get("DATABASE_USER")
password = _get("DATABASE_PASSWORD")
schema = _get("DATABASE_SCHEMA", "public")
try:
port = int(port_value) if port_value is not None else 5432
except ValueError as exc:
raise RuntimeError(
"Invalid DATABASE_PORT value: expected integer, got"
f" '{port_value}'"
) from exc
admin_url = _get("DATABASE_ADMIN_URL")
if admin_url:
admin_conninfo = parse_dsn(admin_url)
admin_user = admin_conninfo.get("user") or user or "postgres"
admin_password = admin_conninfo.get("password")
admin_database = admin_conninfo.get("dbname") or "postgres"
host = admin_conninfo.get("host") or host
port = int(admin_conninfo.get("port") or port)
else:
admin_user = _get("DATABASE_SUPERUSER", user or "postgres")
admin_password = _get("DATABASE_SUPERUSER_PASSWORD", password)
admin_database = _get("DATABASE_SUPERUSER_DB", "postgres")
missing = [
name
for name, value in (
("DATABASE_HOST", host),
("DATABASE_NAME", database),
("DATABASE_USER", user),
)
if not value
]
if missing:
raise RuntimeError(
"Missing required database configuration: " +
", ".join(missing)
)
host = cast(str, host)
database = cast(str, database)
user = cast(str, user)
driver = cast(str, driver)
admin_user = cast(str, admin_user)
admin_database = cast(str, admin_database)
return cls(
driver=driver,
host=host,
port=port,
database=database,
user=user,
password=password,
schema=schema,
admin_user=admin_user,
admin_password=admin_password,
admin_database=admin_database,
)
def admin_dsn(self, database: Optional[str] = None) -> str:
target_db = database or self.admin_database
return self._compose_url(
user=self.admin_user,
password=self.admin_password,
database=target_db,
schema=None,
)
def application_dsn(self) -> str:
"""Return a SQLAlchemy URL for connecting as the application role."""
return self._compose_url(
user=self.user,
password=self.password,
database=self.database,
schema=self.schema,
)
def _compose_url(
self,
*,
user: Optional[str],
password: Optional[str],
database: str,
schema: Optional[str],
) -> str:
auth = ""
if user:
encoded_user = quote_plus(user)
if password:
encoded_pass = quote_plus(password)
auth = f"{encoded_user}:{encoded_pass}@"
else:
auth = f"{encoded_user}@"
host = self.host
if ":" in host and not host.startswith("["):
host = f"[{host}]"
host_port = host
if self.port:
host_port = f"{host}:{self.port}"
url = f"{self.driver}://{auth}{host_port}/{database}"
params = {}
if schema and schema.strip() and schema != "public":
params["options"] = f"-csearch_path={schema}"
if params:
url = f"{url}?{urlencode(params, quote_via=quote_plus)}"
return url
class DatabaseSetup:
"""Encapsulates the full setup workflow."""
def __init__(self, config: DatabaseConfig, *, dry_run: bool = False) -> None:
self.config = config
self.dry_run = dry_run
self._models_loaded = False
self._rollback_actions: list[tuple[str, Callable[[], None]]] = []
def _register_rollback(self, label: str, action: Callable[[], None]) -> None:
if self.dry_run:
return
self._rollback_actions.append((label, action))
def execute_rollbacks(self) -> None:
if not self._rollback_actions:
logger.info("No rollback actions registered; nothing to undo.")
return
logger.warning(
"Attempting rollback of %d action(s)", len(self._rollback_actions)
)
for label, action in reversed(self._rollback_actions):
try:
logger.warning("Rollback step: %s", label)
action()
except Exception:
logger.exception("Rollback action '%s' failed", label)
self._rollback_actions.clear()
def clear_rollbacks(self) -> None:
self._rollback_actions.clear()
def _describe_connection(self, user: str, database: str) -> str:
return f"{user}@{self.config.host}:{self.config.port}/{database}"
def validate_admin_connection(self) -> None:
descriptor = self._describe_connection(
self.config.admin_user, self.config.admin_database
)
logger.info("Validating admin connection (%s)", descriptor)
try:
with self._admin_connection(self.config.admin_database) as conn:
with conn.cursor() as cursor:
cursor.execute("SELECT 1")
except psycopg2.Error as exc:
raise RuntimeError(
"Unable to connect with admin credentials. "
"Check DATABASE_ADMIN_URL or DATABASE_SUPERUSER settings."
f" Target: {descriptor}"
) from exc
logger.info("Admin connection verified (%s)", descriptor)
def validate_application_connection(self) -> None:
descriptor = self._describe_connection(
self.config.user, self.config.database
)
logger.info("Validating application connection (%s)", descriptor)
try:
with self._application_connection() as conn:
with conn.cursor() as cursor:
cursor.execute("SELECT 1")
except psycopg2.Error as exc:
raise RuntimeError(
"Unable to connect using application credentials. "
"Ensure the role exists and credentials are correct. "
f"Target: {descriptor}"
) from exc
logger.info("Application connection verified (%s)", descriptor)
def ensure_database(self) -> None:
"""Create the target database when it does not already exist."""
logger.info("Ensuring database '%s' exists", self.config.database)
try:
conn = self._admin_connection(self.config.admin_database)
except RuntimeError:
logger.error(
"Could not connect to admin database '%s' while creating '%s'.",
self.config.admin_database,
self.config.database,
)
raise
try:
conn.autocommit = True
conn.set_isolation_level(extensions.ISOLATION_LEVEL_AUTOCOMMIT)
cursor = conn.cursor()
try:
try:
cursor.execute(
"SELECT 1 FROM pg_database WHERE datname = %s",
(self.config.database,),
)
except psycopg2.Error as exc:
message = (
"Unable to inspect existing databases while ensuring '%s'."
" Verify admin permissions."
) % self.config.database
logger.error(message)
raise RuntimeError(message) from exc
exists = cursor.fetchone() is not None
if exists:
logger.info(
"Database '%s' already present", self.config.database
)
return
if self.dry_run:
logger.info(
"Dry run: would create database '%s'. Run without --dry-run to proceed.",
self.config.database,
)
return
try:
cursor.execute(
sql.SQL("CREATE DATABASE {} ENCODING 'UTF8'").format(
sql.Identifier(self.config.database)
)
)
except psycopg2.Error as exc:
message = (
"Failed to create database '%s'. Rerun with --dry-run for diagnostics"
) % self.config.database
logger.error(message)
raise RuntimeError(message) from exc
else:
rollback_label = f"drop database {self.config.database}"
self._register_rollback(
rollback_label,
lambda db=self.config.database: self._drop_database(db),
)
logger.info("Created database '%s'", self.config.database)
finally:
cursor.close()
finally:
conn.close()
def ensure_role(self) -> None:
"""Create the application role and assign privileges when missing."""
logger.info("Ensuring role '%s' exists", self.config.user)
try:
admin_conn = self._admin_connection(self.config.admin_database)
except RuntimeError:
logger.error(
"Unable to connect with admin credentials while ensuring role '%s'",
self.config.user,
)
raise
with admin_conn as conn:
conn.autocommit = True
with conn.cursor() as cursor:
try:
cursor.execute(
"SELECT 1 FROM pg_roles WHERE rolname = %s",
(self.config.user,),
)
except psycopg2.Error as exc:
message = (
"Unable to inspect existing roles while ensuring role '%s'."
" Verify admin permissions."
) % self.config.user
logger.error(message)
raise RuntimeError(message) from exc
role_exists = cursor.fetchone() is not None
if not role_exists:
logger.info("Creating role '%s'", self.config.user)
if self.dry_run:
logger.info(
"Dry run: would create role '%s'. Run without --dry-run to apply.",
self.config.user,
)
return
try:
if self.config.password:
cursor.execute(
sql.SQL("CREATE ROLE {} WITH LOGIN PASSWORD %s").format(
sql.Identifier(self.config.user)
),
(self.config.password,),
)
else:
cursor.execute(
sql.SQL("CREATE ROLE {} WITH LOGIN").format(
sql.Identifier(self.config.user)
)
)
except psycopg2.Error as exc:
message = (
"Failed to create role '%s'. Review admin privileges and rerun."
) % self.config.user
logger.error(message)
raise RuntimeError(message) from exc
else:
rollback_label = f"drop role {self.config.user}"
self._register_rollback(
rollback_label,
lambda role=self.config.user: self._drop_role(role),
)
else:
logger.info("Role '%s' already present", self.config.user)
try:
role_conn = self._admin_connection(self.config.database)
except RuntimeError:
logger.error(
"Unable to connect to application database '%s' while granting privileges to role '%s'",
self.config.database,
self.config.user,
)
raise
if self.dry_run:
logger.info(
"Dry run: would grant privileges on schema/database to role '%s'.",
self.config.user,
)
return
with role_conn as conn:
conn.autocommit = True
with conn.cursor() as cursor:
schema_name = self.config.schema or "public"
schema_identifier = sql.Identifier(schema_name)
role_identifier = sql.Identifier(self.config.user)
try:
cursor.execute(
sql.SQL("GRANT CONNECT ON DATABASE {} TO {}").format(
sql.Identifier(self.config.database),
role_identifier,
)
)
cursor.execute(
sql.SQL("GRANT USAGE ON SCHEMA {} TO {}").format(
schema_identifier,
role_identifier,
)
)
cursor.execute(
sql.SQL("GRANT CREATE ON SCHEMA {} TO {}").format(
schema_identifier,
role_identifier,
)
)
cursor.execute(
sql.SQL(
"GRANT ALL PRIVILEGES ON ALL TABLES IN SCHEMA {} TO {}"
).format(
schema_identifier,
role_identifier,
)
)
cursor.execute(
sql.SQL(
"GRANT ALL PRIVILEGES ON ALL SEQUENCES IN SCHEMA {} TO {}"
).format(
schema_identifier,
role_identifier,
)
)
cursor.execute(
sql.SQL(
"ALTER DEFAULT PRIVILEGES IN SCHEMA {} GRANT SELECT, INSERT, UPDATE, DELETE ON TABLES TO {}"
).format(
schema_identifier,
role_identifier,
)
)
cursor.execute(
sql.SQL(
"ALTER DEFAULT PRIVILEGES IN SCHEMA {} GRANT USAGE, SELECT ON SEQUENCES TO {}"
).format(
schema_identifier,
role_identifier,
)
)
except psycopg2.Error as exc:
message = (
"Failed to grant privileges to role '%s' in schema '%s'."
" Rerun with --dry-run for more context."
) % (self.config.user, schema_name)
logger.error(message)
raise RuntimeError(message) from exc
logger.info(
"Granted privileges on schema '%s' to role '%s'",
schema_name,
self.config.user,
)
rollback_label = f"revoke privileges for {self.config.user}"
self._register_rollback(
rollback_label,
lambda schema=schema_name: self._revoke_role_privileges(
schema_name=schema
),
)
def ensure_schema(self) -> None:
"""Create the configured schema when it does not exist."""
schema_name = self.config.schema
if not schema_name or schema_name == "public":
logger.info("Using default schema 'public'; nothing to ensure")
return
logger.info("Ensuring schema '%s' exists", schema_name)
with self._admin_connection(self.config.database) as conn:
conn.autocommit = True
with conn.cursor() as cursor:
cursor.execute(
sql.SQL(
"SELECT 1 FROM information_schema.schemata WHERE schema_name = %s"
),
(schema_name,),
)
exists = cursor.fetchone() is not None
if not exists:
if self.dry_run:
logger.info(
"Dry run: would create schema '%s'",
schema_name,
)
else:
cursor.execute(
sql.SQL("CREATE SCHEMA {}").format(
sql.Identifier(schema_name)
)
)
logger.info("Created schema '%s'", schema_name)
try:
if self.dry_run:
logger.info(
"Dry run: would set schema '%s' owner to '%s'",
schema_name,
self.config.user,
)
else:
cursor.execute(
sql.SQL("ALTER SCHEMA {} OWNER TO {}").format(
sql.Identifier(schema_name),
sql.Identifier(self.config.user),
)
)
except errors.UndefinedObject:
logger.warning(
"Role '%s' not found when assigning ownership to schema '%s'."
" Run --ensure-role after creating the schema.",
self.config.user,
schema_name,
)
def application_role_exists(self) -> bool:
try:
with self._admin_connection(self.config.admin_database) as conn:
with conn.cursor() as cursor:
try:
cursor.execute(
"SELECT 1 FROM pg_roles WHERE rolname = %s",
(self.config.user,),
)
except psycopg2.Error as exc:
message = (
"Unable to inspect existing roles while checking for role '%s'."
" Verify admin permissions."
) % self.config.user
logger.error(message)
raise RuntimeError(message) from exc
return cursor.fetchone() is not None
except RuntimeError:
raise
def _admin_connection(self, database: Optional[str] = None) -> PGConnection:
target_db = database or self.config.admin_database
dsn = self.config.admin_dsn(database)
descriptor = self._describe_connection(
self.config.admin_user, target_db
)
try:
return psycopg2.connect(dsn)
except psycopg2.Error as exc:
raise RuntimeError(
"Unable to establish admin connection. "
f"Target: {descriptor}"
) from exc
def _application_connection(self) -> PGConnection:
dsn = self.config.application_dsn()
descriptor = self._describe_connection(
self.config.user, self.config.database
)
try:
return psycopg2.connect(dsn)
except psycopg2.Error as exc:
raise RuntimeError(
"Unable to establish application connection. "
f"Target: {descriptor}"
) from exc
def initialize_schema(self) -> None:
"""Create database objects from SQLAlchemy metadata if missing."""
self._ensure_models_loaded()
logger.info("Ensuring SQLAlchemy metadata is reflected in database")
engine = create_engine(self.config.application_dsn(), future=True)
try:
inspector = inspect(engine)
existing_tables = set(
inspector.get_table_names(schema=self.config.schema)
)
metadata_tables = set(Base.metadata.tables.keys())
missing_tables = sorted(metadata_tables - existing_tables)
if missing_tables:
logger.info("Pending tables: %s", ", ".join(missing_tables))
else:
logger.info("All tables already exist")
if self.dry_run:
if missing_tables:
logger.info("Dry run: skipping creation of pending tables")
return
Base.metadata.create_all(bind=engine, checkfirst=True)
finally:
engine.dispose()
logger.info("Schema initialization complete")
def _ensure_models_loaded(self) -> None:
if self._models_loaded:
return
package = importlib.import_module("models")
for module_info in pkgutil.iter_modules(package.__path__):
importlib.import_module(f"{package.__name__}.{module_info.name}")
self._models_loaded = True
def run_migrations(self, migrations_dir: Optional[Path | str] = None) -> None:
"""Execute pending SQL migrations in chronological order."""
directory = (
Path(migrations_dir)
if migrations_dir is not None
else DEFAULT_MIGRATIONS_DIR
)
directory = directory.resolve()
if not directory.exists():
logger.warning("Migrations directory '%s' not found", directory)
return
migration_files = sorted(directory.glob("*.sql"))
if not migration_files:
logger.info("No migration scripts found in '%s'", directory)
return
baseline_name = "000_base.sql"
baseline_path = directory / baseline_name
schema_name = self.config.schema or "public"
with self._application_connection() as conn:
conn.autocommit = True
with conn.cursor() as cursor:
table_exists = self._migrations_table_exists(
cursor, schema_name)
if not table_exists:
if self.dry_run:
logger.info(
"Dry run: would create migration history table %s.%s",
schema_name,
MIGRATIONS_TABLE,
)
applied: set[str] = set()
else:
self._create_migrations_table(cursor, schema_name)
logger.info(
"Created migration history table %s.%s",
schema_name,
MIGRATIONS_TABLE,
)
applied = set()
else:
applied = self._fetch_applied_migrations(
cursor, schema_name)
if (
baseline_path.exists()
and baseline_name not in applied
):
if self.dry_run:
logger.info(
"Dry run: baseline migration '%s' pending; would apply and mark legacy files",
baseline_name,
)
else:
logger.info(
"Baseline migration '%s' pending; applying and marking older migrations",
baseline_name,
)
try:
baseline_applied = self._apply_migration_file(
cursor, schema_name, baseline_path
)
except Exception:
logger.error(
"Failed while applying baseline migration '%s'."
" Review the migration contents and rerun with --dry-run for diagnostics.",
baseline_name,
exc_info=True,
)
raise
applied.add(baseline_applied)
legacy_files = [
path
for path in migration_files
if path.name != baseline_name
]
for legacy in legacy_files:
if legacy.name not in applied:
try:
cursor.execute(
sql.SQL(
"INSERT INTO {} (filename, applied_at) VALUES (%s, NOW())"
).format(
sql.Identifier(
schema_name,
MIGRATIONS_TABLE,
)
),
(legacy.name,),
)
except Exception:
logger.error(
"Unable to record legacy migration '%s' after baseline application."
" Check schema_migrations table in schema '%s' for partial state.",
legacy.name,
schema_name,
exc_info=True,
)
raise
applied.add(legacy.name)
logger.info(
"Marked legacy migration '%s' as applied via baseline",
legacy.name,
)
pending = [
path
for path in migration_files
if path.name not in applied
]
if not pending:
logger.info("No pending migrations")
return
logger.info(
"Pending migrations: %s",
", ".join(path.name for path in pending),
)
if self.dry_run:
logger.info("Dry run: skipping migration execution")
return
for path in pending:
self._apply_migration_file(cursor, schema_name, path)
logger.info("Applied %d migrations", len(pending))
def _apply_migration_file(
self,
cursor,
schema_name: str,
path: Path,
) -> str:
logger.info("Applying migration '%s'", path.name)
sql_text = path.read_text(encoding="utf-8")
try:
cursor.execute(sql_text)
cursor.execute(
sql.SQL(
"INSERT INTO {} (filename, applied_at) VALUES (%s, NOW())"
).format(
sql.Identifier(schema_name, MIGRATIONS_TABLE)
),
(path.name,),
)
return path.name
except Exception:
logger.exception("Failed to apply migration '%s'", path.name)
raise
def _migrations_table_exists(self, cursor, schema_name: str) -> bool:
cursor.execute(
"""
SELECT 1
FROM information_schema.tables
WHERE table_schema = %s AND table_name = %s
""",
(schema_name, MIGRATIONS_TABLE),
)
return cursor.fetchone() is not None
def _create_migrations_table(self, cursor, schema_name: str) -> None:
cursor.execute(
sql.SQL(
"CREATE TABLE IF NOT EXISTS {} ("
"filename TEXT PRIMARY KEY,"
"applied_at TIMESTAMPTZ NOT NULL DEFAULT NOW()"
")"
).format(
sql.Identifier(schema_name, MIGRATIONS_TABLE)
)
)
def _fetch_applied_migrations(self, cursor, schema_name: str) -> set[str]:
cursor.execute(
sql.SQL("SELECT filename FROM {} ORDER BY filename").format(
sql.Identifier(schema_name, MIGRATIONS_TABLE)
)
)
return {row[0] for row in cursor.fetchall()}
def seed_baseline_data(self, *, dry_run: bool) -> None:
"""Seed reference data such as currencies."""
from scripts import seed_data
seed_args = argparse.Namespace(
currencies=True,
units=True,
defaults=False,
dry_run=dry_run,
verbose=0,
)
seed_data.run_with_namespace(seed_args, config=self.config)
if dry_run:
logger.info("Dry run: skipped seed verification")
return
expected_currencies = {
code for code, *_ in getattr(seed_data, "CURRENCY_SEEDS", ())
}
expected_units = {
code
for code, *_ in getattr(seed_data, "MEASUREMENT_UNIT_SEEDS", ())
}
self._verify_seeded_data(
expected_currency_codes=expected_currencies,
expected_unit_codes=expected_units,
)
def _verify_seeded_data(
self,
*,
expected_currency_codes: set[str],
expected_unit_codes: set[str],
) -> None:
if not expected_currency_codes and not expected_unit_codes:
logger.info("No seed datasets configured for verification")
return
with self._application_connection() as conn:
with conn.cursor() as cursor:
if expected_currency_codes:
cursor.execute(
"SELECT code, is_active FROM currency WHERE code = ANY(%s)",
(list(expected_currency_codes),),
)
rows = cursor.fetchall()
found_codes = {row[0] for row in rows}
missing_codes = sorted(
expected_currency_codes - found_codes
)
if missing_codes:
message = (
"Missing expected currencies after seeding: %s. "
"Run scripts/seed_data.py --currencies to restore them."
) % ", ".join(missing_codes)
logger.error(message)
raise RuntimeError(message)
logger.info(
"Verified %d seeded currencies present",
len(found_codes),
)
default_status = next(
(row[1] for row in rows if row[0] == "USD"), None
)
if default_status is False:
message = (
"Default currency 'USD' is inactive after seeding. "
"Reactivate it or rerun the seeding command."
)
logger.error(message)
raise RuntimeError(message)
elif default_status is None:
message = (
"Default currency 'USD' not found after seeding. "
"Ensure baseline migration 000_base.sql ran successfully."
)
logger.error(message)
raise RuntimeError(message)
else:
logger.info("Verified default currency 'USD' active")
if expected_unit_codes:
try:
cursor.execute(
"SELECT code, is_active FROM measurement_unit WHERE code = ANY(%s)",
(list(expected_unit_codes),),
)
except errors.UndefinedTable:
conn.rollback()
message = (
"measurement_unit table not found during seed verification. "
"Ensure baseline migration 000_base.sql has been applied."
)
logger.error(message)
raise RuntimeError(message)
else:
rows = cursor.fetchall()
found_units = {row[0] for row in rows}
missing_units = sorted(
expected_unit_codes - found_units
)
if missing_units:
message = (
"Missing expected measurement units after seeding: %s. "
"Run scripts/seed_data.py --units to restore them."
) % ", ".join(missing_units)
logger.error(message)
raise RuntimeError(message)
inactive_units = sorted(
row[0] for row in rows if not bool(row[1])
)
if inactive_units:
message = (
"Measurement units inactive after seeding: %s. "
"Reactivate them or rerun unit seeding."
) % ", ".join(inactive_units)
logger.error(message)
raise RuntimeError(message)
logger.info(
"Verified %d measurement units present",
len(found_units),
)
logger.info("Seed verification complete")
def _drop_database(self, database: str) -> None:
logger.warning("Rollback: dropping database '%s'", database)
with self._admin_connection(self.config.admin_database) as conn:
conn.autocommit = True
with conn.cursor() as cursor:
cursor.execute(
"SELECT pg_terminate_backend(pid) FROM pg_stat_activity WHERE datname = %s",
(database,),
)
cursor.execute(
sql.SQL("DROP DATABASE IF EXISTS {}" ).format(
sql.Identifier(database)
)
)
def _drop_role(self, role: str) -> None:
logger.warning("Rollback: dropping role '%s'", role)
with self._admin_connection(self.config.admin_database) as conn:
conn.autocommit = True
with conn.cursor() as cursor:
cursor.execute(
sql.SQL("DROP ROLE IF EXISTS {}" ).format(
sql.Identifier(role)
)
)
def _revoke_role_privileges(self, *, schema_name: str) -> None:
logger.warning(
"Rollback: revoking privileges on schema '%s' for role '%s'",
schema_name,
self.config.user,
)
with self._admin_connection(self.config.database) as conn:
conn.autocommit = True
with conn.cursor() as cursor:
cursor.execute(
sql.SQL("REVOKE ALL PRIVILEGES ON ALL TABLES IN SCHEMA {} FROM {}" ).format(
sql.Identifier(schema_name),
sql.Identifier(self.config.user)
)
)
cursor.execute(
sql.SQL("REVOKE ALL PRIVILEGES ON ALL SEQUENCES IN SCHEMA {} FROM {}" ).format(
sql.Identifier(schema_name),
sql.Identifier(self.config.user)
)
)
cursor.execute(
sql.SQL("ALTER DEFAULT PRIVILEGES IN SCHEMA {} REVOKE SELECT, INSERT, UPDATE, DELETE ON TABLES FROM {}" ).format(
sql.Identifier(schema_name),
sql.Identifier(self.config.user)
)
)
cursor.execute(
sql.SQL("ALTER DEFAULT PRIVILEGES IN SCHEMA {} REVOKE USAGE, SELECT ON SEQUENCES FROM {}" ).format(
sql.Identifier(schema_name),
sql.Identifier(self.config.user)
)
)
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="Bootstrap CalMiner database")
parser.add_argument(
"--ensure-database",
action="store_true",
help="Create the application database when it does not already exist.",
)
parser.add_argument(
"--ensure-role",
action="store_true",
help="Create the application role and grant necessary privileges.",
)
parser.add_argument(
"--ensure-schema",
action="store_true",
help="Create the configured schema if it does not exist.",
)
parser.add_argument(
"--initialize-schema",
action="store_true",
help="Create missing tables based on SQLAlchemy models.",
)
parser.add_argument(
"--run-migrations",
action="store_true",
help="Execute pending SQL migrations.",
)
parser.add_argument(
"--seed-data",
action="store_true",
help="Seed baseline reference data (currencies, etc.).",
)
parser.add_argument(
"--migrations-dir",
default=None,
help="Override the default migrations directory.",
)
parser.add_argument("--db-driver", help="Override DATABASE_DRIVER")
parser.add_argument("--db-host", help="Override DATABASE_HOST")
parser.add_argument("--db-port", type=int,
help="Override DATABASE_PORT")
parser.add_argument("--db-name", help="Override DATABASE_NAME")
parser.add_argument("--db-user", help="Override DATABASE_USER")
parser.add_argument(
"--db-password", help="Override DATABASE_PASSWORD")
parser.add_argument("--db-schema", help="Override DATABASE_SCHEMA")
parser.add_argument(
"--admin-url",
help="Override DATABASE_ADMIN_URL for administrative operations",
)
parser.add_argument(
"--admin-user", help="Override DATABASE_SUPERUSER for admin ops")
parser.add_argument(
"--admin-password",
help="Override DATABASE_SUPERUSER_PASSWORD for admin ops",
)
parser.add_argument(
"--admin-db",
help="Override DATABASE_SUPERUSER_DB for admin ops",
)
parser.add_argument(
"--dry-run",
action="store_true",
help="Log actions without applying changes.",
)
parser.add_argument(
"--verbose", "-v", action="count", default=0, help="Increase logging verbosity"
)
return parser.parse_args()
def main() -> None:
args = parse_args()
level = logging.WARNING - (10 * min(args.verbose, 2))
logging.basicConfig(level=max(level, logging.INFO),
format="%(levelname)s %(message)s")
override_args: dict[str, Optional[str]] = {
"DATABASE_DRIVER": args.db_driver,
"DATABASE_HOST": args.db_host,
"DATABASE_NAME": args.db_name,
"DATABASE_USER": args.db_user,
"DATABASE_PASSWORD": args.db_password,
"DATABASE_SCHEMA": args.db_schema,
"DATABASE_ADMIN_URL": args.admin_url,
"DATABASE_SUPERUSER": args.admin_user,
"DATABASE_SUPERUSER_PASSWORD": args.admin_password,
"DATABASE_SUPERUSER_DB": args.admin_db,
}
if args.db_port is not None:
override_args["DATABASE_PORT"] = str(args.db_port)
config = DatabaseConfig.from_env(overrides=override_args)
setup = DatabaseSetup(config, dry_run=args.dry_run)
admin_tasks_requested = args.ensure_database or args.ensure_role or args.ensure_schema
if admin_tasks_requested:
setup.validate_admin_connection()
app_validated = False
def ensure_application_connection_for(operation: str) -> bool:
nonlocal app_validated
if app_validated:
return True
if setup.dry_run and not setup.application_role_exists():
logger.info(
"Dry run: skipping %s because application role '%s' does not exist yet.",
operation,
setup.config.user,
)
return False
setup.validate_application_connection()
app_validated = True
return True
should_run_migrations = args.run_migrations
auto_run_migrations_reason: Optional[str] = None
if args.seed_data and not should_run_migrations:
should_run_migrations = True
auto_run_migrations_reason = (
"Seed data requested without explicit --run-migrations; applying migrations first."
)
try:
if args.ensure_database:
setup.ensure_database()
if args.ensure_role:
setup.ensure_role()
if args.ensure_schema:
setup.ensure_schema()
if args.initialize_schema:
if ensure_application_connection_for(
"SQLAlchemy schema initialization"
):
setup.initialize_schema()
if should_run_migrations:
if ensure_application_connection_for("migration execution"):
if auto_run_migrations_reason:
logger.info(auto_run_migrations_reason)
migrations_path = (
Path(args.migrations_dir)
if args.migrations_dir
else None
)
setup.run_migrations(migrations_path)
if args.seed_data:
if ensure_application_connection_for("baseline data seeding"):
setup.seed_baseline_data(dry_run=args.dry_run)
except Exception:
if not setup.dry_run:
setup.execute_rollbacks()
raise
finally:
if not setup.dry_run:
setup.clear_rollbacks()
if __name__ == "__main__":
main()