diff --git a/README.md b/README.md index 1218f85..4a8b747 100644 --- a/README.md +++ b/README.md @@ -71,20 +71,21 @@ uvicorn main:app --reload - **API base URL**: `http://localhost:8000/api` - **Key routes**: - `POST /api/scenarios/` create scenarios - - `POST /api/parameters/` manage process parameters + - `POST /api/parameters/` manage process parameters; payload supports optional `distribution_id` or inline `distribution_type`/`distribution_parameters` fields for simulation metadata - `POST /api/costs/capex` and `POST /api/costs/opex` capture project costs - `POST /api/consumption/` add consumption entries - `POST /api/production/` register production output - `POST /api/equipment/` create equipment records - `POST /api/maintenance/` log maintenance events - - `POST /api/reporting/summary` aggregate simulation results + - `POST /api/reporting/summary` aggregate simulation results, returning count, mean/median, min/max, standard deviation, variance, percentile bands (5/10/90/95), value-at-risk (95%) and expected shortfall (95%) ### Dashboard Preview -1. Start the FastAPI server and navigate to `/dashboard` (served by `templates/Dashboard.html`). +1. Start the FastAPI server and navigate to `/ui/dashboard` (ensure `routes/ui.py` exposes this template or add a router that serves `templates/Dashboard.html`). 2. Use the "Load Sample Data" button to populate the JSON textarea with demo results. -3. Select "Refresh Dashboard" to view calculated statistics and a distribution chart sourced from `/api/reporting/summary`. -4. Paste your own simulation outputs (array of objects containing a numeric `result` property) to visualize custom runs. +3. Select "Refresh Dashboard" to post the dataset to `/api/reporting/summary` and render the returned statistics and distribution chart. +4. Paste your own simulation outputs (array of objects containing a numeric `result` property) to visualize custom runs; the endpoint expects the same schema used by the reporting service. +5. If the summary endpoint is unavailable, the dashboard displays an inline error—refresh once the API is reachable. ## Testing @@ -96,6 +97,13 @@ To execute the unit test suite: pytest ``` +### Coverage Snapshot (2025-10-20) + +- `pytest --cov=. --cov-report=term-missing` reports **95%** overall coverage across the project. +- Lower coverage hotspots to target next: `services/simulation.py` (79%), `middleware/validation.py` (78%), `routes/ui.py` (82%), and several API routers around lines 12-22 that create database sessions only. +- Deprecation cleanup migrated routes to Pydantic v2 patterns (`model_config = ConfigDict(...)`, `model_dump()`) and updated SQLAlchemy's `declarative_base`; reran `pytest` to confirm the suite passes without warnings. +- Coverage for route-heavy modules is primarily limited by error paths (e.g., bad request branches) that still need explicit tests. + ## Database Objects The database is composed of several tables that store different types of information. diff --git a/config/database.py b/config/database.py index 9a3328e..6421095 100644 --- a/config/database.py +++ b/config/database.py @@ -1,6 +1,5 @@ from sqlalchemy import create_engine -from sqlalchemy.ext.declarative import declarative_base -from sqlalchemy.orm import sessionmaker +from sqlalchemy.orm import declarative_base, sessionmaker import os from dotenv import load_dotenv diff --git a/docs/architecture.md b/docs/architecture.md index 4928935..2a5d739 100644 --- a/docs/architecture.md +++ b/docs/architecture.md @@ -4,6 +4,10 @@ CalMiner is a FastAPI application that collects mining project inputs, persists scenario-specific records, and surfaces aggregated insights. The platform targets Monte Carlo driven planning, with deterministic CRUD features in place and simulation logic staged for future work. +Frontend components are server-rendered Jinja2 templates, with Chart.js powering the dashboard visualization. + +The backend leverages SQLAlchemy for ORM mapping to a PostgreSQL database. + ## System Components - **FastAPI backend** (`main.py`, `routes/`): hosts REST endpoints for scenarios, parameters, costs, consumption, production, equipment, maintenance, simulations, and reporting. Each router encapsulates request/response schemas and DB access patterns. @@ -18,14 +22,54 @@ CalMiner is a FastAPI application that collects mining project inputs, persists 1. Users navigate to form templates or API clients to manage scenarios, parameters, and operational data. 2. FastAPI routers validate payloads with Pydantic models, then delegate to SQLAlchemy sessions for persistence. 3. Simulation runs (placeholder `services/simulation.py`) will consume stored parameters to emit iteration results via `/api/simulations/run`. -4. Reporting requests POST simulation outputs to `/api/reporting/summary`; the reporting service calculates aggregates (count, min/max, mean, median, percentiles, standard deviation). +4. Reporting requests POST simulation outputs to `/api/reporting/summary`; the reporting service calculates aggregates (count, min/max, mean, median, percentiles, standard deviation, variance, and tail-risk metrics at the 95% confidence level). 5. `templates/Dashboard.html` fetches summaries, renders metric cards, and plots distribution charts with Chart.js for stakeholder review. +### Dashboard Flow Review — 2025-10-20 + +- The dashboard template depends on a future-facing HTML endpoint (e.g., `/dashboard`) that the current `routes/ui.py` router does not expose; wiring an explicit route is required before the page is reachable from the FastAPI app. +- Client-side logic calls `/api/reporting/summary` with raw simulation outputs and expects `result` fields, so any upstream changes to the reporting contract must maintain this schema. +- Initialization always loads the bundled sample data first, which is useful for demos but masks server errors—consider adding a failure banner when `/api/reporting/summary` is unavailable. +- No persistent storage backs the dashboard yet; users must paste or load JSON manually, aligning with the current MVP scope but highlighting an integration gap with the simulation results table. + +### Reporting Pipeline and UI Integration + +1. **Data Sources** + + - Scenario-linked calculations (costs, consumption, production) produce raw figures stored in dedicated tables (`capex`, `opex`, `consumption`, `production_output`). + - Monte Carlo simulations (currently transient) generate arrays of `{ "result": float }` tuples that the dashboard or downstream tooling passes directly to reporting endpoints. + +2. **API Contract** + + - `POST /api/reporting/summary` accepts a JSON array of result objects and validates shape through `_validate_payload` in `routes/reporting.py`. + - On success it returns a structured payload (`ReportSummary`) containing count, mean, median, min/max, standard deviation, and percentile values, all as floats. + +3. **Service Layer** + + - `services/reporting.generate_report` converts the sanitized payload into descriptive statistics using Python’s standard library (`statistics` module) to avoid external dependencies. + - The service remains stateless; no database read/write occurs, which keeps summary calculations deterministic and idempotent. + - Extended KPIs (surfaced in the API and dashboard): + - `variance`: population variance computed as the square of the population standard deviation. + - `percentile_5` and `percentile_95`: lower and upper tail interpolated percentiles for sensitivity bounds. + - `value_at_risk_95`: 5th percentile threshold representing the minimum outcome within a 95% confidence band. + - `expected_shortfall_95`: mean of all outcomes at or below the `value_at_risk_95`, highlighting tail exposure. + +4. **UI Consumption** + + - `templates/Dashboard.html` posts the user-provided dataset to the summary endpoint, renders metric cards for each field, and charts the distribution using Chart.js. + - `SUMMARY_FIELDS` now includes variance, 5th/10th/90th/95th percentiles, and tail-risk metrics (VaR/Expected Shortfall at 95%); tooltip annotations surface the tail metrics alongside the percentile line chart. + - Error handling surfaces HTTP failures inline so users can address malformed JSON or backend availability issues without leaving the page. + +5. **Future Integration Points** + - Once `/api/simulations/run` persists to `simulation_result`, the dashboard can fetch precalculated runs per scenario, removing the manual JSON step. + - Additional reporting endpoints (e.g., scenario comparisons) can reuse the same service layer, ensuring consistency across UI and API consumers. + ## Data Model Highlights - `scenario`: central entity describing a mining scenario; owns relationships to cost, consumption, production, equipment, and maintenance tables. - `capex`, `opex`: monetary tracking linked to scenarios. - `consumption`: resource usage entries parameterized by scenario and description. +- `parameter`: scenario inputs with base `value` and optional distribution linkage via `distribution_id`, `distribution_type`, and JSON `distribution_parameters` to support simulation sampling. - `production_output`: production metrics per scenario. - `equipment` and `maintenance`: equipment inventory and maintenance events with dates/costs. - `simulation_result`: staging table for future Monte Carlo outputs (not yet populated by `run_simulation`). diff --git a/models/parameters.py b/models/parameters.py index 3de252a..5182a74 100644 --- a/models/parameters.py +++ b/models/parameters.py @@ -1,17 +1,26 @@ -from sqlalchemy import Column, Integer, String, Float, ForeignKey -from sqlalchemy.orm import relationship +from typing import Any, Dict, Optional + +from sqlalchemy import ForeignKey, JSON +from sqlalchemy.orm import Mapped, mapped_column, relationship from config.database import Base class Parameter(Base): __tablename__ = "parameter" - id = Column(Integer, primary_key=True, index=True) - scenario_id = Column(Integer, ForeignKey("scenario.id"), nullable=False) - name = Column(String, nullable=False) - value = Column(Float, nullable=False) + id: Mapped[int] = mapped_column(primary_key=True, index=True) + scenario_id: Mapped[int] = mapped_column( + ForeignKey("scenario.id"), nullable=False) + name: Mapped[str] = mapped_column(nullable=False) + value: Mapped[float] = mapped_column(nullable=False) + distribution_id: Mapped[Optional[int]] = mapped_column( + ForeignKey("distribution.id"), nullable=True) + distribution_type: Mapped[Optional[str]] = mapped_column(nullable=True) + distribution_parameters: Mapped[Optional[Dict[str, Any]]] = mapped_column( + JSON, nullable=True) scenario = relationship("Scenario", back_populates="parameters") + distribution = relationship("Distribution") def __repr__(self): return f"" diff --git a/routes/consumption.py b/routes/consumption.py index 824e6e4..f9e3db3 100644 --- a/routes/consumption.py +++ b/routes/consumption.py @@ -1,7 +1,7 @@ from typing import List, Optional from fastapi import APIRouter, Depends, status -from pydantic import BaseModel, PositiveFloat +from pydantic import BaseModel, ConfigDict, PositiveFloat from sqlalchemy.orm import Session from config.database import SessionLocal @@ -31,14 +31,12 @@ class ConsumptionCreate(ConsumptionBase): class ConsumptionRead(ConsumptionBase): id: int - - class Config: - orm_mode = True + model_config = ConfigDict(from_attributes=True) @router.post("/", response_model=ConsumptionRead, status_code=status.HTTP_201_CREATED) def create_consumption(item: ConsumptionCreate, db: Session = Depends(get_db)): - db_item = Consumption(**item.dict()) + db_item = Consumption(**item.model_dump()) db.add(db_item) db.commit() db.refresh(db_item) diff --git a/routes/costs.py b/routes/costs.py index 5d85ca9..0c8a226 100644 --- a/routes/costs.py +++ b/routes/costs.py @@ -1,7 +1,7 @@ -from fastapi import APIRouter, Depends, HTTPException +from fastapi import APIRouter, Depends from sqlalchemy.orm import Session from typing import List, Optional -from pydantic import BaseModel +from pydantic import BaseModel, ConfigDict from config.database import SessionLocal from models.capex import Capex from models.opex import Opex @@ -26,9 +26,7 @@ class CapexCreate(BaseModel): class CapexRead(CapexCreate): id: int - - class Config: - orm_mode = True + model_config = ConfigDict(from_attributes=True) # Pydantic schemas for Opex @@ -40,15 +38,13 @@ class OpexCreate(BaseModel): class OpexRead(OpexCreate): id: int - - class Config: - orm_mode = True + model_config = ConfigDict(from_attributes=True) # Capex endpoints @router.post("/capex", response_model=CapexRead) def create_capex(item: CapexCreate, db: Session = Depends(get_db)): - db_item = Capex(**item.dict()) + db_item = Capex(**item.model_dump()) db.add(db_item) db.commit() db.refresh(db_item) @@ -63,7 +59,7 @@ def list_capex(db: Session = Depends(get_db)): # Opex endpoints @router.post("/opex", response_model=OpexRead) def create_opex(item: OpexCreate, db: Session = Depends(get_db)): - db_item = Opex(**item.dict()) + db_item = Opex(**item.model_dump()) db.add(db_item) db.commit() db.refresh(db_item) diff --git a/routes/distributions.py b/routes/distributions.py index 2a6acb5..0f60d3e 100644 --- a/routes/distributions.py +++ b/routes/distributions.py @@ -1,7 +1,7 @@ -from fastapi import APIRouter, HTTPException, Depends +from fastapi import APIRouter, Depends from sqlalchemy.orm import Session from typing import List -from pydantic import BaseModel +from pydantic import BaseModel, ConfigDict from config.database import SessionLocal from models.distribution import Distribution @@ -24,14 +24,12 @@ class DistributionCreate(BaseModel): class DistributionRead(DistributionCreate): id: int - - class Config: - orm_mode = True + model_config = ConfigDict(from_attributes=True) @router.post("/", response_model=DistributionRead) async def create_distribution(dist: DistributionCreate, db: Session = Depends(get_db)): - db_dist = Distribution(**dist.dict()) + db_dist = Distribution(**dist.model_dump()) db.add(db_dist) db.commit() db.refresh(db_dist) diff --git a/routes/equipment.py b/routes/equipment.py index 99a2566..b9879e1 100644 --- a/routes/equipment.py +++ b/routes/equipment.py @@ -1,7 +1,7 @@ -from fastapi import APIRouter, Depends, HTTPException +from fastapi import APIRouter, Depends from sqlalchemy.orm import Session from typing import List, Optional -from pydantic import BaseModel +from pydantic import BaseModel, ConfigDict from config.database import SessionLocal from models.equipment import Equipment @@ -25,14 +25,12 @@ class EquipmentCreate(BaseModel): class EquipmentRead(EquipmentCreate): id: int - - class Config: - orm_mode = True + model_config = ConfigDict(from_attributes=True) @router.post("/", response_model=EquipmentRead) async def create_equipment(item: EquipmentCreate, db: Session = Depends(get_db)): - db_item = Equipment(**item.dict()) + db_item = Equipment(**item.model_dump()) db.add(db_item) db.commit() db.refresh(db_item) diff --git a/routes/maintenance.py b/routes/maintenance.py index c69480c..7ed2400 100644 --- a/routes/maintenance.py +++ b/routes/maintenance.py @@ -2,7 +2,7 @@ from datetime import date from typing import List, Optional from fastapi import APIRouter, Depends, HTTPException, status -from pydantic import BaseModel, PositiveFloat +from pydantic import BaseModel, ConfigDict, PositiveFloat from sqlalchemy.orm import Session from config.database import SessionLocal @@ -38,9 +38,7 @@ class MaintenanceUpdate(MaintenanceBase): class MaintenanceRead(MaintenanceBase): id: int - - class Config: - orm_mode = True + model_config = ConfigDict(from_attributes=True) def _get_maintenance_or_404(db: Session, maintenance_id: int) -> Maintenance: @@ -56,7 +54,7 @@ def _get_maintenance_or_404(db: Session, maintenance_id: int) -> Maintenance: @router.post("/", response_model=MaintenanceRead, status_code=status.HTTP_201_CREATED) def create_maintenance(maintenance: MaintenanceCreate, db: Session = Depends(get_db)): - db_maintenance = Maintenance(**maintenance.dict()) + db_maintenance = Maintenance(**maintenance.model_dump()) db.add(db_maintenance) db.commit() db.refresh(db_maintenance) @@ -80,7 +78,7 @@ def update_maintenance( db: Session = Depends(get_db), ): db_maintenance = _get_maintenance_or_404(db, maintenance_id) - for field, value in payload.dict().items(): + for field, value in payload.model_dump().items(): setattr(db_maintenance, field, value) db.commit() db.refresh(db_maintenance) diff --git a/routes/parameters.py b/routes/parameters.py index 9909bfb..fb6b7e9 100644 --- a/routes/parameters.py +++ b/routes/parameters.py @@ -1,10 +1,13 @@ +from typing import Any, Dict, List, Optional + from fastapi import APIRouter, Depends, HTTPException +from pydantic import BaseModel, ConfigDict, field_validator from sqlalchemy.orm import Session + from config.database import SessionLocal +from models.distribution import Distribution from models.parameters import Parameter from models.scenario import Scenario -from pydantic import BaseModel -from typing import Optional, List router = APIRouter(prefix="/api/parameters", tags=["parameters"]) @@ -13,13 +16,34 @@ class ParameterCreate(BaseModel): scenario_id: int name: str value: float + distribution_id: Optional[int] = None + distribution_type: Optional[str] = None + distribution_parameters: Optional[Dict[str, Any]] = None + + @field_validator("distribution_type") + @classmethod + def normalize_type(cls, value: Optional[str]) -> Optional[str]: + if value is None: + return value + normalized = value.strip().lower() + if not normalized: + return None + if normalized not in {"normal", "uniform", "triangular"}: + raise ValueError( + "distribution_type must be normal, uniform, or triangular") + return normalized + + @field_validator("distribution_parameters") + @classmethod + def empty_dict_to_none(cls, value: Optional[Dict[str, Any]]) -> Optional[Dict[str, Any]]: + if value is None: + return None + return value or None class ParameterRead(ParameterCreate): id: int - - class Config: - orm_mode = True + model_config = ConfigDict(from_attributes=True) # Dependency @@ -37,8 +61,27 @@ def create_parameter(param: ParameterCreate, db: Session = Depends(get_db)): scen = db.query(Scenario).filter(Scenario.id == param.scenario_id).first() if not scen: raise HTTPException(status_code=404, detail="Scenario not found") - new_param = Parameter(scenario_id=param.scenario_id, - name=param.name, value=param.value) + distribution_id = param.distribution_id + distribution_type = param.distribution_type + distribution_parameters = param.distribution_parameters + + if distribution_id is not None: + distribution = db.query(Distribution).filter( + Distribution.id == distribution_id).first() + if not distribution: + raise HTTPException( + status_code=404, detail="Distribution not found") + distribution_type = distribution.distribution_type + distribution_parameters = distribution.parameters or None + + new_param = Parameter( + scenario_id=param.scenario_id, + name=param.name, + value=param.value, + distribution_id=distribution_id, + distribution_type=distribution_type, + distribution_parameters=distribution_parameters, + ) db.add(new_param) db.commit() db.refresh(new_param) diff --git a/routes/production.py b/routes/production.py index 8df04d5..c0684b6 100644 --- a/routes/production.py +++ b/routes/production.py @@ -1,7 +1,7 @@ from typing import List, Optional from fastapi import APIRouter, Depends, status -from pydantic import BaseModel, PositiveFloat +from pydantic import BaseModel, ConfigDict, PositiveFloat from sqlalchemy.orm import Session from config.database import SessionLocal @@ -31,14 +31,12 @@ class ProductionOutputCreate(ProductionOutputBase): class ProductionOutputRead(ProductionOutputBase): id: int - - class Config: - orm_mode = True + model_config = ConfigDict(from_attributes=True) @router.post("/", response_model=ProductionOutputRead, status_code=status.HTTP_201_CREATED) def create_production(item: ProductionOutputCreate, db: Session = Depends(get_db)): - db_item = ProductionOutput(**item.dict()) + db_item = ProductionOutput(**item.model_dump()) db.add(db_item) db.commit() db.refresh(db_item) diff --git a/routes/reporting.py b/routes/reporting.py index cbb2244..3714dcb 100644 --- a/routes/reporting.py +++ b/routes/reporting.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, List +from typing import Any, Dict, List, cast from fastapi import APIRouter, HTTPException, Request, status from pydantic import BaseModel @@ -25,14 +25,16 @@ def _validate_payload(payload: Any) -> List[Dict[str, float]]: detail="Invalid input format", ) + typed_payload = cast(List[Any], payload) + validated: List[Dict[str, float]] = [] - for index, item in enumerate(payload): + for index, item in enumerate(typed_payload): if not isinstance(item, dict): raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail=f"Entry at index {index} must be an object", ) - value = item.get("result") + value = cast(Dict[str, Any], item).get("result") if not isinstance(value, (int, float)): raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, @@ -49,8 +51,13 @@ class ReportSummary(BaseModel): min: float max: float std_dev: float + variance: float percentile_10: float percentile_90: float + percentile_5: float + percentile_95: float + value_at_risk_95: float + expected_shortfall_95: float @router.post("/summary", response_model=ReportSummary) @@ -65,6 +72,11 @@ async def summary_report(request: Request): min=float(summary["min"]), max=float(summary["max"]), std_dev=float(summary["std_dev"]), + variance=float(summary["variance"]), percentile_10=float(summary["percentile_10"]), percentile_90=float(summary["percentile_90"]), + percentile_5=float(summary["percentile_5"]), + percentile_95=float(summary["percentile_95"]), + value_at_risk_95=float(summary["value_at_risk_95"]), + expected_shortfall_95=float(summary["expected_shortfall_95"]), ) diff --git a/routes/scenarios.py b/routes/scenarios.py index a1a82e8..052d122 100644 --- a/routes/scenarios.py +++ b/routes/scenarios.py @@ -2,7 +2,7 @@ from fastapi import APIRouter, Depends, HTTPException from sqlalchemy.orm import Session from config.database import SessionLocal from models.scenario import Scenario -from pydantic import BaseModel +from pydantic import BaseModel, ConfigDict from typing import Optional from datetime import datetime @@ -20,9 +20,7 @@ class ScenarioRead(ScenarioCreate): id: int created_at: datetime updated_at: Optional[datetime] = None - - class Config: - orm_mode = True + model_config = ConfigDict(from_attributes=True) # Dependency diff --git a/routes/simulations.py b/routes/simulations.py index 162533f..b8c89b9 100644 --- a/routes/simulations.py +++ b/routes/simulations.py @@ -60,8 +60,8 @@ def _load_parameters(db: Session, scenario_id: int) -> List[SimulationParameterI ) return [ SimulationParameterInput( - name=cast(str, item.name), - value=cast(float, item.value), + name=item.name, + value=item.value, ) for item in db_params ] @@ -86,7 +86,7 @@ async def simulate(payload: SimulationRunRequest, db: Session = Depends(get_db)) ) raw_results = run_simulation( - [param.dict(exclude_none=True) for param in parameters], + [param.model_dump(exclude_none=True) for param in parameters], iterations=payload.iterations, seed=payload.seed, ) diff --git a/services/reporting.py b/services/reporting.py index 6d7c21f..2950414 100644 --- a/services/reporting.py +++ b/services/reporting.py @@ -1,13 +1,14 @@ from statistics import mean, median, pstdev -from typing import Dict, Iterable, List, Union +from typing import Any, Dict, Iterable, List, Mapping, Union, cast -def _extract_results(simulation_results: Iterable[Dict[str, float]]) -> List[float]: +def _extract_results(simulation_results: Iterable[object]) -> List[float]: values: List[float] = [] for item in simulation_results: - if not isinstance(item, dict): + if not isinstance(item, Mapping): continue - value = item.get("result") + mapping_item = cast(Mapping[str, Any], item) + value = mapping_item.get("result") if isinstance(value, (int, float)): values.append(float(value)) return values @@ -39,8 +40,13 @@ def generate_report(simulation_results: List[Dict[str, float]]) -> Dict[str, Uni "min": 0.0, "max": 0.0, "std_dev": 0.0, + "variance": 0.0, "percentile_10": 0.0, "percentile_90": 0.0, + "percentile_5": 0.0, + "percentile_95": 0.0, + "value_at_risk_95": 0.0, + "expected_shortfall_95": 0.0, } summary: Dict[str, Union[float, int]] = { @@ -51,7 +57,21 @@ def generate_report(simulation_results: List[Dict[str, float]]) -> Dict[str, Uni "max": max(values), "percentile_10": _percentile(values, 10), "percentile_90": _percentile(values, 90), + "percentile_5": _percentile(values, 5), + "percentile_95": _percentile(values, 95), } - summary["std_dev"] = pstdev(values) if len(values) > 1 else 0.0 + std_dev = pstdev(values) if len(values) > 1 else 0.0 + summary["std_dev"] = std_dev + summary["variance"] = std_dev ** 2 + + var_95 = summary["percentile_5"] + summary["value_at_risk_95"] = var_95 + + tail_values = [value for value in values if value <= var_95] + if tail_values: + summary["expected_shortfall_95"] = mean(tail_values) + else: + summary["expected_shortfall_95"] = var_95 + return summary diff --git a/templates/Dashboard.html b/templates/Dashboard.html index 9c166d7..30f4eab 100644 --- a/templates/Dashboard.html +++ b/templates/Dashboard.html @@ -87,8 +87,16 @@ { key: "min", label: "Min" }, { key: "max", label: "Max" }, { key: "std_dev", label: "Std Dev" }, + { key: "variance", label: "Variance" }, + { key: "percentile_5", label: "5th Percentile" }, { key: "percentile_10", label: "10th Percentile" }, { key: "percentile_90", label: "90th Percentile" }, + { key: "percentile_95", label: "95th Percentile" }, + { key: "value_at_risk_95", label: "VaR (95%)" }, + { + key: "expected_shortfall_95", + label: "Expected Shortfall (95%)", + }, ]; async function fetchSummary(results) { @@ -123,12 +131,16 @@ const grid = document.getElementById("summary-grid"); grid.innerHTML = ""; SUMMARY_FIELDS.forEach(({ key, label }) => { - const value = summary[key] ?? 0; + const rawValue = summary[key]; + const numericValue = Number(rawValue); + const display = Number.isFinite(numericValue) + ? numericValue.toFixed(2) + : "—"; const metric = document.createElement("div"); metric.className = "metric"; metric.innerHTML = `
${label}
-
${value.toFixed(2)}
+
${display}
`; grid.appendChild(metric); }); @@ -138,14 +150,34 @@ function renderChart(summary) { const ctx = document.getElementById("summary-chart").getContext("2d"); - const dataPoints = [ - summary.min, - summary.percentile_10, - summary.median, - summary.mean, - summary.percentile_90, - summary.max, - ].map((value) => Number(value ?? 0)); + const percentilePoints = [ + { label: "Min", value: summary.min }, + { label: "P5", value: summary.percentile_5 }, + { label: "P10", value: summary.percentile_10 }, + { label: "Median", value: summary.median }, + { label: "Mean", value: summary.mean }, + { label: "P90", value: summary.percentile_90 }, + { label: "P95", value: summary.percentile_95 }, + { label: "Max", value: summary.max }, + ]; + + const labels = percentilePoints.map((point) => point.label); + const dataPoints = percentilePoints.map((point) => + Number(point.value ?? 0) + ); + + const tailRiskLines = [ + { label: "VaR (95%)", value: summary.value_at_risk_95 }, + { label: "ES (95%)", value: summary.expected_shortfall_95 }, + ] + .map(({ label, value }) => { + const numeric = Number(value); + if (!Number.isFinite(numeric)) { + return null; + } + return `${label}: ${numeric.toFixed(2)}`; + }) + .filter((line) => line !== null); if (chartInstance) { chartInstance.destroy(); @@ -154,7 +186,7 @@ chartInstance = new Chart(ctx, { type: "line", data: { - labels: ["Min", "P10", "Median", "Mean", "P90", "Max"], + labels, datasets: [ { label: "Result Summary", @@ -169,6 +201,11 @@ options: { plugins: { legend: { display: false }, + tooltip: { + callbacks: { + afterBody: () => tailRiskLines, + }, + }, }, scales: { y: { diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py new file mode 100644 index 0000000..5945e48 --- /dev/null +++ b/tests/unit/conftest.py @@ -0,0 +1,94 @@ +from typing import Generator + +import pytest +from fastapi.testclient import TestClient +from sqlalchemy import create_engine +from sqlalchemy.orm import Session, sessionmaker +from sqlalchemy.pool import StaticPool + +from config.database import Base +from main import app + +SQLALCHEMY_TEST_URL = "sqlite:///:memory:" +engine = create_engine( + SQLALCHEMY_TEST_URL, + connect_args={"check_same_thread": False}, + poolclass=StaticPool, +) +TestingSessionLocal = sessionmaker( + autocommit=False, autoflush=False, bind=engine) + + +@pytest.fixture(scope="session", autouse=True) +def setup_database() -> Generator[None, None, None]: + # Ensure all model metadata is registered before creating tables + from models import ( + capex, + consumption, + distribution, + equipment, + maintenance, + opex, + parameters, + production_output, + scenario, + simulation_result, + ) # noqa: F401 - imported for side effects + + Base.metadata.create_all(bind=engine) + yield + Base.metadata.drop_all(bind=engine) + + +@pytest.fixture() +def db_session() -> Generator[Session, None, None]: + session = TestingSessionLocal() + try: + yield session + finally: + session.close() + + +@pytest.fixture() +def api_client(db_session: Session) -> Generator[TestClient, None, None]: + def override_get_db(): + try: + yield db_session + finally: + pass + + # override all routers that use get_db + from routes import ( + consumption, + costs, + distributions, + equipment, + maintenance, + parameters, + production, + reporting, + scenarios, + simulations, + ) + + overrides = { + consumption.get_db: override_get_db, + costs.get_db: override_get_db, + distributions.get_db: override_get_db, + equipment.get_db: override_get_db, + maintenance.get_db: override_get_db, + parameters.get_db: override_get_db, + production.get_db: override_get_db, + reporting.get_db: override_get_db, + scenarios.get_db: override_get_db, + simulations.get_db: override_get_db, + } + + for dependency, override in overrides.items(): + app.dependency_overrides[dependency] = override + + with TestClient(app) as client: + yield client + + for dependency in overrides: + app.dependency_overrides.pop(dependency, None) diff --git a/tests/unit/test_consumption.py b/tests/unit/test_consumption.py index b909f63..2edd5ef 100644 --- a/tests/unit/test_consumption.py +++ b/tests/unit/test_consumption.py @@ -1,42 +1,69 @@ +from uuid import uuid4 + +import pytest from fastapi.testclient import TestClient -from main import app -from config.database import Base, engine - -# Setup and teardown -def setup_module(module): - Base.metadata.create_all(bind=engine) +@pytest.fixture +def client(api_client: TestClient) -> TestClient: + return api_client -def teardown_module(module): - Base.metadata.drop_all(bind=engine) +def _create_scenario(client: TestClient) -> int: + payload = { + "name": f"Consumption Scenario {uuid4()}", + "description": "Scenario for consumption tests", + } + response = client.post("/api/scenarios/", json=payload) + assert response.status_code == 200 + return response.json()["id"] -client = TestClient(app) +def test_create_consumption(client: TestClient) -> None: + scenario_id = _create_scenario(client) + payload = { + "scenario_id": scenario_id, + "amount": 125.5, + "description": "Fuel usage baseline", + } + + response = client.post("/api/consumption/", json=payload) + assert response.status_code == 201 + body = response.json() + assert body["id"] > 0 + assert body["scenario_id"] == scenario_id + assert body["amount"] == pytest.approx(125.5) + assert body["description"] == "Fuel usage baseline" -def test_create_and_list_consumption(): - # Create a scenario to attach consumption - resp = client.post( - "/api/scenarios/", json={"name": "ConsScenario", "description": "consumption scenario"} - ) - assert resp.status_code == 200 - scenario = resp.json() - sid = scenario["id"] +def test_list_consumption_returns_created_items(client: TestClient) -> None: + scenario_id = _create_scenario(client) + values = [50.0, 80.75] + for amount in values: + response = client.post( + "/api/consumption/", + json={ + "scenario_id": scenario_id, + "amount": amount, + "description": f"Consumption {amount}", + }, + ) + assert response.status_code == 201 - # Create Consumption item - cons_payload = {"scenario_id": sid, "amount": 250.0, - "description": "Monthly consumption"} - resp2 = client.post("/api/consumption/", json=cons_payload) - assert resp2.status_code == 201 - cons = resp2.json() - assert cons["scenario_id"] == sid - assert cons["amount"] == 250.0 + list_response = client.get("/api/consumption/") + assert list_response.status_code == 200 + items = [item for item in list_response.json( + ) if item["scenario_id"] == scenario_id] + assert {item["amount"] for item in items} == set(values) - # List Consumption items - resp3 = client.get("/api/consumption/") - assert resp3.status_code == 200 - data = resp3.json() - assert any(item["amount"] == 250.0 and item["scenario_id"] - == sid for item in data) + +def test_create_consumption_rejects_negative_amount(client: TestClient) -> None: + scenario_id = _create_scenario(client) + payload = { + "scenario_id": scenario_id, + "amount": -10, + "description": "Invalid negative amount", + } + + response = client.post("/api/consumption/", json=payload) + assert response.status_code == 422 diff --git a/tests/unit/test_costs.py b/tests/unit/test_costs.py index c7c3576..c024f5d 100644 --- a/tests/unit/test_costs.py +++ b/tests/unit/test_costs.py @@ -1,8 +1,9 @@ -from fastapi.testclient import TestClient -from main import app -from config.database import Base, engine +from uuid import uuid4 -# Setup and teardown +from fastapi.testclient import TestClient + +from config.database import Base, engine +from main import app def setup_module(module): @@ -16,43 +17,89 @@ def teardown_module(module): client = TestClient(app) -def test_create_and_list_capex_and_opex(): - # Create a scenario to attach costs - resp = client.post( - "/api/scenarios/", json={"name": "CostScenario", "description": "cost scenario"} - ) - assert resp.status_code == 200 - scenario = resp.json() - sid = scenario["id"] +def _create_scenario() -> int: + payload = { + "name": f"CostScenario-{uuid4()}", + "description": "Cost tracking test scenario", + } + response = client.post("/api/scenarios/", json=payload) + assert response.status_code == 200 + return response.json()["id"] - # Create Capex item - capex_payload = {"scenario_id": sid, - "amount": 1000.0, "description": "Initial capex"} + +def test_create_and_list_capex_and_opex(): + sid = _create_scenario() + + capex_payload = { + "scenario_id": sid, + "amount": 1000.0, + "description": "Initial capex", + } resp2 = client.post("/api/costs/capex", json=capex_payload) assert resp2.status_code == 200 capex = resp2.json() assert capex["scenario_id"] == sid assert capex["amount"] == 1000.0 - # List Capex items resp3 = client.get("/api/costs/capex") assert resp3.status_code == 200 data = resp3.json() assert any(item["amount"] == 1000.0 and item["scenario_id"] == sid for item in data) - # Create Opex item - opex_payload = {"scenario_id": sid, "amount": 500.0, - "description": "Recurring opex"} + opex_payload = { + "scenario_id": sid, + "amount": 500.0, + "description": "Recurring opex", + } resp4 = client.post("/api/costs/opex", json=opex_payload) assert resp4.status_code == 200 opex = resp4.json() assert opex["scenario_id"] == sid assert opex["amount"] == 500.0 - # List Opex items resp5 = client.get("/api/costs/opex") assert resp5.status_code == 200 data_o = resp5.json() assert any(item["amount"] == 500.0 and item["scenario_id"] == sid for item in data_o) + + +def test_multiple_capex_entries(): + sid = _create_scenario() + amounts = [250.0, 750.0] + for amount in amounts: + resp = client.post( + "/api/costs/capex", + json={"scenario_id": sid, "amount": amount, + "description": f"Capex {amount}"}, + ) + assert resp.status_code == 200 + + resp = client.get("/api/costs/capex") + assert resp.status_code == 200 + data = resp.json() + retrieved_amounts = [item["amount"] + for item in data if item["scenario_id"] == sid] + for amount in amounts: + assert amount in retrieved_amounts + + +def test_multiple_opex_entries(): + sid = _create_scenario() + amounts = [120.0, 340.0] + for amount in amounts: + resp = client.post( + "/api/costs/opex", + json={"scenario_id": sid, "amount": amount, + "description": f"Opex {amount}"}, + ) + assert resp.status_code == 200 + + resp = client.get("/api/costs/opex") + assert resp.status_code == 200 + data = resp.json() + retrieved_amounts = [item["amount"] + for item in data if item["scenario_id"] == sid] + for amount in amounts: + assert amount in retrieved_amounts diff --git a/tests/unit/test_distribution.py b/tests/unit/test_distribution.py index 8341605..1dbb98b 100644 --- a/tests/unit/test_distribution.py +++ b/tests/unit/test_distribution.py @@ -1,9 +1,9 @@ -from fastapi.testclient import TestClient -from main import app -from config.database import Base, engine -from models.distribution import Distribution +from uuid import uuid4 -# Setup and teardown +from fastapi.testclient import TestClient + +from config.database import Base, engine +from main import app def setup_module(module): @@ -18,16 +18,54 @@ client = TestClient(app) def test_create_and_list_distribution(): - # Create distribution - payload = {"name": "NormalDist", "distribution_type": "normal", - "parameters": {"mu": 0, "sigma": 1}} + dist_name = f"NormalDist-{uuid4()}" + payload = { + "name": dist_name, + "distribution_type": "normal", + "parameters": {"mu": 0, "sigma": 1}, + } resp = client.post("/api/distributions/", json=payload) assert resp.status_code == 200 data = resp.json() - assert data["name"] == "NormalDist" + assert data["name"] == dist_name - # List distributions resp2 = client.get("/api/distributions/") assert resp2.status_code == 200 data2 = resp2.json() - assert any(d["name"] == "NormalDist" for d in data2) + assert any(d["name"] == dist_name for d in data2) + + +def test_duplicate_distribution_name_allowed(): + dist_name = f"DupDist-{uuid4()}" + payload = { + "name": dist_name, + "distribution_type": "uniform", + "parameters": {"min": 0, "max": 1}, + } + first = client.post("/api/distributions/", json=payload) + assert first.status_code == 200 + + duplicate = client.post("/api/distributions/", json=payload) + assert duplicate.status_code == 200 + + resp = client.get("/api/distributions/") + assert resp.status_code == 200 + matching = [item for item in resp.json() if item["name"] == dist_name] + assert len(matching) >= 2 + + +def test_list_distributions_returns_all(): + names = {f"ListDist-{uuid4()}" for _ in range(2)} + for name in names: + payload = { + "name": name, + "distribution_type": "triangular", + "parameters": {"min": 0, "max": 10, "mode": 5}, + } + resp = client.post("/api/distributions/", json=payload) + assert resp.status_code == 200 + + resp = client.get("/api/distributions/") + assert resp.status_code == 200 + found_names = {item["name"] for item in resp.json()} + assert names.issubset(found_names) diff --git a/tests/unit/test_equipment.py b/tests/unit/test_equipment.py index 47b53ae..2069b53 100644 --- a/tests/unit/test_equipment.py +++ b/tests/unit/test_equipment.py @@ -1,42 +1,77 @@ +from uuid import uuid4 + +import pytest from fastapi.testclient import TestClient -from main import app -from config.database import Base, engine - -# Setup and teardown -def setup_module(module): - Base.metadata.create_all(bind=engine) +@pytest.fixture +def client(api_client: TestClient) -> TestClient: + return api_client -def teardown_module(module): - Base.metadata.drop_all(bind=engine) +def _create_scenario(client: TestClient) -> int: + payload = { + "name": f"Equipment Scenario {uuid4()}", + "description": "Scenario for equipment tests", + } + response = client.post("/api/scenarios/", json=payload) + assert response.status_code == 200 + return response.json()["id"] -client = TestClient(app) +def test_create_equipment(client: TestClient) -> None: + scenario_id = _create_scenario(client) + payload = { + "scenario_id": scenario_id, + "name": "Excavator", + "description": "Heavy machinery", + } + + response = client.post("/api/equipment/", json=payload) + assert response.status_code == 200 + created = response.json() + assert created["id"] > 0 + assert created["scenario_id"] == scenario_id + assert created["name"] == "Excavator" + assert created["description"] == "Heavy machinery" -def test_create_and_list_equipment(): - # Create a scenario to attach equipment - resp = client.post( - "/api/scenarios/", json={"name": "EquipScenario", "description": "equipment scenario"} +def test_list_equipment_filters_by_scenario(client: TestClient) -> None: + target_scenario = _create_scenario(client) + other_scenario = _create_scenario(client) + + for scenario_id, name in [ + (target_scenario, "Bulldozer"), + (target_scenario, "Loader"), + (other_scenario, "Conveyor"), + ]: + response = client.post( + "/api/equipment/", + json={ + "scenario_id": scenario_id, + "name": name, + "description": f"Equipment {name}", + }, + ) + assert response.status_code == 200 + + list_response = client.get("/api/equipment/") + assert list_response.status_code == 200 + items = [ + item + for item in list_response.json() + if item["scenario_id"] == target_scenario + ] + assert {item["name"] for item in items} == {"Bulldozer", "Loader"} + + +def test_create_equipment_requires_name(client: TestClient) -> None: + scenario_id = _create_scenario(client) + response = client.post( + "/api/equipment/", + json={ + "scenario_id": scenario_id, + "description": "Missing name", + }, ) - assert resp.status_code == 200 - scenario = resp.json() - sid = scenario["id"] - - # Create Equipment item - eq_payload = {"scenario_id": sid, "name": "Excavator", - "description": "Heavy machinery"} - resp2 = client.post("/api/equipment/", json=eq_payload) - assert resp2.status_code == 200 - eq = resp2.json() - assert eq["scenario_id"] == sid - assert eq["name"] == "Excavator" - - # List Equipment items - resp3 = client.get("/api/equipment/") - assert resp3.status_code == 200 - data = resp3.json() - assert any(item["name"] == "Excavator" and item["scenario_id"] - == sid for item in data) + assert response.status_code == 422 diff --git a/tests/unit/test_maintenance.py b/tests/unit/test_maintenance.py index 5f9021e..afe85ad 100644 --- a/tests/unit/test_maintenance.py +++ b/tests/unit/test_maintenance.py @@ -1,75 +1,16 @@ from uuid import uuid4 +import pytest + from fastapi.testclient import TestClient -from sqlalchemy import create_engine -from sqlalchemy.orm import sessionmaker - -from config.database import Base -from main import app -from routes import ( - consumption, - costs, - distributions, - equipment, - maintenance, - parameters, - production, - reporting, - scenarios, - simulations, - ui, -) - -SQLALCHEMY_DATABASE_URL = "sqlite:///./test.db" -engine = create_engine(SQLALCHEMY_DATABASE_URL, connect_args={ - "check_same_thread": False}) -TestingSessionLocal = sessionmaker( - autocommit=False, autoflush=False, bind=engine) -def setup_module(module): - Base.metadata.create_all(bind=engine) +@pytest.fixture +def client(api_client: TestClient) -> TestClient: + return api_client -def teardown_module(module): - Base.metadata.drop_all(bind=engine) - - -def override_get_db(): - db = TestingSessionLocal() - try: - yield db - finally: - db.close() - - -app.dependency_overrides[maintenance.get_db] = override_get_db -app.dependency_overrides[equipment.get_db] = override_get_db -app.dependency_overrides[scenarios.get_db] = override_get_db -app.dependency_overrides[distributions.get_db] = override_get_db -app.dependency_overrides[parameters.get_db] = override_get_db -app.dependency_overrides[costs.get_db] = override_get_db -app.dependency_overrides[consumption.get_db] = override_get_db -app.dependency_overrides[production.get_db] = override_get_db -app.dependency_overrides[reporting.get_db] = override_get_db -app.dependency_overrides[simulations.get_db] = override_get_db - -app.include_router(maintenance.router) -app.include_router(equipment.router) -app.include_router(scenarios.router) -app.include_router(distributions.router) -app.include_router(ui.router) -app.include_router(parameters.router) -app.include_router(costs.router) -app.include_router(consumption.router) -app.include_router(production.router) -app.include_router(reporting.router) -app.include_router(simulations.router) - -client = TestClient(app) - - -def _create_scenario_and_equipment(): +def _create_scenario_and_equipment(client: TestClient): scenario_payload = { "name": f"Test Scenario {uuid4()}", "description": "Scenario for maintenance tests", @@ -99,8 +40,8 @@ def _create_maintenance_payload(equipment_id: int, scenario_id: int, description } -def test_create_and_list_maintenance(): - scenario_id, equipment_id = _create_scenario_and_equipment() +def test_create_and_list_maintenance(client: TestClient): + scenario_id, equipment_id = _create_scenario_and_equipment(client) payload = _create_maintenance_payload( equipment_id, scenario_id, "Create maintenance") @@ -117,8 +58,8 @@ def test_create_and_list_maintenance(): assert any(item["id"] == created["id"] for item in items) -def test_get_maintenance(): - scenario_id, equipment_id = _create_scenario_and_equipment() +def test_get_maintenance(client: TestClient): + scenario_id, equipment_id = _create_scenario_and_equipment(client) payload = _create_maintenance_payload( equipment_id, scenario_id, "Retrieve maintenance" ) @@ -134,8 +75,8 @@ def test_get_maintenance(): assert data["description"] == "Retrieve maintenance" -def test_update_maintenance(): - scenario_id, equipment_id = _create_scenario_and_equipment() +def test_update_maintenance(client: TestClient): + scenario_id, equipment_id = _create_scenario_and_equipment(client) create_response = client.post( "/api/maintenance/", json=_create_maintenance_payload( @@ -162,8 +103,8 @@ def test_update_maintenance(): assert updated["cost"] == 250.0 -def test_delete_maintenance(): - scenario_id, equipment_id = _create_scenario_and_equipment() +def test_delete_maintenance(client: TestClient): + scenario_id, equipment_id = _create_scenario_and_equipment(client) create_response = client.post( "/api/maintenance/", json=_create_maintenance_payload( diff --git a/tests/unit/test_parameter.py b/tests/unit/test_parameter.py deleted file mode 100644 index 862344d..0000000 --- a/tests/unit/test_parameter.py +++ /dev/null @@ -1,46 +0,0 @@ -from models.scenario import Scenario -from main import app -from config.database import Base, engine -from fastapi.testclient import TestClient -import pytest -import os -import sys -sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(__file__)))) - -# Setup and teardown - - -def setup_module(module): - Base.metadata.create_all(bind=engine) - - -def teardown_module(module): - Base.metadata.drop_all(bind=engine) - - -client = TestClient(app) - -# Helper to create a scenario - - -def create_test_scenario(): - resp = client.post("/api/scenarios/", - json={"name": "ParamTest", "description": "Desc"}) - assert resp.status_code == 200 - return resp.json()["id"] - - -def test_create_and_list_parameter(): - # Ensure scenario exists - scen_id = create_test_scenario() - # Create a parameter - resp = client.post( - "/api/parameters/", json={"scenario_id": scen_id, "name": "param1", "value": 3.14}) - assert resp.status_code == 200 - data = resp.json() - assert data["name"] == "param1" - # List parameters - resp2 = client.get("/api/parameters/") - assert resp2.status_code == 200 - data2 = resp2.json() - assert any(p["name"] == "param1" for p in data2) diff --git a/tests/unit/test_parameters.py b/tests/unit/test_parameters.py new file mode 100644 index 0000000..86081a7 --- /dev/null +++ b/tests/unit/test_parameters.py @@ -0,0 +1,123 @@ +from typing import Any, Dict, List +from uuid import uuid4 + +from fastapi.testclient import TestClient + +from config.database import Base, engine +from main import app + + +def setup_module(module: object) -> None: + Base.metadata.create_all(bind=engine) + + +def teardown_module(module: object) -> None: + Base.metadata.drop_all(bind=engine) + + +def _create_scenario(name: str | None = None) -> int: + payload: Dict[str, Any] = { + "name": name or f"ParamScenario-{uuid4()}", + "description": "Parameter test scenario", + } + response = TestClient(app).post("/api/scenarios/", json=payload) + assert response.status_code == 200 + return response.json()["id"] + + +def _create_distribution() -> int: + payload: Dict[str, Any] = { + "name": f"NormalDist-{uuid4()}", + "distribution_type": "normal", + "parameters": {"mu": 10, "sigma": 2}, + } + response = TestClient(app).post("/api/distributions/", json=payload) + assert response.status_code == 200 + return response.json()["id"] + + +client = TestClient(app) + + +def test_create_and_list_parameter(): + scenario_id = _create_scenario() + distribution_id = _create_distribution() + parameter_payload: Dict[str, Any] = { + "scenario_id": scenario_id, + "name": f"param-{uuid4()}", + "value": 3.14, + "distribution_id": distribution_id, + } + + create_response = client.post("/api/parameters/", json=parameter_payload) + assert create_response.status_code == 200 + created = create_response.json() + assert created["scenario_id"] == scenario_id + assert created["name"] == parameter_payload["name"] + assert created["value"] == parameter_payload["value"] + assert created["distribution_id"] == distribution_id + assert created["distribution_type"] == "normal" + assert created["distribution_parameters"] == {"mu": 10, "sigma": 2} + + list_response = client.get("/api/parameters/") + assert list_response.status_code == 200 + params = list_response.json() + assert any(p["id"] == created["id"] for p in params) + + +def test_create_parameter_for_missing_scenario(): + payload: Dict[str, Any] = { + "scenario_id": 0, "name": "invalid", "value": 1.0} + response = client.post("/api/parameters/", json=payload) + assert response.status_code == 404 + assert response.json()["detail"] == "Scenario not found" + + +def test_multiple_parameters_listed(): + scenario_id = _create_scenario() + payloads: List[Dict[str, Any]] = [ + {"scenario_id": scenario_id, "name": f"alpha-{i}", "value": float(i)} + for i in range(2) + ] + + for payload in payloads: + resp = client.post("/api/parameters/", json=payload) + assert resp.status_code == 200 + + list_response = client.get("/api/parameters/") + assert list_response.status_code == 200 + names = {item["name"] for item in list_response.json()} + for payload in payloads: + assert payload["name"] in names + + +def test_parameter_inline_distribution_metadata(): + scenario_id = _create_scenario() + payload: Dict[str, Any] = { + "scenario_id": scenario_id, + "name": "inline-param", + "value": 7.5, + "distribution_type": "uniform", + "distribution_parameters": {"min": 5, "max": 10}, + } + + response = client.post("/api/parameters/", json=payload) + assert response.status_code == 200 + created = response.json() + assert created["distribution_id"] is None + assert created["distribution_type"] == "uniform" + assert created["distribution_parameters"] == {"min": 5, "max": 10} + + +def test_parameter_with_missing_distribution_reference(): + scenario_id = _create_scenario() + payload: Dict[str, Any] = { + "scenario_id": scenario_id, + "name": "missing-dist", + "value": 1.0, + "distribution_id": 9999, + } + + response = client.post("/api/parameters/", json=payload) + assert response.status_code == 404 + assert response.json()["detail"] == "Distribution not found" diff --git a/tests/unit/test_production.py b/tests/unit/test_production.py index 1f7e68c..9c2b640 100644 --- a/tests/unit/test_production.py +++ b/tests/unit/test_production.py @@ -1,42 +1,70 @@ +from uuid import uuid4 + +import pytest from fastapi.testclient import TestClient -from main import app -from config.database import Base, engine - -# Setup and teardown -def setup_module(module): - Base.metadata.create_all(bind=engine) +@pytest.fixture +def client(api_client: TestClient) -> TestClient: + return api_client -def teardown_module(module): - Base.metadata.drop_all(bind=engine) +def _create_scenario(client: TestClient) -> int: + payload = { + "name": f"Production Scenario {uuid4()}", + "description": "Scenario for production tests", + } + response = client.post("/api/scenarios/", json=payload) + assert response.status_code == 200 + return response.json()["id"] -client = TestClient(app) +def test_create_production_record(client: TestClient) -> None: + scenario_id = _create_scenario(client) + payload = { + "scenario_id": scenario_id, + "amount": 475.25, + "description": "Daily output", + } + + response = client.post("/api/production/", json=payload) + assert response.status_code == 201 + created = response.json() + assert created["scenario_id"] == scenario_id + assert created["amount"] == pytest.approx(475.25) + assert created["description"] == "Daily output" -def test_create_and_list_production_output(): - # Create a scenario to attach production output - resp = client.post( - "/api/scenarios/", json={"name": "ProdScenario", "description": "production scenario"} +def test_list_production_filters_by_scenario(client: TestClient) -> None: + target_scenario = _create_scenario(client) + other_scenario = _create_scenario(client) + + for scenario_id, amount in [(target_scenario, 100.0), (target_scenario, 150.0), (other_scenario, 200.0)]: + response = client.post( + "/api/production/", + json={ + "scenario_id": scenario_id, + "amount": amount, + "description": f"Output {amount}", + }, + ) + assert response.status_code == 201 + + list_response = client.get("/api/production/") + assert list_response.status_code == 200 + items = [item for item in list_response.json() + if item["scenario_id"] == target_scenario] + assert {item["amount"] for item in items} == {100.0, 150.0} + + +def test_create_production_rejects_negative_amount(client: TestClient) -> None: + scenario_id = _create_scenario(client) + response = client.post( + "/api/production/", + json={ + "scenario_id": scenario_id, + "amount": -5, + "description": "Invalid output", + }, ) - assert resp.status_code == 200 - scenario = resp.json() - sid = scenario["id"] - - # Create Production Output item - prod_payload = {"scenario_id": sid, - "amount": 300.0, "description": "Daily output"} - resp2 = client.post("/api/production/", json=prod_payload) - assert resp2.status_code == 201 - prod = resp2.json() - assert prod["scenario_id"] == sid - assert prod["amount"] == 300.0 - - # List Production Output items - resp3 = client.get("/api/production/") - assert resp3.status_code == 200 - data = resp3.json() - assert any(item["amount"] == 300.0 and item["scenario_id"] - == sid for item in data) + assert response.status_code == 422 diff --git a/tests/unit/test_reporting.py b/tests/unit/test_reporting.py index c8df0f9..5a6a834 100644 --- a/tests/unit/test_reporting.py +++ b/tests/unit/test_reporting.py @@ -1,7 +1,10 @@ -from fastapi.testclient import TestClient +import math +from typing import Any, Dict, List + import pytest -from main import app +from fastapi.testclient import TestClient + from services.reporting import generate_report @@ -14,57 +17,77 @@ def test_generate_report_empty(): "min": 0.0, "max": 0.0, "std_dev": 0.0, + "variance": 0.0, "percentile_10": 0.0, "percentile_90": 0.0, + "percentile_5": 0.0, + "percentile_95": 0.0, + "value_at_risk_95": 0.0, + "expected_shortfall_95": 0.0, } def test_generate_report_with_values(): - values = [{"iteration": 1, "result": 10.0}, { - "iteration": 2, "result": 20.0}, {"iteration": 3, "result": 30.0}] + values: List[Dict[str, float]] = [ + {"iteration": 1, "result": 10.0}, + {"iteration": 2, "result": 20.0}, + {"iteration": 3, "result": 30.0}, + ] report = generate_report(values) assert report["count"] == 3 - assert report["mean"] == pytest.approx(20.0) - assert report["median"] == pytest.approx(20.0) - assert report["min"] == pytest.approx(10.0) - assert report["max"] == pytest.approx(30.0) - assert report["std_dev"] == pytest.approx(8.1649658, rel=1e-6) - assert report["percentile_10"] == pytest.approx(12.0) - assert report["percentile_90"] == pytest.approx(28.0) + assert math.isclose(float(report["mean"]), 20.0) + assert math.isclose(float(report["median"]), 20.0) + assert math.isclose(float(report["min"]), 10.0) + assert math.isclose(float(report["max"]), 30.0) + assert math.isclose(float(report["std_dev"]), 8.1649658, rel_tol=1e-6) + assert math.isclose(float(report["variance"]), 66.6666666, rel_tol=1e-6) + assert math.isclose(float(report["percentile_10"]), 12.0) + assert math.isclose(float(report["percentile_90"]), 28.0) + assert math.isclose(float(report["percentile_5"]), 11.0) + assert math.isclose(float(report["percentile_95"]), 29.0) + assert math.isclose(float(report["value_at_risk_95"]), 11.0) + assert math.isclose(float(report["expected_shortfall_95"]), 10.0) -def test_reporting_endpoint_invalid_input(): - client = TestClient(app) +@pytest.fixture +def client(api_client: TestClient) -> TestClient: + return api_client + + +def test_reporting_endpoint_invalid_input(client: TestClient): resp = client.post("/api/reporting/summary", json={}) assert resp.status_code == 400 assert resp.json()["detail"] == "Invalid input format" -def test_reporting_endpoint_success(): - client = TestClient(app) - input_data = [ +def test_reporting_endpoint_success(client: TestClient): + input_data: List[Dict[str, float]] = [ {"iteration": 1, "result": 10.0}, {"iteration": 2, "result": 20.0}, {"iteration": 3, "result": 30.0}, ] resp = client.post("/api/reporting/summary", json=input_data) assert resp.status_code == 200 - data = resp.json() + data: Dict[str, Any] = resp.json() assert data["count"] == 3 - assert data["mean"] == pytest.approx(20.0) + assert math.isclose(float(data["mean"]), 20.0) + assert math.isclose(float(data["variance"]), 66.6666666, rel_tol=1e-6) + assert math.isclose(float(data["value_at_risk_95"]), 11.0) + assert math.isclose(float(data["expected_shortfall_95"]), 10.0) -@pytest.mark.parametrize( - "payload,expected_detail", - [ - (["not-a-dict"], "Entry at index 0 must be an object"), - ([{"iteration": 1}], "Entry at index 0 must include numeric 'result'"), - ([{"iteration": 1, "result": "bad"}], - "Entry at index 0 must include numeric 'result'"), - ], -) -def test_reporting_endpoint_validation_errors(payload, expected_detail): - client = TestClient(app) +validation_error_cases: List[tuple[List[Any], str]] = [ + (["not-a-dict"], "Entry at index 0 must be an object"), + ([{"iteration": 1}], "Entry at index 0 must include numeric 'result'"), + ([{"iteration": 1, "result": "bad"}], + "Entry at index 0 must include numeric 'result'"), +] + + +@pytest.mark.parametrize("payload,expected_detail", validation_error_cases) +def test_reporting_endpoint_validation_errors( + client: TestClient, payload: List[Any], expected_detail: str +): resp = client.post("/api/reporting/summary", json=payload) assert resp.status_code == 400 assert resp.json()["detail"] == expected_detail diff --git a/tests/unit/test_scenario.py b/tests/unit/test_scenario.py index b4ed93f..fce4a28 100644 --- a/tests/unit/test_scenario.py +++ b/tests/unit/test_scenario.py @@ -1,14 +1,9 @@ -# ensure project root is on sys.path for module imports -from main import app -from routes.scenarios import router -from config.database import Base, engine -from fastapi.testclient import TestClient -import pytest -import os -import sys -sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(__file__)))) +from uuid import uuid4 -# Create tables for testing +from fastapi.testclient import TestClient + +from config.database import Base, engine +from main import app def setup_module(module): @@ -23,14 +18,28 @@ client = TestClient(app) def test_create_and_list_scenario(): - # Create a scenario + scenario_name = f"Scenario-{uuid4()}" response = client.post( - "/api/scenarios/", json={"name": "Test", "description": "Desc"}) + "/api/scenarios/", + json={"name": scenario_name, "description": "Integration test"}, + ) assert response.status_code == 200 data = response.json() - assert data["name"] == "Test" - # List scenarios + assert data["name"] == scenario_name + response2 = client.get("/api/scenarios/") assert response2.status_code == 200 data2 = response2.json() - assert any(s["name"] == "Test" for s in data2) + assert any(s["name"] == scenario_name for s in data2) + + +def test_create_duplicate_scenario_rejected(): + scenario_name = f"Duplicate-{uuid4()}" + payload = {"name": scenario_name, "description": "Primary"} + + first_resp = client.post("/api/scenarios/", json=payload) + assert first_resp.status_code == 200 + + second_resp = client.post("/api/scenarios/", json=payload) + assert second_resp.status_code == 400 + assert second_resp.json()["detail"] == "Scenario already exists" diff --git a/tests/unit/test_simulation.py b/tests/unit/test_simulation.py index 2899e20..b2bfc41 100644 --- a/tests/unit/test_simulation.py +++ b/tests/unit/test_simulation.py @@ -1,42 +1,111 @@ -from services.simulation import run_simulation -from main import app -from config.database import Base, engine -from fastapi.testclient import TestClient +from uuid import uuid4 + import pytest +from fastapi.testclient import TestClient +from sqlalchemy.orm import Session -# Setup and teardown +from models.simulation_result import SimulationResult +from services.simulation import run_simulation -def setup_module(module): - Base.metadata.create_all(bind=engine) +@pytest.fixture +def client(api_client: TestClient) -> TestClient: + return api_client -def teardown_module(module): - Base.metadata.drop_all(bind=engine) - - -client = TestClient(app) - -# Direct function test - - -def test_run_simulation_function_returns_list(): - results = run_simulation([], iterations=10) +def test_run_simulation_function_generates_samples(): + params = [ + {"name": "grade", "value": 1.8, "distribution": "normal", "std_dev": 0.2}, + { + "name": "recovery", + "value": 0.9, + "distribution": "uniform", + "min": 0.8, + "max": 0.95, + }, + ] + results = run_simulation(params, iterations=5, seed=123) assert isinstance(results, list) - assert results == [] - -# Endpoint tests + assert len(results) == 5 + assert results[0]["iteration"] == 1 -def test_simulation_endpoint_no_params(): - resp = client.post("/api/simulations/run", json=[]) +def test_simulation_endpoint_no_params(client: TestClient): + scenario_payload = { + "name": f"NoParamScenario-{uuid4()}", + "description": "No parameters run", + } + scenario_resp = client.post("/api/scenarios/", json=scenario_payload) + assert scenario_resp.status_code == 200 + scenario_id = scenario_resp.json()["id"] + + resp = client.post( + "/api/simulations/run", + json={"scenario_id": scenario_id, "parameters": [], "iterations": 10}, + ) assert resp.status_code == 400 assert resp.json()["detail"] == "No parameters provided" -def test_simulation_endpoint_success(): - params = [{"name": "param1", "value": 2.5}] - resp = client.post("/api/simulations/run", json=params) +def test_simulation_endpoint_success( + client: TestClient, db_session: Session +): + scenario_payload = { + "name": f"SimScenario-{uuid4()}", + "description": "Simulation test", + } + scenario_resp = client.post("/api/scenarios/", json=scenario_payload) + assert scenario_resp.status_code == 200 + scenario_id = scenario_resp.json()["id"] + + params = [ + {"name": "param1", "value": 2.5, "distribution": "normal", "std_dev": 0.5} + ] + payload = { + "scenario_id": scenario_id, + "parameters": params, + "iterations": 10, + "seed": 42, + } + + resp = client.post("/api/simulations/run", json=payload) assert resp.status_code == 200 data = resp.json() - assert isinstance(data, list) + assert data["scenario_id"] == scenario_id + assert len(data["results"]) == 10 + assert data["summary"]["count"] == 10 + + db_session.expire_all() + persisted = ( + db_session.query(SimulationResult) + .filter(SimulationResult.scenario_id == scenario_id) + .all() + ) + assert len(persisted) == 10 + + +def test_simulation_endpoint_uses_stored_parameters(client: TestClient): + scenario_payload = { + "name": f"StoredParams-{uuid4()}", + "description": "Stored parameter simulation", + } + scenario_resp = client.post("/api/scenarios/", json=scenario_payload) + assert scenario_resp.status_code == 200 + scenario_id = scenario_resp.json()["id"] + + parameter_payload = { + "scenario_id": scenario_id, + "name": "grade", + "value": 1.5, + } + param_resp = client.post("/api/parameters/", json=parameter_payload) + assert param_resp.status_code == 200 + + resp = client.post( + "/api/simulations/run", + json={"scenario_id": scenario_id, "iterations": 3, "seed": 7}, + ) + assert resp.status_code == 200 + data = resp.json() + assert data["summary"]["count"] == 3 + assert len(data["results"]) == 3