refactor: Centralize database session management in a shared dependency module

This commit is contained in:
2025-10-20 22:30:56 +02:00
parent 434be86b76
commit c6233e1a56
15 changed files with 61 additions and 153 deletions

View File

@@ -29,7 +29,7 @@ The architecture is documented in [docs/architecture.md](docs/architecture.md).
The project is organized into several key directories:
- `models/`: Contains SQLAlchemy models representing database tables.
- `routes/`: Defines FastAPI routes for API endpoints.
- `routes/`: Defines FastAPI routes for API endpoints; shared dependencies like `get_db` live in `routes/dependencies.py`.
- `services/`: Business logic and service layer.
- `components/`: Frontend components (to be defined).
- `config/`: Configuration files and settings.

View File

@@ -10,7 +10,7 @@ The backend leverages SQLAlchemy for ORM mapping to a PostgreSQL database.
## System Components
- **FastAPI backend** (`main.py`, `routes/`): hosts REST endpoints for scenarios, parameters, costs, consumption, production, equipment, maintenance, simulations, and reporting. Each router encapsulates request/response schemas and DB access patterns.
- **FastAPI backend** (`main.py`, `routes/`): hosts REST endpoints for scenarios, parameters, costs, consumption, production, equipment, maintenance, simulations, and reporting. Each router encapsulates request/response schemas and DB access patterns, leveraging a shared dependency module (`routes/dependencies.get_db`) for SQLAlchemy session management.
- **Service layer** (`services/`): houses business logic. `services/reporting.py` produces statistical summaries, while `services/simulation.py` provides the Monte Carlo integration point.
- **Persistence** (`models/`, `config/database.py`): SQLAlchemy models map to PostgreSQL tables in schema `bricsium_platform`. Relationships connect scenarios to derived domain entities.
- **Presentation** (`templates/`, `components/`): server-rendered views support data entry (scenario and parameter forms) and the dashboard visualization powered by Chart.js.

View File

@@ -151,3 +151,11 @@ Next actionable items:
4. Scaffold Monte Carlo Simulation endpoints (`services/simulation.py`, `routes/simulations.py`, tests).
5. Scaffold Reporting endpoints (`services/reporting.py`, `routes/reporting.py`, front-end Dashboard, tests).
6. Add CI job for tests and coverage.
## UI Template Audit (2025-10-20)
- Existing HTML templates: `ScenarioForm.html`, `ParameterInput.html`, and `Dashboard.html` (reporting summary view).
- Coverage gaps remain for costs, consumption, production, equipment, maintenance, and simulation workflows—no dedicated templates yet.
- Shared layout primitives (navigation/header/footer) are absent; current pages duplicate boilerplate markup.
- Dashboard currently covers reporting metrics but should be wired to a central `/` route once the shared layout lands.
- Next steps align with the updated TODO checklist: introduce a `base.html`, refactor existing templates to extend it, and scaffold placeholder pages for the remaining features.

View File

@@ -4,21 +4,11 @@ from fastapi import APIRouter, Depends, status
from pydantic import BaseModel, ConfigDict, PositiveFloat
from sqlalchemy.orm import Session
from config.database import SessionLocal
from models.consumption import Consumption
from routes.dependencies import get_db
router = APIRouter(prefix="/api/consumption", tags=["Consumption"])
def get_db():
db = SessionLocal()
try:
yield db
finally:
db.close()
class ConsumptionBase(BaseModel):
scenario_id: int
amount: PositiveFloat

View File

@@ -1,23 +1,17 @@
from fastapi import APIRouter, Depends
from sqlalchemy.orm import Session
from typing import List, Optional
from fastapi import APIRouter, Depends
from pydantic import BaseModel, ConfigDict
from config.database import SessionLocal
from sqlalchemy.orm import Session
from models.capex import Capex
from models.opex import Opex
from routes.dependencies import get_db
router = APIRouter(prefix="/api/costs", tags=["Costs"])
def get_db():
db = SessionLocal()
try:
yield db
finally:
db.close()
# Pydantic schemas for Capex
class CapexCreate(BaseModel):
scenario_id: int
amount: float

