feat: implement CSV export functionality with customizable columns and formatters
This commit is contained in:
@@ -1,178 +1,240 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import csv
|
||||
from dataclasses import dataclass
|
||||
from datetime import timezone
|
||||
from functools import partial
|
||||
from typing import Callable, Iterable, Iterator, Sequence
|
||||
from datetime import date, datetime, timezone
|
||||
from decimal import Decimal, InvalidOperation, ROUND_HALF_UP
|
||||
from enum import Enum
|
||||
from io import StringIO
|
||||
from typing import Any, Callable, Iterable, Iterator, Sequence
|
||||
|
||||
from .export_query import ProjectExportFilters, ScenarioExportFilters
|
||||
CSVValueFormatter = Callable[[Any], str]
|
||||
Accessor = Callable[[Any], Any]
|
||||
|
||||
__all__ = [
|
||||
"CSVExportColumn",
|
||||
"CSVExportRowFactory",
|
||||
"build_project_row_factory",
|
||||
"build_scenario_row_factory",
|
||||
"CSVExporter",
|
||||
"default_project_columns",
|
||||
"default_scenario_columns",
|
||||
"stream_projects_to_csv",
|
||||
"stream_scenarios_to_csv",
|
||||
"default_formatter",
|
||||
"format_datetime_utc",
|
||||
"format_date_iso",
|
||||
"format_decimal",
|
||||
]
|
||||
|
||||
CSVValueFormatter = Callable[[object | None], str]
|
||||
RowExtractor = Callable[[object], Sequence[object | None]]
|
||||
|
||||
|
||||
@dataclass(slots=True, frozen=True)
|
||||
@dataclass(slots=True)
|
||||
class CSVExportColumn:
|
||||
"""Describe how a field is rendered into CSV output."""
|
||||
"""Declarative description of a CSV export column."""
|
||||
|
||||
header: str
|
||||
accessor: str
|
||||
accessor: Accessor | str
|
||||
formatter: CSVValueFormatter | None = None
|
||||
required: bool = False
|
||||
|
||||
def resolve_accessor(self) -> Accessor:
|
||||
if isinstance(self.accessor, str):
|
||||
return _coerce_accessor(self.accessor)
|
||||
return self.accessor
|
||||
|
||||
def value_for(self, entity: Any) -> Any:
|
||||
accessor = self.resolve_accessor()
|
||||
try:
|
||||
return accessor(entity)
|
||||
except Exception: # pragma: no cover - defensive safeguard
|
||||
return None
|
||||
|
||||
|
||||
@dataclass(slots=True, frozen=True)
|
||||
class CSVExportRowFactory:
|
||||
"""Builds row values for CSV serialization."""
|
||||
class CSVExporter:
|
||||
"""Stream Python objects as UTF-8 encoded CSV rows."""
|
||||
|
||||
columns: tuple[CSVExportColumn, ...]
|
||||
def __init__(
|
||||
self,
|
||||
columns: Sequence[CSVExportColumn],
|
||||
*,
|
||||
include_header: bool = True,
|
||||
line_terminator: str = "\n",
|
||||
) -> None:
|
||||
if not columns:
|
||||
raise ValueError("At least one column is required for CSV export.")
|
||||
self._columns: tuple[CSVExportColumn, ...] = tuple(columns)
|
||||
self._include_header = include_header
|
||||
self._line_terminator = line_terminator
|
||||
|
||||
@property
|
||||
def columns(self) -> tuple[CSVExportColumn, ...]:
|
||||
return self._columns
|
||||
|
||||
def headers(self) -> tuple[str, ...]:
|
||||
return tuple(column.header for column in self.columns)
|
||||
return tuple(column.header for column in self._columns)
|
||||
|
||||
def to_row(self, entity: object) -> tuple[str, ...]:
|
||||
values: list[str] = []
|
||||
for column in self.columns:
|
||||
value = getattr(entity, column.accessor, None)
|
||||
formatter = column.formatter or _default_formatter
|
||||
values.append(formatter(value))
|
||||
return tuple(values)
|
||||
def iter_bytes(self, records: Iterable[Any]) -> Iterator[bytes]:
|
||||
buffer = StringIO()
|
||||
writer = csv.writer(buffer, lineterminator=self._line_terminator)
|
||||
|
||||
if self._include_header:
|
||||
writer.writerow(self.headers())
|
||||
yield _drain_buffer(buffer)
|
||||
|
||||
for record in records:
|
||||
writer.writerow(self._format_row(record))
|
||||
yield _drain_buffer(buffer)
|
||||
|
||||
def _format_row(self, record: Any) -> list[str]:
|
||||
formatted: list[str] = []
|
||||
for column in self._columns:
|
||||
accessor = column.resolve_accessor()
|
||||
raw_value = accessor(record)
|
||||
formatter = column.formatter or default_formatter
|
||||
formatted.append(formatter(raw_value))
|
||||
return formatted
|
||||
|
||||
|
||||
def build_project_row_factory(
|
||||
*,
|
||||
include_timestamps: bool = True,
|
||||
include_description: bool = True,
|
||||
) -> CSVExportRowFactory:
|
||||
columns: list[CSVExportColumn] = [
|
||||
CSVExportColumn(header="name", accessor="name"),
|
||||
CSVExportColumn(header="location", accessor="location"),
|
||||
CSVExportColumn(header="operation_type", accessor="operation_type"),
|
||||
]
|
||||
if include_description:
|
||||
columns.append(
|
||||
CSVExportColumn(header="description", accessor="description")
|
||||
)
|
||||
if include_timestamps:
|
||||
columns.extend(
|
||||
(
|
||||
CSVExportColumn(
|
||||
header="created_at",
|
||||
accessor="created_at",
|
||||
formatter=_format_datetime,
|
||||
),
|
||||
CSVExportColumn(
|
||||
header="updated_at",
|
||||
accessor="updated_at",
|
||||
formatter=_format_datetime,
|
||||
),
|
||||
)
|
||||
)
|
||||
return CSVExportRowFactory(columns=tuple(columns))
|
||||
|
||||
|
||||
def build_scenario_row_factory(
|
||||
def default_project_columns(
|
||||
*,
|
||||
include_description: bool = True,
|
||||
include_timestamps: bool = True,
|
||||
) -> CSVExportRowFactory:
|
||||
) -> tuple[CSVExportColumn, ...]:
|
||||
columns: list[CSVExportColumn] = [
|
||||
CSVExportColumn(header="project_name", accessor="project_name"),
|
||||
CSVExportColumn(header="name", accessor="name"),
|
||||
CSVExportColumn(header="status", accessor="status"),
|
||||
CSVExportColumn(
|
||||
header="start_date",
|
||||
accessor="start_date",
|
||||
formatter=_format_date,
|
||||
),
|
||||
CSVExportColumn(
|
||||
header="end_date",
|
||||
accessor="end_date",
|
||||
formatter=_format_date,
|
||||
),
|
||||
CSVExportColumn(header="discount_rate",
|
||||
accessor="discount_rate", formatter=_format_number),
|
||||
CSVExportColumn(header="currency", accessor="currency"),
|
||||
CSVExportColumn(header="primary_resource",
|
||||
accessor="primary_resource"),
|
||||
CSVExportColumn("name", "name", required=True),
|
||||
CSVExportColumn("location", "location"),
|
||||
CSVExportColumn("operation_type", "operation_type"),
|
||||
]
|
||||
if include_description:
|
||||
columns.append(
|
||||
CSVExportColumn(header="description", accessor="description")
|
||||
)
|
||||
columns.append(CSVExportColumn("description", "description"))
|
||||
if include_timestamps:
|
||||
columns.extend(
|
||||
(
|
||||
CSVExportColumn(
|
||||
header="created_at",
|
||||
accessor="created_at",
|
||||
formatter=_format_datetime,
|
||||
),
|
||||
CSVExportColumn(
|
||||
header="updated_at",
|
||||
accessor="updated_at",
|
||||
formatter=_format_datetime,
|
||||
),
|
||||
CSVExportColumn("created_at", "created_at",
|
||||
formatter=format_datetime_utc),
|
||||
CSVExportColumn("updated_at", "updated_at",
|
||||
formatter=format_datetime_utc),
|
||||
)
|
||||
)
|
||||
return CSVExportRowFactory(columns=tuple(columns))
|
||||
return tuple(columns)
|
||||
|
||||
|
||||
def default_scenario_columns(
|
||||
*,
|
||||
include_description: bool = True,
|
||||
include_timestamps: bool = True,
|
||||
) -> tuple[CSVExportColumn, ...]:
|
||||
columns: list[CSVExportColumn] = [
|
||||
CSVExportColumn(
|
||||
"project_name",
|
||||
lambda scenario: getattr(
|
||||
getattr(scenario, "project", None), "name", None),
|
||||
required=True,
|
||||
),
|
||||
CSVExportColumn("name", "name", required=True),
|
||||
CSVExportColumn("status", "status"),
|
||||
CSVExportColumn("start_date", "start_date", formatter=format_date_iso),
|
||||
CSVExportColumn("end_date", "end_date", formatter=format_date_iso),
|
||||
CSVExportColumn("discount_rate", "discount_rate",
|
||||
formatter=format_decimal),
|
||||
CSVExportColumn("currency", "currency"),
|
||||
CSVExportColumn("primary_resource", "primary_resource"),
|
||||
]
|
||||
if include_description:
|
||||
columns.append(CSVExportColumn("description", "description"))
|
||||
if include_timestamps:
|
||||
columns.extend(
|
||||
(
|
||||
CSVExportColumn("created_at", "created_at",
|
||||
formatter=format_datetime_utc),
|
||||
CSVExportColumn("updated_at", "updated_at",
|
||||
formatter=format_datetime_utc),
|
||||
)
|
||||
)
|
||||
return tuple(columns)
|
||||
|
||||
|
||||
def stream_projects_to_csv(
|
||||
projects: Sequence[object],
|
||||
projects: Iterable[Any],
|
||||
*,
|
||||
row_factory: CSVExportRowFactory | None = None,
|
||||
) -> Iterator[str]:
|
||||
factory = row_factory or build_project_row_factory()
|
||||
yield _join_row(factory.headers())
|
||||
for project in projects:
|
||||
yield _join_row(factory.to_row(project))
|
||||
columns: Sequence[CSVExportColumn] | None = None,
|
||||
) -> Iterator[bytes]:
|
||||
resolved_columns = tuple(columns or default_project_columns())
|
||||
exporter = CSVExporter(resolved_columns)
|
||||
yield from exporter.iter_bytes(projects)
|
||||
|
||||
|
||||
def stream_scenarios_to_csv(
|
||||
scenarios: Sequence[object],
|
||||
scenarios: Iterable[Any],
|
||||
*,
|
||||
row_factory: CSVExportRowFactory | None = None,
|
||||
) -> Iterator[str]:
|
||||
factory = row_factory or build_scenario_row_factory()
|
||||
yield _join_row(factory.headers())
|
||||
for scenario in scenarios:
|
||||
yield _join_row(factory.to_row(scenario))
|
||||
columns: Sequence[CSVExportColumn] | None = None,
|
||||
) -> Iterator[bytes]:
|
||||
resolved_columns = tuple(columns or default_scenario_columns())
|
||||
exporter = CSVExporter(resolved_columns)
|
||||
yield from exporter.iter_bytes(scenarios)
|
||||
|
||||
|
||||
def _join_row(values: Iterable[str]) -> str:
|
||||
return ",".join(values)
|
||||
|
||||
|
||||
def _default_formatter(value: object | None) -> str:
|
||||
def default_formatter(value: Any) -> str:
|
||||
if value is None:
|
||||
return ""
|
||||
if isinstance(value, Enum):
|
||||
return str(value.value)
|
||||
if isinstance(value, Decimal):
|
||||
return format_decimal(value)
|
||||
if isinstance(value, datetime):
|
||||
return format_datetime_utc(value)
|
||||
if isinstance(value, date):
|
||||
return format_date_iso(value)
|
||||
if isinstance(value, bool):
|
||||
return "true" if value else "false"
|
||||
return str(value)
|
||||
|
||||
|
||||
def _format_datetime(value: object | None) -> str:
|
||||
from datetime import datetime
|
||||
|
||||
def format_datetime_utc(value: Any) -> str:
|
||||
if not isinstance(value, datetime):
|
||||
return ""
|
||||
return value.astimezone(timezone.utc).replace(tzinfo=timezone.utc).isoformat().replace("+00:00", "Z")
|
||||
if value.tzinfo is None:
|
||||
value = value.replace(tzinfo=timezone.utc)
|
||||
value = value.astimezone(timezone.utc)
|
||||
return value.isoformat().replace("+00:00", "Z")
|
||||
|
||||
|
||||
def _format_date(value: object | None) -> str:
|
||||
from datetime import date
|
||||
|
||||
def format_date_iso(value: Any) -> str:
|
||||
if not isinstance(value, date):
|
||||
return ""
|
||||
return value.isoformat()
|
||||
|
||||
|
||||
def _format_number(value: object | None) -> str:
|
||||
def format_decimal(value: Any) -> str:
|
||||
if value is None:
|
||||
return ""
|
||||
if isinstance(value, Decimal):
|
||||
try:
|
||||
quantised = value.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP)
|
||||
except InvalidOperation: # pragma: no cover - unexpected precision issues
|
||||
quantised = value
|
||||
return format(quantised, "f")
|
||||
if isinstance(value, (int, float)):
|
||||
return f"{value:.2f}"
|
||||
return default_formatter(value)
|
||||
|
||||
|
||||
def _coerce_accessor(accessor: Accessor | str) -> Accessor:
|
||||
if callable(accessor):
|
||||
return accessor
|
||||
|
||||
path = [segment for segment in accessor.split(".") if segment]
|
||||
|
||||
def _resolve(entity: Any) -> Any:
|
||||
current: Any = entity
|
||||
for segment in path:
|
||||
if current is None:
|
||||
return None
|
||||
current = getattr(current, segment, None)
|
||||
return current
|
||||
|
||||
return _resolve
|
||||
|
||||
|
||||
def _drain_buffer(buffer: StringIO) -> bytes:
|
||||
data = buffer.getvalue()
|
||||
buffer.seek(0)
|
||||
buffer.truncate(0)
|
||||
return data.encode("utf-8")
|
||||
|
||||
166
tests/test_export_serializers.py
Normal file
166
tests/test_export_serializers.py
Normal file
@@ -0,0 +1,166 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from datetime import date, datetime, timezone
|
||||
from decimal import Decimal
|
||||
from typing import Any, Iterable
|
||||
|
||||
import pytest
|
||||
|
||||
from services.export_serializers import (
|
||||
CSVExportColumn,
|
||||
CSVExporter,
|
||||
default_formatter,
|
||||
default_project_columns,
|
||||
default_scenario_columns,
|
||||
format_date_iso,
|
||||
format_datetime_utc,
|
||||
format_decimal,
|
||||
stream_projects_to_csv,
|
||||
stream_scenarios_to_csv,
|
||||
)
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class DummyProject:
|
||||
name: str
|
||||
location: str | None = None
|
||||
operation_type: str = "open_pit"
|
||||
description: str | None = None
|
||||
created_at: datetime | None = None
|
||||
updated_at: datetime | None = None
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class DummyScenario:
|
||||
project: DummyProject | None
|
||||
name: str
|
||||
status: str = "draft"
|
||||
start_date: date | None = None
|
||||
end_date: date | None = None
|
||||
discount_rate: Decimal | None = None
|
||||
currency: str | None = None
|
||||
primary_resource: str | None = None
|
||||
description: str | None = None
|
||||
created_at: datetime | None = None
|
||||
updated_at: datetime | None = None
|
||||
|
||||
|
||||
def collect_csv_bytes(chunks: Iterable[bytes]) -> list[str]:
|
||||
return [chunk.decode("utf-8") for chunk in chunks]
|
||||
|
||||
|
||||
def test_csv_exporter_writes_header_and_rows() -> None:
|
||||
exporter = CSVExporter(
|
||||
[
|
||||
CSVExportColumn("Name", "name"),
|
||||
CSVExportColumn("Location", "location"),
|
||||
]
|
||||
)
|
||||
|
||||
project = DummyProject(name="Alpha", location="Nevada")
|
||||
|
||||
chunks = collect_csv_bytes(exporter.iter_bytes([project]))
|
||||
|
||||
assert chunks[0] == "Name,Location\n"
|
||||
assert chunks[1] == "Alpha,Nevada\n"
|
||||
|
||||
|
||||
def test_csv_exporter_handles_optional_values_and_default_formatter() -> None:
|
||||
exporter = CSVExporter(
|
||||
[
|
||||
CSVExportColumn("Name", "name"),
|
||||
CSVExportColumn("Description", "description"),
|
||||
]
|
||||
)
|
||||
|
||||
project = DummyProject(name="Bravo")
|
||||
|
||||
chunks = collect_csv_bytes(exporter.iter_bytes([project]))
|
||||
|
||||
assert chunks[-1] == "Bravo,\n"
|
||||
|
||||
|
||||
def test_stream_projects_uses_default_columns() -> None:
|
||||
projects = [
|
||||
DummyProject(
|
||||
name="Alpha",
|
||||
location="Nevada",
|
||||
operation_type="open_pit",
|
||||
description="Primary",
|
||||
created_at=datetime(2025, 1, 1, tzinfo=timezone.utc),
|
||||
updated_at=datetime(2025, 1, 2, tzinfo=timezone.utc),
|
||||
)
|
||||
]
|
||||
|
||||
chunks = collect_csv_bytes(stream_projects_to_csv(projects))
|
||||
|
||||
assert chunks[0].startswith("name,location,operation_type")
|
||||
assert any("Alpha" in chunk for chunk in chunks)
|
||||
|
||||
|
||||
def test_stream_scenarios_resolves_project_name_accessor() -> None:
|
||||
project = DummyProject(name="Project X")
|
||||
scenario = DummyScenario(project=project, name="Scenario A")
|
||||
|
||||
chunks = collect_csv_bytes(stream_scenarios_to_csv([scenario]))
|
||||
|
||||
assert "Project X" in chunks[-1]
|
||||
assert "Scenario A" in chunks[-1]
|
||||
|
||||
|
||||
def test_custom_formatter_applies() -> None:
|
||||
def uppercase(value: Any) -> str:
|
||||
return str(value).upper() if value is not None else ""
|
||||
|
||||
exporter = CSVExporter([
|
||||
CSVExportColumn("Name", "name", formatter=uppercase),
|
||||
])
|
||||
|
||||
chunks = collect_csv_bytes(
|
||||
exporter.iter_bytes([DummyProject(name="alpha")]))
|
||||
|
||||
assert chunks[-1] == "ALPHA\n"
|
||||
|
||||
|
||||
def test_default_formatter_handles_multiple_types() -> None:
|
||||
assert default_formatter(None) == ""
|
||||
assert default_formatter(True) == "true"
|
||||
assert default_formatter(False) == "false"
|
||||
assert default_formatter(Decimal("1.234")) == "1.23"
|
||||
assert default_formatter(
|
||||
datetime(2025, 1, 1, tzinfo=timezone.utc)).endswith("Z")
|
||||
assert default_formatter(date(2025, 1, 1)) == "2025-01-01"
|
||||
|
||||
|
||||
def test_format_helpers() -> None:
|
||||
assert format_date_iso(date(2025, 5, 1)) == "2025-05-01"
|
||||
assert format_date_iso("not-a-date") == ""
|
||||
|
||||
ts = datetime(2025, 5, 1, 12, 0, tzinfo=timezone.utc)
|
||||
assert format_datetime_utc(ts) == "2025-05-01T12:00:00Z"
|
||||
|
||||
assert format_datetime_utc("nope") == ""
|
||||
assert format_decimal(None) == ""
|
||||
assert format_decimal(Decimal("12.345")) == "12.35"
|
||||
assert format_decimal(10) == "10.00"
|
||||
|
||||
|
||||
def test_default_project_columns_includes_required_fields() -> None:
|
||||
columns = default_project_columns()
|
||||
headers = [column.header for column in columns]
|
||||
assert headers[:3] == ["name", "location", "operation_type"]
|
||||
|
||||
|
||||
def test_default_scenario_columns_handles_missing_project() -> None:
|
||||
scenario = DummyScenario(project=None, name="Orphan Scenario")
|
||||
|
||||
exporter = CSVExporter(default_scenario_columns())
|
||||
chunks = collect_csv_bytes(exporter.iter_bytes([scenario]))
|
||||
|
||||
assert chunks[-1].startswith(",Orphan Scenario")
|
||||
|
||||
|
||||
def test_csv_exporter_requires_columns() -> None:
|
||||
with pytest.raises(ValueError):
|
||||
CSVExporter([])
|
||||
Reference in New Issue
Block a user