diff --git a/tests/test_project_scenario_models.py b/tests/test_project_scenario_models.py new file mode 100644 index 0000000..e90aa19 --- /dev/null +++ b/tests/test_project_scenario_models.py @@ -0,0 +1,142 @@ +from __future__ import annotations + +from collections.abc import Iterator +from datetime import datetime, timedelta, timezone + +import pytest +from sqlalchemy import create_engine, select +from sqlalchemy.exc import IntegrityError +from sqlalchemy.orm import Session, sessionmaker + +from config.database import Base +from models import ( + MiningOperationType, + Project, + ProjectProfitability, + Scenario, + ScenarioProfitability, + ScenarioStatus, +) + + +@pytest.fixture() +def engine() -> Iterator: + engine = create_engine("sqlite:///:memory:", future=True) + Base.metadata.create_all(bind=engine) + try: + yield engine + finally: + Base.metadata.drop_all(bind=engine) + + +@pytest.fixture() +def session(engine) -> Iterator[Session]: + TestingSession = sessionmaker( + bind=engine, expire_on_commit=False, future=True) + session = TestingSession() + try: + yield session + finally: + session.close() + + +def test_project_scenario_cascade_deletes(session: Session) -> None: + project = Project(name="Cascade Mine", + operation_type=MiningOperationType.OTHER) + scenario_a = Scenario( + name="Base Case", status=ScenarioStatus.DRAFT, project=project) + scenario_b = Scenario( + name="Expansion", status=ScenarioStatus.DRAFT, project=project) + + session.add(project) + session.commit() + + assert session.scalar(select(Project).where( + Project.id == project.id)) is not None + assert len(session.scalars(select(Scenario).where( + Scenario.project_id == project.id)).all()) == 2 + + session.delete(project) + session.commit() + + assert session.scalar(select(Project).where( + Project.id == project.id)) is None + assert session.scalars(select(Scenario)).first() is None + + +def test_scenario_unique_name_per_project(session: Session) -> None: + project = Project(name="Uniqueness Mine", + operation_type=MiningOperationType.OTHER) + scenario = Scenario(name="Duplicated", + status=ScenarioStatus.DRAFT, project=project) + session.add_all([project, scenario]) + session.commit() + + duplicate = Scenario( + name="Duplicated", status=ScenarioStatus.DRAFT, project=project) + session.add(duplicate) + + with pytest.raises(IntegrityError): + session.commit() + session.rollback() + + +def test_latest_profitability_helpers(session: Session) -> None: + project = Project(name="Hierarchy Mine", + operation_type=MiningOperationType.OTHER) + scenario = Scenario(name="Economic Model", + status=ScenarioStatus.DRAFT, project=project) + session.add_all([project, scenario]) + session.commit() + + base_time = datetime.now(timezone.utc) + scenario_snapshot_old = ScenarioProfitability( + scenario=scenario, + npv=1_000_000, + calculated_at=base_time, + ) + scenario_snapshot_new = ScenarioProfitability( + scenario=scenario, + npv=2_500_000, + calculated_at=base_time + timedelta(hours=6), + ) + project_snapshot_old = ProjectProfitability( + project=project, + npv=5_000_000, + calculated_at=base_time, + ) + project_snapshot_new = ProjectProfitability( + project=project, + npv=7_500_000, + calculated_at=base_time + timedelta(hours=12), + ) + + session.add_all( + [ + scenario_snapshot_old, + scenario_snapshot_new, + project_snapshot_old, + project_snapshot_new, + ] + ) + session.commit() + + session.refresh(scenario) + session.refresh(project) + + assert scenario.latest_profitability is scenario_snapshot_new + assert project.latest_profitability is project_snapshot_new + + +def test_currency_normalisation_on_model(session: Session) -> None: + project = Project(name="Forex Mine", + operation_type=MiningOperationType.OTHER) + scenario = Scenario( + name="Currency Case", status=ScenarioStatus.DRAFT, project=project, currency="usd") + + session.add_all([project, scenario]) + session.commit() + + session.refresh(scenario) + + assert scenario.currency == "USD"