From eaef99f0acdc79cdd37aff9f981819ecf99fc383 Mon Sep 17 00:00:00 2001 From: zwitschi Date: Mon, 10 Nov 2025 09:20:41 +0100 Subject: [PATCH] feat: enhance import functionality with commit results and summary models for projects and scenarios --- schemas/imports.py | 122 ++++++++++++++++++++++++- services/importers.py | 143 +++++++++++++++++++++++++++++ tests/test_import_ingestion.py | 158 +++++++++++++++++++++++++++++++++ 3 files changed, 422 insertions(+), 1 deletion(-) diff --git a/schemas/imports.py b/schemas/imports.py index a697c94..c60b9a7 100644 --- a/schemas/imports.py +++ b/schemas/imports.py @@ -2,11 +2,14 @@ from __future__ import annotations from datetime import date, datetime from typing import Any, Mapping +from typing import Literal -from pydantic import BaseModel, ConfigDict, field_validator, model_validator +from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator from models import MiningOperationType, ResourceType, ScenarioStatus +PreviewStateLiteral = Literal["new", "update", "skip", "error"] + def _normalise_string(value: Any) -> str: if value is None: @@ -170,3 +173,120 @@ class ScenarioImportRow(BaseModel): if self.start_date and self.end_date and self.start_date > self.end_date: raise ValueError("End date must be on or after start date") return self + + +class ImportRowErrorModel(BaseModel): + row_number: int + field: str | None = None + message: str + + model_config = ConfigDict(from_attributes=True, extra="forbid") + + +class ImportPreviewRowIssueModel(BaseModel): + message: str + field: str | None = None + + model_config = ConfigDict(from_attributes=True, extra="forbid") + + +class ImportPreviewRowIssuesModel(BaseModel): + row_number: int + state: PreviewStateLiteral | None = None + issues: list[ImportPreviewRowIssueModel] = Field(default_factory=list) + + model_config = ConfigDict(from_attributes=True, extra="forbid") + + +class ImportPreviewSummaryModel(BaseModel): + total_rows: int + accepted: int + skipped: int + errored: int + + model_config = ConfigDict(from_attributes=True, extra="forbid") + + +class ProjectImportPreviewRow(BaseModel): + row_number: int + data: ProjectImportRow + state: PreviewStateLiteral + issues: list[str] = Field(default_factory=list) + context: dict[str, Any] | None = None + + model_config = ConfigDict(from_attributes=True, extra="forbid") + + +class ScenarioImportPreviewRow(BaseModel): + row_number: int + data: ScenarioImportRow + state: PreviewStateLiteral + issues: list[str] = Field(default_factory=list) + context: dict[str, Any] | None = None + + model_config = ConfigDict(from_attributes=True, extra="forbid") + + +class ProjectImportPreviewResponse(BaseModel): + rows: list[ProjectImportPreviewRow] + summary: ImportPreviewSummaryModel + row_issues: list[ImportPreviewRowIssuesModel] = Field(default_factory=list) + parser_errors: list[ImportRowErrorModel] = Field(default_factory=list) + stage_token: str | None = None + + model_config = ConfigDict(from_attributes=True, extra="forbid") + + +class ScenarioImportPreviewResponse(BaseModel): + rows: list[ScenarioImportPreviewRow] + summary: ImportPreviewSummaryModel + row_issues: list[ImportPreviewRowIssuesModel] = Field(default_factory=list) + parser_errors: list[ImportRowErrorModel] = Field(default_factory=list) + stage_token: str | None = None + + model_config = ConfigDict(from_attributes=True, extra="forbid") + + +class ImportCommitSummaryModel(BaseModel): + created: int + updated: int + + model_config = ConfigDict(from_attributes=True, extra="forbid") + + +class ProjectImportCommitRow(BaseModel): + row_number: int + data: ProjectImportRow + context: dict[str, Any] + + model_config = ConfigDict(from_attributes=True, extra="forbid") + + +class ScenarioImportCommitRow(BaseModel): + row_number: int + data: ScenarioImportRow + context: dict[str, Any] + + model_config = ConfigDict(from_attributes=True, extra="forbid") + + +class ProjectImportCommitResponse(BaseModel): + token: str + rows: list[ProjectImportCommitRow] + summary: ImportCommitSummaryModel + + model_config = ConfigDict(from_attributes=True, extra="forbid") + + +class ScenarioImportCommitResponse(BaseModel): + token: str + rows: list[ScenarioImportCommitRow] + summary: ImportCommitSummaryModel + + model_config = ConfigDict(from_attributes=True, extra="forbid") + + +class ImportCommitRequest(BaseModel): + token: str + + model_config = ConfigDict(extra="forbid") diff --git a/services/importers.py b/services/importers.py index 921f2fd..4594783 100644 --- a/services/importers.py +++ b/services/importers.py @@ -135,6 +135,19 @@ class StagedImportView(Generic[TImportRow]): rows: tuple[StagedRowView[TImportRow], ...] +@dataclass(slots=True, frozen=True) +class ImportCommitSummary: + created: int + updated: int + + +@dataclass(slots=True, frozen=True) +class ImportCommitResult(Generic[TImportRow]): + token: str + rows: tuple[StagedRowView[TImportRow], ...] + summary: ImportCommitSummary + + UnitOfWorkFactory = Callable[[], UnitOfWork] @@ -402,6 +415,136 @@ class ImportIngestionService: def clear_staged_scenarios(self, token: str) -> bool: return self._scenario_stage.pop(token, None) is not None + def commit_project_import(self, token: str) -> ImportCommitResult[ProjectImportRow]: + staged = self._project_stage.get(token) + if not staged: + raise ValueError(f"Unknown project import token: {token}") + + staged_view = _build_staged_view(staged) + created = updated = 0 + + with self._uow_factory() as uow: + if not uow.projects: + raise RuntimeError("Project repository is unavailable") + + for row in staged.rows: + mode = row.context.get("mode") + data = row.parsed.data + + if mode == "create": + project = Project( + name=data.name, + location=data.location, + operation_type=data.operation_type, + description=data.description, + ) + if data.created_at: + project.created_at = data.created_at + if data.updated_at: + project.updated_at = data.updated_at + uow.projects.create(project) + created += 1 + elif mode == "update": + project_id = row.context.get("project_id") + if not project_id: + raise ValueError( + "Staged project update is missing project_id context" + ) + project = uow.projects.get(project_id) + project.name = data.name + project.location = data.location + project.operation_type = data.operation_type + project.description = data.description + if data.created_at: + project.created_at = data.created_at + if data.updated_at: + project.updated_at = data.updated_at + updated += 1 + else: + raise ValueError( + f"Unsupported staged project mode: {mode!r}") + + self._project_stage.pop(token, None) + return ImportCommitResult( + token=token, + rows=staged_view.rows, + summary=ImportCommitSummary(created=created, updated=updated), + ) + + def commit_scenario_import(self, token: str) -> ImportCommitResult[ScenarioImportRow]: + staged = self._scenario_stage.get(token) + if not staged: + raise ValueError(f"Unknown scenario import token: {token}") + + staged_view = _build_staged_view(staged) + created = updated = 0 + + with self._uow_factory() as uow: + if not uow.scenarios or not uow.projects: + raise RuntimeError("Scenario repositories are unavailable") + + for row in staged.rows: + mode = row.context.get("mode") + data = row.parsed.data + + project_id = row.context.get("project_id") + if not project_id: + raise ValueError( + "Staged scenario row is missing project_id context" + ) + + project = uow.projects.get(project_id) + + if mode == "create": + scenario = Scenario( + project_id=project.id, + name=data.name, + status=data.status, + start_date=data.start_date, + end_date=data.end_date, + discount_rate=data.discount_rate, + currency=data.currency, + primary_resource=data.primary_resource, + description=data.description, + ) + if data.created_at: + scenario.created_at = data.created_at + if data.updated_at: + scenario.updated_at = data.updated_at + uow.scenarios.create(scenario) + created += 1 + elif mode == "update": + scenario_id = row.context.get("scenario_id") + if not scenario_id: + raise ValueError( + "Staged scenario update is missing scenario_id context" + ) + scenario = uow.scenarios.get(scenario_id) + scenario.project_id = project.id + scenario.name = data.name + scenario.status = data.status + scenario.start_date = data.start_date + scenario.end_date = data.end_date + scenario.discount_rate = data.discount_rate + scenario.currency = data.currency + scenario.primary_resource = data.primary_resource + scenario.description = data.description + if data.created_at: + scenario.created_at = data.created_at + if data.updated_at: + scenario.updated_at = data.updated_at + updated += 1 + else: + raise ValueError( + f"Unsupported staged scenario mode: {mode!r}") + + self._scenario_stage.pop(token, None) + return ImportCommitResult( + token=token, + rows=staged_view.rows, + summary=ImportCommitSummary(created=created, updated=updated), + ) + def _store_project_stage( self, rows: list[StagedRow[ProjectImportRow]] ) -> str: diff --git a/tests/test_import_ingestion.py b/tests/test_import_ingestion.py index cfba1f2..97f9174 100644 --- a/tests/test_import_ingestion.py +++ b/tests/test_import_ingestion.py @@ -9,11 +9,19 @@ import pytest from models.project import MiningOperationType, Project from models.scenario import Scenario, ScenarioStatus from services.importers import ( + ImportCommitResult, ImportIngestionService, ImportPreviewState, StagedImportView, ) +from services.repositories import ProjectRepository from services.unit_of_work import UnitOfWork +from schemas.imports import ( + ProjectImportCommitResponse, + ProjectImportPreviewResponse, + ScenarioImportCommitResponse, + ScenarioImportPreviewResponse, +) @pytest.fixture() @@ -73,6 +81,9 @@ def test_preview_projects_flags_updates_and_duplicates( assert update_context is not None and update_context.get( "project_id") is not None + response_model = ProjectImportPreviewResponse.model_validate(preview) + assert response_model.summary.accepted == preview.summary.accepted + def test_preview_scenarios_validates_projects_and_updates( ingestion_service: ImportIngestionService, @@ -148,6 +159,9 @@ def test_preview_scenarios_validates_projects_and_updates( error_row = preview.rows[2] assert any("does not exist" in msg for msg in error_row.issues) + response_model = ScenarioImportPreviewResponse.model_validate(preview) + assert response_model.summary.errored == preview.summary.errored + def test_preview_scenarios_aggregates_parser_errors( ingestion_service: ImportIngestionService, @@ -182,6 +196,9 @@ def test_preview_scenarios_aggregates_parser_errors( assert any(detail.field == "status" for detail in bundle.issues) assert all(detail.message for detail in bundle.issues) + response_model = ScenarioImportPreviewResponse.model_validate(preview) + assert response_model.summary.total_rows == 1 + def test_consume_staged_projects_removes_token( ingestion_service: ImportIngestionService, @@ -235,3 +252,144 @@ def test_clear_staged_scenarios_drops_entry( assert ingestion_service.clear_staged_scenarios(token) is True assert ingestion_service.get_staged_scenarios(token) is None assert ingestion_service.clear_staged_scenarios(token) is False + + +def test_commit_project_import_applies_create_and_update( + ingestion_service: ImportIngestionService, + unit_of_work_factory: Callable[[], UnitOfWork], +) -> None: + with unit_of_work_factory() as uow: + assert uow.projects is not None + existing = Project( + name="Project A", + location="Chile", + operation_type=MiningOperationType.OPEN_PIT, + ) + uow.projects.create(existing) + + csv_content = ( + "name,location,operation_type\n" + "Project A,Peru,underground\n" + "Project B,Canada,open pit\n" + ) + stream = BytesIO(csv_content.encode("utf-8")) + preview = ingestion_service.preview_projects(stream, "projects.csv") + assert preview.stage_token is not None + + result = ingestion_service.commit_project_import(preview.stage_token) + assert isinstance(result, ImportCommitResult) + assert result.summary.created == 1 + assert result.summary.updated == 1 + assert ingestion_service.get_staged_projects(preview.stage_token) is None + + commit_response = ProjectImportCommitResponse.model_validate(result) + assert commit_response.summary.updated == 1 + + with unit_of_work_factory() as uow: + assert uow.projects is not None + projects = uow.projects.list() + names = sorted(project.name for project in projects) + assert names == ["Project A", "Project B"] + updated_project = next(p for p in projects if p.name == "Project A") + assert updated_project.location == "Peru" + assert updated_project.operation_type == MiningOperationType.UNDERGROUND + new_project = next(p for p in projects if p.name == "Project B") + assert new_project.location == "Canada" + + +def test_commit_project_import_with_invalid_token_raises( + ingestion_service: ImportIngestionService, +) -> None: + with pytest.raises(ValueError): + ingestion_service.commit_project_import("missing-token") + + +def test_commit_scenario_import_applies_create_and_update( + ingestion_service: ImportIngestionService, + unit_of_work_factory: Callable[[], UnitOfWork], +) -> None: + with unit_of_work_factory() as uow: + assert uow.projects is not None and uow.scenarios is not None + project = Project( + name="Project X", + location="Chile", + operation_type=MiningOperationType.OPEN_PIT, + ) + uow.projects.create(project) + scenario = Scenario( + project_id=project.id, + name="Existing Scenario", + status=ScenarioStatus.ACTIVE, + ) + uow.scenarios.create(scenario) + + csv_content = ( + "project_name,name,status\n" + "Project X,Existing Scenario,Archived\n" + "Project X,New Scenario,Draft\n" + ) + stream = BytesIO(csv_content.encode("utf-8")) + preview = ingestion_service.preview_scenarios(stream, "scenarios.csv") + assert preview.stage_token is not None + + result = ingestion_service.commit_scenario_import(preview.stage_token) + assert result.summary.created == 1 + assert result.summary.updated == 1 + assert ingestion_service.get_staged_scenarios(preview.stage_token) is None + + commit_response = ScenarioImportCommitResponse.model_validate(result) + assert commit_response.summary.created == 1 + + with unit_of_work_factory() as uow: + assert uow.projects is not None and uow.scenarios is not None + scenarios = uow.scenarios.list_for_project(uow.projects.list()[0].id) + names = sorted(scenario.name for scenario in scenarios) + assert names == ["Existing Scenario", "New Scenario"] + updated_scenario = next( + item for item in scenarios if item.name == "Existing Scenario" + ) + assert updated_scenario.status == ScenarioStatus.ARCHIVED + new_scenario = next( + item for item in scenarios if item.name == "New Scenario" + ) + assert new_scenario.status == ScenarioStatus.DRAFT + + +def test_commit_scenario_import_with_invalid_token_raises( + ingestion_service: ImportIngestionService, +) -> None: + with pytest.raises(ValueError): + ingestion_service.commit_scenario_import("missing-token") + + +def test_commit_project_import_rolls_back_on_failure( + ingestion_service: ImportIngestionService, + unit_of_work_factory: Callable[[], UnitOfWork], + monkeypatch: pytest.MonkeyPatch, +) -> None: + with unit_of_work_factory() as uow: + assert uow.projects is not None + + csv_content = ( + "name,location,operation_type\n" + "Project Fail,Peru,open pit\n" + ) + stream = BytesIO(csv_content.encode("utf-8")) + preview = ingestion_service.preview_projects(stream, "projects.csv") + assert preview.stage_token is not None + token = preview.stage_token + + def _boom(self: ProjectRepository, project: Project) -> Project: + raise RuntimeError("boom") + + monkeypatch.setattr(ProjectRepository, "create", _boom) + + with pytest.raises(RuntimeError): + ingestion_service.commit_project_import(token) + + # Token should still be present for retry. + assert ingestion_service.get_staged_projects(token) is not None + + with unit_of_work_factory() as uow: + assert uow.projects is not None + assert uow.projects.count() == 0