- Updated test functions in various test files to enhance code clarity by formatting long lines and improving indentation. - Adjusted assertions to use multi-line formatting for better readability. - Added new test cases for theme settings API to ensure proper functionality. - Ensured consistent use of line breaks and spacing across test files for uniformity.
91 lines
2.8 KiB
Python
91 lines
2.8 KiB
Python
from typing import Any, Dict, List, Optional
|
|
|
|
from fastapi import APIRouter, Depends, HTTPException
|
|
from pydantic import BaseModel, ConfigDict, field_validator
|
|
from sqlalchemy.orm import Session
|
|
|
|
from models.distribution import Distribution
|
|
from models.parameters import Parameter
|
|
from models.scenario import Scenario
|
|
from routes.dependencies import get_db
|
|
|
|
router = APIRouter(prefix="/api/parameters", tags=["parameters"])
|
|
|
|
|
|
class ParameterCreate(BaseModel):
|
|
scenario_id: int
|
|
name: str
|
|
value: float
|
|
distribution_id: Optional[int] = None
|
|
distribution_type: Optional[str] = None
|
|
distribution_parameters: Optional[Dict[str, Any]] = None
|
|
|
|
@field_validator("distribution_type")
|
|
@classmethod
|
|
def normalize_type(cls, value: Optional[str]) -> Optional[str]:
|
|
if value is None:
|
|
return value
|
|
normalized = value.strip().lower()
|
|
if not normalized:
|
|
return None
|
|
if normalized not in {"normal", "uniform", "triangular"}:
|
|
raise ValueError(
|
|
"distribution_type must be normal, uniform, or triangular"
|
|
)
|
|
return normalized
|
|
|
|
@field_validator("distribution_parameters")
|
|
@classmethod
|
|
def empty_dict_to_none(
|
|
cls, value: Optional[Dict[str, Any]]
|
|
) -> Optional[Dict[str, Any]]:
|
|
if value is None:
|
|
return None
|
|
return value or None
|
|
|
|
|
|
class ParameterRead(ParameterCreate):
|
|
id: int
|
|
model_config = ConfigDict(from_attributes=True)
|
|
|
|
|
|
@router.post("/", response_model=ParameterRead)
|
|
def create_parameter(param: ParameterCreate, db: Session = Depends(get_db)):
|
|
scen = db.query(Scenario).filter(Scenario.id == param.scenario_id).first()
|
|
if not scen:
|
|
raise HTTPException(status_code=404, detail="Scenario not found")
|
|
distribution_id = param.distribution_id
|
|
distribution_type = param.distribution_type
|
|
distribution_parameters = param.distribution_parameters
|
|
|
|
if distribution_id is not None:
|
|
distribution = (
|
|
db.query(Distribution)
|
|
.filter(Distribution.id == distribution_id)
|
|
.first()
|
|
)
|
|
if not distribution:
|
|
raise HTTPException(
|
|
status_code=404, detail="Distribution not found"
|
|
)
|
|
distribution_type = distribution.distribution_type
|
|
distribution_parameters = distribution.parameters or None
|
|
|
|
new_param = Parameter(
|
|
scenario_id=param.scenario_id,
|
|
name=param.name,
|
|
value=param.value,
|
|
distribution_id=distribution_id,
|
|
distribution_type=distribution_type,
|
|
distribution_parameters=distribution_parameters,
|
|
)
|
|
db.add(new_param)
|
|
db.commit()
|
|
db.refresh(new_param)
|
|
return new_param
|
|
|
|
|
|
@router.get("/", response_model=List[ParameterRead])
|
|
def list_parameters(db: Session = Depends(get_db)):
|
|
return db.query(Parameter).all()
|