Files
calminer/services/export_serializers.py

179 lines
5.2 KiB
Python

from __future__ import annotations
from dataclasses import dataclass
from datetime import timezone
from functools import partial
from typing import Callable, Iterable, Iterator, Sequence
from .export_query import ProjectExportFilters, ScenarioExportFilters
__all__ = [
"CSVExportColumn",
"CSVExportRowFactory",
"build_project_row_factory",
"build_scenario_row_factory",
]
CSVValueFormatter = Callable[[object | None], str]
RowExtractor = Callable[[object], Sequence[object | None]]
@dataclass(slots=True, frozen=True)
class CSVExportColumn:
"""Describe how a field is rendered into CSV output."""
header: str
accessor: str
formatter: CSVValueFormatter | None = None
@dataclass(slots=True, frozen=True)
class CSVExportRowFactory:
"""Builds row values for CSV serialization."""
columns: tuple[CSVExportColumn, ...]
def headers(self) -> tuple[str, ...]:
return tuple(column.header for column in self.columns)
def to_row(self, entity: object) -> tuple[str, ...]:
values: list[str] = []
for column in self.columns:
value = getattr(entity, column.accessor, None)
formatter = column.formatter or _default_formatter
values.append(formatter(value))
return tuple(values)
def build_project_row_factory(
*,
include_timestamps: bool = True,
include_description: bool = True,
) -> CSVExportRowFactory:
columns: list[CSVExportColumn] = [
CSVExportColumn(header="name", accessor="name"),
CSVExportColumn(header="location", accessor="location"),
CSVExportColumn(header="operation_type", accessor="operation_type"),
]
if include_description:
columns.append(
CSVExportColumn(header="description", accessor="description")
)
if include_timestamps:
columns.extend(
(
CSVExportColumn(
header="created_at",
accessor="created_at",
formatter=_format_datetime,
),
CSVExportColumn(
header="updated_at",
accessor="updated_at",
formatter=_format_datetime,
),
)
)
return CSVExportRowFactory(columns=tuple(columns))
def build_scenario_row_factory(
*,
include_description: bool = True,
include_timestamps: bool = True,
) -> CSVExportRowFactory:
columns: list[CSVExportColumn] = [
CSVExportColumn(header="project_name", accessor="project_name"),
CSVExportColumn(header="name", accessor="name"),
CSVExportColumn(header="status", accessor="status"),
CSVExportColumn(
header="start_date",
accessor="start_date",
formatter=_format_date,
),
CSVExportColumn(
header="end_date",
accessor="end_date",
formatter=_format_date,
),
CSVExportColumn(header="discount_rate",
accessor="discount_rate", formatter=_format_number),
CSVExportColumn(header="currency", accessor="currency"),
CSVExportColumn(header="primary_resource",
accessor="primary_resource"),
]
if include_description:
columns.append(
CSVExportColumn(header="description", accessor="description")
)
if include_timestamps:
columns.extend(
(
CSVExportColumn(
header="created_at",
accessor="created_at",
formatter=_format_datetime,
),
CSVExportColumn(
header="updated_at",
accessor="updated_at",
formatter=_format_datetime,
),
)
)
return CSVExportRowFactory(columns=tuple(columns))
def stream_projects_to_csv(
projects: Sequence[object],
*,
row_factory: CSVExportRowFactory | None = None,
) -> Iterator[str]:
factory = row_factory or build_project_row_factory()
yield _join_row(factory.headers())
for project in projects:
yield _join_row(factory.to_row(project))
def stream_scenarios_to_csv(
scenarios: Sequence[object],
*,
row_factory: CSVExportRowFactory | None = None,
) -> Iterator[str]:
factory = row_factory or build_scenario_row_factory()
yield _join_row(factory.headers())
for scenario in scenarios:
yield _join_row(factory.to_row(scenario))
def _join_row(values: Iterable[str]) -> str:
return ",".join(values)
def _default_formatter(value: object | None) -> str:
if value is None:
return ""
return str(value)
def _format_datetime(value: object | None) -> str:
from datetime import datetime
if not isinstance(value, datetime):
return ""
return value.astimezone(timezone.utc).replace(tzinfo=timezone.utc).isoformat().replace("+00:00", "Z")
def _format_date(value: object | None) -> str:
from datetime import date
if not isinstance(value, date):
return ""
return value.isoformat()
def _format_number(value: object | None) -> str:
if value is None:
return ""
return f"{value:.2f}"