feat: implement import functionality for projects and scenarios with CSV/XLSX support, including validation and error handling
This commit is contained in:
564
services/importers.py
Normal file
564
services/importers.py
Normal file
@@ -0,0 +1,564 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Any, BinaryIO, Callable, Generic, Iterable, Mapping, TypeVar, cast
|
||||
from uuid import uuid4
|
||||
from types import MappingProxyType
|
||||
|
||||
import pandas as pd
|
||||
from pandas import DataFrame
|
||||
from pydantic import BaseModel, ValidationError
|
||||
|
||||
from models import Project, Scenario
|
||||
from schemas.imports import ProjectImportRow, ScenarioImportRow
|
||||
from services.unit_of_work import UnitOfWork
|
||||
|
||||
TImportRow = TypeVar("TImportRow", bound=BaseModel)
|
||||
|
||||
PROJECT_COLUMNS: tuple[str, ...] = (
|
||||
"name",
|
||||
"location",
|
||||
"operation_type",
|
||||
"description",
|
||||
"created_at",
|
||||
"updated_at",
|
||||
)
|
||||
|
||||
SCENARIO_COLUMNS: tuple[str, ...] = (
|
||||
"project_name",
|
||||
"name",
|
||||
"status",
|
||||
"start_date",
|
||||
"end_date",
|
||||
"discount_rate",
|
||||
"currency",
|
||||
"primary_resource",
|
||||
"description",
|
||||
"created_at",
|
||||
"updated_at",
|
||||
)
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class ImportRowError:
|
||||
row_number: int
|
||||
field: str | None
|
||||
message: str
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class ParsedImportRow(Generic[TImportRow]):
|
||||
row_number: int
|
||||
data: TImportRow
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class ImportResult(Generic[TImportRow]):
|
||||
rows: list[ParsedImportRow[TImportRow]]
|
||||
errors: list[ImportRowError]
|
||||
|
||||
|
||||
class UnsupportedImportFormat(ValueError):
|
||||
pass
|
||||
|
||||
|
||||
class ImportPreviewState(str, Enum):
|
||||
NEW = "new"
|
||||
UPDATE = "update"
|
||||
SKIP = "skip"
|
||||
ERROR = "error"
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class ImportPreviewRow(Generic[TImportRow]):
|
||||
row_number: int
|
||||
data: TImportRow
|
||||
state: ImportPreviewState
|
||||
issues: list[str]
|
||||
context: dict[str, Any] | None = None
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class ImportPreviewSummary:
|
||||
total_rows: int
|
||||
accepted: int
|
||||
skipped: int
|
||||
errored: int
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class ImportPreview(Generic[TImportRow]):
|
||||
rows: list[ImportPreviewRow[TImportRow]]
|
||||
summary: ImportPreviewSummary
|
||||
row_issues: list["ImportPreviewRowIssues"]
|
||||
parser_errors: list[ImportRowError]
|
||||
stage_token: str | None
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class StagedRow(Generic[TImportRow]):
|
||||
parsed: ParsedImportRow[TImportRow]
|
||||
context: dict[str, Any]
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class ImportPreviewRowIssue:
|
||||
message: str
|
||||
field: str | None = None
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class ImportPreviewRowIssues:
|
||||
row_number: int
|
||||
state: ImportPreviewState | None
|
||||
issues: list[ImportPreviewRowIssue]
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class StagedImport(Generic[TImportRow]):
|
||||
token: str
|
||||
rows: list[StagedRow[TImportRow]]
|
||||
|
||||
|
||||
@dataclass(slots=True, frozen=True)
|
||||
class StagedRowView(Generic[TImportRow]):
|
||||
row_number: int
|
||||
data: TImportRow
|
||||
context: Mapping[str, Any]
|
||||
|
||||
|
||||
@dataclass(slots=True, frozen=True)
|
||||
class StagedImportView(Generic[TImportRow]):
|
||||
token: str
|
||||
rows: tuple[StagedRowView[TImportRow], ...]
|
||||
|
||||
|
||||
UnitOfWorkFactory = Callable[[], UnitOfWork]
|
||||
|
||||
|
||||
class ImportIngestionService:
|
||||
"""Coordinates parsing, validation, and preview staging for imports."""
|
||||
|
||||
def __init__(self, uow_factory: UnitOfWorkFactory) -> None:
|
||||
self._uow_factory = uow_factory
|
||||
self._project_stage: dict[str, StagedImport[ProjectImportRow]] = {}
|
||||
self._scenario_stage: dict[str, StagedImport[ScenarioImportRow]] = {}
|
||||
|
||||
def preview_projects(
|
||||
self,
|
||||
stream: BinaryIO,
|
||||
filename: str,
|
||||
) -> ImportPreview[ProjectImportRow]:
|
||||
result = load_project_imports(stream, filename)
|
||||
parser_errors = result.errors
|
||||
|
||||
preview_rows: list[ImportPreviewRow[ProjectImportRow]] = []
|
||||
staged_rows: list[StagedRow[ProjectImportRow]] = []
|
||||
accepted = skipped = errored = 0
|
||||
|
||||
seen_names: set[str] = set()
|
||||
|
||||
existing_by_name: dict[str, Project] = {}
|
||||
if result.rows:
|
||||
with self._uow_factory() as uow:
|
||||
if not uow.projects:
|
||||
raise RuntimeError("Project repository is unavailable")
|
||||
existing_by_name = dict(
|
||||
uow.projects.find_by_names(
|
||||
parsed.data.name for parsed in result.rows
|
||||
)
|
||||
)
|
||||
|
||||
for parsed in result.rows:
|
||||
name_key = _normalise_key(parsed.data.name)
|
||||
issues: list[str] = []
|
||||
context: dict[str, Any] | None = None
|
||||
state = ImportPreviewState.NEW
|
||||
|
||||
if name_key in seen_names:
|
||||
state = ImportPreviewState.SKIP
|
||||
issues.append(
|
||||
"Duplicate project name within upload; row skipped.")
|
||||
else:
|
||||
seen_names.add(name_key)
|
||||
existing = existing_by_name.get(name_key)
|
||||
if existing:
|
||||
state = ImportPreviewState.UPDATE
|
||||
context = {
|
||||
"mode": "update",
|
||||
"project_id": existing.id,
|
||||
}
|
||||
issues.append("Existing project will be updated.")
|
||||
else:
|
||||
context = {"mode": "create"}
|
||||
|
||||
preview_rows.append(
|
||||
ImportPreviewRow(
|
||||
row_number=parsed.row_number,
|
||||
data=parsed.data,
|
||||
state=state,
|
||||
issues=issues,
|
||||
context=context,
|
||||
)
|
||||
)
|
||||
|
||||
if state in {ImportPreviewState.NEW, ImportPreviewState.UPDATE}:
|
||||
accepted += 1
|
||||
staged_rows.append(
|
||||
StagedRow(parsed=parsed, context=context or {
|
||||
"mode": "create"})
|
||||
)
|
||||
elif state == ImportPreviewState.SKIP:
|
||||
skipped += 1
|
||||
else:
|
||||
errored += 1
|
||||
|
||||
parser_error_rows = {error.row_number for error in parser_errors}
|
||||
errored += len(parser_error_rows)
|
||||
total_rows = len(preview_rows) + len(parser_error_rows)
|
||||
|
||||
summary = ImportPreviewSummary(
|
||||
total_rows=total_rows,
|
||||
accepted=accepted,
|
||||
skipped=skipped,
|
||||
errored=errored,
|
||||
)
|
||||
|
||||
row_issues = _compile_row_issues(preview_rows, parser_errors)
|
||||
|
||||
stage_token: str | None = None
|
||||
if staged_rows:
|
||||
stage_token = self._store_project_stage(staged_rows)
|
||||
|
||||
return ImportPreview(
|
||||
rows=preview_rows,
|
||||
summary=summary,
|
||||
row_issues=row_issues,
|
||||
parser_errors=parser_errors,
|
||||
stage_token=stage_token,
|
||||
)
|
||||
|
||||
def preview_scenarios(
|
||||
self,
|
||||
stream: BinaryIO,
|
||||
filename: str,
|
||||
) -> ImportPreview[ScenarioImportRow]:
|
||||
result = load_scenario_imports(stream, filename)
|
||||
parser_errors = result.errors
|
||||
|
||||
preview_rows: list[ImportPreviewRow[ScenarioImportRow]] = []
|
||||
staged_rows: list[StagedRow[ScenarioImportRow]] = []
|
||||
accepted = skipped = errored = 0
|
||||
|
||||
seen_pairs: set[tuple[str, str]] = set()
|
||||
|
||||
existing_projects: dict[str, Project] = {}
|
||||
existing_scenarios: dict[tuple[int, str], Scenario] = {}
|
||||
|
||||
if result.rows:
|
||||
with self._uow_factory() as uow:
|
||||
if not uow.projects or not uow.scenarios:
|
||||
raise RuntimeError("Repositories are unavailable")
|
||||
|
||||
existing_projects = dict(
|
||||
uow.projects.find_by_names(
|
||||
parsed.data.project_name for parsed in result.rows
|
||||
)
|
||||
)
|
||||
|
||||
names_by_project: dict[int, set[str]] = {}
|
||||
for parsed in result.rows:
|
||||
project = existing_projects.get(
|
||||
_normalise_key(parsed.data.project_name)
|
||||
)
|
||||
if not project:
|
||||
continue
|
||||
names_by_project.setdefault(project.id, set()).add(
|
||||
_normalise_key(parsed.data.name)
|
||||
)
|
||||
|
||||
for project_id, names in names_by_project.items():
|
||||
matches = uow.scenarios.find_by_project_and_names(
|
||||
project_id, names)
|
||||
for name_key, scenario in matches.items():
|
||||
existing_scenarios[(project_id, name_key)] = scenario
|
||||
|
||||
for parsed in result.rows:
|
||||
project_key = _normalise_key(parsed.data.project_name)
|
||||
scenario_key = _normalise_key(parsed.data.name)
|
||||
issues: list[str] = []
|
||||
context: dict[str, Any] | None = None
|
||||
state = ImportPreviewState.NEW
|
||||
|
||||
if (project_key, scenario_key) in seen_pairs:
|
||||
state = ImportPreviewState.SKIP
|
||||
issues.append(
|
||||
"Duplicate scenario for project within upload; row skipped."
|
||||
)
|
||||
else:
|
||||
seen_pairs.add((project_key, scenario_key))
|
||||
project = existing_projects.get(project_key)
|
||||
if not project:
|
||||
state = ImportPreviewState.ERROR
|
||||
issues.append(
|
||||
f"Project '{parsed.data.project_name}' does not exist."
|
||||
)
|
||||
else:
|
||||
context = {"mode": "create", "project_id": project.id}
|
||||
existing = existing_scenarios.get(
|
||||
(project.id, scenario_key))
|
||||
if existing:
|
||||
state = ImportPreviewState.UPDATE
|
||||
context = {
|
||||
"mode": "update",
|
||||
"project_id": project.id,
|
||||
"scenario_id": existing.id,
|
||||
}
|
||||
issues.append("Existing scenario will be updated.")
|
||||
|
||||
preview_rows.append(
|
||||
ImportPreviewRow(
|
||||
row_number=parsed.row_number,
|
||||
data=parsed.data,
|
||||
state=state,
|
||||
issues=issues,
|
||||
context=context,
|
||||
)
|
||||
)
|
||||
|
||||
if state in {ImportPreviewState.NEW, ImportPreviewState.UPDATE}:
|
||||
accepted += 1
|
||||
staged_rows.append(
|
||||
StagedRow(parsed=parsed, context=context or {
|
||||
"mode": "create"})
|
||||
)
|
||||
elif state == ImportPreviewState.SKIP:
|
||||
skipped += 1
|
||||
else:
|
||||
errored += 1
|
||||
|
||||
parser_error_rows = {error.row_number for error in parser_errors}
|
||||
errored += len(parser_error_rows)
|
||||
total_rows = len(preview_rows) + len(parser_error_rows)
|
||||
|
||||
summary = ImportPreviewSummary(
|
||||
total_rows=total_rows,
|
||||
accepted=accepted,
|
||||
skipped=skipped,
|
||||
errored=errored,
|
||||
)
|
||||
|
||||
row_issues = _compile_row_issues(preview_rows, parser_errors)
|
||||
|
||||
stage_token: str | None = None
|
||||
if staged_rows:
|
||||
stage_token = self._store_scenario_stage(staged_rows)
|
||||
|
||||
return ImportPreview(
|
||||
rows=preview_rows,
|
||||
summary=summary,
|
||||
row_issues=row_issues,
|
||||
parser_errors=parser_errors,
|
||||
stage_token=stage_token,
|
||||
)
|
||||
|
||||
def get_staged_projects(
|
||||
self, token: str
|
||||
) -> StagedImportView[ProjectImportRow] | None:
|
||||
staged = self._project_stage.get(token)
|
||||
if not staged:
|
||||
return None
|
||||
return _build_staged_view(staged)
|
||||
|
||||
def get_staged_scenarios(
|
||||
self, token: str
|
||||
) -> StagedImportView[ScenarioImportRow] | None:
|
||||
staged = self._scenario_stage.get(token)
|
||||
if not staged:
|
||||
return None
|
||||
return _build_staged_view(staged)
|
||||
|
||||
def consume_staged_projects(
|
||||
self, token: str
|
||||
) -> StagedImportView[ProjectImportRow] | None:
|
||||
staged = self._project_stage.pop(token, None)
|
||||
if not staged:
|
||||
return None
|
||||
return _build_staged_view(staged)
|
||||
|
||||
def consume_staged_scenarios(
|
||||
self, token: str
|
||||
) -> StagedImportView[ScenarioImportRow] | None:
|
||||
staged = self._scenario_stage.pop(token, None)
|
||||
if not staged:
|
||||
return None
|
||||
return _build_staged_view(staged)
|
||||
|
||||
def clear_staged_projects(self, token: str) -> bool:
|
||||
return self._project_stage.pop(token, None) is not None
|
||||
|
||||
def clear_staged_scenarios(self, token: str) -> bool:
|
||||
return self._scenario_stage.pop(token, None) is not None
|
||||
|
||||
def _store_project_stage(
|
||||
self, rows: list[StagedRow[ProjectImportRow]]
|
||||
) -> str:
|
||||
token = str(uuid4())
|
||||
self._project_stage[token] = StagedImport(token=token, rows=rows)
|
||||
return token
|
||||
|
||||
def _store_scenario_stage(
|
||||
self, rows: list[StagedRow[ScenarioImportRow]]
|
||||
) -> str:
|
||||
token = str(uuid4())
|
||||
self._scenario_stage[token] = StagedImport(token=token, rows=rows)
|
||||
return token
|
||||
|
||||
|
||||
def load_project_imports(stream: BinaryIO, filename: str) -> ImportResult[ProjectImportRow]:
|
||||
df = _load_dataframe(stream, filename)
|
||||
return _parse_dataframe(df, ProjectImportRow, PROJECT_COLUMNS)
|
||||
|
||||
|
||||
def load_scenario_imports(stream: BinaryIO, filename: str) -> ImportResult[ScenarioImportRow]:
|
||||
df = _load_dataframe(stream, filename)
|
||||
return _parse_dataframe(df, ScenarioImportRow, SCENARIO_COLUMNS)
|
||||
|
||||
|
||||
def _load_dataframe(stream: BinaryIO, filename: str) -> DataFrame:
|
||||
stream.seek(0)
|
||||
suffix = Path(filename).suffix.lower()
|
||||
if suffix == ".csv":
|
||||
df = pd.read_csv(stream, dtype=str,
|
||||
keep_default_na=False, encoding="utf-8")
|
||||
elif suffix in {".xls", ".xlsx"}:
|
||||
df = pd.read_excel(stream, dtype=str, engine="openpyxl")
|
||||
else:
|
||||
raise UnsupportedImportFormat(
|
||||
f"Unsupported file type: {suffix or 'unknown'}")
|
||||
df.columns = [str(col).strip().lower() for col in df.columns]
|
||||
return df
|
||||
|
||||
|
||||
def _parse_dataframe(
|
||||
df: DataFrame,
|
||||
model: type[TImportRow],
|
||||
expected_columns: Iterable[str],
|
||||
) -> ImportResult[TImportRow]:
|
||||
rows: list[ParsedImportRow[TImportRow]] = []
|
||||
errors: list[ImportRowError] = []
|
||||
for index, raw in enumerate(df.to_dict(orient="records"), start=2):
|
||||
payload = _prepare_payload(
|
||||
cast(dict[str, object], raw), expected_columns)
|
||||
try:
|
||||
rows.append(
|
||||
ParsedImportRow(row_number=index, data=model(**payload))
|
||||
)
|
||||
except ValidationError as exc: # pragma: no cover - exercised via tests
|
||||
for detail in exc.errors():
|
||||
loc = ".".join(str(part)
|
||||
for part in detail.get("loc", [])) or None
|
||||
errors.append(
|
||||
ImportRowError(
|
||||
row_number=index,
|
||||
field=loc,
|
||||
message=detail.get("msg", "Invalid value"),
|
||||
)
|
||||
)
|
||||
return ImportResult(rows=rows, errors=errors)
|
||||
|
||||
|
||||
def _prepare_payload(
|
||||
raw: dict[str, object], expected_columns: Iterable[str]
|
||||
) -> dict[str, object | None]:
|
||||
payload: dict[str, object | None] = {}
|
||||
for column in expected_columns:
|
||||
if column not in raw:
|
||||
continue
|
||||
value = raw.get(column)
|
||||
if isinstance(value, str):
|
||||
value = value.strip()
|
||||
if value == "":
|
||||
value = None
|
||||
if value is not None and pd.isna(cast(Any, value)):
|
||||
value = None
|
||||
payload[column] = value
|
||||
return payload
|
||||
|
||||
|
||||
def _normalise_key(value: str) -> str:
|
||||
return value.strip().lower()
|
||||
|
||||
|
||||
def _build_staged_view(
|
||||
staged: StagedImport[TImportRow],
|
||||
) -> StagedImportView[TImportRow]:
|
||||
rows = tuple(
|
||||
StagedRowView(
|
||||
row_number=row.parsed.row_number,
|
||||
data=cast(TImportRow, _deep_copy_model(row.parsed.data)),
|
||||
context=MappingProxyType(dict(row.context)),
|
||||
)
|
||||
for row in staged.rows
|
||||
)
|
||||
return StagedImportView(token=staged.token, rows=rows)
|
||||
|
||||
|
||||
def _deep_copy_model(model: BaseModel) -> BaseModel:
|
||||
copy_method = getattr(model, "model_copy", None)
|
||||
if callable(copy_method): # pydantic v2
|
||||
return cast(BaseModel, copy_method(deep=True))
|
||||
return model.copy(deep=True) # type: ignore[attr-defined]
|
||||
|
||||
|
||||
def _compile_row_issues(
|
||||
preview_rows: Iterable[ImportPreviewRow[Any]],
|
||||
parser_errors: Iterable[ImportRowError],
|
||||
) -> list[ImportPreviewRowIssues]:
|
||||
issue_map: dict[int, ImportPreviewRowIssues] = {}
|
||||
|
||||
def ensure_bundle(
|
||||
row_number: int,
|
||||
state: ImportPreviewState | None,
|
||||
) -> ImportPreviewRowIssues:
|
||||
bundle = issue_map.get(row_number)
|
||||
if bundle is None:
|
||||
bundle = ImportPreviewRowIssues(
|
||||
row_number=row_number,
|
||||
state=state,
|
||||
issues=[],
|
||||
)
|
||||
issue_map[row_number] = bundle
|
||||
else:
|
||||
if _state_priority(state) > _state_priority(bundle.state):
|
||||
bundle.state = state
|
||||
return bundle
|
||||
|
||||
for row in preview_rows:
|
||||
if not row.issues:
|
||||
continue
|
||||
bundle = ensure_bundle(row.row_number, row.state)
|
||||
for message in row.issues:
|
||||
bundle.issues.append(ImportPreviewRowIssue(message=message))
|
||||
|
||||
for error in parser_errors:
|
||||
bundle = ensure_bundle(error.row_number, ImportPreviewState.ERROR)
|
||||
bundle.issues.append(
|
||||
ImportPreviewRowIssue(message=error.message, field=error.field)
|
||||
)
|
||||
|
||||
return sorted(issue_map.values(), key=lambda item: item.row_number)
|
||||
|
||||
|
||||
def _state_priority(state: ImportPreviewState | None) -> int:
|
||||
if state is None:
|
||||
return -1
|
||||
if state == ImportPreviewState.ERROR:
|
||||
return 3
|
||||
if state == ImportPreviewState.SKIP:
|
||||
return 2
|
||||
if state == ImportPreviewState.UPDATE:
|
||||
return 1
|
||||
return 0
|
||||
@@ -2,7 +2,7 @@ from __future__ import annotations
|
||||
|
||||
from collections.abc import Iterable
|
||||
from datetime import datetime
|
||||
from typing import Sequence
|
||||
from typing import Mapping, Sequence
|
||||
|
||||
from sqlalchemy import select, func
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
@@ -70,6 +70,15 @@ class ProjectRepository:
|
||||
"Project violates uniqueness constraints") from exc
|
||||
return project
|
||||
|
||||
def find_by_names(self, names: Iterable[str]) -> Mapping[str, Project]:
|
||||
normalised = {name.strip().lower()
|
||||
for name in names if name and name.strip()}
|
||||
if not normalised:
|
||||
return {}
|
||||
stmt = select(Project).where(func.lower(Project.name).in_(normalised))
|
||||
records = self.session.execute(stmt).scalars().all()
|
||||
return {project.name.lower(): project for project in records}
|
||||
|
||||
def delete(self, project_id: int) -> None:
|
||||
project = self.get(project_id)
|
||||
self.session.delete(project)
|
||||
@@ -149,6 +158,25 @@ class ScenarioRepository:
|
||||
raise EntityConflictError("Scenario violates constraints") from exc
|
||||
return scenario
|
||||
|
||||
def find_by_project_and_names(
|
||||
self,
|
||||
project_id: int,
|
||||
names: Iterable[str],
|
||||
) -> Mapping[str, Scenario]:
|
||||
normalised = {name.strip().lower()
|
||||
for name in names if name and name.strip()}
|
||||
if not normalised:
|
||||
return {}
|
||||
stmt = (
|
||||
select(Scenario)
|
||||
.where(
|
||||
Scenario.project_id == project_id,
|
||||
func.lower(Scenario.name).in_(normalised),
|
||||
)
|
||||
)
|
||||
records = self.session.execute(stmt).scalars().all()
|
||||
return {scenario.name.lower(): scenario for scenario in records}
|
||||
|
||||
def delete(self, scenario_id: int) -> None:
|
||||
scenario = self.get(scenario_id)
|
||||
self.session.delete(scenario)
|
||||
|
||||
Reference in New Issue
Block a user