feat: Implement SQLAlchemy enum helper and normalize enum values in database initialization

This commit is contained in:
2025-11-12 18:11:19 +01:00
parent bcdc9e861e
commit 1f892ebdbb
9 changed files with 143 additions and 24 deletions

View File

@@ -2,6 +2,7 @@
## 2025-11-12
- Completed local run verification: started application with `uvicorn main:app --reload` without errors, verified authenticated routes (/login, /, /projects/ui, /projects) load correctly with seeded data, and summarized findings for deployment pipeline readiness.
- Switched `models/performance_metric.py` to reuse the shared declarative base from `config.database`, clearing the SQLAlchemy 2.0 `declarative_base` deprecation warning and verifying repository tests still pass.
- Replaced the Alembic migration workflow with the idempotent Pydantic-backed initializer (`scripts/init_db.py`), added a guarded reset utility (`scripts/reset_db.py`), removed migration artifacts/tooling (Alembic directory, config, Docker entrypoint), refreshed the container entrypoint to invoke `uvicorn` directly, and updated installation/architecture docs plus the README to direct developers to the new seeding/reset flow.
- Eliminated Bandit hardcoded-secret findings by replacing literal JWT tokens and passwords across auth/security tests with randomized helpers drawn from `tests/utils/security.py`, ensuring fixtures still assert expected behaviours.

View File

@@ -1,6 +1,21 @@
from __future__ import annotations
from enum import Enum
from typing import Type
from sqlalchemy import Enum as SQLEnum
def sql_enum(enum_cls: Type[Enum], *, name: str) -> SQLEnum:
"""Build a SQLAlchemy Enum that maps using the enum member values."""
return SQLEnum(
enum_cls,
name=name,
create_type=False,
validate_strings=True,
values_callable=lambda enum_cls: [member.value for member in enum_cls],
)
class MiningOperationType(str, Enum):

View File

@@ -6,7 +6,6 @@ from typing import TYPE_CHECKING
from sqlalchemy import (
Date,
DateTime,
Enum as SQLEnum,
ForeignKey,
Integer,
Numeric,
@@ -18,7 +17,7 @@ from sqlalchemy.orm import Mapped, mapped_column, relationship, validates
from sqlalchemy.sql import func
from config.database import Base
from .enums import CostBucket, FinancialCategory
from .enums import CostBucket, FinancialCategory, sql_enum
from services.currency import normalise_currency
if TYPE_CHECKING: # pragma: no cover
@@ -36,10 +35,10 @@ class FinancialInput(Base):
)
name: Mapped[str] = mapped_column(String(255), nullable=False)
category: Mapped[FinancialCategory] = mapped_column(
SQLEnum(FinancialCategory, name="financialcategory", create_type=False), nullable=False
sql_enum(FinancialCategory, name="financialcategory"), nullable=False
)
cost_bucket: Mapped[CostBucket | None] = mapped_column(
SQLEnum(CostBucket, name="costbucket", create_type=False), nullable=True
sql_enum(CostBucket, name="costbucket"), nullable=True
)
amount: Mapped[float] = mapped_column(Numeric(18, 2), nullable=False)
currency: Mapped[str | None] = mapped_column(String(3), nullable=True)

View File

@@ -3,9 +3,9 @@ from __future__ import annotations
from datetime import datetime
from typing import TYPE_CHECKING, List
from .enums import MiningOperationType
from .enums import MiningOperationType, sql_enum
from sqlalchemy import DateTime, Enum as SQLEnum, ForeignKey, Integer, String, Text
from sqlalchemy import DateTime, ForeignKey, Integer, String, Text
from sqlalchemy.orm import Mapped, mapped_column, relationship
from sqlalchemy.sql import func
@@ -16,8 +16,6 @@ if TYPE_CHECKING: # pragma: no cover
from .pricing_settings import PricingSettings
class Project(Base):
"""Top-level mining project grouping multiple scenarios."""
@@ -27,7 +25,7 @@ class Project(Base):
name: Mapped[str] = mapped_column(String(255), nullable=False, unique=True)
location: Mapped[str | None] = mapped_column(String(255), nullable=True)
operation_type: Mapped[MiningOperationType] = mapped_column(
SQLEnum(MiningOperationType, name="miningoperationtype", create_type=False),
sql_enum(MiningOperationType, name="miningoperationtype"),
nullable=False,
default=MiningOperationType.OTHER,
)

