feat: Implement SQLAlchemy enum helper and normalize enum values in database initialization

This commit is contained in:
2025-11-12 18:11:19 +01:00
parent bcdc9e861e
commit 1f892ebdbb
9 changed files with 143 additions and 24 deletions

View File

@@ -2,6 +2,7 @@
## 2025-11-12 ## 2025-11-12
- Completed local run verification: started application with `uvicorn main:app --reload` without errors, verified authenticated routes (/login, /, /projects/ui, /projects) load correctly with seeded data, and summarized findings for deployment pipeline readiness.
- Switched `models/performance_metric.py` to reuse the shared declarative base from `config.database`, clearing the SQLAlchemy 2.0 `declarative_base` deprecation warning and verifying repository tests still pass. - Switched `models/performance_metric.py` to reuse the shared declarative base from `config.database`, clearing the SQLAlchemy 2.0 `declarative_base` deprecation warning and verifying repository tests still pass.
- Replaced the Alembic migration workflow with the idempotent Pydantic-backed initializer (`scripts/init_db.py`), added a guarded reset utility (`scripts/reset_db.py`), removed migration artifacts/tooling (Alembic directory, config, Docker entrypoint), refreshed the container entrypoint to invoke `uvicorn` directly, and updated installation/architecture docs plus the README to direct developers to the new seeding/reset flow. - Replaced the Alembic migration workflow with the idempotent Pydantic-backed initializer (`scripts/init_db.py`), added a guarded reset utility (`scripts/reset_db.py`), removed migration artifacts/tooling (Alembic directory, config, Docker entrypoint), refreshed the container entrypoint to invoke `uvicorn` directly, and updated installation/architecture docs plus the README to direct developers to the new seeding/reset flow.
- Eliminated Bandit hardcoded-secret findings by replacing literal JWT tokens and passwords across auth/security tests with randomized helpers drawn from `tests/utils/security.py`, ensuring fixtures still assert expected behaviours. - Eliminated Bandit hardcoded-secret findings by replacing literal JWT tokens and passwords across auth/security tests with randomized helpers drawn from `tests/utils/security.py`, ensuring fixtures still assert expected behaviours.

View File

@@ -1,6 +1,21 @@
from __future__ import annotations from __future__ import annotations
from enum import Enum from enum import Enum
from typing import Type
from sqlalchemy import Enum as SQLEnum
def sql_enum(enum_cls: Type[Enum], *, name: str) -> SQLEnum:
"""Build a SQLAlchemy Enum that maps using the enum member values."""
return SQLEnum(
enum_cls,
name=name,
create_type=False,
validate_strings=True,
values_callable=lambda enum_cls: [member.value for member in enum_cls],
)
class MiningOperationType(str, Enum): class MiningOperationType(str, Enum):

View File

@@ -6,7 +6,6 @@ from typing import TYPE_CHECKING
from sqlalchemy import ( from sqlalchemy import (
Date, Date,
DateTime, DateTime,
Enum as SQLEnum,
ForeignKey, ForeignKey,
Integer, Integer,
Numeric, Numeric,
@@ -18,7 +17,7 @@ from sqlalchemy.orm import Mapped, mapped_column, relationship, validates
from sqlalchemy.sql import func from sqlalchemy.sql import func
from config.database import Base from config.database import Base
from .enums import CostBucket, FinancialCategory from .enums import CostBucket, FinancialCategory, sql_enum
from services.currency import normalise_currency from services.currency import normalise_currency
if TYPE_CHECKING: # pragma: no cover if TYPE_CHECKING: # pragma: no cover
@@ -36,10 +35,10 @@ class FinancialInput(Base):
) )
name: Mapped[str] = mapped_column(String(255), nullable=False) name: Mapped[str] = mapped_column(String(255), nullable=False)
category: Mapped[FinancialCategory] = mapped_column( category: Mapped[FinancialCategory] = mapped_column(
SQLEnum(FinancialCategory, name="financialcategory", create_type=False), nullable=False sql_enum(FinancialCategory, name="financialcategory"), nullable=False
) )
cost_bucket: Mapped[CostBucket | None] = mapped_column( cost_bucket: Mapped[CostBucket | None] = mapped_column(
SQLEnum(CostBucket, name="costbucket", create_type=False), nullable=True sql_enum(CostBucket, name="costbucket"), nullable=True
) )
amount: Mapped[float] = mapped_column(Numeric(18, 2), nullable=False) amount: Mapped[float] = mapped_column(Numeric(18, 2), nullable=False)
currency: Mapped[str | None] = mapped_column(String(3), nullable=True) currency: Mapped[str | None] = mapped_column(String(3), nullable=True)

