Files
calminer/services/export_serializers.py

241 lines
7.3 KiB
Python

from __future__ import annotations
import csv
from dataclasses import dataclass
from datetime import date, datetime, timezone
from decimal import Decimal, InvalidOperation, ROUND_HALF_UP
from enum import Enum
from io import StringIO
from typing import Any, Callable, Iterable, Iterator, Sequence
CSVValueFormatter = Callable[[Any], str]
Accessor = Callable[[Any], Any]
__all__ = [
"CSVExportColumn",
"CSVExporter",
"default_project_columns",
"default_scenario_columns",
"stream_projects_to_csv",
"stream_scenarios_to_csv",
"default_formatter",
"format_datetime_utc",
"format_date_iso",
"format_decimal",
]
@dataclass(slots=True)
class CSVExportColumn:
"""Declarative description of a CSV export column."""
header: str
accessor: Accessor | str
formatter: CSVValueFormatter | None = None
required: bool = False
def resolve_accessor(self) -> Accessor:
if isinstance(self.accessor, str):
return _coerce_accessor(self.accessor)
return self.accessor
def value_for(self, entity: Any) -> Any:
accessor = self.resolve_accessor()
try:
return accessor(entity)
except Exception: # pragma: no cover - defensive safeguard
return None
class CSVExporter:
"""Stream Python objects as UTF-8 encoded CSV rows."""
def __init__(
self,
columns: Sequence[CSVExportColumn],
*,
include_header: bool = True,
line_terminator: str = "\n",
) -> None:
if not columns:
raise ValueError("At least one column is required for CSV export.")
self._columns: tuple[CSVExportColumn, ...] = tuple(columns)
self._include_header = include_header
self._line_terminator = line_terminator
@property
def columns(self) -> tuple[CSVExportColumn, ...]:
return self._columns
def headers(self) -> tuple[str, ...]:
return tuple(column.header for column in self._columns)
def iter_bytes(self, records: Iterable[Any]) -> Iterator[bytes]:
buffer = StringIO()
writer = csv.writer(buffer, lineterminator=self._line_terminator)
if self._include_header:
writer.writerow(self.headers())
yield _drain_buffer(buffer)
for record in records:
writer.writerow(self._format_row(record))
yield _drain_buffer(buffer)
def _format_row(self, record: Any) -> list[str]:
formatted: list[str] = []
for column in self._columns:
accessor = column.resolve_accessor()
raw_value = accessor(record)
formatter = column.formatter or default_formatter
formatted.append(formatter(raw_value))
return formatted
def default_project_columns(
*,
include_description: bool = True,
include_timestamps: bool = True,
) -> tuple[CSVExportColumn, ...]:
columns: list[CSVExportColumn] = [
CSVExportColumn("name", "name", required=True),
CSVExportColumn("location", "location"),
CSVExportColumn("operation_type", "operation_type"),
]
if include_description:
columns.append(CSVExportColumn("description", "description"))
if include_timestamps:
columns.extend(
(
CSVExportColumn("created_at", "created_at",
formatter=format_datetime_utc),
CSVExportColumn("updated_at", "updated_at",
formatter=format_datetime_utc),
)
)
return tuple(columns)
def default_scenario_columns(
*,
include_description: bool = True,
include_timestamps: bool = True,
) -> tuple[CSVExportColumn, ...]:
columns: list[CSVExportColumn] = [
CSVExportColumn(
"project_name",
lambda scenario: getattr(
getattr(scenario, "project", None), "name", None),
required=True,
),
CSVExportColumn("name", "name", required=True),
CSVExportColumn("status", "status"),
CSVExportColumn("start_date", "start_date", formatter=format_date_iso),
CSVExportColumn("end_date", "end_date", formatter=format_date_iso),
CSVExportColumn("discount_rate", "discount_rate",
formatter=format_decimal),
CSVExportColumn("currency", "currency"),
CSVExportColumn("primary_resource", "primary_resource"),
]
if include_description:
columns.append(CSVExportColumn("description", "description"))
if include_timestamps:
columns.extend(
(
CSVExportColumn("created_at", "created_at",
formatter=format_datetime_utc),
CSVExportColumn("updated_at", "updated_at",
formatter=format_datetime_utc),
)
)
return tuple(columns)
def stream_projects_to_csv(
projects: Iterable[Any],
*,
columns: Sequence[CSVExportColumn] | None = None,
) -> Iterator[bytes]:
resolved_columns = tuple(columns or default_project_columns())
exporter = CSVExporter(resolved_columns)
yield from exporter.iter_bytes(projects)
def stream_scenarios_to_csv(
scenarios: Iterable[Any],
*,
columns: Sequence[CSVExportColumn] | None = None,
) -> Iterator[bytes]:
resolved_columns = tuple(columns or default_scenario_columns())
exporter = CSVExporter(resolved_columns)
yield from exporter.iter_bytes(scenarios)
def default_formatter(value: Any) -> str:
if value is None:
return ""
if isinstance(value, Enum):
return str(value.value)
if isinstance(value, Decimal):
return format_decimal(value)
if isinstance(value, datetime):
return format_datetime_utc(value)
if isinstance(value, date):
return format_date_iso(value)
if isinstance(value, bool):
return "true" if value else "false"
return str(value)
def format_datetime_utc(value: Any) -> str:
if not isinstance(value, datetime):
return ""
if value.tzinfo is None:
value = value.replace(tzinfo=timezone.utc)
value = value.astimezone(timezone.utc)
return value.isoformat().replace("+00:00", "Z")
def format_date_iso(value: Any) -> str:
if not isinstance(value, date):
return ""
return value.isoformat()
def format_decimal(value: Any) -> str:
if value is None:
return ""
if isinstance(value, Decimal):
try:
quantised = value.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP)
except InvalidOperation: # pragma: no cover - unexpected precision issues
quantised = value
return format(quantised, "f")
if isinstance(value, (int, float)):
return f"{value:.2f}"
return default_formatter(value)
def _coerce_accessor(accessor: Accessor | str) -> Accessor:
if callable(accessor):
return accessor
path = [segment for segment in accessor.split(".") if segment]
def _resolve(entity: Any) -> Any:
current: Any = entity
for segment in path:
if current is None:
return None
current = getattr(current, segment, None)
return current
return _resolve
def _drain_buffer(buffer: StringIO) -> bytes:
data = buffer.getvalue()
buffer.seek(0)
buffer.truncate(0)
return data.encode("utf-8")