feat: implement CRUD APIs for projects and scenarios with validated schemas

This commit is contained in:
2025-11-09 17:23:10 +01:00
parent 8bf46b80c8
commit 61b42b3041
14 changed files with 380 additions and 38 deletions

View File

@@ -101,8 +101,10 @@ def upgrade() -> None:
sa.Column("location", sa.String(length=255), nullable=True), sa.Column("location", sa.String(length=255), nullable=True),
sa.Column("operation_type", mining_operation_type, nullable=False), sa.Column("operation_type", mining_operation_type, nullable=False),
sa.Column("description", sa.Text(), nullable=True), sa.Column("description", sa.Text(), nullable=True),
sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False), sa.Column("created_at", sa.DateTime(timezone=True),
sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False), server_default=sa.func.now(), nullable=False),
sa.Column("updated_at", sa.DateTime(timezone=True),
server_default=sa.func.now(), nullable=False),
sa.PrimaryKeyConstraint("id"), sa.PrimaryKeyConstraint("id"),
sa.UniqueConstraint("name"), sa.UniqueConstraint("name"),
) )
@@ -117,16 +119,21 @@ def upgrade() -> None:
sa.Column("status", scenario_status, nullable=False), sa.Column("status", scenario_status, nullable=False),
sa.Column("start_date", sa.Date(), nullable=True), sa.Column("start_date", sa.Date(), nullable=True),
sa.Column("end_date", sa.Date(), nullable=True), sa.Column("end_date", sa.Date(), nullable=True),
sa.Column("discount_rate", sa.Numeric(precision=5, scale=2), nullable=True), sa.Column("discount_rate", sa.Numeric(
precision=5, scale=2), nullable=True),
sa.Column("currency", sa.String(length=3), nullable=True), sa.Column("currency", sa.String(length=3), nullable=True),
sa.Column("primary_resource", resource_type, nullable=True), sa.Column("primary_resource", resource_type, nullable=True),
sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False), sa.Column("created_at", sa.DateTime(timezone=True),
sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False), server_default=sa.func.now(), nullable=False),
sa.ForeignKeyConstraint(["project_id"], ["projects.id"], ondelete="CASCADE"), sa.Column("updated_at", sa.DateTime(timezone=True),
server_default=sa.func.now(), nullable=False),
sa.ForeignKeyConstraint(
["project_id"], ["projects.id"], ondelete="CASCADE"),
sa.PrimaryKeyConstraint("id"), sa.PrimaryKeyConstraint("id"),
) )
op.create_index(op.f("ix_scenarios_id"), "scenarios", ["id"], unique=False) op.create_index(op.f("ix_scenarios_id"), "scenarios", ["id"], unique=False)
op.create_index(op.f("ix_scenarios_project_id"), "scenarios", ["project_id"], unique=False) op.create_index(op.f("ix_scenarios_project_id"),
"scenarios", ["project_id"], unique=False)
op.create_table( op.create_table(
"financial_inputs", "financial_inputs",
@@ -139,13 +146,18 @@ def upgrade() -> None:
sa.Column("currency", sa.String(length=3), nullable=True), sa.Column("currency", sa.String(length=3), nullable=True),
sa.Column("effective_date", sa.Date(), nullable=True), sa.Column("effective_date", sa.Date(), nullable=True),
sa.Column("notes", sa.Text(), nullable=True), sa.Column("notes", sa.Text(), nullable=True),
sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False), sa.Column("created_at", sa.DateTime(timezone=True),
sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False), server_default=sa.func.now(), nullable=False),
sa.ForeignKeyConstraint(["scenario_id"], ["scenarios.id"], ondelete="CASCADE"), sa.Column("updated_at", sa.DateTime(timezone=True),
server_default=sa.func.now(), nullable=False),
sa.ForeignKeyConstraint(
["scenario_id"], ["scenarios.id"], ondelete="CASCADE"),
sa.PrimaryKeyConstraint("id"), sa.PrimaryKeyConstraint("id"),
) )
op.create_index(op.f("ix_financial_inputs_id"), "financial_inputs", ["id"], unique=False) op.create_index(op.f("ix_financial_inputs_id"),
op.create_index(op.f("ix_financial_inputs_scenario_id"), "financial_inputs", ["scenario_id"], unique=False) "financial_inputs", ["id"], unique=False)
op.create_index(op.f("ix_financial_inputs_scenario_id"),
"financial_inputs", ["scenario_id"], unique=False)
op.create_table( op.create_table(
"simulation_parameters", "simulation_parameters",
@@ -155,28 +167,41 @@ def upgrade() -> None:
sa.Column("distribution", distribution_type, nullable=False), sa.Column("distribution", distribution_type, nullable=False),
sa.Column("variable", stochastic_variable, nullable=True), sa.Column("variable", stochastic_variable, nullable=True),
sa.Column("resource_type", resource_type, nullable=True), sa.Column("resource_type", resource_type, nullable=True),
sa.Column("mean_value", sa.Numeric(precision=18, scale=4), nullable=True), sa.Column("mean_value", sa.Numeric(
sa.Column("standard_deviation", sa.Numeric(precision=18, scale=4), nullable=True), precision=18, scale=4), nullable=True),
sa.Column("minimum_value", sa.Numeric(precision=18, scale=4), nullable=True), sa.Column("standard_deviation", sa.Numeric(
sa.Column("maximum_value", sa.Numeric(precision=18, scale=4), nullable=True), precision=18, scale=4), nullable=True),
sa.Column("minimum_value", sa.Numeric(
precision=18, scale=4), nullable=True),
sa.Column("maximum_value", sa.Numeric(
precision=18, scale=4), nullable=True),
sa.Column("unit", sa.String(length=32), nullable=True), sa.Column("unit", sa.String(length=32), nullable=True),
sa.Column("metadata", sa.JSON(), nullable=True), sa.Column("configuration", sa.JSON(), nullable=True),
sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False), sa.Column("created_at", sa.DateTime(timezone=True),
sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False), server_default=sa.func.now(), nullable=False),
sa.ForeignKeyConstraint(["scenario_id"], ["scenarios.id"], ondelete="CASCADE"), sa.Column("updated_at", sa.DateTime(timezone=True),
server_default=sa.func.now(), nullable=False),
sa.ForeignKeyConstraint(
["scenario_id"], ["scenarios.id"], ondelete="CASCADE"),
sa.PrimaryKeyConstraint("id"), sa.PrimaryKeyConstraint("id"),
) )
op.create_index(op.f("ix_simulation_parameters_id"), "simulation_parameters", ["id"], unique=False) op.create_index(op.f("ix_simulation_parameters_id"),
op.create_index(op.f("ix_simulation_parameters_scenario_id"), "simulation_parameters", ["scenario_id"], unique=False) "simulation_parameters", ["id"], unique=False)
op.create_index(op.f("ix_simulation_parameters_scenario_id"),
"simulation_parameters", ["scenario_id"], unique=False)
def downgrade() -> None: def downgrade() -> None:
op.drop_index(op.f("ix_simulation_parameters_scenario_id"), table_name="simulation_parameters") op.drop_index(op.f("ix_simulation_parameters_scenario_id"),
op.drop_index(op.f("ix_simulation_parameters_id"), table_name="simulation_parameters") table_name="simulation_parameters")
op.drop_index(op.f("ix_simulation_parameters_id"),
table_name="simulation_parameters")
op.drop_table("simulation_parameters") op.drop_table("simulation_parameters")
op.drop_index(op.f("ix_financial_inputs_scenario_id"), table_name="financial_inputs") op.drop_index(op.f("ix_financial_inputs_scenario_id"),
op.drop_index(op.f("ix_financial_inputs_id"), table_name="financial_inputs") table_name="financial_inputs")
op.drop_index(op.f("ix_financial_inputs_id"),
table_name="financial_inputs")
op.drop_table("financial_inputs") op.drop_table("financial_inputs")
op.drop_index(op.f("ix_scenarios_project_id"), table_name="scenarios") op.drop_index(op.f("ix_scenarios_project_id"), table_name="scenarios")