13
routes/dependencies.py Normal file
View File

@@ -0,0 +1,13 @@
from collections.abc import Generator
from sqlalchemy.orm import Session
from config.database import SessionLocal
def get_db() -> Generator[Session, None, None]:
db = SessionLocal()
try:
yield db
finally:
db.close()

View File

@@ -1,25 +1,19 @@
from typing import Dict, List
from fastapi import APIRouter, Depends
from sqlalchemy.orm import Session
from typing import List
from pydantic import BaseModel, ConfigDict
from config.database import SessionLocal
from sqlalchemy.orm import Session
from models.distribution import Distribution
from routes.dependencies import get_db
router = APIRouter(prefix="/api/distributions", tags=["Distributions"])
def get_db():
db = SessionLocal()
try:
yield db
finally:
db.close()
class DistributionCreate(BaseModel):
name: str
distribution_type: str
parameters: dict
parameters: Dict[str, float | int]
class DistributionRead(DistributionCreate):

View File

@@ -1,22 +1,16 @@
from fastapi import APIRouter, Depends
from sqlalchemy.orm import Session
from typing import List, Optional
from fastapi import APIRouter, Depends
from pydantic import BaseModel, ConfigDict
from config.database import SessionLocal
from sqlalchemy.orm import Session
from models.equipment import Equipment
from routes.dependencies import get_db
router = APIRouter(prefix="/api/equipment", tags=["Equipment"])
def get_db():
db = SessionLocal()
try:
yield db
finally:
db.close()
# Pydantic schemas
class EquipmentCreate(BaseModel):
scenario_id: int
name: str

View File

@@ -5,21 +5,13 @@ from fastapi import APIRouter, Depends, HTTPException, status
from pydantic import BaseModel, ConfigDict, PositiveFloat
from sqlalchemy.orm import Session
from config.database import SessionLocal
from models.maintenance import Maintenance
from routes.dependencies import get_db
router = APIRouter(prefix="/api/maintenance", tags=["Maintenance"])
def get_db():
db = SessionLocal()
try:
yield db
finally:
db.close()
class MaintenanceBase(BaseModel):
equipment_id: int
scenario_id: int

View File

@@ -4,10 +4,10 @@ from fastapi import APIRouter, Depends, HTTPException
from pydantic import BaseModel, ConfigDict, field_validator
from sqlalchemy.orm import Session
from config.database import SessionLocal
from models.distribution import Distribution
from models.parameters import Parameter
from models.scenario import Scenario
from routes.dependencies import get_db
router = APIRouter(prefix="/api/parameters", tags=["parameters"])
@@ -45,17 +45,6 @@ class ParameterRead(ParameterCreate):
id: int
model_config = ConfigDict(from_attributes=True)
# Dependency
def get_db():
db = SessionLocal()
try:
yield db
finally:
db.close()
@router.post("/", response_model=ParameterRead)
def create_parameter(param: ParameterCreate, db: Session = Depends(get_db)):
scen = db.query(Scenario).filter(Scenario.id == param.scenario_id).first()

View File

@@ -4,21 +4,11 @@ from fastapi import APIRouter, Depends, status
from pydantic import BaseModel, ConfigDict, PositiveFloat
from sqlalchemy.orm import Session
from config.database import SessionLocal
from models.production_output import ProductionOutput
from routes.dependencies import get_db
router = APIRouter(prefix="/api/production", tags=["Production"])
def get_db():
db = SessionLocal()
try:
yield db
finally:
db.close()
class ProductionOutputBase(BaseModel):
scenario_id: int
amount: PositiveFloat

View File

