feat: Enhance dashboard metrics and summary statistics

- Added new summary fields: variance, 5th percentile, 95th percentile, VaR (95%), and expected shortfall (95%) to the dashboard.
- Updated the display logic for summary metrics to handle non-finite values gracefully.
- Modified the chart rendering to include additional percentile points and tail risk metrics in tooltips.

test: Introduce unit tests for consumption, costs, and other modules

- Created a comprehensive test suite for consumption, costs, equipment, maintenance, production, reporting, and simulation modules.
- Implemented fixtures for database setup and teardown using an in-memory SQLite database for isolated testing.
- Added tests for creating, listing, and validating various entities, ensuring proper error handling and response validation.

refactor: Consolidate parameter tests and remove deprecated files

- Merged parameter-related tests into a new test file for better organization and clarity.
- Removed the old parameter test file that was no longer in use.
- Improved test coverage for parameter creation, listing, and validation scenarios.

fix: Ensure proper validation and error handling in API endpoints

- Added validation to reject negative amounts in consumption and production records.
- Implemented checks to prevent duplicate scenario creation and ensure proper error messages are returned.
- Enhanced reporting endpoint tests to validate input formats and expected outputs.
This commit is contained in:
2025-10-20 22:06:39 +02:00
parent 606cb64ff1
commit 434be86b76
28 changed files with 945 additions and 401 deletions

View File

@@ -71,20 +71,21 @@ uvicorn main:app --reload
- **API base URL**: `http://localhost:8000/api` - **API base URL**: `http://localhost:8000/api`
- **Key routes**: - **Key routes**:
- `POST /api/scenarios/` create scenarios - `POST /api/scenarios/` create scenarios
- `POST /api/parameters/` manage process parameters - `POST /api/parameters/` manage process parameters; payload supports optional `distribution_id` or inline `distribution_type`/`distribution_parameters` fields for simulation metadata
- `POST /api/costs/capex` and `POST /api/costs/opex` capture project costs - `POST /api/costs/capex` and `POST /api/costs/opex` capture project costs
- `POST /api/consumption/` add consumption entries - `POST /api/consumption/` add consumption entries
- `POST /api/production/` register production output - `POST /api/production/` register production output
- `POST /api/equipment/` create equipment records - `POST /api/equipment/` create equipment records
- `POST /api/maintenance/` log maintenance events - `POST /api/maintenance/` log maintenance events
- `POST /api/reporting/summary` aggregate simulation results - `POST /api/reporting/summary` aggregate simulation results, returning count, mean/median, min/max, standard deviation, variance, percentile bands (5/10/90/95), value-at-risk (95%) and expected shortfall (95%)
### Dashboard Preview ### Dashboard Preview
1. Start the FastAPI server and navigate to `/dashboard` (served by `templates/Dashboard.html`). 1. Start the FastAPI server and navigate to `/ui/dashboard` (ensure `routes/ui.py` exposes this template or add a router that serves `templates/Dashboard.html`).
2. Use the "Load Sample Data" button to populate the JSON textarea with demo results. 2. Use the "Load Sample Data" button to populate the JSON textarea with demo results.
3. Select "Refresh Dashboard" to view calculated statistics and a distribution chart sourced from `/api/reporting/summary`. 3. Select "Refresh Dashboard" to post the dataset to `/api/reporting/summary` and render the returned statistics and distribution chart.
4. Paste your own simulation outputs (array of objects containing a numeric `result` property) to visualize custom runs. 4. Paste your own simulation outputs (array of objects containing a numeric `result` property) to visualize custom runs; the endpoint expects the same schema used by the reporting service.
5. If the summary endpoint is unavailable, the dashboard displays an inline error—refresh once the API is reachable.
## Testing ## Testing
@@ -96,6 +97,13 @@ To execute the unit test suite:
pytest pytest
``` ```
### Coverage Snapshot (2025-10-20)
- `pytest --cov=. --cov-report=term-missing` reports **95%** overall coverage across the project.
- Lower coverage hotspots to target next: `services/simulation.py` (79%), `middleware/validation.py` (78%), `routes/ui.py` (82%), and several API routers around lines 12-22 that create database sessions only.
- Deprecation cleanup migrated routes to Pydantic v2 patterns (`model_config = ConfigDict(...)`, `model_dump()`) and updated SQLAlchemy's `declarative_base`; reran `pytest` to confirm the suite passes without warnings.
- Coverage for route-heavy modules is primarily limited by error paths (e.g., bad request branches) that still need explicit tests.
## Database Objects ## Database Objects
The database is composed of several tables that store different types of information. The database is composed of several tables that store different types of information.

View File

@@ -1,6 +1,5 @@
from sqlalchemy import create_engine from sqlalchemy import create_engine
from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.orm import declarative_base, sessionmaker
from sqlalchemy.orm import sessionmaker
import os import os
from dotenv import load_dotenv from dotenv import load_dotenv

View File

@@ -4,6 +4,10 @@
CalMiner is a FastAPI application that collects mining project inputs, persists scenario-specific records, and surfaces aggregated insights. The platform targets Monte Carlo driven planning, with deterministic CRUD features in place and simulation logic staged for future work. CalMiner is a FastAPI application that collects mining project inputs, persists scenario-specific records, and surfaces aggregated insights. The platform targets Monte Carlo driven planning, with deterministic CRUD features in place and simulation logic staged for future work.
Frontend components are server-rendered Jinja2 templates, with Chart.js powering the dashboard visualization.
The backend leverages SQLAlchemy for ORM mapping to a PostgreSQL database.
## System Components ## System Components
- **FastAPI backend** (`main.py`, `routes/`): hosts REST endpoints for scenarios, parameters, costs, consumption, production, equipment, maintenance, simulations, and reporting. Each router encapsulates request/response schemas and DB access patterns. - **FastAPI backend** (`main.py`, `routes/`): hosts REST endpoints for scenarios, parameters, costs, consumption, production, equipment, maintenance, simulations, and reporting. Each router encapsulates request/response schemas and DB access patterns.
@@ -18,14 +22,54 @@ CalMiner is a FastAPI application that collects mining project inputs, persists
1. Users navigate to form templates or API clients to manage scenarios, parameters, and operational data. 1. Users navigate to form templates or API clients to manage scenarios, parameters, and operational data.
2. FastAPI routers validate payloads with Pydantic models, then delegate to SQLAlchemy sessions for persistence. 2. FastAPI routers validate payloads with Pydantic models, then delegate to SQLAlchemy sessions for persistence.
3. Simulation runs (placeholder `services/simulation.py`) will consume stored parameters to emit iteration results via `/api/simulations/run`. 3. Simulation runs (placeholder `services/simulation.py`) will consume stored parameters to emit iteration results via `/api/simulations/run`.
4. Reporting requests POST simulation outputs to `/api/reporting/summary`; the reporting service calculates aggregates (count, min/max, mean, median, percentiles, standard deviation). 4. Reporting requests POST simulation outputs to `/api/reporting/summary`; the reporting service calculates aggregates (count, min/max, mean, median, percentiles, standard deviation, variance, and tail-risk metrics at the 95% confidence level).
5. `templates/Dashboard.html` fetches summaries, renders metric cards, and plots distribution charts with Chart.js for stakeholder review. 5. `templates/Dashboard.html` fetches summaries, renders metric cards, and plots distribution charts with Chart.js for stakeholder review.
### Dashboard Flow Review — 2025-10-20
- The dashboard template depends on a future-facing HTML endpoint (e.g., `/dashboard`) that the current `routes/ui.py` router does not expose; wiring an explicit route is required before the page is reachable from the FastAPI app.
- Client-side logic calls `/api/reporting/summary` with raw simulation outputs and expects `result` fields, so any upstream changes to the reporting contract must maintain this schema.
- Initialization always loads the bundled sample data first, which is useful for demos but masks server errors—consider adding a failure banner when `/api/reporting/summary` is unavailable.
- No persistent storage backs the dashboard yet; users must paste or load JSON manually, aligning with the current MVP scope but highlighting an integration gap with the simulation results table.
### Reporting Pipeline and UI Integration
1. **Data Sources**
- Scenario-linked calculations (costs, consumption, production) produce raw figures stored in dedicated tables (`capex`, `opex`, `consumption`, `production_output`).
- Monte Carlo simulations (currently transient) generate arrays of `{ "result": float }` tuples that the dashboard or downstream tooling passes directly to reporting endpoints.
2. **API Contract**
- `POST /api/reporting/summary` accepts a JSON array of result objects and validates shape through `_validate_payload` in `routes/reporting.py`.
- On success it returns a structured payload (`ReportSummary`) containing count, mean, median, min/max, standard deviation, and percentile values, all as floats.
3. **Service Layer**
- `services/reporting.generate_report` converts the sanitized payload into descriptive statistics using Pythons standard library (`statistics` module) to avoid external dependencies.
- The service remains stateless; no database read/write occurs, which keeps summary calculations deterministic and idempotent.
- Extended KPIs (surfaced in the API and dashboard):
- `variance`: population variance computed as the square of the population standard deviation.
- `percentile_5` and `percentile_95`: lower and upper tail interpolated percentiles for sensitivity bounds.
- `value_at_risk_95`: 5th percentile threshold representing the minimum outcome within a 95% confidence band.
- `expected_shortfall_95`: mean of all outcomes at or below the `value_at_risk_95`, highlighting tail exposure.
4. **UI Consumption**
- `templates/Dashboard.html` posts the user-provided dataset to the summary endpoint, renders metric cards for each field, and charts the distribution using Chart.js.
- `SUMMARY_FIELDS` now includes variance, 5th/10th/90th/95th percentiles, and tail-risk metrics (VaR/Expected Shortfall at 95%); tooltip annotations surface the tail metrics alongside the percentile line chart.
- Error handling surfaces HTTP failures inline so users can address malformed JSON or backend availability issues without leaving the page.
5. **Future Integration Points**
- Once `/api/simulations/run` persists to `simulation_result`, the dashboard can fetch precalculated runs per scenario, removing the manual JSON step.
- Additional reporting endpoints (e.g., scenario comparisons) can reuse the same service layer, ensuring consistency across UI and API consumers.
## Data Model Highlights ## Data Model Highlights
- `scenario`: central entity describing a mining scenario; owns relationships to cost, consumption, production, equipment, and maintenance tables. - `scenario`: central entity describing a mining scenario; owns relationships to cost, consumption, production, equipment, and maintenance tables.
- `capex`, `opex`: monetary tracking linked to scenarios. - `capex`, `opex`: monetary tracking linked to scenarios.
- `consumption`: resource usage entries parameterized by scenario and description. - `consumption`: resource usage entries parameterized by scenario and description.
- `parameter`: scenario inputs with base `value` and optional distribution linkage via `distribution_id`, `distribution_type`, and JSON `distribution_parameters` to support simulation sampling.
- `production_output`: production metrics per scenario. - `production_output`: production metrics per scenario.
- `equipment` and `maintenance`: equipment inventory and maintenance events with dates/costs. - `equipment` and `maintenance`: equipment inventory and maintenance events with dates/costs.
- `simulation_result`: staging table for future Monte Carlo outputs (not yet populated by `run_simulation`). - `simulation_result`: staging table for future Monte Carlo outputs (not yet populated by `run_simulation`).