View File

@@ -6,3 +6,4 @@
- Added core SQLAlchemy domain models, shared metadata descriptors, and Alembic migration setup (with initial schema snapshot) to establish the persistence layer foundation. - Added core SQLAlchemy domain models, shared metadata descriptors, and Alembic migration setup (with initial schema snapshot) to establish the persistence layer foundation.
- Introduced repository and unit-of-work helpers for projects, scenarios, financial inputs, and simulation parameters to support service-layer operations. - Introduced repository and unit-of-work helpers for projects, scenarios, financial inputs, and simulation parameters to support service-layer operations.
- Added SQLite-backed pytest coverage for repository and unit-of-work behaviours to validate persistence interactions. - Added SQLite-backed pytest coverage for repository and unit-of-work behaviours to validate persistence interactions.
- Exposed project and scenario CRUD APIs with validated schemas and integrated them into the FastAPI application.

1
config/__init__.py Normal file
View File

@@ -0,0 +1 @@
"""Configuration package."""

12
dependencies.py Normal file
View File

@@ -0,0 +1,12 @@
from __future__ import annotations
from collections.abc import Generator
from services.unit_of_work import UnitOfWork
def get_unit_of_work() -> Generator[UnitOfWork, None, None]:
"""FastAPI dependency yielding a unit-of-work instance."""
with UnitOfWork() as uow:
yield uow

