feat: implement CSV export functionality with customizable columns and formatters

This commit is contained in:
2025-11-10 15:36:14 +01:00
parent 1a7581cda0
commit 5f183faa63
2 changed files with 347 additions and 119 deletions

View File

@@ -1,178 +1,240 @@
from __future__ import annotations
import csv
from dataclasses import dataclass
from datetime import timezone
from functools import partial
from typing import Callable, Iterable, Iterator, Sequence
from datetime import date, datetime, timezone
from decimal import Decimal, InvalidOperation, ROUND_HALF_UP
from enum import Enum
from io import StringIO
from typing import Any, Callable, Iterable, Iterator, Sequence
from .export_query import ProjectExportFilters, ScenarioExportFilters
CSVValueFormatter = Callable[[Any], str]
Accessor = Callable[[Any], Any]
__all__ = [
"CSVExportColumn",
"CSVExportRowFactory",
"build_project_row_factory",
"build_scenario_row_factory",
"CSVExporter",
"default_project_columns",
"default_scenario_columns",
"stream_projects_to_csv",
"stream_scenarios_to_csv",
"default_formatter",
"format_datetime_utc",
"format_date_iso",
"format_decimal",
]
CSVValueFormatter = Callable[[object | None], str]
RowExtractor = Callable[[object], Sequence[object | None]]
@dataclass(slots=True, frozen=True)
@dataclass(slots=True)
class CSVExportColumn:
"""Describe how a field is rendered into CSV output."""
"""Declarative description of a CSV export column."""
header: str
accessor: str
accessor: Accessor | str
formatter: CSVValueFormatter | None = None
required: bool = False
def resolve_accessor(self) -> Accessor:
if isinstance(self.accessor, str):
return _coerce_accessor(self.accessor)
return self.accessor
def value_for(self, entity: Any) -> Any:
accessor = self.resolve_accessor()
try:
return accessor(entity)
except Exception: # pragma: no cover - defensive safeguard
return None
@dataclass(slots=True, frozen=True)
class CSVExportRowFactory:
"""Builds row values for CSV serialization."""
class CSVExporter:
"""Stream Python objects as UTF-8 encoded CSV rows."""
columns: tuple[CSVExportColumn, ...]
def __init__(
self,
columns: Sequence[CSVExportColumn],
*,
include_header: bool = True,
line_terminator: str = "\n",
) -> None:
if not columns:
raise ValueError("At least one column is required for CSV export.")
self._columns: tuple[CSVExportColumn, ...] = tuple(columns)
self._include_header = include_header
self._line_terminator = line_terminator
@property
def columns(self) -> tuple[CSVExportColumn, ...]:
return self._columns
def headers(self) -> tuple[str, ...]:
return tuple(column.header for column in self.columns)
return tuple(column.header for column in self._columns)
def to_row(self, entity: object) -> tuple[str, ...]:
values: list[str] = []
for column in self.columns:
value = getattr(entity, column.accessor, None)
formatter = column.formatter or _default_formatter
values.append(formatter(value))
return tuple(values)
def iter_bytes(self, records: Iterable[Any]) -> Iterator[bytes]:
buffer = StringIO()
writer = csv.writer(buffer, lineterminator=self._line_terminator)
if self._include_header:
writer.writerow(self.headers())
yield _drain_buffer(buffer)
for record in records:
writer.writerow(self._format_row(record))
yield _drain_buffer(buffer)
def _format_row(self, record: Any) -> list[str]:
formatted: list[str] = []
for column in self._columns:
accessor = column.resolve_accessor()
raw_value = accessor(record)
formatter = column.formatter or default_formatter
formatted.append(formatter(raw_value))
return formatted
def build_project_row_factory(
*,
include_timestamps: bool = True,
include_description: bool = True,
) -> CSVExportRowFactory:
columns: list[CSVExportColumn] = [
CSVExportColumn(header="name", accessor="name"),
CSVExportColumn(header="location", accessor="location"),
CSVExportColumn(header="operation_type", accessor="operation_type"),
]
if include_description:
columns.append(
CSVExportColumn(header="description", accessor="description")
)
if include_timestamps:
columns.extend(
(
CSVExportColumn(
header="created_at",
accessor="created_at",
formatter=_format_datetime,
),
CSVExportColumn(
header="updated_at",
accessor="updated_at",
formatter=_format_datetime,
),
)
)
return CSVExportRowFactory(columns=tuple(columns))
def build_scenario_row_factory(
def default_project_columns(
*,
include_description: bool = True,
include_timestamps: bool = True,
) -> CSVExportRowFactory:
) -> tuple[CSVExportColumn, ...]:
columns: list[CSVExportColumn] = [
CSVExportColumn(header="project_name", accessor="project_name"),
CSVExportColumn(header="name", accessor="name"),
CSVExportColumn(header="status", accessor="status"),
CSVExportColumn(
header="start_date",
accessor="start_date",
formatter=_format_date,
),
CSVExportColumn(
header="end_date",
accessor="end_date",
formatter=_format_date,
),
CSVExportColumn(header="discount_rate",
accessor="discount_rate", formatter=_format_number),
CSVExportColumn(header="currency", accessor="currency"),
CSVExportColumn(header="primary_resource",
accessor="primary_resource"),
CSVExportColumn("name", "name", required=True),
CSVExportColumn("location", "location"),
CSVExportColumn("operation_type", "operation_type"),
]
if include_description:
columns.append(
CSVExportColumn(header="description", accessor="description")
)
columns.append(CSVExportColumn("description", "description"))
if include_timestamps:
columns.extend(
(
CSVExportColumn(
header="created_at",
accessor="created_at",
formatter=_format_datetime,
),
CSVExportColumn(
header="updated_at",
accessor="updated_at",
formatter=_format_datetime,
),
CSVExportColumn("created_at", "created_at",
formatter=format_datetime_utc),
CSVExportColumn("updated_at", "updated_at",
formatter=format_datetime_utc),
)
)
return CSVExportRowFactory(columns=tuple(columns))
return tuple(columns)
def default_scenario_columns(
*,
include_description: bool = True,
include_timestamps: bool = True,
) -> tuple[CSVExportColumn, ...]:
columns: list[CSVExportColumn] = [
CSVExportColumn(
"project_name",
lambda scenario: getattr(
getattr(scenario, "project", None), "name", None),
required=True,
),
CSVExportColumn("name", "name", required=True),
CSVExportColumn("status", "status"),
CSVExportColumn("start_date", "start_date", formatter=format_date_iso),
CSVExportColumn("end_date", "end_date", formatter=format_date_iso),
CSVExportColumn("discount_rate", "discount_rate",
formatter=format_decimal),
CSVExportColumn("currency", "currency"),
CSVExportColumn("primary_resource", "primary_resource"),
]
if include_description:
columns.append(CSVExportColumn("description", "description"))
if include_timestamps:
columns.extend(
(
CSVExportColumn("created_at", "created_at",
formatter=format_datetime_utc),
CSVExportColumn("updated_at", "updated_at",
formatter=format_datetime_utc),
)
)
return tuple(columns)
def stream_projects_to_csv(
projects: Sequence[object],
projects: Iterable[Any],
*,
row_factory: CSVExportRowFactory | None = None,
) -> Iterator[str]:
factory = row_factory or build_project_row_factory()
yield _join_row(factory.headers())
for project in projects:
yield _join_row(factory.to_row(project))
columns: Sequence[CSVExportColumn] | None = None,
) -> Iterator[bytes]:
resolved_columns = tuple(columns or default_project_columns())
exporter = CSVExporter(resolved_columns)
yield from exporter.iter_bytes(projects)
def stream_scenarios_to_csv(
scenarios: Sequence[object],
scenarios: Iterable[Any],
*,
row_factory: CSVExportRowFactory | None = None,
) -> Iterator[str]:
factory = row_factory or build_scenario_row_factory()
yield _join_row(factory.headers())
for scenario in scenarios:
yield _join_row(factory.to_row(scenario))
columns: Sequence[CSVExportColumn] | None = None,
) -> Iterator[bytes]:
resolved_columns = tuple(columns or default_scenario_columns())
exporter = CSVExporter(resolved_columns)
yield from exporter.iter_bytes(scenarios)
def _join_row(values: Iterable[str]) -> str:
return ",".join(values)
def _default_formatter(value: object | None) -> str:
def default_formatter(value: Any) -> str:
if value is None:
return ""
if isinstance(value, Enum):
return str(value.value)
if isinstance(value, Decimal):
return format_decimal(value)
if isinstance(value, datetime):
return format_datetime_utc(value)
if isinstance(value, date):
return format_date_iso(value)
if isinstance(value, bool):
return "true" if value else "false"
return str(value)
def _format_datetime(value: object | None) -> str:
from datetime import datetime
def format_datetime_utc(value: Any) -> str:
if not isinstance(value, datetime):
return ""
return value.astimezone(timezone.utc).replace(tzinfo=timezone.utc).isoformat().replace("+00:00", "Z")
if value.tzinfo is None:
value = value.replace(tzinfo=timezone.utc)
value = value.astimezone(timezone.utc)
return value.isoformat().replace("+00:00", "Z")
def _format_date(value: object | None) -> str:
from datetime import date
def format_date_iso(value: Any) -> str:
if not isinstance(value, date):
return ""
return value.isoformat()
def _format_number(value: object | None) -> str:
def format_decimal(value: Any) -> str:
if value is None:
return ""
if isinstance(value, Decimal):
try:
quantised = value.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP)
except InvalidOperation: # pragma: no cover - unexpected precision issues
quantised = value
return format(quantised, "f")
if isinstance(value, (int, float)):
return f"{value:.2f}"
return default_formatter(value)
def _coerce_accessor(accessor: Accessor | str) -> Accessor:
if callable(accessor):
return accessor
path = [segment for segment in accessor.split(".") if segment]
def _resolve(entity: Any) -> Any:
current: Any = entity
for segment in path:
if current is None:
return None
current = getattr(current, segment, None)
return current
return _resolve
def _drain_buffer(buffer: StringIO) -> bytes:
data = buffer.getvalue()
buffer.seek(0)
buffer.truncate(0)
return data.encode("utf-8")

