179 lines
5.2 KiB
Python
179 lines
5.2 KiB
Python
from __future__ import annotations
|
|
|
|
from dataclasses import dataclass
|
|
from datetime import timezone
|
|
from functools import partial
|
|
from typing import Callable, Iterable, Iterator, Sequence
|
|
|
|
from .export_query import ProjectExportFilters, ScenarioExportFilters
|
|
|
|
__all__ = [
|
|
"CSVExportColumn",
|
|
"CSVExportRowFactory",
|
|
"build_project_row_factory",
|
|
"build_scenario_row_factory",
|
|
]
|
|
|
|
CSVValueFormatter = Callable[[object | None], str]
|
|
RowExtractor = Callable[[object], Sequence[object | None]]
|
|
|
|
|
|
@dataclass(slots=True, frozen=True)
|
|
class CSVExportColumn:
|
|
"""Describe how a field is rendered into CSV output."""
|
|
|
|
header: str
|
|
accessor: str
|
|
formatter: CSVValueFormatter | None = None
|
|
|
|
|
|
@dataclass(slots=True, frozen=True)
|
|
class CSVExportRowFactory:
|
|
"""Builds row values for CSV serialization."""
|
|
|
|
columns: tuple[CSVExportColumn, ...]
|
|
|
|
def headers(self) -> tuple[str, ...]:
|
|
return tuple(column.header for column in self.columns)
|
|
|
|
def to_row(self, entity: object) -> tuple[str, ...]:
|
|
values: list[str] = []
|
|
for column in self.columns:
|
|
value = getattr(entity, column.accessor, None)
|
|
formatter = column.formatter or _default_formatter
|
|
values.append(formatter(value))
|
|
return tuple(values)
|
|
|
|
|
|
def build_project_row_factory(
|
|
*,
|
|
include_timestamps: bool = True,
|
|
include_description: bool = True,
|
|
) -> CSVExportRowFactory:
|
|
columns: list[CSVExportColumn] = [
|
|
CSVExportColumn(header="name", accessor="name"),
|
|
CSVExportColumn(header="location", accessor="location"),
|
|
CSVExportColumn(header="operation_type", accessor="operation_type"),
|
|
]
|
|
if include_description:
|
|
columns.append(
|
|
CSVExportColumn(header="description", accessor="description")
|
|
)
|
|
if include_timestamps:
|
|
columns.extend(
|
|
(
|
|
CSVExportColumn(
|
|
header="created_at",
|
|
accessor="created_at",
|
|
formatter=_format_datetime,
|
|
),
|
|
CSVExportColumn(
|
|
header="updated_at",
|
|
accessor="updated_at",
|
|
formatter=_format_datetime,
|
|
),
|
|
)
|
|
)
|
|
return CSVExportRowFactory(columns=tuple(columns))
|
|
|
|
|
|
def build_scenario_row_factory(
|
|
*,
|
|
include_description: bool = True,
|
|
include_timestamps: bool = True,
|
|
) -> CSVExportRowFactory:
|
|
columns: list[CSVExportColumn] = [
|
|
CSVExportColumn(header="project_name", accessor="project_name"),
|
|
CSVExportColumn(header="name", accessor="name"),
|
|
CSVExportColumn(header="status", accessor="status"),
|
|
CSVExportColumn(
|
|
header="start_date",
|
|
accessor="start_date",
|
|
formatter=_format_date,
|
|
),
|
|
CSVExportColumn(
|
|
header="end_date",
|
|
accessor="end_date",
|
|
formatter=_format_date,
|
|
),
|
|
CSVExportColumn(header="discount_rate",
|
|
accessor="discount_rate", formatter=_format_number),
|
|
CSVExportColumn(header="currency", accessor="currency"),
|
|
CSVExportColumn(header="primary_resource",
|
|
accessor="primary_resource"),
|
|
]
|
|
if include_description:
|
|
columns.append(
|
|
CSVExportColumn(header="description", accessor="description")
|
|
)
|
|
if include_timestamps:
|
|
columns.extend(
|
|
(
|
|
CSVExportColumn(
|
|
header="created_at",
|
|
accessor="created_at",
|
|
formatter=_format_datetime,
|
|
),
|
|
CSVExportColumn(
|
|
header="updated_at",
|
|
accessor="updated_at",
|
|
formatter=_format_datetime,
|
|
),
|
|
)
|
|
)
|
|
return CSVExportRowFactory(columns=tuple(columns))
|
|
|
|
|
|
def stream_projects_to_csv(
|
|
projects: Sequence[object],
|
|
*,
|
|
row_factory: CSVExportRowFactory | None = None,
|
|
) -> Iterator[str]:
|
|
factory = row_factory or build_project_row_factory()
|
|
yield _join_row(factory.headers())
|
|
for project in projects:
|
|
yield _join_row(factory.to_row(project))
|
|
|
|
|
|
def stream_scenarios_to_csv(
|
|
scenarios: Sequence[object],
|
|
*,
|
|
row_factory: CSVExportRowFactory | None = None,
|
|
) -> Iterator[str]:
|
|
factory = row_factory or build_scenario_row_factory()
|
|
yield _join_row(factory.headers())
|
|
for scenario in scenarios:
|
|
yield _join_row(factory.to_row(scenario))
|
|
|
|
|
|
def _join_row(values: Iterable[str]) -> str:
|
|
return ",".join(values)
|
|
|
|
|
|
def _default_formatter(value: object | None) -> str:
|
|
if value is None:
|
|
return ""
|
|
return str(value)
|
|
|
|
|
|
def _format_datetime(value: object | None) -> str:
|
|
from datetime import datetime
|
|
|
|
if not isinstance(value, datetime):
|
|
return ""
|
|
return value.astimezone(timezone.utc).replace(tzinfo=timezone.utc).isoformat().replace("+00:00", "Z")
|
|
|
|
|
|
def _format_date(value: object | None) -> str:
|
|
from datetime import date
|
|
|
|
if not isinstance(value, date):
|
|
return ""
|
|
return value.isoformat()
|
|
|
|
|
|
def _format_number(value: object | None) -> str:
|
|
if value is None:
|
|
return ""
|
|
return f"{value:.2f}"
|