View File

@@ -10,6 +10,8 @@ from models import (
Scenario, Scenario,
SimulationParameter, SimulationParameter,
) )
from routes.projects import router as projects_router
from routes.scenarios import router as scenarios_router
# Initialize database schema (imports above ensure models are registered) # Initialize database schema (imports above ensure models are registered)
Base.metadata.create_all(bind=engine) Base.metadata.create_all(bind=engine)
@@ -29,5 +31,7 @@ async def health() -> dict[str, str]:
return {"status": "ok"} return {"status": "ok"}
app.mount("/static", StaticFiles(directory="static"), name="static") app.include_router(projects_router)
app.include_router(scenarios_router)
app.mount("/static", StaticFiles(directory="static"), name="static")

View File

@@ -52,12 +52,16 @@ class SimulationParameter(Base):
resource_type: Mapped[ResourceType | None] = mapped_column( resource_type: Mapped[ResourceType | None] = mapped_column(
SQLEnum(ResourceType), nullable=True SQLEnum(ResourceType), nullable=True
) )
mean_value: Mapped[float | None] = mapped_column(Numeric(18, 4), nullable=True) mean_value: Mapped[float | None] = mapped_column(
standard_deviation: Mapped[float | None] = mapped_column(Numeric(18, 4), nullable=True) Numeric(18, 4), nullable=True)
minimum_value: Mapped[float | None] = mapped_column(Numeric(18, 4), nullable=True) standard_deviation: Mapped[float | None] = mapped_column(
maximum_value: Mapped[float | None] = mapped_column(Numeric(18, 4), nullable=True) Numeric(18, 4), nullable=True)
minimum_value: Mapped[float | None] = mapped_column(
Numeric(18, 4), nullable=True)
maximum_value: Mapped[float | None] = mapped_column(
Numeric(18, 4), nullable=True)
unit: Mapped[str | None] = mapped_column(String(32), nullable=True) unit: Mapped[str | None] = mapped_column(String(32), nullable=True)
metadata: Mapped[dict | None] = mapped_column(JSON, nullable=True) configuration: Mapped[dict | None] = mapped_column(JSON, nullable=True)
created_at: Mapped[datetime] = mapped_column( created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), nullable=False, server_default=func.now() DateTime(timezone=True), nullable=False, server_default=func.now()
) )

View File

