feat: Implement SQLAlchemy enum helper and normalize enum values in database initialization
This commit is contained in:
@@ -17,8 +17,7 @@ Notes:
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional
|
||||
from typing import List, Optional, Set
|
||||
import os
|
||||
import logging
|
||||
from decimal import Decimal
|
||||
@@ -191,7 +190,8 @@ TABLE_DDLS = [
|
||||
currency VARCHAR(3),
|
||||
primary_resource resourcetype,
|
||||
created_at TIMESTAMPTZ DEFAULT now(),
|
||||
updated_at TIMESTAMPTZ DEFAULT now()
|
||||
updated_at TIMESTAMPTZ DEFAULT now(),
|
||||
CONSTRAINT uq_scenarios_project_name UNIQUE (project_id, name)
|
||||
);
|
||||
""",
|
||||
"""
|
||||
@@ -206,7 +206,8 @@ TABLE_DDLS = [
|
||||
effective_date DATE,
|
||||
notes TEXT,
|
||||
created_at TIMESTAMPTZ DEFAULT now(),
|
||||
updated_at TIMESTAMPTZ DEFAULT now()
|
||||
updated_at TIMESTAMPTZ DEFAULT now(),
|
||||
CONSTRAINT uq_financial_inputs_scenario_name UNIQUE (scenario_id, name)
|
||||
);
|
||||
""",
|
||||
"""
|
||||
@@ -429,6 +430,61 @@ def ensure_enums(engine: Engine) -> None:
|
||||
conn.execute(text(sql))
|
||||
|
||||
|
||||
def _fetch_enum_values(conn, type_name: str) -> Set[str]:
|
||||
rows = conn.execute(
|
||||
text(
|
||||
"""
|
||||
SELECT e.enumlabel
|
||||
FROM pg_enum e
|
||||
JOIN pg_type t ON t.oid = e.enumtypid
|
||||
WHERE t.typname = :type_name
|
||||
"""
|
||||
),
|
||||
{"type_name": type_name},
|
||||
)
|
||||
return {row.enumlabel for row in rows}
|
||||
|
||||
|
||||
def normalize_enum_values(engine: Engine) -> None:
|
||||
with engine.begin() as conn:
|
||||
for type_name, expected_values in ENUM_DEFINITIONS.items():
|
||||
try:
|
||||
existing_values = _fetch_enum_values(conn, type_name)
|
||||
except Exception as exc: # pragma: no cover - system catalogs missing
|
||||
logger.debug(
|
||||
"Skipping enum normalization for %s due to error: %s",
|
||||
type_name,
|
||||
exc,
|
||||
)
|
||||
continue
|
||||
|
||||
expected_set = set(expected_values)
|
||||
for value in list(existing_values):
|
||||
if value in expected_set:
|
||||
continue
|
||||
|
||||
normalized = value.lower()
|
||||
if (
|
||||
normalized != value
|
||||
and normalized in expected_set
|
||||
and normalized not in existing_values
|
||||
):
|
||||
logger.info(
|
||||
"Renaming enum value %s.%s -> %s",
|
||||
type_name,
|
||||
value,
|
||||
normalized,
|
||||
)
|
||||
conn.execute(
|
||||
text(
|
||||
f"ALTER TYPE {type_name} RENAME VALUE :old_value TO :new_value"
|
||||
),
|
||||
{"old_value": value, "new_value": normalized},
|
||||
)
|
||||
existing_values.remove(value)
|
||||
existing_values.add(normalized)
|
||||
|
||||
|
||||
def ensure_tables(engine: Engine) -> None:
|
||||
with engine.begin() as conn:
|
||||
for ddl in TABLE_DDLS:
|
||||
@@ -436,6 +492,45 @@ def ensure_tables(engine: Engine) -> None:
|
||||
conn.execute(text(ddl))
|
||||
|
||||
|
||||
CONSTRAINT_DDLS = [
|
||||
"""
|
||||
DO $$
|
||||
BEGIN
|
||||
IF NOT EXISTS (
|
||||
SELECT 1
|
||||
FROM pg_constraint
|
||||
WHERE conname = 'uq_scenarios_project_name'
|
||||
) THEN
|
||||
ALTER TABLE scenarios
|
||||
ADD CONSTRAINT uq_scenarios_project_name UNIQUE (project_id, name);
|
||||
END IF;
|
||||
END;
|
||||
$$;
|
||||
""",
|
||||
"""
|
||||
DO $$
|
||||
BEGIN
|
||||
IF NOT EXISTS (
|
||||
SELECT 1
|
||||
FROM pg_constraint
|
||||
WHERE conname = 'uq_financial_inputs_scenario_name'
|
||||
) THEN
|
||||
ALTER TABLE financial_inputs
|
||||
ADD CONSTRAINT uq_financial_inputs_scenario_name UNIQUE (scenario_id, name);
|
||||
END IF;
|
||||
END;
|
||||
$$;
|
||||
""",
|
||||
]
|
||||
|
||||
|
||||
def ensure_constraints(engine: Engine) -> None:
|
||||
with engine.begin() as conn:
|
||||
for ddl in CONSTRAINT_DDLS:
|
||||
logger.debug("Ensuring constraint via:\n%s", ddl)
|
||||
conn.execute(text(ddl))
|
||||
|
||||
|
||||
def seed_roles(engine: Engine) -> None:
|
||||
with engine.begin() as conn:
|
||||
for r in DEFAULT_ROLES:
|
||||
@@ -657,7 +752,9 @@ def init_db(database_url: Optional[str] = None) -> None:
|
||||
engine = _create_engine(database_url)
|
||||
logger.info("Starting DB initialization using engine=%s", engine)
|
||||
ensure_enums(engine)
|
||||
normalize_enum_values(engine)
|
||||
ensure_tables(engine)
|
||||
ensure_constraints(engine)
|
||||
seed_roles(engine)
|
||||
seed_admin_user(engine)
|
||||
ensure_default_pricing(engine)
|
||||
|
||||
Reference in New Issue
Block a user