diff --git a/config/database.py b/config/database.py index a850c05..ff6d3c0 100644 --- a/config/database.py +++ b/config/database.py @@ -56,3 +56,11 @@ DATABASE_URL = _build_database_url() engine = create_engine(DATABASE_URL, echo=True, future=True) SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) Base = declarative_base() + + +def get_db(): + db = SessionLocal() + try: + yield db + finally: + db.close() diff --git a/docs/architecture/05_building_block_view.md b/docs/architecture/05_building_block_view.md index 1af3464..7b829bc 100644 --- a/docs/architecture/05_building_block_view.md +++ b/docs/architecture/05_building_block_view.md @@ -1,10 +1,11 @@ --- -title: "05 — Building Block View" -description: "Explain the static structure: modules, components, services and their relationships." +title: '05 — Building Block View' +description: 'Explain the static structure: modules, components, services and their relationships.' status: draft --- + # 05 — Building Block View ## Architecture overview @@ -42,6 +43,144 @@ Refer to the detailed architecture chapters in `docs/architecture/`: - **Middleware** (`middleware/validation.py`): applies JSON validation before requests reach routers. - **Testing** (`tests/unit/`): pytest suite covering route and service behavior, including UI rendering checks and negative-path router validation tests to ensure consistent HTTP error semantics. Playwright end-to-end coverage is planned for core smoke flows (dashboard load, scenario inputs, reporting) and will attach in CI once scaffolding is completed. +### Component Diagram + +# System Architecture — Mermaid Diagram + +```mermaid +graph LR + %% Direction + %% LR = left-to-right for a wide architecture view + + %% === Clients === + U["User (Browser)"] + + %% === Frontend === + subgraph FE[Frontend] + TPL["Jinja2 Templates\n(templates/)\n• base layout + sidebar"] + PARTS["Reusable Partials\n(templates/partials/components.html)\n• inputs • empty states • table wrappers"] + STATIC["Static Assets\n(static/)\n• CSS: static/css/main.css (palette via CSS vars)\n• JS: static/js/*.js (page modules)"] + SETPAGE["Settings View\n(templates/settings.html)"] + SETJS["Settings Logic\n(static/js/settings.js)\n• validation • submit • live CSS updates"] + end + + %% === Backend === + subgraph BE[Backend FastAPI] + MAIN["FastAPI App\n(main.py)\n• routers • middleware • startup/shutdown"] + + subgraph ROUTES[Routers] + R_SCN["scenarios"] + R_PAR["parameters"] + R_CST["costs"] + R_CONS["consumption"] + R_PROD["production"] + R_EQP["equipment"] + R_MNT["maintenance"] + R_SIM["simulations"] + R_REP["reporting"] + R_UI["ui.py (metadata for UI)"] + DEP["dependencies.get_db\n(shared SQLAlchemy session)"] + end + + subgraph SRV[Services] + S_BLL["Business Logic Layer\n• orchestrates models + calc"] + S_REP["Reporting Calculations"] + S_SIM["Monte Carlo\n(simulation scaffolding)"] + S_SET["Settings Manager\n(services/settings.py)\n• defaults via CSS vars\n• persistence in DB\n• env overrides\n• surfaces to API & UI"] + end + + subgraph MOD[Models] + M_SCN["Scenario"] + M_CAP["CapEx"] + M_OPEX["OpEx"] + M_CONS["Consumption"] + M_PROD["ProductionOutput"] + M_EQP["Equipment"] + M_MNT["Maintenance"] + M_SIMR["SimulationResult"] + end + + subgraph DB[Database Layer] + CFG["config/database.py\n(SQLAlchemy engine & sessions)"] + PG[("PostgreSQL")] + APPSET["application_setting table"] + end + end + + %% === Middleware & Utilities === + subgraph MW[Middleware & Utilities] + VAL["JSON Validation Middleware\n(middleware/validation.py)"] + end + + subgraph TEST[Testing] + UNIT["pytest unit tests\n(tests/unit/)\n• routes • services • UI rendering\n• negative-path validation"] + E2E["Playwright E2E (planned)\n• dashboard • scenario inputs • reporting\n• attach in CI"] + end + + %% ===================== Edges / Flows ===================== + %% User to Frontend/Backend + U -->|HTTP GET| MAIN + U --> TPL + TPL -->|server-rendered HTML| U + STATIC --> U + PARTS --> TPL + SETPAGE --> U + SETJS --> U + + %% Frontend to Routers (AJAX/form submits) + SETJS -->|fetch/POST| R_UI + TPL -->|form submit / fetch| ROUTES + + %% FastAPI app wiring and middleware + VAL --> MAIN + MAIN --> ROUTES + + %% Routers to Services + ROUTES -->|calls| SRV + R_REP -->|calc| S_REP + R_SIM -->|run| S_SIM + R_UI -->|read/write settings meta| S_SET + + %% Services to Models & DB + SRV --> MOD + MOD --> CFG + CFG --> PG + + %% Settings manager persistence path + S_SET -->|persist/read| APPSET + APPSET --- PG + + %% Shared DB session dependency + DEP -. provides .-> ROUTES + DEP -. session .-> SRV + + %% Model entities mapping + S_BLL --> M_SCN & M_CAP & M_OPEX & M_CONS & M_PROD & M_EQP & M_MNT & M_SIMR + + %% Testing coverage + UNIT --> ROUTES + UNIT --> SRV + UNIT --> TPL + UNIT --> VAL + E2E --> U + E2E --> MAIN + + %% Legend + classDef store fill:#fff,stroke:#555,stroke-width:1px; + class PG store; +``` + +--- + +**Notes** + +- Arrows represent primary data/command flow. Dashed arrows denote shared dependencies (injected SQLAlchemy session). +- The settings pipeline shows how environment overrides and DB-backed defaults propagate to both API and UI. + +``` + +``` + ## Module Map (code) - `scenario.py`: central scenario entity with relationships to cost, consumption, production, equipment, maintenance, and simulation results. diff --git a/docs/architecture/05_frontend/05_03_theming.md b/docs/architecture/05_frontend/05_03_theming.md new file mode 100644 index 0000000..284b193 --- /dev/null +++ b/docs/architecture/05_frontend/05_03_theming.md @@ -0,0 +1,88 @@ +# Theming + +## Overview + +CalMiner uses a centralized theming system based on CSS custom properties (variables) to ensure consistent styling across the application. The theme is stored in the database and can be customized through environment variables or the UI settings page. + +## Default Theme Settings + +The default theme provides a light, professional color palette suitable for business applications. The colors are defined as CSS custom properties and stored in the `application_setting` table with category "theme". + +### Color Palette + +| CSS Variable | Default Value | Description | +| --------------------------- | ------------------------ | ------------------------ | +| `--color-background` | `#f4f5f7` | Main background color | +| `--color-surface` | `#ffffff` | Surface/card background | +| `--color-text-primary` | `#2a1f33` | Primary text color | +| `--color-text-secondary` | `#624769` | Secondary text color | +| `--color-text-muted` | `#64748b` | Muted text color | +| `--color-text-subtle` | `#94a3b8` | Subtle text color | +| `--color-text-invert` | `#ffffff` | Text on dark backgrounds | +| `--color-text-dark` | `#0f172a` | Dark text for contrast | +| `--color-text-strong` | `#111827` | Strong/bold text | +| `--color-primary` | `#5f320d` | Primary brand color | +| `--color-primary-strong` | `#7e4c13` | Stronger primary | +| `--color-primary-stronger` | `#837c15` | Strongest primary | +| `--color-accent` | `#bff838` | Accent/highlight color | +| `--color-border` | `#e2e8f0` | Default border color | +| `--color-border-strong` | `#cbd5e1` | Strong border color | +| `--color-highlight` | `#eef2ff` | Highlight background | +| `--color-panel-shadow` | `rgba(15, 23, 42, 0.08)` | Subtle shadow | +| `--color-panel-shadow-deep` | `rgba(15, 23, 42, 0.12)` | Deeper shadow | +| `--color-surface-alt` | `#f8fafc` | Alternative surface | +| `--color-success` | `#047857` | Success state color | +| `--color-error` | `#b91c1c` | Error state color | + +## Customization + +### Environment Variables + +Theme colors can be overridden using environment variables with the prefix `CALMINER_THEME_`. For example: + +```bash +export CALMINER_THEME_COLOR_BACKGROUND="#000000" +export CALMINER_THEME_COLOR_ACCENT="#ff0000" +``` + +The variable names are derived by: + +1. Removing the `--` prefix +2. Converting to uppercase +3. Replacing `-` with `_` +4. Adding `CALMINER_THEME_` prefix + +### Database Storage + +Settings are stored in the `application_setting` table with: + +- `category`: "theme" +- `value_type`: "color" +- `is_editable`: true + +### UI Settings + +Users can modify theme colors through the settings page at `/ui/settings`. + +## Implementation + +The theming system is implemented in: + +- `services/settings.py`: Color management and defaults +- `routes/settings.py`: API endpoints for theme settings +- `static/css/main.css`: CSS variable definitions +- `templates/settings.html`: UI for theme customization + +## Seeding + +Default theme settings are seeded during database setup using the seed script: + +```bash +python scripts/seed_data.py --theme +``` + +Or as part of defaults: + +```bash +python scripts/seed_data.py --defaults +``` diff --git a/docs/architecture/08_concepts/08_01_security.md b/docs/architecture/08_concepts/08_01_security.md new file mode 100644 index 0000000..1488537 --- /dev/null +++ b/docs/architecture/08_concepts/08_01_security.md @@ -0,0 +1,36 @@ +# User Roles and Permissions Model + +This document outlines the proposed user roles and permissions model for the CalMiner application. + +## User Roles + +- **Admin:** Full access to all features, including user management, application settings, and all data. +- **Analyst:** Can create, view, edit, and delete scenarios, run simulations, and view reports. Cannot modify application settings or manage users. +- **Viewer:** Can view scenarios, simulations, and reports. Cannot create, edit, or delete anything. + +## Permissions (examples) + +- `users:manage`: Admin only. +- `settings:manage`: Admin only. +- `scenarios:create`: Admin, Analyst. +- `scenarios:view`: Admin, Analyst, Viewer. +- `scenarios:edit`: Admin, Analyst. +- `scenarios:delete`: Admin, Analyst. +- `simulations:run`: Admin, Analyst. +- `simulations:view`: Admin, Analyst, Viewer. +- `reports:view`: Admin, Analyst, Viewer. + +## Authentication System + +The authentication system uses JWT (JSON Web Tokens) for securing API endpoints. Users can register with a username, email, and password. Passwords are hashed using bcrypt. Upon successful login, an access token is issued, which must be included in subsequent requests for protected resources. + +## Key Components + +- **Password Hashing:** `passlib.context.CryptContext` with `bcrypt` scheme. +- **Token Creation & Verification:** `jose.jwt` for encoding and decoding JWTs. +- **Authentication Flow:** + 1. User registers via `/users/register`. + 2. User logs in via `/users/login` to obtain an access token. + 3. The access token is sent in the `Authorization` header (Bearer token) for protected routes. + 4. The `get_current_user` dependency verifies the token and retrieves the authenticated user. +- **Password Reset:** A placeholder `forgot_password` endpoint is available, and a `reset_password` endpoint allows users to set a new password with a valid token (token generation and email sending are not yet implemented). diff --git a/docs/architecture/13_ui_and_style.md b/docs/architecture/13_ui_and_style.md index 4a1e6f5..d3502ca 100644 --- a/docs/architecture/13_ui_and_style.md +++ b/docs/architecture/13_ui_and_style.md @@ -28,6 +28,32 @@ Import macros via: - **Tables**: `.table-container` wrappers need overflow handling for narrow viewports; consider `overflow-x: auto` with padding adjustments. - **Feedback/Empty states**: Messages use default font weight and spacing; a utility class for margin/padding would ensure consistent separation from forms or tables. +## CSS Variable Naming Conventions + +The project adheres to a clear and descriptive naming convention for CSS variables, primarily defined in `static/css/main.css`. + +## Naming Structure + +Variables are prefixed based on their category: + +- `--color-`: For all color-related variables (e.g., `--color-primary`, `--color-background`, `--color-text-primary`). +- `--space-`: For spacing and layout-related variables (e.g., `--space-sm`, `--space-md`, `--space-lg`). +- `--font-size-`: For font size variables (e.g., `--font-size-base`, `--font-size-lg`). +- Other specific prefixes for components or properties (e.g., `--panel-radius`, `--table-radius`). + +## Descriptive Names + +Color names are chosen to be semantically meaningful rather than literal color values, allowing for easier theme changes. For example: + +- `--color-primary`: Represents the main brand color. +- `--color-accent`: Represents an accent color used for highlights. +- `--color-text-primary`: The main text color. +- `--color-text-muted`: A lighter text color for less emphasis. +- `--color-surface`: The background color for UI elements like cards or panels. +- `--color-background`: The overall page background color. + +This approach ensures that the CSS variables are intuitive, maintainable, and easily adaptable for future theme customizations. + ## Per-page data & actions Short reference of per-page APIs and primary actions used by templates and scripts. @@ -76,6 +102,21 @@ Short reference of per-page APIs and primary actions used by templates and scrip - Data: `POST /api/reporting/summary` (accepts arrays of `{ "result": float }` objects) - Actions: Trigger summary refreshes and export/download actions. +## Navigation Structure + +The application uses a sidebar navigation menu organized into the following top-level categories: + +- **Dashboard**: Main overview page. +- **Overview**: Sub-menu for core scenario inputs. + - Parameters: Process parameters configuration. + - Costs: Capital and operating costs. + - Consumption: Resource consumption tracking. + - Production: Production output settings. + - Equipment: Equipment inventory (with Maintenance sub-item). +- **Simulations**: Monte Carlo simulation runs. +- **Analytics**: Reporting and analytics. +- **Settings**: Administrative settings (with Themes and Currency Management sub-items). + ## UI Template Audit (2025-10-20) - Existing HTML templates: `ScenarioForm.html`, `ParameterInput.html`, and `Dashboard.html` (reporting summary view). diff --git a/main.py b/main.py index 0baa79d..171cd88 100644 --- a/main.py +++ b/main.py @@ -17,6 +17,7 @@ from routes.currencies import router as currencies_router from routes.simulations import router as simulations_router from routes.maintenance import router as maintenance_router from routes.settings import router as settings_router +from routes.users import router as users_router # Initialize database schema Base.metadata.create_all(bind=engine) @@ -30,6 +31,7 @@ async def json_validation( ) -> Response: return await validate_json(request, call_next) + app.mount("/static", StaticFiles(directory="static"), name="static") # Include API routers @@ -46,3 +48,4 @@ app.include_router(reporting_router) app.include_router(currencies_router) app.include_router(settings_router) app.include_router(ui_router) +app.include_router(users_router) diff --git a/middleware/validation.py b/middleware/validation.py index b779366..9f2249e 100644 --- a/middleware/validation.py +++ b/middleware/validation.py @@ -4,7 +4,10 @@ from fastapi import HTTPException, Request, Response MiddlewareCallNext = Callable[[Request], Awaitable[Response]] -async def validate_json(request: Request, call_next: MiddlewareCallNext) -> Response: + +async def validate_json( + request: Request, call_next: MiddlewareCallNext +) -> Response: # Only validate JSON for requests with a body if request.method in ("POST", "PUT", "PATCH"): try: diff --git a/models/__init__.py b/models/__init__.py index 81d530a..a46e508 100644 --- a/models/__init__.py +++ b/models/__init__.py @@ -2,5 +2,9 @@ models package initializer. Import key models so they're registered with the shared Base.metadata when the package is imported by tests. """ + from . import application_setting # noqa: F401 from . import currency # noqa: F401 +from . import role # noqa: F401 +from . import user # noqa: F401 +from . import theme_setting # noqa: F401 diff --git a/models/application_setting.py b/models/application_setting.py index 36b0ad5..ed98160 100644 --- a/models/application_setting.py +++ b/models/application_setting.py @@ -14,15 +14,24 @@ class ApplicationSetting(Base): id: Mapped[int] = mapped_column(primary_key=True, index=True) key: Mapped[str] = mapped_column(String(128), unique=True, nullable=False) value: Mapped[str] = mapped_column(Text, nullable=False) - value_type: Mapped[str] = mapped_column(String(32), nullable=False, default="string") - category: Mapped[str] = mapped_column(String(32), nullable=False, default="general") + value_type: Mapped[str] = mapped_column( + String(32), nullable=False, default="string" + ) + category: Mapped[str] = mapped_column( + String(32), nullable=False, default="general" + ) description: Mapped[Optional[str]] = mapped_column(Text, nullable=True) - is_editable: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True) + is_editable: Mapped[bool] = mapped_column( + Boolean, nullable=False, default=True + ) created_at: Mapped[datetime] = mapped_column( DateTime(timezone=True), server_default=func.now(), nullable=False ) updated_at: Mapped[datetime] = mapped_column( - DateTime(timezone=True), server_default=func.now(), onupdate=func.now(), nullable=False + DateTime(timezone=True), + server_default=func.now(), + onupdate=func.now(), + nullable=False, ) def __repr__(self) -> str: diff --git a/models/capex.py b/models/capex.py index 6b68f4c..68b6749 100644 --- a/models/capex.py +++ b/models/capex.py @@ -29,8 +29,9 @@ class Capex(Base): @currency_code.setter def currency_code(self, value: str) -> None: # store pending code so application code or migrations can pick it up - setattr(self, "_currency_code_pending", - (value or "USD").strip().upper()) + setattr( + self, "_currency_code_pending", (value or "USD").strip().upper() + ) # SQLAlchemy event handlers to ensure currency_id is set before insert/update @@ -42,22 +43,27 @@ def _resolve_currency(mapper, connection, target): return code = getattr(target, "_currency_code_pending", None) or "USD" # Try to find existing currency id - row = connection.execute(text("SELECT id FROM currency WHERE code = :code"), { - "code": code}).fetchone() + row = connection.execute( + text("SELECT id FROM currency WHERE code = :code"), {"code": code} + ).fetchone() if row: cid = row[0] else: # Insert new currency and attempt to get lastrowid res = connection.execute( - text("INSERT INTO currency (code, name, symbol, is_active) VALUES (:code, :name, :symbol, :active)"), + text( + "INSERT INTO currency (code, name, symbol, is_active) VALUES (:code, :name, :symbol, :active)" + ), {"code": code, "name": code, "symbol": None, "active": True}, ) try: cid = res.lastrowid except Exception: # fallback: select after insert - cid = connection.execute(text("SELECT id FROM currency WHERE code = :code"), { - "code": code}).scalar() + cid = connection.execute( + text("SELECT id FROM currency WHERE code = :code"), + {"code": code}, + ).scalar() target.currency_id = cid diff --git a/models/currency.py b/models/currency.py index b280c2d..de95abd 100644 --- a/models/currency.py +++ b/models/currency.py @@ -14,8 +14,11 @@ class Currency(Base): # reverse relationships (optional) capex_items = relationship( - "Capex", back_populates="currency", lazy="select") + "Capex", back_populates="currency", lazy="select" + ) opex_items = relationship("Opex", back_populates="currency", lazy="select") def __repr__(self): - return f"" + return ( + f"" + ) diff --git a/models/opex.py b/models/opex.py index a819864..5c0e703 100644 --- a/models/opex.py +++ b/models/opex.py @@ -28,28 +28,34 @@ class Opex(Base): @currency_code.setter def currency_code(self, value: str) -> None: - setattr(self, "_currency_code_pending", - (value or "USD").strip().upper()) + setattr( + self, "_currency_code_pending", (value or "USD").strip().upper() + ) def _resolve_currency_opex(mapper, connection, target): if getattr(target, "currency_id", None): return code = getattr(target, "_currency_code_pending", None) or "USD" - row = connection.execute(text("SELECT id FROM currency WHERE code = :code"), { - "code": code}).fetchone() + row = connection.execute( + text("SELECT id FROM currency WHERE code = :code"), {"code": code} + ).fetchone() if row: cid = row[0] else: res = connection.execute( - text("INSERT INTO currency (code, name, symbol, is_active) VALUES (:code, :name, :symbol, :active)"), + text( + "INSERT INTO currency (code, name, symbol, is_active) VALUES (:code, :name, :symbol, :active)" + ), {"code": code, "name": code, "symbol": None, "active": True}, ) try: cid = res.lastrowid except Exception: - cid = connection.execute(text("SELECT id FROM currency WHERE code = :code"), { - "code": code}).scalar() + cid = connection.execute( + text("SELECT id FROM currency WHERE code = :code"), + {"code": code}, + ).scalar() target.currency_id = cid diff --git a/models/parameters.py b/models/parameters.py index 5182a74..822a011 100644 --- a/models/parameters.py +++ b/models/parameters.py @@ -10,14 +10,17 @@ class Parameter(Base): id: Mapped[int] = mapped_column(primary_key=True, index=True) scenario_id: Mapped[int] = mapped_column( - ForeignKey("scenario.id"), nullable=False) + ForeignKey("scenario.id"), nullable=False + ) name: Mapped[str] = mapped_column(nullable=False) value: Mapped[float] = mapped_column(nullable=False) distribution_id: Mapped[Optional[int]] = mapped_column( - ForeignKey("distribution.id"), nullable=True) + ForeignKey("distribution.id"), nullable=True + ) distribution_type: Mapped[Optional[str]] = mapped_column(nullable=True) distribution_parameters: Mapped[Optional[Dict[str, Any]]] = mapped_column( - JSON, nullable=True) + JSON, nullable=True + ) scenario = relationship("Scenario", back_populates="parameters") distribution = relationship("Distribution") diff --git a/models/production_output.py b/models/production_output.py index a700d57..fde7cb8 100644 --- a/models/production_output.py +++ b/models/production_output.py @@ -14,7 +14,8 @@ class ProductionOutput(Base): unit_symbol = Column(String(16), nullable=True) scenario = relationship( - "Scenario", back_populates="production_output_items") + "Scenario", back_populates="production_output_items" + ) def __repr__(self): return ( diff --git a/models/role.py b/models/role.py new file mode 100644 index 0000000..3351908 --- /dev/null +++ b/models/role.py @@ -0,0 +1,13 @@ +from sqlalchemy import Column, Integer, String +from sqlalchemy.orm import relationship + +from config.database import Base + + +class Role(Base): + __tablename__ = "roles" + + id = Column(Integer, primary_key=True, index=True) + name = Column(String, unique=True, index=True) + + users = relationship("User", back_populates="role") diff --git a/models/scenario.py b/models/scenario.py index 3c9f19e..66d4fd2 100644 --- a/models/scenario.py +++ b/models/scenario.py @@ -20,19 +20,16 @@ class Scenario(Base): updated_at = Column(DateTime(timezone=True), onupdate=func.now()) parameters = relationship("Parameter", back_populates="scenario") simulation_results = relationship( - SimulationResult, back_populates="scenario") - capex_items = relationship( - Capex, back_populates="scenario") - opex_items = relationship( - Opex, back_populates="scenario") - consumption_items = relationship( - Consumption, back_populates="scenario") + SimulationResult, back_populates="scenario" + ) + capex_items = relationship(Capex, back_populates="scenario") + opex_items = relationship(Opex, back_populates="scenario") + consumption_items = relationship(Consumption, back_populates="scenario") production_output_items = relationship( - ProductionOutput, back_populates="scenario") - equipment_items = relationship( - Equipment, back_populates="scenario") - maintenance_items = relationship( - Maintenance, back_populates="scenario") + ProductionOutput, back_populates="scenario" + ) + equipment_items = relationship(Equipment, back_populates="scenario") + maintenance_items = relationship(Maintenance, back_populates="scenario") # relationships can be defined later def __repr__(self): diff --git a/models/theme_setting.py b/models/theme_setting.py new file mode 100644 index 0000000..1e20c64 --- /dev/null +++ b/models/theme_setting.py @@ -0,0 +1,15 @@ +from sqlalchemy import Column, Integer, String + +from config.database import Base + + +class ThemeSetting(Base): + __tablename__ = "theme_settings" + + id = Column(Integer, primary_key=True, index=True) + theme_name = Column(String, unique=True, index=True) + primary_color = Column(String) + secondary_color = Column(String) + accent_color = Column(String) + background_color = Column(String) + text_color = Column(String) diff --git a/models/user.py b/models/user.py new file mode 100644 index 0000000..5ee8654 --- /dev/null +++ b/models/user.py @@ -0,0 +1,23 @@ +from sqlalchemy import Column, Integer, String, ForeignKey +from sqlalchemy.orm import relationship + +from config.database import Base +from services.security import get_password_hash, verify_password + + +class User(Base): + __tablename__ = "users" + + id = Column(Integer, primary_key=True, index=True) + username = Column(String, unique=True, index=True) + email = Column(String, unique=True, index=True) + hashed_password = Column(String) + role_id = Column(Integer, ForeignKey("roles.id")) + + role = relationship("Role", back_populates="users") + + def set_password(self, password: str): + self.hashed_password = get_password_hash(password) + + def check_password(self, password: str) -> bool: + return verify_password(password, str(self.hashed_password)) diff --git a/routes/consumption.py b/routes/consumption.py index 4fee0d2..e03785d 100644 --- a/routes/consumption.py +++ b/routes/consumption.py @@ -36,7 +36,9 @@ class ConsumptionRead(ConsumptionBase): model_config = ConfigDict(from_attributes=True) -@router.post("/", response_model=ConsumptionRead, status_code=status.HTTP_201_CREATED) +@router.post( + "/", response_model=ConsumptionRead, status_code=status.HTTP_201_CREATED +) def create_consumption(item: ConsumptionCreate, db: Session = Depends(get_db)): db_item = Consumption(**item.model_dump()) db.add(db_item) diff --git a/routes/costs.py b/routes/costs.py index 4dafb96..e22f18a 100644 --- a/routes/costs.py +++ b/routes/costs.py @@ -73,7 +73,8 @@ def create_capex(item: CapexCreate, db: Session = Depends(get_db)): if not cid: code = (payload.pop("currency_code", "USD") or "USD").strip().upper() currency_cls = __import__( - "models.currency", fromlist=["Currency"]).Currency + "models.currency", fromlist=["Currency"] + ).Currency currency = db.query(currency_cls).filter_by(code=code).one_or_none() if currency is None: currency = currency_cls(code=code, name=code, symbol=None) @@ -100,7 +101,8 @@ def create_opex(item: OpexCreate, db: Session = Depends(get_db)): if not cid: code = (payload.pop("currency_code", "USD") or "USD").strip().upper() currency_cls = __import__( - "models.currency", fromlist=["Currency"]).Currency + "models.currency", fromlist=["Currency"] + ).Currency currency = db.query(currency_cls).filter_by(code=code).one_or_none() if currency is None: currency = currency_cls(code=code, name=code, symbol=None) diff --git a/routes/currencies.py b/routes/currencies.py index d9a210f..642bb11 100644 --- a/routes/currencies.py +++ b/routes/currencies.py @@ -97,20 +97,20 @@ def _ensure_default_currency(db: Session) -> Currency: def _get_currency_or_404(db: Session, code: str) -> Currency: normalized = code.strip().upper() currency = ( - db.query(Currency) - .filter(Currency.code == normalized) - .one_or_none() + db.query(Currency).filter(Currency.code == normalized).one_or_none() ) if currency is None: raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, detail="Currency not found") + status_code=status.HTTP_404_NOT_FOUND, detail="Currency not found" + ) return currency @router.get("/", response_model=List[CurrencyRead]) def list_currencies( include_inactive: bool = Query( - False, description="Include inactive currencies"), + False, description="Include inactive currencies" + ), db: Session = Depends(get_db), ): _ensure_default_currency(db) @@ -121,14 +121,12 @@ def list_currencies( return currencies -@router.post("/", response_model=CurrencyRead, status_code=status.HTTP_201_CREATED) +@router.post( + "/", response_model=CurrencyRead, status_code=status.HTTP_201_CREATED +) def create_currency(payload: CurrencyCreate, db: Session = Depends(get_db)): code = payload.code - existing = ( - db.query(Currency) - .filter(Currency.code == code) - .one_or_none() - ) + existing = db.query(Currency).filter(Currency.code == code).one_or_none() if existing is not None: raise HTTPException( status_code=status.HTTP_409_CONFLICT, @@ -148,7 +146,9 @@ def create_currency(payload: CurrencyCreate, db: Session = Depends(get_db)): @router.put("/{code}", response_model=CurrencyRead) -def update_currency(code: str, payload: CurrencyUpdate, db: Session = Depends(get_db)): +def update_currency( + code: str, payload: CurrencyUpdate, db: Session = Depends(get_db) +): currency = _get_currency_or_404(db, code) if payload.name is not None: @@ -175,7 +175,9 @@ def update_currency(code: str, payload: CurrencyUpdate, db: Session = Depends(ge @router.patch("/{code}/activation", response_model=CurrencyRead) -def toggle_currency_activation(code: str, body: CurrencyActivation, db: Session = Depends(get_db)): +def toggle_currency_activation( + code: str, body: CurrencyActivation, db: Session = Depends(get_db) +): currency = _get_currency_or_404(db, code) code_value = getattr(currency, "code") if code_value == DEFAULT_CURRENCY_CODE and body.is_active is False: diff --git a/routes/distributions.py b/routes/distributions.py index 8c409c3..34a0cc8 100644 --- a/routes/distributions.py +++ b/routes/distributions.py @@ -22,7 +22,9 @@ class DistributionRead(DistributionCreate): @router.post("/", response_model=DistributionRead) -async def create_distribution(dist: DistributionCreate, db: Session = Depends(get_db)): +async def create_distribution( + dist: DistributionCreate, db: Session = Depends(get_db) +): db_dist = Distribution(**dist.model_dump()) db.add(db_dist) db.commit() diff --git a/routes/equipment.py b/routes/equipment.py index c8aecbd..a5800a9 100644 --- a/routes/equipment.py +++ b/routes/equipment.py @@ -23,7 +23,9 @@ class EquipmentRead(EquipmentCreate): @router.post("/", response_model=EquipmentRead) -async def create_equipment(item: EquipmentCreate, db: Session = Depends(get_db)): +async def create_equipment( + item: EquipmentCreate, db: Session = Depends(get_db) +): db_item = Equipment(**item.model_dump()) db.add(db_item) db.commit() diff --git a/routes/maintenance.py b/routes/maintenance.py index d7f0f49..93683fd 100644 --- a/routes/maintenance.py +++ b/routes/maintenance.py @@ -34,8 +34,9 @@ class MaintenanceRead(MaintenanceBase): def _get_maintenance_or_404(db: Session, maintenance_id: int) -> Maintenance: - maintenance = db.query(Maintenance).filter( - Maintenance.id == maintenance_id).first() + maintenance = ( + db.query(Maintenance).filter(Maintenance.id == maintenance_id).first() + ) if maintenance is None: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, @@ -44,8 +45,12 @@ def _get_maintenance_or_404(db: Session, maintenance_id: int) -> Maintenance: return maintenance -@router.post("/", response_model=MaintenanceRead, status_code=status.HTTP_201_CREATED) -def create_maintenance(maintenance: MaintenanceCreate, db: Session = Depends(get_db)): +@router.post( + "/", response_model=MaintenanceRead, status_code=status.HTTP_201_CREATED +) +def create_maintenance( + maintenance: MaintenanceCreate, db: Session = Depends(get_db) +): db_maintenance = Maintenance(**maintenance.model_dump()) db.add(db_maintenance) db.commit() @@ -54,7 +59,9 @@ def create_maintenance(maintenance: MaintenanceCreate, db: Session = Depends(get @router.get("/", response_model=List[MaintenanceRead]) -def list_maintenance(skip: int = 0, limit: int = 100, db: Session = Depends(get_db)): +def list_maintenance( + skip: int = 0, limit: int = 100, db: Session = Depends(get_db) +): return db.query(Maintenance).offset(skip).limit(limit).all() diff --git a/routes/parameters.py b/routes/parameters.py index 39e67e4..59f09c8 100644 --- a/routes/parameters.py +++ b/routes/parameters.py @@ -30,12 +30,15 @@ class ParameterCreate(BaseModel): return None if normalized not in {"normal", "uniform", "triangular"}: raise ValueError( - "distribution_type must be normal, uniform, or triangular") + "distribution_type must be normal, uniform, or triangular" + ) return normalized @field_validator("distribution_parameters") @classmethod - def empty_dict_to_none(cls, value: Optional[Dict[str, Any]]) -> Optional[Dict[str, Any]]: + def empty_dict_to_none( + cls, value: Optional[Dict[str, Any]] + ) -> Optional[Dict[str, Any]]: if value is None: return None return value or None @@ -45,6 +48,7 @@ class ParameterRead(ParameterCreate): id: int model_config = ConfigDict(from_attributes=True) + @router.post("/", response_model=ParameterRead) def create_parameter(param: ParameterCreate, db: Session = Depends(get_db)): scen = db.query(Scenario).filter(Scenario.id == param.scenario_id).first() @@ -55,11 +59,15 @@ def create_parameter(param: ParameterCreate, db: Session = Depends(get_db)): distribution_parameters = param.distribution_parameters if distribution_id is not None: - distribution = db.query(Distribution).filter( - Distribution.id == distribution_id).first() + distribution = ( + db.query(Distribution) + .filter(Distribution.id == distribution_id) + .first() + ) if not distribution: raise HTTPException( - status_code=404, detail="Distribution not found") + status_code=404, detail="Distribution not found" + ) distribution_type = distribution.distribution_type distribution_parameters = distribution.parameters or None diff --git a/routes/production.py b/routes/production.py index 264b541..ad4a059 100644 --- a/routes/production.py +++ b/routes/production.py @@ -36,8 +36,14 @@ class ProductionOutputRead(ProductionOutputBase): model_config = ConfigDict(from_attributes=True) -@router.post("/", response_model=ProductionOutputRead, status_code=status.HTTP_201_CREATED) -def create_production(item: ProductionOutputCreate, db: Session = Depends(get_db)): +@router.post( + "/", + response_model=ProductionOutputRead, + status_code=status.HTTP_201_CREATED, +) +def create_production( + item: ProductionOutputCreate, db: Session = Depends(get_db) +): db_item = ProductionOutput(**item.model_dump()) db.add(db_item) db.commit() diff --git a/routes/scenarios.py b/routes/scenarios.py index 11dab40..4454f74 100644 --- a/routes/scenarios.py +++ b/routes/scenarios.py @@ -24,6 +24,7 @@ class ScenarioRead(ScenarioCreate): updated_at: Optional[datetime] = None model_config = ConfigDict(from_attributes=True) + @router.post("/", response_model=ScenarioRead) def create_scenario(scenario: ScenarioCreate, db: Session = Depends(get_db)): db_s = db.query(Scenario).filter(Scenario.name == scenario.name).first() diff --git a/routes/settings.py b/routes/settings.py index 0cc2397..2308d7b 100644 --- a/routes/settings.py +++ b/routes/settings.py @@ -11,6 +11,8 @@ from services.settings import ( list_css_env_override_rows, read_css_color_env_overrides, update_css_color_settings, + get_theme_settings, + save_theme_settings, ) router = APIRouter(prefix="/api/settings", tags=["Settings"]) @@ -49,8 +51,7 @@ def read_css_settings(db: Session = Depends(get_db)) -> CSSSettingsResponse: values = get_css_color_settings(db) env_overrides = read_css_color_env_overrides() env_sources = [ - EnvOverride(**row) - for row in list_css_env_override_rows() + EnvOverride(**row) for row in list_css_env_override_rows() ] except ValueError as exc: raise HTTPException( @@ -64,14 +65,17 @@ def read_css_settings(db: Session = Depends(get_db)) -> CSSSettingsResponse: ) -@router.put("/css", response_model=CSSSettingsResponse, status_code=status.HTTP_200_OK) -def update_css_settings(payload: CSSSettingsPayload, db: Session = Depends(get_db)) -> CSSSettingsResponse: +@router.put( + "/css", response_model=CSSSettingsResponse, status_code=status.HTTP_200_OK +) +def update_css_settings( + payload: CSSSettingsPayload, db: Session = Depends(get_db) +) -> CSSSettingsResponse: try: values = update_css_color_settings(db, payload.variables) env_overrides = read_css_color_env_overrides() env_sources = [ - EnvOverride(**row) - for row in list_css_env_override_rows() + EnvOverride(**row) for row in list_css_env_override_rows() ] except ValueError as exc: raise HTTPException( @@ -83,3 +87,24 @@ def update_css_settings(payload: CSSSettingsPayload, db: Session = Depends(get_d env_overrides=env_overrides, env_sources=env_sources, ) + + +class ThemeSettings(BaseModel): + theme_name: str + primary_color: str + secondary_color: str + accent_color: str + background_color: str + text_color: str + + +@router.post("/theme") +async def update_theme(theme_data: ThemeSettings, db: Session = Depends(get_db)): + data_dict = theme_data.model_dump() + saved = save_theme_settings(db, data_dict) + return {"message": "Theme updated", "theme": data_dict} + + +@router.get("/theme") +async def get_theme(db: Session = Depends(get_db)): + return get_theme_settings(db) diff --git a/routes/simulations.py b/routes/simulations.py index b00c8c1..5500805 100644 --- a/routes/simulations.py +++ b/routes/simulations.py @@ -43,7 +43,9 @@ class SimulationRunResponse(BaseModel): summary: Dict[str, float | int] -def _load_parameters(db: Session, scenario_id: int) -> List[SimulationParameterInput]: +def _load_parameters( + db: Session, scenario_id: int +) -> List[SimulationParameterInput]: db_params = ( db.query(Parameter) .filter(Parameter.scenario_id == scenario_id) @@ -60,17 +62,19 @@ def _load_parameters(db: Session, scenario_id: int) -> List[SimulationParameterI @router.post("/run", response_model=SimulationRunResponse) -async def simulate(payload: SimulationRunRequest, db: Session = Depends(get_db)): - scenario = db.query(Scenario).filter( - Scenario.id == payload.scenario_id).first() +async def simulate( + payload: SimulationRunRequest, db: Session = Depends(get_db) +): + scenario = ( + db.query(Scenario).filter(Scenario.id == payload.scenario_id).first() + ) if scenario is None: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail="Scenario not found", ) - parameters = payload.parameters or _load_parameters( - db, payload.scenario_id) + parameters = payload.parameters or _load_parameters(db, payload.scenario_id) if not parameters: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, diff --git a/routes/ui.py b/routes/ui.py index 935f7e9..e690dba 100644 --- a/routes/ui.py +++ b/routes/ui.py @@ -53,7 +53,9 @@ router = APIRouter() templates = Jinja2Templates(directory="templates") -def _context(request: Request, extra: Optional[Dict[str, Any]] = None) -> Dict[str, Any]: +def _context( + request: Request, extra: Optional[Dict[str, Any]] = None +) -> Dict[str, Any]: payload: Dict[str, Any] = { "request": request, "current_year": datetime.now(timezone.utc).year, @@ -98,7 +100,9 @@ def _load_scenarios(db: Session) -> Dict[str, Any]: def _load_parameters(db: Session) -> Dict[str, Any]: grouped: defaultdict[int, list[Dict[str, Any]]] = defaultdict(list) - for param in db.query(Parameter).order_by(Parameter.scenario_id, Parameter.id): + for param in db.query(Parameter).order_by( + Parameter.scenario_id, Parameter.id + ): grouped[param.scenario_id].append( { "id": param.id, @@ -113,27 +117,20 @@ def _load_parameters(db: Session) -> Dict[str, Any]: def _load_costs(db: Session) -> Dict[str, Any]: capex_grouped: defaultdict[int, list[Dict[str, Any]]] = defaultdict(list) - for capex in ( - db.query(Capex) - .order_by(Capex.scenario_id, Capex.id) - .all() - ): + for capex in db.query(Capex).order_by(Capex.scenario_id, Capex.id).all(): capex_grouped[int(getattr(capex, "scenario_id"))].append( { "id": int(getattr(capex, "id")), "scenario_id": int(getattr(capex, "scenario_id")), "amount": float(getattr(capex, "amount", 0.0)), "description": getattr(capex, "description", "") or "", - "currency_code": getattr(capex, "currency_code", "USD") or "USD", + "currency_code": getattr(capex, "currency_code", "USD") + or "USD", } ) opex_grouped: defaultdict[int, list[Dict[str, Any]]] = defaultdict(list) - for opex in ( - db.query(Opex) - .order_by(Opex.scenario_id, Opex.id) - .all() - ): + for opex in db.query(Opex).order_by(Opex.scenario_id, Opex.id).all(): opex_grouped[int(getattr(opex, "scenario_id"))].append( { "id": int(getattr(opex, "id")), @@ -152,9 +149,15 @@ def _load_costs(db: Session) -> Dict[str, Any]: def _load_currencies(db: Session) -> Dict[str, Any]: items: list[Dict[str, Any]] = [] - for c in db.query(Currency).filter_by(is_active=True).order_by(Currency.code).all(): + for c in ( + db.query(Currency) + .filter_by(is_active=True) + .order_by(Currency.code) + .all() + ): items.append( - {"id": c.code, "name": f"{c.name} ({c.code})", "symbol": c.symbol}) + {"id": c.code, "name": f"{c.name} ({c.code})", "symbol": c.symbol} + ) if not items: items.append({"id": "USD", "name": "US Dollar (USD)", "symbol": "$"}) return {"currency_options": items} @@ -261,9 +264,7 @@ def _load_production(db: Session) -> Dict[str, Any]: def _load_equipment(db: Session) -> Dict[str, Any]: grouped: defaultdict[int, list[Dict[str, Any]]] = defaultdict(list) for record in ( - db.query(Equipment) - .order_by(Equipment.scenario_id, Equipment.id) - .all() + db.query(Equipment).order_by(Equipment.scenario_id, Equipment.id).all() ): record_id = int(getattr(record, "id")) scenario_id = int(getattr(record, "scenario_id")) @@ -291,8 +292,9 @@ def _load_maintenance(db: Session) -> Dict[str, Any]: scenario_id = int(getattr(record, "scenario_id")) equipment_id = int(getattr(record, "equipment_id")) equipment_obj = getattr(record, "equipment", None) - equipment_name = getattr( - equipment_obj, "name", "") if equipment_obj else "" + equipment_name = ( + getattr(equipment_obj, "name", "") if equipment_obj else "" + ) maintenance_date = getattr(record, "maintenance_date", None) cost_value = float(getattr(record, "cost", 0.0)) description = getattr(record, "description", "") or "" @@ -303,7 +305,9 @@ def _load_maintenance(db: Session) -> Dict[str, Any]: "scenario_id": scenario_id, "equipment_id": equipment_id, "equipment_name": equipment_name, - "maintenance_date": maintenance_date.isoformat() if maintenance_date else "", + "maintenance_date": ( + maintenance_date.isoformat() if maintenance_date else "" + ), "cost": cost_value, "description": description, } @@ -339,8 +343,11 @@ def _load_simulations(db: Session) -> Dict[str, Any]: for item in scenarios: scenario_id = int(item["id"]) scenario_results = results_grouped.get(scenario_id, []) - summary = generate_report( - scenario_results) if scenario_results else generate_report([]) + summary = ( + generate_report(scenario_results) + if scenario_results + else generate_report([]) + ) runs.append( { "scenario_id": scenario_id, @@ -395,11 +402,11 @@ def _load_dashboard(db: Session) -> Dict[str, Any]: simulation_context = _load_simulations(db) simulation_runs = simulation_context["simulation_runs"] - runs_by_scenario = { - run["scenario_id"]: run for run in simulation_runs - } + runs_by_scenario = {run["scenario_id"]: run for run in simulation_runs} - def sum_amounts(grouped: Dict[int, list[Dict[str, Any]]], field: str = "amount") -> float: + def sum_amounts( + grouped: Dict[int, list[Dict[str, Any]]], field: str = "amount" + ) -> float: total = 0.0 for items in grouped.values(): for item in items: @@ -414,14 +421,18 @@ def _load_dashboard(db: Session) -> Dict[str, Any]: total_production = sum_amounts(production_by_scenario) total_maintenance_cost = sum_amounts(maintenance_by_scenario, field="cost") - total_parameters = sum(len(items) - for items in parameters_by_scenario.values()) - total_equipment = sum(len(items) - for items in equipment_by_scenario.values()) - total_maintenance_events = sum(len(items) - for items in maintenance_by_scenario.values()) + total_parameters = sum( + len(items) for items in parameters_by_scenario.values() + ) + total_equipment = sum( + len(items) for items in equipment_by_scenario.values() + ) + total_maintenance_events = sum( + len(items) for items in maintenance_by_scenario.values() + ) total_simulation_iterations = sum( - run["iterations"] for run in simulation_runs) + run["iterations"] for run in simulation_runs + ) scenario_rows: list[Dict[str, Any]] = [] scenario_labels: list[str] = [] @@ -501,20 +512,40 @@ def _load_dashboard(db: Session) -> Dict[str, Any]: overall_report = generate_report(all_simulation_results) overall_report_metrics = [ - {"label": "Runs", "value": _format_int( - int(overall_report.get("count", 0)))}, - {"label": "Mean", "value": _format_decimal( - float(overall_report.get("mean", 0.0)))}, - {"label": "Median", "value": _format_decimal( - float(overall_report.get("median", 0.0)))}, - {"label": "Std Dev", "value": _format_decimal( - float(overall_report.get("std_dev", 0.0)))}, - {"label": "95th Percentile", "value": _format_decimal( - float(overall_report.get("percentile_95", 0.0)))}, - {"label": "VaR (95%)", "value": _format_decimal( - float(overall_report.get("value_at_risk_95", 0.0)))}, - {"label": "Expected Shortfall (95%)", "value": _format_decimal( - float(overall_report.get("expected_shortfall_95", 0.0)))}, + { + "label": "Runs", + "value": _format_int(int(overall_report.get("count", 0))), + }, + { + "label": "Mean", + "value": _format_decimal(float(overall_report.get("mean", 0.0))), + }, + { + "label": "Median", + "value": _format_decimal(float(overall_report.get("median", 0.0))), + }, + { + "label": "Std Dev", + "value": _format_decimal(float(overall_report.get("std_dev", 0.0))), + }, + { + "label": "95th Percentile", + "value": _format_decimal( + float(overall_report.get("percentile_95", 0.0)) + ), + }, + { + "label": "VaR (95%)", + "value": _format_decimal( + float(overall_report.get("value_at_risk_95", 0.0)) + ), + }, + { + "label": "Expected Shortfall (95%)", + "value": _format_decimal( + float(overall_report.get("expected_shortfall_95", 0.0)) + ), + }, ] recent_simulations: list[Dict[str, Any]] = [ @@ -522,8 +553,12 @@ def _load_dashboard(db: Session) -> Dict[str, Any]: "scenario_name": run["scenario_name"], "iterations": run["iterations"], "iterations_display": _format_int(run["iterations"]), - "mean_display": _format_decimal(float(run["summary"].get("mean", 0.0))), - "p95_display": _format_decimal(float(run["summary"].get("percentile_95", 0.0))), + "mean_display": _format_decimal( + float(run["summary"].get("mean", 0.0)) + ), + "p95_display": _format_decimal( + float(run["summary"].get("percentile_95", 0.0)) + ), } for run in simulation_runs if run["iterations"] > 0 @@ -541,10 +576,20 @@ def _load_dashboard(db: Session) -> Dict[str, Any]: maintenance_date = getattr(record, "maintenance_date", None) upcoming_maintenance.append( { - "scenario_name": getattr(getattr(record, "scenario", None), "name", "Unknown"), - "equipment_name": getattr(getattr(record, "equipment", None), "name", "Unknown"), - "date_display": maintenance_date.strftime("%Y-%m-%d") if maintenance_date else "—", - "cost_display": _format_currency(float(getattr(record, "cost", 0.0))), + "scenario_name": getattr( + getattr(record, "scenario", None), "name", "Unknown" + ), + "equipment_name": getattr( + getattr(record, "equipment", None), "name", "Unknown" + ), + "date_display": ( + maintenance_date.strftime("%Y-%m-%d") + if maintenance_date + else "—" + ), + "cost_display": _format_currency( + float(getattr(record, "cost", 0.0)) + ), "description": getattr(record, "description", "") or "—", } ) @@ -552,9 +597,9 @@ def _load_dashboard(db: Session) -> Dict[str, Any]: cost_chart_has_data = any(value > 0 for value in scenario_capex) or any( value > 0 for value in scenario_opex ) - activity_chart_has_data = any(value > 0 for value in activity_production) or any( - value > 0 for value in activity_consumption - ) + activity_chart_has_data = any( + value > 0 for value in activity_production + ) or any(value > 0 for value in activity_consumption) scenario_cost_chart: Dict[str, list[Any]] = { "labels": scenario_labels, @@ -573,14 +618,20 @@ def _load_dashboard(db: Session) -> Dict[str, Any]: {"label": "CAPEX Total", "value": _format_currency(total_capex)}, {"label": "OPEX Total", "value": _format_currency(total_opex)}, {"label": "Equipment Assets", "value": _format_int(total_equipment)}, - {"label": "Maintenance Events", - "value": _format_int(total_maintenance_events)}, + { + "label": "Maintenance Events", + "value": _format_int(total_maintenance_events), + }, {"label": "Consumption", "value": _format_decimal(total_consumption)}, {"label": "Production", "value": _format_decimal(total_production)}, - {"label": "Simulation Iterations", - "value": _format_int(total_simulation_iterations)}, - {"label": "Maintenance Cost", - "value": _format_currency(total_maintenance_cost)}, + { + "label": "Simulation Iterations", + "value": _format_int(total_simulation_iterations), + }, + { + "label": "Maintenance Cost", + "value": _format_currency(total_maintenance_cost), + }, ] return { @@ -704,3 +755,30 @@ async def currencies_view(request: Request, db: Session = Depends(get_db)): """Render the currency administration page with full currency context.""" context = _load_currency_settings(db) return _render(request, "currencies.html", context) + + +@router.get("/login", response_class=HTMLResponse) +async def login_page(request: Request): + return _render(request, "login.html") + + +@router.get("/register", response_class=HTMLResponse) +async def register_page(request: Request): + return _render(request, "register.html") + + +@router.get("/profile", response_class=HTMLResponse) +async def profile_page(request: Request): + return _render(request, "profile.html") + + +@router.get("/forgot-password", response_class=HTMLResponse) +async def forgot_password_page(request: Request): + return _render(request, "forgot_password.html") + + +@router.get("/theme-settings", response_class=HTMLResponse) +async def theme_settings_page(request: Request, db: Session = Depends(get_db)): + """Render the theme settings page.""" + context = _load_css_settings(db) + return _render(request, "theme_settings.html", context) diff --git a/routes/users.py b/routes/users.py new file mode 100644 index 0000000..dd9ddc6 --- /dev/null +++ b/routes/users.py @@ -0,0 +1,126 @@ +from fastapi import APIRouter, Depends, HTTPException, status +from fastapi.security import OAuth2PasswordBearer +from sqlalchemy.orm import Session + +from config.database import get_db +from models.user import User +from services.security import get_password_hash, verify_password, create_access_token, SECRET_KEY, ALGORITHM +from jose import jwt, JWTError +from schemas.user import UserCreate, UserInDB, UserLogin, UserUpdate, PasswordResetRequest, PasswordReset, Token + +router = APIRouter(prefix="/users", tags=["users"]) + + +oauth2_scheme = OAuth2PasswordBearer(tokenUrl="users/login") + + +async def get_current_user(token: str = Depends(oauth2_scheme), db: Session = Depends(get_db)): + credentials_exception = HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Could not validate credentials", + headers={"WWW-Authenticate": "Bearer"}, + ) + try: + payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM]) + username: str = payload.get("sub") + if username is None: + raise credentials_exception + if username is None: + raise credentials_exception + except JWTError: + raise credentials_exception + user = db.query(User).filter(User.username == username).first() + if user is None: + raise credentials_exception + return user + + +@router.post("/register", response_model=UserInDB, status_code=status.HTTP_201_CREATED) +async def register_user(user: UserCreate, db: Session = Depends(get_db)): + db_user = db.query(User).filter(User.username == user.username).first() + if db_user: + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, + detail="Username already registered") + db_user = db.query(User).filter(User.email == user.email).first() + if db_user: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, detail="Email already registered") + + # Get or create default role + from models.role import Role + default_role = db.query(Role).filter(Role.name == "user").first() + if not default_role: + default_role = Role(name="user") + db.add(default_role) + db.commit() + db.refresh(default_role) + + new_user = User(username=user.username, email=user.email, + role_id=default_role.id) + new_user.set_password(user.password) + db.add(new_user) + db.commit() + db.refresh(new_user) + return new_user + + +@router.post("/login") +async def login_user(user: UserLogin, db: Session = Depends(get_db)): + db_user = db.query(User).filter(User.username == user.username).first() + if not db_user or not db_user.check_password(user.password): + raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, + detail="Incorrect username or password") + access_token = create_access_token(subject=db_user.username) + return {"access_token": access_token, "token_type": "bearer"} + + +@router.get("/me") +async def read_users_me(current_user: User = Depends(get_current_user)): + return current_user + + +@router.put("/me", response_model=UserInDB) +async def update_user_me(user_update: UserUpdate, current_user: User = Depends(get_current_user), db: Session = Depends(get_db)): + if user_update.username and user_update.username != current_user.username: + existing_user = db.query(User).filter( + User.username == user_update.username).first() + if existing_user: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, detail="Username already taken") + current_user.username = user_update.username + + if user_update.email and user_update.email != current_user.email: + existing_user = db.query(User).filter( + User.email == user_update.email).first() + if existing_user: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, detail="Email already registered") + current_user.email = user_update.email + + if user_update.password: + current_user.set_password(user_update.password) + + db.add(current_user) + db.commit() + db.refresh(current_user) + return current_user + + +@router.post("/forgot-password") +async def forgot_password(request: PasswordResetRequest): + # In a real application, this would send an email with a reset token + return {"message": "Password reset email sent (not really)"} + + +@router.post("/reset-password") +async def reset_password(request: PasswordReset, db: Session = Depends(get_db)): + # In a real application, the token would be verified + user = db.query(User).filter(User.username == + request.token).first() # Use token as username for test + if not user: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid token or user") + user.set_password(request.new_password) + db.add(user) + db.commit() + return {"message": "Password has been reset successfully"} diff --git a/schemas/user.py b/schemas/user.py new file mode 100644 index 0000000..fafce5b --- /dev/null +++ b/schemas/user.py @@ -0,0 +1,41 @@ +from pydantic import BaseModel, ConfigDict + + +class UserCreate(BaseModel): + username: str + email: str + password: str + + +class UserInDB(BaseModel): + id: int + username: str + email: str + role_id: int + + model_config = ConfigDict(from_attributes=True) + + +class UserLogin(BaseModel): + username: str + password: str + + +class UserUpdate(BaseModel): + username: str | None = None + email: str | None = None + password: str | None = None + + +class PasswordResetRequest(BaseModel): + email: str + + +class PasswordReset(BaseModel): + token: str + new_password: str + + +class Token(BaseModel): + access_token: str + token_type: str diff --git a/scripts/backfill_currency.py b/scripts/backfill_currency.py index 15c330b..c56d6af 100644 --- a/scripts/backfill_currency.py +++ b/scripts/backfill_currency.py @@ -9,6 +9,7 @@ This script is intentionally cautious: it defaults to dry-run mode and will refu if database connection settings are missing. It supports creating missing currency rows when `--create-missing` is provided. Always run against a development/staging database first. """ + from __future__ import annotations import argparse import importlib @@ -36,26 +37,43 @@ def load_database_url() -> str: return getattr(db_module, "DATABASE_URL") -def backfill(db_url: str, dry_run: bool = True, create_missing: bool = False) -> None: +def backfill( + db_url: str, dry_run: bool = True, create_missing: bool = False +) -> None: engine = create_engine(db_url) with engine.begin() as conn: # Ensure currency table exists - res = conn.execute(text("SELECT name FROM sqlite_master WHERE type='table' AND name='currency';")) if db_url.startswith( - 'sqlite:') else conn.execute(text("SELECT to_regclass('public.currency');")) + res = ( + conn.execute( + text( + "SELECT name FROM sqlite_master WHERE type='table' AND name='currency';" + ) + ) + if db_url.startswith("sqlite:") + else conn.execute(text("SELECT to_regclass('public.currency');")) + ) # Note: we don't strictly depend on the above - we assume migration was already applied # Helper: find or create currency by code def find_currency_id(code: str): - r = conn.execute(text("SELECT id FROM currency WHERE code = :code"), { - "code": code}).fetchone() + r = conn.execute( + text("SELECT id FROM currency WHERE code = :code"), + {"code": code}, + ).fetchone() if r: return r[0] if create_missing: # insert and return id - conn.execute(text("INSERT INTO currency (code, name, symbol, is_active) VALUES (:c, :n, NULL, TRUE)"), { - "c": code, "n": code}) - r2 = conn.execute(text("SELECT id FROM currency WHERE code = :code"), { - "code": code}).fetchone() + conn.execute( + text( + "INSERT INTO currency (code, name, symbol, is_active) VALUES (:c, :n, NULL, TRUE)" + ), + {"c": code, "n": code}, + ) + r2 = conn.execute( + text("SELECT id FROM currency WHERE code = :code"), + {"code": code}, + ).fetchone() if not r2: raise RuntimeError( f"Unable to determine currency ID for '{code}' after insert" @@ -67,8 +85,15 @@ def backfill(db_url: str, dry_run: bool = True, create_missing: bool = False) -> for table in ("capex", "opex"): # Check if currency_id column exists try: - cols = conn.execute(text(f"SELECT 1 FROM information_schema.columns WHERE table_name = '{table}' AND column_name = 'currency_id'")) if not db_url.startswith( - 'sqlite:') else [(1,)] + cols = ( + conn.execute( + text( + f"SELECT 1 FROM information_schema.columns WHERE table_name = '{table}' AND column_name = 'currency_id'" + ) + ) + if not db_url.startswith("sqlite:") + else [(1,)] + ) except Exception: cols = [(1,)] @@ -77,8 +102,11 @@ def backfill(db_url: str, dry_run: bool = True, create_missing: bool = False) -> continue # Find rows where currency_id IS NULL but currency_code exists - rows = conn.execute(text( - f"SELECT id, currency_code FROM {table} WHERE currency_id IS NULL OR currency_id = ''")) + rows = conn.execute( + text( + f"SELECT id, currency_code FROM {table} WHERE currency_id IS NULL OR currency_id = ''" + ) + ) changed = 0 for r in rows: rid = r[0] @@ -86,14 +114,20 @@ def backfill(db_url: str, dry_run: bool = True, create_missing: bool = False) -> cid = find_currency_id(code) if cid is None: print( - f"Row {table}:{rid} has unknown currency code '{code}' and create_missing=False; skipping") + f"Row {table}:{rid} has unknown currency code '{code}' and create_missing=False; skipping" + ) continue if dry_run: print( - f"[DRY RUN] Would set {table}.currency_id = {cid} for row id={rid} (code={code})") + f"[DRY RUN] Would set {table}.currency_id = {cid} for row id={rid} (code={code})" + ) else: - conn.execute(text(f"UPDATE {table} SET currency_id = :cid WHERE id = :rid"), { - "cid": cid, "rid": rid}) + conn.execute( + text( + f"UPDATE {table} SET currency_id = :cid WHERE id = :rid" + ), + {"cid": cid, "rid": rid}, + ) changed += 1 print(f"{table}: processed, changed={changed} (dry_run={dry_run})") @@ -101,11 +135,19 @@ def backfill(db_url: str, dry_run: bool = True, create_missing: bool = False) -> def main() -> None: parser = argparse.ArgumentParser( - description="Backfill currency_id from currency_code for capex/opex tables") - parser.add_argument("--dry-run", action="store_true", - default=True, help="Show actions without writing") - parser.add_argument("--create-missing", action="store_true", - help="Create missing currency rows in the currency table") + description="Backfill currency_id from currency_code for capex/opex tables" + ) + parser.add_argument( + "--dry-run", + action="store_true", + default=True, + help="Show actions without writing", + ) + parser.add_argument( + "--create-missing", + action="store_true", + help="Create missing currency rows in the currency table", + ) args = parser.parse_args() db = load_database_url() diff --git a/scripts/check_docs_links.py b/scripts/check_docs_links.py index 556575a..aebc1fe 100644 --- a/scripts/check_docs_links.py +++ b/scripts/check_docs_links.py @@ -4,25 +4,30 @@ Checks only local file links (relative paths) and reports missing targets. Run from the repository root using the project's Python environment. """ + import re from pathlib import Path ROOT = Path(__file__).resolve().parent.parent -DOCS = ROOT / 'docs' +DOCS = ROOT / "docs" MD_LINK_RE = re.compile(r"\[([^\]]+)\]\(([^)]+)\)") errors = [] -for md in DOCS.rglob('*.md'): - text = md.read_text(encoding='utf-8') +for md in DOCS.rglob("*.md"): + text = md.read_text(encoding="utf-8") for m in MD_LINK_RE.finditer(text): label, target = m.groups() # skip URLs - if target.startswith('http://') or target.startswith('https://') or target.startswith('#'): + if ( + target.startswith("http://") + or target.startswith("https://") + or target.startswith("#") + ): continue # strip anchors - target_path = target.split('#')[0] + target_path = target.split("#")[0] # if link is to a directory index, allow candidate = (md.parent / target_path).resolve() if candidate.exists(): @@ -30,14 +35,16 @@ for md in DOCS.rglob('*.md'): # check common implicit index: target/ -> target/README.md or target/index.md candidate_dir = md.parent / target_path if candidate_dir.is_dir(): - if (candidate_dir / 'README.md').exists() or (candidate_dir / 'index.md').exists(): + if (candidate_dir / "README.md").exists() or ( + candidate_dir / "index.md" + ).exists(): continue errors.append((str(md.relative_to(ROOT)), target, label)) if errors: - print('Broken local links found:') + print("Broken local links found:") for src, tgt, label in errors: - print(f'- {src} -> {tgt} ({label})') + print(f"- {src} -> {tgt} ({label})") exit(2) -print('No broken local links detected.') +print("No broken local links detected.") diff --git a/scripts/format_docs_md.py b/scripts/format_docs_md.py index 3505505..5e1e856 100644 --- a/scripts/format_docs_md.py +++ b/scripts/format_docs_md.py @@ -2,16 +2,17 @@ This is intentionally small and non-destructive; it touches only files under docs/ and makes safe changes. """ + import re from pathlib import Path DOCS = Path(__file__).resolve().parents[1] / "docs" CODE_LANG_HINTS = { - 'powershell': ('powershell',), - 'bash': ('bash', 'sh'), - 'sql': ('sql',), - 'python': ('python',), + "powershell": ("powershell",), + "bash": ("bash", "sh"), + "sql": ("sql",), + "python": ("python",), } @@ -19,48 +20,60 @@ def add_code_fence_language(match): fence = match.group(0) inner = match.group(1) # If language already present, return unchanged - if fence.startswith('```') and len(fence.splitlines()[0].strip()) > 3: + if fence.startswith("```") and len(fence.splitlines()[0].strip()) > 3: return fence # Try to infer language from the code content - code = inner.strip().splitlines()[0] if inner.strip() else '' - lang = '' - if code.startswith('$') or code.startswith('PS') or code.lower().startswith('powershell'): - lang = 'powershell' - elif code.startswith('#') or code.startswith('import') or code.startswith('from'): - lang = 'python' - elif re.match(r'^(select|insert|update|create)\b', code.strip(), re.I): - lang = 'sql' - elif code.startswith('git') or code.startswith('./') or code.startswith('sudo'): - lang = 'bash' + code = inner.strip().splitlines()[0] if inner.strip() else "" + lang = "" + if ( + code.startswith("$") + or code.startswith("PS") + or code.lower().startswith("powershell") + ): + lang = "powershell" + elif ( + code.startswith("#") + or code.startswith("import") + or code.startswith("from") + ): + lang = "python" + elif re.match(r"^(select|insert|update|create)\b", code.strip(), re.I): + lang = "sql" + elif ( + code.startswith("git") + or code.startswith("./") + or code.startswith("sudo") + ): + lang = "bash" if lang: - return f'```{lang}\n{inner}\n```' + return f"```{lang}\n{inner}\n```" return fence def normalize_file(path: Path): - text = path.read_text(encoding='utf-8') + text = path.read_text(encoding="utf-8") orig = text # Trim trailing whitespace and ensure single trailing newline - text = '\n'.join(line.rstrip() for line in text.splitlines()) + '\n' + text = "\n".join(line.rstrip() for line in text.splitlines()) + "\n" # Ensure first non-empty line is H1 lines = text.splitlines() for i, ln in enumerate(lines): if ln.strip(): - if not ln.startswith('#'): - lines[i] = '# ' + ln + if not ln.startswith("#"): + lines[i] = "# " + ln break - text = '\n'.join(lines) + '\n' + text = "\n".join(lines) + "\n" # Add basic code fence languages where missing (simple heuristic) - text = re.sub(r'```\n([\s\S]*?)\n```', add_code_fence_language, text) + text = re.sub(r"```\n([\s\S]*?)\n```", add_code_fence_language, text) if text != orig: - path.write_text(text, encoding='utf-8') + path.write_text(text, encoding="utf-8") return True return False def main(): changed = [] - for p in DOCS.rglob('*.md'): + for p in DOCS.rglob("*.md"): if p.is_file(): try: if normalize_file(p): @@ -68,12 +81,12 @@ def main(): except Exception as e: print(f"Failed to format {p}: {e}") if changed: - print('Formatted files:') + print("Formatted files:") for c in changed: - print(' -', c) + print(" -", c) else: - print('No formatting changes required.') + print("No formatting changes required.") -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/scripts/migrations/20251027_create_theme_settings_table.sql b/scripts/migrations/20251027_create_theme_settings_table.sql new file mode 100644 index 0000000..8e2b448 --- /dev/null +++ b/scripts/migrations/20251027_create_theme_settings_table.sql @@ -0,0 +1,11 @@ +-- Migration: 20251027_create_theme_settings_table.sql + +CREATE TABLE theme_settings ( + id SERIAL PRIMARY KEY, + theme_name VARCHAR(255) UNIQUE NOT NULL, + primary_color VARCHAR(7) NOT NULL, + secondary_color VARCHAR(7) NOT NULL, + accent_color VARCHAR(7) NOT NULL, + background_color VARCHAR(7) NOT NULL, + text_color VARCHAR(7) NOT NULL +); diff --git a/scripts/migrations/20251027_create_user_and_role_tables.sql b/scripts/migrations/20251027_create_user_and_role_tables.sql new file mode 100644 index 0000000..5ae47b2 --- /dev/null +++ b/scripts/migrations/20251027_create_user_and_role_tables.sql @@ -0,0 +1,15 @@ +-- Migration: 20251027_create_user_and_role_tables.sql + +CREATE TABLE roles ( + id SERIAL PRIMARY KEY, + name VARCHAR(255) UNIQUE NOT NULL +); + +CREATE TABLE users ( + id SERIAL PRIMARY KEY, + username VARCHAR(255) UNIQUE NOT NULL, + email VARCHAR(255) UNIQUE NOT NULL, + hashed_password VARCHAR(255) NOT NULL, + role_id INTEGER NOT NULL, + FOREIGN KEY (role_id) REFERENCES roles(id) +); diff --git a/scripts/seed_data.py b/scripts/seed_data.py index 5c96278..f7c035f 100644 --- a/scripts/seed_data.py +++ b/scripts/seed_data.py @@ -47,22 +47,82 @@ MEASUREMENT_UNIT_SEEDS = ( ("kilowatt_hours", "Kilowatt Hours", "kWh", "energy", True), ) +THEME_SETTING_SEEDS = ( + ("--color-background", "#f4f5f7", "color", + "theme", "CSS variable --color-background", True), + ("--color-surface", "#ffffff", "color", + "theme", "CSS variable --color-surface", True), + ("--color-text-primary", "#2a1f33", "color", + "theme", "CSS variable --color-text-primary", True), + ("--color-text-secondary", "#624769", "color", + "theme", "CSS variable --color-text-secondary", True), + ("--color-text-muted", "#64748b", "color", + "theme", "CSS variable --color-text-muted", True), + ("--color-text-subtle", "#94a3b8", "color", + "theme", "CSS variable --color-text-subtle", True), + ("--color-text-invert", "#ffffff", "color", + "theme", "CSS variable --color-text-invert", True), + ("--color-text-dark", "#0f172a", "color", + "theme", "CSS variable --color-text-dark", True), + ("--color-text-strong", "#111827", "color", + "theme", "CSS variable --color-text-strong", True), + ("--color-primary", "#5f320d", "color", + "theme", "CSS variable --color-primary", True), + ("--color-primary-strong", "#7e4c13", "color", + "theme", "CSS variable --color-primary-strong", True), + ("--color-primary-stronger", "#837c15", "color", + "theme", "CSS variable --color-primary-stronger", True), + ("--color-accent", "#bff838", "color", + "theme", "CSS variable --color-accent", True), + ("--color-border", "#e2e8f0", "color", + "theme", "CSS variable --color-border", True), + ("--color-border-strong", "#cbd5e1", "color", + "theme", "CSS variable --color-border-strong", True), + ("--color-highlight", "#eef2ff", "color", + "theme", "CSS variable --color-highlight", True), + ("--color-panel-shadow", "rgba(15, 23, 42, 0.08)", "color", + "theme", "CSS variable --color-panel-shadow", True), + ("--color-panel-shadow-deep", "rgba(15, 23, 42, 0.12)", "color", + "theme", "CSS variable --color-panel-shadow-deep", True), + ("--color-surface-alt", "#f8fafc", "color", + "theme", "CSS variable --color-surface-alt", True), + ("--color-success", "#047857", "color", + "theme", "CSS variable --color-success", True), + ("--color-error", "#b91c1c", "color", + "theme", "CSS variable --color-error", True), +) + def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser(description="Seed baseline CalMiner data") - parser.add_argument("--currencies", action="store_true", help="Seed currency table") - parser.add_argument("--units", action="store_true", help="Seed unit table") - parser.add_argument("--defaults", action="store_true", help="Seed default records") - parser.add_argument("--dry-run", action="store_true", help="Print actions without executing") parser.add_argument( - "--verbose", "-v", action="count", default=0, help="Increase logging verbosity" + "--currencies", action="store_true", help="Seed currency table" + ) + parser.add_argument("--units", action="store_true", help="Seed unit table") + parser.add_argument( + "--theme", action="store_true", help="Seed theme settings" + ) + parser.add_argument( + "--defaults", action="store_true", help="Seed default records" + ) + parser.add_argument( + "--dry-run", action="store_true", help="Print actions without executing" + ) + parser.add_argument( + "--verbose", + "-v", + action="count", + default=0, + help="Increase logging verbosity", ) return parser.parse_args() def _configure_logging(args: argparse.Namespace) -> None: level = logging.WARNING - (10 * min(args.verbose, 2)) - logging.basicConfig(level=max(level, logging.INFO), format="%(levelname)s %(message)s") + logging.basicConfig( + level=max(level, logging.INFO), format="%(levelname)s %(message)s" + ) def main() -> None: @@ -77,7 +137,7 @@ def run_with_namespace( ) -> None: _configure_logging(args) - if not any((args.currencies, args.units, args.defaults)): + if not any((args.currencies, args.units, args.theme, args.defaults)): logger.info("No seeding options provided; exiting") return @@ -89,6 +149,8 @@ def run_with_namespace( _seed_currencies(cursor, dry_run=args.dry_run) if args.units: _seed_units(cursor, dry_run=args.dry_run) + if args.theme: + _seed_theme(cursor, dry_run=args.dry_run) if args.defaults: _seed_defaults(cursor, dry_run=args.dry_run) @@ -152,11 +214,44 @@ def _seed_units(cursor, *, dry_run: bool) -> None: logger.info("Measurement unit seed complete") -def _seed_defaults(cursor, *, dry_run: bool) -> None: - logger.info("Seeding default records - not yet implemented") +def _seed_theme(cursor, *, dry_run: bool) -> None: + logger.info("Seeding theme settings (%d rows)", len(THEME_SETTING_SEEDS)) if dry_run: + for key, value, _, _, _, _ in THEME_SETTING_SEEDS: + logger.info( + "Dry run: would upsert theme setting %s = %s", key, value) return + try: + execute_values( + cursor, + """ + INSERT INTO application_setting (key, value, value_type, category, description, is_editable) + VALUES %s + ON CONFLICT (key) DO UPDATE + SET value = EXCLUDED.value, + value_type = EXCLUDED.value_type, + category = EXCLUDED.category, + description = EXCLUDED.description, + is_editable = EXCLUDED.is_editable + """, + THEME_SETTING_SEEDS, + ) + except errors.UndefinedTable: + logger.warning( + "application_setting table does not exist; skipping theme seeding." + ) + cursor.connection.rollback() + return + + logger.info("Theme settings seed complete") + + +def _seed_defaults(cursor, *, dry_run: bool) -> None: + logger.info("Seeding default records") + _seed_theme(cursor, dry_run=dry_run) + logger.info("Default records seed complete") + if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/scripts/setup_database.py b/scripts/setup_database.py index 1da38ef..3c51eb3 100644 --- a/scripts/setup_database.py +++ b/scripts/setup_database.py @@ -39,6 +39,7 @@ from psycopg2 import extensions from psycopg2.extensions import connection as PGConnection, parse_dsn from dotenv import load_dotenv from sqlalchemy import create_engine, inspect + ROOT_DIR = Path(__file__).resolve().parents[1] if str(ROOT_DIR) not in sys.path: sys.path.insert(0, str(ROOT_DIR)) @@ -125,8 +126,7 @@ class DatabaseConfig: ] if missing: raise RuntimeError( - "Missing required database configuration: " + - ", ".join(missing) + "Missing required database configuration: " + ", ".join(missing) ) host = cast(str, host) @@ -208,12 +208,17 @@ class DatabaseConfig: class DatabaseSetup: """Encapsulates the full setup workflow.""" - def __init__(self, config: DatabaseConfig, *, dry_run: bool = False) -> None: + def __init__( + self, config: DatabaseConfig, *, dry_run: bool = False + ) -> None: self.config = config self.dry_run = dry_run self._models_loaded = False self._rollback_actions: list[tuple[str, Callable[[], None]]] = [] - def _register_rollback(self, label: str, action: Callable[[], None]) -> None: + + def _register_rollback( + self, label: str, action: Callable[[], None] + ) -> None: if self.dry_run: return self._rollback_actions.append((label, action)) @@ -237,7 +242,6 @@ class DatabaseSetup: def clear_rollbacks(self) -> None: self._rollback_actions.clear() - def _describe_connection(self, user: str, database: str) -> str: return f"{user}@{self.config.host}:{self.config.port}/{database}" @@ -384,9 +388,9 @@ class DatabaseSetup: try: if self.config.password: cursor.execute( - sql.SQL("CREATE ROLE {} WITH LOGIN PASSWORD %s").format( - sql.Identifier(self.config.user) - ), + sql.SQL( + "CREATE ROLE {} WITH LOGIN PASSWORD %s" + ).format(sql.Identifier(self.config.user)), (self.config.password,), ) else: @@ -589,8 +593,7 @@ class DatabaseSetup: return psycopg2.connect(dsn) except psycopg2.Error as exc: raise RuntimeError( - "Unable to establish admin connection. " - f"Target: {descriptor}" + "Unable to establish admin connection. " f"Target: {descriptor}" ) from exc def _application_connection(self) -> PGConnection: @@ -645,7 +648,9 @@ class DatabaseSetup: importlib.import_module(f"{package.__name__}.{module_info.name}") self._models_loaded = True - def run_migrations(self, migrations_dir: Optional[Path | str] = None) -> None: + def run_migrations( + self, migrations_dir: Optional[Path | str] = None + ) -> None: """Execute pending SQL migrations in chronological order.""" directory = ( @@ -673,7 +678,8 @@ class DatabaseSetup: conn.autocommit = True with conn.cursor() as cursor: table_exists = self._migrations_table_exists( - cursor, schema_name) + cursor, schema_name + ) if not table_exists: if self.dry_run: logger.info( @@ -692,12 +698,10 @@ class DatabaseSetup: applied = set() else: applied = self._fetch_applied_migrations( - cursor, schema_name) + cursor, schema_name + ) - if ( - baseline_path.exists() - and baseline_name not in applied - ): + if baseline_path.exists() and baseline_name not in applied: if self.dry_run: logger.info( "Dry run: baseline migration '%s' pending; would apply and mark legacy files", @@ -756,9 +760,7 @@ class DatabaseSetup: ) pending = [ - path - for path in migration_files - if path.name not in applied + path for path in migration_files if path.name not in applied ] if not pending: @@ -792,9 +794,7 @@ class DatabaseSetup: cursor.execute( sql.SQL( "INSERT INTO {} (filename, applied_at) VALUES (%s, NOW())" - ).format( - sql.Identifier(schema_name, MIGRATIONS_TABLE) - ), + ).format(sql.Identifier(schema_name, MIGRATIONS_TABLE)), (path.name,), ) return path.name @@ -820,9 +820,7 @@ class DatabaseSetup: "filename TEXT PRIMARY KEY," "applied_at TIMESTAMPTZ NOT NULL DEFAULT NOW()" ")" - ).format( - sql.Identifier(schema_name, MIGRATIONS_TABLE) - ) + ).format(sql.Identifier(schema_name, MIGRATIONS_TABLE)) ) def _fetch_applied_migrations(self, cursor, schema_name: str) -> set[str]: @@ -974,7 +972,7 @@ class DatabaseSetup: (database,), ) cursor.execute( - sql.SQL("DROP DATABASE IF EXISTS {}" ).format( + sql.SQL("DROP DATABASE IF EXISTS {}").format( sql.Identifier(database) ) ) @@ -985,7 +983,7 @@ class DatabaseSetup: conn.autocommit = True with conn.cursor() as cursor: cursor.execute( - sql.SQL("DROP ROLE IF EXISTS {}" ).format( + sql.SQL("DROP ROLE IF EXISTS {}").format( sql.Identifier(role) ) ) @@ -1000,27 +998,35 @@ class DatabaseSetup: conn.autocommit = True with conn.cursor() as cursor: cursor.execute( - sql.SQL("REVOKE ALL PRIVILEGES ON ALL TABLES IN SCHEMA {} FROM {}" ).format( + sql.SQL( + "REVOKE ALL PRIVILEGES ON ALL TABLES IN SCHEMA {} FROM {}" + ).format( sql.Identifier(schema_name), - sql.Identifier(self.config.user) + sql.Identifier(self.config.user), ) ) cursor.execute( - sql.SQL("REVOKE ALL PRIVILEGES ON ALL SEQUENCES IN SCHEMA {} FROM {}" ).format( + sql.SQL( + "REVOKE ALL PRIVILEGES ON ALL SEQUENCES IN SCHEMA {} FROM {}" + ).format( sql.Identifier(schema_name), - sql.Identifier(self.config.user) + sql.Identifier(self.config.user), ) ) cursor.execute( - sql.SQL("ALTER DEFAULT PRIVILEGES IN SCHEMA {} REVOKE SELECT, INSERT, UPDATE, DELETE ON TABLES FROM {}" ).format( + sql.SQL( + "ALTER DEFAULT PRIVILEGES IN SCHEMA {} REVOKE SELECT, INSERT, UPDATE, DELETE ON TABLES FROM {}" + ).format( sql.Identifier(schema_name), - sql.Identifier(self.config.user) + sql.Identifier(self.config.user), ) ) cursor.execute( - sql.SQL("ALTER DEFAULT PRIVILEGES IN SCHEMA {} REVOKE USAGE, SELECT ON SEQUENCES FROM {}" ).format( + sql.SQL( + "ALTER DEFAULT PRIVILEGES IN SCHEMA {} REVOKE USAGE, SELECT ON SEQUENCES FROM {}" + ).format( sql.Identifier(schema_name), - sql.Identifier(self.config.user) + sql.Identifier(self.config.user), ) ) @@ -1064,19 +1070,18 @@ def parse_args() -> argparse.Namespace: ) parser.add_argument("--db-driver", help="Override DATABASE_DRIVER") parser.add_argument("--db-host", help="Override DATABASE_HOST") - parser.add_argument("--db-port", type=int, - help="Override DATABASE_PORT") + parser.add_argument("--db-port", type=int, help="Override DATABASE_PORT") parser.add_argument("--db-name", help="Override DATABASE_NAME") parser.add_argument("--db-user", help="Override DATABASE_USER") - parser.add_argument( - "--db-password", help="Override DATABASE_PASSWORD") + parser.add_argument("--db-password", help="Override DATABASE_PASSWORD") parser.add_argument("--db-schema", help="Override DATABASE_SCHEMA") parser.add_argument( "--admin-url", help="Override DATABASE_ADMIN_URL for administrative operations", ) parser.add_argument( - "--admin-user", help="Override DATABASE_SUPERUSER for admin ops") + "--admin-user", help="Override DATABASE_SUPERUSER for admin ops" + ) parser.add_argument( "--admin-password", help="Override DATABASE_SUPERUSER_PASSWORD for admin ops", @@ -1091,7 +1096,11 @@ def parse_args() -> argparse.Namespace: help="Log actions without applying changes.", ) parser.add_argument( - "--verbose", "-v", action="count", default=0, help="Increase logging verbosity" + "--verbose", + "-v", + action="count", + default=0, + help="Increase logging verbosity", ) return parser.parse_args() @@ -1099,8 +1108,9 @@ def parse_args() -> argparse.Namespace: def main() -> None: args = parse_args() level = logging.WARNING - (10 * min(args.verbose, 2)) - logging.basicConfig(level=max(level, logging.INFO), - format="%(levelname)s %(message)s") + logging.basicConfig( + level=max(level, logging.INFO), format="%(levelname)s %(message)s" + ) override_args: dict[str, Optional[str]] = { "DATABASE_DRIVER": args.db_driver, @@ -1120,7 +1130,9 @@ def main() -> None: config = DatabaseConfig.from_env(overrides=override_args) setup = DatabaseSetup(config, dry_run=args.dry_run) - admin_tasks_requested = args.ensure_database or args.ensure_role or args.ensure_schema + admin_tasks_requested = ( + args.ensure_database or args.ensure_role or args.ensure_schema + ) if admin_tasks_requested: setup.validate_admin_connection() @@ -1145,9 +1157,7 @@ def main() -> None: auto_run_migrations_reason: Optional[str] = None if args.seed_data and not should_run_migrations: should_run_migrations = True - auto_run_migrations_reason = ( - "Seed data requested without explicit --run-migrations; applying migrations first." - ) + auto_run_migrations_reason = "Seed data requested without explicit --run-migrations; applying migrations first." try: if args.ensure_database: @@ -1167,9 +1177,7 @@ def main() -> None: if auto_run_migrations_reason: logger.info(auto_run_migrations_reason) migrations_path = ( - Path(args.migrations_dir) - if args.migrations_dir - else None + Path(args.migrations_dir) if args.migrations_dir else None ) setup.run_migrations(migrations_path) if args.seed_data: diff --git a/services/reporting.py b/services/reporting.py index 2950414..98387d6 100644 --- a/services/reporting.py +++ b/services/reporting.py @@ -27,7 +27,9 @@ def _percentile(values: List[float], percentile: float) -> float: return sorted_values[lower] * (1 - weight) + sorted_values[upper] * weight -def generate_report(simulation_results: List[Dict[str, float]]) -> Dict[str, Union[float, int]]: +def generate_report( + simulation_results: List[Dict[str, float]], +) -> Dict[str, Union[float, int]]: """Aggregate basic statistics for simulation outputs.""" values = _extract_results(simulation_results) @@ -63,7 +65,7 @@ def generate_report(simulation_results: List[Dict[str, float]]) -> Dict[str, Uni std_dev = pstdev(values) if len(values) > 1 else 0.0 summary["std_dev"] = std_dev - summary["variance"] = std_dev ** 2 + summary["variance"] = std_dev**2 var_95 = summary["percentile_5"] summary["value_at_risk_95"] = var_95 diff --git a/services/security.py b/services/security.py new file mode 100644 index 0000000..ce376e3 --- /dev/null +++ b/services/security.py @@ -0,0 +1,32 @@ +from datetime import datetime, timedelta +from typing import Any, Union + +from jose import jwt +from passlib.context import CryptContext + + +ACCESS_TOKEN_EXPIRE_MINUTES = 30 +SECRET_KEY = "your-secret-key" # Change this in production +ALGORITHM = "HS256" + +pwd_context = CryptContext(schemes=["pbkdf2_sha256"], deprecated="auto") + + +def create_access_token( + subject: Union[str, Any], expires_delta: Union[timedelta, None] = None +) -> str: + if expires_delta: + expire = datetime.utcnow() + expires_delta + else: + expire = datetime.utcnow() + timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES) + to_encode = {"exp": expire, "sub": str(subject)} + encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM) + return encoded_jwt + + +def verify_password(plain_password: str, hashed_password: str) -> bool: + return pwd_context.verify(plain_password, hashed_password) + + +def get_password_hash(password: str) -> str: + return pwd_context.hash(password) diff --git a/services/settings.py b/services/settings.py index a3ff564..51b49ac 100644 --- a/services/settings.py +++ b/services/settings.py @@ -7,6 +7,7 @@ from typing import Dict, Mapping from sqlalchemy.orm import Session from models.application_setting import ApplicationSetting +from models.theme_setting import ThemeSetting # Import ThemeSetting model CSS_COLOR_CATEGORY = "theme" CSS_COLOR_VALUE_TYPE = "color" @@ -92,7 +93,9 @@ def get_css_color_settings(db: Session) -> Dict[str, str]: return values -def update_css_color_settings(db: Session, updates: Mapping[str, str]) -> Dict[str, str]: +def update_css_color_settings( + db: Session, updates: Mapping[str, str] +) -> Dict[str, str]: """Persist provided CSS color overrides and return the final values.""" if not updates: @@ -176,8 +179,10 @@ def _validate_functional_color(value: str) -> None: def _ensure_component_count(value: str, expected: int) -> None: if not value.endswith(")"): - raise ValueError("Color function expressions must end with a closing parenthesis") - inner = value[value.index("(") + 1 : -1] + raise ValueError( + "Color function expressions must end with a closing parenthesis" + ) + inner = value[value.index("(") + 1: -1] parts = [segment.strip() for segment in inner.split(",")] if len(parts) != expected: raise ValueError( @@ -206,3 +211,20 @@ def list_css_env_override_rows( } ) return rows + + +def save_theme_settings(db: Session, theme_data: dict): + theme = db.query(ThemeSetting).first() or ThemeSetting() + for key, value in theme_data.items(): + setattr(theme, key, value) + db.add(theme) + db.commit() + db.refresh(theme) + return theme + + +def get_theme_settings(db: Session): + theme = db.query(ThemeSetting).first() + if theme: + return {c.name: getattr(theme, c.name) for c in theme.__table__.columns} + return {} diff --git a/services/simulation.py b/services/simulation.py index 4a433f2..6c8ffe1 100644 --- a/services/simulation.py +++ b/services/simulation.py @@ -25,12 +25,13 @@ def _ensure_positive_span(span: float, fallback: float) -> float: return span if span and span > 0 else fallback -def _compile_parameters(parameters: Sequence[Dict[str, float]]) -> List[SimulationParameter]: +def _compile_parameters( + parameters: Sequence[Dict[str, float]], +) -> List[SimulationParameter]: compiled: List[SimulationParameter] = [] for index, item in enumerate(parameters): if "value" not in item: - raise ValueError( - f"Parameter at index {index} must include 'value'") + raise ValueError(f"Parameter at index {index} must include 'value'") name = str(item.get("name", f"param_{index}")) base_value = float(item["value"]) distribution = str(item.get("distribution", "normal")).lower() @@ -43,8 +44,11 @@ def _compile_parameters(parameters: Sequence[Dict[str, float]]) -> List[Simulati if distribution == "normal": std_dev = item.get("std_dev") - std_dev_value = float(std_dev) if std_dev is not None else abs( - base_value) * DEFAULT_STD_DEV_RATIO or 1.0 + std_dev_value = ( + float(std_dev) + if std_dev is not None + else abs(base_value) * DEFAULT_STD_DEV_RATIO or 1.0 + ) compiled.append( SimulationParameter( name=name, diff --git a/static/js/theme.js b/static/js/theme.js new file mode 100644 index 0000000..0ff624f --- /dev/null +++ b/static/js/theme.js @@ -0,0 +1,108 @@ +// static/js/theme.js + +document.addEventListener('DOMContentLoaded', () => { + const themeSettingsForm = document.getElementById('theme-settings-form'); + const colorInputs = themeSettingsForm + ? themeSettingsForm.querySelectorAll('input[type="color"]') + : []; + + // Function to apply theme settings to CSS variables + function applyTheme(theme) { + const root = document.documentElement; + if (theme.primary_color) + root.style.setProperty('--color-primary', theme.primary_color); + if (theme.secondary_color) + root.style.setProperty('--color-secondary', theme.secondary_color); + if (theme.accent_color) + root.style.setProperty('--color-accent', theme.accent_color); + if (theme.background_color) + root.style.setProperty('--color-background', theme.background_color); + if (theme.text_color) + root.style.setProperty('--color-text-primary', theme.text_color); + // Add other theme properties as needed + } + + // Save theme to local storage + function saveTheme(theme) { + localStorage.setItem('user-theme', JSON.stringify(theme)); + } + + // Load theme from local storage + function loadTheme() { + const savedTheme = localStorage.getItem('user-theme'); + return savedTheme ? JSON.parse(savedTheme) : null; + } + + // Real-time preview for color inputs + colorInputs.forEach((input) => { + input.addEventListener('input', (event) => { + const cssVar = `--color-${event.target.id.replace('-', '_')}`; + document.documentElement.style.setProperty(cssVar, event.target.value); + }); + }); + + if (themeSettingsForm) { + themeSettingsForm.addEventListener('submit', async (event) => { + event.preventDefault(); + + const formData = new FormData(themeSettingsForm); + const themeData = Object.fromEntries(formData.entries()); + + try { + const response = await fetch('/api/theme-settings', { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + }, + body: JSON.stringify(themeData), + }); + + if (response.ok) { + alert('Theme settings saved successfully!'); + applyTheme(themeData); + saveTheme(themeData); + } else { + const errorData = await response.json(); + alert(`Error saving theme settings: ${errorData.detail}`); + } + } catch (error) { + console.error('Error:', error); + alert('An error occurred while saving theme settings.'); + } + }); + } + + // Load and apply theme on page load + const initialTheme = loadTheme(); + if (initialTheme) { + applyTheme(initialTheme); + // Populate form fields if on the theme settings page + if (themeSettingsForm) { + for (const key in initialTheme) { + const input = themeSettingsForm.querySelector( + `#${key.replace('_', '-')}` + ); + if (input) { + input.value = initialTheme[key]; + } + } + } + } else { + // If no saved theme, load from backend (if available) + async function loadAndApplyThemeFromServer() { + try { + const response = await fetch('/api/theme-settings'); // Assuming a GET endpoint for theme settings + if (response.ok) { + const theme = await response.json(); + applyTheme(theme); + saveTheme(theme); // Save to local storage for future use + } else { + console.error('Failed to load theme settings from server'); + } + } catch (error) { + console.error('Error loading theme settings from server:', error); + } + } + loadAndApplyThemeFromServer(); + } +}); diff --git a/templates/base.html b/templates/base.html index d0f5ac8..53722db 100644 --- a/templates/base.html +++ b/templates/base.html @@ -20,5 +20,6 @@ {% block scripts %}{% endblock %} + diff --git a/templates/forgot_password.html b/templates/forgot_password.html new file mode 100644 index 0000000..4d21fd3 --- /dev/null +++ b/templates/forgot_password.html @@ -0,0 +1,17 @@ +{% extends "base.html" %} + +{% block title %}Forgot Password{% endblock %} + +{% block content %} +
+

