Files
calminer/services/importers.py

565 lines
17 KiB
Python

from __future__ import annotations
from dataclasses import dataclass
from enum import Enum
from pathlib import Path
from typing import Any, BinaryIO, Callable, Generic, Iterable, Mapping, TypeVar, cast
from uuid import uuid4
from types import MappingProxyType
import pandas as pd
from pandas import DataFrame
from pydantic import BaseModel, ValidationError
from models import Project, Scenario
from schemas.imports import ProjectImportRow, ScenarioImportRow
from services.unit_of_work import UnitOfWork
TImportRow = TypeVar("TImportRow", bound=BaseModel)
PROJECT_COLUMNS: tuple[str, ...] = (
"name",
"location",
"operation_type",
"description",
"created_at",
"updated_at",
)
SCENARIO_COLUMNS: tuple[str, ...] = (
"project_name",
"name",
"status",
"start_date",
"end_date",
"discount_rate",
"currency",
"primary_resource",
"description",
"created_at",
"updated_at",
)
@dataclass(slots=True)
class ImportRowError:
row_number: int
field: str | None
message: str
@dataclass(slots=True)
class ParsedImportRow(Generic[TImportRow]):
row_number: int
data: TImportRow
@dataclass(slots=True)
class ImportResult(Generic[TImportRow]):
rows: list[ParsedImportRow[TImportRow]]
errors: list[ImportRowError]
class UnsupportedImportFormat(ValueError):
pass
class ImportPreviewState(str, Enum):
NEW = "new"
UPDATE = "update"
SKIP = "skip"
ERROR = "error"
@dataclass(slots=True)
class ImportPreviewRow(Generic[TImportRow]):
row_number: int
data: TImportRow
state: ImportPreviewState
issues: list[str]
context: dict[str, Any] | None = None
@dataclass(slots=True)
class ImportPreviewSummary:
total_rows: int
accepted: int
skipped: int
errored: int
@dataclass(slots=True)
class ImportPreview(Generic[TImportRow]):
rows: list[ImportPreviewRow[TImportRow]]
summary: ImportPreviewSummary
row_issues: list["ImportPreviewRowIssues"]
parser_errors: list[ImportRowError]
stage_token: str | None
@dataclass(slots=True)
class StagedRow(Generic[TImportRow]):
parsed: ParsedImportRow[TImportRow]
context: dict[str, Any]
@dataclass(slots=True)
class ImportPreviewRowIssue:
message: str
field: str | None = None
@dataclass(slots=True)
class ImportPreviewRowIssues:
row_number: int
state: ImportPreviewState | None
issues: list[ImportPreviewRowIssue]
@dataclass(slots=True)
class StagedImport(Generic[TImportRow]):
token: str
rows: list[StagedRow[TImportRow]]
@dataclass(slots=True, frozen=True)
class StagedRowView(Generic[TImportRow]):
row_number: int
data: TImportRow
context: Mapping[str, Any]
@dataclass(slots=True, frozen=True)
class StagedImportView(Generic[TImportRow]):
token: str
rows: tuple[StagedRowView[TImportRow], ...]
UnitOfWorkFactory = Callable[[], UnitOfWork]
class ImportIngestionService:
"""Coordinates parsing, validation, and preview staging for imports."""
def __init__(self, uow_factory: UnitOfWorkFactory) -> None:
self._uow_factory = uow_factory
self._project_stage: dict[str, StagedImport[ProjectImportRow]] = {}
self._scenario_stage: dict[str, StagedImport[ScenarioImportRow]] = {}
def preview_projects(
self,
stream: BinaryIO,
filename: str,
) -> ImportPreview[ProjectImportRow]:
result = load_project_imports(stream, filename)
parser_errors = result.errors
preview_rows: list[ImportPreviewRow[ProjectImportRow]] = []
staged_rows: list[StagedRow[ProjectImportRow]] = []
accepted = skipped = errored = 0
seen_names: set[str] = set()
existing_by_name: dict[str, Project] = {}
if result.rows:
with self._uow_factory() as uow:
if not uow.projects:
raise RuntimeError("Project repository is unavailable")
existing_by_name = dict(
uow.projects.find_by_names(
parsed.data.name for parsed in result.rows
)
)
for parsed in result.rows:
name_key = _normalise_key(parsed.data.name)
issues: list[str] = []
context: dict[str, Any] | None = None
state = ImportPreviewState.NEW
if name_key in seen_names:
state = ImportPreviewState.SKIP
issues.append(
"Duplicate project name within upload; row skipped.")
else:
seen_names.add(name_key)
existing = existing_by_name.get(name_key)
if existing:
state = ImportPreviewState.UPDATE
context = {
"mode": "update",
"project_id": existing.id,
}
issues.append("Existing project will be updated.")
else:
context = {"mode": "create"}
preview_rows.append(
ImportPreviewRow(
row_number=parsed.row_number,
data=parsed.data,
state=state,
issues=issues,
context=context,
)
)
if state in {ImportPreviewState.NEW, ImportPreviewState.UPDATE}:
accepted += 1
staged_rows.append(
StagedRow(parsed=parsed, context=context or {
"mode": "create"})
)
elif state == ImportPreviewState.SKIP:
skipped += 1
else:
errored += 1
parser_error_rows = {error.row_number for error in parser_errors}
errored += len(parser_error_rows)
total_rows = len(preview_rows) + len(parser_error_rows)
summary = ImportPreviewSummary(
total_rows=total_rows,
accepted=accepted,
skipped=skipped,
errored=errored,
)
row_issues = _compile_row_issues(preview_rows, parser_errors)
stage_token: str | None = None
if staged_rows:
stage_token = self._store_project_stage(staged_rows)
return ImportPreview(
rows=preview_rows,
summary=summary,
row_issues=row_issues,
parser_errors=parser_errors,
stage_token=stage_token,
)
def preview_scenarios(
self,
stream: BinaryIO,
filename: str,
) -> ImportPreview[ScenarioImportRow]:
result = load_scenario_imports(stream, filename)
parser_errors = result.errors
preview_rows: list[ImportPreviewRow[ScenarioImportRow]] = []
staged_rows: list[StagedRow[ScenarioImportRow]] = []
accepted = skipped = errored = 0
seen_pairs: set[tuple[str, str]] = set()
existing_projects: dict[str, Project] = {}
existing_scenarios: dict[tuple[int, str], Scenario] = {}
if result.rows:
with self._uow_factory() as uow:
if not uow.projects or not uow.scenarios:
raise RuntimeError("Repositories are unavailable")
existing_projects = dict(
uow.projects.find_by_names(
parsed.data.project_name for parsed in result.rows
)
)
names_by_project: dict[int, set[str]] = {}
for parsed in result.rows:
project = existing_projects.get(
_normalise_key(parsed.data.project_name)
)
if not project:
continue
names_by_project.setdefault(project.id, set()).add(
_normalise_key(parsed.data.name)
)
for project_id, names in names_by_project.items():
matches = uow.scenarios.find_by_project_and_names(
project_id, names)
for name_key, scenario in matches.items():
existing_scenarios[(project_id, name_key)] = scenario
for parsed in result.rows:
project_key = _normalise_key(parsed.data.project_name)
scenario_key = _normalise_key(parsed.data.name)
issues: list[str] = []
context: dict[str, Any] | None = None
state = ImportPreviewState.NEW
if (project_key, scenario_key) in seen_pairs:
state = ImportPreviewState.SKIP
issues.append(
"Duplicate scenario for project within upload; row skipped."
)
else:
seen_pairs.add((project_key, scenario_key))
project = existing_projects.get(project_key)
if not project:
state = ImportPreviewState.ERROR
issues.append(
f"Project '{parsed.data.project_name}' does not exist."
)
else:
context = {"mode": "create", "project_id": project.id}
existing = existing_scenarios.get(
(project.id, scenario_key))
if existing:
state = ImportPreviewState.UPDATE
context = {
"mode": "update",
"project_id": project.id,
"scenario_id": existing.id,
}
issues.append("Existing scenario will be updated.")
preview_rows.append(
ImportPreviewRow(
row_number=parsed.row_number,
data=parsed.data,
state=state,
issues=issues,
context=context,
)
)
if state in {ImportPreviewState.NEW, ImportPreviewState.UPDATE}:
accepted += 1
staged_rows.append(
StagedRow(parsed=parsed, context=context or {
"mode": "create"})
)
elif state == ImportPreviewState.SKIP:
skipped += 1
else:
errored += 1
parser_error_rows = {error.row_number for error in parser_errors}
errored += len(parser_error_rows)
total_rows = len(preview_rows) + len(parser_error_rows)
summary = ImportPreviewSummary(
total_rows=total_rows,
accepted=accepted,
skipped=skipped,
errored=errored,
)
row_issues = _compile_row_issues(preview_rows, parser_errors)
stage_token: str | None = None
if staged_rows:
stage_token = self._store_scenario_stage(staged_rows)
return ImportPreview(
rows=preview_rows,
summary=summary,
row_issues=row_issues,
parser_errors=parser_errors,
stage_token=stage_token,
)
def get_staged_projects(
self, token: str
) -> StagedImportView[ProjectImportRow] | None:
staged = self._project_stage.get(token)
if not staged:
return None
return _build_staged_view(staged)
def get_staged_scenarios(
self, token: str
) -> StagedImportView[ScenarioImportRow] | None:
staged = self._scenario_stage.get(token)
if not staged:
return None
return _build_staged_view(staged)
def consume_staged_projects(
self, token: str
) -> StagedImportView[ProjectImportRow] | None:
staged = self._project_stage.pop(token, None)
if not staged:
return None
return _build_staged_view(staged)
def consume_staged_scenarios(
self, token: str
) -> StagedImportView[ScenarioImportRow] | None:
staged = self._scenario_stage.pop(token, None)
if not staged:
return None
return _build_staged_view(staged)
def clear_staged_projects(self, token: str) -> bool:
return self._project_stage.pop(token, None) is not None
def clear_staged_scenarios(self, token: str) -> bool:
return self._scenario_stage.pop(token, None) is not None
def _store_project_stage(
self, rows: list[StagedRow[ProjectImportRow]]
) -> str:
token = str(uuid4())
self._project_stage[token] = StagedImport(token=token, rows=rows)
return token
def _store_scenario_stage(
self, rows: list[StagedRow[ScenarioImportRow]]
) -> str:
token = str(uuid4())
self._scenario_stage[token] = StagedImport(token=token, rows=rows)
return token
def load_project_imports(stream: BinaryIO, filename: str) -> ImportResult[ProjectImportRow]:
df = _load_dataframe(stream, filename)
return _parse_dataframe(df, ProjectImportRow, PROJECT_COLUMNS)
def load_scenario_imports(stream: BinaryIO, filename: str) -> ImportResult[ScenarioImportRow]:
df = _load_dataframe(stream, filename)
return _parse_dataframe(df, ScenarioImportRow, SCENARIO_COLUMNS)
def _load_dataframe(stream: BinaryIO, filename: str) -> DataFrame:
stream.seek(0)
suffix = Path(filename).suffix.lower()
if suffix == ".csv":
df = pd.read_csv(stream, dtype=str,
keep_default_na=False, encoding="utf-8")
elif suffix in {".xls", ".xlsx"}:
df = pd.read_excel(stream, dtype=str, engine="openpyxl")
else:
raise UnsupportedImportFormat(
f"Unsupported file type: {suffix or 'unknown'}")
df.columns = [str(col).strip().lower() for col in df.columns]
return df
def _parse_dataframe(
df: DataFrame,
model: type[TImportRow],
expected_columns: Iterable[str],
) -> ImportResult[TImportRow]:
rows: list[ParsedImportRow[TImportRow]] = []
errors: list[ImportRowError] = []
for index, raw in enumerate(df.to_dict(orient="records"), start=2):
payload = _prepare_payload(
cast(dict[str, object], raw), expected_columns)
try:
rows.append(
ParsedImportRow(row_number=index, data=model(**payload))
)
except ValidationError as exc: # pragma: no cover - exercised via tests
for detail in exc.errors():
loc = ".".join(str(part)
for part in detail.get("loc", [])) or None
errors.append(
ImportRowError(
row_number=index,
field=loc,
message=detail.get("msg", "Invalid value"),
)
)
return ImportResult(rows=rows, errors=errors)
def _prepare_payload(
raw: dict[str, object], expected_columns: Iterable[str]
) -> dict[str, object | None]:
payload: dict[str, object | None] = {}
for column in expected_columns:
if column not in raw:
continue
value = raw.get(column)
if isinstance(value, str):
value = value.strip()
if value == "":
value = None
if value is not None and pd.isna(cast(Any, value)):
value = None
payload[column] = value
return payload
def _normalise_key(value: str) -> str:
return value.strip().lower()
def _build_staged_view(
staged: StagedImport[TImportRow],
) -> StagedImportView[TImportRow]:
rows = tuple(
StagedRowView(
row_number=row.parsed.row_number,
data=cast(TImportRow, _deep_copy_model(row.parsed.data)),
context=MappingProxyType(dict(row.context)),
)
for row in staged.rows
)
return StagedImportView(token=staged.token, rows=rows)
def _deep_copy_model(model: BaseModel) -> BaseModel:
copy_method = getattr(model, "model_copy", None)
if callable(copy_method): # pydantic v2
return cast(BaseModel, copy_method(deep=True))
return model.copy(deep=True) # type: ignore[attr-defined]
def _compile_row_issues(
preview_rows: Iterable[ImportPreviewRow[Any]],
parser_errors: Iterable[ImportRowError],
) -> list[ImportPreviewRowIssues]:
issue_map: dict[int, ImportPreviewRowIssues] = {}
def ensure_bundle(
row_number: int,
state: ImportPreviewState | None,
) -> ImportPreviewRowIssues:
bundle = issue_map.get(row_number)
if bundle is None:
bundle = ImportPreviewRowIssues(
row_number=row_number,
state=state,
issues=[],
)
issue_map[row_number] = bundle
else:
if _state_priority(state) > _state_priority(bundle.state):
bundle.state = state
return bundle
for row in preview_rows:
if not row.issues:
continue
bundle = ensure_bundle(row.row_number, row.state)
for message in row.issues:
bundle.issues.append(ImportPreviewRowIssue(message=message))
for error in parser_errors:
bundle = ensure_bundle(error.row_number, ImportPreviewState.ERROR)
bundle.issues.append(
ImportPreviewRowIssue(message=error.message, field=error.field)
)
return sorted(issue_map.values(), key=lambda item: item.row_number)
def _state_priority(state: ImportPreviewState | None) -> int:
if state is None:
return -1
if state == ImportPreviewState.ERROR:
return 3
if state == ImportPreviewState.SKIP:
return 2
if state == ImportPreviewState.UPDATE:
return 1
return 0