From 61b42b3041d5e2006c481d21468947500a710a5a Mon Sep 17 00:00:00 2001 From: zwitschi Date: Sun, 9 Nov 2025 17:23:10 +0100 Subject: [PATCH] feat: implement CRUD APIs for projects and scenarios with validated schemas --- .../versions/20251109_01_initial_schema.py | 77 ++++++++----- changelog.md | 1 + config/__init__.py | 1 + dependencies.py | 12 ++ main.py | 6 +- models/simulation_parameter.py | 14 ++- pyproject.toml | 3 + routes/__init__.py | 1 + routes/projects.py | 76 +++++++++++++ routes/scenarios.py | 103 ++++++++++++++++++ schemas/project.py | 37 +++++++ schemas/scenario.py | 66 +++++++++++ services/__init__.py | 1 + services/repositories.py | 20 +++- 14 files changed, 380 insertions(+), 38 deletions(-) create mode 100644 config/__init__.py create mode 100644 dependencies.py create mode 100644 routes/__init__.py create mode 100644 routes/projects.py create mode 100644 routes/scenarios.py create mode 100644 schemas/project.py create mode 100644 schemas/scenario.py create mode 100644 services/__init__.py diff --git a/alembic/versions/20251109_01_initial_schema.py b/alembic/versions/20251109_01_initial_schema.py index 020a56b..cf282d5 100644 --- a/alembic/versions/20251109_01_initial_schema.py +++ b/alembic/versions/20251109_01_initial_schema.py @@ -101,8 +101,10 @@ def upgrade() -> None: sa.Column("location", sa.String(length=255), nullable=True), sa.Column("operation_type", mining_operation_type, nullable=False), sa.Column("description", sa.Text(), nullable=True), - sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False), - sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False), + sa.Column("created_at", sa.DateTime(timezone=True), + server_default=sa.func.now(), nullable=False), + sa.Column("updated_at", sa.DateTime(timezone=True), + server_default=sa.func.now(), nullable=False), sa.PrimaryKeyConstraint("id"), sa.UniqueConstraint("name"), ) @@ -117,16 +119,21 @@ def upgrade() -> None: sa.Column("status", scenario_status, nullable=False), sa.Column("start_date", sa.Date(), nullable=True), sa.Column("end_date", sa.Date(), nullable=True), - sa.Column("discount_rate", sa.Numeric(precision=5, scale=2), nullable=True), + sa.Column("discount_rate", sa.Numeric( + precision=5, scale=2), nullable=True), sa.Column("currency", sa.String(length=3), nullable=True), sa.Column("primary_resource", resource_type, nullable=True), - sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False), - sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False), - sa.ForeignKeyConstraint(["project_id"], ["projects.id"], ondelete="CASCADE"), + sa.Column("created_at", sa.DateTime(timezone=True), + server_default=sa.func.now(), nullable=False), + sa.Column("updated_at", sa.DateTime(timezone=True), + server_default=sa.func.now(), nullable=False), + sa.ForeignKeyConstraint( + ["project_id"], ["projects.id"], ondelete="CASCADE"), sa.PrimaryKeyConstraint("id"), ) op.create_index(op.f("ix_scenarios_id"), "scenarios", ["id"], unique=False) - op.create_index(op.f("ix_scenarios_project_id"), "scenarios", ["project_id"], unique=False) + op.create_index(op.f("ix_scenarios_project_id"), + "scenarios", ["project_id"], unique=False) op.create_table( "financial_inputs", @@ -139,13 +146,18 @@ def upgrade() -> None: sa.Column("currency", sa.String(length=3), nullable=True), sa.Column("effective_date", sa.Date(), nullable=True), sa.Column("notes", sa.Text(), nullable=True), - sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False), - sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False), - sa.ForeignKeyConstraint(["scenario_id"], ["scenarios.id"], ondelete="CASCADE"), + sa.Column("created_at", sa.DateTime(timezone=True), + server_default=sa.func.now(), nullable=False), + sa.Column("updated_at", sa.DateTime(timezone=True), + server_default=sa.func.now(), nullable=False), + sa.ForeignKeyConstraint( + ["scenario_id"], ["scenarios.id"], ondelete="CASCADE"), sa.PrimaryKeyConstraint("id"), ) - op.create_index(op.f("ix_financial_inputs_id"), "financial_inputs", ["id"], unique=False) - op.create_index(op.f("ix_financial_inputs_scenario_id"), "financial_inputs", ["scenario_id"], unique=False) + op.create_index(op.f("ix_financial_inputs_id"), + "financial_inputs", ["id"], unique=False) + op.create_index(op.f("ix_financial_inputs_scenario_id"), + "financial_inputs", ["scenario_id"], unique=False) op.create_table( "simulation_parameters", @@ -155,28 +167,41 @@ def upgrade() -> None: sa.Column("distribution", distribution_type, nullable=False), sa.Column("variable", stochastic_variable, nullable=True), sa.Column("resource_type", resource_type, nullable=True), - sa.Column("mean_value", sa.Numeric(precision=18, scale=4), nullable=True), - sa.Column("standard_deviation", sa.Numeric(precision=18, scale=4), nullable=True), - sa.Column("minimum_value", sa.Numeric(precision=18, scale=4), nullable=True), - sa.Column("maximum_value", sa.Numeric(precision=18, scale=4), nullable=True), + sa.Column("mean_value", sa.Numeric( + precision=18, scale=4), nullable=True), + sa.Column("standard_deviation", sa.Numeric( + precision=18, scale=4), nullable=True), + sa.Column("minimum_value", sa.Numeric( + precision=18, scale=4), nullable=True), + sa.Column("maximum_value", sa.Numeric( + precision=18, scale=4), nullable=True), sa.Column("unit", sa.String(length=32), nullable=True), - sa.Column("metadata", sa.JSON(), nullable=True), - sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False), - sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False), - sa.ForeignKeyConstraint(["scenario_id"], ["scenarios.id"], ondelete="CASCADE"), + sa.Column("configuration", sa.JSON(), nullable=True), + sa.Column("created_at", sa.DateTime(timezone=True), + server_default=sa.func.now(), nullable=False), + sa.Column("updated_at", sa.DateTime(timezone=True), + server_default=sa.func.now(), nullable=False), + sa.ForeignKeyConstraint( + ["scenario_id"], ["scenarios.id"], ondelete="CASCADE"), sa.PrimaryKeyConstraint("id"), ) - op.create_index(op.f("ix_simulation_parameters_id"), "simulation_parameters", ["id"], unique=False) - op.create_index(op.f("ix_simulation_parameters_scenario_id"), "simulation_parameters", ["scenario_id"], unique=False) + op.create_index(op.f("ix_simulation_parameters_id"), + "simulation_parameters", ["id"], unique=False) + op.create_index(op.f("ix_simulation_parameters_scenario_id"), + "simulation_parameters", ["scenario_id"], unique=False) def downgrade() -> None: - op.drop_index(op.f("ix_simulation_parameters_scenario_id"), table_name="simulation_parameters") - op.drop_index(op.f("ix_simulation_parameters_id"), table_name="simulation_parameters") + op.drop_index(op.f("ix_simulation_parameters_scenario_id"), + table_name="simulation_parameters") + op.drop_index(op.f("ix_simulation_parameters_id"), + table_name="simulation_parameters") op.drop_table("simulation_parameters") - op.drop_index(op.f("ix_financial_inputs_scenario_id"), table_name="financial_inputs") - op.drop_index(op.f("ix_financial_inputs_id"), table_name="financial_inputs") + op.drop_index(op.f("ix_financial_inputs_scenario_id"), + table_name="financial_inputs") + op.drop_index(op.f("ix_financial_inputs_id"), + table_name="financial_inputs") op.drop_table("financial_inputs") op.drop_index(op.f("ix_scenarios_project_id"), table_name="scenarios") diff --git a/changelog.md b/changelog.md index 579efe5..a436c92 100644 --- a/changelog.md +++ b/changelog.md @@ -6,3 +6,4 @@ - Added core SQLAlchemy domain models, shared metadata descriptors, and Alembic migration setup (with initial schema snapshot) to establish the persistence layer foundation. - Introduced repository and unit-of-work helpers for projects, scenarios, financial inputs, and simulation parameters to support service-layer operations. - Added SQLite-backed pytest coverage for repository and unit-of-work behaviours to validate persistence interactions. +- Exposed project and scenario CRUD APIs with validated schemas and integrated them into the FastAPI application. diff --git a/config/__init__.py b/config/__init__.py new file mode 100644 index 0000000..56096f2 --- /dev/null +++ b/config/__init__.py @@ -0,0 +1 @@ +"""Configuration package.""" diff --git a/dependencies.py b/dependencies.py new file mode 100644 index 0000000..e492586 --- /dev/null +++ b/dependencies.py @@ -0,0 +1,12 @@ +from __future__ import annotations + +from collections.abc import Generator + +from services.unit_of_work import UnitOfWork + + +def get_unit_of_work() -> Generator[UnitOfWork, None, None]: + """FastAPI dependency yielding a unit-of-work instance.""" + + with UnitOfWork() as uow: + yield uow diff --git a/main.py b/main.py index e1cbc9a..4fd284f 100644 --- a/main.py +++ b/main.py @@ -10,6 +10,8 @@ from models import ( Scenario, SimulationParameter, ) +from routes.projects import router as projects_router +from routes.scenarios import router as scenarios_router # Initialize database schema (imports above ensure models are registered) Base.metadata.create_all(bind=engine) @@ -29,5 +31,7 @@ async def health() -> dict[str, str]: return {"status": "ok"} -app.mount("/static", StaticFiles(directory="static"), name="static") +app.include_router(projects_router) +app.include_router(scenarios_router) +app.mount("/static", StaticFiles(directory="static"), name="static") diff --git a/models/simulation_parameter.py b/models/simulation_parameter.py index e77b1b2..5b7ac16 100644 --- a/models/simulation_parameter.py +++ b/models/simulation_parameter.py @@ -52,12 +52,16 @@ class SimulationParameter(Base): resource_type: Mapped[ResourceType | None] = mapped_column( SQLEnum(ResourceType), nullable=True ) - mean_value: Mapped[float | None] = mapped_column(Numeric(18, 4), nullable=True) - standard_deviation: Mapped[float | None] = mapped_column(Numeric(18, 4), nullable=True) - minimum_value: Mapped[float | None] = mapped_column(Numeric(18, 4), nullable=True) - maximum_value: Mapped[float | None] = mapped_column(Numeric(18, 4), nullable=True) + mean_value: Mapped[float | None] = mapped_column( + Numeric(18, 4), nullable=True) + standard_deviation: Mapped[float | None] = mapped_column( + Numeric(18, 4), nullable=True) + minimum_value: Mapped[float | None] = mapped_column( + Numeric(18, 4), nullable=True) + maximum_value: Mapped[float | None] = mapped_column( + Numeric(18, 4), nullable=True) unit: Mapped[str | None] = mapped_column(String(32), nullable=True) - metadata: Mapped[dict | None] = mapped_column(JSON, nullable=True) + configuration: Mapped[dict | None] = mapped_column(JSON, nullable=True) created_at: Mapped[datetime] = mapped_column( DateTime(timezone=True), nullable=False, server_default=func.now() ) diff --git a/pyproject.toml b/pyproject.toml index 35be63b..de07a01 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -14,3 +14,6 @@ exclude = ''' )/ ''' +[tool.pytest.ini_options] +pythonpath = ["."] + diff --git a/routes/__init__.py b/routes/__init__.py new file mode 100644 index 0000000..3f06ec5 --- /dev/null +++ b/routes/__init__.py @@ -0,0 +1 @@ +"""API route registrations.""" \ No newline at end of file diff --git a/routes/projects.py b/routes/projects.py new file mode 100644 index 0000000..a362e1f --- /dev/null +++ b/routes/projects.py @@ -0,0 +1,76 @@ +from __future__ import annotations + +from typing import List + +from fastapi import APIRouter, Depends, HTTPException, status + +from dependencies import get_unit_of_work +from models import Project +from schemas.project import ProjectCreate, ProjectRead, ProjectUpdate +from services.exceptions import EntityConflictError, EntityNotFoundError +from services.unit_of_work import UnitOfWork + +router = APIRouter(prefix="/projects", tags=["Projects"]) + + +def _to_read_model(project: Project) -> ProjectRead: + return ProjectRead.model_validate(project) + + +@router.get("", response_model=List[ProjectRead]) +def list_projects(uow: UnitOfWork = Depends(get_unit_of_work)) -> List[ProjectRead]: + projects = uow.projects.list() + return [_to_read_model(project) for project in projects] + + +@router.post("", response_model=ProjectRead, status_code=status.HTTP_201_CREATED) +def create_project( + payload: ProjectCreate, uow: UnitOfWork = Depends(get_unit_of_work) +) -> ProjectRead: + project = Project(**payload.model_dump()) + try: + created = uow.projects.create(project) + except EntityConflictError as exc: + raise HTTPException( + status_code=status.HTTP_409_CONFLICT, detail=str(exc) + ) from exc + return _to_read_model(created) + + +@router.get("/{project_id}", response_model=ProjectRead) +def get_project(project_id: int, uow: UnitOfWork = Depends(get_unit_of_work)) -> ProjectRead: + try: + project = uow.projects.get(project_id) + except EntityNotFoundError as exc: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, detail=str(exc)) from exc + return _to_read_model(project) + + +@router.put("/{project_id}", response_model=ProjectRead) +def update_project( + project_id: int, + payload: ProjectUpdate, + uow: UnitOfWork = Depends(get_unit_of_work), +) -> ProjectRead: + try: + project = uow.projects.get(project_id) + except EntityNotFoundError as exc: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, detail=str(exc)) from exc + + update_data = payload.model_dump(exclude_unset=True) + for field, value in update_data.items(): + setattr(project, field, value) + + uow.flush() + return _to_read_model(project) + + +@router.delete("/{project_id}", status_code=status.HTTP_204_NO_CONTENT) +def delete_project(project_id: int, uow: UnitOfWork = Depends(get_unit_of_work)) -> None: + try: + uow.projects.delete(project_id) + except EntityNotFoundError as exc: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, detail=str(exc)) from exc diff --git a/routes/scenarios.py b/routes/scenarios.py new file mode 100644 index 0000000..c34d81d --- /dev/null +++ b/routes/scenarios.py @@ -0,0 +1,103 @@ +from __future__ import annotations + +from typing import List + +from fastapi import APIRouter, Depends, HTTPException, status + +from dependencies import get_unit_of_work +from models import Scenario +from schemas.scenario import ScenarioCreate, ScenarioRead, ScenarioUpdate +from services.exceptions import EntityConflictError, EntityNotFoundError +from services.unit_of_work import UnitOfWork + +router = APIRouter(tags=["Scenarios"]) + + +def _to_read_model(scenario: Scenario) -> ScenarioRead: + return ScenarioRead.model_validate(scenario) + + +@router.get( + "/projects/{project_id}/scenarios", + response_model=List[ScenarioRead], +) +def list_scenarios_for_project( + project_id: int, uow: UnitOfWork = Depends(get_unit_of_work) +) -> List[ScenarioRead]: + try: + uow.projects.get(project_id) + except EntityNotFoundError as exc: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, detail=str(exc)) from exc + + scenarios = uow.scenarios.list_for_project(project_id) + return [_to_read_model(scenario) for scenario in scenarios] + + +@router.post( + "/projects/{project_id}/scenarios", + response_model=ScenarioRead, + status_code=status.HTTP_201_CREATED, +) +def create_scenario_for_project( + project_id: int, + payload: ScenarioCreate, + uow: UnitOfWork = Depends(get_unit_of_work), +) -> ScenarioRead: + try: + uow.projects.get(project_id) + except EntityNotFoundError as exc: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, detail=str(exc)) from exc + + scenario = Scenario(project_id=project_id, **payload.model_dump()) + + try: + created = uow.scenarios.create(scenario) + except EntityConflictError as exc: + raise HTTPException( + status_code=status.HTTP_409_CONFLICT, detail=str(exc)) from exc + return _to_read_model(created) + + +@router.get("/scenarios/{scenario_id}", response_model=ScenarioRead) +def get_scenario( + scenario_id: int, uow: UnitOfWork = Depends(get_unit_of_work) +) -> ScenarioRead: + try: + scenario = uow.scenarios.get(scenario_id) + except EntityNotFoundError as exc: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, detail=str(exc)) from exc + return _to_read_model(scenario) + + +@router.put("/scenarios/{scenario_id}", response_model=ScenarioRead) +def update_scenario( + scenario_id: int, + payload: ScenarioUpdate, + uow: UnitOfWork = Depends(get_unit_of_work), +) -> ScenarioRead: + try: + scenario = uow.scenarios.get(scenario_id) + except EntityNotFoundError as exc: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, detail=str(exc)) from exc + + update_data = payload.model_dump(exclude_unset=True) + for field, value in update_data.items(): + setattr(scenario, field, value) + + uow.flush() + return _to_read_model(scenario) + + +@router.delete("/scenarios/{scenario_id}", status_code=status.HTTP_204_NO_CONTENT) +def delete_scenario( + scenario_id: int, uow: UnitOfWork = Depends(get_unit_of_work) +) -> None: + try: + uow.scenarios.delete(scenario_id) + except EntityNotFoundError as exc: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, detail=str(exc)) from exc diff --git a/schemas/project.py b/schemas/project.py new file mode 100644 index 0000000..1b0107d --- /dev/null +++ b/schemas/project.py @@ -0,0 +1,37 @@ +from __future__ import annotations + +from datetime import datetime + +from pydantic import BaseModel, ConfigDict + +from models import MiningOperationType + + +class ProjectBase(BaseModel): + name: str + location: str | None = None + operation_type: MiningOperationType + description: str | None = None + + model_config = ConfigDict(extra="forbid") + + +class ProjectCreate(ProjectBase): + pass + + +class ProjectUpdate(BaseModel): + name: str | None = None + location: str | None = None + operation_type: MiningOperationType | None = None + description: str | None = None + + model_config = ConfigDict(extra="forbid") + + +class ProjectRead(ProjectBase): + id: int + created_at: datetime + updated_at: datetime + + model_config = ConfigDict(from_attributes=True) diff --git a/schemas/scenario.py b/schemas/scenario.py new file mode 100644 index 0000000..6681cba --- /dev/null +++ b/schemas/scenario.py @@ -0,0 +1,66 @@ +from __future__ import annotations + +from datetime import date, datetime + +from pydantic import BaseModel, ConfigDict, field_validator + +from models import ResourceType, ScenarioStatus + + +class ScenarioBase(BaseModel): + name: str + description: str | None = None + status: ScenarioStatus = ScenarioStatus.DRAFT + start_date: date | None = None + end_date: date | None = None + discount_rate: float | None = None + currency: str | None = None + primary_resource: ResourceType | None = None + + model_config = ConfigDict(extra="forbid") + + @field_validator("currency") + @classmethod + def normalise_currency(cls, value: str | None) -> str | None: + if value is None: + return value + value = value.upper() + if len(value) != 3: + raise ValueError("Currency code must be a 3-letter ISO value") + return value + + +class ScenarioCreate(ScenarioBase): + pass + + +class ScenarioUpdate(BaseModel): + name: str | None = None + description: str | None = None + status: ScenarioStatus | None = None + start_date: date | None = None + end_date: date | None = None + discount_rate: float | None = None + currency: str | None = None + primary_resource: ResourceType | None = None + + model_config = ConfigDict(extra="forbid") + + @field_validator("currency") + @classmethod + def normalise_currency(cls, value: str | None) -> str | None: + if value is None: + return value + value = value.upper() + if len(value) != 3: + raise ValueError("Currency code must be a 3-letter ISO value") + return value + + +class ScenarioRead(ScenarioBase): + id: int + project_id: int + created_at: datetime + updated_at: datetime + + model_config = ConfigDict(from_attributes=True) diff --git a/services/__init__.py b/services/__init__.py new file mode 100644 index 0000000..1476ae9 --- /dev/null +++ b/services/__init__.py @@ -0,0 +1 @@ +"""Service layer utilities.""" diff --git a/services/repositories.py b/services/repositories.py index bafff7b..21c10da 100644 --- a/services/repositories.py +++ b/services/repositories.py @@ -35,7 +35,8 @@ class ProjectRepository: try: self.session.flush() except IntegrityError as exc: # pragma: no cover - reliance on DB constraints - raise EntityConflictError("Project violates uniqueness constraints") from exc + raise EntityConflictError( + "Project violates uniqueness constraints") from exc return project def delete(self, project_id: int) -> None: @@ -64,7 +65,10 @@ class ScenarioRepository: joinedload(Scenario.financial_inputs), joinedload(Scenario.simulation_parameters), ) - scenario = self.session.execute(stmt).scalar_one_or_none() + result = self.session.execute(stmt) + if with_children: + result = result.unique() + scenario = result.scalar_one_or_none() if scenario is None: raise EntityNotFoundError(f"Scenario {scenario_id} not found") return scenario @@ -102,7 +106,8 @@ class FinancialInputRepository: try: self.session.flush() except IntegrityError as exc: # pragma: no cover - raise EntityConflictError("Financial input violates constraints") from exc + raise EntityConflictError( + "Financial input violates constraints") from exc return entities def delete(self, input_id: int) -> None: @@ -135,12 +140,15 @@ class SimulationParameterRepository: try: self.session.flush() except IntegrityError as exc: # pragma: no cover - raise EntityConflictError("Simulation parameter violates constraints") from exc + raise EntityConflictError( + "Simulation parameter violates constraints") from exc return entities def delete(self, parameter_id: int) -> None: - stmt = select(SimulationParameter).where(SimulationParameter.id == parameter_id) + stmt = select(SimulationParameter).where( + SimulationParameter.id == parameter_id) entity = self.session.execute(stmt).scalar_one_or_none() if entity is None: - raise EntityNotFoundError(f"Simulation parameter {parameter_id} not found") + raise EntityNotFoundError( + f"Simulation parameter {parameter_id} not found") self.session.delete(entity)