317 lines
12 KiB
Python
317 lines
12 KiB
Python
from __future__ import annotations
|
|
|
|
from dataclasses import dataclass, field
|
|
from decimal import Decimal
|
|
import re
|
|
from types import SimpleNamespace
|
|
from typing import Any, Dict, Iterable, Tuple
|
|
|
|
import pytest
|
|
|
|
from scripts import init_db
|
|
|
|
|
|
@pytest.fixture(autouse=True)
|
|
def clear_seed_admin_env(monkeypatch: pytest.MonkeyPatch) -> None:
|
|
"""Remove environment overrides so defaults are deterministic during tests."""
|
|
for name in (
|
|
"CALMINER_SEED_ADMIN_EMAIL",
|
|
"CALMINER_SEED_ADMIN_USERNAME",
|
|
"CALMINER_SEED_ADMIN_PASSWORD",
|
|
):
|
|
monkeypatch.delenv(name, raising=False)
|
|
|
|
|
|
@dataclass
|
|
class FakeState:
|
|
enums: set[str] = field(default_factory=set)
|
|
tables: set[str] = field(default_factory=set)
|
|
roles: dict[int, Dict[str, Any]] = field(default_factory=dict)
|
|
users: dict[str, Dict[str, Any]] = field(default_factory=dict)
|
|
user_roles: set[Tuple[int, int]] = field(default_factory=set)
|
|
pricing_settings: dict[str, Dict[str, Any]] = field(default_factory=dict)
|
|
projects: dict[str, Dict[str, Any]] = field(default_factory=dict)
|
|
scenarios: dict[Tuple[int, str], Dict[str, Any]
|
|
] = field(default_factory=dict)
|
|
financial_inputs: dict[Tuple[int, str],
|
|
Dict[str, Any]] = field(default_factory=dict)
|
|
sequences: Dict[str, int] = field(default_factory=lambda: {
|
|
"users": 0,
|
|
"projects": 0,
|
|
"scenarios": 0,
|
|
"financial_inputs": 0,
|
|
})
|
|
|
|
|
|
class FakeResult:
|
|
def __init__(self, rows: Iterable[Any]) -> None:
|
|
self._rows = list(rows)
|
|
|
|
def fetchone(self) -> Any | None:
|
|
return self._rows[0] if self._rows else None
|
|
|
|
|
|
class FakeConnection:
|
|
def __init__(self, state: FakeState) -> None:
|
|
self.state = state
|
|
|
|
def execute(self, statement: Any, params: dict[str, Any] | None = None) -> FakeResult:
|
|
params = params or {}
|
|
sql = str(statement).strip()
|
|
lower_sql = sql.lower()
|
|
|
|
if lower_sql.startswith("do $$"):
|
|
match = re.search(r"create type\s+(\w+)\s+as enum", lower_sql)
|
|
if match:
|
|
self.state.enums.add(match.group(1))
|
|
return FakeResult([])
|
|
|
|
if lower_sql.startswith("create table if not exists"):
|
|
match = re.search(r"create table if not exists\s+(\w+)", lower_sql)
|
|
if match:
|
|
self.state.tables.add(match.group(1))
|
|
return FakeResult([])
|
|
|
|
if lower_sql.startswith("insert into roles"):
|
|
role_id = params["id"]
|
|
record = self.state.roles.get(role_id, {"id": role_id})
|
|
record.update(
|
|
name=params["name"],
|
|
display_name=params["display_name"],
|
|
description=params.get("description"),
|
|
)
|
|
self.state.roles[role_id] = record
|
|
return FakeResult([])
|
|
|
|
if lower_sql.startswith("insert into users"):
|
|
username = params["username"]
|
|
record = self.state.users.get(username)
|
|
if record is None:
|
|
self.state.sequences["users"] += 1
|
|
record = {
|
|
"id": self.state.sequences["users"], "username": username}
|
|
record.update(
|
|
email=params["email"],
|
|
password_hash=params["password_hash"],
|
|
is_active=params["is_active"],
|
|
is_superuser=params["is_superuser"],
|
|
)
|
|
self.state.users[username] = record
|
|
return FakeResult([])
|
|
|
|
if lower_sql.startswith("select id from users where username"):
|
|
username = params["username"]
|
|
record = self.state.users.get(username)
|
|
rows = [SimpleNamespace(id=record["id"])] if record else []
|
|
return FakeResult(rows)
|
|
|
|
if lower_sql.startswith("insert into user_roles"):
|
|
key = (int(params["user_id"]), int(params["role_id"]))
|
|
self.state.user_roles.add(key)
|
|
return FakeResult([])
|
|
|
|
if lower_sql.startswith("insert into pricing_settings"):
|
|
slug = params["slug"]
|
|
record = self.state.pricing_settings.get(slug, {"slug": slug})
|
|
record.update(
|
|
name=params["name"],
|
|
description=params.get("description"),
|
|
default_currency=params.get("default_currency"),
|
|
default_payable_pct=float(params["default_payable_pct"]),
|
|
moisture_threshold_pct=float(params["moisture_threshold_pct"]),
|
|
moisture_penalty_per_pct=float(
|
|
params["moisture_penalty_per_pct"]),
|
|
)
|
|
self.state.pricing_settings[slug] = record
|
|
return FakeResult([])
|
|
|
|
if lower_sql.startswith("insert into projects"):
|
|
name = params["name"]
|
|
record = self.state.projects.get(name)
|
|
if record is None:
|
|
self.state.sequences["projects"] += 1
|
|
record = {"id": self.state.sequences["projects"], "name": name}
|
|
record.update(
|
|
location=params.get("location"),
|
|
operation_type=params["operation_type"],
|
|
description=params.get("description"),
|
|
pricing_settings_id=params.get("pricing_settings_id"),
|
|
)
|
|
self.state.projects[name] = record
|
|
return FakeResult([])
|
|
|
|
if lower_sql.startswith("select id from projects where name"):
|
|
project = self.state.projects.get(params["name"])
|
|
rows = [SimpleNamespace(id=project["id"])] if project else []
|
|
return FakeResult(rows)
|
|
|
|
if lower_sql.startswith("insert into scenarios"):
|
|
key = (int(params["project_id"]), params["name"])
|
|
record = self.state.scenarios.get(key)
|
|
if record is None:
|
|
self.state.sequences["scenarios"] += 1
|
|
record = {
|
|
"id": self.state.sequences["scenarios"],
|
|
"project_id": int(params["project_id"]),
|
|
"name": params["name"],
|
|
}
|
|
record.update(
|
|
description=params.get("description"),
|
|
status=params.get("status"),
|
|
discount_rate=params.get("discount_rate"),
|
|
currency=params.get("currency"),
|
|
primary_resource=params.get("primary_resource"),
|
|
)
|
|
self.state.scenarios[key] = record
|
|
return FakeResult([])
|
|
|
|
if lower_sql.startswith("select id from scenarios where project_id"):
|
|
key = (int(params["project_id"]), params["name"])
|
|
scenario = self.state.scenarios.get(key)
|
|
rows = [SimpleNamespace(id=scenario["id"])] if scenario else []
|
|
return FakeResult(rows)
|
|
|
|
if lower_sql.startswith("insert into financial_inputs"):
|
|
key = (int(params["scenario_id"]), params["name"])
|
|
record = self.state.financial_inputs.get(key)
|
|
if record is None:
|
|
self.state.sequences["financial_inputs"] += 1
|
|
record = {
|
|
"id": self.state.sequences["financial_inputs"],
|
|
"scenario_id": int(params["scenario_id"]),
|
|
"name": params["name"],
|
|
}
|
|
amount = params["amount"]
|
|
if not isinstance(amount, Decimal):
|
|
amount = Decimal(str(amount))
|
|
record.update(
|
|
category=params["category"],
|
|
cost_bucket=params.get("cost_bucket"),
|
|
amount=amount,
|
|
currency=params.get("currency"),
|
|
notes=params.get("notes"),
|
|
)
|
|
self.state.financial_inputs[key] = record
|
|
return FakeResult([])
|
|
|
|
if "from pg_enum" in lower_sql and "enumlabel" in lower_sql:
|
|
type_name_param = params.get("type_name")
|
|
if type_name_param is None:
|
|
return FakeResult([])
|
|
type_name = str(type_name_param)
|
|
values = init_db.ENUM_DEFINITIONS.get(type_name, [])
|
|
rows = [SimpleNamespace(enumlabel=value) for value in values]
|
|
return FakeResult(rows)
|
|
|
|
if lower_sql.startswith("alter type") and "rename value" in lower_sql:
|
|
return FakeResult([])
|
|
|
|
raise NotImplementedError(
|
|
f"Unhandled SQL during test execution: {sql}")
|
|
|
|
|
|
class FakeTransaction:
|
|
def __init__(self, state: FakeState) -> None:
|
|
self.state = state
|
|
|
|
def __enter__(self) -> FakeConnection: # noqa: D401 - simple context helper
|
|
return FakeConnection(self.state)
|
|
|
|
def __exit__(self, exc_type, exc, tb) -> bool:
|
|
return False
|
|
|
|
|
|
class FakeEngine:
|
|
def __init__(self) -> None:
|
|
self.state = FakeState()
|
|
self.begin_calls = 0
|
|
|
|
def begin(self) -> FakeTransaction: # noqa: D401 - simple context helper
|
|
self.begin_calls += 1
|
|
return FakeTransaction(self.state)
|
|
|
|
|
|
@pytest.fixture()
|
|
def fake_engine(monkeypatch: pytest.MonkeyPatch) -> FakeEngine:
|
|
engine = FakeEngine()
|
|
|
|
def _fake_create_engine(database_url: str | None = None) -> FakeEngine: # noqa: ARG001 - signature parity
|
|
return engine
|
|
|
|
monkeypatch.setattr(init_db, "_create_engine", _fake_create_engine)
|
|
return engine
|
|
|
|
|
|
def test_init_db_seeds_demo_data_idempotently(fake_engine: FakeEngine) -> None:
|
|
init_db.init_db(database_url="postgresql://fake")
|
|
|
|
state = fake_engine.state
|
|
expected_enum_names = set(init_db.ENUM_DEFINITIONS.keys())
|
|
assert state.enums == expected_enum_names
|
|
|
|
expected_role_ids = {role["id"] for role in init_db.DEFAULT_ROLES}
|
|
assert set(state.roles.keys()) == expected_role_ids
|
|
|
|
assert "admin" in state.users
|
|
admin_record = state.users["admin"]
|
|
assert admin_record["email"] == init_db.DEFAULT_ADMIN["email"]
|
|
assert state.user_roles == {(admin_record["id"], 1)}
|
|
|
|
assert set(state.pricing_settings.keys()) == {
|
|
init_db.DEFAULT_PRICING["slug"]}
|
|
|
|
expected_project_names = {
|
|
project.name for project in init_db.DEFAULT_PROJECTS}
|
|
assert set(state.projects.keys()) == expected_project_names
|
|
|
|
assert len(state.scenarios) == len(init_db.DEFAULT_SCENARIOS)
|
|
assert len(state.financial_inputs) == len(init_db.DEFAULT_FINANCIAL_INPUTS)
|
|
|
|
snapshot = {
|
|
"projects": {name: data.copy() for name, data in state.projects.items()},
|
|
"scenario_keys": set(state.scenarios.keys()),
|
|
"financial_keys": set(state.financial_inputs.keys()),
|
|
"user_roles": set(state.user_roles),
|
|
"admin_id": admin_record["id"],
|
|
}
|
|
|
|
init_db.init_db(database_url="postgresql://fake")
|
|
|
|
state_after = fake_engine.state
|
|
assert set(state_after.roles.keys()) == expected_role_ids
|
|
assert len(state_after.users) == 1
|
|
assert state_after.users["admin"]["id"] == snapshot["admin_id"]
|
|
assert set(state_after.projects.keys()) == set(snapshot["projects"].keys())
|
|
assert set(state_after.scenarios.keys()) == snapshot["scenario_keys"]
|
|
assert set(state_after.financial_inputs.keys()
|
|
) == snapshot["financial_keys"]
|
|
assert state_after.user_roles == snapshot["user_roles"]
|
|
|
|
|
|
def test_enum_seed_values_align_with_definitions() -> None:
|
|
ddl_blob = " ".join(init_db.TABLE_DDLS).lower()
|
|
for enum_name, values in init_db.ENUM_DEFINITIONS.items():
|
|
assert enum_name in ddl_blob
|
|
if enum_name == "miningoperationtype":
|
|
for project in init_db.DEFAULT_PROJECTS:
|
|
assert project.operation_type in values
|
|
if enum_name == "scenariostatus":
|
|
for scenario in init_db.DEFAULT_SCENARIOS:
|
|
assert scenario.status in values
|
|
if enum_name == "resourcetype":
|
|
for scenario in init_db.DEFAULT_SCENARIOS:
|
|
if scenario.primary_resource is not None:
|
|
assert scenario.primary_resource in values
|
|
if enum_name == "financialcategory":
|
|
for item in init_db.DEFAULT_FINANCIAL_INPUTS:
|
|
assert item.category in values
|
|
if enum_name == "costbucket":
|
|
for item in init_db.DEFAULT_FINANCIAL_INPUTS:
|
|
if item.cost_bucket is not None:
|
|
assert item.cost_bucket in values
|
|
if enum_name == "distributiontype":
|
|
# Simulation parameters reference this type in the schema.
|
|
assert "distributiontype" in ddl_blob
|
|
if enum_name == "stochasticvariable":
|
|
assert "stochasticvariable" in ddl_blob
|