View File

@@ -3,9 +3,9 @@ from __future__ import annotations
from datetime import datetime from datetime import datetime
from typing import TYPE_CHECKING, List from typing import TYPE_CHECKING, List
from .enums import MiningOperationType from .enums import MiningOperationType, sql_enum
from sqlalchemy import DateTime, Enum as SQLEnum, ForeignKey, Integer, String, Text from sqlalchemy import DateTime, ForeignKey, Integer, String, Text
from sqlalchemy.orm import Mapped, mapped_column, relationship from sqlalchemy.orm import Mapped, mapped_column, relationship
from sqlalchemy.sql import func from sqlalchemy.sql import func
@@ -16,8 +16,6 @@ if TYPE_CHECKING: # pragma: no cover
from .pricing_settings import PricingSettings from .pricing_settings import PricingSettings
class Project(Base): class Project(Base):
"""Top-level mining project grouping multiple scenarios.""" """Top-level mining project grouping multiple scenarios."""
@@ -27,7 +25,7 @@ class Project(Base):
name: Mapped[str] = mapped_column(String(255), nullable=False, unique=True) name: Mapped[str] = mapped_column(String(255), nullable=False, unique=True)
location: Mapped[str | None] = mapped_column(String(255), nullable=True) location: Mapped[str | None] = mapped_column(String(255), nullable=True)
operation_type: Mapped[MiningOperationType] = mapped_column( operation_type: Mapped[MiningOperationType] = mapped_column(
SQLEnum(MiningOperationType, name="miningoperationtype", create_type=False), sql_enum(MiningOperationType, name="miningoperationtype"),
nullable=False, nullable=False,
default=MiningOperationType.OTHER, default=MiningOperationType.OTHER,
) )

View File

@@ -6,7 +6,6 @@ from typing import TYPE_CHECKING, List
from sqlalchemy import ( from sqlalchemy import (
Date, Date,
DateTime, DateTime,
Enum as SQLEnum,
ForeignKey, ForeignKey,
Integer, Integer,
Numeric, Numeric,
@@ -19,7 +18,7 @@ from sqlalchemy.sql import func
from config.database import Base from config.database import Base
from services.currency import normalise_currency from services.currency import normalise_currency
from .enums import ResourceType, ScenarioStatus from .enums import ResourceType, ScenarioStatus, sql_enum
if TYPE_CHECKING: # pragma: no cover if TYPE_CHECKING: # pragma: no cover
from .financial_input import FinancialInput from .financial_input import FinancialInput
@@ -43,7 +42,7 @@ class Scenario(Base):
name: Mapped[str] = mapped_column(String(255), nullable=False) name: Mapped[str] = mapped_column(String(255), nullable=False)
description: Mapped[str | None] = mapped_column(Text, nullable=True) description: Mapped[str | None] = mapped_column(Text, nullable=True)
status: Mapped[ScenarioStatus] = mapped_column( status: Mapped[ScenarioStatus] = mapped_column(
SQLEnum(ScenarioStatus, name="scenariostatus", create_type=False), sql_enum(ScenarioStatus, name="scenariostatus"),
nullable=False, nullable=False,
default=ScenarioStatus.DRAFT, default=ScenarioStatus.DRAFT,
) )
@@ -53,7 +52,7 @@ class Scenario(Base):
Numeric(5, 2), nullable=True) Numeric(5, 2), nullable=True)
currency: Mapped[str | None] = mapped_column(String(3), nullable=True) currency: Mapped[str | None] = mapped_column(String(3), nullable=True)
primary_resource: Mapped[ResourceType | None] = mapped_column( primary_resource: Mapped[ResourceType | None] = mapped_column(
SQLEnum(ResourceType, name="resourcetype", create_type=False), nullable=True sql_enum(ResourceType, name="resourcetype"), nullable=True
) )
created_at: Mapped[datetime] = mapped_column( created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), nullable=False, server_default=func.now() DateTime(timezone=True), nullable=False, server_default=func.now()

