feat: Implement SQLAlchemy enum helper and normalize enum values in database initialization
This commit is contained in:
@@ -2,6 +2,7 @@
|
|||||||
|
|
||||||
## 2025-11-12
|
## 2025-11-12
|
||||||
|
|
||||||
|
- Completed local run verification: started application with `uvicorn main:app --reload` without errors, verified authenticated routes (/login, /, /projects/ui, /projects) load correctly with seeded data, and summarized findings for deployment pipeline readiness.
|
||||||
- Switched `models/performance_metric.py` to reuse the shared declarative base from `config.database`, clearing the SQLAlchemy 2.0 `declarative_base` deprecation warning and verifying repository tests still pass.
|
- Switched `models/performance_metric.py` to reuse the shared declarative base from `config.database`, clearing the SQLAlchemy 2.0 `declarative_base` deprecation warning and verifying repository tests still pass.
|
||||||
- Replaced the Alembic migration workflow with the idempotent Pydantic-backed initializer (`scripts/init_db.py`), added a guarded reset utility (`scripts/reset_db.py`), removed migration artifacts/tooling (Alembic directory, config, Docker entrypoint), refreshed the container entrypoint to invoke `uvicorn` directly, and updated installation/architecture docs plus the README to direct developers to the new seeding/reset flow.
|
- Replaced the Alembic migration workflow with the idempotent Pydantic-backed initializer (`scripts/init_db.py`), added a guarded reset utility (`scripts/reset_db.py`), removed migration artifacts/tooling (Alembic directory, config, Docker entrypoint), refreshed the container entrypoint to invoke `uvicorn` directly, and updated installation/architecture docs plus the README to direct developers to the new seeding/reset flow.
|
||||||
- Eliminated Bandit hardcoded-secret findings by replacing literal JWT tokens and passwords across auth/security tests with randomized helpers drawn from `tests/utils/security.py`, ensuring fixtures still assert expected behaviours.
|
- Eliminated Bandit hardcoded-secret findings by replacing literal JWT tokens and passwords across auth/security tests with randomized helpers drawn from `tests/utils/security.py`, ensuring fixtures still assert expected behaviours.
|
||||||
|
|||||||
@@ -1,6 +1,21 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
|
from typing import Type
|
||||||
|
|
||||||
|
from sqlalchemy import Enum as SQLEnum
|
||||||
|
|
||||||
|
|
||||||
|
def sql_enum(enum_cls: Type[Enum], *, name: str) -> SQLEnum:
|
||||||
|
"""Build a SQLAlchemy Enum that maps using the enum member values."""
|
||||||
|
|
||||||
|
return SQLEnum(
|
||||||
|
enum_cls,
|
||||||
|
name=name,
|
||||||
|
create_type=False,
|
||||||
|
validate_strings=True,
|
||||||
|
values_callable=lambda enum_cls: [member.value for member in enum_cls],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class MiningOperationType(str, Enum):
|
class MiningOperationType(str, Enum):
|
||||||
|
|||||||
@@ -6,7 +6,6 @@ from typing import TYPE_CHECKING
|
|||||||
from sqlalchemy import (
|
from sqlalchemy import (
|
||||||
Date,
|
Date,
|
||||||
DateTime,
|
DateTime,
|
||||||
Enum as SQLEnum,
|
|
||||||
ForeignKey,
|
ForeignKey,
|
||||||
Integer,
|
Integer,
|
||||||
Numeric,
|
Numeric,
|
||||||
@@ -18,7 +17,7 @@ from sqlalchemy.orm import Mapped, mapped_column, relationship, validates
|
|||||||
from sqlalchemy.sql import func
|
from sqlalchemy.sql import func
|
||||||
|
|
||||||
from config.database import Base
|
from config.database import Base
|
||||||
from .enums import CostBucket, FinancialCategory
|
from .enums import CostBucket, FinancialCategory, sql_enum
|
||||||
from services.currency import normalise_currency
|
from services.currency import normalise_currency
|
||||||
|
|
||||||
if TYPE_CHECKING: # pragma: no cover
|
if TYPE_CHECKING: # pragma: no cover
|
||||||
@@ -36,10 +35,10 @@ class FinancialInput(Base):
|
|||||||
)
|
)
|
||||||
name: Mapped[str] = mapped_column(String(255), nullable=False)
|
name: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||||
category: Mapped[FinancialCategory] = mapped_column(
|
category: Mapped[FinancialCategory] = mapped_column(
|
||||||
SQLEnum(FinancialCategory, name="financialcategory", create_type=False), nullable=False
|
sql_enum(FinancialCategory, name="financialcategory"), nullable=False
|
||||||
)
|
)
|
||||||
cost_bucket: Mapped[CostBucket | None] = mapped_column(
|
cost_bucket: Mapped[CostBucket | None] = mapped_column(
|
||||||
SQLEnum(CostBucket, name="costbucket", create_type=False), nullable=True
|
sql_enum(CostBucket, name="costbucket"), nullable=True
|
||||||
)
|
)
|
||||||
amount: Mapped[float] = mapped_column(Numeric(18, 2), nullable=False)
|
amount: Mapped[float] = mapped_column(Numeric(18, 2), nullable=False)
|
||||||
currency: Mapped[str | None] = mapped_column(String(3), nullable=True)
|
currency: Mapped[str | None] = mapped_column(String(3), nullable=True)
|
||||||
|
|||||||
@@ -3,9 +3,9 @@ from __future__ import annotations
|
|||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import TYPE_CHECKING, List
|
from typing import TYPE_CHECKING, List
|
||||||
|
|
||||||
from .enums import MiningOperationType
|
from .enums import MiningOperationType, sql_enum
|
||||||
|
|
||||||
from sqlalchemy import DateTime, Enum as SQLEnum, ForeignKey, Integer, String, Text
|
from sqlalchemy import DateTime, ForeignKey, Integer, String, Text
|
||||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||||
from sqlalchemy.sql import func
|
from sqlalchemy.sql import func
|
||||||
|
|
||||||
@@ -16,8 +16,6 @@ if TYPE_CHECKING: # pragma: no cover
|
|||||||
from .pricing_settings import PricingSettings
|
from .pricing_settings import PricingSettings
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class Project(Base):
|
class Project(Base):
|
||||||
"""Top-level mining project grouping multiple scenarios."""
|
"""Top-level mining project grouping multiple scenarios."""
|
||||||
|
|
||||||
@@ -27,7 +25,7 @@ class Project(Base):
|
|||||||
name: Mapped[str] = mapped_column(String(255), nullable=False, unique=True)
|
name: Mapped[str] = mapped_column(String(255), nullable=False, unique=True)
|
||||||
location: Mapped[str | None] = mapped_column(String(255), nullable=True)
|
location: Mapped[str | None] = mapped_column(String(255), nullable=True)
|
||||||
operation_type: Mapped[MiningOperationType] = mapped_column(
|
operation_type: Mapped[MiningOperationType] = mapped_column(
|
||||||
SQLEnum(MiningOperationType, name="miningoperationtype", create_type=False),
|
sql_enum(MiningOperationType, name="miningoperationtype"),
|
||||||
nullable=False,
|
nullable=False,
|
||||||
default=MiningOperationType.OTHER,
|
default=MiningOperationType.OTHER,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -6,7 +6,6 @@ from typing import TYPE_CHECKING, List
|
|||||||
from sqlalchemy import (
|
from sqlalchemy import (
|
||||||
Date,
|
Date,
|
||||||
DateTime,
|
DateTime,
|
||||||
Enum as SQLEnum,
|
|
||||||
ForeignKey,
|
ForeignKey,
|
||||||
Integer,
|
Integer,
|
||||||
Numeric,
|
Numeric,
|
||||||
@@ -19,7 +18,7 @@ from sqlalchemy.sql import func
|
|||||||
|
|
||||||
from config.database import Base
|
from config.database import Base
|
||||||
from services.currency import normalise_currency
|
from services.currency import normalise_currency
|
||||||
from .enums import ResourceType, ScenarioStatus
|
from .enums import ResourceType, ScenarioStatus, sql_enum
|
||||||
|
|
||||||
if TYPE_CHECKING: # pragma: no cover
|
if TYPE_CHECKING: # pragma: no cover
|
||||||
from .financial_input import FinancialInput
|
from .financial_input import FinancialInput
|
||||||
@@ -43,7 +42,7 @@ class Scenario(Base):
|
|||||||
name: Mapped[str] = mapped_column(String(255), nullable=False)
|
name: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||||
description: Mapped[str | None] = mapped_column(Text, nullable=True)
|
description: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||||
status: Mapped[ScenarioStatus] = mapped_column(
|
status: Mapped[ScenarioStatus] = mapped_column(
|
||||||
SQLEnum(ScenarioStatus, name="scenariostatus", create_type=False),
|
sql_enum(ScenarioStatus, name="scenariostatus"),
|
||||||
nullable=False,
|
nullable=False,
|
||||||
default=ScenarioStatus.DRAFT,
|
default=ScenarioStatus.DRAFT,
|
||||||
)
|
)
|
||||||
@@ -53,7 +52,7 @@ class Scenario(Base):
|
|||||||
Numeric(5, 2), nullable=True)
|
Numeric(5, 2), nullable=True)
|
||||||
currency: Mapped[str | None] = mapped_column(String(3), nullable=True)
|
currency: Mapped[str | None] = mapped_column(String(3), nullable=True)
|
||||||
primary_resource: Mapped[ResourceType | None] = mapped_column(
|
primary_resource: Mapped[ResourceType | None] = mapped_column(
|
||||||
SQLEnum(ResourceType, name="resourcetype", create_type=False), nullable=True
|
sql_enum(ResourceType, name="resourcetype"), nullable=True
|
||||||
)
|
)
|
||||||
created_at: Mapped[datetime] = mapped_column(
|
created_at: Mapped[datetime] = mapped_column(
|
||||||
DateTime(timezone=True), nullable=False, server_default=func.now()
|
DateTime(timezone=True), nullable=False, server_default=func.now()
|
||||||
|
|||||||
@@ -3,12 +3,11 @@ from __future__ import annotations
|
|||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
from .enums import DistributionType, ResourceType, StochasticVariable
|
from .enums import DistributionType, ResourceType, StochasticVariable, sql_enum
|
||||||
|
|
||||||
from sqlalchemy import (
|
from sqlalchemy import (
|
||||||
JSON,
|
JSON,
|
||||||
DateTime,
|
DateTime,
|
||||||
Enum as SQLEnum,
|
|
||||||
ForeignKey,
|
ForeignKey,
|
||||||
Integer,
|
Integer,
|
||||||
Numeric,
|
Numeric,
|
||||||
@@ -34,13 +33,13 @@ class SimulationParameter(Base):
|
|||||||
)
|
)
|
||||||
name: Mapped[str] = mapped_column(String(255), nullable=False)
|
name: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||||
distribution: Mapped[DistributionType] = mapped_column(
|
distribution: Mapped[DistributionType] = mapped_column(
|
||||||
SQLEnum(DistributionType, name="distributiontype", create_type=False), nullable=False
|
sql_enum(DistributionType, name="distributiontype"), nullable=False
|
||||||
)
|
)
|
||||||
variable: Mapped[StochasticVariable | None] = mapped_column(
|
variable: Mapped[StochasticVariable | None] = mapped_column(
|
||||||
SQLEnum(StochasticVariable, name="stochasticvariable", create_type=False), nullable=True
|
sql_enum(StochasticVariable, name="stochasticvariable"), nullable=True
|
||||||
)
|
)
|
||||||
resource_type: Mapped[ResourceType | None] = mapped_column(
|
resource_type: Mapped[ResourceType | None] = mapped_column(
|
||||||
SQLEnum(ResourceType, name="resourcetype", create_type=False), nullable=True
|
sql_enum(ResourceType, name="resourcetype"), nullable=True
|
||||||
)
|
)
|
||||||
mean_value: Mapped[float | None] = mapped_column(
|
mean_value: Mapped[float | None] = mapped_column(
|
||||||
Numeric(18, 4), nullable=True)
|
Numeric(18, 4), nullable=True)
|
||||||
|
|||||||
@@ -17,8 +17,7 @@ Notes:
|
|||||||
"""
|
"""
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from typing import List, Optional, Set
|
||||||
from typing import List, Optional
|
|
||||||
import os
|
import os
|
||||||
import logging
|
import logging
|
||||||
from decimal import Decimal
|
from decimal import Decimal
|
||||||
@@ -191,7 +190,8 @@ TABLE_DDLS = [
|
|||||||
currency VARCHAR(3),
|
currency VARCHAR(3),
|
||||||
primary_resource resourcetype,
|
primary_resource resourcetype,
|
||||||
created_at TIMESTAMPTZ DEFAULT now(),
|
created_at TIMESTAMPTZ DEFAULT now(),
|
||||||
updated_at TIMESTAMPTZ DEFAULT now()
|
updated_at TIMESTAMPTZ DEFAULT now(),
|
||||||
|
CONSTRAINT uq_scenarios_project_name UNIQUE (project_id, name)
|
||||||
);
|
);
|
||||||
""",
|
""",
|
||||||
"""
|
"""
|
||||||
@@ -206,7 +206,8 @@ TABLE_DDLS = [
|
|||||||
effective_date DATE,
|
effective_date DATE,
|
||||||
notes TEXT,
|
notes TEXT,
|
||||||
created_at TIMESTAMPTZ DEFAULT now(),
|
created_at TIMESTAMPTZ DEFAULT now(),
|
||||||
updated_at TIMESTAMPTZ DEFAULT now()
|
updated_at TIMESTAMPTZ DEFAULT now(),
|
||||||
|
CONSTRAINT uq_financial_inputs_scenario_name UNIQUE (scenario_id, name)
|
||||||
);
|
);
|
||||||
""",
|
""",
|
||||||
"""
|
"""
|
||||||
@@ -429,6 +430,61 @@ def ensure_enums(engine: Engine) -> None:
|
|||||||
conn.execute(text(sql))
|
conn.execute(text(sql))
|
||||||
|
|
||||||
|
|
||||||
|
def _fetch_enum_values(conn, type_name: str) -> Set[str]:
|
||||||
|
rows = conn.execute(
|
||||||
|
text(
|
||||||
|
"""
|
||||||
|
SELECT e.enumlabel
|
||||||
|
FROM pg_enum e
|
||||||
|
JOIN pg_type t ON t.oid = e.enumtypid
|
||||||
|
WHERE t.typname = :type_name
|
||||||
|
"""
|
||||||
|
),
|
||||||
|
{"type_name": type_name},
|
||||||
|
)
|
||||||
|
return {row.enumlabel for row in rows}
|
||||||
|
|
||||||
|
|
||||||
|
def normalize_enum_values(engine: Engine) -> None:
|
||||||
|
with engine.begin() as conn:
|
||||||
|
for type_name, expected_values in ENUM_DEFINITIONS.items():
|
||||||
|
try:
|
||||||
|
existing_values = _fetch_enum_values(conn, type_name)
|
||||||
|
except Exception as exc: # pragma: no cover - system catalogs missing
|
||||||
|
logger.debug(
|
||||||
|
"Skipping enum normalization for %s due to error: %s",
|
||||||
|
type_name,
|
||||||
|
exc,
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
|
expected_set = set(expected_values)
|
||||||
|
for value in list(existing_values):
|
||||||
|
if value in expected_set:
|
||||||
|
continue
|
||||||
|
|
||||||
|
normalized = value.lower()
|
||||||
|
if (
|
||||||
|
normalized != value
|
||||||
|
and normalized in expected_set
|
||||||
|
and normalized not in existing_values
|
||||||
|
):
|
||||||
|
logger.info(
|
||||||
|
"Renaming enum value %s.%s -> %s",
|
||||||
|
type_name,
|
||||||
|
value,
|
||||||
|
normalized,
|
||||||
|
)
|
||||||
|
conn.execute(
|
||||||
|
text(
|
||||||
|
f"ALTER TYPE {type_name} RENAME VALUE :old_value TO :new_value"
|
||||||
|
),
|
||||||
|
{"old_value": value, "new_value": normalized},
|
||||||
|
)
|
||||||
|
existing_values.remove(value)
|
||||||
|
existing_values.add(normalized)
|
||||||
|
|
||||||
|
|
||||||
def ensure_tables(engine: Engine) -> None:
|
def ensure_tables(engine: Engine) -> None:
|
||||||
with engine.begin() as conn:
|
with engine.begin() as conn:
|
||||||
for ddl in TABLE_DDLS:
|
for ddl in TABLE_DDLS:
|
||||||
@@ -436,6 +492,45 @@ def ensure_tables(engine: Engine) -> None:
|
|||||||
conn.execute(text(ddl))
|
conn.execute(text(ddl))
|
||||||
|
|
||||||
|
|
||||||
|
CONSTRAINT_DDLS = [
|
||||||
|
"""
|
||||||
|
DO $$
|
||||||
|
BEGIN
|
||||||
|
IF NOT EXISTS (
|
||||||
|
SELECT 1
|
||||||
|
FROM pg_constraint
|
||||||
|
WHERE conname = 'uq_scenarios_project_name'
|
||||||
|
) THEN
|
||||||
|
ALTER TABLE scenarios
|
||||||
|
ADD CONSTRAINT uq_scenarios_project_name UNIQUE (project_id, name);
|
||||||
|
END IF;
|
||||||
|
END;
|
||||||
|
$$;
|
||||||
|
""",
|
||||||
|
"""
|
||||||
|
DO $$
|
||||||
|
BEGIN
|
||||||
|
IF NOT EXISTS (
|
||||||
|
SELECT 1
|
||||||
|
FROM pg_constraint
|
||||||
|
WHERE conname = 'uq_financial_inputs_scenario_name'
|
||||||
|
) THEN
|
||||||
|
ALTER TABLE financial_inputs
|
||||||
|
ADD CONSTRAINT uq_financial_inputs_scenario_name UNIQUE (scenario_id, name);
|
||||||
|
END IF;
|
||||||
|
END;
|
||||||
|
$$;
|
||||||
|
""",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def ensure_constraints(engine: Engine) -> None:
|
||||||
|
with engine.begin() as conn:
|
||||||
|
for ddl in CONSTRAINT_DDLS:
|
||||||
|
logger.debug("Ensuring constraint via:\n%s", ddl)
|
||||||
|
conn.execute(text(ddl))
|
||||||
|
|
||||||
|
|
||||||
def seed_roles(engine: Engine) -> None:
|
def seed_roles(engine: Engine) -> None:
|
||||||
with engine.begin() as conn:
|
with engine.begin() as conn:
|
||||||
for r in DEFAULT_ROLES:
|
for r in DEFAULT_ROLES:
|
||||||
@@ -657,7 +752,9 @@ def init_db(database_url: Optional[str] = None) -> None:
|
|||||||
engine = _create_engine(database_url)
|
engine = _create_engine(database_url)
|
||||||
logger.info("Starting DB initialization using engine=%s", engine)
|
logger.info("Starting DB initialization using engine=%s", engine)
|
||||||
ensure_enums(engine)
|
ensure_enums(engine)
|
||||||
|
normalize_enum_values(engine)
|
||||||
ensure_tables(engine)
|
ensure_tables(engine)
|
||||||
|
ensure_constraints(engine)
|
||||||
seed_roles(engine)
|
seed_roles(engine)
|
||||||
seed_admin_user(engine)
|
seed_admin_user(engine)
|
||||||
ensure_default_pricing(engine)
|
ensure_default_pricing(engine)
|
||||||
|
|||||||
@@ -60,7 +60,7 @@ class FakeConnection:
|
|||||||
sql = str(statement).strip()
|
sql = str(statement).strip()
|
||||||
lower_sql = sql.lower()
|
lower_sql = sql.lower()
|
||||||
|
|
||||||
if lower_sql.startswith("do $$ begin"):
|
if lower_sql.startswith("do $$"):
|
||||||
match = re.search(r"create type\s+(\w+)\s+as enum", lower_sql)
|
match = re.search(r"create type\s+(\w+)\s+as enum", lower_sql)
|
||||||
if match:
|
if match:
|
||||||
self.state.enums.add(match.group(1))
|
self.state.enums.add(match.group(1))
|
||||||
@@ -194,6 +194,18 @@ class FakeConnection:
|
|||||||
self.state.financial_inputs[key] = record
|
self.state.financial_inputs[key] = record
|
||||||
return FakeResult([])
|
return FakeResult([])
|
||||||
|
|
||||||
|
if "from pg_enum" in lower_sql and "enumlabel" in lower_sql:
|
||||||
|
type_name_param = params.get("type_name")
|
||||||
|
if type_name_param is None:
|
||||||
|
return FakeResult([])
|
||||||
|
type_name = str(type_name_param)
|
||||||
|
values = init_db.ENUM_DEFINITIONS.get(type_name, [])
|
||||||
|
rows = [SimpleNamespace(enumlabel=value) for value in values]
|
||||||
|
return FakeResult(rows)
|
||||||
|
|
||||||
|
if lower_sql.startswith("alter type") and "rename value" in lower_sql:
|
||||||
|
return FakeResult([])
|
||||||
|
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
f"Unhandled SQL during test execution: {sql}")
|
f"Unhandled SQL during test execution: {sql}")
|
||||||
|
|
||||||
|
|||||||
@@ -14,7 +14,6 @@ from models import Role, User, UserRole
|
|||||||
from dependencies import get_auth_session, require_current_user
|
from dependencies import get_auth_session, require_current_user
|
||||||
from services.security import hash_password
|
from services.security import hash_password
|
||||||
from services.session import AuthSession, SessionTokens
|
from services.session import AuthSession, SessionTokens
|
||||||
from tests.conftest import app
|
|
||||||
from tests.utils.security import random_password, random_token
|
from tests.utils.security import random_password, random_token
|
||||||
|
|
||||||
COOKIE_SOURCE = "cookie"
|
COOKIE_SOURCE = "cookie"
|
||||||
|
|||||||
Reference in New Issue
Block a user