diff --git a/tests/scripts/test_init_db.py b/tests/scripts/test_init_db.py new file mode 100644 index 0000000..c8a4334 --- /dev/null +++ b/tests/scripts/test_init_db.py @@ -0,0 +1,304 @@ +from __future__ import annotations + +from dataclasses import dataclass, field +from decimal import Decimal +import re +from types import SimpleNamespace +from typing import Any, Dict, Iterable, Tuple + +import pytest + +from scripts import init_db + + +@pytest.fixture(autouse=True) +def clear_seed_admin_env(monkeypatch: pytest.MonkeyPatch) -> None: + """Remove environment overrides so defaults are deterministic during tests.""" + for name in ( + "CALMINER_SEED_ADMIN_EMAIL", + "CALMINER_SEED_ADMIN_USERNAME", + "CALMINER_SEED_ADMIN_PASSWORD", + ): + monkeypatch.delenv(name, raising=False) + + +@dataclass +class FakeState: + enums: set[str] = field(default_factory=set) + tables: set[str] = field(default_factory=set) + roles: dict[int, Dict[str, Any]] = field(default_factory=dict) + users: dict[str, Dict[str, Any]] = field(default_factory=dict) + user_roles: set[Tuple[int, int]] = field(default_factory=set) + pricing_settings: dict[str, Dict[str, Any]] = field(default_factory=dict) + projects: dict[str, Dict[str, Any]] = field(default_factory=dict) + scenarios: dict[Tuple[int, str], Dict[str, Any] + ] = field(default_factory=dict) + financial_inputs: dict[Tuple[int, str], + Dict[str, Any]] = field(default_factory=dict) + sequences: Dict[str, int] = field(default_factory=lambda: { + "users": 0, + "projects": 0, + "scenarios": 0, + "financial_inputs": 0, + }) + + +class FakeResult: + def __init__(self, rows: Iterable[Any]) -> None: + self._rows = list(rows) + + def fetchone(self) -> Any | None: + return self._rows[0] if self._rows else None + + +class FakeConnection: + def __init__(self, state: FakeState) -> None: + self.state = state + + def execute(self, statement: Any, params: dict[str, Any] | None = None) -> FakeResult: + params = params or {} + sql = str(statement).strip() + lower_sql = sql.lower() + + if lower_sql.startswith("do $$ begin"): + match = re.search(r"create type\s+(\w+)\s+as enum", lower_sql) + if match: + self.state.enums.add(match.group(1)) + return FakeResult([]) + + if lower_sql.startswith("create table if not exists"): + match = re.search(r"create table if not exists\s+(\w+)", lower_sql) + if match: + self.state.tables.add(match.group(1)) + return FakeResult([]) + + if lower_sql.startswith("insert into roles"): + role_id = params["id"] + record = self.state.roles.get(role_id, {"id": role_id}) + record.update( + name=params["name"], + display_name=params["display_name"], + description=params.get("description"), + ) + self.state.roles[role_id] = record + return FakeResult([]) + + if lower_sql.startswith("insert into users"): + username = params["username"] + record = self.state.users.get(username) + if record is None: + self.state.sequences["users"] += 1 + record = { + "id": self.state.sequences["users"], "username": username} + record.update( + email=params["email"], + password_hash=params["password_hash"], + is_active=params["is_active"], + is_superuser=params["is_superuser"], + ) + self.state.users[username] = record + return FakeResult([]) + + if lower_sql.startswith("select id from users where username"): + username = params["username"] + record = self.state.users.get(username) + rows = [SimpleNamespace(id=record["id"])] if record else [] + return FakeResult(rows) + + if lower_sql.startswith("insert into user_roles"): + key = (int(params["user_id"]), int(params["role_id"])) + self.state.user_roles.add(key) + return FakeResult([]) + + if lower_sql.startswith("insert into pricing_settings"): + slug = params["slug"] + record = self.state.pricing_settings.get(slug, {"slug": slug}) + record.update( + name=params["name"], + description=params.get("description"), + default_currency=params.get("default_currency"), + default_payable_pct=float(params["default_payable_pct"]), + moisture_threshold_pct=float(params["moisture_threshold_pct"]), + moisture_penalty_per_pct=float( + params["moisture_penalty_per_pct"]), + ) + self.state.pricing_settings[slug] = record + return FakeResult([]) + + if lower_sql.startswith("insert into projects"): + name = params["name"] + record = self.state.projects.get(name) + if record is None: + self.state.sequences["projects"] += 1 + record = {"id": self.state.sequences["projects"], "name": name} + record.update( + location=params.get("location"), + operation_type=params["operation_type"], + description=params.get("description"), + pricing_settings_id=params.get("pricing_settings_id"), + ) + self.state.projects[name] = record + return FakeResult([]) + + if lower_sql.startswith("select id from projects where name"): + project = self.state.projects.get(params["name"]) + rows = [SimpleNamespace(id=project["id"])] if project else [] + return FakeResult(rows) + + if lower_sql.startswith("insert into scenarios"): + key = (int(params["project_id"]), params["name"]) + record = self.state.scenarios.get(key) + if record is None: + self.state.sequences["scenarios"] += 1 + record = { + "id": self.state.sequences["scenarios"], + "project_id": int(params["project_id"]), + "name": params["name"], + } + record.update( + description=params.get("description"), + status=params.get("status"), + discount_rate=params.get("discount_rate"), + currency=params.get("currency"), + primary_resource=params.get("primary_resource"), + ) + self.state.scenarios[key] = record + return FakeResult([]) + + if lower_sql.startswith("select id from scenarios where project_id"): + key = (int(params["project_id"]), params["name"]) + scenario = self.state.scenarios.get(key) + rows = [SimpleNamespace(id=scenario["id"])] if scenario else [] + return FakeResult(rows) + + if lower_sql.startswith("insert into financial_inputs"): + key = (int(params["scenario_id"]), params["name"]) + record = self.state.financial_inputs.get(key) + if record is None: + self.state.sequences["financial_inputs"] += 1 + record = { + "id": self.state.sequences["financial_inputs"], + "scenario_id": int(params["scenario_id"]), + "name": params["name"], + } + amount = params["amount"] + if not isinstance(amount, Decimal): + amount = Decimal(str(amount)) + record.update( + category=params["category"], + cost_bucket=params.get("cost_bucket"), + amount=amount, + currency=params.get("currency"), + notes=params.get("notes"), + ) + self.state.financial_inputs[key] = record + return FakeResult([]) + + raise NotImplementedError( + f"Unhandled SQL during test execution: {sql}") + + +class FakeTransaction: + def __init__(self, state: FakeState) -> None: + self.state = state + + def __enter__(self) -> FakeConnection: # noqa: D401 - simple context helper + return FakeConnection(self.state) + + def __exit__(self, exc_type, exc, tb) -> bool: + return False + + +class FakeEngine: + def __init__(self) -> None: + self.state = FakeState() + self.begin_calls = 0 + + def begin(self) -> FakeTransaction: # noqa: D401 - simple context helper + self.begin_calls += 1 + return FakeTransaction(self.state) + + +@pytest.fixture() +def fake_engine(monkeypatch: pytest.MonkeyPatch) -> FakeEngine: + engine = FakeEngine() + + def _fake_create_engine(database_url: str | None = None) -> FakeEngine: # noqa: ARG001 - signature parity + return engine + + monkeypatch.setattr(init_db, "_create_engine", _fake_create_engine) + return engine + + +def test_init_db_seeds_demo_data_idempotently(fake_engine: FakeEngine) -> None: + init_db.init_db(database_url="postgresql://fake") + + state = fake_engine.state + expected_enum_names = set(init_db.ENUM_DEFINITIONS.keys()) + assert state.enums == expected_enum_names + + expected_role_ids = {role["id"] for role in init_db.DEFAULT_ROLES} + assert set(state.roles.keys()) == expected_role_ids + + assert "admin" in state.users + admin_record = state.users["admin"] + assert admin_record["email"] == init_db.DEFAULT_ADMIN["email"] + assert state.user_roles == {(admin_record["id"], 1)} + + assert set(state.pricing_settings.keys()) == { + init_db.DEFAULT_PRICING["slug"]} + + expected_project_names = { + project.name for project in init_db.DEFAULT_PROJECTS} + assert set(state.projects.keys()) == expected_project_names + + assert len(state.scenarios) == len(init_db.DEFAULT_SCENARIOS) + assert len(state.financial_inputs) == len(init_db.DEFAULT_FINANCIAL_INPUTS) + + snapshot = { + "projects": {name: data.copy() for name, data in state.projects.items()}, + "scenario_keys": set(state.scenarios.keys()), + "financial_keys": set(state.financial_inputs.keys()), + "user_roles": set(state.user_roles), + "admin_id": admin_record["id"], + } + + init_db.init_db(database_url="postgresql://fake") + + state_after = fake_engine.state + assert set(state_after.roles.keys()) == expected_role_ids + assert len(state_after.users) == 1 + assert state_after.users["admin"]["id"] == snapshot["admin_id"] + assert set(state_after.projects.keys()) == set(snapshot["projects"].keys()) + assert set(state_after.scenarios.keys()) == snapshot["scenario_keys"] + assert set(state_after.financial_inputs.keys() + ) == snapshot["financial_keys"] + assert state_after.user_roles == snapshot["user_roles"] + + +def test_enum_seed_values_align_with_definitions() -> None: + ddl_blob = " ".join(init_db.TABLE_DDLS).lower() + for enum_name, values in init_db.ENUM_DEFINITIONS.items(): + assert enum_name in ddl_blob + if enum_name == "miningoperationtype": + for project in init_db.DEFAULT_PROJECTS: + assert project.operation_type in values + if enum_name == "scenariostatus": + for scenario in init_db.DEFAULT_SCENARIOS: + assert scenario.status in values + if enum_name == "resourcetype": + for scenario in init_db.DEFAULT_SCENARIOS: + if scenario.primary_resource is not None: + assert scenario.primary_resource in values + if enum_name == "financialcategory": + for item in init_db.DEFAULT_FINANCIAL_INPUTS: + assert item.category in values + if enum_name == "costbucket": + for item in init_db.DEFAULT_FINANCIAL_INPUTS: + if item.cost_bucket is not None: + assert item.cost_bucket in values + if enum_name == "distributiontype": + # Simulation parameters reference this type in the schema. + assert "distributiontype" in ddl_blob + if enum_name == "stochasticvariable": + assert "stochasticvariable" in ddl_blob