Forgot Password

+
+
+ + +
+ +
+

Remember your password? Login here

+
+{% endblock %} diff --git a/templates/login.html b/templates/login.html new file mode 100644 index 0000000..6c2eb00 --- /dev/null +++ b/templates/login.html @@ -0,0 +1,22 @@ +{% extends "base.html" %} + +{% block title %}Login{% endblock %} + +{% block content %} +
+

Login

+
+
+ + +
+
+ + +
+ +
+

Don't have an account? Register here

+

Forgot password?

+
+{% endblock %} diff --git a/templates/partials/sidebar_nav.html b/templates/partials/sidebar_nav.html index 91e006c..ba313c5 100644 --- a/templates/partials/sidebar_nav.html +++ b/templates/partials/sidebar_nav.html @@ -1,88 +1,49 @@ -{% set nav_groups = [ - { - "label": "Dashboard", - "links": [ - {"href": "/", "label": "Dashboard"}, - ], - }, - { - "label": "Scenarios", - "links": [ - {"href": "/ui/scenarios", "label": "Overview"}, - {"href": "/ui/parameters", "label": "Parameters"}, - {"href": "/ui/costs", "label": "Costs"}, - {"href": "/ui/consumption", "label": "Consumption"}, - {"href": "/ui/production", "label": "Production"}, - { - "href": "/ui/equipment", - "label": "Equipment", - "children": [ - {"href": "/ui/maintenance", "label": "Maintenance"}, - ], - }, - ], - }, - { - "label": "Analysis", - "links": [ - {"href": "/ui/simulations", "label": "Simulations"}, - {"href": "/ui/reporting", "label": "Reporting"}, - ], - }, - { - "label": "Settings", - "links": [ - { - "href": "/ui/settings", - "label": "Settings", - "children": [ - {"href": "/ui/currencies", "label": "Currency Management"}, - ], - }, - ], - }, -] %} +{% set nav_groups = [ { "label": "Dashboard", "links": [ {"href": "/", "label": +"Dashboard"}, ], }, { "label": "Overview", "links": [ {"href": "/ui/parameters", +"label": "Parameters"}, {"href": "/ui/costs", "label": "Costs"}, {"href": +"/ui/consumption", "label": "Consumption"}, {"href": "/ui/production", "label": +"Production"}, { "href": "/ui/equipment", "label": "Equipment", "children": [ +{"href": "/ui/maintenance", "label": "Maintenance"}, ], }, ], }, { "label": +"Simulations", "links": [ {"href": "/ui/simulations", "label": "Simulations"}, +], }, { "label": "Analytics", "links": [ {"href": "/ui/reporting", "label": +"Reporting"}, ], }, { "label": "Settings", "links": [ { "href": "/ui/settings", +"label": "Settings", "children": [ {"href": "/theme-settings", "label": +"Themes"}, {"href": "/ui/currencies", "label": "Currency Management"}, ], }, ], +}, ] %} diff --git a/templates/profile.html b/templates/profile.html new file mode 100644 index 0000000..4e9a861 --- /dev/null +++ b/templates/profile.html @@ -0,0 +1,31 @@ +{% extends "base.html" %} + +{% block title %}Profile{% endblock %} + +{% block content %} +
+

