Files
calminer/tests/unit/conftest.py
zwitschi 75f533b87b
All checks were successful
Run Tests / test (push) Successful in 1m51s
fix: Update HTTP status code for unprocessable entity and improve test database setup
2025-10-25 19:26:43 +02:00

251 lines
7.3 KiB
Python

from datetime import date
from typing import Any, Dict, Generator
from uuid import uuid4
import pytest
from fastapi.testclient import TestClient
from sqlalchemy import create_engine
from sqlalchemy.orm import Session, sessionmaker
from sqlalchemy.pool import StaticPool
from config.database import Base
from main import app
from models.capex import Capex
from models.consumption import Consumption
from models.equipment import Equipment
from models.maintenance import Maintenance
from models.opex import Opex
from models.parameters import Parameter
from models.production_output import ProductionOutput
from models.scenario import Scenario
from models.simulation_result import SimulationResult
SQLALCHEMY_TEST_URL = "sqlite:///:memory:"
engine = create_engine(
SQLALCHEMY_TEST_URL,
connect_args={"check_same_thread": False},
poolclass=StaticPool,
)
TestingSessionLocal = sessionmaker(
autocommit=False, autoflush=False, bind=engine)
@pytest.fixture(scope="session", autouse=True)
def setup_database() -> Generator[None, None, None]:
# Ensure all model metadata is registered before creating tables
from models import (
application_setting,
capex,
consumption,
distribution,
equipment,
maintenance,
opex,
parameters,
production_output,
scenario,
simulation_result,
) # noqa: F401 - imported for side effects
_ = (
capex,
consumption,
distribution,
equipment,
maintenance,
application_setting,
opex,
parameters,
production_output,
scenario,
simulation_result,
)
Base.metadata.create_all(bind=engine)
yield
Base.metadata.drop_all(bind=engine)
@pytest.fixture()
def db_session() -> Generator[Session, None, None]:
Base.metadata.drop_all(bind=engine)
Base.metadata.create_all(bind=engine)
session = TestingSessionLocal()
try:
yield session
finally:
session.rollback()
session.close()
@pytest.fixture()
def api_client(db_session: Session) -> Generator[TestClient, None, None]:
def override_get_db():
try:
yield db_session
finally:
pass
from routes import dependencies as route_dependencies
app.dependency_overrides[route_dependencies.get_db] = override_get_db
with TestClient(app) as client:
yield client
app.dependency_overrides.pop(route_dependencies.get_db, None)
@pytest.fixture()
def seeded_ui_data(db_session: Session) -> Generator[Dict[str, Any], None, None]:
"""Populate a scenario with representative related records for UI tests."""
scenario_name = f"Scenario Alpha {uuid4()}"
scenario = Scenario(name=scenario_name,
description="Seeded UI scenario")
db_session.add(scenario)
db_session.flush()
parameter = Parameter(
scenario_id=scenario.id,
name="Ore Grade",
value=1.5,
distribution_type="normal",
distribution_parameters={"mean": 1.5, "std_dev": 0.1},
)
capex = Capex(
scenario_id=scenario.id,
amount=1_000_000.0,
description="Drill purchase",
currency_code="USD",
)
opex = Opex(
scenario_id=scenario.id,
amount=250_000.0,
description="Fuel spend",
currency_code="USD",
)
consumption = Consumption(
scenario_id=scenario.id,
amount=1_200.0,
description="Diesel (L)",
unit_name="Liters",
unit_symbol="L",
)
production = ProductionOutput(
scenario_id=scenario.id,
amount=800.0,
description="Ore (tonnes)",
unit_name="Tonnes",
unit_symbol="t",
)
equipment = Equipment(
scenario_id=scenario.id,
name="Excavator 42",
description="Primary loader",
)
db_session.add_all(
[parameter, capex, opex, consumption, production, equipment]
)
db_session.flush()
maintenance = Maintenance(
scenario_id=scenario.id,
equipment_id=equipment.id,
maintenance_date=date(2025, 1, 15),
description="Hydraulic service",
cost=15_000.0,
)
simulation_results = [
SimulationResult(
scenario_id=scenario.id,
iteration=index,
result=value,
)
for index, value in enumerate((950_000.0, 975_000.0, 990_000.0), start=1)
]
db_session.add(maintenance)
db_session.add_all(simulation_results)
db_session.commit()
try:
yield {
"scenario": scenario,
"equipment": equipment,
"simulation_results": simulation_results,
}
finally:
db_session.query(SimulationResult).filter_by(
scenario_id=scenario.id
).delete()
db_session.query(Maintenance).filter_by(
scenario_id=scenario.id
).delete()
db_session.query(Equipment).filter_by(id=equipment.id).delete()
db_session.query(ProductionOutput).filter_by(
scenario_id=scenario.id
).delete()
db_session.query(Consumption).filter_by(
scenario_id=scenario.id
).delete()
db_session.query(Opex).filter_by(scenario_id=scenario.id).delete()
db_session.query(Capex).filter_by(scenario_id=scenario.id).delete()
db_session.query(Parameter).filter_by(scenario_id=scenario.id).delete()
db_session.query(Scenario).filter_by(id=scenario.id).delete()
db_session.commit()
@pytest.fixture()
def invalid_request_payloads(db_session: Session) -> Generator[Dict[str, Any], None, None]:
"""Provide reusable invalid request bodies for exercising validation branches."""
duplicate_name = f"Scenario Duplicate {uuid4()}"
existing = Scenario(name=duplicate_name,
description="Existing scenario for duplicate checks")
db_session.add(existing)
db_session.commit()
payloads: Dict[str, Any] = {
"existing_scenario": existing,
"scenario_duplicate": {
"name": duplicate_name,
"description": "Second scenario should fail with duplicate name",
},
"parameter_missing_scenario": {
"scenario_id": existing.id + 99,
"name": "Invalid Parameter",
"value": 1.0,
},
"parameter_invalid_distribution": {
"scenario_id": existing.id,
"name": "Weird Dist",
"value": 2.5,
"distribution_type": "invalid",
},
"simulation_unknown_scenario": {
"scenario_id": existing.id + 99,
"iterations": 10,
"parameters": [
{"name": "grade", "value": 1.2, "distribution": "normal"}
],
},
"simulation_missing_parameters": {
"scenario_id": existing.id,
"iterations": 5,
"parameters": [],
},
"reporting_non_list_payload": {"result": 10.0},
"reporting_missing_result": [{"value": 12.0}],
"maintenance_negative_cost": {
"equipment_id": 1,
"scenario_id": existing.id,
"maintenance_date": "2025-01-15",
"cost": -500.0,
},
}
try:
yield payloads
finally:
db_session.query(Scenario).filter_by(id=existing.id).delete()
db_session.commit()