Files
calminer/tests/unit/test_simulation.py
zwitschi 139ae04538 Enhance UI rendering and add unit tests for simulation functionality
- Updated the `_render` function in `ui.py` to correctly pass the request object to `TemplateResponse`.
- Initialized `upcoming_maintenance` as a typed list in `_load_dashboard` for better type safety.
- Added new unit tests in `test_simulation.py` to cover triangular sampling and uniform distribution defaults.
- Implemented a test to ensure that running the simulation without parameters returns an empty result.
- Created a parameterized test in `test_ui_routes.py` to verify that additional UI routes render the correct templates and context.
2025-10-21 09:26:39 +02:00

221 lines
6.8 KiB
Python

from math import isclose
from random import Random
from uuid import uuid4
import pytest
from fastapi.testclient import TestClient
from sqlalchemy.orm import Session
from typing import Any, Dict, List
from models.simulation_result import SimulationResult
from services.simulation import DEFAULT_UNIFORM_SPAN_RATIO, run_simulation
@pytest.fixture
def client(api_client: TestClient) -> TestClient:
return api_client
def test_run_simulation_function_generates_samples():
params: List[Dict[str, Any]] = [
{"name": "grade", "value": 1.8, "distribution": "normal", "std_dev": 0.2},
{
"name": "recovery",
"value": 0.9,
"distribution": "uniform",
"min": 0.8,
"max": 0.95,
},
]
results = run_simulation(params, iterations=5, seed=123)
assert isinstance(results, list)
assert len(results) == 5
assert results[0]["iteration"] == 1
def test_run_simulation_with_zero_iterations_returns_empty():
params: List[Dict[str, Any]] = [
{"name": "grade", "value": 1.2, "distribution": "normal"}
]
results = run_simulation(params, iterations=0)
assert results == []
@pytest.mark.parametrize(
"parameter_payload,error_message",
[
({"name": "missing-value"}, "Parameter at index 0 must include 'value'"),
(
{
"name": "bad-dist",
"value": 1.0,
"distribution": "unsupported",
},
"Parameter 'bad-dist' has unsupported distribution 'unsupported'",
),
(
{
"name": "uniform-range",
"value": 1.0,
"distribution": "uniform",
"min": 5,
"max": 5,
},
"Parameter 'uniform-range' requires 'min' < 'max' for uniform distribution",
),
(
{
"name": "triangular-mode",
"value": 5.0,
"distribution": "triangular",
"min": 1,
"max": 3,
"mode": 5,
},
"Parameter 'triangular-mode' mode must be within min/max bounds for triangular distribution",
),
],
)
def test_run_simulation_parameter_validation_errors(
parameter_payload: Dict[str, Any], error_message: str
) -> None:
with pytest.raises(ValueError) as exc:
run_simulation([parameter_payload])
assert str(exc.value) == error_message
def test_run_simulation_normal_std_dev_fallback():
params: List[Dict[str, Any]] = [
{
"name": "std-dev-fallback",
"value": 10.0,
"distribution": "normal",
"std_dev": 0,
}
]
results = run_simulation(params, iterations=3, seed=99)
assert len(results) == 3
assert all("result" in entry for entry in results)
def test_run_simulation_triangular_sampling_path():
params: List[Dict[str, Any]] = [
{"name": "tri", "value": 10.0, "distribution": "triangular"}
]
seed = 21
iterations = 4
results = run_simulation(params, iterations=iterations, seed=seed)
assert len(results) == iterations
span = 10.0 * DEFAULT_UNIFORM_SPAN_RATIO
rng = Random(seed)
expected_samples = [
rng.triangular(10.0 - span, 10.0 + span, 10.0) for _ in range(iterations)
]
actual_samples = [entry["result"] for entry in results]
for actual, expected in zip(actual_samples, expected_samples):
assert isclose(actual, expected, rel_tol=1e-9)
def test_run_simulation_uniform_defaults_apply_bounds():
params: List[Dict[str, Any]] = [
{"name": "uniform-auto", "value": 200.0, "distribution": "uniform"}
]
seed = 17
iterations = 3
results = run_simulation(params, iterations=iterations, seed=seed)
assert len(results) == iterations
span = 200.0 * DEFAULT_UNIFORM_SPAN_RATIO
rng = Random(seed)
expected_samples = [
rng.uniform(200.0 - span, 200.0 + span) for _ in range(iterations)
]
actual_samples = [entry["result"] for entry in results]
for actual, expected in zip(actual_samples, expected_samples):
assert isclose(actual, expected, rel_tol=1e-9)
def test_run_simulation_without_parameters_returns_empty():
assert run_simulation([], iterations=5) == []
def test_simulation_endpoint_no_params(client: TestClient):
scenario_payload: Dict[str, Any] = {
"name": f"NoParamScenario-{uuid4()}",
"description": "No parameters run",
}
scenario_resp = client.post("/api/scenarios/", json=scenario_payload)
assert scenario_resp.status_code == 200
scenario_id = scenario_resp.json()["id"]
resp = client.post(
"/api/simulations/run",
json={"scenario_id": scenario_id, "parameters": [], "iterations": 10},
)
assert resp.status_code == 400
assert resp.json()["detail"] == "No parameters provided"
def test_simulation_endpoint_success(
client: TestClient, db_session: Session
):
scenario_payload: Dict[str, Any] = {
"name": f"SimScenario-{uuid4()}",
"description": "Simulation test",
}
scenario_resp = client.post("/api/scenarios/", json=scenario_payload)
assert scenario_resp.status_code == 200
scenario_id = scenario_resp.json()["id"]
params: List[Dict[str, Any]] = [
{"name": "param1", "value": 2.5, "distribution": "normal", "std_dev": 0.5}
]
payload: Dict[str, Any] = {
"scenario_id": scenario_id,
"parameters": params,
"iterations": 10,
"seed": 42,
}
resp = client.post("/api/simulations/run", json=payload)
assert resp.status_code == 200
data = resp.json()
assert data["scenario_id"] == scenario_id
assert len(data["results"]) == 10
assert data["summary"]["count"] == 10
db_session.expire_all()
persisted = (
db_session.query(SimulationResult)
.filter(SimulationResult.scenario_id == scenario_id)
.all()
)
assert len(persisted) == 10
def test_simulation_endpoint_uses_stored_parameters(client: TestClient):
scenario_payload: Dict[str, Any] = {
"name": f"StoredParams-{uuid4()}",
"description": "Stored parameter simulation",
}
scenario_resp = client.post("/api/scenarios/", json=scenario_payload)
assert scenario_resp.status_code == 200
scenario_id = scenario_resp.json()["id"]
parameter_payload: Dict[str, Any] = {
"scenario_id": scenario_id,
"name": "grade",
"value": 1.5,
}
param_resp = client.post("/api/parameters/", json=parameter_payload)
assert param_resp.status_code == 200
resp = client.post(
"/api/simulations/run",
json={"scenario_id": scenario_id, "iterations": 3, "seed": 7},
)
assert resp.status_code == 200
data = resp.json()
assert data["summary"]["count"] == 3
assert len(data["results"]) == 3