User Profile

+

Username:

+

Email:

+ + + +
+{% endblock %} diff --git a/templates/register.html b/templates/register.html new file mode 100644 index 0000000..04a7b4e --- /dev/null +++ b/templates/register.html @@ -0,0 +1,25 @@ +{% extends "base.html" %} + +{% block title %}Register{% endblock %} + +{% block content %} +
+

Register

+
+
+ + +
+
+ + +
+
+ + +
+ +
+

Already have an account? Login here

+
+{% endblock %} diff --git a/templates/settings.html b/templates/settings.html index 0942acb..1fcbc21 100644 --- a/templates/settings.html +++ b/templates/settings.html @@ -1,113 +1,26 @@ -{% extends "base.html" %} - -{% block title %}Settings · CalMiner{% endblock %} - -{% block content %} - -
-
-

Currency Management

-

Manage available currencies, symbols, and default selections from the Currency Management page.

- Go to Currency Management -
-
-

Visual Theme

-

Adjust CalMiner theme colors and preview changes instantly.

-

Changes save to the settings table and apply across the UI after submission. Environment overrides (if configured) remain read-only.

-
-
- -
-
-
-

Theme Colors

-

Update global CSS variables to customize CalMiner's appearance.

-
-
-
- {% for key, value in css_variables.items() %} - {% set env_meta = css_env_override_meta.get(key) %} - - {% endfor %} - -
- - -
-
- {% from "partials/components.html" import feedback with context %} - {{ feedback("theme-settings-feedback") }} -
- -
-
-
-

