feat: implement CSV export functionality with customizable columns and formatters

This commit is contained in:
2025-11-10 15:36:14 +01:00
parent 1a7581cda0
commit 5f183faa63
2 changed files with 347 additions and 119 deletions

View File

@@ -1,178 +1,240 @@
from __future__ import annotations
import csv
from dataclasses import dataclass
from datetime import timezone
from functools import partial
from typing import Callable, Iterable, Iterator, Sequence
from datetime import date, datetime, timezone
from decimal import Decimal, InvalidOperation, ROUND_HALF_UP
from enum import Enum
from io import StringIO
from typing import Any, Callable, Iterable, Iterator, Sequence
from .export_query import ProjectExportFilters, ScenarioExportFilters
CSVValueFormatter = Callable[[Any], str]
Accessor = Callable[[Any], Any]
__all__ = [
"CSVExportColumn",
"CSVExportRowFactory",
"build_project_row_factory",
"build_scenario_row_factory",
"CSVExporter",
"default_project_columns",
"default_scenario_columns",
"stream_projects_to_csv",
"stream_scenarios_to_csv",
"default_formatter",
"format_datetime_utc",
"format_date_iso",
"format_decimal",
]
CSVValueFormatter = Callable[[object | None], str]
RowExtractor = Callable[[object], Sequence[object | None]]
@dataclass(slots=True, frozen=True)
@dataclass(slots=True)
class CSVExportColumn:
"""Describe how a field is rendered into CSV output."""
"""Declarative description of a CSV export column."""
header: str
accessor: str
accessor: Accessor | str
formatter: CSVValueFormatter | None = None
required: bool = False
def resolve_accessor(self) -> Accessor:
if isinstance(self.accessor, str):
return _coerce_accessor(self.accessor)
return self.accessor
def value_for(self, entity: Any) -> Any:
accessor = self.resolve_accessor()
try:
return accessor(entity)
except Exception: # pragma: no cover - defensive safeguard
return None
@dataclass(slots=True, frozen=True)
class CSVExportRowFactory:
"""Builds row values for CSV serialization."""
class CSVExporter:
"""Stream Python objects as UTF-8 encoded CSV rows."""
columns: tuple[CSVExportColumn, ...]
def __init__(
self,
columns: Sequence[CSVExportColumn],
*,
include_header: bool = True,
line_terminator: str = "\n",
) -> None:
if not columns:
raise ValueError("At least one column is required for CSV export.")
self._columns: tuple[CSVExportColumn, ...] = tuple(columns)
self._include_header = include_header
self._line_terminator = line_terminator
@property
def columns(self) -> tuple[CSVExportColumn, ...]:
return self._columns
def headers(self) -> tuple[str, ...]:
return tuple(column.header for column in self.columns)
return tuple(column.header for column in self._columns)
def to_row(self, entity: object) -> tuple[str, ...]:
values: list[str] = []
for column in self.columns:
value = getattr(entity, column.accessor, None)
formatter = column.formatter or _default_formatter
values.append(formatter(value))
return tuple(values)
def iter_bytes(self, records: Iterable[Any]) -> Iterator[bytes]:
buffer = StringIO()
writer = csv.writer(buffer, lineterminator=self._line_terminator)
if self._include_header:
writer.writerow(self.headers())
yield _drain_buffer(buffer)
for record in records:
writer.writerow(self._format_row(record))
yield _drain_buffer(buffer)
def _format_row(self, record: Any) -> list[str]:
formatted: list[str] = []
for column in self._columns:
accessor = column.resolve_accessor()
raw_value = accessor(record)
formatter = column.formatter or default_formatter
formatted.append(formatter(raw_value))
return formatted
def build_project_row_factory(
*,
include_timestamps: bool = True,
include_description: bool = True,
) -> CSVExportRowFactory:
columns: list[CSVExportColumn] = [
CSVExportColumn(header="name", accessor="name"),
CSVExportColumn(header="location", accessor="location"),
CSVExportColumn(header="operation_type", accessor="operation_type"),
]
if include_description:
columns.append(
CSVExportColumn(header="description", accessor="description")
)
if include_timestamps:
columns.extend(
(
CSVExportColumn(
header="created_at",
accessor="created_at",
formatter=_format_datetime,
),
CSVExportColumn(
header="updated_at",
accessor="updated_at",
formatter=_format_datetime,
),
)
)
return CSVExportRowFactory(columns=tuple(columns))
def build_scenario_row_factory(
def default_project_columns(
*,
include_description: bool = True,
include_timestamps: bool = True,
) -> CSVExportRowFactory:
) -> tuple[CSVExportColumn, ...]:
columns: list[CSVExportColumn] = [
CSVExportColumn(header="project_name", accessor="project_name"),
CSVExportColumn(header="name", accessor="name"),
CSVExportColumn(header="status", accessor="status"),
CSVExportColumn(
header="start_date",
accessor="start_date",
formatter=_format_date,
),
CSVExportColumn(
header="end_date",
accessor="end_date",
formatter=_format_date,
),
CSVExportColumn(header="discount_rate",
accessor="discount_rate", formatter=_format_number),
CSVExportColumn(header="currency", accessor="currency"),
CSVExportColumn(header="primary_resource",
accessor="primary_resource"),
CSVExportColumn("name", "name", required=True),
CSVExportColumn("location", "location"),
CSVExportColumn("operation_type", "operation_type"),
]
if include_description:
columns.append(
CSVExportColumn(header="description", accessor="description")
)
columns.append(CSVExportColumn("description", "description"))
if include_timestamps:
columns.extend(
(
CSVExportColumn(
header="created_at",
accessor="created_at",
formatter=_format_datetime,
),
CSVExportColumn(
header="updated_at",
accessor="updated_at",
formatter=_format_datetime,
),
CSVExportColumn("created_at", "created_at",
formatter=format_datetime_utc),
CSVExportColumn("updated_at", "updated_at",
formatter=format_datetime_utc),
)
)
return CSVExportRowFactory(columns=tuple(columns))
return tuple(columns)
def default_scenario_columns(
*,
include_description: bool = True,
include_timestamps: bool = True,
) -> tuple[CSVExportColumn, ...]:
columns: list[CSVExportColumn] = [
CSVExportColumn(
"project_name",
lambda scenario: getattr(
getattr(scenario, "project", None), "name", None),
required=True,
),
CSVExportColumn("name", "name", required=True),
CSVExportColumn("status", "status"),
CSVExportColumn("start_date", "start_date", formatter=format_date_iso),
CSVExportColumn("end_date", "end_date", formatter=format_date_iso),
CSVExportColumn("discount_rate", "discount_rate",
formatter=format_decimal),
CSVExportColumn("currency", "currency"),
CSVExportColumn("primary_resource", "primary_resource"),
]
if include_description:
columns.append(CSVExportColumn("description", "description"))
if include_timestamps:
columns.extend(
(
CSVExportColumn("created_at", "created_at",
formatter=format_datetime_utc),
CSVExportColumn("updated_at", "updated_at",
formatter=format_datetime_utc),
)
)
return tuple(columns)
def stream_projects_to_csv(
projects: Sequence[object],
projects: Iterable[Any],
*,
row_factory: CSVExportRowFactory | None = None,
) -> Iterator[str]:
factory = row_factory or build_project_row_factory()
yield _join_row(factory.headers())
for project in projects:
yield _join_row(factory.to_row(project))
columns: Sequence[CSVExportColumn] | None = None,
) -> Iterator[bytes]:
resolved_columns = tuple(columns or default_project_columns())
exporter = CSVExporter(resolved_columns)
yield from exporter.iter_bytes(projects)
def stream_scenarios_to_csv(
scenarios: Sequence[object],
scenarios: Iterable[Any],
*,
row_factory: CSVExportRowFactory | None = None,
) -> Iterator[str]:
factory = row_factory or build_scenario_row_factory()
yield _join_row(factory.headers())
for scenario in scenarios:
yield _join_row(factory.to_row(scenario))
columns: Sequence[CSVExportColumn] | None = None,
) -> Iterator[bytes]:
resolved_columns = tuple(columns or default_scenario_columns())
exporter = CSVExporter(resolved_columns)
yield from exporter.iter_bytes(scenarios)
def _join_row(values: Iterable[str]) -> str:
return ",".join(values)
def _default_formatter(value: object | None) -> str:
def default_formatter(value: Any) -> str:
if value is None:
return ""
if isinstance(value, Enum):
return str(value.value)
if isinstance(value, Decimal):
return format_decimal(value)
if isinstance(value, datetime):
return format_datetime_utc(value)
if isinstance(value, date):
return format_date_iso(value)
if isinstance(value, bool):
return "true" if value else "false"
return str(value)
def _format_datetime(value: object | None) -> str:
from datetime import datetime
def format_datetime_utc(value: Any) -> str:
if not isinstance(value, datetime):
return ""
return value.astimezone(timezone.utc).replace(tzinfo=timezone.utc).isoformat().replace("+00:00", "Z")
if value.tzinfo is None:
value = value.replace(tzinfo=timezone.utc)
value = value.astimezone(timezone.utc)
return value.isoformat().replace("+00:00", "Z")
def _format_date(value: object | None) -> str:
from datetime import date
def format_date_iso(value: Any) -> str:
if not isinstance(value, date):
return ""
return value.isoformat()
def _format_number(value: object | None) -> str:
def format_decimal(value: Any) -> str:
if value is None:
return ""
return f"{value:.2f}"
if isinstance(value, Decimal):
try:
quantised = value.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP)
except InvalidOperation: # pragma: no cover - unexpected precision issues
quantised = value
return format(quantised, "f")
if isinstance(value, (int, float)):
return f"{value:.2f}"
return default_formatter(value)
def _coerce_accessor(accessor: Accessor | str) -> Accessor:
if callable(accessor):
return accessor
path = [segment for segment in accessor.split(".") if segment]
def _resolve(entity: Any) -> Any:
current: Any = entity
for segment in path:
if current is None:
return None
current = getattr(current, segment, None)
return current
return _resolve
def _drain_buffer(buffer: StringIO) -> bytes:
data = buffer.getvalue()
buffer.seek(0)
buffer.truncate(0)
return data.encode("utf-8")