Files
calminer/services/repositories.py
zwitschi 795a9f99f4 feat: Enhance currency handling and validation across scenarios
- Updated form template to prefill currency input with default value and added help text for clarity.
- Modified integration tests to assert more descriptive error messages for invalid currency codes.
- Introduced new tests for currency normalization and validation in various scenarios, including imports and exports.
- Added comprehensive tests for pricing calculations, ensuring defaults are respected and overrides function correctly.
- Implemented unit tests for pricing settings repository, ensuring CRUD operations and default settings are handled properly.
- Enhanced scenario pricing evaluation tests to validate currency handling and metadata defaults.
- Added simulation tests to ensure Monte Carlo runs are accurate and handle various distribution scenarios.
2025-11-11 18:29:59 +01:00

890 lines
31 KiB
Python

from __future__ import annotations
from collections.abc import Iterable
from dataclasses import dataclass
from datetime import datetime
from typing import Mapping, Sequence
from sqlalchemy import select, func
from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import Session, joinedload, selectinload
from models import (
FinancialInput,
Project,
PricingImpuritySettings,
PricingMetalSettings,
PricingSettings,
ResourceType,
Role,
Scenario,
ScenarioStatus,
SimulationParameter,
User,
UserRole,
)
from services.exceptions import EntityConflictError, EntityNotFoundError
from services.export_query import ProjectExportFilters, ScenarioExportFilters
from services.pricing import PricingMetadata
class ProjectRepository:
"""Persistence operations for Project entities."""
def __init__(self, session: Session) -> None:
self.session = session
def list(
self,
*,
with_children: bool = False,
with_pricing: bool = False,
) -> Sequence[Project]:
stmt = select(Project).order_by(Project.created_at)
if with_children:
stmt = stmt.options(selectinload(Project.scenarios))
if with_pricing:
stmt = stmt.options(selectinload(Project.pricing_settings))
return self.session.execute(stmt).scalars().all()
def count(self) -> int:
stmt = select(func.count(Project.id))
return self.session.execute(stmt).scalar_one()
def recent(self, limit: int = 5) -> Sequence[Project]:
stmt = (
select(Project)
.order_by(Project.updated_at.desc())
.limit(limit)
)
return self.session.execute(stmt).scalars().all()
def get(
self,
project_id: int,
*,
with_children: bool = False,
with_pricing: bool = False,
) -> Project:
stmt = select(Project).where(Project.id == project_id)
if with_children:
stmt = stmt.options(joinedload(Project.scenarios))
if with_pricing:
stmt = stmt.options(joinedload(Project.pricing_settings))
result = self.session.execute(stmt)
if with_children:
result = result.unique()
project = result.scalar_one_or_none()
if project is None:
raise EntityNotFoundError(f"Project {project_id} not found")
return project
def exists(self, project_id: int) -> bool:
stmt = select(Project.id).where(Project.id == project_id).limit(1)
return self.session.execute(stmt).scalar_one_or_none() is not None
def create(self, project: Project) -> Project:
self.session.add(project)
try:
self.session.flush()
except IntegrityError as exc: # pragma: no cover - reliance on DB constraints
raise EntityConflictError(
"Project violates uniqueness constraints") from exc
return project
def find_by_names(self, names: Iterable[str]) -> Mapping[str, Project]:
normalised = {name.strip().lower()
for name in names if name and name.strip()}
if not normalised:
return {}
stmt = select(Project).where(func.lower(Project.name).in_(normalised))
records = self.session.execute(stmt).scalars().all()
return {project.name.lower(): project for project in records}
def filtered_for_export(
self,
filters: ProjectExportFilters | None = None,
*,
include_scenarios: bool = False,
include_pricing: bool = False,
) -> Sequence[Project]:
stmt = select(Project)
if include_scenarios:
stmt = stmt.options(selectinload(Project.scenarios))
if include_pricing:
stmt = stmt.options(selectinload(Project.pricing_settings))
if filters:
ids = filters.normalised_ids()
if ids:
stmt = stmt.where(Project.id.in_(ids))
name_matches = filters.normalised_names()
if name_matches:
stmt = stmt.where(func.lower(Project.name).in_(name_matches))
name_pattern = filters.name_search_pattern()
if name_pattern:
stmt = stmt.where(Project.name.ilike(name_pattern))
locations = filters.normalised_locations()
if locations:
stmt = stmt.where(func.lower(Project.location).in_(locations))
if filters.operation_types:
stmt = stmt.where(Project.operation_type.in_(
filters.operation_types))
if filters.created_from:
stmt = stmt.where(Project.created_at >= filters.created_from)
if filters.created_to:
stmt = stmt.where(Project.created_at <= filters.created_to)
if filters.updated_from:
stmt = stmt.where(Project.updated_at >= filters.updated_from)
if filters.updated_to:
stmt = stmt.where(Project.updated_at <= filters.updated_to)
stmt = stmt.order_by(Project.name, Project.id)
return self.session.execute(stmt).scalars().all()
def delete(self, project_id: int) -> None:
project = self.get(project_id)
self.session.delete(project)
def set_pricing_settings(
self,
project: Project,
pricing_settings: PricingSettings | None,
) -> Project:
project.pricing_settings = pricing_settings
project.pricing_settings_id = (
pricing_settings.id if pricing_settings is not None else None
)
self.session.flush()
return project
class ScenarioRepository:
"""Persistence operations for Scenario entities."""
def __init__(self, session: Session) -> None:
self.session = session
def list_for_project(
self,
project_id: int,
*,
with_children: bool = False,
) -> Sequence[Scenario]:
stmt = (
select(Scenario)
.where(Scenario.project_id == project_id)
.order_by(Scenario.created_at)
)
if with_children:
stmt = stmt.options(
selectinload(Scenario.financial_inputs),
selectinload(Scenario.simulation_parameters),
)
result = self.session.execute(stmt)
if with_children:
result = result.unique()
return result.scalars().all()
def count(self) -> int:
stmt = select(func.count(Scenario.id))
return self.session.execute(stmt).scalar_one()
def count_by_status(self, status: ScenarioStatus) -> int:
stmt = select(func.count(Scenario.id)).where(Scenario.status == status)
return self.session.execute(stmt).scalar_one()
def recent(self, limit: int = 5, *, with_project: bool = False) -> Sequence[Scenario]:
stmt = select(Scenario).order_by(
Scenario.updated_at.desc()).limit(limit)
if with_project:
stmt = stmt.options(joinedload(Scenario.project))
return self.session.execute(stmt).scalars().all()
def list_by_status(
self,
status: ScenarioStatus,
*,
limit: int | None = None,
with_project: bool = False,
) -> Sequence[Scenario]:
stmt = (
select(Scenario)
.where(Scenario.status == status)
.order_by(Scenario.updated_at.desc())
)
if with_project:
stmt = stmt.options(joinedload(Scenario.project))
if limit is not None:
stmt = stmt.limit(limit)
return self.session.execute(stmt).scalars().all()
def get(self, scenario_id: int, *, with_children: bool = False) -> Scenario:
stmt = select(Scenario).where(Scenario.id == scenario_id)
if with_children:
stmt = stmt.options(
joinedload(Scenario.financial_inputs),
joinedload(Scenario.simulation_parameters),
)
result = self.session.execute(stmt)
if with_children:
result = result.unique()
scenario = result.scalar_one_or_none()
if scenario is None:
raise EntityNotFoundError(f"Scenario {scenario_id} not found")
return scenario
def exists(self, scenario_id: int) -> bool:
stmt = select(Scenario.id).where(Scenario.id == scenario_id).limit(1)
return self.session.execute(stmt).scalar_one_or_none() is not None
def create(self, scenario: Scenario) -> Scenario:
self.session.add(scenario)
try:
self.session.flush()
except IntegrityError as exc: # pragma: no cover
raise EntityConflictError("Scenario violates constraints") from exc
return scenario
def find_by_project_and_names(
self,
project_id: int,
names: Iterable[str],
) -> Mapping[str, Scenario]:
normalised = {name.strip().lower()
for name in names if name and name.strip()}
if not normalised:
return {}
stmt = (
select(Scenario)
.where(
Scenario.project_id == project_id,
func.lower(Scenario.name).in_(normalised),
)
)
records = self.session.execute(stmt).scalars().all()
return {scenario.name.lower(): scenario for scenario in records}
def filtered_for_export(
self,
filters: ScenarioExportFilters | None = None,
*,
include_project: bool = True,
) -> Sequence[Scenario]:
stmt = select(Scenario)
if include_project:
stmt = stmt.options(joinedload(Scenario.project))
if filters:
scenario_ids = filters.normalised_ids()
if scenario_ids:
stmt = stmt.where(Scenario.id.in_(scenario_ids))
project_ids = filters.normalised_project_ids()
if project_ids:
stmt = stmt.where(Scenario.project_id.in_(project_ids))
project_names = filters.normalised_project_names()
if project_names:
project_id_select = select(Project.id).where(
func.lower(Project.name).in_(project_names)
)
stmt = stmt.where(Scenario.project_id.in_(project_id_select))
name_pattern = filters.name_search_pattern()
if name_pattern:
stmt = stmt.where(Scenario.name.ilike(name_pattern))
if filters.statuses:
stmt = stmt.where(Scenario.status.in_(filters.statuses))
if filters.start_date_from:
stmt = stmt.where(Scenario.start_date >=
filters.start_date_from)
if filters.start_date_to:
stmt = stmt.where(Scenario.start_date <= filters.start_date_to)
if filters.end_date_from:
stmt = stmt.where(Scenario.end_date >= filters.end_date_from)
if filters.end_date_to:
stmt = stmt.where(Scenario.end_date <= filters.end_date_to)
if filters.created_from:
stmt = stmt.where(Scenario.created_at >= filters.created_from)
if filters.created_to:
stmt = stmt.where(Scenario.created_at <= filters.created_to)
if filters.updated_from:
stmt = stmt.where(Scenario.updated_at >= filters.updated_from)
if filters.updated_to:
stmt = stmt.where(Scenario.updated_at <= filters.updated_to)
currencies = filters.normalised_currencies()
if currencies:
stmt = stmt.where(func.upper(
Scenario.currency).in_(currencies))
if filters.primary_resources:
stmt = stmt.where(Scenario.primary_resource.in_(
filters.primary_resources))
stmt = stmt.order_by(Scenario.name, Scenario.id)
return self.session.execute(stmt).scalars().all()
def delete(self, scenario_id: int) -> None:
scenario = self.get(scenario_id)
self.session.delete(scenario)
class FinancialInputRepository:
"""Persistence operations for FinancialInput entities."""
def __init__(self, session: Session) -> None:
self.session = session
def list_for_scenario(self, scenario_id: int) -> Sequence[FinancialInput]:
stmt = (
select(FinancialInput)
.where(FinancialInput.scenario_id == scenario_id)
.order_by(FinancialInput.created_at)
)
return self.session.execute(stmt).scalars().all()
def bulk_upsert(self, inputs: Iterable[FinancialInput]) -> Sequence[FinancialInput]:
entities = list(inputs)
self.session.add_all(entities)
try:
self.session.flush()
except IntegrityError as exc: # pragma: no cover
raise EntityConflictError(
"Financial input violates constraints") from exc
return entities
def delete(self, input_id: int) -> None:
stmt = select(FinancialInput).where(FinancialInput.id == input_id)
entity = self.session.execute(stmt).scalar_one_or_none()
if entity is None:
raise EntityNotFoundError(f"Financial input {input_id} not found")
self.session.delete(entity)
def latest_created_at(self) -> datetime | None:
stmt = (
select(FinancialInput.created_at)
.order_by(FinancialInput.created_at.desc())
.limit(1)
)
return self.session.execute(stmt).scalar_one_or_none()
class SimulationParameterRepository:
"""Persistence operations for SimulationParameter entities."""
def __init__(self, session: Session) -> None:
self.session = session
def list_for_scenario(self, scenario_id: int) -> Sequence[SimulationParameter]:
stmt = (
select(SimulationParameter)
.where(SimulationParameter.scenario_id == scenario_id)
.order_by(SimulationParameter.created_at)
)
return self.session.execute(stmt).scalars().all()
def bulk_upsert(
self, parameters: Iterable[SimulationParameter]
) -> Sequence[SimulationParameter]:
entities = list(parameters)
self.session.add_all(entities)
try:
self.session.flush()
except IntegrityError as exc: # pragma: no cover
raise EntityConflictError(
"Simulation parameter violates constraints") from exc
return entities
def delete(self, parameter_id: int) -> None:
stmt = select(SimulationParameter).where(
SimulationParameter.id == parameter_id)
entity = self.session.execute(stmt).scalar_one_or_none()
if entity is None:
raise EntityNotFoundError(
f"Simulation parameter {parameter_id} not found")
self.session.delete(entity)
class PricingSettingsRepository:
"""Persistence operations for pricing configuration entities."""
def __init__(self, session: Session) -> None:
self.session = session
def list(self, *, include_children: bool = False) -> Sequence[PricingSettings]:
stmt = select(PricingSettings).order_by(PricingSettings.created_at)
if include_children:
stmt = stmt.options(
selectinload(PricingSettings.metal_overrides),
selectinload(PricingSettings.impurity_overrides),
)
result = self.session.execute(stmt)
if include_children:
result = result.unique()
return result.scalars().all()
def get(self, settings_id: int, *, include_children: bool = False) -> PricingSettings:
stmt = select(PricingSettings).where(PricingSettings.id == settings_id)
if include_children:
stmt = stmt.options(
selectinload(PricingSettings.metal_overrides),
selectinload(PricingSettings.impurity_overrides),
)
result = self.session.execute(stmt)
if include_children:
result = result.unique()
settings = result.scalar_one_or_none()
if settings is None:
raise EntityNotFoundError(
f"Pricing settings {settings_id} not found")
return settings
def find_by_slug(
self,
slug: str,
*,
include_children: bool = False,
) -> PricingSettings | None:
normalised = slug.strip().lower()
stmt = select(PricingSettings).where(
PricingSettings.slug == normalised)
if include_children:
stmt = stmt.options(
selectinload(PricingSettings.metal_overrides),
selectinload(PricingSettings.impurity_overrides),
)
result = self.session.execute(stmt)
if include_children:
result = result.unique()
return result.scalar_one_or_none()
def get_by_slug(self, slug: str, *, include_children: bool = False) -> PricingSettings:
settings = self.find_by_slug(slug, include_children=include_children)
if settings is None:
raise EntityNotFoundError(
f"Pricing settings slug '{slug}' not found"
)
return settings
def create(self, settings: PricingSettings) -> PricingSettings:
self.session.add(settings)
try:
self.session.flush()
except IntegrityError as exc: # pragma: no cover - relies on DB constraints
raise EntityConflictError(
"Pricing settings violates constraints") from exc
return settings
def delete(self, settings_id: int) -> None:
settings = self.get(settings_id, include_children=True)
self.session.delete(settings)
def attach_metal_override(
self,
settings: PricingSettings,
override: PricingMetalSettings,
) -> PricingMetalSettings:
settings.metal_overrides.append(override)
self.session.add(override)
self.session.flush()
return override
def attach_impurity_override(
self,
settings: PricingSettings,
override: PricingImpuritySettings,
) -> PricingImpuritySettings:
settings.impurity_overrides.append(override)
self.session.add(override)
self.session.flush()
return override
class RoleRepository:
"""Persistence operations for Role entities."""
def __init__(self, session: Session) -> None:
self.session = session
def list(self) -> Sequence[Role]:
stmt = select(Role).order_by(Role.name)
return self.session.execute(stmt).scalars().all()
def get(self, role_id: int) -> Role:
stmt = select(Role).where(Role.id == role_id)
role = self.session.execute(stmt).scalar_one_or_none()
if role is None:
raise EntityNotFoundError(f"Role {role_id} not found")
return role
def get_by_name(self, name: str) -> Role | None:
stmt = select(Role).where(Role.name == name)
return self.session.execute(stmt).scalar_one_or_none()
def create(self, role: Role) -> Role:
self.session.add(role)
try:
self.session.flush()
except IntegrityError as exc: # pragma: no cover - DB constraint enforcement
raise EntityConflictError(
"Role violates uniqueness constraints") from exc
return role
class UserRepository:
"""Persistence operations for User entities and their role assignments."""
def __init__(self, session: Session) -> None:
self.session = session
def list(self, *, with_roles: bool = False) -> Sequence[User]:
stmt = select(User).order_by(User.created_at)
if with_roles:
stmt = stmt.options(selectinload(User.roles))
return self.session.execute(stmt).scalars().all()
def _apply_role_option(self, stmt, with_roles: bool):
if with_roles:
stmt = stmt.options(
joinedload(User.role_assignments).joinedload(UserRole.role),
selectinload(User.roles),
)
return stmt
def get(self, user_id: int, *, with_roles: bool = False) -> User:
stmt = select(User).where(User.id == user_id).execution_options(
populate_existing=True)
stmt = self._apply_role_option(stmt, with_roles)
result = self.session.execute(stmt)
if with_roles:
result = result.unique()
user = result.scalar_one_or_none()
if user is None:
raise EntityNotFoundError(f"User {user_id} not found")
return user
def get_by_email(self, email: str, *, with_roles: bool = False) -> User | None:
stmt = select(User).where(User.email == email).execution_options(
populate_existing=True)
stmt = self._apply_role_option(stmt, with_roles)
result = self.session.execute(stmt)
if with_roles:
result = result.unique()
return result.scalar_one_or_none()
def get_by_username(self, username: str, *, with_roles: bool = False) -> User | None:
stmt = select(User).where(User.username ==
username).execution_options(populate_existing=True)
stmt = self._apply_role_option(stmt, with_roles)
result = self.session.execute(stmt)
if with_roles:
result = result.unique()
return result.scalar_one_or_none()
def create(self, user: User) -> User:
self.session.add(user)
try:
self.session.flush()
except IntegrityError as exc: # pragma: no cover - DB constraint enforcement
raise EntityConflictError(
"User violates uniqueness constraints") from exc
return user
def assign_role(
self,
*,
user_id: int,
role_id: int,
granted_by: int | None = None,
) -> UserRole:
stmt = select(UserRole).where(
UserRole.user_id == user_id,
UserRole.role_id == role_id,
)
assignment = self.session.execute(stmt).scalar_one_or_none()
if assignment:
return assignment
assignment = UserRole(
user_id=user_id,
role_id=role_id,
granted_by=granted_by,
)
self.session.add(assignment)
try:
self.session.flush()
except IntegrityError as exc: # pragma: no cover - DB constraint enforcement
raise EntityConflictError(
"Assignment violates constraints") from exc
return assignment
def revoke_role(self, *, user_id: int, role_id: int) -> None:
stmt = select(UserRole).where(
UserRole.user_id == user_id,
UserRole.role_id == role_id,
)
assignment = self.session.execute(stmt).scalar_one_or_none()
if assignment is None:
raise EntityNotFoundError(
f"Role {role_id} not assigned to user {user_id}")
self.session.delete(assignment)
self.session.flush()
DEFAULT_PRICING_SETTINGS_NAME = "Default Pricing Settings"
DEFAULT_PRICING_SETTINGS_DESCRIPTION = (
"Default pricing configuration generated from environment metadata."
)
@dataclass(slots=True)
class PricingSettingsSeedResult:
settings: PricingSettings
created: bool
updated_fields: int
impurity_upserts: int
def ensure_default_pricing_settings(
repo: PricingSettingsRepository,
*,
metadata: PricingMetadata,
slug: str = "default",
name: str | None = None,
description: str | None = None,
) -> PricingSettingsSeedResult:
"""Ensure a baseline pricing settings record exists and matches metadata defaults."""
normalised_slug = (slug or "default").strip().lower() or "default"
target_name = name or DEFAULT_PRICING_SETTINGS_NAME
target_description = description or DEFAULT_PRICING_SETTINGS_DESCRIPTION
updated_fields = 0
impurity_upserts = 0
try:
settings = repo.get_by_slug(normalised_slug, include_children=True)
created = False
except EntityNotFoundError:
settings = PricingSettings(
name=target_name,
slug=normalised_slug,
description=target_description,
default_currency=metadata.default_currency,
default_payable_pct=metadata.default_payable_pct,
moisture_threshold_pct=metadata.moisture_threshold_pct,
moisture_penalty_per_pct=metadata.moisture_penalty_per_pct,
)
settings.metadata_payload = None
settings = repo.create(settings)
created = True
else:
if settings.name != target_name:
settings.name = target_name
updated_fields += 1
if target_description and settings.description != target_description:
settings.description = target_description
updated_fields += 1
if settings.default_currency != metadata.default_currency:
settings.default_currency = metadata.default_currency
updated_fields += 1
if float(settings.default_payable_pct) != float(metadata.default_payable_pct):
settings.default_payable_pct = metadata.default_payable_pct
updated_fields += 1
if float(settings.moisture_threshold_pct) != float(metadata.moisture_threshold_pct):
settings.moisture_threshold_pct = metadata.moisture_threshold_pct
updated_fields += 1
if float(settings.moisture_penalty_per_pct) != float(metadata.moisture_penalty_per_pct):
settings.moisture_penalty_per_pct = metadata.moisture_penalty_per_pct
updated_fields += 1
impurity_thresholds = {
code.strip().upper(): float(value)
for code, value in (metadata.impurity_thresholds or {}).items()
if code.strip()
}
impurity_penalties = {
code.strip().upper(): float(value)
for code, value in (metadata.impurity_penalty_per_ppm or {}).items()
if code.strip()
}
if impurity_thresholds or impurity_penalties:
existing_map = {
override.impurity_code: override
for override in settings.impurity_overrides
}
target_codes = set(impurity_thresholds) | set(impurity_penalties)
for code in sorted(target_codes):
threshold_value = impurity_thresholds.get(code, 0.0)
penalty_value = impurity_penalties.get(code, 0.0)
existing = existing_map.get(code)
if existing is None:
repo.attach_impurity_override(
settings,
PricingImpuritySettings(
impurity_code=code,
threshold_ppm=threshold_value,
penalty_per_ppm=penalty_value,
),
)
impurity_upserts += 1
continue
changed = False
if float(existing.threshold_ppm) != float(threshold_value):
existing.threshold_ppm = threshold_value
changed = True
if float(existing.penalty_per_ppm) != float(penalty_value):
existing.penalty_per_ppm = penalty_value
changed = True
if changed:
updated_fields += 1
if updated_fields > 0 or impurity_upserts > 0:
repo.session.flush()
return PricingSettingsSeedResult(
settings=settings,
created=created,
updated_fields=updated_fields,
impurity_upserts=impurity_upserts,
)
def pricing_settings_to_metadata(settings: PricingSettings) -> PricingMetadata:
"""Convert a persisted pricing settings record into metadata defaults."""
payload = settings.metadata_payload or {}
payload_thresholds = payload.get("impurity_thresholds") or {}
payload_penalties = payload.get("impurity_penalty_per_ppm") or {}
thresholds: dict[str, float] = {
code.strip().upper(): float(value)
for code, value in payload_thresholds.items()
if isinstance(code, str) and code.strip()
}
penalties: dict[str, float] = {
code.strip().upper(): float(value)
for code, value in payload_penalties.items()
if isinstance(code, str) and code.strip()
}
for override in settings.impurity_overrides:
code = override.impurity_code.strip().upper()
thresholds[code] = float(override.threshold_ppm)
penalties[code] = float(override.penalty_per_ppm)
return PricingMetadata(
default_payable_pct=float(settings.default_payable_pct),
default_currency=settings.default_currency,
moisture_threshold_pct=float(settings.moisture_threshold_pct),
moisture_penalty_per_pct=float(settings.moisture_penalty_per_pct),
impurity_thresholds=thresholds,
impurity_penalty_per_ppm=penalties,
)
DEFAULT_ROLE_DEFINITIONS: tuple[dict[str, str], ...] = (
{
"name": "admin",
"display_name": "Administrator",
"description": "Full platform access with user management rights.",
},
{
"name": "project_manager",
"display_name": "Project Manager",
"description": "Manage projects, scenarios, and associated data.",
},
{
"name": "analyst",
"display_name": "Analyst",
"description": "Review dashboards and scenario outputs.",
},
{
"name": "viewer",
"display_name": "Viewer",
"description": "Read-only access to assigned projects and reports.",
},
)
def ensure_default_roles(role_repo: RoleRepository) -> list[Role]:
"""Ensure standard roles exist, creating missing ones.
Returns all current role records in creation order.
"""
roles: list[Role] = []
for definition in DEFAULT_ROLE_DEFINITIONS:
existing = role_repo.get_by_name(definition["name"])
if existing:
roles.append(existing)
continue
role = Role(**definition)
roles.append(role_repo.create(role))
return roles
def ensure_admin_user(
user_repo: UserRepository,
role_repo: RoleRepository,
*,
email: str,
username: str,
password: str,
) -> User:
"""Ensure an administrator user exists and holds the admin role."""
user = user_repo.get_by_email(email, with_roles=True)
if user is None:
user = User(
email=email,
username=username,
password_hash=User.hash_password(password),
is_active=True,
is_superuser=True,
)
user_repo.create(user)
else:
if not user.is_active:
user.is_active = True
if not user.is_superuser:
user.is_superuser = True
user_repo.session.flush()
admin_role = role_repo.get_by_name("admin")
if admin_role is None: # pragma: no cover - safety if ensure_default_roles wasn't called
admin_role = role_repo.create(
Role(
name="admin",
display_name="Administrator",
description="Full platform access with user management rights.",
)
)
user_repo.assign_role(
user_id=user.id,
role_id=admin_role.id,
granted_by=user.id,
)
return user