@@ -3,21 +3,12 @@ from typing import Any, Dict, List, cast
from fastapi import APIRouter, HTTPException, Request, status
from pydantic import BaseModel
from config.database import SessionLocal
from services.reporting import generate_report
router = APIRouter(prefix="/api/reporting", tags=["Reporting"])
def get_db():
db = SessionLocal()
try:
yield db
finally:
db.close()
def _validate_payload(payload: Any) -> List[Dict[str, float]]:
if not isinstance(payload, list):
raise HTTPException(

View File

@@ -1,10 +1,12 @@
from fastapi import APIRouter, Depends, HTTPException
from sqlalchemy.orm import Session
from config.database import SessionLocal
from models.scenario import Scenario
from pydantic import BaseModel, ConfigDict
from typing import Optional
from datetime import datetime
from typing import Optional
from fastapi import APIRouter, Depends, HTTPException
from pydantic import BaseModel, ConfigDict
from sqlalchemy.orm import Session
from models.scenario import Scenario
from routes.dependencies import get_db
router = APIRouter(prefix="/api/scenarios", tags=["scenarios"])
@@ -22,29 +24,15 @@ class ScenarioRead(ScenarioCreate):
updated_at: Optional[datetime] = None
model_config = ConfigDict(from_attributes=True)
# Dependency
def get_db():
db = SessionLocal()
try:
yield db
finally:
db.close()
@router.post("/", response_model=ScenarioRead)
def create_scenario(scenario: ScenarioCreate, db: Session = Depends(get_db)):
print(f"Creating scenario with name: {scenario.name}")
db_s = db.query(Scenario).filter(Scenario.name == scenario.name).first()
if db_s:
print(f"Scenario with name {scenario.name} already exists.")
raise HTTPException(status_code=400, detail="Scenario already exists")
new_s = Scenario(name=scenario.name, description=scenario.description)
db.add(new_s)
db.commit()
db.refresh(new_s)
print(f"Scenario with name {scenario.name} created successfully.")
return new_s

View File

@@ -1,27 +1,19 @@
from typing import List, Optional, cast
from typing import Dict, List, Optional
from fastapi import APIRouter, Depends, HTTPException, status
from pydantic import BaseModel, PositiveInt
from sqlalchemy.orm import Session
from config.database import SessionLocal
from models.parameters import Parameter
from models.scenario import Scenario
from models.simulation_result import SimulationResult
from routes.dependencies import get_db
from services.reporting import generate_report
from services.simulation import run_simulation
router = APIRouter(prefix="/api/simulations", tags=["Simulations"])
def get_db():
db = SessionLocal()
try:
yield db
finally:
db.close()
class SimulationParameterInput(BaseModel):
name: str
value: float
@@ -48,7 +40,7 @@ class SimulationRunResponse(BaseModel):
scenario_id: int
iterations: int
results: List[SimulationResultItem]
summary: dict
summary: Dict[str, float | int]
def _load_parameters(db: Session, scenario_id: int) -> List[SimulationParameterInput]:

View File

@@ -57,38 +57,11 @@ def api_client(db_session: Session) -> Generator[TestClient, None, None]:
finally:
pass
# override all routers that use get_db
from routes import (
consumption,
costs,
distributions,
equipment,
maintenance,
parameters,
production,
reporting,
scenarios,
simulations,
)
from routes import dependencies as route_dependencies
overrides = {
consumption.get_db: override_get_db,
costs.get_db: override_get_db,
distributions.get_db: override_get_db,
equipment.get_db: override_get_db,
maintenance.get_db: override_get_db,
parameters.get_db: override_get_db,
production.get_db: override_get_db,
reporting.get_db: override_get_db,
scenarios.get_db: override_get_db,
simulations.get_db: override_get_db,
}
for dependency, override in overrides.items():
app.dependency_overrides[dependency] = override
app.dependency_overrides[route_dependencies.get_db] = override_get_db
with TestClient(app) as client:
yield client
for dependency in overrides:
app.dependency_overrides.pop(dependency, None)
app.dependency_overrides.pop(route_dependencies.get_db, None)