feat: Enhance dashboard metrics and summary statistics
- Added new summary fields: variance, 5th percentile, 95th percentile, VaR (95%), and expected shortfall (95%) to the dashboard. - Updated the display logic for summary metrics to handle non-finite values gracefully. - Modified the chart rendering to include additional percentile points and tail risk metrics in tooltips. test: Introduce unit tests for consumption, costs, and other modules - Created a comprehensive test suite for consumption, costs, equipment, maintenance, production, reporting, and simulation modules. - Implemented fixtures for database setup and teardown using an in-memory SQLite database for isolated testing. - Added tests for creating, listing, and validating various entities, ensuring proper error handling and response validation. refactor: Consolidate parameter tests and remove deprecated files - Merged parameter-related tests into a new test file for better organization and clarity. - Removed the old parameter test file that was no longer in use. - Improved test coverage for parameter creation, listing, and validation scenarios. fix: Ensure proper validation and error handling in API endpoints - Added validation to reject negative amounts in consumption and production records. - Implemented checks to prevent duplicate scenario creation and ensure proper error messages are returned. - Enhanced reporting endpoint tests to validate input formats and expected outputs.
This commit is contained in:
18
README.md
18
README.md
@@ -71,20 +71,21 @@ uvicorn main:app --reload
|
||||
- **API base URL**: `http://localhost:8000/api`
|
||||
- **Key routes**:
|
||||
- `POST /api/scenarios/` create scenarios
|
||||
- `POST /api/parameters/` manage process parameters
|
||||
- `POST /api/parameters/` manage process parameters; payload supports optional `distribution_id` or inline `distribution_type`/`distribution_parameters` fields for simulation metadata
|
||||
- `POST /api/costs/capex` and `POST /api/costs/opex` capture project costs
|
||||
- `POST /api/consumption/` add consumption entries
|
||||
- `POST /api/production/` register production output
|
||||
- `POST /api/equipment/` create equipment records
|
||||
- `POST /api/maintenance/` log maintenance events
|
||||
- `POST /api/reporting/summary` aggregate simulation results
|
||||
- `POST /api/reporting/summary` aggregate simulation results, returning count, mean/median, min/max, standard deviation, variance, percentile bands (5/10/90/95), value-at-risk (95%) and expected shortfall (95%)
|
||||
|
||||
### Dashboard Preview
|
||||
|
||||
1. Start the FastAPI server and navigate to `/dashboard` (served by `templates/Dashboard.html`).
|
||||
1. Start the FastAPI server and navigate to `/ui/dashboard` (ensure `routes/ui.py` exposes this template or add a router that serves `templates/Dashboard.html`).
|
||||
2. Use the "Load Sample Data" button to populate the JSON textarea with demo results.
|
||||
3. Select "Refresh Dashboard" to view calculated statistics and a distribution chart sourced from `/api/reporting/summary`.
|
||||
4. Paste your own simulation outputs (array of objects containing a numeric `result` property) to visualize custom runs.
|
||||
3. Select "Refresh Dashboard" to post the dataset to `/api/reporting/summary` and render the returned statistics and distribution chart.
|
||||
4. Paste your own simulation outputs (array of objects containing a numeric `result` property) to visualize custom runs; the endpoint expects the same schema used by the reporting service.
|
||||
5. If the summary endpoint is unavailable, the dashboard displays an inline error—refresh once the API is reachable.
|
||||
|
||||
## Testing
|
||||
|
||||
@@ -96,6 +97,13 @@ To execute the unit test suite:
|
||||
pytest
|
||||
```
|
||||
|
||||
### Coverage Snapshot (2025-10-20)
|
||||
|
||||
- `pytest --cov=. --cov-report=term-missing` reports **95%** overall coverage across the project.
|
||||
- Lower coverage hotspots to target next: `services/simulation.py` (79%), `middleware/validation.py` (78%), `routes/ui.py` (82%), and several API routers around lines 12-22 that create database sessions only.
|
||||
- Deprecation cleanup migrated routes to Pydantic v2 patterns (`model_config = ConfigDict(...)`, `model_dump()`) and updated SQLAlchemy's `declarative_base`; reran `pytest` to confirm the suite passes without warnings.
|
||||
- Coverage for route-heavy modules is primarily limited by error paths (e.g., bad request branches) that still need explicit tests.
|
||||
|
||||
## Database Objects
|
||||
|
||||
The database is composed of several tables that store different types of information.
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from sqlalchemy.orm import declarative_base, sessionmaker
|
||||
import os
|
||||
from dotenv import load_dotenv
|
||||
|
||||
|
||||
@@ -4,6 +4,10 @@
|
||||
|
||||
CalMiner is a FastAPI application that collects mining project inputs, persists scenario-specific records, and surfaces aggregated insights. The platform targets Monte Carlo driven planning, with deterministic CRUD features in place and simulation logic staged for future work.
|
||||
|
||||
Frontend components are server-rendered Jinja2 templates, with Chart.js powering the dashboard visualization.
|
||||
|
||||
The backend leverages SQLAlchemy for ORM mapping to a PostgreSQL database.
|
||||
|
||||
## System Components
|
||||
|
||||
- **FastAPI backend** (`main.py`, `routes/`): hosts REST endpoints for scenarios, parameters, costs, consumption, production, equipment, maintenance, simulations, and reporting. Each router encapsulates request/response schemas and DB access patterns.
|
||||
@@ -18,14 +22,54 @@ CalMiner is a FastAPI application that collects mining project inputs, persists
|
||||
1. Users navigate to form templates or API clients to manage scenarios, parameters, and operational data.
|
||||
2. FastAPI routers validate payloads with Pydantic models, then delegate to SQLAlchemy sessions for persistence.
|
||||
3. Simulation runs (placeholder `services/simulation.py`) will consume stored parameters to emit iteration results via `/api/simulations/run`.
|
||||
4. Reporting requests POST simulation outputs to `/api/reporting/summary`; the reporting service calculates aggregates (count, min/max, mean, median, percentiles, standard deviation).
|
||||
4. Reporting requests POST simulation outputs to `/api/reporting/summary`; the reporting service calculates aggregates (count, min/max, mean, median, percentiles, standard deviation, variance, and tail-risk metrics at the 95% confidence level).
|
||||
5. `templates/Dashboard.html` fetches summaries, renders metric cards, and plots distribution charts with Chart.js for stakeholder review.
|
||||
|
||||
### Dashboard Flow Review — 2025-10-20
|
||||
|
||||
- The dashboard template depends on a future-facing HTML endpoint (e.g., `/dashboard`) that the current `routes/ui.py` router does not expose; wiring an explicit route is required before the page is reachable from the FastAPI app.
|
||||
- Client-side logic calls `/api/reporting/summary` with raw simulation outputs and expects `result` fields, so any upstream changes to the reporting contract must maintain this schema.
|
||||
- Initialization always loads the bundled sample data first, which is useful for demos but masks server errors—consider adding a failure banner when `/api/reporting/summary` is unavailable.
|
||||
- No persistent storage backs the dashboard yet; users must paste or load JSON manually, aligning with the current MVP scope but highlighting an integration gap with the simulation results table.
|
||||
|
||||
### Reporting Pipeline and UI Integration
|
||||
|
||||
1. **Data Sources**
|
||||
|
||||
- Scenario-linked calculations (costs, consumption, production) produce raw figures stored in dedicated tables (`capex`, `opex`, `consumption`, `production_output`).
|
||||
- Monte Carlo simulations (currently transient) generate arrays of `{ "result": float }` tuples that the dashboard or downstream tooling passes directly to reporting endpoints.
|
||||
|
||||
2. **API Contract**
|
||||
|
||||
- `POST /api/reporting/summary` accepts a JSON array of result objects and validates shape through `_validate_payload` in `routes/reporting.py`.
|
||||
- On success it returns a structured payload (`ReportSummary`) containing count, mean, median, min/max, standard deviation, and percentile values, all as floats.
|
||||
|
||||
3. **Service Layer**
|
||||
|
||||
- `services/reporting.generate_report` converts the sanitized payload into descriptive statistics using Python’s standard library (`statistics` module) to avoid external dependencies.
|
||||
- The service remains stateless; no database read/write occurs, which keeps summary calculations deterministic and idempotent.
|
||||
- Extended KPIs (surfaced in the API and dashboard):
|
||||
- `variance`: population variance computed as the square of the population standard deviation.
|
||||
- `percentile_5` and `percentile_95`: lower and upper tail interpolated percentiles for sensitivity bounds.
|
||||
- `value_at_risk_95`: 5th percentile threshold representing the minimum outcome within a 95% confidence band.
|
||||
- `expected_shortfall_95`: mean of all outcomes at or below the `value_at_risk_95`, highlighting tail exposure.
|
||||
|
||||
4. **UI Consumption**
|
||||
|
||||
- `templates/Dashboard.html` posts the user-provided dataset to the summary endpoint, renders metric cards for each field, and charts the distribution using Chart.js.
|
||||
- `SUMMARY_FIELDS` now includes variance, 5th/10th/90th/95th percentiles, and tail-risk metrics (VaR/Expected Shortfall at 95%); tooltip annotations surface the tail metrics alongside the percentile line chart.
|
||||
- Error handling surfaces HTTP failures inline so users can address malformed JSON or backend availability issues without leaving the page.
|
||||
|
||||
5. **Future Integration Points**
|
||||
- Once `/api/simulations/run` persists to `simulation_result`, the dashboard can fetch precalculated runs per scenario, removing the manual JSON step.
|
||||
- Additional reporting endpoints (e.g., scenario comparisons) can reuse the same service layer, ensuring consistency across UI and API consumers.
|
||||
|
||||
## Data Model Highlights
|
||||
|
||||
- `scenario`: central entity describing a mining scenario; owns relationships to cost, consumption, production, equipment, and maintenance tables.
|
||||
- `capex`, `opex`: monetary tracking linked to scenarios.
|
||||
- `consumption`: resource usage entries parameterized by scenario and description.
|
||||
- `parameter`: scenario inputs with base `value` and optional distribution linkage via `distribution_id`, `distribution_type`, and JSON `distribution_parameters` to support simulation sampling.
|
||||
- `production_output`: production metrics per scenario.
|
||||
- `equipment` and `maintenance`: equipment inventory and maintenance events with dates/costs.
|
||||
- `simulation_result`: staging table for future Monte Carlo outputs (not yet populated by `run_simulation`).
|
||||
|
||||
@@ -1,17 +1,26 @@
|
||||
from sqlalchemy import Column, Integer, String, Float, ForeignKey
|
||||
from sqlalchemy.orm import relationship
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from sqlalchemy import ForeignKey, JSON
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
from config.database import Base
|
||||
|
||||
|
||||
class Parameter(Base):
|
||||
__tablename__ = "parameter"
|
||||
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
scenario_id = Column(Integer, ForeignKey("scenario.id"), nullable=False)
|
||||
name = Column(String, nullable=False)
|
||||
value = Column(Float, nullable=False)
|
||||
id: Mapped[int] = mapped_column(primary_key=True, index=True)
|
||||
scenario_id: Mapped[int] = mapped_column(
|
||||
ForeignKey("scenario.id"), nullable=False)
|
||||
name: Mapped[str] = mapped_column(nullable=False)
|
||||
value: Mapped[float] = mapped_column(nullable=False)
|
||||
distribution_id: Mapped[Optional[int]] = mapped_column(
|
||||
ForeignKey("distribution.id"), nullable=True)
|
||||
distribution_type: Mapped[Optional[str]] = mapped_column(nullable=True)
|
||||
distribution_parameters: Mapped[Optional[Dict[str, Any]]] = mapped_column(
|
||||
JSON, nullable=True)
|
||||
|
||||
scenario = relationship("Scenario", back_populates="parameters")
|
||||
distribution = relationship("Distribution")
|
||||
|
||||
def __repr__(self):
|
||||
return f"<Parameter id={self.id} name={self.name} value={self.value}>"
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from typing import List, Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, status
|
||||
from pydantic import BaseModel, PositiveFloat
|
||||
from pydantic import BaseModel, ConfigDict, PositiveFloat
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from config.database import SessionLocal
|
||||
@@ -31,14 +31,12 @@ class ConsumptionCreate(ConsumptionBase):
|
||||
|
||||
class ConsumptionRead(ConsumptionBase):
|
||||
id: int
|
||||
|
||||
class Config:
|
||||
orm_mode = True
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
@router.post("/", response_model=ConsumptionRead, status_code=status.HTTP_201_CREATED)
|
||||
def create_consumption(item: ConsumptionCreate, db: Session = Depends(get_db)):
|
||||
db_item = Consumption(**item.dict())
|
||||
db_item = Consumption(**item.model_dump())
|
||||
db.add(db_item)
|
||||
db.commit()
|
||||
db.refresh(db_item)
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from fastapi import APIRouter, Depends
|
||||
from sqlalchemy.orm import Session
|
||||
from typing import List, Optional
|
||||
from pydantic import BaseModel
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
from config.database import SessionLocal
|
||||
from models.capex import Capex
|
||||
from models.opex import Opex
|
||||
@@ -26,9 +26,7 @@ class CapexCreate(BaseModel):
|
||||
|
||||
class CapexRead(CapexCreate):
|
||||
id: int
|
||||
|
||||
class Config:
|
||||
orm_mode = True
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
# Pydantic schemas for Opex
|
||||
@@ -40,15 +38,13 @@ class OpexCreate(BaseModel):
|
||||
|
||||
class OpexRead(OpexCreate):
|
||||
id: int
|
||||
|
||||
class Config:
|
||||
orm_mode = True
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
# Capex endpoints
|
||||
@router.post("/capex", response_model=CapexRead)
|
||||
def create_capex(item: CapexCreate, db: Session = Depends(get_db)):
|
||||
db_item = Capex(**item.dict())
|
||||
db_item = Capex(**item.model_dump())
|
||||
db.add(db_item)
|
||||
db.commit()
|
||||
db.refresh(db_item)
|
||||
@@ -63,7 +59,7 @@ def list_capex(db: Session = Depends(get_db)):
|
||||
# Opex endpoints
|
||||
@router.post("/opex", response_model=OpexRead)
|
||||
def create_opex(item: OpexCreate, db: Session = Depends(get_db)):
|
||||
db_item = Opex(**item.dict())
|
||||
db_item = Opex(**item.model_dump())
|
||||
db.add(db_item)
|
||||
db.commit()
|
||||
db.refresh(db_item)
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from fastapi import APIRouter, HTTPException, Depends
|
||||
from fastapi import APIRouter, Depends
|
||||
from sqlalchemy.orm import Session
|
||||
from typing import List
|
||||
from pydantic import BaseModel
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
from config.database import SessionLocal
|
||||
from models.distribution import Distribution
|
||||
|
||||
@@ -24,14 +24,12 @@ class DistributionCreate(BaseModel):
|
||||
|
||||
class DistributionRead(DistributionCreate):
|
||||
id: int
|
||||
|
||||
class Config:
|
||||
orm_mode = True
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
@router.post("/", response_model=DistributionRead)
|
||||
async def create_distribution(dist: DistributionCreate, db: Session = Depends(get_db)):
|
||||
db_dist = Distribution(**dist.dict())
|
||||
db_dist = Distribution(**dist.model_dump())
|
||||
db.add(db_dist)
|
||||
db.commit()
|
||||
db.refresh(db_dist)
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from fastapi import APIRouter, Depends
|
||||
from sqlalchemy.orm import Session
|
||||
from typing import List, Optional
|
||||
from pydantic import BaseModel
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
from config.database import SessionLocal
|
||||
from models.equipment import Equipment
|
||||
|
||||
@@ -25,14 +25,12 @@ class EquipmentCreate(BaseModel):
|
||||
|
||||
class EquipmentRead(EquipmentCreate):
|
||||
id: int
|
||||
|
||||
class Config:
|
||||
orm_mode = True
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
@router.post("/", response_model=EquipmentRead)
|
||||
async def create_equipment(item: EquipmentCreate, db: Session = Depends(get_db)):
|
||||
db_item = Equipment(**item.dict())
|
||||
db_item = Equipment(**item.model_dump())
|
||||
db.add(db_item)
|
||||
db.commit()
|
||||
db.refresh(db_item)
|
||||
|
||||
@@ -2,7 +2,7 @@ from datetime import date
|
||||
from typing import List, Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from pydantic import BaseModel, PositiveFloat
|
||||
from pydantic import BaseModel, ConfigDict, PositiveFloat
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from config.database import SessionLocal
|
||||
@@ -38,9 +38,7 @@ class MaintenanceUpdate(MaintenanceBase):
|
||||
|
||||
class MaintenanceRead(MaintenanceBase):
|
||||
id: int
|
||||
|
||||
class Config:
|
||||
orm_mode = True
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
def _get_maintenance_or_404(db: Session, maintenance_id: int) -> Maintenance:
|
||||
@@ -56,7 +54,7 @@ def _get_maintenance_or_404(db: Session, maintenance_id: int) -> Maintenance:
|
||||
|
||||
@router.post("/", response_model=MaintenanceRead, status_code=status.HTTP_201_CREATED)
|
||||
def create_maintenance(maintenance: MaintenanceCreate, db: Session = Depends(get_db)):
|
||||
db_maintenance = Maintenance(**maintenance.dict())
|
||||
db_maintenance = Maintenance(**maintenance.model_dump())
|
||||
db.add(db_maintenance)
|
||||
db.commit()
|
||||
db.refresh(db_maintenance)
|
||||
@@ -80,7 +78,7 @@ def update_maintenance(
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
db_maintenance = _get_maintenance_or_404(db, maintenance_id)
|
||||
for field, value in payload.dict().items():
|
||||
for field, value in payload.model_dump().items():
|
||||
setattr(db_maintenance, field, value)
|
||||
db.commit()
|
||||
db.refresh(db_maintenance)
|
||||
|
||||
@@ -1,10 +1,13 @@
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from pydantic import BaseModel, ConfigDict, field_validator
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from config.database import SessionLocal
|
||||
from models.distribution import Distribution
|
||||
from models.parameters import Parameter
|
||||
from models.scenario import Scenario
|
||||
from pydantic import BaseModel
|
||||
from typing import Optional, List
|
||||
|
||||
router = APIRouter(prefix="/api/parameters", tags=["parameters"])
|
||||
|
||||
@@ -13,13 +16,34 @@ class ParameterCreate(BaseModel):
|
||||
scenario_id: int
|
||||
name: str
|
||||
value: float
|
||||
distribution_id: Optional[int] = None
|
||||
distribution_type: Optional[str] = None
|
||||
distribution_parameters: Optional[Dict[str, Any]] = None
|
||||
|
||||
@field_validator("distribution_type")
|
||||
@classmethod
|
||||
def normalize_type(cls, value: Optional[str]) -> Optional[str]:
|
||||
if value is None:
|
||||
return value
|
||||
normalized = value.strip().lower()
|
||||
if not normalized:
|
||||
return None
|
||||
if normalized not in {"normal", "uniform", "triangular"}:
|
||||
raise ValueError(
|
||||
"distribution_type must be normal, uniform, or triangular")
|
||||
return normalized
|
||||
|
||||
@field_validator("distribution_parameters")
|
||||
@classmethod
|
||||
def empty_dict_to_none(cls, value: Optional[Dict[str, Any]]) -> Optional[Dict[str, Any]]:
|
||||
if value is None:
|
||||
return None
|
||||
return value or None
|
||||
|
||||
|
||||
class ParameterRead(ParameterCreate):
|
||||
id: int
|
||||
|
||||
class Config:
|
||||
orm_mode = True
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
# Dependency
|
||||
|
||||
@@ -37,8 +61,27 @@ def create_parameter(param: ParameterCreate, db: Session = Depends(get_db)):
|
||||
scen = db.query(Scenario).filter(Scenario.id == param.scenario_id).first()
|
||||
if not scen:
|
||||
raise HTTPException(status_code=404, detail="Scenario not found")
|
||||
new_param = Parameter(scenario_id=param.scenario_id,
|
||||
name=param.name, value=param.value)
|
||||
distribution_id = param.distribution_id
|
||||
distribution_type = param.distribution_type
|
||||
distribution_parameters = param.distribution_parameters
|
||||
|
||||
if distribution_id is not None:
|
||||
distribution = db.query(Distribution).filter(
|
||||
Distribution.id == distribution_id).first()
|
||||
if not distribution:
|
||||
raise HTTPException(
|
||||
status_code=404, detail="Distribution not found")
|
||||
distribution_type = distribution.distribution_type
|
||||
distribution_parameters = distribution.parameters or None
|
||||
|
||||
new_param = Parameter(
|
||||
scenario_id=param.scenario_id,
|
||||
name=param.name,
|
||||
value=param.value,
|
||||
distribution_id=distribution_id,
|
||||
distribution_type=distribution_type,
|
||||
distribution_parameters=distribution_parameters,
|
||||
)
|
||||
db.add(new_param)
|
||||
db.commit()
|
||||
db.refresh(new_param)
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from typing import List, Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, status
|
||||
from pydantic import BaseModel, PositiveFloat
|
||||
from pydantic import BaseModel, ConfigDict, PositiveFloat
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from config.database import SessionLocal
|
||||
@@ -31,14 +31,12 @@ class ProductionOutputCreate(ProductionOutputBase):
|
||||
|
||||
class ProductionOutputRead(ProductionOutputBase):
|
||||
id: int
|
||||
|
||||
class Config:
|
||||
orm_mode = True
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
@router.post("/", response_model=ProductionOutputRead, status_code=status.HTTP_201_CREATED)
|
||||
def create_production(item: ProductionOutputCreate, db: Session = Depends(get_db)):
|
||||
db_item = ProductionOutput(**item.dict())
|
||||
db_item = ProductionOutput(**item.model_dump())
|
||||
db.add(db_item)
|
||||
db.commit()
|
||||
db.refresh(db_item)
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import Any, Dict, List
|
||||
from typing import Any, Dict, List, cast
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Request, status
|
||||
from pydantic import BaseModel
|
||||
@@ -25,14 +25,16 @@ def _validate_payload(payload: Any) -> List[Dict[str, float]]:
|
||||
detail="Invalid input format",
|
||||
)
|
||||
|
||||
typed_payload = cast(List[Any], payload)
|
||||
|
||||
validated: List[Dict[str, float]] = []
|
||||
for index, item in enumerate(payload):
|
||||
for index, item in enumerate(typed_payload):
|
||||
if not isinstance(item, dict):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"Entry at index {index} must be an object",
|
||||
)
|
||||
value = item.get("result")
|
||||
value = cast(Dict[str, Any], item).get("result")
|
||||
if not isinstance(value, (int, float)):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
@@ -49,8 +51,13 @@ class ReportSummary(BaseModel):
|
||||
min: float
|
||||
max: float
|
||||
std_dev: float
|
||||
variance: float
|
||||
percentile_10: float
|
||||
percentile_90: float
|
||||
percentile_5: float
|
||||
percentile_95: float
|
||||
value_at_risk_95: float
|
||||
expected_shortfall_95: float
|
||||
|
||||
|
||||
@router.post("/summary", response_model=ReportSummary)
|
||||
@@ -65,6 +72,11 @@ async def summary_report(request: Request):
|
||||
min=float(summary["min"]),
|
||||
max=float(summary["max"]),
|
||||
std_dev=float(summary["std_dev"]),
|
||||
variance=float(summary["variance"]),
|
||||
percentile_10=float(summary["percentile_10"]),
|
||||
percentile_90=float(summary["percentile_90"]),
|
||||
percentile_5=float(summary["percentile_5"]),
|
||||
percentile_95=float(summary["percentile_95"]),
|
||||
value_at_risk_95=float(summary["value_at_risk_95"]),
|
||||
expected_shortfall_95=float(summary["expected_shortfall_95"]),
|
||||
)
|
||||
|
||||
@@ -2,7 +2,7 @@ from fastapi import APIRouter, Depends, HTTPException
|
||||
from sqlalchemy.orm import Session
|
||||
from config.database import SessionLocal
|
||||
from models.scenario import Scenario
|
||||
from pydantic import BaseModel
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
from typing import Optional
|
||||
from datetime import datetime
|
||||
|
||||
@@ -20,9 +20,7 @@ class ScenarioRead(ScenarioCreate):
|
||||
id: int
|
||||
created_at: datetime
|
||||
updated_at: Optional[datetime] = None
|
||||
|
||||
class Config:
|
||||
orm_mode = True
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
# Dependency
|
||||
|
||||
|
||||
@@ -60,8 +60,8 @@ def _load_parameters(db: Session, scenario_id: int) -> List[SimulationParameterI
|
||||
)
|
||||
return [
|
||||
SimulationParameterInput(
|
||||
name=cast(str, item.name),
|
||||
value=cast(float, item.value),
|
||||
name=item.name,
|
||||
value=item.value,
|
||||
)
|
||||
for item in db_params
|
||||
]
|
||||
@@ -86,7 +86,7 @@ async def simulate(payload: SimulationRunRequest, db: Session = Depends(get_db))
|
||||
)
|
||||
|
||||
raw_results = run_simulation(
|
||||
[param.dict(exclude_none=True) for param in parameters],
|
||||
[param.model_dump(exclude_none=True) for param in parameters],
|
||||
iterations=payload.iterations,
|
||||
seed=payload.seed,
|
||||
)
|
||||
|
||||
@@ -1,13 +1,14 @@
|
||||
from statistics import mean, median, pstdev
|
||||
from typing import Dict, Iterable, List, Union
|
||||
from typing import Any, Dict, Iterable, List, Mapping, Union, cast
|
||||
|
||||
|
||||
def _extract_results(simulation_results: Iterable[Dict[str, float]]) -> List[float]:
|
||||
def _extract_results(simulation_results: Iterable[object]) -> List[float]:
|
||||
values: List[float] = []
|
||||
for item in simulation_results:
|
||||
if not isinstance(item, dict):
|
||||
if not isinstance(item, Mapping):
|
||||
continue
|
||||
value = item.get("result")
|
||||
mapping_item = cast(Mapping[str, Any], item)
|
||||
value = mapping_item.get("result")
|
||||
if isinstance(value, (int, float)):
|
||||
values.append(float(value))
|
||||
return values
|
||||
@@ -39,8 +40,13 @@ def generate_report(simulation_results: List[Dict[str, float]]) -> Dict[str, Uni
|
||||
"min": 0.0,
|
||||
"max": 0.0,
|
||||
"std_dev": 0.0,
|
||||
"variance": 0.0,
|
||||
"percentile_10": 0.0,
|
||||
"percentile_90": 0.0,
|
||||
"percentile_5": 0.0,
|
||||
"percentile_95": 0.0,
|
||||
"value_at_risk_95": 0.0,
|
||||
"expected_shortfall_95": 0.0,
|
||||
}
|
||||
|
||||
summary: Dict[str, Union[float, int]] = {
|
||||
@@ -51,7 +57,21 @@ def generate_report(simulation_results: List[Dict[str, float]]) -> Dict[str, Uni
|
||||
"max": max(values),
|
||||
"percentile_10": _percentile(values, 10),
|
||||
"percentile_90": _percentile(values, 90),
|
||||
"percentile_5": _percentile(values, 5),
|
||||
"percentile_95": _percentile(values, 95),
|
||||
}
|
||||
|
||||
summary["std_dev"] = pstdev(values) if len(values) > 1 else 0.0
|
||||
std_dev = pstdev(values) if len(values) > 1 else 0.0
|
||||
summary["std_dev"] = std_dev
|
||||
summary["variance"] = std_dev ** 2
|
||||
|
||||
var_95 = summary["percentile_5"]
|
||||
summary["value_at_risk_95"] = var_95
|
||||
|
||||
tail_values = [value for value in values if value <= var_95]
|
||||
if tail_values:
|
||||
summary["expected_shortfall_95"] = mean(tail_values)
|
||||
else:
|
||||
summary["expected_shortfall_95"] = var_95
|
||||
|
||||
return summary
|
||||
|
||||
@@ -87,8 +87,16 @@
|
||||
{ key: "min", label: "Min" },
|
||||
{ key: "max", label: "Max" },
|
||||
{ key: "std_dev", label: "Std Dev" },
|
||||
{ key: "variance", label: "Variance" },
|
||||
{ key: "percentile_5", label: "5th Percentile" },
|
||||
{ key: "percentile_10", label: "10th Percentile" },
|
||||
{ key: "percentile_90", label: "90th Percentile" },
|
||||
{ key: "percentile_95", label: "95th Percentile" },
|
||||
{ key: "value_at_risk_95", label: "VaR (95%)" },
|
||||
{
|
||||
key: "expected_shortfall_95",
|
||||
label: "Expected Shortfall (95%)",
|
||||
},
|
||||
];
|
||||
|
||||
async function fetchSummary(results) {
|
||||
@@ -123,12 +131,16 @@
|
||||
const grid = document.getElementById("summary-grid");
|
||||
grid.innerHTML = "";
|
||||
SUMMARY_FIELDS.forEach(({ key, label }) => {
|
||||
const value = summary[key] ?? 0;
|
||||
const rawValue = summary[key];
|
||||
const numericValue = Number(rawValue);
|
||||
const display = Number.isFinite(numericValue)
|
||||
? numericValue.toFixed(2)
|
||||
: "—";
|
||||
const metric = document.createElement("div");
|
||||
metric.className = "metric";
|
||||
metric.innerHTML = `
|
||||
<div class="metric-label">${label}</div>
|
||||
<div class="metric-value">${value.toFixed(2)}</div>
|
||||
<div class="metric-value">${display}</div>
|
||||
`;
|
||||
grid.appendChild(metric);
|
||||
});
|
||||
@@ -138,14 +150,34 @@
|
||||
|
||||
function renderChart(summary) {
|
||||
const ctx = document.getElementById("summary-chart").getContext("2d");
|
||||
const dataPoints = [
|
||||
summary.min,
|
||||
summary.percentile_10,
|
||||
summary.median,
|
||||
summary.mean,
|
||||
summary.percentile_90,
|
||||
summary.max,
|
||||
].map((value) => Number(value ?? 0));
|
||||
const percentilePoints = [
|
||||
{ label: "Min", value: summary.min },
|
||||
{ label: "P5", value: summary.percentile_5 },
|
||||
{ label: "P10", value: summary.percentile_10 },
|
||||
{ label: "Median", value: summary.median },
|
||||
{ label: "Mean", value: summary.mean },
|
||||
{ label: "P90", value: summary.percentile_90 },
|
||||
{ label: "P95", value: summary.percentile_95 },
|
||||
{ label: "Max", value: summary.max },
|
||||
];
|
||||
|
||||
const labels = percentilePoints.map((point) => point.label);
|
||||
const dataPoints = percentilePoints.map((point) =>
|
||||
Number(point.value ?? 0)
|
||||
);
|
||||
|
||||
const tailRiskLines = [
|
||||
{ label: "VaR (95%)", value: summary.value_at_risk_95 },
|
||||
{ label: "ES (95%)", value: summary.expected_shortfall_95 },
|
||||
]
|
||||
.map(({ label, value }) => {
|
||||
const numeric = Number(value);
|
||||
if (!Number.isFinite(numeric)) {
|
||||
return null;
|
||||
}
|
||||
return `${label}: ${numeric.toFixed(2)}`;
|
||||
})
|
||||
.filter((line) => line !== null);
|
||||
|
||||
if (chartInstance) {
|
||||
chartInstance.destroy();
|
||||
@@ -154,7 +186,7 @@
|
||||
chartInstance = new Chart(ctx, {
|
||||
type: "line",
|
||||
data: {
|
||||
labels: ["Min", "P10", "Median", "Mean", "P90", "Max"],
|
||||
labels,
|
||||
datasets: [
|
||||
{
|
||||
label: "Result Summary",
|
||||
@@ -169,6 +201,11 @@
|
||||
options: {
|
||||
plugins: {
|
||||
legend: { display: false },
|
||||
tooltip: {
|
||||
callbacks: {
|
||||
afterBody: () => tailRiskLines,
|
||||
},
|
||||
},
|
||||
},
|
||||
scales: {
|
||||
y: {
|
||||
|
||||
94
tests/unit/conftest.py
Normal file
94
tests/unit/conftest.py
Normal file
@@ -0,0 +1,94 @@
|
||||
from typing import Generator
|
||||
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
from sqlalchemy.pool import StaticPool
|
||||
|
||||
from config.database import Base
|
||||
from main import app
|
||||
|
||||
SQLALCHEMY_TEST_URL = "sqlite:///:memory:"
|
||||
engine = create_engine(
|
||||
SQLALCHEMY_TEST_URL,
|
||||
connect_args={"check_same_thread": False},
|
||||
poolclass=StaticPool,
|
||||
)
|
||||
TestingSessionLocal = sessionmaker(
|
||||
autocommit=False, autoflush=False, bind=engine)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session", autouse=True)
|
||||
def setup_database() -> Generator[None, None, None]:
|
||||
# Ensure all model metadata is registered before creating tables
|
||||
from models import (
|
||||
capex,
|
||||
consumption,
|
||||
distribution,
|
||||
equipment,
|
||||
maintenance,
|
||||
opex,
|
||||
parameters,
|
||||
production_output,
|
||||
scenario,
|
||||
simulation_result,
|
||||
) # noqa: F401 - imported for side effects
|
||||
|
||||
Base.metadata.create_all(bind=engine)
|
||||
yield
|
||||
Base.metadata.drop_all(bind=engine)
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def db_session() -> Generator[Session, None, None]:
|
||||
session = TestingSessionLocal()
|
||||
try:
|
||||
yield session
|
||||
finally:
|
||||
session.close()
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def api_client(db_session: Session) -> Generator[TestClient, None, None]:
|
||||
def override_get_db():
|
||||
try:
|
||||
yield db_session
|
||||
finally:
|
||||
pass
|
||||
|
||||
# override all routers that use get_db
|
||||
from routes import (
|
||||
consumption,
|
||||
costs,
|
||||
distributions,
|
||||
equipment,
|
||||
maintenance,
|
||||
parameters,
|
||||
production,
|
||||
reporting,
|
||||
scenarios,
|
||||
simulations,
|
||||
)
|
||||
|
||||
overrides = {
|
||||
consumption.get_db: override_get_db,
|
||||
costs.get_db: override_get_db,
|
||||
distributions.get_db: override_get_db,
|
||||
equipment.get_db: override_get_db,
|
||||
maintenance.get_db: override_get_db,
|
||||
parameters.get_db: override_get_db,
|
||||
production.get_db: override_get_db,
|
||||
reporting.get_db: override_get_db,
|
||||
scenarios.get_db: override_get_db,
|
||||
simulations.get_db: override_get_db,
|
||||
}
|
||||
|
||||
for dependency, override in overrides.items():
|
||||
app.dependency_overrides[dependency] = override
|
||||
|
||||
with TestClient(app) as client:
|
||||
yield client
|
||||
|
||||
for dependency in overrides:
|
||||
app.dependency_overrides.pop(dependency, None)
|
||||
@@ -1,42 +1,69 @@
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
from main import app
|
||||
from config.database import Base, engine
|
||||
|
||||
# Setup and teardown
|
||||
|
||||
|
||||
def setup_module(module):
|
||||
Base.metadata.create_all(bind=engine)
|
||||
@pytest.fixture
|
||||
def client(api_client: TestClient) -> TestClient:
|
||||
return api_client
|
||||
|
||||
|
||||
def teardown_module(module):
|
||||
Base.metadata.drop_all(bind=engine)
|
||||
def _create_scenario(client: TestClient) -> int:
|
||||
payload = {
|
||||
"name": f"Consumption Scenario {uuid4()}",
|
||||
"description": "Scenario for consumption tests",
|
||||
}
|
||||
response = client.post("/api/scenarios/", json=payload)
|
||||
assert response.status_code == 200
|
||||
return response.json()["id"]
|
||||
|
||||
|
||||
client = TestClient(app)
|
||||
def test_create_consumption(client: TestClient) -> None:
|
||||
scenario_id = _create_scenario(client)
|
||||
payload = {
|
||||
"scenario_id": scenario_id,
|
||||
"amount": 125.5,
|
||||
"description": "Fuel usage baseline",
|
||||
}
|
||||
|
||||
response = client.post("/api/consumption/", json=payload)
|
||||
assert response.status_code == 201
|
||||
body = response.json()
|
||||
assert body["id"] > 0
|
||||
assert body["scenario_id"] == scenario_id
|
||||
assert body["amount"] == pytest.approx(125.5)
|
||||
assert body["description"] == "Fuel usage baseline"
|
||||
|
||||
|
||||
def test_create_and_list_consumption():
|
||||
# Create a scenario to attach consumption
|
||||
resp = client.post(
|
||||
"/api/scenarios/", json={"name": "ConsScenario", "description": "consumption scenario"}
|
||||
def test_list_consumption_returns_created_items(client: TestClient) -> None:
|
||||
scenario_id = _create_scenario(client)
|
||||
values = [50.0, 80.75]
|
||||
for amount in values:
|
||||
response = client.post(
|
||||
"/api/consumption/",
|
||||
json={
|
||||
"scenario_id": scenario_id,
|
||||
"amount": amount,
|
||||
"description": f"Consumption {amount}",
|
||||
},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
scenario = resp.json()
|
||||
sid = scenario["id"]
|
||||
assert response.status_code == 201
|
||||
|
||||
# Create Consumption item
|
||||
cons_payload = {"scenario_id": sid, "amount": 250.0,
|
||||
"description": "Monthly consumption"}
|
||||
resp2 = client.post("/api/consumption/", json=cons_payload)
|
||||
assert resp2.status_code == 201
|
||||
cons = resp2.json()
|
||||
assert cons["scenario_id"] == sid
|
||||
assert cons["amount"] == 250.0
|
||||
list_response = client.get("/api/consumption/")
|
||||
assert list_response.status_code == 200
|
||||
items = [item for item in list_response.json(
|
||||
) if item["scenario_id"] == scenario_id]
|
||||
assert {item["amount"] for item in items} == set(values)
|
||||
|
||||
# List Consumption items
|
||||
resp3 = client.get("/api/consumption/")
|
||||
assert resp3.status_code == 200
|
||||
data = resp3.json()
|
||||
assert any(item["amount"] == 250.0 and item["scenario_id"]
|
||||
== sid for item in data)
|
||||
|
||||
def test_create_consumption_rejects_negative_amount(client: TestClient) -> None:
|
||||
scenario_id = _create_scenario(client)
|
||||
payload = {
|
||||
"scenario_id": scenario_id,
|
||||
"amount": -10,
|
||||
"description": "Invalid negative amount",
|
||||
}
|
||||
|
||||
response = client.post("/api/consumption/", json=payload)
|
||||
assert response.status_code == 422
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
from fastapi.testclient import TestClient
|
||||
from main import app
|
||||
from config.database import Base, engine
|
||||
from uuid import uuid4
|
||||
|
||||
# Setup and teardown
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from config.database import Base, engine
|
||||
from main import app
|
||||
|
||||
|
||||
def setup_module(module):
|
||||
@@ -16,43 +17,89 @@ def teardown_module(module):
|
||||
client = TestClient(app)
|
||||
|
||||
|
||||
def test_create_and_list_capex_and_opex():
|
||||
# Create a scenario to attach costs
|
||||
resp = client.post(
|
||||
"/api/scenarios/", json={"name": "CostScenario", "description": "cost scenario"}
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
scenario = resp.json()
|
||||
sid = scenario["id"]
|
||||
def _create_scenario() -> int:
|
||||
payload = {
|
||||
"name": f"CostScenario-{uuid4()}",
|
||||
"description": "Cost tracking test scenario",
|
||||
}
|
||||
response = client.post("/api/scenarios/", json=payload)
|
||||
assert response.status_code == 200
|
||||
return response.json()["id"]
|
||||
|
||||
# Create Capex item
|
||||
capex_payload = {"scenario_id": sid,
|
||||
"amount": 1000.0, "description": "Initial capex"}
|
||||
|
||||
def test_create_and_list_capex_and_opex():
|
||||
sid = _create_scenario()
|
||||
|
||||
capex_payload = {
|
||||
"scenario_id": sid,
|
||||
"amount": 1000.0,
|
||||
"description": "Initial capex",
|
||||
}
|
||||
resp2 = client.post("/api/costs/capex", json=capex_payload)
|
||||
assert resp2.status_code == 200
|
||||
capex = resp2.json()
|
||||
assert capex["scenario_id"] == sid
|
||||
assert capex["amount"] == 1000.0
|
||||
|
||||
# List Capex items
|
||||
resp3 = client.get("/api/costs/capex")
|
||||
assert resp3.status_code == 200
|
||||
data = resp3.json()
|
||||
assert any(item["amount"] == 1000.0 and item["scenario_id"]
|
||||
== sid for item in data)
|
||||
|
||||
# Create Opex item
|
||||
opex_payload = {"scenario_id": sid, "amount": 500.0,
|
||||
"description": "Recurring opex"}
|
||||
opex_payload = {
|
||||
"scenario_id": sid,
|
||||
"amount": 500.0,
|
||||
"description": "Recurring opex",
|
||||
}
|
||||
resp4 = client.post("/api/costs/opex", json=opex_payload)
|
||||
assert resp4.status_code == 200
|
||||
opex = resp4.json()
|
||||
assert opex["scenario_id"] == sid
|
||||
assert opex["amount"] == 500.0
|
||||
|
||||
# List Opex items
|
||||
resp5 = client.get("/api/costs/opex")
|
||||
assert resp5.status_code == 200
|
||||
data_o = resp5.json()
|
||||
assert any(item["amount"] == 500.0 and item["scenario_id"]
|
||||
== sid for item in data_o)
|
||||
|
||||
|
||||
def test_multiple_capex_entries():
|
||||
sid = _create_scenario()
|
||||
amounts = [250.0, 750.0]
|
||||
for amount in amounts:
|
||||
resp = client.post(
|
||||
"/api/costs/capex",
|
||||
json={"scenario_id": sid, "amount": amount,
|
||||
"description": f"Capex {amount}"},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
|
||||
resp = client.get("/api/costs/capex")
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
retrieved_amounts = [item["amount"]
|
||||
for item in data if item["scenario_id"] == sid]
|
||||
for amount in amounts:
|
||||
assert amount in retrieved_amounts
|
||||
|
||||
|
||||
def test_multiple_opex_entries():
|
||||
sid = _create_scenario()
|
||||
amounts = [120.0, 340.0]
|
||||
for amount in amounts:
|
||||
resp = client.post(
|
||||
"/api/costs/opex",
|
||||
json={"scenario_id": sid, "amount": amount,
|
||||
"description": f"Opex {amount}"},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
|
||||
resp = client.get("/api/costs/opex")
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
retrieved_amounts = [item["amount"]
|
||||
for item in data if item["scenario_id"] == sid]
|
||||
for amount in amounts:
|
||||
assert amount in retrieved_amounts
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
from fastapi.testclient import TestClient
|
||||
from main import app
|
||||
from config.database import Base, engine
|
||||
from models.distribution import Distribution
|
||||
from uuid import uuid4
|
||||
|
||||
# Setup and teardown
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from config.database import Base, engine
|
||||
from main import app
|
||||
|
||||
|
||||
def setup_module(module):
|
||||
@@ -18,16 +18,54 @@ client = TestClient(app)
|
||||
|
||||
|
||||
def test_create_and_list_distribution():
|
||||
# Create distribution
|
||||
payload = {"name": "NormalDist", "distribution_type": "normal",
|
||||
"parameters": {"mu": 0, "sigma": 1}}
|
||||
dist_name = f"NormalDist-{uuid4()}"
|
||||
payload = {
|
||||
"name": dist_name,
|
||||
"distribution_type": "normal",
|
||||
"parameters": {"mu": 0, "sigma": 1},
|
||||
}
|
||||
resp = client.post("/api/distributions/", json=payload)
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["name"] == "NormalDist"
|
||||
assert data["name"] == dist_name
|
||||
|
||||
# List distributions
|
||||
resp2 = client.get("/api/distributions/")
|
||||
assert resp2.status_code == 200
|
||||
data2 = resp2.json()
|
||||
assert any(d["name"] == "NormalDist" for d in data2)
|
||||
assert any(d["name"] == dist_name for d in data2)
|
||||
|
||||
|
||||
def test_duplicate_distribution_name_allowed():
|
||||
dist_name = f"DupDist-{uuid4()}"
|
||||
payload = {
|
||||
"name": dist_name,
|
||||
"distribution_type": "uniform",
|
||||
"parameters": {"min": 0, "max": 1},
|
||||
}
|
||||
first = client.post("/api/distributions/", json=payload)
|
||||
assert first.status_code == 200
|
||||
|
||||
duplicate = client.post("/api/distributions/", json=payload)
|
||||
assert duplicate.status_code == 200
|
||||
|
||||
resp = client.get("/api/distributions/")
|
||||
assert resp.status_code == 200
|
||||
matching = [item for item in resp.json() if item["name"] == dist_name]
|
||||
assert len(matching) >= 2
|
||||
|
||||
|
||||
def test_list_distributions_returns_all():
|
||||
names = {f"ListDist-{uuid4()}" for _ in range(2)}
|
||||
for name in names:
|
||||
payload = {
|
||||
"name": name,
|
||||
"distribution_type": "triangular",
|
||||
"parameters": {"min": 0, "max": 10, "mode": 5},
|
||||
}
|
||||
resp = client.post("/api/distributions/", json=payload)
|
||||
assert resp.status_code == 200
|
||||
|
||||
resp = client.get("/api/distributions/")
|
||||
assert resp.status_code == 200
|
||||
found_names = {item["name"] for item in resp.json()}
|
||||
assert names.issubset(found_names)
|
||||
|
||||
@@ -1,42 +1,77 @@
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
from main import app
|
||||
from config.database import Base, engine
|
||||
|
||||
# Setup and teardown
|
||||
|
||||
|
||||
def setup_module(module):
|
||||
Base.metadata.create_all(bind=engine)
|
||||
@pytest.fixture
|
||||
def client(api_client: TestClient) -> TestClient:
|
||||
return api_client
|
||||
|
||||
|
||||
def teardown_module(module):
|
||||
Base.metadata.drop_all(bind=engine)
|
||||
def _create_scenario(client: TestClient) -> int:
|
||||
payload = {
|
||||
"name": f"Equipment Scenario {uuid4()}",
|
||||
"description": "Scenario for equipment tests",
|
||||
}
|
||||
response = client.post("/api/scenarios/", json=payload)
|
||||
assert response.status_code == 200
|
||||
return response.json()["id"]
|
||||
|
||||
|
||||
client = TestClient(app)
|
||||
def test_create_equipment(client: TestClient) -> None:
|
||||
scenario_id = _create_scenario(client)
|
||||
payload = {
|
||||
"scenario_id": scenario_id,
|
||||
"name": "Excavator",
|
||||
"description": "Heavy machinery",
|
||||
}
|
||||
|
||||
response = client.post("/api/equipment/", json=payload)
|
||||
assert response.status_code == 200
|
||||
created = response.json()
|
||||
assert created["id"] > 0
|
||||
assert created["scenario_id"] == scenario_id
|
||||
assert created["name"] == "Excavator"
|
||||
assert created["description"] == "Heavy machinery"
|
||||
|
||||
|
||||
def test_create_and_list_equipment():
|
||||
# Create a scenario to attach equipment
|
||||
resp = client.post(
|
||||
"/api/scenarios/", json={"name": "EquipScenario", "description": "equipment scenario"}
|
||||
def test_list_equipment_filters_by_scenario(client: TestClient) -> None:
|
||||
target_scenario = _create_scenario(client)
|
||||
other_scenario = _create_scenario(client)
|
||||
|
||||
for scenario_id, name in [
|
||||
(target_scenario, "Bulldozer"),
|
||||
(target_scenario, "Loader"),
|
||||
(other_scenario, "Conveyor"),
|
||||
]:
|
||||
response = client.post(
|
||||
"/api/equipment/",
|
||||
json={
|
||||
"scenario_id": scenario_id,
|
||||
"name": name,
|
||||
"description": f"Equipment {name}",
|
||||
},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
scenario = resp.json()
|
||||
sid = scenario["id"]
|
||||
assert response.status_code == 200
|
||||
|
||||
# Create Equipment item
|
||||
eq_payload = {"scenario_id": sid, "name": "Excavator",
|
||||
"description": "Heavy machinery"}
|
||||
resp2 = client.post("/api/equipment/", json=eq_payload)
|
||||
assert resp2.status_code == 200
|
||||
eq = resp2.json()
|
||||
assert eq["scenario_id"] == sid
|
||||
assert eq["name"] == "Excavator"
|
||||
list_response = client.get("/api/equipment/")
|
||||
assert list_response.status_code == 200
|
||||
items = [
|
||||
item
|
||||
for item in list_response.json()
|
||||
if item["scenario_id"] == target_scenario
|
||||
]
|
||||
assert {item["name"] for item in items} == {"Bulldozer", "Loader"}
|
||||
|
||||
# List Equipment items
|
||||
resp3 = client.get("/api/equipment/")
|
||||
assert resp3.status_code == 200
|
||||
data = resp3.json()
|
||||
assert any(item["name"] == "Excavator" and item["scenario_id"]
|
||||
== sid for item in data)
|
||||
|
||||
def test_create_equipment_requires_name(client: TestClient) -> None:
|
||||
scenario_id = _create_scenario(client)
|
||||
response = client.post(
|
||||
"/api/equipment/",
|
||||
json={
|
||||
"scenario_id": scenario_id,
|
||||
"description": "Missing name",
|
||||
},
|
||||
)
|
||||
assert response.status_code == 422
|
||||
|
||||
@@ -1,75 +1,16 @@
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from fastapi.testclient import TestClient
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from config.database import Base
|
||||
from main import app
|
||||
from routes import (
|
||||
consumption,
|
||||
costs,
|
||||
distributions,
|
||||
equipment,
|
||||
maintenance,
|
||||
parameters,
|
||||
production,
|
||||
reporting,
|
||||
scenarios,
|
||||
simulations,
|
||||
ui,
|
||||
)
|
||||
|
||||
SQLALCHEMY_DATABASE_URL = "sqlite:///./test.db"
|
||||
engine = create_engine(SQLALCHEMY_DATABASE_URL, connect_args={
|
||||
"check_same_thread": False})
|
||||
TestingSessionLocal = sessionmaker(
|
||||
autocommit=False, autoflush=False, bind=engine)
|
||||
|
||||
|
||||
def setup_module(module):
|
||||
Base.metadata.create_all(bind=engine)
|
||||
@pytest.fixture
|
||||
def client(api_client: TestClient) -> TestClient:
|
||||
return api_client
|
||||
|
||||
|
||||
def teardown_module(module):
|
||||
Base.metadata.drop_all(bind=engine)
|
||||
|
||||
|
||||
def override_get_db():
|
||||
db = TestingSessionLocal()
|
||||
try:
|
||||
yield db
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
app.dependency_overrides[maintenance.get_db] = override_get_db
|
||||
app.dependency_overrides[equipment.get_db] = override_get_db
|
||||
app.dependency_overrides[scenarios.get_db] = override_get_db
|
||||
app.dependency_overrides[distributions.get_db] = override_get_db
|
||||
app.dependency_overrides[parameters.get_db] = override_get_db
|
||||
app.dependency_overrides[costs.get_db] = override_get_db
|
||||
app.dependency_overrides[consumption.get_db] = override_get_db
|
||||
app.dependency_overrides[production.get_db] = override_get_db
|
||||
app.dependency_overrides[reporting.get_db] = override_get_db
|
||||
app.dependency_overrides[simulations.get_db] = override_get_db
|
||||
|
||||
app.include_router(maintenance.router)
|
||||
app.include_router(equipment.router)
|
||||
app.include_router(scenarios.router)
|
||||
app.include_router(distributions.router)
|
||||
app.include_router(ui.router)
|
||||
app.include_router(parameters.router)
|
||||
app.include_router(costs.router)
|
||||
app.include_router(consumption.router)
|
||||
app.include_router(production.router)
|
||||
app.include_router(reporting.router)
|
||||
app.include_router(simulations.router)
|
||||
|
||||
client = TestClient(app)
|
||||
|
||||
|
||||
def _create_scenario_and_equipment():
|
||||
def _create_scenario_and_equipment(client: TestClient):
|
||||
scenario_payload = {
|
||||
"name": f"Test Scenario {uuid4()}",
|
||||
"description": "Scenario for maintenance tests",
|
||||
@@ -99,8 +40,8 @@ def _create_maintenance_payload(equipment_id: int, scenario_id: int, description
|
||||
}
|
||||
|
||||
|
||||
def test_create_and_list_maintenance():
|
||||
scenario_id, equipment_id = _create_scenario_and_equipment()
|
||||
def test_create_and_list_maintenance(client: TestClient):
|
||||
scenario_id, equipment_id = _create_scenario_and_equipment(client)
|
||||
payload = _create_maintenance_payload(
|
||||
equipment_id, scenario_id, "Create maintenance")
|
||||
|
||||
@@ -117,8 +58,8 @@ def test_create_and_list_maintenance():
|
||||
assert any(item["id"] == created["id"] for item in items)
|
||||
|
||||
|
||||
def test_get_maintenance():
|
||||
scenario_id, equipment_id = _create_scenario_and_equipment()
|
||||
def test_get_maintenance(client: TestClient):
|
||||
scenario_id, equipment_id = _create_scenario_and_equipment(client)
|
||||
payload = _create_maintenance_payload(
|
||||
equipment_id, scenario_id, "Retrieve maintenance"
|
||||
)
|
||||
@@ -134,8 +75,8 @@ def test_get_maintenance():
|
||||
assert data["description"] == "Retrieve maintenance"
|
||||
|
||||
|
||||
def test_update_maintenance():
|
||||
scenario_id, equipment_id = _create_scenario_and_equipment()
|
||||
def test_update_maintenance(client: TestClient):
|
||||
scenario_id, equipment_id = _create_scenario_and_equipment(client)
|
||||
create_response = client.post(
|
||||
"/api/maintenance/",
|
||||
json=_create_maintenance_payload(
|
||||
@@ -162,8 +103,8 @@ def test_update_maintenance():
|
||||
assert updated["cost"] == 250.0
|
||||
|
||||
|
||||
def test_delete_maintenance():
|
||||
scenario_id, equipment_id = _create_scenario_and_equipment()
|
||||
def test_delete_maintenance(client: TestClient):
|
||||
scenario_id, equipment_id = _create_scenario_and_equipment(client)
|
||||
create_response = client.post(
|
||||
"/api/maintenance/",
|
||||
json=_create_maintenance_payload(
|
||||
|
||||
@@ -1,46 +0,0 @@
|
||||
from models.scenario import Scenario
|
||||
from main import app
|
||||
from config.database import Base, engine
|
||||
from fastapi.testclient import TestClient
|
||||
import pytest
|
||||
import os
|
||||
import sys
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(__file__))))
|
||||
|
||||
# Setup and teardown
|
||||
|
||||
|
||||
def setup_module(module):
|
||||
Base.metadata.create_all(bind=engine)
|
||||
|
||||
|
||||
def teardown_module(module):
|
||||
Base.metadata.drop_all(bind=engine)
|
||||
|
||||
|
||||
client = TestClient(app)
|
||||
|
||||
# Helper to create a scenario
|
||||
|
||||
|
||||
def create_test_scenario():
|
||||
resp = client.post("/api/scenarios/",
|
||||
json={"name": "ParamTest", "description": "Desc"})
|
||||
assert resp.status_code == 200
|
||||
return resp.json()["id"]
|
||||
|
||||
|
||||
def test_create_and_list_parameter():
|
||||
# Ensure scenario exists
|
||||
scen_id = create_test_scenario()
|
||||
# Create a parameter
|
||||
resp = client.post(
|
||||
"/api/parameters/", json={"scenario_id": scen_id, "name": "param1", "value": 3.14})
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["name"] == "param1"
|
||||
# List parameters
|
||||
resp2 = client.get("/api/parameters/")
|
||||
assert resp2.status_code == 200
|
||||
data2 = resp2.json()
|
||||
assert any(p["name"] == "param1" for p in data2)
|
||||
123
tests/unit/test_parameters.py
Normal file
123
tests/unit/test_parameters.py
Normal file
@@ -0,0 +1,123 @@
|
||||
from typing import Any, Dict, List
|
||||
from uuid import uuid4
|
||||
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from config.database import Base, engine
|
||||
from main import app
|
||||
|
||||
|
||||
def setup_module(module: object) -> None:
|
||||
Base.metadata.create_all(bind=engine)
|
||||
|
||||
|
||||
def teardown_module(module: object) -> None:
|
||||
Base.metadata.drop_all(bind=engine)
|
||||
|
||||
|
||||
def _create_scenario(name: str | None = None) -> int:
|
||||
payload: Dict[str, Any] = {
|
||||
"name": name or f"ParamScenario-{uuid4()}",
|
||||
"description": "Parameter test scenario",
|
||||
}
|
||||
response = TestClient(app).post("/api/scenarios/", json=payload)
|
||||
assert response.status_code == 200
|
||||
return response.json()["id"]
|
||||
|
||||
|
||||
def _create_distribution() -> int:
|
||||
payload: Dict[str, Any] = {
|
||||
"name": f"NormalDist-{uuid4()}",
|
||||
"distribution_type": "normal",
|
||||
"parameters": {"mu": 10, "sigma": 2},
|
||||
}
|
||||
response = TestClient(app).post("/api/distributions/", json=payload)
|
||||
assert response.status_code == 200
|
||||
return response.json()["id"]
|
||||
|
||||
|
||||
client = TestClient(app)
|
||||
|
||||
|
||||
def test_create_and_list_parameter():
|
||||
scenario_id = _create_scenario()
|
||||
distribution_id = _create_distribution()
|
||||
parameter_payload: Dict[str, Any] = {
|
||||
"scenario_id": scenario_id,
|
||||
"name": f"param-{uuid4()}",
|
||||
"value": 3.14,
|
||||
"distribution_id": distribution_id,
|
||||
}
|
||||
|
||||
create_response = client.post("/api/parameters/", json=parameter_payload)
|
||||
assert create_response.status_code == 200
|
||||
created = create_response.json()
|
||||
assert created["scenario_id"] == scenario_id
|
||||
assert created["name"] == parameter_payload["name"]
|
||||
assert created["value"] == parameter_payload["value"]
|
||||
assert created["distribution_id"] == distribution_id
|
||||
assert created["distribution_type"] == "normal"
|
||||
assert created["distribution_parameters"] == {"mu": 10, "sigma": 2}
|
||||
|
||||
list_response = client.get("/api/parameters/")
|
||||
assert list_response.status_code == 200
|
||||
params = list_response.json()
|
||||
assert any(p["id"] == created["id"] for p in params)
|
||||
|
||||
|
||||
def test_create_parameter_for_missing_scenario():
|
||||
payload: Dict[str, Any] = {
|
||||
"scenario_id": 0, "name": "invalid", "value": 1.0}
|
||||
response = client.post("/api/parameters/", json=payload)
|
||||
assert response.status_code == 404
|
||||
assert response.json()["detail"] == "Scenario not found"
|
||||
|
||||
|
||||
def test_multiple_parameters_listed():
|
||||
scenario_id = _create_scenario()
|
||||
payloads: List[Dict[str, Any]] = [
|
||||
{"scenario_id": scenario_id, "name": f"alpha-{i}", "value": float(i)}
|
||||
for i in range(2)
|
||||
]
|
||||
|
||||
for payload in payloads:
|
||||
resp = client.post("/api/parameters/", json=payload)
|
||||
assert resp.status_code == 200
|
||||
|
||||
list_response = client.get("/api/parameters/")
|
||||
assert list_response.status_code == 200
|
||||
names = {item["name"] for item in list_response.json()}
|
||||
for payload in payloads:
|
||||
assert payload["name"] in names
|
||||
|
||||
|
||||
def test_parameter_inline_distribution_metadata():
|
||||
scenario_id = _create_scenario()
|
||||
payload: Dict[str, Any] = {
|
||||
"scenario_id": scenario_id,
|
||||
"name": "inline-param",
|
||||
"value": 7.5,
|
||||
"distribution_type": "uniform",
|
||||
"distribution_parameters": {"min": 5, "max": 10},
|
||||
}
|
||||
|
||||
response = client.post("/api/parameters/", json=payload)
|
||||
assert response.status_code == 200
|
||||
created = response.json()
|
||||
assert created["distribution_id"] is None
|
||||
assert created["distribution_type"] == "uniform"
|
||||
assert created["distribution_parameters"] == {"min": 5, "max": 10}
|
||||
|
||||
|
||||
def test_parameter_with_missing_distribution_reference():
|
||||
scenario_id = _create_scenario()
|
||||
payload: Dict[str, Any] = {
|
||||
"scenario_id": scenario_id,
|
||||
"name": "missing-dist",
|
||||
"value": 1.0,
|
||||
"distribution_id": 9999,
|
||||
}
|
||||
|
||||
response = client.post("/api/parameters/", json=payload)
|
||||
assert response.status_code == 404
|
||||
assert response.json()["detail"] == "Distribution not found"
|
||||
@@ -1,42 +1,70 @@
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
from main import app
|
||||
from config.database import Base, engine
|
||||
|
||||
# Setup and teardown
|
||||
|
||||
|
||||
def setup_module(module):
|
||||
Base.metadata.create_all(bind=engine)
|
||||
@pytest.fixture
|
||||
def client(api_client: TestClient) -> TestClient:
|
||||
return api_client
|
||||
|
||||
|
||||
def teardown_module(module):
|
||||
Base.metadata.drop_all(bind=engine)
|
||||
def _create_scenario(client: TestClient) -> int:
|
||||
payload = {
|
||||
"name": f"Production Scenario {uuid4()}",
|
||||
"description": "Scenario for production tests",
|
||||
}
|
||||
response = client.post("/api/scenarios/", json=payload)
|
||||
assert response.status_code == 200
|
||||
return response.json()["id"]
|
||||
|
||||
|
||||
client = TestClient(app)
|
||||
def test_create_production_record(client: TestClient) -> None:
|
||||
scenario_id = _create_scenario(client)
|
||||
payload = {
|
||||
"scenario_id": scenario_id,
|
||||
"amount": 475.25,
|
||||
"description": "Daily output",
|
||||
}
|
||||
|
||||
response = client.post("/api/production/", json=payload)
|
||||
assert response.status_code == 201
|
||||
created = response.json()
|
||||
assert created["scenario_id"] == scenario_id
|
||||
assert created["amount"] == pytest.approx(475.25)
|
||||
assert created["description"] == "Daily output"
|
||||
|
||||
|
||||
def test_create_and_list_production_output():
|
||||
# Create a scenario to attach production output
|
||||
resp = client.post(
|
||||
"/api/scenarios/", json={"name": "ProdScenario", "description": "production scenario"}
|
||||
def test_list_production_filters_by_scenario(client: TestClient) -> None:
|
||||
target_scenario = _create_scenario(client)
|
||||
other_scenario = _create_scenario(client)
|
||||
|
||||
for scenario_id, amount in [(target_scenario, 100.0), (target_scenario, 150.0), (other_scenario, 200.0)]:
|
||||
response = client.post(
|
||||
"/api/production/",
|
||||
json={
|
||||
"scenario_id": scenario_id,
|
||||
"amount": amount,
|
||||
"description": f"Output {amount}",
|
||||
},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
scenario = resp.json()
|
||||
sid = scenario["id"]
|
||||
assert response.status_code == 201
|
||||
|
||||
# Create Production Output item
|
||||
prod_payload = {"scenario_id": sid,
|
||||
"amount": 300.0, "description": "Daily output"}
|
||||
resp2 = client.post("/api/production/", json=prod_payload)
|
||||
assert resp2.status_code == 201
|
||||
prod = resp2.json()
|
||||
assert prod["scenario_id"] == sid
|
||||
assert prod["amount"] == 300.0
|
||||
list_response = client.get("/api/production/")
|
||||
assert list_response.status_code == 200
|
||||
items = [item for item in list_response.json()
|
||||
if item["scenario_id"] == target_scenario]
|
||||
assert {item["amount"] for item in items} == {100.0, 150.0}
|
||||
|
||||
# List Production Output items
|
||||
resp3 = client.get("/api/production/")
|
||||
assert resp3.status_code == 200
|
||||
data = resp3.json()
|
||||
assert any(item["amount"] == 300.0 and item["scenario_id"]
|
||||
== sid for item in data)
|
||||
|
||||
def test_create_production_rejects_negative_amount(client: TestClient) -> None:
|
||||
scenario_id = _create_scenario(client)
|
||||
response = client.post(
|
||||
"/api/production/",
|
||||
json={
|
||||
"scenario_id": scenario_id,
|
||||
"amount": -5,
|
||||
"description": "Invalid output",
|
||||
},
|
||||
)
|
||||
assert response.status_code == 422
|
||||
|
||||
@@ -1,7 +1,10 @@
|
||||
from fastapi.testclient import TestClient
|
||||
import math
|
||||
from typing import Any, Dict, List
|
||||
|
||||
import pytest
|
||||
|
||||
from main import app
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from services.reporting import generate_report
|
||||
|
||||
|
||||
@@ -14,57 +17,77 @@ def test_generate_report_empty():
|
||||
"min": 0.0,
|
||||
"max": 0.0,
|
||||
"std_dev": 0.0,
|
||||
"variance": 0.0,
|
||||
"percentile_10": 0.0,
|
||||
"percentile_90": 0.0,
|
||||
"percentile_5": 0.0,
|
||||
"percentile_95": 0.0,
|
||||
"value_at_risk_95": 0.0,
|
||||
"expected_shortfall_95": 0.0,
|
||||
}
|
||||
|
||||
|
||||
def test_generate_report_with_values():
|
||||
values = [{"iteration": 1, "result": 10.0}, {
|
||||
"iteration": 2, "result": 20.0}, {"iteration": 3, "result": 30.0}]
|
||||
values: List[Dict[str, float]] = [
|
||||
{"iteration": 1, "result": 10.0},
|
||||
{"iteration": 2, "result": 20.0},
|
||||
{"iteration": 3, "result": 30.0},
|
||||
]
|
||||
report = generate_report(values)
|
||||
assert report["count"] == 3
|
||||
assert report["mean"] == pytest.approx(20.0)
|
||||
assert report["median"] == pytest.approx(20.0)
|
||||
assert report["min"] == pytest.approx(10.0)
|
||||
assert report["max"] == pytest.approx(30.0)
|
||||
assert report["std_dev"] == pytest.approx(8.1649658, rel=1e-6)
|
||||
assert report["percentile_10"] == pytest.approx(12.0)
|
||||
assert report["percentile_90"] == pytest.approx(28.0)
|
||||
assert math.isclose(float(report["mean"]), 20.0)
|
||||
assert math.isclose(float(report["median"]), 20.0)
|
||||
assert math.isclose(float(report["min"]), 10.0)
|
||||
assert math.isclose(float(report["max"]), 30.0)
|
||||
assert math.isclose(float(report["std_dev"]), 8.1649658, rel_tol=1e-6)
|
||||
assert math.isclose(float(report["variance"]), 66.6666666, rel_tol=1e-6)
|
||||
assert math.isclose(float(report["percentile_10"]), 12.0)
|
||||
assert math.isclose(float(report["percentile_90"]), 28.0)
|
||||
assert math.isclose(float(report["percentile_5"]), 11.0)
|
||||
assert math.isclose(float(report["percentile_95"]), 29.0)
|
||||
assert math.isclose(float(report["value_at_risk_95"]), 11.0)
|
||||
assert math.isclose(float(report["expected_shortfall_95"]), 10.0)
|
||||
|
||||
|
||||
def test_reporting_endpoint_invalid_input():
|
||||
client = TestClient(app)
|
||||
@pytest.fixture
|
||||
def client(api_client: TestClient) -> TestClient:
|
||||
return api_client
|
||||
|
||||
|
||||
def test_reporting_endpoint_invalid_input(client: TestClient):
|
||||
resp = client.post("/api/reporting/summary", json={})
|
||||
assert resp.status_code == 400
|
||||
assert resp.json()["detail"] == "Invalid input format"
|
||||
|
||||
|
||||
def test_reporting_endpoint_success():
|
||||
client = TestClient(app)
|
||||
input_data = [
|
||||
def test_reporting_endpoint_success(client: TestClient):
|
||||
input_data: List[Dict[str, float]] = [
|
||||
{"iteration": 1, "result": 10.0},
|
||||
{"iteration": 2, "result": 20.0},
|
||||
{"iteration": 3, "result": 30.0},
|
||||
]
|
||||
resp = client.post("/api/reporting/summary", json=input_data)
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
data: Dict[str, Any] = resp.json()
|
||||
assert data["count"] == 3
|
||||
assert data["mean"] == pytest.approx(20.0)
|
||||
assert math.isclose(float(data["mean"]), 20.0)
|
||||
assert math.isclose(float(data["variance"]), 66.6666666, rel_tol=1e-6)
|
||||
assert math.isclose(float(data["value_at_risk_95"]), 11.0)
|
||||
assert math.isclose(float(data["expected_shortfall_95"]), 10.0)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"payload,expected_detail",
|
||||
[
|
||||
validation_error_cases: List[tuple[List[Any], str]] = [
|
||||
(["not-a-dict"], "Entry at index 0 must be an object"),
|
||||
([{"iteration": 1}], "Entry at index 0 must include numeric 'result'"),
|
||||
([{"iteration": 1, "result": "bad"}],
|
||||
"Entry at index 0 must include numeric 'result'"),
|
||||
],
|
||||
)
|
||||
def test_reporting_endpoint_validation_errors(payload, expected_detail):
|
||||
client = TestClient(app)
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("payload,expected_detail", validation_error_cases)
|
||||
def test_reporting_endpoint_validation_errors(
|
||||
client: TestClient, payload: List[Any], expected_detail: str
|
||||
):
|
||||
resp = client.post("/api/reporting/summary", json=payload)
|
||||
assert resp.status_code == 400
|
||||
assert resp.json()["detail"] == expected_detail
|
||||
|
||||
@@ -1,14 +1,9 @@
|
||||
# ensure project root is on sys.path for module imports
|
||||
from main import app
|
||||
from routes.scenarios import router
|
||||
from config.database import Base, engine
|
||||
from fastapi.testclient import TestClient
|
||||
import pytest
|
||||
import os
|
||||
import sys
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(__file__))))
|
||||
from uuid import uuid4
|
||||
|
||||
# Create tables for testing
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from config.database import Base, engine
|
||||
from main import app
|
||||
|
||||
|
||||
def setup_module(module):
|
||||
@@ -23,14 +18,28 @@ client = TestClient(app)
|
||||
|
||||
|
||||
def test_create_and_list_scenario():
|
||||
# Create a scenario
|
||||
scenario_name = f"Scenario-{uuid4()}"
|
||||
response = client.post(
|
||||
"/api/scenarios/", json={"name": "Test", "description": "Desc"})
|
||||
"/api/scenarios/",
|
||||
json={"name": scenario_name, "description": "Integration test"},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["name"] == "Test"
|
||||
# List scenarios
|
||||
assert data["name"] == scenario_name
|
||||
|
||||
response2 = client.get("/api/scenarios/")
|
||||
assert response2.status_code == 200
|
||||
data2 = response2.json()
|
||||
assert any(s["name"] == "Test" for s in data2)
|
||||
assert any(s["name"] == scenario_name for s in data2)
|
||||
|
||||
|
||||
def test_create_duplicate_scenario_rejected():
|
||||
scenario_name = f"Duplicate-{uuid4()}"
|
||||
payload = {"name": scenario_name, "description": "Primary"}
|
||||
|
||||
first_resp = client.post("/api/scenarios/", json=payload)
|
||||
assert first_resp.status_code == 200
|
||||
|
||||
second_resp = client.post("/api/scenarios/", json=payload)
|
||||
assert second_resp.status_code == 400
|
||||
assert second_resp.json()["detail"] == "Scenario already exists"
|
||||
|
||||
@@ -1,42 +1,111 @@
|
||||
from services.simulation import run_simulation
|
||||
from main import app
|
||||
from config.database import Base, engine
|
||||
from fastapi.testclient import TestClient
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
# Setup and teardown
|
||||
from models.simulation_result import SimulationResult
|
||||
from services.simulation import run_simulation
|
||||
|
||||
|
||||
def setup_module(module):
|
||||
Base.metadata.create_all(bind=engine)
|
||||
@pytest.fixture
|
||||
def client(api_client: TestClient) -> TestClient:
|
||||
return api_client
|
||||
|
||||
|
||||
def teardown_module(module):
|
||||
Base.metadata.drop_all(bind=engine)
|
||||
|
||||
|
||||
client = TestClient(app)
|
||||
|
||||
# Direct function test
|
||||
|
||||
|
||||
def test_run_simulation_function_returns_list():
|
||||
results = run_simulation([], iterations=10)
|
||||
def test_run_simulation_function_generates_samples():
|
||||
params = [
|
||||
{"name": "grade", "value": 1.8, "distribution": "normal", "std_dev": 0.2},
|
||||
{
|
||||
"name": "recovery",
|
||||
"value": 0.9,
|
||||
"distribution": "uniform",
|
||||
"min": 0.8,
|
||||
"max": 0.95,
|
||||
},
|
||||
]
|
||||
results = run_simulation(params, iterations=5, seed=123)
|
||||
assert isinstance(results, list)
|
||||
assert results == []
|
||||
|
||||
# Endpoint tests
|
||||
assert len(results) == 5
|
||||
assert results[0]["iteration"] == 1
|
||||
|
||||
|
||||
def test_simulation_endpoint_no_params():
|
||||
resp = client.post("/api/simulations/run", json=[])
|
||||
def test_simulation_endpoint_no_params(client: TestClient):
|
||||
scenario_payload = {
|
||||
"name": f"NoParamScenario-{uuid4()}",
|
||||
"description": "No parameters run",
|
||||
}
|
||||
scenario_resp = client.post("/api/scenarios/", json=scenario_payload)
|
||||
assert scenario_resp.status_code == 200
|
||||
scenario_id = scenario_resp.json()["id"]
|
||||
|
||||
resp = client.post(
|
||||
"/api/simulations/run",
|
||||
json={"scenario_id": scenario_id, "parameters": [], "iterations": 10},
|
||||
)
|
||||
assert resp.status_code == 400
|
||||
assert resp.json()["detail"] == "No parameters provided"
|
||||
|
||||
|
||||
def test_simulation_endpoint_success():
|
||||
params = [{"name": "param1", "value": 2.5}]
|
||||
resp = client.post("/api/simulations/run", json=params)
|
||||
def test_simulation_endpoint_success(
|
||||
client: TestClient, db_session: Session
|
||||
):
|
||||
scenario_payload = {
|
||||
"name": f"SimScenario-{uuid4()}",
|
||||
"description": "Simulation test",
|
||||
}
|
||||
scenario_resp = client.post("/api/scenarios/", json=scenario_payload)
|
||||
assert scenario_resp.status_code == 200
|
||||
scenario_id = scenario_resp.json()["id"]
|
||||
|
||||
params = [
|
||||
{"name": "param1", "value": 2.5, "distribution": "normal", "std_dev": 0.5}
|
||||
]
|
||||
payload = {
|
||||
"scenario_id": scenario_id,
|
||||
"parameters": params,
|
||||
"iterations": 10,
|
||||
"seed": 42,
|
||||
}
|
||||
|
||||
resp = client.post("/api/simulations/run", json=payload)
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert isinstance(data, list)
|
||||
assert data["scenario_id"] == scenario_id
|
||||
assert len(data["results"]) == 10
|
||||
assert data["summary"]["count"] == 10
|
||||
|
||||
db_session.expire_all()
|
||||
persisted = (
|
||||
db_session.query(SimulationResult)
|
||||
.filter(SimulationResult.scenario_id == scenario_id)
|
||||
.all()
|
||||
)
|
||||
assert len(persisted) == 10
|
||||
|
||||
|
||||
def test_simulation_endpoint_uses_stored_parameters(client: TestClient):
|
||||
scenario_payload = {
|
||||
"name": f"StoredParams-{uuid4()}",
|
||||
"description": "Stored parameter simulation",
|
||||
}
|
||||
scenario_resp = client.post("/api/scenarios/", json=scenario_payload)
|
||||
assert scenario_resp.status_code == 200
|
||||
scenario_id = scenario_resp.json()["id"]
|
||||
|
||||
parameter_payload = {
|
||||
"scenario_id": scenario_id,
|
||||
"name": "grade",
|
||||
"value": 1.5,
|
||||
}
|
||||
param_resp = client.post("/api/parameters/", json=parameter_payload)
|
||||
assert param_resp.status_code == 200
|
||||
|
||||
resp = client.post(
|
||||
"/api/simulations/run",
|
||||
json={"scenario_id": scenario_id, "iterations": 3, "seed": 7},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["summary"]["count"] == 3
|
||||
assert len(data["results"]) == 3
|
||||
|
||||
Reference in New Issue
Block a user