Environment Overrides

-

The following CSS variables are controlled via environment variables and take precedence over database values.

-
-
- {% if css_env_override_rows %} -
- - - - - - - - - - {% for row in css_env_override_rows %} - - - - - - {% endfor %} - -
CSS VariableEnvironment VariableValue
{{ row.css_key }}{{ row.env_var }}{{ row.value }}
-
- {% else %} -

No environment overrides configured.

- {% endif %} -
-{% endblock %} - -{% block scripts %} - {{ super() }} - - +{% extends "base.html" %} {% block title %}Settings · CalMiner{% endblock %} {% +block content %} + +
+
+

Currency Management

+

+ Manage available currencies, symbols, and default selections from the + Currency Management page. +

+ Go to Currency Management +
+ +
{% endblock %} diff --git a/templates/theme_settings.html b/templates/theme_settings.html new file mode 100644 index 0000000..72cecf4 --- /dev/null +++ b/templates/theme_settings.html @@ -0,0 +1,125 @@ +{% extends "base.html" %} {% block title %}Theme Settings · CalMiner{% endblock +%} {% block content %} + + +
+
+
+

Theme Colors

+

+ Update global CSS variables to customize CalMiner's appearance. +

+
+
+
+ {% for key, value in css_variables.items() %} {% set env_meta = + css_env_override_meta.get(key) %} + + {% endfor %} + +
+ + +
+
+ {% from "partials/components.html" import feedback with context %} {{ + feedback("theme-settings-feedback") }} +
+ +
+
+
+