View File

@@ -3,12 +3,11 @@ from __future__ import annotations
from datetime import datetime from datetime import datetime
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from .enums import DistributionType, ResourceType, StochasticVariable from .enums import DistributionType, ResourceType, StochasticVariable, sql_enum
from sqlalchemy import ( from sqlalchemy import (
JSON, JSON,
DateTime, DateTime,
Enum as SQLEnum,
ForeignKey, ForeignKey,
Integer, Integer,
Numeric, Numeric,
@@ -34,13 +33,13 @@ class SimulationParameter(Base):
) )
name: Mapped[str] = mapped_column(String(255), nullable=False) name: Mapped[str] = mapped_column(String(255), nullable=False)
distribution: Mapped[DistributionType] = mapped_column( distribution: Mapped[DistributionType] = mapped_column(
SQLEnum(DistributionType, name="distributiontype", create_type=False), nullable=False sql_enum(DistributionType, name="distributiontype"), nullable=False
) )
variable: Mapped[StochasticVariable | None] = mapped_column( variable: Mapped[StochasticVariable | None] = mapped_column(
SQLEnum(StochasticVariable, name="stochasticvariable", create_type=False), nullable=True sql_enum(StochasticVariable, name="stochasticvariable"), nullable=True
) )
resource_type: Mapped[ResourceType | None] = mapped_column( resource_type: Mapped[ResourceType | None] = mapped_column(
SQLEnum(ResourceType, name="resourcetype", create_type=False), nullable=True sql_enum(ResourceType, name="resourcetype"), nullable=True
) )
mean_value: Mapped[float | None] = mapped_column( mean_value: Mapped[float | None] = mapped_column(
Numeric(18, 4), nullable=True) Numeric(18, 4), nullable=True)

View File

@@ -17,8 +17,7 @@ Notes:
""" """
from __future__ import annotations from __future__ import annotations
from dataclasses import dataclass from typing import List, Optional, Set
from typing import List, Optional
import os import os
import logging import logging
from decimal import Decimal from decimal import Decimal
@@ -191,7 +190,8 @@ TABLE_DDLS = [
currency VARCHAR(3), currency VARCHAR(3),
primary_resource resourcetype, primary_resource resourcetype,
created_at TIMESTAMPTZ DEFAULT now(), created_at TIMESTAMPTZ DEFAULT now(),
updated_at TIMESTAMPTZ DEFAULT now() updated_at TIMESTAMPTZ DEFAULT now(),
CONSTRAINT uq_scenarios_project_name UNIQUE (project_id, name)
); );
""", """,
""" """
@@ -206,7 +206,8 @@ TABLE_DDLS = [
effective_date DATE, effective_date DATE,
notes TEXT, notes TEXT,
created_at TIMESTAMPTZ DEFAULT now(), created_at TIMESTAMPTZ DEFAULT now(),
updated_at TIMESTAMPTZ DEFAULT now() updated_at TIMESTAMPTZ DEFAULT now(),
CONSTRAINT uq_financial_inputs_scenario_name UNIQUE (scenario_id, name)
); );
""", """,
""" """
@@ -429,6 +430,61 @@ def ensure_enums(engine: Engine) -> None:
conn.execute(text(sql)) conn.execute(text(sql))
def _fetch_enum_values(conn, type_name: str) -> Set[str]:
rows = conn.execute(
text(
"""
SELECT e.enumlabel
FROM pg_enum e
JOIN pg_type t ON t.oid = e.enumtypid
WHERE t.typname = :type_name
"""
),
{"type_name": type_name},
)
return {row.enumlabel for row in rows}
def normalize_enum_values(engine: Engine) -> None:
with engine.begin() as conn:
for type_name, expected_values in ENUM_DEFINITIONS.items():
try:
existing_values = _fetch_enum_values(conn, type_name)
except Exception as exc: # pragma: no cover - system catalogs missing
logger.debug(
"Skipping enum normalization for %s due to error: %s",
type_name,
exc,
)
continue
expected_set = set(expected_values)
for value in list(existing_values):
if value in expected_set:
continue
normalized = value.lower()
if (
normalized != value
and normalized in expected_set
and normalized not in existing_values
):
logger.info(
"Renaming enum value %s.%s -> %s",
type_name,
value,
normalized,
)
conn.execute(
text(
f"ALTER TYPE {type_name} RENAME VALUE :old_value TO :new_value"
),
{"old_value": value, "new_value": normalized},
)
existing_values.remove(value)
existing_values.add(normalized)
def ensure_tables(engine: Engine) -> None: def ensure_tables(engine: Engine) -> None:
with engine.begin() as conn: with engine.begin() as conn:
for ddl in TABLE_DDLS: for ddl in TABLE_DDLS:
@@ -436,6 +492,45 @@ def ensure_tables(engine: Engine) -> None:
conn.execute(text(ddl)) conn.execute(text(ddl))
CONSTRAINT_DDLS = [
"""
DO $$
BEGIN
IF NOT EXISTS (
SELECT 1
FROM pg_constraint
WHERE conname = 'uq_scenarios_project_name'
) THEN
ALTER TABLE scenarios
ADD CONSTRAINT uq_scenarios_project_name UNIQUE (project_id, name);
END IF;
END;
$$;
""",
"""
DO $$
BEGIN
IF NOT EXISTS (
SELECT 1
FROM pg_constraint
WHERE conname = 'uq_financial_inputs_scenario_name'
) THEN
ALTER TABLE financial_inputs
ADD CONSTRAINT uq_financial_inputs_scenario_name UNIQUE (scenario_id, name);
END IF;
END;
$$;
""",
]
def ensure_constraints(engine: Engine) -> None:
with engine.begin() as conn:
for ddl in CONSTRAINT_DDLS:
logger.debug("Ensuring constraint via:\n%s", ddl)
conn.execute(text(ddl))
def seed_roles(engine: Engine) -> None: def seed_roles(engine: Engine) -> None:
with engine.begin() as conn: with engine.begin() as conn:
for r in DEFAULT_ROLES: for r in DEFAULT_ROLES:
@@ -657,7 +752,9 @@ def init_db(database_url: Optional[str] = None) -> None:
engine = _create_engine(database_url) engine = _create_engine(database_url)
logger.info("Starting DB initialization using engine=%s", engine) logger.info("Starting DB initialization using engine=%s", engine)
ensure_enums(engine) ensure_enums(engine)
normalize_enum_values(engine)
ensure_tables(engine) ensure_tables(engine)
ensure_constraints(engine)
seed_roles(engine) seed_roles(engine)
seed_admin_user(engine) seed_admin_user(engine)
ensure_default_pricing(engine) ensure_default_pricing(engine)

