diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/unit/__init__.py b/tests/unit/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/unit/test_parameter.py b/tests/unit/test_parameter.py new file mode 100644 index 0000000..862344d --- /dev/null +++ b/tests/unit/test_parameter.py @@ -0,0 +1,46 @@ +from models.scenario import Scenario +from main import app +from config.database import Base, engine +from fastapi.testclient import TestClient +import pytest +import os +import sys +sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(__file__)))) + +# Setup and teardown + + +def setup_module(module): + Base.metadata.create_all(bind=engine) + + +def teardown_module(module): + Base.metadata.drop_all(bind=engine) + + +client = TestClient(app) + +# Helper to create a scenario + + +def create_test_scenario(): + resp = client.post("/api/scenarios/", + json={"name": "ParamTest", "description": "Desc"}) + assert resp.status_code == 200 + return resp.json()["id"] + + +def test_create_and_list_parameter(): + # Ensure scenario exists + scen_id = create_test_scenario() + # Create a parameter + resp = client.post( + "/api/parameters/", json={"scenario_id": scen_id, "name": "param1", "value": 3.14}) + assert resp.status_code == 200 + data = resp.json() + assert data["name"] == "param1" + # List parameters + resp2 = client.get("/api/parameters/") + assert resp2.status_code == 200 + data2 = resp2.json() + assert any(p["name"] == "param1" for p in data2) diff --git a/tests/unit/test_reporting.py b/tests/unit/test_reporting.py new file mode 100644 index 0000000..bb858a1 --- /dev/null +++ b/tests/unit/test_reporting.py @@ -0,0 +1,27 @@ +from services.reporting import generate_report +from routes.reporting import router +from fastapi.testclient import TestClient +from main import app +import pytest + +# Function test +def test_generate_report_empty(): + report = generate_report([]) + assert isinstance(report, dict) + +# Endpoint test +def test_reporting_endpoint_invalid_input(): + client = TestClient(app) + resp = client.post("/api/reporting/summary", json={}) + assert resp.status_code == 400 + assert resp.json()["detail"] == "Invalid input format" + + +def test_reporting_endpoint_success(): + client = TestClient(app) + # Minimal input: list of dicts + input_data = [{"iteration": 1, "result": 10.0}] + resp = client.post("/api/reporting/summary", json=input_data) + assert resp.status_code == 200 + data = resp.json() + assert isinstance(data, dict) diff --git a/tests/unit/test_scenario.py b/tests/unit/test_scenario.py new file mode 100644 index 0000000..b4ed93f --- /dev/null +++ b/tests/unit/test_scenario.py @@ -0,0 +1,36 @@ +# ensure project root is on sys.path for module imports +from main import app +from routes.scenarios import router +from config.database import Base, engine +from fastapi.testclient import TestClient +import pytest +import os +import sys +sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(__file__)))) + +# Create tables for testing + + +def setup_module(module): + Base.metadata.create_all(bind=engine) + + +def teardown_module(module): + Base.metadata.drop_all(bind=engine) + + +client = TestClient(app) + + +def test_create_and_list_scenario(): + # Create a scenario + response = client.post( + "/api/scenarios/", json={"name": "Test", "description": "Desc"}) + assert response.status_code == 200 + data = response.json() + assert data["name"] == "Test" + # List scenarios + response2 = client.get("/api/scenarios/") + assert response2.status_code == 200 + data2 = response2.json() + assert any(s["name"] == "Test" for s in data2) diff --git a/tests/unit/test_simulation.py b/tests/unit/test_simulation.py new file mode 100644 index 0000000..2899e20 --- /dev/null +++ b/tests/unit/test_simulation.py @@ -0,0 +1,42 @@ +from services.simulation import run_simulation +from main import app +from config.database import Base, engine +from fastapi.testclient import TestClient +import pytest + +# Setup and teardown + + +def setup_module(module): + Base.metadata.create_all(bind=engine) + + +def teardown_module(module): + Base.metadata.drop_all(bind=engine) + + +client = TestClient(app) + +# Direct function test + + +def test_run_simulation_function_returns_list(): + results = run_simulation([], iterations=10) + assert isinstance(results, list) + assert results == [] + +# Endpoint tests + + +def test_simulation_endpoint_no_params(): + resp = client.post("/api/simulations/run", json=[]) + assert resp.status_code == 400 + assert resp.json()["detail"] == "No parameters provided" + + +def test_simulation_endpoint_success(): + params = [{"name": "param1", "value": 2.5}] + resp = client.post("/api/simulations/run", json=params) + assert resp.status_code == 200 + data = resp.json() + assert isinstance(data, list)