Files
calminer/routes/parameters.py
zwitschi 434be86b76 feat: Enhance dashboard metrics and summary statistics
- Added new summary fields: variance, 5th percentile, 95th percentile, VaR (95%), and expected shortfall (95%) to the dashboard.
- Updated the display logic for summary metrics to handle non-finite values gracefully.
- Modified the chart rendering to include additional percentile points and tail risk metrics in tooltips.

test: Introduce unit tests for consumption, costs, and other modules

- Created a comprehensive test suite for consumption, costs, equipment, maintenance, production, reporting, and simulation modules.
- Implemented fixtures for database setup and teardown using an in-memory SQLite database for isolated testing.
- Added tests for creating, listing, and validating various entities, ensuring proper error handling and response validation.

refactor: Consolidate parameter tests and remove deprecated files

- Merged parameter-related tests into a new test file for better organization and clarity.
- Removed the old parameter test file that was no longer in use.
- Improved test coverage for parameter creation, listing, and validation scenarios.

fix: Ensure proper validation and error handling in API endpoints

- Added validation to reject negative amounts in consumption and production records.
- Implemented checks to prevent duplicate scenario creation and ensure proper error messages are returned.
- Enhanced reporting endpoint tests to validate input formats and expected outputs.
2025-10-20 22:06:39 +02:00

94 lines
2.9 KiB
Python

from typing import Any, Dict, List, Optional
from fastapi import APIRouter, Depends, HTTPException
from pydantic import BaseModel, ConfigDict, field_validator
from sqlalchemy.orm import Session
from config.database import SessionLocal
from models.distribution import Distribution
from models.parameters import Parameter
from models.scenario import Scenario
router = APIRouter(prefix="/api/parameters", tags=["parameters"])
class ParameterCreate(BaseModel):
scenario_id: int
name: str
value: float
distribution_id: Optional[int] = None
distribution_type: Optional[str] = None
distribution_parameters: Optional[Dict[str, Any]] = None
@field_validator("distribution_type")
@classmethod
def normalize_type(cls, value: Optional[str]) -> Optional[str]:
if value is None:
return value
normalized = value.strip().lower()
if not normalized:
return None
if normalized not in {"normal", "uniform", "triangular"}:
raise ValueError(
"distribution_type must be normal, uniform, or triangular")
return normalized
@field_validator("distribution_parameters")
@classmethod
def empty_dict_to_none(cls, value: Optional[Dict[str, Any]]) -> Optional[Dict[str, Any]]:
if value is None:
return None
return value or None
class ParameterRead(ParameterCreate):
id: int
model_config = ConfigDict(from_attributes=True)
# Dependency
def get_db():
db = SessionLocal()
try:
yield db
finally:
db.close()
@router.post("/", response_model=ParameterRead)
def create_parameter(param: ParameterCreate, db: Session = Depends(get_db)):
scen = db.query(Scenario).filter(Scenario.id == param.scenario_id).first()
if not scen:
raise HTTPException(status_code=404, detail="Scenario not found")
distribution_id = param.distribution_id
distribution_type = param.distribution_type
distribution_parameters = param.distribution_parameters
if distribution_id is not None:
distribution = db.query(Distribution).filter(
Distribution.id == distribution_id).first()
if not distribution:
raise HTTPException(
status_code=404, detail="Distribution not found")
distribution_type = distribution.distribution_type
distribution_parameters = distribution.parameters or None
new_param = Parameter(
scenario_id=param.scenario_id,
name=param.name,
value=param.value,
distribution_id=distribution_id,
distribution_type=distribution_type,
distribution_parameters=distribution_parameters,
)
db.add(new_param)
db.commit()
db.refresh(new_param)
return new_param
@router.get("/", response_model=List[ParameterRead])
def list_parameters(db: Session = Depends(get_db)):
return db.query(Parameter).all()