View File

@@ -1,17 +1,26 @@
from sqlalchemy import Column, Integer, String, Float, ForeignKey from typing import Any, Dict, Optional
from sqlalchemy.orm import relationship
from sqlalchemy import ForeignKey, JSON
from sqlalchemy.orm import Mapped, mapped_column, relationship
from config.database import Base from config.database import Base
class Parameter(Base): class Parameter(Base):
__tablename__ = "parameter" __tablename__ = "parameter"
id = Column(Integer, primary_key=True, index=True) id: Mapped[int] = mapped_column(primary_key=True, index=True)
scenario_id = Column(Integer, ForeignKey("scenario.id"), nullable=False) scenario_id: Mapped[int] = mapped_column(
name = Column(String, nullable=False) ForeignKey("scenario.id"), nullable=False)
value = Column(Float, nullable=False) name: Mapped[str] = mapped_column(nullable=False)
value: Mapped[float] = mapped_column(nullable=False)
distribution_id: Mapped[Optional[int]] = mapped_column(
ForeignKey("distribution.id"), nullable=True)
distribution_type: Mapped[Optional[str]] = mapped_column(nullable=True)
distribution_parameters: Mapped[Optional[Dict[str, Any]]] = mapped_column(
JSON, nullable=True)
scenario = relationship("Scenario", back_populates="parameters") scenario = relationship("Scenario", back_populates="parameters")
distribution = relationship("Distribution")
def __repr__(self): def __repr__(self):
return f"<Parameter id={self.id} name={self.name} value={self.value}>" return f"<Parameter id={self.id} name={self.name} value={self.value}>"

View File

@@ -1,7 +1,7 @@
from typing import List, Optional from typing import List, Optional
from fastapi import APIRouter, Depends, status from fastapi import APIRouter, Depends, status
from pydantic import BaseModel, PositiveFloat from pydantic import BaseModel, ConfigDict, PositiveFloat
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from config.database import SessionLocal from config.database import SessionLocal
@@ -31,14 +31,12 @@ class ConsumptionCreate(ConsumptionBase):
class ConsumptionRead(ConsumptionBase): class ConsumptionRead(ConsumptionBase):
id: int id: int
model_config = ConfigDict(from_attributes=True)
class Config:
orm_mode = True
@router.post("/", response_model=ConsumptionRead, status_code=status.HTTP_201_CREATED) @router.post("/", response_model=ConsumptionRead, status_code=status.HTTP_201_CREATED)
def create_consumption(item: ConsumptionCreate, db: Session = Depends(get_db)): def create_consumption(item: ConsumptionCreate, db: Session = Depends(get_db)):
db_item = Consumption(**item.dict()) db_item = Consumption(**item.model_dump())
db.add(db_item) db.add(db_item)
db.commit() db.commit()
db.refresh(db_item) db.refresh(db_item)

View File

@@ -1,7 +1,7 @@
from fastapi import APIRouter, Depends, HTTPException from fastapi import APIRouter, Depends
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from typing import List, Optional from typing import List, Optional
from pydantic import BaseModel from pydantic import BaseModel, ConfigDict
from config.database import SessionLocal from config.database import SessionLocal
from models.capex import Capex from models.capex import Capex
from models.opex import Opex from models.opex import Opex
@@ -26,9 +26,7 @@ class CapexCreate(BaseModel):
class CapexRead(CapexCreate): class CapexRead(CapexCreate):
id: int id: int
model_config = ConfigDict(from_attributes=True)
class Config:
orm_mode = True
# Pydantic schemas for Opex # Pydantic schemas for Opex
@@ -40,15 +38,13 @@ class OpexCreate(BaseModel):
class OpexRead(OpexCreate): class OpexRead(OpexCreate):
id: int id: int
model_config = ConfigDict(from_attributes=True)
class Config:
orm_mode = True
# Capex endpoints # Capex endpoints
@router.post("/capex", response_model=CapexRead) @router.post("/capex", response_model=CapexRead)
def create_capex(item: CapexCreate, db: Session = Depends(get_db)): def create_capex(item: CapexCreate, db: Session = Depends(get_db)):
db_item = Capex(**item.dict()) db_item = Capex(**item.model_dump())
db.add(db_item) db.add(db_item)
db.commit() db.commit()
db.refresh(db_item) db.refresh(db_item)
@@ -63,7 +59,7 @@ def list_capex(db: Session = Depends(get_db)):
# Opex endpoints # Opex endpoints
@router.post("/opex", response_model=OpexRead) @router.post("/opex", response_model=OpexRead)
def create_opex(item: OpexCreate, db: Session = Depends(get_db)): def create_opex(item: OpexCreate, db: Session = Depends(get_db)):
db_item = Opex(**item.dict()) db_item = Opex(**item.model_dump())
db.add(db_item) db.add(db_item)
db.commit() db.commit()
db.refresh(db_item) db.refresh(db_item)

View File

@@ -1,7 +1,7 @@
from fastapi import APIRouter, HTTPException, Depends from fastapi import APIRouter, Depends
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from typing import List from typing import List
from pydantic import BaseModel from pydantic import BaseModel, ConfigDict
from config.database import SessionLocal from config.database import SessionLocal
from models.distribution import Distribution from models.distribution import Distribution
@@ -24,14 +24,12 @@ class DistributionCreate(BaseModel):
class DistributionRead(DistributionCreate): class DistributionRead(DistributionCreate):
id: int id: int
model_config = ConfigDict(from_attributes=True)
class Config:
orm_mode = True
@router.post("/", response_model=DistributionRead) @router.post("/", response_model=DistributionRead)
async def create_distribution(dist: DistributionCreate, db: Session = Depends(get_db)): async def create_distribution(dist: DistributionCreate, db: Session = Depends(get_db)):
db_dist = Distribution(**dist.dict()) db_dist = Distribution(**dist.model_dump())
db.add(db_dist) db.add(db_dist)
db.commit() db.commit()
db.refresh(db_dist) db.refresh(db_dist)

View File

@@ -1,7 +1,7 @@
from fastapi import APIRouter, Depends, HTTPException from fastapi import APIRouter, Depends
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from typing import List, Optional from typing import List, Optional
from pydantic import BaseModel from pydantic import BaseModel, ConfigDict
from config.database import SessionLocal from config.database import SessionLocal
from models.equipment import Equipment from models.equipment import Equipment
@@ -25,14 +25,12 @@ class EquipmentCreate(BaseModel):
class EquipmentRead(EquipmentCreate): class EquipmentRead(EquipmentCreate):
id: int id: int
model_config = ConfigDict(from_attributes=True)
class Config:
orm_mode = True
@router.post("/", response_model=EquipmentRead) @router.post("/", response_model=EquipmentRead)
async def create_equipment(item: EquipmentCreate, db: Session = Depends(get_db)): async def create_equipment(item: EquipmentCreate, db: Session = Depends(get_db)):
db_item = Equipment(**item.dict()) db_item = Equipment(**item.model_dump())
db.add(db_item) db.add(db_item)
db.commit() db.commit()
db.refresh(db_item) db.refresh(db_item)

