feat: Enhance currency handling and validation across scenarios
- Updated form template to prefill currency input with default value and added help text for clarity. - Modified integration tests to assert more descriptive error messages for invalid currency codes. - Introduced new tests for currency normalization and validation in various scenarios, including imports and exports. - Added comprehensive tests for pricing calculations, ensuring defaults are respected and overrides function correctly. - Implemented unit tests for pricing settings repository, ensuring CRUD operations and default settings are handled properly. - Enhanced scenario pricing evaluation tests to validate currency handling and metadata defaults. - Added simulation tests to ensure Monte Carlo runs are accurate and handle various distribution scenarios.
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Iterable
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from typing import Mapping, Sequence
|
||||
|
||||
@@ -11,6 +12,9 @@ from sqlalchemy.orm import Session, joinedload, selectinload
|
||||
from models import (
|
||||
FinancialInput,
|
||||
Project,
|
||||
PricingImpuritySettings,
|
||||
PricingMetalSettings,
|
||||
PricingSettings,
|
||||
ResourceType,
|
||||
Role,
|
||||
Scenario,
|
||||
@@ -21,6 +25,7 @@ from models import (
|
||||
)
|
||||
from services.exceptions import EntityConflictError, EntityNotFoundError
|
||||
from services.export_query import ProjectExportFilters, ScenarioExportFilters
|
||||
from services.pricing import PricingMetadata
|
||||
|
||||
|
||||
class ProjectRepository:
|
||||
@@ -29,10 +34,17 @@ class ProjectRepository:
|
||||
def __init__(self, session: Session) -> None:
|
||||
self.session = session
|
||||
|
||||
def list(self, *, with_children: bool = False) -> Sequence[Project]:
|
||||
def list(
|
||||
self,
|
||||
*,
|
||||
with_children: bool = False,
|
||||
with_pricing: bool = False,
|
||||
) -> Sequence[Project]:
|
||||
stmt = select(Project).order_by(Project.created_at)
|
||||
if with_children:
|
||||
stmt = stmt.options(selectinload(Project.scenarios))
|
||||
if with_pricing:
|
||||
stmt = stmt.options(selectinload(Project.pricing_settings))
|
||||
return self.session.execute(stmt).scalars().all()
|
||||
|
||||
def count(self) -> int:
|
||||
@@ -47,10 +59,18 @@ class ProjectRepository:
|
||||
)
|
||||
return self.session.execute(stmt).scalars().all()
|
||||
|
||||
def get(self, project_id: int, *, with_children: bool = False) -> Project:
|
||||
def get(
|
||||
self,
|
||||
project_id: int,
|
||||
*,
|
||||
with_children: bool = False,
|
||||
with_pricing: bool = False,
|
||||
) -> Project:
|
||||
stmt = select(Project).where(Project.id == project_id)
|
||||
if with_children:
|
||||
stmt = stmt.options(joinedload(Project.scenarios))
|
||||
if with_pricing:
|
||||
stmt = stmt.options(joinedload(Project.pricing_settings))
|
||||
result = self.session.execute(stmt)
|
||||
if with_children:
|
||||
result = result.unique()
|
||||
@@ -86,10 +106,13 @@ class ProjectRepository:
|
||||
filters: ProjectExportFilters | None = None,
|
||||
*,
|
||||
include_scenarios: bool = False,
|
||||
include_pricing: bool = False,
|
||||
) -> Sequence[Project]:
|
||||
stmt = select(Project)
|
||||
if include_scenarios:
|
||||
stmt = stmt.options(selectinload(Project.scenarios))
|
||||
if include_pricing:
|
||||
stmt = stmt.options(selectinload(Project.pricing_settings))
|
||||
|
||||
if filters:
|
||||
ids = filters.normalised_ids()
|
||||
@@ -131,6 +154,18 @@ class ProjectRepository:
|
||||
project = self.get(project_id)
|
||||
self.session.delete(project)
|
||||
|
||||
def set_pricing_settings(
|
||||
self,
|
||||
project: Project,
|
||||
pricing_settings: PricingSettings | None,
|
||||
) -> Project:
|
||||
project.pricing_settings = pricing_settings
|
||||
project.pricing_settings_id = (
|
||||
pricing_settings.id if pricing_settings is not None else None
|
||||
)
|
||||
self.session.flush()
|
||||
return project
|
||||
|
||||
|
||||
class ScenarioRepository:
|
||||
"""Persistence operations for Scenario entities."""
|
||||
@@ -138,13 +173,26 @@ class ScenarioRepository:
|
||||
def __init__(self, session: Session) -> None:
|
||||
self.session = session
|
||||
|
||||
def list_for_project(self, project_id: int) -> Sequence[Scenario]:
|
||||
def list_for_project(
|
||||
self,
|
||||
project_id: int,
|
||||
*,
|
||||
with_children: bool = False,
|
||||
) -> Sequence[Scenario]:
|
||||
stmt = (
|
||||
select(Scenario)
|
||||
.where(Scenario.project_id == project_id)
|
||||
.order_by(Scenario.created_at)
|
||||
)
|
||||
return self.session.execute(stmt).scalars().all()
|
||||
if with_children:
|
||||
stmt = stmt.options(
|
||||
selectinload(Scenario.financial_inputs),
|
||||
selectinload(Scenario.simulation_parameters),
|
||||
)
|
||||
result = self.session.execute(stmt)
|
||||
if with_children:
|
||||
result = result.unique()
|
||||
return result.scalars().all()
|
||||
|
||||
def count(self) -> int:
|
||||
stmt = select(func.count(Scenario.id))
|
||||
@@ -376,6 +424,101 @@ class SimulationParameterRepository:
|
||||
self.session.delete(entity)
|
||||
|
||||
|
||||
class PricingSettingsRepository:
|
||||
"""Persistence operations for pricing configuration entities."""
|
||||
|
||||
def __init__(self, session: Session) -> None:
|
||||
self.session = session
|
||||
|
||||
def list(self, *, include_children: bool = False) -> Sequence[PricingSettings]:
|
||||
stmt = select(PricingSettings).order_by(PricingSettings.created_at)
|
||||
if include_children:
|
||||
stmt = stmt.options(
|
||||
selectinload(PricingSettings.metal_overrides),
|
||||
selectinload(PricingSettings.impurity_overrides),
|
||||
)
|
||||
result = self.session.execute(stmt)
|
||||
if include_children:
|
||||
result = result.unique()
|
||||
return result.scalars().all()
|
||||
|
||||
def get(self, settings_id: int, *, include_children: bool = False) -> PricingSettings:
|
||||
stmt = select(PricingSettings).where(PricingSettings.id == settings_id)
|
||||
if include_children:
|
||||
stmt = stmt.options(
|
||||
selectinload(PricingSettings.metal_overrides),
|
||||
selectinload(PricingSettings.impurity_overrides),
|
||||
)
|
||||
result = self.session.execute(stmt)
|
||||
if include_children:
|
||||
result = result.unique()
|
||||
settings = result.scalar_one_or_none()
|
||||
if settings is None:
|
||||
raise EntityNotFoundError(
|
||||
f"Pricing settings {settings_id} not found")
|
||||
return settings
|
||||
|
||||
def find_by_slug(
|
||||
self,
|
||||
slug: str,
|
||||
*,
|
||||
include_children: bool = False,
|
||||
) -> PricingSettings | None:
|
||||
normalised = slug.strip().lower()
|
||||
stmt = select(PricingSettings).where(
|
||||
PricingSettings.slug == normalised)
|
||||
if include_children:
|
||||
stmt = stmt.options(
|
||||
selectinload(PricingSettings.metal_overrides),
|
||||
selectinload(PricingSettings.impurity_overrides),
|
||||
)
|
||||
result = self.session.execute(stmt)
|
||||
if include_children:
|
||||
result = result.unique()
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
def get_by_slug(self, slug: str, *, include_children: bool = False) -> PricingSettings:
|
||||
settings = self.find_by_slug(slug, include_children=include_children)
|
||||
if settings is None:
|
||||
raise EntityNotFoundError(
|
||||
f"Pricing settings slug '{slug}' not found"
|
||||
)
|
||||
return settings
|
||||
|
||||
def create(self, settings: PricingSettings) -> PricingSettings:
|
||||
self.session.add(settings)
|
||||
try:
|
||||
self.session.flush()
|
||||
except IntegrityError as exc: # pragma: no cover - relies on DB constraints
|
||||
raise EntityConflictError(
|
||||
"Pricing settings violates constraints") from exc
|
||||
return settings
|
||||
|
||||
def delete(self, settings_id: int) -> None:
|
||||
settings = self.get(settings_id, include_children=True)
|
||||
self.session.delete(settings)
|
||||
|
||||
def attach_metal_override(
|
||||
self,
|
||||
settings: PricingSettings,
|
||||
override: PricingMetalSettings,
|
||||
) -> PricingMetalSettings:
|
||||
settings.metal_overrides.append(override)
|
||||
self.session.add(override)
|
||||
self.session.flush()
|
||||
return override
|
||||
|
||||
def attach_impurity_override(
|
||||
self,
|
||||
settings: PricingSettings,
|
||||
override: PricingImpuritySettings,
|
||||
) -> PricingImpuritySettings:
|
||||
settings.impurity_overrides.append(override)
|
||||
self.session.add(override)
|
||||
self.session.flush()
|
||||
return override
|
||||
|
||||
|
||||
class RoleRepository:
|
||||
"""Persistence operations for Role entities."""
|
||||
|
||||
@@ -507,6 +650,159 @@ class UserRepository:
|
||||
self.session.flush()
|
||||
|
||||
|
||||
DEFAULT_PRICING_SETTINGS_NAME = "Default Pricing Settings"
|
||||
DEFAULT_PRICING_SETTINGS_DESCRIPTION = (
|
||||
"Default pricing configuration generated from environment metadata."
|
||||
)
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class PricingSettingsSeedResult:
|
||||
settings: PricingSettings
|
||||
created: bool
|
||||
updated_fields: int
|
||||
impurity_upserts: int
|
||||
|
||||
|
||||
def ensure_default_pricing_settings(
|
||||
repo: PricingSettingsRepository,
|
||||
*,
|
||||
metadata: PricingMetadata,
|
||||
slug: str = "default",
|
||||
name: str | None = None,
|
||||
description: str | None = None,
|
||||
) -> PricingSettingsSeedResult:
|
||||
"""Ensure a baseline pricing settings record exists and matches metadata defaults."""
|
||||
|
||||
normalised_slug = (slug or "default").strip().lower() or "default"
|
||||
target_name = name or DEFAULT_PRICING_SETTINGS_NAME
|
||||
target_description = description or DEFAULT_PRICING_SETTINGS_DESCRIPTION
|
||||
|
||||
updated_fields = 0
|
||||
impurity_upserts = 0
|
||||
|
||||
try:
|
||||
settings = repo.get_by_slug(normalised_slug, include_children=True)
|
||||
created = False
|
||||
except EntityNotFoundError:
|
||||
settings = PricingSettings(
|
||||
name=target_name,
|
||||
slug=normalised_slug,
|
||||
description=target_description,
|
||||
default_currency=metadata.default_currency,
|
||||
default_payable_pct=metadata.default_payable_pct,
|
||||
moisture_threshold_pct=metadata.moisture_threshold_pct,
|
||||
moisture_penalty_per_pct=metadata.moisture_penalty_per_pct,
|
||||
)
|
||||
settings.metadata_payload = None
|
||||
settings = repo.create(settings)
|
||||
created = True
|
||||
else:
|
||||
if settings.name != target_name:
|
||||
settings.name = target_name
|
||||
updated_fields += 1
|
||||
if target_description and settings.description != target_description:
|
||||
settings.description = target_description
|
||||
updated_fields += 1
|
||||
if settings.default_currency != metadata.default_currency:
|
||||
settings.default_currency = metadata.default_currency
|
||||
updated_fields += 1
|
||||
if float(settings.default_payable_pct) != float(metadata.default_payable_pct):
|
||||
settings.default_payable_pct = metadata.default_payable_pct
|
||||
updated_fields += 1
|
||||
if float(settings.moisture_threshold_pct) != float(metadata.moisture_threshold_pct):
|
||||
settings.moisture_threshold_pct = metadata.moisture_threshold_pct
|
||||
updated_fields += 1
|
||||
if float(settings.moisture_penalty_per_pct) != float(metadata.moisture_penalty_per_pct):
|
||||
settings.moisture_penalty_per_pct = metadata.moisture_penalty_per_pct
|
||||
updated_fields += 1
|
||||
|
||||
impurity_thresholds = {
|
||||
code.strip().upper(): float(value)
|
||||
for code, value in (metadata.impurity_thresholds or {}).items()
|
||||
if code.strip()
|
||||
}
|
||||
impurity_penalties = {
|
||||
code.strip().upper(): float(value)
|
||||
for code, value in (metadata.impurity_penalty_per_ppm or {}).items()
|
||||
if code.strip()
|
||||
}
|
||||
|
||||
if impurity_thresholds or impurity_penalties:
|
||||
existing_map = {
|
||||
override.impurity_code: override
|
||||
for override in settings.impurity_overrides
|
||||
}
|
||||
target_codes = set(impurity_thresholds) | set(impurity_penalties)
|
||||
for code in sorted(target_codes):
|
||||
threshold_value = impurity_thresholds.get(code, 0.0)
|
||||
penalty_value = impurity_penalties.get(code, 0.0)
|
||||
existing = existing_map.get(code)
|
||||
if existing is None:
|
||||
repo.attach_impurity_override(
|
||||
settings,
|
||||
PricingImpuritySettings(
|
||||
impurity_code=code,
|
||||
threshold_ppm=threshold_value,
|
||||
penalty_per_ppm=penalty_value,
|
||||
),
|
||||
)
|
||||
impurity_upserts += 1
|
||||
continue
|
||||
changed = False
|
||||
if float(existing.threshold_ppm) != float(threshold_value):
|
||||
existing.threshold_ppm = threshold_value
|
||||
changed = True
|
||||
if float(existing.penalty_per_ppm) != float(penalty_value):
|
||||
existing.penalty_per_ppm = penalty_value
|
||||
changed = True
|
||||
if changed:
|
||||
updated_fields += 1
|
||||
|
||||
if updated_fields > 0 or impurity_upserts > 0:
|
||||
repo.session.flush()
|
||||
|
||||
return PricingSettingsSeedResult(
|
||||
settings=settings,
|
||||
created=created,
|
||||
updated_fields=updated_fields,
|
||||
impurity_upserts=impurity_upserts,
|
||||
)
|
||||
|
||||
|
||||
def pricing_settings_to_metadata(settings: PricingSettings) -> PricingMetadata:
|
||||
"""Convert a persisted pricing settings record into metadata defaults."""
|
||||
|
||||
payload = settings.metadata_payload or {}
|
||||
payload_thresholds = payload.get("impurity_thresholds") or {}
|
||||
payload_penalties = payload.get("impurity_penalty_per_ppm") or {}
|
||||
|
||||
thresholds: dict[str, float] = {
|
||||
code.strip().upper(): float(value)
|
||||
for code, value in payload_thresholds.items()
|
||||
if isinstance(code, str) and code.strip()
|
||||
}
|
||||
penalties: dict[str, float] = {
|
||||
code.strip().upper(): float(value)
|
||||
for code, value in payload_penalties.items()
|
||||
if isinstance(code, str) and code.strip()
|
||||
}
|
||||
|
||||
for override in settings.impurity_overrides:
|
||||
code = override.impurity_code.strip().upper()
|
||||
thresholds[code] = float(override.threshold_ppm)
|
||||
penalties[code] = float(override.penalty_per_ppm)
|
||||
|
||||
return PricingMetadata(
|
||||
default_payable_pct=float(settings.default_payable_pct),
|
||||
default_currency=settings.default_currency,
|
||||
moisture_threshold_pct=float(settings.moisture_threshold_pct),
|
||||
moisture_penalty_per_pct=float(settings.moisture_penalty_per_pct),
|
||||
impurity_thresholds=thresholds,
|
||||
impurity_penalty_per_ppm=penalties,
|
||||
)
|
||||
|
||||
|
||||
DEFAULT_ROLE_DEFINITIONS: tuple[dict[str, str], ...] = (
|
||||
{
|
||||
"name": "admin",
|
||||
|
||||
Reference in New Issue
Block a user