from uuid import uuid4 import pytest from fastapi.testclient import TestClient from sqlalchemy.orm import Session from typing import Any, Dict, List from models.simulation_result import SimulationResult from services.simulation import run_simulation @pytest.fixture def client(api_client: TestClient) -> TestClient: return api_client def test_run_simulation_function_generates_samples(): params: List[Dict[str, Any]] = [ {"name": "grade", "value": 1.8, "distribution": "normal", "std_dev": 0.2}, { "name": "recovery", "value": 0.9, "distribution": "uniform", "min": 0.8, "max": 0.95, }, ] results = run_simulation(params, iterations=5, seed=123) assert isinstance(results, list) assert len(results) == 5 assert results[0]["iteration"] == 1 def test_run_simulation_with_zero_iterations_returns_empty(): params: List[Dict[str, Any]] = [ {"name": "grade", "value": 1.2, "distribution": "normal"} ] results = run_simulation(params, iterations=0) assert results == [] @pytest.mark.parametrize( "parameter_payload,error_message", [ ({"name": "missing-value"}, "Parameter at index 0 must include 'value'"), ( { "name": "bad-dist", "value": 1.0, "distribution": "unsupported", }, "Parameter 'bad-dist' has unsupported distribution 'unsupported'", ), ( { "name": "uniform-range", "value": 1.0, "distribution": "uniform", "min": 5, "max": 5, }, "Parameter 'uniform-range' requires 'min' < 'max' for uniform distribution", ), ( { "name": "triangular-mode", "value": 5.0, "distribution": "triangular", "min": 1, "max": 3, "mode": 5, }, "Parameter 'triangular-mode' mode must be within min/max bounds for triangular distribution", ), ], ) def test_run_simulation_parameter_validation_errors( parameter_payload: Dict[str, Any], error_message: str ) -> None: with pytest.raises(ValueError) as exc: run_simulation([parameter_payload]) assert str(exc.value) == error_message def test_run_simulation_normal_std_dev_fallback(): params: List[Dict[str, Any]] = [ { "name": "std-dev-fallback", "value": 10.0, "distribution": "normal", "std_dev": 0, } ] results = run_simulation(params, iterations=3, seed=99) assert len(results) == 3 assert all("result" in entry for entry in results) def test_simulation_endpoint_no_params(client: TestClient): scenario_payload: Dict[str, Any] = { "name": f"NoParamScenario-{uuid4()}", "description": "No parameters run", } scenario_resp = client.post("/api/scenarios/", json=scenario_payload) assert scenario_resp.status_code == 200 scenario_id = scenario_resp.json()["id"] resp = client.post( "/api/simulations/run", json={"scenario_id": scenario_id, "parameters": [], "iterations": 10}, ) assert resp.status_code == 400 assert resp.json()["detail"] == "No parameters provided" def test_simulation_endpoint_success( client: TestClient, db_session: Session ): scenario_payload: Dict[str, Any] = { "name": f"SimScenario-{uuid4()}", "description": "Simulation test", } scenario_resp = client.post("/api/scenarios/", json=scenario_payload) assert scenario_resp.status_code == 200 scenario_id = scenario_resp.json()["id"] params: List[Dict[str, Any]] = [ {"name": "param1", "value": 2.5, "distribution": "normal", "std_dev": 0.5} ] payload: Dict[str, Any] = { "scenario_id": scenario_id, "parameters": params, "iterations": 10, "seed": 42, } resp = client.post("/api/simulations/run", json=payload) assert resp.status_code == 200 data = resp.json() assert data["scenario_id"] == scenario_id assert len(data["results"]) == 10 assert data["summary"]["count"] == 10 db_session.expire_all() persisted = ( db_session.query(SimulationResult) .filter(SimulationResult.scenario_id == scenario_id) .all() ) assert len(persisted) == 10 def test_simulation_endpoint_uses_stored_parameters(client: TestClient): scenario_payload: Dict[str, Any] = { "name": f"StoredParams-{uuid4()}", "description": "Stored parameter simulation", } scenario_resp = client.post("/api/scenarios/", json=scenario_payload) assert scenario_resp.status_code == 200 scenario_id = scenario_resp.json()["id"] parameter_payload: Dict[str, Any] = { "scenario_id": scenario_id, "name": "grade", "value": 1.5, } param_resp = client.post("/api/parameters/", json=parameter_payload) assert param_resp.status_code == 200 resp = client.post( "/api/simulations/run", json={"scenario_id": scenario_id, "iterations": 3, "seed": 7}, ) assert resp.status_code == 200 data = resp.json() assert data["summary"]["count"] == 3 assert len(data["results"]) == 3