feat: Implement SQLAlchemy enum helper and normalize enum values in database initialization
This commit is contained in:
@@ -2,6 +2,7 @@
|
||||
|
||||
## 2025-11-12
|
||||
|
||||
- Completed local run verification: started application with `uvicorn main:app --reload` without errors, verified authenticated routes (/login, /, /projects/ui, /projects) load correctly with seeded data, and summarized findings for deployment pipeline readiness.
|
||||
- Switched `models/performance_metric.py` to reuse the shared declarative base from `config.database`, clearing the SQLAlchemy 2.0 `declarative_base` deprecation warning and verifying repository tests still pass.
|
||||
- Replaced the Alembic migration workflow with the idempotent Pydantic-backed initializer (`scripts/init_db.py`), added a guarded reset utility (`scripts/reset_db.py`), removed migration artifacts/tooling (Alembic directory, config, Docker entrypoint), refreshed the container entrypoint to invoke `uvicorn` directly, and updated installation/architecture docs plus the README to direct developers to the new seeding/reset flow.
|
||||
- Eliminated Bandit hardcoded-secret findings by replacing literal JWT tokens and passwords across auth/security tests with randomized helpers drawn from `tests/utils/security.py`, ensuring fixtures still assert expected behaviours.
|
||||
|
||||
@@ -1,6 +1,21 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from enum import Enum
|
||||
from typing import Type
|
||||
|
||||
from sqlalchemy import Enum as SQLEnum
|
||||
|
||||
|
||||
def sql_enum(enum_cls: Type[Enum], *, name: str) -> SQLEnum:
|
||||
"""Build a SQLAlchemy Enum that maps using the enum member values."""
|
||||
|
||||
return SQLEnum(
|
||||
enum_cls,
|
||||
name=name,
|
||||
create_type=False,
|
||||
validate_strings=True,
|
||||
values_callable=lambda enum_cls: [member.value for member in enum_cls],
|
||||
)
|
||||
|
||||
|
||||
class MiningOperationType(str, Enum):
|
||||
|
||||
@@ -6,7 +6,6 @@ from typing import TYPE_CHECKING
|
||||
from sqlalchemy import (
|
||||
Date,
|
||||
DateTime,
|
||||
Enum as SQLEnum,
|
||||
ForeignKey,
|
||||
Integer,
|
||||
Numeric,
|
||||
@@ -18,7 +17,7 @@ from sqlalchemy.orm import Mapped, mapped_column, relationship, validates
|
||||
from sqlalchemy.sql import func
|
||||
|
||||
from config.database import Base
|
||||
from .enums import CostBucket, FinancialCategory
|
||||
from .enums import CostBucket, FinancialCategory, sql_enum
|
||||
from services.currency import normalise_currency
|
||||
|
||||
if TYPE_CHECKING: # pragma: no cover
|
||||
@@ -36,10 +35,10 @@ class FinancialInput(Base):
|
||||
)
|
||||
name: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
category: Mapped[FinancialCategory] = mapped_column(
|
||||
SQLEnum(FinancialCategory, name="financialcategory", create_type=False), nullable=False
|
||||
sql_enum(FinancialCategory, name="financialcategory"), nullable=False
|
||||
)
|
||||
cost_bucket: Mapped[CostBucket | None] = mapped_column(
|
||||
SQLEnum(CostBucket, name="costbucket", create_type=False), nullable=True
|
||||
sql_enum(CostBucket, name="costbucket"), nullable=True
|
||||
)
|
||||
amount: Mapped[float] = mapped_column(Numeric(18, 2), nullable=False)
|
||||
currency: Mapped[str | None] = mapped_column(String(3), nullable=True)
|
||||
|
||||
@@ -3,9 +3,9 @@ from __future__ import annotations
|
||||
from datetime import datetime
|
||||
from typing import TYPE_CHECKING, List
|
||||
|
||||
from .enums import MiningOperationType
|
||||
from .enums import MiningOperationType, sql_enum
|
||||
|
||||
from sqlalchemy import DateTime, Enum as SQLEnum, ForeignKey, Integer, String, Text
|
||||
from sqlalchemy import DateTime, ForeignKey, Integer, String, Text
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
from sqlalchemy.sql import func
|
||||
|
||||
@@ -16,8 +16,6 @@ if TYPE_CHECKING: # pragma: no cover
|
||||
from .pricing_settings import PricingSettings
|
||||
|
||||
|
||||
|
||||
|
||||
class Project(Base):
|
||||
"""Top-level mining project grouping multiple scenarios."""
|
||||
|
||||
@@ -27,7 +25,7 @@ class Project(Base):
|
||||
name: Mapped[str] = mapped_column(String(255), nullable=False, unique=True)
|
||||
location: Mapped[str | None] = mapped_column(String(255), nullable=True)
|
||||
operation_type: Mapped[MiningOperationType] = mapped_column(
|
||||
SQLEnum(MiningOperationType, name="miningoperationtype", create_type=False),
|
||||
sql_enum(MiningOperationType, name="miningoperationtype"),
|
||||
nullable=False,
|
||||
default=MiningOperationType.OTHER,
|
||||
)
|
||||
|
||||
@@ -6,7 +6,6 @@ from typing import TYPE_CHECKING, List
|
||||
from sqlalchemy import (
|
||||
Date,
|
||||
DateTime,
|
||||
Enum as SQLEnum,
|
||||
ForeignKey,
|
||||
Integer,
|
||||
Numeric,
|
||||
@@ -19,7 +18,7 @@ from sqlalchemy.sql import func
|
||||
|
||||
from config.database import Base
|
||||
from services.currency import normalise_currency
|
||||
from .enums import ResourceType, ScenarioStatus
|
||||
from .enums import ResourceType, ScenarioStatus, sql_enum
|
||||
|
||||
if TYPE_CHECKING: # pragma: no cover
|
||||
from .financial_input import FinancialInput
|
||||
@@ -43,7 +42,7 @@ class Scenario(Base):
|
||||
name: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
description: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
status: Mapped[ScenarioStatus] = mapped_column(
|
||||
SQLEnum(ScenarioStatus, name="scenariostatus", create_type=False),
|
||||
sql_enum(ScenarioStatus, name="scenariostatus"),
|
||||
nullable=False,
|
||||
default=ScenarioStatus.DRAFT,
|
||||
)
|
||||
@@ -53,7 +52,7 @@ class Scenario(Base):
|
||||
Numeric(5, 2), nullable=True)
|
||||
currency: Mapped[str | None] = mapped_column(String(3), nullable=True)
|
||||
primary_resource: Mapped[ResourceType | None] = mapped_column(
|
||||
SQLEnum(ResourceType, name="resourcetype", create_type=False), nullable=True
|
||||
sql_enum(ResourceType, name="resourcetype"), nullable=True
|
||||
)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), nullable=False, server_default=func.now()
|
||||
|
||||
@@ -3,12 +3,11 @@ from __future__ import annotations
|
||||
from datetime import datetime
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from .enums import DistributionType, ResourceType, StochasticVariable
|
||||
from .enums import DistributionType, ResourceType, StochasticVariable, sql_enum
|
||||
|
||||
from sqlalchemy import (
|
||||
JSON,
|
||||
DateTime,
|
||||
Enum as SQLEnum,
|
||||
ForeignKey,
|
||||
Integer,
|
||||
Numeric,
|
||||
@@ -34,13 +33,13 @@ class SimulationParameter(Base):
|
||||
)
|
||||
name: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
distribution: Mapped[DistributionType] = mapped_column(
|
||||
SQLEnum(DistributionType, name="distributiontype", create_type=False), nullable=False
|
||||
sql_enum(DistributionType, name="distributiontype"), nullable=False
|
||||
)
|
||||
variable: Mapped[StochasticVariable | None] = mapped_column(
|
||||
SQLEnum(StochasticVariable, name="stochasticvariable", create_type=False), nullable=True
|
||||
sql_enum(StochasticVariable, name="stochasticvariable"), nullable=True
|
||||
)
|
||||
resource_type: Mapped[ResourceType | None] = mapped_column(
|
||||
SQLEnum(ResourceType, name="resourcetype", create_type=False), nullable=True
|
||||
sql_enum(ResourceType, name="resourcetype"), nullable=True
|
||||
)
|
||||
mean_value: Mapped[float | None] = mapped_column(
|
||||
Numeric(18, 4), nullable=True)
|
||||
|
||||
@@ -17,8 +17,7 @@ Notes:
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional
|
||||
from typing import List, Optional, Set
|
||||
import os
|
||||
import logging
|
||||
from decimal import Decimal
|
||||
@@ -191,7 +190,8 @@ TABLE_DDLS = [
|
||||
currency VARCHAR(3),
|
||||
primary_resource resourcetype,
|
||||
created_at TIMESTAMPTZ DEFAULT now(),
|
||||
updated_at TIMESTAMPTZ DEFAULT now()
|
||||
updated_at TIMESTAMPTZ DEFAULT now(),
|
||||
CONSTRAINT uq_scenarios_project_name UNIQUE (project_id, name)
|
||||
);
|
||||
""",
|
||||
"""
|
||||
@@ -206,7 +206,8 @@ TABLE_DDLS = [
|
||||
effective_date DATE,
|
||||
notes TEXT,
|
||||
created_at TIMESTAMPTZ DEFAULT now(),
|
||||
updated_at TIMESTAMPTZ DEFAULT now()
|
||||
updated_at TIMESTAMPTZ DEFAULT now(),
|
||||
CONSTRAINT uq_financial_inputs_scenario_name UNIQUE (scenario_id, name)
|
||||
);
|
||||
""",
|
||||
"""
|
||||
@@ -429,6 +430,61 @@ def ensure_enums(engine: Engine) -> None:
|
||||
conn.execute(text(sql))
|
||||
|
||||
|
||||
def _fetch_enum_values(conn, type_name: str) -> Set[str]:
|
||||
rows = conn.execute(
|
||||
text(
|
||||
"""
|
||||
SELECT e.enumlabel
|
||||
FROM pg_enum e
|
||||
JOIN pg_type t ON t.oid = e.enumtypid
|
||||
WHERE t.typname = :type_name
|
||||
"""
|
||||
),
|
||||
{"type_name": type_name},
|
||||
)
|
||||
return {row.enumlabel for row in rows}
|
||||
|
||||
|
||||
def normalize_enum_values(engine: Engine) -> None:
|
||||
with engine.begin() as conn:
|
||||
for type_name, expected_values in ENUM_DEFINITIONS.items():
|
||||
try:
|
||||
existing_values = _fetch_enum_values(conn, type_name)
|
||||
except Exception as exc: # pragma: no cover - system catalogs missing
|
||||
logger.debug(
|
||||
"Skipping enum normalization for %s due to error: %s",
|
||||
type_name,
|
||||
exc,
|
||||
)
|
||||
continue
|
||||
|
||||
expected_set = set(expected_values)
|
||||
for value in list(existing_values):
|
||||
if value in expected_set:
|
||||
continue
|
||||
|
||||
normalized = value.lower()
|
||||
if (
|
||||
normalized != value
|
||||
and normalized in expected_set
|
||||
and normalized not in existing_values
|
||||
):
|
||||
logger.info(
|
||||
"Renaming enum value %s.%s -> %s",
|
||||
type_name,
|
||||
value,
|
||||
normalized,
|
||||
)
|
||||
conn.execute(
|
||||
text(
|
||||
f"ALTER TYPE {type_name} RENAME VALUE :old_value TO :new_value"
|
||||
),
|
||||
{"old_value": value, "new_value": normalized},
|
||||
)
|
||||
existing_values.remove(value)
|
||||
existing_values.add(normalized)
|
||||
|
||||
|
||||
def ensure_tables(engine: Engine) -> None:
|
||||
with engine.begin() as conn:
|
||||
for ddl in TABLE_DDLS:
|
||||
@@ -436,6 +492,45 @@ def ensure_tables(engine: Engine) -> None:
|
||||
conn.execute(text(ddl))
|
||||
|
||||
|
||||
CONSTRAINT_DDLS = [
|
||||
"""
|
||||
DO $$
|
||||
BEGIN
|
||||
IF NOT EXISTS (
|
||||
SELECT 1
|
||||
FROM pg_constraint
|
||||
WHERE conname = 'uq_scenarios_project_name'
|
||||
) THEN
|
||||
ALTER TABLE scenarios
|
||||
ADD CONSTRAINT uq_scenarios_project_name UNIQUE (project_id, name);
|
||||
END IF;
|
||||
END;
|
||||
$$;
|
||||
""",
|
||||
"""
|
||||
DO $$
|
||||
BEGIN
|
||||
IF NOT EXISTS (
|
||||
SELECT 1
|
||||
FROM pg_constraint
|
||||
WHERE conname = 'uq_financial_inputs_scenario_name'
|
||||
) THEN
|
||||
ALTER TABLE financial_inputs
|
||||
ADD CONSTRAINT uq_financial_inputs_scenario_name UNIQUE (scenario_id, name);
|
||||
END IF;
|
||||
END;
|
||||
$$;
|
||||
""",
|
||||
]
|
||||
|
||||
|
||||
def ensure_constraints(engine: Engine) -> None:
|
||||
with engine.begin() as conn:
|
||||
for ddl in CONSTRAINT_DDLS:
|
||||
logger.debug("Ensuring constraint via:\n%s", ddl)
|
||||
conn.execute(text(ddl))
|
||||
|
||||
|
||||
def seed_roles(engine: Engine) -> None:
|
||||
with engine.begin() as conn:
|
||||
for r in DEFAULT_ROLES:
|
||||
@@ -657,7 +752,9 @@ def init_db(database_url: Optional[str] = None) -> None:
|
||||
engine = _create_engine(database_url)
|
||||
logger.info("Starting DB initialization using engine=%s", engine)
|
||||
ensure_enums(engine)
|
||||
normalize_enum_values(engine)
|
||||
ensure_tables(engine)
|
||||
ensure_constraints(engine)
|
||||
seed_roles(engine)
|
||||
seed_admin_user(engine)
|
||||
ensure_default_pricing(engine)
|
||||
|
||||
@@ -60,7 +60,7 @@ class FakeConnection:
|
||||
sql = str(statement).strip()
|
||||
lower_sql = sql.lower()
|
||||
|
||||
if lower_sql.startswith("do $$ begin"):
|
||||
if lower_sql.startswith("do $$"):
|
||||
match = re.search(r"create type\s+(\w+)\s+as enum", lower_sql)
|
||||
if match:
|
||||
self.state.enums.add(match.group(1))
|
||||
@@ -194,6 +194,18 @@ class FakeConnection:
|
||||
self.state.financial_inputs[key] = record
|
||||
return FakeResult([])
|
||||
|
||||
if "from pg_enum" in lower_sql and "enumlabel" in lower_sql:
|
||||
type_name_param = params.get("type_name")
|
||||
if type_name_param is None:
|
||||
return FakeResult([])
|
||||
type_name = str(type_name_param)
|
||||
values = init_db.ENUM_DEFINITIONS.get(type_name, [])
|
||||
rows = [SimpleNamespace(enumlabel=value) for value in values]
|
||||
return FakeResult(rows)
|
||||
|
||||
if lower_sql.startswith("alter type") and "rename value" in lower_sql:
|
||||
return FakeResult([])
|
||||
|
||||
raise NotImplementedError(
|
||||
f"Unhandled SQL during test execution: {sql}")
|
||||
|
||||
|
||||
@@ -14,7 +14,6 @@ from models import Role, User, UserRole
|
||||
from dependencies import get_auth_session, require_current_user
|
||||
from services.security import hash_password
|
||||
from services.session import AuthSession, SessionTokens
|
||||
from tests.conftest import app
|
||||
from tests.utils.security import random_password, random_token
|
||||
|
||||
COOKIE_SOURCE = "cookie"
|
||||
|
||||
Reference in New Issue
Block a user