refactor: Centralize database session management in a shared dependency module

This commit is contained in:
2025-10-20 22:30:56 +02:00
parent 434be86b76
commit c6233e1a56
15 changed files with 61 additions and 153 deletions

View File

@@ -4,21 +4,11 @@ from fastapi import APIRouter, Depends, status
from pydantic import BaseModel, ConfigDict, PositiveFloat
from sqlalchemy.orm import Session
from config.database import SessionLocal
from models.consumption import Consumption
from routes.dependencies import get_db
router = APIRouter(prefix="/api/consumption", tags=["Consumption"])
def get_db():
db = SessionLocal()
try:
yield db
finally:
db.close()
class ConsumptionBase(BaseModel):
scenario_id: int
amount: PositiveFloat

View File

@@ -1,23 +1,17 @@
from fastapi import APIRouter, Depends
from sqlalchemy.orm import Session
from typing import List, Optional
from fastapi import APIRouter, Depends
from pydantic import BaseModel, ConfigDict
from config.database import SessionLocal
from sqlalchemy.orm import Session
from models.capex import Capex
from models.opex import Opex
from routes.dependencies import get_db
router = APIRouter(prefix="/api/costs", tags=["Costs"])
def get_db():
db = SessionLocal()
try:
yield db
finally:
db.close()
# Pydantic schemas for Capex
class CapexCreate(BaseModel):
scenario_id: int
amount: float

13
routes/dependencies.py Normal file
View File

@@ -0,0 +1,13 @@
from collections.abc import Generator
from sqlalchemy.orm import Session
from config.database import SessionLocal
def get_db() -> Generator[Session, None, None]:
db = SessionLocal()
try:
yield db
finally:
db.close()

View File

@@ -1,25 +1,19 @@
from typing import Dict, List
from fastapi import APIRouter, Depends
from sqlalchemy.orm import Session
from typing import List
from pydantic import BaseModel, ConfigDict
from config.database import SessionLocal
from sqlalchemy.orm import Session
from models.distribution import Distribution
from routes.dependencies import get_db
router = APIRouter(prefix="/api/distributions", tags=["Distributions"])
def get_db():
db = SessionLocal()
try:
yield db
finally:
db.close()
class DistributionCreate(BaseModel):
name: str
distribution_type: str
parameters: dict
parameters: Dict[str, float | int]
class DistributionRead(DistributionCreate):

View File

@@ -1,22 +1,16 @@
from fastapi import APIRouter, Depends
from sqlalchemy.orm import Session
from typing import List, Optional
from fastapi import APIRouter, Depends
from pydantic import BaseModel, ConfigDict
from config.database import SessionLocal
from sqlalchemy.orm import Session
from models.equipment import Equipment
from routes.dependencies import get_db
router = APIRouter(prefix="/api/equipment", tags=["Equipment"])
def get_db():
db = SessionLocal()
try:
yield db
finally:
db.close()
# Pydantic schemas
class EquipmentCreate(BaseModel):
scenario_id: int
name: str

View File

@@ -5,21 +5,13 @@ from fastapi import APIRouter, Depends, HTTPException, status
from pydantic import BaseModel, ConfigDict, PositiveFloat
from sqlalchemy.orm import Session
from config.database import SessionLocal
from models.maintenance import Maintenance
from routes.dependencies import get_db
router = APIRouter(prefix="/api/maintenance", tags=["Maintenance"])
def get_db():
db = SessionLocal()
try:
yield db
finally:
db.close()
class MaintenanceBase(BaseModel):
equipment_id: int
scenario_id: int

View File

