feat: Enhance currency handling and validation across scenarios

- Updated form template to prefill currency input with default value and added help text for clarity.
- Modified integration tests to assert more descriptive error messages for invalid currency codes.
- Introduced new tests for currency normalization and validation in various scenarios, including imports and exports.
- Added comprehensive tests for pricing calculations, ensuring defaults are respected and overrides function correctly.
- Implemented unit tests for pricing settings repository, ensuring CRUD operations and default settings are handled properly.
- Enhanced scenario pricing evaluation tests to validate currency handling and metadata defaults.
- Added simulation tests to ensure Monte Carlo runs are accurate and handle various distribution scenarios.
This commit is contained in:
2025-11-11 18:29:59 +01:00
parent 032e6d2681
commit 795a9f99f4
50 changed files with 5110 additions and 81 deletions

View File

@@ -1,6 +1,7 @@
from __future__ import annotations
from collections.abc import Iterable
from dataclasses import dataclass
from datetime import datetime
from typing import Mapping, Sequence
@@ -11,6 +12,9 @@ from sqlalchemy.orm import Session, joinedload, selectinload
from models import (
FinancialInput,
Project,
PricingImpuritySettings,
PricingMetalSettings,
PricingSettings,
ResourceType,
Role,
Scenario,
@@ -21,6 +25,7 @@ from models import (
)
from services.exceptions import EntityConflictError, EntityNotFoundError
from services.export_query import ProjectExportFilters, ScenarioExportFilters
from services.pricing import PricingMetadata
class ProjectRepository:
@@ -29,10 +34,17 @@ class ProjectRepository:
def __init__(self, session: Session) -> None:
self.session = session
def list(self, *, with_children: bool = False) -> Sequence[Project]:
def list(
self,
*,
with_children: bool = False,
with_pricing: bool = False,
) -> Sequence[Project]:
stmt = select(Project).order_by(Project.created_at)
if with_children:
stmt = stmt.options(selectinload(Project.scenarios))
if with_pricing:
stmt = stmt.options(selectinload(Project.pricing_settings))
return self.session.execute(stmt).scalars().all()
def count(self) -> int:
@@ -47,10 +59,18 @@ class ProjectRepository:
)
return self.session.execute(stmt).scalars().all()
def get(self, project_id: int, *, with_children: bool = False) -> Project:
def get(
self,
project_id: int,
*,
with_children: bool = False,
with_pricing: bool = False,
) -> Project:
stmt = select(Project).where(Project.id == project_id)
if with_children:
stmt = stmt.options(joinedload(Project.scenarios))
if with_pricing:
stmt = stmt.options(joinedload(Project.pricing_settings))
result = self.session.execute(stmt)
if with_children:
result = result.unique()
@@ -86,10 +106,13 @@ class ProjectRepository:
filters: ProjectExportFilters | None = None,
*,
include_scenarios: bool = False,
include_pricing: bool = False,
) -> Sequence[Project]:
stmt = select(Project)
if include_scenarios:
stmt = stmt.options(selectinload(Project.scenarios))
if include_pricing:
stmt = stmt.options(selectinload(Project.pricing_settings))
if filters:
ids = filters.normalised_ids()
@@ -131,6 +154,18 @@ class ProjectRepository:
project = self.get(project_id)
self.session.delete(project)
def set_pricing_settings(
self,
project: Project,
pricing_settings: PricingSettings | None,
) -> Project:
project.pricing_settings = pricing_settings
project.pricing_settings_id = (
pricing_settings.id if pricing_settings is not None else None
)
self.session.flush()
return project
class ScenarioRepository:
"""Persistence operations for Scenario entities."""
@@ -138,13 +173,26 @@ class ScenarioRepository:
def __init__(self, session: Session) -> None:
self.session = session
def list_for_project(self, project_id: int) -> Sequence[Scenario]:
def list_for_project(
self,
project_id: int,
*,
with_children: bool = False,
) -> Sequence[Scenario]:
stmt = (
select(Scenario)
.where(Scenario.project_id == project_id)
.order_by(Scenario.created_at)
)
return self.session.execute(stmt).scalars().all()
if with_children:
stmt = stmt.options(
selectinload(Scenario.financial_inputs),
selectinload(Scenario.simulation_parameters),
)
result = self.session.execute(stmt)
if with_children:
result = result.unique()
return result.scalars().all()
def count(self) -> int:
stmt = select(func.count(Scenario.id))
@@ -376,6 +424,101 @@ class SimulationParameterRepository:
self.session.delete(entity)
class PricingSettingsRepository:
"""Persistence operations for pricing configuration entities."""
def __init__(self, session: Session) -> None:
self.session = session
def list(self, *, include_children: bool = False) -> Sequence[PricingSettings]:
stmt = select(PricingSettings).order_by(PricingSettings.created_at)
if include_children:
stmt = stmt.options(
selectinload(PricingSettings.metal_overrides),
selectinload(PricingSettings.impurity_overrides),
)
result = self.session.execute(stmt)
if include_children:
result = result.unique()
return result.scalars().all()
def get(self, settings_id: int, *, include_children: bool = False) -> PricingSettings:
stmt = select(PricingSettings).where(PricingSettings.id == settings_id)
if include_children:
stmt = stmt.options(
selectinload(PricingSettings.metal_overrides),
selectinload(PricingSettings.impurity_overrides),
)
result = self.session.execute(stmt)
if include_children:
result = result.unique()
settings = result.scalar_one_or_none()
if settings is None:
raise EntityNotFoundError(
f"Pricing settings {settings_id} not found")
return settings
def find_by_slug(
self,
slug: str,
*,
include_children: bool = False,
) -> PricingSettings | None:
normalised = slug.strip().lower()
stmt = select(PricingSettings).where(
PricingSettings.slug == normalised)
if include_children:
stmt = stmt.options(
selectinload(PricingSettings.metal_overrides),
selectinload(PricingSettings.impurity_overrides),
)
result = self.session.execute(stmt)
if include_children:
result = result.unique()
return result.scalar_one_or_none()
def get_by_slug(self, slug: str, *, include_children: bool = False) -> PricingSettings:
settings = self.find_by_slug(slug, include_children=include_children)
if settings is None:
raise EntityNotFoundError(
f"Pricing settings slug '{slug}' not found"
)
return settings
def create(self, settings: PricingSettings) -> PricingSettings:
self.session.add(settings)
try:
self.session.flush()
except IntegrityError as exc: # pragma: no cover - relies on DB constraints
raise EntityConflictError(
"Pricing settings violates constraints") from exc
return settings
def delete(self, settings_id: int) -> None:
settings = self.get(settings_id, include_children=True)
self.session.delete(settings)
def attach_metal_override(
self,
settings: PricingSettings,
override: PricingMetalSettings,
) -> PricingMetalSettings:
settings.metal_overrides.append(override)
self.session.add(override)
self.session.flush()
return override
def attach_impurity_override(
self,
settings: PricingSettings,
override: PricingImpuritySettings,
) -> PricingImpuritySettings:
settings.impurity_overrides.append(override)
self.session.add(override)
self.session.flush()
return override
class RoleRepository:
"""Persistence operations for Role entities."""
@@ -507,6 +650,159 @@ class UserRepository:
self.session.flush()
DEFAULT_PRICING_SETTINGS_NAME = "Default Pricing Settings"
DEFAULT_PRICING_SETTINGS_DESCRIPTION = (
"Default pricing configuration generated from environment metadata."
)
@dataclass(slots=True)
class PricingSettingsSeedResult:
settings: PricingSettings
created: bool
updated_fields: int
impurity_upserts: int
def ensure_default_pricing_settings(
repo: PricingSettingsRepository,
*,
metadata: PricingMetadata,
slug: str = "default",
name: str | None = None,
description: str | None = None,
) -> PricingSettingsSeedResult:
"""Ensure a baseline pricing settings record exists and matches metadata defaults."""
normalised_slug = (slug or "default").strip().lower() or "default"
target_name = name or DEFAULT_PRICING_SETTINGS_NAME
target_description = description or DEFAULT_PRICING_SETTINGS_DESCRIPTION
updated_fields = 0
impurity_upserts = 0
try:
settings = repo.get_by_slug(normalised_slug, include_children=True)
created = False
except EntityNotFoundError:
settings = PricingSettings(
name=target_name,
slug=normalised_slug,
description=target_description,
default_currency=metadata.default_currency,
default_payable_pct=metadata.default_payable_pct,
moisture_threshold_pct=metadata.moisture_threshold_pct,
moisture_penalty_per_pct=metadata.moisture_penalty_per_pct,
)
settings.metadata_payload = None
settings = repo.create(settings)
created = True
else:
if settings.name != target_name:
settings.name = target_name
updated_fields += 1
if target_description and settings.description != target_description:
settings.description = target_description
updated_fields += 1
if settings.default_currency != metadata.default_currency:
settings.default_currency = metadata.default_currency
updated_fields += 1
if float(settings.default_payable_pct) != float(metadata.default_payable_pct):
settings.default_payable_pct = metadata.default_payable_pct
updated_fields += 1
if float(settings.moisture_threshold_pct) != float(metadata.moisture_threshold_pct):
settings.moisture_threshold_pct = metadata.moisture_threshold_pct
updated_fields += 1
if float(settings.moisture_penalty_per_pct) != float(metadata.moisture_penalty_per_pct):
settings.moisture_penalty_per_pct = metadata.moisture_penalty_per_pct
updated_fields += 1
impurity_thresholds = {
code.strip().upper(): float(value)
for code, value in (metadata.impurity_thresholds or {}).items()
if code.strip()
}
impurity_penalties = {
code.strip().upper(): float(value)
for code, value in (metadata.impurity_penalty_per_ppm or {}).items()
if code.strip()
}
if impurity_thresholds or impurity_penalties:
existing_map = {
override.impurity_code: override
for override in settings.impurity_overrides
}
target_codes = set(impurity_thresholds) | set(impurity_penalties)
for code in sorted(target_codes):
threshold_value = impurity_thresholds.get(code, 0.0)
penalty_value = impurity_penalties.get(code, 0.0)
existing = existing_map.get(code)
if existing is None:
repo.attach_impurity_override(
settings,
PricingImpuritySettings(
impurity_code=code,
threshold_ppm=threshold_value,
penalty_per_ppm=penalty_value,
),
)
impurity_upserts += 1
continue
changed = False
if float(existing.threshold_ppm) != float(threshold_value):
existing.threshold_ppm = threshold_value
changed = True
if float(existing.penalty_per_ppm) != float(penalty_value):
existing.penalty_per_ppm = penalty_value
changed = True
if changed:
updated_fields += 1
if updated_fields > 0 or impurity_upserts > 0:
repo.session.flush()
return PricingSettingsSeedResult(
settings=settings,
created=created,
updated_fields=updated_fields,
impurity_upserts=impurity_upserts,
)
def pricing_settings_to_metadata(settings: PricingSettings) -> PricingMetadata:
"""Convert a persisted pricing settings record into metadata defaults."""
payload = settings.metadata_payload or {}
payload_thresholds = payload.get("impurity_thresholds") or {}
payload_penalties = payload.get("impurity_penalty_per_ppm") or {}
thresholds: dict[str, float] = {
code.strip().upper(): float(value)
for code, value in payload_thresholds.items()
if isinstance(code, str) and code.strip()
}
penalties: dict[str, float] = {
code.strip().upper(): float(value)
for code, value in payload_penalties.items()
if isinstance(code, str) and code.strip()
}
for override in settings.impurity_overrides:
code = override.impurity_code.strip().upper()
thresholds[code] = float(override.threshold_ppm)
penalties[code] = float(override.penalty_per_ppm)
return PricingMetadata(
default_payable_pct=float(settings.default_payable_pct),
default_currency=settings.default_currency,
moisture_threshold_pct=float(settings.moisture_threshold_pct),
moisture_penalty_per_pct=float(settings.moisture_penalty_per_pct),
impurity_thresholds=thresholds,
impurity_penalty_per_ppm=penalties,
)
DEFAULT_ROLE_DEFINITIONS: tuple[dict[str, str], ...] = (
{
"name": "admin",