from __future__ import annotations from collections.abc import Iterable from dataclasses import dataclass from datetime import datetime from typing import Mapping, Sequence from sqlalchemy import select, func from sqlalchemy.exc import IntegrityError from sqlalchemy.orm import Session, joinedload, selectinload from models import ( FinancialInput, Project, PricingImpuritySettings, PricingMetalSettings, PricingSettings, ProjectCapexSnapshot, ProjectProfitability, Role, Scenario, ScenarioCapexSnapshot, ScenarioProfitability, ScenarioStatus, SimulationParameter, User, UserRole, ) from services.exceptions import EntityConflictError, EntityNotFoundError from services.export_query import ProjectExportFilters, ScenarioExportFilters from services.pricing import PricingMetadata def _enum_value(e): """Return the underlying value for Enum members, otherwise return as-is.""" return getattr(e, "value", e) class ProjectRepository: """Persistence operations for Project entities.""" def __init__(self, session: Session) -> None: self.session = session def list( self, *, with_children: bool = False, with_pricing: bool = False, ) -> Sequence[Project]: stmt = select(Project).order_by(Project.created_at) if with_children: stmt = stmt.options(selectinload(Project.scenarios)) if with_pricing: stmt = stmt.options(selectinload(Project.pricing_settings)) return self.session.execute(stmt).scalars().all() def count(self) -> int: stmt = select(func.count(Project.id)) return self.session.execute(stmt).scalar_one() def recent(self, limit: int = 5) -> Sequence[Project]: stmt = ( select(Project) .order_by(Project.updated_at.desc()) .limit(limit) ) return self.session.execute(stmt).scalars().all() def get( self, project_id: int, *, with_children: bool = False, with_pricing: bool = False, ) -> Project: stmt = select(Project).where(Project.id == project_id) if with_children: stmt = stmt.options(joinedload(Project.scenarios)) if with_pricing: stmt = stmt.options(joinedload(Project.pricing_settings)) result = self.session.execute(stmt) if with_children: result = result.unique() project = result.scalar_one_or_none() if project is None: raise EntityNotFoundError(f"Project {project_id} not found") return project def exists(self, project_id: int) -> bool: stmt = select(Project.id).where(Project.id == project_id).limit(1) return self.session.execute(stmt).scalar_one_or_none() is not None def create(self, project: Project) -> Project: self.session.add(project) try: self.session.flush() except IntegrityError as exc: # pragma: no cover - reliance on DB constraints from monitoring.metrics import observe_project_operation observe_project_operation("create", "error") raise EntityConflictError( "Project violates uniqueness constraints") from exc from monitoring.metrics import observe_project_operation observe_project_operation("create", "success") return project def find_by_names(self, names: Iterable[str]) -> Mapping[str, Project]: normalised = {name.strip().lower() for name in names if name and name.strip()} if not normalised: return {} stmt = select(Project).where(func.lower(Project.name).in_(normalised)) records = self.session.execute(stmt).scalars().all() return {project.name.lower(): project for project in records} def filtered_for_export( self, filters: ProjectExportFilters | None = None, *, include_scenarios: bool = False, include_pricing: bool = False, ) -> Sequence[Project]: stmt = select(Project) if include_scenarios: stmt = stmt.options(selectinload(Project.scenarios)) if include_pricing: stmt = stmt.options(selectinload(Project.pricing_settings)) if filters: ids = filters.normalised_ids() if ids: stmt = stmt.where(Project.id.in_(ids)) name_matches = filters.normalised_names() if name_matches: stmt = stmt.where(func.lower(Project.name).in_(name_matches)) name_pattern = filters.name_search_pattern() if name_pattern: stmt = stmt.where(Project.name.ilike(name_pattern)) locations = filters.normalised_locations() if locations: stmt = stmt.where(func.lower(Project.location).in_(locations)) if filters.operation_types: stmt = stmt.where(Project.operation_type.in_( filters.operation_types)) if filters.created_from: stmt = stmt.where(Project.created_at >= filters.created_from) if filters.created_to: stmt = stmt.where(Project.created_at <= filters.created_to) if filters.updated_from: stmt = stmt.where(Project.updated_at >= filters.updated_from) if filters.updated_to: stmt = stmt.where(Project.updated_at <= filters.updated_to) stmt = stmt.order_by(Project.name, Project.id) return self.session.execute(stmt).scalars().all() def delete(self, project_id: int) -> None: project = self.get(project_id) self.session.delete(project) def set_pricing_settings( self, project: Project, pricing_settings: PricingSettings | None, ) -> Project: project.pricing_settings = pricing_settings project.pricing_settings_id = ( pricing_settings.id if pricing_settings is not None else None ) self.session.flush() return project class ScenarioRepository: """Persistence operations for Scenario entities.""" def __init__(self, session: Session) -> None: self.session = session def list_for_project( self, project_id: int, *, with_children: bool = False, ) -> Sequence[Scenario]: stmt = ( select(Scenario) .where(Scenario.project_id == project_id) .order_by(Scenario.created_at) ) if with_children: stmt = stmt.options( selectinload(Scenario.financial_inputs), selectinload(Scenario.simulation_parameters), ) result = self.session.execute(stmt) if with_children: result = result.unique() return result.scalars().all() def count(self) -> int: stmt = select(func.count(Scenario.id)) return self.session.execute(stmt).scalar_one() def count_by_status(self, status: ScenarioStatus) -> int: status_val = _enum_value(status) stmt = select(func.count(Scenario.id)).where( Scenario.status == status_val) return self.session.execute(stmt).scalar_one() def recent(self, limit: int = 5, *, with_project: bool = False) -> Sequence[Scenario]: stmt = select(Scenario).order_by( Scenario.updated_at.desc()).limit(limit) if with_project: stmt = stmt.options(joinedload(Scenario.project)) return self.session.execute(stmt).scalars().all() def list_by_status( self, status: ScenarioStatus, *, limit: int | None = None, with_project: bool = False, ) -> Sequence[Scenario]: status_val = _enum_value(status) stmt = ( select(Scenario) .where(Scenario.status == status_val) .order_by(Scenario.updated_at.desc()) ) if with_project: stmt = stmt.options(joinedload(Scenario.project)) if limit is not None: stmt = stmt.limit(limit) return self.session.execute(stmt).scalars().all() def get(self, scenario_id: int, *, with_children: bool = False) -> Scenario: stmt = select(Scenario).where(Scenario.id == scenario_id) if with_children: stmt = stmt.options( joinedload(Scenario.financial_inputs), joinedload(Scenario.simulation_parameters), ) result = self.session.execute(stmt) if with_children: result = result.unique() scenario = result.scalar_one_or_none() if scenario is None: raise EntityNotFoundError(f"Scenario {scenario_id} not found") return scenario def exists(self, scenario_id: int) -> bool: stmt = select(Scenario.id).where(Scenario.id == scenario_id).limit(1) return self.session.execute(stmt).scalar_one_or_none() is not None def create(self, scenario: Scenario) -> Scenario: self.session.add(scenario) try: self.session.flush() except IntegrityError as exc: # pragma: no cover from monitoring.metrics import observe_scenario_operation observe_scenario_operation("create", "error") raise EntityConflictError("Scenario violates constraints") from exc from monitoring.metrics import observe_scenario_operation observe_scenario_operation("create", "success") return scenario def find_by_project_and_names( self, project_id: int, names: Iterable[str], ) -> Mapping[str, Scenario]: normalised = {name.strip().lower() for name in names if name and name.strip()} if not normalised: return {} stmt = ( select(Scenario) .where( Scenario.project_id == project_id, func.lower(Scenario.name).in_(normalised), ) ) records = self.session.execute(stmt).scalars().all() return {scenario.name.lower(): scenario for scenario in records} def filtered_for_export( self, filters: ScenarioExportFilters | None = None, *, include_project: bool = True, ) -> Sequence[Scenario]: stmt = select(Scenario) if include_project: stmt = stmt.options(joinedload(Scenario.project)) if filters: scenario_ids = filters.normalised_ids() if scenario_ids: stmt = stmt.where(Scenario.id.in_(scenario_ids)) project_ids = filters.normalised_project_ids() if project_ids: stmt = stmt.where(Scenario.project_id.in_(project_ids)) project_names = filters.normalised_project_names() if project_names: project_id_select = select(Project.id).where( func.lower(Project.name).in_(project_names) ) stmt = stmt.where(Scenario.project_id.in_(project_id_select)) name_pattern = filters.name_search_pattern() if name_pattern: stmt = stmt.where(Scenario.name.ilike(name_pattern)) if filters.statuses: # Accept Enum members or raw values in filters.statuses status_values = [ _enum_value(s) for s in (filters.statuses or []) ] stmt = stmt.where(Scenario.status.in_(status_values)) if filters.start_date_from: stmt = stmt.where(Scenario.start_date >= filters.start_date_from) if filters.start_date_to: stmt = stmt.where(Scenario.start_date <= filters.start_date_to) if filters.end_date_from: stmt = stmt.where(Scenario.end_date >= filters.end_date_from) if filters.end_date_to: stmt = stmt.where(Scenario.end_date <= filters.end_date_to) if filters.created_from: stmt = stmt.where(Scenario.created_at >= filters.created_from) if filters.created_to: stmt = stmt.where(Scenario.created_at <= filters.created_to) if filters.updated_from: stmt = stmt.where(Scenario.updated_at >= filters.updated_from) if filters.updated_to: stmt = stmt.where(Scenario.updated_at <= filters.updated_to) currencies = filters.normalised_currencies() if currencies: stmt = stmt.where(func.upper( Scenario.currency).in_(currencies)) if filters.primary_resources: stmt = stmt.where(Scenario.primary_resource.in_( filters.primary_resources)) stmt = stmt.order_by(Scenario.name, Scenario.id) return self.session.execute(stmt).scalars().all() def delete(self, scenario_id: int) -> None: scenario = self.get(scenario_id) self.session.delete(scenario) class ProjectProfitabilityRepository: """Persistence operations for project-level profitability snapshots.""" def __init__(self, session: Session) -> None: self.session = session def create(self, snapshot: ProjectProfitability) -> ProjectProfitability: self.session.add(snapshot) self.session.flush() return snapshot def list_for_project( self, project_id: int, *, limit: int | None = None, ) -> Sequence[ProjectProfitability]: stmt = ( select(ProjectProfitability) .where(ProjectProfitability.project_id == project_id) .order_by(ProjectProfitability.calculated_at.desc()) ) if limit is not None: stmt = stmt.limit(limit) return self.session.execute(stmt).scalars().all() def latest_for_project( self, project_id: int, ) -> ProjectProfitability | None: stmt = ( select(ProjectProfitability) .where(ProjectProfitability.project_id == project_id) .order_by(ProjectProfitability.calculated_at.desc()) .limit(1) ) return self.session.execute(stmt).scalar_one_or_none() def delete(self, snapshot_id: int) -> None: stmt = select(ProjectProfitability).where( ProjectProfitability.id == snapshot_id ) entity = self.session.execute(stmt).scalar_one_or_none() if entity is None: raise EntityNotFoundError( f"Project profitability snapshot {snapshot_id} not found" ) self.session.delete(entity) class ScenarioProfitabilityRepository: """Persistence operations for scenario-level profitability snapshots.""" def __init__(self, session: Session) -> None: self.session = session def create(self, snapshot: ScenarioProfitability) -> ScenarioProfitability: self.session.add(snapshot) self.session.flush() return snapshot def list_for_scenario( self, scenario_id: int, *, limit: int | None = None, ) -> Sequence[ScenarioProfitability]: stmt = ( select(ScenarioProfitability) .where(ScenarioProfitability.scenario_id == scenario_id) .order_by(ScenarioProfitability.calculated_at.desc()) ) if limit is not None: stmt = stmt.limit(limit) return self.session.execute(stmt).scalars().all() def latest_for_scenario( self, scenario_id: int, ) -> ScenarioProfitability | None: stmt = ( select(ScenarioProfitability) .where(ScenarioProfitability.scenario_id == scenario_id) .order_by(ScenarioProfitability.calculated_at.desc()) .limit(1) ) return self.session.execute(stmt).scalar_one_or_none() def delete(self, snapshot_id: int) -> None: stmt = select(ScenarioProfitability).where( ScenarioProfitability.id == snapshot_id ) entity = self.session.execute(stmt).scalar_one_or_none() if entity is None: raise EntityNotFoundError( f"Scenario profitability snapshot {snapshot_id} not found" ) self.session.delete(entity) class ProjectCapexRepository: """Persistence operations for project-level capex snapshots.""" def __init__(self, session: Session) -> None: self.session = session def create(self, snapshot: ProjectCapexSnapshot) -> ProjectCapexSnapshot: self.session.add(snapshot) self.session.flush() return snapshot def list_for_project( self, project_id: int, *, limit: int | None = None, ) -> Sequence[ProjectCapexSnapshot]: stmt = ( select(ProjectCapexSnapshot) .where(ProjectCapexSnapshot.project_id == project_id) .order_by(ProjectCapexSnapshot.calculated_at.desc()) ) if limit is not None: stmt = stmt.limit(limit) return self.session.execute(stmt).scalars().all() def latest_for_project( self, project_id: int, ) -> ProjectCapexSnapshot | None: stmt = ( select(ProjectCapexSnapshot) .where(ProjectCapexSnapshot.project_id == project_id) .order_by(ProjectCapexSnapshot.calculated_at.desc()) .limit(1) ) return self.session.execute(stmt).scalar_one_or_none() def delete(self, snapshot_id: int) -> None: stmt = select(ProjectCapexSnapshot).where( ProjectCapexSnapshot.id == snapshot_id ) entity = self.session.execute(stmt).scalar_one_or_none() if entity is None: raise EntityNotFoundError( f"Project capex snapshot {snapshot_id} not found" ) self.session.delete(entity) class ScenarioCapexRepository: """Persistence operations for scenario-level capex snapshots.""" def __init__(self, session: Session) -> None: self.session = session def create(self, snapshot: ScenarioCapexSnapshot) -> ScenarioCapexSnapshot: self.session.add(snapshot) self.session.flush() return snapshot def list_for_scenario( self, scenario_id: int, *, limit: int | None = None, ) -> Sequence[ScenarioCapexSnapshot]: stmt = ( select(ScenarioCapexSnapshot) .where(ScenarioCapexSnapshot.scenario_id == scenario_id) .order_by(ScenarioCapexSnapshot.calculated_at.desc()) ) if limit is not None: stmt = stmt.limit(limit) return self.session.execute(stmt).scalars().all() def latest_for_scenario( self, scenario_id: int, ) -> ScenarioCapexSnapshot | None: stmt = ( select(ScenarioCapexSnapshot) .where(ScenarioCapexSnapshot.scenario_id == scenario_id) .order_by(ScenarioCapexSnapshot.calculated_at.desc()) .limit(1) ) return self.session.execute(stmt).scalar_one_or_none() def delete(self, snapshot_id: int) -> None: stmt = select(ScenarioCapexSnapshot).where( ScenarioCapexSnapshot.id == snapshot_id ) entity = self.session.execute(stmt).scalar_one_or_none() if entity is None: raise EntityNotFoundError( f"Scenario capex snapshot {snapshot_id} not found" ) self.session.delete(entity) class FinancialInputRepository: """Persistence operations for FinancialInput entities.""" def __init__(self, session: Session) -> None: self.session = session def list_for_scenario(self, scenario_id: int) -> Sequence[FinancialInput]: stmt = ( select(FinancialInput) .where(FinancialInput.scenario_id == scenario_id) .order_by(FinancialInput.created_at) ) return self.session.execute(stmt).scalars().all() def bulk_upsert(self, inputs: Iterable[FinancialInput]) -> Sequence[FinancialInput]: entities = list(inputs) self.session.add_all(entities) try: self.session.flush() except IntegrityError as exc: # pragma: no cover raise EntityConflictError( "Financial input violates constraints") from exc return entities def delete(self, input_id: int) -> None: stmt = select(FinancialInput).where(FinancialInput.id == input_id) entity = self.session.execute(stmt).scalar_one_or_none() if entity is None: raise EntityNotFoundError(f"Financial input {input_id} not found") self.session.delete(entity) def latest_created_at(self) -> datetime | None: stmt = ( select(FinancialInput.created_at) .order_by(FinancialInput.created_at.desc()) .limit(1) ) return self.session.execute(stmt).scalar_one_or_none() class SimulationParameterRepository: """Persistence operations for SimulationParameter entities.""" def __init__(self, session: Session) -> None: self.session = session def list_for_scenario(self, scenario_id: int) -> Sequence[SimulationParameter]: stmt = ( select(SimulationParameter) .where(SimulationParameter.scenario_id == scenario_id) .order_by(SimulationParameter.created_at) ) return self.session.execute(stmt).scalars().all() def bulk_upsert( self, parameters: Iterable[SimulationParameter] ) -> Sequence[SimulationParameter]: entities = list(parameters) self.session.add_all(entities) try: self.session.flush() except IntegrityError as exc: # pragma: no cover raise EntityConflictError( "Simulation parameter violates constraints") from exc return entities def delete(self, parameter_id: int) -> None: stmt = select(SimulationParameter).where( SimulationParameter.id == parameter_id) entity = self.session.execute(stmt).scalar_one_or_none() if entity is None: raise EntityNotFoundError( f"Simulation parameter {parameter_id} not found") self.session.delete(entity) class PricingSettingsRepository: """Persistence operations for pricing configuration entities.""" def __init__(self, session: Session) -> None: self.session = session def list(self, *, include_children: bool = False) -> Sequence[PricingSettings]: stmt = select(PricingSettings).order_by(PricingSettings.created_at) if include_children: stmt = stmt.options( selectinload(PricingSettings.metal_overrides), selectinload(PricingSettings.impurity_overrides), ) result = self.session.execute(stmt) if include_children: result = result.unique() return result.scalars().all() def get(self, settings_id: int, *, include_children: bool = False) -> PricingSettings: stmt = select(PricingSettings).where(PricingSettings.id == settings_id) if include_children: stmt = stmt.options( selectinload(PricingSettings.metal_overrides), selectinload(PricingSettings.impurity_overrides), ) result = self.session.execute(stmt) if include_children: result = result.unique() settings = result.scalar_one_or_none() if settings is None: raise EntityNotFoundError( f"Pricing settings {settings_id} not found") return settings def find_by_slug( self, slug: str, *, include_children: bool = False, ) -> PricingSettings | None: normalised = slug.strip().lower() stmt = select(PricingSettings).where( PricingSettings.slug == normalised) if include_children: stmt = stmt.options( selectinload(PricingSettings.metal_overrides), selectinload(PricingSettings.impurity_overrides), ) result = self.session.execute(stmt) if include_children: result = result.unique() return result.scalar_one_or_none() def get_by_slug(self, slug: str, *, include_children: bool = False) -> PricingSettings: settings = self.find_by_slug(slug, include_children=include_children) if settings is None: raise EntityNotFoundError( f"Pricing settings slug '{slug}' not found" ) return settings def create(self, settings: PricingSettings) -> PricingSettings: self.session.add(settings) try: self.session.flush() except IntegrityError as exc: # pragma: no cover - relies on DB constraints raise EntityConflictError( "Pricing settings violates constraints") from exc return settings def delete(self, settings_id: int) -> None: settings = self.get(settings_id, include_children=True) self.session.delete(settings) def attach_metal_override( self, settings: PricingSettings, override: PricingMetalSettings, ) -> PricingMetalSettings: settings.metal_overrides.append(override) self.session.add(override) self.session.flush() return override def attach_impurity_override( self, settings: PricingSettings, override: PricingImpuritySettings, ) -> PricingImpuritySettings: settings.impurity_overrides.append(override) self.session.add(override) self.session.flush() return override class RoleRepository: """Persistence operations for Role entities.""" def __init__(self, session: Session) -> None: self.session = session def list(self) -> Sequence[Role]: stmt = select(Role).order_by(Role.name) return self.session.execute(stmt).scalars().all() def get(self, role_id: int) -> Role: stmt = select(Role).where(Role.id == role_id) role = self.session.execute(stmt).scalar_one_or_none() if role is None: raise EntityNotFoundError(f"Role {role_id} not found") return role def get_by_name(self, name: str) -> Role | None: stmt = select(Role).where(Role.name == name) return self.session.execute(stmt).scalar_one_or_none() def create(self, role: Role) -> Role: self.session.add(role) try: self.session.flush() except IntegrityError as exc: # pragma: no cover - DB constraint enforcement raise EntityConflictError( "Role violates uniqueness constraints") from exc return role class UserRepository: """Persistence operations for User entities and their role assignments.""" def __init__(self, session: Session) -> None: self.session = session def list(self, *, with_roles: bool = False) -> Sequence[User]: stmt = select(User).order_by(User.created_at) if with_roles: stmt = stmt.options(selectinload(User.roles)) return self.session.execute(stmt).scalars().all() def _apply_role_option(self, stmt, with_roles: bool): if with_roles: stmt = stmt.options( joinedload(User.role_assignments).joinedload(UserRole.role), selectinload(User.roles), ) return stmt def get(self, user_id: int, *, with_roles: bool = False) -> User: stmt = select(User).where(User.id == user_id).execution_options( populate_existing=True) stmt = self._apply_role_option(stmt, with_roles) result = self.session.execute(stmt) if with_roles: result = result.unique() user = result.scalar_one_or_none() if user is None: raise EntityNotFoundError(f"User {user_id} not found") return user def get_by_email(self, email: str, *, with_roles: bool = False) -> User | None: stmt = select(User).where(User.email == email).execution_options( populate_existing=True) stmt = self._apply_role_option(stmt, with_roles) result = self.session.execute(stmt) if with_roles: result = result.unique() return result.scalar_one_or_none() def get_by_username(self, username: str, *, with_roles: bool = False) -> User | None: stmt = select(User).where(User.username == username).execution_options(populate_existing=True) stmt = self._apply_role_option(stmt, with_roles) result = self.session.execute(stmt) if with_roles: result = result.unique() return result.scalar_one_or_none() def create(self, user: User) -> User: self.session.add(user) try: self.session.flush() except IntegrityError as exc: # pragma: no cover - DB constraint enforcement raise EntityConflictError( "User violates uniqueness constraints") from exc return user def assign_role( self, *, user_id: int, role_id: int, granted_by: int | None = None, ) -> UserRole: stmt = select(UserRole).where( UserRole.user_id == user_id, UserRole.role_id == role_id, ) assignment = self.session.execute(stmt).scalar_one_or_none() if assignment: return assignment assignment = UserRole( user_id=user_id, role_id=role_id, granted_by=granted_by, ) self.session.add(assignment) try: self.session.flush() except IntegrityError as exc: # pragma: no cover - DB constraint enforcement raise EntityConflictError( "Assignment violates constraints") from exc return assignment def revoke_role(self, *, user_id: int, role_id: int) -> None: stmt = select(UserRole).where( UserRole.user_id == user_id, UserRole.role_id == role_id, ) assignment = self.session.execute(stmt).scalar_one_or_none() if assignment is None: raise EntityNotFoundError( f"Role {role_id} not assigned to user {user_id}") self.session.delete(assignment) self.session.flush() DEFAULT_PRICING_SETTINGS_NAME = "Default Pricing Settings" DEFAULT_PRICING_SETTINGS_DESCRIPTION = ( "Default pricing configuration generated from environment metadata." ) @dataclass(slots=True) class PricingSettingsSeedResult: settings: PricingSettings created: bool updated_fields: int impurity_upserts: int def ensure_default_pricing_settings( repo: PricingSettingsRepository, *, metadata: PricingMetadata, slug: str = "default", name: str | None = None, description: str | None = None, ) -> PricingSettingsSeedResult: """Ensure a baseline pricing settings record exists and matches metadata defaults.""" normalised_slug = (slug or "default").strip().lower() or "default" target_name = name or DEFAULT_PRICING_SETTINGS_NAME target_description = description or DEFAULT_PRICING_SETTINGS_DESCRIPTION updated_fields = 0 impurity_upserts = 0 try: settings = repo.get_by_slug(normalised_slug, include_children=True) created = False except EntityNotFoundError: settings = PricingSettings( name=target_name, slug=normalised_slug, description=target_description, default_currency=metadata.default_currency, default_payable_pct=metadata.default_payable_pct, moisture_threshold_pct=metadata.moisture_threshold_pct, moisture_penalty_per_pct=metadata.moisture_penalty_per_pct, ) settings.metadata_payload = None settings = repo.create(settings) created = True else: if settings.name != target_name: settings.name = target_name updated_fields += 1 if target_description and settings.description != target_description: settings.description = target_description updated_fields += 1 if settings.default_currency != metadata.default_currency: settings.default_currency = metadata.default_currency updated_fields += 1 if float(settings.default_payable_pct) != float(metadata.default_payable_pct): settings.default_payable_pct = metadata.default_payable_pct updated_fields += 1 if float(settings.moisture_threshold_pct) != float(metadata.moisture_threshold_pct): settings.moisture_threshold_pct = metadata.moisture_threshold_pct updated_fields += 1 if float(settings.moisture_penalty_per_pct) != float(metadata.moisture_penalty_per_pct): settings.moisture_penalty_per_pct = metadata.moisture_penalty_per_pct updated_fields += 1 impurity_thresholds = { code.strip().upper(): float(value) for code, value in (metadata.impurity_thresholds or {}).items() if code.strip() } impurity_penalties = { code.strip().upper(): float(value) for code, value in (metadata.impurity_penalty_per_ppm or {}).items() if code.strip() } if impurity_thresholds or impurity_penalties: existing_map = { override.impurity_code: override for override in settings.impurity_overrides } target_codes = set(impurity_thresholds) | set(impurity_penalties) for code in sorted(target_codes): threshold_value = impurity_thresholds.get(code, 0.0) penalty_value = impurity_penalties.get(code, 0.0) existing = existing_map.get(code) if existing is None: repo.attach_impurity_override( settings, PricingImpuritySettings( impurity_code=code, threshold_ppm=threshold_value, penalty_per_ppm=penalty_value, ), ) impurity_upserts += 1 continue changed = False if float(existing.threshold_ppm) != float(threshold_value): existing.threshold_ppm = threshold_value changed = True if float(existing.penalty_per_ppm) != float(penalty_value): existing.penalty_per_ppm = penalty_value changed = True if changed: updated_fields += 1 if updated_fields > 0 or impurity_upserts > 0: repo.session.flush() return PricingSettingsSeedResult( settings=settings, created=created, updated_fields=updated_fields, impurity_upserts=impurity_upserts, ) def pricing_settings_to_metadata(settings: PricingSettings) -> PricingMetadata: """Convert a persisted pricing settings record into metadata defaults.""" payload = settings.metadata_payload or {} payload_thresholds = payload.get("impurity_thresholds") or {} payload_penalties = payload.get("impurity_penalty_per_ppm") or {} thresholds: dict[str, float] = { code.strip().upper(): float(value) for code, value in payload_thresholds.items() if isinstance(code, str) and code.strip() } penalties: dict[str, float] = { code.strip().upper(): float(value) for code, value in payload_penalties.items() if isinstance(code, str) and code.strip() } for override in settings.impurity_overrides: code = override.impurity_code.strip().upper() thresholds[code] = float(override.threshold_ppm) penalties[code] = float(override.penalty_per_ppm) return PricingMetadata( default_payable_pct=float(settings.default_payable_pct), default_currency=settings.default_currency, moisture_threshold_pct=float(settings.moisture_threshold_pct), moisture_penalty_per_pct=float(settings.moisture_penalty_per_pct), impurity_thresholds=thresholds, impurity_penalty_per_ppm=penalties, ) DEFAULT_ROLE_DEFINITIONS: tuple[dict[str, str], ...] = ( { "name": "admin", "display_name": "Administrator", "description": "Full platform access with user management rights.", }, { "name": "project_manager", "display_name": "Project Manager", "description": "Manage projects, scenarios, and associated data.", }, { "name": "analyst", "display_name": "Analyst", "description": "Review dashboards and scenario outputs.", }, { "name": "viewer", "display_name": "Viewer", "description": "Read-only access to assigned projects and reports.", }, ) def ensure_default_roles(role_repo: RoleRepository) -> list[Role]: """Ensure standard roles exist, creating missing ones. Returns all current role records in creation order. """ roles: list[Role] = [] for definition in DEFAULT_ROLE_DEFINITIONS: existing = role_repo.get_by_name(definition["name"]) if existing: roles.append(existing) continue role = Role(**definition) roles.append(role_repo.create(role)) return roles def ensure_admin_user( user_repo: UserRepository, role_repo: RoleRepository, *, email: str, username: str, password: str, ) -> User: """Ensure an administrator user exists and holds the admin role.""" user = user_repo.get_by_email(email, with_roles=True) if user is None: user = User( email=email, username=username, password_hash=User.hash_password(password), is_active=True, is_superuser=True, ) user_repo.create(user) else: if not user.is_active: user.is_active = True if not user.is_superuser: user.is_superuser = True user_repo.session.flush() admin_role = role_repo.get_by_name("admin") if admin_role is None: # pragma: no cover - safety if ensure_default_roles wasn't called admin_role = role_repo.create( Role( name="admin", display_name="Administrator", description="Full platform access with user management rights.", ) ) user_repo.assign_role( user_id=user.id, role_id=admin_role.id, granted_by=user.id, ) return user