View File

@@ -6,7 +6,6 @@ from typing import TYPE_CHECKING, List
from sqlalchemy import (
Date,
DateTime,
Enum as SQLEnum,
ForeignKey,
Integer,
Numeric,
@@ -19,7 +18,7 @@ from sqlalchemy.sql import func
from config.database import Base
from services.currency import normalise_currency
from .enums import ResourceType, ScenarioStatus
from .enums import ResourceType, ScenarioStatus, sql_enum
if TYPE_CHECKING: # pragma: no cover
from .financial_input import FinancialInput
@@ -43,7 +42,7 @@ class Scenario(Base):
name: Mapped[str] = mapped_column(String(255), nullable=False)
description: Mapped[str | None] = mapped_column(Text, nullable=True)
status: Mapped[ScenarioStatus] = mapped_column(
SQLEnum(ScenarioStatus, name="scenariostatus", create_type=False),
sql_enum(ScenarioStatus, name="scenariostatus"),
nullable=False,
default=ScenarioStatus.DRAFT,
)
@@ -53,7 +52,7 @@ class Scenario(Base):
Numeric(5, 2), nullable=True)
currency: Mapped[str | None] = mapped_column(String(3), nullable=True)
primary_resource: Mapped[ResourceType | None] = mapped_column(
SQLEnum(ResourceType, name="resourcetype", create_type=False), nullable=True
sql_enum(ResourceType, name="resourcetype"), nullable=True
)
created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), nullable=False, server_default=func.now()

View File

@@ -3,12 +3,11 @@ from __future__ import annotations
from datetime import datetime
from typing import TYPE_CHECKING
from .enums import DistributionType, ResourceType, StochasticVariable
from .enums import DistributionType, ResourceType, StochasticVariable, sql_enum
from sqlalchemy import (
JSON,
DateTime,
Enum as SQLEnum,
ForeignKey,
Integer,
Numeric,
@@ -34,13 +33,13 @@ class SimulationParameter(Base):
)
name: Mapped[str] = mapped_column(String(255), nullable=False)
distribution: Mapped[DistributionType] = mapped_column(
SQLEnum(DistributionType, name="distributiontype", create_type=False), nullable=False
sql_enum(DistributionType, name="distributiontype"), nullable=False
)
variable: Mapped[StochasticVariable | None] = mapped_column(
SQLEnum(StochasticVariable, name="stochasticvariable", create_type=False), nullable=True
sql_enum(StochasticVariable, name="stochasticvariable"), nullable=True
)
resource_type: Mapped[ResourceType | None] = mapped_column(
SQLEnum(ResourceType, name="resourcetype", create_type=False), nullable=True
sql_enum(ResourceType, name="resourcetype"), nullable=True
)
mean_value: Mapped[float | None] = mapped_column(
Numeric(18, 4), nullable=True)

View File

