feat: implement CSV export functionality with customizable columns and formatters

This commit is contained in:
2025-11-10 15:36:14 +01:00
parent 1a7581cda0
commit 5f183faa63
2 changed files with 347 additions and 119 deletions

View File

@@ -1,178 +1,240 @@
from __future__ import annotations from __future__ import annotations
import csv
from dataclasses import dataclass from dataclasses import dataclass
from datetime import timezone from datetime import date, datetime, timezone
from functools import partial from decimal import Decimal, InvalidOperation, ROUND_HALF_UP
from typing import Callable, Iterable, Iterator, Sequence from enum import Enum
from io import StringIO
from typing import Any, Callable, Iterable, Iterator, Sequence
from .export_query import ProjectExportFilters, ScenarioExportFilters CSVValueFormatter = Callable[[Any], str]
Accessor = Callable[[Any], Any]
__all__ = [ __all__ = [
"CSVExportColumn", "CSVExportColumn",
"CSVExportRowFactory", "CSVExporter",
"build_project_row_factory", "default_project_columns",
"build_scenario_row_factory", "default_scenario_columns",
"stream_projects_to_csv",
"stream_scenarios_to_csv",
"default_formatter",
"format_datetime_utc",
"format_date_iso",
"format_decimal",
] ]
CSVValueFormatter = Callable[[object | None], str]
RowExtractor = Callable[[object], Sequence[object | None]]
@dataclass(slots=True)
@dataclass(slots=True, frozen=True)
class CSVExportColumn: class CSVExportColumn:
"""Describe how a field is rendered into CSV output.""" """Declarative description of a CSV export column."""
header: str header: str
accessor: str accessor: Accessor | str
formatter: CSVValueFormatter | None = None formatter: CSVValueFormatter | None = None
required: bool = False
def resolve_accessor(self) -> Accessor:
if isinstance(self.accessor, str):
return _coerce_accessor(self.accessor)
return self.accessor
def value_for(self, entity: Any) -> Any:
accessor = self.resolve_accessor()
try:
return accessor(entity)
except Exception: # pragma: no cover - defensive safeguard
return None
@dataclass(slots=True, frozen=True) class CSVExporter:
class CSVExportRowFactory: """Stream Python objects as UTF-8 encoded CSV rows."""
"""Builds row values for CSV serialization."""
columns: tuple[CSVExportColumn, ...] def __init__(
self,
columns: Sequence[CSVExportColumn],
*,
include_header: bool = True,
line_terminator: str = "\n",
) -> None:
if not columns:
raise ValueError("At least one column is required for CSV export.")
self._columns: tuple[CSVExportColumn, ...] = tuple(columns)
self._include_header = include_header
self._line_terminator = line_terminator
@property
def columns(self) -> tuple[CSVExportColumn, ...]:
return self._columns
def headers(self) -> tuple[str, ...]: def headers(self) -> tuple[str, ...]:
return tuple(column.header for column in self.columns) return tuple(column.header for column in self._columns)
def to_row(self, entity: object) -> tuple[str, ...]: def iter_bytes(self, records: Iterable[Any]) -> Iterator[bytes]:
values: list[str] = [] buffer = StringIO()
for column in self.columns: writer = csv.writer(buffer, lineterminator=self._line_terminator)
value = getattr(entity, column.accessor, None)
formatter = column.formatter or _default_formatter if self._include_header:
values.append(formatter(value)) writer.writerow(self.headers())
return tuple(values) yield _drain_buffer(buffer)
for record in records:
writer.writerow(self._format_row(record))
yield _drain_buffer(buffer)
def _format_row(self, record: Any) -> list[str]:
formatted: list[str] = []
for column in self._columns:
accessor = column.resolve_accessor()
raw_value = accessor(record)
formatter = column.formatter or default_formatter
formatted.append(formatter(raw_value))
return formatted
def build_project_row_factory( def default_project_columns(
*,
include_timestamps: bool = True,
include_description: bool = True,
) -> CSVExportRowFactory:
columns: list[CSVExportColumn] = [
CSVExportColumn(header="name", accessor="name"),
CSVExportColumn(header="location", accessor="location"),
CSVExportColumn(header="operation_type", accessor="operation_type"),
]
if include_description:
columns.append(
CSVExportColumn(header="description", accessor="description")
)
if include_timestamps:
columns.extend(
(
CSVExportColumn(
header="created_at",
accessor="created_at",
formatter=_format_datetime,
),
CSVExportColumn(
header="updated_at",
accessor="updated_at",
formatter=_format_datetime,
),
)
)
return CSVExportRowFactory(columns=tuple(columns))
def build_scenario_row_factory(
*, *,
include_description: bool = True, include_description: bool = True,
include_timestamps: bool = True, include_timestamps: bool = True,
) -> CSVExportRowFactory: ) -> tuple[CSVExportColumn, ...]:
columns: list[CSVExportColumn] = [ columns: list[CSVExportColumn] = [
CSVExportColumn(header="project_name", accessor="project_name"), CSVExportColumn("name", "name", required=True),
CSVExportColumn(header="name", accessor="name"), CSVExportColumn("location", "location"),
CSVExportColumn(header="status", accessor="status"), CSVExportColumn("operation_type", "operation_type"),
CSVExportColumn(
header="start_date",
accessor="start_date",
formatter=_format_date,
),
CSVExportColumn(
header="end_date",
accessor="end_date",
formatter=_format_date,
),
CSVExportColumn(header="discount_rate",
accessor="discount_rate", formatter=_format_number),
CSVExportColumn(header="currency", accessor="currency"),
CSVExportColumn(header="primary_resource",
accessor="primary_resource"),
] ]
if include_description: if include_description:
columns.append( columns.append(CSVExportColumn("description", "description"))
CSVExportColumn(header="description", accessor="description")
)
if include_timestamps: if include_timestamps:
columns.extend( columns.extend(
( (
CSVExportColumn( CSVExportColumn("created_at", "created_at",
header="created_at", formatter=format_datetime_utc),
accessor="created_at", CSVExportColumn("updated_at", "updated_at",
formatter=_format_datetime, formatter=format_datetime_utc),
),
CSVExportColumn(
header="updated_at",
accessor="updated_at",
formatter=_format_datetime,
),
) )
) )
return CSVExportRowFactory(columns=tuple(columns)) return tuple(columns)
def default_scenario_columns(
*,
include_description: bool = True,
include_timestamps: bool = True,
) -> tuple[CSVExportColumn, ...]:
columns: list[CSVExportColumn] = [
CSVExportColumn(
"project_name",
lambda scenario: getattr(
getattr(scenario, "project", None), "name", None),
required=True,
),
CSVExportColumn("name", "name", required=True),
CSVExportColumn("status", "status"),
CSVExportColumn("start_date", "start_date", formatter=format_date_iso),
CSVExportColumn("end_date", "end_date", formatter=format_date_iso),
CSVExportColumn("discount_rate", "discount_rate",
formatter=format_decimal),
CSVExportColumn("currency", "currency"),
CSVExportColumn("primary_resource", "primary_resource"),
]
if include_description:
columns.append(CSVExportColumn("description", "description"))
if include_timestamps:
columns.extend(
(
CSVExportColumn("created_at", "created_at",
formatter=format_datetime_utc),
CSVExportColumn("updated_at", "updated_at",
formatter=format_datetime_utc),
)
)
return tuple(columns)
def stream_projects_to_csv( def stream_projects_to_csv(
projects: Sequence[object], projects: Iterable[Any],
*, *,
row_factory: CSVExportRowFactory | None = None, columns: Sequence[CSVExportColumn] | None = None,
) -> Iterator[str]: ) -> Iterator[bytes]:
factory = row_factory or build_project_row_factory() resolved_columns = tuple(columns or default_project_columns())
yield _join_row(factory.headers()) exporter = CSVExporter(resolved_columns)
for project in projects: yield from exporter.iter_bytes(projects)
yield _join_row(factory.to_row(project))
def stream_scenarios_to_csv( def stream_scenarios_to_csv(
scenarios: Sequence[object], scenarios: Iterable[Any],
*, *,
row_factory: CSVExportRowFactory | None = None, columns: Sequence[CSVExportColumn] | None = None,
) -> Iterator[str]: ) -> Iterator[bytes]:
factory = row_factory or build_scenario_row_factory() resolved_columns = tuple(columns or default_scenario_columns())
yield _join_row(factory.headers()) exporter = CSVExporter(resolved_columns)
for scenario in scenarios: yield from exporter.iter_bytes(scenarios)
yield _join_row(factory.to_row(scenario))
def _join_row(values: Iterable[str]) -> str: def default_formatter(value: Any) -> str:
return ",".join(values)
def _default_formatter(value: object | None) -> str:
if value is None: if value is None:
return "" return ""
if isinstance(value, Enum):
return str(value.value)
if isinstance(value, Decimal):
return format_decimal(value)
if isinstance(value, datetime):
return format_datetime_utc(value)
if isinstance(value, date):
return format_date_iso(value)
if isinstance(value, bool):
return "true" if value else "false"
return str(value) return str(value)
def _format_datetime(value: object | None) -> str: def format_datetime_utc(value: Any) -> str:
from datetime import datetime
if not isinstance(value, datetime): if not isinstance(value, datetime):
return "" return ""
return value.astimezone(timezone.utc).replace(tzinfo=timezone.utc).isoformat().replace("+00:00", "Z") if value.tzinfo is None:
value = value.replace(tzinfo=timezone.utc)
value = value.astimezone(timezone.utc)
return value.isoformat().replace("+00:00", "Z")
def _format_date(value: object | None) -> str: def format_date_iso(value: Any) -> str:
from datetime import date
if not isinstance(value, date): if not isinstance(value, date):
return "" return ""
return value.isoformat() return value.isoformat()
def _format_number(value: object | None) -> str: def format_decimal(value: Any) -> str:
if value is None: if value is None:
return "" return ""
if isinstance(value, Decimal):
try:
quantised = value.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP)
except InvalidOperation: # pragma: no cover - unexpected precision issues
quantised = value
return format(quantised, "f")
if isinstance(value, (int, float)):
return f"{value:.2f}" return f"{value:.2f}"
return default_formatter(value)
def _coerce_accessor(accessor: Accessor | str) -> Accessor:
if callable(accessor):
return accessor
path = [segment for segment in accessor.split(".") if segment]
def _resolve(entity: Any) -> Any:
current: Any = entity
for segment in path:
if current is None:
return None
current = getattr(current, segment, None)
return current
return _resolve
def _drain_buffer(buffer: StringIO) -> bytes:
data = buffer.getvalue()
buffer.seek(0)
buffer.truncate(0)
return data.encode("utf-8")

