refactor: Centralize database session management in a shared dependency module

This commit is contained in:
2025-10-20 22:30:56 +02:00
parent 434be86b76
commit c6233e1a56
15 changed files with 61 additions and 153 deletions

View File

@@ -29,7 +29,7 @@ The architecture is documented in [docs/architecture.md](docs/architecture.md).
The project is organized into several key directories: The project is organized into several key directories:
- `models/`: Contains SQLAlchemy models representing database tables. - `models/`: Contains SQLAlchemy models representing database tables.
- `routes/`: Defines FastAPI routes for API endpoints. - `routes/`: Defines FastAPI routes for API endpoints; shared dependencies like `get_db` live in `routes/dependencies.py`.
- `services/`: Business logic and service layer. - `services/`: Business logic and service layer.
- `components/`: Frontend components (to be defined). - `components/`: Frontend components (to be defined).
- `config/`: Configuration files and settings. - `config/`: Configuration files and settings.

View File

@@ -10,7 +10,7 @@ The backend leverages SQLAlchemy for ORM mapping to a PostgreSQL database.
## System Components ## System Components
- **FastAPI backend** (`main.py`, `routes/`): hosts REST endpoints for scenarios, parameters, costs, consumption, production, equipment, maintenance, simulations, and reporting. Each router encapsulates request/response schemas and DB access patterns. - **FastAPI backend** (`main.py`, `routes/`): hosts REST endpoints for scenarios, parameters, costs, consumption, production, equipment, maintenance, simulations, and reporting. Each router encapsulates request/response schemas and DB access patterns, leveraging a shared dependency module (`routes/dependencies.get_db`) for SQLAlchemy session management.
- **Service layer** (`services/`): houses business logic. `services/reporting.py` produces statistical summaries, while `services/simulation.py` provides the Monte Carlo integration point. - **Service layer** (`services/`): houses business logic. `services/reporting.py` produces statistical summaries, while `services/simulation.py` provides the Monte Carlo integration point.
- **Persistence** (`models/`, `config/database.py`): SQLAlchemy models map to PostgreSQL tables in schema `bricsium_platform`. Relationships connect scenarios to derived domain entities. - **Persistence** (`models/`, `config/database.py`): SQLAlchemy models map to PostgreSQL tables in schema `bricsium_platform`. Relationships connect scenarios to derived domain entities.
- **Presentation** (`templates/`, `components/`): server-rendered views support data entry (scenario and parameter forms) and the dashboard visualization powered by Chart.js. - **Presentation** (`templates/`, `components/`): server-rendered views support data entry (scenario and parameter forms) and the dashboard visualization powered by Chart.js.

View File

@@ -151,3 +151,11 @@ Next actionable items:
4. Scaffold Monte Carlo Simulation endpoints (`services/simulation.py`, `routes/simulations.py`, tests). 4. Scaffold Monte Carlo Simulation endpoints (`services/simulation.py`, `routes/simulations.py`, tests).
5. Scaffold Reporting endpoints (`services/reporting.py`, `routes/reporting.py`, front-end Dashboard, tests). 5. Scaffold Reporting endpoints (`services/reporting.py`, `routes/reporting.py`, front-end Dashboard, tests).
6. Add CI job for tests and coverage. 6. Add CI job for tests and coverage.
## UI Template Audit (2025-10-20)
- Existing HTML templates: `ScenarioForm.html`, `ParameterInput.html`, and `Dashboard.html` (reporting summary view).
- Coverage gaps remain for costs, consumption, production, equipment, maintenance, and simulation workflows—no dedicated templates yet.
- Shared layout primitives (navigation/header/footer) are absent; current pages duplicate boilerplate markup.
- Dashboard currently covers reporting metrics but should be wired to a central `/` route once the shared layout lands.
- Next steps align with the updated TODO checklist: introduce a `base.html`, refactor existing templates to extend it, and scaffold placeholder pages for the remaining features.

View File

