diff --git a/changelog.md b/changelog.md index 34770e7..e449756 100644 --- a/changelog.md +++ b/changelog.md @@ -30,3 +30,5 @@ - Implemented environment-driven admin bootstrap settings, wired the `bootstrap_admin` helper into FastAPI startup, added pytest coverage for creation/idempotency/reset logic, and documented operational guidance in the RBAC plan and security concept. - Retired the legacy authentication RBAC implementation plan document after migrating its guidance into live documentation and synchronized the contributor instructions to reflect the removal. - Completed the Authentication & RBAC checklist by shipping the new models, migrations, repositories, guard dependencies, and integration tests. +- Documented the project/scenario import/export field mapping and file format guidelines in `calminer-docs/requirements/FR-008.md`, and introduced `schemas/imports.py` with Pydantic models that normalise incoming CSV/Excel rows for projects and scenarios. +- Added `services/importers.py` to load CSV/XLSX files into the new import schemas, pulled in `openpyxl` for Excel support, and covered the parsing behaviour with `tests/test_import_parsing.py`. diff --git a/requirements.txt b/requirements.txt index 7227da4..7658aa7 100644 --- a/requirements.txt +++ b/requirements.txt @@ -11,4 +11,5 @@ numpy passlib argon2-cffi python-jose -python-multipart \ No newline at end of file +python-multipart +openpyxl \ No newline at end of file diff --git a/schemas/imports.py b/schemas/imports.py new file mode 100644 index 0000000..a697c94 --- /dev/null +++ b/schemas/imports.py @@ -0,0 +1,172 @@ +from __future__ import annotations + +from datetime import date, datetime +from typing import Any, Mapping + +from pydantic import BaseModel, ConfigDict, field_validator, model_validator + +from models import MiningOperationType, ResourceType, ScenarioStatus + + +def _normalise_string(value: Any) -> str: + if value is None: + return "" + if isinstance(value, str): + return value.strip() + return str(value).strip() + + +def _strip_or_none(value: Any | None) -> str | None: + if value is None: + return None + text = _normalise_string(value) + return text or None + + +def _coerce_enum(value: Any, enum_cls: Any, aliases: Mapping[str, Any]) -> Any: + if value is None: + return value + if isinstance(value, enum_cls): + return value + text = _normalise_string(value).lower() + if not text: + return None + if text in aliases: + return aliases[text] + try: + return enum_cls(text) + except ValueError as exc: # pragma: no cover - surfaced by Pydantic + raise ValueError( + f"Invalid value '{value}' for {enum_cls.__name__}") from exc + + +OPERATION_TYPE_ALIASES: dict[str, MiningOperationType] = { + "open pit": MiningOperationType.OPEN_PIT, + "openpit": MiningOperationType.OPEN_PIT, + "underground": MiningOperationType.UNDERGROUND, + "in-situ leach": MiningOperationType.IN_SITU_LEACH, + "in situ": MiningOperationType.IN_SITU_LEACH, + "placer": MiningOperationType.PLACER, + "quarry": MiningOperationType.QUARRY, + "mountaintop removal": MiningOperationType.MOUNTAINTOP_REMOVAL, + "other": MiningOperationType.OTHER, +} + + +SCENARIO_STATUS_ALIASES: dict[str, ScenarioStatus] = { + "draft": ScenarioStatus.DRAFT, + "active": ScenarioStatus.ACTIVE, + "archived": ScenarioStatus.ARCHIVED, +} + + +RESOURCE_TYPE_ALIASES: dict[str, ResourceType] = { + key.replace("_", " ").lower(): value for key, value in ResourceType.__members__.items() +} +RESOURCE_TYPE_ALIASES.update( + {value.value.replace("_", " ").lower(): value for value in ResourceType} +) + + +class ProjectImportRow(BaseModel): + name: str + location: str | None = None + operation_type: MiningOperationType + description: str | None = None + created_at: datetime | None = None + updated_at: datetime | None = None + + model_config = ConfigDict(extra="forbid") + + @field_validator("name", mode="before") + @classmethod + def validate_name(cls, value: Any) -> str: + text = _normalise_string(value) + if not text: + raise ValueError("Project name is required") + return text + + @field_validator("location", "description", mode="before") + @classmethod + def optional_text(cls, value: Any | None) -> str | None: + return _strip_or_none(value) + + @field_validator("operation_type", mode="before") + @classmethod + def map_operation_type(cls, value: Any) -> MiningOperationType | None: + return _coerce_enum(value, MiningOperationType, OPERATION_TYPE_ALIASES) + + +class ScenarioImportRow(BaseModel): + project_name: str + name: str + status: ScenarioStatus = ScenarioStatus.DRAFT + start_date: date | None = None + end_date: date | None = None + discount_rate: float | None = None + currency: str | None = None + primary_resource: ResourceType | None = None + description: str | None = None + created_at: datetime | None = None + updated_at: datetime | None = None + + model_config = ConfigDict(extra="forbid") + + @field_validator("project_name", "name", mode="before") + @classmethod + def validate_required_text(cls, value: Any, info) -> str: + text = _normalise_string(value) + if not text: + raise ValueError( + f"{info.field_name.replace('_', ' ').title()} is required") + return text + + @field_validator("status", mode="before") + @classmethod + def map_status(cls, value: Any) -> ScenarioStatus | None: + return _coerce_enum(value, ScenarioStatus, SCENARIO_STATUS_ALIASES) + + @field_validator("primary_resource", mode="before") + @classmethod + def map_resource(cls, value: Any) -> ResourceType | None: + return _coerce_enum(value, ResourceType, RESOURCE_TYPE_ALIASES) + + @field_validator("description", mode="before") + @classmethod + def optional_description(cls, value: Any | None) -> str | None: + return _strip_or_none(value) + + @field_validator("currency", mode="before") + @classmethod + def normalise_currency(cls, value: Any | None) -> str | None: + if value is None: + return None + text = _normalise_string(value).upper() + if not text: + return None + if len(text) != 3: + raise ValueError("Currency code must be a 3-letter ISO value") + return text + + @field_validator("discount_rate", mode="before") + @classmethod + def coerce_discount_rate(cls, value: Any | None) -> float | None: + if value is None: + return None + if isinstance(value, (int, float)): + return float(value) + text = _normalise_string(value) + if not text: + return None + if text.endswith("%"): + text = text[:-1] + try: + return float(text) + except ValueError as exc: + raise ValueError("Discount rate must be numeric") from exc + + @model_validator(mode="after") + def validate_dates(self) -> "ScenarioImportRow": + if self.start_date and self.end_date and self.start_date > self.end_date: + raise ValueError("End date must be on or after start date") + return self diff --git a/services/importers.py b/services/importers.py new file mode 100644 index 0000000..921f2fd --- /dev/null +++ b/services/importers.py @@ -0,0 +1,564 @@ +from __future__ import annotations + +from dataclasses import dataclass +from enum import Enum +from pathlib import Path +from typing import Any, BinaryIO, Callable, Generic, Iterable, Mapping, TypeVar, cast +from uuid import uuid4 +from types import MappingProxyType + +import pandas as pd +from pandas import DataFrame +from pydantic import BaseModel, ValidationError + +from models import Project, Scenario +from schemas.imports import ProjectImportRow, ScenarioImportRow +from services.unit_of_work import UnitOfWork + +TImportRow = TypeVar("TImportRow", bound=BaseModel) + +PROJECT_COLUMNS: tuple[str, ...] = ( + "name", + "location", + "operation_type", + "description", + "created_at", + "updated_at", +) + +SCENARIO_COLUMNS: tuple[str, ...] = ( + "project_name", + "name", + "status", + "start_date", + "end_date", + "discount_rate", + "currency", + "primary_resource", + "description", + "created_at", + "updated_at", +) + + +@dataclass(slots=True) +class ImportRowError: + row_number: int + field: str | None + message: str + + +@dataclass(slots=True) +class ParsedImportRow(Generic[TImportRow]): + row_number: int + data: TImportRow + + +@dataclass(slots=True) +class ImportResult(Generic[TImportRow]): + rows: list[ParsedImportRow[TImportRow]] + errors: list[ImportRowError] + + +class UnsupportedImportFormat(ValueError): + pass + + +class ImportPreviewState(str, Enum): + NEW = "new" + UPDATE = "update" + SKIP = "skip" + ERROR = "error" + + +@dataclass(slots=True) +class ImportPreviewRow(Generic[TImportRow]): + row_number: int + data: TImportRow + state: ImportPreviewState + issues: list[str] + context: dict[str, Any] | None = None + + +@dataclass(slots=True) +class ImportPreviewSummary: + total_rows: int + accepted: int + skipped: int + errored: int + + +@dataclass(slots=True) +class ImportPreview(Generic[TImportRow]): + rows: list[ImportPreviewRow[TImportRow]] + summary: ImportPreviewSummary + row_issues: list["ImportPreviewRowIssues"] + parser_errors: list[ImportRowError] + stage_token: str | None + + +@dataclass(slots=True) +class StagedRow(Generic[TImportRow]): + parsed: ParsedImportRow[TImportRow] + context: dict[str, Any] + + +@dataclass(slots=True) +class ImportPreviewRowIssue: + message: str + field: str | None = None + + +@dataclass(slots=True) +class ImportPreviewRowIssues: + row_number: int + state: ImportPreviewState | None + issues: list[ImportPreviewRowIssue] + + +@dataclass(slots=True) +class StagedImport(Generic[TImportRow]): + token: str + rows: list[StagedRow[TImportRow]] + + +@dataclass(slots=True, frozen=True) +class StagedRowView(Generic[TImportRow]): + row_number: int + data: TImportRow + context: Mapping[str, Any] + + +@dataclass(slots=True, frozen=True) +class StagedImportView(Generic[TImportRow]): + token: str + rows: tuple[StagedRowView[TImportRow], ...] + + +UnitOfWorkFactory = Callable[[], UnitOfWork] + + +class ImportIngestionService: + """Coordinates parsing, validation, and preview staging for imports.""" + + def __init__(self, uow_factory: UnitOfWorkFactory) -> None: + self._uow_factory = uow_factory + self._project_stage: dict[str, StagedImport[ProjectImportRow]] = {} + self._scenario_stage: dict[str, StagedImport[ScenarioImportRow]] = {} + + def preview_projects( + self, + stream: BinaryIO, + filename: str, + ) -> ImportPreview[ProjectImportRow]: + result = load_project_imports(stream, filename) + parser_errors = result.errors + + preview_rows: list[ImportPreviewRow[ProjectImportRow]] = [] + staged_rows: list[StagedRow[ProjectImportRow]] = [] + accepted = skipped = errored = 0 + + seen_names: set[str] = set() + + existing_by_name: dict[str, Project] = {} + if result.rows: + with self._uow_factory() as uow: + if not uow.projects: + raise RuntimeError("Project repository is unavailable") + existing_by_name = dict( + uow.projects.find_by_names( + parsed.data.name for parsed in result.rows + ) + ) + + for parsed in result.rows: + name_key = _normalise_key(parsed.data.name) + issues: list[str] = [] + context: dict[str, Any] | None = None + state = ImportPreviewState.NEW + + if name_key in seen_names: + state = ImportPreviewState.SKIP + issues.append( + "Duplicate project name within upload; row skipped.") + else: + seen_names.add(name_key) + existing = existing_by_name.get(name_key) + if existing: + state = ImportPreviewState.UPDATE + context = { + "mode": "update", + "project_id": existing.id, + } + issues.append("Existing project will be updated.") + else: + context = {"mode": "create"} + + preview_rows.append( + ImportPreviewRow( + row_number=parsed.row_number, + data=parsed.data, + state=state, + issues=issues, + context=context, + ) + ) + + if state in {ImportPreviewState.NEW, ImportPreviewState.UPDATE}: + accepted += 1 + staged_rows.append( + StagedRow(parsed=parsed, context=context or { + "mode": "create"}) + ) + elif state == ImportPreviewState.SKIP: + skipped += 1 + else: + errored += 1 + + parser_error_rows = {error.row_number for error in parser_errors} + errored += len(parser_error_rows) + total_rows = len(preview_rows) + len(parser_error_rows) + + summary = ImportPreviewSummary( + total_rows=total_rows, + accepted=accepted, + skipped=skipped, + errored=errored, + ) + + row_issues = _compile_row_issues(preview_rows, parser_errors) + + stage_token: str | None = None + if staged_rows: + stage_token = self._store_project_stage(staged_rows) + + return ImportPreview( + rows=preview_rows, + summary=summary, + row_issues=row_issues, + parser_errors=parser_errors, + stage_token=stage_token, + ) + + def preview_scenarios( + self, + stream: BinaryIO, + filename: str, + ) -> ImportPreview[ScenarioImportRow]: + result = load_scenario_imports(stream, filename) + parser_errors = result.errors + + preview_rows: list[ImportPreviewRow[ScenarioImportRow]] = [] + staged_rows: list[StagedRow[ScenarioImportRow]] = [] + accepted = skipped = errored = 0 + + seen_pairs: set[tuple[str, str]] = set() + + existing_projects: dict[str, Project] = {} + existing_scenarios: dict[tuple[int, str], Scenario] = {} + + if result.rows: + with self._uow_factory() as uow: + if not uow.projects or not uow.scenarios: + raise RuntimeError("Repositories are unavailable") + + existing_projects = dict( + uow.projects.find_by_names( + parsed.data.project_name for parsed in result.rows + ) + ) + + names_by_project: dict[int, set[str]] = {} + for parsed in result.rows: + project = existing_projects.get( + _normalise_key(parsed.data.project_name) + ) + if not project: + continue + names_by_project.setdefault(project.id, set()).add( + _normalise_key(parsed.data.name) + ) + + for project_id, names in names_by_project.items(): + matches = uow.scenarios.find_by_project_and_names( + project_id, names) + for name_key, scenario in matches.items(): + existing_scenarios[(project_id, name_key)] = scenario + + for parsed in result.rows: + project_key = _normalise_key(parsed.data.project_name) + scenario_key = _normalise_key(parsed.data.name) + issues: list[str] = [] + context: dict[str, Any] | None = None + state = ImportPreviewState.NEW + + if (project_key, scenario_key) in seen_pairs: + state = ImportPreviewState.SKIP + issues.append( + "Duplicate scenario for project within upload; row skipped." + ) + else: + seen_pairs.add((project_key, scenario_key)) + project = existing_projects.get(project_key) + if not project: + state = ImportPreviewState.ERROR + issues.append( + f"Project '{parsed.data.project_name}' does not exist." + ) + else: + context = {"mode": "create", "project_id": project.id} + existing = existing_scenarios.get( + (project.id, scenario_key)) + if existing: + state = ImportPreviewState.UPDATE + context = { + "mode": "update", + "project_id": project.id, + "scenario_id": existing.id, + } + issues.append("Existing scenario will be updated.") + + preview_rows.append( + ImportPreviewRow( + row_number=parsed.row_number, + data=parsed.data, + state=state, + issues=issues, + context=context, + ) + ) + + if state in {ImportPreviewState.NEW, ImportPreviewState.UPDATE}: + accepted += 1 + staged_rows.append( + StagedRow(parsed=parsed, context=context or { + "mode": "create"}) + ) + elif state == ImportPreviewState.SKIP: + skipped += 1 + else: + errored += 1 + + parser_error_rows = {error.row_number for error in parser_errors} + errored += len(parser_error_rows) + total_rows = len(preview_rows) + len(parser_error_rows) + + summary = ImportPreviewSummary( + total_rows=total_rows, + accepted=accepted, + skipped=skipped, + errored=errored, + ) + + row_issues = _compile_row_issues(preview_rows, parser_errors) + + stage_token: str | None = None + if staged_rows: + stage_token = self._store_scenario_stage(staged_rows) + + return ImportPreview( + rows=preview_rows, + summary=summary, + row_issues=row_issues, + parser_errors=parser_errors, + stage_token=stage_token, + ) + + def get_staged_projects( + self, token: str + ) -> StagedImportView[ProjectImportRow] | None: + staged = self._project_stage.get(token) + if not staged: + return None + return _build_staged_view(staged) + + def get_staged_scenarios( + self, token: str + ) -> StagedImportView[ScenarioImportRow] | None: + staged = self._scenario_stage.get(token) + if not staged: + return None + return _build_staged_view(staged) + + def consume_staged_projects( + self, token: str + ) -> StagedImportView[ProjectImportRow] | None: + staged = self._project_stage.pop(token, None) + if not staged: + return None + return _build_staged_view(staged) + + def consume_staged_scenarios( + self, token: str + ) -> StagedImportView[ScenarioImportRow] | None: + staged = self._scenario_stage.pop(token, None) + if not staged: + return None + return _build_staged_view(staged) + + def clear_staged_projects(self, token: str) -> bool: + return self._project_stage.pop(token, None) is not None + + def clear_staged_scenarios(self, token: str) -> bool: + return self._scenario_stage.pop(token, None) is not None + + def _store_project_stage( + self, rows: list[StagedRow[ProjectImportRow]] + ) -> str: + token = str(uuid4()) + self._project_stage[token] = StagedImport(token=token, rows=rows) + return token + + def _store_scenario_stage( + self, rows: list[StagedRow[ScenarioImportRow]] + ) -> str: + token = str(uuid4()) + self._scenario_stage[token] = StagedImport(token=token, rows=rows) + return token + + +def load_project_imports(stream: BinaryIO, filename: str) -> ImportResult[ProjectImportRow]: + df = _load_dataframe(stream, filename) + return _parse_dataframe(df, ProjectImportRow, PROJECT_COLUMNS) + + +def load_scenario_imports(stream: BinaryIO, filename: str) -> ImportResult[ScenarioImportRow]: + df = _load_dataframe(stream, filename) + return _parse_dataframe(df, ScenarioImportRow, SCENARIO_COLUMNS) + + +def _load_dataframe(stream: BinaryIO, filename: str) -> DataFrame: + stream.seek(0) + suffix = Path(filename).suffix.lower() + if suffix == ".csv": + df = pd.read_csv(stream, dtype=str, + keep_default_na=False, encoding="utf-8") + elif suffix in {".xls", ".xlsx"}: + df = pd.read_excel(stream, dtype=str, engine="openpyxl") + else: + raise UnsupportedImportFormat( + f"Unsupported file type: {suffix or 'unknown'}") + df.columns = [str(col).strip().lower() for col in df.columns] + return df + + +def _parse_dataframe( + df: DataFrame, + model: type[TImportRow], + expected_columns: Iterable[str], +) -> ImportResult[TImportRow]: + rows: list[ParsedImportRow[TImportRow]] = [] + errors: list[ImportRowError] = [] + for index, raw in enumerate(df.to_dict(orient="records"), start=2): + payload = _prepare_payload( + cast(dict[str, object], raw), expected_columns) + try: + rows.append( + ParsedImportRow(row_number=index, data=model(**payload)) + ) + except ValidationError as exc: # pragma: no cover - exercised via tests + for detail in exc.errors(): + loc = ".".join(str(part) + for part in detail.get("loc", [])) or None + errors.append( + ImportRowError( + row_number=index, + field=loc, + message=detail.get("msg", "Invalid value"), + ) + ) + return ImportResult(rows=rows, errors=errors) + + +def _prepare_payload( + raw: dict[str, object], expected_columns: Iterable[str] +) -> dict[str, object | None]: + payload: dict[str, object | None] = {} + for column in expected_columns: + if column not in raw: + continue + value = raw.get(column) + if isinstance(value, str): + value = value.strip() + if value == "": + value = None + if value is not None and pd.isna(cast(Any, value)): + value = None + payload[column] = value + return payload + + +def _normalise_key(value: str) -> str: + return value.strip().lower() + + +def _build_staged_view( + staged: StagedImport[TImportRow], +) -> StagedImportView[TImportRow]: + rows = tuple( + StagedRowView( + row_number=row.parsed.row_number, + data=cast(TImportRow, _deep_copy_model(row.parsed.data)), + context=MappingProxyType(dict(row.context)), + ) + for row in staged.rows + ) + return StagedImportView(token=staged.token, rows=rows) + + +def _deep_copy_model(model: BaseModel) -> BaseModel: + copy_method = getattr(model, "model_copy", None) + if callable(copy_method): # pydantic v2 + return cast(BaseModel, copy_method(deep=True)) + return model.copy(deep=True) # type: ignore[attr-defined] + + +def _compile_row_issues( + preview_rows: Iterable[ImportPreviewRow[Any]], + parser_errors: Iterable[ImportRowError], +) -> list[ImportPreviewRowIssues]: + issue_map: dict[int, ImportPreviewRowIssues] = {} + + def ensure_bundle( + row_number: int, + state: ImportPreviewState | None, + ) -> ImportPreviewRowIssues: + bundle = issue_map.get(row_number) + if bundle is None: + bundle = ImportPreviewRowIssues( + row_number=row_number, + state=state, + issues=[], + ) + issue_map[row_number] = bundle + else: + if _state_priority(state) > _state_priority(bundle.state): + bundle.state = state + return bundle + + for row in preview_rows: + if not row.issues: + continue + bundle = ensure_bundle(row.row_number, row.state) + for message in row.issues: + bundle.issues.append(ImportPreviewRowIssue(message=message)) + + for error in parser_errors: + bundle = ensure_bundle(error.row_number, ImportPreviewState.ERROR) + bundle.issues.append( + ImportPreviewRowIssue(message=error.message, field=error.field) + ) + + return sorted(issue_map.values(), key=lambda item: item.row_number) + + +def _state_priority(state: ImportPreviewState | None) -> int: + if state is None: + return -1 + if state == ImportPreviewState.ERROR: + return 3 + if state == ImportPreviewState.SKIP: + return 2 + if state == ImportPreviewState.UPDATE: + return 1 + return 0 diff --git a/services/repositories.py b/services/repositories.py index 170e33e..64e7108 100644 --- a/services/repositories.py +++ b/services/repositories.py @@ -2,7 +2,7 @@ from __future__ import annotations from collections.abc import Iterable from datetime import datetime -from typing import Sequence +from typing import Mapping, Sequence from sqlalchemy import select, func from sqlalchemy.exc import IntegrityError @@ -70,6 +70,15 @@ class ProjectRepository: "Project violates uniqueness constraints") from exc return project + def find_by_names(self, names: Iterable[str]) -> Mapping[str, Project]: + normalised = {name.strip().lower() + for name in names if name and name.strip()} + if not normalised: + return {} + stmt = select(Project).where(func.lower(Project.name).in_(normalised)) + records = self.session.execute(stmt).scalars().all() + return {project.name.lower(): project for project in records} + def delete(self, project_id: int) -> None: project = self.get(project_id) self.session.delete(project) @@ -149,6 +158,25 @@ class ScenarioRepository: raise EntityConflictError("Scenario violates constraints") from exc return scenario + def find_by_project_and_names( + self, + project_id: int, + names: Iterable[str], + ) -> Mapping[str, Scenario]: + normalised = {name.strip().lower() + for name in names if name and name.strip()} + if not normalised: + return {} + stmt = ( + select(Scenario) + .where( + Scenario.project_id == project_id, + func.lower(Scenario.name).in_(normalised), + ) + ) + records = self.session.execute(stmt).scalars().all() + return {scenario.name.lower(): scenario for scenario in records} + def delete(self, scenario_id: int) -> None: scenario = self.get(scenario_id) self.session.delete(scenario) diff --git a/tests/test_import_ingestion.py b/tests/test_import_ingestion.py new file mode 100644 index 0000000..cfba1f2 --- /dev/null +++ b/tests/test_import_ingestion.py @@ -0,0 +1,237 @@ +from __future__ import annotations + +from io import BytesIO +from typing import Callable + +import pandas as pd +import pytest + +from models.project import MiningOperationType, Project +from models.scenario import Scenario, ScenarioStatus +from services.importers import ( + ImportIngestionService, + ImportPreviewState, + StagedImportView, +) +from services.unit_of_work import UnitOfWork + + +@pytest.fixture() +def ingestion_service(unit_of_work_factory: Callable[[], UnitOfWork]) -> ImportIngestionService: + return ImportIngestionService(unit_of_work_factory) + + +def test_preview_projects_flags_updates_and_duplicates( + ingestion_service: ImportIngestionService, + unit_of_work_factory: Callable[[], UnitOfWork], +) -> None: + with unit_of_work_factory() as uow: + assert uow.projects is not None + existing = Project( + name="Project A", + location="Chile", + operation_type=MiningOperationType.OPEN_PIT, + ) + uow.projects.create(existing) + + csv_content = ( + "name,location,operation_type\n" + "Project A,Peru,open pit\n" + "Project B,Canada,underground\n" + "Project B,Canada,underground\n" + ) + stream = BytesIO(csv_content.encode("utf-8")) + + preview = ingestion_service.preview_projects(stream, "projects.csv") + + states = [row.state for row in preview.rows] + assert states == [ + ImportPreviewState.UPDATE, + ImportPreviewState.NEW, + ImportPreviewState.SKIP, + ] + assert preview.summary.total_rows == 3 + assert preview.summary.accepted == 2 + assert preview.summary.skipped == 1 + assert preview.summary.errored == 0 + assert preview.parser_errors == [] + assert preview.stage_token is not None + issue_map = {bundle.row_number: bundle for bundle in preview.row_issues} + assert 2 in issue_map and issue_map[2].state == ImportPreviewState.UPDATE + assert { + detail.message for detail in issue_map[2].issues + } == {"Existing project will be updated."} + assert 4 in issue_map and issue_map[4].state == ImportPreviewState.SKIP + assert any( + "Duplicate project name" in detail.message + for detail in issue_map[4].issues + ) + # type: ignore[attr-defined] + staged = ingestion_service._project_stage[preview.stage_token] + assert len(staged.rows) == 2 + update_context = preview.rows[0].context + assert update_context is not None and update_context.get( + "project_id") is not None + + +def test_preview_scenarios_validates_projects_and_updates( + ingestion_service: ImportIngestionService, + unit_of_work_factory: Callable[[], UnitOfWork], +) -> None: + with unit_of_work_factory() as uow: + assert uow.projects is not None and uow.scenarios is not None + project = Project( + name="Existing Project", + location="Chile", + operation_type=MiningOperationType.OPEN_PIT, + ) + uow.projects.create(project) + scenario = Scenario( + project_id=project.id, + name="Existing Scenario", + status=ScenarioStatus.ACTIVE, + ) + uow.scenarios.create(scenario) + + df = pd.DataFrame( + [ + { + "project_name": "Existing Project", + "name": "Existing Scenario", + "status": "Active", + }, + { + "project_name": "Existing Project", + "name": "New Scenario", + "status": "Draft", + }, + { + "project_name": "Missing Project", + "name": "Ghost Scenario", + "status": "Draft", + }, + { + "project_name": "Existing Project", + "name": "New Scenario", + "status": "Draft", + }, + ] + ) + buffer = BytesIO() + df.to_csv(buffer, index=False) + buffer.seek(0) + + preview = ingestion_service.preview_scenarios(buffer, "scenarios.csv") + + states = [row.state for row in preview.rows] + assert states == [ + ImportPreviewState.UPDATE, + ImportPreviewState.NEW, + ImportPreviewState.ERROR, + ImportPreviewState.SKIP, + ] + assert preview.summary.total_rows == 4 + assert preview.summary.accepted == 2 + assert preview.summary.skipped == 1 + assert preview.summary.errored == 1 + assert preview.stage_token is not None + issue_map = {bundle.row_number: bundle for bundle in preview.row_issues} + assert 2 in issue_map and issue_map[2].state == ImportPreviewState.UPDATE + assert 4 in issue_map and issue_map[4].state == ImportPreviewState.ERROR + assert any( + "does not exist" in detail.message + for detail in issue_map[4].issues + ) + # type: ignore[attr-defined] + staged = ingestion_service._scenario_stage[preview.stage_token] + assert len(staged.rows) == 2 + error_row = preview.rows[2] + assert any("does not exist" in msg for msg in error_row.issues) + + +def test_preview_scenarios_aggregates_parser_errors( + ingestion_service: ImportIngestionService, + unit_of_work_factory: Callable[[], UnitOfWork], +) -> None: + with unit_of_work_factory() as uow: + assert uow.projects is not None + project = Project( + name="Existing Project", + location="Chile", + operation_type=MiningOperationType.OPEN_PIT, + ) + uow.projects.create(project) + + csv_content = ( + "project_name,name,status\n" + "Existing Project,Broken Scenario,UNKNOWN_STATUS\n" + ) + stream = BytesIO(csv_content.encode("utf-8")) + + preview = ingestion_service.preview_scenarios(stream, "invalid.csv") + + assert preview.rows == [] + assert preview.summary.total_rows == 1 + assert preview.summary.errored == 1 + assert preview.stage_token is None + assert len(preview.parser_errors) == 1 + issue_map = {bundle.row_number: bundle for bundle in preview.row_issues} + assert 2 in issue_map + bundle = issue_map[2] + assert bundle.state == ImportPreviewState.ERROR + assert any(detail.field == "status" for detail in bundle.issues) + assert all(detail.message for detail in bundle.issues) + + +def test_consume_staged_projects_removes_token( + ingestion_service: ImportIngestionService, + unit_of_work_factory: Callable[[], UnitOfWork], +) -> None: + with unit_of_work_factory() as uow: + assert uow.projects is not None + + csv_content = ( + "name,location,operation_type\n" + "Project X,Peru,open pit\n" + ) + stream = BytesIO(csv_content.encode("utf-8")) + + preview = ingestion_service.preview_projects(stream, "projects.csv") + assert preview.stage_token is not None + token = preview.stage_token + + initial_view = ingestion_service.get_staged_projects(token) + assert isinstance(initial_view, StagedImportView) + consumed = ingestion_service.consume_staged_projects(token) + assert consumed == initial_view + assert ingestion_service.get_staged_projects(token) is None + assert ingestion_service.consume_staged_projects(token) is None + + +def test_clear_staged_scenarios_drops_entry( + ingestion_service: ImportIngestionService, + unit_of_work_factory: Callable[[], UnitOfWork], +) -> None: + with unit_of_work_factory() as uow: + assert uow.projects is not None + project = Project( + name="Project Y", + location="Chile", + operation_type=MiningOperationType.OPEN_PIT, + ) + uow.projects.create(project) + + csv_content = ( + "project_name,name,status\n" + "Project Y,Scenario 1,Active\n" + ) + stream = BytesIO(csv_content.encode("utf-8")) + + preview = ingestion_service.preview_scenarios(stream, "scenarios.csv") + assert preview.stage_token is not None + token = preview.stage_token + + assert ingestion_service.get_staged_scenarios(token) is not None + assert ingestion_service.clear_staged_scenarios(token) is True + assert ingestion_service.get_staged_scenarios(token) is None + assert ingestion_service.clear_staged_scenarios(token) is False diff --git a/tests/test_import_parsing.py b/tests/test_import_parsing.py new file mode 100644 index 0000000..6c79974 --- /dev/null +++ b/tests/test_import_parsing.py @@ -0,0 +1,78 @@ +from __future__ import annotations + +from io import BytesIO + +import pandas as pd +import pytest + +from services.importers import ImportResult, load_project_imports, load_scenario_imports +from schemas.imports import ProjectImportRow, ScenarioImportRow + + +def test_load_project_imports_from_csv() -> None: + csv_content = ( + "name,location,operation_type,description\n" + "Project A,Chile,open pit,First project\n" + "Project B,,underground,Second project\n" + ) + stream = BytesIO(csv_content.encode("utf-8")) + + result = load_project_imports(stream, "projects.csv") + + assert isinstance(result, ImportResult) + assert len(result.rows) == 2 + assert not result.errors + first = result.rows[0] + assert first.row_number == 2 + assert isinstance(first.data, ProjectImportRow) + assert first.data.name == "Project A" + assert first.data.operation_type.value == "open_pit" + second = result.rows[1] + assert second.row_number == 3 + assert isinstance(second.data, ProjectImportRow) + assert second.data.location is None + + +def test_load_scenario_imports_from_excel() -> None: + df = pd.DataFrame( + [ + { + "project_name": "Project A", + "name": "Scenario 1", + "status": "Active", + "start_date": "2025-01-01", + "end_date": "2025-12-31", + "discount_rate": "7.5%", + "currency": "usd", + "primary_resource": "Electricity", + } + ] + ) + buffer = BytesIO() + df.to_excel(buffer, index=False) + buffer.seek(0) + + result = load_scenario_imports(buffer, "scenarios.xlsx") + + assert len(result.rows) == 1 + assert not result.errors + row = result.rows[0] + assert row.row_number == 2 + assert isinstance(row.data, ScenarioImportRow) + assert row.data.status.value == "active" + assert row.data.currency == "USD" + assert row.data.discount_rate == pytest.approx(7.5) + + +def test_import_errors_include_row_numbers() -> None: + csv_content = "name,operation_type\n,open pit\n" + stream = BytesIO(csv_content.encode("utf-8")) + + result = load_project_imports(stream, "projects.csv") + + assert len(result.rows) == 0 + assert len(result.errors) == 1 + error = result.errors[0] + assert error.row_number == 2 + assert error.field == "name" + assert "required" in error.message