Files
calminer/tests/test_repositories.py

244 lines
7.5 KiB
Python

from __future__ import annotations
from collections.abc import Iterator
from datetime import datetime, timezone
import pytest
from sqlalchemy import create_engine
from sqlalchemy.orm import Session, sessionmaker
from config.database import Base
from models import (
DistributionType,
FinancialCategory,
FinancialInput,
MiningOperationType,
Project,
Scenario,
ScenarioStatus,
SimulationParameter,
StochasticVariable,
)
from services.repositories import (
FinancialInputRepository,
ProjectRepository,
ScenarioRepository,
SimulationParameterRepository,
)
from services.unit_of_work import UnitOfWork
@pytest.fixture()
def engine():
engine = create_engine("sqlite:///:memory:", future=True)
Base.metadata.create_all(bind=engine)
try:
yield engine
finally:
Base.metadata.drop_all(bind=engine)
@pytest.fixture()
def session(engine) -> Iterator[Session]:
TestingSession = sessionmaker(bind=engine, expire_on_commit=False, future=True)
session = TestingSession()
try:
yield session
finally:
session.close()
def test_project_repository_create_and_list(session: Session) -> None:
repo = ProjectRepository(session)
project = Project(name="Project Alpha", operation_type=MiningOperationType.OPEN_PIT)
repo.create(project)
projects = repo.list()
assert len(projects) == 1
assert projects[0].name == "Project Alpha"
def test_scenario_repository_get_with_children(session: Session) -> None:
project = Project(name="Project Beta", operation_type=MiningOperationType.UNDERGROUND)
scenario = Scenario(name="Scenario 1", project=project, status=ScenarioStatus.ACTIVE)
scenario.financial_inputs.append(
FinancialInput(
name="Lease Payment",
category=FinancialCategory.OPERATING_EXPENDITURE,
amount=10000,
currency="usd",
)
)
scenario.simulation_parameters.append(
SimulationParameter(
name="Copper Price",
distribution=DistributionType.NORMAL,
mean_value=3.5,
variable=StochasticVariable.METAL_PRICE,
)
)
session.add(project)
session.flush()
repo = ScenarioRepository(session)
retrieved = repo.get(scenario.id, with_children=True)
assert retrieved.project.name == "Project Beta"
assert len(retrieved.financial_inputs) == 1
assert retrieved.financial_inputs[0].currency == "USD"
assert len(retrieved.simulation_parameters) == 1
assert (
retrieved.simulation_parameters[0].variable
== StochasticVariable.METAL_PRICE
)
param_repo = SimulationParameterRepository(session)
params = param_repo.list_for_scenario(scenario.id)
assert len(params) == 1
def test_financial_input_repository_bulk_upsert(session: Session) -> None:
project = Project(name="Project Gamma", operation_type=MiningOperationType.QUARRY)
scenario = Scenario(name="Scenario Bulk", project=project)
session.add(project)
session.flush()
repo = FinancialInputRepository(session)
created = repo.bulk_upsert(
[
FinancialInput(
scenario_id=scenario.id,
name="Explosives",
category=FinancialCategory.OPERATING_EXPENDITURE,
amount=5000,
currency="cad",
),
FinancialInput(
scenario_id=scenario.id,
name="Equipment Lease",
category=FinancialCategory.OPERATING_EXPENDITURE,
amount=12000,
currency="cad",
),
]
)
assert len(created) == 2
stored = repo.list_for_scenario(scenario.id)
assert len(stored) == 2
assert all(item.currency == "CAD" for item in stored)
def test_unit_of_work_commit_and_rollback(engine) -> None:
TestingSession = sessionmaker(bind=engine, expire_on_commit=False, future=True)
# Commit path
with UnitOfWork(session_factory=TestingSession) as uow:
uow.projects.create(
Project(name="Project Delta", operation_type=MiningOperationType.PLACER)
)
with TestingSession() as session:
projects = ProjectRepository(session).list()
assert len(projects) == 1
# Rollback path
with pytest.raises(RuntimeError):
with UnitOfWork(session_factory=TestingSession) as uow:
uow.projects.create(
Project(name="Project Epsilon", operation_type=MiningOperationType.OTHER)
)
raise RuntimeError("trigger rollback")
with TestingSession() as session:
projects = ProjectRepository(session).list()
assert len(projects) == 1
def test_project_repository_count_and_recent(session: Session) -> None:
repo = ProjectRepository(session)
project_alpha = Project(name="Alpha", operation_type=MiningOperationType.OPEN_PIT)
project_bravo = Project(name="Bravo", operation_type=MiningOperationType.UNDERGROUND)
repo.create(project_alpha)
repo.create(project_bravo)
project_alpha.updated_at = datetime(2025, 1, 1, tzinfo=timezone.utc)
project_bravo.updated_at = datetime(2025, 1, 2, tzinfo=timezone.utc)
session.flush()
assert repo.count() == 2
recent = repo.recent(limit=1)
assert len(recent) == 1
assert recent[0].name == "Bravo"
def test_scenario_repository_counts_and_filters(session: Session) -> None:
project = Project(name="Project", operation_type=MiningOperationType.PLACER)
session.add(project)
session.flush()
repo = ScenarioRepository(session)
draft = Scenario(name="Draft", project_id=project.id,
status=ScenarioStatus.DRAFT)
active = Scenario(name="Active", project_id=project.id,
status=ScenarioStatus.ACTIVE)
repo.create(draft)
repo.create(active)
draft.updated_at = datetime(2025, 1, 1, tzinfo=timezone.utc)
active.updated_at = datetime(2025, 1, 3, tzinfo=timezone.utc)
session.flush()
assert repo.count() == 2
assert repo.count_by_status(ScenarioStatus.ACTIVE) == 1
recent = repo.recent(limit=1, with_project=True)
assert len(recent) == 1
assert recent[0].name == "Active"
assert recent[0].project.name == "Project"
drafts = repo.list_by_status(ScenarioStatus.DRAFT, limit=2, with_project=True)
assert len(drafts) == 1
assert drafts[0].name == "Draft"
assert drafts[0].project_id == project.id
def test_financial_input_repository_latest_created_at(session: Session) -> None:
project = Project(name="Project FI", operation_type=MiningOperationType.OTHER)
scenario = Scenario(name="Scenario FI", project=project)
session.add(project)
session.flush()
repo = FinancialInputRepository(session)
old_timestamp = datetime(2024, 12, 31, 15, 0)
new_timestamp = datetime(2025, 1, 2, 8, 30)
repo.bulk_upsert(
[
FinancialInput(
scenario_id=scenario.id,
name="Legacy Entry",
category=FinancialCategory.OPERATING_EXPENDITURE,
amount=1000,
currency="usd",
created_at=old_timestamp,
updated_at=old_timestamp,
),
FinancialInput(
scenario_id=scenario.id,
name="New Entry",
category=FinancialCategory.OPERATING_EXPENDITURE,
amount=2000,
currency="usd",
created_at=new_timestamp,
updated_at=new_timestamp,
),
]
)
latest = repo.latest_created_at()
assert latest == new_timestamp