@@ -4,21 +4,11 @@ from fastapi import APIRouter, Depends, status
from pydantic import BaseModel, ConfigDict, PositiveFloat from pydantic import BaseModel, ConfigDict, PositiveFloat
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from config.database import SessionLocal
from models.consumption import Consumption from models.consumption import Consumption
from routes.dependencies import get_db
router = APIRouter(prefix="/api/consumption", tags=["Consumption"]) router = APIRouter(prefix="/api/consumption", tags=["Consumption"])
def get_db():
db = SessionLocal()
try:
yield db
finally:
db.close()
class ConsumptionBase(BaseModel): class ConsumptionBase(BaseModel):
scenario_id: int scenario_id: int
amount: PositiveFloat amount: PositiveFloat

View File

@@ -1,23 +1,17 @@
from fastapi import APIRouter, Depends
from sqlalchemy.orm import Session
from typing import List, Optional from typing import List, Optional
from fastapi import APIRouter, Depends
from pydantic import BaseModel, ConfigDict from pydantic import BaseModel, ConfigDict
from config.database import SessionLocal from sqlalchemy.orm import Session
from models.capex import Capex from models.capex import Capex
from models.opex import Opex from models.opex import Opex
from routes.dependencies import get_db
router = APIRouter(prefix="/api/costs", tags=["Costs"]) router = APIRouter(prefix="/api/costs", tags=["Costs"])
def get_db():
db = SessionLocal()
try:
yield db
finally:
db.close()
# Pydantic schemas for Capex # Pydantic schemas for Capex
class CapexCreate(BaseModel): class CapexCreate(BaseModel):
scenario_id: int scenario_id: int
amount: float amount: float

13
routes/dependencies.py Normal file
View File

@@ -0,0 +1,13 @@
from collections.abc import Generator
from sqlalchemy.orm import Session
from config.database import SessionLocal
def get_db() -> Generator[Session, None, None]:
db = SessionLocal()
try:
yield db
finally:
db.close()

View File

@@ -1,25 +1,19 @@
from typing import Dict, List
from fastapi import APIRouter, Depends from fastapi import APIRouter, Depends
from sqlalchemy.orm import Session
from typing import List
from pydantic import BaseModel, ConfigDict from pydantic import BaseModel, ConfigDict
from config.database import SessionLocal from sqlalchemy.orm import Session
from models.distribution import Distribution from models.distribution import Distribution
from routes.dependencies import get_db
router = APIRouter(prefix="/api/distributions", tags=["Distributions"]) router = APIRouter(prefix="/api/distributions", tags=["Distributions"])
def get_db():
db = SessionLocal()
try:
yield db
finally:
db.close()
class DistributionCreate(BaseModel): class DistributionCreate(BaseModel):
name: str name: str
distribution_type: str distribution_type: str
parameters: dict parameters: Dict[str, float | int]
class DistributionRead(DistributionCreate): class DistributionRead(DistributionCreate):

View File

@@ -1,22 +1,16 @@
from fastapi import APIRouter, Depends
from sqlalchemy.orm import Session
from typing import List, Optional from typing import List, Optional
from fastapi import APIRouter, Depends
from pydantic import BaseModel, ConfigDict from pydantic import BaseModel, ConfigDict
from config.database import SessionLocal from sqlalchemy.orm import Session
from models.equipment import Equipment from models.equipment import Equipment
from routes.dependencies import get_db
router = APIRouter(prefix="/api/equipment", tags=["Equipment"]) router = APIRouter(prefix="/api/equipment", tags=["Equipment"])
def get_db():
db = SessionLocal()
try:
yield db
finally:
db.close()
# Pydantic schemas # Pydantic schemas
class EquipmentCreate(BaseModel): class EquipmentCreate(BaseModel):
scenario_id: int scenario_id: int
name: str name: str

View File

