from __future__ import annotations from dataclasses import dataclass from datetime import timezone from functools import partial from typing import Callable, Iterable, Iterator, Sequence from .export_query import ProjectExportFilters, ScenarioExportFilters __all__ = [ "CSVExportColumn", "CSVExportRowFactory", "build_project_row_factory", "build_scenario_row_factory", ] CSVValueFormatter = Callable[[object | None], str] RowExtractor = Callable[[object], Sequence[object | None]] @dataclass(slots=True, frozen=True) class CSVExportColumn: """Describe how a field is rendered into CSV output.""" header: str accessor: str formatter: CSVValueFormatter | None = None @dataclass(slots=True, frozen=True) class CSVExportRowFactory: """Builds row values for CSV serialization.""" columns: tuple[CSVExportColumn, ...] def headers(self) -> tuple[str, ...]: return tuple(column.header for column in self.columns) def to_row(self, entity: object) -> tuple[str, ...]: values: list[str] = [] for column in self.columns: value = getattr(entity, column.accessor, None) formatter = column.formatter or _default_formatter values.append(formatter(value)) return tuple(values) def build_project_row_factory( *, include_timestamps: bool = True, include_description: bool = True, ) -> CSVExportRowFactory: columns: list[CSVExportColumn] = [ CSVExportColumn(header="name", accessor="name"), CSVExportColumn(header="location", accessor="location"), CSVExportColumn(header="operation_type", accessor="operation_type"), ] if include_description: columns.append( CSVExportColumn(header="description", accessor="description") ) if include_timestamps: columns.extend( ( CSVExportColumn( header="created_at", accessor="created_at", formatter=_format_datetime, ), CSVExportColumn( header="updated_at", accessor="updated_at", formatter=_format_datetime, ), ) ) return CSVExportRowFactory(columns=tuple(columns)) def build_scenario_row_factory( *, include_description: bool = True, include_timestamps: bool = True, ) -> CSVExportRowFactory: columns: list[CSVExportColumn] = [ CSVExportColumn(header="project_name", accessor="project_name"), CSVExportColumn(header="name", accessor="name"), CSVExportColumn(header="status", accessor="status"), CSVExportColumn( header="start_date", accessor="start_date", formatter=_format_date, ), CSVExportColumn( header="end_date", accessor="end_date", formatter=_format_date, ), CSVExportColumn(header="discount_rate", accessor="discount_rate", formatter=_format_number), CSVExportColumn(header="currency", accessor="currency"), CSVExportColumn(header="primary_resource", accessor="primary_resource"), ] if include_description: columns.append( CSVExportColumn(header="description", accessor="description") ) if include_timestamps: columns.extend( ( CSVExportColumn( header="created_at", accessor="created_at", formatter=_format_datetime, ), CSVExportColumn( header="updated_at", accessor="updated_at", formatter=_format_datetime, ), ) ) return CSVExportRowFactory(columns=tuple(columns)) def stream_projects_to_csv( projects: Sequence[object], *, row_factory: CSVExportRowFactory | None = None, ) -> Iterator[str]: factory = row_factory or build_project_row_factory() yield _join_row(factory.headers()) for project in projects: yield _join_row(factory.to_row(project)) def stream_scenarios_to_csv( scenarios: Sequence[object], *, row_factory: CSVExportRowFactory | None = None, ) -> Iterator[str]: factory = row_factory or build_scenario_row_factory() yield _join_row(factory.headers()) for scenario in scenarios: yield _join_row(factory.to_row(scenario)) def _join_row(values: Iterable[str]) -> str: return ",".join(values) def _default_formatter(value: object | None) -> str: if value is None: return "" return str(value) def _format_datetime(value: object | None) -> str: from datetime import datetime if not isinstance(value, datetime): return "" return value.astimezone(timezone.utc).replace(tzinfo=timezone.utc).isoformat().replace("+00:00", "Z") def _format_date(value: object | None) -> str: from datetime import date if not isinstance(value, date): return "" return value.isoformat() def _format_number(value: object | None) -> str: if value is None: return "" return f"{value:.2f}"