Files
calminer/tests/unit/test_simulation.py
zwitschi 97b1c0360b
Some checks failed
Run Tests / e2e tests (push) Failing after 1m27s
Run Tests / lint tests (push) Failing after 6s
Run Tests / unit tests (push) Failing after 7s
Refactor test cases for improved readability and consistency
- Updated test functions in various test files to enhance code clarity by formatting long lines and improving indentation.
- Adjusted assertions to use multi-line formatting for better readability.
- Added new test cases for theme settings API to ensure proper functionality.
- Ensured consistent use of line breaks and spacing across test files for uniformity.
2025-10-27 10:32:55 +01:00

233 lines
6.9 KiB
Python

from math import isclose
from random import Random
from uuid import uuid4
import pytest
from fastapi.testclient import TestClient
from sqlalchemy.orm import Session
from typing import Any, Dict, List
from models.simulation_result import SimulationResult
from services.simulation import DEFAULT_UNIFORM_SPAN_RATIO, run_simulation
@pytest.fixture
def client(api_client: TestClient) -> TestClient:
return api_client
def test_run_simulation_function_generates_samples():
params: List[Dict[str, Any]] = [
{
"name": "grade",
"value": 1.8,
"distribution": "normal",
"std_dev": 0.2,
},
{
"name": "recovery",
"value": 0.9,
"distribution": "uniform",
"min": 0.8,
"max": 0.95,
},
]
results = run_simulation(params, iterations=5, seed=123)
assert isinstance(results, list)
assert len(results) == 5
assert results[0]["iteration"] == 1
def test_run_simulation_with_zero_iterations_returns_empty():
params: List[Dict[str, Any]] = [
{"name": "grade", "value": 1.2, "distribution": "normal"}
]
results = run_simulation(params, iterations=0)
assert results == []
@pytest.mark.parametrize(
"parameter_payload,error_message",
[
(
{"name": "missing-value"},
"Parameter at index 0 must include 'value'",
),
(
{
"name": "bad-dist",
"value": 1.0,
"distribution": "unsupported",
},
"Parameter 'bad-dist' has unsupported distribution 'unsupported'",
),
(
{
"name": "uniform-range",
"value": 1.0,
"distribution": "uniform",
"min": 5,
"max": 5,
},
"Parameter 'uniform-range' requires 'min' < 'max' for uniform distribution",
),
(
{
"name": "triangular-mode",
"value": 5.0,
"distribution": "triangular",
"min": 1,
"max": 3,
"mode": 5,
},
"Parameter 'triangular-mode' mode must be within min/max bounds for triangular distribution",
),
],
)
def test_run_simulation_parameter_validation_errors(
parameter_payload: Dict[str, Any], error_message: str
) -> None:
with pytest.raises(ValueError) as exc:
run_simulation([parameter_payload])
assert str(exc.value) == error_message
def test_run_simulation_normal_std_dev_fallback():
params: List[Dict[str, Any]] = [
{
"name": "std-dev-fallback",
"value": 10.0,
"distribution": "normal",
"std_dev": 0,
}
]
results = run_simulation(params, iterations=3, seed=99)
assert len(results) == 3
assert all("result" in entry for entry in results)
def test_run_simulation_triangular_sampling_path():
params: List[Dict[str, Any]] = [
{"name": "tri", "value": 10.0, "distribution": "triangular"}
]
seed = 21
iterations = 4
results = run_simulation(params, iterations=iterations, seed=seed)
assert len(results) == iterations
span = 10.0 * DEFAULT_UNIFORM_SPAN_RATIO
rng = Random(seed)
expected_samples = [
rng.triangular(10.0 - span, 10.0 + span, 10.0)
for _ in range(iterations)
]
actual_samples = [entry["result"] for entry in results]
for actual, expected in zip(actual_samples, expected_samples):
assert isclose(actual, expected, rel_tol=1e-9)
def test_run_simulation_uniform_defaults_apply_bounds():
params: List[Dict[str, Any]] = [
{"name": "uniform-auto", "value": 200.0, "distribution": "uniform"}
]
seed = 17
iterations = 3
results = run_simulation(params, iterations=iterations, seed=seed)
assert len(results) == iterations
span = 200.0 * DEFAULT_UNIFORM_SPAN_RATIO
rng = Random(seed)
expected_samples = [
rng.uniform(200.0 - span, 200.0 + span) for _ in range(iterations)
]
actual_samples = [entry["result"] for entry in results]
for actual, expected in zip(actual_samples, expected_samples):
assert isclose(actual, expected, rel_tol=1e-9)
def test_run_simulation_without_parameters_returns_empty():
assert run_simulation([], iterations=5) == []
def test_simulation_endpoint_no_params(client: TestClient):
scenario_payload: Dict[str, Any] = {
"name": f"NoParamScenario-{uuid4()}",
"description": "No parameters run",
}
scenario_resp = client.post("/api/scenarios/", json=scenario_payload)
assert scenario_resp.status_code == 200
scenario_id = scenario_resp.json()["id"]
resp = client.post(
"/api/simulations/run",
json={"scenario_id": scenario_id, "parameters": [], "iterations": 10},
)
assert resp.status_code == 400
assert resp.json()["detail"] == "No parameters provided"
def test_simulation_endpoint_success(client: TestClient, db_session: Session):
scenario_payload: Dict[str, Any] = {
"name": f"SimScenario-{uuid4()}",
"description": "Simulation test",
}
scenario_resp = client.post("/api/scenarios/", json=scenario_payload)
assert scenario_resp.status_code == 200
scenario_id = scenario_resp.json()["id"]
params: List[Dict[str, Any]] = [
{
"name": "param1",
"value": 2.5,
"distribution": "normal",
"std_dev": 0.5,
}
]
payload: Dict[str, Any] = {
"scenario_id": scenario_id,
"parameters": params,
"iterations": 10,
"seed": 42,
}
resp = client.post("/api/simulations/run", json=payload)
assert resp.status_code == 200
data = resp.json()
assert data["scenario_id"] == scenario_id
assert len(data["results"]) == 10
assert data["summary"]["count"] == 10
db_session.expire_all()
persisted = (
db_session.query(SimulationResult)
.filter(SimulationResult.scenario_id == scenario_id)
.all()
)
assert len(persisted) == 10
def test_simulation_endpoint_uses_stored_parameters(client: TestClient):
scenario_payload: Dict[str, Any] = {
"name": f"StoredParams-{uuid4()}",
"description": "Stored parameter simulation",
}
scenario_resp = client.post("/api/scenarios/", json=scenario_payload)
assert scenario_resp.status_code == 200
scenario_id = scenario_resp.json()["id"]
parameter_payload: Dict[str, Any] = {
"scenario_id": scenario_id,
"name": "grade",
"value": 1.5,
}
param_resp = client.post("/api/parameters/", json=parameter_payload)
assert param_resp.status_code == 200
resp = client.post(
"/api/simulations/run",
json={"scenario_id": scenario_id, "iterations": 3, "seed": 7},
)
assert resp.status_code == 200
data = resp.json()
assert data["summary"]["count"] == 3
assert len(data["results"]) == 3