Compare commits
5 Commits
e0fa3861a6
...
b1a0153a8d
| Author | SHA1 | Date | |
|---|---|---|---|
| b1a0153a8d | |||
| 609b0d779f | |||
| eaef99f0ac | |||
| 3bc124c11f | |||
| 7058eb4172 |
11
.env.example
11
.env.example
@@ -9,3 +9,14 @@ DATABASE_PASSWORD=<password>
|
||||
DATABASE_NAME=calminer
|
||||
# Optional: set a schema (comma-separated for multiple entries)
|
||||
# DATABASE_SCHEMA=public
|
||||
|
||||
# Default administrative credentials are provided at deployment time through environment variables
|
||||
# (`CALMINER_SEED_ADMIN_EMAIL`, `CALMINER_SEED_ADMIN_USERNAME`, `CALMINER_SEED_ADMIN_PASSWORD`, `CALMINER_SEED_ADMIN_ROLES`).
|
||||
# These values are consumed by a shared bootstrap helper on application startup, ensuring mandatory roles and the administrator account exist before any user interaction.
|
||||
CALMINER_SEED_ADMIN_EMAIL=<email>
|
||||
CALMINER_SEED_ADMIN_USERNAME=<username>
|
||||
CALMINER_SEED_ADMIN_PASSWORD=<password>
|
||||
CALMINER_SEED_ADMIN_ROLES=<roles>
|
||||
# Operators can request a managed credential reset by setting `CALMINER_SEED_FORCE=true`.
|
||||
# On the next startup the helper rotates the admin password and reapplies role assignments, so downstream environments must update stored secrets immediately after the reset.
|
||||
# CALMINER_SEED_FORCE=false
|
||||
@@ -30,3 +30,6 @@
|
||||
- Implemented environment-driven admin bootstrap settings, wired the `bootstrap_admin` helper into FastAPI startup, added pytest coverage for creation/idempotency/reset logic, and documented operational guidance in the RBAC plan and security concept.
|
||||
- Retired the legacy authentication RBAC implementation plan document after migrating its guidance into live documentation and synchronized the contributor instructions to reflect the removal.
|
||||
- Completed the Authentication & RBAC checklist by shipping the new models, migrations, repositories, guard dependencies, and integration tests.
|
||||
- Documented the project/scenario import/export field mapping and file format guidelines in `calminer-docs/requirements/FR-008.md`, and introduced `schemas/imports.py` with Pydantic models that normalise incoming CSV/Excel rows for projects and scenarios.
|
||||
- Added `services/importers.py` to load CSV/XLSX files into the new import schemas, pulled in `openpyxl` for Excel support, and covered the parsing behaviour with `tests/test_import_parsing.py`.
|
||||
- Expanded the import ingestion workflow with staging previews, transactional persistence commits, FastAPI preview/commit endpoints under `/imports`, and new API tests (`tests/test_import_ingestion.py`, `tests/test_import_api.py`) ensuring end-to-end coverage.
|
||||
|
||||
@@ -21,6 +21,7 @@ from services.session import (
|
||||
extract_session_tokens,
|
||||
)
|
||||
from services.unit_of_work import UnitOfWork
|
||||
from services.importers import ImportIngestionService
|
||||
|
||||
|
||||
def get_unit_of_work() -> Generator[UnitOfWork, None, None]:
|
||||
@@ -30,6 +31,15 @@ def get_unit_of_work() -> Generator[UnitOfWork, None, None]:
|
||||
yield uow
|
||||
|
||||
|
||||
_IMPORT_INGESTION_SERVICE = ImportIngestionService(lambda: UnitOfWork())
|
||||
|
||||
|
||||
def get_import_ingestion_service() -> ImportIngestionService:
|
||||
"""Provide singleton import ingestion service."""
|
||||
|
||||
return _IMPORT_INGESTION_SERVICE
|
||||
|
||||
|
||||
def get_application_settings() -> Settings:
|
||||
"""Provide cached application settings instance."""
|
||||
|
||||
|
||||
2
main.py
2
main.py
@@ -16,6 +16,7 @@ from models import (
|
||||
)
|
||||
from routes.auth import router as auth_router
|
||||
from routes.dashboard import router as dashboard_router
|
||||
from routes.imports import router as imports_router
|
||||
from routes.projects import router as projects_router
|
||||
from routes.scenarios import router as scenarios_router
|
||||
from services.bootstrap import bootstrap_admin
|
||||
@@ -61,6 +62,7 @@ async def ensure_admin_bootstrap() -> None:
|
||||
|
||||
app.include_router(dashboard_router)
|
||||
app.include_router(auth_router)
|
||||
app.include_router(imports_router)
|
||||
app.include_router(projects_router)
|
||||
app.include_router(scenarios_router)
|
||||
|
||||
|
||||
@@ -11,4 +11,5 @@ numpy
|
||||
passlib
|
||||
argon2-cffi
|
||||
python-jose
|
||||
python-multipart
|
||||
python-multipart
|
||||
openpyxl
|
||||
143
routes/imports.py
Normal file
143
routes/imports.py
Normal file
@@ -0,0 +1,143 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from io import BytesIO
|
||||
|
||||
from fastapi import APIRouter, Depends, File, HTTPException, UploadFile, status
|
||||
|
||||
from dependencies import get_import_ingestion_service, require_roles
|
||||
from models import User
|
||||
from schemas.imports import (
|
||||
ImportCommitRequest,
|
||||
ProjectImportCommitResponse,
|
||||
ProjectImportPreviewResponse,
|
||||
ScenarioImportCommitResponse,
|
||||
ScenarioImportPreviewResponse,
|
||||
)
|
||||
from services.importers import ImportIngestionService, UnsupportedImportFormat
|
||||
|
||||
router = APIRouter(prefix="/imports", tags=["Imports"])
|
||||
|
||||
MANAGE_ROLES = ("project_manager", "admin")
|
||||
|
||||
|
||||
async def _read_upload_file(upload: UploadFile) -> BytesIO:
|
||||
content = await upload.read()
|
||||
if not content:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Uploaded file is empty.",
|
||||
)
|
||||
return BytesIO(content)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/projects/preview",
|
||||
response_model=ProjectImportPreviewResponse,
|
||||
status_code=status.HTTP_200_OK,
|
||||
)
|
||||
async def preview_project_import(
|
||||
file: UploadFile = File(...,
|
||||
description="Project import file (CSV or Excel)"),
|
||||
_: User = Depends(require_roles(*MANAGE_ROLES)),
|
||||
ingestion_service: ImportIngestionService = Depends(
|
||||
get_import_ingestion_service),
|
||||
) -> ProjectImportPreviewResponse:
|
||||
if not file.filename:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Filename is required for import.",
|
||||
)
|
||||
|
||||
stream = await _read_upload_file(file)
|
||||
|
||||
try:
|
||||
preview = ingestion_service.preview_projects(stream, file.filename)
|
||||
except UnsupportedImportFormat as exc:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=str(exc),
|
||||
) from exc
|
||||
|
||||
return ProjectImportPreviewResponse.model_validate(preview)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/scenarios/preview",
|
||||
response_model=ScenarioImportPreviewResponse,
|
||||
status_code=status.HTTP_200_OK,
|
||||
)
|
||||
async def preview_scenario_import(
|
||||
file: UploadFile = File(...,
|
||||
description="Scenario import file (CSV or Excel)"),
|
||||
_: User = Depends(require_roles(*MANAGE_ROLES)),
|
||||
ingestion_service: ImportIngestionService = Depends(
|
||||
get_import_ingestion_service),
|
||||
) -> ScenarioImportPreviewResponse:
|
||||
if not file.filename:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Filename is required for import.",
|
||||
)
|
||||
|
||||
stream = await _read_upload_file(file)
|
||||
|
||||
try:
|
||||
preview = ingestion_service.preview_scenarios(stream, file.filename)
|
||||
except UnsupportedImportFormat as exc:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=str(exc),
|
||||
) from exc
|
||||
|
||||
return ScenarioImportPreviewResponse.model_validate(preview)
|
||||
|
||||
|
||||
def _value_error_status(exc: ValueError) -> int:
|
||||
detail = str(exc)
|
||||
if detail.lower().startswith("unknown"):
|
||||
return status.HTTP_404_NOT_FOUND
|
||||
return status.HTTP_400_BAD_REQUEST
|
||||
|
||||
|
||||
@router.post(
|
||||
"/projects/commit",
|
||||
response_model=ProjectImportCommitResponse,
|
||||
status_code=status.HTTP_200_OK,
|
||||
)
|
||||
async def commit_project_import_endpoint(
|
||||
payload: ImportCommitRequest,
|
||||
_: User = Depends(require_roles(*MANAGE_ROLES)),
|
||||
ingestion_service: ImportIngestionService = Depends(
|
||||
get_import_ingestion_service),
|
||||
) -> ProjectImportCommitResponse:
|
||||
try:
|
||||
result = ingestion_service.commit_project_import(payload.token)
|
||||
except ValueError as exc:
|
||||
raise HTTPException(
|
||||
status_code=_value_error_status(exc),
|
||||
detail=str(exc),
|
||||
) from exc
|
||||
|
||||
return ProjectImportCommitResponse.model_validate(result)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/scenarios/commit",
|
||||
response_model=ScenarioImportCommitResponse,
|
||||
status_code=status.HTTP_200_OK,
|
||||
)
|
||||
async def commit_scenario_import_endpoint(
|
||||
payload: ImportCommitRequest,
|
||||
_: User = Depends(require_roles(*MANAGE_ROLES)),
|
||||
ingestion_service: ImportIngestionService = Depends(
|
||||
get_import_ingestion_service),
|
||||
) -> ScenarioImportCommitResponse:
|
||||
try:
|
||||
result = ingestion_service.commit_scenario_import(payload.token)
|
||||
except ValueError as exc:
|
||||
raise HTTPException(
|
||||
status_code=_value_error_status(exc),
|
||||
detail=str(exc),
|
||||
) from exc
|
||||
|
||||
return ScenarioImportCommitResponse.model_validate(result)
|
||||
292
schemas/imports.py
Normal file
292
schemas/imports.py
Normal file
@@ -0,0 +1,292 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import date, datetime
|
||||
from typing import Any, Mapping
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator
|
||||
|
||||
from models import MiningOperationType, ResourceType, ScenarioStatus
|
||||
|
||||
PreviewStateLiteral = Literal["new", "update", "skip", "error"]
|
||||
|
||||
|
||||
def _normalise_string(value: Any) -> str:
|
||||
if value is None:
|
||||
return ""
|
||||
if isinstance(value, str):
|
||||
return value.strip()
|
||||
return str(value).strip()
|
||||
|
||||
|
||||
def _strip_or_none(value: Any | None) -> str | None:
|
||||
if value is None:
|
||||
return None
|
||||
text = _normalise_string(value)
|
||||
return text or None
|
||||
|
||||
|
||||
def _coerce_enum(value: Any, enum_cls: Any, aliases: Mapping[str, Any]) -> Any:
|
||||
if value is None:
|
||||
return value
|
||||
if isinstance(value, enum_cls):
|
||||
return value
|
||||
text = _normalise_string(value).lower()
|
||||
if not text:
|
||||
return None
|
||||
if text in aliases:
|
||||
return aliases[text]
|
||||
try:
|
||||
return enum_cls(text)
|
||||
except ValueError as exc: # pragma: no cover - surfaced by Pydantic
|
||||
raise ValueError(
|
||||
f"Invalid value '{value}' for {enum_cls.__name__}") from exc
|
||||
|
||||
|
||||
OPERATION_TYPE_ALIASES: dict[str, MiningOperationType] = {
|
||||
"open pit": MiningOperationType.OPEN_PIT,
|
||||
"openpit": MiningOperationType.OPEN_PIT,
|
||||
"underground": MiningOperationType.UNDERGROUND,
|
||||
"in-situ leach": MiningOperationType.IN_SITU_LEACH,
|
||||
"in situ": MiningOperationType.IN_SITU_LEACH,
|
||||
"placer": MiningOperationType.PLACER,
|
||||
"quarry": MiningOperationType.QUARRY,
|
||||
"mountaintop removal": MiningOperationType.MOUNTAINTOP_REMOVAL,
|
||||
"other": MiningOperationType.OTHER,
|
||||
}
|
||||
|
||||
|
||||
SCENARIO_STATUS_ALIASES: dict[str, ScenarioStatus] = {
|
||||
"draft": ScenarioStatus.DRAFT,
|
||||
"active": ScenarioStatus.ACTIVE,
|
||||
"archived": ScenarioStatus.ARCHIVED,
|
||||
}
|
||||
|
||||
|
||||
RESOURCE_TYPE_ALIASES: dict[str, ResourceType] = {
|
||||
key.replace("_", " ").lower(): value for key, value in ResourceType.__members__.items()
|
||||
}
|
||||
RESOURCE_TYPE_ALIASES.update(
|
||||
{value.value.replace("_", " ").lower(): value for value in ResourceType}
|
||||
)
|
||||
|
||||
|
||||
class ProjectImportRow(BaseModel):
|
||||
name: str
|
||||
location: str | None = None
|
||||
operation_type: MiningOperationType
|
||||
description: str | None = None
|
||||
created_at: datetime | None = None
|
||||
updated_at: datetime | None = None
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
@field_validator("name", mode="before")
|
||||
@classmethod
|
||||
def validate_name(cls, value: Any) -> str:
|
||||
text = _normalise_string(value)
|
||||
if not text:
|
||||
raise ValueError("Project name is required")
|
||||
return text
|
||||
|
||||
@field_validator("location", "description", mode="before")
|
||||
@classmethod
|
||||
def optional_text(cls, value: Any | None) -> str | None:
|
||||
return _strip_or_none(value)
|
||||
|
||||
@field_validator("operation_type", mode="before")
|
||||
@classmethod
|
||||
def map_operation_type(cls, value: Any) -> MiningOperationType | None:
|
||||
return _coerce_enum(value, MiningOperationType, OPERATION_TYPE_ALIASES)
|
||||
|
||||
|
||||
class ScenarioImportRow(BaseModel):
|
||||
project_name: str
|
||||
name: str
|
||||
status: ScenarioStatus = ScenarioStatus.DRAFT
|
||||
start_date: date | None = None
|
||||
end_date: date | None = None
|
||||
discount_rate: float | None = None
|
||||
currency: str | None = None
|
||||
primary_resource: ResourceType | None = None
|
||||
description: str | None = None
|
||||
created_at: datetime | None = None
|
||||
updated_at: datetime | None = None
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
@field_validator("project_name", "name", mode="before")
|
||||
@classmethod
|
||||
def validate_required_text(cls, value: Any, info) -> str:
|
||||
text = _normalise_string(value)
|
||||
if not text:
|
||||
raise ValueError(
|
||||
f"{info.field_name.replace('_', ' ').title()} is required")
|
||||
return text
|
||||
|
||||
@field_validator("status", mode="before")
|
||||
@classmethod
|
||||
def map_status(cls, value: Any) -> ScenarioStatus | None:
|
||||
return _coerce_enum(value, ScenarioStatus, SCENARIO_STATUS_ALIASES)
|
||||
|
||||
@field_validator("primary_resource", mode="before")
|
||||
@classmethod
|
||||
def map_resource(cls, value: Any) -> ResourceType | None:
|
||||
return _coerce_enum(value, ResourceType, RESOURCE_TYPE_ALIASES)
|
||||
|
||||
@field_validator("description", mode="before")
|
||||
@classmethod
|
||||
def optional_description(cls, value: Any | None) -> str | None:
|
||||
return _strip_or_none(value)
|
||||
|
||||
@field_validator("currency", mode="before")
|
||||
@classmethod
|
||||
def normalise_currency(cls, value: Any | None) -> str | None:
|
||||
if value is None:
|
||||
return None
|
||||
text = _normalise_string(value).upper()
|
||||
if not text:
|
||||
return None
|
||||
if len(text) != 3:
|
||||
raise ValueError("Currency code must be a 3-letter ISO value")
|
||||
return text
|
||||
|
||||
@field_validator("discount_rate", mode="before")
|
||||
@classmethod
|
||||
def coerce_discount_rate(cls, value: Any | None) -> float | None:
|
||||
if value is None:
|
||||
return None
|
||||
if isinstance(value, (int, float)):
|
||||
return float(value)
|
||||
text = _normalise_string(value)
|
||||
if not text:
|
||||
return None
|
||||
if text.endswith("%"):
|
||||
text = text[:-1]
|
||||
try:
|
||||
return float(text)
|
||||
except ValueError as exc:
|
||||
raise ValueError("Discount rate must be numeric") from exc
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_dates(self) -> "ScenarioImportRow":
|
||||
if self.start_date and self.end_date and self.start_date > self.end_date:
|
||||
raise ValueError("End date must be on or after start date")
|
||||
return self
|
||||
|
||||
|
||||
class ImportRowErrorModel(BaseModel):
|
||||
row_number: int
|
||||
field: str | None = None
|
||||
message: str
|
||||
|
||||
model_config = ConfigDict(from_attributes=True, extra="forbid")
|
||||
|
||||
|
||||
class ImportPreviewRowIssueModel(BaseModel):
|
||||
message: str
|
||||
field: str | None = None
|
||||
|
||||
model_config = ConfigDict(from_attributes=True, extra="forbid")
|
||||
|
||||
|
||||
class ImportPreviewRowIssuesModel(BaseModel):
|
||||
row_number: int
|
||||
state: PreviewStateLiteral | None = None
|
||||
issues: list[ImportPreviewRowIssueModel] = Field(default_factory=list)
|
||||
|
||||
model_config = ConfigDict(from_attributes=True, extra="forbid")
|
||||
|
||||
|
||||
class ImportPreviewSummaryModel(BaseModel):
|
||||
total_rows: int
|
||||
accepted: int
|
||||
skipped: int
|
||||
errored: int
|
||||
|
||||
model_config = ConfigDict(from_attributes=True, extra="forbid")
|
||||
|
||||
|
||||
class ProjectImportPreviewRow(BaseModel):
|
||||
row_number: int
|
||||
data: ProjectImportRow
|
||||
state: PreviewStateLiteral
|
||||
issues: list[str] = Field(default_factory=list)
|
||||
context: dict[str, Any] | None = None
|
||||
|
||||
model_config = ConfigDict(from_attributes=True, extra="forbid")
|
||||
|
||||
|
||||
class ScenarioImportPreviewRow(BaseModel):
|
||||
row_number: int
|
||||
data: ScenarioImportRow
|
||||
state: PreviewStateLiteral
|
||||
issues: list[str] = Field(default_factory=list)
|
||||
context: dict[str, Any] | None = None
|
||||
|
||||
model_config = ConfigDict(from_attributes=True, extra="forbid")
|
||||
|
||||
|
||||
class ProjectImportPreviewResponse(BaseModel):
|
||||
rows: list[ProjectImportPreviewRow]
|
||||
summary: ImportPreviewSummaryModel
|
||||
row_issues: list[ImportPreviewRowIssuesModel] = Field(default_factory=list)
|
||||
parser_errors: list[ImportRowErrorModel] = Field(default_factory=list)
|
||||
stage_token: str | None = None
|
||||
|
||||
model_config = ConfigDict(from_attributes=True, extra="forbid")
|
||||
|
||||
|
||||
class ScenarioImportPreviewResponse(BaseModel):
|
||||
rows: list[ScenarioImportPreviewRow]
|
||||
summary: ImportPreviewSummaryModel
|
||||
row_issues: list[ImportPreviewRowIssuesModel] = Field(default_factory=list)
|
||||
parser_errors: list[ImportRowErrorModel] = Field(default_factory=list)
|
||||
stage_token: str | None = None
|
||||
|
||||
model_config = ConfigDict(from_attributes=True, extra="forbid")
|
||||
|
||||
|
||||
class ImportCommitSummaryModel(BaseModel):
|
||||
created: int
|
||||
updated: int
|
||||
|
||||
model_config = ConfigDict(from_attributes=True, extra="forbid")
|
||||
|
||||
|
||||
class ProjectImportCommitRow(BaseModel):
|
||||
row_number: int
|
||||
data: ProjectImportRow
|
||||
context: dict[str, Any]
|
||||
|
||||
model_config = ConfigDict(from_attributes=True, extra="forbid")
|
||||
|
||||
|
||||
class ScenarioImportCommitRow(BaseModel):
|
||||
row_number: int
|
||||
data: ScenarioImportRow
|
||||
context: dict[str, Any]
|
||||
|
||||
model_config = ConfigDict(from_attributes=True, extra="forbid")
|
||||
|
||||
|
||||
class ProjectImportCommitResponse(BaseModel):
|
||||
token: str
|
||||
rows: list[ProjectImportCommitRow]
|
||||
summary: ImportCommitSummaryModel
|
||||
|
||||
model_config = ConfigDict(from_attributes=True, extra="forbid")
|
||||
|
||||
|
||||
class ScenarioImportCommitResponse(BaseModel):
|
||||
token: str
|
||||
rows: list[ScenarioImportCommitRow]
|
||||
summary: ImportCommitSummaryModel
|
||||
|
||||
model_config = ConfigDict(from_attributes=True, extra="forbid")
|
||||
|
||||
|
||||
class ImportCommitRequest(BaseModel):
|
||||
token: str
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
707
services/importers.py
Normal file
707
services/importers.py
Normal file
@@ -0,0 +1,707 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Any, BinaryIO, Callable, Generic, Iterable, Mapping, TypeVar, cast
|
||||
from uuid import uuid4
|
||||
from types import MappingProxyType
|
||||
|
||||
import pandas as pd
|
||||
from pandas import DataFrame
|
||||
from pydantic import BaseModel, ValidationError
|
||||
|
||||
from models import Project, Scenario
|
||||
from schemas.imports import ProjectImportRow, ScenarioImportRow
|
||||
from services.unit_of_work import UnitOfWork
|
||||
|
||||
TImportRow = TypeVar("TImportRow", bound=BaseModel)
|
||||
|
||||
PROJECT_COLUMNS: tuple[str, ...] = (
|
||||
"name",
|
||||
"location",
|
||||
"operation_type",
|
||||
"description",
|
||||
"created_at",
|
||||
"updated_at",
|
||||
)
|
||||
|
||||
SCENARIO_COLUMNS: tuple[str, ...] = (
|
||||
"project_name",
|
||||
"name",
|
||||
"status",
|
||||
"start_date",
|
||||
"end_date",
|
||||
"discount_rate",
|
||||
"currency",
|
||||
"primary_resource",
|
||||
"description",
|
||||
"created_at",
|
||||
"updated_at",
|
||||
)
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class ImportRowError:
|
||||
row_number: int
|
||||
field: str | None
|
||||
message: str
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class ParsedImportRow(Generic[TImportRow]):
|
||||
row_number: int
|
||||
data: TImportRow
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class ImportResult(Generic[TImportRow]):
|
||||
rows: list[ParsedImportRow[TImportRow]]
|
||||
errors: list[ImportRowError]
|
||||
|
||||
|
||||
class UnsupportedImportFormat(ValueError):
|
||||
pass
|
||||
|
||||
|
||||
class ImportPreviewState(str, Enum):
|
||||
NEW = "new"
|
||||
UPDATE = "update"
|
||||
SKIP = "skip"
|
||||
ERROR = "error"
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class ImportPreviewRow(Generic[TImportRow]):
|
||||
row_number: int
|
||||
data: TImportRow
|
||||
state: ImportPreviewState
|
||||
issues: list[str]
|
||||
context: dict[str, Any] | None = None
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class ImportPreviewSummary:
|
||||
total_rows: int
|
||||
accepted: int
|
||||
skipped: int
|
||||
errored: int
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class ImportPreview(Generic[TImportRow]):
|
||||
rows: list[ImportPreviewRow[TImportRow]]
|
||||
summary: ImportPreviewSummary
|
||||
row_issues: list["ImportPreviewRowIssues"]
|
||||
parser_errors: list[ImportRowError]
|
||||
stage_token: str | None
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class StagedRow(Generic[TImportRow]):
|
||||
parsed: ParsedImportRow[TImportRow]
|
||||
context: dict[str, Any]
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class ImportPreviewRowIssue:
|
||||
message: str
|
||||
field: str | None = None
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class ImportPreviewRowIssues:
|
||||
row_number: int
|
||||
state: ImportPreviewState | None
|
||||
issues: list[ImportPreviewRowIssue]
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class StagedImport(Generic[TImportRow]):
|
||||
token: str
|
||||
rows: list[StagedRow[TImportRow]]
|
||||
|
||||
|
||||
@dataclass(slots=True, frozen=True)
|
||||
class StagedRowView(Generic[TImportRow]):
|
||||
row_number: int
|
||||
data: TImportRow
|
||||
context: Mapping[str, Any]
|
||||
|
||||
|
||||
@dataclass(slots=True, frozen=True)
|
||||
class StagedImportView(Generic[TImportRow]):
|
||||
token: str
|
||||
rows: tuple[StagedRowView[TImportRow], ...]
|
||||
|
||||
|
||||
@dataclass(slots=True, frozen=True)
|
||||
class ImportCommitSummary:
|
||||
created: int
|
||||
updated: int
|
||||
|
||||
|
||||
@dataclass(slots=True, frozen=True)
|
||||
class ImportCommitResult(Generic[TImportRow]):
|
||||
token: str
|
||||
rows: tuple[StagedRowView[TImportRow], ...]
|
||||
summary: ImportCommitSummary
|
||||
|
||||
|
||||
UnitOfWorkFactory = Callable[[], UnitOfWork]
|
||||
|
||||
|
||||
class ImportIngestionService:
|
||||
"""Coordinates parsing, validation, and preview staging for imports."""
|
||||
|
||||
def __init__(self, uow_factory: UnitOfWorkFactory) -> None:
|
||||
self._uow_factory = uow_factory
|
||||
self._project_stage: dict[str, StagedImport[ProjectImportRow]] = {}
|
||||
self._scenario_stage: dict[str, StagedImport[ScenarioImportRow]] = {}
|
||||
|
||||
def preview_projects(
|
||||
self,
|
||||
stream: BinaryIO,
|
||||
filename: str,
|
||||
) -> ImportPreview[ProjectImportRow]:
|
||||
result = load_project_imports(stream, filename)
|
||||
parser_errors = result.errors
|
||||
|
||||
preview_rows: list[ImportPreviewRow[ProjectImportRow]] = []
|
||||
staged_rows: list[StagedRow[ProjectImportRow]] = []
|
||||
accepted = skipped = errored = 0
|
||||
|
||||
seen_names: set[str] = set()
|
||||
|
||||
existing_by_name: dict[str, Project] = {}
|
||||
if result.rows:
|
||||
with self._uow_factory() as uow:
|
||||
if not uow.projects:
|
||||
raise RuntimeError("Project repository is unavailable")
|
||||
existing_by_name = dict(
|
||||
uow.projects.find_by_names(
|
||||
parsed.data.name for parsed in result.rows
|
||||
)
|
||||
)
|
||||
|
||||
for parsed in result.rows:
|
||||
name_key = _normalise_key(parsed.data.name)
|
||||
issues: list[str] = []
|
||||
context: dict[str, Any] | None = None
|
||||
state = ImportPreviewState.NEW
|
||||
|
||||
if name_key in seen_names:
|
||||
state = ImportPreviewState.SKIP
|
||||
issues.append(
|
||||
"Duplicate project name within upload; row skipped.")
|
||||
else:
|
||||
seen_names.add(name_key)
|
||||
existing = existing_by_name.get(name_key)
|
||||
if existing:
|
||||
state = ImportPreviewState.UPDATE
|
||||
context = {
|
||||
"mode": "update",
|
||||
"project_id": existing.id,
|
||||
}
|
||||
issues.append("Existing project will be updated.")
|
||||
else:
|
||||
context = {"mode": "create"}
|
||||
|
||||
preview_rows.append(
|
||||
ImportPreviewRow(
|
||||
row_number=parsed.row_number,
|
||||
data=parsed.data,
|
||||
state=state,
|
||||
issues=issues,
|
||||
context=context,
|
||||
)
|
||||
)
|
||||
|
||||
if state in {ImportPreviewState.NEW, ImportPreviewState.UPDATE}:
|
||||
accepted += 1
|
||||
staged_rows.append(
|
||||
StagedRow(parsed=parsed, context=context or {
|
||||
"mode": "create"})
|
||||
)
|
||||
elif state == ImportPreviewState.SKIP:
|
||||
skipped += 1
|
||||
else:
|
||||
errored += 1
|
||||
|
||||
parser_error_rows = {error.row_number for error in parser_errors}
|
||||
errored += len(parser_error_rows)
|
||||
total_rows = len(preview_rows) + len(parser_error_rows)
|
||||
|
||||
summary = ImportPreviewSummary(
|
||||
total_rows=total_rows,
|
||||
accepted=accepted,
|
||||
skipped=skipped,
|
||||
errored=errored,
|
||||
)
|
||||
|
||||
row_issues = _compile_row_issues(preview_rows, parser_errors)
|
||||
|
||||
stage_token: str | None = None
|
||||
if staged_rows:
|
||||
stage_token = self._store_project_stage(staged_rows)
|
||||
|
||||
return ImportPreview(
|
||||
rows=preview_rows,
|
||||
summary=summary,
|
||||
row_issues=row_issues,
|
||||
parser_errors=parser_errors,
|
||||
stage_token=stage_token,
|
||||
)
|
||||
|
||||
def preview_scenarios(
|
||||
self,
|
||||
stream: BinaryIO,
|
||||
filename: str,
|
||||
) -> ImportPreview[ScenarioImportRow]:
|
||||
result = load_scenario_imports(stream, filename)
|
||||
parser_errors = result.errors
|
||||
|
||||
preview_rows: list[ImportPreviewRow[ScenarioImportRow]] = []
|
||||
staged_rows: list[StagedRow[ScenarioImportRow]] = []
|
||||
accepted = skipped = errored = 0
|
||||
|
||||
seen_pairs: set[tuple[str, str]] = set()
|
||||
|
||||
existing_projects: dict[str, Project] = {}
|
||||
existing_scenarios: dict[tuple[int, str], Scenario] = {}
|
||||
|
||||
if result.rows:
|
||||
with self._uow_factory() as uow:
|
||||
if not uow.projects or not uow.scenarios:
|
||||
raise RuntimeError("Repositories are unavailable")
|
||||
|
||||
existing_projects = dict(
|
||||
uow.projects.find_by_names(
|
||||
parsed.data.project_name for parsed in result.rows
|
||||
)
|
||||
)
|
||||
|
||||
names_by_project: dict[int, set[str]] = {}
|
||||
for parsed in result.rows:
|
||||
project = existing_projects.get(
|
||||
_normalise_key(parsed.data.project_name)
|
||||
)
|
||||
if not project:
|
||||
continue
|
||||
names_by_project.setdefault(project.id, set()).add(
|
||||
_normalise_key(parsed.data.name)
|
||||
)
|
||||
|
||||
for project_id, names in names_by_project.items():
|
||||
matches = uow.scenarios.find_by_project_and_names(
|
||||
project_id, names)
|
||||
for name_key, scenario in matches.items():
|
||||
existing_scenarios[(project_id, name_key)] = scenario
|
||||
|
||||
for parsed in result.rows:
|
||||
project_key = _normalise_key(parsed.data.project_name)
|
||||
scenario_key = _normalise_key(parsed.data.name)
|
||||
issues: list[str] = []
|
||||
context: dict[str, Any] | None = None
|
||||
state = ImportPreviewState.NEW
|
||||
|
||||
if (project_key, scenario_key) in seen_pairs:
|
||||
state = ImportPreviewState.SKIP
|
||||
issues.append(
|
||||
"Duplicate scenario for project within upload; row skipped."
|
||||
)
|
||||
else:
|
||||
seen_pairs.add((project_key, scenario_key))
|
||||
project = existing_projects.get(project_key)
|
||||
if not project:
|
||||
state = ImportPreviewState.ERROR
|
||||
issues.append(
|
||||
f"Project '{parsed.data.project_name}' does not exist."
|
||||
)
|
||||
else:
|
||||
context = {"mode": "create", "project_id": project.id}
|
||||
existing = existing_scenarios.get(
|
||||
(project.id, scenario_key))
|
||||
if existing:
|
||||
state = ImportPreviewState.UPDATE
|
||||
context = {
|
||||
"mode": "update",
|
||||
"project_id": project.id,
|
||||
"scenario_id": existing.id,
|
||||
}
|
||||
issues.append("Existing scenario will be updated.")
|
||||
|
||||
preview_rows.append(
|
||||
ImportPreviewRow(
|
||||
row_number=parsed.row_number,
|
||||
data=parsed.data,
|
||||
state=state,
|
||||
issues=issues,
|
||||
context=context,
|
||||
)
|
||||
)
|
||||
|
||||
if state in {ImportPreviewState.NEW, ImportPreviewState.UPDATE}:
|
||||
accepted += 1
|
||||
staged_rows.append(
|
||||
StagedRow(parsed=parsed, context=context or {
|
||||
"mode": "create"})
|
||||
)
|
||||
elif state == ImportPreviewState.SKIP:
|
||||
skipped += 1
|
||||
else:
|
||||
errored += 1
|
||||
|
||||
parser_error_rows = {error.row_number for error in parser_errors}
|
||||
errored += len(parser_error_rows)
|
||||
total_rows = len(preview_rows) + len(parser_error_rows)
|
||||
|
||||
summary = ImportPreviewSummary(
|
||||
total_rows=total_rows,
|
||||
accepted=accepted,
|
||||
skipped=skipped,
|
||||
errored=errored,
|
||||
)
|
||||
|
||||
row_issues = _compile_row_issues(preview_rows, parser_errors)
|
||||
|
||||
stage_token: str | None = None
|
||||
if staged_rows:
|
||||
stage_token = self._store_scenario_stage(staged_rows)
|
||||
|
||||
return ImportPreview(
|
||||
rows=preview_rows,
|
||||
summary=summary,
|
||||
row_issues=row_issues,
|
||||
parser_errors=parser_errors,
|
||||
stage_token=stage_token,
|
||||
)
|
||||
|
||||
def get_staged_projects(
|
||||
self, token: str
|
||||
) -> StagedImportView[ProjectImportRow] | None:
|
||||
staged = self._project_stage.get(token)
|
||||
if not staged:
|
||||
return None
|
||||
return _build_staged_view(staged)
|
||||
|
||||
def get_staged_scenarios(
|
||||
self, token: str
|
||||
) -> StagedImportView[ScenarioImportRow] | None:
|
||||
staged = self._scenario_stage.get(token)
|
||||
if not staged:
|
||||
return None
|
||||
return _build_staged_view(staged)
|
||||
|
||||
def consume_staged_projects(
|
||||
self, token: str
|
||||
) -> StagedImportView[ProjectImportRow] | None:
|
||||
staged = self._project_stage.pop(token, None)
|
||||
if not staged:
|
||||
return None
|
||||
return _build_staged_view(staged)
|
||||
|
||||
def consume_staged_scenarios(
|
||||
self, token: str
|
||||
) -> StagedImportView[ScenarioImportRow] | None:
|
||||
staged = self._scenario_stage.pop(token, None)
|
||||
if not staged:
|
||||
return None
|
||||
return _build_staged_view(staged)
|
||||
|
||||
def clear_staged_projects(self, token: str) -> bool:
|
||||
return self._project_stage.pop(token, None) is not None
|
||||
|
||||
def clear_staged_scenarios(self, token: str) -> bool:
|
||||
return self._scenario_stage.pop(token, None) is not None
|
||||
|
||||
def commit_project_import(self, token: str) -> ImportCommitResult[ProjectImportRow]:
|
||||
staged = self._project_stage.get(token)
|
||||
if not staged:
|
||||
raise ValueError(f"Unknown project import token: {token}")
|
||||
|
||||
staged_view = _build_staged_view(staged)
|
||||
created = updated = 0
|
||||
|
||||
with self._uow_factory() as uow:
|
||||
if not uow.projects:
|
||||
raise RuntimeError("Project repository is unavailable")
|
||||
|
||||
for row in staged.rows:
|
||||
mode = row.context.get("mode")
|
||||
data = row.parsed.data
|
||||
|
||||
if mode == "create":
|
||||
project = Project(
|
||||
name=data.name,
|
||||
location=data.location,
|
||||
operation_type=data.operation_type,
|
||||
description=data.description,
|
||||
)
|
||||
if data.created_at:
|
||||
project.created_at = data.created_at
|
||||
if data.updated_at:
|
||||
project.updated_at = data.updated_at
|
||||
uow.projects.create(project)
|
||||
created += 1
|
||||
elif mode == "update":
|
||||
project_id = row.context.get("project_id")
|
||||
if not project_id:
|
||||
raise ValueError(
|
||||
"Staged project update is missing project_id context"
|
||||
)
|
||||
project = uow.projects.get(project_id)
|
||||
project.name = data.name
|
||||
project.location = data.location
|
||||
project.operation_type = data.operation_type
|
||||
project.description = data.description
|
||||
if data.created_at:
|
||||
project.created_at = data.created_at
|
||||
if data.updated_at:
|
||||
project.updated_at = data.updated_at
|
||||
updated += 1
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported staged project mode: {mode!r}")
|
||||
|
||||
self._project_stage.pop(token, None)
|
||||
return ImportCommitResult(
|
||||
token=token,
|
||||
rows=staged_view.rows,
|
||||
summary=ImportCommitSummary(created=created, updated=updated),
|
||||
)
|
||||
|
||||
def commit_scenario_import(self, token: str) -> ImportCommitResult[ScenarioImportRow]:
|
||||
staged = self._scenario_stage.get(token)
|
||||
if not staged:
|
||||
raise ValueError(f"Unknown scenario import token: {token}")
|
||||
|
||||
staged_view = _build_staged_view(staged)
|
||||
created = updated = 0
|
||||
|
||||
with self._uow_factory() as uow:
|
||||
if not uow.scenarios or not uow.projects:
|
||||
raise RuntimeError("Scenario repositories are unavailable")
|
||||
|
||||
for row in staged.rows:
|
||||
mode = row.context.get("mode")
|
||||
data = row.parsed.data
|
||||
|
||||
project_id = row.context.get("project_id")
|
||||
if not project_id:
|
||||
raise ValueError(
|
||||
"Staged scenario row is missing project_id context"
|
||||
)
|
||||
|
||||
project = uow.projects.get(project_id)
|
||||
|
||||
if mode == "create":
|
||||
scenario = Scenario(
|
||||
project_id=project.id,
|
||||
name=data.name,
|
||||
status=data.status,
|
||||
start_date=data.start_date,
|
||||
end_date=data.end_date,
|
||||
discount_rate=data.discount_rate,
|
||||
currency=data.currency,
|
||||
primary_resource=data.primary_resource,
|
||||
description=data.description,
|
||||
)
|
||||
if data.created_at:
|
||||
scenario.created_at = data.created_at
|
||||
if data.updated_at:
|
||||
scenario.updated_at = data.updated_at
|
||||
uow.scenarios.create(scenario)
|
||||
created += 1
|
||||
elif mode == "update":
|
||||
scenario_id = row.context.get("scenario_id")
|
||||
if not scenario_id:
|
||||
raise ValueError(
|
||||
"Staged scenario update is missing scenario_id context"
|
||||
)
|
||||
scenario = uow.scenarios.get(scenario_id)
|
||||
scenario.project_id = project.id
|
||||
scenario.name = data.name
|
||||
scenario.status = data.status
|
||||
scenario.start_date = data.start_date
|
||||
scenario.end_date = data.end_date
|
||||
scenario.discount_rate = data.discount_rate
|
||||
scenario.currency = data.currency
|
||||
scenario.primary_resource = data.primary_resource
|
||||
scenario.description = data.description
|
||||
if data.created_at:
|
||||
scenario.created_at = data.created_at
|
||||
if data.updated_at:
|
||||
scenario.updated_at = data.updated_at
|
||||
updated += 1
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported staged scenario mode: {mode!r}")
|
||||
|
||||
self._scenario_stage.pop(token, None)
|
||||
return ImportCommitResult(
|
||||
token=token,
|
||||
rows=staged_view.rows,
|
||||
summary=ImportCommitSummary(created=created, updated=updated),
|
||||
)
|
||||
|
||||
def _store_project_stage(
|
||||
self, rows: list[StagedRow[ProjectImportRow]]
|
||||
) -> str:
|
||||
token = str(uuid4())
|
||||
self._project_stage[token] = StagedImport(token=token, rows=rows)
|
||||
return token
|
||||
|
||||
def _store_scenario_stage(
|
||||
self, rows: list[StagedRow[ScenarioImportRow]]
|
||||
) -> str:
|
||||
token = str(uuid4())
|
||||
self._scenario_stage[token] = StagedImport(token=token, rows=rows)
|
||||
return token
|
||||
|
||||
|
||||
def load_project_imports(stream: BinaryIO, filename: str) -> ImportResult[ProjectImportRow]:
|
||||
df = _load_dataframe(stream, filename)
|
||||
return _parse_dataframe(df, ProjectImportRow, PROJECT_COLUMNS)
|
||||
|
||||
|
||||
def load_scenario_imports(stream: BinaryIO, filename: str) -> ImportResult[ScenarioImportRow]:
|
||||
df = _load_dataframe(stream, filename)
|
||||
return _parse_dataframe(df, ScenarioImportRow, SCENARIO_COLUMNS)
|
||||
|
||||
|
||||
def _load_dataframe(stream: BinaryIO, filename: str) -> DataFrame:
|
||||
stream.seek(0)
|
||||
suffix = Path(filename).suffix.lower()
|
||||
if suffix == ".csv":
|
||||
df = pd.read_csv(stream, dtype=str,
|
||||
keep_default_na=False, encoding="utf-8")
|
||||
elif suffix in {".xls", ".xlsx"}:
|
||||
df = pd.read_excel(stream, dtype=str, engine="openpyxl")
|
||||
else:
|
||||
raise UnsupportedImportFormat(
|
||||
f"Unsupported file type: {suffix or 'unknown'}")
|
||||
df.columns = [str(col).strip().lower() for col in df.columns]
|
||||
return df
|
||||
|
||||
|
||||
def _parse_dataframe(
|
||||
df: DataFrame,
|
||||
model: type[TImportRow],
|
||||
expected_columns: Iterable[str],
|
||||
) -> ImportResult[TImportRow]:
|
||||
rows: list[ParsedImportRow[TImportRow]] = []
|
||||
errors: list[ImportRowError] = []
|
||||
for index, raw in enumerate(df.to_dict(orient="records"), start=2):
|
||||
payload = _prepare_payload(
|
||||
cast(dict[str, object], raw), expected_columns)
|
||||
try:
|
||||
rows.append(
|
||||
ParsedImportRow(row_number=index, data=model(**payload))
|
||||
)
|
||||
except ValidationError as exc: # pragma: no cover - exercised via tests
|
||||
for detail in exc.errors():
|
||||
loc = ".".join(str(part)
|
||||
for part in detail.get("loc", [])) or None
|
||||
errors.append(
|
||||
ImportRowError(
|
||||
row_number=index,
|
||||
field=loc,
|
||||
message=detail.get("msg", "Invalid value"),
|
||||
)
|
||||
)
|
||||
return ImportResult(rows=rows, errors=errors)
|
||||
|
||||
|
||||
def _prepare_payload(
|
||||
raw: dict[str, object], expected_columns: Iterable[str]
|
||||
) -> dict[str, object | None]:
|
||||
payload: dict[str, object | None] = {}
|
||||
for column in expected_columns:
|
||||
if column not in raw:
|
||||
continue
|
||||
value = raw.get(column)
|
||||
if isinstance(value, str):
|
||||
value = value.strip()
|
||||
if value == "":
|
||||
value = None
|
||||
if value is not None and pd.isna(cast(Any, value)):
|
||||
value = None
|
||||
payload[column] = value
|
||||
return payload
|
||||
|
||||
|
||||
def _normalise_key(value: str) -> str:
|
||||
return value.strip().lower()
|
||||
|
||||
|
||||
def _build_staged_view(
|
||||
staged: StagedImport[TImportRow],
|
||||
) -> StagedImportView[TImportRow]:
|
||||
rows = tuple(
|
||||
StagedRowView(
|
||||
row_number=row.parsed.row_number,
|
||||
data=cast(TImportRow, _deep_copy_model(row.parsed.data)),
|
||||
context=MappingProxyType(dict(row.context)),
|
||||
)
|
||||
for row in staged.rows
|
||||
)
|
||||
return StagedImportView(token=staged.token, rows=rows)
|
||||
|
||||
|
||||
def _deep_copy_model(model: BaseModel) -> BaseModel:
|
||||
copy_method = getattr(model, "model_copy", None)
|
||||
if callable(copy_method): # pydantic v2
|
||||
return cast(BaseModel, copy_method(deep=True))
|
||||
return model.copy(deep=True) # type: ignore[attr-defined]
|
||||
|
||||
|
||||
def _compile_row_issues(
|
||||
preview_rows: Iterable[ImportPreviewRow[Any]],
|
||||
parser_errors: Iterable[ImportRowError],
|
||||
) -> list[ImportPreviewRowIssues]:
|
||||
issue_map: dict[int, ImportPreviewRowIssues] = {}
|
||||
|
||||
def ensure_bundle(
|
||||
row_number: int,
|
||||
state: ImportPreviewState | None,
|
||||
) -> ImportPreviewRowIssues:
|
||||
bundle = issue_map.get(row_number)
|
||||
if bundle is None:
|
||||
bundle = ImportPreviewRowIssues(
|
||||
row_number=row_number,
|
||||
state=state,
|
||||
issues=[],
|
||||
)
|
||||
issue_map[row_number] = bundle
|
||||
else:
|
||||
if _state_priority(state) > _state_priority(bundle.state):
|
||||
bundle.state = state
|
||||
return bundle
|
||||
|
||||
for row in preview_rows:
|
||||
if not row.issues:
|
||||
continue
|
||||
bundle = ensure_bundle(row.row_number, row.state)
|
||||
for message in row.issues:
|
||||
bundle.issues.append(ImportPreviewRowIssue(message=message))
|
||||
|
||||
for error in parser_errors:
|
||||
bundle = ensure_bundle(error.row_number, ImportPreviewState.ERROR)
|
||||
bundle.issues.append(
|
||||
ImportPreviewRowIssue(message=error.message, field=error.field)
|
||||
)
|
||||
|
||||
return sorted(issue_map.values(), key=lambda item: item.row_number)
|
||||
|
||||
|
||||
def _state_priority(state: ImportPreviewState | None) -> int:
|
||||
if state is None:
|
||||
return -1
|
||||
if state == ImportPreviewState.ERROR:
|
||||
return 3
|
||||
if state == ImportPreviewState.SKIP:
|
||||
return 2
|
||||
if state == ImportPreviewState.UPDATE:
|
||||
return 1
|
||||
return 0
|
||||
@@ -2,7 +2,7 @@ from __future__ import annotations
|
||||
|
||||
from collections.abc import Iterable
|
||||
from datetime import datetime
|
||||
from typing import Sequence
|
||||
from typing import Mapping, Sequence
|
||||
|
||||
from sqlalchemy import select, func
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
@@ -70,6 +70,15 @@ class ProjectRepository:
|
||||
"Project violates uniqueness constraints") from exc
|
||||
return project
|
||||
|
||||
def find_by_names(self, names: Iterable[str]) -> Mapping[str, Project]:
|
||||
normalised = {name.strip().lower()
|
||||
for name in names if name and name.strip()}
|
||||
if not normalised:
|
||||
return {}
|
||||
stmt = select(Project).where(func.lower(Project.name).in_(normalised))
|
||||
records = self.session.execute(stmt).scalars().all()
|
||||
return {project.name.lower(): project for project in records}
|
||||
|
||||
def delete(self, project_id: int) -> None:
|
||||
project = self.get(project_id)
|
||||
self.session.delete(project)
|
||||
@@ -149,6 +158,25 @@ class ScenarioRepository:
|
||||
raise EntityConflictError("Scenario violates constraints") from exc
|
||||
return scenario
|
||||
|
||||
def find_by_project_and_names(
|
||||
self,
|
||||
project_id: int,
|
||||
names: Iterable[str],
|
||||
) -> Mapping[str, Scenario]:
|
||||
normalised = {name.strip().lower()
|
||||
for name in names if name and name.strip()}
|
||||
if not normalised:
|
||||
return {}
|
||||
stmt = (
|
||||
select(Scenario)
|
||||
.where(
|
||||
Scenario.project_id == project_id,
|
||||
func.lower(Scenario.name).in_(normalised),
|
||||
)
|
||||
)
|
||||
records = self.session.execute(stmt).scalars().all()
|
||||
return {scenario.name.lower(): scenario for scenario in records}
|
||||
|
||||
def delete(self, scenario_id: int) -> None:
|
||||
scenario = self.get(scenario_id)
|
||||
self.session.delete(scenario)
|
||||
|
||||
@@ -11,12 +11,14 @@ from sqlalchemy.orm import sessionmaker
|
||||
from sqlalchemy.pool import StaticPool
|
||||
|
||||
from config.database import Base
|
||||
from dependencies import get_auth_session, get_unit_of_work
|
||||
from dependencies import get_auth_session, get_import_ingestion_service, get_unit_of_work
|
||||
from models import User
|
||||
from routes.auth import router as auth_router
|
||||
from routes.dashboard import router as dashboard_router
|
||||
from routes.projects import router as projects_router
|
||||
from routes.scenarios import router as scenarios_router
|
||||
from routes.imports import router as imports_router
|
||||
from services.importers import ImportIngestionService
|
||||
from services.unit_of_work import UnitOfWork
|
||||
from services.session import AuthSession, SessionTokens
|
||||
|
||||
@@ -51,6 +53,7 @@ def app(session_factory: sessionmaker) -> FastAPI:
|
||||
application.include_router(dashboard_router)
|
||||
application.include_router(projects_router)
|
||||
application.include_router(scenarios_router)
|
||||
application.include_router(imports_router)
|
||||
|
||||
def _override_uow() -> Iterator[UnitOfWork]:
|
||||
with UnitOfWork(session_factory=session_factory) as uow:
|
||||
@@ -58,6 +61,18 @@ def app(session_factory: sessionmaker) -> FastAPI:
|
||||
|
||||
application.dependency_overrides[get_unit_of_work] = _override_uow
|
||||
|
||||
def _ingestion_uow_factory() -> UnitOfWork:
|
||||
return UnitOfWork(session_factory=session_factory)
|
||||
|
||||
ingestion_service = ImportIngestionService(_ingestion_uow_factory)
|
||||
|
||||
def _override_ingestion_service() -> ImportIngestionService:
|
||||
return ingestion_service
|
||||
|
||||
application.dependency_overrides[
|
||||
get_import_ingestion_service
|
||||
] = _override_ingestion_service
|
||||
|
||||
with UnitOfWork(session_factory=session_factory) as uow:
|
||||
assert uow.users is not None
|
||||
uow.ensure_default_roles()
|
||||
|
||||
70
tests/test_import_api.py
Normal file
70
tests/test_import_api.py
Normal file
@@ -0,0 +1,70 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from models.project import MiningOperationType, Project
|
||||
from models.scenario import Scenario, ScenarioStatus
|
||||
|
||||
|
||||
def test_project_import_preview_and_commit_flow(
|
||||
client: TestClient,
|
||||
unit_of_work_factory,
|
||||
) -> None:
|
||||
with unit_of_work_factory() as uow:
|
||||
assert uow.projects is not None
|
||||
existing = Project(
|
||||
name="Existing Project",
|
||||
location="Chile",
|
||||
operation_type=MiningOperationType.OPEN_PIT,
|
||||
)
|
||||
uow.projects.create(existing)
|
||||
|
||||
csv_content = (
|
||||
"name,location,operation_type\n"
|
||||
"Existing Project,Peru,underground\n"
|
||||
"New Project,Canada,open pit\n"
|
||||
)
|
||||
|
||||
preview_response = client.post(
|
||||
"/imports/projects/preview",
|
||||
files={"file": ("projects.csv", csv_content, "text/csv")},
|
||||
)
|
||||
assert preview_response.status_code == 200
|
||||
preview_data = preview_response.json()
|
||||
assert preview_data["summary"]["accepted"] == 2
|
||||
token = preview_data["stage_token"]
|
||||
assert token
|
||||
|
||||
commit_response = client.post(
|
||||
"/imports/projects/commit",
|
||||
json={"token": token},
|
||||
)
|
||||
assert commit_response.status_code == 200
|
||||
commit_data = commit_response.json()
|
||||
assert commit_data["summary"] == {"created": 1, "updated": 1}
|
||||
|
||||
with unit_of_work_factory() as uow:
|
||||
assert uow.projects is not None
|
||||
projects = {project.name: project for project in uow.projects.list()}
|
||||
assert "Existing Project" in projects and "New Project" in projects
|
||||
assert (
|
||||
projects["Existing Project"].operation_type
|
||||
== MiningOperationType.UNDERGROUND
|
||||
)
|
||||
|
||||
repeat_commit = client.post(
|
||||
"/imports/projects/commit",
|
||||
json={"token": token},
|
||||
)
|
||||
assert repeat_commit.status_code == 404
|
||||
|
||||
|
||||
def test_scenario_import_commit_invalid_token_returns_404(
|
||||
client: TestClient,
|
||||
) -> None:
|
||||
response = client.post(
|
||||
"/imports/scenarios/commit",
|
||||
json={"token": "missing-token"},
|
||||
)
|
||||
assert response.status_code == 404
|
||||
assert "Unknown scenario import token" in response.json()["detail"]
|
||||
395
tests/test_import_ingestion.py
Normal file
395
tests/test_import_ingestion.py
Normal file
@@ -0,0 +1,395 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from io import BytesIO
|
||||
from typing import Callable
|
||||
|
||||
import pandas as pd
|
||||
import pytest
|
||||
|
||||
from models.project import MiningOperationType, Project
|
||||
from models.scenario import Scenario, ScenarioStatus
|
||||
from services.importers import (
|
||||
ImportCommitResult,
|
||||
ImportIngestionService,
|
||||
ImportPreviewState,
|
||||
StagedImportView,
|
||||
)
|
||||
from services.repositories import ProjectRepository
|
||||
from services.unit_of_work import UnitOfWork
|
||||
from schemas.imports import (
|
||||
ProjectImportCommitResponse,
|
||||
ProjectImportPreviewResponse,
|
||||
ScenarioImportCommitResponse,
|
||||
ScenarioImportPreviewResponse,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def ingestion_service(unit_of_work_factory: Callable[[], UnitOfWork]) -> ImportIngestionService:
|
||||
return ImportIngestionService(unit_of_work_factory)
|
||||
|
||||
|
||||
def test_preview_projects_flags_updates_and_duplicates(
|
||||
ingestion_service: ImportIngestionService,
|
||||
unit_of_work_factory: Callable[[], UnitOfWork],
|
||||
) -> None:
|
||||
with unit_of_work_factory() as uow:
|
||||
assert uow.projects is not None
|
||||
existing = Project(
|
||||
name="Project A",
|
||||
location="Chile",
|
||||
operation_type=MiningOperationType.OPEN_PIT,
|
||||
)
|
||||
uow.projects.create(existing)
|
||||
|
||||
csv_content = (
|
||||
"name,location,operation_type\n"
|
||||
"Project A,Peru,open pit\n"
|
||||
"Project B,Canada,underground\n"
|
||||
"Project B,Canada,underground\n"
|
||||
)
|
||||
stream = BytesIO(csv_content.encode("utf-8"))
|
||||
|
||||
preview = ingestion_service.preview_projects(stream, "projects.csv")
|
||||
|
||||
states = [row.state for row in preview.rows]
|
||||
assert states == [
|
||||
ImportPreviewState.UPDATE,
|
||||
ImportPreviewState.NEW,
|
||||
ImportPreviewState.SKIP,
|
||||
]
|
||||
assert preview.summary.total_rows == 3
|
||||
assert preview.summary.accepted == 2
|
||||
assert preview.summary.skipped == 1
|
||||
assert preview.summary.errored == 0
|
||||
assert preview.parser_errors == []
|
||||
assert preview.stage_token is not None
|
||||
issue_map = {bundle.row_number: bundle for bundle in preview.row_issues}
|
||||
assert 2 in issue_map and issue_map[2].state == ImportPreviewState.UPDATE
|
||||
assert {
|
||||
detail.message for detail in issue_map[2].issues
|
||||
} == {"Existing project will be updated."}
|
||||
assert 4 in issue_map and issue_map[4].state == ImportPreviewState.SKIP
|
||||
assert any(
|
||||
"Duplicate project name" in detail.message
|
||||
for detail in issue_map[4].issues
|
||||
)
|
||||
# type: ignore[attr-defined]
|
||||
staged = ingestion_service._project_stage[preview.stage_token]
|
||||
assert len(staged.rows) == 2
|
||||
update_context = preview.rows[0].context
|
||||
assert update_context is not None and update_context.get(
|
||||
"project_id") is not None
|
||||
|
||||
response_model = ProjectImportPreviewResponse.model_validate(preview)
|
||||
assert response_model.summary.accepted == preview.summary.accepted
|
||||
|
||||
|
||||
def test_preview_scenarios_validates_projects_and_updates(
|
||||
ingestion_service: ImportIngestionService,
|
||||
unit_of_work_factory: Callable[[], UnitOfWork],
|
||||
) -> None:
|
||||
with unit_of_work_factory() as uow:
|
||||
assert uow.projects is not None and uow.scenarios is not None
|
||||
project = Project(
|
||||
name="Existing Project",
|
||||
location="Chile",
|
||||
operation_type=MiningOperationType.OPEN_PIT,
|
||||
)
|
||||
uow.projects.create(project)
|
||||
scenario = Scenario(
|
||||
project_id=project.id,
|
||||
name="Existing Scenario",
|
||||
status=ScenarioStatus.ACTIVE,
|
||||
)
|
||||
uow.scenarios.create(scenario)
|
||||
|
||||
df = pd.DataFrame(
|
||||
[
|
||||
{
|
||||
"project_name": "Existing Project",
|
||||
"name": "Existing Scenario",
|
||||
"status": "Active",
|
||||
},
|
||||
{
|
||||
"project_name": "Existing Project",
|
||||
"name": "New Scenario",
|
||||
"status": "Draft",
|
||||
},
|
||||
{
|
||||
"project_name": "Missing Project",
|
||||
"name": "Ghost Scenario",
|
||||
"status": "Draft",
|
||||
},
|
||||
{
|
||||
"project_name": "Existing Project",
|
||||
"name": "New Scenario",
|
||||
"status": "Draft",
|
||||
},
|
||||
]
|
||||
)
|
||||
buffer = BytesIO()
|
||||
df.to_csv(buffer, index=False)
|
||||
buffer.seek(0)
|
||||
|
||||
preview = ingestion_service.preview_scenarios(buffer, "scenarios.csv")
|
||||
|
||||
states = [row.state for row in preview.rows]
|
||||
assert states == [
|
||||
ImportPreviewState.UPDATE,
|
||||
ImportPreviewState.NEW,
|
||||
ImportPreviewState.ERROR,
|
||||
ImportPreviewState.SKIP,
|
||||
]
|
||||
assert preview.summary.total_rows == 4
|
||||
assert preview.summary.accepted == 2
|
||||
assert preview.summary.skipped == 1
|
||||
assert preview.summary.errored == 1
|
||||
assert preview.stage_token is not None
|
||||
issue_map = {bundle.row_number: bundle for bundle in preview.row_issues}
|
||||
assert 2 in issue_map and issue_map[2].state == ImportPreviewState.UPDATE
|
||||
assert 4 in issue_map and issue_map[4].state == ImportPreviewState.ERROR
|
||||
assert any(
|
||||
"does not exist" in detail.message
|
||||
for detail in issue_map[4].issues
|
||||
)
|
||||
# type: ignore[attr-defined]
|
||||
staged = ingestion_service._scenario_stage[preview.stage_token]
|
||||
assert len(staged.rows) == 2
|
||||
error_row = preview.rows[2]
|
||||
assert any("does not exist" in msg for msg in error_row.issues)
|
||||
|
||||
response_model = ScenarioImportPreviewResponse.model_validate(preview)
|
||||
assert response_model.summary.errored == preview.summary.errored
|
||||
|
||||
|
||||
def test_preview_scenarios_aggregates_parser_errors(
|
||||
ingestion_service: ImportIngestionService,
|
||||
unit_of_work_factory: Callable[[], UnitOfWork],
|
||||
) -> None:
|
||||
with unit_of_work_factory() as uow:
|
||||
assert uow.projects is not None
|
||||
project = Project(
|
||||
name="Existing Project",
|
||||
location="Chile",
|
||||
operation_type=MiningOperationType.OPEN_PIT,
|
||||
)
|
||||
uow.projects.create(project)
|
||||
|
||||
csv_content = (
|
||||
"project_name,name,status\n"
|
||||
"Existing Project,Broken Scenario,UNKNOWN_STATUS\n"
|
||||
)
|
||||
stream = BytesIO(csv_content.encode("utf-8"))
|
||||
|
||||
preview = ingestion_service.preview_scenarios(stream, "invalid.csv")
|
||||
|
||||
assert preview.rows == []
|
||||
assert preview.summary.total_rows == 1
|
||||
assert preview.summary.errored == 1
|
||||
assert preview.stage_token is None
|
||||
assert len(preview.parser_errors) == 1
|
||||
issue_map = {bundle.row_number: bundle for bundle in preview.row_issues}
|
||||
assert 2 in issue_map
|
||||
bundle = issue_map[2]
|
||||
assert bundle.state == ImportPreviewState.ERROR
|
||||
assert any(detail.field == "status" for detail in bundle.issues)
|
||||
assert all(detail.message for detail in bundle.issues)
|
||||
|
||||
response_model = ScenarioImportPreviewResponse.model_validate(preview)
|
||||
assert response_model.summary.total_rows == 1
|
||||
|
||||
|
||||
def test_consume_staged_projects_removes_token(
|
||||
ingestion_service: ImportIngestionService,
|
||||
unit_of_work_factory: Callable[[], UnitOfWork],
|
||||
) -> None:
|
||||
with unit_of_work_factory() as uow:
|
||||
assert uow.projects is not None
|
||||
|
||||
csv_content = (
|
||||
"name,location,operation_type\n"
|
||||
"Project X,Peru,open pit\n"
|
||||
)
|
||||
stream = BytesIO(csv_content.encode("utf-8"))
|
||||
|
||||
preview = ingestion_service.preview_projects(stream, "projects.csv")
|
||||
assert preview.stage_token is not None
|
||||
token = preview.stage_token
|
||||
|
||||
initial_view = ingestion_service.get_staged_projects(token)
|
||||
assert isinstance(initial_view, StagedImportView)
|
||||
consumed = ingestion_service.consume_staged_projects(token)
|
||||
assert consumed == initial_view
|
||||
assert ingestion_service.get_staged_projects(token) is None
|
||||
assert ingestion_service.consume_staged_projects(token) is None
|
||||
|
||||
|
||||
def test_clear_staged_scenarios_drops_entry(
|
||||
ingestion_service: ImportIngestionService,
|
||||
unit_of_work_factory: Callable[[], UnitOfWork],
|
||||
) -> None:
|
||||
with unit_of_work_factory() as uow:
|
||||
assert uow.projects is not None
|
||||
project = Project(
|
||||
name="Project Y",
|
||||
location="Chile",
|
||||
operation_type=MiningOperationType.OPEN_PIT,
|
||||
)
|
||||
uow.projects.create(project)
|
||||
|
||||
csv_content = (
|
||||
"project_name,name,status\n"
|
||||
"Project Y,Scenario 1,Active\n"
|
||||
)
|
||||
stream = BytesIO(csv_content.encode("utf-8"))
|
||||
|
||||
preview = ingestion_service.preview_scenarios(stream, "scenarios.csv")
|
||||
assert preview.stage_token is not None
|
||||
token = preview.stage_token
|
||||
|
||||
assert ingestion_service.get_staged_scenarios(token) is not None
|
||||
assert ingestion_service.clear_staged_scenarios(token) is True
|
||||
assert ingestion_service.get_staged_scenarios(token) is None
|
||||
assert ingestion_service.clear_staged_scenarios(token) is False
|
||||
|
||||
|
||||
def test_commit_project_import_applies_create_and_update(
|
||||
ingestion_service: ImportIngestionService,
|
||||
unit_of_work_factory: Callable[[], UnitOfWork],
|
||||
) -> None:
|
||||
with unit_of_work_factory() as uow:
|
||||
assert uow.projects is not None
|
||||
existing = Project(
|
||||
name="Project A",
|
||||
location="Chile",
|
||||
operation_type=MiningOperationType.OPEN_PIT,
|
||||
)
|
||||
uow.projects.create(existing)
|
||||
|
||||
csv_content = (
|
||||
"name,location,operation_type\n"
|
||||
"Project A,Peru,underground\n"
|
||||
"Project B,Canada,open pit\n"
|
||||
)
|
||||
stream = BytesIO(csv_content.encode("utf-8"))
|
||||
preview = ingestion_service.preview_projects(stream, "projects.csv")
|
||||
assert preview.stage_token is not None
|
||||
|
||||
result = ingestion_service.commit_project_import(preview.stage_token)
|
||||
assert isinstance(result, ImportCommitResult)
|
||||
assert result.summary.created == 1
|
||||
assert result.summary.updated == 1
|
||||
assert ingestion_service.get_staged_projects(preview.stage_token) is None
|
||||
|
||||
commit_response = ProjectImportCommitResponse.model_validate(result)
|
||||
assert commit_response.summary.updated == 1
|
||||
|
||||
with unit_of_work_factory() as uow:
|
||||
assert uow.projects is not None
|
||||
projects = uow.projects.list()
|
||||
names = sorted(project.name for project in projects)
|
||||
assert names == ["Project A", "Project B"]
|
||||
updated_project = next(p for p in projects if p.name == "Project A")
|
||||
assert updated_project.location == "Peru"
|
||||
assert updated_project.operation_type == MiningOperationType.UNDERGROUND
|
||||
new_project = next(p for p in projects if p.name == "Project B")
|
||||
assert new_project.location == "Canada"
|
||||
|
||||
|
||||
def test_commit_project_import_with_invalid_token_raises(
|
||||
ingestion_service: ImportIngestionService,
|
||||
) -> None:
|
||||
with pytest.raises(ValueError):
|
||||
ingestion_service.commit_project_import("missing-token")
|
||||
|
||||
|
||||
def test_commit_scenario_import_applies_create_and_update(
|
||||
ingestion_service: ImportIngestionService,
|
||||
unit_of_work_factory: Callable[[], UnitOfWork],
|
||||
) -> None:
|
||||
with unit_of_work_factory() as uow:
|
||||
assert uow.projects is not None and uow.scenarios is not None
|
||||
project = Project(
|
||||
name="Project X",
|
||||
location="Chile",
|
||||
operation_type=MiningOperationType.OPEN_PIT,
|
||||
)
|
||||
uow.projects.create(project)
|
||||
scenario = Scenario(
|
||||
project_id=project.id,
|
||||
name="Existing Scenario",
|
||||
status=ScenarioStatus.ACTIVE,
|
||||
)
|
||||
uow.scenarios.create(scenario)
|
||||
|
||||
csv_content = (
|
||||
"project_name,name,status\n"
|
||||
"Project X,Existing Scenario,Archived\n"
|
||||
"Project X,New Scenario,Draft\n"
|
||||
)
|
||||
stream = BytesIO(csv_content.encode("utf-8"))
|
||||
preview = ingestion_service.preview_scenarios(stream, "scenarios.csv")
|
||||
assert preview.stage_token is not None
|
||||
|
||||
result = ingestion_service.commit_scenario_import(preview.stage_token)
|
||||
assert result.summary.created == 1
|
||||
assert result.summary.updated == 1
|
||||
assert ingestion_service.get_staged_scenarios(preview.stage_token) is None
|
||||
|
||||
commit_response = ScenarioImportCommitResponse.model_validate(result)
|
||||
assert commit_response.summary.created == 1
|
||||
|
||||
with unit_of_work_factory() as uow:
|
||||
assert uow.projects is not None and uow.scenarios is not None
|
||||
scenarios = uow.scenarios.list_for_project(uow.projects.list()[0].id)
|
||||
names = sorted(scenario.name for scenario in scenarios)
|
||||
assert names == ["Existing Scenario", "New Scenario"]
|
||||
updated_scenario = next(
|
||||
item for item in scenarios if item.name == "Existing Scenario"
|
||||
)
|
||||
assert updated_scenario.status == ScenarioStatus.ARCHIVED
|
||||
new_scenario = next(
|
||||
item for item in scenarios if item.name == "New Scenario"
|
||||
)
|
||||
assert new_scenario.status == ScenarioStatus.DRAFT
|
||||
|
||||
|
||||
def test_commit_scenario_import_with_invalid_token_raises(
|
||||
ingestion_service: ImportIngestionService,
|
||||
) -> None:
|
||||
with pytest.raises(ValueError):
|
||||
ingestion_service.commit_scenario_import("missing-token")
|
||||
|
||||
|
||||
def test_commit_project_import_rolls_back_on_failure(
|
||||
ingestion_service: ImportIngestionService,
|
||||
unit_of_work_factory: Callable[[], UnitOfWork],
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
with unit_of_work_factory() as uow:
|
||||
assert uow.projects is not None
|
||||
|
||||
csv_content = (
|
||||
"name,location,operation_type\n"
|
||||
"Project Fail,Peru,open pit\n"
|
||||
)
|
||||
stream = BytesIO(csv_content.encode("utf-8"))
|
||||
preview = ingestion_service.preview_projects(stream, "projects.csv")
|
||||
assert preview.stage_token is not None
|
||||
token = preview.stage_token
|
||||
|
||||
def _boom(self: ProjectRepository, project: Project) -> Project:
|
||||
raise RuntimeError("boom")
|
||||
|
||||
monkeypatch.setattr(ProjectRepository, "create", _boom)
|
||||
|
||||
with pytest.raises(RuntimeError):
|
||||
ingestion_service.commit_project_import(token)
|
||||
|
||||
# Token should still be present for retry.
|
||||
assert ingestion_service.get_staged_projects(token) is not None
|
||||
|
||||
with unit_of_work_factory() as uow:
|
||||
assert uow.projects is not None
|
||||
assert uow.projects.count() == 0
|
||||
78
tests/test_import_parsing.py
Normal file
78
tests/test_import_parsing.py
Normal file
@@ -0,0 +1,78 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from io import BytesIO
|
||||
|
||||
import pandas as pd
|
||||
import pytest
|
||||
|
||||
from services.importers import ImportResult, load_project_imports, load_scenario_imports
|
||||
from schemas.imports import ProjectImportRow, ScenarioImportRow
|
||||
|
||||
|
||||
def test_load_project_imports_from_csv() -> None:
|
||||
csv_content = (
|
||||
"name,location,operation_type,description\n"
|
||||
"Project A,Chile,open pit,First project\n"
|
||||
"Project B,,underground,Second project\n"
|
||||
)
|
||||
stream = BytesIO(csv_content.encode("utf-8"))
|
||||
|
||||
result = load_project_imports(stream, "projects.csv")
|
||||
|
||||
assert isinstance(result, ImportResult)
|
||||
assert len(result.rows) == 2
|
||||
assert not result.errors
|
||||
first = result.rows[0]
|
||||
assert first.row_number == 2
|
||||
assert isinstance(first.data, ProjectImportRow)
|
||||
assert first.data.name == "Project A"
|
||||
assert first.data.operation_type.value == "open_pit"
|
||||
second = result.rows[1]
|
||||
assert second.row_number == 3
|
||||
assert isinstance(second.data, ProjectImportRow)
|
||||
assert second.data.location is None
|
||||
|
||||
|
||||
def test_load_scenario_imports_from_excel() -> None:
|
||||
df = pd.DataFrame(
|
||||
[
|
||||
{
|
||||
"project_name": "Project A",
|
||||
"name": "Scenario 1",
|
||||
"status": "Active",
|
||||
"start_date": "2025-01-01",
|
||||
"end_date": "2025-12-31",
|
||||
"discount_rate": "7.5%",
|
||||
"currency": "usd",
|
||||
"primary_resource": "Electricity",
|
||||
}
|
||||
]
|
||||
)
|
||||
buffer = BytesIO()
|
||||
df.to_excel(buffer, index=False)
|
||||
buffer.seek(0)
|
||||
|
||||
result = load_scenario_imports(buffer, "scenarios.xlsx")
|
||||
|
||||
assert len(result.rows) == 1
|
||||
assert not result.errors
|
||||
row = result.rows[0]
|
||||
assert row.row_number == 2
|
||||
assert isinstance(row.data, ScenarioImportRow)
|
||||
assert row.data.status.value == "active"
|
||||
assert row.data.currency == "USD"
|
||||
assert row.data.discount_rate == pytest.approx(7.5)
|
||||
|
||||
|
||||
def test_import_errors_include_row_numbers() -> None:
|
||||
csv_content = "name,operation_type\n,open pit\n"
|
||||
stream = BytesIO(csv_content.encode("utf-8"))
|
||||
|
||||
result = load_project_imports(stream, "projects.csv")
|
||||
|
||||
assert len(result.rows) == 0
|
||||
assert len(result.errors) == 1
|
||||
error = result.errors[0]
|
||||
assert error.row_number == 2
|
||||
assert error.field == "name"
|
||||
assert "required" in error.message
|
||||
Reference in New Issue
Block a user