Files
calminer/routes/simulations.py
zwitschi 97b1c0360b
Some checks failed
Run Tests / e2e tests (push) Failing after 1m27s
Run Tests / lint tests (push) Failing after 6s
Run Tests / unit tests (push) Failing after 7s
Refactor test cases for improved readability and consistency
- Updated test functions in various test files to enhance code clarity by formatting long lines and improving indentation.
- Adjusted assertions to use multi-line formatting for better readability.
- Added new test cases for theme settings API to ensure proper functionality.
- Ensured consistent use of line breaks and spacing across test files for uniformity.
2025-10-27 10:32:55 +01:00

127 lines
3.4 KiB
Python

from typing import Dict, List, Optional
from fastapi import APIRouter, Depends, HTTPException, status
from pydantic import BaseModel, PositiveInt
from sqlalchemy.orm import Session
from models.parameters import Parameter
from models.scenario import Scenario
from models.simulation_result import SimulationResult
from routes.dependencies import get_db
from services.reporting import generate_report
from services.simulation import run_simulation
router = APIRouter(prefix="/api/simulations", tags=["Simulations"])
class SimulationParameterInput(BaseModel):
name: str
value: float
distribution: Optional[str] = "normal"
std_dev: Optional[float] = None
min: Optional[float] = None
max: Optional[float] = None
mode: Optional[float] = None
class SimulationRunRequest(BaseModel):
scenario_id: int
iterations: PositiveInt = 1000
parameters: Optional[List[SimulationParameterInput]] = None
seed: Optional[int] = None
class SimulationResultItem(BaseModel):
iteration: int
result: float
class SimulationRunResponse(BaseModel):
scenario_id: int
iterations: int
results: List[SimulationResultItem]
summary: Dict[str, float | int]
def _load_parameters(
db: Session, scenario_id: int
) -> List[SimulationParameterInput]:
db_params = (
db.query(Parameter)
.filter(Parameter.scenario_id == scenario_id)
.order_by(Parameter.id)
.all()
)
return [
SimulationParameterInput(
name=item.name,
value=item.value,
)
for item in db_params
]
@router.post("/run", response_model=SimulationRunResponse)
async def simulate(
payload: SimulationRunRequest, db: Session = Depends(get_db)
):
scenario = (
db.query(Scenario).filter(Scenario.id == payload.scenario_id).first()
)
if scenario is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Scenario not found",
)
parameters = payload.parameters or _load_parameters(db, payload.scenario_id)
if not parameters:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="No parameters provided",
)
raw_results = run_simulation(
[param.model_dump(exclude_none=True) for param in parameters],
iterations=payload.iterations,
seed=payload.seed,
)
if not raw_results:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Simulation produced no results",
)
# Persist results (replace existing values for scenario)
db.query(SimulationResult).filter(
SimulationResult.scenario_id == payload.scenario_id
).delete()
db.bulk_save_objects(
[
SimulationResult(
scenario_id=payload.scenario_id,
iteration=item["iteration"],
result=item["result"],
)
for item in raw_results
]
)
db.commit()
summary = generate_report(raw_results)
response = SimulationRunResponse(
scenario_id=payload.scenario_id,
iterations=payload.iterations,
results=[
SimulationResultItem(
iteration=int(item["iteration"]),
result=float(item["result"]),
)
for item in raw_results
],
summary=summary,
)
return response