from __future__ import annotations from collections.abc import Iterator from datetime import datetime, timezone import pytest from sqlalchemy import create_engine from sqlalchemy.orm import Session, sessionmaker from config.database import Base from models import ( DistributionType, FinancialCategory, FinancialInput, MiningOperationType, Project, Scenario, ScenarioStatus, SimulationParameter, StochasticVariable, ) from services.repositories import ( FinancialInputRepository, ProjectRepository, ScenarioRepository, SimulationParameterRepository, ) from services.unit_of_work import UnitOfWork @pytest.fixture() def engine(): engine = create_engine("sqlite:///:memory:", future=True) Base.metadata.create_all(bind=engine) try: yield engine finally: Base.metadata.drop_all(bind=engine) @pytest.fixture() def session(engine) -> Iterator[Session]: TestingSession = sessionmaker(bind=engine, expire_on_commit=False, future=True) session = TestingSession() try: yield session finally: session.close() def test_project_repository_create_and_list(session: Session) -> None: repo = ProjectRepository(session) project = Project(name="Project Alpha", operation_type=MiningOperationType.OPEN_PIT) repo.create(project) projects = repo.list() assert len(projects) == 1 assert projects[0].name == "Project Alpha" def test_scenario_repository_get_with_children(session: Session) -> None: project = Project(name="Project Beta", operation_type=MiningOperationType.UNDERGROUND) scenario = Scenario(name="Scenario 1", project=project, status=ScenarioStatus.ACTIVE) scenario.financial_inputs.append( FinancialInput( name="Lease Payment", category=FinancialCategory.OPERATING_EXPENDITURE, amount=10000, currency="usd", ) ) scenario.simulation_parameters.append( SimulationParameter( name="Copper Price", distribution=DistributionType.NORMAL, mean_value=3.5, variable=StochasticVariable.METAL_PRICE, ) ) session.add(project) session.flush() repo = ScenarioRepository(session) retrieved = repo.get(scenario.id, with_children=True) assert retrieved.project.name == "Project Beta" assert len(retrieved.financial_inputs) == 1 assert retrieved.financial_inputs[0].currency == "USD" assert len(retrieved.simulation_parameters) == 1 assert ( retrieved.simulation_parameters[0].variable == StochasticVariable.METAL_PRICE ) param_repo = SimulationParameterRepository(session) params = param_repo.list_for_scenario(scenario.id) assert len(params) == 1 def test_financial_input_repository_bulk_upsert(session: Session) -> None: project = Project(name="Project Gamma", operation_type=MiningOperationType.QUARRY) scenario = Scenario(name="Scenario Bulk", project=project) session.add(project) session.flush() repo = FinancialInputRepository(session) created = repo.bulk_upsert( [ FinancialInput( scenario_id=scenario.id, name="Explosives", category=FinancialCategory.OPERATING_EXPENDITURE, amount=5000, currency="cad", ), FinancialInput( scenario_id=scenario.id, name="Equipment Lease", category=FinancialCategory.OPERATING_EXPENDITURE, amount=12000, currency="cad", ), ] ) assert len(created) == 2 stored = repo.list_for_scenario(scenario.id) assert len(stored) == 2 assert all(item.currency == "CAD" for item in stored) def test_unit_of_work_commit_and_rollback(engine) -> None: TestingSession = sessionmaker(bind=engine, expire_on_commit=False, future=True) # Commit path with UnitOfWork(session_factory=TestingSession) as uow: uow.projects.create( Project(name="Project Delta", operation_type=MiningOperationType.PLACER) ) with TestingSession() as session: projects = ProjectRepository(session).list() assert len(projects) == 1 # Rollback path with pytest.raises(RuntimeError): with UnitOfWork(session_factory=TestingSession) as uow: uow.projects.create( Project(name="Project Epsilon", operation_type=MiningOperationType.OTHER) ) raise RuntimeError("trigger rollback") with TestingSession() as session: projects = ProjectRepository(session).list() assert len(projects) == 1 def test_project_repository_count_and_recent(session: Session) -> None: repo = ProjectRepository(session) project_alpha = Project(name="Alpha", operation_type=MiningOperationType.OPEN_PIT) project_bravo = Project(name="Bravo", operation_type=MiningOperationType.UNDERGROUND) repo.create(project_alpha) repo.create(project_bravo) project_alpha.updated_at = datetime(2025, 1, 1, tzinfo=timezone.utc) project_bravo.updated_at = datetime(2025, 1, 2, tzinfo=timezone.utc) session.flush() assert repo.count() == 2 recent = repo.recent(limit=1) assert len(recent) == 1 assert recent[0].name == "Bravo" def test_scenario_repository_counts_and_filters(session: Session) -> None: project = Project(name="Project", operation_type=MiningOperationType.PLACER) session.add(project) session.flush() repo = ScenarioRepository(session) draft = Scenario(name="Draft", project_id=project.id, status=ScenarioStatus.DRAFT) active = Scenario(name="Active", project_id=project.id, status=ScenarioStatus.ACTIVE) repo.create(draft) repo.create(active) draft.updated_at = datetime(2025, 1, 1, tzinfo=timezone.utc) active.updated_at = datetime(2025, 1, 3, tzinfo=timezone.utc) session.flush() assert repo.count() == 2 assert repo.count_by_status(ScenarioStatus.ACTIVE) == 1 recent = repo.recent(limit=1, with_project=True) assert len(recent) == 1 assert recent[0].name == "Active" assert recent[0].project.name == "Project" drafts = repo.list_by_status(ScenarioStatus.DRAFT, limit=2, with_project=True) assert len(drafts) == 1 assert drafts[0].name == "Draft" assert drafts[0].project_id == project.id def test_financial_input_repository_latest_created_at(session: Session) -> None: project = Project(name="Project FI", operation_type=MiningOperationType.OTHER) scenario = Scenario(name="Scenario FI", project=project) session.add(project) session.flush() repo = FinancialInputRepository(session) old_timestamp = datetime(2024, 12, 31, 15, 0) new_timestamp = datetime(2025, 1, 2, 8, 30) repo.bulk_upsert( [ FinancialInput( scenario_id=scenario.id, name="Legacy Entry", category=FinancialCategory.OPERATING_EXPENDITURE, amount=1000, currency="usd", created_at=old_timestamp, updated_at=old_timestamp, ), FinancialInput( scenario_id=scenario.id, name="New Entry", category=FinancialCategory.OPERATING_EXPENDITURE, amount=2000, currency="usd", created_at=new_timestamp, updated_at=new_timestamp, ), ] ) latest = repo.latest_created_at() assert latest == new_timestamp