View File

@@ -2,7 +2,7 @@ from datetime import date
from typing import List, Optional from typing import List, Optional
from fastapi import APIRouter, Depends, HTTPException, status from fastapi import APIRouter, Depends, HTTPException, status
from pydantic import BaseModel, PositiveFloat from pydantic import BaseModel, ConfigDict, PositiveFloat
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from config.database import SessionLocal from config.database import SessionLocal
@@ -38,9 +38,7 @@ class MaintenanceUpdate(MaintenanceBase):
class MaintenanceRead(MaintenanceBase): class MaintenanceRead(MaintenanceBase):
id: int id: int
model_config = ConfigDict(from_attributes=True)
class Config:
orm_mode = True
def _get_maintenance_or_404(db: Session, maintenance_id: int) -> Maintenance: def _get_maintenance_or_404(db: Session, maintenance_id: int) -> Maintenance:
@@ -56,7 +54,7 @@ def _get_maintenance_or_404(db: Session, maintenance_id: int) -> Maintenance:
@router.post("/", response_model=MaintenanceRead, status_code=status.HTTP_201_CREATED) @router.post("/", response_model=MaintenanceRead, status_code=status.HTTP_201_CREATED)
def create_maintenance(maintenance: MaintenanceCreate, db: Session = Depends(get_db)): def create_maintenance(maintenance: MaintenanceCreate, db: Session = Depends(get_db)):
db_maintenance = Maintenance(**maintenance.dict()) db_maintenance = Maintenance(**maintenance.model_dump())
db.add(db_maintenance) db.add(db_maintenance)
db.commit() db.commit()
db.refresh(db_maintenance) db.refresh(db_maintenance)
@@ -80,7 +78,7 @@ def update_maintenance(
db: Session = Depends(get_db), db: Session = Depends(get_db),
): ):
db_maintenance = _get_maintenance_or_404(db, maintenance_id) db_maintenance = _get_maintenance_or_404(db, maintenance_id)
for field, value in payload.dict().items(): for field, value in payload.model_dump().items():
setattr(db_maintenance, field, value) setattr(db_maintenance, field, value)
db.commit() db.commit()
db.refresh(db_maintenance) db.refresh(db_maintenance)

View File

@@ -1,10 +1,13 @@
from typing import Any, Dict, List, Optional
from fastapi import APIRouter, Depends, HTTPException from fastapi import APIRouter, Depends, HTTPException
from pydantic import BaseModel, ConfigDict, field_validator
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from config.database import SessionLocal from config.database import SessionLocal
from models.distribution import Distribution
from models.parameters import Parameter from models.parameters import Parameter
from models.scenario import Scenario from models.scenario import Scenario
from pydantic import BaseModel
from typing import Optional, List
router = APIRouter(prefix="/api/parameters", tags=["parameters"]) router = APIRouter(prefix="/api/parameters", tags=["parameters"])
@@ -13,13 +16,34 @@ class ParameterCreate(BaseModel):
scenario_id: int scenario_id: int
name: str name: str
value: float value: float
distribution_id: Optional[int] = None
distribution_type: Optional[str] = None
distribution_parameters: Optional[Dict[str, Any]] = None
@field_validator("distribution_type")
@classmethod
def normalize_type(cls, value: Optional[str]) -> Optional[str]:
if value is None:
return value
normalized = value.strip().lower()
if not normalized:
return None
if normalized not in {"normal", "uniform", "triangular"}:
raise ValueError(
"distribution_type must be normal, uniform, or triangular")
return normalized
@field_validator("distribution_parameters")
@classmethod
def empty_dict_to_none(cls, value: Optional[Dict[str, Any]]) -> Optional[Dict[str, Any]]:
if value is None:
return None
return value or None
class ParameterRead(ParameterCreate): class ParameterRead(ParameterCreate):
id: int id: int
model_config = ConfigDict(from_attributes=True)
class Config:
orm_mode = True
# Dependency # Dependency
@@ -37,8 +61,27 @@ def create_parameter(param: ParameterCreate, db: Session = Depends(get_db)):
scen = db.query(Scenario).filter(Scenario.id == param.scenario_id).first() scen = db.query(Scenario).filter(Scenario.id == param.scenario_id).first()
if not scen: if not scen:
raise HTTPException(status_code=404, detail="Scenario not found") raise HTTPException(status_code=404, detail="Scenario not found")
new_param = Parameter(scenario_id=param.scenario_id, distribution_id = param.distribution_id
name=param.name, value=param.value) distribution_type = param.distribution_type
distribution_parameters = param.distribution_parameters
if distribution_id is not None:
distribution = db.query(Distribution).filter(
Distribution.id == distribution_id).first()
if not distribution:
raise HTTPException(
status_code=404, detail="Distribution not found")
distribution_type = distribution.distribution_type
distribution_parameters = distribution.parameters or None
new_param = Parameter(
scenario_id=param.scenario_id,
name=param.name,
value=param.value,
distribution_id=distribution_id,
distribution_type=distribution_type,
distribution_parameters=distribution_parameters,
)
db.add(new_param) db.add(new_param)
db.commit() db.commit()
db.refresh(new_param) db.refresh(new_param)

View File

@@ -1,7 +1,7 @@
from typing import List, Optional from typing import List, Optional
from fastapi import APIRouter, Depends, status from fastapi import APIRouter, Depends, status
from pydantic import BaseModel, PositiveFloat from pydantic import BaseModel, ConfigDict, PositiveFloat
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from config.database import SessionLocal from config.database import SessionLocal
@@ -31,14 +31,12 @@ class ProductionOutputCreate(ProductionOutputBase):
class ProductionOutputRead(ProductionOutputBase): class ProductionOutputRead(ProductionOutputBase):
id: int id: int
model_config = ConfigDict(from_attributes=True)
class Config:
orm_mode = True
@router.post("/", response_model=ProductionOutputRead, status_code=status.HTTP_201_CREATED) @router.post("/", response_model=ProductionOutputRead, status_code=status.HTTP_201_CREATED)
def create_production(item: ProductionOutputCreate, db: Session = Depends(get_db)): def create_production(item: ProductionOutputCreate, db: Session = Depends(get_db)):
db_item = ProductionOutput(**item.dict()) db_item = ProductionOutput(**item.model_dump())
db.add(db_item) db.add(db_item)
db.commit() db.commit()
db.refresh(db_item) db.refresh(db_item)

View File

@@ -1,4 +1,4 @@
from typing import Any, Dict, List from typing import Any, Dict, List, cast
from fastapi import APIRouter, HTTPException, Request, status from fastapi import APIRouter, HTTPException, Request, status
from pydantic import BaseModel from pydantic import BaseModel
@@ -25,14 +25,16 @@ def _validate_payload(payload: Any) -> List[Dict[str, float]]:
detail="Invalid input format", detail="Invalid input format",
) )
typed_payload = cast(List[Any], payload)
validated: List[Dict[str, float]] = [] validated: List[Dict[str, float]] = []
for index, item in enumerate(payload): for index, item in enumerate(typed_payload):
if not isinstance(item, dict): if not isinstance(item, dict):
raise HTTPException( raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Entry at index {index} must be an object", detail=f"Entry at index {index} must be an object",
) )
value = item.get("result") value = cast(Dict[str, Any], item).get("result")
if not isinstance(value, (int, float)): if not isinstance(value, (int, float)):
raise HTTPException( raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, status_code=status.HTTP_400_BAD_REQUEST,
@@ -49,8 +51,13 @@ class ReportSummary(BaseModel):
min: float min: float
max: float max: float
std_dev: float std_dev: float
variance: float
percentile_10: float percentile_10: float
percentile_90: float percentile_90: float
percentile_5: float
percentile_95: float
value_at_risk_95: float
expected_shortfall_95: float
@router.post("/summary", response_model=ReportSummary) @router.post("/summary", response_model=ReportSummary)
@@ -65,6 +72,11 @@ async def summary_report(request: Request):
min=float(summary["min"]), min=float(summary["min"]),
max=float(summary["max"]), max=float(summary["max"]),
std_dev=float(summary["std_dev"]), std_dev=float(summary["std_dev"]),
variance=float(summary["variance"]),
percentile_10=float(summary["percentile_10"]), percentile_10=float(summary["percentile_10"]),
percentile_90=float(summary["percentile_90"]), percentile_90=float(summary["percentile_90"]),
percentile_5=float(summary["percentile_5"]),
percentile_95=float(summary["percentile_95"]),
value_at_risk_95=float(summary["value_at_risk_95"]),
expected_shortfall_95=float(summary["expected_shortfall_95"]),
) )

