from __future__ import annotations from collections.abc import Iterator from datetime import datetime, timedelta, timezone import pytest from sqlalchemy import create_engine, select from sqlalchemy.exc import IntegrityError from sqlalchemy.orm import Session, sessionmaker from config.database import Base from models import ( MiningOperationType, Project, ProjectProfitability, Scenario, ScenarioProfitability, ScenarioStatus, ) @pytest.fixture() def engine() -> Iterator: engine = create_engine("sqlite:///:memory:", future=True) Base.metadata.create_all(bind=engine) try: yield engine finally: Base.metadata.drop_all(bind=engine) @pytest.fixture() def session(engine) -> Iterator[Session]: TestingSession = sessionmaker( bind=engine, expire_on_commit=False, future=True) session = TestingSession() try: yield session finally: session.close() def test_project_scenario_cascade_deletes(session: Session) -> None: project = Project(name="Cascade Mine", operation_type=MiningOperationType.OTHER) Scenario( name="Base Case", status=ScenarioStatus.DRAFT, project=project) Scenario( name="Expansion", status=ScenarioStatus.DRAFT, project=project) session.add(project) session.commit() assert session.scalar(select(Project).where( Project.id == project.id)) is not None assert len(session.scalars(select(Scenario).where( Scenario.project_id == project.id)).all()) == 2 session.delete(project) session.commit() assert session.scalar(select(Project).where( Project.id == project.id)) is None assert session.scalars(select(Scenario)).first() is None def test_scenario_unique_name_per_project(session: Session) -> None: project = Project(name="Uniqueness Mine", operation_type=MiningOperationType.OTHER) scenario = Scenario(name="Duplicated", status=ScenarioStatus.DRAFT, project=project) session.add_all([project, scenario]) session.commit() duplicate = Scenario( name="Duplicated", status=ScenarioStatus.DRAFT, project=project) session.add(duplicate) with pytest.raises(IntegrityError): session.commit() session.rollback() def test_latest_profitability_helpers(session: Session) -> None: project = Project(name="Hierarchy Mine", operation_type=MiningOperationType.OTHER) scenario = Scenario(name="Economic Model", status=ScenarioStatus.DRAFT, project=project) session.add_all([project, scenario]) session.commit() base_time = datetime.now(timezone.utc) scenario_snapshot_old = ScenarioProfitability( scenario=scenario, npv=1_000_000, calculated_at=base_time, ) scenario_snapshot_new = ScenarioProfitability( scenario=scenario, npv=2_500_000, calculated_at=base_time + timedelta(hours=6), ) project_snapshot_old = ProjectProfitability( project=project, npv=5_000_000, calculated_at=base_time, ) project_snapshot_new = ProjectProfitability( project=project, npv=7_500_000, calculated_at=base_time + timedelta(hours=12), ) session.add_all( [ scenario_snapshot_old, scenario_snapshot_new, project_snapshot_old, project_snapshot_new, ] ) session.commit() session.refresh(scenario) session.refresh(project) assert scenario.latest_profitability is scenario_snapshot_new assert project.latest_profitability is project_snapshot_new def test_currency_normalisation_on_model(session: Session) -> None: project = Project(name="Forex Mine", operation_type=MiningOperationType.OTHER) scenario = Scenario( name="Currency Case", status=ScenarioStatus.DRAFT, project=project, currency="usd") session.add_all([project, scenario]) session.commit() session.refresh(scenario) assert scenario.currency == "USD"