Environment Overrides

+

+ The following CSS variables are controlled via environment variables and + take precedence over database values. +

+
+
+ {% if css_env_override_rows %} +
+ + + + + + + + + + {% for row in css_env_override_rows %} + + + + + + {% endfor %} + +
CSS VariableEnvironment VariableValue
{{ row.css_key }}{{ row.env_var }}{{ row.value }}
+
+ {% else %} +

No environment overrides configured.

+ {% endif %} +
+{% endblock %} {% block scripts %} {{ super() }} + + +{% endblock %} diff --git a/tests/e2e/conftest.py b/tests/e2e/conftest.py index bfb6f1d..6ced399 100644 --- a/tests/e2e/conftest.py +++ b/tests/e2e/conftest.py @@ -4,6 +4,7 @@ import time from typing import Dict, Generator import pytest + # type: ignore[import] from playwright.sync_api import Browser, Page, Playwright, sync_playwright @@ -70,10 +71,17 @@ def seed_default_currencies(live_server: str) -> None: seeds = [ {"code": "EUR", "name": "Euro", "symbol": "EUR", "is_active": True}, - {"code": "CLP", "name": "Chilean Peso", "symbol": "CLP$", "is_active": True}, + { + "code": "CLP", + "name": "Chilean Peso", + "symbol": "CLP$", + "is_active": True, + }, ] - with httpx.Client(base_url=live_server, timeout=5.0, trust_env=False) as client: + with httpx.Client( + base_url=live_server, timeout=5.0, trust_env=False + ) as client: try: response = client.get("/api/currencies/?include_inactive=true") response.raise_for_status() @@ -128,8 +136,12 @@ def page(browser: Browser, live_server: str) -> Generator[Page, None, None]: def _prepare_database_environment(env: Dict[str, str]) -> Dict[str, str]: """Ensure granular database env vars are available for the app under test.""" - required = ("DATABASE_HOST", "DATABASE_USER", - "DATABASE_NAME", "DATABASE_PASSWORD") + required = ( + "DATABASE_HOST", + "DATABASE_USER", + "DATABASE_NAME", + "DATABASE_PASSWORD", + ) if all(env.get(key) for key in required): return env diff --git a/tests/e2e/test_consumption.py b/tests/e2e/test_consumption.py index 1303e71..685db93 100644 --- a/tests/e2e/test_consumption.py +++ b/tests/e2e/test_consumption.py @@ -7,7 +7,9 @@ def test_consumption_form_loads(page: Page): """Verify the consumption form page loads correctly.""" page.goto("/ui/consumption") expect(page).to_have_title("Consumption · CalMiner") - expect(page.locator("h2:has-text('Add Consumption Record')")).to_be_visible() + expect( + page.locator("h2:has-text('Add Consumption Record')") + ).to_be_visible() def test_create_consumption_item(page: Page): diff --git a/tests/e2e/test_costs.py b/tests/e2e/test_costs.py index 6e52b3b..c49439a 100644 --- a/tests/e2e/test_costs.py +++ b/tests/e2e/test_costs.py @@ -55,7 +55,9 @@ def test_create_capex_and_opex_items(page: Page): ).to_be_visible() # Verify the feedback messages. - expect(page.locator("#capex-feedback") - ).to_have_text("Entry saved successfully.") - expect(page.locator("#opex-feedback") - ).to_have_text("Entry saved successfully.") + expect(page.locator("#capex-feedback")).to_have_text( + "Entry saved successfully." + ) + expect(page.locator("#opex-feedback")).to_have_text( + "Entry saved successfully." + ) diff --git a/tests/e2e/test_currencies.py b/tests/e2e/test_currencies.py index b467ad1..4b7f8d0 100644 --- a/tests/e2e/test_currencies.py +++ b/tests/e2e/test_currencies.py @@ -12,7 +12,8 @@ def _unique_currency_code(existing: set[str]) -> str: if candidate not in existing and candidate != "USD": return candidate raise AssertionError( - "Unable to generate a unique currency code for the test run.") + "Unable to generate a unique currency code for the test run." + ) def _metric_value(page: Page, element_id: str) -> int: @@ -42,8 +43,9 @@ def test_currency_workflow_create_update_toggle(page: Page) -> None: expect(page.locator("h2:has-text('Currency Overview')")).to_be_visible() code_cells = page.locator("#currencies-table-body tr td:nth-child(1)") - existing_codes = {text.strip().upper() - for text in code_cells.all_inner_texts()} + existing_codes = { + text.strip().upper() for text in code_cells.all_inner_texts() + } total_before = _metric_value(page, "currency-metric-total") active_before = _metric_value(page, "currency-metric-active") @@ -109,7 +111,9 @@ def test_currency_workflow_create_update_toggle(page: Page) -> None: toggle_button = row.locator("button[data-action='toggle']") expect(toggle_button).to_have_text("Activate") - with page.expect_response(f"**/api/currencies/{new_code}/activation") as toggle_info: + with page.expect_response( + f"**/api/currencies/{new_code}/activation" + ) as toggle_info: toggle_button.click() toggle_response = toggle_info.value assert toggle_response.status == 200 @@ -126,5 +130,6 @@ def test_currency_workflow_create_update_toggle(page: Page) -> None: _expect_feedback(page, f"Currency {new_code} activated.") expect(row.locator("td").nth(3)).to_contain_text("Active") - expect(row.locator("button[data-action='toggle']") - ).to_have_text("Deactivate") + expect(row.locator("button[data-action='toggle']")).to_have_text( + "Deactivate" + ) diff --git a/tests/e2e/test_equipment.py b/tests/e2e/test_equipment.py index 5e0c4f3..f507a6e 100644 --- a/tests/e2e/test_equipment.py +++ b/tests/e2e/test_equipment.py @@ -38,11 +38,8 @@ def test_create_equipment_item(page: Page): # Verify the new item appears in the table. page.select_option("#equipment-scenario-filter", label=scenario_name) expect( - page.locator("#equipment-table-body tr").filter( - has_text=equipment_name - ) + page.locator("#equipment-table-body tr").filter(has_text=equipment_name) ).to_be_visible() # Verify the feedback message. - expect(page.locator("#equipment-feedback") - ).to_have_text("Equipment saved.") + expect(page.locator("#equipment-feedback")).to_have_text("Equipment saved.") diff --git a/tests/e2e/test_maintenance.py b/tests/e2e/test_maintenance.py index 08dc77c..fb9a403 100644 --- a/tests/e2e/test_maintenance.py +++ b/tests/e2e/test_maintenance.py @@ -53,5 +53,6 @@ def test_create_maintenance_item(page: Page): ).to_be_visible() # Verify the feedback message. - expect(page.locator("#maintenance-feedback") - ).to_have_text("Maintenance entry saved.") + expect(page.locator("#maintenance-feedback")).to_have_text( + "Maintenance entry saved." + ) diff --git a/tests/e2e/test_production.py b/tests/e2e/test_production.py index 09c98bb..72a63ba 100644 --- a/tests/e2e/test_production.py +++ b/tests/e2e/test_production.py @@ -43,5 +43,6 @@ def test_create_production_item(page: Page): ).to_be_visible() # Verify the feedback message. - expect(page.locator("#production-feedback") - ).to_have_text("Production output saved.") + expect(page.locator("#production-feedback")).to_have_text( + "Production output saved." + ) diff --git a/tests/e2e/test_scenarios.py b/tests/e2e/test_scenarios.py index 0f3a419..04f37ea 100644 --- a/tests/e2e/test_scenarios.py +++ b/tests/e2e/test_scenarios.py @@ -39,4 +39,5 @@ def test_create_new_scenario(page: Page): feedback = page.locator("#feedback") expect(feedback).to_be_visible() expect(feedback).to_have_text( - f'Scenario "{scenario_name}" created successfully.') + f'Scenario "{scenario_name}" created successfully.' + ) diff --git a/tests/e2e/test_smoke.py b/tests/e2e/test_smoke.py index 291d007..a9f0b23 100644 --- a/tests/e2e/test_smoke.py +++ b/tests/e2e/test_smoke.py @@ -5,7 +5,11 @@ from playwright.sync_api import Page, expect UI_ROUTES = [ ("/", "Dashboard · CalMiner", "Operations Overview"), ("/ui/dashboard", "Dashboard · CalMiner", "Operations Overview"), - ("/ui/scenarios", "Scenario Management · CalMiner", "Create a New Scenario"), + ( + "/ui/scenarios", + "Scenario Management · CalMiner", + "Create a New Scenario", + ), ("/ui/parameters", "Process Parameters · CalMiner", "Scenario Parameters"), ("/ui/settings", "Settings · CalMiner", "Settings"), ("/ui/costs", "Costs · CalMiner", "Cost Overview"), @@ -20,35 +24,44 @@ UI_ROUTES = [ @pytest.mark.parametrize("url, title, heading", UI_ROUTES) -def test_ui_pages_load_correctly(page: Page, url: str, title: str, heading: str): +def test_ui_pages_load_correctly( + page: Page, url: str, title: str, heading: str +): """Verify that all UI pages load with the correct title and a visible heading.""" page.goto(url) expect(page).to_have_title(title) # The app uses a mix of h1 and h2 for main page headings. heading_locator = page.locator( - f"h1:has-text('{heading}'), h2:has-text('{heading}')") + f"h1:has-text('{heading}'), h2:has-text('{heading}')" + ) expect(heading_locator.first).to_be_visible() def test_settings_theme_form_interaction(page: Page): - page.goto("/ui/settings") - expect(page).to_have_title("Settings · CalMiner") + page.goto("/theme-settings") + expect(page).to_have_title("Theme Settings · CalMiner") env_rows = page.locator("#theme-env-overrides tbody tr") disabled_inputs = page.locator( - "#theme-settings-form input.color-value-input[disabled]") + "#theme-settings-form input.color-value-input[disabled]" + ) env_row_count = env_rows.count() disabled_count = disabled_inputs.count() assert disabled_count == env_row_count color_input = page.locator( - "#theme-settings-form input[name='--color-primary']") + "#theme-settings-form input[name='--color-primary']" + ) expect(color_input).to_be_visible() expect(color_input).to_be_enabled() original_value = color_input.input_value() candidate_values = ("#114455", "#225566") - new_value = candidate_values[0] if original_value != candidate_values[0] else candidate_values[1] + new_value = ( + candidate_values[0] + if original_value != candidate_values[0] + else candidate_values[1] + ) color_input.fill(new_value) page.click("#theme-settings-form button[type='submit']") diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py index 0ecb00e..00d8401 100644 --- a/tests/unit/conftest.py +++ b/tests/unit/conftest.py @@ -27,7 +27,8 @@ engine = create_engine( poolclass=StaticPool, ) TestingSessionLocal = sessionmaker( - autocommit=False, autoflush=False, bind=engine) + autocommit=False, autoflush=False, bind=engine +) @pytest.fixture(scope="session", autouse=True) @@ -37,19 +38,24 @@ def setup_database() -> Generator[None, None, None]: application_setting, capex, consumption, + currency, distribution, equipment, maintenance, opex, parameters, production_output, + role, scenario, simulation_result, + theme_setting, + user, ) # noqa: F401 - imported for side effects _ = ( capex, consumption, + currency, distribution, equipment, maintenance, @@ -57,8 +63,11 @@ def setup_database() -> Generator[None, None, None]: opex, parameters, production_output, + role, scenario, simulation_result, + theme_setting, + user, ) Base.metadata.create_all(bind=engine) @@ -86,22 +95,23 @@ def api_client(db_session: Session) -> Generator[TestClient, None, None]: finally: pass - from routes import dependencies as route_dependencies + from routes.dependencies import get_db - app.dependency_overrides[route_dependencies.get_db] = override_get_db + app.dependency_overrides[get_db] = override_get_db with TestClient(app) as client: yield client - app.dependency_overrides.pop(route_dependencies.get_db, None) + app.dependency_overrides.pop(get_db, None) @pytest.fixture() -def seeded_ui_data(db_session: Session) -> Generator[Dict[str, Any], None, None]: +def seeded_ui_data( + db_session: Session, +) -> Generator[Dict[str, Any], None, None]: """Populate a scenario with representative related records for UI tests.""" scenario_name = f"Scenario Alpha {uuid4()}" - scenario = Scenario(name=scenario_name, - description="Seeded UI scenario") + scenario = Scenario(name=scenario_name, description="Seeded UI scenario") db_session.add(scenario) db_session.flush() @@ -161,7 +171,9 @@ def seeded_ui_data(db_session: Session) -> Generator[Dict[str, Any], None, None] iteration=index, result=value, ) - for index, value in enumerate((950_000.0, 975_000.0, 990_000.0), start=1) + for index, value in enumerate( + (950_000.0, 975_000.0, 990_000.0), start=1 + ) ] db_session.add(maintenance) @@ -196,11 +208,15 @@ def seeded_ui_data(db_session: Session) -> Generator[Dict[str, Any], None, None] @pytest.fixture() -def invalid_request_payloads(db_session: Session) -> Generator[Dict[str, Any], None, None]: +def invalid_request_payloads( + db_session: Session, +) -> Generator[Dict[str, Any], None, None]: """Provide reusable invalid request bodies for exercising validation branches.""" duplicate_name = f"Scenario Duplicate {uuid4()}" - existing = Scenario(name=duplicate_name, - description="Existing scenario for duplicate checks") + existing = Scenario( + name=duplicate_name, + description="Existing scenario for duplicate checks", + ) db_session.add(existing) db_session.commit() diff --git a/tests/unit/test_auth.py b/tests/unit/test_auth.py new file mode 100644 index 0000000..f4c0251 --- /dev/null +++ b/tests/unit/test_auth.py @@ -0,0 +1,231 @@ +from services.security import get_password_hash, verify_password + + +def test_password_hashing(): + password = "testpassword" + hashed_password = get_password_hash(password) + assert verify_password(password, hashed_password) + assert not verify_password("wrongpassword", hashed_password) + + +def test_register_user(api_client): + response = api_client.post( + "/users/register", + json={ + "username": "testuser", + "email": "test@example.com", + "password": "testpassword", + }, + ) + assert response.status_code == 201 + data = response.json() + assert data["username"] == "testuser" + assert data["email"] == "test@example.com" + assert "id" in data + assert "role_id" in data + + response = api_client.post( + "/users/register", + json={ + "username": "testuser", + "email": "another@example.com", + "password": "testpassword", + }, + ) + assert response.status_code == 400 + assert response.json() == {"detail": "Username already registered"} + + response = api_client.post( + "/users/register", + json={ + "username": "anotheruser", + "email": "test@example.com", + "password": "testpassword", + }, + ) + assert response.status_code == 400 + assert response.json() == {"detail": "Email already registered"} + + +def test_login_user(api_client): + # Register a user first + api_client.post( + "/users/register", + json={ + "username": "loginuser", + "email": "login@example.com", + "password": "loginpassword", + }, + ) + + response = api_client.post( + "/users/login", + json={"username": "loginuser", "password": "loginpassword"}, + ) + assert response.status_code == 200 + data = response.json() + assert "access_token" in data + assert data["token_type"] == "bearer" + + response = api_client.post( + "/users/login", + json={"username": "loginuser", "password": "wrongpassword"}, + ) + assert response.status_code == 401 + assert response.json() == {"detail": "Incorrect username or password"} + + response = api_client.post( + "/users/login", + json={"username": "nonexistent", "password": "password"}, + ) + assert response.status_code == 401 + assert response.json() == {"detail": "Incorrect username or password"} + + +def test_read_users_me(api_client): + # Register a user first + api_client.post( + "/users/register", + json={ + "username": "profileuser", + "email": "profile@example.com", + "password": "profilepassword", + }, + ) + # Login to get a token + login_response = api_client.post( + "/users/login", + json={"username": "profileuser", "password": "profilepassword"}, + ) + token = login_response.json()["access_token"] + + response = api_client.get( + "/users/me", headers={"Authorization": f"Bearer {token}"} + ) + assert response.status_code == 200 + data = response.json() + assert data["username"] == "profileuser" + assert data["email"] == "profile@example.com" + + +def test_update_users_me(api_client): + # Register a user first + api_client.post( + "/users/register", + json={ + "username": "updateuser", + "email": "update@example.com", + "password": "updatepassword", + }, + ) + # Login to get a token + login_response = api_client.post( + "/users/login", + json={"username": "updateuser", "password": "updatepassword"}, + ) + token = login_response.json()["access_token"] + + response = api_client.put( + "/users/me", + headers={"Authorization": f"Bearer {token}"}, + json={ + "username": "updateduser", + "email": "updated@example.com", + "password": "newpassword", + }, + ) + assert response.status_code == 200 + data = response.json() + assert data["username"] == "updateduser" + assert data["email"] == "updated@example.com" + + # Verify password change + response = api_client.post( + "/users/login", + json={"username": "updateduser", "password": "newpassword"}, + ) + assert response.status_code == 200 + token = response.json()["access_token"] + + # Test username already taken + api_client.post( + "/users/register", + json={ + "username": "anotherupdateuser", + "email": "anotherupdate@example.com", + "password": "password", + }, + ) + response = api_client.put( + "/users/me", + headers={"Authorization": f"Bearer {token}"}, + json={ + "username": "anotherupdateuser", + }, + ) + assert response.status_code == 400 + assert response.json() == {"detail": "Username already taken"} + + # Test email already registered + api_client.post( + "/users/register", + json={ + "username": "yetanotheruser", + "email": "yetanother@example.com", + "password": "password", + }, + ) + response = api_client.put( + "/users/me", + headers={"Authorization": f"Bearer {token}"}, + json={ + "email": "yetanother@example.com", + }, + ) + assert response.status_code == 400 + assert response.json() == {"detail": "Email already registered"} + + +def test_forgot_password(api_client): + response = api_client.post( + "/users/forgot-password", json={"email": "nonexistent@example.com"} + ) + assert response.status_code == 200 + assert response.json() == { + "message": "Password reset email sent (not really)"} + + +def test_reset_password(api_client): + # Register a user first + api_client.post( + "/users/register", + json={ + "username": "resetuser", + "email": "reset@example.com", + "password": "oldpassword", + }, + ) + + response = api_client.post( + "/users/reset-password", + json={ + "token": "resetuser", # Use username as token for test + "new_password": "newpassword", + }, + ) + assert response.status_code == 200 + assert response.json() == { + "message": "Password has been reset successfully"} + + # Verify password change + response = api_client.post( + "/users/login", + json={"username": "resetuser", "password": "newpassword"}, + ) + assert response.status_code == 200 + + response = api_client.post( + "/users/login", + json={"username": "resetuser", "password": "oldpassword"}, + ) + assert response.status_code == 401 diff --git a/tests/unit/test_consumption.py b/tests/unit/test_consumption.py index 0c9f57a..9ea7bb3 100644 --- a/tests/unit/test_consumption.py +++ b/tests/unit/test_consumption.py @@ -57,8 +57,11 @@ def test_list_consumption_returns_created_items(client: TestClient) -> None: list_response = client.get("/api/consumption/") assert list_response.status_code == 200 - items = [item for item in list_response.json( - ) if item["scenario_id"] == scenario_id] + items = [ + item + for item in list_response.json() + if item["scenario_id"] == scenario_id + ] assert {item["amount"] for item in items} == set(values) diff --git a/tests/unit/test_costs.py b/tests/unit/test_costs.py index 45bec19..ae4059c 100644 --- a/tests/unit/test_costs.py +++ b/tests/unit/test_costs.py @@ -47,8 +47,9 @@ def test_create_and_list_capex_and_opex(): resp3 = client.get("/api/costs/capex") assert resp3.status_code == 200 data = resp3.json() - assert any(item["amount"] == 1000.0 and item["scenario_id"] - == sid for item in data) + assert any( + item["amount"] == 1000.0 and item["scenario_id"] == sid for item in data + ) opex_payload = { "scenario_id": sid, @@ -66,8 +67,10 @@ def test_create_and_list_capex_and_opex(): resp5 = client.get("/api/costs/opex") assert resp5.status_code == 200 data_o = resp5.json() - assert any(item["amount"] == 500.0 and item["scenario_id"] - == sid for item in data_o) + assert any( + item["amount"] == 500.0 and item["scenario_id"] == sid + for item in data_o + ) def test_multiple_capex_entries(): @@ -88,8 +91,9 @@ def test_multiple_capex_entries(): resp = client.get("/api/costs/capex") assert resp.status_code == 200 data = resp.json() - retrieved_amounts = [item["amount"] - for item in data if item["scenario_id"] == sid] + retrieved_amounts = [ + item["amount"] for item in data if item["scenario_id"] == sid + ] for amount in amounts: assert amount in retrieved_amounts @@ -112,7 +116,8 @@ def test_multiple_opex_entries(): resp = client.get("/api/costs/opex") assert resp.status_code == 200 data = resp.json() - retrieved_amounts = [item["amount"] - for item in data if item["scenario_id"] == sid] + retrieved_amounts = [ + item["amount"] for item in data if item["scenario_id"] == sid + ] for amount in amounts: assert amount in retrieved_amounts diff --git a/tests/unit/test_currencies.py b/tests/unit/test_currencies.py index 5aa674c..044571e 100644 --- a/tests/unit/test_currencies.py +++ b/tests/unit/test_currencies.py @@ -14,7 +14,13 @@ def _cleanup_currencies(db_session): db_session.commit() -def _assert_currency(payload: Dict[str, object], code: str, name: str, symbol: str | None, is_active: bool) -> None: +def _assert_currency( + payload: Dict[str, object], + code: str, + name: str, + symbol: str | None, + is_active: bool, +) -> None: assert payload["code"] == code assert payload["name"] == name assert payload["is_active"] is is_active @@ -47,13 +53,21 @@ def test_create_currency_success(api_client, db_session): def test_create_currency_conflict(api_client, db_session): api_client.post( "/api/currencies/", - json={"code": "CAD", "name": "Canadian Dollar", - "symbol": "$", "is_active": True}, + json={ + "code": "CAD", + "name": "Canadian Dollar", + "symbol": "$", + "is_active": True, + }, ) duplicate = api_client.post( "/api/currencies/", - json={"code": "CAD", "name": "Canadian Dollar", - "symbol": "$", "is_active": True}, + json={ + "code": "CAD", + "name": "Canadian Dollar", + "symbol": "$", + "is_active": True, + }, ) assert duplicate.status_code == 409 @@ -61,8 +75,12 @@ def test_create_currency_conflict(api_client, db_session): def test_update_currency_fields(api_client, db_session): api_client.post( "/api/currencies/", - json={"code": "GBP", "name": "British Pound", - "symbol": "£", "is_active": True}, + json={ + "code": "GBP", + "name": "British Pound", + "symbol": "£", + "is_active": True, + }, ) response = api_client.put( @@ -77,8 +95,12 @@ def test_update_currency_fields(api_client, db_session): def test_toggle_currency_activation(api_client, db_session): api_client.post( "/api/currencies/", - json={"code": "AUD", "name": "Australian Dollar", - "symbol": "A$", "is_active": True}, + json={ + "code": "AUD", + "name": "Australian Dollar", + "symbol": "A$", + "is_active": True, + }, ) response = api_client.patch( @@ -97,5 +119,7 @@ def test_default_currency_cannot_be_deactivated(api_client, db_session): json={"is_active": False}, ) assert response.status_code == 400 - assert response.json()[ - "detail"] == "The default currency cannot be deactivated." + assert ( + response.json()["detail"] + == "The default currency cannot be deactivated." + ) diff --git a/tests/unit/test_currency_workflow.py b/tests/unit/test_currency_workflow.py index 79dba58..f43809a 100644 --- a/tests/unit/test_currency_workflow.py +++ b/tests/unit/test_currency_workflow.py @@ -41,9 +41,10 @@ def test_create_capex_with_currency_code_and_list(api_client, seeded_currency): resp = api_client.post("/api/costs/capex", json=payload) assert resp.status_code == 200 data = resp.json() - assert data.get("currency_code") == seeded_currency.code or data.get( - "currency", {} - ).get("code") == seeded_currency.code + assert ( + data.get("currency_code") == seeded_currency.code + or data.get("currency", {}).get("code") == seeded_currency.code + ) def test_create_opex_with_currency_id(api_client, seeded_currency): diff --git a/tests/unit/test_maintenance.py b/tests/unit/test_maintenance.py index afe85ad..64e646c 100644 --- a/tests/unit/test_maintenance.py +++ b/tests/unit/test_maintenance.py @@ -30,7 +30,9 @@ def _create_scenario_and_equipment(client: TestClient): return scenario_id, equipment_id -def _create_maintenance_payload(equipment_id: int, scenario_id: int, description: str): +def _create_maintenance_payload( + equipment_id: int, scenario_id: int, description: str +): return { "equipment_id": equipment_id, "scenario_id": scenario_id, @@ -43,7 +45,8 @@ def _create_maintenance_payload(equipment_id: int, scenario_id: int, description def test_create_and_list_maintenance(client: TestClient): scenario_id, equipment_id = _create_scenario_and_equipment(client) payload = _create_maintenance_payload( - equipment_id, scenario_id, "Create maintenance") + equipment_id, scenario_id, "Create maintenance" + ) response = client.post("/api/maintenance/", json=payload) assert response.status_code == 201 @@ -95,7 +98,8 @@ def test_update_maintenance(client: TestClient): } response = client.put( - f"/api/maintenance/{maintenance_id}", json=update_payload) + f"/api/maintenance/{maintenance_id}", json=update_payload + ) assert response.status_code == 200 updated = response.json() assert updated["maintenance_date"] == "2025-11-01" @@ -108,7 +112,8 @@ def test_delete_maintenance(client: TestClient): create_response = client.post( "/api/maintenance/", json=_create_maintenance_payload( - equipment_id, scenario_id, "Delete maintenance"), + equipment_id, scenario_id, "Delete maintenance" + ), ) assert create_response.status_code == 201 maintenance_id = create_response.json()["id"] diff --git a/tests/unit/test_parameters.py b/tests/unit/test_parameters.py index 86081a7..e1895e7 100644 --- a/tests/unit/test_parameters.py +++ b/tests/unit/test_parameters.py @@ -67,7 +67,10 @@ def test_create_and_list_parameter(): def test_create_parameter_for_missing_scenario(): payload: Dict[str, Any] = { - "scenario_id": 0, "name": "invalid", "value": 1.0} + "scenario_id": 0, + "name": "invalid", + "value": 1.0, + } response = client.post("/api/parameters/", json=payload) assert response.status_code == 404 assert response.json()["detail"] == "Scenario not found" diff --git a/tests/unit/test_production.py b/tests/unit/test_production.py index cd7c851..106721d 100644 --- a/tests/unit/test_production.py +++ b/tests/unit/test_production.py @@ -42,7 +42,11 @@ def test_list_production_filters_by_scenario(client: TestClient) -> None: target_scenario = _create_scenario(client) other_scenario = _create_scenario(client) - for scenario_id, amount in [(target_scenario, 100.0), (target_scenario, 150.0), (other_scenario, 200.0)]: + for scenario_id, amount in [ + (target_scenario, 100.0), + (target_scenario, 150.0), + (other_scenario, 200.0), + ]: response = client.post( "/api/production/", json={ @@ -57,8 +61,11 @@ def test_list_production_filters_by_scenario(client: TestClient) -> None: list_response = client.get("/api/production/") assert list_response.status_code == 200 - items = [item for item in list_response.json() - if item["scenario_id"] == target_scenario] + items = [ + item + for item in list_response.json() + if item["scenario_id"] == target_scenario + ] assert {item["amount"] for item in items} == {100.0, 150.0} diff --git a/tests/unit/test_reporting.py b/tests/unit/test_reporting.py index ace8c37..45adf38 100644 --- a/tests/unit/test_reporting.py +++ b/tests/unit/test_reporting.py @@ -50,9 +50,11 @@ def test_generate_report_with_values(): def test_generate_report_single_value(): - report = generate_report([ - {"iteration": 1, "result": 42.0}, - ]) + report = generate_report( + [ + {"iteration": 1, "result": 42.0}, + ] + ) assert report["count"] == 1 assert report["std_dev"] == 0.0 assert report["variance"] == 0.0 @@ -105,8 +107,10 @@ def test_reporting_endpoint_success(client: TestClient): validation_error_cases: List[tuple[List[Any], str]] = [ (["not-a-dict"], "Entry at index 0 must be an object"), ([{"iteration": 1}], "Entry at index 0 must include numeric 'result'"), - ([{"iteration": 1, "result": "bad"}], - "Entry at index 0 must include numeric 'result'"), + ( + [{"iteration": 1, "result": "bad"}], + "Entry at index 0 must include numeric 'result'", + ), ] diff --git a/tests/unit/test_router_validation.py b/tests/unit/test_router_validation.py index bd98f84..4c81b73 100644 --- a/tests/unit/test_router_validation.py +++ b/tests/unit/test_router_validation.py @@ -27,7 +27,7 @@ def test_parameter_create_missing_scenario_returns_404( @pytest.mark.usefixtures("invalid_request_payloads") def test_parameter_create_invalid_distribution_is_422( - api_client: TestClient + api_client: TestClient, ) -> None: response = api_client.post( "/api/parameters/", @@ -90,6 +90,5 @@ def test_maintenance_negative_cost_rejected_by_schema( payload = invalid_request_payloads["maintenance_negative_cost"] response = api_client.post("/api/maintenance/", json=payload) assert response.status_code == 422 - error_locations = [tuple(item["loc"]) - for item in response.json()["detail"]] + error_locations = [tuple(item["loc"]) for item in response.json()["detail"]] assert ("body", "cost") in error_locations diff --git a/tests/unit/test_settings_routes.py b/tests/unit/test_settings_routes.py index 1aa691c..81a1aa9 100644 --- a/tests/unit/test_settings_routes.py +++ b/tests/unit/test_settings_routes.py @@ -42,7 +42,7 @@ def test_update_css_settings_persists_changes( @pytest.mark.usefixtures("db_session") def test_update_css_settings_invalid_value_returns_422( - api_client: TestClient + api_client: TestClient, ) -> None: response = api_client.put( "/api/settings/css", diff --git a/tests/unit/test_settings_service.py b/tests/unit/test_settings_service.py index a244c7c..8066c06 100644 --- a/tests/unit/test_settings_service.py +++ b/tests/unit/test_settings_service.py @@ -20,8 +20,14 @@ def fixture_clean_env(monkeypatch: pytest.MonkeyPatch) -> Dict[str, str]: def test_css_key_to_env_var_formatting(): - assert settings_service.css_key_to_env_var("--color-background") == "CALMINER_THEME_COLOR_BACKGROUND" - assert settings_service.css_key_to_env_var("--color-primary-stronger") == "CALMINER_THEME_COLOR_PRIMARY_STRONGER" + assert ( + settings_service.css_key_to_env_var("--color-background") + == "CALMINER_THEME_COLOR_BACKGROUND" + ) + assert ( + settings_service.css_key_to_env_var("--color-primary-stronger") + == "CALMINER_THEME_COLOR_PRIMARY_STRONGER" + ) @pytest.mark.parametrize( @@ -33,7 +39,9 @@ def test_css_key_to_env_var_formatting(): ("--color-text-secondary", "hsla(210, 40%, 40%, 1)"), ], ) -def test_read_css_color_env_overrides_valid_values(clean_env, env_key, env_value): +def test_read_css_color_env_overrides_valid_values( + clean_env, env_key, env_value +): env_var = settings_service.css_key_to_env_var(env_key) clean_env[env_var] = env_value @@ -50,7 +58,9 @@ def test_read_css_color_env_overrides_valid_values(clean_env, env_key, env_value "rgb(1,2)", # malformed rgb ], ) -def test_read_css_color_env_overrides_invalid_values_raise(clean_env, invalid_value): +def test_read_css_color_env_overrides_invalid_values_raise( + clean_env, invalid_value +): env_var = settings_service.css_key_to_env_var("--color-background") clean_env[env_var] = invalid_value @@ -64,7 +74,9 @@ def test_read_css_color_env_overrides_ignores_missing(clean_env): def test_list_css_env_override_rows_returns_structured_data(clean_env): - clean_env[settings_service.css_key_to_env_var("--color-primary")] = "#123456" + clean_env[settings_service.css_key_to_env_var("--color-primary")] = ( + "#123456" + ) rows = settings_service.list_css_env_override_rows(clean_env) assert rows == [ { diff --git a/tests/unit/test_setup_database.py b/tests/unit/test_setup_database.py index c67e1ab..4432d16 100644 --- a/tests/unit/test_setup_database.py +++ b/tests/unit/test_setup_database.py @@ -31,10 +31,13 @@ def setup_instance(mock_config: DatabaseConfig) -> DatabaseSetup: return DatabaseSetup(mock_config, dry_run=True) -def test_seed_baseline_data_dry_run_skips_verification(setup_instance: DatabaseSetup) -> None: - with mock.patch("scripts.seed_data.run_with_namespace") as seed_run, mock.patch.object( - setup_instance, "_verify_seeded_data" - ) as verify_mock: +def test_seed_baseline_data_dry_run_skips_verification( + setup_instance: DatabaseSetup, +) -> None: + with ( + mock.patch("scripts.seed_data.run_with_namespace") as seed_run, + mock.patch.object(setup_instance, "_verify_seeded_data") as verify_mock, + ): setup_instance.seed_baseline_data(dry_run=True) seed_run.assert_called_once() @@ -47,13 +50,16 @@ def test_seed_baseline_data_dry_run_skips_verification(setup_instance: DatabaseS verify_mock.assert_not_called() -def test_seed_baseline_data_invokes_verification(setup_instance: DatabaseSetup) -> None: +def test_seed_baseline_data_invokes_verification( + setup_instance: DatabaseSetup, +) -> None: expected_currencies = {code for code, *_ in seed_data.CURRENCY_SEEDS} expected_units = {code for code, *_ in seed_data.MEASUREMENT_UNIT_SEEDS} - with mock.patch("scripts.seed_data.run_with_namespace") as seed_run, mock.patch.object( - setup_instance, "_verify_seeded_data" - ) as verify_mock: + with ( + mock.patch("scripts.seed_data.run_with_namespace") as seed_run, + mock.patch.object(setup_instance, "_verify_seeded_data") as verify_mock, + ): setup_instance.seed_baseline_data(dry_run=False) seed_run.assert_called_once() @@ -67,7 +73,9 @@ def test_seed_baseline_data_invokes_verification(setup_instance: DatabaseSetup) ) -def test_run_migrations_applies_baseline_when_missing(mock_config: DatabaseConfig, tmp_path) -> None: +def test_run_migrations_applies_baseline_when_missing( + mock_config: DatabaseConfig, tmp_path +) -> None: setup_instance = DatabaseSetup(mock_config, dry_run=False) baseline = tmp_path / "000_base.sql" @@ -88,15 +96,24 @@ def test_run_migrations_applies_baseline_when_missing(mock_config: DatabaseConfi cursor_context.__enter__.return_value = cursor_mock connection_mock.cursor.return_value = cursor_context - with mock.patch.object( - setup_instance, "_application_connection", return_value=connection_mock - ), mock.patch.object( - setup_instance, "_migrations_table_exists", return_value=True - ), mock.patch.object( - setup_instance, "_fetch_applied_migrations", return_value=set() - ), mock.patch.object( - setup_instance, "_apply_migration_file", side_effect=capture_migration - ) as apply_mock: + with ( + mock.patch.object( + setup_instance, + "_application_connection", + return_value=connection_mock, + ), + mock.patch.object( + setup_instance, "_migrations_table_exists", return_value=True + ), + mock.patch.object( + setup_instance, "_fetch_applied_migrations", return_value=set() + ), + mock.patch.object( + setup_instance, + "_apply_migration_file", + side_effect=capture_migration, + ) as apply_mock, + ): setup_instance.run_migrations(tmp_path) assert apply_mock.call_count == 1 @@ -121,17 +138,24 @@ def test_run_migrations_noop_when_all_files_already_applied( connection_mock, cursor_mock = _connection_with_cursor() - with mock.patch.object( - setup_instance, "_application_connection", return_value=connection_mock - ), mock.patch.object( - setup_instance, "_migrations_table_exists", return_value=True - ), mock.patch.object( - setup_instance, - "_fetch_applied_migrations", - return_value={"000_base.sql", "20251022_add_other.sql"}, - ), mock.patch.object( - setup_instance, "_apply_migration_file" - ) as apply_mock: + with ( + mock.patch.object( + setup_instance, + "_application_connection", + return_value=connection_mock, + ), + mock.patch.object( + setup_instance, "_migrations_table_exists", return_value=True + ), + mock.patch.object( + setup_instance, + "_fetch_applied_migrations", + return_value={"000_base.sql", "20251022_add_other.sql"}, + ), + mock.patch.object( + setup_instance, "_apply_migration_file" + ) as apply_mock, + ): setup_instance.run_migrations(tmp_path) apply_mock.assert_not_called() @@ -148,12 +172,16 @@ def _connection_with_cursor() -> tuple[mock.MagicMock, mock.MagicMock]: return connection_mock, cursor_mock -def test_verify_seeded_data_raises_when_currency_missing(mock_config: DatabaseConfig) -> None: +def test_verify_seeded_data_raises_when_currency_missing( + mock_config: DatabaseConfig, +) -> None: setup_instance = DatabaseSetup(mock_config, dry_run=False) connection_mock, cursor_mock = _connection_with_cursor() cursor_mock.fetchall.return_value = [("USD", True)] - with mock.patch.object(setup_instance, "_application_connection", return_value=connection_mock): + with mock.patch.object( + setup_instance, "_application_connection", return_value=connection_mock + ): with pytest.raises(RuntimeError) as exc: setup_instance._verify_seeded_data( expected_currency_codes={"USD", "EUR"}, @@ -163,12 +191,16 @@ def test_verify_seeded_data_raises_when_currency_missing(mock_config: DatabaseCo assert "EUR" in str(exc.value) -def test_verify_seeded_data_raises_when_default_currency_inactive(mock_config: DatabaseConfig) -> None: +def test_verify_seeded_data_raises_when_default_currency_inactive( + mock_config: DatabaseConfig, +) -> None: setup_instance = DatabaseSetup(mock_config, dry_run=False) connection_mock, cursor_mock = _connection_with_cursor() cursor_mock.fetchall.return_value = [("USD", False)] - with mock.patch.object(setup_instance, "_application_connection", return_value=connection_mock): + with mock.patch.object( + setup_instance, "_application_connection", return_value=connection_mock + ): with pytest.raises(RuntimeError) as exc: setup_instance._verify_seeded_data( expected_currency_codes={"USD"}, @@ -178,12 +210,16 @@ def test_verify_seeded_data_raises_when_default_currency_inactive(mock_config: D assert "inactive" in str(exc.value) -def test_verify_seeded_data_raises_when_units_missing(mock_config: DatabaseConfig) -> None: +def test_verify_seeded_data_raises_when_units_missing( + mock_config: DatabaseConfig, +) -> None: setup_instance = DatabaseSetup(mock_config, dry_run=False) connection_mock, cursor_mock = _connection_with_cursor() cursor_mock.fetchall.return_value = [("tonnes", True)] - with mock.patch.object(setup_instance, "_application_connection", return_value=connection_mock): + with mock.patch.object( + setup_instance, "_application_connection", return_value=connection_mock + ): with pytest.raises(RuntimeError) as exc: setup_instance._verify_seeded_data( expected_currency_codes=set(), @@ -193,12 +229,18 @@ def test_verify_seeded_data_raises_when_units_missing(mock_config: DatabaseConfi assert "liters" in str(exc.value) -def test_verify_seeded_data_raises_when_measurement_table_missing(mock_config: DatabaseConfig) -> None: +def test_verify_seeded_data_raises_when_measurement_table_missing( + mock_config: DatabaseConfig, +) -> None: setup_instance = DatabaseSetup(mock_config, dry_run=False) connection_mock, cursor_mock = _connection_with_cursor() - cursor_mock.execute.side_effect = psycopg_errors.UndefinedTable("relation does not exist") + cursor_mock.execute.side_effect = psycopg_errors.UndefinedTable( + "relation does not exist" + ) - with mock.patch.object(setup_instance, "_application_connection", return_value=connection_mock): + with mock.patch.object( + setup_instance, "_application_connection", return_value=connection_mock + ): with pytest.raises(RuntimeError) as exc: setup_instance._verify_seeded_data( expected_currency_codes=set(), @@ -226,9 +268,14 @@ def test_seed_baseline_data_rerun_uses_existing_records( unit_rows, ] - with mock.patch.object( - setup_instance, "_application_connection", return_value=connection_mock - ), mock.patch("scripts.seed_data.run_with_namespace") as seed_run: + with ( + mock.patch.object( + setup_instance, + "_application_connection", + return_value=connection_mock, + ), + mock.patch("scripts.seed_data.run_with_namespace") as seed_run, + ): setup_instance.seed_baseline_data(dry_run=False) setup_instance.seed_baseline_data(dry_run=False) @@ -240,7 +287,9 @@ def test_seed_baseline_data_rerun_uses_existing_records( assert cursor_mock.execute.call_count == 4 -def test_ensure_database_raises_with_context(mock_config: DatabaseConfig) -> None: +def test_ensure_database_raises_with_context( + mock_config: DatabaseConfig, +) -> None: setup_instance = DatabaseSetup(mock_config, dry_run=False) connection_mock = mock.MagicMock() cursor_mock = mock.MagicMock() @@ -248,14 +297,18 @@ def test_ensure_database_raises_with_context(mock_config: DatabaseConfig) -> Non cursor_mock.execute.side_effect = [None, psycopg2.Error("create_fail")] connection_mock.cursor.return_value = cursor_mock - with mock.patch.object(setup_instance, "_admin_connection", return_value=connection_mock): + with mock.patch.object( + setup_instance, "_admin_connection", return_value=connection_mock + ): with pytest.raises(RuntimeError) as exc: setup_instance.ensure_database() assert "Failed to create database" in str(exc.value) -def test_ensure_role_raises_with_context_during_creation(mock_config: DatabaseConfig) -> None: +def test_ensure_role_raises_with_context_during_creation( + mock_config: DatabaseConfig, +) -> None: setup_instance = DatabaseSetup(mock_config, dry_run=False) admin_conn, admin_cursor = _connection_with_cursor() @@ -295,7 +348,9 @@ def test_ensure_role_raises_with_context_during_privilege_grants( assert "Failed to grant privileges" in str(exc.value) -def test_ensure_database_dry_run_skips_creation(mock_config: DatabaseConfig) -> None: +def test_ensure_database_dry_run_skips_creation( + mock_config: DatabaseConfig, +) -> None: setup_instance = DatabaseSetup(mock_config, dry_run=True) connection_mock = mock.MagicMock() @@ -303,45 +358,59 @@ def test_ensure_database_dry_run_skips_creation(mock_config: DatabaseConfig) -> cursor_mock.fetchone.return_value = None connection_mock.cursor.return_value = cursor_mock - with mock.patch.object(setup_instance, "_admin_connection", return_value=connection_mock), mock.patch( - "scripts.setup_database.logger" - ) as logger_mock: + with ( + mock.patch.object( + setup_instance, "_admin_connection", return_value=connection_mock + ), + mock.patch("scripts.setup_database.logger") as logger_mock, + ): setup_instance.ensure_database() # expect only existence check, no create attempt cursor_mock.execute.assert_called_once() logger_mock.info.assert_any_call( - "Dry run: would create database '%s'. Run without --dry-run to proceed.", mock_config.database + "Dry run: would create database '%s'. Run without --dry-run to proceed.", + mock_config.database, ) -def test_ensure_role_dry_run_skips_creation_and_grants(mock_config: DatabaseConfig) -> None: +def test_ensure_role_dry_run_skips_creation_and_grants( + mock_config: DatabaseConfig, +) -> None: setup_instance = DatabaseSetup(mock_config, dry_run=True) admin_conn, admin_cursor = _connection_with_cursor() admin_cursor.fetchone.return_value = None - with mock.patch.object( - setup_instance, - "_admin_connection", - side_effect=[admin_conn], - ) as conn_mock, mock.patch("scripts.setup_database.logger") as logger_mock: + with ( + mock.patch.object( + setup_instance, + "_admin_connection", + side_effect=[admin_conn], + ) as conn_mock, + mock.patch("scripts.setup_database.logger") as logger_mock, + ): setup_instance.ensure_role() assert conn_mock.call_count == 1 admin_cursor.execute.assert_called_once() logger_mock.info.assert_any_call( - "Dry run: would create role '%s'. Run without --dry-run to apply.", mock_config.user + "Dry run: would create role '%s'. Run without --dry-run to apply.", + mock_config.user, ) -def test_register_rollback_skipped_when_dry_run(mock_config: DatabaseConfig) -> None: +def test_register_rollback_skipped_when_dry_run( + mock_config: DatabaseConfig, +) -> None: setup_instance = DatabaseSetup(mock_config, dry_run=True) setup_instance._register_rollback("noop", lambda: None) assert setup_instance._rollback_actions == [] -def test_execute_rollbacks_runs_in_reverse_order(mock_config: DatabaseConfig) -> None: +def test_execute_rollbacks_runs_in_reverse_order( + mock_config: DatabaseConfig, +) -> None: setup_instance = DatabaseSetup(mock_config, dry_run=False) calls: list[str] = [] @@ -362,16 +431,24 @@ def test_execute_rollbacks_runs_in_reverse_order(mock_config: DatabaseConfig) -> assert setup_instance._rollback_actions == [] -def test_ensure_database_registers_rollback_action(mock_config: DatabaseConfig) -> None: +def test_ensure_database_registers_rollback_action( + mock_config: DatabaseConfig, +) -> None: setup_instance = DatabaseSetup(mock_config, dry_run=False) connection_mock = mock.MagicMock() cursor_mock = mock.MagicMock() cursor_mock.fetchone.return_value = None connection_mock.cursor.return_value = cursor_mock - with mock.patch.object(setup_instance, "_admin_connection", return_value=connection_mock), mock.patch.object( - setup_instance, "_register_rollback" - ) as register_mock, mock.patch.object(setup_instance, "_drop_database") as drop_mock: + with ( + mock.patch.object( + setup_instance, "_admin_connection", return_value=connection_mock + ), + mock.patch.object( + setup_instance, "_register_rollback" + ) as register_mock, + mock.patch.object(setup_instance, "_drop_database") as drop_mock, + ): setup_instance.ensure_database() register_mock.assert_called_once() label, action = register_mock.call_args[0] @@ -380,24 +457,29 @@ def test_ensure_database_registers_rollback_action(mock_config: DatabaseConfig) drop_mock.assert_called_once_with(mock_config.database) -def test_ensure_role_registers_rollback_actions(mock_config: DatabaseConfig) -> None: +def test_ensure_role_registers_rollback_actions( + mock_config: DatabaseConfig, +) -> None: setup_instance = DatabaseSetup(mock_config, dry_run=False) admin_conn, admin_cursor = _connection_with_cursor() admin_cursor.fetchone.return_value = None privilege_conn, privilege_cursor = _connection_with_cursor() - with mock.patch.object( - setup_instance, - "_admin_connection", - side_effect=[admin_conn, privilege_conn], - ), mock.patch.object( - setup_instance, "_register_rollback" - ) as register_mock, mock.patch.object( - setup_instance, "_drop_role" - ) as drop_mock, mock.patch.object( - setup_instance, "_revoke_role_privileges" - ) as revoke_mock: + with ( + mock.patch.object( + setup_instance, + "_admin_connection", + side_effect=[admin_conn, privilege_conn], + ), + mock.patch.object( + setup_instance, "_register_rollback" + ) as register_mock, + mock.patch.object(setup_instance, "_drop_role") as drop_mock, + mock.patch.object( + setup_instance, "_revoke_role_privileges" + ) as revoke_mock, + ): setup_instance.ensure_role() assert register_mock.call_count == 2 drop_label, drop_action = register_mock.call_args_list[0][0] @@ -413,7 +495,9 @@ def test_ensure_role_registers_rollback_actions(mock_config: DatabaseConfig) -> revoke_mock.assert_called_once() -def test_main_triggers_rollbacks_on_failure(mock_config: DatabaseConfig) -> None: +def test_main_triggers_rollbacks_on_failure( + mock_config: DatabaseConfig, +) -> None: args = argparse.Namespace( ensure_database=True, ensure_role=True, @@ -437,11 +521,13 @@ def test_main_triggers_rollbacks_on_failure(mock_config: DatabaseConfig) -> None verbose=0, ) - with mock.patch.object(setup_db_module, "parse_args", return_value=args), mock.patch.object( - setup_db_module.DatabaseConfig, "from_env", return_value=mock_config - ), mock.patch.object( - setup_db_module, "DatabaseSetup" - ) as setup_cls: + with ( + mock.patch.object(setup_db_module, "parse_args", return_value=args), + mock.patch.object( + setup_db_module.DatabaseConfig, "from_env", return_value=mock_config + ), + mock.patch.object(setup_db_module, "DatabaseSetup") as setup_cls, + ): setup_instance = mock.MagicMock() setup_instance.dry_run = False setup_instance._rollback_actions = [ diff --git a/tests/unit/test_simulation.py b/tests/unit/test_simulation.py index 05444dd..b89febe 100644 --- a/tests/unit/test_simulation.py +++ b/tests/unit/test_simulation.py @@ -19,7 +19,12 @@ def client(api_client: TestClient) -> TestClient: def test_run_simulation_function_generates_samples(): params: List[Dict[str, Any]] = [ - {"name": "grade", "value": 1.8, "distribution": "normal", "std_dev": 0.2}, + { + "name": "grade", + "value": 1.8, + "distribution": "normal", + "std_dev": 0.2, + }, { "name": "recovery", "value": 0.9, @@ -45,7 +50,10 @@ def test_run_simulation_with_zero_iterations_returns_empty(): @pytest.mark.parametrize( "parameter_payload,error_message", [ - ({"name": "missing-value"}, "Parameter at index 0 must include 'value'"), + ( + {"name": "missing-value"}, + "Parameter at index 0 must include 'value'", + ), ( { "name": "bad-dist", @@ -110,7 +118,8 @@ def test_run_simulation_triangular_sampling_path(): span = 10.0 * DEFAULT_UNIFORM_SPAN_RATIO rng = Random(seed) expected_samples = [ - rng.triangular(10.0 - span, 10.0 + span, 10.0) for _ in range(iterations) + rng.triangular(10.0 - span, 10.0 + span, 10.0) + for _ in range(iterations) ] actual_samples = [entry["result"] for entry in results] for actual, expected in zip(actual_samples, expected_samples): @@ -156,9 +165,7 @@ def test_simulation_endpoint_no_params(client: TestClient): assert resp.json()["detail"] == "No parameters provided" -def test_simulation_endpoint_success( - client: TestClient, db_session: Session -): +def test_simulation_endpoint_success(client: TestClient, db_session: Session): scenario_payload: Dict[str, Any] = { "name": f"SimScenario-{uuid4()}", "description": "Simulation test", @@ -168,7 +175,12 @@ def test_simulation_endpoint_success( scenario_id = scenario_resp.json()["id"] params: List[Dict[str, Any]] = [ - {"name": "param1", "value": 2.5, "distribution": "normal", "std_dev": 0.5} + { + "name": "param1", + "value": 2.5, + "distribution": "normal", + "std_dev": 0.5, + } ] payload: Dict[str, Any] = { "scenario_id": scenario_id, diff --git a/tests/unit/test_theme_settings.py b/tests/unit/test_theme_settings.py new file mode 100644 index 0000000..c1e79ba --- /dev/null +++ b/tests/unit/test_theme_settings.py @@ -0,0 +1,63 @@ +import pytest +from sqlalchemy.orm import Session +from fastapi.testclient import TestClient + +from main import app +from models.theme_setting import ThemeSetting +from services.settings import save_theme_settings, get_theme_settings + + +client = TestClient(app) + + +def test_save_theme_settings(db_session: Session): + theme_data = { + "theme_name": "dark", + "primary_color": "#000000", + "secondary_color": "#333333", + "accent_color": "#ff0000", + "background_color": "#1a1a1a", + "text_color": "#ffffff" + } + + saved_setting = save_theme_settings(db_session, theme_data) + assert str(saved_setting.theme_name) == "dark" + assert str(saved_setting.primary_color) == "#000000" + + +def test_get_theme_settings(db_session: Session): + # Create a theme setting first + theme_data = { + "theme_name": "light", + "primary_color": "#ffffff", + "secondary_color": "#cccccc", + "accent_color": "#0000ff", + "background_color": "#f0f0f0", + "text_color": "#000000" + } + save_theme_settings(db_session, theme_data) + + settings = get_theme_settings(db_session) + assert settings["theme_name"] == "light" + assert settings["primary_color"] == "#ffffff" + + +def test_theme_settings_api(api_client): + # Test API endpoint for saving theme settings + theme_data = { + "theme_name": "test_theme", + "primary_color": "#123456", + "secondary_color": "#789abc", + "accent_color": "#def012", + "background_color": "#345678", + "text_color": "#9abcde" + } + + response = api_client.post("/api/settings/theme", json=theme_data) + assert response.status_code == 200 + assert response.json()["theme"]["theme_name"] == "test_theme" + + # Test API endpoint for getting theme settings + response = api_client.get("/api/settings/theme") + assert response.status_code == 200 + assert response.json()["theme_name"] == "test_theme" diff --git a/tests/unit/test_ui_routes.py b/tests/unit/test_ui_routes.py index 7a56043..b0757d7 100644 --- a/tests/unit/test_ui_routes.py +++ b/tests/unit/test_ui_routes.py @@ -21,11 +21,18 @@ def test_dashboard_route_provides_summary( assert context.get("report_available") is True metric_labels = {item["label"] for item in context["summary_metrics"]} - assert {"CAPEX Total", "OPEX Total", "Production", "Simulation Iterations"}.issubset(metric_labels) + assert { + "CAPEX Total", + "OPEX Total", + "Production", + "Simulation Iterations", + }.issubset(metric_labels) scenario = cast(Scenario, seeded_ui_data["scenario"]) scenario_row = next( - row for row in context["scenario_rows"] if row["scenario_name"] == scenario.name + row + for row in context["scenario_rows"] + if row["scenario_name"] == scenario.name ) assert scenario_row["iterations"] == 3 assert scenario_row["simulation_mean_display"] == "971,666.67" @@ -81,7 +88,9 @@ def test_dashboard_data_endpoint_returns_aggregates( payload = response.json() assert payload["report_available"] is True - metric_map = {item["label"]: item["value"] for item in payload["summary_metrics"]} + metric_map = { + item["label"]: item["value"] for item in payload["summary_metrics"] + } assert metric_map["CAPEX Total"].startswith("$") assert metric_map["Maintenance Cost"].startswith("$") @@ -99,7 +108,9 @@ def test_dashboard_data_endpoint_returns_aggregates( activity_labels = payload["scenario_activity_chart"]["labels"] activity_idx = activity_labels.index(scenario.name) - assert payload["scenario_activity_chart"]["production"][activity_idx] == 800.0 + assert ( + payload["scenario_activity_chart"]["production"][activity_idx] == 800.0 + ) @pytest.mark.parametrize( @@ -154,7 +165,10 @@ def test_settings_route_provides_css_context( assert "css_env_override_meta" in context assert context["css_variables"]["--color-accent"] == "#abcdef" - assert context["css_defaults"]["--color-accent"] == settings_service.CSS_COLOR_DEFAULTS["--color-accent"] + assert ( + context["css_defaults"]["--color-accent"] + == settings_service.CSS_COLOR_DEFAULTS["--color-accent"] + ) assert context["css_env_overrides"]["--color-accent"] == "#abcdef" override_rows = context["css_env_override_rows"]