from uuid import uuid4 import pytest from fastapi.testclient import TestClient from sqlalchemy.orm import Session from models.simulation_result import SimulationResult from services.simulation import run_simulation @pytest.fixture def client(api_client: TestClient) -> TestClient: return api_client def test_run_simulation_function_generates_samples(): params = [ {"name": "grade", "value": 1.8, "distribution": "normal", "std_dev": 0.2}, { "name": "recovery", "value": 0.9, "distribution": "uniform", "min": 0.8, "max": 0.95, }, ] results = run_simulation(params, iterations=5, seed=123) assert isinstance(results, list) assert len(results) == 5 assert results[0]["iteration"] == 1 def test_simulation_endpoint_no_params(client: TestClient): scenario_payload = { "name": f"NoParamScenario-{uuid4()}", "description": "No parameters run", } scenario_resp = client.post("/api/scenarios/", json=scenario_payload) assert scenario_resp.status_code == 200 scenario_id = scenario_resp.json()["id"] resp = client.post( "/api/simulations/run", json={"scenario_id": scenario_id, "parameters": [], "iterations": 10}, ) assert resp.status_code == 400 assert resp.json()["detail"] == "No parameters provided" def test_simulation_endpoint_success( client: TestClient, db_session: Session ): scenario_payload = { "name": f"SimScenario-{uuid4()}", "description": "Simulation test", } scenario_resp = client.post("/api/scenarios/", json=scenario_payload) assert scenario_resp.status_code == 200 scenario_id = scenario_resp.json()["id"] params = [ {"name": "param1", "value": 2.5, "distribution": "normal", "std_dev": 0.5} ] payload = { "scenario_id": scenario_id, "parameters": params, "iterations": 10, "seed": 42, } resp = client.post("/api/simulations/run", json=payload) assert resp.status_code == 200 data = resp.json() assert data["scenario_id"] == scenario_id assert len(data["results"]) == 10 assert data["summary"]["count"] == 10 db_session.expire_all() persisted = ( db_session.query(SimulationResult) .filter(SimulationResult.scenario_id == scenario_id) .all() ) assert len(persisted) == 10 def test_simulation_endpoint_uses_stored_parameters(client: TestClient): scenario_payload = { "name": f"StoredParams-{uuid4()}", "description": "Stored parameter simulation", } scenario_resp = client.post("/api/scenarios/", json=scenario_payload) assert scenario_resp.status_code == 200 scenario_id = scenario_resp.json()["id"] parameter_payload = { "scenario_id": scenario_id, "name": "grade", "value": 1.5, } param_resp = client.post("/api/parameters/", json=parameter_payload) assert param_resp.status_code == 200 resp = client.post( "/api/simulations/run", json={"scenario_id": scenario_id, "iterations": 3, "seed": 7}, ) assert resp.status_code == 200 data = resp.json() assert data["summary"]["count"] == 3 assert len(data["results"]) == 3