from __future__ import annotations from contextlib import AbstractContextManager from typing import Callable, Sequence from sqlalchemy.orm import Session from config.database import SessionLocal from models import PricingSettings, Project, Role, Scenario from services.pricing import PricingMetadata from services.repositories import ( FinancialInputRepository, PricingSettingsRepository, PricingSettingsSeedResult, ProjectRepository, ProjectProfitabilityRepository, ProjectProcessingOpexRepository, ProjectCapexRepository, RoleRepository, ScenarioRepository, ScenarioProfitabilityRepository, ScenarioProcessingOpexRepository, ScenarioCapexRepository, SimulationParameterRepository, UserRepository, ensure_admin_user as ensure_admin_user_record, ensure_default_pricing_settings, ensure_default_roles, pricing_settings_to_metadata, ) from services.scenario_validation import ScenarioComparisonValidator class UnitOfWork(AbstractContextManager["UnitOfWork"]): """Simple unit-of-work wrapper around SQLAlchemy sessions.""" def __init__(self, session_factory: Callable[[], Session] = SessionLocal) -> None: self._session_factory = session_factory self.session: Session | None = None self._scenario_validator: ScenarioComparisonValidator | None = None self.projects: ProjectRepository | None = None self.scenarios: ScenarioRepository | None = None self.financial_inputs: FinancialInputRepository | None = None self.simulation_parameters: SimulationParameterRepository | None = None self.project_profitability: ProjectProfitabilityRepository | None = None self.project_capex: ProjectCapexRepository | None = None self.project_processing_opex: ProjectProcessingOpexRepository | None = None self.scenario_profitability: ScenarioProfitabilityRepository | None = None self.scenario_capex: ScenarioCapexRepository | None = None self.scenario_processing_opex: ScenarioProcessingOpexRepository | None = None self.users: UserRepository | None = None self.roles: RoleRepository | None = None self.pricing_settings: PricingSettingsRepository | None = None def __enter__(self) -> "UnitOfWork": self.session = self._session_factory() self.projects = ProjectRepository(self.session) self.scenarios = ScenarioRepository(self.session) self.financial_inputs = FinancialInputRepository(self.session) self.simulation_parameters = SimulationParameterRepository( self.session) self.project_profitability = ProjectProfitabilityRepository( self.session) self.project_capex = ProjectCapexRepository(self.session) self.project_processing_opex = ProjectProcessingOpexRepository( self.session) self.scenario_profitability = ScenarioProfitabilityRepository( self.session ) self.scenario_capex = ScenarioCapexRepository(self.session) self.scenario_processing_opex = ScenarioProcessingOpexRepository( self.session) self.users = UserRepository(self.session) self.roles = RoleRepository(self.session) self.pricing_settings = PricingSettingsRepository(self.session) self._scenario_validator = ScenarioComparisonValidator() return self def __exit__(self, exc_type, exc_value, traceback) -> None: assert self.session is not None if exc_type is None: self.session.commit() else: self.session.rollback() self.session.close() self._scenario_validator = None self.projects = None self.scenarios = None self.financial_inputs = None self.simulation_parameters = None self.project_profitability = None self.project_capex = None self.project_processing_opex = None self.scenario_profitability = None self.scenario_capex = None self.scenario_processing_opex = None self.users = None self.roles = None self.pricing_settings = None def flush(self) -> None: if not self.session: raise RuntimeError("UnitOfWork session is not initialised") self.session.flush() def commit(self) -> None: if not self.session: raise RuntimeError("UnitOfWork session is not initialised") self.session.commit() def rollback(self) -> None: if not self.session: raise RuntimeError("UnitOfWork session is not initialised") self.session.rollback() def validate_scenarios_for_comparison( self, scenario_ids: Sequence[int] ) -> list[Scenario]: if not self.session or not self._scenario_validator or not self.scenarios: raise RuntimeError("UnitOfWork session is not initialised") scenarios = [self.scenarios.get(scenario_id) for scenario_id in scenario_ids] self._scenario_validator.validate(scenarios) return scenarios def validate_scenario_models_for_comparison( self, scenarios: Sequence[Scenario] ) -> None: if not self._scenario_validator: raise RuntimeError("UnitOfWork session is not initialised") self._scenario_validator.validate(scenarios) def ensure_default_roles(self) -> list[Role]: if not self.roles: raise RuntimeError("UnitOfWork session is not initialised") return ensure_default_roles(self.roles) def ensure_admin_user( self, *, email: str, username: str, password: str, ) -> None: if not self.users or not self.roles: raise RuntimeError("UnitOfWork session is not initialised") ensure_default_roles(self.roles) ensure_admin_user_record( self.users, self.roles, email=email, username=username, password=password, ) def ensure_default_pricing_settings( self, *, metadata: PricingMetadata, slug: str = "default", name: str | None = None, description: str | None = None, ) -> PricingSettingsSeedResult: if not self.pricing_settings: raise RuntimeError("UnitOfWork session is not initialised") return ensure_default_pricing_settings( self.pricing_settings, metadata=metadata, slug=slug, name=name, description=description, ) def get_pricing_metadata( self, *, slug: str = "default", ) -> PricingMetadata | None: if not self.pricing_settings: raise RuntimeError("UnitOfWork session is not initialised") settings = self.pricing_settings.find_by_slug( slug, include_children=True, ) if settings is None: return None return pricing_settings_to_metadata(settings) def set_project_pricing_settings( self, project: Project, pricing_settings: PricingSettings | None, ) -> Project: if not self.projects: raise RuntimeError("UnitOfWork session is not initialised") return self.projects.set_pricing_settings(project, pricing_settings)