@@ -4,10 +4,10 @@ from fastapi import APIRouter, Depends, HTTPException
from pydantic import BaseModel, ConfigDict, field_validator
from sqlalchemy.orm import Session
from config.database import SessionLocal
from models.distribution import Distribution
from models.parameters import Parameter
from models.scenario import Scenario
from routes.dependencies import get_db
router = APIRouter(prefix="/api/parameters", tags=["parameters"])
@@ -45,17 +45,6 @@ class ParameterRead(ParameterCreate):
id: int
model_config = ConfigDict(from_attributes=True)
# Dependency
def get_db():
db = SessionLocal()
try:
yield db
finally:
db.close()
@router.post("/", response_model=ParameterRead)
def create_parameter(param: ParameterCreate, db: Session = Depends(get_db)):
scen = db.query(Scenario).filter(Scenario.id == param.scenario_id).first()

View File

@@ -4,21 +4,11 @@ from fastapi import APIRouter, Depends, status
from pydantic import BaseModel, ConfigDict, PositiveFloat
from sqlalchemy.orm import Session
from config.database import SessionLocal
from models.production_output import ProductionOutput
from routes.dependencies import get_db
router = APIRouter(prefix="/api/production", tags=["Production"])
def get_db():
db = SessionLocal()
try:
yield db
finally:
db.close()
class ProductionOutputBase(BaseModel):
scenario_id: int
amount: PositiveFloat

View File

@@ -3,21 +3,12 @@ from typing import Any, Dict, List, cast
from fastapi import APIRouter, HTTPException, Request, status
from pydantic import BaseModel
from config.database import SessionLocal
from services.reporting import generate_report
router = APIRouter(prefix="/api/reporting", tags=["Reporting"])
def get_db():
db = SessionLocal()
try:
yield db
finally:
db.close()
def _validate_payload(payload: Any) -> List[Dict[str, float]]:
if not isinstance(payload, list):
raise HTTPException(

View File

@@ -1,10 +1,12 @@
from fastapi import APIRouter, Depends, HTTPException
from sqlalchemy.orm import Session
from config.database import SessionLocal
from models.scenario import Scenario
from pydantic import BaseModel, ConfigDict
from typing import Optional
from datetime import datetime
from typing import Optional
from fastapi import APIRouter, Depends, HTTPException
from pydantic import BaseModel, ConfigDict
from sqlalchemy.orm import Session
from models.scenario import Scenario
from routes.dependencies import get_db
router = APIRouter(prefix="/api/scenarios", tags=["scenarios"])
@@ -22,29 +24,15 @@ class ScenarioRead(ScenarioCreate):
updated_at: Optional[datetime] = None
model_config = ConfigDict(from_attributes=True)
# Dependency
def get_db():
db = SessionLocal()
try:
yield db
finally:
db.close()
@router.post("/", response_model=ScenarioRead)
def create_scenario(scenario: ScenarioCreate, db: Session = Depends(get_db)):
print(f"Creating scenario with name: {scenario.name}")
db_s = db.query(Scenario).filter(Scenario.name == scenario.name).first()
if db_s:
print(f"Scenario with name {scenario.name} already exists.")
raise HTTPException(status_code=400, detail="Scenario already exists")
new_s = Scenario(name=scenario.name, description=scenario.description)
db.add(new_s)
db.commit()
db.refresh(new_s)
print(f"Scenario with name {scenario.name} created successfully.")
return new_s

View File

@@ -1,27 +1,19 @@
from typing import List, Optional, cast
from typing import Dict, List, Optional
from fastapi import APIRouter, Depends, HTTPException, status
from pydantic import BaseModel, PositiveInt
from sqlalchemy.orm import Session
from config.database import SessionLocal
from models.parameters import Parameter
from models.scenario import Scenario
from models.simulation_result import SimulationResult
from routes.dependencies import get_db
from services.reporting import generate_report
from services.simulation import run_simulation
router = APIRouter(prefix="/api/simulations", tags=["Simulations"])
def get_db():
db = SessionLocal()
try:
yield db
finally:
db.close()
class SimulationParameterInput(BaseModel):
name: str
value: float
@@ -48,7 +40,7 @@ class SimulationRunResponse(BaseModel):
scenario_id: int
iterations: int
results: List[SimulationResultItem]
summary: dict
summary: Dict[str, float | int]
def _load_parameters(db: Session, scenario_id: int) -> List[SimulationParameterInput]: