Enhance simulation API and logic: refactor parameter handling, add support for multiple distributions, and improve simulation result persistence
This commit is contained in:
@@ -1,9 +1,15 @@
|
||||
from fastapi import APIRouter, HTTPException, Depends
|
||||
from typing import List
|
||||
from typing import List, Optional, cast
|
||||
|
||||
from services.simulation import run_simulation
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from pydantic import BaseModel, PositiveInt
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from config.database import SessionLocal
|
||||
from models.parameters import Parameter
|
||||
from models.scenario import Scenario
|
||||
from models.simulation_result import SimulationResult
|
||||
from services.reporting import generate_report
|
||||
from services.simulation import run_simulation
|
||||
|
||||
router = APIRouter(prefix="/api/simulations", tags=["Simulations"])
|
||||
|
||||
@@ -16,10 +22,109 @@ def get_db():
|
||||
db.close()
|
||||
|
||||
|
||||
@router.post("/run", response_model=List[dict])
|
||||
async def simulate(params: List[dict], iterations: int = 1000, db: Session = Depends(get_db)):
|
||||
if not params:
|
||||
raise HTTPException(status_code=400, detail="No parameters provided")
|
||||
# TODO: you might use db to fetch scenario info or persist results
|
||||
results = run_simulation(params, iterations)
|
||||
return results
|
||||
class SimulationParameterInput(BaseModel):
|
||||
name: str
|
||||
value: float
|
||||
distribution: Optional[str] = "normal"
|
||||
std_dev: Optional[float] = None
|
||||
min: Optional[float] = None
|
||||
max: Optional[float] = None
|
||||
mode: Optional[float] = None
|
||||
|
||||
|
||||
class SimulationRunRequest(BaseModel):
|
||||
scenario_id: int
|
||||
iterations: PositiveInt = 1000
|
||||
parameters: Optional[List[SimulationParameterInput]] = None
|
||||
seed: Optional[int] = None
|
||||
|
||||
|
||||
class SimulationResultItem(BaseModel):
|
||||
iteration: int
|
||||
result: float
|
||||
|
||||
|
||||
class SimulationRunResponse(BaseModel):
|
||||
scenario_id: int
|
||||
iterations: int
|
||||
results: List[SimulationResultItem]
|
||||
summary: dict
|
||||
|
||||
|
||||
def _load_parameters(db: Session, scenario_id: int) -> List[SimulationParameterInput]:
|
||||
db_params = (
|
||||
db.query(Parameter)
|
||||
.filter(Parameter.scenario_id == scenario_id)
|
||||
.order_by(Parameter.id)
|
||||
.all()
|
||||
)
|
||||
return [
|
||||
SimulationParameterInput(
|
||||
name=cast(str, item.name),
|
||||
value=cast(float, item.value),
|
||||
)
|
||||
for item in db_params
|
||||
]
|
||||
|
||||
|
||||
@router.post("/run", response_model=SimulationRunResponse)
|
||||
async def simulate(payload: SimulationRunRequest, db: Session = Depends(get_db)):
|
||||
scenario = db.query(Scenario).filter(
|
||||
Scenario.id == payload.scenario_id).first()
|
||||
if scenario is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Scenario not found",
|
||||
)
|
||||
|
||||
parameters = payload.parameters or _load_parameters(
|
||||
db, payload.scenario_id)
|
||||
if not parameters:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="No parameters provided",
|
||||
)
|
||||
|
||||
raw_results = run_simulation(
|
||||
[param.dict(exclude_none=True) for param in parameters],
|
||||
iterations=payload.iterations,
|
||||
seed=payload.seed,
|
||||
)
|
||||
|
||||
if not raw_results:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Simulation produced no results",
|
||||
)
|
||||
|
||||
# Persist results (replace existing values for scenario)
|
||||
db.query(SimulationResult).filter(
|
||||
SimulationResult.scenario_id == payload.scenario_id
|
||||
).delete()
|
||||
db.bulk_save_objects(
|
||||
[
|
||||
SimulationResult(
|
||||
scenario_id=payload.scenario_id,
|
||||
iteration=item["iteration"],
|
||||
result=item["result"],
|
||||
)
|
||||
for item in raw_results
|
||||
]
|
||||
)
|
||||
db.commit()
|
||||
|
||||
summary = generate_report(raw_results)
|
||||
|
||||
response = SimulationRunResponse(
|
||||
scenario_id=payload.scenario_id,
|
||||
iterations=payload.iterations,
|
||||
results=[
|
||||
SimulationResultItem(
|
||||
iteration=int(item["iteration"]),
|
||||
result=float(item["result"]),
|
||||
)
|
||||
for item in raw_results
|
||||
],
|
||||
summary=summary,
|
||||
)
|
||||
return response
|
||||
|
||||
Reference in New Issue
Block a user