from typing import Any, Dict, List, Optional from fastapi import APIRouter, Depends, HTTPException from pydantic import BaseModel, ConfigDict, field_validator from sqlalchemy.orm import Session from models.distribution import Distribution from models.parameters import Parameter from models.scenario import Scenario from routes.dependencies import get_db router = APIRouter(prefix="/api/parameters", tags=["parameters"]) class ParameterCreate(BaseModel): scenario_id: int name: str value: float distribution_id: Optional[int] = None distribution_type: Optional[str] = None distribution_parameters: Optional[Dict[str, Any]] = None @field_validator("distribution_type") @classmethod def normalize_type(cls, value: Optional[str]) -> Optional[str]: if value is None: return value normalized = value.strip().lower() if not normalized: return None if normalized not in {"normal", "uniform", "triangular"}: raise ValueError( "distribution_type must be normal, uniform, or triangular" ) return normalized @field_validator("distribution_parameters") @classmethod def empty_dict_to_none( cls, value: Optional[Dict[str, Any]] ) -> Optional[Dict[str, Any]]: if value is None: return None return value or None class ParameterRead(ParameterCreate): id: int model_config = ConfigDict(from_attributes=True) @router.post("/", response_model=ParameterRead) def create_parameter(param: ParameterCreate, db: Session = Depends(get_db)): scen = db.query(Scenario).filter(Scenario.id == param.scenario_id).first() if not scen: raise HTTPException(status_code=404, detail="Scenario not found") distribution_id = param.distribution_id distribution_type = param.distribution_type distribution_parameters = param.distribution_parameters if distribution_id is not None: distribution = ( db.query(Distribution) .filter(Distribution.id == distribution_id) .first() ) if not distribution: raise HTTPException( status_code=404, detail="Distribution not found" ) distribution_type = distribution.distribution_type distribution_parameters = distribution.parameters or None new_param = Parameter( scenario_id=param.scenario_id, name=param.name, value=param.value, distribution_id=distribution_id, distribution_type=distribution_type, distribution_parameters=distribution_parameters, ) db.add(new_param) db.commit() db.refresh(new_param) return new_param @router.get("/", response_model=List[ParameterRead]) def list_parameters(db: Session = Depends(get_db)): return db.query(Parameter).all()