@@ -14,3 +14,6 @@ exclude = '''
)/ )/
''' '''
[tool.pytest.ini_options]
pythonpath = ["."]

1
routes/__init__.py Normal file
View File

@@ -0,0 +1 @@
"""API route registrations."""

76
routes/projects.py Normal file
View File

@@ -0,0 +1,76 @@
from __future__ import annotations
from typing import List
from fastapi import APIRouter, Depends, HTTPException, status
from dependencies import get_unit_of_work
from models import Project
from schemas.project import ProjectCreate, ProjectRead, ProjectUpdate
from services.exceptions import EntityConflictError, EntityNotFoundError
from services.unit_of_work import UnitOfWork
router = APIRouter(prefix="/projects", tags=["Projects"])
def _to_read_model(project: Project) -> ProjectRead:
return ProjectRead.model_validate(project)
@router.get("", response_model=List[ProjectRead])
def list_projects(uow: UnitOfWork = Depends(get_unit_of_work)) -> List[ProjectRead]:
projects = uow.projects.list()
return [_to_read_model(project) for project in projects]
@router.post("", response_model=ProjectRead, status_code=status.HTTP_201_CREATED)
def create_project(
payload: ProjectCreate, uow: UnitOfWork = Depends(get_unit_of_work)
) -> ProjectRead:
project = Project(**payload.model_dump())
try:
created = uow.projects.create(project)
except EntityConflictError as exc:
raise HTTPException(
status_code=status.HTTP_409_CONFLICT, detail=str(exc)
) from exc
return _to_read_model(created)
@router.get("/{project_id}", response_model=ProjectRead)
def get_project(project_id: int, uow: UnitOfWork = Depends(get_unit_of_work)) -> ProjectRead:
try:
project = uow.projects.get(project_id)
except EntityNotFoundError as exc:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail=str(exc)) from exc
return _to_read_model(project)
@router.put("/{project_id}", response_model=ProjectRead)
def update_project(
project_id: int,
payload: ProjectUpdate,
uow: UnitOfWork = Depends(get_unit_of_work),
) -> ProjectRead:
try:
project = uow.projects.get(project_id)
except EntityNotFoundError as exc:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail=str(exc)) from exc
update_data = payload.model_dump(exclude_unset=True)
for field, value in update_data.items():
setattr(project, field, value)
uow.flush()
return _to_read_model(project)
@router.delete("/{project_id}", status_code=status.HTTP_204_NO_CONTENT)
def delete_project(project_id: int, uow: UnitOfWork = Depends(get_unit_of_work)) -> None:
try:
uow.projects.delete(project_id)
except EntityNotFoundError as exc:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail=str(exc)) from exc

103
routes/scenarios.py Normal file
View File

@@ -0,0 +1,103 @@
from __future__ import annotations
from typing import List
from fastapi import APIRouter, Depends, HTTPException, status
from dependencies import get_unit_of_work
from models import Scenario
from schemas.scenario import ScenarioCreate, ScenarioRead, ScenarioUpdate
from services.exceptions import EntityConflictError, EntityNotFoundError
from services.unit_of_work import UnitOfWork
router = APIRouter(tags=["Scenarios"])
def _to_read_model(scenario: Scenario) -> ScenarioRead:
return ScenarioRead.model_validate(scenario)
@router.get(
"/projects/{project_id}/scenarios",
response_model=List[ScenarioRead],
)
def list_scenarios_for_project(
project_id: int, uow: UnitOfWork = Depends(get_unit_of_work)
) -> List[ScenarioRead]:
try:
uow.projects.get(project_id)
except EntityNotFoundError as exc:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail=str(exc)) from exc
scenarios = uow.scenarios.list_for_project(project_id)
return [_to_read_model(scenario) for scenario in scenarios]
@router.post(
"/projects/{project_id}/scenarios",
response_model=ScenarioRead,
status_code=status.HTTP_201_CREATED,
)
def create_scenario_for_project(
project_id: int,
payload: ScenarioCreate,
uow: UnitOfWork = Depends(get_unit_of_work),
) -> ScenarioRead:
try:
uow.projects.get(project_id)
except EntityNotFoundError as exc:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail=str(exc)) from exc
scenario = Scenario(project_id=project_id, **payload.model_dump())
try:
created = uow.scenarios.create(scenario)
except EntityConflictError as exc:
raise HTTPException(
status_code=status.HTTP_409_CONFLICT, detail=str(exc)) from exc
return _to_read_model(created)
@router.get("/scenarios/{scenario_id}", response_model=ScenarioRead)
def get_scenario(
scenario_id: int, uow: UnitOfWork = Depends(get_unit_of_work)
) -> ScenarioRead:
try:
scenario = uow.scenarios.get(scenario_id)
except EntityNotFoundError as exc:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail=str(exc)) from exc
return _to_read_model(scenario)
@router.put("/scenarios/{scenario_id}", response_model=ScenarioRead)
def update_scenario(
scenario_id: int,
payload: ScenarioUpdate,
uow: UnitOfWork = Depends(get_unit_of_work),
) -> ScenarioRead:
try:
scenario = uow.scenarios.get(scenario_id)
except EntityNotFoundError as exc:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail=str(exc)) from exc
update_data = payload.model_dump(exclude_unset=True)
for field, value in update_data.items():
setattr(scenario, field, value)
uow.flush()
return _to_read_model(scenario)
@router.delete("/scenarios/{scenario_id}", status_code=status.HTTP_204_NO_CONTENT)
def delete_scenario(
scenario_id: int, uow: UnitOfWork = Depends(get_unit_of_work)
) -> None:
try:
uow.scenarios.delete(scenario_id)
except EntityNotFoundError as exc:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail=str(exc)) from exc

37
schemas/project.py Normal file
View File

@@ -0,0 +1,37 @@
from __future__ import annotations
from datetime import datetime
from pydantic import BaseModel, ConfigDict
from models import MiningOperationType
class ProjectBase(BaseModel):
name: str
location: str | None = None
operation_type: MiningOperationType
description: str | None = None
model_config = ConfigDict(extra="forbid")
class ProjectCreate(ProjectBase):
pass
class ProjectUpdate(BaseModel):
name: str | None = None
location: str | None = None
operation_type: MiningOperationType | None = None
description: str | None = None
model_config = ConfigDict(extra="forbid")
class ProjectRead(ProjectBase):
id: int
created_at: datetime
updated_at: datetime
model_config = ConfigDict(from_attributes=True)

66
schemas/scenario.py Normal file
View File

@@ -0,0 +1,66 @@
from __future__ import annotations
from datetime import date, datetime
from pydantic import BaseModel, ConfigDict, field_validator
from models import ResourceType, ScenarioStatus
class ScenarioBase(BaseModel):
name: str
description: str | None = None
status: ScenarioStatus = ScenarioStatus.DRAFT
start_date: date | None = None
end_date: date | None = None
discount_rate: float | None = None
currency: str | None = None
primary_resource: ResourceType | None = None
model_config = ConfigDict(extra="forbid")
@field_validator("currency")
@classmethod
def normalise_currency(cls, value: str | None) -> str | None:
if value is None:
return value
value = value.upper()
if len(value) != 3:
raise ValueError("Currency code must be a 3-letter ISO value")
return value
class ScenarioCreate(ScenarioBase):
pass
class ScenarioUpdate(BaseModel):
name: str | None = None
description: str | None = None
status: ScenarioStatus | None = None
start_date: date | None = None
end_date: date | None = None
discount_rate: float | None = None
currency: str | None = None
primary_resource: ResourceType | None = None
model_config = ConfigDict(extra="forbid")
@field_validator("currency")
@classmethod
def normalise_currency(cls, value: str | None) -> str | None:
if value is None:
return value
value = value.upper()
if len(value) != 3:
raise ValueError("Currency code must be a 3-letter ISO value")
return value
class ScenarioRead(ScenarioBase):
id: int
project_id: int
created_at: datetime
updated_at: datetime
model_config = ConfigDict(from_attributes=True)

1
services/__init__.py Normal file
View File

@@ -0,0 +1 @@
"""Service layer utilities."""

View File

@@ -35,7 +35,8 @@ class ProjectRepository:
try: try:
self.session.flush() self.session.flush()
except IntegrityError as exc: # pragma: no cover - reliance on DB constraints except IntegrityError as exc: # pragma: no cover - reliance on DB constraints
raise EntityConflictError("Project violates uniqueness constraints") from exc raise EntityConflictError(
"Project violates uniqueness constraints") from exc
return project return project
def delete(self, project_id: int) -> None: def delete(self, project_id: int) -> None:
@@ -64,7 +65,10 @@ class ScenarioRepository:
joinedload(Scenario.financial_inputs), joinedload(Scenario.financial_inputs),
joinedload(Scenario.simulation_parameters), joinedload(Scenario.simulation_parameters),
) )
scenario = self.session.execute(stmt).scalar_one_or_none() result = self.session.execute(stmt)
if with_children:
result = result.unique()
scenario = result.scalar_one_or_none()
if scenario is None: if scenario is None:
raise EntityNotFoundError(f"Scenario {scenario_id} not found") raise EntityNotFoundError(f"Scenario {scenario_id} not found")
return scenario return scenario
@@ -102,7 +106,8 @@ class FinancialInputRepository:
try: try:
self.session.flush() self.session.flush()
except IntegrityError as exc: # pragma: no cover except IntegrityError as exc: # pragma: no cover
raise EntityConflictError("Financial input violates constraints") from exc raise EntityConflictError(
"Financial input violates constraints") from exc
return entities return entities
def delete(self, input_id: int) -> None: def delete(self, input_id: int) -> None:
@@ -135,12 +140,15 @@ class SimulationParameterRepository:
try: try:
self.session.flush() self.session.flush()
except IntegrityError as exc: # pragma: no cover except IntegrityError as exc: # pragma: no cover
raise EntityConflictError("Simulation parameter violates constraints") from exc raise EntityConflictError(
"Simulation parameter violates constraints") from exc
return entities return entities
def delete(self, parameter_id: int) -> None: def delete(self, parameter_id: int) -> None:
stmt = select(SimulationParameter).where(SimulationParameter.id == parameter_id) stmt = select(SimulationParameter).where(
SimulationParameter.id == parameter_id)
entity = self.session.execute(stmt).scalar_one_or_none() entity = self.session.execute(stmt).scalar_one_or_none()
if entity is None: if entity is None:
raise EntityNotFoundError(f"Simulation parameter {parameter_id} not found") raise EntityNotFoundError(
f"Simulation parameter {parameter_id} not found")
self.session.delete(entity) self.session.delete(entity)