feat: Add profitability calculation schemas and service functions

- Introduced Pydantic schemas for profitability calculations in `schemas/calculations.py`.
- Implemented service functions for profitability calculations in `services/calculations.py`.
- Added new exception class `ProfitabilityValidationError` for handling validation errors.
- Created repositories for managing project and scenario profitability snapshots.
- Developed a utility script for verifying authenticated routes.
- Added a new HTML template for the profitability calculator interface.
- Implemented a script to fix user ID sequence in the database.
This commit is contained in:
2025-11-12 22:22:29 +01:00
parent 6d496a599e
commit b1a6df9f90
15 changed files with 1654 additions and 0 deletions

View File

@@ -1,10 +1,12 @@
"""Service layer utilities."""
from .pricing import calculate_pricing, PricingInput, PricingMetadata, PricingResult
from .calculations import calculate_profitability
__all__ = [
"calculate_pricing",
"PricingInput",
"PricingMetadata",
"PricingResult",
"calculate_profitability",
]

205
services/calculations.py Normal file
View File

@@ -0,0 +1,205 @@
"""Service functions for financial calculations."""
from __future__ import annotations
from services.currency import CurrencyValidationError, normalise_currency
from services.exceptions import ProfitabilityValidationError
from services.financial import (
CashFlow,
ConvergenceError,
PaybackNotReachedError,
internal_rate_of_return,
net_present_value,
payback_period,
)
from services.pricing import PricingInput, PricingMetadata, PricingResult, calculate_pricing
from schemas.calculations import (
CashFlowEntry,
ProfitabilityCalculationRequest,
ProfitabilityCalculationResult,
ProfitabilityCosts,
ProfitabilityMetrics,
)
def _build_pricing_input(
request: ProfitabilityCalculationRequest,
) -> PricingInput:
"""Construct a pricing input instance including impurity overrides."""
impurity_values: dict[str, float] = {}
impurity_thresholds: dict[str, float] = {}
impurity_penalties: dict[str, float] = {}
for impurity in request.impurities:
code = impurity.name.strip()
if not code:
continue
code = code.upper()
if impurity.value is not None:
impurity_values[code] = float(impurity.value)
if impurity.threshold is not None:
impurity_thresholds[code] = float(impurity.threshold)
if impurity.penalty is not None:
impurity_penalties[code] = float(impurity.penalty)
pricing_input = PricingInput(
metal=request.metal,
ore_tonnage=request.ore_tonnage,
head_grade_pct=request.head_grade_pct,
recovery_pct=request.recovery_pct,
payable_pct=request.payable_pct,
reference_price=request.reference_price,
treatment_charge=request.treatment_charge,
smelting_charge=request.smelting_charge,
moisture_pct=request.moisture_pct,
moisture_threshold_pct=request.moisture_threshold_pct,
moisture_penalty_per_pct=request.moisture_penalty_per_pct,
impurity_ppm=impurity_values,
impurity_thresholds=impurity_thresholds,
impurity_penalty_per_ppm=impurity_penalties,
premiums=request.premiums,
fx_rate=request.fx_rate,
currency_code=request.currency_code,
)
return pricing_input
def _generate_cash_flows(
*,
periods: int,
net_per_period: float,
initial_capex: float,
) -> tuple[list[CashFlow], list[CashFlowEntry]]:
"""Create cash flow structures for financial metric calculations."""
cash_flow_models: list[CashFlow] = [
CashFlow(amount=-initial_capex, period_index=0)
]
cash_flow_entries: list[CashFlowEntry] = [
CashFlowEntry(
period=0,
revenue=0.0,
processing_opex=0.0,
sustaining_capex=0.0,
net=-initial_capex,
)
]
for period in range(1, periods + 1):
cash_flow_models.append(
CashFlow(amount=net_per_period, period_index=period))
cash_flow_entries.append(
CashFlowEntry(
period=period,
revenue=0.0,
processing_opex=0.0,
sustaining_capex=0.0,
net=net_per_period,
)
)
return cash_flow_models, cash_flow_entries
def calculate_profitability(
request: ProfitabilityCalculationRequest,
*,
metadata: PricingMetadata,
) -> ProfitabilityCalculationResult:
"""Calculate profitability metrics using pricing inputs and cost data."""
if request.periods <= 0:
raise ProfitabilityValidationError(
"Evaluation periods must be at least 1.", ["periods"]
)
pricing_input = _build_pricing_input(request)
try:
pricing_result: PricingResult = calculate_pricing(
pricing_input, metadata=metadata
)
except CurrencyValidationError as exc:
raise ProfitabilityValidationError(
str(exc), ["currency_code"]) from exc
periods = request.periods
revenue_total = float(pricing_result.net_revenue)
revenue_per_period = revenue_total / periods
processing_total = float(request.processing_opex) * periods
sustaining_total = float(request.sustaining_capex) * periods
initial_capex = float(request.initial_capex)
net_per_period = (
revenue_per_period
- float(request.processing_opex)
- float(request.sustaining_capex)
)
cash_flow_models, cash_flow_entries = _generate_cash_flows(
periods=periods,
net_per_period=net_per_period,
initial_capex=initial_capex,
)
# Update per-period entries to include explicit costs for presentation
for entry in cash_flow_entries[1:]:
entry.revenue = revenue_per_period
entry.processing_opex = float(request.processing_opex)
entry.sustaining_capex = float(request.sustaining_capex)
entry.net = net_per_period
discount_rate = (request.discount_rate or 0.0) / 100.0
npv_value = net_present_value(discount_rate, cash_flow_models)
try:
irr_value = internal_rate_of_return(cash_flow_models) * 100.0
except (ValueError, ZeroDivisionError, ConvergenceError):
irr_value = None
try:
payback_value = payback_period(cash_flow_models)
except (ValueError, PaybackNotReachedError):
payback_value = None
total_costs = processing_total + sustaining_total + initial_capex
total_net = revenue_total - total_costs
if revenue_total == 0:
margin_value = None
else:
margin_value = (total_net / revenue_total) * 100.0
currency = request.currency_code or pricing_result.currency
try:
currency = normalise_currency(currency)
except CurrencyValidationError as exc:
raise ProfitabilityValidationError(
str(exc), ["currency_code"]) from exc
costs = ProfitabilityCosts(
processing_opex_total=processing_total,
sustaining_capex_total=sustaining_total,
initial_capex=initial_capex,
)
metrics = ProfitabilityMetrics(
npv=npv_value,
irr=irr_value,
payback_period=payback_value,
margin=margin_value,
)
return ProfitabilityCalculationResult(
pricing=pricing_result,
costs=costs,
metrics=metrics,
cash_flows=cash_flow_entries,
currency=currency,
)
__all__ = ["calculate_profitability"]

