Compare commits
25 Commits
c6a0eb2588
...
e0fa3861a6
| Author | SHA1 | Date | |
|---|---|---|---|
| e0fa3861a6 | |||
| ab328b1a0b | |||
| 24cb3c2f57 | |||
| 118657491c | |||
| 0f79864188 | |||
| 27262bdfa3 | |||
| 3601c2e422 | |||
| 53879a411f | |||
| 2d848c2e09 | |||
| dad862e48e | |||
| 400f85c907 | |||
| 7f5ed6a42d | |||
| 053da332ac | |||
| 02da881d3e | |||
| c39dde3198 | |||
| faea6777a0 | |||
| d36611606d | |||
| 191500aeb7 | |||
| 61b42b3041 | |||
| 8bf46b80c8 | |||
| c69f933684 | |||
| c6fdc2d923 | |||
| dc3ebfbba5 | |||
| 32a96a27c5 | |||
| 203a5d08f2 |
@@ -9,6 +9,3 @@ DATABASE_PASSWORD=<password>
|
||||
DATABASE_NAME=calminer
|
||||
# Optional: set a schema (comma-separated for multiple entries)
|
||||
# DATABASE_SCHEMA=public
|
||||
|
||||
# Legacy fallback (still supported, but granular settings are preferred)
|
||||
# DATABASE_URL=postgresql://<user>:<password>@localhost:5432/calminer
|
||||
35
alembic.ini
Normal file
35
alembic.ini
Normal file
@@ -0,0 +1,35 @@
|
||||
[alembic]
|
||||
script_location = alembic
|
||||
sqlalchemy.url = %(DATABASE_URL)s
|
||||
|
||||
[loggers]
|
||||
keys = root,sqlalchemy,alembic
|
||||
|
||||
[handlers]
|
||||
keys = console
|
||||
|
||||
[formatters]
|
||||
keys = generic
|
||||
|
||||
[logger_root]
|
||||
level = WARN
|
||||
handlers = console
|
||||
|
||||
[logger_sqlalchemy]
|
||||
level = WARN
|
||||
handlers =
|
||||
qualname = sqlalchemy.engine
|
||||
|
||||
[logger_alembic]
|
||||
level = INFO
|
||||
handlers =
|
||||
qualname = alembic
|
||||
|
||||
[handler_console]
|
||||
class = StreamHandler
|
||||
args = (sys.stderr,)
|
||||
level = NOTSET
|
||||
formatter = generic
|
||||
|
||||
[formatter_generic]
|
||||
format = %(levelname)-5.5s [%(name)s] %(message)s
|
||||
63
alembic/env.py
Normal file
63
alembic/env.py
Normal file
@@ -0,0 +1,63 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from logging.config import fileConfig
|
||||
from typing import Iterable
|
||||
|
||||
from alembic import context
|
||||
from sqlalchemy import engine_from_config, pool
|
||||
|
||||
from config.database import Base, DATABASE_URL
|
||||
from models import * # noqa: F401,F403 - ensure models are imported for metadata registration
|
||||
|
||||
# this is the Alembic Config object, which provides access to the values within the .ini file.
|
||||
config = context.config
|
||||
|
||||
if config.config_file_name is not None:
|
||||
fileConfig(config.config_file_name)
|
||||
|
||||
# Interpret the config file for Python logging.
|
||||
# This line sets up loggers basically.
|
||||
config.set_main_option("sqlalchemy.url", DATABASE_URL)
|
||||
|
||||
target_metadata = Base.metadata
|
||||
|
||||
|
||||
def run_migrations_offline() -> None:
|
||||
"""Run migrations in 'offline' mode."""
|
||||
|
||||
url = config.get_main_option("sqlalchemy.url")
|
||||
context.configure(
|
||||
url=url,
|
||||
target_metadata=target_metadata,
|
||||
literal_binds=True,
|
||||
dialect_opts={"paramstyle": "named"},
|
||||
)
|
||||
|
||||
with context.begin_transaction():
|
||||
context.run_migrations()
|
||||
|
||||
|
||||
def run_migrations_online() -> None:
|
||||
"""Run migrations in 'online' mode."""
|
||||
|
||||
connectable = engine_from_config(
|
||||
config.get_section(config.config_ini_section, {}),
|
||||
prefix="sqlalchemy.",
|
||||
poolclass=pool.NullPool,
|
||||
)
|
||||
|
||||
with connectable.connect() as connection:
|
||||
context.configure(connection=connection, target_metadata=target_metadata)
|
||||
|
||||
with context.begin_transaction():
|
||||
context.run_migrations()
|
||||
|
||||
|
||||
def run_migrations() -> None:
|
||||
if context.is_offline_mode():
|
||||
run_migrations_offline()
|
||||
else:
|
||||
run_migrations_online()
|
||||
|
||||
|
||||
run_migrations()
|
||||
17
alembic/script.py.mako
Normal file
17
alembic/script.py.mako
Normal file
@@ -0,0 +1,17 @@
|
||||
"""${message}"""
|
||||
|
||||
revision = ${repr(revision)}
|
||||
down_revision = ${repr(down_revision)}
|
||||
branch_labels = ${repr(branch_labels)}
|
||||
depends_on = ${repr(depends_on)}
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
${upgrades if upgrades else "pass"}
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
${downgrades if downgrades else "pass"}
|
||||
220
alembic/versions/20251109_01_initial_schema.py
Normal file
220
alembic/versions/20251109_01_initial_schema.py
Normal file
@@ -0,0 +1,220 @@
|
||||
"""Initial domain schema"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "20251109_01"
|
||||
down_revision = None
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
mining_operation_type = sa.Enum(
|
||||
"open_pit",
|
||||
"underground",
|
||||
"in_situ_leach",
|
||||
"placer",
|
||||
"quarry",
|
||||
"mountaintop_removal",
|
||||
"other",
|
||||
name="miningoperationtype",
|
||||
)
|
||||
|
||||
scenario_status = sa.Enum(
|
||||
"draft",
|
||||
"active",
|
||||
"archived",
|
||||
name="scenariostatus",
|
||||
)
|
||||
|
||||
financial_category = sa.Enum(
|
||||
"capex",
|
||||
"opex",
|
||||
"revenue",
|
||||
"contingency",
|
||||
"other",
|
||||
name="financialcategory",
|
||||
)
|
||||
|
||||
cost_bucket = sa.Enum(
|
||||
"capital_initial",
|
||||
"capital_sustaining",
|
||||
"operating_fixed",
|
||||
"operating_variable",
|
||||
"maintenance",
|
||||
"reclamation",
|
||||
"royalties",
|
||||
"general_admin",
|
||||
name="costbucket",
|
||||
)
|
||||
|
||||
distribution_type = sa.Enum(
|
||||
"normal",
|
||||
"triangular",
|
||||
"uniform",
|
||||
"lognormal",
|
||||
"custom",
|
||||
name="distributiontype",
|
||||
)
|
||||
|
||||
stochastic_variable = sa.Enum(
|
||||
"ore_grade",
|
||||
"recovery_rate",
|
||||
"metal_price",
|
||||
"operating_cost",
|
||||
"capital_cost",
|
||||
"discount_rate",
|
||||
"throughput",
|
||||
name="stochasticvariable",
|
||||
)
|
||||
|
||||
resource_type = sa.Enum(
|
||||
"diesel",
|
||||
"electricity",
|
||||
"water",
|
||||
"explosives",
|
||||
"reagents",
|
||||
"labor",
|
||||
"equipment_hours",
|
||||
"tailings_capacity",
|
||||
name="resourcetype",
|
||||
)
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
bind = op.get_bind()
|
||||
mining_operation_type.create(bind, checkfirst=True)
|
||||
scenario_status.create(bind, checkfirst=True)
|
||||
financial_category.create(bind, checkfirst=True)
|
||||
cost_bucket.create(bind, checkfirst=True)
|
||||
distribution_type.create(bind, checkfirst=True)
|
||||
stochastic_variable.create(bind, checkfirst=True)
|
||||
resource_type.create(bind, checkfirst=True)
|
||||
|
||||
op.create_table(
|
||||
"projects",
|
||||
sa.Column("id", sa.Integer(), nullable=False),
|
||||
sa.Column("name", sa.String(length=255), nullable=False),
|
||||
sa.Column("location", sa.String(length=255), nullable=True),
|
||||
sa.Column("operation_type", mining_operation_type, nullable=False),
|
||||
sa.Column("description", sa.Text(), nullable=True),
|
||||
sa.Column("created_at", sa.DateTime(timezone=True),
|
||||
server_default=sa.func.now(), nullable=False),
|
||||
sa.Column("updated_at", sa.DateTime(timezone=True),
|
||||
server_default=sa.func.now(), nullable=False),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
sa.UniqueConstraint("name"),
|
||||
)
|
||||
op.create_index(op.f("ix_projects_id"), "projects", ["id"], unique=False)
|
||||
|
||||
op.create_table(
|
||||
"scenarios",
|
||||
sa.Column("id", sa.Integer(), nullable=False),
|
||||
sa.Column("project_id", sa.Integer(), nullable=False),
|
||||
sa.Column("name", sa.String(length=255), nullable=False),
|
||||
sa.Column("description", sa.Text(), nullable=True),
|
||||
sa.Column("status", scenario_status, nullable=False),
|
||||
sa.Column("start_date", sa.Date(), nullable=True),
|
||||
sa.Column("end_date", sa.Date(), nullable=True),
|
||||
sa.Column("discount_rate", sa.Numeric(
|
||||
precision=5, scale=2), nullable=True),
|
||||
sa.Column("currency", sa.String(length=3), nullable=True),
|
||||
sa.Column("primary_resource", resource_type, nullable=True),
|
||||
sa.Column("created_at", sa.DateTime(timezone=True),
|
||||
server_default=sa.func.now(), nullable=False),
|
||||
sa.Column("updated_at", sa.DateTime(timezone=True),
|
||||
server_default=sa.func.now(), nullable=False),
|
||||
sa.ForeignKeyConstraint(
|
||||
["project_id"], ["projects.id"], ondelete="CASCADE"),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
op.create_index(op.f("ix_scenarios_id"), "scenarios", ["id"], unique=False)
|
||||
op.create_index(op.f("ix_scenarios_project_id"),
|
||||
"scenarios", ["project_id"], unique=False)
|
||||
|
||||
op.create_table(
|
||||
"financial_inputs",
|
||||
sa.Column("id", sa.Integer(), nullable=False),
|
||||
sa.Column("scenario_id", sa.Integer(), nullable=False),
|
||||
sa.Column("name", sa.String(length=255), nullable=False),
|
||||
sa.Column("category", financial_category, nullable=False),
|
||||
sa.Column("cost_bucket", cost_bucket, nullable=True),
|
||||
sa.Column("amount", sa.Numeric(precision=18, scale=2), nullable=False),
|
||||
sa.Column("currency", sa.String(length=3), nullable=True),
|
||||
sa.Column("effective_date", sa.Date(), nullable=True),
|
||||
sa.Column("notes", sa.Text(), nullable=True),
|
||||
sa.Column("created_at", sa.DateTime(timezone=True),
|
||||
server_default=sa.func.now(), nullable=False),
|
||||
sa.Column("updated_at", sa.DateTime(timezone=True),
|
||||
server_default=sa.func.now(), nullable=False),
|
||||
sa.ForeignKeyConstraint(
|
||||
["scenario_id"], ["scenarios.id"], ondelete="CASCADE"),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
op.create_index(op.f("ix_financial_inputs_id"),
|
||||
"financial_inputs", ["id"], unique=False)
|
||||
op.create_index(op.f("ix_financial_inputs_scenario_id"),
|
||||
"financial_inputs", ["scenario_id"], unique=False)
|
||||
|
||||
op.create_table(
|
||||
"simulation_parameters",
|
||||
sa.Column("id", sa.Integer(), nullable=False),
|
||||
sa.Column("scenario_id", sa.Integer(), nullable=False),
|
||||
sa.Column("name", sa.String(length=255), nullable=False),
|
||||
sa.Column("distribution", distribution_type, nullable=False),
|
||||
sa.Column("variable", stochastic_variable, nullable=True),
|
||||
sa.Column("resource_type", resource_type, nullable=True),
|
||||
sa.Column("mean_value", sa.Numeric(
|
||||
precision=18, scale=4), nullable=True),
|
||||
sa.Column("standard_deviation", sa.Numeric(
|
||||
precision=18, scale=4), nullable=True),
|
||||
sa.Column("minimum_value", sa.Numeric(
|
||||
precision=18, scale=4), nullable=True),
|
||||
sa.Column("maximum_value", sa.Numeric(
|
||||
precision=18, scale=4), nullable=True),
|
||||
sa.Column("unit", sa.String(length=32), nullable=True),
|
||||
sa.Column("configuration", sa.JSON(), nullable=True),
|
||||
sa.Column("created_at", sa.DateTime(timezone=True),
|
||||
server_default=sa.func.now(), nullable=False),
|
||||
sa.Column("updated_at", sa.DateTime(timezone=True),
|
||||
server_default=sa.func.now(), nullable=False),
|
||||
sa.ForeignKeyConstraint(
|
||||
["scenario_id"], ["scenarios.id"], ondelete="CASCADE"),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
op.create_index(op.f("ix_simulation_parameters_id"),
|
||||
"simulation_parameters", ["id"], unique=False)
|
||||
op.create_index(op.f("ix_simulation_parameters_scenario_id"),
|
||||
"simulation_parameters", ["scenario_id"], unique=False)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_index(op.f("ix_simulation_parameters_scenario_id"),
|
||||
table_name="simulation_parameters")
|
||||
op.drop_index(op.f("ix_simulation_parameters_id"),
|
||||
table_name="simulation_parameters")
|
||||
op.drop_table("simulation_parameters")
|
||||
|
||||
op.drop_index(op.f("ix_financial_inputs_scenario_id"),
|
||||
table_name="financial_inputs")
|
||||
op.drop_index(op.f("ix_financial_inputs_id"),
|
||||
table_name="financial_inputs")
|
||||
op.drop_table("financial_inputs")
|
||||
|
||||
op.drop_index(op.f("ix_scenarios_project_id"), table_name="scenarios")
|
||||
op.drop_index(op.f("ix_scenarios_id"), table_name="scenarios")
|
||||
op.drop_table("scenarios")
|
||||
|
||||
op.drop_index(op.f("ix_projects_id"), table_name="projects")
|
||||
op.drop_table("projects")
|
||||
|
||||
resource_type.drop(op.get_bind(), checkfirst=True)
|
||||
stochastic_variable.drop(op.get_bind(), checkfirst=True)
|
||||
distribution_type.drop(op.get_bind(), checkfirst=True)
|
||||
cost_bucket.drop(op.get_bind(), checkfirst=True)
|
||||
financial_category.drop(op.get_bind(), checkfirst=True)
|
||||
scenario_status.drop(op.get_bind(), checkfirst=True)
|
||||
mining_operation_type.drop(op.get_bind(), checkfirst=True)
|
||||
210
alembic/versions/20251109_02_add_auth_tables.py
Normal file
210
alembic/versions/20251109_02_add_auth_tables.py
Normal file
@@ -0,0 +1,210 @@
|
||||
"""Add authentication and RBAC tables"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from passlib.context import CryptContext
|
||||
from sqlalchemy.sql import column, table
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "20251109_02"
|
||||
down_revision = "20251109_01"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
password_context = CryptContext(schemes=["argon2"], deprecated="auto")
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.create_table(
|
||||
"users",
|
||||
sa.Column("id", sa.Integer(), primary_key=True),
|
||||
sa.Column("email", sa.String(length=255), nullable=False),
|
||||
sa.Column("username", sa.String(length=128), nullable=False),
|
||||
sa.Column("password_hash", sa.String(length=255), nullable=False),
|
||||
sa.Column(
|
||||
"is_active",
|
||||
sa.Boolean(),
|
||||
nullable=False,
|
||||
server_default=sa.true(),
|
||||
),
|
||||
sa.Column(
|
||||
"is_superuser",
|
||||
sa.Boolean(),
|
||||
nullable=False,
|
||||
server_default=sa.false(),
|
||||
),
|
||||
sa.Column("last_login_at", sa.DateTime(timezone=True), nullable=True),
|
||||
sa.Column(
|
||||
"created_at",
|
||||
sa.DateTime(timezone=True),
|
||||
nullable=False,
|
||||
server_default=sa.func.now(),
|
||||
),
|
||||
sa.Column(
|
||||
"updated_at",
|
||||
sa.DateTime(timezone=True),
|
||||
nullable=False,
|
||||
server_default=sa.func.now(),
|
||||
),
|
||||
sa.UniqueConstraint("email", name="uq_users_email"),
|
||||
sa.UniqueConstraint("username", name="uq_users_username"),
|
||||
)
|
||||
op.create_index(
|
||||
"ix_users_active_superuser",
|
||||
"users",
|
||||
["is_active", "is_superuser"],
|
||||
unique=False,
|
||||
)
|
||||
|
||||
op.create_table(
|
||||
"roles",
|
||||
sa.Column("id", sa.Integer(), primary_key=True),
|
||||
sa.Column("name", sa.String(length=64), nullable=False),
|
||||
sa.Column("display_name", sa.String(length=128), nullable=False),
|
||||
sa.Column("description", sa.Text(), nullable=True),
|
||||
sa.Column(
|
||||
"created_at",
|
||||
sa.DateTime(timezone=True),
|
||||
nullable=False,
|
||||
server_default=sa.func.now(),
|
||||
),
|
||||
sa.Column(
|
||||
"updated_at",
|
||||
sa.DateTime(timezone=True),
|
||||
nullable=False,
|
||||
server_default=sa.func.now(),
|
||||
),
|
||||
sa.UniqueConstraint("name", name="uq_roles_name"),
|
||||
)
|
||||
|
||||
op.create_table(
|
||||
"user_roles",
|
||||
sa.Column("user_id", sa.Integer(), nullable=False),
|
||||
sa.Column("role_id", sa.Integer(), nullable=False),
|
||||
sa.Column(
|
||||
"granted_at",
|
||||
sa.DateTime(timezone=True),
|
||||
nullable=False,
|
||||
server_default=sa.func.now(),
|
||||
),
|
||||
sa.Column("granted_by", sa.Integer(), nullable=True),
|
||||
sa.ForeignKeyConstraint(
|
||||
["user_id"],
|
||||
["users.id"],
|
||||
ondelete="CASCADE",
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["role_id"],
|
||||
["roles.id"],
|
||||
ondelete="CASCADE",
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["granted_by"],
|
||||
["users.id"],
|
||||
ondelete="SET NULL",
|
||||
),
|
||||
sa.PrimaryKeyConstraint("user_id", "role_id"),
|
||||
sa.UniqueConstraint("user_id", "role_id",
|
||||
name="uq_user_roles_user_role"),
|
||||
)
|
||||
op.create_index(
|
||||
"ix_user_roles_role_id",
|
||||
"user_roles",
|
||||
["role_id"],
|
||||
unique=False,
|
||||
)
|
||||
|
||||
# Seed default roles
|
||||
roles_table = table(
|
||||
"roles",
|
||||
column("id", sa.Integer()),
|
||||
column("name", sa.String()),
|
||||
column("display_name", sa.String()),
|
||||
column("description", sa.Text()),
|
||||
)
|
||||
|
||||
op.bulk_insert(
|
||||
roles_table,
|
||||
[
|
||||
{
|
||||
"id": 1,
|
||||
"name": "admin",
|
||||
"display_name": "Administrator",
|
||||
"description": "Full platform access with user management rights.",
|
||||
},
|
||||
{
|
||||
"id": 2,
|
||||
"name": "project_manager",
|
||||
"display_name": "Project Manager",
|
||||
"description": "Manage projects, scenarios, and associated data.",
|
||||
},
|
||||
{
|
||||
"id": 3,
|
||||
"name": "analyst",
|
||||
"display_name": "Analyst",
|
||||
"description": "Review dashboards and scenario outputs.",
|
||||
},
|
||||
{
|
||||
"id": 4,
|
||||
"name": "viewer",
|
||||
"display_name": "Viewer",
|
||||
"description": "Read-only access to assigned projects and reports.",
|
||||
},
|
||||
],
|
||||
)
|
||||
|
||||
admin_password_hash = password_context.hash("ChangeMe123!")
|
||||
|
||||
users_table = table(
|
||||
"users",
|
||||
column("id", sa.Integer()),
|
||||
column("email", sa.String()),
|
||||
column("username", sa.String()),
|
||||
column("password_hash", sa.String()),
|
||||
column("is_active", sa.Boolean()),
|
||||
column("is_superuser", sa.Boolean()),
|
||||
)
|
||||
|
||||
op.bulk_insert(
|
||||
users_table,
|
||||
[
|
||||
{
|
||||
"id": 1,
|
||||
"email": "admin@calminer.local",
|
||||
"username": "admin",
|
||||
"password_hash": admin_password_hash,
|
||||
"is_active": True,
|
||||
"is_superuser": True,
|
||||
}
|
||||
],
|
||||
)
|
||||
|
||||
user_roles_table = table(
|
||||
"user_roles",
|
||||
column("user_id", sa.Integer()),
|
||||
column("role_id", sa.Integer()),
|
||||
column("granted_by", sa.Integer()),
|
||||
)
|
||||
|
||||
op.bulk_insert(
|
||||
user_roles_table,
|
||||
[
|
||||
{
|
||||
"user_id": 1,
|
||||
"role_id": 1,
|
||||
"granted_by": 1,
|
||||
}
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_index("ix_user_roles_role_id", table_name="user_roles")
|
||||
op.drop_table("user_roles")
|
||||
|
||||
op.drop_table("roles")
|
||||
|
||||
op.drop_index("ix_users_active_superuser", table_name="users")
|
||||
op.drop_table("users")
|
||||
BIN
alembic_test.db
Normal file
BIN
alembic_test.db
Normal file
Binary file not shown.
32
changelog.md
Normal file
32
changelog.md
Normal file
@@ -0,0 +1,32 @@
|
||||
# Changelog
|
||||
|
||||
## 2025-11-09
|
||||
|
||||
- Captured current implementation status, requirements coverage, missing features, and prioritized roadmap in `calminer-docs/implementation_status.md` to guide future development.
|
||||
- Added core SQLAlchemy domain models, shared metadata descriptors, and Alembic migration setup (with initial schema snapshot) to establish the persistence layer foundation.
|
||||
- Introduced repository and unit-of-work helpers for projects, scenarios, financial inputs, and simulation parameters to support service-layer operations.
|
||||
- Added SQLite-backed pytest coverage for repository and unit-of-work behaviours to validate persistence interactions.
|
||||
- Exposed project and scenario CRUD APIs with validated schemas and integrated them into the FastAPI application.
|
||||
- Connected project and scenario routers to new Jinja2 list/detail/edit views with HTML forms and redirects.
|
||||
- Implemented FR-009 client-side enhancements with responsive navigation toggle, mobile-first scenario tables, and shared asset loading across templates.
|
||||
- Added scenario comparison validator, FastAPI comparison endpoint, and comprehensive unit tests to enforce FR-009 validation rules through API errors.
|
||||
- Delivered a new dashboard experience with `templates/dashboard.html`, dedicated styling, and a FastAPI route supplying real project/scenario metrics via repository helpers.
|
||||
- Extended repositories with count/recency utilities and added pytest coverage, including a dashboard rendering smoke test validating empty-state messaging.
|
||||
- Brought project and scenario detail pages plus their forms in line with the dashboard visuals, adding metric cards, layout grids, and refreshed CTA styles.
|
||||
- Reordered project route registration to prioritize static UI paths, eliminating 422 errors on `/projects/ui` and `/projects/create`, and added pytest smoke coverage for the navigation endpoints.
|
||||
- Added end-to-end integration tests for project and scenario lifecycles, validating HTML redirects, template rendering, and API interactions, and updated `ProjectRepository.get` to deduplicate joined loads for detail views.
|
||||
- Updated all Jinja2 template responses to the new Starlette signature to eliminate deprecation warnings while keeping request-aware context available to the templates.
|
||||
- Introduced `services/security.py` to centralize Argon2 password hashing utilities and JWT creation/verification with typed payloads, and added pytest coverage for hashing, expiry, tampering, and token type mismatch scenarios.
|
||||
- Added `routes/auth.py` with registration, login, and password reset flows, refreshed auth templates with error messaging, wired navigation links, and introduced end-to-end pytest coverage for the new forms and token flows.
|
||||
- Implemented cookie-based authentication session middleware with automatic access token refresh, logout handling, navigation adjustments, and documentation/test updates capturing the new behaviour.
|
||||
- Delivered idempotent seeding utilities with `scripts/initial_data.py`, entry-point runner `scripts/00_initial_data.py`, documentation updates, and pytest coverage to verify role/admin provisioning.
|
||||
- Secured project and scenario routers with RBAC guard dependencies, enforced repository access checks via helper utilities, and aligned template routes with FastAPI dependency injection patterns.
|
||||
|
||||
## 2025-11-10
|
||||
|
||||
- Extended authorization helper layer with project/scenario ownership lookups, integrated them into FastAPI dependencies, refreshed pytest fixtures to keep the suite authenticated, and documented the new patterns across RBAC plan and security guides.
|
||||
- Added dedicated pytest coverage for guard dependencies, exercising success plus failure paths (missing session, inactive user, missing roles, project/scenario access errors) via `tests/test_dependencies_guards.py`.
|
||||
- Added integration tests in `tests/test_authorization_integration.py` verifying anonymous 401 responses, role-based 403s, and authorized project manager flows across API and UI endpoints.
|
||||
- Implemented environment-driven admin bootstrap settings, wired the `bootstrap_admin` helper into FastAPI startup, added pytest coverage for creation/idempotency/reset logic, and documented operational guidance in the RBAC plan and security concept.
|
||||
- Retired the legacy authentication RBAC implementation plan document after migrating its guidance into live documentation and synchronized the contributor instructions to reflect the removal.
|
||||
- Completed the Authentication & RBAC checklist by shipping the new models, migrations, repositories, guard dependencies, and integration tests.
|
||||
1
config/__init__.py
Normal file
1
config/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Configuration package."""
|
||||
188
config/settings.py
Normal file
188
config/settings.py
Normal file
@@ -0,0 +1,188 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from datetime import timedelta
|
||||
from functools import lru_cache
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from services.security import JWTSettings
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class AdminBootstrapSettings:
|
||||
"""Default administrator bootstrap configuration."""
|
||||
|
||||
email: str
|
||||
username: str
|
||||
password: str
|
||||
roles: tuple[str, ...]
|
||||
force_reset: bool
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class SessionSettings:
|
||||
"""Cookie and header configuration for session token transport."""
|
||||
|
||||
access_cookie_name: str
|
||||
refresh_cookie_name: str
|
||||
cookie_secure: bool
|
||||
cookie_domain: Optional[str]
|
||||
cookie_path: str
|
||||
header_name: str
|
||||
header_prefix: str
|
||||
allow_header_fallback: bool
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class Settings:
|
||||
"""Application configuration sourced from environment variables."""
|
||||
|
||||
jwt_secret_key: str = "change-me"
|
||||
jwt_algorithm: str = "HS256"
|
||||
jwt_access_token_minutes: int = 15
|
||||
jwt_refresh_token_days: int = 7
|
||||
session_access_cookie_name: str = "calminer_access_token"
|
||||
session_refresh_cookie_name: str = "calminer_refresh_token"
|
||||
session_cookie_secure: bool = False
|
||||
session_cookie_domain: Optional[str] = None
|
||||
session_cookie_path: str = "/"
|
||||
session_header_name: str = "Authorization"
|
||||
session_header_prefix: str = "Bearer"
|
||||
session_allow_header_fallback: bool = True
|
||||
admin_email: str = "admin@calminer.local"
|
||||
admin_username: str = "admin"
|
||||
admin_password: str = "ChangeMe123!"
|
||||
admin_roles: tuple[str, ...] = ("admin",)
|
||||
admin_force_reset: bool = False
|
||||
|
||||
@classmethod
|
||||
def from_environment(cls) -> "Settings":
|
||||
"""Construct settings from environment variables."""
|
||||
|
||||
return cls(
|
||||
jwt_secret_key=os.getenv("CALMINER_JWT_SECRET", "change-me"),
|
||||
jwt_algorithm=os.getenv("CALMINER_JWT_ALGORITHM", "HS256"),
|
||||
jwt_access_token_minutes=cls._int_from_env(
|
||||
"CALMINER_JWT_ACCESS_MINUTES", 15
|
||||
),
|
||||
jwt_refresh_token_days=cls._int_from_env(
|
||||
"CALMINER_JWT_REFRESH_DAYS", 7
|
||||
),
|
||||
session_access_cookie_name=os.getenv(
|
||||
"CALMINER_SESSION_ACCESS_COOKIE", "calminer_access_token"
|
||||
),
|
||||
session_refresh_cookie_name=os.getenv(
|
||||
"CALMINER_SESSION_REFRESH_COOKIE", "calminer_refresh_token"
|
||||
),
|
||||
session_cookie_secure=cls._bool_from_env(
|
||||
"CALMINER_SESSION_COOKIE_SECURE", False
|
||||
),
|
||||
session_cookie_domain=os.getenv("CALMINER_SESSION_COOKIE_DOMAIN"),
|
||||
session_cookie_path=os.getenv("CALMINER_SESSION_COOKIE_PATH", "/"),
|
||||
session_header_name=os.getenv(
|
||||
"CALMINER_SESSION_HEADER_NAME", "Authorization"
|
||||
),
|
||||
session_header_prefix=os.getenv(
|
||||
"CALMINER_SESSION_HEADER_PREFIX", "Bearer"
|
||||
),
|
||||
session_allow_header_fallback=cls._bool_from_env(
|
||||
"CALMINER_SESSION_ALLOW_HEADER_FALLBACK", True
|
||||
),
|
||||
admin_email=os.getenv(
|
||||
"CALMINER_SEED_ADMIN_EMAIL", "admin@calminer.local"
|
||||
),
|
||||
admin_username=os.getenv(
|
||||
"CALMINER_SEED_ADMIN_USERNAME", "admin"
|
||||
),
|
||||
admin_password=os.getenv(
|
||||
"CALMINER_SEED_ADMIN_PASSWORD", "ChangeMe123!"
|
||||
),
|
||||
admin_roles=cls._parse_admin_roles(
|
||||
os.getenv("CALMINER_SEED_ADMIN_ROLES")
|
||||
),
|
||||
admin_force_reset=cls._bool_from_env(
|
||||
"CALMINER_SEED_FORCE", False
|
||||
),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _int_from_env(name: str, default: int) -> int:
|
||||
raw_value = os.getenv(name)
|
||||
if raw_value is None:
|
||||
return default
|
||||
try:
|
||||
return int(raw_value)
|
||||
except ValueError:
|
||||
return default
|
||||
|
||||
@staticmethod
|
||||
def _bool_from_env(name: str, default: bool) -> bool:
|
||||
raw_value = os.getenv(name)
|
||||
if raw_value is None:
|
||||
return default
|
||||
lowered = raw_value.strip().lower()
|
||||
if lowered in {"1", "true", "yes", "on"}:
|
||||
return True
|
||||
if lowered in {"0", "false", "no", "off"}:
|
||||
return False
|
||||
return default
|
||||
|
||||
@staticmethod
|
||||
def _parse_admin_roles(raw_value: str | None) -> tuple[str, ...]:
|
||||
if not raw_value:
|
||||
return ("admin",)
|
||||
parts = [segment.strip()
|
||||
for segment in raw_value.split(",") if segment.strip()]
|
||||
if "admin" not in parts:
|
||||
parts.insert(0, "admin")
|
||||
seen: set[str] = set()
|
||||
ordered: list[str] = []
|
||||
for role_name in parts:
|
||||
if role_name not in seen:
|
||||
ordered.append(role_name)
|
||||
seen.add(role_name)
|
||||
return tuple(ordered)
|
||||
|
||||
def jwt_settings(self) -> JWTSettings:
|
||||
"""Build runtime JWT settings compatible with token helpers."""
|
||||
|
||||
return JWTSettings(
|
||||
secret_key=self.jwt_secret_key,
|
||||
algorithm=self.jwt_algorithm,
|
||||
access_token_ttl=timedelta(minutes=self.jwt_access_token_minutes),
|
||||
refresh_token_ttl=timedelta(days=self.jwt_refresh_token_days),
|
||||
)
|
||||
|
||||
def session_settings(self) -> SessionSettings:
|
||||
"""Provide transport configuration for session tokens."""
|
||||
|
||||
return SessionSettings(
|
||||
access_cookie_name=self.session_access_cookie_name,
|
||||
refresh_cookie_name=self.session_refresh_cookie_name,
|
||||
cookie_secure=self.session_cookie_secure,
|
||||
cookie_domain=self.session_cookie_domain,
|
||||
cookie_path=self.session_cookie_path,
|
||||
header_name=self.session_header_name,
|
||||
header_prefix=self.session_header_prefix,
|
||||
allow_header_fallback=self.session_allow_header_fallback,
|
||||
)
|
||||
|
||||
def admin_bootstrap_settings(self) -> AdminBootstrapSettings:
|
||||
"""Return configured admin bootstrap settings."""
|
||||
|
||||
return AdminBootstrapSettings(
|
||||
email=self.admin_email,
|
||||
username=self.admin_username,
|
||||
password=self.admin_password,
|
||||
roles=self.admin_roles,
|
||||
force_reset=self.admin_force_reset,
|
||||
)
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def get_settings() -> Settings:
|
||||
"""Return cached application settings."""
|
||||
|
||||
return Settings.from_environment()
|
||||
245
dependencies.py
Normal file
245
dependencies.py
Normal file
@@ -0,0 +1,245 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable, Iterable, Generator
|
||||
|
||||
from fastapi import Depends, HTTPException, Request, status
|
||||
|
||||
from config.settings import Settings, get_settings
|
||||
from models import Project, Role, Scenario, User
|
||||
from services.authorization import (
|
||||
ensure_project_access as ensure_project_access_helper,
|
||||
ensure_scenario_access as ensure_scenario_access_helper,
|
||||
ensure_scenario_in_project as ensure_scenario_in_project_helper,
|
||||
)
|
||||
from services.exceptions import AuthorizationError, EntityNotFoundError
|
||||
from services.security import JWTSettings
|
||||
from services.session import (
|
||||
AuthSession,
|
||||
SessionStrategy,
|
||||
SessionTokens,
|
||||
build_session_strategy,
|
||||
extract_session_tokens,
|
||||
)
|
||||
from services.unit_of_work import UnitOfWork
|
||||
|
||||
|
||||
def get_unit_of_work() -> Generator[UnitOfWork, None, None]:
|
||||
"""FastAPI dependency yielding a unit-of-work instance."""
|
||||
|
||||
with UnitOfWork() as uow:
|
||||
yield uow
|
||||
|
||||
|
||||
def get_application_settings() -> Settings:
|
||||
"""Provide cached application settings instance."""
|
||||
|
||||
return get_settings()
|
||||
|
||||
|
||||
def get_jwt_settings() -> JWTSettings:
|
||||
"""Provide JWT runtime configuration derived from settings."""
|
||||
|
||||
return get_settings().jwt_settings()
|
||||
|
||||
|
||||
def get_session_strategy(
|
||||
settings: Settings = Depends(get_application_settings),
|
||||
) -> SessionStrategy:
|
||||
"""Yield configured session transport strategy."""
|
||||
|
||||
return build_session_strategy(settings.session_settings())
|
||||
|
||||
|
||||
def get_session_tokens(
|
||||
request: Request,
|
||||
strategy: SessionStrategy = Depends(get_session_strategy),
|
||||
) -> SessionTokens:
|
||||
"""Extract raw session tokens from the incoming request."""
|
||||
|
||||
existing = getattr(request.state, "auth_session", None)
|
||||
if isinstance(existing, AuthSession):
|
||||
return existing.tokens
|
||||
|
||||
tokens = extract_session_tokens(request, strategy)
|
||||
request.state.auth_session = AuthSession(tokens=tokens)
|
||||
return tokens
|
||||
|
||||
|
||||
def get_auth_session(
|
||||
request: Request,
|
||||
tokens: SessionTokens = Depends(get_session_tokens),
|
||||
) -> AuthSession:
|
||||
"""Provide authentication session context for the current request."""
|
||||
|
||||
existing = getattr(request.state, "auth_session", None)
|
||||
if isinstance(existing, AuthSession):
|
||||
return existing
|
||||
|
||||
if tokens.is_empty:
|
||||
session = AuthSession.anonymous()
|
||||
else:
|
||||
session = AuthSession(tokens=tokens)
|
||||
request.state.auth_session = session
|
||||
return session
|
||||
|
||||
|
||||
def get_current_user(
|
||||
session: AuthSession = Depends(get_auth_session),
|
||||
) -> User | None:
|
||||
"""Return the current authenticated user if present."""
|
||||
|
||||
return session.user
|
||||
|
||||
|
||||
def require_current_user(
|
||||
session: AuthSession = Depends(get_auth_session),
|
||||
) -> User:
|
||||
"""Ensure that a request is authenticated and return the user context."""
|
||||
|
||||
if session.user is None or session.tokens.is_empty:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Authentication required.",
|
||||
)
|
||||
return session.user
|
||||
|
||||
|
||||
def require_authenticated_user(
|
||||
user: User = Depends(require_current_user),
|
||||
) -> User:
|
||||
"""Ensure the current user account is active."""
|
||||
|
||||
if not user.is_active:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="User account is disabled.",
|
||||
)
|
||||
return user
|
||||
|
||||
|
||||
def _user_role_names(user: User) -> set[str]:
|
||||
roles: Iterable[Role] = getattr(user, "roles", []) or []
|
||||
return {role.name for role in roles}
|
||||
|
||||
|
||||
def require_roles(*roles: str) -> Callable[[User], User]:
|
||||
"""Dependency factory enforcing membership in one of the given roles."""
|
||||
|
||||
required = tuple(role.strip() for role in roles if role.strip())
|
||||
if not required:
|
||||
raise ValueError("require_roles requires at least one role name")
|
||||
|
||||
def _dependency(user: User = Depends(require_authenticated_user)) -> User:
|
||||
if user.is_superuser:
|
||||
return user
|
||||
|
||||
role_names = _user_role_names(user)
|
||||
if not any(role in role_names for role in required):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Insufficient permissions for this action.",
|
||||
)
|
||||
return user
|
||||
|
||||
return _dependency
|
||||
|
||||
|
||||
def require_any_role(*roles: str) -> Callable[[User], User]:
|
||||
"""Alias of require_roles for readability in some contexts."""
|
||||
|
||||
return require_roles(*roles)
|
||||
|
||||
|
||||
def require_project_resource(*, require_manage: bool = False) -> Callable[[int], Project]:
|
||||
"""Dependency factory that resolves a project with authorization checks."""
|
||||
|
||||
def _dependency(
|
||||
project_id: int,
|
||||
user: User = Depends(require_authenticated_user),
|
||||
uow: UnitOfWork = Depends(get_unit_of_work),
|
||||
) -> Project:
|
||||
try:
|
||||
return ensure_project_access_helper(
|
||||
uow,
|
||||
project_id=project_id,
|
||||
user=user,
|
||||
require_manage=require_manage,
|
||||
)
|
||||
except EntityNotFoundError as exc:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=str(exc),
|
||||
) from exc
|
||||
except AuthorizationError as exc:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail=str(exc),
|
||||
) from exc
|
||||
|
||||
return _dependency
|
||||
|
||||
|
||||
def require_scenario_resource(
|
||||
*, require_manage: bool = False, with_children: bool = False
|
||||
) -> Callable[[int], Scenario]:
|
||||
"""Dependency factory that resolves a scenario with authorization checks."""
|
||||
|
||||
def _dependency(
|
||||
scenario_id: int,
|
||||
user: User = Depends(require_authenticated_user),
|
||||
uow: UnitOfWork = Depends(get_unit_of_work),
|
||||
) -> Scenario:
|
||||
try:
|
||||
return ensure_scenario_access_helper(
|
||||
uow,
|
||||
scenario_id=scenario_id,
|
||||
user=user,
|
||||
require_manage=require_manage,
|
||||
with_children=with_children,
|
||||
)
|
||||
except EntityNotFoundError as exc:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=str(exc),
|
||||
) from exc
|
||||
except AuthorizationError as exc:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail=str(exc),
|
||||
) from exc
|
||||
|
||||
return _dependency
|
||||
|
||||
|
||||
def require_project_scenario_resource(
|
||||
*, require_manage: bool = False, with_children: bool = False
|
||||
) -> Callable[[int, int], Scenario]:
|
||||
"""Dependency factory ensuring a scenario belongs to the given project and is accessible."""
|
||||
|
||||
def _dependency(
|
||||
project_id: int,
|
||||
scenario_id: int,
|
||||
user: User = Depends(require_authenticated_user),
|
||||
uow: UnitOfWork = Depends(get_unit_of_work),
|
||||
) -> Scenario:
|
||||
try:
|
||||
return ensure_scenario_in_project_helper(
|
||||
uow,
|
||||
project_id=project_id,
|
||||
scenario_id=scenario_id,
|
||||
user=user,
|
||||
require_manage=require_manage,
|
||||
with_children=with_children,
|
||||
)
|
||||
except EntityNotFoundError as exc:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=str(exc),
|
||||
) from exc
|
||||
except AuthorizationError as exc:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail=str(exc),
|
||||
) from exc
|
||||
|
||||
return _dependency
|
||||
75
main.py
75
main.py
@@ -1,29 +1,34 @@
|
||||
from routes.distributions import router as distributions_router
|
||||
from routes.ui import router as ui_router
|
||||
from routes.parameters import router as parameters_router
|
||||
import logging
|
||||
from typing import Awaitable, Callable
|
||||
|
||||
from fastapi import FastAPI, Request, Response
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
from middleware.validation import validate_json
|
||||
from config.database import Base, engine
|
||||
from routes.scenarios import router as scenarios_router
|
||||
from routes.costs import router as costs_router
|
||||
from routes.consumption import router as consumption_router
|
||||
from routes.production import router as production_router
|
||||
from routes.equipment import router as equipment_router
|
||||
from routes.reporting import router as reporting_router
|
||||
from routes.currencies import router as currencies_router
|
||||
from routes.simulations import router as simulations_router
|
||||
from routes.maintenance import router as maintenance_router
|
||||
from routes.settings import router as settings_router
|
||||
from routes.users import router as users_router
|
||||
|
||||
# Initialize database schema
|
||||
from config.database import Base, engine
|
||||
from config.settings import get_settings
|
||||
from middleware.auth_session import AuthSessionMiddleware
|
||||
from middleware.validation import validate_json
|
||||
from models import (
|
||||
FinancialInput,
|
||||
Project,
|
||||
Scenario,
|
||||
SimulationParameter,
|
||||
)
|
||||
from routes.auth import router as auth_router
|
||||
from routes.dashboard import router as dashboard_router
|
||||
from routes.projects import router as projects_router
|
||||
from routes.scenarios import router as scenarios_router
|
||||
from services.bootstrap import bootstrap_admin
|
||||
|
||||
# Initialize database schema (imports above ensure models are registered)
|
||||
Base.metadata.create_all(bind=engine)
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
app.add_middleware(AuthSessionMiddleware)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@app.middleware("http")
|
||||
async def json_validation(
|
||||
@@ -37,20 +42,26 @@ async def health() -> dict[str, str]:
|
||||
return {"status": "ok"}
|
||||
|
||||
|
||||
app.mount("/static", StaticFiles(directory="static"), name="static")
|
||||
@app.on_event("startup")
|
||||
async def ensure_admin_bootstrap() -> None:
|
||||
settings = get_settings().admin_bootstrap_settings()
|
||||
try:
|
||||
role_result, admin_result = bootstrap_admin(settings=settings)
|
||||
logger.info(
|
||||
"Admin bootstrap completed: roles=%s created=%s updated=%s rotated=%s assigned=%s",
|
||||
role_result.ensured,
|
||||
admin_result.created_user,
|
||||
admin_result.updated_user,
|
||||
admin_result.password_rotated,
|
||||
admin_result.roles_granted,
|
||||
)
|
||||
except Exception: # pragma: no cover - defensive logging
|
||||
logger.exception("Failed to bootstrap administrator account")
|
||||
|
||||
# Include API routers
|
||||
|
||||
app.include_router(dashboard_router)
|
||||
app.include_router(auth_router)
|
||||
app.include_router(projects_router)
|
||||
app.include_router(scenarios_router)
|
||||
app.include_router(parameters_router)
|
||||
app.include_router(distributions_router)
|
||||
app.include_router(costs_router)
|
||||
app.include_router(consumption_router)
|
||||
app.include_router(simulations_router)
|
||||
app.include_router(production_router)
|
||||
app.include_router(equipment_router)
|
||||
app.include_router(maintenance_router)
|
||||
app.include_router(reporting_router)
|
||||
app.include_router(currencies_router)
|
||||
app.include_router(settings_router)
|
||||
app.include_router(ui_router)
|
||||
app.include_router(users_router)
|
||||
|
||||
app.mount("/static", StaticFiles(directory="static"), name="static")
|
||||
|
||||
177
middleware/auth_session.py
Normal file
177
middleware/auth_session.py
Normal file
@@ -0,0 +1,177 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Callable, Iterable, Optional
|
||||
|
||||
from fastapi import Request, Response
|
||||
from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
|
||||
from starlette.types import ASGIApp
|
||||
|
||||
from config.settings import Settings, get_settings
|
||||
from models import User
|
||||
from services.exceptions import EntityNotFoundError
|
||||
from services.security import (
|
||||
JWTSettings,
|
||||
TokenDecodeError,
|
||||
TokenError,
|
||||
TokenExpiredError,
|
||||
TokenTypeMismatchError,
|
||||
create_access_token,
|
||||
create_refresh_token,
|
||||
decode_access_token,
|
||||
decode_refresh_token,
|
||||
)
|
||||
from services.session import (
|
||||
AuthSession,
|
||||
SessionStrategy,
|
||||
SessionTokens,
|
||||
build_session_strategy,
|
||||
clear_session_cookies,
|
||||
extract_session_tokens,
|
||||
set_session_cookies,
|
||||
)
|
||||
from services.unit_of_work import UnitOfWork
|
||||
|
||||
_AUTH_SCOPE = "auth"
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class _ResolutionResult:
|
||||
session: AuthSession
|
||||
strategy: SessionStrategy
|
||||
jwt_settings: JWTSettings
|
||||
|
||||
|
||||
class AuthSessionMiddleware(BaseHTTPMiddleware):
|
||||
"""Resolve authenticated users from session cookies and refresh tokens."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
app: ASGIApp,
|
||||
*,
|
||||
settings_provider: Callable[[], Settings] = get_settings,
|
||||
unit_of_work_factory: Callable[[], UnitOfWork] = UnitOfWork,
|
||||
refresh_scopes: Iterable[str] | None = None,
|
||||
) -> None:
|
||||
super().__init__(app)
|
||||
self._settings_provider = settings_provider
|
||||
self._unit_of_work_factory = unit_of_work_factory
|
||||
self._refresh_scopes = tuple(
|
||||
refresh_scopes) if refresh_scopes else (_AUTH_SCOPE,)
|
||||
|
||||
async def dispatch(self, request: Request, call_next: RequestResponseEndpoint) -> Response:
|
||||
resolved = self._resolve_session(request)
|
||||
response = await call_next(request)
|
||||
self._apply_session(response, resolved)
|
||||
return response
|
||||
|
||||
def _resolve_session(self, request: Request) -> _ResolutionResult:
|
||||
settings = self._settings_provider()
|
||||
jwt_settings = settings.jwt_settings()
|
||||
strategy = build_session_strategy(settings.session_settings())
|
||||
|
||||
tokens = extract_session_tokens(request, strategy)
|
||||
session = AuthSession(tokens=tokens)
|
||||
request.state.auth_session = session
|
||||
|
||||
if tokens.access_token:
|
||||
if self._try_access_token(session, tokens, jwt_settings):
|
||||
return _ResolutionResult(session=session, strategy=strategy, jwt_settings=jwt_settings)
|
||||
|
||||
if tokens.refresh_token:
|
||||
self._try_refresh_token(
|
||||
session, tokens.refresh_token, jwt_settings)
|
||||
|
||||
return _ResolutionResult(session=session, strategy=strategy, jwt_settings=jwt_settings)
|
||||
|
||||
def _try_access_token(
|
||||
self,
|
||||
session: AuthSession,
|
||||
tokens: SessionTokens,
|
||||
jwt_settings: JWTSettings,
|
||||
) -> bool:
|
||||
try:
|
||||
payload = decode_access_token(
|
||||
tokens.access_token or "", jwt_settings)
|
||||
except TokenExpiredError:
|
||||
return False
|
||||
except (TokenDecodeError, TokenTypeMismatchError, TokenError):
|
||||
session.mark_cleared()
|
||||
return False
|
||||
|
||||
user = self._load_user(payload.sub)
|
||||
if not user or not user.is_active or _AUTH_SCOPE not in payload.scopes:
|
||||
session.mark_cleared()
|
||||
return False
|
||||
|
||||
session.user = user
|
||||
session.scopes = tuple(payload.scopes)
|
||||
return True
|
||||
|
||||
def _try_refresh_token(
|
||||
self,
|
||||
session: AuthSession,
|
||||
refresh_token: str,
|
||||
jwt_settings: JWTSettings,
|
||||
) -> None:
|
||||
try:
|
||||
payload = decode_refresh_token(refresh_token, jwt_settings)
|
||||
except (TokenExpiredError, TokenDecodeError, TokenTypeMismatchError, TokenError):
|
||||
session.mark_cleared()
|
||||
return
|
||||
|
||||
user = self._load_user(payload.sub)
|
||||
if not user or not user.is_active or not self._is_refresh_scope_allowed(payload.scopes):
|
||||
session.mark_cleared()
|
||||
return
|
||||
|
||||
session.user = user
|
||||
session.scopes = tuple(payload.scopes)
|
||||
|
||||
access_token = create_access_token(
|
||||
str(user.id),
|
||||
jwt_settings,
|
||||
scopes=payload.scopes,
|
||||
)
|
||||
new_refresh = create_refresh_token(
|
||||
str(user.id),
|
||||
jwt_settings,
|
||||
scopes=payload.scopes,
|
||||
)
|
||||
session.issue_tokens(access_token=access_token,
|
||||
refresh_token=new_refresh)
|
||||
|
||||
def _is_refresh_scope_allowed(self, scopes: Iterable[str]) -> bool:
|
||||
candidate_scopes = set(scopes)
|
||||
return any(scope in candidate_scopes for scope in self._refresh_scopes)
|
||||
|
||||
def _load_user(self, subject: str) -> Optional[User]:
|
||||
try:
|
||||
user_id = int(subject)
|
||||
except ValueError:
|
||||
return None
|
||||
|
||||
with self._unit_of_work_factory() as uow:
|
||||
if not uow.users:
|
||||
return None
|
||||
try:
|
||||
user = uow.users.get(user_id, with_roles=True)
|
||||
except EntityNotFoundError:
|
||||
return None
|
||||
return user
|
||||
|
||||
def _apply_session(self, response: Response, resolved: _ResolutionResult) -> None:
|
||||
session = resolved.session
|
||||
if session.clear_cookies:
|
||||
clear_session_cookies(response, resolved.strategy)
|
||||
return
|
||||
|
||||
if session.issued_access_token:
|
||||
refresh_token = session.issued_refresh_token or session.tokens.refresh_token
|
||||
set_session_cookies(
|
||||
response,
|
||||
access_token=session.issued_access_token,
|
||||
refresh_token=refresh_token,
|
||||
strategy=resolved.strategy,
|
||||
jwt_settings=resolved.jwt_settings,
|
||||
)
|
||||
40
models/__init__.py
Normal file
40
models/__init__.py
Normal file
@@ -0,0 +1,40 @@
|
||||
"""Database models and shared metadata for the CalMiner domain."""
|
||||
|
||||
from .financial_input import FinancialCategory, FinancialInput
|
||||
from .metadata import (
|
||||
COST_BUCKET_METADATA,
|
||||
RESOURCE_METADATA,
|
||||
STOCHASTIC_VARIABLE_METADATA,
|
||||
CostBucket,
|
||||
ResourceDescriptor,
|
||||
ResourceType,
|
||||
StochasticVariable,
|
||||
StochasticVariableDescriptor,
|
||||
)
|
||||
from .project import MiningOperationType, Project
|
||||
from .scenario import Scenario, ScenarioStatus
|
||||
from .simulation_parameter import DistributionType, SimulationParameter
|
||||
from .user import Role, User, UserRole, password_context
|
||||
|
||||
__all__ = [
|
||||
"FinancialCategory",
|
||||
"FinancialInput",
|
||||
"MiningOperationType",
|
||||
"Project",
|
||||
"Scenario",
|
||||
"ScenarioStatus",
|
||||
"DistributionType",
|
||||
"SimulationParameter",
|
||||
"ResourceType",
|
||||
"CostBucket",
|
||||
"StochasticVariable",
|
||||
"RESOURCE_METADATA",
|
||||
"COST_BUCKET_METADATA",
|
||||
"STOCHASTIC_VARIABLE_METADATA",
|
||||
"ResourceDescriptor",
|
||||
"StochasticVariableDescriptor",
|
||||
"User",
|
||||
"Role",
|
||||
"UserRole",
|
||||
"password_context",
|
||||
]
|
||||
88
models/financial_input.py
Normal file
88
models/financial_input.py
Normal file
@@ -0,0 +1,88 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import date, datetime
|
||||
from enum import Enum
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from sqlalchemy import (
|
||||
Date,
|
||||
DateTime,
|
||||
Enum as SQLEnum,
|
||||
ForeignKey,
|
||||
Integer,
|
||||
Numeric,
|
||||
String,
|
||||
Text,
|
||||
)
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship, validates
|
||||
|
||||
from sqlalchemy import (
|
||||
Date,
|
||||
DateTime,
|
||||
Enum as SQLEnum,
|
||||
ForeignKey,
|
||||
Integer,
|
||||
Numeric,
|
||||
String,
|
||||
Text,
|
||||
)
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
from sqlalchemy.sql import func
|
||||
|
||||
from config.database import Base
|
||||
from .metadata import CostBucket
|
||||
|
||||
if TYPE_CHECKING: # pragma: no cover
|
||||
from .scenario import Scenario
|
||||
|
||||
|
||||
class FinancialCategory(str, Enum):
|
||||
"""Enumeration of cost and revenue classifications."""
|
||||
|
||||
CAPITAL_EXPENDITURE = "capex"
|
||||
OPERATING_EXPENDITURE = "opex"
|
||||
REVENUE = "revenue"
|
||||
CONTINGENCY = "contingency"
|
||||
OTHER = "other"
|
||||
|
||||
|
||||
class FinancialInput(Base):
|
||||
"""Line-item financial assumption attached to a scenario."""
|
||||
|
||||
__tablename__ = "financial_inputs"
|
||||
|
||||
id: Mapped[int] = mapped_column(Integer, primary_key=True)
|
||||
scenario_id: Mapped[int] = mapped_column(
|
||||
ForeignKey("scenarios.id", ondelete="CASCADE"), nullable=False, index=True
|
||||
)
|
||||
name: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
category: Mapped[FinancialCategory] = mapped_column(
|
||||
SQLEnum(FinancialCategory), nullable=False
|
||||
)
|
||||
cost_bucket: Mapped[CostBucket | None] = mapped_column(
|
||||
SQLEnum(CostBucket), nullable=True
|
||||
)
|
||||
amount: Mapped[float] = mapped_column(Numeric(18, 2), nullable=False)
|
||||
currency: Mapped[str | None] = mapped_column(String(3), nullable=True)
|
||||
effective_date: Mapped[date | None] = mapped_column(Date, nullable=True)
|
||||
notes: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), nullable=False, server_default=func.now()
|
||||
)
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), nullable=False, server_default=func.now(), onupdate=func.now()
|
||||
)
|
||||
|
||||
scenario: Mapped["Scenario"] = relationship("Scenario", back_populates="financial_inputs")
|
||||
|
||||
@validates("currency")
|
||||
def _validate_currency(self, key: str, value: str | None) -> str | None:
|
||||
if value is None:
|
||||
return value
|
||||
value = value.upper()
|
||||
if len(value) != 3:
|
||||
raise ValueError("Currency code must be a 3-letter ISO 4217 value")
|
||||
return value
|
||||
|
||||
def __repr__(self) -> str: # pragma: no cover
|
||||
return f"FinancialInput(id={self.id!r}, scenario_id={self.scenario_id!r}, name={self.name!r})"
|
||||
146
models/metadata.py
Normal file
146
models/metadata.py
Normal file
@@ -0,0 +1,146 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class ResourceType(str, Enum):
|
||||
"""Primary consumables and resources used in mining operations."""
|
||||
|
||||
DIESEL = "diesel"
|
||||
ELECTRICITY = "electricity"
|
||||
WATER = "water"
|
||||
EXPLOSIVES = "explosives"
|
||||
REAGENTS = "reagents"
|
||||
LABOR = "labor"
|
||||
EQUIPMENT_HOURS = "equipment_hours"
|
||||
TAILINGS_CAPACITY = "tailings_capacity"
|
||||
|
||||
|
||||
class CostBucket(str, Enum):
|
||||
"""Granular cost buckets aligned with project accounting."""
|
||||
|
||||
CAPITAL_INITIAL = "capital_initial"
|
||||
CAPITAL_SUSTAINING = "capital_sustaining"
|
||||
OPERATING_FIXED = "operating_fixed"
|
||||
OPERATING_VARIABLE = "operating_variable"
|
||||
MAINTENANCE = "maintenance"
|
||||
RECLAMATION = "reclamation"
|
||||
ROYALTIES = "royalties"
|
||||
GENERAL_ADMIN = "general_admin"
|
||||
|
||||
|
||||
class StochasticVariable(str, Enum):
|
||||
"""Domain variables that typically require probabilistic modelling."""
|
||||
|
||||
ORE_GRADE = "ore_grade"
|
||||
RECOVERY_RATE = "recovery_rate"
|
||||
METAL_PRICE = "metal_price"
|
||||
OPERATING_COST = "operating_cost"
|
||||
CAPITAL_COST = "capital_cost"
|
||||
DISCOUNT_RATE = "discount_rate"
|
||||
THROUGHPUT = "throughput"
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ResourceDescriptor:
|
||||
"""Describes canonical metadata for a resource type."""
|
||||
|
||||
unit: str
|
||||
description: str
|
||||
|
||||
|
||||
RESOURCE_METADATA: dict[ResourceType, ResourceDescriptor] = {
|
||||
ResourceType.DIESEL: ResourceDescriptor(unit="L", description="Diesel fuel consumption"),
|
||||
ResourceType.ELECTRICITY: ResourceDescriptor(unit="kWh", description="Electrical power usage"),
|
||||
ResourceType.WATER: ResourceDescriptor(unit="m3", description="Process and dust suppression water"),
|
||||
ResourceType.EXPLOSIVES: ResourceDescriptor(unit="kg", description="Blasting agent consumption"),
|
||||
ResourceType.REAGENTS: ResourceDescriptor(unit="kg", description="Processing reagents"),
|
||||
ResourceType.LABOR: ResourceDescriptor(unit="hours", description="Direct labor hours"),
|
||||
ResourceType.EQUIPMENT_HOURS: ResourceDescriptor(unit="hours", description="Mobile equipment operating hours"),
|
||||
ResourceType.TAILINGS_CAPACITY: ResourceDescriptor(unit="m3", description="Tailings storage usage"),
|
||||
}
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class CostBucketDescriptor:
|
||||
"""Describes reporting label and guidance for a cost bucket."""
|
||||
|
||||
label: str
|
||||
description: str
|
||||
|
||||
|
||||
COST_BUCKET_METADATA: dict[CostBucket, CostBucketDescriptor] = {
|
||||
CostBucket.CAPITAL_INITIAL: CostBucketDescriptor(
|
||||
label="Initial Capital",
|
||||
description="Pre-production capital required to construct the mine",
|
||||
),
|
||||
CostBucket.CAPITAL_SUSTAINING: CostBucketDescriptor(
|
||||
label="Sustaining Capital",
|
||||
description="Ongoing capital investments to maintain operations",
|
||||
),
|
||||
CostBucket.OPERATING_FIXED: CostBucketDescriptor(
|
||||
label="Fixed Operating",
|
||||
description="Fixed operating costs independent of production rate",
|
||||
),
|
||||
CostBucket.OPERATING_VARIABLE: CostBucketDescriptor(
|
||||
label="Variable Operating",
|
||||
description="Costs that scale with throughput or production",
|
||||
),
|
||||
CostBucket.MAINTENANCE: CostBucketDescriptor(
|
||||
label="Maintenance",
|
||||
description="Maintenance and repair expenditures",
|
||||
),
|
||||
CostBucket.RECLAMATION: CostBucketDescriptor(
|
||||
label="Reclamation",
|
||||
description="Mine closure and reclamation liabilities",
|
||||
),
|
||||
CostBucket.ROYALTIES: CostBucketDescriptor(
|
||||
label="Royalties",
|
||||
description="Royalty and streaming obligations",
|
||||
),
|
||||
CostBucket.GENERAL_ADMIN: CostBucketDescriptor(
|
||||
label="G&A",
|
||||
description="Corporate and site general and administrative costs",
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class StochasticVariableDescriptor:
|
||||
"""Metadata describing how a stochastic variable is typically modelled."""
|
||||
|
||||
unit: str
|
||||
description: str
|
||||
|
||||
|
||||
STOCHASTIC_VARIABLE_METADATA: dict[StochasticVariable, StochasticVariableDescriptor] = {
|
||||
StochasticVariable.ORE_GRADE: StochasticVariableDescriptor(
|
||||
unit="g/t",
|
||||
description="Head grade variability across the ore body",
|
||||
),
|
||||
StochasticVariable.RECOVERY_RATE: StochasticVariableDescriptor(
|
||||
unit="%",
|
||||
description="Metallurgical recovery uncertainty",
|
||||
),
|
||||
StochasticVariable.METAL_PRICE: StochasticVariableDescriptor(
|
||||
unit="$/unit",
|
||||
description="Commodity price fluctuations",
|
||||
),
|
||||
StochasticVariable.OPERATING_COST: StochasticVariableDescriptor(
|
||||
unit="$/t",
|
||||
description="Operating cost per tonne volatility",
|
||||
),
|
||||
StochasticVariable.CAPITAL_COST: StochasticVariableDescriptor(
|
||||
unit="$",
|
||||
description="Capital cost overrun/underrun potential",
|
||||
),
|
||||
StochasticVariable.DISCOUNT_RATE: StochasticVariableDescriptor(
|
||||
unit="%",
|
||||
description="Discount rate sensitivity",
|
||||
),
|
||||
StochasticVariable.THROUGHPUT: StochasticVariableDescriptor(
|
||||
unit="t/d",
|
||||
description="Plant throughput variability",
|
||||
),
|
||||
}
|
||||
56
models/project.py
Normal file
56
models/project.py
Normal file
@@ -0,0 +1,56 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import TYPE_CHECKING, List
|
||||
|
||||
from sqlalchemy import DateTime, Enum as SQLEnum, Integer, String, Text
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
from sqlalchemy.sql import func
|
||||
|
||||
from config.database import Base
|
||||
|
||||
if TYPE_CHECKING: # pragma: no cover
|
||||
from .scenario import Scenario
|
||||
|
||||
|
||||
class MiningOperationType(str, Enum):
|
||||
"""Supported mining operation categories."""
|
||||
|
||||
OPEN_PIT = "open_pit"
|
||||
UNDERGROUND = "underground"
|
||||
IN_SITU_LEACH = "in_situ_leach"
|
||||
PLACER = "placer"
|
||||
QUARRY = "quarry"
|
||||
MOUNTAINTOP_REMOVAL = "mountaintop_removal"
|
||||
OTHER = "other"
|
||||
|
||||
|
||||
class Project(Base):
|
||||
"""Top-level mining project grouping multiple scenarios."""
|
||||
|
||||
__tablename__ = "projects"
|
||||
|
||||
id: Mapped[int] = mapped_column(Integer, primary_key=True, index=True)
|
||||
name: Mapped[str] = mapped_column(String(255), nullable=False, unique=True)
|
||||
location: Mapped[str | None] = mapped_column(String(255), nullable=True)
|
||||
operation_type: Mapped[MiningOperationType] = mapped_column(
|
||||
SQLEnum(MiningOperationType), nullable=False, default=MiningOperationType.OTHER
|
||||
)
|
||||
description: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), nullable=False, server_default=func.now()
|
||||
)
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), nullable=False, server_default=func.now(), onupdate=func.now()
|
||||
)
|
||||
|
||||
scenarios: Mapped[List["Scenario"]] = relationship(
|
||||
"Scenario",
|
||||
back_populates="project",
|
||||
cascade="all, delete-orphan",
|
||||
passive_deletes=True,
|
||||
)
|
||||
|
||||
def __repr__(self) -> str: # pragma: no cover - helpful for debugging
|
||||
return f"Project(id={self.id!r}, name={self.name!r})"
|
||||
80
models/scenario.py
Normal file
80
models/scenario.py
Normal file
@@ -0,0 +1,80 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import date, datetime
|
||||
from enum import Enum
|
||||
from typing import TYPE_CHECKING, List
|
||||
|
||||
from sqlalchemy import (
|
||||
Date,
|
||||
DateTime,
|
||||
Enum as SQLEnum,
|
||||
ForeignKey,
|
||||
Integer,
|
||||
Numeric,
|
||||
String,
|
||||
Text,
|
||||
)
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
from sqlalchemy.sql import func
|
||||
|
||||
from config.database import Base
|
||||
from .metadata import ResourceType
|
||||
|
||||
if TYPE_CHECKING: # pragma: no cover
|
||||
from .financial_input import FinancialInput
|
||||
from .project import Project
|
||||
from .simulation_parameter import SimulationParameter
|
||||
|
||||
|
||||
class ScenarioStatus(str, Enum):
|
||||
"""Lifecycle states for project scenarios."""
|
||||
|
||||
DRAFT = "draft"
|
||||
ACTIVE = "active"
|
||||
ARCHIVED = "archived"
|
||||
|
||||
|
||||
class Scenario(Base):
|
||||
"""A specific configuration of assumptions for a project."""
|
||||
|
||||
__tablename__ = "scenarios"
|
||||
|
||||
id: Mapped[int] = mapped_column(Integer, primary_key=True, index=True)
|
||||
project_id: Mapped[int] = mapped_column(
|
||||
ForeignKey("projects.id", ondelete="CASCADE"), nullable=False, index=True
|
||||
)
|
||||
name: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
description: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
status: Mapped[ScenarioStatus] = mapped_column(
|
||||
SQLEnum(ScenarioStatus), nullable=False, default=ScenarioStatus.DRAFT
|
||||
)
|
||||
start_date: Mapped[date | None] = mapped_column(Date, nullable=True)
|
||||
end_date: Mapped[date | None] = mapped_column(Date, nullable=True)
|
||||
discount_rate: Mapped[float | None] = mapped_column(Numeric(5, 2), nullable=True)
|
||||
currency: Mapped[str | None] = mapped_column(String(3), nullable=True)
|
||||
primary_resource: Mapped[ResourceType | None] = mapped_column(
|
||||
SQLEnum(ResourceType), nullable=True
|
||||
)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), nullable=False, server_default=func.now()
|
||||
)
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), nullable=False, server_default=func.now(), onupdate=func.now()
|
||||
)
|
||||
|
||||
project: Mapped["Project"] = relationship("Project", back_populates="scenarios")
|
||||
financial_inputs: Mapped[List["FinancialInput"]] = relationship(
|
||||
"FinancialInput",
|
||||
back_populates="scenario",
|
||||
cascade="all, delete-orphan",
|
||||
passive_deletes=True,
|
||||
)
|
||||
simulation_parameters: Mapped[List["SimulationParameter"]] = relationship(
|
||||
"SimulationParameter",
|
||||
back_populates="scenario",
|
||||
cascade="all, delete-orphan",
|
||||
passive_deletes=True,
|
||||
)
|
||||
|
||||
def __repr__(self) -> str: # pragma: no cover
|
||||
return f"Scenario(id={self.id!r}, name={self.name!r}, project_id={self.project_id!r})"
|
||||
80
models/simulation_parameter.py
Normal file
80
models/simulation_parameter.py
Normal file
@@ -0,0 +1,80 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from sqlalchemy import (
|
||||
JSON,
|
||||
DateTime,
|
||||
Enum as SQLEnum,
|
||||
ForeignKey,
|
||||
Integer,
|
||||
Numeric,
|
||||
String,
|
||||
)
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
from sqlalchemy.sql import func
|
||||
|
||||
from config.database import Base
|
||||
from .metadata import ResourceType, StochasticVariable
|
||||
|
||||
if TYPE_CHECKING: # pragma: no cover
|
||||
from .scenario import Scenario
|
||||
|
||||
|
||||
class DistributionType(str, Enum):
|
||||
"""Supported stochastic distribution families for simulations."""
|
||||
|
||||
NORMAL = "normal"
|
||||
TRIANGULAR = "triangular"
|
||||
UNIFORM = "uniform"
|
||||
LOGNORMAL = "lognormal"
|
||||
CUSTOM = "custom"
|
||||
|
||||
|
||||
class SimulationParameter(Base):
|
||||
"""Probability distribution settings for scenario simulations."""
|
||||
|
||||
__tablename__ = "simulation_parameters"
|
||||
|
||||
id: Mapped[int] = mapped_column(Integer, primary_key=True)
|
||||
scenario_id: Mapped[int] = mapped_column(
|
||||
ForeignKey("scenarios.id", ondelete="CASCADE"), nullable=False, index=True
|
||||
)
|
||||
name: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
distribution: Mapped[DistributionType] = mapped_column(
|
||||
SQLEnum(DistributionType), nullable=False
|
||||
)
|
||||
variable: Mapped[StochasticVariable | None] = mapped_column(
|
||||
SQLEnum(StochasticVariable), nullable=True
|
||||
)
|
||||
resource_type: Mapped[ResourceType | None] = mapped_column(
|
||||
SQLEnum(ResourceType), nullable=True
|
||||
)
|
||||
mean_value: Mapped[float | None] = mapped_column(
|
||||
Numeric(18, 4), nullable=True)
|
||||
standard_deviation: Mapped[float | None] = mapped_column(
|
||||
Numeric(18, 4), nullable=True)
|
||||
minimum_value: Mapped[float | None] = mapped_column(
|
||||
Numeric(18, 4), nullable=True)
|
||||
maximum_value: Mapped[float | None] = mapped_column(
|
||||
Numeric(18, 4), nullable=True)
|
||||
unit: Mapped[str | None] = mapped_column(String(32), nullable=True)
|
||||
configuration: Mapped[dict | None] = mapped_column(JSON, nullable=True)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), nullable=False, server_default=func.now()
|
||||
)
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), nullable=False, server_default=func.now(), onupdate=func.now()
|
||||
)
|
||||
|
||||
scenario: Mapped["Scenario"] = relationship(
|
||||
"Scenario", back_populates="simulation_parameters"
|
||||
)
|
||||
|
||||
def __repr__(self) -> str: # pragma: no cover
|
||||
return (
|
||||
f"SimulationParameter(id={self.id!r}, scenario_id={self.scenario_id!r}, "
|
||||
f"name={self.name!r})"
|
||||
)
|
||||
176
models/user.py
Normal file
176
models/user.py
Normal file
@@ -0,0 +1,176 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from typing import List, Optional
|
||||
|
||||
from passlib.context import CryptContext
|
||||
|
||||
try: # pragma: no cover - defensive compatibility shim
|
||||
import importlib.metadata as importlib_metadata
|
||||
import argon2 # type: ignore
|
||||
|
||||
setattr(argon2, "__version__", importlib_metadata.version("argon2-cffi"))
|
||||
except Exception:
|
||||
pass
|
||||
from sqlalchemy import (
|
||||
Boolean,
|
||||
DateTime,
|
||||
ForeignKey,
|
||||
Integer,
|
||||
String,
|
||||
Text,
|
||||
UniqueConstraint,
|
||||
)
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
from sqlalchemy.sql import func
|
||||
|
||||
from config.database import Base
|
||||
|
||||
# Configure password hashing strategy. Argon2 provides strong resistance against
|
||||
# GPU-based cracking attempts, aligning with the security plan.
|
||||
password_context = CryptContext(schemes=["argon2"], deprecated="auto")
|
||||
|
||||
|
||||
class User(Base):
|
||||
"""Authenticated platform user with optional elevated privileges."""
|
||||
|
||||
__tablename__ = "users"
|
||||
__table_args__ = (
|
||||
UniqueConstraint("email", name="uq_users_email"),
|
||||
UniqueConstraint("username", name="uq_users_username"),
|
||||
)
|
||||
|
||||
id: Mapped[int] = mapped_column(Integer, primary_key=True)
|
||||
email: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
username: Mapped[str] = mapped_column(String(128), nullable=False)
|
||||
password_hash: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
is_active: Mapped[bool] = mapped_column(
|
||||
Boolean, nullable=False, default=True)
|
||||
is_superuser: Mapped[bool] = mapped_column(
|
||||
Boolean, nullable=False, default=False)
|
||||
last_login_at: Mapped[datetime | None] = mapped_column(
|
||||
DateTime(timezone=True), nullable=True
|
||||
)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), nullable=False, server_default=func.now()
|
||||
)
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), nullable=False, server_default=func.now(), onupdate=func.now()
|
||||
)
|
||||
|
||||
role_assignments: Mapped[List["UserRole"]] = relationship(
|
||||
"UserRole",
|
||||
back_populates="user",
|
||||
cascade="all, delete-orphan",
|
||||
foreign_keys="UserRole.user_id",
|
||||
)
|
||||
roles: Mapped[List["Role"]] = relationship(
|
||||
"Role",
|
||||
secondary="user_roles",
|
||||
primaryjoin="User.id == UserRole.user_id",
|
||||
secondaryjoin="Role.id == UserRole.role_id",
|
||||
viewonly=True,
|
||||
back_populates="users",
|
||||
)
|
||||
|
||||
def set_password(self, raw_password: str) -> None:
|
||||
"""Hash and store a password for the user."""
|
||||
|
||||
self.password_hash = self.hash_password(raw_password)
|
||||
|
||||
@staticmethod
|
||||
def hash_password(raw_password: str) -> str:
|
||||
"""Return the Argon2 hash for a clear-text password."""
|
||||
|
||||
return password_context.hash(raw_password)
|
||||
|
||||
def verify_password(self, candidate_password: str) -> bool:
|
||||
"""Validate a password against the stored hash."""
|
||||
|
||||
if not self.password_hash:
|
||||
return False
|
||||
return password_context.verify(candidate_password, self.password_hash)
|
||||
|
||||
def __repr__(self) -> str: # pragma: no cover - helpful for debugging
|
||||
return f"User(id={self.id!r}, email={self.email!r})"
|
||||
|
||||
|
||||
class Role(Base):
|
||||
"""Role encapsulating a set of permissions."""
|
||||
|
||||
__tablename__ = "roles"
|
||||
|
||||
id: Mapped[int] = mapped_column(Integer, primary_key=True)
|
||||
name: Mapped[str] = mapped_column(String(64), nullable=False, unique=True)
|
||||
display_name: Mapped[str] = mapped_column(String(128), nullable=False)
|
||||
description: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), nullable=False, server_default=func.now()
|
||||
)
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), nullable=False, server_default=func.now(), onupdate=func.now()
|
||||
)
|
||||
|
||||
assignments: Mapped[List["UserRole"]] = relationship(
|
||||
"UserRole",
|
||||
back_populates="role",
|
||||
cascade="all, delete-orphan",
|
||||
foreign_keys="UserRole.role_id",
|
||||
)
|
||||
users: Mapped[List["User"]] = relationship(
|
||||
"User",
|
||||
secondary="user_roles",
|
||||
primaryjoin="Role.id == UserRole.role_id",
|
||||
secondaryjoin="User.id == UserRole.user_id",
|
||||
viewonly=True,
|
||||
back_populates="roles",
|
||||
)
|
||||
|
||||
def __repr__(self) -> str: # pragma: no cover - helpful for debugging
|
||||
return f"Role(id={self.id!r}, name={self.name!r})"
|
||||
|
||||
|
||||
class UserRole(Base):
|
||||
"""Association between users and roles with assignment metadata."""
|
||||
|
||||
__tablename__ = "user_roles"
|
||||
__table_args__ = (
|
||||
UniqueConstraint("user_id", "role_id", name="uq_user_roles_user_role"),
|
||||
)
|
||||
|
||||
user_id: Mapped[int] = mapped_column(
|
||||
Integer,
|
||||
ForeignKey("users.id", ondelete="CASCADE"),
|
||||
primary_key=True,
|
||||
)
|
||||
role_id: Mapped[int] = mapped_column(
|
||||
Integer,
|
||||
ForeignKey("roles.id", ondelete="CASCADE"),
|
||||
primary_key=True,
|
||||
)
|
||||
granted_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), nullable=False, server_default=func.now()
|
||||
)
|
||||
granted_by: Mapped[Optional[int]] = mapped_column(
|
||||
Integer,
|
||||
ForeignKey("users.id", ondelete="SET NULL"),
|
||||
nullable=True,
|
||||
)
|
||||
|
||||
user: Mapped["User"] = relationship(
|
||||
"User",
|
||||
foreign_keys=[user_id],
|
||||
back_populates="role_assignments",
|
||||
)
|
||||
role: Mapped["Role"] = relationship(
|
||||
"Role",
|
||||
foreign_keys=[role_id],
|
||||
back_populates="assignments",
|
||||
)
|
||||
granted_by_user: Mapped[Optional["User"]] = relationship(
|
||||
"User",
|
||||
foreign_keys=[granted_by],
|
||||
)
|
||||
|
||||
def __repr__(self) -> str: # pragma: no cover - debugging helper
|
||||
return f"UserRole(user_id={self.user_id!r}, role_id={self.role_id!r})"
|
||||
@@ -14,3 +14,6 @@ exclude = '''
|
||||
)/
|
||||
'''
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
pythonpath = ["."]
|
||||
|
||||
|
||||
2
requirements-dev.txt
Normal file
2
requirements-dev.txt
Normal file
@@ -0,0 +1,2 @@
|
||||
-r requirements.txt
|
||||
alembic
|
||||
@@ -9,4 +9,6 @@ jinja2
|
||||
pandas
|
||||
numpy
|
||||
passlib
|
||||
argon2-cffi
|
||||
python-jose
|
||||
python-multipart
|
||||
1
routes/__init__.py
Normal file
1
routes/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""API route registrations."""
|
||||
484
routes/auth.py
Normal file
484
routes/auth.py
Normal file
@@ -0,0 +1,484 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Any, Iterable
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request, UploadFile, status
|
||||
from fastapi.responses import HTMLResponse, RedirectResponse
|
||||
from fastapi.templating import Jinja2Templates
|
||||
from pydantic import ValidationError
|
||||
from starlette.datastructures import FormData
|
||||
|
||||
from dependencies import (
|
||||
get_auth_session,
|
||||
get_jwt_settings,
|
||||
get_session_strategy,
|
||||
get_unit_of_work,
|
||||
require_current_user,
|
||||
)
|
||||
from models import Role, User
|
||||
from schemas.auth import (
|
||||
LoginForm,
|
||||
PasswordResetForm,
|
||||
PasswordResetRequestForm,
|
||||
RegistrationForm,
|
||||
)
|
||||
from services.exceptions import EntityConflictError
|
||||
from services.security import (
|
||||
JWTSettings,
|
||||
TokenDecodeError,
|
||||
TokenExpiredError,
|
||||
TokenTypeMismatchError,
|
||||
create_access_token,
|
||||
create_refresh_token,
|
||||
decode_access_token,
|
||||
hash_password,
|
||||
verify_password,
|
||||
)
|
||||
from services.session import (
|
||||
AuthSession,
|
||||
SessionStrategy,
|
||||
clear_session_cookies,
|
||||
set_session_cookies,
|
||||
)
|
||||
from services.repositories import RoleRepository, UserRepository
|
||||
from services.unit_of_work import UnitOfWork
|
||||
|
||||
router = APIRouter(tags=["Authentication"])
|
||||
templates = Jinja2Templates(directory="templates")
|
||||
|
||||
_PASSWORD_RESET_SCOPE = "password-reset"
|
||||
_AUTH_SCOPE = "auth"
|
||||
|
||||
|
||||
def _template(
|
||||
request: Request,
|
||||
template_name: str,
|
||||
context: dict[str, Any],
|
||||
*,
|
||||
status_code: int = status.HTTP_200_OK,
|
||||
) -> HTMLResponse:
|
||||
return templates.TemplateResponse(
|
||||
request,
|
||||
template_name,
|
||||
context,
|
||||
status_code=status_code,
|
||||
)
|
||||
|
||||
|
||||
def _validation_errors(exc: ValidationError) -> list[str]:
|
||||
return [error.get("msg", "Invalid input.") for error in exc.errors()]
|
||||
|
||||
|
||||
def _scopes(include: Iterable[str]) -> list[str]:
|
||||
return list(include)
|
||||
|
||||
|
||||
def _normalise_form_data(form_data: FormData) -> dict[str, str]:
|
||||
normalised: dict[str, str] = {}
|
||||
for key, value in form_data.multi_items():
|
||||
if isinstance(value, UploadFile):
|
||||
str_value = value.filename or ""
|
||||
else:
|
||||
str_value = str(value)
|
||||
normalised[key] = str_value
|
||||
return normalised
|
||||
|
||||
|
||||
def _require_users_repo(uow: UnitOfWork) -> UserRepository:
|
||||
if not uow.users:
|
||||
raise RuntimeError("User repository is not initialised")
|
||||
return uow.users
|
||||
|
||||
|
||||
def _require_roles_repo(uow: UnitOfWork) -> RoleRepository:
|
||||
if not uow.roles:
|
||||
raise RuntimeError("Role repository is not initialised")
|
||||
return uow.roles
|
||||
|
||||
|
||||
@router.get("/login", response_class=HTMLResponse, include_in_schema=False, name="auth.login_form")
|
||||
def login_form(request: Request) -> HTMLResponse:
|
||||
return _template(
|
||||
request,
|
||||
"login.html",
|
||||
{
|
||||
"form_action": request.url_for("auth.login_submit"),
|
||||
"errors": [],
|
||||
"username": "",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@router.post("/login", include_in_schema=False, name="auth.login_submit")
|
||||
async def login_submit(
|
||||
request: Request,
|
||||
uow: UnitOfWork = Depends(get_unit_of_work),
|
||||
jwt_settings: JWTSettings = Depends(get_jwt_settings),
|
||||
session_strategy: SessionStrategy = Depends(get_session_strategy),
|
||||
):
|
||||
form_data = _normalise_form_data(await request.form())
|
||||
try:
|
||||
form = LoginForm(**form_data)
|
||||
except ValidationError as exc:
|
||||
return _template(
|
||||
request,
|
||||
"login.html",
|
||||
{
|
||||
"form_action": request.url_for("auth.login_submit"),
|
||||
"errors": _validation_errors(exc),
|
||||
},
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
)
|
||||
|
||||
identifier = form.username
|
||||
users_repo = _require_users_repo(uow)
|
||||
user = _lookup_user(users_repo, identifier)
|
||||
errors: list[str] = []
|
||||
|
||||
if not user or not verify_password(form.password, user.password_hash):
|
||||
errors.append("Invalid username or password.")
|
||||
elif not user.is_active:
|
||||
errors.append("Account is inactive. Contact an administrator.")
|
||||
|
||||
if errors:
|
||||
return _template(
|
||||
request,
|
||||
"login.html",
|
||||
{
|
||||
"form_action": request.url_for("auth.login_submit"),
|
||||
"errors": errors,
|
||||
"username": identifier,
|
||||
},
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
)
|
||||
|
||||
assert user is not None # mypy hint - guarded above
|
||||
user.last_login_at = datetime.now(timezone.utc)
|
||||
|
||||
access_token = create_access_token(
|
||||
str(user.id),
|
||||
jwt_settings,
|
||||
scopes=_scopes((_AUTH_SCOPE,)),
|
||||
)
|
||||
refresh_token = create_refresh_token(
|
||||
str(user.id),
|
||||
jwt_settings,
|
||||
scopes=_scopes((_AUTH_SCOPE,)),
|
||||
)
|
||||
|
||||
response = RedirectResponse(
|
||||
request.url_for("dashboard.home"),
|
||||
status_code=status.HTTP_303_SEE_OTHER,
|
||||
)
|
||||
set_session_cookies(
|
||||
response,
|
||||
access_token=access_token,
|
||||
refresh_token=refresh_token,
|
||||
strategy=session_strategy,
|
||||
jwt_settings=jwt_settings,
|
||||
)
|
||||
return response
|
||||
|
||||
|
||||
@router.get("/logout", include_in_schema=False, name="auth.logout")
|
||||
async def logout(
|
||||
request: Request,
|
||||
_: User = Depends(require_current_user),
|
||||
session: AuthSession = Depends(get_auth_session),
|
||||
session_strategy: SessionStrategy = Depends(get_session_strategy),
|
||||
) -> RedirectResponse:
|
||||
session.mark_cleared()
|
||||
redirect_url = request.url_for(
|
||||
"auth.login_form").include_query_params(logout="1")
|
||||
response = RedirectResponse(
|
||||
redirect_url,
|
||||
status_code=status.HTTP_303_SEE_OTHER,
|
||||
)
|
||||
clear_session_cookies(response, session_strategy)
|
||||
return response
|
||||
|
||||
|
||||
def _lookup_user(users_repo: UserRepository, identifier: str) -> User | None:
|
||||
if "@" in identifier:
|
||||
return users_repo.get_by_email(identifier.lower(), with_roles=True)
|
||||
return users_repo.get_by_username(identifier, with_roles=True)
|
||||
|
||||
|
||||
@router.get("/register", response_class=HTMLResponse, include_in_schema=False, name="auth.register_form")
|
||||
def register_form(request: Request) -> HTMLResponse:
|
||||
return _template(
|
||||
request,
|
||||
"register.html",
|
||||
{
|
||||
"form_action": request.url_for("auth.register_submit"),
|
||||
"errors": [],
|
||||
"form_data": None,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@router.post("/register", include_in_schema=False, name="auth.register_submit")
|
||||
async def register_submit(
|
||||
request: Request,
|
||||
uow: UnitOfWork = Depends(get_unit_of_work),
|
||||
):
|
||||
form_data = _normalise_form_data(await request.form())
|
||||
try:
|
||||
form = RegistrationForm(**form_data)
|
||||
except ValidationError as exc:
|
||||
return _registration_error_response(request, _validation_errors(exc))
|
||||
|
||||
errors: list[str] = []
|
||||
users_repo = _require_users_repo(uow)
|
||||
roles_repo = _require_roles_repo(uow)
|
||||
uow.ensure_default_roles()
|
||||
|
||||
if users_repo.get_by_email(form.email):
|
||||
errors.append("Email is already registered.")
|
||||
if users_repo.get_by_username(form.username):
|
||||
errors.append("Username is already taken.")
|
||||
|
||||
if errors:
|
||||
return _registration_error_response(request, errors, form)
|
||||
|
||||
user = User(
|
||||
email=form.email,
|
||||
username=form.username,
|
||||
password_hash=hash_password(form.password),
|
||||
is_active=True,
|
||||
is_superuser=False,
|
||||
)
|
||||
|
||||
try:
|
||||
created = users_repo.create(user)
|
||||
except EntityConflictError:
|
||||
return _registration_error_response(
|
||||
request,
|
||||
["An account with this username or email already exists."],
|
||||
form,
|
||||
)
|
||||
|
||||
viewer_role = _ensure_viewer_role(roles_repo)
|
||||
if viewer_role is not None:
|
||||
users_repo.assign_role(
|
||||
user_id=created.id,
|
||||
role_id=viewer_role.id,
|
||||
granted_by=created.id,
|
||||
)
|
||||
|
||||
redirect_url = request.url_for(
|
||||
"auth.login_form").include_query_params(registered="1")
|
||||
return RedirectResponse(
|
||||
redirect_url,
|
||||
status_code=status.HTTP_303_SEE_OTHER,
|
||||
)
|
||||
|
||||
|
||||
def _registration_error_response(
|
||||
request: Request,
|
||||
errors: list[str],
|
||||
form: RegistrationForm | None = None,
|
||||
) -> HTMLResponse:
|
||||
context = {
|
||||
"form_action": request.url_for("auth.register_submit"),
|
||||
"errors": errors,
|
||||
"form_data": form.model_dump(exclude={"password", "confirm_password"}) if form else None,
|
||||
}
|
||||
return _template(
|
||||
request,
|
||||
"register.html",
|
||||
context,
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
)
|
||||
|
||||
|
||||
def _ensure_viewer_role(roles_repo: RoleRepository) -> Role | None:
|
||||
viewer = roles_repo.get_by_name("viewer")
|
||||
if viewer:
|
||||
return viewer
|
||||
return roles_repo.get_by_name("viewer")
|
||||
|
||||
|
||||
@router.get(
|
||||
"/forgot-password",
|
||||
response_class=HTMLResponse,
|
||||
include_in_schema=False,
|
||||
name="auth.password_reset_request_form",
|
||||
)
|
||||
def password_reset_request_form(request: Request) -> HTMLResponse:
|
||||
return _template(
|
||||
request,
|
||||
"forgot_password.html",
|
||||
{
|
||||
"form_action": request.url_for("auth.password_reset_request_submit"),
|
||||
"errors": [],
|
||||
"message": None,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/forgot-password",
|
||||
include_in_schema=False,
|
||||
name="auth.password_reset_request_submit",
|
||||
)
|
||||
async def password_reset_request_submit(
|
||||
request: Request,
|
||||
uow: UnitOfWork = Depends(get_unit_of_work),
|
||||
jwt_settings: JWTSettings = Depends(get_jwt_settings),
|
||||
):
|
||||
form_data = _normalise_form_data(await request.form())
|
||||
try:
|
||||
form = PasswordResetRequestForm(**form_data)
|
||||
except ValidationError as exc:
|
||||
return _template(
|
||||
request,
|
||||
"forgot_password.html",
|
||||
{
|
||||
"form_action": request.url_for("auth.password_reset_request_submit"),
|
||||
"errors": _validation_errors(exc),
|
||||
"message": None,
|
||||
},
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
)
|
||||
|
||||
users_repo = _require_users_repo(uow)
|
||||
user = users_repo.get_by_email(form.email)
|
||||
if not user:
|
||||
return _template(
|
||||
request,
|
||||
"forgot_password.html",
|
||||
{
|
||||
"form_action": request.url_for("auth.password_reset_request_submit"),
|
||||
"errors": [],
|
||||
"message": "If an account exists, a reset link has been sent.",
|
||||
},
|
||||
)
|
||||
|
||||
token = create_access_token(
|
||||
str(user.id),
|
||||
jwt_settings,
|
||||
scopes=_scopes((_PASSWORD_RESET_SCOPE,)),
|
||||
expires_delta=timedelta(hours=1),
|
||||
)
|
||||
|
||||
reset_url = request.url_for(
|
||||
"auth.password_reset_form").include_query_params(token=token)
|
||||
return RedirectResponse(reset_url, status_code=status.HTTP_303_SEE_OTHER)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/reset-password",
|
||||
response_class=HTMLResponse,
|
||||
include_in_schema=False,
|
||||
name="auth.password_reset_form",
|
||||
)
|
||||
def password_reset_form(
|
||||
request: Request,
|
||||
token: str | None = None,
|
||||
jwt_settings: JWTSettings = Depends(get_jwt_settings),
|
||||
) -> HTMLResponse:
|
||||
errors: list[str] = []
|
||||
if not token:
|
||||
errors.append("Missing password reset token.")
|
||||
else:
|
||||
try:
|
||||
payload = decode_access_token(token, jwt_settings)
|
||||
if _PASSWORD_RESET_SCOPE not in payload.scopes:
|
||||
errors.append("Invalid token scope.")
|
||||
except TokenExpiredError:
|
||||
errors.append(
|
||||
"Token has expired. Please request a new password reset.")
|
||||
except (TokenDecodeError, TokenTypeMismatchError):
|
||||
errors.append("Invalid password reset token.")
|
||||
|
||||
return _template(
|
||||
request,
|
||||
"reset_password.html",
|
||||
{
|
||||
"form_action": request.url_for("auth.password_reset_submit"),
|
||||
"token": token,
|
||||
"errors": errors,
|
||||
},
|
||||
status_code=status.HTTP_400_BAD_REQUEST if errors else status.HTTP_200_OK,
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/reset-password",
|
||||
include_in_schema=False,
|
||||
name="auth.password_reset_submit",
|
||||
)
|
||||
async def password_reset_submit(
|
||||
request: Request,
|
||||
uow: UnitOfWork = Depends(get_unit_of_work),
|
||||
jwt_settings: JWTSettings = Depends(get_jwt_settings),
|
||||
):
|
||||
form_data = _normalise_form_data(await request.form())
|
||||
try:
|
||||
form = PasswordResetForm(**form_data)
|
||||
except ValidationError as exc:
|
||||
return _template(
|
||||
request,
|
||||
"reset_password.html",
|
||||
{
|
||||
"form_action": request.url_for("auth.password_reset_submit"),
|
||||
"token": form_data.get("token"),
|
||||
"errors": _validation_errors(exc),
|
||||
},
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
)
|
||||
|
||||
try:
|
||||
payload = decode_access_token(form.token, jwt_settings)
|
||||
except TokenExpiredError:
|
||||
return _reset_error_response(
|
||||
request,
|
||||
form.token,
|
||||
"Token has expired. Please request a new password reset.",
|
||||
)
|
||||
except (TokenDecodeError, TokenTypeMismatchError):
|
||||
return _reset_error_response(
|
||||
request,
|
||||
form.token,
|
||||
"Invalid password reset token.",
|
||||
)
|
||||
|
||||
if _PASSWORD_RESET_SCOPE not in payload.scopes:
|
||||
return _reset_error_response(
|
||||
request,
|
||||
form.token,
|
||||
"Invalid password reset token scope.",
|
||||
)
|
||||
|
||||
users_repo = _require_users_repo(uow)
|
||||
user_id = int(payload.sub)
|
||||
user = users_repo.get(user_id)
|
||||
if not user:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail="User not found")
|
||||
|
||||
user.set_password(form.password)
|
||||
if not user.is_active:
|
||||
user.is_active = True
|
||||
|
||||
redirect_url = request.url_for(
|
||||
"auth.login_form").include_query_params(reset="1")
|
||||
return RedirectResponse(
|
||||
redirect_url,
|
||||
status_code=status.HTTP_303_SEE_OTHER,
|
||||
)
|
||||
|
||||
|
||||
def _reset_error_response(request: Request, token: str, message: str) -> HTMLResponse:
|
||||
return _template(
|
||||
request,
|
||||
"reset_password.html",
|
||||
{
|
||||
"form_action": request.url_for("auth.password_reset_submit"),
|
||||
"token": token,
|
||||
"errors": [message],
|
||||
},
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
)
|
||||
124
routes/dashboard.py
Normal file
124
routes/dashboard.py
Normal file
@@ -0,0 +1,124 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
|
||||
from fastapi import APIRouter, Depends, Request
|
||||
from fastapi.responses import HTMLResponse
|
||||
from fastapi.templating import Jinja2Templates
|
||||
|
||||
from dependencies import get_unit_of_work, require_authenticated_user
|
||||
from models import User
|
||||
from models import ScenarioStatus
|
||||
from services.unit_of_work import UnitOfWork
|
||||
|
||||
router = APIRouter(tags=["Dashboard"])
|
||||
templates = Jinja2Templates(directory="templates")
|
||||
|
||||
|
||||
def _format_timestamp(moment: datetime | None) -> str | None:
|
||||
if moment is None:
|
||||
return None
|
||||
return moment.strftime("%Y-%m-%d")
|
||||
|
||||
|
||||
def _format_timestamp_with_time(moment: datetime | None) -> str | None:
|
||||
if moment is None:
|
||||
return None
|
||||
return moment.strftime("%Y-%m-%d %H:%M")
|
||||
|
||||
|
||||
def _load_metrics(uow: UnitOfWork) -> dict[str, object]:
|
||||
if not uow.projects or not uow.scenarios or not uow.financial_inputs:
|
||||
raise RuntimeError("UnitOfWork repositories not initialised")
|
||||
total_projects = uow.projects.count()
|
||||
active_scenarios = uow.scenarios.count_by_status(ScenarioStatus.ACTIVE)
|
||||
pending_simulations = uow.scenarios.count_by_status(ScenarioStatus.DRAFT)
|
||||
last_import_at = uow.financial_inputs.latest_created_at()
|
||||
return {
|
||||
"total_projects": total_projects,
|
||||
"active_scenarios": active_scenarios,
|
||||
"pending_simulations": pending_simulations,
|
||||
"last_import": _format_timestamp(last_import_at),
|
||||
}
|
||||
|
||||
|
||||
def _load_recent_projects(uow: UnitOfWork) -> list:
|
||||
if not uow.projects:
|
||||
raise RuntimeError("Project repository not initialised")
|
||||
return list(uow.projects.recent(limit=5))
|
||||
|
||||
|
||||
def _load_simulation_updates(uow: UnitOfWork) -> list[dict[str, object]]:
|
||||
updates: list[dict[str, object]] = []
|
||||
if not uow.scenarios:
|
||||
raise RuntimeError("Scenario repository not initialised")
|
||||
scenarios = uow.scenarios.recent(limit=5, with_project=True)
|
||||
for scenario in scenarios:
|
||||
project_name = scenario.project.name if scenario.project else f"Project #{scenario.project_id}"
|
||||
timestamp_label = _format_timestamp_with_time(scenario.updated_at)
|
||||
updates.append(
|
||||
{
|
||||
"title": f"{scenario.name} · {scenario.status.value.title()}",
|
||||
"description": f"Latest update recorded for {project_name}.",
|
||||
"timestamp": scenario.updated_at,
|
||||
"timestamp_label": timestamp_label,
|
||||
}
|
||||
)
|
||||
return updates
|
||||
|
||||
|
||||
def _load_scenario_alerts(
|
||||
request: Request, uow: UnitOfWork
|
||||
) -> list[dict[str, object]]:
|
||||
alerts: list[dict[str, object]] = []
|
||||
|
||||
if not uow.scenarios:
|
||||
raise RuntimeError("Scenario repository not initialised")
|
||||
|
||||
drafts = uow.scenarios.list_by_status(
|
||||
ScenarioStatus.DRAFT, limit=3, with_project=True
|
||||
)
|
||||
for scenario in drafts:
|
||||
project_name = scenario.project.name if scenario.project else f"Project #{scenario.project_id}"
|
||||
alerts.append(
|
||||
{
|
||||
"title": f"Draft scenario: {scenario.name}",
|
||||
"message": f"{project_name} has a scenario awaiting validation.",
|
||||
"link": request.url_for(
|
||||
"projects.view_project", project_id=scenario.project_id
|
||||
),
|
||||
}
|
||||
)
|
||||
|
||||
if not alerts:
|
||||
archived = uow.scenarios.list_by_status(
|
||||
ScenarioStatus.ARCHIVED, limit=3, with_project=True
|
||||
)
|
||||
for scenario in archived:
|
||||
project_name = scenario.project.name if scenario.project else f"Project #{scenario.project_id}"
|
||||
alerts.append(
|
||||
{
|
||||
"title": f"Archived scenario: {scenario.name}",
|
||||
"message": f"Review archived scenario insights for {project_name}.",
|
||||
"link": request.url_for(
|
||||
"scenarios.view_scenario", scenario_id=scenario.id
|
||||
),
|
||||
}
|
||||
)
|
||||
|
||||
return alerts
|
||||
|
||||
|
||||
@router.get("/", response_class=HTMLResponse, include_in_schema=False, name="dashboard.home")
|
||||
def dashboard_home(
|
||||
request: Request,
|
||||
_: User = Depends(require_authenticated_user),
|
||||
uow: UnitOfWork = Depends(get_unit_of_work),
|
||||
) -> HTMLResponse:
|
||||
context = {
|
||||
"metrics": _load_metrics(uow),
|
||||
"recent_projects": _load_recent_projects(uow),
|
||||
"simulation_updates": _load_simulation_updates(uow),
|
||||
"scenario_alerts": _load_scenario_alerts(request, uow),
|
||||
}
|
||||
return templates.TemplateResponse(request, "dashboard.html", context)
|
||||
319
routes/projects.py
Normal file
319
routes/projects.py
Normal file
@@ -0,0 +1,319 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import List
|
||||
|
||||
from fastapi import APIRouter, Depends, Form, HTTPException, Request, status
|
||||
from fastapi.responses import HTMLResponse, RedirectResponse
|
||||
from fastapi.templating import Jinja2Templates
|
||||
|
||||
from dependencies import (
|
||||
get_unit_of_work,
|
||||
require_any_role,
|
||||
require_project_resource,
|
||||
require_roles,
|
||||
)
|
||||
from models import MiningOperationType, Project, ScenarioStatus, User
|
||||
from schemas.project import ProjectCreate, ProjectRead, ProjectUpdate
|
||||
from services.exceptions import EntityConflictError, EntityNotFoundError
|
||||
from services.unit_of_work import UnitOfWork
|
||||
|
||||
router = APIRouter(prefix="/projects", tags=["Projects"])
|
||||
templates = Jinja2Templates(directory="templates")
|
||||
|
||||
READ_ROLES = ("viewer", "analyst", "project_manager", "admin")
|
||||
MANAGE_ROLES = ("project_manager", "admin")
|
||||
|
||||
|
||||
def _to_read_model(project: Project) -> ProjectRead:
|
||||
return ProjectRead.model_validate(project)
|
||||
|
||||
|
||||
def _require_project_repo(uow: UnitOfWork):
|
||||
if not uow.projects:
|
||||
raise RuntimeError("Project repository not initialised")
|
||||
return uow.projects
|
||||
|
||||
|
||||
def _operation_type_choices() -> list[tuple[str, str]]:
|
||||
return [
|
||||
(op.value, op.name.replace("_", " ").title()) for op in MiningOperationType
|
||||
]
|
||||
|
||||
|
||||
@router.get("", response_model=List[ProjectRead])
|
||||
def list_projects(
|
||||
_: User = Depends(require_any_role(*READ_ROLES)),
|
||||
uow: UnitOfWork = Depends(get_unit_of_work),
|
||||
) -> List[ProjectRead]:
|
||||
projects = _require_project_repo(uow).list()
|
||||
return [_to_read_model(project) for project in projects]
|
||||
|
||||
|
||||
@router.post("", response_model=ProjectRead, status_code=status.HTTP_201_CREATED)
|
||||
def create_project(
|
||||
payload: ProjectCreate,
|
||||
_: User = Depends(require_roles(*MANAGE_ROLES)),
|
||||
uow: UnitOfWork = Depends(get_unit_of_work),
|
||||
) -> ProjectRead:
|
||||
project = Project(**payload.model_dump())
|
||||
try:
|
||||
created = _require_project_repo(uow).create(project)
|
||||
except EntityConflictError as exc:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_409_CONFLICT, detail=str(exc)
|
||||
) from exc
|
||||
return _to_read_model(created)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/ui",
|
||||
response_class=HTMLResponse,
|
||||
include_in_schema=False,
|
||||
name="projects.project_list_page",
|
||||
)
|
||||
def project_list_page(
|
||||
request: Request,
|
||||
_: User = Depends(require_any_role(*READ_ROLES)),
|
||||
uow: UnitOfWork = Depends(get_unit_of_work),
|
||||
) -> HTMLResponse:
|
||||
projects = _require_project_repo(uow).list(with_children=True)
|
||||
for project in projects:
|
||||
setattr(project, "scenario_count", len(project.scenarios))
|
||||
return templates.TemplateResponse(
|
||||
request,
|
||||
"projects/list.html",
|
||||
{
|
||||
"projects": projects,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/create",
|
||||
response_class=HTMLResponse,
|
||||
include_in_schema=False,
|
||||
name="projects.create_project_form",
|
||||
)
|
||||
def create_project_form(
|
||||
request: Request, _: User = Depends(require_roles(*MANAGE_ROLES))
|
||||
) -> HTMLResponse:
|
||||
return templates.TemplateResponse(
|
||||
request,
|
||||
"projects/form.html",
|
||||
{
|
||||
"project": None,
|
||||
"operation_types": _operation_type_choices(),
|
||||
"form_action": request.url_for("projects.create_project_submit"),
|
||||
"cancel_url": request.url_for("projects.project_list_page"),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/create",
|
||||
include_in_schema=False,
|
||||
name="projects.create_project_submit",
|
||||
)
|
||||
def create_project_submit(
|
||||
request: Request,
|
||||
_: User = Depends(require_roles(*MANAGE_ROLES)),
|
||||
name: str = Form(...),
|
||||
location: str | None = Form(None),
|
||||
operation_type: str = Form(...),
|
||||
description: str | None = Form(None),
|
||||
uow: UnitOfWork = Depends(get_unit_of_work),
|
||||
):
|
||||
def _normalise(value: str | None) -> str | None:
|
||||
if value is None:
|
||||
return None
|
||||
value = value.strip()
|
||||
return value or None
|
||||
|
||||
try:
|
||||
op_type = MiningOperationType(operation_type)
|
||||
except ValueError as exc:
|
||||
return templates.TemplateResponse(
|
||||
request,
|
||||
"projects/form.html",
|
||||
{
|
||||
"project": None,
|
||||
"operation_types": _operation_type_choices(),
|
||||
"form_action": request.url_for("projects.create_project_submit"),
|
||||
"cancel_url": request.url_for("projects.project_list_page"),
|
||||
"error": "Invalid operation type.",
|
||||
},
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
)
|
||||
|
||||
project = Project(
|
||||
name=name.strip(),
|
||||
location=_normalise(location),
|
||||
operation_type=op_type,
|
||||
description=_normalise(description),
|
||||
)
|
||||
try:
|
||||
_require_project_repo(uow).create(project)
|
||||
except EntityConflictError as exc:
|
||||
return templates.TemplateResponse(
|
||||
request,
|
||||
"projects/form.html",
|
||||
{
|
||||
"project": project,
|
||||
"operation_types": _operation_type_choices(),
|
||||
"form_action": request.url_for("projects.create_project_submit"),
|
||||
"cancel_url": request.url_for("projects.project_list_page"),
|
||||
"error": "Project with this name already exists.",
|
||||
},
|
||||
status_code=status.HTTP_409_CONFLICT,
|
||||
)
|
||||
|
||||
return RedirectResponse(
|
||||
request.url_for("projects.project_list_page"),
|
||||
status_code=status.HTTP_303_SEE_OTHER,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/{project_id}", response_model=ProjectRead)
|
||||
def get_project(project: Project = Depends(require_project_resource())) -> ProjectRead:
|
||||
return _to_read_model(project)
|
||||
|
||||
|
||||
@router.put("/{project_id}", response_model=ProjectRead)
|
||||
def update_project(
|
||||
payload: ProjectUpdate,
|
||||
project: Project = Depends(
|
||||
require_project_resource(require_manage=True)
|
||||
),
|
||||
uow: UnitOfWork = Depends(get_unit_of_work),
|
||||
) -> ProjectRead:
|
||||
update_data = payload.model_dump(exclude_unset=True)
|
||||
for field, value in update_data.items():
|
||||
setattr(project, field, value)
|
||||
|
||||
uow.flush()
|
||||
return _to_read_model(project)
|
||||
|
||||
|
||||
@router.delete("/{project_id}", status_code=status.HTTP_204_NO_CONTENT)
|
||||
def delete_project(
|
||||
project: Project = Depends(require_project_resource(require_manage=True)),
|
||||
uow: UnitOfWork = Depends(get_unit_of_work),
|
||||
) -> None:
|
||||
_require_project_repo(uow).delete(project.id)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/{project_id}/view",
|
||||
response_class=HTMLResponse,
|
||||
include_in_schema=False,
|
||||
name="projects.view_project",
|
||||
)
|
||||
def view_project(
|
||||
request: Request,
|
||||
project: Project = Depends(require_project_resource()),
|
||||
uow: UnitOfWork = Depends(get_unit_of_work),
|
||||
) -> HTMLResponse:
|
||||
project = _require_project_repo(uow).get(project.id, with_children=True)
|
||||
|
||||
scenarios = sorted(project.scenarios, key=lambda s: s.created_at)
|
||||
scenario_stats = {
|
||||
"total": len(scenarios),
|
||||
"active": sum(1 for scenario in scenarios if scenario.status == ScenarioStatus.ACTIVE),
|
||||
"draft": sum(1 for scenario in scenarios if scenario.status == ScenarioStatus.DRAFT),
|
||||
"archived": sum(1 for scenario in scenarios if scenario.status == ScenarioStatus.ARCHIVED),
|
||||
"latest_update": max(
|
||||
(scenario.updated_at for scenario in scenarios if scenario.updated_at),
|
||||
default=None,
|
||||
),
|
||||
}
|
||||
return templates.TemplateResponse(
|
||||
request,
|
||||
"projects/detail.html",
|
||||
{
|
||||
"project": project,
|
||||
"scenarios": scenarios,
|
||||
"scenario_stats": scenario_stats,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/{project_id}/edit",
|
||||
response_class=HTMLResponse,
|
||||
include_in_schema=False,
|
||||
name="projects.edit_project_form",
|
||||
)
|
||||
def edit_project_form(
|
||||
request: Request,
|
||||
project: Project = Depends(
|
||||
require_project_resource(require_manage=True)
|
||||
),
|
||||
) -> HTMLResponse:
|
||||
return templates.TemplateResponse(
|
||||
request,
|
||||
"projects/form.html",
|
||||
{
|
||||
"project": project,
|
||||
"operation_types": _operation_type_choices(),
|
||||
"form_action": request.url_for(
|
||||
"projects.edit_project_submit", project_id=project.id
|
||||
),
|
||||
"cancel_url": request.url_for(
|
||||
"projects.view_project", project_id=project.id
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/{project_id}/edit",
|
||||
include_in_schema=False,
|
||||
name="projects.edit_project_submit",
|
||||
)
|
||||
def edit_project_submit(
|
||||
request: Request,
|
||||
project: Project = Depends(
|
||||
require_project_resource(require_manage=True)
|
||||
),
|
||||
name: str = Form(...),
|
||||
location: str | None = Form(None),
|
||||
operation_type: str | None = Form(None),
|
||||
description: str | None = Form(None),
|
||||
uow: UnitOfWork = Depends(get_unit_of_work),
|
||||
):
|
||||
def _normalise(value: str | None) -> str | None:
|
||||
if value is None:
|
||||
return None
|
||||
value = value.strip()
|
||||
return value or None
|
||||
|
||||
project.name = name.strip()
|
||||
project.location = _normalise(location)
|
||||
if operation_type:
|
||||
try:
|
||||
project.operation_type = MiningOperationType(operation_type)
|
||||
except ValueError as exc:
|
||||
return templates.TemplateResponse(
|
||||
request,
|
||||
"projects/form.html",
|
||||
{
|
||||
"project": project,
|
||||
"operation_types": _operation_type_choices(),
|
||||
"form_action": request.url_for(
|
||||
"projects.edit_project_submit", project_id=project.id
|
||||
),
|
||||
"cancel_url": request.url_for(
|
||||
"projects.view_project", project_id=project.id
|
||||
),
|
||||
"error": "Invalid operation type.",
|
||||
},
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
)
|
||||
project.description = _normalise(description)
|
||||
|
||||
uow.flush()
|
||||
|
||||
return RedirectResponse(
|
||||
request.url_for("projects.view_project", project_id=project.id),
|
||||
status_code=status.HTTP_303_SEE_OTHER,
|
||||
)
|
||||
465
routes/scenarios.py
Normal file
465
routes/scenarios.py
Normal file
@@ -0,0 +1,465 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import date
|
||||
from typing import List
|
||||
|
||||
from fastapi import APIRouter, Depends, Form, HTTPException, Request, status
|
||||
from fastapi.responses import HTMLResponse, RedirectResponse
|
||||
from fastapi.templating import Jinja2Templates
|
||||
|
||||
from dependencies import (
|
||||
get_unit_of_work,
|
||||
require_any_role,
|
||||
require_roles,
|
||||
require_scenario_resource,
|
||||
)
|
||||
from models import ResourceType, Scenario, ScenarioStatus, User
|
||||
from schemas.scenario import (
|
||||
ScenarioComparisonRequest,
|
||||
ScenarioComparisonResponse,
|
||||
ScenarioCreate,
|
||||
ScenarioRead,
|
||||
ScenarioUpdate,
|
||||
)
|
||||
from services.exceptions import (
|
||||
EntityConflictError,
|
||||
EntityNotFoundError,
|
||||
ScenarioValidationError,
|
||||
)
|
||||
from services.unit_of_work import UnitOfWork
|
||||
|
||||
router = APIRouter(tags=["Scenarios"])
|
||||
templates = Jinja2Templates(directory="templates")
|
||||
|
||||
READ_ROLES = ("viewer", "analyst", "project_manager", "admin")
|
||||
MANAGE_ROLES = ("project_manager", "admin")
|
||||
|
||||
|
||||
def _to_read_model(scenario: Scenario) -> ScenarioRead:
|
||||
return ScenarioRead.model_validate(scenario)
|
||||
|
||||
|
||||
def _resource_type_choices() -> list[tuple[str, str]]:
|
||||
return [
|
||||
(resource.value, resource.value.replace("_", " ").title())
|
||||
for resource in ResourceType
|
||||
]
|
||||
|
||||
|
||||
def _scenario_status_choices() -> list[tuple[str, str]]:
|
||||
return [
|
||||
(status.value, status.value.title()) for status in ScenarioStatus
|
||||
]
|
||||
|
||||
|
||||
def _require_project_repo(uow: UnitOfWork):
|
||||
if not uow.projects:
|
||||
raise RuntimeError("Project repository not initialised")
|
||||
return uow.projects
|
||||
|
||||
|
||||
def _require_scenario_repo(uow: UnitOfWork):
|
||||
if not uow.scenarios:
|
||||
raise RuntimeError("Scenario repository not initialised")
|
||||
return uow.scenarios
|
||||
|
||||
|
||||
@router.get(
|
||||
"/projects/{project_id}/scenarios",
|
||||
response_model=List[ScenarioRead],
|
||||
)
|
||||
def list_scenarios_for_project(
|
||||
project_id: int,
|
||||
_: User = Depends(require_any_role(*READ_ROLES)),
|
||||
uow: UnitOfWork = Depends(get_unit_of_work),
|
||||
) -> List[ScenarioRead]:
|
||||
project_repo = _require_project_repo(uow)
|
||||
scenario_repo = _require_scenario_repo(uow)
|
||||
try:
|
||||
project_repo.get(project_id)
|
||||
except EntityNotFoundError as exc:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail=str(exc)) from exc
|
||||
|
||||
scenarios = scenario_repo.list_for_project(project_id)
|
||||
return [_to_read_model(scenario) for scenario in scenarios]
|
||||
|
||||
|
||||
@router.post(
|
||||
"/projects/{project_id}/scenarios/compare",
|
||||
response_model=ScenarioComparisonResponse,
|
||||
status_code=status.HTTP_200_OK,
|
||||
)
|
||||
def compare_scenarios(
|
||||
project_id: int,
|
||||
payload: ScenarioComparisonRequest,
|
||||
_: User = Depends(require_any_role(*READ_ROLES)),
|
||||
uow: UnitOfWork = Depends(get_unit_of_work),
|
||||
) -> ScenarioComparisonResponse:
|
||||
try:
|
||||
_require_project_repo(uow).get(project_id)
|
||||
except EntityNotFoundError as exc:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail=str(exc)
|
||||
) from exc
|
||||
|
||||
try:
|
||||
scenarios = uow.validate_scenarios_for_comparison(payload.scenario_ids)
|
||||
if any(scenario.project_id != project_id for scenario in scenarios):
|
||||
raise ScenarioValidationError(
|
||||
code="SCENARIO_PROJECT_MISMATCH",
|
||||
message="Selected scenarios do not belong to the same project.",
|
||||
scenario_ids=[
|
||||
scenario.id for scenario in scenarios if scenario.id is not None
|
||||
],
|
||||
)
|
||||
except EntityNotFoundError as exc:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail=str(exc)
|
||||
) from exc
|
||||
except ScenarioValidationError as exc:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_422_UNPROCESSABLE_CONTENT,
|
||||
detail={
|
||||
"code": exc.code,
|
||||
"message": exc.message,
|
||||
"scenario_ids": list(exc.scenario_ids or []),
|
||||
},
|
||||
) from exc
|
||||
|
||||
return ScenarioComparisonResponse(
|
||||
project_id=project_id,
|
||||
scenarios=[_to_read_model(scenario) for scenario in scenarios],
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/projects/{project_id}/scenarios",
|
||||
response_model=ScenarioRead,
|
||||
status_code=status.HTTP_201_CREATED,
|
||||
)
|
||||
def create_scenario_for_project(
|
||||
project_id: int,
|
||||
payload: ScenarioCreate,
|
||||
_: User = Depends(require_roles(*MANAGE_ROLES)),
|
||||
uow: UnitOfWork = Depends(get_unit_of_work),
|
||||
) -> ScenarioRead:
|
||||
project_repo = _require_project_repo(uow)
|
||||
scenario_repo = _require_scenario_repo(uow)
|
||||
try:
|
||||
project_repo.get(project_id)
|
||||
except EntityNotFoundError as exc:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail=str(exc)) from exc
|
||||
|
||||
scenario = Scenario(project_id=project_id, **payload.model_dump())
|
||||
|
||||
try:
|
||||
created = scenario_repo.create(scenario)
|
||||
except EntityConflictError as exc:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_409_CONFLICT, detail=str(exc)) from exc
|
||||
return _to_read_model(created)
|
||||
|
||||
|
||||
@router.get("/scenarios/{scenario_id}", response_model=ScenarioRead)
|
||||
def get_scenario(
|
||||
scenario: Scenario = Depends(require_scenario_resource()),
|
||||
) -> ScenarioRead:
|
||||
return _to_read_model(scenario)
|
||||
|
||||
|
||||
@router.put("/scenarios/{scenario_id}", response_model=ScenarioRead)
|
||||
def update_scenario(
|
||||
payload: ScenarioUpdate,
|
||||
scenario: Scenario = Depends(
|
||||
require_scenario_resource(require_manage=True)
|
||||
),
|
||||
uow: UnitOfWork = Depends(get_unit_of_work),
|
||||
) -> ScenarioRead:
|
||||
update_data = payload.model_dump(exclude_unset=True)
|
||||
for field, value in update_data.items():
|
||||
setattr(scenario, field, value)
|
||||
|
||||
uow.flush()
|
||||
return _to_read_model(scenario)
|
||||
|
||||
|
||||
@router.delete("/scenarios/{scenario_id}", status_code=status.HTTP_204_NO_CONTENT)
|
||||
def delete_scenario(
|
||||
scenario: Scenario = Depends(
|
||||
require_scenario_resource(require_manage=True)
|
||||
),
|
||||
uow: UnitOfWork = Depends(get_unit_of_work),
|
||||
) -> None:
|
||||
_require_scenario_repo(uow).delete(scenario.id)
|
||||
|
||||
|
||||
def _normalise(value: str | None) -> str | None:
|
||||
if value is None:
|
||||
return None
|
||||
value = value.strip()
|
||||
return value or None
|
||||
|
||||
|
||||
def _parse_date(value: str | None) -> date | None:
|
||||
value = _normalise(value)
|
||||
if not value:
|
||||
return None
|
||||
return date.fromisoformat(value)
|
||||
|
||||
|
||||
def _parse_discount_rate(value: str | None) -> float | None:
|
||||
value = _normalise(value)
|
||||
if not value:
|
||||
return None
|
||||
try:
|
||||
return float(value)
|
||||
except ValueError:
|
||||
return None
|
||||
|
||||
|
||||
@router.get(
|
||||
"/projects/{project_id}/scenarios/new",
|
||||
response_class=HTMLResponse,
|
||||
include_in_schema=False,
|
||||
name="scenarios.create_scenario_form",
|
||||
)
|
||||
def create_scenario_form(
|
||||
project_id: int,
|
||||
request: Request,
|
||||
_: User = Depends(require_roles(*MANAGE_ROLES)),
|
||||
uow: UnitOfWork = Depends(get_unit_of_work),
|
||||
) -> HTMLResponse:
|
||||
try:
|
||||
project = _require_project_repo(uow).get(project_id)
|
||||
except EntityNotFoundError as exc:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail=str(exc)
|
||||
) from exc
|
||||
|
||||
return templates.TemplateResponse(
|
||||
request,
|
||||
"scenarios/form.html",
|
||||
{
|
||||
"project": project,
|
||||
"scenario": None,
|
||||
"scenario_statuses": _scenario_status_choices(),
|
||||
"resource_types": _resource_type_choices(),
|
||||
"form_action": request.url_for(
|
||||
"scenarios.create_scenario_submit", project_id=project_id
|
||||
),
|
||||
"cancel_url": request.url_for(
|
||||
"projects.view_project", project_id=project_id
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/projects/{project_id}/scenarios/new",
|
||||
include_in_schema=False,
|
||||
name="scenarios.create_scenario_submit",
|
||||
)
|
||||
def create_scenario_submit(
|
||||
project_id: int,
|
||||
request: Request,
|
||||
_: User = Depends(require_roles(*MANAGE_ROLES)),
|
||||
name: str = Form(...),
|
||||
description: str | None = Form(None),
|
||||
status_value: str = Form(ScenarioStatus.DRAFT.value),
|
||||
start_date: str | None = Form(None),
|
||||
end_date: str | None = Form(None),
|
||||
discount_rate: str | None = Form(None),
|
||||
currency: str | None = Form(None),
|
||||
primary_resource: str | None = Form(None),
|
||||
uow: UnitOfWork = Depends(get_unit_of_work),
|
||||
):
|
||||
project_repo = _require_project_repo(uow)
|
||||
scenario_repo = _require_scenario_repo(uow)
|
||||
try:
|
||||
project = project_repo.get(project_id)
|
||||
except EntityNotFoundError as exc:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail=str(exc)
|
||||
) from exc
|
||||
|
||||
try:
|
||||
status_enum = ScenarioStatus(status_value)
|
||||
except ValueError:
|
||||
status_enum = ScenarioStatus.DRAFT
|
||||
|
||||
resource_enum = None
|
||||
if primary_resource:
|
||||
try:
|
||||
resource_enum = ResourceType(primary_resource)
|
||||
except ValueError:
|
||||
resource_enum = None
|
||||
|
||||
currency_value = _normalise(currency)
|
||||
currency_value = currency_value.upper() if currency_value else None
|
||||
|
||||
scenario = Scenario(
|
||||
project_id=project_id,
|
||||
name=name.strip(),
|
||||
description=_normalise(description),
|
||||
status=status_enum,
|
||||
start_date=_parse_date(start_date),
|
||||
end_date=_parse_date(end_date),
|
||||
discount_rate=_parse_discount_rate(discount_rate),
|
||||
currency=currency_value,
|
||||
primary_resource=resource_enum,
|
||||
)
|
||||
|
||||
try:
|
||||
scenario_repo.create(scenario)
|
||||
except EntityConflictError as exc:
|
||||
return templates.TemplateResponse(
|
||||
request,
|
||||
"scenarios/form.html",
|
||||
{
|
||||
"project": project,
|
||||
"scenario": scenario,
|
||||
"scenario_statuses": _scenario_status_choices(),
|
||||
"resource_types": _resource_type_choices(),
|
||||
"form_action": request.url_for(
|
||||
"scenarios.create_scenario_submit", project_id=project_id
|
||||
),
|
||||
"cancel_url": request.url_for(
|
||||
"projects.view_project", project_id=project_id
|
||||
),
|
||||
"error": "Scenario could not be created.",
|
||||
},
|
||||
status_code=status.HTTP_409_CONFLICT,
|
||||
)
|
||||
|
||||
return RedirectResponse(
|
||||
request.url_for("projects.view_project", project_id=project_id),
|
||||
status_code=status.HTTP_303_SEE_OTHER,
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/scenarios/{scenario_id}/view",
|
||||
response_class=HTMLResponse,
|
||||
include_in_schema=False,
|
||||
name="scenarios.view_scenario",
|
||||
)
|
||||
def view_scenario(
|
||||
request: Request,
|
||||
scenario: Scenario = Depends(
|
||||
require_scenario_resource(with_children=True)
|
||||
),
|
||||
uow: UnitOfWork = Depends(get_unit_of_work),
|
||||
) -> HTMLResponse:
|
||||
project = _require_project_repo(uow).get(scenario.project_id)
|
||||
financial_inputs = sorted(
|
||||
scenario.financial_inputs, key=lambda item: item.created_at
|
||||
)
|
||||
simulation_parameters = sorted(
|
||||
scenario.simulation_parameters, key=lambda item: item.created_at
|
||||
)
|
||||
|
||||
scenario_metrics = {
|
||||
"financial_count": len(financial_inputs),
|
||||
"parameter_count": len(simulation_parameters),
|
||||
"currency": scenario.currency,
|
||||
"primary_resource": scenario.primary_resource.value.replace('_', ' ').title() if scenario.primary_resource else None,
|
||||
}
|
||||
|
||||
return templates.TemplateResponse(
|
||||
request,
|
||||
"scenarios/detail.html",
|
||||
{
|
||||
"project": project,
|
||||
"scenario": scenario,
|
||||
"scenario_metrics": scenario_metrics,
|
||||
"financial_inputs": financial_inputs,
|
||||
"simulation_parameters": simulation_parameters,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/scenarios/{scenario_id}/edit",
|
||||
response_class=HTMLResponse,
|
||||
include_in_schema=False,
|
||||
name="scenarios.edit_scenario_form",
|
||||
)
|
||||
def edit_scenario_form(
|
||||
request: Request,
|
||||
scenario: Scenario = Depends(
|
||||
require_scenario_resource(require_manage=True)
|
||||
),
|
||||
uow: UnitOfWork = Depends(get_unit_of_work),
|
||||
) -> HTMLResponse:
|
||||
project = _require_project_repo(uow).get(scenario.project_id)
|
||||
|
||||
return templates.TemplateResponse(
|
||||
request,
|
||||
"scenarios/form.html",
|
||||
{
|
||||
"project": project,
|
||||
"scenario": scenario,
|
||||
"scenario_statuses": _scenario_status_choices(),
|
||||
"resource_types": _resource_type_choices(),
|
||||
"form_action": request.url_for(
|
||||
"scenarios.edit_scenario_submit", scenario_id=scenario.id
|
||||
),
|
||||
"cancel_url": request.url_for(
|
||||
"scenarios.view_scenario", scenario_id=scenario.id
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/scenarios/{scenario_id}/edit",
|
||||
include_in_schema=False,
|
||||
name="scenarios.edit_scenario_submit",
|
||||
)
|
||||
def edit_scenario_submit(
|
||||
request: Request,
|
||||
scenario: Scenario = Depends(
|
||||
require_scenario_resource(require_manage=True)
|
||||
),
|
||||
name: str = Form(...),
|
||||
description: str | None = Form(None),
|
||||
status_value: str = Form(ScenarioStatus.DRAFT.value),
|
||||
start_date: str | None = Form(None),
|
||||
end_date: str | None = Form(None),
|
||||
discount_rate: str | None = Form(None),
|
||||
currency: str | None = Form(None),
|
||||
primary_resource: str | None = Form(None),
|
||||
uow: UnitOfWork = Depends(get_unit_of_work),
|
||||
):
|
||||
project = _require_project_repo(uow).get(scenario.project_id)
|
||||
|
||||
scenario.name = name.strip()
|
||||
scenario.description = _normalise(description)
|
||||
try:
|
||||
scenario.status = ScenarioStatus(status_value)
|
||||
except ValueError:
|
||||
scenario.status = ScenarioStatus.DRAFT
|
||||
scenario.start_date = _parse_date(start_date)
|
||||
scenario.end_date = _parse_date(end_date)
|
||||
|
||||
scenario.discount_rate = _parse_discount_rate(discount_rate)
|
||||
|
||||
currency_value = _normalise(currency)
|
||||
scenario.currency = currency_value.upper() if currency_value else None
|
||||
|
||||
resource_enum = None
|
||||
if primary_resource:
|
||||
try:
|
||||
resource_enum = ResourceType(primary_resource)
|
||||
except ValueError:
|
||||
resource_enum = None
|
||||
scenario.primary_resource = resource_enum
|
||||
|
||||
uow.flush()
|
||||
|
||||
return RedirectResponse(
|
||||
request.url_for("scenarios.view_scenario", scenario_id=scenario.id),
|
||||
status_code=status.HTTP_303_SEE_OTHER,
|
||||
)
|
||||
67
schemas/auth.py
Normal file
67
schemas/auth.py
Normal file
@@ -0,0 +1,67 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field, ValidationInfo, field_validator
|
||||
|
||||
|
||||
class FormModel(BaseModel):
|
||||
"""Base Pydantic model for HTML form submissions."""
|
||||
|
||||
model_config = ConfigDict(extra="forbid", str_strip_whitespace=True)
|
||||
|
||||
|
||||
class RegistrationForm(FormModel):
|
||||
username: str = Field(min_length=3, max_length=128)
|
||||
email: str = Field(min_length=5, max_length=255)
|
||||
password: str = Field(min_length=8, max_length=256)
|
||||
confirm_password: str
|
||||
|
||||
@field_validator("email")
|
||||
@classmethod
|
||||
def validate_email(cls, value: str) -> str:
|
||||
if "@" not in value or value.startswith("@") or value.endswith("@"):
|
||||
raise ValueError("Invalid email address.")
|
||||
local, domain = value.split("@", 1)
|
||||
if not local or "." not in domain:
|
||||
raise ValueError("Invalid email address.")
|
||||
return value.lower()
|
||||
|
||||
@field_validator("confirm_password")
|
||||
@classmethod
|
||||
def passwords_match(cls, value: str, info: ValidationInfo) -> str:
|
||||
password = info.data.get("password")
|
||||
if password != value:
|
||||
raise ValueError("Passwords do not match.")
|
||||
return value
|
||||
|
||||
|
||||
class LoginForm(FormModel):
|
||||
username: str = Field(min_length=1, max_length=255)
|
||||
password: str = Field(min_length=1, max_length=256)
|
||||
|
||||
|
||||
class PasswordResetRequestForm(FormModel):
|
||||
email: str = Field(min_length=5, max_length=255)
|
||||
|
||||
@field_validator("email")
|
||||
@classmethod
|
||||
def validate_email(cls, value: str) -> str:
|
||||
if "@" not in value or value.startswith("@") or value.endswith("@"):
|
||||
raise ValueError("Invalid email address.")
|
||||
local, domain = value.split("@", 1)
|
||||
if not local or "." not in domain:
|
||||
raise ValueError("Invalid email address.")
|
||||
return value.lower()
|
||||
|
||||
|
||||
class PasswordResetForm(FormModel):
|
||||
token: str = Field(min_length=1)
|
||||
password: str = Field(min_length=8, max_length=256)
|
||||
confirm_password: str
|
||||
|
||||
@field_validator("confirm_password")
|
||||
@classmethod
|
||||
def reset_passwords_match(cls, value: str, info: ValidationInfo) -> str:
|
||||
password = info.data.get("password")
|
||||
if password != value:
|
||||
raise ValueError("Passwords do not match.")
|
||||
return value
|
||||
37
schemas/project.py
Normal file
37
schemas/project.py
Normal file
@@ -0,0 +1,37 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
from models import MiningOperationType
|
||||
|
||||
|
||||
class ProjectBase(BaseModel):
|
||||
name: str
|
||||
location: str | None = None
|
||||
operation_type: MiningOperationType
|
||||
description: str | None = None
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
|
||||
class ProjectCreate(ProjectBase):
|
||||
pass
|
||||
|
||||
|
||||
class ProjectUpdate(BaseModel):
|
||||
name: str | None = None
|
||||
location: str | None = None
|
||||
operation_type: MiningOperationType | None = None
|
||||
description: str | None = None
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
|
||||
class ProjectRead(ProjectBase):
|
||||
id: int
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
87
schemas/scenario.py
Normal file
87
schemas/scenario.py
Normal file
@@ -0,0 +1,87 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import date, datetime
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, field_validator, model_validator
|
||||
|
||||
from models import ResourceType, ScenarioStatus
|
||||
|
||||
|
||||
class ScenarioBase(BaseModel):
|
||||
name: str
|
||||
description: str | None = None
|
||||
status: ScenarioStatus = ScenarioStatus.DRAFT
|
||||
start_date: date | None = None
|
||||
end_date: date | None = None
|
||||
discount_rate: float | None = None
|
||||
currency: str | None = None
|
||||
primary_resource: ResourceType | None = None
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
@field_validator("currency")
|
||||
@classmethod
|
||||
def normalise_currency(cls, value: str | None) -> str | None:
|
||||
if value is None:
|
||||
return value
|
||||
value = value.upper()
|
||||
if len(value) != 3:
|
||||
raise ValueError("Currency code must be a 3-letter ISO value")
|
||||
return value
|
||||
|
||||
|
||||
class ScenarioCreate(ScenarioBase):
|
||||
pass
|
||||
|
||||
|
||||
class ScenarioUpdate(BaseModel):
|
||||
name: str | None = None
|
||||
description: str | None = None
|
||||
status: ScenarioStatus | None = None
|
||||
start_date: date | None = None
|
||||
end_date: date | None = None
|
||||
discount_rate: float | None = None
|
||||
currency: str | None = None
|
||||
primary_resource: ResourceType | None = None
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
@field_validator("currency")
|
||||
@classmethod
|
||||
def normalise_currency(cls, value: str | None) -> str | None:
|
||||
if value is None:
|
||||
return value
|
||||
value = value.upper()
|
||||
if len(value) != 3:
|
||||
raise ValueError("Currency code must be a 3-letter ISO value")
|
||||
return value
|
||||
|
||||
|
||||
class ScenarioRead(ScenarioBase):
|
||||
id: int
|
||||
project_id: int
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
class ScenarioComparisonRequest(BaseModel):
|
||||
scenario_ids: list[int]
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
@model_validator(mode="after")
|
||||
def ensure_minimum_ids(self) -> "ScenarioComparisonRequest":
|
||||
unique_ids: list[int] = list(dict.fromkeys(self.scenario_ids))
|
||||
if len(unique_ids) < 2:
|
||||
raise ValueError("At least two unique scenario identifiers are required for comparison.")
|
||||
self.scenario_ids = unique_ids
|
||||
return self
|
||||
|
||||
|
||||
class ScenarioComparisonResponse(BaseModel):
|
||||
project_id: int
|
||||
scenarios: list[ScenarioRead]
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
20
scripts/00_initial_data.py
Normal file
20
scripts/00_initial_data.py
Normal file
@@ -0,0 +1,20 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
|
||||
from scripts.initial_data import load_config, seed_initial_data
|
||||
|
||||
|
||||
def main() -> int:
|
||||
logging.basicConfig(level=logging.INFO, format="[%(levelname)s] %(message)s")
|
||||
try:
|
||||
config = load_config()
|
||||
seed_initial_data(config)
|
||||
except Exception as exc: # pragma: no cover - operational guard
|
||||
logging.exception("Seeding failed: %s", exc)
|
||||
return 1
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise SystemExit(main())
|
||||
1
scripts/__init__.py
Normal file
1
scripts/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Utility scripts for CalMiner maintenance tasks."""
|
||||
183
scripts/initial_data.py
Normal file
183
scripts/initial_data.py
Normal file
@@ -0,0 +1,183 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from typing import Callable, Iterable
|
||||
|
||||
from dotenv import load_dotenv
|
||||
|
||||
from models import Role, User
|
||||
from services.repositories import DEFAULT_ROLE_DEFINITIONS, RoleRepository, UserRepository
|
||||
from services.unit_of_work import UnitOfWork
|
||||
|
||||
|
||||
@dataclass
|
||||
class SeedConfig:
|
||||
admin_email: str
|
||||
admin_username: str
|
||||
admin_password: str
|
||||
admin_roles: tuple[str, ...]
|
||||
force_reset: bool
|
||||
|
||||
|
||||
@dataclass
|
||||
class RoleSeedResult:
|
||||
created: int
|
||||
updated: int
|
||||
total: int
|
||||
|
||||
|
||||
@dataclass
|
||||
class AdminSeedResult:
|
||||
created_user: bool
|
||||
updated_user: bool
|
||||
password_rotated: bool
|
||||
roles_granted: int
|
||||
|
||||
|
||||
def parse_bool(value: str | None) -> bool:
|
||||
if value is None:
|
||||
return False
|
||||
return value.strip().lower() in {"1", "true", "yes", "on"}
|
||||
|
||||
|
||||
def normalise_role_list(raw_value: str | None) -> tuple[str, ...]:
|
||||
if not raw_value:
|
||||
return ("admin",)
|
||||
parts = [segment.strip() for segment in raw_value.split(",") if segment.strip()]
|
||||
if "admin" not in parts:
|
||||
parts.insert(0, "admin")
|
||||
seen: set[str] = set()
|
||||
ordered: list[str] = []
|
||||
for role_name in parts:
|
||||
if role_name not in seen:
|
||||
ordered.append(role_name)
|
||||
seen.add(role_name)
|
||||
return tuple(ordered)
|
||||
|
||||
|
||||
def load_config() -> SeedConfig:
|
||||
load_dotenv()
|
||||
admin_email = os.getenv("CALMINER_SEED_ADMIN_EMAIL", "admin@calminer.local")
|
||||
admin_username = os.getenv("CALMINER_SEED_ADMIN_USERNAME", "admin")
|
||||
admin_password = os.getenv("CALMINER_SEED_ADMIN_PASSWORD", "ChangeMe123!")
|
||||
admin_roles = normalise_role_list(os.getenv("CALMINER_SEED_ADMIN_ROLES"))
|
||||
force_reset = parse_bool(os.getenv("CALMINER_SEED_FORCE"))
|
||||
return SeedConfig(
|
||||
admin_email=admin_email,
|
||||
admin_username=admin_username,
|
||||
admin_password=admin_password,
|
||||
admin_roles=admin_roles,
|
||||
force_reset=force_reset,
|
||||
)
|
||||
|
||||
|
||||
def ensure_default_roles(
|
||||
role_repo: RoleRepository,
|
||||
definitions: Iterable[dict[str, str]] = DEFAULT_ROLE_DEFINITIONS,
|
||||
) -> RoleSeedResult:
|
||||
created = 0
|
||||
updated = 0
|
||||
total = 0
|
||||
for definition in definitions:
|
||||
total += 1
|
||||
existing = role_repo.get_by_name(definition["name"])
|
||||
if existing is None:
|
||||
role_repo.create(Role(**definition))
|
||||
created += 1
|
||||
continue
|
||||
changed = False
|
||||
if existing.display_name != definition["display_name"]:
|
||||
existing.display_name = definition["display_name"]
|
||||
changed = True
|
||||
if existing.description != definition["description"]:
|
||||
existing.description = definition["description"]
|
||||
changed = True
|
||||
if changed:
|
||||
updated += 1
|
||||
role_repo.session.flush()
|
||||
return RoleSeedResult(created=created, updated=updated, total=total)
|
||||
|
||||
|
||||
def ensure_admin_user(
|
||||
user_repo: UserRepository,
|
||||
role_repo: RoleRepository,
|
||||
config: SeedConfig,
|
||||
) -> AdminSeedResult:
|
||||
created_user = False
|
||||
updated_user = False
|
||||
password_rotated = False
|
||||
roles_granted = 0
|
||||
|
||||
user = user_repo.get_by_email(config.admin_email, with_roles=True)
|
||||
if user is None:
|
||||
user = User(
|
||||
email=config.admin_email,
|
||||
username=config.admin_username,
|
||||
password_hash=User.hash_password(config.admin_password),
|
||||
is_active=True,
|
||||
is_superuser=True,
|
||||
)
|
||||
user_repo.create(user)
|
||||
created_user = True
|
||||
else:
|
||||
if user.username != config.admin_username:
|
||||
user.username = config.admin_username
|
||||
updated_user = True
|
||||
if not user.is_active:
|
||||
user.is_active = True
|
||||
updated_user = True
|
||||
if not user.is_superuser:
|
||||
user.is_superuser = True
|
||||
updated_user = True
|
||||
if config.force_reset:
|
||||
user.password_hash = User.hash_password(config.admin_password)
|
||||
password_rotated = True
|
||||
updated_user = True
|
||||
user_repo.session.flush()
|
||||
|
||||
for role_name in config.admin_roles:
|
||||
role = role_repo.get_by_name(role_name)
|
||||
if role is None:
|
||||
logging.warning("Role '%s' is not defined and will be skipped", role_name)
|
||||
continue
|
||||
already_assigned = any(assignment.role_id == role.id for assignment in user.role_assignments)
|
||||
if already_assigned:
|
||||
continue
|
||||
user_repo.assign_role(user_id=user.id, role_id=role.id, granted_by=user.id)
|
||||
roles_granted += 1
|
||||
|
||||
return AdminSeedResult(
|
||||
created_user=created_user,
|
||||
updated_user=updated_user,
|
||||
password_rotated=password_rotated,
|
||||
roles_granted=roles_granted,
|
||||
)
|
||||
|
||||
|
||||
def seed_initial_data(
|
||||
config: SeedConfig,
|
||||
*,
|
||||
unit_of_work_factory: Callable[[], UnitOfWork] | None = None,
|
||||
) -> None:
|
||||
logging.info("Starting initial data seeding")
|
||||
factory = unit_of_work_factory or UnitOfWork
|
||||
with factory() as uow:
|
||||
assert uow.roles is not None and uow.users is not None
|
||||
role_result = ensure_default_roles(uow.roles)
|
||||
admin_result = ensure_admin_user(uow.users, uow.roles, config)
|
||||
logging.info(
|
||||
"Roles processed: %s total, %s created, %s updated",
|
||||
role_result.total,
|
||||
role_result.created,
|
||||
role_result.updated,
|
||||
)
|
||||
logging.info(
|
||||
"Admin user: created=%s updated=%s password_rotated=%s roles_granted=%s",
|
||||
admin_result.created_user,
|
||||
admin_result.updated_user,
|
||||
admin_result.password_rotated,
|
||||
admin_result.roles_granted,
|
||||
)
|
||||
logging.info("Initial data seeding completed successfully")
|
||||
1
services/__init__.py
Normal file
1
services/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Service layer utilities."""
|
||||
104
services/authorization.py
Normal file
104
services/authorization.py
Normal file
@@ -0,0 +1,104 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Iterable
|
||||
|
||||
from models import Project, Role, Scenario, User
|
||||
from services.exceptions import AuthorizationError, EntityNotFoundError
|
||||
from services.repositories import ProjectRepository, ScenarioRepository
|
||||
from services.unit_of_work import UnitOfWork
|
||||
|
||||
READ_ROLES: frozenset[str] = frozenset(
|
||||
{"viewer", "analyst", "project_manager", "admin"}
|
||||
)
|
||||
MANAGE_ROLES: frozenset[str] = frozenset({"project_manager", "admin"})
|
||||
|
||||
|
||||
def _user_role_names(user: User) -> set[str]:
|
||||
roles: Iterable[Role] = getattr(user, "roles", []) or []
|
||||
return {role.name for role in roles}
|
||||
|
||||
|
||||
def _require_project_repo(uow: UnitOfWork) -> ProjectRepository:
|
||||
if not uow.projects:
|
||||
raise RuntimeError("Project repository not initialised")
|
||||
return uow.projects
|
||||
|
||||
|
||||
def _require_scenario_repo(uow: UnitOfWork) -> ScenarioRepository:
|
||||
if not uow.scenarios:
|
||||
raise RuntimeError("Scenario repository not initialised")
|
||||
return uow.scenarios
|
||||
|
||||
|
||||
def _assert_user_can_access(user: User, *, require_manage: bool) -> None:
|
||||
if not user.is_active:
|
||||
raise AuthorizationError("User account is disabled.")
|
||||
if user.is_superuser:
|
||||
return
|
||||
|
||||
allowed = MANAGE_ROLES if require_manage else READ_ROLES
|
||||
if not _user_role_names(user) & allowed:
|
||||
raise AuthorizationError(
|
||||
"Insufficient role permissions for this action.")
|
||||
|
||||
|
||||
def ensure_project_access(
|
||||
uow: UnitOfWork,
|
||||
*,
|
||||
project_id: int,
|
||||
user: User,
|
||||
require_manage: bool = False,
|
||||
) -> Project:
|
||||
"""Resolve a project and ensure the user holds the required permissions."""
|
||||
|
||||
repo = _require_project_repo(uow)
|
||||
project = repo.get(project_id)
|
||||
_assert_user_can_access(user, require_manage=require_manage)
|
||||
return project
|
||||
|
||||
|
||||
def ensure_scenario_access(
|
||||
uow: UnitOfWork,
|
||||
*,
|
||||
scenario_id: int,
|
||||
user: User,
|
||||
require_manage: bool = False,
|
||||
with_children: bool = False,
|
||||
) -> Scenario:
|
||||
"""Resolve a scenario and ensure the user holds the required permissions."""
|
||||
|
||||
repo = _require_scenario_repo(uow)
|
||||
scenario = repo.get(scenario_id, with_children=with_children)
|
||||
_assert_user_can_access(user, require_manage=require_manage)
|
||||
return scenario
|
||||
|
||||
|
||||
def ensure_scenario_in_project(
|
||||
uow: UnitOfWork,
|
||||
*,
|
||||
project_id: int,
|
||||
scenario_id: int,
|
||||
user: User,
|
||||
require_manage: bool = False,
|
||||
with_children: bool = False,
|
||||
) -> Scenario:
|
||||
"""Resolve a scenario ensuring it belongs to the project and the user may access it."""
|
||||
|
||||
project = ensure_project_access(
|
||||
uow,
|
||||
project_id=project_id,
|
||||
user=user,
|
||||
require_manage=require_manage,
|
||||
)
|
||||
scenario = ensure_scenario_access(
|
||||
uow,
|
||||
scenario_id=scenario_id,
|
||||
user=user,
|
||||
require_manage=require_manage,
|
||||
with_children=with_children,
|
||||
)
|
||||
if scenario.project_id != project.id:
|
||||
raise EntityNotFoundError(
|
||||
f"Scenario {scenario_id} does not belong to project {project_id}."
|
||||
)
|
||||
return scenario
|
||||
129
services/bootstrap.py
Normal file
129
services/bootstrap.py
Normal file
@@ -0,0 +1,129 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from typing import Callable
|
||||
|
||||
from config.settings import AdminBootstrapSettings
|
||||
from models import User
|
||||
from services.repositories import ensure_default_roles
|
||||
from services.unit_of_work import UnitOfWork
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class RoleBootstrapResult:
|
||||
created: int
|
||||
ensured: int
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class AdminBootstrapResult:
|
||||
created_user: bool
|
||||
updated_user: bool
|
||||
password_rotated: bool
|
||||
roles_granted: int
|
||||
|
||||
|
||||
def bootstrap_admin(
|
||||
*,
|
||||
settings: AdminBootstrapSettings,
|
||||
unit_of_work_factory: Callable[[], UnitOfWork] = UnitOfWork,
|
||||
) -> tuple[RoleBootstrapResult, AdminBootstrapResult]:
|
||||
"""Ensure default roles and administrator account exist."""
|
||||
|
||||
with unit_of_work_factory() as uow:
|
||||
assert uow.roles is not None and uow.users is not None
|
||||
|
||||
role_result = _bootstrap_roles(uow)
|
||||
admin_result = _bootstrap_admin_user(uow, settings)
|
||||
|
||||
logger.info(
|
||||
"Admin bootstrap result: created_user=%s updated_user=%s password_rotated=%s roles_granted=%s",
|
||||
admin_result.created_user,
|
||||
admin_result.updated_user,
|
||||
admin_result.password_rotated,
|
||||
admin_result.roles_granted,
|
||||
)
|
||||
return role_result, admin_result
|
||||
|
||||
|
||||
def _bootstrap_roles(uow: UnitOfWork) -> RoleBootstrapResult:
|
||||
assert uow.roles is not None
|
||||
before = {role.name for role in uow.roles.list()}
|
||||
ensure_default_roles(uow.roles)
|
||||
after = {role.name for role in uow.roles.list()}
|
||||
created = len(after - before)
|
||||
return RoleBootstrapResult(created=created, ensured=len(after))
|
||||
|
||||
|
||||
def _bootstrap_admin_user(
|
||||
uow: UnitOfWork,
|
||||
settings: AdminBootstrapSettings,
|
||||
) -> AdminBootstrapResult:
|
||||
assert uow.users is not None and uow.roles is not None
|
||||
|
||||
created_user = False
|
||||
updated_user = False
|
||||
password_rotated = False
|
||||
roles_granted = 0
|
||||
|
||||
user = uow.users.get_by_email(settings.email, with_roles=True)
|
||||
if user is None:
|
||||
user = User(
|
||||
email=settings.email,
|
||||
username=settings.username,
|
||||
password_hash=User.hash_password(settings.password),
|
||||
is_active=True,
|
||||
is_superuser=True,
|
||||
)
|
||||
uow.users.create(user)
|
||||
created_user = True
|
||||
else:
|
||||
if user.username != settings.username:
|
||||
user.username = settings.username
|
||||
updated_user = True
|
||||
if not user.is_active:
|
||||
user.is_active = True
|
||||
updated_user = True
|
||||
if not user.is_superuser:
|
||||
user.is_superuser = True
|
||||
updated_user = True
|
||||
if settings.force_reset:
|
||||
user.password_hash = User.hash_password(settings.password)
|
||||
password_rotated = True
|
||||
updated_user = True
|
||||
uow.users.session.flush()
|
||||
|
||||
user = uow.users.get(user.id, with_roles=True)
|
||||
assert user is not None
|
||||
|
||||
existing_roles = {role.name for role in user.roles}
|
||||
for role_name in settings.roles:
|
||||
role = uow.roles.get_by_name(role_name)
|
||||
if role is None:
|
||||
logger.warning(
|
||||
"Bootstrap admin role '%s' is not defined; skipping assignment",
|
||||
role_name,
|
||||
)
|
||||
continue
|
||||
if role.name in existing_roles:
|
||||
continue
|
||||
uow.users.assign_role(
|
||||
user_id=user.id,
|
||||
role_id=role.id,
|
||||
granted_by=user.id,
|
||||
)
|
||||
roles_granted += 1
|
||||
existing_roles.add(role.name)
|
||||
|
||||
uow.users.session.flush()
|
||||
|
||||
return AdminBootstrapResult(
|
||||
created_user=created_user,
|
||||
updated_user=updated_user,
|
||||
password_rotated=password_rotated,
|
||||
roles_granted=roles_granted,
|
||||
)
|
||||
28
services/exceptions.py
Normal file
28
services/exceptions.py
Normal file
@@ -0,0 +1,28 @@
|
||||
"""Domain-level exceptions for service and repository layers."""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Sequence
|
||||
|
||||
|
||||
class EntityNotFoundError(Exception):
|
||||
"""Raised when a requested entity cannot be located."""
|
||||
|
||||
|
||||
class EntityConflictError(Exception):
|
||||
"""Raised when attempting to create or update an entity that violates uniqueness."""
|
||||
|
||||
|
||||
class AuthorizationError(Exception):
|
||||
"""Raised when a user lacks permission to perform an action."""
|
||||
|
||||
|
||||
@dataclass(eq=False)
|
||||
class ScenarioValidationError(Exception):
|
||||
"""Raised when scenarios fail comparison validation rules."""
|
||||
|
||||
code: str
|
||||
message: str
|
||||
scenario_ids: Sequence[int] | None = None
|
||||
|
||||
def __str__(self) -> str: # pragma: no cover - mirrors message for logging
|
||||
return self.message
|
||||
447
services/repositories.py
Normal file
447
services/repositories.py
Normal file
@@ -0,0 +1,447 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Iterable
|
||||
from datetime import datetime
|
||||
from typing import Sequence
|
||||
|
||||
from sqlalchemy import select, func
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.orm import Session, joinedload, selectinload
|
||||
|
||||
from models import (
|
||||
FinancialInput,
|
||||
Project,
|
||||
Role,
|
||||
Scenario,
|
||||
ScenarioStatus,
|
||||
SimulationParameter,
|
||||
User,
|
||||
UserRole,
|
||||
)
|
||||
from services.exceptions import EntityConflictError, EntityNotFoundError
|
||||
|
||||
|
||||
class ProjectRepository:
|
||||
"""Persistence operations for Project entities."""
|
||||
|
||||
def __init__(self, session: Session) -> None:
|
||||
self.session = session
|
||||
|
||||
def list(self, *, with_children: bool = False) -> Sequence[Project]:
|
||||
stmt = select(Project).order_by(Project.created_at)
|
||||
if with_children:
|
||||
stmt = stmt.options(selectinload(Project.scenarios))
|
||||
return self.session.execute(stmt).scalars().all()
|
||||
|
||||
def count(self) -> int:
|
||||
stmt = select(func.count(Project.id))
|
||||
return self.session.execute(stmt).scalar_one()
|
||||
|
||||
def recent(self, limit: int = 5) -> Sequence[Project]:
|
||||
stmt = (
|
||||
select(Project)
|
||||
.order_by(Project.updated_at.desc())
|
||||
.limit(limit)
|
||||
)
|
||||
return self.session.execute(stmt).scalars().all()
|
||||
|
||||
def get(self, project_id: int, *, with_children: bool = False) -> Project:
|
||||
stmt = select(Project).where(Project.id == project_id)
|
||||
if with_children:
|
||||
stmt = stmt.options(joinedload(Project.scenarios))
|
||||
result = self.session.execute(stmt)
|
||||
if with_children:
|
||||
result = result.unique()
|
||||
project = result.scalar_one_or_none()
|
||||
if project is None:
|
||||
raise EntityNotFoundError(f"Project {project_id} not found")
|
||||
return project
|
||||
|
||||
def exists(self, project_id: int) -> bool:
|
||||
stmt = select(Project.id).where(Project.id == project_id).limit(1)
|
||||
return self.session.execute(stmt).scalar_one_or_none() is not None
|
||||
|
||||
def create(self, project: Project) -> Project:
|
||||
self.session.add(project)
|
||||
try:
|
||||
self.session.flush()
|
||||
except IntegrityError as exc: # pragma: no cover - reliance on DB constraints
|
||||
raise EntityConflictError(
|
||||
"Project violates uniqueness constraints") from exc
|
||||
return project
|
||||
|
||||
def delete(self, project_id: int) -> None:
|
||||
project = self.get(project_id)
|
||||
self.session.delete(project)
|
||||
|
||||
|
||||
class ScenarioRepository:
|
||||
"""Persistence operations for Scenario entities."""
|
||||
|
||||
def __init__(self, session: Session) -> None:
|
||||
self.session = session
|
||||
|
||||
def list_for_project(self, project_id: int) -> Sequence[Scenario]:
|
||||
stmt = (
|
||||
select(Scenario)
|
||||
.where(Scenario.project_id == project_id)
|
||||
.order_by(Scenario.created_at)
|
||||
)
|
||||
return self.session.execute(stmt).scalars().all()
|
||||
|
||||
def count(self) -> int:
|
||||
stmt = select(func.count(Scenario.id))
|
||||
return self.session.execute(stmt).scalar_one()
|
||||
|
||||
def count_by_status(self, status: ScenarioStatus) -> int:
|
||||
stmt = select(func.count(Scenario.id)).where(Scenario.status == status)
|
||||
return self.session.execute(stmt).scalar_one()
|
||||
|
||||
def recent(self, limit: int = 5, *, with_project: bool = False) -> Sequence[Scenario]:
|
||||
stmt = select(Scenario).order_by(
|
||||
Scenario.updated_at.desc()).limit(limit)
|
||||
if with_project:
|
||||
stmt = stmt.options(joinedload(Scenario.project))
|
||||
return self.session.execute(stmt).scalars().all()
|
||||
|
||||
def list_by_status(
|
||||
self,
|
||||
status: ScenarioStatus,
|
||||
*,
|
||||
limit: int | None = None,
|
||||
with_project: bool = False,
|
||||
) -> Sequence[Scenario]:
|
||||
stmt = (
|
||||
select(Scenario)
|
||||
.where(Scenario.status == status)
|
||||
.order_by(Scenario.updated_at.desc())
|
||||
)
|
||||
if with_project:
|
||||
stmt = stmt.options(joinedload(Scenario.project))
|
||||
if limit is not None:
|
||||
stmt = stmt.limit(limit)
|
||||
return self.session.execute(stmt).scalars().all()
|
||||
|
||||
def get(self, scenario_id: int, *, with_children: bool = False) -> Scenario:
|
||||
stmt = select(Scenario).where(Scenario.id == scenario_id)
|
||||
if with_children:
|
||||
stmt = stmt.options(
|
||||
joinedload(Scenario.financial_inputs),
|
||||
joinedload(Scenario.simulation_parameters),
|
||||
)
|
||||
result = self.session.execute(stmt)
|
||||
if with_children:
|
||||
result = result.unique()
|
||||
scenario = result.scalar_one_or_none()
|
||||
if scenario is None:
|
||||
raise EntityNotFoundError(f"Scenario {scenario_id} not found")
|
||||
return scenario
|
||||
|
||||
def exists(self, scenario_id: int) -> bool:
|
||||
stmt = select(Scenario.id).where(Scenario.id == scenario_id).limit(1)
|
||||
return self.session.execute(stmt).scalar_one_or_none() is not None
|
||||
|
||||
def create(self, scenario: Scenario) -> Scenario:
|
||||
self.session.add(scenario)
|
||||
try:
|
||||
self.session.flush()
|
||||
except IntegrityError as exc: # pragma: no cover
|
||||
raise EntityConflictError("Scenario violates constraints") from exc
|
||||
return scenario
|
||||
|
||||
def delete(self, scenario_id: int) -> None:
|
||||
scenario = self.get(scenario_id)
|
||||
self.session.delete(scenario)
|
||||
|
||||
|
||||
class FinancialInputRepository:
|
||||
"""Persistence operations for FinancialInput entities."""
|
||||
|
||||
def __init__(self, session: Session) -> None:
|
||||
self.session = session
|
||||
|
||||
def list_for_scenario(self, scenario_id: int) -> Sequence[FinancialInput]:
|
||||
stmt = (
|
||||
select(FinancialInput)
|
||||
.where(FinancialInput.scenario_id == scenario_id)
|
||||
.order_by(FinancialInput.created_at)
|
||||
)
|
||||
return self.session.execute(stmt).scalars().all()
|
||||
|
||||
def bulk_upsert(self, inputs: Iterable[FinancialInput]) -> Sequence[FinancialInput]:
|
||||
entities = list(inputs)
|
||||
self.session.add_all(entities)
|
||||
try:
|
||||
self.session.flush()
|
||||
except IntegrityError as exc: # pragma: no cover
|
||||
raise EntityConflictError(
|
||||
"Financial input violates constraints") from exc
|
||||
return entities
|
||||
|
||||
def delete(self, input_id: int) -> None:
|
||||
stmt = select(FinancialInput).where(FinancialInput.id == input_id)
|
||||
entity = self.session.execute(stmt).scalar_one_or_none()
|
||||
if entity is None:
|
||||
raise EntityNotFoundError(f"Financial input {input_id} not found")
|
||||
self.session.delete(entity)
|
||||
|
||||
def latest_created_at(self) -> datetime | None:
|
||||
stmt = (
|
||||
select(FinancialInput.created_at)
|
||||
.order_by(FinancialInput.created_at.desc())
|
||||
.limit(1)
|
||||
)
|
||||
return self.session.execute(stmt).scalar_one_or_none()
|
||||
|
||||
|
||||
class SimulationParameterRepository:
|
||||
"""Persistence operations for SimulationParameter entities."""
|
||||
|
||||
def __init__(self, session: Session) -> None:
|
||||
self.session = session
|
||||
|
||||
def list_for_scenario(self, scenario_id: int) -> Sequence[SimulationParameter]:
|
||||
stmt = (
|
||||
select(SimulationParameter)
|
||||
.where(SimulationParameter.scenario_id == scenario_id)
|
||||
.order_by(SimulationParameter.created_at)
|
||||
)
|
||||
return self.session.execute(stmt).scalars().all()
|
||||
|
||||
def bulk_upsert(
|
||||
self, parameters: Iterable[SimulationParameter]
|
||||
) -> Sequence[SimulationParameter]:
|
||||
entities = list(parameters)
|
||||
self.session.add_all(entities)
|
||||
try:
|
||||
self.session.flush()
|
||||
except IntegrityError as exc: # pragma: no cover
|
||||
raise EntityConflictError(
|
||||
"Simulation parameter violates constraints") from exc
|
||||
return entities
|
||||
|
||||
def delete(self, parameter_id: int) -> None:
|
||||
stmt = select(SimulationParameter).where(
|
||||
SimulationParameter.id == parameter_id)
|
||||
entity = self.session.execute(stmt).scalar_one_or_none()
|
||||
if entity is None:
|
||||
raise EntityNotFoundError(
|
||||
f"Simulation parameter {parameter_id} not found")
|
||||
self.session.delete(entity)
|
||||
|
||||
|
||||
class RoleRepository:
|
||||
"""Persistence operations for Role entities."""
|
||||
|
||||
def __init__(self, session: Session) -> None:
|
||||
self.session = session
|
||||
|
||||
def list(self) -> Sequence[Role]:
|
||||
stmt = select(Role).order_by(Role.name)
|
||||
return self.session.execute(stmt).scalars().all()
|
||||
|
||||
def get(self, role_id: int) -> Role:
|
||||
stmt = select(Role).where(Role.id == role_id)
|
||||
role = self.session.execute(stmt).scalar_one_or_none()
|
||||
if role is None:
|
||||
raise EntityNotFoundError(f"Role {role_id} not found")
|
||||
return role
|
||||
|
||||
def get_by_name(self, name: str) -> Role | None:
|
||||
stmt = select(Role).where(Role.name == name)
|
||||
return self.session.execute(stmt).scalar_one_or_none()
|
||||
|
||||
def create(self, role: Role) -> Role:
|
||||
self.session.add(role)
|
||||
try:
|
||||
self.session.flush()
|
||||
except IntegrityError as exc: # pragma: no cover - DB constraint enforcement
|
||||
raise EntityConflictError(
|
||||
"Role violates uniqueness constraints") from exc
|
||||
return role
|
||||
|
||||
|
||||
class UserRepository:
|
||||
"""Persistence operations for User entities and their role assignments."""
|
||||
|
||||
def __init__(self, session: Session) -> None:
|
||||
self.session = session
|
||||
|
||||
def list(self, *, with_roles: bool = False) -> Sequence[User]:
|
||||
stmt = select(User).order_by(User.created_at)
|
||||
if with_roles:
|
||||
stmt = stmt.options(selectinload(User.roles))
|
||||
return self.session.execute(stmt).scalars().all()
|
||||
|
||||
def _apply_role_option(self, stmt, with_roles: bool):
|
||||
if with_roles:
|
||||
stmt = stmt.options(
|
||||
joinedload(User.role_assignments).joinedload(UserRole.role),
|
||||
selectinload(User.roles),
|
||||
)
|
||||
return stmt
|
||||
|
||||
def get(self, user_id: int, *, with_roles: bool = False) -> User:
|
||||
stmt = select(User).where(User.id == user_id).execution_options(
|
||||
populate_existing=True)
|
||||
stmt = self._apply_role_option(stmt, with_roles)
|
||||
result = self.session.execute(stmt)
|
||||
if with_roles:
|
||||
result = result.unique()
|
||||
user = result.scalar_one_or_none()
|
||||
if user is None:
|
||||
raise EntityNotFoundError(f"User {user_id} not found")
|
||||
return user
|
||||
|
||||
def get_by_email(self, email: str, *, with_roles: bool = False) -> User | None:
|
||||
stmt = select(User).where(User.email == email).execution_options(
|
||||
populate_existing=True)
|
||||
stmt = self._apply_role_option(stmt, with_roles)
|
||||
result = self.session.execute(stmt)
|
||||
if with_roles:
|
||||
result = result.unique()
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
def get_by_username(self, username: str, *, with_roles: bool = False) -> User | None:
|
||||
stmt = select(User).where(User.username ==
|
||||
username).execution_options(populate_existing=True)
|
||||
stmt = self._apply_role_option(stmt, with_roles)
|
||||
result = self.session.execute(stmt)
|
||||
if with_roles:
|
||||
result = result.unique()
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
def create(self, user: User) -> User:
|
||||
self.session.add(user)
|
||||
try:
|
||||
self.session.flush()
|
||||
except IntegrityError as exc: # pragma: no cover - DB constraint enforcement
|
||||
raise EntityConflictError(
|
||||
"User violates uniqueness constraints") from exc
|
||||
return user
|
||||
|
||||
def assign_role(
|
||||
self,
|
||||
*,
|
||||
user_id: int,
|
||||
role_id: int,
|
||||
granted_by: int | None = None,
|
||||
) -> UserRole:
|
||||
stmt = select(UserRole).where(
|
||||
UserRole.user_id == user_id,
|
||||
UserRole.role_id == role_id,
|
||||
)
|
||||
assignment = self.session.execute(stmt).scalar_one_or_none()
|
||||
if assignment:
|
||||
return assignment
|
||||
|
||||
assignment = UserRole(
|
||||
user_id=user_id,
|
||||
role_id=role_id,
|
||||
granted_by=granted_by,
|
||||
)
|
||||
self.session.add(assignment)
|
||||
try:
|
||||
self.session.flush()
|
||||
except IntegrityError as exc: # pragma: no cover - DB constraint enforcement
|
||||
raise EntityConflictError(
|
||||
"Assignment violates constraints") from exc
|
||||
return assignment
|
||||
|
||||
def revoke_role(self, *, user_id: int, role_id: int) -> None:
|
||||
stmt = select(UserRole).where(
|
||||
UserRole.user_id == user_id,
|
||||
UserRole.role_id == role_id,
|
||||
)
|
||||
assignment = self.session.execute(stmt).scalar_one_or_none()
|
||||
if assignment is None:
|
||||
raise EntityNotFoundError(
|
||||
f"Role {role_id} not assigned to user {user_id}")
|
||||
self.session.delete(assignment)
|
||||
self.session.flush()
|
||||
|
||||
|
||||
DEFAULT_ROLE_DEFINITIONS: tuple[dict[str, str], ...] = (
|
||||
{
|
||||
"name": "admin",
|
||||
"display_name": "Administrator",
|
||||
"description": "Full platform access with user management rights.",
|
||||
},
|
||||
{
|
||||
"name": "project_manager",
|
||||
"display_name": "Project Manager",
|
||||
"description": "Manage projects, scenarios, and associated data.",
|
||||
},
|
||||
{
|
||||
"name": "analyst",
|
||||
"display_name": "Analyst",
|
||||
"description": "Review dashboards and scenario outputs.",
|
||||
},
|
||||
{
|
||||
"name": "viewer",
|
||||
"display_name": "Viewer",
|
||||
"description": "Read-only access to assigned projects and reports.",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def ensure_default_roles(role_repo: RoleRepository) -> list[Role]:
|
||||
"""Ensure standard roles exist, creating missing ones.
|
||||
|
||||
Returns all current role records in creation order.
|
||||
"""
|
||||
|
||||
roles: list[Role] = []
|
||||
for definition in DEFAULT_ROLE_DEFINITIONS:
|
||||
existing = role_repo.get_by_name(definition["name"])
|
||||
if existing:
|
||||
roles.append(existing)
|
||||
continue
|
||||
role = Role(**definition)
|
||||
roles.append(role_repo.create(role))
|
||||
return roles
|
||||
|
||||
|
||||
def ensure_admin_user(
|
||||
user_repo: UserRepository,
|
||||
role_repo: RoleRepository,
|
||||
*,
|
||||
email: str,
|
||||
username: str,
|
||||
password: str,
|
||||
) -> User:
|
||||
"""Ensure an administrator user exists and holds the admin role."""
|
||||
|
||||
user = user_repo.get_by_email(email, with_roles=True)
|
||||
if user is None:
|
||||
user = User(
|
||||
email=email,
|
||||
username=username,
|
||||
password_hash=User.hash_password(password),
|
||||
is_active=True,
|
||||
is_superuser=True,
|
||||
)
|
||||
user_repo.create(user)
|
||||
else:
|
||||
if not user.is_active:
|
||||
user.is_active = True
|
||||
if not user.is_superuser:
|
||||
user.is_superuser = True
|
||||
user_repo.session.flush()
|
||||
|
||||
admin_role = role_repo.get_by_name("admin")
|
||||
if admin_role is None: # pragma: no cover - safety if ensure_default_roles wasn't called
|
||||
admin_role = role_repo.create(
|
||||
Role(
|
||||
name="admin",
|
||||
display_name="Administrator",
|
||||
description="Full platform access with user management rights.",
|
||||
)
|
||||
)
|
||||
|
||||
user_repo.assign_role(
|
||||
user_id=user.id,
|
||||
role_id=admin_role.id,
|
||||
granted_by=user.id,
|
||||
)
|
||||
return user
|
||||
106
services/scenario_validation.py
Normal file
106
services/scenario_validation.py
Normal file
@@ -0,0 +1,106 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from datetime import date
|
||||
from typing import Iterable, Sequence
|
||||
|
||||
from models import Scenario, ScenarioStatus
|
||||
from services.exceptions import ScenarioValidationError
|
||||
|
||||
ALLOWED_STATUSES: frozenset[ScenarioStatus] = frozenset(
|
||||
{ScenarioStatus.DRAFT, ScenarioStatus.ACTIVE}
|
||||
)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class _ValidationContext:
|
||||
scenarios: Sequence[Scenario]
|
||||
|
||||
@property
|
||||
def scenario_ids(self) -> list[int]:
|
||||
return [scenario.id for scenario in self.scenarios if scenario.id is not None]
|
||||
|
||||
|
||||
class ScenarioComparisonValidator:
|
||||
"""Validates scenarios prior to comparison workflows."""
|
||||
|
||||
def validate(self, scenarios: Sequence[Scenario] | Iterable[Scenario]) -> None:
|
||||
scenario_list = list(scenarios)
|
||||
if len(scenario_list) < 2:
|
||||
# Nothing to validate when fewer than two scenarios are provided.
|
||||
return
|
||||
|
||||
context = _ValidationContext(scenario_list)
|
||||
|
||||
self._ensure_same_project(context)
|
||||
self._ensure_allowed_status(context)
|
||||
self._ensure_shared_currency(context)
|
||||
self._ensure_timeline_overlap(context)
|
||||
self._ensure_shared_primary_resource(context)
|
||||
|
||||
def _ensure_same_project(self, context: _ValidationContext) -> None:
|
||||
project_ids = {scenario.project_id for scenario in context.scenarios}
|
||||
if len(project_ids) > 1:
|
||||
raise ScenarioValidationError(
|
||||
code="SCENARIO_PROJECT_MISMATCH",
|
||||
message="Selected scenarios do not belong to the same project.",
|
||||
scenario_ids=context.scenario_ids,
|
||||
)
|
||||
|
||||
def _ensure_allowed_status(self, context: _ValidationContext) -> None:
|
||||
disallowed = [
|
||||
scenario
|
||||
for scenario in context.scenarios
|
||||
if scenario.status not in ALLOWED_STATUSES
|
||||
]
|
||||
if disallowed:
|
||||
raise ScenarioValidationError(
|
||||
code="SCENARIO_STATUS_INVALID",
|
||||
message="Archived scenarios cannot be compared.",
|
||||
scenario_ids=[
|
||||
scenario.id for scenario in disallowed if scenario.id is not None],
|
||||
)
|
||||
|
||||
def _ensure_shared_currency(self, context: _ValidationContext) -> None:
|
||||
currencies = {
|
||||
scenario.currency
|
||||
for scenario in context.scenarios
|
||||
if scenario.currency is not None
|
||||
}
|
||||
if len(currencies) > 1:
|
||||
raise ScenarioValidationError(
|
||||
code="SCENARIO_CURRENCY_MISMATCH",
|
||||
message="Scenarios use different currencies and cannot be compared.",
|
||||
scenario_ids=context.scenario_ids,
|
||||
)
|
||||
|
||||
def _ensure_timeline_overlap(self, context: _ValidationContext) -> None:
|
||||
ranges = [
|
||||
(scenario.start_date, scenario.end_date)
|
||||
for scenario in context.scenarios
|
||||
if scenario.start_date and scenario.end_date
|
||||
]
|
||||
if len(ranges) < 2:
|
||||
return
|
||||
|
||||
latest_start: date = max(start for start, _ in ranges)
|
||||
earliest_end: date = min(end for _, end in ranges)
|
||||
if latest_start > earliest_end:
|
||||
raise ScenarioValidationError(
|
||||
code="SCENARIO_TIMELINE_DISJOINT",
|
||||
message="Scenario timelines do not overlap; adjust the comparison window.",
|
||||
scenario_ids=context.scenario_ids,
|
||||
)
|
||||
|
||||
def _ensure_shared_primary_resource(self, context: _ValidationContext) -> None:
|
||||
resources = {
|
||||
scenario.primary_resource
|
||||
for scenario in context.scenarios
|
||||
if scenario.primary_resource is not None
|
||||
}
|
||||
if len(resources) > 1:
|
||||
raise ScenarioValidationError(
|
||||
code="SCENARIO_RESOURCE_MISMATCH",
|
||||
message="Scenarios target different primary resources and cannot be compared.",
|
||||
scenario_ids=context.scenario_ids,
|
||||
)
|
||||
213
services/security.py
Normal file
213
services/security.py
Normal file
@@ -0,0 +1,213 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Any, Dict, Iterable, Literal, Type
|
||||
|
||||
from jose import ExpiredSignatureError, JWTError, jwt
|
||||
from passlib.context import CryptContext
|
||||
|
||||
try: # pragma: no cover - compatibility shim for passlib/argon2 warning
|
||||
import importlib.metadata as importlib_metadata
|
||||
import argon2 # type: ignore
|
||||
|
||||
setattr(argon2, "__version__", importlib_metadata.version("argon2-cffi"))
|
||||
except Exception: # pragma: no cover - executed only when metadata lookup fails
|
||||
pass
|
||||
|
||||
from pydantic import BaseModel, Field, ValidationError
|
||||
|
||||
password_context = CryptContext(schemes=["argon2"], deprecated="auto")
|
||||
|
||||
|
||||
def hash_password(password: str) -> str:
|
||||
"""Derive a secure hash for a plain-text password."""
|
||||
|
||||
return password_context.hash(password)
|
||||
|
||||
|
||||
def verify_password(candidate: str, hashed: str) -> bool:
|
||||
"""Verify that a candidate password matches a stored hash."""
|
||||
|
||||
try:
|
||||
return password_context.verify(candidate, hashed)
|
||||
except ValueError:
|
||||
# Raised when the stored hash is malformed or uses an unknown scheme.
|
||||
return False
|
||||
|
||||
|
||||
class TokenError(Exception):
|
||||
"""Base class for token encoding/decoding issues."""
|
||||
|
||||
|
||||
class TokenDecodeError(TokenError):
|
||||
"""Raised when a token cannot be decoded or validated."""
|
||||
|
||||
|
||||
class TokenExpiredError(TokenError):
|
||||
"""Raised when a token has expired."""
|
||||
|
||||
|
||||
class TokenTypeMismatchError(TokenError):
|
||||
"""Raised when a token type does not match the expected flavour."""
|
||||
|
||||
|
||||
TokenKind = Literal["access", "refresh"]
|
||||
|
||||
|
||||
class TokenPayload(BaseModel):
|
||||
"""Shared fields for CalMiner JWT payloads."""
|
||||
|
||||
sub: str
|
||||
exp: int
|
||||
type: TokenKind
|
||||
scopes: list[str] = Field(default_factory=list)
|
||||
|
||||
@property
|
||||
def expires_at(self) -> datetime:
|
||||
return datetime.fromtimestamp(self.exp, tz=timezone.utc)
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class JWTSettings:
|
||||
"""Runtime configuration for JWT encoding and validation."""
|
||||
|
||||
secret_key: str
|
||||
algorithm: str = "HS256"
|
||||
access_token_ttl: timedelta = field(
|
||||
default_factory=lambda: timedelta(minutes=15))
|
||||
refresh_token_ttl: timedelta = field(
|
||||
default_factory=lambda: timedelta(days=7))
|
||||
|
||||
|
||||
def create_access_token(
|
||||
subject: str,
|
||||
settings: JWTSettings,
|
||||
*,
|
||||
scopes: Iterable[str] | None = None,
|
||||
expires_delta: timedelta | None = None,
|
||||
extra_claims: Dict[str, Any] | None = None,
|
||||
) -> str:
|
||||
"""Issue a signed access token for the provided subject."""
|
||||
|
||||
lifetime = expires_delta or settings.access_token_ttl
|
||||
return _create_token(
|
||||
subject=subject,
|
||||
token_type="access",
|
||||
settings=settings,
|
||||
lifetime=lifetime,
|
||||
scopes=scopes,
|
||||
extra_claims=extra_claims,
|
||||
)
|
||||
|
||||
|
||||
def create_refresh_token(
|
||||
subject: str,
|
||||
settings: JWTSettings,
|
||||
*,
|
||||
scopes: Iterable[str] | None = None,
|
||||
expires_delta: timedelta | None = None,
|
||||
extra_claims: Dict[str, Any] | None = None,
|
||||
) -> str:
|
||||
"""Issue a signed refresh token for the provided subject."""
|
||||
|
||||
lifetime = expires_delta or settings.refresh_token_ttl
|
||||
return _create_token(
|
||||
subject=subject,
|
||||
token_type="refresh",
|
||||
settings=settings,
|
||||
lifetime=lifetime,
|
||||
scopes=scopes,
|
||||
extra_claims=extra_claims,
|
||||
)
|
||||
|
||||
|
||||
def decode_access_token(token: str, settings: JWTSettings) -> TokenPayload:
|
||||
"""Validate and decode an access token."""
|
||||
|
||||
return _decode_token(token, settings, expected_type="access")
|
||||
|
||||
|
||||
def decode_refresh_token(token: str, settings: JWTSettings) -> TokenPayload:
|
||||
"""Validate and decode a refresh token."""
|
||||
|
||||
return _decode_token(token, settings, expected_type="refresh")
|
||||
|
||||
|
||||
def _create_token(
|
||||
*,
|
||||
subject: str,
|
||||
token_type: TokenKind,
|
||||
settings: JWTSettings,
|
||||
lifetime: timedelta,
|
||||
scopes: Iterable[str] | None,
|
||||
extra_claims: Dict[str, Any] | None,
|
||||
) -> str:
|
||||
now = datetime.now(timezone.utc)
|
||||
expire = now + lifetime
|
||||
payload: Dict[str, Any] = {
|
||||
"sub": subject,
|
||||
"type": token_type,
|
||||
"iat": int(now.timestamp()),
|
||||
"exp": int(expire.timestamp()),
|
||||
}
|
||||
if scopes:
|
||||
payload["scopes"] = list(scopes)
|
||||
if extra_claims:
|
||||
payload.update(extra_claims)
|
||||
|
||||
return jwt.encode(payload, settings.secret_key, algorithm=settings.algorithm)
|
||||
|
||||
|
||||
def _decode_token(
|
||||
token: str,
|
||||
settings: JWTSettings,
|
||||
expected_type: TokenKind,
|
||||
) -> TokenPayload:
|
||||
try:
|
||||
decoded = jwt.decode(
|
||||
token,
|
||||
settings.secret_key,
|
||||
algorithms=[settings.algorithm],
|
||||
options={"verify_aud": False},
|
||||
)
|
||||
except ExpiredSignatureError as exc: # pragma: no cover - jose marks this path
|
||||
raise TokenExpiredError("Token has expired") from exc
|
||||
except JWTError as exc: # pragma: no cover - jose error bubble
|
||||
raise TokenDecodeError("Unable to decode token") from exc
|
||||
|
||||
try:
|
||||
payload = _model_validate(TokenPayload, decoded)
|
||||
except ValidationError as exc:
|
||||
raise TokenDecodeError("Token payload validation failed") from exc
|
||||
|
||||
if payload.type != expected_type:
|
||||
raise TokenTypeMismatchError(
|
||||
f"Expected a {expected_type} token but received '{payload.type}'."
|
||||
)
|
||||
|
||||
return payload
|
||||
|
||||
|
||||
def _model_validate(model: Type[TokenPayload], data: Dict[str, Any]) -> TokenPayload:
|
||||
if hasattr(model, "model_validate"):
|
||||
return model.model_validate(data) # type: ignore[attr-defined]
|
||||
return model.parse_obj(data) # type: ignore[attr-defined]
|
||||
|
||||
|
||||
__all__ = [
|
||||
"JWTSettings",
|
||||
"TokenDecodeError",
|
||||
"TokenError",
|
||||
"TokenExpiredError",
|
||||
"TokenKind",
|
||||
"TokenPayload",
|
||||
"TokenTypeMismatchError",
|
||||
"create_access_token",
|
||||
"create_refresh_token",
|
||||
"decode_access_token",
|
||||
"decode_refresh_token",
|
||||
"hash_password",
|
||||
"password_context",
|
||||
"verify_password",
|
||||
]
|
||||
192
services/session.py
Normal file
192
services/session.py
Normal file
@@ -0,0 +1,192 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Literal, Optional, TYPE_CHECKING
|
||||
|
||||
from fastapi import Request, Response
|
||||
|
||||
from config.settings import SessionSettings
|
||||
from services.security import JWTSettings
|
||||
|
||||
if TYPE_CHECKING: # pragma: no cover - used only for static typing
|
||||
from models import User
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class SessionStrategy:
|
||||
"""Describe how authentication tokens are transported with requests."""
|
||||
|
||||
access_cookie_name: str
|
||||
refresh_cookie_name: str
|
||||
cookie_secure: bool
|
||||
cookie_domain: Optional[str]
|
||||
cookie_path: str
|
||||
header_name: str
|
||||
header_prefix: str
|
||||
allow_header_fallback: bool = True
|
||||
|
||||
@classmethod
|
||||
def from_settings(cls, settings: SessionSettings) -> "SessionStrategy":
|
||||
return cls(
|
||||
access_cookie_name=settings.access_cookie_name,
|
||||
refresh_cookie_name=settings.refresh_cookie_name,
|
||||
cookie_secure=settings.cookie_secure,
|
||||
cookie_domain=settings.cookie_domain,
|
||||
cookie_path=settings.cookie_path,
|
||||
header_name=settings.header_name,
|
||||
header_prefix=settings.header_prefix,
|
||||
allow_header_fallback=settings.allow_header_fallback,
|
||||
)
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class SessionTokens:
|
||||
"""Raw access and refresh tokens extracted from the transport layer."""
|
||||
|
||||
access_token: Optional[str]
|
||||
refresh_token: Optional[str]
|
||||
access_token_source: Literal["cookie", "header", "none"] = "none"
|
||||
|
||||
@property
|
||||
def has_access(self) -> bool:
|
||||
return bool(self.access_token)
|
||||
|
||||
@property
|
||||
def has_refresh(self) -> bool:
|
||||
return bool(self.refresh_token)
|
||||
|
||||
@property
|
||||
def is_empty(self) -> bool:
|
||||
return not self.has_access and not self.has_refresh
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class AuthSession:
|
||||
"""Holds authenticated user context resolved from session tokens."""
|
||||
|
||||
tokens: SessionTokens
|
||||
user: Optional["User"] = None
|
||||
scopes: tuple[str, ...] = ()
|
||||
issued_access_token: Optional[str] = None
|
||||
issued_refresh_token: Optional[str] = None
|
||||
clear_cookies: bool = False
|
||||
|
||||
@property
|
||||
def is_authenticated(self) -> bool:
|
||||
return self.user is not None
|
||||
|
||||
@classmethod
|
||||
def anonymous(cls) -> "AuthSession":
|
||||
return cls(tokens=SessionTokens(access_token=None, refresh_token=None))
|
||||
|
||||
def issue_tokens(
|
||||
self,
|
||||
*,
|
||||
access_token: str,
|
||||
refresh_token: Optional[str] = None,
|
||||
access_source: Literal["cookie", "header", "none"] = "cookie",
|
||||
) -> None:
|
||||
self.issued_access_token = access_token
|
||||
if refresh_token is not None:
|
||||
self.issued_refresh_token = refresh_token
|
||||
self.tokens = SessionTokens(
|
||||
access_token=access_token,
|
||||
refresh_token=refresh_token if refresh_token is not None else self.tokens.refresh_token,
|
||||
access_token_source=access_source,
|
||||
)
|
||||
|
||||
def mark_cleared(self) -> None:
|
||||
self.clear_cookies = True
|
||||
self.tokens = SessionTokens(access_token=None, refresh_token=None)
|
||||
self.user = None
|
||||
self.scopes = ()
|
||||
|
||||
|
||||
def extract_session_tokens(request: Request, strategy: SessionStrategy) -> SessionTokens:
|
||||
"""Pull tokens from cookies or headers according to configured strategy."""
|
||||
|
||||
access_token: Optional[str] = None
|
||||
refresh_token: Optional[str] = None
|
||||
access_source: Literal["cookie", "header", "none"] = "none"
|
||||
|
||||
if strategy.access_cookie_name in request.cookies:
|
||||
access_token = request.cookies.get(strategy.access_cookie_name) or None
|
||||
if access_token:
|
||||
access_source = "cookie"
|
||||
|
||||
if strategy.refresh_cookie_name in request.cookies:
|
||||
refresh_token = request.cookies.get(
|
||||
strategy.refresh_cookie_name) or None
|
||||
|
||||
if not access_token and strategy.allow_header_fallback:
|
||||
header_value = request.headers.get(strategy.header_name)
|
||||
if header_value:
|
||||
candidate = header_value.strip()
|
||||
prefix = f"{strategy.header_prefix} " if strategy.header_prefix else ""
|
||||
if prefix and candidate.lower().startswith(prefix.lower()):
|
||||
candidate = candidate[len(prefix):].strip()
|
||||
if candidate:
|
||||
access_token = candidate
|
||||
access_source = "header"
|
||||
|
||||
return SessionTokens(
|
||||
access_token=access_token,
|
||||
refresh_token=refresh_token,
|
||||
access_token_source=access_source,
|
||||
)
|
||||
|
||||
|
||||
def build_session_strategy(settings: SessionSettings) -> SessionStrategy:
|
||||
"""Create a session strategy object from settings configuration."""
|
||||
|
||||
return SessionStrategy.from_settings(settings)
|
||||
|
||||
|
||||
def set_session_cookies(
|
||||
response: Response,
|
||||
*,
|
||||
access_token: str,
|
||||
refresh_token: Optional[str],
|
||||
strategy: SessionStrategy,
|
||||
jwt_settings: JWTSettings,
|
||||
) -> None:
|
||||
"""Persist session cookies on an outgoing response."""
|
||||
|
||||
access_ttl = int(jwt_settings.access_token_ttl.total_seconds())
|
||||
refresh_ttl = int(jwt_settings.refresh_token_ttl.total_seconds())
|
||||
response.set_cookie(
|
||||
strategy.access_cookie_name,
|
||||
access_token,
|
||||
httponly=True,
|
||||
secure=strategy.cookie_secure,
|
||||
samesite="lax",
|
||||
max_age=max(access_ttl, 0) or None,
|
||||
domain=strategy.cookie_domain,
|
||||
path=strategy.cookie_path,
|
||||
)
|
||||
if refresh_token is not None:
|
||||
response.set_cookie(
|
||||
strategy.refresh_cookie_name,
|
||||
refresh_token,
|
||||
httponly=True,
|
||||
secure=strategy.cookie_secure,
|
||||
samesite="lax",
|
||||
max_age=max(refresh_ttl, 0) or None,
|
||||
domain=strategy.cookie_domain,
|
||||
path=strategy.cookie_path,
|
||||
)
|
||||
|
||||
|
||||
def clear_session_cookies(response: Response, strategy: SessionStrategy) -> None:
|
||||
"""Remove session cookies from the client."""
|
||||
|
||||
response.delete_cookie(
|
||||
strategy.access_cookie_name,
|
||||
domain=strategy.cookie_domain,
|
||||
path=strategy.cookie_path,
|
||||
)
|
||||
response.delete_cookie(
|
||||
strategy.refresh_cookie_name,
|
||||
domain=strategy.cookie_domain,
|
||||
path=strategy.cookie_path,
|
||||
)
|
||||
118
services/unit_of_work.py
Normal file
118
services/unit_of_work.py
Normal file
@@ -0,0 +1,118 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from contextlib import AbstractContextManager
|
||||
from typing import Callable, Sequence
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from config.database import SessionLocal
|
||||
from models import Role, Scenario
|
||||
from services.repositories import (
|
||||
FinancialInputRepository,
|
||||
ProjectRepository,
|
||||
RoleRepository,
|
||||
ScenarioRepository,
|
||||
SimulationParameterRepository,
|
||||
UserRepository,
|
||||
ensure_admin_user as ensure_admin_user_record,
|
||||
ensure_default_roles,
|
||||
)
|
||||
from services.scenario_validation import ScenarioComparisonValidator
|
||||
|
||||
|
||||
class UnitOfWork(AbstractContextManager["UnitOfWork"]):
|
||||
"""Simple unit-of-work wrapper around SQLAlchemy sessions."""
|
||||
|
||||
def __init__(self, session_factory: Callable[[], Session] = SessionLocal) -> None:
|
||||
self._session_factory = session_factory
|
||||
self.session: Session | None = None
|
||||
self._scenario_validator: ScenarioComparisonValidator | None = None
|
||||
self.projects: ProjectRepository | None = None
|
||||
self.scenarios: ScenarioRepository | None = None
|
||||
self.financial_inputs: FinancialInputRepository | None = None
|
||||
self.simulation_parameters: SimulationParameterRepository | None = None
|
||||
self.users: UserRepository | None = None
|
||||
self.roles: RoleRepository | None = None
|
||||
|
||||
def __enter__(self) -> "UnitOfWork":
|
||||
self.session = self._session_factory()
|
||||
self.projects = ProjectRepository(self.session)
|
||||
self.scenarios = ScenarioRepository(self.session)
|
||||
self.financial_inputs = FinancialInputRepository(self.session)
|
||||
self.simulation_parameters = SimulationParameterRepository(
|
||||
self.session)
|
||||
self.users = UserRepository(self.session)
|
||||
self.roles = RoleRepository(self.session)
|
||||
self._scenario_validator = ScenarioComparisonValidator()
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_value, traceback) -> None:
|
||||
assert self.session is not None
|
||||
if exc_type is None:
|
||||
self.session.commit()
|
||||
else:
|
||||
self.session.rollback()
|
||||
self.session.close()
|
||||
self._scenario_validator = None
|
||||
self.projects = None
|
||||
self.scenarios = None
|
||||
self.financial_inputs = None
|
||||
self.simulation_parameters = None
|
||||
self.users = None
|
||||
self.roles = None
|
||||
|
||||
def flush(self) -> None:
|
||||
if not self.session:
|
||||
raise RuntimeError("UnitOfWork session is not initialised")
|
||||
self.session.flush()
|
||||
|
||||
def commit(self) -> None:
|
||||
if not self.session:
|
||||
raise RuntimeError("UnitOfWork session is not initialised")
|
||||
self.session.commit()
|
||||
|
||||
def rollback(self) -> None:
|
||||
if not self.session:
|
||||
raise RuntimeError("UnitOfWork session is not initialised")
|
||||
self.session.rollback()
|
||||
|
||||
def validate_scenarios_for_comparison(
|
||||
self, scenario_ids: Sequence[int]
|
||||
) -> list[Scenario]:
|
||||
if not self.session or not self._scenario_validator or not self.scenarios:
|
||||
raise RuntimeError("UnitOfWork session is not initialised")
|
||||
|
||||
scenarios = [self.scenarios.get(scenario_id)
|
||||
for scenario_id in scenario_ids]
|
||||
self._scenario_validator.validate(scenarios)
|
||||
return scenarios
|
||||
|
||||
def validate_scenario_models_for_comparison(
|
||||
self, scenarios: Sequence[Scenario]
|
||||
) -> None:
|
||||
if not self._scenario_validator:
|
||||
raise RuntimeError("UnitOfWork session is not initialised")
|
||||
self._scenario_validator.validate(scenarios)
|
||||
|
||||
def ensure_default_roles(self) -> list[Role]:
|
||||
if not self.roles:
|
||||
raise RuntimeError("UnitOfWork session is not initialised")
|
||||
return ensure_default_roles(self.roles)
|
||||
|
||||
def ensure_admin_user(
|
||||
self,
|
||||
*,
|
||||
email: str,
|
||||
username: str,
|
||||
password: str,
|
||||
) -> None:
|
||||
if not self.users or not self.roles:
|
||||
raise RuntimeError("UnitOfWork session is not initialised")
|
||||
ensure_default_roles(self.roles)
|
||||
ensure_admin_user_record(
|
||||
self.users,
|
||||
self.roles,
|
||||
email=email,
|
||||
username=username,
|
||||
password=password,
|
||||
)
|
||||
150
static/css/dashboard.css
Normal file
150
static/css/dashboard.css
Normal file
@@ -0,0 +1,150 @@
|
||||
:root {
|
||||
--dashboard-gap: 1.5rem;
|
||||
}
|
||||
|
||||
.dashboard-header {
|
||||
align-items: center;
|
||||
}
|
||||
|
||||
.header-actions {
|
||||
display: flex;
|
||||
gap: 0.75rem;
|
||||
flex-wrap: wrap;
|
||||
justify-content: flex-end;
|
||||
}
|
||||
|
||||
.dashboard-metrics {
|
||||
display: grid;
|
||||
gap: var(--dashboard-gap);
|
||||
grid-template-columns: repeat(auto-fit, minmax(220px, 1fr));
|
||||
margin-bottom: 2rem;
|
||||
}
|
||||
|
||||
.metric-card {
|
||||
background: var(--card);
|
||||
border-radius: var(--radius);
|
||||
padding: 1.5rem;
|
||||
box-shadow: var(--shadow);
|
||||
border: 1px solid var(--color-border);
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
gap: 0.35rem;
|
||||
}
|
||||
|
||||
.metric-card h2 {
|
||||
margin: 0;
|
||||
font-size: 1rem;
|
||||
color: var(--muted);
|
||||
text-transform: uppercase;
|
||||
letter-spacing: 0.08em;
|
||||
}
|
||||
|
||||
.metric-value {
|
||||
font-size: 2rem;
|
||||
font-weight: 700;
|
||||
margin: 0;
|
||||
}
|
||||
|
||||
.metric-caption {
|
||||
color: var(--color-text-subtle);
|
||||
font-size: 0.85rem;
|
||||
}
|
||||
|
||||
.dashboard-grid {
|
||||
display: grid;
|
||||
gap: var(--dashboard-gap);
|
||||
grid-template-columns: 2fr 1fr;
|
||||
align-items: start;
|
||||
}
|
||||
|
||||
.grid-main {
|
||||
display: grid;
|
||||
gap: var(--dashboard-gap);
|
||||
}
|
||||
|
||||
.grid-sidebar {
|
||||
display: grid;
|
||||
gap: var(--dashboard-gap);
|
||||
}
|
||||
|
||||
.table-link {
|
||||
color: var(--brand-2);
|
||||
text-decoration: none;
|
||||
}
|
||||
|
||||
.table-link:hover,
|
||||
.table-link:focus {
|
||||
text-decoration: underline;
|
||||
}
|
||||
|
||||
.timeline {
|
||||
list-style: none;
|
||||
margin: 0;
|
||||
padding: 0;
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
gap: 1rem;
|
||||
}
|
||||
|
||||
.timeline-label {
|
||||
font-size: 0.85rem;
|
||||
color: var(--color-text-subtle);
|
||||
display: block;
|
||||
margin-bottom: 0.35rem;
|
||||
}
|
||||
|
||||
.alerts-list,
|
||||
.links-list {
|
||||
list-style: none;
|
||||
margin: 0;
|
||||
padding: 0;
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
gap: 0.75rem;
|
||||
}
|
||||
|
||||
.alerts-list li {
|
||||
padding: 0.75rem;
|
||||
border-radius: var(--radius-sm);
|
||||
background: rgba(209, 75, 75, 0.16);
|
||||
border: 1px solid rgba(209, 75, 75, 0.3);
|
||||
}
|
||||
|
||||
.links-list a {
|
||||
color: var(--brand-3);
|
||||
text-decoration: none;
|
||||
}
|
||||
|
||||
.links-list a:hover,
|
||||
.links-list a:focus {
|
||||
text-decoration: underline;
|
||||
}
|
||||
|
||||
@media (max-width: 1024px) {
|
||||
.dashboard-grid {
|
||||
grid-template-columns: 1fr;
|
||||
}
|
||||
|
||||
.grid-sidebar {
|
||||
grid-template-columns: repeat(auto-fit, minmax(260px, 1fr));
|
||||
}
|
||||
|
||||
.header-actions {
|
||||
justify-content: flex-start;
|
||||
}
|
||||
}
|
||||
|
||||
@media (max-width: 640px) {
|
||||
.metric-card {
|
||||
padding: 1.25rem;
|
||||
}
|
||||
|
||||
.metric-value {
|
||||
font-size: 1.75rem;
|
||||
}
|
||||
|
||||
.header-actions {
|
||||
flex-direction: column;
|
||||
align-items: stretch;
|
||||
}
|
||||
}
|
||||
@@ -52,8 +52,8 @@ body {
|
||||
|
||||
body {
|
||||
margin: 0;
|
||||
font-family: ui-sans-serif, system-ui, -apple-system, 'Segoe UI', 'Roboto',
|
||||
Helvetica, Arial, 'Apple Color Emoji', 'Segoe UI Emoji';
|
||||
font-family: ui-sans-serif, system-ui, -apple-system, "Segoe UI", "Roboto",
|
||||
Helvetica, Arial, "Apple Color Emoji", "Segoe UI Emoji";
|
||||
color: var(--text);
|
||||
background: linear-gradient(180deg, var(--bg) 0%, var(--bg-2) 100%);
|
||||
line-height: 1.45;
|
||||
@@ -337,7 +337,7 @@ a {
|
||||
gap: var(--space-sm);
|
||||
font-weight: 600;
|
||||
color: var(--text);
|
||||
font-family: 'Fira Code', 'Consolas', 'Courier New', monospace;
|
||||
font-family: "Fira Code", "Consolas", "Courier New", monospace;
|
||||
font-size: 0.85rem;
|
||||
}
|
||||
|
||||
@@ -366,7 +366,7 @@ a {
|
||||
}
|
||||
|
||||
.color-value-input {
|
||||
font-family: 'Fira Code', 'Consolas', 'Courier New', monospace;
|
||||
font-family: "Fira Code", "Consolas", "Courier New", monospace;
|
||||
}
|
||||
|
||||
.color-value-input[disabled] {
|
||||
@@ -395,7 +395,7 @@ a {
|
||||
}
|
||||
|
||||
.env-overrides-table code {
|
||||
font-family: 'Fira Code', 'Consolas', 'Courier New', monospace;
|
||||
font-family: "Fira Code", "Consolas", "Courier New", monospace;
|
||||
font-size: 0.85rem;
|
||||
}
|
||||
|
||||
@@ -550,7 +550,7 @@ a {
|
||||
}
|
||||
|
||||
.btn.is-loading::after {
|
||||
content: '';
|
||||
content: "";
|
||||
width: 0.85rem;
|
||||
height: 0.85rem;
|
||||
border: 2px solid rgba(255, 255, 255, 0.6);
|
||||
@@ -656,14 +656,14 @@ a {
|
||||
color: var(--color-surface-alt);
|
||||
padding: 1rem;
|
||||
border-radius: 8px;
|
||||
font-family: 'Fira Code', 'Consolas', 'Courier New', monospace;
|
||||
font-family: "Fira Code", "Consolas", "Courier New", monospace;
|
||||
overflow-x: auto;
|
||||
margin-top: 1.5rem;
|
||||
}
|
||||
|
||||
.monospace-input {
|
||||
width: 100%;
|
||||
font-family: 'Fira Code', 'Consolas', 'Courier New', monospace;
|
||||
font-family: "Fira Code", "Consolas", "Courier New", monospace;
|
||||
min-height: 120px;
|
||||
}
|
||||
|
||||
@@ -740,6 +740,72 @@ tbody tr:nth-child(even) {
|
||||
font-size: 0.9rem;
|
||||
}
|
||||
|
||||
.sidebar-toggle {
|
||||
display: none;
|
||||
align-items: center;
|
||||
gap: 0.6rem;
|
||||
padding: 0.55rem 1rem;
|
||||
border-radius: 999px;
|
||||
border: none;
|
||||
background: linear-gradient(135deg, var(--brand-2), var(--brand));
|
||||
color: var(--color-text-dark);
|
||||
font-weight: 600;
|
||||
cursor: pointer;
|
||||
box-shadow: 0 6px 16px rgba(0, 0, 0, 0.25);
|
||||
transition: transform 0.2s ease, box-shadow 0.2s ease;
|
||||
}
|
||||
|
||||
.sidebar-toggle:hover,
|
||||
.sidebar-toggle:focus-visible {
|
||||
transform: translateY(-1px);
|
||||
box-shadow: 0 8px 20px rgba(0, 0, 0, 0.3);
|
||||
}
|
||||
|
||||
.sidebar-toggle:focus-visible {
|
||||
outline: 2px solid rgba(255, 255, 255, 0.65);
|
||||
outline-offset: 3px;
|
||||
}
|
||||
|
||||
.sidebar-toggle-icon {
|
||||
position: relative;
|
||||
display: inline-block;
|
||||
width: 18px;
|
||||
height: 2px;
|
||||
background-color: currentColor;
|
||||
}
|
||||
|
||||
.sidebar-toggle-icon::before,
|
||||
.sidebar-toggle-icon::after {
|
||||
content: "";
|
||||
position: absolute;
|
||||
left: 0;
|
||||
width: 18px;
|
||||
height: 2px;
|
||||
background-color: currentColor;
|
||||
}
|
||||
|
||||
.sidebar-toggle-icon::before {
|
||||
top: -6px;
|
||||
}
|
||||
|
||||
.sidebar-toggle-icon::after {
|
||||
top: 6px;
|
||||
}
|
||||
|
||||
.sidebar-toggle-label {
|
||||
font-size: 0.95rem;
|
||||
}
|
||||
|
||||
.sidebar-overlay {
|
||||
position: fixed;
|
||||
inset: 0;
|
||||
background: rgba(7, 11, 17, 0.6);
|
||||
z-index: 800;
|
||||
opacity: 0;
|
||||
pointer-events: none;
|
||||
transition: opacity 0.25s ease;
|
||||
}
|
||||
|
||||
@media (max-width: 1024px) {
|
||||
.app-sidebar {
|
||||
width: 240px;
|
||||
@@ -790,4 +856,39 @@ tbody tr:nth-child(even) {
|
||||
.dashboard-columns {
|
||||
grid-template-columns: 1fr;
|
||||
}
|
||||
|
||||
.sidebar-toggle {
|
||||
display: inline-flex;
|
||||
margin: 1rem auto 1.5rem;
|
||||
}
|
||||
|
||||
body.sidebar-collapsed .app-sidebar {
|
||||
display: none;
|
||||
}
|
||||
|
||||
body.sidebar-open {
|
||||
overflow: hidden;
|
||||
}
|
||||
|
||||
body.sidebar-open .app-sidebar {
|
||||
display: block;
|
||||
position: fixed;
|
||||
top: 0;
|
||||
left: 0;
|
||||
width: min(320px, 82vw);
|
||||
height: 100vh;
|
||||
overflow-y: auto;
|
||||
z-index: 900;
|
||||
box-shadow: 0 12px 30px rgba(8, 14, 25, 0.4);
|
||||
}
|
||||
|
||||
body.sidebar-open .sidebar-overlay {
|
||||
opacity: 1;
|
||||
pointer-events: auto;
|
||||
}
|
||||
|
||||
body.sidebar-open .app-main {
|
||||
position: relative;
|
||||
z-index: 950;
|
||||
}
|
||||
}
|
||||
|
||||
200
static/css/projects.css
Normal file
200
static/css/projects.css
Normal file
@@ -0,0 +1,200 @@
|
||||
:root {
|
||||
--card-bg: rgba(21, 27, 35, 0.8);
|
||||
--card-border: rgba(255, 255, 255, 0.08);
|
||||
--hover-highlight: rgba(241, 178, 26, 0.12);
|
||||
}
|
||||
|
||||
.header-actions {
|
||||
display: flex;
|
||||
gap: 0.75rem;
|
||||
flex-wrap: wrap;
|
||||
justify-content: flex-end;
|
||||
}
|
||||
|
||||
.project-metrics {
|
||||
display: grid;
|
||||
gap: 1.5rem;
|
||||
grid-template-columns: repeat(auto-fit, minmax(220px, 1fr));
|
||||
margin-bottom: 2rem;
|
||||
}
|
||||
|
||||
.metric-card {
|
||||
background: var(--card-bg);
|
||||
border-radius: var(--radius);
|
||||
padding: 1.5rem;
|
||||
box-shadow: var(--shadow);
|
||||
border: 1px solid var(--card-border);
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
gap: 0.35rem;
|
||||
}
|
||||
|
||||
.metric-card h2 {
|
||||
margin: 0;
|
||||
font-size: 1rem;
|
||||
color: var(--muted);
|
||||
text-transform: uppercase;
|
||||
letter-spacing: 0.08em;
|
||||
}
|
||||
|
||||
.metric-value {
|
||||
font-size: 2rem;
|
||||
font-weight: 700;
|
||||
margin: 0;
|
||||
}
|
||||
|
||||
.metric-caption {
|
||||
color: var(--color-text-subtle);
|
||||
font-size: 0.85rem;
|
||||
}
|
||||
|
||||
.project-form {
|
||||
background: var(--card-bg);
|
||||
border: 1px solid var(--card-border);
|
||||
border-radius: var(--radius);
|
||||
box-shadow: var(--shadow);
|
||||
padding: 1.75rem;
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
gap: 1.5rem;
|
||||
}
|
||||
|
||||
.definition-list {
|
||||
display: grid;
|
||||
grid-template-columns: repeat(auto-fit, minmax(240px, 1fr));
|
||||
gap: 1.25rem 2rem;
|
||||
}
|
||||
|
||||
.definition-list dt {
|
||||
font-weight: 600;
|
||||
color: var(--muted);
|
||||
margin-bottom: 0.2rem;
|
||||
text-transform: uppercase;
|
||||
font-size: 0.75rem;
|
||||
}
|
||||
|
||||
.definition-list dd {
|
||||
margin: 0;
|
||||
font-size: 1rem;
|
||||
}
|
||||
|
||||
.card {
|
||||
background: var(--card-bg);
|
||||
border: 1px solid var(--card-border);
|
||||
box-shadow: var(--shadow);
|
||||
border-radius: var(--radius);
|
||||
padding: 1.5rem;
|
||||
margin-bottom: 2rem;
|
||||
}
|
||||
|
||||
.card-header {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: space-between;
|
||||
margin-bottom: 1rem;
|
||||
}
|
||||
|
||||
.card-header h2 {
|
||||
margin: 0;
|
||||
}
|
||||
|
||||
.project-layout {
|
||||
display: grid;
|
||||
gap: 1.5rem;
|
||||
}
|
||||
|
||||
.table-responsive {
|
||||
overflow-x: auto;
|
||||
border-radius: var(--table-radius);
|
||||
}
|
||||
|
||||
.table {
|
||||
width: 100%;
|
||||
border-collapse: collapse;
|
||||
border-radius: var(--table-radius);
|
||||
overflow: hidden;
|
||||
box-shadow: var(--shadow);
|
||||
}
|
||||
|
||||
.table th,
|
||||
.table td {
|
||||
padding: 0.75rem 1rem;
|
||||
border-bottom: 1px solid var(--card-border);
|
||||
background: rgba(21, 27, 35, 0.85);
|
||||
}
|
||||
|
||||
.table tbody tr:hover {
|
||||
background: var(--hover-highlight);
|
||||
}
|
||||
|
||||
.table-link {
|
||||
color: var(--brand-2);
|
||||
text-decoration: none;
|
||||
margin-left: 0.5rem;
|
||||
}
|
||||
|
||||
.table-link:hover,
|
||||
.table-link:focus {
|
||||
text-decoration: underline;
|
||||
}
|
||||
|
||||
.text-right {
|
||||
text-align: right;
|
||||
}
|
||||
|
||||
@media (min-width: 960px) {
|
||||
.project-layout {
|
||||
grid-template-columns: 1.1fr 1.9fr;
|
||||
align-items: start;
|
||||
}
|
||||
|
||||
.header-actions {
|
||||
justify-content: flex-start;
|
||||
}
|
||||
}
|
||||
|
||||
.alert {
|
||||
padding: 0.75rem 1rem;
|
||||
border-radius: var(--radius-sm);
|
||||
margin-bottom: 1rem;
|
||||
}
|
||||
|
||||
.alert-error {
|
||||
background: rgba(209, 75, 75, 0.2);
|
||||
border: 1px solid rgba(209, 75, 75, 0.4);
|
||||
color: var(--color-text-invert);
|
||||
}
|
||||
|
||||
.form {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
gap: 1.25rem;
|
||||
}
|
||||
|
||||
.form-grid {
|
||||
display: grid;
|
||||
grid-template-columns: repeat(auto-fit, minmax(220px, 1fr));
|
||||
gap: 1.25rem;
|
||||
}
|
||||
|
||||
.form-group {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
gap: 0.5rem;
|
||||
}
|
||||
|
||||
.form-group input,
|
||||
.form-group select,
|
||||
.form-group textarea {
|
||||
padding: 0.75rem 0.85rem;
|
||||
border-radius: var(--radius-sm);
|
||||
border: 1px solid var(--card-border);
|
||||
background: rgba(8, 12, 19, 0.75);
|
||||
color: var(--text);
|
||||
}
|
||||
|
||||
.form-actions {
|
||||
display: flex;
|
||||
gap: 0.75rem;
|
||||
justify-content: flex-end;
|
||||
}
|
||||
197
static/css/scenarios.css
Normal file
197
static/css/scenarios.css
Normal file
@@ -0,0 +1,197 @@
|
||||
.scenario-meta {
|
||||
display: grid;
|
||||
grid-template-columns: repeat(auto-fit, minmax(220px, 1fr));
|
||||
gap: 1.25rem;
|
||||
}
|
||||
|
||||
.table {
|
||||
width: 100%;
|
||||
border-collapse: collapse;
|
||||
border-radius: var(--table-radius);
|
||||
overflow: hidden;
|
||||
box-shadow: var(--shadow);
|
||||
}
|
||||
|
||||
.table th,
|
||||
.table td {
|
||||
padding: 0.75rem 1rem;
|
||||
border-bottom: 1px solid var(--color-border);
|
||||
background: rgba(21, 27, 35, 0.85);
|
||||
}
|
||||
|
||||
.table tbody tr:hover {
|
||||
background: rgba(43, 165, 143, 0.12);
|
||||
}
|
||||
|
||||
.breadcrumb {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 0.5rem;
|
||||
font-size: 0.9rem;
|
||||
color: var(--muted);
|
||||
margin-bottom: 1.2rem;
|
||||
}
|
||||
|
||||
.breadcrumb a {
|
||||
color: var(--brand-2);
|
||||
text-decoration: none;
|
||||
}
|
||||
|
||||
.header-actions {
|
||||
display: flex;
|
||||
gap: 0.75rem;
|
||||
flex-wrap: wrap;
|
||||
justify-content: flex-end;
|
||||
}
|
||||
|
||||
.scenario-metrics {
|
||||
display: grid;
|
||||
gap: 1.5rem;
|
||||
grid-template-columns: repeat(auto-fit, minmax(220px, 1fr));
|
||||
margin-bottom: 2rem;
|
||||
}
|
||||
|
||||
.metric-card {
|
||||
background: rgba(21, 27, 35, 0.85);
|
||||
border-radius: var(--radius);
|
||||
padding: 1.5rem;
|
||||
box-shadow: var(--shadow);
|
||||
border: 1px solid var(--color-border);
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
gap: 0.35rem;
|
||||
}
|
||||
|
||||
.metric-card h2 {
|
||||
margin: 0;
|
||||
font-size: 1rem;
|
||||
color: var(--muted);
|
||||
text-transform: uppercase;
|
||||
letter-spacing: 0.08em;
|
||||
}
|
||||
|
||||
.metric-value {
|
||||
font-size: 2rem;
|
||||
font-weight: 700;
|
||||
margin: 0;
|
||||
}
|
||||
|
||||
.metric-caption {
|
||||
color: var(--color-text-subtle);
|
||||
font-size: 0.85rem;
|
||||
}
|
||||
|
||||
.scenario-filters {
|
||||
display: grid;
|
||||
gap: 0.75rem;
|
||||
margin-bottom: 1.5rem;
|
||||
}
|
||||
|
||||
.scenario-filters .filter-field {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
gap: 0.35rem;
|
||||
}
|
||||
|
||||
.scenario-filters .filter-actions {
|
||||
display: flex;
|
||||
gap: 0.5rem;
|
||||
flex-wrap: wrap;
|
||||
align-items: center;
|
||||
}
|
||||
|
||||
.scenario-filters input,
|
||||
.scenario-filters select {
|
||||
width: 100%;
|
||||
padding: 0.6rem 0.75rem;
|
||||
border-radius: var(--radius-sm);
|
||||
border: 1px solid var(--color-border);
|
||||
background: rgba(8, 12, 19, 0.75);
|
||||
color: var(--text);
|
||||
}
|
||||
|
||||
.scenario-form {
|
||||
background: rgba(21, 27, 35, 0.85);
|
||||
border: 1px solid var(--color-border);
|
||||
border-radius: var(--radius);
|
||||
box-shadow: var(--shadow);
|
||||
padding: 1.75rem;
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
gap: 1.5rem;
|
||||
}
|
||||
|
||||
.table-responsive {
|
||||
width: 100%;
|
||||
overflow-x: auto;
|
||||
-webkit-overflow-scrolling: touch;
|
||||
border-radius: var(--table-radius);
|
||||
margin: 0;
|
||||
}
|
||||
|
||||
.table-responsive .table {
|
||||
min-width: 640px;
|
||||
}
|
||||
|
||||
.table-responsive::-webkit-scrollbar {
|
||||
height: 6px;
|
||||
}
|
||||
|
||||
.table-responsive::-webkit-scrollbar-thumb {
|
||||
background: rgba(255, 255, 255, 0.2);
|
||||
border-radius: 999px;
|
||||
}
|
||||
|
||||
@media (min-width: 720px) {
|
||||
.scenario-filters {
|
||||
grid-template-columns: repeat(auto-fit, minmax(220px, 1fr));
|
||||
align-items: end;
|
||||
}
|
||||
|
||||
.scenario-filters .filter-actions {
|
||||
justify-content: flex-end;
|
||||
}
|
||||
|
||||
.table-responsive .table {
|
||||
min-width: 100%;
|
||||
}
|
||||
}
|
||||
|
||||
@media (max-width: 640px) {
|
||||
.breadcrumb {
|
||||
flex-wrap: wrap;
|
||||
gap: 0.35rem;
|
||||
}
|
||||
|
||||
.table th,
|
||||
.table td {
|
||||
padding: 0.55rem 0.65rem;
|
||||
font-size: 0.9rem;
|
||||
white-space: nowrap;
|
||||
}
|
||||
|
||||
.table tbody tr {
|
||||
border-radius: var(--radius-sm);
|
||||
}
|
||||
}
|
||||
|
||||
.scenario-layout {
|
||||
display: grid;
|
||||
gap: 1.5rem;
|
||||
}
|
||||
|
||||
.empty-state {
|
||||
color: var(--muted);
|
||||
font-style: italic;
|
||||
}
|
||||
|
||||
@media (min-width: 960px) {
|
||||
.header-actions {
|
||||
justify-content: flex-start;
|
||||
}
|
||||
|
||||
.scenario-layout {
|
||||
grid-template-columns: 1.1fr 1.9fr;
|
||||
align-items: start;
|
||||
}
|
||||
}
|
||||
BIN
static/img/logo_big.png
Normal file
BIN
static/img/logo_big.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 1.8 MiB |
116
static/js/projects.js
Normal file
116
static/js/projects.js
Normal file
@@ -0,0 +1,116 @@
|
||||
document.addEventListener("DOMContentLoaded", () => {
|
||||
const table = document.querySelector("[data-project-table]");
|
||||
const rows = table ? Array.from(table.querySelectorAll("tbody tr")) : [];
|
||||
const filterInput = document.querySelector("[data-project-filter]");
|
||||
|
||||
if (table && filterInput) {
|
||||
filterInput.addEventListener("input", () => {
|
||||
const query = filterInput.value.trim().toLowerCase();
|
||||
rows.forEach((row) => {
|
||||
const match = row.textContent.toLowerCase().includes(query);
|
||||
row.style.display = match ? "" : "none";
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
const sidebar = document.querySelector(".app-sidebar");
|
||||
const appMain = document.querySelector(".app-main");
|
||||
if (!sidebar || !appMain) {
|
||||
return;
|
||||
}
|
||||
|
||||
const body = document.body;
|
||||
const mobileQuery = window.matchMedia("(max-width: 900px)");
|
||||
let toggleButton = document.querySelector("[data-sidebar-toggle]");
|
||||
|
||||
if (!toggleButton) {
|
||||
toggleButton = document.createElement("button");
|
||||
toggleButton.type = "button";
|
||||
toggleButton.className = "sidebar-toggle";
|
||||
toggleButton.setAttribute("data-sidebar-toggle", "");
|
||||
toggleButton.setAttribute("aria-expanded", "false");
|
||||
toggleButton.setAttribute("aria-label", "Toggle primary navigation");
|
||||
toggleButton.hidden = true;
|
||||
toggleButton.innerHTML = [
|
||||
'<span class="sidebar-toggle-icon" aria-hidden="true"></span>',
|
||||
'<span class="sidebar-toggle-label">Menu</span>',
|
||||
].join("");
|
||||
appMain.insertBefore(toggleButton, appMain.firstChild);
|
||||
}
|
||||
|
||||
let overlay = document.querySelector("[data-sidebar-overlay]");
|
||||
if (!overlay) {
|
||||
overlay = document.createElement("div");
|
||||
overlay.className = "sidebar-overlay";
|
||||
overlay.setAttribute("data-sidebar-overlay", "");
|
||||
overlay.setAttribute("aria-hidden", "true");
|
||||
document.body.appendChild(overlay);
|
||||
}
|
||||
|
||||
const primaryNav = document.querySelector(".sidebar-nav");
|
||||
if (primaryNav) {
|
||||
if (!primaryNav.id) {
|
||||
primaryNav.id = "primary-navigation";
|
||||
}
|
||||
toggleButton.setAttribute("aria-controls", primaryNav.id);
|
||||
}
|
||||
|
||||
const openSidebar = () => {
|
||||
body.classList.remove("sidebar-collapsed");
|
||||
body.classList.add("sidebar-open");
|
||||
toggleButton.setAttribute("aria-expanded", "true");
|
||||
overlay.setAttribute("aria-hidden", "false");
|
||||
};
|
||||
|
||||
const closeSidebar = (focusToggle = false) => {
|
||||
body.classList.add("sidebar-collapsed");
|
||||
body.classList.remove("sidebar-open");
|
||||
toggleButton.setAttribute("aria-expanded", "false");
|
||||
overlay.setAttribute("aria-hidden", "true");
|
||||
if (focusToggle) {
|
||||
toggleButton.focus({ preventScroll: true });
|
||||
}
|
||||
};
|
||||
|
||||
const toggleSidebar = () => {
|
||||
if (body.classList.contains("sidebar-open")) {
|
||||
closeSidebar();
|
||||
} else {
|
||||
openSidebar();
|
||||
sidebar.setAttribute("aria-hidden", "false");
|
||||
}
|
||||
};
|
||||
|
||||
const applyResponsiveState = (mql) => {
|
||||
if (!mql.matches) {
|
||||
toggleButton.hidden = true;
|
||||
body.classList.remove("sidebar-open", "sidebar-collapsed");
|
||||
sidebar.setAttribute("aria-hidden", "true");
|
||||
overlay.setAttribute("aria-hidden", "true");
|
||||
sidebar.removeAttribute("aria-hidden");
|
||||
return;
|
||||
}
|
||||
|
||||
toggleButton.hidden = false;
|
||||
if (!body.classList.contains("sidebar-open")) {
|
||||
body.classList.add("sidebar-collapsed");
|
||||
sidebar.setAttribute("aria-hidden", "true");
|
||||
}
|
||||
};
|
||||
|
||||
toggleButton.addEventListener("click", toggleSidebar);
|
||||
overlay.addEventListener("click", () => closeSidebar());
|
||||
|
||||
document.addEventListener("keydown", (event) => {
|
||||
if (event.key === "Escape" && body.classList.contains("sidebar-open")) {
|
||||
closeSidebar(true);
|
||||
}
|
||||
});
|
||||
|
||||
applyResponsiveState(mobileQuery);
|
||||
if (typeof mobileQuery.addEventListener === "function") {
|
||||
mobileQuery.addEventListener("change", applyResponsiveState);
|
||||
} else if (typeof mobileQuery.addListener === "function") {
|
||||
mobileQuery.addListener(applyResponsiveState);
|
||||
}
|
||||
});
|
||||
132
templates/Dashboard.html
Normal file
132
templates/Dashboard.html
Normal file
@@ -0,0 +1,132 @@
|
||||
{% extends "base.html" %}
|
||||
{% block title %}Dashboard · CalMiner{% endblock %}
|
||||
|
||||
{% block head_extra %}
|
||||
<link rel="stylesheet" href="/static/css/dashboard.css" />
|
||||
{% endblock %}
|
||||
|
||||
{% block content %}
|
||||
<section class="page-header dashboard-header">
|
||||
<div>
|
||||
<h1>Welcome back</h1>
|
||||
<p class="page-subtitle">Monitor project progress and scenario insights at a glance.</p>
|
||||
</div>
|
||||
<div class="header-actions">
|
||||
<a class="btn primary" href="{{ url_for('projects.create_project_form') }}">New Project</a>
|
||||
<a class="btn" href="#">Import Data</a>
|
||||
</div>
|
||||
</section>
|
||||
|
||||
<section class="dashboard-metrics">
|
||||
<article class="metric-card">
|
||||
<h2>Total Projects</h2>
|
||||
<p class="metric-value">{{ metrics.total_projects }}</p>
|
||||
<span class="metric-caption">Across all operation types</span>
|
||||
</article>
|
||||
<article class="metric-card">
|
||||
<h2>Active Scenarios</h2>
|
||||
<p class="metric-value">{{ metrics.active_scenarios }}</p>
|
||||
<span class="metric-caption">Ready for analysis</span>
|
||||
</article>
|
||||
<article class="metric-card">
|
||||
<h2>Pending Simulations</h2>
|
||||
<p class="metric-value">{{ metrics.pending_simulations }}</p>
|
||||
<span class="metric-caption">Awaiting execution</span>
|
||||
</article>
|
||||
<article class="metric-card">
|
||||
<h2>Last Data Import</h2>
|
||||
<p class="metric-value">{{ metrics.last_import or '—' }}</p>
|
||||
<span class="metric-caption">UTC timestamp</span>
|
||||
</article>
|
||||
</section>
|
||||
|
||||
<section class="dashboard-grid">
|
||||
<div class="grid-main">
|
||||
<div class="card">
|
||||
<header class="card-header">
|
||||
<h2>Recent Projects</h2>
|
||||
<a class="btn btn-link" href="{{ url_for('projects.project_list_page') }}">View all</a>
|
||||
</header>
|
||||
{% if recent_projects %}
|
||||
<table class="table">
|
||||
<thead>
|
||||
<tr>
|
||||
<th>Project</th>
|
||||
<th>Operation</th>
|
||||
<th>Updated</th>
|
||||
</tr>
|
||||
</thead>
|
||||
<tbody>
|
||||
{% for project in recent_projects %}
|
||||
<tr>
|
||||
<td>
|
||||
<a class="table-link" href="{{ url_for('projects.view_project', project_id=project.id) }}">{{ project.name }}</a>
|
||||
</td>
|
||||
<td>{{ project.operation_type.value.replace('_', ' ') | title }}</td>
|
||||
<td>{{ project.updated_at.strftime('%Y-%m-%d') if project.updated_at else '—' }}</td>
|
||||
</tr>
|
||||
{% endfor %}
|
||||
</tbody>
|
||||
</table>
|
||||
{% else %}
|
||||
<p class="empty-state">No recent projects. <a href="{{ url_for('projects.create_project_form') }}">Create one now.</a></p>
|
||||
{% endif %}
|
||||
</div>
|
||||
|
||||
<div class="card">
|
||||
<header class="card-header">
|
||||
<h2>Simulation Pipeline</h2>
|
||||
</header>
|
||||
{% if simulation_updates %}
|
||||
<ul class="timeline">
|
||||
{% for update in simulation_updates %}
|
||||
<li>
|
||||
<span class="timeline-label">{{ update.timestamp_label or '—' }}</span>
|
||||
<div>
|
||||
<strong>{{ update.title }}</strong>
|
||||
<p>{{ update.description }}</p>
|
||||
</div>
|
||||
</li>
|
||||
{% endfor %}
|
||||
</ul>
|
||||
{% else %}
|
||||
<p class="empty-state">No simulation runs yet. Configure a scenario to start simulations.</p>
|
||||
{% endif %}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<aside class="grid-sidebar">
|
||||
<div class="card">
|
||||
<header class="card-header">
|
||||
<h2>Scenario Alerts</h2>
|
||||
</header>
|
||||
{% if scenario_alerts %}
|
||||
<ul class="alerts-list">
|
||||
{% for alert in scenario_alerts %}
|
||||
<li>
|
||||
<strong>{{ alert.title }}</strong>
|
||||
<p>{{ alert.message }}</p>
|
||||
{% if alert.link %}
|
||||
<a class="btn btn-link" href="{{ alert.link }}">Review</a>
|
||||
{% endif %}
|
||||
</li>
|
||||
{% endfor %}
|
||||
</ul>
|
||||
{% else %}
|
||||
<p class="empty-state">All scenarios look good. We'll highlight issues here.</p>
|
||||
{% endif %}
|
||||
</div>
|
||||
|
||||
<div class="card">
|
||||
<header class="card-header">
|
||||
<h2>Resources</h2>
|
||||
</header>
|
||||
<ul class="links-list">
|
||||
<li><a href="https://github.com/" target="_blank">CalMiner Repository</a></li>
|
||||
<li><a href="https://example.com/docs" target="_blank">Documentation</a></li>
|
||||
<li><a href="mailto:support@example.com">Contact Support</a></li>
|
||||
</ul>
|
||||
</div>
|
||||
</aside>
|
||||
</section>
|
||||
{% endblock %}
|
||||
@@ -20,6 +20,7 @@
|
||||
</div>
|
||||
</div>
|
||||
{% block scripts %}{% endblock %}
|
||||
<script src="/static/js/projects.js" defer></script>
|
||||
<script src="/static/js/theme.js"></script>
|
||||
</body>
|
||||
</html>
|
||||
|
||||
@@ -1,14 +1,22 @@
|
||||
{% extends "base.html" %}
|
||||
|
||||
{% block title %}Forgot Password{% endblock %}
|
||||
|
||||
{% block content %}
|
||||
{% extends "base.html" %} {% block title %}Forgot Password{% endblock %} {%
|
||||
block content %}
|
||||
<div class="container">
|
||||
<h1>Forgot Password</h1>
|
||||
<form id="forgot-password-form">
|
||||
{% if errors %}
|
||||
<div class="alert alert-error">
|
||||
<ul>
|
||||
{% for error in errors %}
|
||||
<li>{{ error }}</li>
|
||||
{% endfor %}
|
||||
</ul>
|
||||
</div>
|
||||
{% endif %} {% if message %}
|
||||
<div class="alert alert-info">{{ message }}</div>
|
||||
{% endif %}
|
||||
<form id="forgot-password-form" method="post" action="{{ form_action }}">
|
||||
<div class="form-group">
|
||||
<label for="email">Email:</label>
|
||||
<input type="email" id="email" name="email" required>
|
||||
<input type="email" id="email" name="email" required />
|
||||
</div>
|
||||
<button type="submit">Reset Password</button>
|
||||
</form>
|
||||
|
||||
@@ -1,18 +1,30 @@
|
||||
{% extends "base.html" %}
|
||||
|
||||
{% block title %}Login{% endblock %}
|
||||
|
||||
{% block content %}
|
||||
{% extends "base.html" %} {% block title %}Login{% endblock %} {% block content
|
||||
%}
|
||||
<div class="container">
|
||||
<h1>Login</h1>
|
||||
<form id="login-form">
|
||||
{% if errors %}
|
||||
<div class="alert alert-error">
|
||||
<ul>
|
||||
{% for error in errors %}
|
||||
<li>{{ error }}</li>
|
||||
{% endfor %}
|
||||
</ul>
|
||||
</div>
|
||||
{% endif %}
|
||||
<form id="login-form" method="post" action="{{ form_action }}">
|
||||
<div class="form-group">
|
||||
<label for="username">Username:</label>
|
||||
<input type="text" id="username" name="username" required>
|
||||
<input
|
||||
type="text"
|
||||
id="username"
|
||||
name="username"
|
||||
value="{{ username | default('') }}"
|
||||
required
|
||||
/>
|
||||
</div>
|
||||
<div class="form-group">
|
||||
<label for="password">Password:</label>
|
||||
<input type="password" id="password" name="password" required>
|
||||
<input type="password" id="password" name="password" required />
|
||||
</div>
|
||||
<button type="submit">Login</button>
|
||||
</form>
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
<div class="sidebar-inner">
|
||||
<div class="sidebar-brand">
|
||||
<a class="sidebar-brand" href="{{ request.url_for('dashboard.home') }}">
|
||||
<span class="brand-logo" aria-hidden="true">CM</span>
|
||||
<div class="brand-text">
|
||||
<span class="brand-title">CalMiner</span>
|
||||
<span class="brand-subtitle">Mining Planner</span>
|
||||
</div>
|
||||
</div>
|
||||
</a>
|
||||
{% include "partials/sidebar_nav.html" %}
|
||||
</div>
|
||||
|
||||
@@ -1,41 +1,88 @@
|
||||
{% set nav_groups = [ { "label": "Dashboard", "links": [ {"href": "/", "label":
|
||||
"Dashboard"}, ], }, { "label": "Overview", "links": [ {"href": "/ui/parameters",
|
||||
"label": "Parameters"}, {"href": "/ui/costs", "label": "Costs"}, {"href":
|
||||
"/ui/consumption", "label": "Consumption"}, {"href": "/ui/production", "label":
|
||||
"Production"}, { "href": "/ui/equipment", "label": "Equipment", "children": [
|
||||
{"href": "/ui/maintenance", "label": "Maintenance"}, ], }, ], }, { "label":
|
||||
"Simulations", "links": [ {"href": "/ui/simulations", "label": "Simulations"},
|
||||
], }, { "label": "Analytics", "links": [ {"href": "/ui/reporting", "label":
|
||||
"Reporting"}, ], }, { "label": "Settings", "links": [ { "href": "/ui/settings",
|
||||
"label": "Settings", "children": [ {"href": "/theme-settings", "label":
|
||||
"Themes"}, {"href": "/ui/currencies", "label": "Currency Management"}, ], }, ],
|
||||
}, ] %}
|
||||
{% set dashboard_href = request.url_for('dashboard.home') if request else '/' %}
|
||||
{% set projects_href = request.url_for('projects.project_list_page') if request else '/projects/ui' %}
|
||||
{% set project_create_href = request.url_for('projects.create_project_form') if request else '/projects/create' %}
|
||||
{% set auth_session = request.state.auth_session if request else None %}
|
||||
{% set is_authenticated = auth_session and auth_session.is_authenticated %}
|
||||
|
||||
{% if is_authenticated %}
|
||||
{% set logout_href = request.url_for('auth.logout') if request else '/logout' %}
|
||||
{% set account_links = [
|
||||
{"href": logout_href, "label": "Logout", "match_prefix": "/logout"}
|
||||
] %}
|
||||
{% else %}
|
||||
{% set login_href = request.url_for('auth.login_form') if request else '/login' %}
|
||||
{% set register_href = request.url_for('auth.register_form') if request else '/register' %}
|
||||
{% set forgot_href = request.url_for('auth.password_reset_request_form') if request else '/forgot-password' %}
|
||||
{% set account_links = [
|
||||
{"href": login_href, "label": "Login", "match_prefix": "/login"},
|
||||
{"href": register_href, "label": "Register", "match_prefix": "/register"},
|
||||
{"href": forgot_href, "label": "Forgot Password", "match_prefix": "/forgot-password"}
|
||||
] %}
|
||||
{% endif %}
|
||||
{% set nav_groups = [
|
||||
{
|
||||
"label": "Workspace",
|
||||
"links": [
|
||||
{"href": dashboard_href, "label": "Dashboard", "match_prefix": "/"},
|
||||
{"href": projects_href, "label": "Projects", "match_prefix": "/projects"},
|
||||
{"href": project_create_href, "label": "New Project", "match_prefix": "/projects/create"}
|
||||
]
|
||||
},
|
||||
{
|
||||
"label": "Insights",
|
||||
"links": [
|
||||
{"href": "/ui/simulations", "label": "Simulations"},
|
||||
{"href": "/ui/reporting", "label": "Reporting"}
|
||||
]
|
||||
},
|
||||
{
|
||||
"label": "Configuration",
|
||||
"links": [
|
||||
{
|
||||
"href": "/ui/settings",
|
||||
"label": "Settings",
|
||||
"children": [
|
||||
{"href": "/theme-settings", "label": "Themes"},
|
||||
{"href": "/ui/currencies", "label": "Currency Management"}
|
||||
]
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"label": "Account",
|
||||
"links": account_links
|
||||
}
|
||||
] %}
|
||||
|
||||
<nav class="sidebar-nav" aria-label="Primary navigation">
|
||||
{% set current_path = request.url.path if request else "" %} {% for group in
|
||||
nav_groups %}
|
||||
{% set current_path = request.url.path if request else '' %}
|
||||
{% for group in nav_groups %}
|
||||
{% if group.links %}
|
||||
<div class="sidebar-section">
|
||||
<div class="sidebar-section-label">{{ group.label }}</div>
|
||||
<div class="sidebar-section-links">
|
||||
{% for link in group.links %} {% set href = link.href %} {% if href == "/"
|
||||
%} {% set is_active = current_path == "/" %} {% else %} {% set is_active =
|
||||
current_path.startswith(href) %} {% endif %}
|
||||
{% for link in group.links %}
|
||||
{% set href = link.href %}
|
||||
{% set match_prefix = link.get('match_prefix', href) %}
|
||||
{% if match_prefix == '/' %}
|
||||
{% set is_active = current_path == '/' %}
|
||||
{% else %}
|
||||
{% set is_active = current_path.startswith(match_prefix) %}
|
||||
{% endif %}
|
||||
<div class="sidebar-link-block">
|
||||
<a
|
||||
href="{{ href }}"
|
||||
class="sidebar-link{% if is_active %} is-active{% endif %}"
|
||||
>
|
||||
<a href="{{ href }}" class="sidebar-link{% if is_active %} is-active{% endif %}">
|
||||
{{ link.label }}
|
||||
</a>
|
||||
{% if link.children %}
|
||||
<div class="sidebar-sublinks">
|
||||
{% for child in link.children %} {% if child.href == "/" %} {% set
|
||||
child_active = current_path == "/" %} {% else %} {% set child_active =
|
||||
current_path.startswith(child.href) %} {% endif %}
|
||||
<a
|
||||
href="{{ child.href }}"
|
||||
class="sidebar-sublink{% if child_active %} is-active{% endif %}"
|
||||
>
|
||||
{% for child in link.children %}
|
||||
{% set child_prefix = child.get('match_prefix', child.href) %}
|
||||
{% if child_prefix == '/' %}
|
||||
{% set child_active = current_path == '/' %}
|
||||
{% else %}
|
||||
{% set child_active = current_path.startswith(child_prefix) %}
|
||||
{% endif %}
|
||||
<a href="{{ child.href }}" class="sidebar-sublink{% if child_active %} is-active{% endif %}">
|
||||
{{ child.label }}
|
||||
</a>
|
||||
{% endfor %}
|
||||
@@ -45,5 +92,6 @@
|
||||
{% endfor %}
|
||||
</div>
|
||||
</div>
|
||||
{% endif %}
|
||||
{% endfor %}
|
||||
</nav>
|
||||
|
||||
113
templates/projects/detail.html
Normal file
113
templates/projects/detail.html
Normal file
@@ -0,0 +1,113 @@
|
||||
{% extends "base.html" %}
|
||||
{% block title %}{{ project.name }} · Project · CalMiner{% endblock %}
|
||||
|
||||
{% block head_extra %}
|
||||
<link rel="stylesheet" href="/static/css/projects.css" />
|
||||
{% endblock %}
|
||||
|
||||
{% block content %}
|
||||
<nav class="breadcrumb">
|
||||
<a href="{{ url_for('projects.project_list_page') }}">Projects</a>
|
||||
<span aria-current="page">{{ project.name }}</span>
|
||||
</nav>
|
||||
|
||||
<header class="page-header">
|
||||
<div>
|
||||
<h1>{{ project.name }}</h1>
|
||||
<p class="text-muted">{{ project.operation_type.value.replace('_', ' ') | title }}</p>
|
||||
</div>
|
||||
<div class="header-actions">
|
||||
<a class="btn" href="{{ url_for('projects.edit_project_form', project_id=project.id) }}">Edit Project</a>
|
||||
<a class="btn primary" href="{{ url_for('scenarios.create_scenario_form', project_id=project.id) }}">New Scenario</a>
|
||||
</div>
|
||||
</header>
|
||||
|
||||
<section class="project-metrics">
|
||||
<article class="metric-card">
|
||||
<h2>Total Scenarios</h2>
|
||||
<p class="metric-value">{{ scenario_stats.total }}</p>
|
||||
<span class="metric-caption">Across this project</span>
|
||||
</article>
|
||||
<article class="metric-card">
|
||||
<h2>Active</h2>
|
||||
<p class="metric-value">{{ scenario_stats.active }}</p>
|
||||
<span class="metric-caption">Currently live analyses</span>
|
||||
</article>
|
||||
<article class="metric-card">
|
||||
<h2>Draft</h2>
|
||||
<p class="metric-value">{{ scenario_stats.draft }}</p>
|
||||
<span class="metric-caption">Awaiting validation</span>
|
||||
</article>
|
||||
<article class="metric-card">
|
||||
<h2>Archived</h2>
|
||||
<p class="metric-value">{{ scenario_stats.archived }}</p>
|
||||
<span class="metric-caption">Historical references</span>
|
||||
</article>
|
||||
</section>
|
||||
|
||||
<div class="project-layout">
|
||||
<section class="card">
|
||||
<h2>Project Overview</h2>
|
||||
<dl class="definition-list">
|
||||
<div>
|
||||
<dt>Location</dt>
|
||||
<dd>{{ project.location or '—' }}</dd>
|
||||
</div>
|
||||
<div>
|
||||
<dt>Description</dt>
|
||||
<dd>{{ project.description or 'No description provided.' }}</dd>
|
||||
</div>
|
||||
<div>
|
||||
<dt>Created</dt>
|
||||
<dd>{{ project.created_at.strftime('%Y-%m-%d %H:%M') }}</dd>
|
||||
</div>
|
||||
<div>
|
||||
<dt>Updated</dt>
|
||||
<dd>{{ project.updated_at.strftime('%Y-%m-%d %H:%M') }}</dd>
|
||||
</div>
|
||||
<div>
|
||||
<dt>Latest Scenario Update</dt>
|
||||
<dd>{{ scenario_stats.latest_update.strftime('%Y-%m-%d %H:%M') if scenario_stats.latest_update else '—' }}</dd>
|
||||
</div>
|
||||
</dl>
|
||||
</section>
|
||||
|
||||
<section class="card">
|
||||
<header class="card-header">
|
||||
<h2>Scenarios</h2>
|
||||
<a class="btn" href="{{ url_for('scenarios.create_scenario_form', project_id=project.id) }}">Add Scenario</a>
|
||||
</header>
|
||||
{% if scenarios %}
|
||||
<div class="table-responsive">
|
||||
<table class="table">
|
||||
<thead>
|
||||
<tr>
|
||||
<th>Name</th>
|
||||
<th>Status</th>
|
||||
<th>Currency</th>
|
||||
<th>Primary Resource</th>
|
||||
<th class="text-right">Actions</th>
|
||||
</tr>
|
||||
</thead>
|
||||
<tbody>
|
||||
{% for scenario in scenarios %}
|
||||
<tr>
|
||||
<td>{{ scenario.name }}</td>
|
||||
<td>{{ scenario.status.value.title() }}</td>
|
||||
<td>{{ scenario.currency or '—' }}</td>
|
||||
<td>{{ scenario.primary_resource.value.replace('_', ' ') | title if scenario.primary_resource else '—' }}</td>
|
||||
<td class="text-right">
|
||||
<a class="table-link" href="{{ url_for('scenarios.view_scenario', scenario_id=scenario.id) }}">View</a>
|
||||
<a class="table-link" href="{{ url_for('scenarios.edit_scenario_form', scenario_id=scenario.id) }}">Edit</a>
|
||||
</td>
|
||||
</tr>
|
||||
{% endfor %}
|
||||
</tbody>
|
||||
</table>
|
||||
</div>
|
||||
{% else %}
|
||||
<p class="empty-state">No scenarios yet. <a href="{{ url_for('scenarios.create_scenario_form', project_id=project.id) }}">Create the first scenario.</a></p>
|
||||
{% endif %}
|
||||
</section>
|
||||
</div>
|
||||
{% endblock %}
|
||||
70
templates/projects/form.html
Normal file
70
templates/projects/form.html
Normal file
@@ -0,0 +1,70 @@
|
||||
{% extends "base.html" %}
|
||||
{% block title %}{% if project %}Edit {{ project.name }}{% else %}New Project{% endif %} · CalMiner{% endblock %}
|
||||
|
||||
{% block head_extra %}
|
||||
<link rel="stylesheet" href="/static/css/projects.css" />
|
||||
{% endblock %}
|
||||
|
||||
{% block content %}
|
||||
<nav class="breadcrumb">
|
||||
<a href="{{ url_for('projects.project_list_page') }}">Projects</a>
|
||||
{% if project %}
|
||||
<a href="{{ url_for('projects.view_project', project_id=project.id) }}">{{ project.name }}</a>
|
||||
<span aria-current="page">Edit</span>
|
||||
{% else %}
|
||||
<span aria-current="page">New</span>
|
||||
{% endif %}
|
||||
</nav>
|
||||
|
||||
<header class="page-header">
|
||||
<div>
|
||||
<h1>{% if project %}Edit Project{% else %}Create Project{% endif %}</h1>
|
||||
<p class="text-muted">Provide core information about the mining project.</p>
|
||||
</div>
|
||||
<div class="header-actions">
|
||||
<a class="btn" href="{{ cancel_url }}">Cancel</a>
|
||||
<button class="btn primary" type="submit">Save Project</button>
|
||||
</div>
|
||||
</header>
|
||||
|
||||
{% if error %}
|
||||
<div class="alert alert-error">{{ error }}</div>
|
||||
{% endif %}
|
||||
|
||||
{% if error %}
|
||||
<div class="alert alert-error">{{ error }}</div>
|
||||
{% endif %}
|
||||
|
||||
<form class="form project-form" method="post" action="{{ form_action }}">
|
||||
<div class="form-grid">
|
||||
<div class="form-group">
|
||||
<label for="name">Name</label>
|
||||
<input id="name" name="name" type="text" required value="{{ project.name if project else '' }}" />
|
||||
</div>
|
||||
|
||||
<div class="form-group">
|
||||
<label for="location">Location</label>
|
||||
<input id="location" name="location" type="text" value="{{ project.location if project else '' }}" />
|
||||
</div>
|
||||
|
||||
<div class="form-group">
|
||||
<label for="operation_type">Operation Type</label>
|
||||
<select id="operation_type" name="operation_type" required>
|
||||
{% for value, label in operation_types %}
|
||||
<option value="{{ value }}" {% if project and project.operation_type.value == value %}selected{% endif %}>{{ label }}</option>
|
||||
{% endfor %}
|
||||
</select>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div class="form-group">
|
||||
<label for="description">Description</label>
|
||||
<textarea id="description" name="description" rows="5">{{ project.description if project else '' }}</textarea>
|
||||
</div>
|
||||
|
||||
<div class="form-actions">
|
||||
<a class="btn" href="{{ cancel_url }}">Cancel</a>
|
||||
<button class="btn primary" type="submit">Save Project</button>
|
||||
</div>
|
||||
</form>
|
||||
{% endblock %}
|
||||
54
templates/projects/list.html
Normal file
54
templates/projects/list.html
Normal file
@@ -0,0 +1,54 @@
|
||||
{% extends "base.html" %}
|
||||
{% block title %}Projects · CalMiner{% endblock %}
|
||||
|
||||
{% block head_extra %}
|
||||
<link rel="stylesheet" href="/static/css/projects.css" />
|
||||
{% endblock %}
|
||||
|
||||
{% block content %}
|
||||
<section class="page-header">
|
||||
<div>
|
||||
<h1>Projects</h1>
|
||||
<p class="text-muted">Manage mining projects and explore their scenarios.</p>
|
||||
</div>
|
||||
<div class="actions">
|
||||
<input
|
||||
type="search"
|
||||
class="form-control"
|
||||
placeholder="Filter projects..."
|
||||
data-project-filter
|
||||
/>
|
||||
<a class="btn btn-primary" href="{{ url_for('projects.create_project_form') }}">New Project</a>
|
||||
</div>
|
||||
</section>
|
||||
|
||||
{% if projects %}
|
||||
<table class="projects-table" data-project-table>
|
||||
<thead>
|
||||
<tr>
|
||||
<th>Name</th>
|
||||
<th>Location</th>
|
||||
<th>Type</th>
|
||||
<th>Scenarios</th>
|
||||
<th></th>
|
||||
</tr>
|
||||
</thead>
|
||||
<tbody>
|
||||
{% for project in projects %}
|
||||
<tr>
|
||||
<td>{{ project.name }}</td>
|
||||
<td>{{ project.location or '—' }}</td>
|
||||
<td>{{ project.operation_type.value.replace('_', ' ') | title }}</td>
|
||||
<td>{{ project.scenario_count }}</td>
|
||||
<td class="text-right">
|
||||
<a class="btn btn-link" href="{{ url_for('projects.view_project', project_id=project.id) }}">View</a>
|
||||
<a class="btn btn-link" href="{{ url_for('projects.edit_project_form', project_id=project.id) }}">Edit</a>
|
||||
</td>
|
||||
</tr>
|
||||
{% endfor %}
|
||||
</tbody>
|
||||
</table>
|
||||
{% else %}
|
||||
<p>No projects yet. <a href="{{ url_for('projects.create_project_form') }}">Create your first project.</a></p>
|
||||
{% endif %}
|
||||
{% endblock %}
|
||||
@@ -1,22 +1,40 @@
|
||||
{% extends "base.html" %}
|
||||
|
||||
{% block title %}Register{% endblock %}
|
||||
|
||||
{% block content %}
|
||||
{% extends "base.html" %} {% block title %}Register{% endblock %} {% block
|
||||
content %}
|
||||
<div class="container">
|
||||
<h1>Register</h1>
|
||||
<form id="register-form">
|
||||
{% if errors %}
|
||||
<div class="alert alert-error">
|
||||
<ul>
|
||||
{% for error in errors %}
|
||||
<li>{{ error }}</li>
|
||||
{% endfor %}
|
||||
</ul>
|
||||
</div>
|
||||
{% endif %}
|
||||
<form id="register-form" method="post" action="{{ form_action }}">
|
||||
<div class="form-group">
|
||||
<label for="username">Username:</label>
|
||||
<input type="text" id="username" name="username" required>
|
||||
<input
|
||||
type="text"
|
||||
id="username"
|
||||
name="username"
|
||||
value="{{ form_data.username if form_data else '' }}"
|
||||
required
|
||||
/>
|
||||
</div>
|
||||
<div class="form-group">
|
||||
<label for="email">Email:</label>
|
||||
<input type="email" id="email" name="email" required>
|
||||
<input
|
||||
type="email"
|
||||
id="email"
|
||||
name="email"
|
||||
value="{{ form_data.email if form_data else '' }}"
|
||||
required
|
||||
/>
|
||||
</div>
|
||||
<div class="form-group">
|
||||
<label for="password">Password:</label>
|
||||
<input type="password" id="password" name="password" required>
|
||||
<input type="password" id="password" name="password" required />
|
||||
</div>
|
||||
<button type="submit">Register</button>
|
||||
</form>
|
||||
|
||||
36
templates/reset_password.html
Normal file
36
templates/reset_password.html
Normal file
@@ -0,0 +1,36 @@
|
||||
{% extends "base.html" %} {% block title %}Reset Password{% endblock %} {% block
|
||||
content %}
|
||||
<div class="container">
|
||||
<h1>Reset Password</h1>
|
||||
{% if errors %}
|
||||
<div class="alert alert-error">
|
||||
<ul>
|
||||
{% for error in errors %}
|
||||
<li>{{ error }}</li>
|
||||
{% endfor %}
|
||||
</ul>
|
||||
</div>
|
||||
{% endif %}
|
||||
<form id="reset-password-form" method="post" action="{{ form_action }}">
|
||||
<input type="hidden" name="token" value="{{ token | default('') }}" />
|
||||
<div class="form-group">
|
||||
<label for="password">New Password:</label>
|
||||
<input type="password" id="password" name="password" required />
|
||||
</div>
|
||||
<div class="form-group">
|
||||
<label for="confirm_password">Confirm Password:</label>
|
||||
<input
|
||||
type="password"
|
||||
id="confirm_password"
|
||||
name="confirm_password"
|
||||
required
|
||||
/>
|
||||
</div>
|
||||
<button type="submit">Update Password</button>
|
||||
</form>
|
||||
<p>
|
||||
Remembered your password?
|
||||
<a href="{{ request.url_for('auth.login_form') }}">Return to login</a>
|
||||
</p>
|
||||
</div>
|
||||
{% endblock %}
|
||||
134
templates/scenarios/detail.html
Normal file
134
templates/scenarios/detail.html
Normal file
@@ -0,0 +1,134 @@
|
||||
{% extends "base.html" %}
|
||||
{% block title %}{{ scenario.name }} · Scenario · CalMiner{% endblock %}
|
||||
|
||||
{% block head_extra %}
|
||||
<link rel="stylesheet" href="/static/css/scenarios.css" />
|
||||
{% endblock %}
|
||||
|
||||
{% block content %}
|
||||
<nav class="breadcrumb">
|
||||
<a href="{{ url_for('projects.project_list_page') }}">Projects</a>
|
||||
<a href="{{ url_for('projects.view_project', project_id=scenario.project_id) }}">{{ project.name }}</a>
|
||||
<span aria-current="page">{{ scenario.name }}</span>
|
||||
</nav>
|
||||
|
||||
<header class="page-header">
|
||||
<div>
|
||||
<h1>{{ scenario.name }}</h1>
|
||||
<p class="text-muted">Status: {{ scenario.status.value.title() }}</p>
|
||||
</div>
|
||||
<div class="header-actions">
|
||||
<a class="btn" href="{{ url_for('projects.view_project', project_id=project.id) }}">Back to Project</a>
|
||||
<a class="btn primary" href="{{ url_for('scenarios.edit_scenario_form', scenario_id=scenario.id) }}">Edit Scenario</a>
|
||||
</div>
|
||||
</header>
|
||||
|
||||
<section class="scenario-metrics">
|
||||
<article class="metric-card">
|
||||
<h2>Financial Inputs</h2>
|
||||
<p class="metric-value">{{ scenario_metrics.financial_count }}</p>
|
||||
<span class="metric-caption">Line items captured</span>
|
||||
</article>
|
||||
<article class="metric-card">
|
||||
<h2>Simulation Parameters</h2>
|
||||
<p class="metric-value">{{ scenario_metrics.parameter_count }}</p>
|
||||
<span class="metric-caption">Inputs driving forecasts</span>
|
||||
</article>
|
||||
<article class="metric-card">
|
||||
<h2>Currency</h2>
|
||||
<p class="metric-value">{{ scenario_metrics.currency or '—' }}</p>
|
||||
<span class="metric-caption">Financial reporting</span>
|
||||
</article>
|
||||
<article class="metric-card">
|
||||
<h2>Primary Resource</h2>
|
||||
<p class="metric-value">{{ scenario_metrics.primary_resource or '—' }}</p>
|
||||
<span class="metric-caption">Scenario focus</span>
|
||||
</article>
|
||||
</section>
|
||||
|
||||
<div class="scenario-layout">
|
||||
<section class="card">
|
||||
<h2>Scenario Details</h2>
|
||||
<dl class="definition-list">
|
||||
<div>
|
||||
<dt>Description</dt>
|
||||
<dd>{{ scenario.description or 'No description provided.' }}</dd>
|
||||
</div>
|
||||
<div>
|
||||
<dt>Timeline</dt>
|
||||
<dd>
|
||||
{{ scenario.start_date or '—' }} → {{ scenario.end_date or '—' }}
|
||||
</dd>
|
||||
</div>
|
||||
<div>
|
||||
<dt>Discount Rate</dt>
|
||||
<dd>{{ scenario.discount_rate or '—' }}</dd>
|
||||
</div>
|
||||
<div>
|
||||
<dt>Last Updated</dt>
|
||||
<dd>{{ scenario.updated_at.strftime('%Y-%m-%d %H:%M') if scenario.updated_at else '—' }}</dd>
|
||||
</div>
|
||||
</dl>
|
||||
</section>
|
||||
|
||||
<section class="card">
|
||||
<h2>Financial Inputs</h2>
|
||||
{% if financial_inputs %}
|
||||
<div class="table-responsive">
|
||||
<table class="table">
|
||||
<thead>
|
||||
<tr>
|
||||
<th>Name</th>
|
||||
<th>Category</th>
|
||||
<th>Amount</th>
|
||||
<th>Currency</th>
|
||||
</tr>
|
||||
</thead>
|
||||
<tbody>
|
||||
{% for item in financial_inputs %}
|
||||
<tr>
|
||||
<td>{{ item.name }}</td>
|
||||
<td>{{ item.category.value.title() }}</td>
|
||||
<td>{{ '{:,.2f}'.format(item.amount) }}</td>
|
||||
<td>{{ item.currency or '—' }}</td>
|
||||
</tr>
|
||||
{% endfor %}
|
||||
</tbody>
|
||||
</table>
|
||||
</div>
|
||||
{% else %}
|
||||
<p class="empty-state">No financial inputs recorded yet.</p>
|
||||
{% endif %}
|
||||
</section>
|
||||
|
||||
<section class="card">
|
||||
<h2>Simulation Parameters</h2>
|
||||
{% if simulation_parameters %}
|
||||
<div class="table-responsive">
|
||||
<table class="table">
|
||||
<thead>
|
||||
<tr>
|
||||
<th>Name</th>
|
||||
<th>Distribution</th>
|
||||
<th>Variable</th>
|
||||
<th>Resource</th>
|
||||
</tr>
|
||||
</thead>
|
||||
<tbody>
|
||||
{% for param in simulation_parameters %}
|
||||
<tr>
|
||||
<td>{{ param.name }}</td>
|
||||
<td>{{ param.distribution.value.title() }}</td>
|
||||
<td>{{ param.variable.value.replace('_', ' ') | title if param.variable else '—' }}</td>
|
||||
<td>{{ param.resource_type.value.replace('_', ' ') | title if param.resource_type else '—' }}</td>
|
||||
</tr>
|
||||
{% endfor %}
|
||||
</tbody>
|
||||
</table>
|
||||
</div>
|
||||
{% else %}
|
||||
<p class="empty-state">No simulation parameters defined.</p>
|
||||
{% endif %}
|
||||
</section>
|
||||
</div>
|
||||
{% endblock %}
|
||||
91
templates/scenarios/form.html
Normal file
91
templates/scenarios/form.html
Normal file
@@ -0,0 +1,91 @@
|
||||
{% extends "base.html" %}
|
||||
{% block title %}{% if scenario %}Edit {{ scenario.name }}{% else %}New Scenario{% endif %} · CalMiner{% endblock %}
|
||||
|
||||
{% block head_extra %}
|
||||
<link rel="stylesheet" href="/static/css/scenarios.css" />
|
||||
{% endblock %}
|
||||
|
||||
{% block content %}
|
||||
<nav class="breadcrumb">
|
||||
<a href="{{ url_for('projects.project_list_page') }}">Projects</a>
|
||||
<a href="{{ url_for('projects.view_project', project_id=project.id) }}">{{ project.name }}</a>
|
||||
{% if scenario %}
|
||||
<span aria-current="page">Edit Scenario</span>
|
||||
{% else %}
|
||||
<span aria-current="page">New Scenario</span>
|
||||
{% endif %}
|
||||
</nav>
|
||||
|
||||
<header class="page-header">
|
||||
<div>
|
||||
<h1>{% if scenario %}Edit Scenario{% else %}Create Scenario{% endif %}</h1>
|
||||
<p class="text-muted">Configure assumptions and metadata for this scenario.</p>
|
||||
</div>
|
||||
<div class="header-actions">
|
||||
<a class="btn" href="{{ cancel_url }}">Cancel</a>
|
||||
<button class="btn primary" type="submit">Save Scenario</button>
|
||||
</div>
|
||||
</header>
|
||||
|
||||
{% if error %}
|
||||
<div class="alert alert-error">{{ error }}</div>
|
||||
{% endif %}
|
||||
|
||||
<form class="form scenario-form" method="post" action="{{ form_action }}">
|
||||
<div class="form-grid">
|
||||
<div class="form-group">
|
||||
<label for="name">Name</label>
|
||||
<input id="name" name="name" type="text" required value="{{ scenario.name if scenario else '' }}" />
|
||||
</div>
|
||||
|
||||
<div class="form-group">
|
||||
<label for="status">Status</label>
|
||||
<select id="status" name="status">
|
||||
{% for value, label in scenario_statuses %}
|
||||
<option value="{{ value }}" {% if scenario and scenario.status.value == value %}selected{% endif %}>{{ label }}</option>
|
||||
{% endfor %}
|
||||
</select>
|
||||
</div>
|
||||
|
||||
<div class="form-group">
|
||||
<label for="currency">Currency</label>
|
||||
<input id="currency" name="currency" type="text" maxlength="3" value="{{ scenario.currency if scenario else '' }}" />
|
||||
</div>
|
||||
|
||||
<div class="form-group">
|
||||
<label for="primary_resource">Primary Resource</label>
|
||||
<select id="primary_resource" name="primary_resource">
|
||||
<option value="">—</option>
|
||||
{% for value, label in resource_types %}
|
||||
<option value="{{ value }}" {% if scenario and scenario.primary_resource and scenario.primary_resource.value == value %}selected{% endif %}>{{ label }}</option>
|
||||
{% endfor %}
|
||||
</select>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div class="form-grid">
|
||||
<div class="form-group">
|
||||
<label for="start_date">Start Date</label>
|
||||
<input id="start_date" name="start_date" type="date" value="{{ scenario.start_date if scenario else '' }}" />
|
||||
</div>
|
||||
<div class="form-group">
|
||||
<label for="end_date">End Date</label>
|
||||
<input id="end_date" name="end_date" type="date" value="{{ scenario.end_date if scenario else '' }}" />
|
||||
</div>
|
||||
<div class="form-group">
|
||||
<label for="discount_rate">Discount Rate (%)</label>
|
||||
<input id="discount_rate" name="discount_rate" type="number" step="0.01" value="{{ scenario.discount_rate if scenario else '' }}" />
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div class="form-group">
|
||||
<label for="description">Description</label>
|
||||
<textarea id="description" name="description" rows="5">{{ scenario.description if scenario else '' }}</textarea>
|
||||
</div>
|
||||
|
||||
<div class="form-actions">
|
||||
<a class="btn" href="{{ cancel_url }}">Cancel</a>
|
||||
<button class="btn primary" type="submit">Save Scenario</button>
|
||||
</div>
|
||||
</form>
|
||||
{% endblock %}
|
||||
99
tests/conftest.py
Normal file
99
tests/conftest.py
Normal file
@@ -0,0 +1,99 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable, Iterator
|
||||
|
||||
import pytest
|
||||
from fastapi import FastAPI, Request
|
||||
from fastapi.testclient import TestClient
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.engine import Engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from sqlalchemy.pool import StaticPool
|
||||
|
||||
from config.database import Base
|
||||
from dependencies import get_auth_session, get_unit_of_work
|
||||
from models import User
|
||||
from routes.auth import router as auth_router
|
||||
from routes.dashboard import router as dashboard_router
|
||||
from routes.projects import router as projects_router
|
||||
from routes.scenarios import router as scenarios_router
|
||||
from services.unit_of_work import UnitOfWork
|
||||
from services.session import AuthSession, SessionTokens
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def engine() -> Iterator[Engine]:
|
||||
engine = create_engine(
|
||||
"sqlite+pysqlite:///:memory:",
|
||||
future=True,
|
||||
connect_args={"check_same_thread": False},
|
||||
poolclass=StaticPool,
|
||||
)
|
||||
Base.metadata.create_all(bind=engine)
|
||||
try:
|
||||
yield engine
|
||||
finally:
|
||||
Base.metadata.drop_all(bind=engine)
|
||||
engine.dispose()
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def session_factory(engine: Engine) -> Iterator[sessionmaker]:
|
||||
testing_session = sessionmaker(
|
||||
bind=engine, expire_on_commit=False, future=True)
|
||||
yield testing_session
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def app(session_factory: sessionmaker) -> FastAPI:
|
||||
application = FastAPI()
|
||||
application.include_router(auth_router)
|
||||
application.include_router(dashboard_router)
|
||||
application.include_router(projects_router)
|
||||
application.include_router(scenarios_router)
|
||||
|
||||
def _override_uow() -> Iterator[UnitOfWork]:
|
||||
with UnitOfWork(session_factory=session_factory) as uow:
|
||||
yield uow
|
||||
|
||||
application.dependency_overrides[get_unit_of_work] = _override_uow
|
||||
|
||||
with UnitOfWork(session_factory=session_factory) as uow:
|
||||
assert uow.users is not None
|
||||
uow.ensure_default_roles()
|
||||
user = User(
|
||||
email="test-superuser@example.com",
|
||||
username="test-superuser",
|
||||
password_hash=User.hash_password("test-password"),
|
||||
is_active=True,
|
||||
is_superuser=True,
|
||||
)
|
||||
uow.users.create(user)
|
||||
user = uow.users.get(user.id, with_roles=True)
|
||||
|
||||
def _override_auth_session(request: Request) -> AuthSession:
|
||||
session = AuthSession(tokens=SessionTokens(
|
||||
access_token="test", refresh_token="test"))
|
||||
session.user = user
|
||||
request.state.auth_session = session
|
||||
return session
|
||||
|
||||
application.dependency_overrides[get_auth_session] = _override_auth_session
|
||||
return application
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def client(app: FastAPI) -> Iterator[TestClient]:
|
||||
test_client = TestClient(app)
|
||||
try:
|
||||
yield test_client
|
||||
finally:
|
||||
test_client.close()
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def unit_of_work_factory(session_factory: sessionmaker) -> Callable[[], UnitOfWork]:
|
||||
def _factory() -> UnitOfWork:
|
||||
return UnitOfWork(session_factory=session_factory)
|
||||
|
||||
return _factory
|
||||
30
tests/integration/README.md
Normal file
30
tests/integration/README.md
Normal file
@@ -0,0 +1,30 @@
|
||||
# Lifecycle Integration Test Plan
|
||||
|
||||
## Coverage Goals
|
||||
|
||||
- Exercise end-to-end creation, update, and deletion flows for projects through both API endpoints and HTML form submissions to ensure consistency across interfaces.
|
||||
- Validate scenario lifecycle interactions (create, update, archive) including business rule enforcement and status transitions when performed via API and UI routes.
|
||||
- Confirm that redirect chains and rendered templates match expected destinations after each lifecycle operation.
|
||||
- Verify repository-backed statistics (counts, recency metadata) update appropriately after lifecycle actions to maintain dashboard accuracy.
|
||||
- Guard against regressions in unit-of-work transaction handling by asserting database state after success and failure paths within integration flows.
|
||||
|
||||
## Happy-Path Journeys
|
||||
|
||||
1. **Project Management Flow**
|
||||
|
||||
- Navigate to the project list UI and confirm empty-state messaging.
|
||||
- Submit the new project form with valid data and verify redirect to the list page with the project present.
|
||||
- Perform an API-based update to adjust project metadata and check the UI detail view reflects changes.
|
||||
- Delete the project through the API and ensure the list UI reverts to the empty state.
|
||||
|
||||
2. **Scenario Lifecycle Flow**
|
||||
|
||||
- From an existing project, create a new scenario via the API and verify the project detail page renders the scenario entry.
|
||||
- Update the scenario through the UI form, ensuring redirect to the scenario detail view with updated fields.
|
||||
- Trigger a validation rule (e.g., duplicate scenario name within a project) to confirm error messaging without data loss.
|
||||
- Archive the scenario using the API and confirm status badges and scenario counts update across dashboard and project detail views.
|
||||
|
||||
3. **Dashboard Consistency Flow**
|
||||
- Seed multiple projects and scenarios through combined API/UI interactions.
|
||||
- Visit the dashboard and ensure metric cards reflect the latest counts, active/draft status breakdowns, and recent activity timestamps after each mutation.
|
||||
- Run the lifecycle flows sequentially to confirm cumulative effects propagate to dashboard summaries and navigation badges.
|
||||
66
tests/integration/test_project_lifecycle.py
Normal file
66
tests/integration/test_project_lifecycle.py
Normal file
@@ -0,0 +1,66 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
|
||||
class TestProjectLifecycle:
|
||||
def test_project_create_update_delete_flow(self, client: TestClient) -> None:
|
||||
# Initial state: no projects listed on the UI page
|
||||
response = client.get("/projects/ui")
|
||||
assert response.status_code == 200
|
||||
assert "No projects yet" in response.text
|
||||
|
||||
# Create a project via the HTML form submission
|
||||
create_payload = {
|
||||
"name": "Lifecycle Mine",
|
||||
"location": "Nevada",
|
||||
"operation_type": "open_pit",
|
||||
"description": "Initial description",
|
||||
}
|
||||
response = client.post(
|
||||
"/projects/create",
|
||||
data=create_payload,
|
||||
follow_redirects=False,
|
||||
)
|
||||
assert response.status_code == 303
|
||||
assert response.headers["Location"].endswith("/projects/ui")
|
||||
|
||||
# Project should now appear on the list page
|
||||
response = client.get("/projects/ui")
|
||||
assert response.status_code == 200
|
||||
assert "Lifecycle Mine" in response.text
|
||||
assert "Nevada" in response.text
|
||||
|
||||
# Fetch the project via API to obtain its identifier
|
||||
response = client.get("/projects")
|
||||
assert response.status_code == 200
|
||||
projects = response.json()
|
||||
assert len(projects) == 1
|
||||
project_id = projects[0]["id"]
|
||||
|
||||
# Update the project using the API endpoint
|
||||
update_payload = {
|
||||
"location": "Arizona",
|
||||
"description": "Updated description",
|
||||
}
|
||||
response = client.put(f"/projects/{project_id}", json=update_payload)
|
||||
assert response.status_code == 200
|
||||
body = response.json()
|
||||
assert body["location"] == "Arizona"
|
||||
assert body["description"] == "Updated description"
|
||||
|
||||
# Verify the UI detail page reflects the updates
|
||||
response = client.get(f"/projects/{project_id}/view")
|
||||
assert response.status_code == 200
|
||||
assert "Arizona" in response.text
|
||||
assert "Updated description" in response.text
|
||||
|
||||
# Delete the project using the API endpoint
|
||||
response = client.delete(f"/projects/{project_id}")
|
||||
assert response.status_code == 204
|
||||
|
||||
# Ensure the list view returns to the empty state
|
||||
response = client.get("/projects/ui")
|
||||
assert response.status_code == 200
|
||||
assert "No projects yet" in response.text
|
||||
assert "Lifecycle Mine" not in response.text
|
||||
106
tests/integration/test_scenario_lifecycle.py
Normal file
106
tests/integration/test_scenario_lifecycle.py
Normal file
@@ -0,0 +1,106 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
|
||||
class TestScenarioLifecycle:
|
||||
def test_scenario_lifecycle_flow(self, client: TestClient) -> None:
|
||||
# Create a project to attach scenarios to
|
||||
project_response = client.post(
|
||||
"/projects",
|
||||
json={
|
||||
"name": "Scenario Host Project",
|
||||
"location": "Ontario",
|
||||
"operation_type": "open_pit",
|
||||
"description": "Project for scenario lifecycle testing",
|
||||
},
|
||||
)
|
||||
assert project_response.status_code == 201
|
||||
project_id = project_response.json()["id"]
|
||||
|
||||
# Create a scenario via the API for the project
|
||||
scenario_response = client.post(
|
||||
f"/projects/{project_id}/scenarios",
|
||||
json={
|
||||
"name": "Lifecycle Scenario",
|
||||
"description": "Initial scenario description",
|
||||
"status": "draft",
|
||||
"currency": "usd",
|
||||
"primary_resource": "diesel",
|
||||
},
|
||||
)
|
||||
assert scenario_response.status_code == 201
|
||||
scenario_id = scenario_response.json()["id"]
|
||||
|
||||
# Project detail page should list the new scenario in draft state
|
||||
project_detail = client.get(f"/projects/{project_id}/view")
|
||||
assert project_detail.status_code == 200
|
||||
assert "Lifecycle Scenario" in project_detail.text
|
||||
assert "<td>Draft</td>" in project_detail.text
|
||||
|
||||
# Update the scenario through the HTML form
|
||||
form_response = client.post(
|
||||
f"/scenarios/{scenario_id}/edit",
|
||||
data={
|
||||
"name": "Lifecycle Scenario Revised",
|
||||
"description": "Revised scenario assumptions",
|
||||
"status_value": "active",
|
||||
"start_date": "2025-01-01",
|
||||
"end_date": "2025-12-31",
|
||||
"discount_rate": "5.5",
|
||||
"currency": "cad",
|
||||
"primary_resource": "electricity",
|
||||
},
|
||||
follow_redirects=False,
|
||||
)
|
||||
assert form_response.status_code == 303
|
||||
assert form_response.headers["Location"].endswith(
|
||||
f"/scenarios/{scenario_id}/view")
|
||||
|
||||
# Scenario detail page should reflect the updated information
|
||||
scenario_detail = client.get(f"/scenarios/{scenario_id}/view")
|
||||
assert scenario_detail.status_code == 200
|
||||
assert "Lifecycle Scenario Revised" in scenario_detail.text
|
||||
assert "Status: Active" in scenario_detail.text
|
||||
assert "CAD" in scenario_detail.text
|
||||
assert "Electricity" in scenario_detail.text
|
||||
assert "Revised scenario assumptions" in scenario_detail.text
|
||||
|
||||
# Project detail page should show the scenario as active with updated currency/resource
|
||||
project_detail = client.get(f"/projects/{project_id}/view")
|
||||
assert "<td>Active</td>" in project_detail.text
|
||||
assert "<td>CAD</td>" in project_detail.text
|
||||
assert "<td>Electricity</td>" in project_detail.text
|
||||
|
||||
# Attempt to update the scenario with invalid currency to trigger validation error
|
||||
invalid_update = client.put(
|
||||
f"/scenarios/{scenario_id}",
|
||||
json={"currency": "ca"},
|
||||
)
|
||||
assert invalid_update.status_code == 422
|
||||
assert (
|
||||
invalid_update.json()["detail"][0]["msg"]
|
||||
== "Value error, Currency code must be a 3-letter ISO value"
|
||||
)
|
||||
|
||||
# Scenario detail should still show the previous (valid) currency
|
||||
scenario_detail = client.get(f"/scenarios/{scenario_id}/view")
|
||||
assert "CAD" in scenario_detail.text
|
||||
|
||||
# Archive the scenario through the API
|
||||
archive_response = client.put(
|
||||
f"/scenarios/{scenario_id}",
|
||||
json={"status": "archived"},
|
||||
)
|
||||
assert archive_response.status_code == 200
|
||||
assert archive_response.json()["status"] == "archived"
|
||||
|
||||
# Scenario detail reflects archived status
|
||||
scenario_detail = client.get(f"/scenarios/{scenario_id}/view")
|
||||
assert "Status: Archived" in scenario_detail.text
|
||||
|
||||
# Project detail metrics and table entries reflect the archived state
|
||||
project_detail = client.get(f"/projects/{project_id}/view")
|
||||
assert "<h2>Archived</h2>" in project_detail.text
|
||||
assert '<p class="metric-value">1</p>' in project_detail.text
|
||||
assert "<td>Archived</td>" in project_detail.text
|
||||
156
tests/scripts/test_initial_data_seed.py
Normal file
156
tests/scripts/test_initial_data_seed.py
Normal file
@@ -0,0 +1,156 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from collections.abc import Callable
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
from config.database import Base
|
||||
from scripts import initial_data
|
||||
from scripts.initial_data import AdminSeedResult, RoleSeedResult, SeedConfig
|
||||
from services.repositories import DEFAULT_ROLE_DEFINITIONS
|
||||
from services.unit_of_work import UnitOfWork
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def in_memory_session_factory() -> Callable[[], Session]:
|
||||
engine = create_engine("sqlite+pysqlite:///:memory:", future=True)
|
||||
Base.metadata.create_all(engine)
|
||||
factory = sessionmaker(bind=engine, autoflush=False, autocommit=False, future=True)
|
||||
|
||||
def _session_factory() -> Session:
|
||||
return factory()
|
||||
|
||||
return _session_factory
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def uow(in_memory_session_factory: Callable[[], Session]) -> UnitOfWork:
|
||||
return UnitOfWork(session_factory=in_memory_session_factory)
|
||||
|
||||
|
||||
def test_ensure_default_roles_idempotent(uow: UnitOfWork) -> None:
|
||||
with uow as working:
|
||||
assert working.roles is not None
|
||||
result_first = initial_data.ensure_default_roles(working.roles, DEFAULT_ROLE_DEFINITIONS)
|
||||
assert result_first == RoleSeedResult(created=4, updated=0, total=4)
|
||||
|
||||
with uow as working:
|
||||
assert working.roles is not None
|
||||
result_second = initial_data.ensure_default_roles(working.roles, DEFAULT_ROLE_DEFINITIONS)
|
||||
assert result_second == RoleSeedResult(created=0, updated=0, total=4)
|
||||
|
||||
|
||||
def test_ensure_admin_user_creates_and_assigns_roles(uow: UnitOfWork) -> None:
|
||||
config = SeedConfig(
|
||||
admin_email="admin@example.com",
|
||||
admin_username="admin",
|
||||
admin_password="secret",
|
||||
admin_roles=("admin", "viewer"),
|
||||
force_reset=False,
|
||||
)
|
||||
|
||||
with uow as working:
|
||||
assert working.roles is not None
|
||||
assert working.users is not None
|
||||
initial_data.ensure_default_roles(working.roles, DEFAULT_ROLE_DEFINITIONS)
|
||||
result = initial_data.ensure_admin_user(working.users, working.roles, config)
|
||||
assert result == AdminSeedResult(
|
||||
created_user=True,
|
||||
updated_user=False,
|
||||
password_rotated=False,
|
||||
roles_granted=2,
|
||||
)
|
||||
|
||||
with uow as working:
|
||||
assert working.roles is not None
|
||||
assert working.users is not None
|
||||
result_again = initial_data.ensure_admin_user(working.users, working.roles, config)
|
||||
assert result_again == AdminSeedResult(
|
||||
created_user=False,
|
||||
updated_user=False,
|
||||
password_rotated=False,
|
||||
roles_granted=0,
|
||||
)
|
||||
|
||||
with uow as working:
|
||||
assert working.users is not None
|
||||
user = working.users.get_by_email("admin@example.com", with_roles=True)
|
||||
assert user is not None
|
||||
assert user.is_active is True
|
||||
assert user.is_superuser is True
|
||||
role_names = {role.name for role in user.roles}
|
||||
assert role_names == {"admin", "viewer"}
|
||||
|
||||
|
||||
def test_ensure_admin_user_force_reset_rotates_password(uow: UnitOfWork) -> None:
|
||||
base_config = SeedConfig(
|
||||
admin_email="admin@example.com",
|
||||
admin_username="admin",
|
||||
admin_password="first",
|
||||
admin_roles=("admin",),
|
||||
force_reset=False,
|
||||
)
|
||||
|
||||
with uow as working:
|
||||
assert working.roles is not None
|
||||
assert working.users is not None
|
||||
initial_data.ensure_default_roles(working.roles, DEFAULT_ROLE_DEFINITIONS)
|
||||
initial_data.ensure_admin_user(working.users, working.roles, base_config)
|
||||
|
||||
rotate_config = SeedConfig(
|
||||
admin_email="admin@example.com",
|
||||
admin_username="admin",
|
||||
admin_password="second",
|
||||
admin_roles=("admin",),
|
||||
force_reset=True,
|
||||
)
|
||||
|
||||
with uow as working:
|
||||
assert working.users is not None
|
||||
user_before = working.users.get_by_email("admin@example.com")
|
||||
assert user_before is not None
|
||||
old_hash = user_before.password_hash
|
||||
|
||||
with uow as working:
|
||||
assert working.roles is not None
|
||||
assert working.users is not None
|
||||
result = initial_data.ensure_admin_user(working.users, working.roles, rotate_config)
|
||||
assert result.password_rotated is True
|
||||
|
||||
with uow as working:
|
||||
assert working.users is not None
|
||||
user_after = working.users.get_by_email("admin@example.com")
|
||||
assert user_after is not None
|
||||
assert user_after.password_hash != old_hash
|
||||
|
||||
|
||||
def test_seed_initial_data_logs_results(
|
||||
caplog,
|
||||
in_memory_session_factory: Callable[[], Session],
|
||||
) -> None:
|
||||
caplog.set_level(logging.INFO)
|
||||
config = SeedConfig(
|
||||
admin_email="seed@example.com",
|
||||
admin_username="seed",
|
||||
admin_password="seed-pass",
|
||||
admin_roles=("admin",),
|
||||
force_reset=False,
|
||||
)
|
||||
|
||||
initial_data.seed_initial_data(
|
||||
config,
|
||||
unit_of_work_factory=lambda: UnitOfWork(session_factory=in_memory_session_factory),
|
||||
)
|
||||
|
||||
assert "Starting initial data seeding" in caplog.text
|
||||
assert "Initial data seeding completed successfully" in caplog.text
|
||||
|
||||
with UnitOfWork(session_factory=in_memory_session_factory) as check_uow:
|
||||
assert check_uow.users is not None
|
||||
assert check_uow.roles is not None
|
||||
user = check_uow.users.get_by_email("seed@example.com")
|
||||
assert user is not None
|
||||
assert check_uow.roles.get_by_name("admin") is not None
|
||||
135
tests/test_auth_repositories.py
Normal file
135
tests/test_auth_repositories.py
Normal file
@@ -0,0 +1,135 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Iterator
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
from config.database import Base
|
||||
from models import Role, User
|
||||
from services.repositories import (
|
||||
RoleRepository,
|
||||
UserRepository,
|
||||
ensure_admin_user,
|
||||
ensure_default_roles,
|
||||
)
|
||||
from services.unit_of_work import UnitOfWork
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def engine() -> Iterator:
|
||||
engine = create_engine("sqlite:///:memory:", future=True)
|
||||
Base.metadata.create_all(bind=engine)
|
||||
try:
|
||||
yield engine
|
||||
finally:
|
||||
Base.metadata.drop_all(bind=engine)
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def session(engine) -> Iterator[Session]:
|
||||
TestingSession = sessionmaker(
|
||||
bind=engine, expire_on_commit=False, future=True)
|
||||
db = TestingSession()
|
||||
try:
|
||||
yield db
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
def test_role_repository_create_and_lookup(session: Session) -> None:
|
||||
repo = RoleRepository(session)
|
||||
role = Role(name="custom", display_name="Custom",
|
||||
description="Custom role")
|
||||
repo.create(role)
|
||||
|
||||
retrieved = repo.get(role.id)
|
||||
assert retrieved.name == "custom"
|
||||
assert repo.get_by_name("custom") is retrieved
|
||||
assert repo.list()[0].name == "custom"
|
||||
|
||||
|
||||
def test_user_repository_assign_and_revoke_role(session: Session) -> None:
|
||||
role_repo = RoleRepository(session)
|
||||
user_repo = UserRepository(session)
|
||||
|
||||
analyst = role_repo.create(
|
||||
Role(name="analyst", display_name="Analyst", description="Analyzes data")
|
||||
)
|
||||
user = User(
|
||||
email="user@example.com",
|
||||
username="user",
|
||||
password_hash=User.hash_password("secret"),
|
||||
)
|
||||
user_repo.create(user)
|
||||
|
||||
assignment = user_repo.assign_role(
|
||||
user_id=user.id, role_id=analyst.id, granted_by=None)
|
||||
assert assignment.role_id == analyst.id
|
||||
|
||||
refreshed = user_repo.get(user.id, with_roles=True)
|
||||
assert refreshed.roles[0].name == "analyst"
|
||||
|
||||
user_repo.revoke_role(user_id=user.id, role_id=analyst.id)
|
||||
refreshed = user_repo.get(user.id, with_roles=True)
|
||||
assert refreshed.roles == []
|
||||
|
||||
|
||||
def test_default_role_and_admin_helpers(session: Session) -> None:
|
||||
role_repo = RoleRepository(session)
|
||||
user_repo = UserRepository(session)
|
||||
|
||||
roles = ensure_default_roles(role_repo)
|
||||
assert {role.name for role in roles} == {
|
||||
"admin", "project_manager", "analyst", "viewer"}
|
||||
|
||||
ensure_admin_user(
|
||||
user_repo,
|
||||
role_repo,
|
||||
email="admin@example.com",
|
||||
username="admin",
|
||||
password="SecurePass1!",
|
||||
)
|
||||
|
||||
admin = user_repo.get_by_email("admin@example.com", with_roles=True)
|
||||
assert admin is not None
|
||||
assert admin.is_superuser
|
||||
assert {role.name for role in admin.roles} >= {"admin"}
|
||||
|
||||
# Idempotent behaviour on subsequent calls
|
||||
ensure_admin_user(
|
||||
user_repo,
|
||||
role_repo,
|
||||
email="admin@example.com",
|
||||
username="admin",
|
||||
password="SecurePass1!",
|
||||
)
|
||||
admin_again = user_repo.get_by_email("admin@example.com", with_roles=True)
|
||||
assert admin_again is not None
|
||||
assert len(admin_again.roles) == len(
|
||||
{role.name for role in admin_again.roles})
|
||||
|
||||
|
||||
def test_unit_of_work_exposes_auth_repositories(engine) -> None:
|
||||
TestingSession = sessionmaker(
|
||||
bind=engine, expire_on_commit=False, future=True)
|
||||
|
||||
with UnitOfWork(session_factory=TestingSession) as uow:
|
||||
assert uow.users is not None
|
||||
assert uow.roles is not None
|
||||
|
||||
roles = uow.ensure_default_roles()
|
||||
assert any(role.name == "admin" for role in roles)
|
||||
|
||||
uow.ensure_admin_user(
|
||||
email="uow-admin@example.com",
|
||||
username="uow-admin",
|
||||
password="AnotherSecret1!",
|
||||
)
|
||||
|
||||
admin = uow.users.get_by_email(
|
||||
"uow-admin@example.com", with_roles=True)
|
||||
assert admin is not None
|
||||
assert admin.is_superuser
|
||||
assert any(role.name == "admin" for role in admin.roles)
|
||||
286
tests/test_auth_routes.py
Normal file
286
tests/test_auth_routes.py
Normal file
@@ -0,0 +1,286 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Iterator
|
||||
from typing import cast
|
||||
from urllib.parse import parse_qs, urlparse
|
||||
|
||||
import pytest
|
||||
from fastapi import FastAPI
|
||||
from fastapi.testclient import TestClient
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
from models import Role, User, UserRole
|
||||
from dependencies import get_auth_session, require_current_user
|
||||
from services.security import hash_password
|
||||
from services.session import AuthSession, SessionTokens
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def db_session(session_factory: sessionmaker) -> Iterator[Session]:
|
||||
session = session_factory()
|
||||
try:
|
||||
yield session
|
||||
finally:
|
||||
session.close()
|
||||
|
||||
|
||||
def _get_user(session: Session, *, email: str | None = None, username: str | None = None) -> User | None:
|
||||
stmt = select(User)
|
||||
if email is not None:
|
||||
stmt = stmt.where(User.email == email)
|
||||
if username is not None:
|
||||
stmt = stmt.where(User.username == username)
|
||||
return session.execute(stmt).scalar_one_or_none()
|
||||
|
||||
|
||||
class TestRegistrationFlow:
|
||||
def test_register_creates_user_and_assigns_role(
|
||||
self,
|
||||
client: TestClient,
|
||||
db_session: Session,
|
||||
) -> None:
|
||||
response = client.post(
|
||||
"/register",
|
||||
data={
|
||||
"username": "newuser",
|
||||
"email": "newuser@example.com",
|
||||
"password": "ComplexP@ss1",
|
||||
"confirm_password": "ComplexP@ss1",
|
||||
},
|
||||
follow_redirects=False,
|
||||
)
|
||||
|
||||
assert response.status_code == 303
|
||||
location = response.headers.get("location")
|
||||
assert location
|
||||
parsed = urlparse(location)
|
||||
assert parsed.path == "/login"
|
||||
assert parse_qs(parsed.query).get("registered") == ["1"]
|
||||
|
||||
created = _get_user(db_session, email="newuser@example.com")
|
||||
assert created is not None
|
||||
assert created.is_active
|
||||
|
||||
role_stmt = select(Role).where(Role.name == "viewer")
|
||||
viewer_role = db_session.execute(role_stmt).scalar_one_or_none()
|
||||
assert viewer_role is not None
|
||||
|
||||
assignments = db_session.execute(
|
||||
select(UserRole).where(
|
||||
UserRole.user_id == created.id,
|
||||
UserRole.role_id == viewer_role.id,
|
||||
)
|
||||
).scalars().all()
|
||||
assert len(assignments) == 1
|
||||
|
||||
def test_register_duplicate_email_shows_error(
|
||||
self,
|
||||
client: TestClient,
|
||||
) -> None:
|
||||
first = client.post(
|
||||
"/register",
|
||||
data={
|
||||
"username": "existing",
|
||||
"email": "existing@example.com",
|
||||
"password": "ComplexP@ss1",
|
||||
"confirm_password": "ComplexP@ss1",
|
||||
},
|
||||
follow_redirects=False,
|
||||
)
|
||||
assert first.status_code == 303
|
||||
|
||||
second = client.post(
|
||||
"/register",
|
||||
data={
|
||||
"username": "existing",
|
||||
"email": "existing@example.com",
|
||||
"password": "ComplexP@ss1",
|
||||
"confirm_password": "ComplexP@ss1",
|
||||
},
|
||||
follow_redirects=False,
|
||||
)
|
||||
|
||||
assert second.status_code == 400
|
||||
assert "Email is already registered" in second.text
|
||||
|
||||
|
||||
class TestLoginFlow:
|
||||
def test_login_sets_tokens_and_updates_last_login(
|
||||
self,
|
||||
client: TestClient,
|
||||
db_session: Session,
|
||||
) -> None:
|
||||
password = "MySecur3Pass!"
|
||||
user = User(
|
||||
email="login@example.com",
|
||||
username="loginuser",
|
||||
password_hash=hash_password(password),
|
||||
is_active=True,
|
||||
)
|
||||
db_session.add(user)
|
||||
db_session.commit()
|
||||
|
||||
response = client.post(
|
||||
"/login",
|
||||
data={"username": "loginuser", "password": password},
|
||||
follow_redirects=False,
|
||||
)
|
||||
|
||||
assert response.status_code == 303
|
||||
assert response.headers.get("location") == "http://testserver/"
|
||||
set_cookie_header = response.headers.get("set-cookie", "")
|
||||
assert "calminer_access_token=" in set_cookie_header
|
||||
assert "calminer_refresh_token=" in set_cookie_header
|
||||
|
||||
updated = _get_user(db_session, username="loginuser")
|
||||
assert updated is not None and updated.last_login_at is not None
|
||||
|
||||
def test_login_invalid_credentials_returns_error(self, client: TestClient) -> None:
|
||||
response = client.post(
|
||||
"/login",
|
||||
data={"username": "unknown", "password": "bad"},
|
||||
follow_redirects=False,
|
||||
)
|
||||
|
||||
assert response.status_code == 400
|
||||
assert "Invalid username or password" in response.text
|
||||
|
||||
|
||||
class TestPasswordResetFlow:
|
||||
def test_password_reset_round_trip(
|
||||
self,
|
||||
client: TestClient,
|
||||
db_session: Session,
|
||||
) -> None:
|
||||
user = User(
|
||||
email="reset@example.com",
|
||||
username="resetuser",
|
||||
password_hash=hash_password("OldP@ssword1"),
|
||||
is_active=True,
|
||||
)
|
||||
db_session.add(user)
|
||||
db_session.commit()
|
||||
|
||||
request_response = client.post(
|
||||
"/forgot-password",
|
||||
data={"email": "reset@example.com"},
|
||||
follow_redirects=False,
|
||||
)
|
||||
|
||||
assert request_response.status_code == 303
|
||||
reset_location = request_response.headers.get("location")
|
||||
assert reset_location is not None
|
||||
parsed = urlparse(reset_location)
|
||||
assert parsed.path == "/reset-password"
|
||||
token = parse_qs(parsed.query).get("token", [None])[0]
|
||||
assert token
|
||||
|
||||
form_response = client.get(reset_location)
|
||||
assert form_response.status_code == 200
|
||||
|
||||
submit_response = client.post(
|
||||
"/reset-password",
|
||||
data={
|
||||
"token": token,
|
||||
"password": "N3wP@ssword!",
|
||||
"confirm_password": "N3wP@ssword!",
|
||||
},
|
||||
follow_redirects=False,
|
||||
)
|
||||
|
||||
assert submit_response.status_code == 303
|
||||
assert "reset=1" in (submit_response.headers.get("location") or "")
|
||||
|
||||
db_session.refresh(user)
|
||||
assert user.verify_password("N3wP@ssword!")
|
||||
|
||||
def test_password_reset_with_unknown_email_shows_generic_message(
|
||||
self,
|
||||
client: TestClient,
|
||||
) -> None:
|
||||
response = client.post(
|
||||
"/forgot-password",
|
||||
data={"email": "doesnotexist@example.com"},
|
||||
follow_redirects=False,
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert "If an account exists" in response.text
|
||||
|
||||
def test_password_reset_mismatched_passwords_return_error(
|
||||
self,
|
||||
client: TestClient,
|
||||
db_session: Session,
|
||||
) -> None:
|
||||
user = User(
|
||||
email="mismatch@example.com",
|
||||
username="mismatch",
|
||||
password_hash=hash_password("OldP@ssword1"),
|
||||
is_active=True,
|
||||
)
|
||||
db_session.add(user)
|
||||
db_session.commit()
|
||||
|
||||
request_response = client.post(
|
||||
"/forgot-password",
|
||||
data={"email": "mismatch@example.com"},
|
||||
follow_redirects=False,
|
||||
)
|
||||
token = parse_qs(urlparse(request_response.headers["location"]).query)[
|
||||
"token"][0]
|
||||
|
||||
submit_response = client.post(
|
||||
"/reset-password",
|
||||
data={
|
||||
"token": token,
|
||||
"password": "NewPass123!",
|
||||
"confirm_password": "Different123!",
|
||||
},
|
||||
follow_redirects=False,
|
||||
)
|
||||
|
||||
assert submit_response.status_code == 400
|
||||
assert "Passwords do not match" in submit_response.text
|
||||
|
||||
|
||||
class TestLogoutFlow:
|
||||
def test_logout_clears_cookies_and_redirects(
|
||||
self,
|
||||
client: TestClient,
|
||||
db_session: Session,
|
||||
) -> None:
|
||||
user = User(
|
||||
email="logout@example.com",
|
||||
username="logoutuser",
|
||||
password_hash=hash_password("SecureP@ss1"),
|
||||
is_active=True,
|
||||
)
|
||||
db_session.add(user)
|
||||
db_session.commit()
|
||||
|
||||
session = AuthSession(
|
||||
tokens=SessionTokens(
|
||||
access_token="access-token",
|
||||
refresh_token="refresh-token",
|
||||
access_token_source="cookie",
|
||||
),
|
||||
user=user,
|
||||
)
|
||||
|
||||
app = cast(FastAPI, client.app)
|
||||
app.dependency_overrides[require_current_user] = lambda: user
|
||||
app.dependency_overrides[get_auth_session] = lambda: session
|
||||
|
||||
try:
|
||||
response = client.get("/logout", follow_redirects=False)
|
||||
finally:
|
||||
app.dependency_overrides.pop(require_current_user, None)
|
||||
app.dependency_overrides.pop(get_auth_session, None)
|
||||
|
||||
assert response.status_code == 303
|
||||
location = response.headers.get("location")
|
||||
assert location and location.startswith("http://testserver/login")
|
||||
set_cookie_header = response.headers.get("set-cookie") or ""
|
||||
assert "calminer_access_token=" in set_cookie_header
|
||||
assert "Max-Age=0" in set_cookie_header or "expires=" in set_cookie_header.lower()
|
||||
111
tests/test_auth_session_middleware.py
Normal file
111
tests/test_auth_session_middleware.py
Normal file
@@ -0,0 +1,111 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import timedelta
|
||||
from typing import Iterator
|
||||
|
||||
import pytest
|
||||
from fastapi import Depends, FastAPI
|
||||
from fastapi.responses import JSONResponse
|
||||
from fastapi.testclient import TestClient
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from config.settings import get_settings
|
||||
from dependencies import get_unit_of_work, require_current_user
|
||||
from middleware.auth_session import AuthSessionMiddleware
|
||||
from models import User
|
||||
from services.security import create_access_token, create_refresh_token, hash_password
|
||||
from services.unit_of_work import UnitOfWork
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def auth_app(session_factory: sessionmaker) -> Iterator[TestClient]:
|
||||
app = FastAPI()
|
||||
|
||||
def _override_uow() -> Iterator[UnitOfWork]:
|
||||
with UnitOfWork(session_factory=session_factory) as uow:
|
||||
yield uow
|
||||
|
||||
app.dependency_overrides[get_unit_of_work] = _override_uow
|
||||
|
||||
@app.get("/me")
|
||||
def read_me(user: User = Depends(require_current_user)) -> JSONResponse:
|
||||
return JSONResponse({"id": user.id, "username": user.username})
|
||||
|
||||
app.add_middleware(
|
||||
AuthSessionMiddleware,
|
||||
unit_of_work_factory=lambda: UnitOfWork(
|
||||
session_factory=session_factory),
|
||||
)
|
||||
|
||||
client = TestClient(app)
|
||||
try:
|
||||
yield client
|
||||
finally:
|
||||
client.close()
|
||||
|
||||
|
||||
def _create_user(session_factory: sessionmaker) -> User:
|
||||
with UnitOfWork(session_factory=session_factory) as uow:
|
||||
assert uow.users is not None
|
||||
user = User(
|
||||
email="jane@example.com",
|
||||
username="jane",
|
||||
password_hash=hash_password("secret"),
|
||||
is_active=True,
|
||||
is_superuser=False,
|
||||
)
|
||||
uow.users.create(user)
|
||||
return user
|
||||
|
||||
|
||||
def _issue_tokens(user: User) -> tuple[str, str]:
|
||||
settings = get_settings().jwt_settings()
|
||||
access = create_access_token(str(user.id), settings, scopes=["auth"])
|
||||
refresh = create_refresh_token(str(user.id), settings, scopes=["auth"])
|
||||
return access, refresh
|
||||
|
||||
|
||||
def test_middleware_populates_current_user(auth_app: TestClient, session_factory: sessionmaker) -> None:
|
||||
user = _create_user(session_factory)
|
||||
access, refresh = _issue_tokens(user)
|
||||
|
||||
auth_app.cookies.set("calminer_access_token", access)
|
||||
auth_app.cookies.set("calminer_refresh_token", refresh)
|
||||
|
||||
response = auth_app.get("/me")
|
||||
assert response.status_code == 200
|
||||
payload = response.json()
|
||||
assert payload["id"] == user.id
|
||||
assert payload["username"] == user.username
|
||||
|
||||
|
||||
def test_middleware_refreshes_expired_access_token(auth_app: TestClient, session_factory: sessionmaker) -> None:
|
||||
user = _create_user(session_factory)
|
||||
settings = get_settings().jwt_settings()
|
||||
expired = create_access_token(
|
||||
str(user.id),
|
||||
settings,
|
||||
scopes=["auth"],
|
||||
expires_delta=timedelta(seconds=-1),
|
||||
)
|
||||
refresh = create_refresh_token(str(user.id), settings, scopes=["auth"])
|
||||
|
||||
auth_app.cookies.set("calminer_access_token", expired)
|
||||
auth_app.cookies.set("calminer_refresh_token", refresh)
|
||||
|
||||
response = auth_app.get("/me")
|
||||
assert response.status_code == 200
|
||||
new_access = response.cookies.get("calminer_access_token")
|
||||
new_refresh = response.cookies.get("calminer_refresh_token")
|
||||
assert new_access is not None and new_access != expired
|
||||
assert new_refresh is not None
|
||||
|
||||
|
||||
def test_middleware_blocks_invalid_tokens(auth_app: TestClient) -> None:
|
||||
auth_app.cookies.set("calminer_access_token", "invalid-token")
|
||||
auth_app.cookies.set("calminer_refresh_token", "invalid-token")
|
||||
|
||||
response = auth_app.get("/me")
|
||||
assert response.status_code == 401
|
||||
set_cookies = response.headers.get_list("set-cookie")
|
||||
assert any("calminer_access_token=" in value for value in set_cookies)
|
||||
165
tests/test_authorization_helpers.py
Normal file
165
tests/test_authorization_helpers.py
Normal file
@@ -0,0 +1,165 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from models import MiningOperationType, Project, Scenario, ScenarioStatus, User
|
||||
from services.authorization import (
|
||||
ensure_project_access,
|
||||
ensure_scenario_access,
|
||||
ensure_scenario_in_project,
|
||||
)
|
||||
from services.exceptions import AuthorizationError, EntityNotFoundError
|
||||
from services.unit_of_work import UnitOfWork
|
||||
|
||||
|
||||
def _create_user_with_roles(
|
||||
uow: UnitOfWork,
|
||||
*,
|
||||
email: str,
|
||||
username: str,
|
||||
roles: set[str],
|
||||
) -> User:
|
||||
assert uow.users is not None
|
||||
assert uow.roles is not None
|
||||
|
||||
user = User(
|
||||
email=email,
|
||||
username=username,
|
||||
password_hash=User.hash_password("secret"),
|
||||
is_active=True,
|
||||
)
|
||||
uow.users.create(user)
|
||||
uow.ensure_default_roles()
|
||||
|
||||
for role_name in roles:
|
||||
role = uow.roles.get_by_name(role_name)
|
||||
assert role is not None, f"Role {role_name} should exist"
|
||||
uow.users.assign_role(user_id=user.id, role_id=role.id)
|
||||
|
||||
return uow.users.get(user.id, with_roles=True)
|
||||
|
||||
|
||||
def _create_project(uow: UnitOfWork, name: str) -> Project:
|
||||
assert uow.projects is not None
|
||||
project = Project(
|
||||
name=name,
|
||||
location=None,
|
||||
operation_type=MiningOperationType.OTHER,
|
||||
description=None,
|
||||
)
|
||||
uow.projects.create(project)
|
||||
return project
|
||||
|
||||
|
||||
def _create_scenario(uow: UnitOfWork, project: Project, name: str) -> Scenario:
|
||||
assert uow.scenarios is not None
|
||||
scenario = Scenario(
|
||||
project_id=project.id,
|
||||
name=name,
|
||||
description=None,
|
||||
status=ScenarioStatus.DRAFT,
|
||||
)
|
||||
uow.scenarios.create(scenario)
|
||||
return scenario
|
||||
|
||||
|
||||
def test_ensure_project_access_allows_view_roles(unit_of_work_factory) -> None:
|
||||
with unit_of_work_factory() as uow:
|
||||
project = _create_project(uow, "Project A")
|
||||
user = _create_user_with_roles(
|
||||
uow,
|
||||
email="viewer@example.com",
|
||||
username="viewer",
|
||||
roles={"viewer"},
|
||||
)
|
||||
|
||||
resolved = ensure_project_access(
|
||||
uow,
|
||||
project_id=project.id,
|
||||
user=user,
|
||||
)
|
||||
assert resolved.id == project.id
|
||||
|
||||
with pytest.raises(AuthorizationError):
|
||||
ensure_project_access(
|
||||
uow,
|
||||
project_id=project.id,
|
||||
user=user,
|
||||
require_manage=True,
|
||||
)
|
||||
|
||||
|
||||
def test_ensure_project_access_allows_manage_roles(unit_of_work_factory) -> None:
|
||||
with unit_of_work_factory() as uow:
|
||||
project = _create_project(uow, "Project B")
|
||||
user = _create_user_with_roles(
|
||||
uow,
|
||||
email="manager@example.com",
|
||||
username="manager",
|
||||
roles={"project_manager"},
|
||||
)
|
||||
|
||||
resolved = ensure_project_access(
|
||||
uow,
|
||||
project_id=project.id,
|
||||
user=user,
|
||||
require_manage=True,
|
||||
)
|
||||
assert resolved.id == project.id
|
||||
|
||||
|
||||
def test_ensure_scenario_access(unit_of_work_factory) -> None:
|
||||
with unit_of_work_factory() as uow:
|
||||
project = _create_project(uow, "Project C")
|
||||
scenario = _create_scenario(uow, project, "Scenario C1")
|
||||
user = _create_user_with_roles(
|
||||
uow,
|
||||
email="analyst@example.com",
|
||||
username="analyst",
|
||||
roles={"analyst"},
|
||||
)
|
||||
|
||||
resolved = ensure_scenario_access(
|
||||
uow,
|
||||
scenario_id=scenario.id,
|
||||
user=user,
|
||||
)
|
||||
assert resolved.id == scenario.id
|
||||
|
||||
with pytest.raises(AuthorizationError):
|
||||
ensure_scenario_access(
|
||||
uow,
|
||||
scenario_id=scenario.id,
|
||||
user=user,
|
||||
require_manage=True,
|
||||
)
|
||||
|
||||
|
||||
def test_ensure_scenario_in_project_validates_membership(unit_of_work_factory) -> None:
|
||||
with unit_of_work_factory() as uow:
|
||||
project_one = _create_project(uow, "Project D")
|
||||
project_two = _create_project(uow, "Project E")
|
||||
scenario = _create_scenario(uow, project_one, "Scenario D1")
|
||||
user = _create_user_with_roles(
|
||||
uow,
|
||||
email="manager2@example.com",
|
||||
username="manager2",
|
||||
roles={"project_manager"},
|
||||
)
|
||||
|
||||
resolved = ensure_scenario_in_project(
|
||||
uow,
|
||||
project_id=project_one.id,
|
||||
scenario_id=scenario.id,
|
||||
user=user,
|
||||
require_manage=True,
|
||||
)
|
||||
assert resolved.id == scenario.id
|
||||
|
||||
with pytest.raises(EntityNotFoundError):
|
||||
ensure_scenario_in_project(
|
||||
uow,
|
||||
project_id=project_two.id,
|
||||
scenario_id=scenario.id,
|
||||
user=user,
|
||||
)
|
||||
208
tests/test_authorization_integration.py
Normal file
208
tests/test_authorization_integration.py
Normal file
@@ -0,0 +1,208 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from contextlib import contextmanager
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from fastapi import Request, status
|
||||
|
||||
from dependencies import get_auth_session
|
||||
from models import MiningOperationType, Project, User
|
||||
from services.session import AuthSession, SessionTokens
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def auth_session_context(client):
|
||||
"""Allow tests to swap the current auth session for the test app."""
|
||||
|
||||
original_override = client.app.dependency_overrides[get_auth_session]
|
||||
|
||||
@contextmanager
|
||||
def _use(user: User | None):
|
||||
def _override(request: Request) -> AuthSession:
|
||||
if user is None:
|
||||
session = AuthSession.anonymous()
|
||||
else:
|
||||
session = AuthSession(tokens=SessionTokens(
|
||||
access_token="token", refresh_token="refresh"))
|
||||
session.user = user
|
||||
request.state.auth_session = session
|
||||
return session
|
||||
|
||||
client.app.dependency_overrides[get_auth_session] = _override
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
client.app.dependency_overrides[get_auth_session] = original_override
|
||||
|
||||
return _use
|
||||
|
||||
|
||||
def _unique(name: str) -> str:
|
||||
return f"{name}-{uuid4().hex}"
|
||||
|
||||
|
||||
def _create_user(uow, *, roles: tuple[str, ...] = (), is_superuser: bool = False) -> User:
|
||||
assert uow.users and uow.roles
|
||||
uow.ensure_default_roles()
|
||||
user = User(
|
||||
email=f"{_unique('user')}@example.com",
|
||||
username=_unique('user'),
|
||||
password_hash=User.hash_password("password"),
|
||||
is_active=True,
|
||||
is_superuser=is_superuser,
|
||||
)
|
||||
uow.users.create(user)
|
||||
uow.flush()
|
||||
|
||||
for role_name in roles:
|
||||
role = uow.roles.get_by_name(role_name)
|
||||
if role is None: # pragma: no cover - defensive guard
|
||||
raise AssertionError(f"Role {role_name} not found")
|
||||
uow.users.assign_role(user_id=user.id, role_id=role.id)
|
||||
return uow.users.get(user.id, with_roles=True)
|
||||
|
||||
|
||||
def _create_project(uow) -> Project:
|
||||
assert uow.projects
|
||||
project = Project(
|
||||
name=_unique('project'),
|
||||
location="Integration Site",
|
||||
operation_type=MiningOperationType.OPEN_PIT,
|
||||
)
|
||||
uow.projects.create(project)
|
||||
uow.flush()
|
||||
return project
|
||||
|
||||
|
||||
class TestAuthenticationRequirements:
|
||||
def test_api_projects_list_requires_login(self, client, auth_session_context):
|
||||
with auth_session_context(None):
|
||||
response = client.get("/projects")
|
||||
|
||||
assert response.status_code == status.HTTP_401_UNAUTHORIZED
|
||||
assert response.json()["detail"] == "Authentication required."
|
||||
|
||||
def test_ui_project_list_requires_login(self, client, auth_session_context):
|
||||
with auth_session_context(None):
|
||||
response = client.get("/projects/ui")
|
||||
|
||||
assert response.status_code == status.HTTP_401_UNAUTHORIZED
|
||||
|
||||
|
||||
class TestRoleRestrictions:
|
||||
def test_api_projects_create_forbidden_for_viewer(
|
||||
self,
|
||||
client,
|
||||
auth_session_context,
|
||||
unit_of_work_factory,
|
||||
) -> None:
|
||||
with unit_of_work_factory() as uow:
|
||||
viewer = _create_user(uow, roles=("viewer",))
|
||||
|
||||
payload = {
|
||||
"name": _unique("project"),
|
||||
"location": "Restricted",
|
||||
"operation_type": MiningOperationType.OPEN_PIT.value,
|
||||
"description": "Test restriction",
|
||||
}
|
||||
with auth_session_context(viewer):
|
||||
response = client.post("/projects", json=payload)
|
||||
|
||||
assert response.status_code == status.HTTP_403_FORBIDDEN
|
||||
assert response.json()[
|
||||
"detail"] == "Insufficient permissions for this action."
|
||||
|
||||
def test_api_projects_create_allows_project_manager(
|
||||
self,
|
||||
client,
|
||||
auth_session_context,
|
||||
unit_of_work_factory,
|
||||
) -> None:
|
||||
with unit_of_work_factory() as uow:
|
||||
manager = _create_user(uow, roles=("project_manager",))
|
||||
|
||||
payload = {
|
||||
"name": _unique("managed-project"),
|
||||
"location": "Permitted",
|
||||
"operation_type": MiningOperationType.OPEN_PIT.value,
|
||||
"description": "Authorized creation",
|
||||
}
|
||||
with auth_session_context(manager):
|
||||
response = client.post("/projects", json=payload)
|
||||
|
||||
assert response.status_code == status.HTTP_201_CREATED
|
||||
assert response.json()["name"] == payload["name"]
|
||||
|
||||
def test_api_projects_update_forbidden_for_viewer(
|
||||
self,
|
||||
client,
|
||||
auth_session_context,
|
||||
unit_of_work_factory,
|
||||
) -> None:
|
||||
with unit_of_work_factory() as uow:
|
||||
project = _create_project(uow)
|
||||
viewer = _create_user(uow, roles=("viewer",))
|
||||
|
||||
with auth_session_context(viewer):
|
||||
response = client.put(
|
||||
f"/projects/{project.id}",
|
||||
json={"description": "Updated"},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_403_FORBIDDEN
|
||||
assert response.json()[
|
||||
"detail"] == "Insufficient role permissions for this action."
|
||||
|
||||
def test_api_projects_update_allows_manager(
|
||||
self,
|
||||
client,
|
||||
auth_session_context,
|
||||
unit_of_work_factory,
|
||||
) -> None:
|
||||
with unit_of_work_factory() as uow:
|
||||
project = _create_project(uow)
|
||||
manager = _create_user(uow, roles=("project_manager",))
|
||||
|
||||
updated_description = "Manager updated description"
|
||||
with auth_session_context(manager):
|
||||
response = client.put(
|
||||
f"/projects/{project.id}",
|
||||
json={"description": updated_description},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
assert response.json()["description"] == updated_description
|
||||
|
||||
def test_ui_project_edit_forbidden_for_viewer(
|
||||
self,
|
||||
client,
|
||||
auth_session_context,
|
||||
unit_of_work_factory,
|
||||
) -> None:
|
||||
with unit_of_work_factory() as uow:
|
||||
project = _create_project(uow)
|
||||
viewer = _create_user(uow, roles=("viewer",))
|
||||
|
||||
with auth_session_context(viewer):
|
||||
response = client.get(f"/projects/{project.id}/edit")
|
||||
|
||||
assert response.status_code == status.HTTP_403_FORBIDDEN
|
||||
assert response.json()[
|
||||
"detail"] == "Insufficient role permissions for this action."
|
||||
|
||||
def test_ui_project_edit_accessible_to_manager(
|
||||
self,
|
||||
client,
|
||||
auth_session_context,
|
||||
unit_of_work_factory,
|
||||
) -> None:
|
||||
with unit_of_work_factory() as uow:
|
||||
project = _create_project(uow)
|
||||
manager = _create_user(uow, roles=("project_manager",))
|
||||
|
||||
with auth_session_context(manager):
|
||||
response = client.get(f"/projects/{project.id}/edit")
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
assert response.template.name == "projects/form.html"
|
||||
116
tests/test_bootstrap.py
Normal file
116
tests/test_bootstrap.py
Normal file
@@ -0,0 +1,116 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable
|
||||
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
from config.database import Base
|
||||
from config.settings import AdminBootstrapSettings
|
||||
from services.bootstrap import AdminBootstrapResult, RoleBootstrapResult, bootstrap_admin
|
||||
from services.unit_of_work import UnitOfWork
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def session_factory() -> Callable[[], Session]:
|
||||
engine = create_engine("sqlite+pysqlite:///:memory:", future=True)
|
||||
Base.metadata.create_all(engine)
|
||||
factory = sessionmaker(bind=engine, expire_on_commit=False, future=True)
|
||||
|
||||
def _factory() -> Session:
|
||||
return factory()
|
||||
|
||||
return _factory
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def unit_of_work_factory(session_factory: Callable[[], Session]) -> Callable[[], UnitOfWork]:
|
||||
def _factory() -> UnitOfWork:
|
||||
return UnitOfWork(session_factory=session_factory)
|
||||
|
||||
return _factory
|
||||
|
||||
|
||||
def _settings(**overrides: Any) -> AdminBootstrapSettings:
|
||||
defaults: dict[str, Any] = {
|
||||
"email": "admin@example.com",
|
||||
"username": "admin",
|
||||
"password": "changeme",
|
||||
"roles": ("admin", "viewer"),
|
||||
"force_reset": False,
|
||||
}
|
||||
defaults.update(overrides)
|
||||
return AdminBootstrapSettings(
|
||||
email=str(defaults["email"]),
|
||||
username=str(defaults["username"]),
|
||||
password=str(defaults["password"]),
|
||||
roles=tuple(defaults["roles"]),
|
||||
force_reset=bool(defaults["force_reset"]),
|
||||
)
|
||||
|
||||
|
||||
def test_bootstrap_creates_admin_and_roles(unit_of_work_factory: Callable[[], UnitOfWork]) -> None:
|
||||
settings = _settings()
|
||||
|
||||
role_result, admin_result = bootstrap_admin(
|
||||
settings=settings,
|
||||
unit_of_work_factory=unit_of_work_factory,
|
||||
)
|
||||
|
||||
assert role_result == RoleBootstrapResult(created=4, ensured=4)
|
||||
assert admin_result == AdminBootstrapResult(
|
||||
created_user=True,
|
||||
updated_user=False,
|
||||
password_rotated=False,
|
||||
roles_granted=2,
|
||||
)
|
||||
|
||||
with unit_of_work_factory() as uow:
|
||||
users_repo = uow.users
|
||||
assert users_repo is not None
|
||||
user = users_repo.get_by_email(settings.email, with_roles=True)
|
||||
assert user is not None
|
||||
assert user.is_superuser is True
|
||||
assert {role.name for role in user.roles} == {"admin", "viewer"}
|
||||
|
||||
|
||||
def test_bootstrap_is_idempotent(unit_of_work_factory: Callable[[], UnitOfWork]) -> None:
|
||||
settings = _settings()
|
||||
|
||||
bootstrap_admin(settings=settings,
|
||||
unit_of_work_factory=unit_of_work_factory)
|
||||
role_result, admin_result = bootstrap_admin(
|
||||
settings=settings,
|
||||
unit_of_work_factory=unit_of_work_factory,
|
||||
)
|
||||
|
||||
assert role_result.created == 0
|
||||
assert role_result.ensured == 4
|
||||
assert admin_result.created_user is False
|
||||
assert admin_result.updated_user is False
|
||||
assert admin_result.roles_granted == 0
|
||||
|
||||
|
||||
def test_bootstrap_respects_force_reset(unit_of_work_factory: Callable[[], UnitOfWork]) -> None:
|
||||
base_settings = _settings(password="initial")
|
||||
bootstrap_admin(settings=base_settings,
|
||||
unit_of_work_factory=unit_of_work_factory)
|
||||
|
||||
rotated_settings = _settings(password="rotated", force_reset=True)
|
||||
_, admin_result = bootstrap_admin(
|
||||
settings=rotated_settings,
|
||||
unit_of_work_factory=unit_of_work_factory,
|
||||
)
|
||||
|
||||
assert admin_result.password_rotated is True
|
||||
assert admin_result.updated_user is True
|
||||
|
||||
with unit_of_work_factory() as uow:
|
||||
users_repo = uow.users
|
||||
assert users_repo is not None
|
||||
user = users_repo.get_by_email(rotated_settings.email)
|
||||
assert user is not None
|
||||
assert user.verify_password("rotated")
|
||||
27
tests/test_dashboard_route.py
Normal file
27
tests/test_dashboard_route.py
Normal file
@@ -0,0 +1,27 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
|
||||
class TestDashboardRoute:
|
||||
def test_renders_empty_state(self, client: TestClient) -> None:
|
||||
response = client.get("/")
|
||||
assert response.status_code == 200
|
||||
html = response.text
|
||||
|
||||
assert "No recent projects" in html
|
||||
assert "No simulation runs yet" in html
|
||||
assert "All scenarios look good" in html
|
||||
assert "—" in html # Last data import placeholder
|
||||
|
||||
|
||||
class TestProjectUIRoutes:
|
||||
def test_projects_ui_page_resolves(self, client: TestClient) -> None:
|
||||
response = client.get("/projects/ui")
|
||||
assert response.status_code == 200
|
||||
assert "Projects" in response.text
|
||||
|
||||
def test_projects_create_form_resolves(self, client: TestClient) -> None:
|
||||
response = client.get("/projects/create")
|
||||
assert response.status_code == 200
|
||||
assert "Create Project" in response.text
|
||||
267
tests/test_dependencies_guards.py
Normal file
267
tests/test_dependencies_guards.py
Normal file
@@ -0,0 +1,267 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from fastapi import HTTPException, status
|
||||
|
||||
from dependencies import (
|
||||
require_any_role,
|
||||
require_authenticated_user,
|
||||
require_current_user,
|
||||
require_project_resource,
|
||||
require_project_scenario_resource,
|
||||
require_roles,
|
||||
require_scenario_resource,
|
||||
)
|
||||
from models import Project, Scenario, User
|
||||
from services.session import AuthSession, SessionTokens
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def uow(unit_of_work_factory):
|
||||
with unit_of_work_factory() as uow:
|
||||
assert uow.users and uow.roles and uow.projects and uow.scenarios
|
||||
uow.ensure_default_roles()
|
||||
yield uow
|
||||
|
||||
|
||||
def _unique(prefix: str) -> str:
|
||||
return f"{prefix}-{uuid4().hex}"
|
||||
|
||||
|
||||
def _create_user(
|
||||
uow,
|
||||
*,
|
||||
roles: tuple[str, ...] = (),
|
||||
is_active: bool = True,
|
||||
is_superuser: bool = False,
|
||||
) -> User:
|
||||
user = User(
|
||||
email=f"{_unique('user')}@example.com",
|
||||
username=_unique('user'),
|
||||
password_hash=User.hash_password("password"),
|
||||
is_active=is_active,
|
||||
is_superuser=is_superuser,
|
||||
)
|
||||
assert uow.users and uow.roles
|
||||
uow.users.create(user)
|
||||
uow.flush()
|
||||
|
||||
for role_name in roles:
|
||||
role = uow.roles.get_by_name(role_name)
|
||||
if role is None: # pragma: no cover - defensive for missing roles
|
||||
raise AssertionError(f"Role {role_name} expected in test database")
|
||||
uow.users.assign_role(user_id=user.id, role_id=role.id)
|
||||
return uow.users.get(user.id, with_roles=True)
|
||||
|
||||
|
||||
def _create_project(uow) -> Project:
|
||||
assert uow.projects
|
||||
project = Project(name=_unique('project'), location="Test Site")
|
||||
uow.projects.create(project)
|
||||
uow.flush()
|
||||
return project
|
||||
|
||||
|
||||
def _create_scenario(uow, project: Project) -> Scenario:
|
||||
assert uow.scenarios
|
||||
scenario = Scenario(project_id=project.id, name=_unique('scenario'))
|
||||
uow.scenarios.create(scenario)
|
||||
uow.flush()
|
||||
return scenario
|
||||
|
||||
|
||||
def test_require_current_user_returns_authenticated_user():
|
||||
user = User(
|
||||
email="user@example.com",
|
||||
username="user",
|
||||
password_hash=User.hash_password("password"),
|
||||
is_active=True,
|
||||
)
|
||||
session = AuthSession(tokens=SessionTokens(
|
||||
access_token="token", refresh_token=None))
|
||||
session.user = user
|
||||
|
||||
result = require_current_user(session=session)
|
||||
|
||||
assert result is user
|
||||
|
||||
|
||||
def test_require_current_user_raises_when_session_missing():
|
||||
anonymous_session = AuthSession.anonymous()
|
||||
|
||||
with pytest.raises(HTTPException) as exc:
|
||||
require_current_user(session=anonymous_session)
|
||||
|
||||
assert exc.value.status_code == status.HTTP_401_UNAUTHORIZED
|
||||
assert exc.value.detail == "Authentication required."
|
||||
|
||||
|
||||
def test_require_authenticated_user_blocks_inactive_users():
|
||||
user = User(
|
||||
email="user@example.com",
|
||||
username="user",
|
||||
password_hash=User.hash_password("password"),
|
||||
is_active=False,
|
||||
)
|
||||
|
||||
with pytest.raises(HTTPException) as exc:
|
||||
require_authenticated_user(user=user)
|
||||
|
||||
assert exc.value.status_code == status.HTTP_403_FORBIDDEN
|
||||
assert exc.value.detail == "User account is disabled."
|
||||
|
||||
|
||||
def test_require_roles_accepts_user_with_role(uow):
|
||||
user = _create_user(uow, roles=("viewer",))
|
||||
|
||||
dependency = require_roles("viewer")
|
||||
result = dependency(user=user)
|
||||
|
||||
assert result is user
|
||||
|
||||
|
||||
def test_require_roles_allows_superuser_without_matching_role(uow):
|
||||
user = _create_user(uow, roles=(), is_superuser=True)
|
||||
|
||||
dependency = require_roles("project_manager")
|
||||
result = dependency(user=user)
|
||||
|
||||
assert result is user
|
||||
|
||||
|
||||
def test_require_roles_rejects_user_missing_required_role(uow):
|
||||
user = _create_user(uow, roles=("viewer",))
|
||||
|
||||
dependency = require_any_role("project_manager", "admin")
|
||||
with pytest.raises(HTTPException) as exc:
|
||||
dependency(user=user)
|
||||
|
||||
assert exc.value.status_code == status.HTTP_403_FORBIDDEN
|
||||
assert exc.value.detail == "Insufficient permissions for this action."
|
||||
|
||||
|
||||
def test_require_roles_raises_value_error_when_no_roles_provided():
|
||||
with pytest.raises(ValueError) as exc:
|
||||
require_roles()
|
||||
|
||||
assert str(exc.value) == "require_roles requires at least one role name"
|
||||
|
||||
|
||||
def test_require_project_resource_returns_project_for_authorized_user(uow):
|
||||
user = _create_user(uow, roles=("viewer",))
|
||||
project = _create_project(uow)
|
||||
|
||||
dependency = require_project_resource()
|
||||
result = dependency(project.id, user=user, uow=uow)
|
||||
|
||||
assert result.id == project.id
|
||||
|
||||
|
||||
def test_require_project_resource_enforces_manage_requirement(uow):
|
||||
user = _create_user(uow, roles=("viewer",))
|
||||
project = _create_project(uow)
|
||||
|
||||
dependency = require_project_resource(require_manage=True)
|
||||
with pytest.raises(HTTPException) as exc:
|
||||
dependency(project.id, user=user, uow=uow)
|
||||
|
||||
assert exc.value.status_code == status.HTTP_403_FORBIDDEN
|
||||
assert exc.value.detail == "Insufficient role permissions for this action."
|
||||
|
||||
|
||||
def test_require_project_resource_raises_not_found_for_missing_project(uow):
|
||||
user = _create_user(uow, roles=("viewer",))
|
||||
|
||||
dependency = require_project_resource()
|
||||
with pytest.raises(HTTPException) as exc:
|
||||
dependency(9999, user=user, uow=uow)
|
||||
|
||||
assert exc.value.status_code == status.HTTP_404_NOT_FOUND
|
||||
assert exc.value.detail == "Project 9999 not found"
|
||||
|
||||
|
||||
def test_require_project_resource_rejects_inactive_user(uow):
|
||||
user = _create_user(uow, roles=("viewer",), is_active=False)
|
||||
project = _create_project(uow)
|
||||
|
||||
dependency = require_project_resource()
|
||||
with pytest.raises(HTTPException) as exc:
|
||||
dependency(project.id, user=user, uow=uow)
|
||||
|
||||
assert exc.value.status_code == status.HTTP_403_FORBIDDEN
|
||||
assert exc.value.detail == "User account is disabled."
|
||||
|
||||
|
||||
def test_require_scenario_resource_returns_scenario_for_authorized_user(uow):
|
||||
user = _create_user(uow, roles=("viewer",))
|
||||
project = _create_project(uow)
|
||||
scenario = _create_scenario(uow, project)
|
||||
|
||||
dependency = require_scenario_resource()
|
||||
result = dependency(scenario.id, user=user, uow=uow)
|
||||
|
||||
assert result.id == scenario.id
|
||||
|
||||
|
||||
def test_require_scenario_resource_requires_manage_role(uow):
|
||||
user = _create_user(uow, roles=("viewer",))
|
||||
project = _create_project(uow)
|
||||
scenario = _create_scenario(uow, project)
|
||||
|
||||
dependency = require_scenario_resource(require_manage=True)
|
||||
with pytest.raises(HTTPException) as exc:
|
||||
dependency(scenario.id, user=user, uow=uow)
|
||||
|
||||
assert exc.value.status_code == status.HTTP_403_FORBIDDEN
|
||||
assert exc.value.detail == "Insufficient role permissions for this action."
|
||||
|
||||
|
||||
def test_require_scenario_resource_raises_not_found(uow):
|
||||
user = _create_user(uow, roles=("viewer",))
|
||||
|
||||
dependency = require_scenario_resource()
|
||||
with pytest.raises(HTTPException) as exc:
|
||||
dependency(12345, user=user, uow=uow)
|
||||
|
||||
assert exc.value.status_code == status.HTTP_404_NOT_FOUND
|
||||
assert exc.value.detail == "Scenario 12345 not found"
|
||||
|
||||
|
||||
def test_require_project_scenario_resource_returns_scenario_when_linked(uow):
|
||||
user = _create_user(uow, roles=("viewer",))
|
||||
project = _create_project(uow)
|
||||
scenario = _create_scenario(uow, project)
|
||||
|
||||
dependency = require_project_scenario_resource()
|
||||
result = dependency(project.id, scenario.id, user=user, uow=uow)
|
||||
|
||||
assert result.id == scenario.id
|
||||
|
||||
|
||||
def test_require_project_scenario_resource_raises_when_scenario_not_in_project(uow):
|
||||
user = _create_user(uow, roles=("viewer",))
|
||||
project = _create_project(uow)
|
||||
other_project = _create_project(uow)
|
||||
scenario = _create_scenario(uow, other_project)
|
||||
|
||||
dependency = require_project_scenario_resource()
|
||||
with pytest.raises(HTTPException) as exc:
|
||||
dependency(project.id, scenario.id, user=user, uow=uow)
|
||||
|
||||
assert exc.value.status_code == status.HTTP_404_NOT_FOUND
|
||||
assert "does not belong" in exc.value.detail
|
||||
|
||||
|
||||
def test_require_project_scenario_resource_requires_manage_role(uow):
|
||||
user = _create_user(uow, roles=("viewer",))
|
||||
project = _create_project(uow)
|
||||
scenario = _create_scenario(uow, project)
|
||||
|
||||
dependency = require_project_scenario_resource(require_manage=True)
|
||||
with pytest.raises(HTTPException) as exc:
|
||||
dependency(project.id, scenario.id, user=user, uow=uow)
|
||||
|
||||
assert exc.value.status_code == status.HTTP_403_FORBIDDEN
|
||||
assert exc.value.detail == "Insufficient role permissions for this action."
|
||||
244
tests/test_repositories.py
Normal file
244
tests/test_repositories.py
Normal file
@@ -0,0 +1,244 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Iterator
|
||||
from datetime import datetime, timezone
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
from config.database import Base
|
||||
from models import (
|
||||
DistributionType,
|
||||
FinancialCategory,
|
||||
FinancialInput,
|
||||
MiningOperationType,
|
||||
Project,
|
||||
Scenario,
|
||||
ScenarioStatus,
|
||||
SimulationParameter,
|
||||
StochasticVariable,
|
||||
)
|
||||
from services.repositories import (
|
||||
FinancialInputRepository,
|
||||
ProjectRepository,
|
||||
ScenarioRepository,
|
||||
SimulationParameterRepository,
|
||||
)
|
||||
from services.unit_of_work import UnitOfWork
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def engine():
|
||||
engine = create_engine("sqlite:///:memory:", future=True)
|
||||
Base.metadata.create_all(bind=engine)
|
||||
try:
|
||||
yield engine
|
||||
finally:
|
||||
Base.metadata.drop_all(bind=engine)
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def session(engine) -> Iterator[Session]:
|
||||
TestingSession = sessionmaker(bind=engine, expire_on_commit=False, future=True)
|
||||
session = TestingSession()
|
||||
try:
|
||||
yield session
|
||||
finally:
|
||||
session.close()
|
||||
|
||||
|
||||
def test_project_repository_create_and_list(session: Session) -> None:
|
||||
repo = ProjectRepository(session)
|
||||
project = Project(name="Project Alpha", operation_type=MiningOperationType.OPEN_PIT)
|
||||
repo.create(project)
|
||||
|
||||
projects = repo.list()
|
||||
|
||||
assert len(projects) == 1
|
||||
assert projects[0].name == "Project Alpha"
|
||||
|
||||
|
||||
def test_scenario_repository_get_with_children(session: Session) -> None:
|
||||
project = Project(name="Project Beta", operation_type=MiningOperationType.UNDERGROUND)
|
||||
scenario = Scenario(name="Scenario 1", project=project, status=ScenarioStatus.ACTIVE)
|
||||
scenario.financial_inputs.append(
|
||||
FinancialInput(
|
||||
name="Lease Payment",
|
||||
category=FinancialCategory.OPERATING_EXPENDITURE,
|
||||
amount=10000,
|
||||
currency="usd",
|
||||
)
|
||||
)
|
||||
scenario.simulation_parameters.append(
|
||||
SimulationParameter(
|
||||
name="Copper Price",
|
||||
distribution=DistributionType.NORMAL,
|
||||
mean_value=3.5,
|
||||
variable=StochasticVariable.METAL_PRICE,
|
||||
)
|
||||
)
|
||||
|
||||
session.add(project)
|
||||
session.flush()
|
||||
|
||||
repo = ScenarioRepository(session)
|
||||
retrieved = repo.get(scenario.id, with_children=True)
|
||||
|
||||
assert retrieved.project.name == "Project Beta"
|
||||
assert len(retrieved.financial_inputs) == 1
|
||||
assert retrieved.financial_inputs[0].currency == "USD"
|
||||
assert len(retrieved.simulation_parameters) == 1
|
||||
assert (
|
||||
retrieved.simulation_parameters[0].variable
|
||||
== StochasticVariable.METAL_PRICE
|
||||
)
|
||||
|
||||
param_repo = SimulationParameterRepository(session)
|
||||
params = param_repo.list_for_scenario(scenario.id)
|
||||
assert len(params) == 1
|
||||
|
||||
|
||||
def test_financial_input_repository_bulk_upsert(session: Session) -> None:
|
||||
project = Project(name="Project Gamma", operation_type=MiningOperationType.QUARRY)
|
||||
scenario = Scenario(name="Scenario Bulk", project=project)
|
||||
session.add(project)
|
||||
session.flush()
|
||||
|
||||
repo = FinancialInputRepository(session)
|
||||
created = repo.bulk_upsert(
|
||||
[
|
||||
FinancialInput(
|
||||
scenario_id=scenario.id,
|
||||
name="Explosives",
|
||||
category=FinancialCategory.OPERATING_EXPENDITURE,
|
||||
amount=5000,
|
||||
currency="cad",
|
||||
),
|
||||
FinancialInput(
|
||||
scenario_id=scenario.id,
|
||||
name="Equipment Lease",
|
||||
category=FinancialCategory.OPERATING_EXPENDITURE,
|
||||
amount=12000,
|
||||
currency="cad",
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
assert len(created) == 2
|
||||
stored = repo.list_for_scenario(scenario.id)
|
||||
assert len(stored) == 2
|
||||
assert all(item.currency == "CAD" for item in stored)
|
||||
|
||||
|
||||
def test_unit_of_work_commit_and_rollback(engine) -> None:
|
||||
TestingSession = sessionmaker(bind=engine, expire_on_commit=False, future=True)
|
||||
|
||||
# Commit path
|
||||
with UnitOfWork(session_factory=TestingSession) as uow:
|
||||
uow.projects.create(
|
||||
Project(name="Project Delta", operation_type=MiningOperationType.PLACER)
|
||||
)
|
||||
|
||||
with TestingSession() as session:
|
||||
projects = ProjectRepository(session).list()
|
||||
assert len(projects) == 1
|
||||
|
||||
# Rollback path
|
||||
with pytest.raises(RuntimeError):
|
||||
with UnitOfWork(session_factory=TestingSession) as uow:
|
||||
uow.projects.create(
|
||||
Project(name="Project Epsilon", operation_type=MiningOperationType.OTHER)
|
||||
)
|
||||
raise RuntimeError("trigger rollback")
|
||||
|
||||
with TestingSession() as session:
|
||||
projects = ProjectRepository(session).list()
|
||||
assert len(projects) == 1
|
||||
|
||||
|
||||
def test_project_repository_count_and_recent(session: Session) -> None:
|
||||
repo = ProjectRepository(session)
|
||||
project_alpha = Project(name="Alpha", operation_type=MiningOperationType.OPEN_PIT)
|
||||
project_bravo = Project(name="Bravo", operation_type=MiningOperationType.UNDERGROUND)
|
||||
|
||||
repo.create(project_alpha)
|
||||
repo.create(project_bravo)
|
||||
|
||||
project_alpha.updated_at = datetime(2025, 1, 1, tzinfo=timezone.utc)
|
||||
project_bravo.updated_at = datetime(2025, 1, 2, tzinfo=timezone.utc)
|
||||
session.flush()
|
||||
|
||||
assert repo.count() == 2
|
||||
recent = repo.recent(limit=1)
|
||||
assert len(recent) == 1
|
||||
assert recent[0].name == "Bravo"
|
||||
|
||||
|
||||
def test_scenario_repository_counts_and_filters(session: Session) -> None:
|
||||
project = Project(name="Project", operation_type=MiningOperationType.PLACER)
|
||||
session.add(project)
|
||||
session.flush()
|
||||
|
||||
repo = ScenarioRepository(session)
|
||||
draft = Scenario(name="Draft", project_id=project.id,
|
||||
status=ScenarioStatus.DRAFT)
|
||||
active = Scenario(name="Active", project_id=project.id,
|
||||
status=ScenarioStatus.ACTIVE)
|
||||
|
||||
repo.create(draft)
|
||||
repo.create(active)
|
||||
|
||||
draft.updated_at = datetime(2025, 1, 1, tzinfo=timezone.utc)
|
||||
active.updated_at = datetime(2025, 1, 3, tzinfo=timezone.utc)
|
||||
session.flush()
|
||||
|
||||
assert repo.count() == 2
|
||||
assert repo.count_by_status(ScenarioStatus.ACTIVE) == 1
|
||||
|
||||
recent = repo.recent(limit=1, with_project=True)
|
||||
assert len(recent) == 1
|
||||
assert recent[0].name == "Active"
|
||||
assert recent[0].project.name == "Project"
|
||||
|
||||
drafts = repo.list_by_status(ScenarioStatus.DRAFT, limit=2, with_project=True)
|
||||
assert len(drafts) == 1
|
||||
assert drafts[0].name == "Draft"
|
||||
assert drafts[0].project_id == project.id
|
||||
|
||||
|
||||
def test_financial_input_repository_latest_created_at(session: Session) -> None:
|
||||
project = Project(name="Project FI", operation_type=MiningOperationType.OTHER)
|
||||
scenario = Scenario(name="Scenario FI", project=project)
|
||||
session.add(project)
|
||||
session.flush()
|
||||
|
||||
repo = FinancialInputRepository(session)
|
||||
old_timestamp = datetime(2024, 12, 31, 15, 0)
|
||||
new_timestamp = datetime(2025, 1, 2, 8, 30)
|
||||
|
||||
repo.bulk_upsert(
|
||||
[
|
||||
FinancialInput(
|
||||
scenario_id=scenario.id,
|
||||
name="Legacy Entry",
|
||||
category=FinancialCategory.OPERATING_EXPENDITURE,
|
||||
amount=1000,
|
||||
currency="usd",
|
||||
created_at=old_timestamp,
|
||||
updated_at=old_timestamp,
|
||||
),
|
||||
FinancialInput(
|
||||
scenario_id=scenario.id,
|
||||
name="New Entry",
|
||||
category=FinancialCategory.OPERATING_EXPENDITURE,
|
||||
amount=2000,
|
||||
currency="usd",
|
||||
created_at=new_timestamp,
|
||||
updated_at=new_timestamp,
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
latest = repo.latest_created_at()
|
||||
assert latest == new_timestamp
|
||||
280
tests/test_scenario_validation.py
Normal file
280
tests/test_scenario_validation.py
Normal file
@@ -0,0 +1,280 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import date
|
||||
from collections.abc import Iterator
|
||||
from typing import cast
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from fastapi import FastAPI, Request
|
||||
from fastapi.testclient import TestClient
|
||||
from pydantic import ValidationError
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from sqlalchemy.engine import Engine
|
||||
from sqlalchemy.pool import StaticPool
|
||||
|
||||
from config.database import Base
|
||||
from dependencies import get_auth_session, get_unit_of_work
|
||||
from models import (
|
||||
MiningOperationType,
|
||||
Project,
|
||||
ResourceType,
|
||||
Scenario,
|
||||
ScenarioStatus,
|
||||
User,
|
||||
)
|
||||
from schemas.scenario import (
|
||||
ScenarioComparisonRequest,
|
||||
ScenarioComparisonResponse,
|
||||
)
|
||||
from services.exceptions import ScenarioValidationError
|
||||
from services.scenario_validation import ScenarioComparisonValidator
|
||||
from services.unit_of_work import UnitOfWork
|
||||
from services.session import AuthSession, SessionTokens
|
||||
from routes.scenarios import router as scenarios_router
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def validator() -> ScenarioComparisonValidator:
|
||||
return ScenarioComparisonValidator()
|
||||
|
||||
|
||||
def _make_scenario(**overrides) -> Scenario:
|
||||
project_id: int = int(overrides.get("project_id", 1))
|
||||
name: str = str(overrides.get("name", "Scenario"))
|
||||
status = cast(ScenarioStatus, overrides.get(
|
||||
"status", ScenarioStatus.DRAFT))
|
||||
start_date = overrides.get("start_date", date(2025, 1, 1))
|
||||
end_date = overrides.get("end_date", date(2025, 12, 31))
|
||||
currency = cast(str, overrides.get("currency", "USD"))
|
||||
primary_resource = cast(ResourceType, overrides.get(
|
||||
"primary_resource", ResourceType.DIESEL))
|
||||
|
||||
scenario = Scenario(
|
||||
project_id=project_id,
|
||||
name=name,
|
||||
status=status,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
currency=currency,
|
||||
primary_resource=primary_resource,
|
||||
)
|
||||
|
||||
if "id" in overrides:
|
||||
scenario.id = overrides["id"]
|
||||
|
||||
return scenario
|
||||
|
||||
|
||||
class TestScenarioComparisonValidator:
|
||||
def test_validate_allows_matching_scenarios(self, validator: ScenarioComparisonValidator) -> None:
|
||||
scenario_a = _make_scenario(id=1)
|
||||
scenario_b = _make_scenario(id=2)
|
||||
|
||||
validator.validate([scenario_a, scenario_b])
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"kwargs_a, kwargs_b, expected_code",
|
||||
[
|
||||
({"project_id": 1}, {"project_id": 2}, "SCENARIO_PROJECT_MISMATCH"),
|
||||
({"status": ScenarioStatus.ARCHIVED}, {}, "SCENARIO_STATUS_INVALID"),
|
||||
({"currency": "USD"}, {"currency": "CAD"},
|
||||
"SCENARIO_CURRENCY_MISMATCH"),
|
||||
(
|
||||
{"start_date": date(2025, 1, 1), "end_date": date(2025, 6, 1)},
|
||||
{"start_date": date(2025, 7, 1),
|
||||
"end_date": date(2025, 12, 31)},
|
||||
"SCENARIO_TIMELINE_DISJOINT",
|
||||
),
|
||||
({"primary_resource": ResourceType.DIESEL}, {
|
||||
"primary_resource": ResourceType.ELECTRICITY}, "SCENARIO_RESOURCE_MISMATCH"),
|
||||
],
|
||||
)
|
||||
def test_validate_raises_for_conflicts(
|
||||
self,
|
||||
validator: ScenarioComparisonValidator,
|
||||
kwargs_a: dict[str, object],
|
||||
kwargs_b: dict[str, object],
|
||||
expected_code: str,
|
||||
) -> None:
|
||||
scenario_a = _make_scenario(id=10, **kwargs_a)
|
||||
scenario_b = _make_scenario(id=20, **kwargs_b)
|
||||
|
||||
with pytest.raises(ScenarioValidationError) as exc_info:
|
||||
validator.validate([scenario_a, scenario_b])
|
||||
|
||||
assert exc_info.value.code == expected_code
|
||||
|
||||
def test_timeline_rule_skips_when_insufficient_ranges(
|
||||
self, validator: ScenarioComparisonValidator
|
||||
) -> None:
|
||||
scenario_a = _make_scenario(id=1, start_date=None, end_date=None)
|
||||
scenario_b = _make_scenario(id=2, start_date=date(
|
||||
2025, 1, 1), end_date=date(2025, 12, 31))
|
||||
|
||||
validator.validate([scenario_a, scenario_b])
|
||||
|
||||
|
||||
class TestScenarioComparisonRequest:
|
||||
def test_requires_two_unique_identifiers(self) -> None:
|
||||
with pytest.raises(ValidationError):
|
||||
ScenarioComparisonRequest.model_validate({"scenario_ids": [1]})
|
||||
|
||||
def test_deduplicates_ids_preserving_order(self) -> None:
|
||||
payload = ScenarioComparisonRequest.model_validate(
|
||||
{"scenario_ids": [1, 1, 2, 2, 3]})
|
||||
|
||||
assert payload.scenario_ids == [1, 2, 3]
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def engine() -> Iterator[Engine]:
|
||||
engine = create_engine(
|
||||
"sqlite+pysqlite:///:memory:",
|
||||
future=True,
|
||||
connect_args={"check_same_thread": False},
|
||||
poolclass=StaticPool,
|
||||
)
|
||||
Base.metadata.create_all(bind=engine)
|
||||
try:
|
||||
yield engine
|
||||
finally:
|
||||
Base.metadata.drop_all(bind=engine)
|
||||
engine.dispose()
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def session_factory(engine: Engine) -> Iterator[sessionmaker]:
|
||||
testing_session = sessionmaker(
|
||||
bind=engine, expire_on_commit=False, future=True)
|
||||
yield testing_session
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def api_client(session_factory) -> Iterator[TestClient]:
|
||||
app = FastAPI()
|
||||
app.include_router(scenarios_router)
|
||||
|
||||
def _override_uow() -> Iterator[UnitOfWork]:
|
||||
with UnitOfWork(session_factory=session_factory) as uow:
|
||||
yield uow
|
||||
|
||||
app.dependency_overrides[get_unit_of_work] = _override_uow
|
||||
|
||||
with UnitOfWork(session_factory=session_factory) as uow:
|
||||
assert uow.users is not None
|
||||
uow.ensure_default_roles()
|
||||
user = User(
|
||||
email="test-scenarios@example.com",
|
||||
username="scenario-tester",
|
||||
password_hash=User.hash_password("password"),
|
||||
is_active=True,
|
||||
is_superuser=True,
|
||||
)
|
||||
uow.users.create(user)
|
||||
user = uow.users.get(user.id, with_roles=True)
|
||||
|
||||
def _override_auth_session(request: Request) -> AuthSession:
|
||||
session = AuthSession(tokens=SessionTokens(
|
||||
access_token="test", refresh_token="test"))
|
||||
session.user = user
|
||||
request.state.auth_session = session
|
||||
return session
|
||||
|
||||
app.dependency_overrides[get_auth_session] = _override_auth_session
|
||||
client = TestClient(app)
|
||||
try:
|
||||
yield client
|
||||
finally:
|
||||
client.close()
|
||||
|
||||
|
||||
def _create_project_with_scenarios(
|
||||
session_factory: sessionmaker,
|
||||
scenario_overrides: list[dict[str, object]],
|
||||
) -> tuple[int, list[int]]:
|
||||
with UnitOfWork(session_factory=session_factory) as uow:
|
||||
assert uow.projects is not None
|
||||
assert uow.scenarios is not None
|
||||
project_name = f"Project {uuid4()}"
|
||||
project = Project(name=project_name,
|
||||
operation_type=MiningOperationType.OPEN_PIT)
|
||||
uow.projects.create(project)
|
||||
|
||||
scenario_ids: list[int] = []
|
||||
for index, overrides in enumerate(scenario_overrides, start=1):
|
||||
scenario = Scenario(
|
||||
project_id=project.id,
|
||||
name=f"Scenario {index}",
|
||||
status=overrides.get("status", ScenarioStatus.DRAFT),
|
||||
start_date=overrides.get("start_date", date(2025, 1, 1)),
|
||||
end_date=overrides.get("end_date", date(2025, 12, 31)),
|
||||
currency=overrides.get("currency", "USD"),
|
||||
primary_resource=overrides.get(
|
||||
"primary_resource", ResourceType.DIESEL),
|
||||
)
|
||||
uow.scenarios.create(scenario)
|
||||
scenario_ids.append(scenario.id)
|
||||
|
||||
return project.id, scenario_ids
|
||||
|
||||
|
||||
class TestScenarioComparisonEndpoint:
|
||||
def test_returns_scenarios_when_validation_passes(
|
||||
self, api_client: TestClient, session_factory: sessionmaker
|
||||
) -> None:
|
||||
project_id, scenario_ids = _create_project_with_scenarios(
|
||||
session_factory,
|
||||
[
|
||||
{},
|
||||
{"start_date": date(2025, 6, 1),
|
||||
"end_date": date(2025, 12, 31)},
|
||||
],
|
||||
)
|
||||
|
||||
response = api_client.post(
|
||||
f"/projects/{project_id}/scenarios/compare",
|
||||
json={"scenario_ids": scenario_ids},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
payload = ScenarioComparisonResponse.model_validate(response.json())
|
||||
assert payload.project_id == project_id
|
||||
assert {scenario.id for scenario in payload.scenarios} == set(
|
||||
scenario_ids)
|
||||
|
||||
def test_returns_422_when_currency_mismatch(
|
||||
self, api_client: TestClient, session_factory: sessionmaker
|
||||
) -> None:
|
||||
project_id, scenario_ids = _create_project_with_scenarios(
|
||||
session_factory,
|
||||
[{"currency": "USD"}, {"currency": "CAD"}],
|
||||
)
|
||||
|
||||
response = api_client.post(
|
||||
f"/projects/{project_id}/scenarios/compare",
|
||||
json={"scenario_ids": scenario_ids},
|
||||
)
|
||||
|
||||
assert response.status_code == 422
|
||||
detail = response.json()["detail"]
|
||||
assert detail["code"] == "SCENARIO_CURRENCY_MISMATCH"
|
||||
|
||||
def test_returns_422_when_second_scenario_from_other_project(
|
||||
self, api_client: TestClient, session_factory: sessionmaker
|
||||
) -> None:
|
||||
project_a_id, scenario_ids_a = _create_project_with_scenarios(
|
||||
session_factory, [{}])
|
||||
project_b_id, scenario_ids_b = _create_project_with_scenarios(
|
||||
session_factory, [{}])
|
||||
|
||||
response = api_client.post(
|
||||
f"/projects/{project_a_id}/scenarios/compare",
|
||||
json={"scenario_ids": [scenario_ids_a[0], scenario_ids_b[0]]},
|
||||
)
|
||||
|
||||
assert response.status_code == 422
|
||||
detail = response.json()["detail"]
|
||||
assert detail["code"] == "SCENARIO_PROJECT_MISMATCH"
|
||||
assert project_a_id != project_b_id
|
||||
76
tests/test_security.py
Normal file
76
tests/test_security.py
Normal file
@@ -0,0 +1,76 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import timedelta
|
||||
|
||||
import pytest
|
||||
|
||||
from services.security import (
|
||||
JWTSettings,
|
||||
TokenDecodeError,
|
||||
TokenExpiredError,
|
||||
TokenTypeMismatchError,
|
||||
create_access_token,
|
||||
create_refresh_token,
|
||||
decode_access_token,
|
||||
decode_refresh_token,
|
||||
hash_password,
|
||||
verify_password,
|
||||
)
|
||||
|
||||
|
||||
def test_hash_password_round_trip() -> None:
|
||||
hashed = hash_password("super-secret")
|
||||
|
||||
assert hashed != "super-secret"
|
||||
assert verify_password("super-secret", hashed)
|
||||
assert not verify_password("incorrect", hashed)
|
||||
|
||||
|
||||
def test_verify_password_handles_malformed_hash() -> None:
|
||||
assert not verify_password("secret", "not-a-valid-hash")
|
||||
|
||||
|
||||
def test_access_token_roundtrip() -> None:
|
||||
settings = JWTSettings(secret_key="unit-test-secret")
|
||||
|
||||
token = create_access_token(
|
||||
"user-id-123",
|
||||
settings,
|
||||
scopes=("read", "write"),
|
||||
extra_claims={"custom": "value"},
|
||||
)
|
||||
|
||||
payload = decode_access_token(token, settings)
|
||||
|
||||
assert payload.sub == "user-id-123"
|
||||
assert payload.type == "access"
|
||||
assert payload.scopes == ["read", "write"]
|
||||
|
||||
|
||||
def test_refresh_token_type_mismatch() -> None:
|
||||
settings = JWTSettings(secret_key="unit-test-secret")
|
||||
token = create_refresh_token("user-id-456", settings)
|
||||
|
||||
with pytest.raises(TokenTypeMismatchError):
|
||||
decode_access_token(token, settings)
|
||||
|
||||
|
||||
def test_decode_expired_token() -> None:
|
||||
settings = JWTSettings(secret_key="unit-test-secret")
|
||||
expired_token = create_access_token(
|
||||
"user-id-789",
|
||||
settings,
|
||||
expires_delta=timedelta(seconds=-5),
|
||||
)
|
||||
|
||||
with pytest.raises(TokenExpiredError):
|
||||
decode_access_token(expired_token, settings)
|
||||
|
||||
|
||||
def test_decode_tampered_token() -> None:
|
||||
settings = JWTSettings(secret_key="unit-test-secret")
|
||||
token = create_access_token("user-id-321", settings)
|
||||
tampered = token[:-1] + ("a" if token[-1] != "a" else "b")
|
||||
|
||||
with pytest.raises(TokenDecodeError):
|
||||
decode_access_token(tampered, settings)
|
||||
83
tests/test_user_model.py
Normal file
83
tests/test_user_model.py
Normal file
@@ -0,0 +1,83 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Iterator
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
from config.database import Base
|
||||
from models import Role, User, UserRole
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def engine() -> Iterator:
|
||||
engine = create_engine("sqlite:///:memory:", future=True)
|
||||
Base.metadata.create_all(bind=engine)
|
||||
try:
|
||||
yield engine
|
||||
finally:
|
||||
Base.metadata.drop_all(bind=engine)
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def session(engine) -> Iterator[Session]:
|
||||
TestingSession = sessionmaker(
|
||||
bind=engine, expire_on_commit=False, future=True)
|
||||
session = TestingSession()
|
||||
try:
|
||||
yield session
|
||||
finally:
|
||||
session.close()
|
||||
|
||||
|
||||
def test_user_password_helpers() -> None:
|
||||
user = User(
|
||||
email="user@example.com",
|
||||
username="example",
|
||||
password_hash=User.hash_password("initial"),
|
||||
)
|
||||
|
||||
user.set_password("new-secret")
|
||||
|
||||
assert user.password_hash != "new-secret"
|
||||
assert user.verify_password("new-secret")
|
||||
assert not user.verify_password("wrong")
|
||||
|
||||
|
||||
def test_user_role_assignment(session: Session) -> None:
|
||||
grantor = User(
|
||||
email="admin@example.com",
|
||||
username="admin",
|
||||
password_hash=User.hash_password("admin-secret"),
|
||||
is_superuser=True,
|
||||
)
|
||||
analyst_role = Role(
|
||||
name="analyst",
|
||||
display_name="Analyst",
|
||||
description="Can review project dashboards",
|
||||
)
|
||||
analyst = User(
|
||||
email="analyst@example.com",
|
||||
username="analyst",
|
||||
password_hash=User.hash_password("analyst-secret"),
|
||||
)
|
||||
|
||||
assignment = UserRole(user=analyst, role=analyst_role,
|
||||
granted_by_user=grantor)
|
||||
|
||||
session.add_all([grantor, analyst_role, analyst, assignment])
|
||||
session.commit()
|
||||
|
||||
session.refresh(analyst)
|
||||
session.refresh(analyst_role)
|
||||
|
||||
# Relationship wrapper exposes the role without needing to traverse assignments manually
|
||||
assert len(analyst.role_assignments) == 1
|
||||
assert analyst.role_assignments[0].granted_by_user is grantor
|
||||
assert len(analyst.roles) == 1
|
||||
assert analyst.roles[0].name == "analyst"
|
||||
|
||||
# Ensure reverse relationship exposes the user
|
||||
assert len(analyst_role.assignments) == 1
|
||||
assert analyst_role.users[0].username == "analyst"
|
||||
Reference in New Issue
Block a user