feat: Add comprehensive tests for database initialization and seeding

This commit is contained in:
2025-11-12 16:38:20 +01:00
parent 8ef6724960
commit 23523f70f1

View File

@@ -0,0 +1,304 @@
from __future__ import annotations
from dataclasses import dataclass, field
from decimal import Decimal
import re
from types import SimpleNamespace
from typing import Any, Dict, Iterable, Tuple
import pytest
from scripts import init_db
@pytest.fixture(autouse=True)
def clear_seed_admin_env(monkeypatch: pytest.MonkeyPatch) -> None:
"""Remove environment overrides so defaults are deterministic during tests."""
for name in (
"CALMINER_SEED_ADMIN_EMAIL",
"CALMINER_SEED_ADMIN_USERNAME",
"CALMINER_SEED_ADMIN_PASSWORD",
):
monkeypatch.delenv(name, raising=False)
@dataclass
class FakeState:
enums: set[str] = field(default_factory=set)
tables: set[str] = field(default_factory=set)
roles: dict[int, Dict[str, Any]] = field(default_factory=dict)
users: dict[str, Dict[str, Any]] = field(default_factory=dict)
user_roles: set[Tuple[int, int]] = field(default_factory=set)
pricing_settings: dict[str, Dict[str, Any]] = field(default_factory=dict)
projects: dict[str, Dict[str, Any]] = field(default_factory=dict)
scenarios: dict[Tuple[int, str], Dict[str, Any]
] = field(default_factory=dict)
financial_inputs: dict[Tuple[int, str],
Dict[str, Any]] = field(default_factory=dict)
sequences: Dict[str, int] = field(default_factory=lambda: {
"users": 0,
"projects": 0,
"scenarios": 0,
"financial_inputs": 0,
})
class FakeResult:
def __init__(self, rows: Iterable[Any]) -> None:
self._rows = list(rows)
def fetchone(self) -> Any | None:
return self._rows[0] if self._rows else None
class FakeConnection:
def __init__(self, state: FakeState) -> None:
self.state = state
def execute(self, statement: Any, params: dict[str, Any] | None = None) -> FakeResult:
params = params or {}
sql = str(statement).strip()
lower_sql = sql.lower()
if lower_sql.startswith("do $$ begin"):
match = re.search(r"create type\s+(\w+)\s+as enum", lower_sql)
if match:
self.state.enums.add(match.group(1))
return FakeResult([])
if lower_sql.startswith("create table if not exists"):
match = re.search(r"create table if not exists\s+(\w+)", lower_sql)
if match:
self.state.tables.add(match.group(1))
return FakeResult([])
if lower_sql.startswith("insert into roles"):
role_id = params["id"]
record = self.state.roles.get(role_id, {"id": role_id})
record.update(
name=params["name"],
display_name=params["display_name"],
description=params.get("description"),
)
self.state.roles[role_id] = record
return FakeResult([])
if lower_sql.startswith("insert into users"):
username = params["username"]
record = self.state.users.get(username)
if record is None:
self.state.sequences["users"] += 1
record = {
"id": self.state.sequences["users"], "username": username}
record.update(
email=params["email"],
password_hash=params["password_hash"],
is_active=params["is_active"],
is_superuser=params["is_superuser"],
)
self.state.users[username] = record
return FakeResult([])
if lower_sql.startswith("select id from users where username"):
username = params["username"]
record = self.state.users.get(username)
rows = [SimpleNamespace(id=record["id"])] if record else []
return FakeResult(rows)
if lower_sql.startswith("insert into user_roles"):
key = (int(params["user_id"]), int(params["role_id"]))
self.state.user_roles.add(key)
return FakeResult([])
if lower_sql.startswith("insert into pricing_settings"):
slug = params["slug"]
record = self.state.pricing_settings.get(slug, {"slug": slug})
record.update(
name=params["name"],
description=params.get("description"),
default_currency=params.get("default_currency"),
default_payable_pct=float(params["default_payable_pct"]),
moisture_threshold_pct=float(params["moisture_threshold_pct"]),
moisture_penalty_per_pct=float(
params["moisture_penalty_per_pct"]),
)
self.state.pricing_settings[slug] = record
return FakeResult([])
if lower_sql.startswith("insert into projects"):
name = params["name"]
record = self.state.projects.get(name)
if record is None:
self.state.sequences["projects"] += 1
record = {"id": self.state.sequences["projects"], "name": name}
record.update(
location=params.get("location"),
operation_type=params["operation_type"],
description=params.get("description"),
pricing_settings_id=params.get("pricing_settings_id"),
)
self.state.projects[name] = record
return FakeResult([])
if lower_sql.startswith("select id from projects where name"):
project = self.state.projects.get(params["name"])
rows = [SimpleNamespace(id=project["id"])] if project else []
return FakeResult(rows)
if lower_sql.startswith("insert into scenarios"):
key = (int(params["project_id"]), params["name"])
record = self.state.scenarios.get(key)
if record is None:
self.state.sequences["scenarios"] += 1
record = {
"id": self.state.sequences["scenarios"],
"project_id": int(params["project_id"]),
"name": params["name"],
}
record.update(
description=params.get("description"),
status=params.get("status"),
discount_rate=params.get("discount_rate"),
currency=params.get("currency"),
primary_resource=params.get("primary_resource"),
)
self.state.scenarios[key] = record
return FakeResult([])
if lower_sql.startswith("select id from scenarios where project_id"):
key = (int(params["project_id"]), params["name"])
scenario = self.state.scenarios.get(key)
rows = [SimpleNamespace(id=scenario["id"])] if scenario else []
return FakeResult(rows)
if lower_sql.startswith("insert into financial_inputs"):
key = (int(params["scenario_id"]), params["name"])
record = self.state.financial_inputs.get(key)
if record is None:
self.state.sequences["financial_inputs"] += 1
record = {
"id": self.state.sequences["financial_inputs"],
"scenario_id": int(params["scenario_id"]),
"name": params["name"],
}
amount = params["amount"]
if not isinstance(amount, Decimal):
amount = Decimal(str(amount))
record.update(
category=params["category"],
cost_bucket=params.get("cost_bucket"),
amount=amount,
currency=params.get("currency"),
notes=params.get("notes"),
)
self.state.financial_inputs[key] = record
return FakeResult([])
raise NotImplementedError(
f"Unhandled SQL during test execution: {sql}")
class FakeTransaction:
def __init__(self, state: FakeState) -> None:
self.state = state
def __enter__(self) -> FakeConnection: # noqa: D401 - simple context helper
return FakeConnection(self.state)
def __exit__(self, exc_type, exc, tb) -> bool:
return False
class FakeEngine:
def __init__(self) -> None:
self.state = FakeState()
self.begin_calls = 0
def begin(self) -> FakeTransaction: # noqa: D401 - simple context helper
self.begin_calls += 1
return FakeTransaction(self.state)
@pytest.fixture()
def fake_engine(monkeypatch: pytest.MonkeyPatch) -> FakeEngine:
engine = FakeEngine()
def _fake_create_engine(database_url: str | None = None) -> FakeEngine: # noqa: ARG001 - signature parity
return engine
monkeypatch.setattr(init_db, "_create_engine", _fake_create_engine)
return engine
def test_init_db_seeds_demo_data_idempotently(fake_engine: FakeEngine) -> None:
init_db.init_db(database_url="postgresql://fake")
state = fake_engine.state
expected_enum_names = set(init_db.ENUM_DEFINITIONS.keys())
assert state.enums == expected_enum_names
expected_role_ids = {role["id"] for role in init_db.DEFAULT_ROLES}
assert set(state.roles.keys()) == expected_role_ids
assert "admin" in state.users
admin_record = state.users["admin"]
assert admin_record["email"] == init_db.DEFAULT_ADMIN["email"]
assert state.user_roles == {(admin_record["id"], 1)}
assert set(state.pricing_settings.keys()) == {
init_db.DEFAULT_PRICING["slug"]}
expected_project_names = {
project.name for project in init_db.DEFAULT_PROJECTS}
assert set(state.projects.keys()) == expected_project_names
assert len(state.scenarios) == len(init_db.DEFAULT_SCENARIOS)
assert len(state.financial_inputs) == len(init_db.DEFAULT_FINANCIAL_INPUTS)
snapshot = {
"projects": {name: data.copy() for name, data in state.projects.items()},
"scenario_keys": set(state.scenarios.keys()),
"financial_keys": set(state.financial_inputs.keys()),
"user_roles": set(state.user_roles),
"admin_id": admin_record["id"],
}
init_db.init_db(database_url="postgresql://fake")
state_after = fake_engine.state
assert set(state_after.roles.keys()) == expected_role_ids
assert len(state_after.users) == 1
assert state_after.users["admin"]["id"] == snapshot["admin_id"]
assert set(state_after.projects.keys()) == set(snapshot["projects"].keys())
assert set(state_after.scenarios.keys()) == snapshot["scenario_keys"]
assert set(state_after.financial_inputs.keys()
) == snapshot["financial_keys"]
assert state_after.user_roles == snapshot["user_roles"]
def test_enum_seed_values_align_with_definitions() -> None:
ddl_blob = " ".join(init_db.TABLE_DDLS).lower()
for enum_name, values in init_db.ENUM_DEFINITIONS.items():
assert enum_name in ddl_blob
if enum_name == "miningoperationtype":
for project in init_db.DEFAULT_PROJECTS:
assert project.operation_type in values
if enum_name == "scenariostatus":
for scenario in init_db.DEFAULT_SCENARIOS:
assert scenario.status in values
if enum_name == "resourcetype":
for scenario in init_db.DEFAULT_SCENARIOS:
if scenario.primary_resource is not None:
assert scenario.primary_resource in values
if enum_name == "financialcategory":
for item in init_db.DEFAULT_FINANCIAL_INPUTS:
assert item.category in values
if enum_name == "costbucket":
for item in init_db.DEFAULT_FINANCIAL_INPUTS:
if item.cost_bucket is not None:
assert item.cost_bucket in values
if enum_name == "distributiontype":
# Simulation parameters reference this type in the schema.
assert "distributiontype" in ddl_blob
if enum_name == "stochasticvariable":
assert "stochasticvariable" in ddl_blob