View File

@@ -2,7 +2,7 @@ from fastapi import APIRouter, Depends, HTTPException
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from config.database import SessionLocal from config.database import SessionLocal
from models.scenario import Scenario from models.scenario import Scenario
from pydantic import BaseModel from pydantic import BaseModel, ConfigDict
from typing import Optional from typing import Optional
from datetime import datetime from datetime import datetime
@@ -20,9 +20,7 @@ class ScenarioRead(ScenarioCreate):
id: int id: int
created_at: datetime created_at: datetime
updated_at: Optional[datetime] = None updated_at: Optional[datetime] = None
model_config = ConfigDict(from_attributes=True)
class Config:
orm_mode = True
# Dependency # Dependency

View File

@@ -60,8 +60,8 @@ def _load_parameters(db: Session, scenario_id: int) -> List[SimulationParameterI
) )
return [ return [
SimulationParameterInput( SimulationParameterInput(
name=cast(str, item.name), name=item.name,
value=cast(float, item.value), value=item.value,
) )
for item in db_params for item in db_params
] ]
@@ -86,7 +86,7 @@ async def simulate(payload: SimulationRunRequest, db: Session = Depends(get_db))
) )
raw_results = run_simulation( raw_results = run_simulation(
[param.dict(exclude_none=True) for param in parameters], [param.model_dump(exclude_none=True) for param in parameters],
iterations=payload.iterations, iterations=payload.iterations,
seed=payload.seed, seed=payload.seed,
) )

View File

