From 1a7581cda0033e239883e401b163562e8f62a107 Mon Sep 17 00:00:00 2001 From: zwitschi Date: Mon, 10 Nov 2025 15:36:06 +0100 Subject: [PATCH] feat: add export filters for projects and scenarios with filtering capabilities --- services/export_query.py | 113 +++++++++++++++++++++ services/export_serializers.py | 178 +++++++++++++++++++++++++++++++++ services/repositories.py | 118 ++++++++++++++++++++++ tests/test_repositories.py | 126 ++++++++++++++++++++++- 4 files changed, 533 insertions(+), 2 deletions(-) create mode 100644 services/export_query.py create mode 100644 services/export_serializers.py diff --git a/services/export_query.py b/services/export_query.py new file mode 100644 index 0000000..acbcb6e --- /dev/null +++ b/services/export_query.py @@ -0,0 +1,113 @@ +from __future__ import annotations + +from dataclasses import dataclass +from datetime import date, datetime +from typing import Iterable + +from models import MiningOperationType, ResourceType, ScenarioStatus + + +def _normalise_lower_strings(values: Iterable[str]) -> tuple[str, ...]: + unique: set[str] = set() + for value in values: + if not value: + continue + trimmed = value.strip().lower() + if not trimmed: + continue + unique.add(trimmed) + return tuple(sorted(unique)) + + +def _normalise_upper_strings(values: Iterable[str]) -> tuple[str, ...]: + unique: set[str] = set() + for value in values: + if not value: + continue + trimmed = value.strip().upper() + if not trimmed: + continue + unique.add(trimmed) + return tuple(sorted(unique)) + + +@dataclass(slots=True, frozen=True) +class ProjectExportFilters: + """Filter parameters for project export queries.""" + + ids: tuple[int, ...] = () + names: tuple[str, ...] = () + name_contains: str | None = None + locations: tuple[str, ...] = () + operation_types: tuple[MiningOperationType, ...] = () + created_from: datetime | None = None + created_to: datetime | None = None + updated_from: datetime | None = None + updated_to: datetime | None = None + + def normalised_ids(self) -> tuple[int, ...]: + unique = {identifier for identifier in self.ids if identifier > 0} + return tuple(sorted(unique)) + + def normalised_names(self) -> tuple[str, ...]: + return _normalise_lower_strings(self.names) + + def normalised_locations(self) -> tuple[str, ...]: + return _normalise_lower_strings(self.locations) + + def name_search_pattern(self) -> str | None: + if not self.name_contains: + return None + pattern = self.name_contains.strip() + if not pattern: + return None + return f"%{pattern}%" + + +@dataclass(slots=True, frozen=True) +class ScenarioExportFilters: + """Filter parameters for scenario export queries.""" + + ids: tuple[int, ...] = () + project_ids: tuple[int, ...] = () + project_names: tuple[str, ...] = () + name_contains: str | None = None + statuses: tuple[ScenarioStatus, ...] = () + start_date_from: date | None = None + start_date_to: date | None = None + end_date_from: date | None = None + end_date_to: date | None = None + created_from: datetime | None = None + created_to: datetime | None = None + updated_from: datetime | None = None + updated_to: datetime | None = None + currencies: tuple[str, ...] = () + primary_resources: tuple[ResourceType, ...] = () + + def normalised_ids(self) -> tuple[int, ...]: + unique = {identifier for identifier in self.ids if identifier > 0} + return tuple(sorted(unique)) + + def normalised_project_ids(self) -> tuple[int, ...]: + unique = {identifier for identifier in self.project_ids if identifier > 0} + return tuple(sorted(unique)) + + def normalised_project_names(self) -> tuple[str, ...]: + return _normalise_lower_strings(self.project_names) + + def name_search_pattern(self) -> str | None: + if not self.name_contains: + return None + pattern = self.name_contains.strip() + if not pattern: + return None + return f"%{pattern}%" + + def normalised_currencies(self) -> tuple[str, ...]: + return _normalise_upper_strings(self.currencies) + + +__all__ = ( + "ProjectExportFilters", + "ScenarioExportFilters", +) diff --git a/services/export_serializers.py b/services/export_serializers.py new file mode 100644 index 0000000..6036788 --- /dev/null +++ b/services/export_serializers.py @@ -0,0 +1,178 @@ +from __future__ import annotations + +from dataclasses import dataclass +from datetime import timezone +from functools import partial +from typing import Callable, Iterable, Iterator, Sequence + +from .export_query import ProjectExportFilters, ScenarioExportFilters + +__all__ = [ + "CSVExportColumn", + "CSVExportRowFactory", + "build_project_row_factory", + "build_scenario_row_factory", +] + +CSVValueFormatter = Callable[[object | None], str] +RowExtractor = Callable[[object], Sequence[object | None]] + + +@dataclass(slots=True, frozen=True) +class CSVExportColumn: + """Describe how a field is rendered into CSV output.""" + + header: str + accessor: str + formatter: CSVValueFormatter | None = None + + +@dataclass(slots=True, frozen=True) +class CSVExportRowFactory: + """Builds row values for CSV serialization.""" + + columns: tuple[CSVExportColumn, ...] + + def headers(self) -> tuple[str, ...]: + return tuple(column.header for column in self.columns) + + def to_row(self, entity: object) -> tuple[str, ...]: + values: list[str] = [] + for column in self.columns: + value = getattr(entity, column.accessor, None) + formatter = column.formatter or _default_formatter + values.append(formatter(value)) + return tuple(values) + + +def build_project_row_factory( + *, + include_timestamps: bool = True, + include_description: bool = True, +) -> CSVExportRowFactory: + columns: list[CSVExportColumn] = [ + CSVExportColumn(header="name", accessor="name"), + CSVExportColumn(header="location", accessor="location"), + CSVExportColumn(header="operation_type", accessor="operation_type"), + ] + if include_description: + columns.append( + CSVExportColumn(header="description", accessor="description") + ) + if include_timestamps: + columns.extend( + ( + CSVExportColumn( + header="created_at", + accessor="created_at", + formatter=_format_datetime, + ), + CSVExportColumn( + header="updated_at", + accessor="updated_at", + formatter=_format_datetime, + ), + ) + ) + return CSVExportRowFactory(columns=tuple(columns)) + + +def build_scenario_row_factory( + *, + include_description: bool = True, + include_timestamps: bool = True, +) -> CSVExportRowFactory: + columns: list[CSVExportColumn] = [ + CSVExportColumn(header="project_name", accessor="project_name"), + CSVExportColumn(header="name", accessor="name"), + CSVExportColumn(header="status", accessor="status"), + CSVExportColumn( + header="start_date", + accessor="start_date", + formatter=_format_date, + ), + CSVExportColumn( + header="end_date", + accessor="end_date", + formatter=_format_date, + ), + CSVExportColumn(header="discount_rate", + accessor="discount_rate", formatter=_format_number), + CSVExportColumn(header="currency", accessor="currency"), + CSVExportColumn(header="primary_resource", + accessor="primary_resource"), + ] + if include_description: + columns.append( + CSVExportColumn(header="description", accessor="description") + ) + if include_timestamps: + columns.extend( + ( + CSVExportColumn( + header="created_at", + accessor="created_at", + formatter=_format_datetime, + ), + CSVExportColumn( + header="updated_at", + accessor="updated_at", + formatter=_format_datetime, + ), + ) + ) + return CSVExportRowFactory(columns=tuple(columns)) + + +def stream_projects_to_csv( + projects: Sequence[object], + *, + row_factory: CSVExportRowFactory | None = None, +) -> Iterator[str]: + factory = row_factory or build_project_row_factory() + yield _join_row(factory.headers()) + for project in projects: + yield _join_row(factory.to_row(project)) + + +def stream_scenarios_to_csv( + scenarios: Sequence[object], + *, + row_factory: CSVExportRowFactory | None = None, +) -> Iterator[str]: + factory = row_factory or build_scenario_row_factory() + yield _join_row(factory.headers()) + for scenario in scenarios: + yield _join_row(factory.to_row(scenario)) + + +def _join_row(values: Iterable[str]) -> str: + return ",".join(values) + + +def _default_formatter(value: object | None) -> str: + if value is None: + return "" + return str(value) + + +def _format_datetime(value: object | None) -> str: + from datetime import datetime + + if not isinstance(value, datetime): + return "" + return value.astimezone(timezone.utc).replace(tzinfo=timezone.utc).isoformat().replace("+00:00", "Z") + + +def _format_date(value: object | None) -> str: + from datetime import date + + if not isinstance(value, date): + return "" + return value.isoformat() + + +def _format_number(value: object | None) -> str: + if value is None: + return "" + return f"{value:.2f}" diff --git a/services/repositories.py b/services/repositories.py index 64e7108..189dea8 100644 --- a/services/repositories.py +++ b/services/repositories.py @@ -11,6 +11,7 @@ from sqlalchemy.orm import Session, joinedload, selectinload from models import ( FinancialInput, Project, + ResourceType, Role, Scenario, ScenarioStatus, @@ -19,6 +20,7 @@ from models import ( UserRole, ) from services.exceptions import EntityConflictError, EntityNotFoundError +from services.export_query import ProjectExportFilters, ScenarioExportFilters class ProjectRepository: @@ -79,6 +81,52 @@ class ProjectRepository: records = self.session.execute(stmt).scalars().all() return {project.name.lower(): project for project in records} + def filtered_for_export( + self, + filters: ProjectExportFilters | None = None, + *, + include_scenarios: bool = False, + ) -> Sequence[Project]: + stmt = select(Project) + if include_scenarios: + stmt = stmt.options(selectinload(Project.scenarios)) + + if filters: + ids = filters.normalised_ids() + if ids: + stmt = stmt.where(Project.id.in_(ids)) + + name_matches = filters.normalised_names() + if name_matches: + stmt = stmt.where(func.lower(Project.name).in_(name_matches)) + + name_pattern = filters.name_search_pattern() + if name_pattern: + stmt = stmt.where(Project.name.ilike(name_pattern)) + + locations = filters.normalised_locations() + if locations: + stmt = stmt.where(func.lower(Project.location).in_(locations)) + + if filters.operation_types: + stmt = stmt.where(Project.operation_type.in_( + filters.operation_types)) + + if filters.created_from: + stmt = stmt.where(Project.created_at >= filters.created_from) + + if filters.created_to: + stmt = stmt.where(Project.created_at <= filters.created_to) + + if filters.updated_from: + stmt = stmt.where(Project.updated_at >= filters.updated_from) + + if filters.updated_to: + stmt = stmt.where(Project.updated_at <= filters.updated_to) + + stmt = stmt.order_by(Project.name, Project.id) + return self.session.execute(stmt).scalars().all() + def delete(self, project_id: int) -> None: project = self.get(project_id) self.session.delete(project) @@ -177,6 +225,76 @@ class ScenarioRepository: records = self.session.execute(stmt).scalars().all() return {scenario.name.lower(): scenario for scenario in records} + def filtered_for_export( + self, + filters: ScenarioExportFilters | None = None, + *, + include_project: bool = True, + ) -> Sequence[Scenario]: + stmt = select(Scenario) + if include_project: + stmt = stmt.options(joinedload(Scenario.project)) + + if filters: + scenario_ids = filters.normalised_ids() + if scenario_ids: + stmt = stmt.where(Scenario.id.in_(scenario_ids)) + + project_ids = filters.normalised_project_ids() + if project_ids: + stmt = stmt.where(Scenario.project_id.in_(project_ids)) + + project_names = filters.normalised_project_names() + if project_names: + project_id_select = select(Project.id).where( + func.lower(Project.name).in_(project_names) + ) + stmt = stmt.where(Scenario.project_id.in_(project_id_select)) + + name_pattern = filters.name_search_pattern() + if name_pattern: + stmt = stmt.where(Scenario.name.ilike(name_pattern)) + + if filters.statuses: + stmt = stmt.where(Scenario.status.in_(filters.statuses)) + + if filters.start_date_from: + stmt = stmt.where(Scenario.start_date >= + filters.start_date_from) + + if filters.start_date_to: + stmt = stmt.where(Scenario.start_date <= filters.start_date_to) + + if filters.end_date_from: + stmt = stmt.where(Scenario.end_date >= filters.end_date_from) + + if filters.end_date_to: + stmt = stmt.where(Scenario.end_date <= filters.end_date_to) + + if filters.created_from: + stmt = stmt.where(Scenario.created_at >= filters.created_from) + + if filters.created_to: + stmt = stmt.where(Scenario.created_at <= filters.created_to) + + if filters.updated_from: + stmt = stmt.where(Scenario.updated_at >= filters.updated_from) + + if filters.updated_to: + stmt = stmt.where(Scenario.updated_at <= filters.updated_to) + + currencies = filters.normalised_currencies() + if currencies: + stmt = stmt.where(func.upper( + Scenario.currency).in_(currencies)) + + if filters.primary_resources: + stmt = stmt.where(Scenario.primary_resource.in_( + filters.primary_resources)) + + stmt = stmt.order_by(Scenario.name, Scenario.id) + return self.session.execute(stmt).scalars().all() + def delete(self, scenario_id: int) -> None: scenario = self.get(scenario_id) self.session.delete(scenario) diff --git a/tests/test_repositories.py b/tests/test_repositories.py index 7e00910..1ea6eb9 100644 --- a/tests/test_repositories.py +++ b/tests/test_repositories.py @@ -1,7 +1,7 @@ from __future__ import annotations from collections.abc import Iterator -from datetime import datetime, timezone +from datetime import date, datetime, timezone import pytest from sqlalchemy import create_engine @@ -16,6 +16,7 @@ from models import ( Project, Scenario, ScenarioStatus, + ResourceType, SimulationParameter, StochasticVariable, ) @@ -25,6 +26,7 @@ from services.repositories import ( ScenarioRepository, SimulationParameterRepository, ) +from services.export_query import ProjectExportFilters, ScenarioExportFilters from services.unit_of_work import UnitOfWork @@ -136,6 +138,7 @@ def test_unit_of_work_commit_and_rollback(engine) -> None: # Commit path with UnitOfWork(session_factory=TestingSession) as uow: + assert uow.projects is not None uow.projects.create( Project(name="Project Delta", operation_type=MiningOperationType.PLACER) ) @@ -147,6 +150,7 @@ def test_unit_of_work_commit_and_rollback(engine) -> None: # Rollback path with pytest.raises(RuntimeError): with UnitOfWork(session_factory=TestingSession) as uow: + assert uow.projects is not None uow.projects.create( Project(name="Project Epsilon", operation_type=MiningOperationType.OTHER) ) @@ -241,4 +245,122 @@ def test_financial_input_repository_latest_created_at(session: Session) -> None: ) latest = repo.latest_created_at() - assert latest == new_timestamp \ No newline at end of file + assert latest == new_timestamp + + +def test_project_repository_filtered_for_export(session: Session) -> None: + repo = ProjectRepository(session) + + alpha_created = datetime(2025, 1, 1, 9, 30, tzinfo=timezone.utc) + alpha_updated = datetime(2025, 1, 2, 12, 0, tzinfo=timezone.utc) + bravo_created = datetime(2025, 2, 1, 9, 30, tzinfo=timezone.utc) + bravo_updated = datetime(2025, 2, 2, 12, 0, tzinfo=timezone.utc) + + project_alpha = Project( + name="Alpha", + location="Nevada", + operation_type=MiningOperationType.OPEN_PIT, + description="Primary export candidate", + ) + project_alpha.created_at = alpha_created + project_alpha.updated_at = alpha_updated + + project_bravo = Project( + name="Bravo", + location="Ontario", + operation_type=MiningOperationType.UNDERGROUND, + description="Excluded project", + ) + project_bravo.created_at = bravo_created + project_bravo.updated_at = bravo_updated + + scenario_alpha = Scenario( + name="Alpha Scenario", + project=project_alpha, + status=ScenarioStatus.ACTIVE, + ) + + session.add_all([project_alpha, project_bravo, scenario_alpha]) + session.flush() + + filters = ProjectExportFilters( + ids=(project_alpha.id, project_alpha.id, -5), + names=("Alpha", " alpha ", ""), + name_contains="alp", + locations=(" nevada ", ""), + operation_types=(MiningOperationType.OPEN_PIT,), + created_from=alpha_created, + created_to=alpha_created, + updated_from=alpha_updated, + updated_to=alpha_updated, + ) + + results = repo.filtered_for_export(filters, include_scenarios=True) + + assert [project.name for project in results] == ["Alpha"] + assert len(results[0].scenarios) == 1 + assert results[0].scenarios[0].name == "Alpha Scenario" + + +def test_scenario_repository_filtered_for_export(session: Session) -> None: + repo = ScenarioRepository(session) + + project_export = Project( + name="Export Project", + operation_type=MiningOperationType.PLACER, + ) + project_other = Project( + name="Other Project", + operation_type=MiningOperationType.OTHER, + ) + + scenario_match = Scenario( + name="Case Alpha", + project=project_export, + status=ScenarioStatus.ACTIVE, + start_date=date(2025, 1, 5), + end_date=date(2025, 2, 1), + discount_rate=7.5, + currency="usd", + primary_resource=ResourceType.EXPLOSIVES, + ) + scenario_match.created_at = datetime(2025, 1, 6, tzinfo=timezone.utc) + scenario_match.updated_at = datetime(2025, 1, 16, tzinfo=timezone.utc) + + scenario_other = Scenario( + name="Case Beta", + project=project_other, + status=ScenarioStatus.DRAFT, + start_date=date(2024, 12, 20), + end_date=date(2025, 3, 1), + currency="cad", + primary_resource=ResourceType.WATER, + ) + scenario_other.created_at = datetime(2024, 12, 25, tzinfo=timezone.utc) + scenario_other.updated_at = datetime(2025, 3, 5, tzinfo=timezone.utc) + + session.add_all([project_export, project_other, scenario_match, scenario_other]) + session.flush() + + filters = ScenarioExportFilters( + ids=(scenario_match.id, scenario_match.id, -1), + project_ids=(project_export.id, 0), + project_names=(" Export Project ", "EXPORT PROJECT"), + name_contains="case", + statuses=(ScenarioStatus.ACTIVE,), + start_date_from=date(2025, 1, 1), + start_date_to=date(2025, 1, 31), + end_date_from=date(2025, 1, 31), + end_date_to=date(2025, 2, 28), + created_from=datetime(2025, 1, 1, tzinfo=timezone.utc), + created_to=datetime(2025, 1, 31, tzinfo=timezone.utc), + updated_from=datetime(2025, 1, 10, tzinfo=timezone.utc), + updated_to=datetime(2025, 1, 31, tzinfo=timezone.utc), + currencies=(" usd ", "USD"), + primary_resources=(ResourceType.EXPLOSIVES,), + ) + + results = repo.filtered_for_export(filters, include_project=True) + + assert [scenario.name for scenario in results] == ["Case Alpha"] + assert results[0].project.name == "Export Project" \ No newline at end of file