241 lines
7.3 KiB
Python
241 lines
7.3 KiB
Python
from __future__ import annotations
|
|
|
|
import csv
|
|
from dataclasses import dataclass
|
|
from datetime import date, datetime, timezone
|
|
from decimal import Decimal, InvalidOperation, ROUND_HALF_UP
|
|
from enum import Enum
|
|
from io import StringIO
|
|
from typing import Any, Callable, Iterable, Iterator, Sequence
|
|
|
|
CSVValueFormatter = Callable[[Any], str]
|
|
Accessor = Callable[[Any], Any]
|
|
|
|
__all__ = [
|
|
"CSVExportColumn",
|
|
"CSVExporter",
|
|
"default_project_columns",
|
|
"default_scenario_columns",
|
|
"stream_projects_to_csv",
|
|
"stream_scenarios_to_csv",
|
|
"default_formatter",
|
|
"format_datetime_utc",
|
|
"format_date_iso",
|
|
"format_decimal",
|
|
]
|
|
|
|
|
|
@dataclass(slots=True)
|
|
class CSVExportColumn:
|
|
"""Declarative description of a CSV export column."""
|
|
|
|
header: str
|
|
accessor: Accessor | str
|
|
formatter: CSVValueFormatter | None = None
|
|
required: bool = False
|
|
|
|
def resolve_accessor(self) -> Accessor:
|
|
if isinstance(self.accessor, str):
|
|
return _coerce_accessor(self.accessor)
|
|
return self.accessor
|
|
|
|
def value_for(self, entity: Any) -> Any:
|
|
accessor = self.resolve_accessor()
|
|
try:
|
|
return accessor(entity)
|
|
except Exception: # pragma: no cover - defensive safeguard
|
|
return None
|
|
|
|
|
|
class CSVExporter:
|
|
"""Stream Python objects as UTF-8 encoded CSV rows."""
|
|
|
|
def __init__(
|
|
self,
|
|
columns: Sequence[CSVExportColumn],
|
|
*,
|
|
include_header: bool = True,
|
|
line_terminator: str = "\n",
|
|
) -> None:
|
|
if not columns:
|
|
raise ValueError("At least one column is required for CSV export.")
|
|
self._columns: tuple[CSVExportColumn, ...] = tuple(columns)
|
|
self._include_header = include_header
|
|
self._line_terminator = line_terminator
|
|
|
|
@property
|
|
def columns(self) -> tuple[CSVExportColumn, ...]:
|
|
return self._columns
|
|
|
|
def headers(self) -> tuple[str, ...]:
|
|
return tuple(column.header for column in self._columns)
|
|
|
|
def iter_bytes(self, records: Iterable[Any]) -> Iterator[bytes]:
|
|
buffer = StringIO()
|
|
writer = csv.writer(buffer, lineterminator=self._line_terminator)
|
|
|
|
if self._include_header:
|
|
writer.writerow(self.headers())
|
|
yield _drain_buffer(buffer)
|
|
|
|
for record in records:
|
|
writer.writerow(self._format_row(record))
|
|
yield _drain_buffer(buffer)
|
|
|
|
def _format_row(self, record: Any) -> list[str]:
|
|
formatted: list[str] = []
|
|
for column in self._columns:
|
|
accessor = column.resolve_accessor()
|
|
raw_value = accessor(record)
|
|
formatter = column.formatter or default_formatter
|
|
formatted.append(formatter(raw_value))
|
|
return formatted
|
|
|
|
|
|
def default_project_columns(
|
|
*,
|
|
include_description: bool = True,
|
|
include_timestamps: bool = True,
|
|
) -> tuple[CSVExportColumn, ...]:
|
|
columns: list[CSVExportColumn] = [
|
|
CSVExportColumn("name", "name", required=True),
|
|
CSVExportColumn("location", "location"),
|
|
CSVExportColumn("operation_type", "operation_type"),
|
|
]
|
|
if include_description:
|
|
columns.append(CSVExportColumn("description", "description"))
|
|
if include_timestamps:
|
|
columns.extend(
|
|
(
|
|
CSVExportColumn("created_at", "created_at",
|
|
formatter=format_datetime_utc),
|
|
CSVExportColumn("updated_at", "updated_at",
|
|
formatter=format_datetime_utc),
|
|
)
|
|
)
|
|
return tuple(columns)
|
|
|
|
|
|
def default_scenario_columns(
|
|
*,
|
|
include_description: bool = True,
|
|
include_timestamps: bool = True,
|
|
) -> tuple[CSVExportColumn, ...]:
|
|
columns: list[CSVExportColumn] = [
|
|
CSVExportColumn(
|
|
"project_name",
|
|
lambda scenario: getattr(
|
|
getattr(scenario, "project", None), "name", None),
|
|
required=True,
|
|
),
|
|
CSVExportColumn("name", "name", required=True),
|
|
CSVExportColumn("status", "status"),
|
|
CSVExportColumn("start_date", "start_date", formatter=format_date_iso),
|
|
CSVExportColumn("end_date", "end_date", formatter=format_date_iso),
|
|
CSVExportColumn("discount_rate", "discount_rate",
|
|
formatter=format_decimal),
|
|
CSVExportColumn("currency", "currency"),
|
|
CSVExportColumn("primary_resource", "primary_resource"),
|
|
]
|
|
if include_description:
|
|
columns.append(CSVExportColumn("description", "description"))
|
|
if include_timestamps:
|
|
columns.extend(
|
|
(
|
|
CSVExportColumn("created_at", "created_at",
|
|
formatter=format_datetime_utc),
|
|
CSVExportColumn("updated_at", "updated_at",
|
|
formatter=format_datetime_utc),
|
|
)
|
|
)
|
|
return tuple(columns)
|
|
|
|
|
|
def stream_projects_to_csv(
|
|
projects: Iterable[Any],
|
|
*,
|
|
columns: Sequence[CSVExportColumn] | None = None,
|
|
) -> Iterator[bytes]:
|
|
resolved_columns = tuple(columns or default_project_columns())
|
|
exporter = CSVExporter(resolved_columns)
|
|
yield from exporter.iter_bytes(projects)
|
|
|
|
|
|
def stream_scenarios_to_csv(
|
|
scenarios: Iterable[Any],
|
|
*,
|
|
columns: Sequence[CSVExportColumn] | None = None,
|
|
) -> Iterator[bytes]:
|
|
resolved_columns = tuple(columns or default_scenario_columns())
|
|
exporter = CSVExporter(resolved_columns)
|
|
yield from exporter.iter_bytes(scenarios)
|
|
|
|
|
|
def default_formatter(value: Any) -> str:
|
|
if value is None:
|
|
return ""
|
|
if isinstance(value, Enum):
|
|
return str(value.value)
|
|
if isinstance(value, Decimal):
|
|
return format_decimal(value)
|
|
if isinstance(value, datetime):
|
|
return format_datetime_utc(value)
|
|
if isinstance(value, date):
|
|
return format_date_iso(value)
|
|
if isinstance(value, bool):
|
|
return "true" if value else "false"
|
|
return str(value)
|
|
|
|
|
|
def format_datetime_utc(value: Any) -> str:
|
|
if not isinstance(value, datetime):
|
|
return ""
|
|
if value.tzinfo is None:
|
|
value = value.replace(tzinfo=timezone.utc)
|
|
value = value.astimezone(timezone.utc)
|
|
return value.isoformat().replace("+00:00", "Z")
|
|
|
|
|
|
def format_date_iso(value: Any) -> str:
|
|
if not isinstance(value, date):
|
|
return ""
|
|
return value.isoformat()
|
|
|
|
|
|
def format_decimal(value: Any) -> str:
|
|
if value is None:
|
|
return ""
|
|
if isinstance(value, Decimal):
|
|
try:
|
|
quantised = value.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP)
|
|
except InvalidOperation: # pragma: no cover - unexpected precision issues
|
|
quantised = value
|
|
return format(quantised, "f")
|
|
if isinstance(value, (int, float)):
|
|
return f"{value:.2f}"
|
|
return default_formatter(value)
|
|
|
|
|
|
def _coerce_accessor(accessor: Accessor | str) -> Accessor:
|
|
if callable(accessor):
|
|
return accessor
|
|
|
|
path = [segment for segment in accessor.split(".") if segment]
|
|
|
|
def _resolve(entity: Any) -> Any:
|
|
current: Any = entity
|
|
for segment in path:
|
|
if current is None:
|
|
return None
|
|
current = getattr(current, segment, None)
|
|
return current
|
|
|
|
return _resolve
|
|
|
|
|
|
def _drain_buffer(buffer: StringIO) -> bytes:
|
|
data = buffer.getvalue()
|
|
buffer.seek(0)
|
|
buffer.truncate(0)
|
|
return data.encode("utf-8")
|