View File

@@ -60,7 +60,7 @@ class FakeConnection:
sql = str(statement).strip() sql = str(statement).strip()
lower_sql = sql.lower() lower_sql = sql.lower()
if lower_sql.startswith("do $$ begin"): if lower_sql.startswith("do $$"):
match = re.search(r"create type\s+(\w+)\s+as enum", lower_sql) match = re.search(r"create type\s+(\w+)\s+as enum", lower_sql)
if match: if match:
self.state.enums.add(match.group(1)) self.state.enums.add(match.group(1))
@@ -194,6 +194,18 @@ class FakeConnection:
self.state.financial_inputs[key] = record self.state.financial_inputs[key] = record
return FakeResult([]) return FakeResult([])
if "from pg_enum" in lower_sql and "enumlabel" in lower_sql:
type_name_param = params.get("type_name")
if type_name_param is None:
return FakeResult([])
type_name = str(type_name_param)
values = init_db.ENUM_DEFINITIONS.get(type_name, [])
rows = [SimpleNamespace(enumlabel=value) for value in values]
return FakeResult(rows)
if lower_sql.startswith("alter type") and "rename value" in lower_sql:
return FakeResult([])
raise NotImplementedError( raise NotImplementedError(
f"Unhandled SQL during test execution: {sql}") f"Unhandled SQL during test execution: {sql}")

View File

@@ -14,7 +14,6 @@ from models import Role, User, UserRole
from dependencies import get_auth_session, require_current_user from dependencies import get_auth_session, require_current_user
from services.security import hash_password from services.security import hash_password
from services.session import AuthSession, SessionTokens from services.session import AuthSession, SessionTokens
from tests.conftest import app
from tests.utils.security import random_password, random_token from tests.utils.security import random_password, random_token
COOKIE_SOURCE = "cookie" COOKIE_SOURCE = "cookie"