feat: Implement SQLAlchemy enum helper and normalize enum values in database initialization

This commit is contained in:
2025-11-12 18:11:19 +01:00
parent bcdc9e861e
commit 1f892ebdbb
9 changed files with 143 additions and 24 deletions

View File

@@ -17,8 +17,7 @@ Notes:
"""
from __future__ import annotations
from dataclasses import dataclass
from typing import List, Optional
from typing import List, Optional, Set
import os
import logging
from decimal import Decimal
@@ -191,7 +190,8 @@ TABLE_DDLS = [
currency VARCHAR(3),
primary_resource resourcetype,
created_at TIMESTAMPTZ DEFAULT now(),
updated_at TIMESTAMPTZ DEFAULT now()
updated_at TIMESTAMPTZ DEFAULT now(),
CONSTRAINT uq_scenarios_project_name UNIQUE (project_id, name)
);
""",
"""
@@ -206,7 +206,8 @@ TABLE_DDLS = [
effective_date DATE,
notes TEXT,
created_at TIMESTAMPTZ DEFAULT now(),
updated_at TIMESTAMPTZ DEFAULT now()
updated_at TIMESTAMPTZ DEFAULT now(),
CONSTRAINT uq_financial_inputs_scenario_name UNIQUE (scenario_id, name)
);
""",
"""
@@ -429,6 +430,61 @@ def ensure_enums(engine: Engine) -> None:
conn.execute(text(sql))
def _fetch_enum_values(conn, type_name: str) -> Set[str]:
rows = conn.execute(
text(
"""
SELECT e.enumlabel
FROM pg_enum e
JOIN pg_type t ON t.oid = e.enumtypid
WHERE t.typname = :type_name
"""
),
{"type_name": type_name},
)
return {row.enumlabel for row in rows}
def normalize_enum_values(engine: Engine) -> None:
with engine.begin() as conn:
for type_name, expected_values in ENUM_DEFINITIONS.items():
try:
existing_values = _fetch_enum_values(conn, type_name)
except Exception as exc: # pragma: no cover - system catalogs missing
logger.debug(
"Skipping enum normalization for %s due to error: %s",
type_name,
exc,
)
continue
expected_set = set(expected_values)
for value in list(existing_values):
if value in expected_set:
continue
normalized = value.lower()
if (
normalized != value
and normalized in expected_set
and normalized not in existing_values
):
logger.info(
"Renaming enum value %s.%s -> %s",
type_name,
value,
normalized,
)
conn.execute(
text(
f"ALTER TYPE {type_name} RENAME VALUE :old_value TO :new_value"
),
{"old_value": value, "new_value": normalized},
)
existing_values.remove(value)
existing_values.add(normalized)
def ensure_tables(engine: Engine) -> None:
with engine.begin() as conn:
for ddl in TABLE_DDLS:
@@ -436,6 +492,45 @@ def ensure_tables(engine: Engine) -> None:
conn.execute(text(ddl))
CONSTRAINT_DDLS = [
"""
DO $$
BEGIN
IF NOT EXISTS (
SELECT 1
FROM pg_constraint
WHERE conname = 'uq_scenarios_project_name'
) THEN
ALTER TABLE scenarios
ADD CONSTRAINT uq_scenarios_project_name UNIQUE (project_id, name);
END IF;
END;
$$;
""",
"""
DO $$
BEGIN
IF NOT EXISTS (
SELECT 1
FROM pg_constraint
WHERE conname = 'uq_financial_inputs_scenario_name'
) THEN
ALTER TABLE financial_inputs
ADD CONSTRAINT uq_financial_inputs_scenario_name UNIQUE (scenario_id, name);
END IF;
END;
$$;
""",
]
def ensure_constraints(engine: Engine) -> None:
with engine.begin() as conn:
for ddl in CONSTRAINT_DDLS:
logger.debug("Ensuring constraint via:\n%s", ddl)
conn.execute(text(ddl))
def seed_roles(engine: Engine) -> None:
with engine.begin() as conn:
for r in DEFAULT_ROLES:
@@ -657,7 +752,9 @@ def init_db(database_url: Optional[str] = None) -> None:
engine = _create_engine(database_url)
logger.info("Starting DB initialization using engine=%s", engine)
ensure_enums(engine)
normalize_enum_values(engine)
ensure_tables(engine)
ensure_constraints(engine)
seed_roles(engine)
seed_admin_user(engine)
ensure_default_pricing(engine)