@@ -17,8 +17,7 @@ Notes:
"""
from __future__ import annotations
from dataclasses import dataclass
from typing import List, Optional
from typing import List, Optional, Set
import os
import logging
from decimal import Decimal
@@ -191,7 +190,8 @@ TABLE_DDLS = [
currency VARCHAR(3),
primary_resource resourcetype,
created_at TIMESTAMPTZ DEFAULT now(),
updated_at TIMESTAMPTZ DEFAULT now()
updated_at TIMESTAMPTZ DEFAULT now(),
CONSTRAINT uq_scenarios_project_name UNIQUE (project_id, name)
);
""",
"""
@@ -206,7 +206,8 @@ TABLE_DDLS = [
effective_date DATE,
notes TEXT,
created_at TIMESTAMPTZ DEFAULT now(),
updated_at TIMESTAMPTZ DEFAULT now()
updated_at TIMESTAMPTZ DEFAULT now(),
CONSTRAINT uq_financial_inputs_scenario_name UNIQUE (scenario_id, name)
);
""",
"""
@@ -429,6 +430,61 @@ def ensure_enums(engine: Engine) -> None:
conn.execute(text(sql))
def _fetch_enum_values(conn, type_name: str) -> Set[str]:
rows = conn.execute(
text(
"""
SELECT e.enumlabel
FROM pg_enum e
JOIN pg_type t ON t.oid = e.enumtypid
WHERE t.typname = :type_name
"""
),
{"type_name": type_name},
)
return {row.enumlabel for row in rows}
def normalize_enum_values(engine: Engine) -> None:
with engine.begin() as conn:
for type_name, expected_values in ENUM_DEFINITIONS.items():
try:
existing_values = _fetch_enum_values(conn, type_name)
except Exception as exc: # pragma: no cover - system catalogs missing
logger.debug(
"Skipping enum normalization for %s due to error: %s",
type_name,
exc,
)
continue
expected_set = set(expected_values)
for value in list(existing_values):
if value in expected_set:
continue
normalized = value.lower()
if (
normalized != value
and normalized in expected_set
and normalized not in existing_values
):
logger.info(
"Renaming enum value %s.%s -> %s",
type_name,
value,
normalized,
)
conn.execute(
text(
f"ALTER TYPE {type_name} RENAME VALUE :old_value TO :new_value"
),
{"old_value": value, "new_value": normalized},
)
existing_values.remove(value)
existing_values.add(normalized)
def ensure_tables(engine: Engine) -> None:
with engine.begin() as conn:
for ddl in TABLE_DDLS:
@@ -436,6 +492,45 @@ def ensure_tables(engine: Engine) -> None:
conn.execute(text(ddl))
CONSTRAINT_DDLS = [
"""
DO $$
BEGIN
IF NOT EXISTS (
SELECT 1
FROM pg_constraint
WHERE conname = 'uq_scenarios_project_name'
) THEN
ALTER TABLE scenarios
ADD CONSTRAINT uq_scenarios_project_name UNIQUE (project_id, name);
END IF;
END;
$$;
""",
"""
DO $$
BEGIN
IF NOT EXISTS (
SELECT 1
FROM pg_constraint
WHERE conname = 'uq_financial_inputs_scenario_name'
) THEN
ALTER TABLE financial_inputs
ADD CONSTRAINT uq_financial_inputs_scenario_name UNIQUE (scenario_id, name);
END IF;
END;
$$;
""",
]
def ensure_constraints(engine: Engine) -> None:
with engine.begin() as conn:
for ddl in CONSTRAINT_DDLS:
logger.debug("Ensuring constraint via:\n%s", ddl)
conn.execute(text(ddl))
def seed_roles(engine: Engine) -> None:
with engine.begin() as conn:
for r in DEFAULT_ROLES:
@@ -657,7 +752,9 @@ def init_db(database_url: Optional[str] = None) -> None:
engine = _create_engine(database_url)
logger.info("Starting DB initialization using engine=%s", engine)
ensure_enums(engine)
normalize_enum_values(engine)
ensure_tables(engine)
ensure_constraints(engine)
seed_roles(engine)
seed_admin_user(engine)
ensure_default_pricing(engine)

View File

@@ -60,7 +60,7 @@ class FakeConnection:
sql = str(statement).strip()
lower_sql = sql.lower()
if lower_sql.startswith("do $$ begin"):
if lower_sql.startswith("do $$"):
match = re.search(r"create type\s+(\w+)\s+as enum", lower_sql)
if match:
self.state.enums.add(match.group(1))
@@ -194,6 +194,18 @@ class FakeConnection:
self.state.financial_inputs[key] = record
return FakeResult([])
if "from pg_enum" in lower_sql and "enumlabel" in lower_sql:
type_name_param = params.get("type_name")
if type_name_param is None:
return FakeResult([])
type_name = str(type_name_param)
values = init_db.ENUM_DEFINITIONS.get(type_name, [])
rows = [SimpleNamespace(enumlabel=value) for value in values]
return FakeResult(rows)
if lower_sql.startswith("alter type") and "rename value" in lower_sql:
return FakeResult([])
raise NotImplementedError(
f"Unhandled SQL during test execution: {sql}")

View File

@@ -14,7 +14,6 @@ from models import Role, User, UserRole
from dependencies import get_auth_session, require_current_user
from services.security import hash_password
from services.session import AuthSession, SessionTokens
from tests.conftest import app
from tests.utils.security import random_password, random_token
COOKIE_SOURCE = "cookie"