@@ -1,13 +1,14 @@
from statistics import mean, median, pstdev from statistics import mean, median, pstdev
from typing import Dict, Iterable, List, Union from typing import Any, Dict, Iterable, List, Mapping, Union, cast
def _extract_results(simulation_results: Iterable[Dict[str, float]]) -> List[float]: def _extract_results(simulation_results: Iterable[object]) -> List[float]:
values: List[float] = [] values: List[float] = []
for item in simulation_results: for item in simulation_results:
if not isinstance(item, dict): if not isinstance(item, Mapping):
continue continue
value = item.get("result") mapping_item = cast(Mapping[str, Any], item)
value = mapping_item.get("result")
if isinstance(value, (int, float)): if isinstance(value, (int, float)):
values.append(float(value)) values.append(float(value))
return values return values
@@ -39,8 +40,13 @@ def generate_report(simulation_results: List[Dict[str, float]]) -> Dict[str, Uni
"min": 0.0, "min": 0.0,
"max": 0.0, "max": 0.0,
"std_dev": 0.0, "std_dev": 0.0,
"variance": 0.0,
"percentile_10": 0.0, "percentile_10": 0.0,
"percentile_90": 0.0, "percentile_90": 0.0,
"percentile_5": 0.0,
"percentile_95": 0.0,
"value_at_risk_95": 0.0,
"expected_shortfall_95": 0.0,
} }
summary: Dict[str, Union[float, int]] = { summary: Dict[str, Union[float, int]] = {
@@ -51,7 +57,21 @@ def generate_report(simulation_results: List[Dict[str, float]]) -> Dict[str, Uni
"max": max(values), "max": max(values),
"percentile_10": _percentile(values, 10), "percentile_10": _percentile(values, 10),
"percentile_90": _percentile(values, 90), "percentile_90": _percentile(values, 90),
"percentile_5": _percentile(values, 5),
"percentile_95": _percentile(values, 95),
} }
summary["std_dev"] = pstdev(values) if len(values) > 1 else 0.0 std_dev = pstdev(values) if len(values) > 1 else 0.0
summary["std_dev"] = std_dev
summary["variance"] = std_dev ** 2
var_95 = summary["percentile_5"]
summary["value_at_risk_95"] = var_95
tail_values = [value for value in values if value <= var_95]
if tail_values:
summary["expected_shortfall_95"] = mean(tail_values)
else:
summary["expected_shortfall_95"] = var_95
return summary return summary

View File

@@ -87,8 +87,16 @@
{ key: "min", label: "Min" }, { key: "min", label: "Min" },
{ key: "max", label: "Max" }, { key: "max", label: "Max" },
{ key: "std_dev", label: "Std Dev" }, { key: "std_dev", label: "Std Dev" },
{ key: "variance", label: "Variance" },
{ key: "percentile_5", label: "5th Percentile" },
{ key: "percentile_10", label: "10th Percentile" }, { key: "percentile_10", label: "10th Percentile" },
{ key: "percentile_90", label: "90th Percentile" }, { key: "percentile_90", label: "90th Percentile" },
{ key: "percentile_95", label: "95th Percentile" },
{ key: "value_at_risk_95", label: "VaR (95%)" },
{
key: "expected_shortfall_95",
label: "Expected Shortfall (95%)",
},
]; ];
async function fetchSummary(results) { async function fetchSummary(results) {
@@ -123,12 +131,16 @@
const grid = document.getElementById("summary-grid"); const grid = document.getElementById("summary-grid");
grid.innerHTML = ""; grid.innerHTML = "";
SUMMARY_FIELDS.forEach(({ key, label }) => { SUMMARY_FIELDS.forEach(({ key, label }) => {
const value = summary[key] ?? 0; const rawValue = summary[key];
const numericValue = Number(rawValue);
const display = Number.isFinite(numericValue)
? numericValue.toFixed(2)
: "—";
const metric = document.createElement("div"); const metric = document.createElement("div");
metric.className = "metric"; metric.className = "metric";
metric.innerHTML = ` metric.innerHTML = `
<div class="metric-label">${label}</div> <div class="metric-label">${label}</div>
<div class="metric-value">${value.toFixed(2)}</div> <div class="metric-value">${display}</div>
`; `;
grid.appendChild(metric); grid.appendChild(metric);
}); });
@@ -138,14 +150,34 @@
function renderChart(summary) { function renderChart(summary) {
const ctx = document.getElementById("summary-chart").getContext("2d"); const ctx = document.getElementById("summary-chart").getContext("2d");
const dataPoints = [ const percentilePoints = [
summary.min, { label: "Min", value: summary.min },
summary.percentile_10, { label: "P5", value: summary.percentile_5 },
summary.median, { label: "P10", value: summary.percentile_10 },
summary.mean, { label: "Median", value: summary.median },
summary.percentile_90, { label: "Mean", value: summary.mean },
summary.max, { label: "P90", value: summary.percentile_90 },
].map((value) => Number(value ?? 0)); { label: "P95", value: summary.percentile_95 },
{ label: "Max", value: summary.max },
];
const labels = percentilePoints.map((point) => point.label);
const dataPoints = percentilePoints.map((point) =>
Number(point.value ?? 0)
);
const tailRiskLines = [
{ label: "VaR (95%)", value: summary.value_at_risk_95 },
{ label: "ES (95%)", value: summary.expected_shortfall_95 },
]
.map(({ label, value }) => {
const numeric = Number(value);
if (!Number.isFinite(numeric)) {
return null;
}
return `${label}: ${numeric.toFixed(2)}`;
})
.filter((line) => line !== null);
if (chartInstance) { if (chartInstance) {
chartInstance.destroy(); chartInstance.destroy();
@@ -154,7 +186,7 @@
chartInstance = new Chart(ctx, { chartInstance = new Chart(ctx, {
type: "line", type: "line",
data: { data: {
labels: ["Min", "P10", "Median", "Mean", "P90", "Max"], labels,
datasets: [ datasets: [
{ {
label: "Result Summary", label: "Result Summary",
@@ -169,6 +201,11 @@
options: { options: {
plugins: { plugins: {
legend: { display: false }, legend: { display: false },
tooltip: {
callbacks: {
afterBody: () => tailRiskLines,
},
},
}, },
scales: { scales: {
y: { y: {

94
tests/unit/conftest.py Normal file
View File

@@ -0,0 +1,94 @@
from typing import Generator
import pytest
from fastapi.testclient import TestClient
from sqlalchemy import create_engine
from sqlalchemy.orm import Session, sessionmaker
from sqlalchemy.pool import StaticPool
from config.database import Base
from main import app
SQLALCHEMY_TEST_URL = "sqlite:///:memory:"
engine = create_engine(
SQLALCHEMY_TEST_URL,
connect_args={"check_same_thread": False},
poolclass=StaticPool,
)
TestingSessionLocal = sessionmaker(
autocommit=False, autoflush=False, bind=engine)
@pytest.fixture(scope="session", autouse=True)
def setup_database() -> Generator[None, None, None]:
# Ensure all model metadata is registered before creating tables
from models import (
capex,
consumption,
distribution,
equipment,
maintenance,
opex,
parameters,
production_output,
scenario,
simulation_result,
) # noqa: F401 - imported for side effects
Base.metadata.create_all(bind=engine)
yield
Base.metadata.drop_all(bind=engine)
@pytest.fixture()
def db_session() -> Generator[Session, None, None]:
session = TestingSessionLocal()
try:
yield session
finally:
session.close()
@pytest.fixture()
def api_client(db_session: Session) -> Generator[TestClient, None, None]:
def override_get_db():
try:
yield db_session
finally:
pass
# override all routers that use get_db
from routes import (
consumption,
costs,
distributions,
equipment,
maintenance,
parameters,
production,
reporting,
scenarios,
simulations,
)
overrides = {
consumption.get_db: override_get_db,
costs.get_db: override_get_db,
distributions.get_db: override_get_db,
equipment.get_db: override_get_db,
maintenance.get_db: override_get_db,
parameters.get_db: override_get_db,
production.get_db: override_get_db,
reporting.get_db: override_get_db,
scenarios.get_db: override_get_db,
simulations.get_db: override_get_db,
}
for dependency, override in overrides.items():
app.dependency_overrides[dependency] = override
with TestClient(app) as client:
yield client
for dependency in overrides:
app.dependency_overrides.pop(dependency, None)

View File

@@ -1,42 +1,69 @@
from uuid import uuid4
import pytest
from fastapi.testclient import TestClient from fastapi.testclient import TestClient
from main import app
from config.database import Base, engine
# Setup and teardown
def setup_module(module): @pytest.fixture
Base.metadata.create_all(bind=engine) def client(api_client: TestClient) -> TestClient:
return api_client
def teardown_module(module): def _create_scenario(client: TestClient) -> int:
Base.metadata.drop_all(bind=engine) payload = {
"name": f"Consumption Scenario {uuid4()}",
"description": "Scenario for consumption tests",
}
response = client.post("/api/scenarios/", json=payload)
assert response.status_code == 200
return response.json()["id"]
client = TestClient(app) def test_create_consumption(client: TestClient) -> None:
scenario_id = _create_scenario(client)
payload = {
"scenario_id": scenario_id,
"amount": 125.5,
"description": "Fuel usage baseline",
}
response = client.post("/api/consumption/", json=payload)
assert response.status_code == 201
body = response.json()
assert body["id"] > 0
assert body["scenario_id"] == scenario_id
assert body["amount"] == pytest.approx(125.5)
assert body["description"] == "Fuel usage baseline"
def test_create_and_list_consumption(): def test_list_consumption_returns_created_items(client: TestClient) -> None:
# Create a scenario to attach consumption scenario_id = _create_scenario(client)
resp = client.post( values = [50.0, 80.75]
"/api/scenarios/", json={"name": "ConsScenario", "description": "consumption scenario"} for amount in values:
) response = client.post(
assert resp.status_code == 200 "/api/consumption/",
scenario = resp.json() json={
sid = scenario["id"] "scenario_id": scenario_id,
"amount": amount,
"description": f"Consumption {amount}",
},
)
assert response.status_code == 201
# Create Consumption item list_response = client.get("/api/consumption/")
cons_payload = {"scenario_id": sid, "amount": 250.0, assert list_response.status_code == 200
"description": "Monthly consumption"} items = [item for item in list_response.json(
resp2 = client.post("/api/consumption/", json=cons_payload) ) if item["scenario_id"] == scenario_id]
assert resp2.status_code == 201 assert {item["amount"] for item in items} == set(values)
cons = resp2.json()
assert cons["scenario_id"] == sid
assert cons["amount"] == 250.0
# List Consumption items
resp3 = client.get("/api/consumption/") def test_create_consumption_rejects_negative_amount(client: TestClient) -> None:
assert resp3.status_code == 200 scenario_id = _create_scenario(client)
data = resp3.json() payload = {
assert any(item["amount"] == 250.0 and item["scenario_id"] "scenario_id": scenario_id,
== sid for item in data) "amount": -10,
"description": "Invalid negative amount",
}
response = client.post("/api/consumption/", json=payload)
assert response.status_code == 422

View File

@@ -1,8 +1,9 @@
from fastapi.testclient import TestClient from uuid import uuid4
from main import app
from config.database import Base, engine
# Setup and teardown from fastapi.testclient import TestClient
from config.database import Base, engine
from main import app
def setup_module(module): def setup_module(module):
@@ -16,43 +17,89 @@ def teardown_module(module):
client = TestClient(app) client = TestClient(app)
def test_create_and_list_capex_and_opex(): def _create_scenario() -> int:
# Create a scenario to attach costs payload = {
resp = client.post( "name": f"CostScenario-{uuid4()}",
"/api/scenarios/", json={"name": "CostScenario", "description": "cost scenario"} "description": "Cost tracking test scenario",
) }
assert resp.status_code == 200 response = client.post("/api/scenarios/", json=payload)
scenario = resp.json() assert response.status_code == 200
sid = scenario["id"] return response.json()["id"]
# Create Capex item
capex_payload = {"scenario_id": sid, def test_create_and_list_capex_and_opex():
"amount": 1000.0, "description": "Initial capex"} sid = _create_scenario()
capex_payload = {
"scenario_id": sid,
"amount": 1000.0,
"description": "Initial capex",
}
resp2 = client.post("/api/costs/capex", json=capex_payload) resp2 = client.post("/api/costs/capex", json=capex_payload)
assert resp2.status_code == 200 assert resp2.status_code == 200
capex = resp2.json() capex = resp2.json()
assert capex["scenario_id"] == sid assert capex["scenario_id"] == sid
assert capex["amount"] == 1000.0 assert capex["amount"] == 1000.0
# List Capex items
resp3 = client.get("/api/costs/capex") resp3 = client.get("/api/costs/capex")
assert resp3.status_code == 200 assert resp3.status_code == 200
data = resp3.json() data = resp3.json()
assert any(item["amount"] == 1000.0 and item["scenario_id"] assert any(item["amount"] == 1000.0 and item["scenario_id"]
== sid for item in data) == sid for item in data)
# Create Opex item opex_payload = {
opex_payload = {"scenario_id": sid, "amount": 500.0, "scenario_id": sid,
"description": "Recurring opex"} "amount": 500.0,
"description": "Recurring opex",
}
resp4 = client.post("/api/costs/opex", json=opex_payload) resp4 = client.post("/api/costs/opex", json=opex_payload)
assert resp4.status_code == 200 assert resp4.status_code == 200
opex = resp4.json() opex = resp4.json()
assert opex["scenario_id"] == sid assert opex["scenario_id"] == sid
assert opex["amount"] == 500.0 assert opex["amount"] == 500.0
# List Opex items
resp5 = client.get("/api/costs/opex") resp5 = client.get("/api/costs/opex")
assert resp5.status_code == 200 assert resp5.status_code == 200
data_o = resp5.json() data_o = resp5.json()
assert any(item["amount"] == 500.0 and item["scenario_id"] assert any(item["amount"] == 500.0 and item["scenario_id"]
== sid for item in data_o) == sid for item in data_o)
def test_multiple_capex_entries():
sid = _create_scenario()
amounts = [250.0, 750.0]
for amount in amounts:
resp = client.post(
"/api/costs/capex",
json={"scenario_id": sid, "amount": amount,
"description": f"Capex {amount}"},
)
assert resp.status_code == 200
resp = client.get("/api/costs/capex")
assert resp.status_code == 200
data = resp.json()
retrieved_amounts = [item["amount"]
for item in data if item["scenario_id"] == sid]
for amount in amounts:
assert amount in retrieved_amounts
def test_multiple_opex_entries():
sid = _create_scenario()
amounts = [120.0, 340.0]
for amount in amounts:
resp = client.post(
"/api/costs/opex",
json={"scenario_id": sid, "amount": amount,
"description": f"Opex {amount}"},
)
assert resp.status_code == 200
resp = client.get("/api/costs/opex")
assert resp.status_code == 200
data = resp.json()
retrieved_amounts = [item["amount"]
for item in data if item["scenario_id"] == sid]
for amount in amounts:
assert amount in retrieved_amounts

View File

@@ -1,9 +1,9 @@
from fastapi.testclient import TestClient from uuid import uuid4
from main import app
from config.database import Base, engine
from models.distribution import Distribution
# Setup and teardown from fastapi.testclient import TestClient
from config.database import Base, engine
from main import app
def setup_module(module): def setup_module(module):
@@ -18,16 +18,54 @@ client = TestClient(app)
def test_create_and_list_distribution(): def test_create_and_list_distribution():
# Create distribution dist_name = f"NormalDist-{uuid4()}"
payload = {"name": "NormalDist", "distribution_type": "normal", payload = {
"parameters": {"mu": 0, "sigma": 1}} "name": dist_name,
"distribution_type": "normal",
"parameters": {"mu": 0, "sigma": 1},
}
resp = client.post("/api/distributions/", json=payload) resp = client.post("/api/distributions/", json=payload)
assert resp.status_code == 200 assert resp.status_code == 200
data = resp.json() data = resp.json()
assert data["name"] == "NormalDist" assert data["name"] == dist_name
# List distributions
resp2 = client.get("/api/distributions/") resp2 = client.get("/api/distributions/")
assert resp2.status_code == 200 assert resp2.status_code == 200
data2 = resp2.json() data2 = resp2.json()
assert any(d["name"] == "NormalDist" for d in data2) assert any(d["name"] == dist_name for d in data2)
def test_duplicate_distribution_name_allowed():
dist_name = f"DupDist-{uuid4()}"
payload = {
"name": dist_name,
"distribution_type": "uniform",
"parameters": {"min": 0, "max": 1},
}
first = client.post("/api/distributions/", json=payload)
assert first.status_code == 200
duplicate = client.post("/api/distributions/", json=payload)
assert duplicate.status_code == 200
resp = client.get("/api/distributions/")
assert resp.status_code == 200
matching = [item for item in resp.json() if item["name"] == dist_name]
assert len(matching) >= 2
def test_list_distributions_returns_all():
names = {f"ListDist-{uuid4()}" for _ in range(2)}
for name in names:
payload = {
"name": name,
"distribution_type": "triangular",
"parameters": {"min": 0, "max": 10, "mode": 5},
}
resp = client.post("/api/distributions/", json=payload)
assert resp.status_code == 200
resp = client.get("/api/distributions/")
assert resp.status_code == 200
found_names = {item["name"] for item in resp.json()}
assert names.issubset(found_names)

View File

@@ -1,42 +1,77 @@
from uuid import uuid4
import pytest
from fastapi.testclient import TestClient from fastapi.testclient import TestClient
from main import app
from config.database import Base, engine
# Setup and teardown
def setup_module(module): @pytest.fixture
Base.metadata.create_all(bind=engine) def client(api_client: TestClient) -> TestClient:
return api_client
def teardown_module(module): def _create_scenario(client: TestClient) -> int:
Base.metadata.drop_all(bind=engine) payload = {
"name": f"Equipment Scenario {uuid4()}",
"description": "Scenario for equipment tests",
}
response = client.post("/api/scenarios/", json=payload)
assert response.status_code == 200
return response.json()["id"]
client = TestClient(app) def test_create_equipment(client: TestClient) -> None:
scenario_id = _create_scenario(client)
payload = {
"scenario_id": scenario_id,
"name": "Excavator",
"description": "Heavy machinery",
}
response = client.post("/api/equipment/", json=payload)
assert response.status_code == 200
created = response.json()
assert created["id"] > 0
assert created["scenario_id"] == scenario_id
assert created["name"] == "Excavator"
assert created["description"] == "Heavy machinery"
def test_create_and_list_equipment(): def test_list_equipment_filters_by_scenario(client: TestClient) -> None:
# Create a scenario to attach equipment target_scenario = _create_scenario(client)
resp = client.post( other_scenario = _create_scenario(client)
"/api/scenarios/", json={"name": "EquipScenario", "description": "equipment scenario"}
for scenario_id, name in [
(target_scenario, "Bulldozer"),
(target_scenario, "Loader"),
(other_scenario, "Conveyor"),
]:
response = client.post(
"/api/equipment/",
json={
"scenario_id": scenario_id,
"name": name,
"description": f"Equipment {name}",
},
)
assert response.status_code == 200
list_response = client.get("/api/equipment/")
assert list_response.status_code == 200
items = [
item
for item in list_response.json()
if item["scenario_id"] == target_scenario
]
assert {item["name"] for item in items} == {"Bulldozer", "Loader"}
def test_create_equipment_requires_name(client: TestClient) -> None:
scenario_id = _create_scenario(client)
response = client.post(
"/api/equipment/",
json={
"scenario_id": scenario_id,
"description": "Missing name",
},
) )
assert resp.status_code == 200 assert response.status_code == 422
scenario = resp.json()
sid = scenario["id"]
# Create Equipment item
eq_payload = {"scenario_id": sid, "name": "Excavator",
"description": "Heavy machinery"}
resp2 = client.post("/api/equipment/", json=eq_payload)
assert resp2.status_code == 200
eq = resp2.json()
assert eq["scenario_id"] == sid
assert eq["name"] == "Excavator"
# List Equipment items
resp3 = client.get("/api/equipment/")
assert resp3.status_code == 200
data = resp3.json()
assert any(item["name"] == "Excavator" and item["scenario_id"]
== sid for item in data)

View File

@@ -1,75 +1,16 @@
from uuid import uuid4 from uuid import uuid4
import pytest
from fastapi.testclient import TestClient from fastapi.testclient import TestClient
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
from config.database import Base
from main import app
from routes import (
consumption,
costs,
distributions,
equipment,
maintenance,
parameters,
production,
reporting,
scenarios,
simulations,
ui,
)
SQLALCHEMY_DATABASE_URL = "sqlite:///./test.db"
engine = create_engine(SQLALCHEMY_DATABASE_URL, connect_args={
"check_same_thread": False})
TestingSessionLocal = sessionmaker(
autocommit=False, autoflush=False, bind=engine)
def setup_module(module): @pytest.fixture
Base.metadata.create_all(bind=engine) def client(api_client: TestClient) -> TestClient:
return api_client
def teardown_module(module): def _create_scenario_and_equipment(client: TestClient):
Base.metadata.drop_all(bind=engine)
def override_get_db():
db = TestingSessionLocal()
try:
yield db
finally:
db.close()
app.dependency_overrides[maintenance.get_db] = override_get_db
app.dependency_overrides[equipment.get_db] = override_get_db
app.dependency_overrides[scenarios.get_db] = override_get_db
app.dependency_overrides[distributions.get_db] = override_get_db
app.dependency_overrides[parameters.get_db] = override_get_db
app.dependency_overrides[costs.get_db] = override_get_db
app.dependency_overrides[consumption.get_db] = override_get_db
app.dependency_overrides[production.get_db] = override_get_db
app.dependency_overrides[reporting.get_db] = override_get_db
app.dependency_overrides[simulations.get_db] = override_get_db
app.include_router(maintenance.router)
app.include_router(equipment.router)
app.include_router(scenarios.router)
app.include_router(distributions.router)
app.include_router(ui.router)
app.include_router(parameters.router)
app.include_router(costs.router)
app.include_router(consumption.router)
app.include_router(production.router)
app.include_router(reporting.router)
app.include_router(simulations.router)
client = TestClient(app)
def _create_scenario_and_equipment():
scenario_payload = { scenario_payload = {
"name": f"Test Scenario {uuid4()}", "name": f"Test Scenario {uuid4()}",
"description": "Scenario for maintenance tests", "description": "Scenario for maintenance tests",
@@ -99,8 +40,8 @@ def _create_maintenance_payload(equipment_id: int, scenario_id: int, description
} }
def test_create_and_list_maintenance(): def test_create_and_list_maintenance(client: TestClient):
scenario_id, equipment_id = _create_scenario_and_equipment() scenario_id, equipment_id = _create_scenario_and_equipment(client)
payload = _create_maintenance_payload( payload = _create_maintenance_payload(
equipment_id, scenario_id, "Create maintenance") equipment_id, scenario_id, "Create maintenance")
@@ -117,8 +58,8 @@ def test_create_and_list_maintenance():
assert any(item["id"] == created["id"] for item in items) assert any(item["id"] == created["id"] for item in items)
def test_get_maintenance(): def test_get_maintenance(client: TestClient):
scenario_id, equipment_id = _create_scenario_and_equipment() scenario_id, equipment_id = _create_scenario_and_equipment(client)
payload = _create_maintenance_payload( payload = _create_maintenance_payload(
equipment_id, scenario_id, "Retrieve maintenance" equipment_id, scenario_id, "Retrieve maintenance"
) )
@@ -134,8 +75,8 @@ def test_get_maintenance():
assert data["description"] == "Retrieve maintenance" assert data["description"] == "Retrieve maintenance"
def test_update_maintenance(): def test_update_maintenance(client: TestClient):
scenario_id, equipment_id = _create_scenario_and_equipment() scenario_id, equipment_id = _create_scenario_and_equipment(client)
create_response = client.post( create_response = client.post(
"/api/maintenance/", "/api/maintenance/",
json=_create_maintenance_payload( json=_create_maintenance_payload(
@@ -162,8 +103,8 @@ def test_update_maintenance():
assert updated["cost"] == 250.0 assert updated["cost"] == 250.0
def test_delete_maintenance(): def test_delete_maintenance(client: TestClient):
scenario_id, equipment_id = _create_scenario_and_equipment() scenario_id, equipment_id = _create_scenario_and_equipment(client)
create_response = client.post( create_response = client.post(
"/api/maintenance/", "/api/maintenance/",
json=_create_maintenance_payload( json=_create_maintenance_payload(

View File

@@ -1,46 +0,0 @@
from models.scenario import Scenario
from main import app
from config.database import Base, engine
from fastapi.testclient import TestClient
import pytest
import os
import sys
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(__file__))))
# Setup and teardown
def setup_module(module):
Base.metadata.create_all(bind=engine)
def teardown_module(module):
Base.metadata.drop_all(bind=engine)
client = TestClient(app)
# Helper to create a scenario
def create_test_scenario():
resp = client.post("/api/scenarios/",
json={"name": "ParamTest", "description": "Desc"})
assert resp.status_code == 200
return resp.json()["id"]
def test_create_and_list_parameter():
# Ensure scenario exists
scen_id = create_test_scenario()
# Create a parameter
resp = client.post(
"/api/parameters/", json={"scenario_id": scen_id, "name": "param1", "value": 3.14})
assert resp.status_code == 200
data = resp.json()
assert data["name"] == "param1"
# List parameters
resp2 = client.get("/api/parameters/")
assert resp2.status_code == 200
data2 = resp2.json()
assert any(p["name"] == "param1" for p in data2)

View File

@@ -0,0 +1,123 @@
from typing import Any, Dict, List
from uuid import uuid4
from fastapi.testclient import TestClient
from config.database import Base, engine
from main import app
def setup_module(module: object) -> None:
Base.metadata.create_all(bind=engine)
def teardown_module(module: object) -> None:
Base.metadata.drop_all(bind=engine)
def _create_scenario(name: str | None = None) -> int:
payload: Dict[str, Any] = {
"name": name or f"ParamScenario-{uuid4()}",
"description": "Parameter test scenario",
}
response = TestClient(app).post("/api/scenarios/", json=payload)
assert response.status_code == 200
return response.json()["id"]
def _create_distribution() -> int:
payload: Dict[str, Any] = {
"name": f"NormalDist-{uuid4()}",
"distribution_type": "normal",
"parameters": {"mu": 10, "sigma": 2},
}
response = TestClient(app).post("/api/distributions/", json=payload)
assert response.status_code == 200
return response.json()["id"]
client = TestClient(app)
def test_create_and_list_parameter():
scenario_id = _create_scenario()
distribution_id = _create_distribution()
parameter_payload: Dict[str, Any] = {
"scenario_id": scenario_id,
"name": f"param-{uuid4()}",
"value": 3.14,
"distribution_id": distribution_id,
}
create_response = client.post("/api/parameters/", json=parameter_payload)
assert create_response.status_code == 200
created = create_response.json()
assert created["scenario_id"] == scenario_id
assert created["name"] == parameter_payload["name"]
assert created["value"] == parameter_payload["value"]
assert created["distribution_id"] == distribution_id
assert created["distribution_type"] == "normal"
assert created["distribution_parameters"] == {"mu": 10, "sigma": 2}
list_response = client.get("/api/parameters/")
assert list_response.status_code == 200
params = list_response.json()
assert any(p["id"] == created["id"] for p in params)
def test_create_parameter_for_missing_scenario():
payload: Dict[str, Any] = {
"scenario_id": 0, "name": "invalid", "value": 1.0}
response = client.post("/api/parameters/", json=payload)
assert response.status_code == 404
assert response.json()["detail"] == "Scenario not found"
def test_multiple_parameters_listed():
scenario_id = _create_scenario()
payloads: List[Dict[str, Any]] = [
{"scenario_id": scenario_id, "name": f"alpha-{i}", "value": float(i)}
for i in range(2)
]
for payload in payloads:
resp = client.post("/api/parameters/", json=payload)
assert resp.status_code == 200
list_response = client.get("/api/parameters/")
assert list_response.status_code == 200
names = {item["name"] for item in list_response.json()}
for payload in payloads:
assert payload["name"] in names
def test_parameter_inline_distribution_metadata():
scenario_id = _create_scenario()
payload: Dict[str, Any] = {
"scenario_id": scenario_id,
"name": "inline-param",
"value": 7.5,
"distribution_type": "uniform",
"distribution_parameters": {"min": 5, "max": 10},
}
response = client.post("/api/parameters/", json=payload)
assert response.status_code == 200
created = response.json()
assert created["distribution_id"] is None
assert created["distribution_type"] == "uniform"
assert created["distribution_parameters"] == {"min": 5, "max": 10}
def test_parameter_with_missing_distribution_reference():
scenario_id = _create_scenario()
payload: Dict[str, Any] = {
"scenario_id": scenario_id,
"name": "missing-dist",
"value": 1.0,
"distribution_id": 9999,
}
response = client.post("/api/parameters/", json=payload)
assert response.status_code == 404
assert response.json()["detail"] == "Distribution not found"

View File

@@ -1,42 +1,70 @@
from uuid import uuid4
import pytest
from fastapi.testclient import TestClient from fastapi.testclient import TestClient
from main import app
from config.database import Base, engine
# Setup and teardown
def setup_module(module): @pytest.fixture
Base.metadata.create_all(bind=engine) def client(api_client: TestClient) -> TestClient:
return api_client
def teardown_module(module): def _create_scenario(client: TestClient) -> int:
Base.metadata.drop_all(bind=engine) payload = {
"name": f"Production Scenario {uuid4()}",
"description": "Scenario for production tests",
}
response = client.post("/api/scenarios/", json=payload)
assert response.status_code == 200
return response.json()["id"]
client = TestClient(app) def test_create_production_record(client: TestClient) -> None:
scenario_id = _create_scenario(client)
payload = {
"scenario_id": scenario_id,
"amount": 475.25,
"description": "Daily output",
}
response = client.post("/api/production/", json=payload)
assert response.status_code == 201
created = response.json()
assert created["scenario_id"] == scenario_id
assert created["amount"] == pytest.approx(475.25)
assert created["description"] == "Daily output"
def test_create_and_list_production_output(): def test_list_production_filters_by_scenario(client: TestClient) -> None:
# Create a scenario to attach production output target_scenario = _create_scenario(client)
resp = client.post( other_scenario = _create_scenario(client)
"/api/scenarios/", json={"name": "ProdScenario", "description": "production scenario"}
for scenario_id, amount in [(target_scenario, 100.0), (target_scenario, 150.0), (other_scenario, 200.0)]:
response = client.post(
"/api/production/",
json={
"scenario_id": scenario_id,
"amount": amount,
"description": f"Output {amount}",
},
)
assert response.status_code == 201
list_response = client.get("/api/production/")
assert list_response.status_code == 200
items = [item for item in list_response.json()
if item["scenario_id"] == target_scenario]
assert {item["amount"] for item in items} == {100.0, 150.0}
def test_create_production_rejects_negative_amount(client: TestClient) -> None:
scenario_id = _create_scenario(client)
response = client.post(
"/api/production/",
json={
"scenario_id": scenario_id,
"amount": -5,
"description": "Invalid output",
},
) )
assert resp.status_code == 200 assert response.status_code == 422
scenario = resp.json()
sid = scenario["id"]
# Create Production Output item
prod_payload = {"scenario_id": sid,
"amount": 300.0, "description": "Daily output"}
resp2 = client.post("/api/production/", json=prod_payload)
assert resp2.status_code == 201
prod = resp2.json()
assert prod["scenario_id"] == sid
assert prod["amount"] == 300.0
# List Production Output items
resp3 = client.get("/api/production/")
assert resp3.status_code == 200
data = resp3.json()
assert any(item["amount"] == 300.0 and item["scenario_id"]
== sid for item in data)

View File

@@ -1,7 +1,10 @@
from fastapi.testclient import TestClient import math
from typing import Any, Dict, List
import pytest import pytest
from main import app from fastapi.testclient import TestClient
from services.reporting import generate_report from services.reporting import generate_report
@@ -14,57 +17,77 @@ def test_generate_report_empty():
"min": 0.0, "min": 0.0,
"max": 0.0, "max": 0.0,
"std_dev": 0.0, "std_dev": 0.0,
"variance": 0.0,
"percentile_10": 0.0, "percentile_10": 0.0,
"percentile_90": 0.0, "percentile_90": 0.0,
"percentile_5": 0.0,
"percentile_95": 0.0,
"value_at_risk_95": 0.0,
"expected_shortfall_95": 0.0,
} }
def test_generate_report_with_values(): def test_generate_report_with_values():
values = [{"iteration": 1, "result": 10.0}, { values: List[Dict[str, float]] = [
"iteration": 2, "result": 20.0}, {"iteration": 3, "result": 30.0}] {"iteration": 1, "result": 10.0},
{"iteration": 2, "result": 20.0},
{"iteration": 3, "result": 30.0},
]
report = generate_report(values) report = generate_report(values)
assert report["count"] == 3 assert report["count"] == 3
assert report["mean"] == pytest.approx(20.0) assert math.isclose(float(report["mean"]), 20.0)
assert report["median"] == pytest.approx(20.0) assert math.isclose(float(report["median"]), 20.0)
assert report["min"] == pytest.approx(10.0) assert math.isclose(float(report["min"]), 10.0)
assert report["max"] == pytest.approx(30.0) assert math.isclose(float(report["max"]), 30.0)
assert report["std_dev"] == pytest.approx(8.1649658, rel=1e-6) assert math.isclose(float(report["std_dev"]), 8.1649658, rel_tol=1e-6)
assert report["percentile_10"] == pytest.approx(12.0) assert math.isclose(float(report["variance"]), 66.6666666, rel_tol=1e-6)
assert report["percentile_90"] == pytest.approx(28.0) assert math.isclose(float(report["percentile_10"]), 12.0)
assert math.isclose(float(report["percentile_90"]), 28.0)
assert math.isclose(float(report["percentile_5"]), 11.0)
assert math.isclose(float(report["percentile_95"]), 29.0)
assert math.isclose(float(report["value_at_risk_95"]), 11.0)
assert math.isclose(float(report["expected_shortfall_95"]), 10.0)
def test_reporting_endpoint_invalid_input(): @pytest.fixture
client = TestClient(app) def client(api_client: TestClient) -> TestClient:
return api_client
def test_reporting_endpoint_invalid_input(client: TestClient):
resp = client.post("/api/reporting/summary", json={}) resp = client.post("/api/reporting/summary", json={})
assert resp.status_code == 400 assert resp.status_code == 400
assert resp.json()["detail"] == "Invalid input format" assert resp.json()["detail"] == "Invalid input format"
def test_reporting_endpoint_success(): def test_reporting_endpoint_success(client: TestClient):
client = TestClient(app) input_data: List[Dict[str, float]] = [
input_data = [
{"iteration": 1, "result": 10.0}, {"iteration": 1, "result": 10.0},
{"iteration": 2, "result": 20.0}, {"iteration": 2, "result": 20.0},
{"iteration": 3, "result": 30.0}, {"iteration": 3, "result": 30.0},
] ]
resp = client.post("/api/reporting/summary", json=input_data) resp = client.post("/api/reporting/summary", json=input_data)
assert resp.status_code == 200 assert resp.status_code == 200
data = resp.json() data: Dict[str, Any] = resp.json()
assert data["count"] == 3 assert data["count"] == 3
assert data["mean"] == pytest.approx(20.0) assert math.isclose(float(data["mean"]), 20.0)
assert math.isclose(float(data["variance"]), 66.6666666, rel_tol=1e-6)
assert math.isclose(float(data["value_at_risk_95"]), 11.0)
assert math.isclose(float(data["expected_shortfall_95"]), 10.0)
@pytest.mark.parametrize( validation_error_cases: List[tuple[List[Any], str]] = [
"payload,expected_detail", (["not-a-dict"], "Entry at index 0 must be an object"),
[ ([{"iteration": 1}], "Entry at index 0 must include numeric 'result'"),
(["not-a-dict"], "Entry at index 0 must be an object"), ([{"iteration": 1, "result": "bad"}],
([{"iteration": 1}], "Entry at index 0 must include numeric 'result'"), "Entry at index 0 must include numeric 'result'"),
([{"iteration": 1, "result": "bad"}], ]
"Entry at index 0 must include numeric 'result'"),
],
) @pytest.mark.parametrize("payload,expected_detail", validation_error_cases)
def test_reporting_endpoint_validation_errors(payload, expected_detail): def test_reporting_endpoint_validation_errors(
client = TestClient(app) client: TestClient, payload: List[Any], expected_detail: str
):
resp = client.post("/api/reporting/summary", json=payload) resp = client.post("/api/reporting/summary", json=payload)
assert resp.status_code == 400 assert resp.status_code == 400
assert resp.json()["detail"] == expected_detail assert resp.json()["detail"] == expected_detail

View File

@@ -1,14 +1,9 @@
# ensure project root is on sys.path for module imports from uuid import uuid4
from main import app
from routes.scenarios import router
from config.database import Base, engine
from fastapi.testclient import TestClient
import pytest
import os
import sys
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(__file__))))
# Create tables for testing from fastapi.testclient import TestClient
from config.database import Base, engine
from main import app
def setup_module(module): def setup_module(module):
@@ -23,14 +18,28 @@ client = TestClient(app)
def test_create_and_list_scenario(): def test_create_and_list_scenario():
# Create a scenario scenario_name = f"Scenario-{uuid4()}"
response = client.post( response = client.post(
"/api/scenarios/", json={"name": "Test", "description": "Desc"}) "/api/scenarios/",
json={"name": scenario_name, "description": "Integration test"},
)
assert response.status_code == 200 assert response.status_code == 200
data = response.json() data = response.json()
assert data["name"] == "Test" assert data["name"] == scenario_name
# List scenarios
response2 = client.get("/api/scenarios/") response2 = client.get("/api/scenarios/")
assert response2.status_code == 200 assert response2.status_code == 200
data2 = response2.json() data2 = response2.json()
assert any(s["name"] == "Test" for s in data2) assert any(s["name"] == scenario_name for s in data2)
def test_create_duplicate_scenario_rejected():
scenario_name = f"Duplicate-{uuid4()}"
payload = {"name": scenario_name, "description": "Primary"}
first_resp = client.post("/api/scenarios/", json=payload)
assert first_resp.status_code == 200
second_resp = client.post("/api/scenarios/", json=payload)
assert second_resp.status_code == 400
assert second_resp.json()["detail"] == "Scenario already exists"

View File

@@ -1,42 +1,111 @@
from services.simulation import run_simulation from uuid import uuid4
from main import app
from config.database import Base, engine
from fastapi.testclient import TestClient
import pytest import pytest
from fastapi.testclient import TestClient
from sqlalchemy.orm import Session
# Setup and teardown from models.simulation_result import SimulationResult
from services.simulation import run_simulation
def setup_module(module): @pytest.fixture
Base.metadata.create_all(bind=engine) def client(api_client: TestClient) -> TestClient:
return api_client
def teardown_module(module): def test_run_simulation_function_generates_samples():
Base.metadata.drop_all(bind=engine) params = [
{"name": "grade", "value": 1.8, "distribution": "normal", "std_dev": 0.2},
{
client = TestClient(app) "name": "recovery",
"value": 0.9,
# Direct function test "distribution": "uniform",
"min": 0.8,
"max": 0.95,
def test_run_simulation_function_returns_list(): },
results = run_simulation([], iterations=10) ]
results = run_simulation(params, iterations=5, seed=123)
assert isinstance(results, list) assert isinstance(results, list)
assert results == [] assert len(results) == 5
assert results[0]["iteration"] == 1
# Endpoint tests
def test_simulation_endpoint_no_params(): def test_simulation_endpoint_no_params(client: TestClient):
resp = client.post("/api/simulations/run", json=[]) scenario_payload = {
"name": f"NoParamScenario-{uuid4()}",
"description": "No parameters run",
}
scenario_resp = client.post("/api/scenarios/", json=scenario_payload)
assert scenario_resp.status_code == 200
scenario_id = scenario_resp.json()["id"]
resp = client.post(
"/api/simulations/run",
json={"scenario_id": scenario_id, "parameters": [], "iterations": 10},
)
assert resp.status_code == 400 assert resp.status_code == 400
assert resp.json()["detail"] == "No parameters provided" assert resp.json()["detail"] == "No parameters provided"
def test_simulation_endpoint_success(): def test_simulation_endpoint_success(
params = [{"name": "param1", "value": 2.5}] client: TestClient, db_session: Session
resp = client.post("/api/simulations/run", json=params) ):
scenario_payload = {
"name": f"SimScenario-{uuid4()}",
"description": "Simulation test",
}
scenario_resp = client.post("/api/scenarios/", json=scenario_payload)
assert scenario_resp.status_code == 200
scenario_id = scenario_resp.json()["id"]
params = [
{"name": "param1", "value": 2.5, "distribution": "normal", "std_dev": 0.5}
]
payload = {
"scenario_id": scenario_id,
"parameters": params,
"iterations": 10,
"seed": 42,
}
resp = client.post("/api/simulations/run", json=payload)
assert resp.status_code == 200 assert resp.status_code == 200
data = resp.json() data = resp.json()
assert isinstance(data, list) assert data["scenario_id"] == scenario_id
assert len(data["results"]) == 10
assert data["summary"]["count"] == 10
db_session.expire_all()
persisted = (
db_session.query(SimulationResult)
.filter(SimulationResult.scenario_id == scenario_id)
.all()
)
assert len(persisted) == 10
def test_simulation_endpoint_uses_stored_parameters(client: TestClient):
scenario_payload = {
"name": f"StoredParams-{uuid4()}",
"description": "Stored parameter simulation",
}
scenario_resp = client.post("/api/scenarios/", json=scenario_payload)
assert scenario_resp.status_code == 200
scenario_id = scenario_resp.json()["id"]
parameter_payload = {
"scenario_id": scenario_id,
"name": "grade",
"value": 1.5,
}
param_resp = client.post("/api/parameters/", json=parameter_payload)
assert param_resp.status_code == 200
resp = client.post(
"/api/simulations/run",
json={"scenario_id": scenario_id, "iterations": 3, "seed": 7},
)
assert resp.status_code == 200
data = resp.json()
assert data["summary"]["count"] == 3
assert len(data["results"]) == 3