diff --git a/services/export_serializers.py b/services/export_serializers.py index 6036788..8e05da2 100644 --- a/services/export_serializers.py +++ b/services/export_serializers.py @@ -1,178 +1,240 @@ from __future__ import annotations +import csv from dataclasses import dataclass -from datetime import timezone -from functools import partial -from typing import Callable, Iterable, Iterator, Sequence +from datetime import date, datetime, timezone +from decimal import Decimal, InvalidOperation, ROUND_HALF_UP +from enum import Enum +from io import StringIO +from typing import Any, Callable, Iterable, Iterator, Sequence -from .export_query import ProjectExportFilters, ScenarioExportFilters +CSVValueFormatter = Callable[[Any], str] +Accessor = Callable[[Any], Any] __all__ = [ "CSVExportColumn", - "CSVExportRowFactory", - "build_project_row_factory", - "build_scenario_row_factory", + "CSVExporter", + "default_project_columns", + "default_scenario_columns", + "stream_projects_to_csv", + "stream_scenarios_to_csv", + "default_formatter", + "format_datetime_utc", + "format_date_iso", + "format_decimal", ] -CSVValueFormatter = Callable[[object | None], str] -RowExtractor = Callable[[object], Sequence[object | None]] - -@dataclass(slots=True, frozen=True) +@dataclass(slots=True) class CSVExportColumn: - """Describe how a field is rendered into CSV output.""" + """Declarative description of a CSV export column.""" header: str - accessor: str + accessor: Accessor | str formatter: CSVValueFormatter | None = None + required: bool = False + + def resolve_accessor(self) -> Accessor: + if isinstance(self.accessor, str): + return _coerce_accessor(self.accessor) + return self.accessor + + def value_for(self, entity: Any) -> Any: + accessor = self.resolve_accessor() + try: + return accessor(entity) + except Exception: # pragma: no cover - defensive safeguard + return None -@dataclass(slots=True, frozen=True) -class CSVExportRowFactory: - """Builds row values for CSV serialization.""" +class CSVExporter: + """Stream Python objects as UTF-8 encoded CSV rows.""" - columns: tuple[CSVExportColumn, ...] + def __init__( + self, + columns: Sequence[CSVExportColumn], + *, + include_header: bool = True, + line_terminator: str = "\n", + ) -> None: + if not columns: + raise ValueError("At least one column is required for CSV export.") + self._columns: tuple[CSVExportColumn, ...] = tuple(columns) + self._include_header = include_header + self._line_terminator = line_terminator + + @property + def columns(self) -> tuple[CSVExportColumn, ...]: + return self._columns def headers(self) -> tuple[str, ...]: - return tuple(column.header for column in self.columns) + return tuple(column.header for column in self._columns) - def to_row(self, entity: object) -> tuple[str, ...]: - values: list[str] = [] - for column in self.columns: - value = getattr(entity, column.accessor, None) - formatter = column.formatter or _default_formatter - values.append(formatter(value)) - return tuple(values) + def iter_bytes(self, records: Iterable[Any]) -> Iterator[bytes]: + buffer = StringIO() + writer = csv.writer(buffer, lineterminator=self._line_terminator) + + if self._include_header: + writer.writerow(self.headers()) + yield _drain_buffer(buffer) + + for record in records: + writer.writerow(self._format_row(record)) + yield _drain_buffer(buffer) + + def _format_row(self, record: Any) -> list[str]: + formatted: list[str] = [] + for column in self._columns: + accessor = column.resolve_accessor() + raw_value = accessor(record) + formatter = column.formatter or default_formatter + formatted.append(formatter(raw_value)) + return formatted -def build_project_row_factory( - *, - include_timestamps: bool = True, - include_description: bool = True, -) -> CSVExportRowFactory: - columns: list[CSVExportColumn] = [ - CSVExportColumn(header="name", accessor="name"), - CSVExportColumn(header="location", accessor="location"), - CSVExportColumn(header="operation_type", accessor="operation_type"), - ] - if include_description: - columns.append( - CSVExportColumn(header="description", accessor="description") - ) - if include_timestamps: - columns.extend( - ( - CSVExportColumn( - header="created_at", - accessor="created_at", - formatter=_format_datetime, - ), - CSVExportColumn( - header="updated_at", - accessor="updated_at", - formatter=_format_datetime, - ), - ) - ) - return CSVExportRowFactory(columns=tuple(columns)) - - -def build_scenario_row_factory( +def default_project_columns( *, include_description: bool = True, include_timestamps: bool = True, -) -> CSVExportRowFactory: +) -> tuple[CSVExportColumn, ...]: columns: list[CSVExportColumn] = [ - CSVExportColumn(header="project_name", accessor="project_name"), - CSVExportColumn(header="name", accessor="name"), - CSVExportColumn(header="status", accessor="status"), - CSVExportColumn( - header="start_date", - accessor="start_date", - formatter=_format_date, - ), - CSVExportColumn( - header="end_date", - accessor="end_date", - formatter=_format_date, - ), - CSVExportColumn(header="discount_rate", - accessor="discount_rate", formatter=_format_number), - CSVExportColumn(header="currency", accessor="currency"), - CSVExportColumn(header="primary_resource", - accessor="primary_resource"), + CSVExportColumn("name", "name", required=True), + CSVExportColumn("location", "location"), + CSVExportColumn("operation_type", "operation_type"), ] if include_description: - columns.append( - CSVExportColumn(header="description", accessor="description") - ) + columns.append(CSVExportColumn("description", "description")) if include_timestamps: columns.extend( ( - CSVExportColumn( - header="created_at", - accessor="created_at", - formatter=_format_datetime, - ), - CSVExportColumn( - header="updated_at", - accessor="updated_at", - formatter=_format_datetime, - ), + CSVExportColumn("created_at", "created_at", + formatter=format_datetime_utc), + CSVExportColumn("updated_at", "updated_at", + formatter=format_datetime_utc), ) ) - return CSVExportRowFactory(columns=tuple(columns)) + return tuple(columns) + + +def default_scenario_columns( + *, + include_description: bool = True, + include_timestamps: bool = True, +) -> tuple[CSVExportColumn, ...]: + columns: list[CSVExportColumn] = [ + CSVExportColumn( + "project_name", + lambda scenario: getattr( + getattr(scenario, "project", None), "name", None), + required=True, + ), + CSVExportColumn("name", "name", required=True), + CSVExportColumn("status", "status"), + CSVExportColumn("start_date", "start_date", formatter=format_date_iso), + CSVExportColumn("end_date", "end_date", formatter=format_date_iso), + CSVExportColumn("discount_rate", "discount_rate", + formatter=format_decimal), + CSVExportColumn("currency", "currency"), + CSVExportColumn("primary_resource", "primary_resource"), + ] + if include_description: + columns.append(CSVExportColumn("description", "description")) + if include_timestamps: + columns.extend( + ( + CSVExportColumn("created_at", "created_at", + formatter=format_datetime_utc), + CSVExportColumn("updated_at", "updated_at", + formatter=format_datetime_utc), + ) + ) + return tuple(columns) def stream_projects_to_csv( - projects: Sequence[object], + projects: Iterable[Any], *, - row_factory: CSVExportRowFactory | None = None, -) -> Iterator[str]: - factory = row_factory or build_project_row_factory() - yield _join_row(factory.headers()) - for project in projects: - yield _join_row(factory.to_row(project)) + columns: Sequence[CSVExportColumn] | None = None, +) -> Iterator[bytes]: + resolved_columns = tuple(columns or default_project_columns()) + exporter = CSVExporter(resolved_columns) + yield from exporter.iter_bytes(projects) def stream_scenarios_to_csv( - scenarios: Sequence[object], + scenarios: Iterable[Any], *, - row_factory: CSVExportRowFactory | None = None, -) -> Iterator[str]: - factory = row_factory or build_scenario_row_factory() - yield _join_row(factory.headers()) - for scenario in scenarios: - yield _join_row(factory.to_row(scenario)) + columns: Sequence[CSVExportColumn] | None = None, +) -> Iterator[bytes]: + resolved_columns = tuple(columns or default_scenario_columns()) + exporter = CSVExporter(resolved_columns) + yield from exporter.iter_bytes(scenarios) -def _join_row(values: Iterable[str]) -> str: - return ",".join(values) - - -def _default_formatter(value: object | None) -> str: +def default_formatter(value: Any) -> str: if value is None: return "" + if isinstance(value, Enum): + return str(value.value) + if isinstance(value, Decimal): + return format_decimal(value) + if isinstance(value, datetime): + return format_datetime_utc(value) + if isinstance(value, date): + return format_date_iso(value) + if isinstance(value, bool): + return "true" if value else "false" return str(value) -def _format_datetime(value: object | None) -> str: - from datetime import datetime - +def format_datetime_utc(value: Any) -> str: if not isinstance(value, datetime): return "" - return value.astimezone(timezone.utc).replace(tzinfo=timezone.utc).isoformat().replace("+00:00", "Z") + if value.tzinfo is None: + value = value.replace(tzinfo=timezone.utc) + value = value.astimezone(timezone.utc) + return value.isoformat().replace("+00:00", "Z") -def _format_date(value: object | None) -> str: - from datetime import date - +def format_date_iso(value: Any) -> str: if not isinstance(value, date): return "" return value.isoformat() -def _format_number(value: object | None) -> str: +def format_decimal(value: Any) -> str: if value is None: return "" - return f"{value:.2f}" + if isinstance(value, Decimal): + try: + quantised = value.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP) + except InvalidOperation: # pragma: no cover - unexpected precision issues + quantised = value + return format(quantised, "f") + if isinstance(value, (int, float)): + return f"{value:.2f}" + return default_formatter(value) + + +def _coerce_accessor(accessor: Accessor | str) -> Accessor: + if callable(accessor): + return accessor + + path = [segment for segment in accessor.split(".") if segment] + + def _resolve(entity: Any) -> Any: + current: Any = entity + for segment in path: + if current is None: + return None + current = getattr(current, segment, None) + return current + + return _resolve + + +def _drain_buffer(buffer: StringIO) -> bytes: + data = buffer.getvalue() + buffer.seek(0) + buffer.truncate(0) + return data.encode("utf-8") diff --git a/tests/test_export_serializers.py b/tests/test_export_serializers.py new file mode 100644 index 0000000..7529d5d --- /dev/null +++ b/tests/test_export_serializers.py @@ -0,0 +1,166 @@ +from __future__ import annotations + +from dataclasses import dataclass +from datetime import date, datetime, timezone +from decimal import Decimal +from typing import Any, Iterable + +import pytest + +from services.export_serializers import ( + CSVExportColumn, + CSVExporter, + default_formatter, + default_project_columns, + default_scenario_columns, + format_date_iso, + format_datetime_utc, + format_decimal, + stream_projects_to_csv, + stream_scenarios_to_csv, +) + + +@dataclass(slots=True) +class DummyProject: + name: str + location: str | None = None + operation_type: str = "open_pit" + description: str | None = None + created_at: datetime | None = None + updated_at: datetime | None = None + + +@dataclass(slots=True) +class DummyScenario: + project: DummyProject | None + name: str + status: str = "draft" + start_date: date | None = None + end_date: date | None = None + discount_rate: Decimal | None = None + currency: str | None = None + primary_resource: str | None = None + description: str | None = None + created_at: datetime | None = None + updated_at: datetime | None = None + + +def collect_csv_bytes(chunks: Iterable[bytes]) -> list[str]: + return [chunk.decode("utf-8") for chunk in chunks] + + +def test_csv_exporter_writes_header_and_rows() -> None: + exporter = CSVExporter( + [ + CSVExportColumn("Name", "name"), + CSVExportColumn("Location", "location"), + ] + ) + + project = DummyProject(name="Alpha", location="Nevada") + + chunks = collect_csv_bytes(exporter.iter_bytes([project])) + + assert chunks[0] == "Name,Location\n" + assert chunks[1] == "Alpha,Nevada\n" + + +def test_csv_exporter_handles_optional_values_and_default_formatter() -> None: + exporter = CSVExporter( + [ + CSVExportColumn("Name", "name"), + CSVExportColumn("Description", "description"), + ] + ) + + project = DummyProject(name="Bravo") + + chunks = collect_csv_bytes(exporter.iter_bytes([project])) + + assert chunks[-1] == "Bravo,\n" + + +def test_stream_projects_uses_default_columns() -> None: + projects = [ + DummyProject( + name="Alpha", + location="Nevada", + operation_type="open_pit", + description="Primary", + created_at=datetime(2025, 1, 1, tzinfo=timezone.utc), + updated_at=datetime(2025, 1, 2, tzinfo=timezone.utc), + ) + ] + + chunks = collect_csv_bytes(stream_projects_to_csv(projects)) + + assert chunks[0].startswith("name,location,operation_type") + assert any("Alpha" in chunk for chunk in chunks) + + +def test_stream_scenarios_resolves_project_name_accessor() -> None: + project = DummyProject(name="Project X") + scenario = DummyScenario(project=project, name="Scenario A") + + chunks = collect_csv_bytes(stream_scenarios_to_csv([scenario])) + + assert "Project X" in chunks[-1] + assert "Scenario A" in chunks[-1] + + +def test_custom_formatter_applies() -> None: + def uppercase(value: Any) -> str: + return str(value).upper() if value is not None else "" + + exporter = CSVExporter([ + CSVExportColumn("Name", "name", formatter=uppercase), + ]) + + chunks = collect_csv_bytes( + exporter.iter_bytes([DummyProject(name="alpha")])) + + assert chunks[-1] == "ALPHA\n" + + +def test_default_formatter_handles_multiple_types() -> None: + assert default_formatter(None) == "" + assert default_formatter(True) == "true" + assert default_formatter(False) == "false" + assert default_formatter(Decimal("1.234")) == "1.23" + assert default_formatter( + datetime(2025, 1, 1, tzinfo=timezone.utc)).endswith("Z") + assert default_formatter(date(2025, 1, 1)) == "2025-01-01" + + +def test_format_helpers() -> None: + assert format_date_iso(date(2025, 5, 1)) == "2025-05-01" + assert format_date_iso("not-a-date") == "" + + ts = datetime(2025, 5, 1, 12, 0, tzinfo=timezone.utc) + assert format_datetime_utc(ts) == "2025-05-01T12:00:00Z" + + assert format_datetime_utc("nope") == "" + assert format_decimal(None) == "" + assert format_decimal(Decimal("12.345")) == "12.35" + assert format_decimal(10) == "10.00" + + +def test_default_project_columns_includes_required_fields() -> None: + columns = default_project_columns() + headers = [column.header for column in columns] + assert headers[:3] == ["name", "location", "operation_type"] + + +def test_default_scenario_columns_handles_missing_project() -> None: + scenario = DummyScenario(project=None, name="Orphan Scenario") + + exporter = CSVExporter(default_scenario_columns()) + chunks = collect_csv_bytes(exporter.iter_bytes([scenario])) + + assert chunks[-1].startswith(",Orphan Scenario") + + +def test_csv_exporter_requires_columns() -> None: + with pytest.raises(ValueError): + CSVExporter([])