@@ -5,21 +5,13 @@ from fastapi import APIRouter, Depends, HTTPException, status
from pydantic import BaseModel, ConfigDict, PositiveFloat from pydantic import BaseModel, ConfigDict, PositiveFloat
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from config.database import SessionLocal
from models.maintenance import Maintenance from models.maintenance import Maintenance
from routes.dependencies import get_db
router = APIRouter(prefix="/api/maintenance", tags=["Maintenance"]) router = APIRouter(prefix="/api/maintenance", tags=["Maintenance"])
def get_db():
db = SessionLocal()
try:
yield db
finally:
db.close()
class MaintenanceBase(BaseModel): class MaintenanceBase(BaseModel):
equipment_id: int equipment_id: int
scenario_id: int scenario_id: int

View File

@@ -4,10 +4,10 @@ from fastapi import APIRouter, Depends, HTTPException
from pydantic import BaseModel, ConfigDict, field_validator from pydantic import BaseModel, ConfigDict, field_validator
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from config.database import SessionLocal
from models.distribution import Distribution from models.distribution import Distribution
from models.parameters import Parameter from models.parameters import Parameter
from models.scenario import Scenario from models.scenario import Scenario
from routes.dependencies import get_db
router = APIRouter(prefix="/api/parameters", tags=["parameters"]) router = APIRouter(prefix="/api/parameters", tags=["parameters"])
@@ -45,17 +45,6 @@ class ParameterRead(ParameterCreate):
id: int id: int
model_config = ConfigDict(from_attributes=True) model_config = ConfigDict(from_attributes=True)
# Dependency
def get_db():
db = SessionLocal()
try:
yield db
finally:
db.close()
@router.post("/", response_model=ParameterRead) @router.post("/", response_model=ParameterRead)
def create_parameter(param: ParameterCreate, db: Session = Depends(get_db)): def create_parameter(param: ParameterCreate, db: Session = Depends(get_db)):
scen = db.query(Scenario).filter(Scenario.id == param.scenario_id).first() scen = db.query(Scenario).filter(Scenario.id == param.scenario_id).first()

View File

@@ -4,21 +4,11 @@ from fastapi import APIRouter, Depends, status
from pydantic import BaseModel, ConfigDict, PositiveFloat from pydantic import BaseModel, ConfigDict, PositiveFloat
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from config.database import SessionLocal
from models.production_output import ProductionOutput from models.production_output import ProductionOutput
from routes.dependencies import get_db
router = APIRouter(prefix="/api/production", tags=["Production"]) router = APIRouter(prefix="/api/production", tags=["Production"])
def get_db():
db = SessionLocal()
try:
yield db
finally:
db.close()
class ProductionOutputBase(BaseModel): class ProductionOutputBase(BaseModel):
scenario_id: int scenario_id: int
amount: PositiveFloat amount: PositiveFloat

View File