View File

@@ -26,3 +26,14 @@ class ScenarioValidationError(Exception):
def __str__(self) -> str: # pragma: no cover - mirrors message for logging
return self.message
@dataclass(eq=False)
class ProfitabilityValidationError(Exception):
"""Raised when profitability calculation inputs fail domain validation."""
message: str
field_errors: Sequence[str] | None = None
def __str__(self) -> str: # pragma: no cover - mirrors message for logging
return self.message

View File

@@ -15,8 +15,10 @@ from models import (
PricingImpuritySettings,
PricingMetalSettings,
PricingSettings,
ProjectProfitability,
Role,
Scenario,
ScenarioProfitability,
ScenarioStatus,
SimulationParameter,
User,
@@ -367,6 +369,106 @@ class ScenarioRepository:
self.session.delete(scenario)
class ProjectProfitabilityRepository:
"""Persistence operations for project-level profitability snapshots."""
def __init__(self, session: Session) -> None:
self.session = session
def create(self, snapshot: ProjectProfitability) -> ProjectProfitability:
self.session.add(snapshot)
self.session.flush()
return snapshot
def list_for_project(
self,
project_id: int,
*,
limit: int | None = None,
) -> Sequence[ProjectProfitability]:
stmt = (
select(ProjectProfitability)
.where(ProjectProfitability.project_id == project_id)
.order_by(ProjectProfitability.calculated_at.desc())
)
if limit is not None:
stmt = stmt.limit(limit)
return self.session.execute(stmt).scalars().all()
def latest_for_project(
self,
project_id: int,
) -> ProjectProfitability | None:
stmt = (
select(ProjectProfitability)
.where(ProjectProfitability.project_id == project_id)
.order_by(ProjectProfitability.calculated_at.desc())
.limit(1)
)
return self.session.execute(stmt).scalar_one_or_none()
def delete(self, snapshot_id: int) -> None:
stmt = select(ProjectProfitability).where(
ProjectProfitability.id == snapshot_id
)
entity = self.session.execute(stmt).scalar_one_or_none()
if entity is None:
raise EntityNotFoundError(
f"Project profitability snapshot {snapshot_id} not found"
)
self.session.delete(entity)
class ScenarioProfitabilityRepository:
"""Persistence operations for scenario-level profitability snapshots."""
def __init__(self, session: Session) -> None:
self.session = session
def create(self, snapshot: ScenarioProfitability) -> ScenarioProfitability:
self.session.add(snapshot)
self.session.flush()
return snapshot
def list_for_scenario(
self,
scenario_id: int,
*,
limit: int | None = None,
) -> Sequence[ScenarioProfitability]:
stmt = (
select(ScenarioProfitability)
.where(ScenarioProfitability.scenario_id == scenario_id)
.order_by(ScenarioProfitability.calculated_at.desc())
)
if limit is not None:
stmt = stmt.limit(limit)
return self.session.execute(stmt).scalars().all()
def latest_for_scenario(
self,
scenario_id: int,
) -> ScenarioProfitability | None:
stmt = (
select(ScenarioProfitability)
.where(ScenarioProfitability.scenario_id == scenario_id)
.order_by(ScenarioProfitability.calculated_at.desc())
.limit(1)
)
return self.session.execute(stmt).scalar_one_or_none()
def delete(self, snapshot_id: int) -> None:
stmt = select(ScenarioProfitability).where(
ScenarioProfitability.id == snapshot_id
)
entity = self.session.execute(stmt).scalar_one_or_none()
if entity is None:
raise EntityNotFoundError(
f"Scenario profitability snapshot {snapshot_id} not found"
)
self.session.delete(entity)
class FinancialInputRepository:
"""Persistence operations for FinancialInput entities."""

View File

@@ -13,8 +13,10 @@ from services.repositories import (
PricingSettingsRepository,
PricingSettingsSeedResult,
ProjectRepository,
ProjectProfitabilityRepository,
RoleRepository,
ScenarioRepository,
ScenarioProfitabilityRepository,
SimulationParameterRepository,
UserRepository,
ensure_admin_user as ensure_admin_user_record,
@@ -36,6 +38,8 @@ class UnitOfWork(AbstractContextManager["UnitOfWork"]):
self.scenarios: ScenarioRepository | None = None
self.financial_inputs: FinancialInputRepository | None = None
self.simulation_parameters: SimulationParameterRepository | None = None
self.project_profitability: ProjectProfitabilityRepository | None = None
self.scenario_profitability: ScenarioProfitabilityRepository | None = None
self.users: UserRepository | None = None
self.roles: RoleRepository | None = None
self.pricing_settings: PricingSettingsRepository | None = None
@@ -47,6 +51,11 @@ class UnitOfWork(AbstractContextManager["UnitOfWork"]):
self.financial_inputs = FinancialInputRepository(self.session)
self.simulation_parameters = SimulationParameterRepository(
self.session)
self.project_profitability = ProjectProfitabilityRepository(
self.session)
self.scenario_profitability = ScenarioProfitabilityRepository(
self.session
)
self.users = UserRepository(self.session)
self.roles = RoleRepository(self.session)
self.pricing_settings = PricingSettingsRepository(self.session)
@@ -65,6 +74,8 @@ class UnitOfWork(AbstractContextManager["UnitOfWork"]):
self.scenarios = None
self.financial_inputs = None
self.simulation_parameters = None
self.project_profitability = None
self.scenario_profitability = None
self.users = None
self.roles = None
self.pricing_settings = None