From c6233e1a56cf9786bcbd171c141bc9aa83a956f7 Mon Sep 17 00:00:00 2001 From: zwitschi Date: Mon, 20 Oct 2025 22:30:56 +0200 Subject: [PATCH] refactor: Centralize database session management in a shared dependency module --- README.md | 2 +- docs/architecture.md | 2 +- docs/implementation_plan.md | 8 ++++++++ routes/consumption.py | 12 +----------- routes/costs.py | 20 +++++++------------- routes/dependencies.py | 13 +++++++++++++ routes/distributions.py | 18 ++++++------------ routes/equipment.py | 20 +++++++------------- routes/maintenance.py | 10 +--------- routes/parameters.py | 13 +------------ routes/production.py | 12 +----------- routes/reporting.py | 9 --------- routes/scenarios.py | 28 ++++++++-------------------- routes/simulations.py | 14 +++----------- tests/unit/conftest.py | 33 +++------------------------------ 15 files changed, 61 insertions(+), 153 deletions(-) create mode 100644 routes/dependencies.py diff --git a/README.md b/README.md index 4a8b747..e224a2e 100644 --- a/README.md +++ b/README.md @@ -29,7 +29,7 @@ The architecture is documented in [docs/architecture.md](docs/architecture.md). The project is organized into several key directories: - `models/`: Contains SQLAlchemy models representing database tables. -- `routes/`: Defines FastAPI routes for API endpoints. +- `routes/`: Defines FastAPI routes for API endpoints; shared dependencies like `get_db` live in `routes/dependencies.py`. - `services/`: Business logic and service layer. - `components/`: Frontend components (to be defined). - `config/`: Configuration files and settings. diff --git a/docs/architecture.md b/docs/architecture.md index 2a5d739..cd1cfe8 100644 --- a/docs/architecture.md +++ b/docs/architecture.md @@ -10,7 +10,7 @@ The backend leverages SQLAlchemy for ORM mapping to a PostgreSQL database. ## System Components -- **FastAPI backend** (`main.py`, `routes/`): hosts REST endpoints for scenarios, parameters, costs, consumption, production, equipment, maintenance, simulations, and reporting. Each router encapsulates request/response schemas and DB access patterns. +- **FastAPI backend** (`main.py`, `routes/`): hosts REST endpoints for scenarios, parameters, costs, consumption, production, equipment, maintenance, simulations, and reporting. Each router encapsulates request/response schemas and DB access patterns, leveraging a shared dependency module (`routes/dependencies.get_db`) for SQLAlchemy session management. - **Service layer** (`services/`): houses business logic. `services/reporting.py` produces statistical summaries, while `services/simulation.py` provides the Monte Carlo integration point. - **Persistence** (`models/`, `config/database.py`): SQLAlchemy models map to PostgreSQL tables in schema `bricsium_platform`. Relationships connect scenarios to derived domain entities. - **Presentation** (`templates/`, `components/`): server-rendered views support data entry (scenario and parameter forms) and the dashboard visualization powered by Chart.js. diff --git a/docs/implementation_plan.md b/docs/implementation_plan.md index 7e2f373..17bf40f 100644 --- a/docs/implementation_plan.md +++ b/docs/implementation_plan.md @@ -151,3 +151,11 @@ Next actionable items: 4. Scaffold Monte Carlo Simulation endpoints (`services/simulation.py`, `routes/simulations.py`, tests). 5. Scaffold Reporting endpoints (`services/reporting.py`, `routes/reporting.py`, front-end Dashboard, tests). 6. Add CI job for tests and coverage. + +## UI Template Audit (2025-10-20) + +- Existing HTML templates: `ScenarioForm.html`, `ParameterInput.html`, and `Dashboard.html` (reporting summary view). +- Coverage gaps remain for costs, consumption, production, equipment, maintenance, and simulation workflows—no dedicated templates yet. +- Shared layout primitives (navigation/header/footer) are absent; current pages duplicate boilerplate markup. +- Dashboard currently covers reporting metrics but should be wired to a central `/` route once the shared layout lands. +- Next steps align with the updated TODO checklist: introduce a `base.html`, refactor existing templates to extend it, and scaffold placeholder pages for the remaining features. diff --git a/routes/consumption.py b/routes/consumption.py index f9e3db3..9201124 100644 --- a/routes/consumption.py +++ b/routes/consumption.py @@ -4,21 +4,11 @@ from fastapi import APIRouter, Depends, status from pydantic import BaseModel, ConfigDict, PositiveFloat from sqlalchemy.orm import Session -from config.database import SessionLocal from models.consumption import Consumption +from routes.dependencies import get_db router = APIRouter(prefix="/api/consumption", tags=["Consumption"]) - - -def get_db(): - db = SessionLocal() - try: - yield db - finally: - db.close() - - class ConsumptionBase(BaseModel): scenario_id: int amount: PositiveFloat diff --git a/routes/costs.py b/routes/costs.py index 0c8a226..277dcd8 100644 --- a/routes/costs.py +++ b/routes/costs.py @@ -1,23 +1,17 @@ -from fastapi import APIRouter, Depends -from sqlalchemy.orm import Session from typing import List, Optional + +from fastapi import APIRouter, Depends from pydantic import BaseModel, ConfigDict -from config.database import SessionLocal +from sqlalchemy.orm import Session + from models.capex import Capex from models.opex import Opex +from routes.dependencies import get_db router = APIRouter(prefix="/api/costs", tags=["Costs"]) - - -def get_db(): - db = SessionLocal() - try: - yield db - finally: - db.close() - - # Pydantic schemas for Capex + + class CapexCreate(BaseModel): scenario_id: int amount: float diff --git a/routes/dependencies.py b/routes/dependencies.py new file mode 100644 index 0000000..0afc871 --- /dev/null +++ b/routes/dependencies.py @@ -0,0 +1,13 @@ +from collections.abc import Generator + +from sqlalchemy.orm import Session + +from config.database import SessionLocal + + +def get_db() -> Generator[Session, None, None]: + db = SessionLocal() + try: + yield db + finally: + db.close() diff --git a/routes/distributions.py b/routes/distributions.py index 0f60d3e..8c409c3 100644 --- a/routes/distributions.py +++ b/routes/distributions.py @@ -1,25 +1,19 @@ +from typing import Dict, List + from fastapi import APIRouter, Depends -from sqlalchemy.orm import Session -from typing import List from pydantic import BaseModel, ConfigDict -from config.database import SessionLocal +from sqlalchemy.orm import Session + from models.distribution import Distribution +from routes.dependencies import get_db router = APIRouter(prefix="/api/distributions", tags=["Distributions"]) -def get_db(): - db = SessionLocal() - try: - yield db - finally: - db.close() - - class DistributionCreate(BaseModel): name: str distribution_type: str - parameters: dict + parameters: Dict[str, float | int] class DistributionRead(DistributionCreate): diff --git a/routes/equipment.py b/routes/equipment.py index b9879e1..c8aecbd 100644 --- a/routes/equipment.py +++ b/routes/equipment.py @@ -1,22 +1,16 @@ -from fastapi import APIRouter, Depends -from sqlalchemy.orm import Session from typing import List, Optional + +from fastapi import APIRouter, Depends from pydantic import BaseModel, ConfigDict -from config.database import SessionLocal +from sqlalchemy.orm import Session + from models.equipment import Equipment +from routes.dependencies import get_db router = APIRouter(prefix="/api/equipment", tags=["Equipment"]) - - -def get_db(): - db = SessionLocal() - try: - yield db - finally: - db.close() - - # Pydantic schemas + + class EquipmentCreate(BaseModel): scenario_id: int name: str diff --git a/routes/maintenance.py b/routes/maintenance.py index 7ed2400..d7f0f49 100644 --- a/routes/maintenance.py +++ b/routes/maintenance.py @@ -5,21 +5,13 @@ from fastapi import APIRouter, Depends, HTTPException, status from pydantic import BaseModel, ConfigDict, PositiveFloat from sqlalchemy.orm import Session -from config.database import SessionLocal from models.maintenance import Maintenance +from routes.dependencies import get_db router = APIRouter(prefix="/api/maintenance", tags=["Maintenance"]) -def get_db(): - db = SessionLocal() - try: - yield db - finally: - db.close() - - class MaintenanceBase(BaseModel): equipment_id: int scenario_id: int diff --git a/routes/parameters.py b/routes/parameters.py index fb6b7e9..39e67e4 100644 --- a/routes/parameters.py +++ b/routes/parameters.py @@ -4,10 +4,10 @@ from fastapi import APIRouter, Depends, HTTPException from pydantic import BaseModel, ConfigDict, field_validator from sqlalchemy.orm import Session -from config.database import SessionLocal from models.distribution import Distribution from models.parameters import Parameter from models.scenario import Scenario +from routes.dependencies import get_db router = APIRouter(prefix="/api/parameters", tags=["parameters"]) @@ -45,17 +45,6 @@ class ParameterRead(ParameterCreate): id: int model_config = ConfigDict(from_attributes=True) -# Dependency - - -def get_db(): - db = SessionLocal() - try: - yield db - finally: - db.close() - - @router.post("/", response_model=ParameterRead) def create_parameter(param: ParameterCreate, db: Session = Depends(get_db)): scen = db.query(Scenario).filter(Scenario.id == param.scenario_id).first() diff --git a/routes/production.py b/routes/production.py index c0684b6..0d1de6a 100644 --- a/routes/production.py +++ b/routes/production.py @@ -4,21 +4,11 @@ from fastapi import APIRouter, Depends, status from pydantic import BaseModel, ConfigDict, PositiveFloat from sqlalchemy.orm import Session -from config.database import SessionLocal from models.production_output import ProductionOutput +from routes.dependencies import get_db router = APIRouter(prefix="/api/production", tags=["Production"]) - - -def get_db(): - db = SessionLocal() - try: - yield db - finally: - db.close() - - class ProductionOutputBase(BaseModel): scenario_id: int amount: PositiveFloat diff --git a/routes/reporting.py b/routes/reporting.py index 3714dcb..09a9417 100644 --- a/routes/reporting.py +++ b/routes/reporting.py @@ -3,21 +3,12 @@ from typing import Any, Dict, List, cast from fastapi import APIRouter, HTTPException, Request, status from pydantic import BaseModel -from config.database import SessionLocal from services.reporting import generate_report router = APIRouter(prefix="/api/reporting", tags=["Reporting"]) -def get_db(): - db = SessionLocal() - try: - yield db - finally: - db.close() - - def _validate_payload(payload: Any) -> List[Dict[str, float]]: if not isinstance(payload, list): raise HTTPException( diff --git a/routes/scenarios.py b/routes/scenarios.py index 052d122..11dab40 100644 --- a/routes/scenarios.py +++ b/routes/scenarios.py @@ -1,10 +1,12 @@ -from fastapi import APIRouter, Depends, HTTPException -from sqlalchemy.orm import Session -from config.database import SessionLocal -from models.scenario import Scenario -from pydantic import BaseModel, ConfigDict -from typing import Optional from datetime import datetime +from typing import Optional + +from fastapi import APIRouter, Depends, HTTPException +from pydantic import BaseModel, ConfigDict +from sqlalchemy.orm import Session + +from models.scenario import Scenario +from routes.dependencies import get_db router = APIRouter(prefix="/api/scenarios", tags=["scenarios"]) @@ -22,29 +24,15 @@ class ScenarioRead(ScenarioCreate): updated_at: Optional[datetime] = None model_config = ConfigDict(from_attributes=True) -# Dependency - - -def get_db(): - db = SessionLocal() - try: - yield db - finally: - db.close() - - @router.post("/", response_model=ScenarioRead) def create_scenario(scenario: ScenarioCreate, db: Session = Depends(get_db)): - print(f"Creating scenario with name: {scenario.name}") db_s = db.query(Scenario).filter(Scenario.name == scenario.name).first() if db_s: - print(f"Scenario with name {scenario.name} already exists.") raise HTTPException(status_code=400, detail="Scenario already exists") new_s = Scenario(name=scenario.name, description=scenario.description) db.add(new_s) db.commit() db.refresh(new_s) - print(f"Scenario with name {scenario.name} created successfully.") return new_s diff --git a/routes/simulations.py b/routes/simulations.py index b8c89b9..b00c8c1 100644 --- a/routes/simulations.py +++ b/routes/simulations.py @@ -1,27 +1,19 @@ -from typing import List, Optional, cast +from typing import Dict, List, Optional from fastapi import APIRouter, Depends, HTTPException, status from pydantic import BaseModel, PositiveInt from sqlalchemy.orm import Session -from config.database import SessionLocal from models.parameters import Parameter from models.scenario import Scenario from models.simulation_result import SimulationResult +from routes.dependencies import get_db from services.reporting import generate_report from services.simulation import run_simulation router = APIRouter(prefix="/api/simulations", tags=["Simulations"]) -def get_db(): - db = SessionLocal() - try: - yield db - finally: - db.close() - - class SimulationParameterInput(BaseModel): name: str value: float @@ -48,7 +40,7 @@ class SimulationRunResponse(BaseModel): scenario_id: int iterations: int results: List[SimulationResultItem] - summary: dict + summary: Dict[str, float | int] def _load_parameters(db: Session, scenario_id: int) -> List[SimulationParameterInput]: diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py index 5945e48..23ddb18 100644 --- a/tests/unit/conftest.py +++ b/tests/unit/conftest.py @@ -57,38 +57,11 @@ def api_client(db_session: Session) -> Generator[TestClient, None, None]: finally: pass - # override all routers that use get_db - from routes import ( - consumption, - costs, - distributions, - equipment, - maintenance, - parameters, - production, - reporting, - scenarios, - simulations, - ) + from routes import dependencies as route_dependencies - overrides = { - consumption.get_db: override_get_db, - costs.get_db: override_get_db, - distributions.get_db: override_get_db, - equipment.get_db: override_get_db, - maintenance.get_db: override_get_db, - parameters.get_db: override_get_db, - production.get_db: override_get_db, - reporting.get_db: override_get_db, - scenarios.get_db: override_get_db, - simulations.get_db: override_get_db, - } - - for dependency, override in overrides.items(): - app.dependency_overrides[dependency] = override + app.dependency_overrides[route_dependencies.get_db] = override_get_db with TestClient(app) as client: yield client - for dependency in overrides: - app.dependency_overrides.pop(dependency, None) + app.dependency_overrides.pop(route_dependencies.get_db, None)