"""Routes handling financial calculation workflows.""" from __future__ import annotations from decimal import Decimal from typing import Any from fastapi import APIRouter, Depends, Query, Request, status from fastapi.responses import HTMLResponse, JSONResponse, Response from fastapi.templating import Jinja2Templates from pydantic import ValidationError from starlette.datastructures import FormData from dependencies import get_pricing_metadata, get_unit_of_work, require_authenticated_user from models import ( Project, ProjectProfitability, Scenario, ScenarioProfitability, User, ) from schemas.calculations import ( ProfitabilityCalculationRequest, ProfitabilityCalculationResult, ) from services.calculations import calculate_profitability from services.exceptions import EntityNotFoundError, ProfitabilityValidationError from services.pricing import PricingMetadata from services.unit_of_work import UnitOfWork router = APIRouter(prefix="/calculations", tags=["Calculations"]) templates = Jinja2Templates(directory="templates") _SUPPORTED_METALS: tuple[dict[str, str], ...] = ( {"value": "copper", "label": "Copper"}, {"value": "gold", "label": "Gold"}, {"value": "lithium", "label": "Lithium"}, ) _SUPPORTED_METAL_VALUES = {entry["value"] for entry in _SUPPORTED_METALS} _DEFAULT_EVALUATION_PERIODS = 10 def _combine_impurity_metadata(metadata: PricingMetadata) -> list[dict[str, object]]: """Build impurity rows combining thresholds and penalties.""" thresholds = getattr(metadata, "impurity_thresholds", {}) or {} penalties = getattr(metadata, "impurity_penalty_per_ppm", {}) or {} impurity_codes = sorted({*thresholds.keys(), *penalties.keys()}) combined: list[dict[str, object]] = [] for code in impurity_codes: combined.append( { "name": code, "threshold": float(thresholds.get(code, 0.0)), "penalty": float(penalties.get(code, 0.0)), "value": None, } ) return combined def _value_or_blank(value: Any) -> Any: if value is None: return "" if isinstance(value, Decimal): return float(value) return value def _normalise_impurity_entries(entries: Any) -> list[dict[str, Any]]: if not entries: return [] normalised: list[dict[str, Any]] = [] for entry in entries: if isinstance(entry, dict): getter = entry.get # type: ignore[assignment] else: def getter(key, default=None, _entry=entry): return getattr( _entry, key, default) normalised.append( { "name": getter("name", "") or "", "value": _value_or_blank(getter("value")), "threshold": _value_or_blank(getter("threshold")), "penalty": _value_or_blank(getter("penalty")), } ) return normalised def _build_default_form_data( *, metadata: PricingMetadata, project: Project | None, scenario: Scenario | None, ) -> dict[str, Any]: payable_default = ( float(metadata.default_payable_pct) if getattr(metadata, "default_payable_pct", None) is not None else 100.0 ) moisture_threshold_default = ( float(metadata.moisture_threshold_pct) if getattr(metadata, "moisture_threshold_pct", None) is not None else 0.0 ) moisture_penalty_default = ( float(metadata.moisture_penalty_per_pct) if getattr(metadata, "moisture_penalty_per_pct", None) is not None else 0.0 ) base_metal_entry = next(iter(_SUPPORTED_METALS), None) metal = base_metal_entry["value"] if base_metal_entry else "" scenario_resource = getattr(scenario, "primary_resource", None) if scenario_resource is not None: candidate = getattr(scenario_resource, "value", str(scenario_resource)) if candidate in _SUPPORTED_METAL_VALUES: metal = candidate currency = "" scenario_currency = getattr(scenario, "currency", None) metadata_currency = getattr(metadata, "default_currency", None) if scenario_currency: currency = str(scenario_currency).upper() elif metadata_currency: currency = str(metadata_currency).upper() discount_rate = "" scenario_discount = getattr(scenario, "discount_rate", None) if scenario_discount is not None: discount_rate = float(scenario_discount) # type: ignore[arg-type] return { "metal": metal, "ore_tonnage": "", "head_grade_pct": "", "recovery_pct": "", "payable_pct": payable_default, "reference_price": "", "treatment_charge": "", "smelting_charge": "", "processing_opex": "", "moisture_pct": "", "moisture_threshold_pct": moisture_threshold_default, "moisture_penalty_per_pct": moisture_penalty_default, "premiums": "", "fx_rate": 1.0, "currency_code": currency, "impurities": None, "initial_capex": "", "sustaining_capex": "", "discount_rate": discount_rate, "periods": _DEFAULT_EVALUATION_PERIODS, } def _prepare_form_data_for_display( *, defaults: dict[str, Any], overrides: dict[str, Any] | None = None, allow_empty_override: bool = False, ) -> dict[str, Any]: data = dict(defaults) if overrides: for key, value in overrides.items(): if key == "csrf_token": continue if key == "impurities": data["impurities"] = _normalise_impurity_entries(value) continue if value is None and not allow_empty_override: continue data[key] = _value_or_blank(value) # Normalise defaults and ensure strings for None. for key, value in list(data.items()): if key == "impurities": if value is None: data[key] = None else: data[key] = _normalise_impurity_entries(value) continue data[key] = _value_or_blank(value) return data def _prepare_default_context( request: Request, *, project: Project | None = None, scenario: Scenario | None = None, metadata: PricingMetadata, form_data: dict[str, Any] | None = None, allow_empty_override: bool = False, result: ProfitabilityCalculationResult | None = None, ) -> dict[str, object]: """Assemble template context shared across calculation endpoints.""" defaults = _build_default_form_data( metadata=metadata, project=project, scenario=scenario, ) data = _prepare_form_data_for_display( defaults=defaults, overrides=form_data, allow_empty_override=allow_empty_override, ) return { "request": request, "project": project, "scenario": scenario, "metadata": metadata, "metadata_impurities": _combine_impurity_metadata(metadata), "supported_metals": _SUPPORTED_METALS, "data": data, "result": result, "errors": [], "notices": [], "cancel_url": request.headers.get("Referer"), "form_action": request.url.path, "csrf_token": None, "default_periods": _DEFAULT_EVALUATION_PERIODS, } def _load_project_and_scenario( *, uow: UnitOfWork, project_id: int | None, scenario_id: int | None, ) -> tuple[Project | None, Scenario | None]: project: Project | None = None scenario: Scenario | None = None if project_id is not None and uow.projects is not None: try: project = uow.projects.get(project_id, with_children=False) except EntityNotFoundError: project = None if scenario_id is not None and uow.scenarios is not None: try: scenario = uow.scenarios.get(scenario_id, with_children=False) except EntityNotFoundError: scenario = None if scenario is not None and project is None: project = scenario.project return project, scenario def _is_json_request(request: Request) -> bool: content_type = request.headers.get("content-type", "").lower() accept = request.headers.get("accept", "").lower() return "application/json" in content_type or "application/json" in accept def _normalise_form_value(value: Any) -> Any: if isinstance(value, str): stripped = value.strip() return stripped if stripped != "" else None return value def _form_to_payload(form: FormData) -> dict[str, Any]: data: dict[str, Any] = {} impurities: dict[int, dict[str, Any]] = {} for key, value in form.multi_items(): normalised_value = _normalise_form_value(value) if key.startswith("impurities[") and "]" in key: try: index_part = key.split("[", 1)[1] index_str, remainder = index_part.split("]", 1) field = remainder.strip("[]") if not field: continue index = int(index_str) except (ValueError, IndexError): continue entry = impurities.setdefault(index, {}) entry[field] = normalised_value continue if key == "csrf_token": continue data[key] = normalised_value if impurities: ordered = [] for _, entry in sorted(impurities.items()): if not entry.get("name"): continue ordered.append(entry) if ordered: data["impurities"] = ordered return data async def _extract_payload(request: Request) -> dict[str, Any]: if request.headers.get("content-type", "").lower().startswith("application/json"): return await request.json() form = await request.form() return _form_to_payload(form) def _list_from_context(context: dict[str, Any], key: str) -> list: value = context.get(key) if isinstance(value, list): return value new_list: list = [] context[key] = new_list return new_list def _should_persist_snapshot( *, project: Project | None, scenario: Scenario | None, payload: ProfitabilityCalculationRequest, ) -> bool: """Determine whether to persist the profitability result. Current strategy persists automatically when a scenario or project context is provided. This can be refined later to honour explicit user choices. """ return bool(scenario or project) def _persist_profitability_snapshots( *, uow: UnitOfWork, project: Project | None, scenario: Scenario | None, user: User | None, request_model: ProfitabilityCalculationRequest, result: ProfitabilityCalculationResult, ) -> None: if not _should_persist_snapshot(project=project, scenario=scenario, payload=request_model): return created_by_id = getattr(user, "id", None) revenue_total = float(result.pricing.net_revenue) processing_total = float(result.costs.processing_opex_total) sustaining_total = float(result.costs.sustaining_capex_total) initial_capex = float(result.costs.initial_capex) net_cash_flow_total = revenue_total - ( processing_total + sustaining_total + initial_capex ) npv_value = ( float(result.metrics.npv) if result.metrics.npv is not None else None ) irr_value = ( float(result.metrics.irr) if result.metrics.irr is not None else None ) payback_value = ( float(result.metrics.payback_period) if result.metrics.payback_period is not None else None ) margin_value = ( float(result.metrics.margin) if result.metrics.margin is not None else None ) payload = { "request": request_model.model_dump(mode="json"), "result": result.model_dump(), } if scenario and uow.scenario_profitability: scenario_snapshot = ScenarioProfitability( scenario_id=scenario.id, created_by_id=created_by_id, calculation_source="calculations.profitability", currency_code=result.currency, npv=npv_value, irr_pct=irr_value, payback_period_years=payback_value, margin_pct=margin_value, revenue_total=revenue_total, processing_opex_total=processing_total, sustaining_capex_total=sustaining_total, initial_capex=initial_capex, net_cash_flow_total=net_cash_flow_total, payload=payload, ) uow.scenario_profitability.create(scenario_snapshot) if project and uow.project_profitability: project_snapshot = ProjectProfitability( project_id=project.id, created_by_id=created_by_id, calculation_source="calculations.profitability", currency_code=result.currency, npv=npv_value, irr_pct=irr_value, payback_period_years=payback_value, margin_pct=margin_value, revenue_total=revenue_total, processing_opex_total=processing_total, sustaining_capex_total=sustaining_total, initial_capex=initial_capex, net_cash_flow_total=net_cash_flow_total, payload=payload, ) uow.project_profitability.create(project_snapshot) @router.get( "/profitability", response_class=HTMLResponse, name="calculations.profitability_form", ) def profitability_form( request: Request, _: User = Depends(require_authenticated_user), metadata: PricingMetadata = Depends(get_pricing_metadata), uow: UnitOfWork = Depends(get_unit_of_work), project_id: int | None = Query( None, description="Optional project identifier"), scenario_id: int | None = Query( None, description="Optional scenario identifier"), ) -> HTMLResponse: """Render the profitability calculation form with default metadata.""" project, scenario = _load_project_and_scenario( uow=uow, project_id=project_id, scenario_id=scenario_id ) context = _prepare_default_context( request, project=project, scenario=scenario, metadata=metadata, ) return templates.TemplateResponse("scenarios/profitability.html", context) @router.post( "/profitability", name="calculations.profitability_submit", ) async def profitability_submit( request: Request, current_user: User = Depends(require_authenticated_user), metadata: PricingMetadata = Depends(get_pricing_metadata), uow: UnitOfWork = Depends(get_unit_of_work), project_id: int | None = Query( None, description="Optional project identifier"), scenario_id: int | None = Query( None, description="Optional scenario identifier"), ) -> Response: """Handle profitability calculations and return HTML or JSON.""" wants_json = _is_json_request(request) payload_data = await _extract_payload(request) try: request_model = ProfitabilityCalculationRequest.model_validate( payload_data) result = calculate_profitability(request_model, metadata=metadata) except ValidationError as exc: if wants_json: return JSONResponse( status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, content={"errors": exc.errors()}, ) project, scenario = _load_project_and_scenario( uow=uow, project_id=project_id, scenario_id=scenario_id ) context = _prepare_default_context( request, project=project, scenario=scenario, metadata=metadata, form_data=payload_data, allow_empty_override=True, ) errors = _list_from_context(context, "errors") errors.extend( [f"{err['loc']} - {err['msg']}" for err in exc.errors()] ) return templates.TemplateResponse( "scenarios/profitability.html", context, status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, ) except ProfitabilityValidationError as exc: if wants_json: return JSONResponse( status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, content={ "errors": exc.field_errors or [], "message": exc.message, }, ) project, scenario = _load_project_and_scenario( uow=uow, project_id=project_id, scenario_id=scenario_id ) context = _prepare_default_context( request, project=project, scenario=scenario, metadata=metadata, form_data=payload_data, allow_empty_override=True, ) messages = list(exc.field_errors or []) or [exc.message] errors = _list_from_context(context, "errors") errors.extend(messages) return templates.TemplateResponse( "scenarios/profitability.html", context, status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, ) project, scenario = _load_project_and_scenario( uow=uow, project_id=project_id, scenario_id=scenario_id ) _persist_profitability_snapshots( uow=uow, project=project, scenario=scenario, user=current_user, request_model=request_model, result=result, ) if wants_json: return JSONResponse( status_code=status.HTTP_200_OK, content=result.model_dump(), ) context = _prepare_default_context( request, project=project, scenario=scenario, metadata=metadata, form_data=request_model.model_dump(mode="json"), result=result, ) notices = _list_from_context(context, "notices") notices.append("Profitability calculation completed successfully.") return templates.TemplateResponse( "scenarios/profitability.html", context, status_code=status.HTTP_200_OK, )