from __future__ import annotations from collections.abc import Iterator import pytest from sqlalchemy import create_engine from sqlalchemy.orm import Session, sessionmaker from config.database import Base from models import ( DistributionType, FinancialCategory, FinancialInput, MiningOperationType, Project, Scenario, ScenarioStatus, SimulationParameter, StochasticVariable, ) from services.repositories import ( FinancialInputRepository, ProjectRepository, ScenarioRepository, SimulationParameterRepository, ) from services.unit_of_work import UnitOfWork @pytest.fixture() def engine(): engine = create_engine("sqlite:///:memory:", future=True) Base.metadata.create_all(bind=engine) try: yield engine finally: Base.metadata.drop_all(bind=engine) @pytest.fixture() def session(engine) -> Iterator[Session]: TestingSession = sessionmaker(bind=engine, expire_on_commit=False, future=True) session = TestingSession() try: yield session finally: session.close() def test_project_repository_create_and_list(session: Session) -> None: repo = ProjectRepository(session) project = Project(name="Project Alpha", operation_type=MiningOperationType.OPEN_PIT) repo.create(project) projects = repo.list() assert len(projects) == 1 assert projects[0].name == "Project Alpha" def test_scenario_repository_get_with_children(session: Session) -> None: project = Project(name="Project Beta", operation_type=MiningOperationType.UNDERGROUND) scenario = Scenario(name="Scenario 1", project=project, status=ScenarioStatus.ACTIVE) scenario.financial_inputs.append( FinancialInput( name="Lease Payment", category=FinancialCategory.OPERATING_EXPENDITURE, amount=10000, currency="usd", ) ) scenario.simulation_parameters.append( SimulationParameter( name="Copper Price", distribution=DistributionType.NORMAL, mean_value=3.5, variable=StochasticVariable.METAL_PRICE, ) ) session.add(project) session.flush() repo = ScenarioRepository(session) retrieved = repo.get(scenario.id, with_children=True) assert retrieved.project.name == "Project Beta" assert len(retrieved.financial_inputs) == 1 assert retrieved.financial_inputs[0].currency == "USD" assert len(retrieved.simulation_parameters) == 1 assert ( retrieved.simulation_parameters[0].variable == StochasticVariable.METAL_PRICE ) param_repo = SimulationParameterRepository(session) params = param_repo.list_for_scenario(scenario.id) assert len(params) == 1 def test_financial_input_repository_bulk_upsert(session: Session) -> None: project = Project(name="Project Gamma", operation_type=MiningOperationType.QUARRY) scenario = Scenario(name="Scenario Bulk", project=project) session.add(project) session.flush() repo = FinancialInputRepository(session) created = repo.bulk_upsert( [ FinancialInput( scenario_id=scenario.id, name="Explosives", category=FinancialCategory.OPERATING_EXPENDITURE, amount=5000, currency="cad", ), FinancialInput( scenario_id=scenario.id, name="Equipment Lease", category=FinancialCategory.OPERATING_EXPENDITURE, amount=12000, currency="cad", ), ] ) assert len(created) == 2 stored = repo.list_for_scenario(scenario.id) assert len(stored) == 2 assert all(item.currency == "CAD" for item in stored) def test_unit_of_work_commit_and_rollback(engine) -> None: TestingSession = sessionmaker(bind=engine, expire_on_commit=False, future=True) # Commit path with UnitOfWork(session_factory=TestingSession) as uow: uow.projects.create( Project(name="Project Delta", operation_type=MiningOperationType.PLACER) ) with TestingSession() as session: projects = ProjectRepository(session).list() assert len(projects) == 1 # Rollback path with pytest.raises(RuntimeError): with UnitOfWork(session_factory=TestingSession) as uow: uow.projects.create( Project(name="Project Epsilon", operation_type=MiningOperationType.OTHER) ) raise RuntimeError("trigger rollback") with TestingSession() as session: projects = ProjectRepository(session).list() assert len(projects) == 1