from __future__ import annotations from dataclasses import dataclass, field from decimal import Decimal import re from types import SimpleNamespace from typing import Any, Dict, Iterable, Tuple import pytest from scripts import init_db @pytest.fixture(autouse=True) def clear_seed_admin_env(monkeypatch: pytest.MonkeyPatch) -> None: """Remove environment overrides so defaults are deterministic during tests.""" for name in ( "CALMINER_SEED_ADMIN_EMAIL", "CALMINER_SEED_ADMIN_USERNAME", "CALMINER_SEED_ADMIN_PASSWORD", ): monkeypatch.delenv(name, raising=False) @dataclass class FakeState: enums: set[str] = field(default_factory=set) tables: set[str] = field(default_factory=set) roles: dict[int, Dict[str, Any]] = field(default_factory=dict) users: dict[str, Dict[str, Any]] = field(default_factory=dict) user_roles: set[Tuple[int, int]] = field(default_factory=set) pricing_settings: dict[str, Dict[str, Any]] = field(default_factory=dict) projects: dict[str, Dict[str, Any]] = field(default_factory=dict) scenarios: dict[Tuple[int, str], Dict[str, Any] ] = field(default_factory=dict) financial_inputs: dict[Tuple[int, str], Dict[str, Any]] = field(default_factory=dict) sequences: Dict[str, int] = field(default_factory=lambda: { "users": 0, "projects": 0, "scenarios": 0, "financial_inputs": 0, }) class FakeResult: def __init__(self, rows: Iterable[Any]) -> None: self._rows = list(rows) def fetchone(self) -> Any | None: return self._rows[0] if self._rows else None class FakeConnection: def __init__(self, state: FakeState) -> None: self.state = state def execute(self, statement: Any, params: dict[str, Any] | None = None) -> FakeResult: params = params or {} sql = str(statement).strip() lower_sql = sql.lower() if lower_sql.startswith("do $$ begin"): match = re.search(r"create type\s+(\w+)\s+as enum", lower_sql) if match: self.state.enums.add(match.group(1)) return FakeResult([]) if lower_sql.startswith("create table if not exists"): match = re.search(r"create table if not exists\s+(\w+)", lower_sql) if match: self.state.tables.add(match.group(1)) return FakeResult([]) if lower_sql.startswith("insert into roles"): role_id = params["id"] record = self.state.roles.get(role_id, {"id": role_id}) record.update( name=params["name"], display_name=params["display_name"], description=params.get("description"), ) self.state.roles[role_id] = record return FakeResult([]) if lower_sql.startswith("insert into users"): username = params["username"] record = self.state.users.get(username) if record is None: self.state.sequences["users"] += 1 record = { "id": self.state.sequences["users"], "username": username} record.update( email=params["email"], password_hash=params["password_hash"], is_active=params["is_active"], is_superuser=params["is_superuser"], ) self.state.users[username] = record return FakeResult([]) if lower_sql.startswith("select id from users where username"): username = params["username"] record = self.state.users.get(username) rows = [SimpleNamespace(id=record["id"])] if record else [] return FakeResult(rows) if lower_sql.startswith("insert into user_roles"): key = (int(params["user_id"]), int(params["role_id"])) self.state.user_roles.add(key) return FakeResult([]) if lower_sql.startswith("insert into pricing_settings"): slug = params["slug"] record = self.state.pricing_settings.get(slug, {"slug": slug}) record.update( name=params["name"], description=params.get("description"), default_currency=params.get("default_currency"), default_payable_pct=float(params["default_payable_pct"]), moisture_threshold_pct=float(params["moisture_threshold_pct"]), moisture_penalty_per_pct=float( params["moisture_penalty_per_pct"]), ) self.state.pricing_settings[slug] = record return FakeResult([]) if lower_sql.startswith("insert into projects"): name = params["name"] record = self.state.projects.get(name) if record is None: self.state.sequences["projects"] += 1 record = {"id": self.state.sequences["projects"], "name": name} record.update( location=params.get("location"), operation_type=params["operation_type"], description=params.get("description"), pricing_settings_id=params.get("pricing_settings_id"), ) self.state.projects[name] = record return FakeResult([]) if lower_sql.startswith("select id from projects where name"): project = self.state.projects.get(params["name"]) rows = [SimpleNamespace(id=project["id"])] if project else [] return FakeResult(rows) if lower_sql.startswith("insert into scenarios"): key = (int(params["project_id"]), params["name"]) record = self.state.scenarios.get(key) if record is None: self.state.sequences["scenarios"] += 1 record = { "id": self.state.sequences["scenarios"], "project_id": int(params["project_id"]), "name": params["name"], } record.update( description=params.get("description"), status=params.get("status"), discount_rate=params.get("discount_rate"), currency=params.get("currency"), primary_resource=params.get("primary_resource"), ) self.state.scenarios[key] = record return FakeResult([]) if lower_sql.startswith("select id from scenarios where project_id"): key = (int(params["project_id"]), params["name"]) scenario = self.state.scenarios.get(key) rows = [SimpleNamespace(id=scenario["id"])] if scenario else [] return FakeResult(rows) if lower_sql.startswith("insert into financial_inputs"): key = (int(params["scenario_id"]), params["name"]) record = self.state.financial_inputs.get(key) if record is None: self.state.sequences["financial_inputs"] += 1 record = { "id": self.state.sequences["financial_inputs"], "scenario_id": int(params["scenario_id"]), "name": params["name"], } amount = params["amount"] if not isinstance(amount, Decimal): amount = Decimal(str(amount)) record.update( category=params["category"], cost_bucket=params.get("cost_bucket"), amount=amount, currency=params.get("currency"), notes=params.get("notes"), ) self.state.financial_inputs[key] = record return FakeResult([]) raise NotImplementedError( f"Unhandled SQL during test execution: {sql}") class FakeTransaction: def __init__(self, state: FakeState) -> None: self.state = state def __enter__(self) -> FakeConnection: # noqa: D401 - simple context helper return FakeConnection(self.state) def __exit__(self, exc_type, exc, tb) -> bool: return False class FakeEngine: def __init__(self) -> None: self.state = FakeState() self.begin_calls = 0 def begin(self) -> FakeTransaction: # noqa: D401 - simple context helper self.begin_calls += 1 return FakeTransaction(self.state) @pytest.fixture() def fake_engine(monkeypatch: pytest.MonkeyPatch) -> FakeEngine: engine = FakeEngine() def _fake_create_engine(database_url: str | None = None) -> FakeEngine: # noqa: ARG001 - signature parity return engine monkeypatch.setattr(init_db, "_create_engine", _fake_create_engine) return engine def test_init_db_seeds_demo_data_idempotently(fake_engine: FakeEngine) -> None: init_db.init_db(database_url="postgresql://fake") state = fake_engine.state expected_enum_names = set(init_db.ENUM_DEFINITIONS.keys()) assert state.enums == expected_enum_names expected_role_ids = {role["id"] for role in init_db.DEFAULT_ROLES} assert set(state.roles.keys()) == expected_role_ids assert "admin" in state.users admin_record = state.users["admin"] assert admin_record["email"] == init_db.DEFAULT_ADMIN["email"] assert state.user_roles == {(admin_record["id"], 1)} assert set(state.pricing_settings.keys()) == { init_db.DEFAULT_PRICING["slug"]} expected_project_names = { project.name for project in init_db.DEFAULT_PROJECTS} assert set(state.projects.keys()) == expected_project_names assert len(state.scenarios) == len(init_db.DEFAULT_SCENARIOS) assert len(state.financial_inputs) == len(init_db.DEFAULT_FINANCIAL_INPUTS) snapshot = { "projects": {name: data.copy() for name, data in state.projects.items()}, "scenario_keys": set(state.scenarios.keys()), "financial_keys": set(state.financial_inputs.keys()), "user_roles": set(state.user_roles), "admin_id": admin_record["id"], } init_db.init_db(database_url="postgresql://fake") state_after = fake_engine.state assert set(state_after.roles.keys()) == expected_role_ids assert len(state_after.users) == 1 assert state_after.users["admin"]["id"] == snapshot["admin_id"] assert set(state_after.projects.keys()) == set(snapshot["projects"].keys()) assert set(state_after.scenarios.keys()) == snapshot["scenario_keys"] assert set(state_after.financial_inputs.keys() ) == snapshot["financial_keys"] assert state_after.user_roles == snapshot["user_roles"] def test_enum_seed_values_align_with_definitions() -> None: ddl_blob = " ".join(init_db.TABLE_DDLS).lower() for enum_name, values in init_db.ENUM_DEFINITIONS.items(): assert enum_name in ddl_blob if enum_name == "miningoperationtype": for project in init_db.DEFAULT_PROJECTS: assert project.operation_type in values if enum_name == "scenariostatus": for scenario in init_db.DEFAULT_SCENARIOS: assert scenario.status in values if enum_name == "resourcetype": for scenario in init_db.DEFAULT_SCENARIOS: if scenario.primary_resource is not None: assert scenario.primary_resource in values if enum_name == "financialcategory": for item in init_db.DEFAULT_FINANCIAL_INPUTS: assert item.category in values if enum_name == "costbucket": for item in init_db.DEFAULT_FINANCIAL_INPUTS: if item.cost_bucket is not None: assert item.cost_bucket in values if enum_name == "distributiontype": # Simulation parameters reference this type in the schema. assert "distributiontype" in ddl_blob if enum_name == "stochasticvariable": assert "stochasticvariable" in ddl_blob