- Introduced a new template for listing scenarios associated with a project. - Added metrics for total, active, draft, and archived scenarios. - Implemented quick actions for creating new scenarios and reviewing project overview. - Enhanced navigation with breadcrumbs for better user experience. refactor: update Opex and Profitability templates for consistency - Changed titles and button labels for clarity in Opex and Profitability templates. - Updated form IDs and action URLs for better alignment with new naming conventions. - Improved navigation links to include scenario and project overviews. test: add integration tests for Opex calculations - Created new tests for Opex calculation HTML and JSON flows. - Validated successful calculations and ensured correct data persistence. - Implemented tests for currency mismatch and unsupported frequency scenarios. test: enhance project and scenario route tests - Added tests to verify scenario list rendering and calculator shortcuts. - Ensured scenario detail pages link back to the portfolio correctly. - Validated project detail pages show associated scenarios accurately.
1269 lines
43 KiB
Python
1269 lines
43 KiB
Python
from __future__ import annotations
|
|
|
|
from collections.abc import Iterable
|
|
from dataclasses import dataclass
|
|
from datetime import datetime
|
|
from typing import Mapping, Sequence
|
|
|
|
from sqlalchemy import select, func
|
|
from sqlalchemy.exc import IntegrityError
|
|
from sqlalchemy.orm import Session, joinedload, selectinload
|
|
|
|
from models import (
|
|
FinancialInput,
|
|
Project,
|
|
PricingImpuritySettings,
|
|
PricingMetalSettings,
|
|
PricingSettings,
|
|
ProjectCapexSnapshot,
|
|
ProjectProfitability,
|
|
ProjectOpexSnapshot,
|
|
NavigationGroup,
|
|
NavigationLink,
|
|
Role,
|
|
Scenario,
|
|
ScenarioCapexSnapshot,
|
|
ScenarioProfitability,
|
|
ScenarioOpexSnapshot,
|
|
ScenarioStatus,
|
|
SimulationParameter,
|
|
User,
|
|
UserRole,
|
|
)
|
|
from services.exceptions import EntityConflictError, EntityNotFoundError
|
|
from services.export_query import ProjectExportFilters, ScenarioExportFilters
|
|
from services.pricing import PricingMetadata
|
|
|
|
|
|
def _enum_value(e):
|
|
"""Return the underlying value for Enum members, otherwise return as-is."""
|
|
return getattr(e, "value", e)
|
|
|
|
|
|
class NavigationRepository:
|
|
"""Persistence operations for navigation metadata."""
|
|
|
|
def __init__(self, session: Session) -> None:
|
|
self.session = session
|
|
|
|
def list_groups_with_links(
|
|
self,
|
|
*,
|
|
include_disabled: bool = False,
|
|
) -> Sequence[NavigationGroup]:
|
|
stmt = (
|
|
select(NavigationGroup)
|
|
.options(
|
|
selectinload(NavigationGroup.links)
|
|
.selectinload(NavigationLink.children)
|
|
)
|
|
.order_by(NavigationGroup.sort_order, NavigationGroup.id)
|
|
)
|
|
if not include_disabled:
|
|
stmt = stmt.where(NavigationGroup.is_enabled.is_(True))
|
|
return self.session.execute(stmt).scalars().all()
|
|
|
|
def get_group_by_slug(self, slug: str) -> NavigationGroup | None:
|
|
stmt = select(NavigationGroup).where(NavigationGroup.slug == slug)
|
|
return self.session.execute(stmt).scalar_one_or_none()
|
|
|
|
def get_link_by_slug(
|
|
self,
|
|
slug: str,
|
|
*,
|
|
group_id: int | None = None,
|
|
) -> NavigationLink | None:
|
|
stmt = select(NavigationLink).where(NavigationLink.slug == slug)
|
|
if group_id is not None:
|
|
stmt = stmt.where(NavigationLink.group_id == group_id)
|
|
return self.session.execute(stmt).scalar_one_or_none()
|
|
|
|
def add_group(self, group: NavigationGroup) -> NavigationGroup:
|
|
self.session.add(group)
|
|
self.session.flush()
|
|
return group
|
|
|
|
def add_link(self, link: NavigationLink) -> NavigationLink:
|
|
self.session.add(link)
|
|
self.session.flush()
|
|
return link
|
|
|
|
class ProjectRepository:
|
|
"""Persistence operations for Project entities."""
|
|
|
|
def __init__(self, session: Session) -> None:
|
|
self.session = session
|
|
|
|
def list(
|
|
self,
|
|
*,
|
|
with_children: bool = False,
|
|
with_pricing: bool = False,
|
|
) -> Sequence[Project]:
|
|
stmt = select(Project).order_by(Project.created_at)
|
|
if with_children:
|
|
stmt = stmt.options(selectinload(Project.scenarios))
|
|
if with_pricing:
|
|
stmt = stmt.options(selectinload(Project.pricing_settings))
|
|
return self.session.execute(stmt).scalars().all()
|
|
|
|
def count(self) -> int:
|
|
stmt = select(func.count(Project.id))
|
|
return self.session.execute(stmt).scalar_one()
|
|
|
|
def recent(self, limit: int = 5) -> Sequence[Project]:
|
|
stmt = (
|
|
select(Project)
|
|
.order_by(Project.updated_at.desc())
|
|
.limit(limit)
|
|
)
|
|
return self.session.execute(stmt).scalars().all()
|
|
|
|
def get(
|
|
self,
|
|
project_id: int,
|
|
*,
|
|
with_children: bool = False,
|
|
with_pricing: bool = False,
|
|
) -> Project:
|
|
stmt = select(Project).where(Project.id == project_id)
|
|
if with_children:
|
|
stmt = stmt.options(joinedload(Project.scenarios))
|
|
if with_pricing:
|
|
stmt = stmt.options(joinedload(Project.pricing_settings))
|
|
result = self.session.execute(stmt)
|
|
if with_children:
|
|
result = result.unique()
|
|
project = result.scalar_one_or_none()
|
|
if project is None:
|
|
raise EntityNotFoundError(f"Project {project_id} not found")
|
|
return project
|
|
|
|
def exists(self, project_id: int) -> bool:
|
|
stmt = select(Project.id).where(Project.id == project_id).limit(1)
|
|
return self.session.execute(stmt).scalar_one_or_none() is not None
|
|
|
|
def create(self, project: Project) -> Project:
|
|
self.session.add(project)
|
|
try:
|
|
self.session.flush()
|
|
except IntegrityError as exc: # pragma: no cover - reliance on DB constraints
|
|
from monitoring.metrics import observe_project_operation
|
|
observe_project_operation("create", "error")
|
|
raise EntityConflictError(
|
|
"Project violates uniqueness constraints") from exc
|
|
from monitoring.metrics import observe_project_operation
|
|
observe_project_operation("create", "success")
|
|
return project
|
|
|
|
def find_by_names(self, names: Iterable[str]) -> Mapping[str, Project]:
|
|
normalised = {name.strip().lower()
|
|
for name in names if name and name.strip()}
|
|
if not normalised:
|
|
return {}
|
|
stmt = select(Project).where(func.lower(Project.name).in_(normalised))
|
|
records = self.session.execute(stmt).scalars().all()
|
|
return {project.name.lower(): project for project in records}
|
|
|
|
def filtered_for_export(
|
|
self,
|
|
filters: ProjectExportFilters | None = None,
|
|
*,
|
|
include_scenarios: bool = False,
|
|
include_pricing: bool = False,
|
|
) -> Sequence[Project]:
|
|
stmt = select(Project)
|
|
if include_scenarios:
|
|
stmt = stmt.options(selectinload(Project.scenarios))
|
|
if include_pricing:
|
|
stmt = stmt.options(selectinload(Project.pricing_settings))
|
|
|
|
if filters:
|
|
ids = filters.normalised_ids()
|
|
if ids:
|
|
stmt = stmt.where(Project.id.in_(ids))
|
|
|
|
name_matches = filters.normalised_names()
|
|
if name_matches:
|
|
stmt = stmt.where(func.lower(Project.name).in_(name_matches))
|
|
|
|
name_pattern = filters.name_search_pattern()
|
|
if name_pattern:
|
|
stmt = stmt.where(Project.name.ilike(name_pattern))
|
|
|
|
locations = filters.normalised_locations()
|
|
if locations:
|
|
stmt = stmt.where(func.lower(Project.location).in_(locations))
|
|
|
|
if filters.operation_types:
|
|
stmt = stmt.where(Project.operation_type.in_(
|
|
filters.operation_types))
|
|
|
|
if filters.created_from:
|
|
stmt = stmt.where(Project.created_at >= filters.created_from)
|
|
|
|
if filters.created_to:
|
|
stmt = stmt.where(Project.created_at <= filters.created_to)
|
|
|
|
if filters.updated_from:
|
|
stmt = stmt.where(Project.updated_at >= filters.updated_from)
|
|
|
|
if filters.updated_to:
|
|
stmt = stmt.where(Project.updated_at <= filters.updated_to)
|
|
|
|
stmt = stmt.order_by(Project.name, Project.id)
|
|
return self.session.execute(stmt).scalars().all()
|
|
|
|
def delete(self, project_id: int) -> None:
|
|
project = self.get(project_id)
|
|
self.session.delete(project)
|
|
|
|
def set_pricing_settings(
|
|
self,
|
|
project: Project,
|
|
pricing_settings: PricingSettings | None,
|
|
) -> Project:
|
|
project.pricing_settings = pricing_settings
|
|
project.pricing_settings_id = (
|
|
pricing_settings.id if pricing_settings is not None else None
|
|
)
|
|
self.session.flush()
|
|
return project
|
|
|
|
|
|
class ScenarioRepository:
|
|
"""Persistence operations for Scenario entities."""
|
|
|
|
def __init__(self, session: Session) -> None:
|
|
self.session = session
|
|
|
|
def list_for_project(
|
|
self,
|
|
project_id: int,
|
|
*,
|
|
with_children: bool = False,
|
|
) -> Sequence[Scenario]:
|
|
stmt = (
|
|
select(Scenario)
|
|
.where(Scenario.project_id == project_id)
|
|
.order_by(Scenario.created_at)
|
|
)
|
|
if with_children:
|
|
stmt = stmt.options(
|
|
selectinload(Scenario.financial_inputs),
|
|
selectinload(Scenario.simulation_parameters),
|
|
)
|
|
result = self.session.execute(stmt)
|
|
if with_children:
|
|
result = result.unique()
|
|
return result.scalars().all()
|
|
|
|
def count(self) -> int:
|
|
stmt = select(func.count(Scenario.id))
|
|
return self.session.execute(stmt).scalar_one()
|
|
|
|
def count_by_status(self, status: ScenarioStatus) -> int:
|
|
status_val = _enum_value(status)
|
|
stmt = select(func.count(Scenario.id)).where(
|
|
Scenario.status == status_val)
|
|
return self.session.execute(stmt).scalar_one()
|
|
|
|
def recent(self, limit: int = 5, *, with_project: bool = False) -> Sequence[Scenario]:
|
|
stmt = select(Scenario).order_by(
|
|
Scenario.updated_at.desc()).limit(limit)
|
|
if with_project:
|
|
stmt = stmt.options(joinedload(Scenario.project))
|
|
return self.session.execute(stmt).scalars().all()
|
|
|
|
def list_by_status(
|
|
self,
|
|
status: ScenarioStatus,
|
|
*,
|
|
limit: int | None = None,
|
|
with_project: bool = False,
|
|
) -> Sequence[Scenario]:
|
|
status_val = _enum_value(status)
|
|
stmt = (
|
|
select(Scenario)
|
|
.where(Scenario.status == status_val)
|
|
.order_by(Scenario.updated_at.desc())
|
|
)
|
|
if with_project:
|
|
stmt = stmt.options(joinedload(Scenario.project))
|
|
if limit is not None:
|
|
stmt = stmt.limit(limit)
|
|
return self.session.execute(stmt).scalars().all()
|
|
|
|
def get(self, scenario_id: int, *, with_children: bool = False) -> Scenario:
|
|
stmt = select(Scenario).where(Scenario.id == scenario_id)
|
|
if with_children:
|
|
stmt = stmt.options(
|
|
joinedload(Scenario.financial_inputs),
|
|
joinedload(Scenario.simulation_parameters),
|
|
)
|
|
result = self.session.execute(stmt)
|
|
if with_children:
|
|
result = result.unique()
|
|
scenario = result.scalar_one_or_none()
|
|
if scenario is None:
|
|
raise EntityNotFoundError(f"Scenario {scenario_id} not found")
|
|
return scenario
|
|
|
|
def exists(self, scenario_id: int) -> bool:
|
|
stmt = select(Scenario.id).where(Scenario.id == scenario_id).limit(1)
|
|
return self.session.execute(stmt).scalar_one_or_none() is not None
|
|
|
|
def create(self, scenario: Scenario) -> Scenario:
|
|
self.session.add(scenario)
|
|
try:
|
|
self.session.flush()
|
|
except IntegrityError as exc: # pragma: no cover
|
|
from monitoring.metrics import observe_scenario_operation
|
|
observe_scenario_operation("create", "error")
|
|
raise EntityConflictError("Scenario violates constraints") from exc
|
|
from monitoring.metrics import observe_scenario_operation
|
|
observe_scenario_operation("create", "success")
|
|
return scenario
|
|
|
|
def find_by_project_and_names(
|
|
self,
|
|
project_id: int,
|
|
names: Iterable[str],
|
|
) -> Mapping[str, Scenario]:
|
|
normalised = {name.strip().lower()
|
|
for name in names if name and name.strip()}
|
|
if not normalised:
|
|
return {}
|
|
stmt = (
|
|
select(Scenario)
|
|
.where(
|
|
Scenario.project_id == project_id,
|
|
func.lower(Scenario.name).in_(normalised),
|
|
)
|
|
)
|
|
records = self.session.execute(stmt).scalars().all()
|
|
return {scenario.name.lower(): scenario for scenario in records}
|
|
|
|
def filtered_for_export(
|
|
self,
|
|
filters: ScenarioExportFilters | None = None,
|
|
*,
|
|
include_project: bool = True,
|
|
) -> Sequence[Scenario]:
|
|
stmt = select(Scenario)
|
|
if include_project:
|
|
stmt = stmt.options(joinedload(Scenario.project))
|
|
|
|
if filters:
|
|
scenario_ids = filters.normalised_ids()
|
|
if scenario_ids:
|
|
stmt = stmt.where(Scenario.id.in_(scenario_ids))
|
|
|
|
project_ids = filters.normalised_project_ids()
|
|
if project_ids:
|
|
stmt = stmt.where(Scenario.project_id.in_(project_ids))
|
|
|
|
project_names = filters.normalised_project_names()
|
|
if project_names:
|
|
project_id_select = select(Project.id).where(
|
|
func.lower(Project.name).in_(project_names)
|
|
)
|
|
stmt = stmt.where(Scenario.project_id.in_(project_id_select))
|
|
|
|
name_pattern = filters.name_search_pattern()
|
|
if name_pattern:
|
|
stmt = stmt.where(Scenario.name.ilike(name_pattern))
|
|
|
|
if filters.statuses:
|
|
# Accept Enum members or raw values in filters.statuses
|
|
status_values = [
|
|
_enum_value(s) for s in (filters.statuses or [])
|
|
]
|
|
stmt = stmt.where(Scenario.status.in_(status_values))
|
|
|
|
if filters.start_date_from:
|
|
stmt = stmt.where(Scenario.start_date >=
|
|
filters.start_date_from)
|
|
|
|
if filters.start_date_to:
|
|
stmt = stmt.where(Scenario.start_date <= filters.start_date_to)
|
|
|
|
if filters.end_date_from:
|
|
stmt = stmt.where(Scenario.end_date >= filters.end_date_from)
|
|
|
|
if filters.end_date_to:
|
|
stmt = stmt.where(Scenario.end_date <= filters.end_date_to)
|
|
|
|
if filters.created_from:
|
|
stmt = stmt.where(Scenario.created_at >= filters.created_from)
|
|
|
|
if filters.created_to:
|
|
stmt = stmt.where(Scenario.created_at <= filters.created_to)
|
|
|
|
if filters.updated_from:
|
|
stmt = stmt.where(Scenario.updated_at >= filters.updated_from)
|
|
|
|
if filters.updated_to:
|
|
stmt = stmt.where(Scenario.updated_at <= filters.updated_to)
|
|
|
|
currencies = filters.normalised_currencies()
|
|
if currencies:
|
|
stmt = stmt.where(func.upper(
|
|
Scenario.currency).in_(currencies))
|
|
|
|
if filters.primary_resources:
|
|
stmt = stmt.where(Scenario.primary_resource.in_(
|
|
filters.primary_resources))
|
|
|
|
stmt = stmt.order_by(Scenario.name, Scenario.id)
|
|
return self.session.execute(stmt).scalars().all()
|
|
|
|
def delete(self, scenario_id: int) -> None:
|
|
scenario = self.get(scenario_id)
|
|
self.session.delete(scenario)
|
|
|
|
|
|
class ProjectProfitabilityRepository:
|
|
"""Persistence operations for project-level profitability snapshots."""
|
|
|
|
def __init__(self, session: Session) -> None:
|
|
self.session = session
|
|
|
|
def create(self, snapshot: ProjectProfitability) -> ProjectProfitability:
|
|
self.session.add(snapshot)
|
|
self.session.flush()
|
|
return snapshot
|
|
|
|
def list_for_project(
|
|
self,
|
|
project_id: int,
|
|
*,
|
|
limit: int | None = None,
|
|
) -> Sequence[ProjectProfitability]:
|
|
stmt = (
|
|
select(ProjectProfitability)
|
|
.where(ProjectProfitability.project_id == project_id)
|
|
.order_by(ProjectProfitability.calculated_at.desc())
|
|
)
|
|
if limit is not None:
|
|
stmt = stmt.limit(limit)
|
|
return self.session.execute(stmt).scalars().all()
|
|
|
|
def latest_for_project(
|
|
self,
|
|
project_id: int,
|
|
) -> ProjectProfitability | None:
|
|
stmt = (
|
|
select(ProjectProfitability)
|
|
.where(ProjectProfitability.project_id == project_id)
|
|
.order_by(ProjectProfitability.calculated_at.desc())
|
|
.limit(1)
|
|
)
|
|
return self.session.execute(stmt).scalar_one_or_none()
|
|
|
|
def delete(self, snapshot_id: int) -> None:
|
|
stmt = select(ProjectProfitability).where(
|
|
ProjectProfitability.id == snapshot_id
|
|
)
|
|
entity = self.session.execute(stmt).scalar_one_or_none()
|
|
if entity is None:
|
|
raise EntityNotFoundError(
|
|
f"Project profitability snapshot {snapshot_id} not found"
|
|
)
|
|
self.session.delete(entity)
|
|
|
|
|
|
class ScenarioProfitabilityRepository:
|
|
"""Persistence operations for scenario-level profitability snapshots."""
|
|
|
|
def __init__(self, session: Session) -> None:
|
|
self.session = session
|
|
|
|
def create(self, snapshot: ScenarioProfitability) -> ScenarioProfitability:
|
|
self.session.add(snapshot)
|
|
self.session.flush()
|
|
return snapshot
|
|
|
|
def list_for_scenario(
|
|
self,
|
|
scenario_id: int,
|
|
*,
|
|
limit: int | None = None,
|
|
) -> Sequence[ScenarioProfitability]:
|
|
stmt = (
|
|
select(ScenarioProfitability)
|
|
.where(ScenarioProfitability.scenario_id == scenario_id)
|
|
.order_by(ScenarioProfitability.calculated_at.desc())
|
|
)
|
|
if limit is not None:
|
|
stmt = stmt.limit(limit)
|
|
return self.session.execute(stmt).scalars().all()
|
|
|
|
def latest_for_scenario(
|
|
self,
|
|
scenario_id: int,
|
|
) -> ScenarioProfitability | None:
|
|
stmt = (
|
|
select(ScenarioProfitability)
|
|
.where(ScenarioProfitability.scenario_id == scenario_id)
|
|
.order_by(ScenarioProfitability.calculated_at.desc())
|
|
.limit(1)
|
|
)
|
|
return self.session.execute(stmt).scalar_one_or_none()
|
|
|
|
def delete(self, snapshot_id: int) -> None:
|
|
stmt = select(ScenarioProfitability).where(
|
|
ScenarioProfitability.id == snapshot_id
|
|
)
|
|
entity = self.session.execute(stmt).scalar_one_or_none()
|
|
if entity is None:
|
|
raise EntityNotFoundError(
|
|
f"Scenario profitability snapshot {snapshot_id} not found"
|
|
)
|
|
self.session.delete(entity)
|
|
|
|
|
|
class ProjectCapexRepository:
|
|
"""Persistence operations for project-level capex snapshots."""
|
|
|
|
def __init__(self, session: Session) -> None:
|
|
self.session = session
|
|
|
|
def create(self, snapshot: ProjectCapexSnapshot) -> ProjectCapexSnapshot:
|
|
self.session.add(snapshot)
|
|
self.session.flush()
|
|
return snapshot
|
|
|
|
def list_for_project(
|
|
self,
|
|
project_id: int,
|
|
*,
|
|
limit: int | None = None,
|
|
) -> Sequence[ProjectCapexSnapshot]:
|
|
stmt = (
|
|
select(ProjectCapexSnapshot)
|
|
.where(ProjectCapexSnapshot.project_id == project_id)
|
|
.order_by(ProjectCapexSnapshot.calculated_at.desc())
|
|
)
|
|
if limit is not None:
|
|
stmt = stmt.limit(limit)
|
|
return self.session.execute(stmt).scalars().all()
|
|
|
|
def latest_for_project(
|
|
self,
|
|
project_id: int,
|
|
) -> ProjectCapexSnapshot | None:
|
|
stmt = (
|
|
select(ProjectCapexSnapshot)
|
|
.where(ProjectCapexSnapshot.project_id == project_id)
|
|
.order_by(ProjectCapexSnapshot.calculated_at.desc())
|
|
.limit(1)
|
|
)
|
|
return self.session.execute(stmt).scalar_one_or_none()
|
|
|
|
def delete(self, snapshot_id: int) -> None:
|
|
stmt = select(ProjectCapexSnapshot).where(
|
|
ProjectCapexSnapshot.id == snapshot_id
|
|
)
|
|
entity = self.session.execute(stmt).scalar_one_or_none()
|
|
if entity is None:
|
|
raise EntityNotFoundError(
|
|
f"Project capex snapshot {snapshot_id} not found"
|
|
)
|
|
self.session.delete(entity)
|
|
|
|
|
|
class ScenarioCapexRepository:
|
|
"""Persistence operations for scenario-level capex snapshots."""
|
|
|
|
def __init__(self, session: Session) -> None:
|
|
self.session = session
|
|
|
|
def create(self, snapshot: ScenarioCapexSnapshot) -> ScenarioCapexSnapshot:
|
|
self.session.add(snapshot)
|
|
self.session.flush()
|
|
return snapshot
|
|
|
|
def list_for_scenario(
|
|
self,
|
|
scenario_id: int,
|
|
*,
|
|
limit: int | None = None,
|
|
) -> Sequence[ScenarioCapexSnapshot]:
|
|
stmt = (
|
|
select(ScenarioCapexSnapshot)
|
|
.where(ScenarioCapexSnapshot.scenario_id == scenario_id)
|
|
.order_by(ScenarioCapexSnapshot.calculated_at.desc())
|
|
)
|
|
if limit is not None:
|
|
stmt = stmt.limit(limit)
|
|
return self.session.execute(stmt).scalars().all()
|
|
|
|
def latest_for_scenario(
|
|
self,
|
|
scenario_id: int,
|
|
) -> ScenarioCapexSnapshot | None:
|
|
stmt = (
|
|
select(ScenarioCapexSnapshot)
|
|
.where(ScenarioCapexSnapshot.scenario_id == scenario_id)
|
|
.order_by(ScenarioCapexSnapshot.calculated_at.desc())
|
|
.limit(1)
|
|
)
|
|
return self.session.execute(stmt).scalar_one_or_none()
|
|
|
|
def delete(self, snapshot_id: int) -> None:
|
|
stmt = select(ScenarioCapexSnapshot).where(
|
|
ScenarioCapexSnapshot.id == snapshot_id
|
|
)
|
|
entity = self.session.execute(stmt).scalar_one_or_none()
|
|
if entity is None:
|
|
raise EntityNotFoundError(
|
|
f"Scenario capex snapshot {snapshot_id} not found"
|
|
)
|
|
self.session.delete(entity)
|
|
|
|
|
|
class ProjectOpexRepository:
|
|
"""Persistence operations for project-level opex snapshots."""
|
|
|
|
def __init__(self, session: Session) -> None:
|
|
self.session = session
|
|
|
|
def create(
|
|
self, snapshot: ProjectOpexSnapshot
|
|
) -> ProjectOpexSnapshot:
|
|
self.session.add(snapshot)
|
|
self.session.flush()
|
|
return snapshot
|
|
|
|
def list_for_project(
|
|
self,
|
|
project_id: int,
|
|
*,
|
|
limit: int | None = None,
|
|
) -> Sequence[ProjectOpexSnapshot]:
|
|
stmt = (
|
|
select(ProjectOpexSnapshot)
|
|
.where(ProjectOpexSnapshot.project_id == project_id)
|
|
.order_by(ProjectOpexSnapshot.calculated_at.desc())
|
|
)
|
|
if limit is not None:
|
|
stmt = stmt.limit(limit)
|
|
return self.session.execute(stmt).scalars().all()
|
|
|
|
def latest_for_project(
|
|
self,
|
|
project_id: int,
|
|
) -> ProjectOpexSnapshot | None:
|
|
stmt = (
|
|
select(ProjectOpexSnapshot)
|
|
.where(ProjectOpexSnapshot.project_id == project_id)
|
|
.order_by(ProjectOpexSnapshot.calculated_at.desc())
|
|
.limit(1)
|
|
)
|
|
return self.session.execute(stmt).scalar_one_or_none()
|
|
|
|
def delete(self, snapshot_id: int) -> None:
|
|
stmt = select(ProjectOpexSnapshot).where(
|
|
ProjectOpexSnapshot.id == snapshot_id
|
|
)
|
|
entity = self.session.execute(stmt).scalar_one_or_none()
|
|
if entity is None:
|
|
raise EntityNotFoundError(
|
|
f"Project opex snapshot {snapshot_id} not found"
|
|
)
|
|
self.session.delete(entity)
|
|
|
|
|
|
class ScenarioOpexRepository:
|
|
"""Persistence operations for scenario-level opex snapshots."""
|
|
|
|
def __init__(self, session: Session) -> None:
|
|
self.session = session
|
|
|
|
def create(
|
|
self, snapshot: ScenarioOpexSnapshot
|
|
) -> ScenarioOpexSnapshot:
|
|
self.session.add(snapshot)
|
|
self.session.flush()
|
|
return snapshot
|
|
|
|
def list_for_scenario(
|
|
self,
|
|
scenario_id: int,
|
|
*,
|
|
limit: int | None = None,
|
|
) -> Sequence[ScenarioOpexSnapshot]:
|
|
stmt = (
|
|
select(ScenarioOpexSnapshot)
|
|
.where(ScenarioOpexSnapshot.scenario_id == scenario_id)
|
|
.order_by(ScenarioOpexSnapshot.calculated_at.desc())
|
|
)
|
|
if limit is not None:
|
|
stmt = stmt.limit(limit)
|
|
return self.session.execute(stmt).scalars().all()
|
|
|
|
def latest_for_scenario(
|
|
self,
|
|
scenario_id: int,
|
|
) -> ScenarioOpexSnapshot | None:
|
|
stmt = (
|
|
select(ScenarioOpexSnapshot)
|
|
.where(ScenarioOpexSnapshot.scenario_id == scenario_id)
|
|
.order_by(ScenarioOpexSnapshot.calculated_at.desc())
|
|
.limit(1)
|
|
)
|
|
return self.session.execute(stmt).scalar_one_or_none()
|
|
|
|
def delete(self, snapshot_id: int) -> None:
|
|
stmt = select(ScenarioOpexSnapshot).where(
|
|
ScenarioOpexSnapshot.id == snapshot_id
|
|
)
|
|
entity = self.session.execute(stmt).scalar_one_or_none()
|
|
if entity is None:
|
|
raise EntityNotFoundError(
|
|
f"Scenario opex snapshot {snapshot_id} not found"
|
|
)
|
|
self.session.delete(entity)
|
|
|
|
|
|
class FinancialInputRepository:
|
|
"""Persistence operations for FinancialInput entities."""
|
|
|
|
def __init__(self, session: Session) -> None:
|
|
self.session = session
|
|
|
|
def list_for_scenario(self, scenario_id: int) -> Sequence[FinancialInput]:
|
|
stmt = (
|
|
select(FinancialInput)
|
|
.where(FinancialInput.scenario_id == scenario_id)
|
|
.order_by(FinancialInput.created_at)
|
|
)
|
|
return self.session.execute(stmt).scalars().all()
|
|
|
|
def bulk_upsert(self, inputs: Iterable[FinancialInput]) -> Sequence[FinancialInput]:
|
|
entities = list(inputs)
|
|
self.session.add_all(entities)
|
|
try:
|
|
self.session.flush()
|
|
except IntegrityError as exc: # pragma: no cover
|
|
raise EntityConflictError(
|
|
"Financial input violates constraints") from exc
|
|
return entities
|
|
|
|
def delete(self, input_id: int) -> None:
|
|
stmt = select(FinancialInput).where(FinancialInput.id == input_id)
|
|
entity = self.session.execute(stmt).scalar_one_or_none()
|
|
if entity is None:
|
|
raise EntityNotFoundError(f"Financial input {input_id} not found")
|
|
self.session.delete(entity)
|
|
|
|
def latest_created_at(self) -> datetime | None:
|
|
stmt = (
|
|
select(FinancialInput.created_at)
|
|
.order_by(FinancialInput.created_at.desc())
|
|
.limit(1)
|
|
)
|
|
return self.session.execute(stmt).scalar_one_or_none()
|
|
|
|
|
|
class SimulationParameterRepository:
|
|
"""Persistence operations for SimulationParameter entities."""
|
|
|
|
def __init__(self, session: Session) -> None:
|
|
self.session = session
|
|
|
|
def list_for_scenario(self, scenario_id: int) -> Sequence[SimulationParameter]:
|
|
stmt = (
|
|
select(SimulationParameter)
|
|
.where(SimulationParameter.scenario_id == scenario_id)
|
|
.order_by(SimulationParameter.created_at)
|
|
)
|
|
return self.session.execute(stmt).scalars().all()
|
|
|
|
def bulk_upsert(
|
|
self, parameters: Iterable[SimulationParameter]
|
|
) -> Sequence[SimulationParameter]:
|
|
entities = list(parameters)
|
|
self.session.add_all(entities)
|
|
try:
|
|
self.session.flush()
|
|
except IntegrityError as exc: # pragma: no cover
|
|
raise EntityConflictError(
|
|
"Simulation parameter violates constraints") from exc
|
|
return entities
|
|
|
|
def delete(self, parameter_id: int) -> None:
|
|
stmt = select(SimulationParameter).where(
|
|
SimulationParameter.id == parameter_id)
|
|
entity = self.session.execute(stmt).scalar_one_or_none()
|
|
if entity is None:
|
|
raise EntityNotFoundError(
|
|
f"Simulation parameter {parameter_id} not found")
|
|
self.session.delete(entity)
|
|
|
|
|
|
class PricingSettingsRepository:
|
|
"""Persistence operations for pricing configuration entities."""
|
|
|
|
def __init__(self, session: Session) -> None:
|
|
self.session = session
|
|
|
|
def list(self, *, include_children: bool = False) -> Sequence[PricingSettings]:
|
|
stmt = select(PricingSettings).order_by(PricingSettings.created_at)
|
|
if include_children:
|
|
stmt = stmt.options(
|
|
selectinload(PricingSettings.metal_overrides),
|
|
selectinload(PricingSettings.impurity_overrides),
|
|
)
|
|
result = self.session.execute(stmt)
|
|
if include_children:
|
|
result = result.unique()
|
|
return result.scalars().all()
|
|
|
|
def get(self, settings_id: int, *, include_children: bool = False) -> PricingSettings:
|
|
stmt = select(PricingSettings).where(PricingSettings.id == settings_id)
|
|
if include_children:
|
|
stmt = stmt.options(
|
|
selectinload(PricingSettings.metal_overrides),
|
|
selectinload(PricingSettings.impurity_overrides),
|
|
)
|
|
result = self.session.execute(stmt)
|
|
if include_children:
|
|
result = result.unique()
|
|
settings = result.scalar_one_or_none()
|
|
if settings is None:
|
|
raise EntityNotFoundError(
|
|
f"Pricing settings {settings_id} not found")
|
|
return settings
|
|
|
|
def find_by_slug(
|
|
self,
|
|
slug: str,
|
|
*,
|
|
include_children: bool = False,
|
|
) -> PricingSettings | None:
|
|
normalised = slug.strip().lower()
|
|
stmt = select(PricingSettings).where(
|
|
PricingSettings.slug == normalised)
|
|
if include_children:
|
|
stmt = stmt.options(
|
|
selectinload(PricingSettings.metal_overrides),
|
|
selectinload(PricingSettings.impurity_overrides),
|
|
)
|
|
result = self.session.execute(stmt)
|
|
if include_children:
|
|
result = result.unique()
|
|
return result.scalar_one_or_none()
|
|
|
|
def get_by_slug(self, slug: str, *, include_children: bool = False) -> PricingSettings:
|
|
settings = self.find_by_slug(slug, include_children=include_children)
|
|
if settings is None:
|
|
raise EntityNotFoundError(
|
|
f"Pricing settings slug '{slug}' not found"
|
|
)
|
|
return settings
|
|
|
|
def create(self, settings: PricingSettings) -> PricingSettings:
|
|
self.session.add(settings)
|
|
try:
|
|
self.session.flush()
|
|
except IntegrityError as exc: # pragma: no cover - relies on DB constraints
|
|
raise EntityConflictError(
|
|
"Pricing settings violates constraints") from exc
|
|
return settings
|
|
|
|
def delete(self, settings_id: int) -> None:
|
|
settings = self.get(settings_id, include_children=True)
|
|
self.session.delete(settings)
|
|
|
|
def attach_metal_override(
|
|
self,
|
|
settings: PricingSettings,
|
|
override: PricingMetalSettings,
|
|
) -> PricingMetalSettings:
|
|
settings.metal_overrides.append(override)
|
|
self.session.add(override)
|
|
self.session.flush()
|
|
return override
|
|
|
|
def attach_impurity_override(
|
|
self,
|
|
settings: PricingSettings,
|
|
override: PricingImpuritySettings,
|
|
) -> PricingImpuritySettings:
|
|
settings.impurity_overrides.append(override)
|
|
self.session.add(override)
|
|
self.session.flush()
|
|
return override
|
|
|
|
|
|
class RoleRepository:
|
|
"""Persistence operations for Role entities."""
|
|
|
|
def __init__(self, session: Session) -> None:
|
|
self.session = session
|
|
|
|
def list(self) -> Sequence[Role]:
|
|
stmt = select(Role).order_by(Role.name)
|
|
return self.session.execute(stmt).scalars().all()
|
|
|
|
def get(self, role_id: int) -> Role:
|
|
stmt = select(Role).where(Role.id == role_id)
|
|
role = self.session.execute(stmt).scalar_one_or_none()
|
|
if role is None:
|
|
raise EntityNotFoundError(f"Role {role_id} not found")
|
|
return role
|
|
|
|
def get_by_name(self, name: str) -> Role | None:
|
|
stmt = select(Role).where(Role.name == name)
|
|
return self.session.execute(stmt).scalar_one_or_none()
|
|
|
|
def create(self, role: Role) -> Role:
|
|
self.session.add(role)
|
|
try:
|
|
self.session.flush()
|
|
except IntegrityError as exc: # pragma: no cover - DB constraint enforcement
|
|
raise EntityConflictError(
|
|
"Role violates uniqueness constraints") from exc
|
|
return role
|
|
|
|
|
|
class UserRepository:
|
|
"""Persistence operations for User entities and their role assignments."""
|
|
|
|
def __init__(self, session: Session) -> None:
|
|
self.session = session
|
|
|
|
def list(self, *, with_roles: bool = False) -> Sequence[User]:
|
|
stmt = select(User).order_by(User.created_at)
|
|
if with_roles:
|
|
stmt = stmt.options(selectinload(User.roles))
|
|
return self.session.execute(stmt).scalars().all()
|
|
|
|
def _apply_role_option(self, stmt, with_roles: bool):
|
|
if with_roles:
|
|
stmt = stmt.options(
|
|
joinedload(User.role_assignments).joinedload(UserRole.role),
|
|
selectinload(User.roles),
|
|
)
|
|
return stmt
|
|
|
|
def get(self, user_id: int, *, with_roles: bool = False) -> User:
|
|
stmt = select(User).where(User.id == user_id).execution_options(
|
|
populate_existing=True)
|
|
stmt = self._apply_role_option(stmt, with_roles)
|
|
result = self.session.execute(stmt)
|
|
if with_roles:
|
|
result = result.unique()
|
|
user = result.scalar_one_or_none()
|
|
if user is None:
|
|
raise EntityNotFoundError(f"User {user_id} not found")
|
|
return user
|
|
|
|
def get_by_email(self, email: str, *, with_roles: bool = False) -> User | None:
|
|
stmt = select(User).where(User.email == email).execution_options(
|
|
populate_existing=True)
|
|
stmt = self._apply_role_option(stmt, with_roles)
|
|
result = self.session.execute(stmt)
|
|
if with_roles:
|
|
result = result.unique()
|
|
return result.scalar_one_or_none()
|
|
|
|
def get_by_username(self, username: str, *, with_roles: bool = False) -> User | None:
|
|
stmt = select(User).where(User.username ==
|
|
username).execution_options(populate_existing=True)
|
|
stmt = self._apply_role_option(stmt, with_roles)
|
|
result = self.session.execute(stmt)
|
|
if with_roles:
|
|
result = result.unique()
|
|
return result.scalar_one_or_none()
|
|
|
|
def create(self, user: User) -> User:
|
|
self.session.add(user)
|
|
try:
|
|
self.session.flush()
|
|
except IntegrityError as exc: # pragma: no cover - DB constraint enforcement
|
|
raise EntityConflictError(
|
|
"User violates uniqueness constraints") from exc
|
|
return user
|
|
|
|
def assign_role(
|
|
self,
|
|
*,
|
|
user_id: int,
|
|
role_id: int,
|
|
granted_by: int | None = None,
|
|
) -> UserRole:
|
|
stmt = select(UserRole).where(
|
|
UserRole.user_id == user_id,
|
|
UserRole.role_id == role_id,
|
|
)
|
|
assignment = self.session.execute(stmt).scalar_one_or_none()
|
|
if assignment:
|
|
return assignment
|
|
|
|
assignment = UserRole(
|
|
user_id=user_id,
|
|
role_id=role_id,
|
|
granted_by=granted_by,
|
|
)
|
|
self.session.add(assignment)
|
|
try:
|
|
self.session.flush()
|
|
except IntegrityError as exc: # pragma: no cover - DB constraint enforcement
|
|
raise EntityConflictError(
|
|
"Assignment violates constraints") from exc
|
|
return assignment
|
|
|
|
def revoke_role(self, *, user_id: int, role_id: int) -> None:
|
|
stmt = select(UserRole).where(
|
|
UserRole.user_id == user_id,
|
|
UserRole.role_id == role_id,
|
|
)
|
|
assignment = self.session.execute(stmt).scalar_one_or_none()
|
|
if assignment is None:
|
|
raise EntityNotFoundError(
|
|
f"Role {role_id} not assigned to user {user_id}")
|
|
self.session.delete(assignment)
|
|
self.session.flush()
|
|
|
|
|
|
DEFAULT_PRICING_SETTINGS_NAME = "Default Pricing Settings"
|
|
DEFAULT_PRICING_SETTINGS_DESCRIPTION = (
|
|
"Default pricing configuration generated from environment metadata."
|
|
)
|
|
|
|
|
|
@dataclass(slots=True)
|
|
class PricingSettingsSeedResult:
|
|
settings: PricingSettings
|
|
created: bool
|
|
updated_fields: int
|
|
impurity_upserts: int
|
|
|
|
|
|
def ensure_default_pricing_settings(
|
|
repo: PricingSettingsRepository,
|
|
*,
|
|
metadata: PricingMetadata,
|
|
slug: str = "default",
|
|
name: str | None = None,
|
|
description: str | None = None,
|
|
) -> PricingSettingsSeedResult:
|
|
"""Ensure a baseline pricing settings record exists and matches metadata defaults."""
|
|
|
|
normalised_slug = (slug or "default").strip().lower() or "default"
|
|
target_name = name or DEFAULT_PRICING_SETTINGS_NAME
|
|
target_description = description or DEFAULT_PRICING_SETTINGS_DESCRIPTION
|
|
|
|
updated_fields = 0
|
|
impurity_upserts = 0
|
|
|
|
try:
|
|
settings = repo.get_by_slug(normalised_slug, include_children=True)
|
|
created = False
|
|
except EntityNotFoundError:
|
|
settings = PricingSettings(
|
|
name=target_name,
|
|
slug=normalised_slug,
|
|
description=target_description,
|
|
default_currency=metadata.default_currency,
|
|
default_payable_pct=metadata.default_payable_pct,
|
|
moisture_threshold_pct=metadata.moisture_threshold_pct,
|
|
moisture_penalty_per_pct=metadata.moisture_penalty_per_pct,
|
|
)
|
|
settings.metadata_payload = None
|
|
settings = repo.create(settings)
|
|
created = True
|
|
else:
|
|
if settings.name != target_name:
|
|
settings.name = target_name
|
|
updated_fields += 1
|
|
if target_description and settings.description != target_description:
|
|
settings.description = target_description
|
|
updated_fields += 1
|
|
if settings.default_currency != metadata.default_currency:
|
|
settings.default_currency = metadata.default_currency
|
|
updated_fields += 1
|
|
if float(settings.default_payable_pct) != float(metadata.default_payable_pct):
|
|
settings.default_payable_pct = metadata.default_payable_pct
|
|
updated_fields += 1
|
|
if float(settings.moisture_threshold_pct) != float(metadata.moisture_threshold_pct):
|
|
settings.moisture_threshold_pct = metadata.moisture_threshold_pct
|
|
updated_fields += 1
|
|
if float(settings.moisture_penalty_per_pct) != float(metadata.moisture_penalty_per_pct):
|
|
settings.moisture_penalty_per_pct = metadata.moisture_penalty_per_pct
|
|
updated_fields += 1
|
|
|
|
impurity_thresholds = {
|
|
code.strip().upper(): float(value)
|
|
for code, value in (metadata.impurity_thresholds or {}).items()
|
|
if code.strip()
|
|
}
|
|
impurity_penalties = {
|
|
code.strip().upper(): float(value)
|
|
for code, value in (metadata.impurity_penalty_per_ppm or {}).items()
|
|
if code.strip()
|
|
}
|
|
|
|
if impurity_thresholds or impurity_penalties:
|
|
existing_map = {
|
|
override.impurity_code: override
|
|
for override in settings.impurity_overrides
|
|
}
|
|
target_codes = set(impurity_thresholds) | set(impurity_penalties)
|
|
for code in sorted(target_codes):
|
|
threshold_value = impurity_thresholds.get(code, 0.0)
|
|
penalty_value = impurity_penalties.get(code, 0.0)
|
|
existing = existing_map.get(code)
|
|
if existing is None:
|
|
repo.attach_impurity_override(
|
|
settings,
|
|
PricingImpuritySettings(
|
|
impurity_code=code,
|
|
threshold_ppm=threshold_value,
|
|
penalty_per_ppm=penalty_value,
|
|
),
|
|
)
|
|
impurity_upserts += 1
|
|
continue
|
|
changed = False
|
|
if float(existing.threshold_ppm) != float(threshold_value):
|
|
existing.threshold_ppm = threshold_value
|
|
changed = True
|
|
if float(existing.penalty_per_ppm) != float(penalty_value):
|
|
existing.penalty_per_ppm = penalty_value
|
|
changed = True
|
|
if changed:
|
|
updated_fields += 1
|
|
|
|
if updated_fields > 0 or impurity_upserts > 0:
|
|
repo.session.flush()
|
|
|
|
return PricingSettingsSeedResult(
|
|
settings=settings,
|
|
created=created,
|
|
updated_fields=updated_fields,
|
|
impurity_upserts=impurity_upserts,
|
|
)
|
|
|
|
|
|
def pricing_settings_to_metadata(settings: PricingSettings) -> PricingMetadata:
|
|
"""Convert a persisted pricing settings record into metadata defaults."""
|
|
|
|
payload = settings.metadata_payload or {}
|
|
payload_thresholds = payload.get("impurity_thresholds") or {}
|
|
payload_penalties = payload.get("impurity_penalty_per_ppm") or {}
|
|
|
|
thresholds: dict[str, float] = {
|
|
code.strip().upper(): float(value)
|
|
for code, value in payload_thresholds.items()
|
|
if isinstance(code, str) and code.strip()
|
|
}
|
|
penalties: dict[str, float] = {
|
|
code.strip().upper(): float(value)
|
|
for code, value in payload_penalties.items()
|
|
if isinstance(code, str) and code.strip()
|
|
}
|
|
|
|
for override in settings.impurity_overrides:
|
|
code = override.impurity_code.strip().upper()
|
|
thresholds[code] = float(override.threshold_ppm)
|
|
penalties[code] = float(override.penalty_per_ppm)
|
|
|
|
return PricingMetadata(
|
|
default_payable_pct=float(settings.default_payable_pct),
|
|
default_currency=settings.default_currency,
|
|
moisture_threshold_pct=float(settings.moisture_threshold_pct),
|
|
moisture_penalty_per_pct=float(settings.moisture_penalty_per_pct),
|
|
impurity_thresholds=thresholds,
|
|
impurity_penalty_per_ppm=penalties,
|
|
)
|
|
|
|
|
|
DEFAULT_ROLE_DEFINITIONS: tuple[dict[str, str], ...] = (
|
|
{
|
|
"name": "admin",
|
|
"display_name": "Administrator",
|
|
"description": "Full platform access with user management rights.",
|
|
},
|
|
{
|
|
"name": "project_manager",
|
|
"display_name": "Project Manager",
|
|
"description": "Manage projects, scenarios, and associated data.",
|
|
},
|
|
{
|
|
"name": "analyst",
|
|
"display_name": "Analyst",
|
|
"description": "Review dashboards and scenario outputs.",
|
|
},
|
|
{
|
|
"name": "viewer",
|
|
"display_name": "Viewer",
|
|
"description": "Read-only access to assigned projects and reports.",
|
|
},
|
|
)
|
|
|
|
|
|
def ensure_default_roles(role_repo: RoleRepository) -> list[Role]:
|
|
"""Ensure standard roles exist, creating missing ones.
|
|
|
|
Returns all current role records in creation order.
|
|
"""
|
|
|
|
roles: list[Role] = []
|
|
for definition in DEFAULT_ROLE_DEFINITIONS:
|
|
existing = role_repo.get_by_name(definition["name"])
|
|
if existing:
|
|
roles.append(existing)
|
|
continue
|
|
role = Role(**definition)
|
|
roles.append(role_repo.create(role))
|
|
return roles
|
|
|
|
|
|
def ensure_admin_user(
|
|
user_repo: UserRepository,
|
|
role_repo: RoleRepository,
|
|
*,
|
|
email: str,
|
|
username: str,
|
|
password: str,
|
|
) -> User:
|
|
"""Ensure an administrator user exists and holds the admin role."""
|
|
|
|
user = user_repo.get_by_email(email, with_roles=True)
|
|
if user is None:
|
|
user = User(
|
|
email=email,
|
|
username=username,
|
|
password_hash=User.hash_password(password),
|
|
is_active=True,
|
|
is_superuser=True,
|
|
)
|
|
user_repo.create(user)
|
|
else:
|
|
if not user.is_active:
|
|
user.is_active = True
|
|
if not user.is_superuser:
|
|
user.is_superuser = True
|
|
user_repo.session.flush()
|
|
|
|
admin_role = role_repo.get_by_name("admin")
|
|
if admin_role is None: # pragma: no cover - safety if ensure_default_roles wasn't called
|
|
admin_role = role_repo.create(
|
|
Role(
|
|
name="admin",
|
|
display_name="Administrator",
|
|
description="Full platform access with user management rights.",
|
|
)
|
|
)
|
|
|
|
user_repo.assign_role(
|
|
user_id=user.id,
|
|
role_id=admin_role.id,
|
|
granted_by=user.id,
|
|
)
|
|
return user
|