Files
calminer/config/database.py

84 lines
2.3 KiB
Python

from sqlalchemy import create_engine
from sqlalchemy.orm import declarative_base, sessionmaker
import os
from dotenv import load_dotenv
load_dotenv()
def _build_database_url() -> str:
"""Construct the SQLAlchemy database URL from granular environment vars.
Falls back to `DATABASE_URL` for backward compatibility.
Supports SQLite when CALMINER_USE_SQLITE is set.
"""
legacy_url = os.environ.get("DATABASE_URL", "")
if legacy_url and legacy_url.strip() != "":
return legacy_url
use_sqlite = os.environ.get("CALMINER_USE_SQLITE", "").lower() in ("true", "1", "yes")
if use_sqlite:
# Use SQLite database
db_path = os.environ.get("DATABASE_PATH", "./data/calminer.db")
# Ensure the directory exists
os.makedirs(os.path.dirname(db_path), exist_ok=True)
return f"sqlite:///{db_path}"
driver = os.environ.get("DATABASE_DRIVER", "postgresql")
host = os.environ.get("DATABASE_HOST")
port = os.environ.get("DATABASE_PORT", "5432")
user = os.environ.get("DATABASE_USER")
password = os.environ.get("DATABASE_PASSWORD")
database = os.environ.get("DATABASE_NAME")
schema = os.environ.get("DATABASE_SCHEMA", "public")
missing = [
var_name
for var_name, value in (
("DATABASE_HOST", host),
("DATABASE_USER", user),
("DATABASE_NAME", database),
)
if not value
]
if missing:
raise RuntimeError(
"Missing database configuration: set DATABASE_URL or provide "
f"granular variables ({', '.join(missing)})"
)
url = f"{driver}://{user}:{password}@{host}"
if port:
url += f":{port}"
url += f"/{database}"
if schema:
url += f"?options=-csearch_path={schema}"
return str(url)
DATABASE_URL = _build_database_url()
engine = create_engine(DATABASE_URL, echo=True, future=True)
# Avoid expiring ORM objects on commit so that objects returned from UnitOfWork
# remain usable for the duration of the request cycle without causing
# DetachedInstanceError when accessed after the session commits.
SessionLocal = sessionmaker(
autocommit=False,
autoflush=False,
bind=engine,
expire_on_commit=False,
)
Base = declarative_base()
def get_db():
db = SessionLocal()
try:
yield db
finally:
db.close()