View File

@@ -0,0 +1,166 @@
from __future__ import annotations
from dataclasses import dataclass
from datetime import date, datetime, timezone
from decimal import Decimal
from typing import Any, Iterable
import pytest
from services.export_serializers import (
CSVExportColumn,
CSVExporter,
default_formatter,
default_project_columns,
default_scenario_columns,
format_date_iso,
format_datetime_utc,
format_decimal,
stream_projects_to_csv,
stream_scenarios_to_csv,
)
@dataclass(slots=True)
class DummyProject:
name: str
location: str | None = None
operation_type: str = "open_pit"
description: str | None = None
created_at: datetime | None = None
updated_at: datetime | None = None
@dataclass(slots=True)
class DummyScenario:
project: DummyProject | None
name: str
status: str = "draft"
start_date: date | None = None
end_date: date | None = None
discount_rate: Decimal | None = None
currency: str | None = None
primary_resource: str | None = None
description: str | None = None
created_at: datetime | None = None
updated_at: datetime | None = None
def collect_csv_bytes(chunks: Iterable[bytes]) -> list[str]:
return [chunk.decode("utf-8") for chunk in chunks]
def test_csv_exporter_writes_header_and_rows() -> None:
exporter = CSVExporter(
[
CSVExportColumn("Name", "name"),
CSVExportColumn("Location", "location"),
]
)
project = DummyProject(name="Alpha", location="Nevada")
chunks = collect_csv_bytes(exporter.iter_bytes([project]))
assert chunks[0] == "Name,Location\n"
assert chunks[1] == "Alpha,Nevada\n"
def test_csv_exporter_handles_optional_values_and_default_formatter() -> None:
exporter = CSVExporter(
[
CSVExportColumn("Name", "name"),
CSVExportColumn("Description", "description"),
]
)
project = DummyProject(name="Bravo")
chunks = collect_csv_bytes(exporter.iter_bytes([project]))
assert chunks[-1] == "Bravo,\n"
def test_stream_projects_uses_default_columns() -> None:
projects = [
DummyProject(
name="Alpha",
location="Nevada",
operation_type="open_pit",
description="Primary",
created_at=datetime(2025, 1, 1, tzinfo=timezone.utc),
updated_at=datetime(2025, 1, 2, tzinfo=timezone.utc),
)
]
chunks = collect_csv_bytes(stream_projects_to_csv(projects))
assert chunks[0].startswith("name,location,operation_type")
assert any("Alpha" in chunk for chunk in chunks)
def test_stream_scenarios_resolves_project_name_accessor() -> None:
project = DummyProject(name="Project X")
scenario = DummyScenario(project=project, name="Scenario A")
chunks = collect_csv_bytes(stream_scenarios_to_csv([scenario]))
assert "Project X" in chunks[-1]
assert "Scenario A" in chunks[-1]
def test_custom_formatter_applies() -> None:
def uppercase(value: Any) -> str:
return str(value).upper() if value is not None else ""
exporter = CSVExporter([
CSVExportColumn("Name", "name", formatter=uppercase),
])
chunks = collect_csv_bytes(
exporter.iter_bytes([DummyProject(name="alpha")]))
assert chunks[-1] == "ALPHA\n"
def test_default_formatter_handles_multiple_types() -> None:
assert default_formatter(None) == ""
assert default_formatter(True) == "true"
assert default_formatter(False) == "false"
assert default_formatter(Decimal("1.234")) == "1.23"
assert default_formatter(
datetime(2025, 1, 1, tzinfo=timezone.utc)).endswith("Z")
assert default_formatter(date(2025, 1, 1)) == "2025-01-01"
def test_format_helpers() -> None:
assert format_date_iso(date(2025, 5, 1)) == "2025-05-01"
assert format_date_iso("not-a-date") == ""
ts = datetime(2025, 5, 1, 12, 0, tzinfo=timezone.utc)
assert format_datetime_utc(ts) == "2025-05-01T12:00:00Z"
assert format_datetime_utc("nope") == ""
assert format_decimal(None) == ""
assert format_decimal(Decimal("12.345")) == "12.35"
assert format_decimal(10) == "10.00"
def test_default_project_columns_includes_required_fields() -> None:
columns = default_project_columns()
headers = [column.header for column in columns]
assert headers[:3] == ["name", "location", "operation_type"]
def test_default_scenario_columns_handles_missing_project() -> None:
scenario = DummyScenario(project=None, name="Orphan Scenario")
exporter = CSVExporter(default_scenario_columns())
chunks = collect_csv_bytes(exporter.iter_bytes([scenario]))
assert chunks[-1].startswith(",Orphan Scenario")
def test_csv_exporter_requires_columns() -> None:
with pytest.raises(ValueError):
CSVExporter([])