diff --git a/routes/simulations.py b/routes/simulations.py index 0a9f100..162533f 100644 --- a/routes/simulations.py +++ b/routes/simulations.py @@ -1,9 +1,15 @@ -from fastapi import APIRouter, HTTPException, Depends -from typing import List +from typing import List, Optional, cast -from services.simulation import run_simulation +from fastapi import APIRouter, Depends, HTTPException, status +from pydantic import BaseModel, PositiveInt from sqlalchemy.orm import Session + from config.database import SessionLocal +from models.parameters import Parameter +from models.scenario import Scenario +from models.simulation_result import SimulationResult +from services.reporting import generate_report +from services.simulation import run_simulation router = APIRouter(prefix="/api/simulations", tags=["Simulations"]) @@ -16,10 +22,109 @@ def get_db(): db.close() -@router.post("/run", response_model=List[dict]) -async def simulate(params: List[dict], iterations: int = 1000, db: Session = Depends(get_db)): - if not params: - raise HTTPException(status_code=400, detail="No parameters provided") - # TODO: you might use db to fetch scenario info or persist results - results = run_simulation(params, iterations) - return results +class SimulationParameterInput(BaseModel): + name: str + value: float + distribution: Optional[str] = "normal" + std_dev: Optional[float] = None + min: Optional[float] = None + max: Optional[float] = None + mode: Optional[float] = None + + +class SimulationRunRequest(BaseModel): + scenario_id: int + iterations: PositiveInt = 1000 + parameters: Optional[List[SimulationParameterInput]] = None + seed: Optional[int] = None + + +class SimulationResultItem(BaseModel): + iteration: int + result: float + + +class SimulationRunResponse(BaseModel): + scenario_id: int + iterations: int + results: List[SimulationResultItem] + summary: dict + + +def _load_parameters(db: Session, scenario_id: int) -> List[SimulationParameterInput]: + db_params = ( + db.query(Parameter) + .filter(Parameter.scenario_id == scenario_id) + .order_by(Parameter.id) + .all() + ) + return [ + SimulationParameterInput( + name=cast(str, item.name), + value=cast(float, item.value), + ) + for item in db_params + ] + + +@router.post("/run", response_model=SimulationRunResponse) +async def simulate(payload: SimulationRunRequest, db: Session = Depends(get_db)): + scenario = db.query(Scenario).filter( + Scenario.id == payload.scenario_id).first() + if scenario is None: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Scenario not found", + ) + + parameters = payload.parameters or _load_parameters( + db, payload.scenario_id) + if not parameters: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="No parameters provided", + ) + + raw_results = run_simulation( + [param.dict(exclude_none=True) for param in parameters], + iterations=payload.iterations, + seed=payload.seed, + ) + + if not raw_results: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Simulation produced no results", + ) + + # Persist results (replace existing values for scenario) + db.query(SimulationResult).filter( + SimulationResult.scenario_id == payload.scenario_id + ).delete() + db.bulk_save_objects( + [ + SimulationResult( + scenario_id=payload.scenario_id, + iteration=item["iteration"], + result=item["result"], + ) + for item in raw_results + ] + ) + db.commit() + + summary = generate_report(raw_results) + + response = SimulationRunResponse( + scenario_id=payload.scenario_id, + iterations=payload.iterations, + results=[ + SimulationResultItem( + iteration=int(item["iteration"]), + result=float(item["result"]), + ) + for item in raw_results + ], + summary=summary, + ) + return response diff --git a/services/simulation.py b/services/simulation.py index db4a4ed..4a433f2 100644 --- a/services/simulation.py +++ b/services/simulation.py @@ -1,17 +1,140 @@ -from typing import Dict, List +from __future__ import annotations + +from dataclasses import dataclass +from random import Random +from typing import Dict, List, Literal, Optional, Sequence -def run_simulation(parameters: List[Dict[str, float]], iterations: int = 1000) -> List[Dict[str, float]]: - """ - Run Monte Carlo simulation with given parameters. +DEFAULT_STD_DEV_RATIO = 0.1 +DEFAULT_UNIFORM_SPAN_RATIO = 0.15 +DistributionType = Literal["normal", "uniform", "triangular"] - Args: - parameters: List of parameter dicts with keys 'name' and 'value'. - iterations: Number of simulation iterations. - Returns: - List of simulation result dicts for each iteration. - """ - # TODO: implement Monte Carlo logic +@dataclass +class SimulationParameter: + name: str + base_value: float + distribution: DistributionType + std_dev: Optional[float] = None + minimum: Optional[float] = None + maximum: Optional[float] = None + mode: Optional[float] = None + + +def _ensure_positive_span(span: float, fallback: float) -> float: + return span if span and span > 0 else fallback + + +def _compile_parameters(parameters: Sequence[Dict[str, float]]) -> List[SimulationParameter]: + compiled: List[SimulationParameter] = [] + for index, item in enumerate(parameters): + if "value" not in item: + raise ValueError( + f"Parameter at index {index} must include 'value'") + name = str(item.get("name", f"param_{index}")) + base_value = float(item["value"]) + distribution = str(item.get("distribution", "normal")).lower() + if distribution not in {"normal", "uniform", "triangular"}: + raise ValueError( + f"Parameter '{name}' has unsupported distribution '{distribution}'" + ) + + span_default = abs(base_value) * DEFAULT_UNIFORM_SPAN_RATIO or 1.0 + + if distribution == "normal": + std_dev = item.get("std_dev") + std_dev_value = float(std_dev) if std_dev is not None else abs( + base_value) * DEFAULT_STD_DEV_RATIO or 1.0 + compiled.append( + SimulationParameter( + name=name, + base_value=base_value, + distribution="normal", + std_dev=_ensure_positive_span(std_dev_value, 1.0), + ) + ) + continue + + minimum = item.get("min") + maximum = item.get("max") + if minimum is None or maximum is None: + minimum = base_value - span_default + maximum = base_value + span_default + minimum = float(minimum) + maximum = float(maximum) + if minimum >= maximum: + raise ValueError( + f"Parameter '{name}' requires 'min' < 'max' for {distribution} distribution" + ) + + if distribution == "uniform": + compiled.append( + SimulationParameter( + name=name, + base_value=base_value, + distribution="uniform", + minimum=minimum, + maximum=maximum, + ) + ) + else: # triangular + mode = item.get("mode") + if mode is None: + mode = base_value + mode_value = float(mode) + if not (minimum <= mode_value <= maximum): + raise ValueError( + f"Parameter '{name}' mode must be within min/max bounds for triangular distribution" + ) + compiled.append( + SimulationParameter( + name=name, + base_value=base_value, + distribution="triangular", + minimum=minimum, + maximum=maximum, + mode=mode_value, + ) + ) + return compiled + + +def _sample_parameter(rng: Random, param: SimulationParameter) -> float: + if param.distribution == "normal": + assert param.std_dev is not None + return rng.normalvariate(param.base_value, param.std_dev) + if param.distribution == "uniform": + assert param.minimum is not None and param.maximum is not None + return rng.uniform(param.minimum, param.maximum) + # triangular + assert ( + param.minimum is not None + and param.maximum is not None + and param.mode is not None + ) + return rng.triangular(param.minimum, param.maximum, param.mode) + + +def run_simulation( + parameters: Sequence[Dict[str, float]], + iterations: int = 1000, + seed: Optional[int] = None, +) -> List[Dict[str, float]]: + """Run a lightweight Monte Carlo simulation using configurable distributions.""" + + if iterations <= 0: + return [] + + compiled_params = _compile_parameters(parameters) + if not compiled_params: + return [] + + rng = Random(seed) results: List[Dict[str, float]] = [] + for iteration in range(1, iterations + 1): + total = 0.0 + for param in compiled_params: + sample = _sample_parameter(rng, param) + total += sample + results.append({"iteration": iteration, "result": total}) return results