View File

@@ -0,0 +1,166 @@
from __future__ import annotations
from dataclasses import dataclass
from datetime import date, datetime, timezone
from decimal import Decimal
from typing import Any, Iterable
import pytest
from services.export_serializers import (
CSVExportColumn,
CSVExporter,
default_formatter,
default_project_columns,
default_scenario_columns,
format_date_iso,
format_datetime_utc,
format_decimal,
stream_projects_to_csv,
stream_scenarios_to_csv,
)
@dataclass(slots=True)
class DummyProject:
name: str
location: str | None = None
operation_type: str = "open_pit"
description: str | None = None
created_at: datetime | None = None
updated_at: datetime | None = None
@dataclass(slots=True)
class DummyScenario:
project: DummyProject | None
name: str
status: str = "draft"
start_date: date | None = None
end_date: date | None = None
discount_rate: Decimal | None = None
currency: str | None = None
primary_resource: str | None = None
description: str | None = None
created_at: datetime | None = None
updated_at: datetime | None = None
def collect_csv_bytes(chunks: Iterable[bytes]) -> list[str]:
return [chunk.decode("utf-8") for chunk in chunks]
def test_csv_exporter_writes_header_and_rows() -> None:
exporter = CSVExporter(
[
CSVExportColumn("Name", "name"),
CSVExportColumn("Location", "location"),
]
)
project = DummyProject(name="Alpha", location="Nevada")
chunks = collect_csv_bytes(exporter.iter_bytes([project]))
assert chunks[0] == "Name,Location\n"
assert chunks[1] == "Alpha,Nevada\n"
def test_csv_exporter_handles_optional_values_and_default_formatter() -> None:
exporter = CSVExporter(
[
CSVExportColumn("Name", "name"),
CSVExportColumn("Description", "description"),
]
)
project = DummyProject(name="Bravo")
chunks = collect_csv_bytes(exporter.iter_bytes([project]))
assert chunks[-1] == "Bravo,\n"
def test_stream_projects_uses_default_columns() -> None:
projects = [
DummyProject(
name="Alpha",
location="Nevada",
operation_type="open_pit",
description="Primary",
created_at=datetime(2025, 1, 1, tzinfo=timezone.utc),
updated_at=datetime(2025, 1, 2, tzinfo=timezone.utc),
)
]
chunks = collect_csv_bytes(stream_projects_to_csv(projects))
assert chunks[0].startswith("name,location,operation_type")
assert any("Alpha" in chunk for chunk in chunks)
def test_stream_scenarios_resolves_project_name_accessor() -> None:
project = DummyProject(name="Project X")
scenario = DummyScenario(project=project, name="Scenario A")
chunks = collect_csv_bytes(stream_scenarios_to_csv([scenario]))
assert "Project X" in chunks[-1]
assert "Scenario A" in chunks[-1]
def test_custom_formatter_applies() -> None:
def uppercase(value: Any) -> str:
return str(value).upper() if value is not None else ""
exporter = CSVExporter([
CSVExportColumn("Name", "name", formatter=uppercase),
])
chunks = collect_csv_bytes(
exporter.iter_bytes([DummyProject(name="alpha")]))
assert chunks[-1] == "ALPHA\n"
def test_default_formatter_handles_multiple_types() -> None:
assert default_formatter(None) == ""
assert default_formatter(True) == "true"
assert default_formatter(False) == "false"
assert default_formatter(Decimal("1.234")) == "1.23"
assert default_formatter(
datetime(2025, 1, 1, tzinfo=timezone.utc)).endswith("Z")
assert default_formatter(date(2025, 1, 1)) == "2025-01-01"
def test_format_helpers() -> None:
assert format_date_iso(date(2025, 5, 1)) == "2025-05-01"
assert format_date_iso("not-a-date") == ""
ts = datetime(2025, 5, 1, 12, 0, tzinfo=timezone.utc)
assert format_datetime_utc(ts) == "2025-05-01T12:00:00Z"
assert format_datetime_utc("nope") == ""
assert format_decimal(None) == ""
assert format_decimal(Decimal("12.345")) == "12.35"
assert format_decimal(10) == "10.00"
def test_default_project_columns_includes_required_fields() -> None:
columns = default_project_columns()
headers = [column.header for column in columns]
assert headers[:3] == ["name", "location", "operation_type"]
def test_default_scenario_columns_handles_missing_project() -> None:
scenario = DummyScenario(project=None, name="Orphan Scenario")
exporter = CSVExporter(default_scenario_columns())
chunks = collect_csv_bytes(exporter.iter_bytes([scenario]))
assert chunks[-1].startswith(",Orphan Scenario")
def test_csv_exporter_requires_columns() -> None:
with pytest.raises(ValueError):
CSVExporter([])