@@ -3,21 +3,12 @@ from typing import Any, Dict, List, cast
from fastapi import APIRouter, HTTPException, Request, status from fastapi import APIRouter, HTTPException, Request, status
from pydantic import BaseModel from pydantic import BaseModel
from config.database import SessionLocal
from services.reporting import generate_report from services.reporting import generate_report
router = APIRouter(prefix="/api/reporting", tags=["Reporting"]) router = APIRouter(prefix="/api/reporting", tags=["Reporting"])
def get_db():
db = SessionLocal()
try:
yield db
finally:
db.close()
def _validate_payload(payload: Any) -> List[Dict[str, float]]: def _validate_payload(payload: Any) -> List[Dict[str, float]]:
if not isinstance(payload, list): if not isinstance(payload, list):
raise HTTPException( raise HTTPException(

View File

@@ -1,10 +1,12 @@
from fastapi import APIRouter, Depends, HTTPException
from sqlalchemy.orm import Session
from config.database import SessionLocal
from models.scenario import Scenario
from pydantic import BaseModel, ConfigDict
from typing import Optional
from datetime import datetime from datetime import datetime
from typing import Optional
from fastapi import APIRouter, Depends, HTTPException
from pydantic import BaseModel, ConfigDict
from sqlalchemy.orm import Session
from models.scenario import Scenario
from routes.dependencies import get_db
router = APIRouter(prefix="/api/scenarios", tags=["scenarios"]) router = APIRouter(prefix="/api/scenarios", tags=["scenarios"])
@@ -22,29 +24,15 @@ class ScenarioRead(ScenarioCreate):
updated_at: Optional[datetime] = None updated_at: Optional[datetime] = None
model_config = ConfigDict(from_attributes=True) model_config = ConfigDict(from_attributes=True)
# Dependency
def get_db():
db = SessionLocal()
try:
yield db
finally:
db.close()
@router.post("/", response_model=ScenarioRead) @router.post("/", response_model=ScenarioRead)
def create_scenario(scenario: ScenarioCreate, db: Session = Depends(get_db)): def create_scenario(scenario: ScenarioCreate, db: Session = Depends(get_db)):
print(f"Creating scenario with name: {scenario.name}")
db_s = db.query(Scenario).filter(Scenario.name == scenario.name).first() db_s = db.query(Scenario).filter(Scenario.name == scenario.name).first()
if db_s: if db_s:
print(f"Scenario with name {scenario.name} already exists.")
raise HTTPException(status_code=400, detail="Scenario already exists") raise HTTPException(status_code=400, detail="Scenario already exists")
new_s = Scenario(name=scenario.name, description=scenario.description) new_s = Scenario(name=scenario.name, description=scenario.description)
db.add(new_s) db.add(new_s)
db.commit() db.commit()
db.refresh(new_s) db.refresh(new_s)
print(f"Scenario with name {scenario.name} created successfully.")
return new_s return new_s

View File

@@ -1,27 +1,19 @@
from typing import List, Optional, cast from typing import Dict, List, Optional
from fastapi import APIRouter, Depends, HTTPException, status from fastapi import APIRouter, Depends, HTTPException, status
from pydantic import BaseModel, PositiveInt from pydantic import BaseModel, PositiveInt
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from config.database import SessionLocal
from models.parameters import Parameter from models.parameters import Parameter
from models.scenario import Scenario from models.scenario import Scenario
from models.simulation_result import SimulationResult from models.simulation_result import SimulationResult
from routes.dependencies import get_db
from services.reporting import generate_report from services.reporting import generate_report
from services.simulation import run_simulation from services.simulation import run_simulation
router = APIRouter(prefix="/api/simulations", tags=["Simulations"]) router = APIRouter(prefix="/api/simulations", tags=["Simulations"])
def get_db():
db = SessionLocal()
try:
yield db
finally:
db.close()
class SimulationParameterInput(BaseModel): class SimulationParameterInput(BaseModel):
name: str name: str
value: float value: float
@@ -48,7 +40,7 @@ class SimulationRunResponse(BaseModel):
scenario_id: int scenario_id: int
iterations: int iterations: int
results: List[SimulationResultItem] results: List[SimulationResultItem]
summary: dict summary: Dict[str, float | int]
def _load_parameters(db: Session, scenario_id: int) -> List[SimulationParameterInput]: def _load_parameters(db: Session, scenario_id: int) -> List[SimulationParameterInput]:

View File

@@ -57,38 +57,11 @@ def api_client(db_session: Session) -> Generator[TestClient, None, None]:
finally: finally:
pass pass
# override all routers that use get_db from routes import dependencies as route_dependencies
from routes import (
consumption,
costs,
distributions,
equipment,
maintenance,
parameters,
production,
reporting,
scenarios,
simulations,
)
overrides = { app.dependency_overrides[route_dependencies.get_db] = override_get_db
consumption.get_db: override_get_db,
costs.get_db: override_get_db,
distributions.get_db: override_get_db,
equipment.get_db: override_get_db,
maintenance.get_db: override_get_db,
parameters.get_db: override_get_db,
production.get_db: override_get_db,
reporting.get_db: override_get_db,
scenarios.get_db: override_get_db,
simulations.get_db: override_get_db,
}
for dependency, override in overrides.items():
app.dependency_overrides[dependency] = override
with TestClient(app) as client: with TestClient(app) as client:
yield client yield client
for dependency in overrides: app.dependency_overrides.pop(route_dependencies.get_db, None)
app.dependency_overrides.pop(dependency, None)