diff --git a/.prettierrc b/.prettierrc deleted file mode 100644 index 0ca3806..0000000 --- a/.prettierrc +++ /dev/null @@ -1,8 +0,0 @@ -{ - "semi": true, - "singleQuote": true, - "trailingComma": "es5", - "printWidth": 80, - "tabWidth": 2, - "useTabs": false -} diff --git a/config/setup_production.env.example b/config/setup_production.env.example deleted file mode 100644 index fefd6f2..0000000 --- a/config/setup_production.env.example +++ /dev/null @@ -1,35 +0,0 @@ -# Copy this file to config/setup_production.env and replace values with production secrets - -# Container image and runtime configuration -CALMINER_IMAGE=registry.example.com/calminer/api:latest -CALMINER_DOMAIN=calminer.example.com -TRAEFIK_ACME_EMAIL=ops@example.com -CALMINER_API_PORT=8000 -UVICORN_WORKERS=4 -UVICORN_LOG_LEVEL=info -CALMINER_NETWORK=calminer_backend -API_LIMIT_CPUS=1.0 -API_LIMIT_MEMORY=1g -API_RESERVATION_MEMORY=512m -TRAEFIK_LIMIT_CPUS=0.5 -TRAEFIK_LIMIT_MEMORY=512m -POSTGRES_LIMIT_CPUS=1.0 -POSTGRES_LIMIT_MEMORY=2g -POSTGRES_RESERVATION_MEMORY=1g - -# Application database connection -DATABASE_DRIVER=postgresql+psycopg2 -DATABASE_HOST=production-db.internal -DATABASE_PORT=5432 -DATABASE_NAME=calminer -DATABASE_USER=calminer_app -DATABASE_PASSWORD=ChangeMe123! -DATABASE_SCHEMA=public - -# Optional consolidated SQLAlchemy URL (overrides granular settings when set) -# DATABASE_URL=postgresql+psycopg2://calminer_app:ChangeMe123!@production-db.internal:5432/calminer - -# Superuser credentials used by scripts/setup_database.py for migrations/seed data -DATABASE_SUPERUSER=postgres -DATABASE_SUPERUSER_PASSWORD=ChangeMeSuper123! -DATABASE_SUPERUSER_DB=postgres diff --git a/config/setup_staging.env.example b/config/setup_staging.env.example deleted file mode 100644 index a166e1f..0000000 --- a/config/setup_staging.env.example +++ /dev/null @@ -1,11 +0,0 @@ -# Sample environment configuration for staging deployment -DATABASE_HOST=staging-db.internal -DATABASE_PORT=5432 -DATABASE_NAME=calminer_staging -DATABASE_USER=calminer_app -DATABASE_PASSWORD= - -# Admin connection used for provisioning database and roles -DATABASE_SUPERUSER=postgres -DATABASE_SUPERUSER_PASSWORD= -DATABASE_SUPERUSER_DB=postgres diff --git a/config/setup_test.env.example b/config/setup_test.env.example deleted file mode 100644 index 2228373..0000000 --- a/config/setup_test.env.example +++ /dev/null @@ -1,14 +0,0 @@ -# Sample environment configuration for running scripts/setup_database.py against a test instance -DATABASE_DRIVER=postgresql -DATABASE_HOST=postgres -DATABASE_PORT=5432 -DATABASE_NAME=calminer_test -DATABASE_USER=calminer_test -DATABASE_PASSWORD= -# optional: specify schema if different from 'public' -#DATABASE_SCHEMA=public - -# Admin connection used for provisioning database and roles -DATABASE_SUPERUSER=postgres -DATABASE_SUPERUSER_PASSWORD= -DATABASE_SUPERUSER_DB=postgres diff --git a/models/__init__.py b/models/__init__.py deleted file mode 100644 index a46e508..0000000 --- a/models/__init__.py +++ /dev/null @@ -1,10 +0,0 @@ -""" -models package initializer. Import key models so they're registered -with the shared Base.metadata when the package is imported by tests. -""" - -from . import application_setting # noqa: F401 -from . import currency # noqa: F401 -from . import role # noqa: F401 -from . import user # noqa: F401 -from . import theme_setting # noqa: F401 diff --git a/models/application_setting.py b/models/application_setting.py deleted file mode 100644 index ed98160..0000000 --- a/models/application_setting.py +++ /dev/null @@ -1,38 +0,0 @@ -from datetime import datetime -from typing import Optional - -from sqlalchemy import Boolean, DateTime, String, Text -from sqlalchemy.orm import Mapped, mapped_column -from sqlalchemy.sql import func - -from config.database import Base - - -class ApplicationSetting(Base): - __tablename__ = "application_setting" - - id: Mapped[int] = mapped_column(primary_key=True, index=True) - key: Mapped[str] = mapped_column(String(128), unique=True, nullable=False) - value: Mapped[str] = mapped_column(Text, nullable=False) - value_type: Mapped[str] = mapped_column( - String(32), nullable=False, default="string" - ) - category: Mapped[str] = mapped_column( - String(32), nullable=False, default="general" - ) - description: Mapped[Optional[str]] = mapped_column(Text, nullable=True) - is_editable: Mapped[bool] = mapped_column( - Boolean, nullable=False, default=True - ) - created_at: Mapped[datetime] = mapped_column( - DateTime(timezone=True), server_default=func.now(), nullable=False - ) - updated_at: Mapped[datetime] = mapped_column( - DateTime(timezone=True), - server_default=func.now(), - onupdate=func.now(), - nullable=False, - ) - - def __repr__(self) -> str: - return f"" diff --git a/models/capex.py b/models/capex.py deleted file mode 100644 index 68b6749..0000000 --- a/models/capex.py +++ /dev/null @@ -1,71 +0,0 @@ -from sqlalchemy import event, text -from sqlalchemy import Column, Integer, Float, String, ForeignKey -from sqlalchemy.orm import relationship -from config.database import Base - - -class Capex(Base): - __tablename__ = "capex" - - id = Column(Integer, primary_key=True, index=True) - scenario_id = Column(Integer, ForeignKey("scenario.id"), nullable=False) - amount = Column(Float, nullable=False) - description = Column(String, nullable=True) - currency_id = Column(Integer, ForeignKey("currency.id"), nullable=False) - - scenario = relationship("Scenario", back_populates="capex_items") - currency = relationship("Currency", back_populates="capex_items") - - def __repr__(self): - return ( - f"" - ) - - @property - def currency_code(self) -> str: - return self.currency.code if self.currency else None - - @currency_code.setter - def currency_code(self, value: str) -> None: - # store pending code so application code or migrations can pick it up - setattr( - self, "_currency_code_pending", (value or "USD").strip().upper() - ) - - -# SQLAlchemy event handlers to ensure currency_id is set before insert/update - - -def _resolve_currency(mapper, connection, target): - # If currency_id already set, nothing to do - if getattr(target, "currency_id", None): - return - code = getattr(target, "_currency_code_pending", None) or "USD" - # Try to find existing currency id - row = connection.execute( - text("SELECT id FROM currency WHERE code = :code"), {"code": code} - ).fetchone() - if row: - cid = row[0] - else: - # Insert new currency and attempt to get lastrowid - res = connection.execute( - text( - "INSERT INTO currency (code, name, symbol, is_active) VALUES (:code, :name, :symbol, :active)" - ), - {"code": code, "name": code, "symbol": None, "active": True}, - ) - try: - cid = res.lastrowid - except Exception: - # fallback: select after insert - cid = connection.execute( - text("SELECT id FROM currency WHERE code = :code"), - {"code": code}, - ).scalar() - target.currency_id = cid - - -event.listen(Capex, "before_insert", _resolve_currency) -event.listen(Capex, "before_update", _resolve_currency) diff --git a/models/consumption.py b/models/consumption.py deleted file mode 100644 index c5239bc..0000000 --- a/models/consumption.py +++ /dev/null @@ -1,22 +0,0 @@ -from sqlalchemy import Column, Integer, Float, String, ForeignKey -from sqlalchemy.orm import relationship -from config.database import Base - - -class Consumption(Base): - __tablename__ = "consumption" - - id = Column(Integer, primary_key=True, index=True) - scenario_id = Column(Integer, ForeignKey("scenario.id"), nullable=False) - amount = Column(Float, nullable=False) - description = Column(String, nullable=True) - unit_name = Column(String(64), nullable=True) - unit_symbol = Column(String(16), nullable=True) - - scenario = relationship("Scenario", back_populates="consumption_items") - - def __repr__(self): - return ( - f"" - ) diff --git a/models/currency.py b/models/currency.py deleted file mode 100644 index de95abd..0000000 --- a/models/currency.py +++ /dev/null @@ -1,24 +0,0 @@ -from sqlalchemy import Column, Integer, String, Boolean -from sqlalchemy.orm import relationship -from config.database import Base - - -class Currency(Base): - __tablename__ = "currency" - - id = Column(Integer, primary_key=True, index=True) - code = Column(String(3), nullable=False, unique=True, index=True) - name = Column(String(128), nullable=False) - symbol = Column(String(8), nullable=True) - is_active = Column(Boolean, nullable=False, default=True) - - # reverse relationships (optional) - capex_items = relationship( - "Capex", back_populates="currency", lazy="select" - ) - opex_items = relationship("Opex", back_populates="currency", lazy="select") - - def __repr__(self): - return ( - f"" - ) diff --git a/models/distribution.py b/models/distribution.py deleted file mode 100644 index 9f9832a..0000000 --- a/models/distribution.py +++ /dev/null @@ -1,14 +0,0 @@ -from sqlalchemy import Column, Integer, String, JSON -from config.database import Base - - -class Distribution(Base): - __tablename__ = "distribution" - - id = Column(Integer, primary_key=True, index=True) - name = Column(String, nullable=False) - distribution_type = Column(String, nullable=False) - parameters = Column(JSON, nullable=True) - - def __repr__(self): - return f"" diff --git a/models/equipment.py b/models/equipment.py deleted file mode 100644 index e431891..0000000 --- a/models/equipment.py +++ /dev/null @@ -1,17 +0,0 @@ -from sqlalchemy import Column, Integer, String, ForeignKey -from sqlalchemy.orm import relationship -from config.database import Base - - -class Equipment(Base): - __tablename__ = "equipment" - - id = Column(Integer, primary_key=True, index=True) - scenario_id = Column(Integer, ForeignKey("scenario.id"), nullable=False) - name = Column(String, nullable=False) - description = Column(String, nullable=True) - - scenario = relationship("Scenario", back_populates="equipment_items") - - def __repr__(self): - return f"" diff --git a/models/maintenance.py b/models/maintenance.py deleted file mode 100644 index 43a7aea..0000000 --- a/models/maintenance.py +++ /dev/null @@ -1,23 +0,0 @@ -from sqlalchemy import Column, Date, Float, ForeignKey, Integer, String -from sqlalchemy.orm import relationship -from config.database import Base - - -class Maintenance(Base): - __tablename__ = "maintenance" - - id = Column(Integer, primary_key=True, index=True) - equipment_id = Column(Integer, ForeignKey("equipment.id"), nullable=False) - scenario_id = Column(Integer, ForeignKey("scenario.id"), nullable=False) - maintenance_date = Column(Date, nullable=False) - description = Column(String, nullable=True) - cost = Column(Float, nullable=False) - - equipment = relationship("Equipment") - scenario = relationship("Scenario", back_populates="maintenance_items") - - def __repr__(self) -> str: - return ( - f"" - ) diff --git a/models/opex.py b/models/opex.py deleted file mode 100644 index 5c0e703..0000000 --- a/models/opex.py +++ /dev/null @@ -1,63 +0,0 @@ -from sqlalchemy import event, text -from sqlalchemy import Column, Integer, Float, String, ForeignKey -from sqlalchemy.orm import relationship -from config.database import Base - - -class Opex(Base): - __tablename__ = "opex" - - id = Column(Integer, primary_key=True, index=True) - scenario_id = Column(Integer, ForeignKey("scenario.id"), nullable=False) - amount = Column(Float, nullable=False) - description = Column(String, nullable=True) - currency_id = Column(Integer, ForeignKey("currency.id"), nullable=False) - - scenario = relationship("Scenario", back_populates="opex_items") - currency = relationship("Currency", back_populates="opex_items") - - def __repr__(self): - return ( - f"" - ) - - @property - def currency_code(self) -> str: - return self.currency.code if self.currency else None - - @currency_code.setter - def currency_code(self, value: str) -> None: - setattr( - self, "_currency_code_pending", (value or "USD").strip().upper() - ) - - -def _resolve_currency_opex(mapper, connection, target): - if getattr(target, "currency_id", None): - return - code = getattr(target, "_currency_code_pending", None) or "USD" - row = connection.execute( - text("SELECT id FROM currency WHERE code = :code"), {"code": code} - ).fetchone() - if row: - cid = row[0] - else: - res = connection.execute( - text( - "INSERT INTO currency (code, name, symbol, is_active) VALUES (:code, :name, :symbol, :active)" - ), - {"code": code, "name": code, "symbol": None, "active": True}, - ) - try: - cid = res.lastrowid - except Exception: - cid = connection.execute( - text("SELECT id FROM currency WHERE code = :code"), - {"code": code}, - ).scalar() - target.currency_id = cid - - -event.listen(Opex, "before_insert", _resolve_currency_opex) -event.listen(Opex, "before_update", _resolve_currency_opex) diff --git a/models/parameters.py b/models/parameters.py deleted file mode 100644 index 822a011..0000000 --- a/models/parameters.py +++ /dev/null @@ -1,29 +0,0 @@ -from typing import Any, Dict, Optional - -from sqlalchemy import ForeignKey, JSON -from sqlalchemy.orm import Mapped, mapped_column, relationship -from config.database import Base - - -class Parameter(Base): - __tablename__ = "parameter" - - id: Mapped[int] = mapped_column(primary_key=True, index=True) - scenario_id: Mapped[int] = mapped_column( - ForeignKey("scenario.id"), nullable=False - ) - name: Mapped[str] = mapped_column(nullable=False) - value: Mapped[float] = mapped_column(nullable=False) - distribution_id: Mapped[Optional[int]] = mapped_column( - ForeignKey("distribution.id"), nullable=True - ) - distribution_type: Mapped[Optional[str]] = mapped_column(nullable=True) - distribution_parameters: Mapped[Optional[Dict[str, Any]]] = mapped_column( - JSON, nullable=True - ) - - scenario = relationship("Scenario", back_populates="parameters") - distribution = relationship("Distribution") - - def __repr__(self): - return f"" diff --git a/models/production_output.py b/models/production_output.py deleted file mode 100644 index fde7cb8..0000000 --- a/models/production_output.py +++ /dev/null @@ -1,24 +0,0 @@ -from sqlalchemy import Column, Integer, Float, String, ForeignKey -from sqlalchemy.orm import relationship -from config.database import Base - - -class ProductionOutput(Base): - __tablename__ = "production_output" - - id = Column(Integer, primary_key=True, index=True) - scenario_id = Column(Integer, ForeignKey("scenario.id"), nullable=False) - amount = Column(Float, nullable=False) - description = Column(String, nullable=True) - unit_name = Column(String(64), nullable=True) - unit_symbol = Column(String(16), nullable=True) - - scenario = relationship( - "Scenario", back_populates="production_output_items" - ) - - def __repr__(self): - return ( - f"" - ) diff --git a/models/role.py b/models/role.py deleted file mode 100644 index 3351908..0000000 --- a/models/role.py +++ /dev/null @@ -1,13 +0,0 @@ -from sqlalchemy import Column, Integer, String -from sqlalchemy.orm import relationship - -from config.database import Base - - -class Role(Base): - __tablename__ = "roles" - - id = Column(Integer, primary_key=True, index=True) - name = Column(String, unique=True, index=True) - - users = relationship("User", back_populates="role") diff --git a/models/scenario.py b/models/scenario.py deleted file mode 100644 index 66d4fd2..0000000 --- a/models/scenario.py +++ /dev/null @@ -1,36 +0,0 @@ -from sqlalchemy import Column, Integer, String, DateTime, func -from sqlalchemy.orm import relationship -from models.simulation_result import SimulationResult -from models.capex import Capex -from models.opex import Opex -from models.consumption import Consumption -from models.production_output import ProductionOutput -from models.equipment import Equipment -from models.maintenance import Maintenance -from config.database import Base - - -class Scenario(Base): - __tablename__ = "scenario" - - id = Column(Integer, primary_key=True, index=True) - name = Column(String, unique=True, nullable=False) - description = Column(String) - created_at = Column(DateTime(timezone=True), server_default=func.now()) - updated_at = Column(DateTime(timezone=True), onupdate=func.now()) - parameters = relationship("Parameter", back_populates="scenario") - simulation_results = relationship( - SimulationResult, back_populates="scenario" - ) - capex_items = relationship(Capex, back_populates="scenario") - opex_items = relationship(Opex, back_populates="scenario") - consumption_items = relationship(Consumption, back_populates="scenario") - production_output_items = relationship( - ProductionOutput, back_populates="scenario" - ) - equipment_items = relationship(Equipment, back_populates="scenario") - maintenance_items = relationship(Maintenance, back_populates="scenario") - - # relationships can be defined later - def __repr__(self): - return f"" diff --git a/models/simulation_result.py b/models/simulation_result.py deleted file mode 100644 index c5edac7..0000000 --- a/models/simulation_result.py +++ /dev/null @@ -1,14 +0,0 @@ -from sqlalchemy import Column, Integer, Float, ForeignKey -from sqlalchemy.orm import relationship -from config.database import Base - - -class SimulationResult(Base): - __tablename__ = "simulation_result" - - id = Column(Integer, primary_key=True, index=True) - scenario_id = Column(Integer, ForeignKey("scenario.id"), nullable=False) - iteration = Column(Integer, nullable=False) - result = Column(Float, nullable=False) - - scenario = relationship("Scenario", back_populates="simulation_results") diff --git a/models/theme_setting.py b/models/theme_setting.py deleted file mode 100644 index 1e20c64..0000000 --- a/models/theme_setting.py +++ /dev/null @@ -1,15 +0,0 @@ -from sqlalchemy import Column, Integer, String - -from config.database import Base - - -class ThemeSetting(Base): - __tablename__ = "theme_settings" - - id = Column(Integer, primary_key=True, index=True) - theme_name = Column(String, unique=True, index=True) - primary_color = Column(String) - secondary_color = Column(String) - accent_color = Column(String) - background_color = Column(String) - text_color = Column(String) diff --git a/models/user.py b/models/user.py deleted file mode 100644 index 5ee8654..0000000 --- a/models/user.py +++ /dev/null @@ -1,23 +0,0 @@ -from sqlalchemy import Column, Integer, String, ForeignKey -from sqlalchemy.orm import relationship - -from config.database import Base -from services.security import get_password_hash, verify_password - - -class User(Base): - __tablename__ = "users" - - id = Column(Integer, primary_key=True, index=True) - username = Column(String, unique=True, index=True) - email = Column(String, unique=True, index=True) - hashed_password = Column(String) - role_id = Column(Integer, ForeignKey("roles.id")) - - role = relationship("Role", back_populates="users") - - def set_password(self, password: str): - self.hashed_password = get_password_hash(password) - - def check_password(self, password: str) -> bool: - return verify_password(password, str(self.hashed_password)) diff --git a/requirements-test.txt b/requirements-test.txt index b2ac481..1e96b46 100644 --- a/requirements-test.txt +++ b/requirements-test.txt @@ -1,7 +1,7 @@ -playwright pytest pytest-cov pytest-httpx -pytest-playwright python-jose ruff +black +mypy \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 0f27fee..e07bb5c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ fastapi -pydantic>=2.0,<3.0 +pydantic uvicorn sqlalchemy psycopg2-binary diff --git a/routes/consumption.py b/routes/consumption.py deleted file mode 100644 index e03785d..0000000 --- a/routes/consumption.py +++ /dev/null @@ -1,52 +0,0 @@ -from typing import List, Optional - -from fastapi import APIRouter, Depends, status -from pydantic import BaseModel, ConfigDict, PositiveFloat, field_validator -from sqlalchemy.orm import Session - -from models.consumption import Consumption -from routes.dependencies import get_db - - -router = APIRouter(prefix="/api/consumption", tags=["Consumption"]) - - -class ConsumptionBase(BaseModel): - scenario_id: int - amount: PositiveFloat - description: Optional[str] = None - unit_name: Optional[str] = None - unit_symbol: Optional[str] = None - - @field_validator("unit_name", "unit_symbol") - @classmethod - def _normalize_text(cls, value: Optional[str]) -> Optional[str]: - if value is None: - return None - stripped = value.strip() - return stripped or None - - -class ConsumptionCreate(ConsumptionBase): - pass - - -class ConsumptionRead(ConsumptionBase): - id: int - model_config = ConfigDict(from_attributes=True) - - -@router.post( - "/", response_model=ConsumptionRead, status_code=status.HTTP_201_CREATED -) -def create_consumption(item: ConsumptionCreate, db: Session = Depends(get_db)): - db_item = Consumption(**item.model_dump()) - db.add(db_item) - db.commit() - db.refresh(db_item) - return db_item - - -@router.get("/", response_model=List[ConsumptionRead]) -def list_consumption(db: Session = Depends(get_db)): - return db.query(Consumption).all() diff --git a/routes/costs.py b/routes/costs.py deleted file mode 100644 index e22f18a..0000000 --- a/routes/costs.py +++ /dev/null @@ -1,121 +0,0 @@ -from typing import List, Optional - -from fastapi import APIRouter, Depends -from pydantic import BaseModel, ConfigDict, field_validator -from sqlalchemy.orm import Session - -from models.capex import Capex -from models.opex import Opex -from routes.dependencies import get_db - -router = APIRouter(prefix="/api/costs", tags=["Costs"]) -# Pydantic schemas for CAPEX and OPEX - - -class _CostBase(BaseModel): - scenario_id: int - amount: float - description: Optional[str] = None - currency_code: Optional[str] = "USD" - currency_id: Optional[int] = None - - @field_validator("currency_code") - @classmethod - def _normalize_currency(cls, value: Optional[str]) -> str: - code = (value or "USD").strip().upper() - return code[:3] if len(code) > 3 else code - - -class CapexCreate(_CostBase): - pass - - -class CapexRead(_CostBase): - id: int - # use from_attributes so Pydantic reads attributes off SQLAlchemy model - model_config = ConfigDict(from_attributes=True) - - # optionally include nested currency info - currency: Optional["CurrencyRead"] = None - - -class OpexCreate(_CostBase): - pass - - -class OpexRead(_CostBase): - id: int - model_config = ConfigDict(from_attributes=True) - currency: Optional["CurrencyRead"] = None - - -class CurrencyRead(BaseModel): - id: int - code: str - name: Optional[str] = None - symbol: Optional[str] = None - is_active: Optional[bool] = True - - model_config = ConfigDict(from_attributes=True) - - -# forward refs -CapexRead.model_rebuild() -OpexRead.model_rebuild() - - -# Capex endpoints -@router.post("/capex", response_model=CapexRead) -def create_capex(item: CapexCreate, db: Session = Depends(get_db)): - payload = item.model_dump() - # Prefer explicit currency_id if supplied - cid = payload.get("currency_id") - if not cid: - code = (payload.pop("currency_code", "USD") or "USD").strip().upper() - currency_cls = __import__( - "models.currency", fromlist=["Currency"] - ).Currency - currency = db.query(currency_cls).filter_by(code=code).one_or_none() - if currency is None: - currency = currency_cls(code=code, name=code, symbol=None) - db.add(currency) - db.flush() - payload["currency_id"] = currency.id - db_item = Capex(**payload) - db.add(db_item) - db.commit() - db.refresh(db_item) - return db_item - - -@router.get("/capex", response_model=List[CapexRead]) -def list_capex(db: Session = Depends(get_db)): - return db.query(Capex).all() - - -# Opex endpoints -@router.post("/opex", response_model=OpexRead) -def create_opex(item: OpexCreate, db: Session = Depends(get_db)): - payload = item.model_dump() - cid = payload.get("currency_id") - if not cid: - code = (payload.pop("currency_code", "USD") or "USD").strip().upper() - currency_cls = __import__( - "models.currency", fromlist=["Currency"] - ).Currency - currency = db.query(currency_cls).filter_by(code=code).one_or_none() - if currency is None: - currency = currency_cls(code=code, name=code, symbol=None) - db.add(currency) - db.flush() - payload["currency_id"] = currency.id - db_item = Opex(**payload) - db.add(db_item) - db.commit() - db.refresh(db_item) - return db_item - - -@router.get("/opex", response_model=List[OpexRead]) -def list_opex(db: Session = Depends(get_db)): - return db.query(Opex).all() diff --git a/routes/currencies.py b/routes/currencies.py deleted file mode 100644 index 8899366..0000000 --- a/routes/currencies.py +++ /dev/null @@ -1,193 +0,0 @@ -from typing import List, Optional - -from fastapi import APIRouter, Depends, HTTPException, Query, status -from pydantic import BaseModel, ConfigDict, Field, field_validator -from sqlalchemy.orm import Session -from sqlalchemy.exc import IntegrityError - -from models.currency import Currency -from routes.dependencies import get_db - -router = APIRouter(prefix="/api/currencies", tags=["Currencies"]) - - -DEFAULT_CURRENCY_CODE = "USD" -DEFAULT_CURRENCY_NAME = "US Dollar" -DEFAULT_CURRENCY_SYMBOL = "$" - - -class CurrencyBase(BaseModel): - name: str = Field(..., min_length=1, max_length=128) - symbol: Optional[str] = Field(default=None, max_length=8) - - @staticmethod - def _normalize_symbol(value: Optional[str]) -> Optional[str]: - if value is None: - return None - value = value.strip() - return value or None - - @field_validator("name") - @classmethod - def _strip_name(cls, value: str) -> str: - return value.strip() - - @field_validator("symbol") - @classmethod - def _strip_symbol(cls, value: Optional[str]) -> Optional[str]: - return cls._normalize_symbol(value) - - -class CurrencyCreate(CurrencyBase): - code: str = Field(..., min_length=3, max_length=3) - is_active: bool = True - - @field_validator("code") - @classmethod - def _normalize_code(cls, value: str) -> str: - return value.strip().upper() - - -class CurrencyUpdate(CurrencyBase): - is_active: Optional[bool] = None - - -class CurrencyActivation(BaseModel): - is_active: bool - - -class CurrencyRead(CurrencyBase): - id: int - code: str - is_active: bool - - model_config = ConfigDict(from_attributes=True) - - -def _ensure_default_currency(db: Session) -> Currency: - existing = ( - db.query(Currency) - .filter(Currency.code == DEFAULT_CURRENCY_CODE) - .one_or_none() - ) - if existing: - return existing - - default_currency = Currency( - code=DEFAULT_CURRENCY_CODE, - name=DEFAULT_CURRENCY_NAME, - symbol=DEFAULT_CURRENCY_SYMBOL, - is_active=True, - ) - db.add(default_currency) - try: - db.commit() - except IntegrityError: - db.rollback() - existing = ( - db.query(Currency) - .filter(Currency.code == DEFAULT_CURRENCY_CODE) - .one() - ) - return existing - db.refresh(default_currency) - return default_currency - - -def _get_currency_or_404(db: Session, code: str) -> Currency: - normalized = code.strip().upper() - currency = ( - db.query(Currency).filter(Currency.code == normalized).one_or_none() - ) - if currency is None: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, detail="Currency not found" - ) - return currency - - -@router.get("/", response_model=List[CurrencyRead]) -def list_currencies( - include_inactive: bool = Query( - False, description="Include inactive currencies" - ), - db: Session = Depends(get_db), -): - _ensure_default_currency(db) - query = db.query(Currency) - if not include_inactive: - query = query.filter(Currency.is_active.is_(True)) - currencies = query.order_by(Currency.code).all() - return currencies - - -@router.post( - "/", response_model=CurrencyRead, status_code=status.HTTP_201_CREATED -) -def create_currency(payload: CurrencyCreate, db: Session = Depends(get_db)): - code = payload.code - existing = db.query(Currency).filter(Currency.code == code).one_or_none() - if existing is not None: - raise HTTPException( - status_code=status.HTTP_409_CONFLICT, - detail=f"Currency '{code}' already exists", - ) - - currency = Currency( - code=code, - name=payload.name, - symbol=CurrencyBase._normalize_symbol(payload.symbol), - is_active=payload.is_active, - ) - db.add(currency) - db.commit() - db.refresh(currency) - return currency - - -@router.put("/{code}", response_model=CurrencyRead) -def update_currency( - code: str, payload: CurrencyUpdate, db: Session = Depends(get_db) -): - currency = _get_currency_or_404(db, code) - - if payload.name is not None: - setattr(currency, "name", payload.name) - if payload.symbol is not None or payload.symbol == "": - setattr( - currency, - "symbol", - CurrencyBase._normalize_symbol(payload.symbol), - ) - if payload.is_active is not None: - code_value = getattr(currency, "code") - if code_value == DEFAULT_CURRENCY_CODE and payload.is_active is False: - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail="The default currency cannot be deactivated.", - ) - setattr(currency, "is_active", payload.is_active) - - db.add(currency) - db.commit() - db.refresh(currency) - return currency - - -@router.patch("/{code}/activation", response_model=CurrencyRead) -def toggle_currency_activation( - code: str, body: CurrencyActivation, db: Session = Depends(get_db) -): - currency = _get_currency_or_404(db, code) - code_value = getattr(currency, "code") - if code_value == DEFAULT_CURRENCY_CODE and body.is_active is False: - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail="The default currency cannot be deactivated.", - ) - - setattr(currency, "is_active", body.is_active) - db.add(currency) - db.commit() - db.refresh(currency) - return currency diff --git a/routes/dependencies.py b/routes/dependencies.py deleted file mode 100644 index 0afc871..0000000 --- a/routes/dependencies.py +++ /dev/null @@ -1,13 +0,0 @@ -from collections.abc import Generator - -from sqlalchemy.orm import Session - -from config.database import SessionLocal - - -def get_db() -> Generator[Session, None, None]: - db = SessionLocal() - try: - yield db - finally: - db.close() diff --git a/routes/distributions.py b/routes/distributions.py deleted file mode 100644 index 34a0cc8..0000000 --- a/routes/distributions.py +++ /dev/null @@ -1,38 +0,0 @@ -from typing import Dict, List - -from fastapi import APIRouter, Depends -from pydantic import BaseModel, ConfigDict -from sqlalchemy.orm import Session - -from models.distribution import Distribution -from routes.dependencies import get_db - -router = APIRouter(prefix="/api/distributions", tags=["Distributions"]) - - -class DistributionCreate(BaseModel): - name: str - distribution_type: str - parameters: Dict[str, float | int] - - -class DistributionRead(DistributionCreate): - id: int - model_config = ConfigDict(from_attributes=True) - - -@router.post("/", response_model=DistributionRead) -async def create_distribution( - dist: DistributionCreate, db: Session = Depends(get_db) -): - db_dist = Distribution(**dist.model_dump()) - db.add(db_dist) - db.commit() - db.refresh(db_dist) - return db_dist - - -@router.get("/", response_model=List[DistributionRead]) -async def list_distributions(db: Session = Depends(get_db)): - dists = db.query(Distribution).all() - return dists diff --git a/routes/equipment.py b/routes/equipment.py deleted file mode 100644 index a5800a9..0000000 --- a/routes/equipment.py +++ /dev/null @@ -1,38 +0,0 @@ -from typing import List, Optional - -from fastapi import APIRouter, Depends -from pydantic import BaseModel, ConfigDict -from sqlalchemy.orm import Session - -from models.equipment import Equipment -from routes.dependencies import get_db - -router = APIRouter(prefix="/api/equipment", tags=["Equipment"]) -# Pydantic schemas - - -class EquipmentCreate(BaseModel): - scenario_id: int - name: str - description: Optional[str] = None - - -class EquipmentRead(EquipmentCreate): - id: int - model_config = ConfigDict(from_attributes=True) - - -@router.post("/", response_model=EquipmentRead) -async def create_equipment( - item: EquipmentCreate, db: Session = Depends(get_db) -): - db_item = Equipment(**item.model_dump()) - db.add(db_item) - db.commit() - db.refresh(db_item) - return db_item - - -@router.get("/", response_model=List[EquipmentRead]) -async def list_equipment(db: Session = Depends(get_db)): - return db.query(Equipment).all() diff --git a/routes/maintenance.py b/routes/maintenance.py deleted file mode 100644 index 93683fd..0000000 --- a/routes/maintenance.py +++ /dev/null @@ -1,91 +0,0 @@ -from datetime import date -from typing import List, Optional - -from fastapi import APIRouter, Depends, HTTPException, status -from pydantic import BaseModel, ConfigDict, PositiveFloat -from sqlalchemy.orm import Session - -from models.maintenance import Maintenance -from routes.dependencies import get_db - - -router = APIRouter(prefix="/api/maintenance", tags=["Maintenance"]) - - -class MaintenanceBase(BaseModel): - equipment_id: int - scenario_id: int - maintenance_date: date - description: Optional[str] = None - cost: PositiveFloat - - -class MaintenanceCreate(MaintenanceBase): - pass - - -class MaintenanceUpdate(MaintenanceBase): - pass - - -class MaintenanceRead(MaintenanceBase): - id: int - model_config = ConfigDict(from_attributes=True) - - -def _get_maintenance_or_404(db: Session, maintenance_id: int) -> Maintenance: - maintenance = ( - db.query(Maintenance).filter(Maintenance.id == maintenance_id).first() - ) - if maintenance is None: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail=f"Maintenance record {maintenance_id} not found", - ) - return maintenance - - -@router.post( - "/", response_model=MaintenanceRead, status_code=status.HTTP_201_CREATED -) -def create_maintenance( - maintenance: MaintenanceCreate, db: Session = Depends(get_db) -): - db_maintenance = Maintenance(**maintenance.model_dump()) - db.add(db_maintenance) - db.commit() - db.refresh(db_maintenance) - return db_maintenance - - -@router.get("/", response_model=List[MaintenanceRead]) -def list_maintenance( - skip: int = 0, limit: int = 100, db: Session = Depends(get_db) -): - return db.query(Maintenance).offset(skip).limit(limit).all() - - -@router.get("/{maintenance_id}", response_model=MaintenanceRead) -def get_maintenance(maintenance_id: int, db: Session = Depends(get_db)): - return _get_maintenance_or_404(db, maintenance_id) - - -@router.put("/{maintenance_id}", response_model=MaintenanceRead) -def update_maintenance( - maintenance_id: int, - payload: MaintenanceUpdate, - db: Session = Depends(get_db), -): - db_maintenance = _get_maintenance_or_404(db, maintenance_id) - for field, value in payload.model_dump().items(): - setattr(db_maintenance, field, value) - db.commit() - db.refresh(db_maintenance) - return db_maintenance - - -@router.delete("/{maintenance_id}", status_code=status.HTTP_204_NO_CONTENT) -def delete_maintenance(maintenance_id: int, db: Session = Depends(get_db)): - db_maintenance = _get_maintenance_or_404(db, maintenance_id) - db.delete(db_maintenance) - db.commit() diff --git a/routes/parameters.py b/routes/parameters.py deleted file mode 100644 index 59f09c8..0000000 --- a/routes/parameters.py +++ /dev/null @@ -1,90 +0,0 @@ -from typing import Any, Dict, List, Optional - -from fastapi import APIRouter, Depends, HTTPException -from pydantic import BaseModel, ConfigDict, field_validator -from sqlalchemy.orm import Session - -from models.distribution import Distribution -from models.parameters import Parameter -from models.scenario import Scenario -from routes.dependencies import get_db - -router = APIRouter(prefix="/api/parameters", tags=["parameters"]) - - -class ParameterCreate(BaseModel): - scenario_id: int - name: str - value: float - distribution_id: Optional[int] = None - distribution_type: Optional[str] = None - distribution_parameters: Optional[Dict[str, Any]] = None - - @field_validator("distribution_type") - @classmethod - def normalize_type(cls, value: Optional[str]) -> Optional[str]: - if value is None: - return value - normalized = value.strip().lower() - if not normalized: - return None - if normalized not in {"normal", "uniform", "triangular"}: - raise ValueError( - "distribution_type must be normal, uniform, or triangular" - ) - return normalized - - @field_validator("distribution_parameters") - @classmethod - def empty_dict_to_none( - cls, value: Optional[Dict[str, Any]] - ) -> Optional[Dict[str, Any]]: - if value is None: - return None - return value or None - - -class ParameterRead(ParameterCreate): - id: int - model_config = ConfigDict(from_attributes=True) - - -@router.post("/", response_model=ParameterRead) -def create_parameter(param: ParameterCreate, db: Session = Depends(get_db)): - scen = db.query(Scenario).filter(Scenario.id == param.scenario_id).first() - if not scen: - raise HTTPException(status_code=404, detail="Scenario not found") - distribution_id = param.distribution_id - distribution_type = param.distribution_type - distribution_parameters = param.distribution_parameters - - if distribution_id is not None: - distribution = ( - db.query(Distribution) - .filter(Distribution.id == distribution_id) - .first() - ) - if not distribution: - raise HTTPException( - status_code=404, detail="Distribution not found" - ) - distribution_type = distribution.distribution_type - distribution_parameters = distribution.parameters or None - - new_param = Parameter( - scenario_id=param.scenario_id, - name=param.name, - value=param.value, - distribution_id=distribution_id, - distribution_type=distribution_type, - distribution_parameters=distribution_parameters, - ) - db.add(new_param) - db.commit() - db.refresh(new_param) - return new_param - - -@router.get("/", response_model=List[ParameterRead]) -def list_parameters(db: Session = Depends(get_db)): - return db.query(Parameter).all() diff --git a/routes/production.py b/routes/production.py deleted file mode 100644 index ad4a059..0000000 --- a/routes/production.py +++ /dev/null @@ -1,56 +0,0 @@ -from typing import List, Optional - -from fastapi import APIRouter, Depends, status -from pydantic import BaseModel, ConfigDict, PositiveFloat, field_validator -from sqlalchemy.orm import Session - -from models.production_output import ProductionOutput -from routes.dependencies import get_db - - -router = APIRouter(prefix="/api/production", tags=["Production"]) - - -class ProductionOutputBase(BaseModel): - scenario_id: int - amount: PositiveFloat - description: Optional[str] = None - unit_name: Optional[str] = None - unit_symbol: Optional[str] = None - - @field_validator("unit_name", "unit_symbol") - @classmethod - def _normalize_text(cls, value: Optional[str]) -> Optional[str]: - if value is None: - return None - stripped = value.strip() - return stripped or None - - -class ProductionOutputCreate(ProductionOutputBase): - pass - - -class ProductionOutputRead(ProductionOutputBase): - id: int - model_config = ConfigDict(from_attributes=True) - - -@router.post( - "/", - response_model=ProductionOutputRead, - status_code=status.HTTP_201_CREATED, -) -def create_production( - item: ProductionOutputCreate, db: Session = Depends(get_db) -): - db_item = ProductionOutput(**item.model_dump()) - db.add(db_item) - db.commit() - db.refresh(db_item) - return db_item - - -@router.get("/", response_model=List[ProductionOutputRead]) -def list_production(db: Session = Depends(get_db)): - return db.query(ProductionOutput).all() diff --git a/routes/reporting.py b/routes/reporting.py deleted file mode 100644 index 09a9417..0000000 --- a/routes/reporting.py +++ /dev/null @@ -1,73 +0,0 @@ -from typing import Any, Dict, List, cast - -from fastapi import APIRouter, HTTPException, Request, status -from pydantic import BaseModel - -from services.reporting import generate_report - - -router = APIRouter(prefix="/api/reporting", tags=["Reporting"]) - - -def _validate_payload(payload: Any) -> List[Dict[str, float]]: - if not isinstance(payload, list): - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail="Invalid input format", - ) - - typed_payload = cast(List[Any], payload) - - validated: List[Dict[str, float]] = [] - for index, item in enumerate(typed_payload): - if not isinstance(item, dict): - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail=f"Entry at index {index} must be an object", - ) - value = cast(Dict[str, Any], item).get("result") - if not isinstance(value, (int, float)): - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail=f"Entry at index {index} must include numeric 'result'", - ) - validated.append({"result": float(value)}) - return validated - - -class ReportSummary(BaseModel): - count: int - mean: float - median: float - min: float - max: float - std_dev: float - variance: float - percentile_10: float - percentile_90: float - percentile_5: float - percentile_95: float - value_at_risk_95: float - expected_shortfall_95: float - - -@router.post("/summary", response_model=ReportSummary) -async def summary_report(request: Request): - payload = await request.json() - validated_payload = _validate_payload(payload) - summary = generate_report(validated_payload) - return ReportSummary( - count=int(summary["count"]), - mean=float(summary["mean"]), - median=float(summary["median"]), - min=float(summary["min"]), - max=float(summary["max"]), - std_dev=float(summary["std_dev"]), - variance=float(summary["variance"]), - percentile_10=float(summary["percentile_10"]), - percentile_90=float(summary["percentile_90"]), - percentile_5=float(summary["percentile_5"]), - percentile_95=float(summary["percentile_95"]), - value_at_risk_95=float(summary["value_at_risk_95"]), - expected_shortfall_95=float(summary["expected_shortfall_95"]), - ) diff --git a/routes/scenarios.py b/routes/scenarios.py deleted file mode 100644 index 4454f74..0000000 --- a/routes/scenarios.py +++ /dev/null @@ -1,42 +0,0 @@ -from datetime import datetime -from typing import Optional - -from fastapi import APIRouter, Depends, HTTPException -from pydantic import BaseModel, ConfigDict -from sqlalchemy.orm import Session - -from models.scenario import Scenario -from routes.dependencies import get_db - -router = APIRouter(prefix="/api/scenarios", tags=["scenarios"]) - -# Pydantic schemas - - -class ScenarioCreate(BaseModel): - name: str - description: Optional[str] = None - - -class ScenarioRead(ScenarioCreate): - id: int - created_at: datetime - updated_at: Optional[datetime] = None - model_config = ConfigDict(from_attributes=True) - - -@router.post("/", response_model=ScenarioRead) -def create_scenario(scenario: ScenarioCreate, db: Session = Depends(get_db)): - db_s = db.query(Scenario).filter(Scenario.name == scenario.name).first() - if db_s: - raise HTTPException(status_code=400, detail="Scenario already exists") - new_s = Scenario(name=scenario.name, description=scenario.description) - db.add(new_s) - db.commit() - db.refresh(new_s) - return new_s - - -@router.get("/", response_model=list[ScenarioRead]) -def list_scenarios(db: Session = Depends(get_db)): - return db.query(Scenario).all() diff --git a/routes/settings.py b/routes/settings.py deleted file mode 100644 index ed06fb5..0000000 --- a/routes/settings.py +++ /dev/null @@ -1,110 +0,0 @@ -from typing import Dict, List - -from fastapi import APIRouter, Depends, HTTPException, status -from pydantic import BaseModel, Field, model_validator -from sqlalchemy.orm import Session - -from routes.dependencies import get_db -from services.settings import ( - CSS_COLOR_DEFAULTS, - get_css_color_settings, - list_css_env_override_rows, - read_css_color_env_overrides, - update_css_color_settings, - get_theme_settings, - save_theme_settings, -) - -router = APIRouter(prefix="/api/settings", tags=["Settings"]) - - -class CSSSettingsPayload(BaseModel): - variables: Dict[str, str] = Field(default_factory=dict) - - @model_validator(mode="after") - def _validate_allowed_keys(self) -> "CSSSettingsPayload": - invalid = set(self.variables.keys()) - set(CSS_COLOR_DEFAULTS.keys()) - if invalid: - invalid_keys = ", ".join(sorted(invalid)) - raise ValueError( - f"Unsupported CSS variables: {invalid_keys}." - " Accepted keys align with the default theme variables." - ) - return self - - -class EnvOverride(BaseModel): - css_key: str - env_var: str - value: str - - -class CSSSettingsResponse(BaseModel): - variables: Dict[str, str] - env_overrides: Dict[str, str] = Field(default_factory=dict) - env_sources: List[EnvOverride] = Field(default_factory=list) - - -@router.get("/css", response_model=CSSSettingsResponse) -def read_css_settings(db: Session = Depends(get_db)) -> CSSSettingsResponse: - try: - values = get_css_color_settings(db) - env_overrides = read_css_color_env_overrides() - env_sources = [ - EnvOverride(**row) for row in list_css_env_override_rows() - ] - except ValueError as exc: - raise HTTPException( - status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail=str(exc), - ) from exc - return CSSSettingsResponse( - variables=values, - env_overrides=env_overrides, - env_sources=env_sources, - ) - - -@router.put( - "/css", response_model=CSSSettingsResponse, status_code=status.HTTP_200_OK -) -def update_css_settings( - payload: CSSSettingsPayload, db: Session = Depends(get_db) -) -> CSSSettingsResponse: - try: - values = update_css_color_settings(db, payload.variables) - env_overrides = read_css_color_env_overrides() - env_sources = [ - EnvOverride(**row) for row in list_css_env_override_rows() - ] - except ValueError as exc: - raise HTTPException( - status_code=status.HTTP_422_UNPROCESSABLE_CONTENT, - detail=str(exc), - ) from exc - return CSSSettingsResponse( - variables=values, - env_overrides=env_overrides, - env_sources=env_sources, - ) - - -class ThemeSettings(BaseModel): - theme_name: str - primary_color: str - secondary_color: str - accent_color: str - background_color: str - text_color: str - - -@router.post("/theme") -async def update_theme(theme_data: ThemeSettings, db: Session = Depends(get_db)): - data_dict = theme_data.model_dump() - save_theme_settings(db, data_dict) - return {"message": "Theme updated", "theme": data_dict} - - -@router.get("/theme") -async def get_theme(db: Session = Depends(get_db)): - return get_theme_settings(db) diff --git a/routes/simulations.py b/routes/simulations.py deleted file mode 100644 index 5500805..0000000 --- a/routes/simulations.py +++ /dev/null @@ -1,126 +0,0 @@ -from typing import Dict, List, Optional - -from fastapi import APIRouter, Depends, HTTPException, status -from pydantic import BaseModel, PositiveInt -from sqlalchemy.orm import Session - -from models.parameters import Parameter -from models.scenario import Scenario -from models.simulation_result import SimulationResult -from routes.dependencies import get_db -from services.reporting import generate_report -from services.simulation import run_simulation - -router = APIRouter(prefix="/api/simulations", tags=["Simulations"]) - - -class SimulationParameterInput(BaseModel): - name: str - value: float - distribution: Optional[str] = "normal" - std_dev: Optional[float] = None - min: Optional[float] = None - max: Optional[float] = None - mode: Optional[float] = None - - -class SimulationRunRequest(BaseModel): - scenario_id: int - iterations: PositiveInt = 1000 - parameters: Optional[List[SimulationParameterInput]] = None - seed: Optional[int] = None - - -class SimulationResultItem(BaseModel): - iteration: int - result: float - - -class SimulationRunResponse(BaseModel): - scenario_id: int - iterations: int - results: List[SimulationResultItem] - summary: Dict[str, float | int] - - -def _load_parameters( - db: Session, scenario_id: int -) -> List[SimulationParameterInput]: - db_params = ( - db.query(Parameter) - .filter(Parameter.scenario_id == scenario_id) - .order_by(Parameter.id) - .all() - ) - return [ - SimulationParameterInput( - name=item.name, - value=item.value, - ) - for item in db_params - ] - - -@router.post("/run", response_model=SimulationRunResponse) -async def simulate( - payload: SimulationRunRequest, db: Session = Depends(get_db) -): - scenario = ( - db.query(Scenario).filter(Scenario.id == payload.scenario_id).first() - ) - if scenario is None: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail="Scenario not found", - ) - - parameters = payload.parameters or _load_parameters(db, payload.scenario_id) - if not parameters: - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail="No parameters provided", - ) - - raw_results = run_simulation( - [param.model_dump(exclude_none=True) for param in parameters], - iterations=payload.iterations, - seed=payload.seed, - ) - - if not raw_results: - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail="Simulation produced no results", - ) - - # Persist results (replace existing values for scenario) - db.query(SimulationResult).filter( - SimulationResult.scenario_id == payload.scenario_id - ).delete() - db.bulk_save_objects( - [ - SimulationResult( - scenario_id=payload.scenario_id, - iteration=item["iteration"], - result=item["result"], - ) - for item in raw_results - ] - ) - db.commit() - - summary = generate_report(raw_results) - - response = SimulationRunResponse( - scenario_id=payload.scenario_id, - iterations=payload.iterations, - results=[ - SimulationResultItem( - iteration=int(item["iteration"]), - result=float(item["result"]), - ) - for item in raw_results - ], - summary=summary, - ) - return response diff --git a/routes/ui.py b/routes/ui.py deleted file mode 100644 index e690dba..0000000 --- a/routes/ui.py +++ /dev/null @@ -1,784 +0,0 @@ -from collections import defaultdict -from datetime import datetime, timezone -from typing import Any, Dict, Optional - -from fastapi import APIRouter, Depends, Request -from fastapi.responses import HTMLResponse, JSONResponse -from fastapi.templating import Jinja2Templates -from sqlalchemy.orm import Session - -from models.capex import Capex -from models.consumption import Consumption -from models.equipment import Equipment -from models.maintenance import Maintenance -from models.opex import Opex -from models.parameters import Parameter -from models.production_output import ProductionOutput -from models.scenario import Scenario -from models.simulation_result import SimulationResult -from routes.dependencies import get_db -from services.reporting import generate_report -from models.currency import Currency -from routes.currencies import DEFAULT_CURRENCY_CODE, _ensure_default_currency -from services.settings import ( - CSS_COLOR_DEFAULTS, - get_css_color_settings, - list_css_env_override_rows, - read_css_color_env_overrides, -) - - -CURRENCY_CHOICES: list[Dict[str, Any]] = [ - {"id": "USD", "name": "US Dollar (USD)"}, - {"id": "EUR", "name": "Euro (EUR)"}, - {"id": "CLP", "name": "Chilean Peso (CLP)"}, - {"id": "RMB", "name": "Chinese Yuan (RMB)"}, - {"id": "GBP", "name": "British Pound (GBP)"}, - {"id": "CAD", "name": "Canadian Dollar (CAD)"}, - {"id": "AUD", "name": "Australian Dollar (AUD)"}, -] - -MEASUREMENT_UNITS: list[Dict[str, Any]] = [ - {"id": "tonnes", "name": "Tonnes", "symbol": "t"}, - {"id": "kilograms", "name": "Kilograms", "symbol": "kg"}, - {"id": "pounds", "name": "Pounds", "symbol": "lb"}, - {"id": "liters", "name": "Liters", "symbol": "L"}, - {"id": "cubic_meters", "name": "Cubic Meters", "symbol": "m3"}, - {"id": "kilowatt_hours", "name": "Kilowatt Hours", "symbol": "kWh"}, -] - -router = APIRouter() - -# Set up Jinja2 templates directory -templates = Jinja2Templates(directory="templates") - - -def _context( - request: Request, extra: Optional[Dict[str, Any]] = None -) -> Dict[str, Any]: - payload: Dict[str, Any] = { - "request": request, - "current_year": datetime.now(timezone.utc).year, - } - if extra: - payload.update(extra) - return payload - - -def _render( - request: Request, - template_name: str, - extra: Optional[Dict[str, Any]] = None, -): - context = _context(request, extra) - return templates.TemplateResponse(request, template_name, context) - - -def _format_currency(value: float) -> str: - return f"${value:,.2f}" - - -def _format_decimal(value: float) -> str: - return f"{value:,.2f}" - - -def _format_int(value: int) -> str: - return f"{value:,}" - - -def _load_scenarios(db: Session) -> Dict[str, Any]: - scenarios: list[Dict[str, Any]] = [ - { - "id": item.id, - "name": item.name, - "description": item.description, - } - for item in db.query(Scenario).order_by(Scenario.name).all() - ] - return {"scenarios": scenarios} - - -def _load_parameters(db: Session) -> Dict[str, Any]: - grouped: defaultdict[int, list[Dict[str, Any]]] = defaultdict(list) - for param in db.query(Parameter).order_by( - Parameter.scenario_id, Parameter.id - ): - grouped[param.scenario_id].append( - { - "id": param.id, - "name": param.name, - "value": param.value, - "distribution_type": param.distribution_type, - "distribution_parameters": param.distribution_parameters, - } - ) - return {"parameters_by_scenario": dict(grouped)} - - -def _load_costs(db: Session) -> Dict[str, Any]: - capex_grouped: defaultdict[int, list[Dict[str, Any]]] = defaultdict(list) - for capex in db.query(Capex).order_by(Capex.scenario_id, Capex.id).all(): - capex_grouped[int(getattr(capex, "scenario_id"))].append( - { - "id": int(getattr(capex, "id")), - "scenario_id": int(getattr(capex, "scenario_id")), - "amount": float(getattr(capex, "amount", 0.0)), - "description": getattr(capex, "description", "") or "", - "currency_code": getattr(capex, "currency_code", "USD") - or "USD", - } - ) - - opex_grouped: defaultdict[int, list[Dict[str, Any]]] = defaultdict(list) - for opex in db.query(Opex).order_by(Opex.scenario_id, Opex.id).all(): - opex_grouped[int(getattr(opex, "scenario_id"))].append( - { - "id": int(getattr(opex, "id")), - "scenario_id": int(getattr(opex, "scenario_id")), - "amount": float(getattr(opex, "amount", 0.0)), - "description": getattr(opex, "description", "") or "", - "currency_code": getattr(opex, "currency_code", "USD") or "USD", - } - ) - - return { - "capex_by_scenario": dict(capex_grouped), - "opex_by_scenario": dict(opex_grouped), - } - - -def _load_currencies(db: Session) -> Dict[str, Any]: - items: list[Dict[str, Any]] = [] - for c in ( - db.query(Currency) - .filter_by(is_active=True) - .order_by(Currency.code) - .all() - ): - items.append( - {"id": c.code, "name": f"{c.name} ({c.code})", "symbol": c.symbol} - ) - if not items: - items.append({"id": "USD", "name": "US Dollar (USD)", "symbol": "$"}) - return {"currency_options": items} - - -def _load_currency_settings(db: Session) -> Dict[str, Any]: - _ensure_default_currency(db) - records = db.query(Currency).order_by(Currency.code).all() - currencies: list[Dict[str, Any]] = [] - for record in records: - code_value = getattr(record, "code") - currencies.append( - { - "id": int(getattr(record, "id")), - "code": code_value, - "name": getattr(record, "name"), - "symbol": getattr(record, "symbol"), - "is_active": bool(getattr(record, "is_active", True)), - "is_default": code_value == DEFAULT_CURRENCY_CODE, - } - ) - - active_count = sum(1 for item in currencies if item["is_active"]) - inactive_count = len(currencies) - active_count - - return { - "currencies": currencies, - "currency_stats": { - "total": len(currencies), - "active": active_count, - "inactive": inactive_count, - }, - "default_currency_code": DEFAULT_CURRENCY_CODE, - "currency_api_base": "/api/currencies", - } - - -def _load_css_settings(db: Session) -> Dict[str, Any]: - variables = get_css_color_settings(db) - env_overrides = read_css_color_env_overrides() - env_rows = list_css_env_override_rows() - env_meta = {row["css_key"]: row for row in env_rows} - return { - "css_variables": variables, - "css_defaults": CSS_COLOR_DEFAULTS, - "css_env_overrides": env_overrides, - "css_env_override_rows": env_rows, - "css_env_override_meta": env_meta, - } - - -def _load_consumption(db: Session) -> Dict[str, Any]: - grouped: defaultdict[int, list[Dict[str, Any]]] = defaultdict(list) - for record in ( - db.query(Consumption) - .order_by(Consumption.scenario_id, Consumption.id) - .all() - ): - record_id = int(getattr(record, "id")) - scenario_id = int(getattr(record, "scenario_id")) - amount_value = float(getattr(record, "amount", 0.0)) - description = getattr(record, "description", "") or "" - unit_name = getattr(record, "unit_name", None) - unit_symbol = getattr(record, "unit_symbol", None) - grouped[scenario_id].append( - { - "id": record_id, - "scenario_id": scenario_id, - "amount": amount_value, - "description": description, - "unit_name": unit_name, - "unit_symbol": unit_symbol, - } - ) - return {"consumption_by_scenario": dict(grouped)} - - -def _load_production(db: Session) -> Dict[str, Any]: - grouped: defaultdict[int, list[Dict[str, Any]]] = defaultdict(list) - for record in ( - db.query(ProductionOutput) - .order_by(ProductionOutput.scenario_id, ProductionOutput.id) - .all() - ): - record_id = int(getattr(record, "id")) - scenario_id = int(getattr(record, "scenario_id")) - amount_value = float(getattr(record, "amount", 0.0)) - description = getattr(record, "description", "") or "" - unit_name = getattr(record, "unit_name", None) - unit_symbol = getattr(record, "unit_symbol", None) - grouped[scenario_id].append( - { - "id": record_id, - "scenario_id": scenario_id, - "amount": amount_value, - "description": description, - "unit_name": unit_name, - "unit_symbol": unit_symbol, - } - ) - return {"production_by_scenario": dict(grouped)} - - -def _load_equipment(db: Session) -> Dict[str, Any]: - grouped: defaultdict[int, list[Dict[str, Any]]] = defaultdict(list) - for record in ( - db.query(Equipment).order_by(Equipment.scenario_id, Equipment.id).all() - ): - record_id = int(getattr(record, "id")) - scenario_id = int(getattr(record, "scenario_id")) - name_value = getattr(record, "name", "") or "" - description = getattr(record, "description", "") or "" - grouped[scenario_id].append( - { - "id": record_id, - "scenario_id": scenario_id, - "name": name_value, - "description": description, - } - ) - return {"equipment_by_scenario": dict(grouped)} - - -def _load_maintenance(db: Session) -> Dict[str, Any]: - grouped: defaultdict[int, list[Dict[str, Any]]] = defaultdict(list) - for record in ( - db.query(Maintenance) - .order_by(Maintenance.scenario_id, Maintenance.maintenance_date) - .all() - ): - record_id = int(getattr(record, "id")) - scenario_id = int(getattr(record, "scenario_id")) - equipment_id = int(getattr(record, "equipment_id")) - equipment_obj = getattr(record, "equipment", None) - equipment_name = ( - getattr(equipment_obj, "name", "") if equipment_obj else "" - ) - maintenance_date = getattr(record, "maintenance_date", None) - cost_value = float(getattr(record, "cost", 0.0)) - description = getattr(record, "description", "") or "" - - grouped[scenario_id].append( - { - "id": record_id, - "scenario_id": scenario_id, - "equipment_id": equipment_id, - "equipment_name": equipment_name, - "maintenance_date": ( - maintenance_date.isoformat() if maintenance_date else "" - ), - "cost": cost_value, - "description": description, - } - ) - return {"maintenance_by_scenario": dict(grouped)} - - -def _load_simulations(db: Session) -> Dict[str, Any]: - scenarios: list[Dict[str, Any]] = [ - { - "id": item.id, - "name": item.name, - } - for item in db.query(Scenario).order_by(Scenario.name).all() - ] - - results_grouped: defaultdict[int, list[Dict[str, Any]]] = defaultdict(list) - for record in ( - db.query(SimulationResult) - .order_by(SimulationResult.scenario_id, SimulationResult.iteration) - .all() - ): - scenario_id = int(getattr(record, "scenario_id")) - results_grouped[scenario_id].append( - { - "iteration": int(getattr(record, "iteration")), - "result": float(getattr(record, "result", 0.0)), - } - ) - - runs: list[Dict[str, Any]] = [] - sample_limit = 20 - for item in scenarios: - scenario_id = int(item["id"]) - scenario_results = results_grouped.get(scenario_id, []) - summary = ( - generate_report(scenario_results) - if scenario_results - else generate_report([]) - ) - runs.append( - { - "scenario_id": scenario_id, - "scenario_name": item["name"], - "iterations": int(summary.get("count", 0)), - "summary": summary, - "sample_results": scenario_results[:sample_limit], - } - ) - - return { - "simulation_scenarios": scenarios, - "simulation_runs": runs, - } - - -def _load_reporting(db: Session) -> Dict[str, Any]: - scenarios = _load_scenarios(db)["scenarios"] - runs = _load_simulations(db)["simulation_runs"] - - summaries: list[Dict[str, Any]] = [] - runs_by_scenario = {run["scenario_id"]: run for run in runs} - - for scenario in scenarios: - scenario_id = scenario["id"] - run = runs_by_scenario.get(scenario_id) - summary = run["summary"] if run else generate_report([]) - summaries.append( - { - "scenario_id": scenario_id, - "scenario_name": scenario["name"], - "summary": summary, - "iterations": run["iterations"] if run else 0, - } - ) - - return { - "report_summaries": summaries, - } - - -def _load_dashboard(db: Session) -> Dict[str, Any]: - scenarios = _load_scenarios(db)["scenarios"] - parameters_by_scenario = _load_parameters(db)["parameters_by_scenario"] - costs_context = _load_costs(db) - capex_by_scenario = costs_context["capex_by_scenario"] - opex_by_scenario = costs_context["opex_by_scenario"] - consumption_by_scenario = _load_consumption(db)["consumption_by_scenario"] - production_by_scenario = _load_production(db)["production_by_scenario"] - equipment_by_scenario = _load_equipment(db)["equipment_by_scenario"] - maintenance_by_scenario = _load_maintenance(db)["maintenance_by_scenario"] - simulation_context = _load_simulations(db) - simulation_runs = simulation_context["simulation_runs"] - - runs_by_scenario = {run["scenario_id"]: run for run in simulation_runs} - - def sum_amounts( - grouped: Dict[int, list[Dict[str, Any]]], field: str = "amount" - ) -> float: - total = 0.0 - for items in grouped.values(): - for item in items: - value = item.get(field, 0.0) - if isinstance(value, (int, float)): - total += float(value) - return total - - total_capex = sum_amounts(capex_by_scenario) - total_opex = sum_amounts(opex_by_scenario) - total_consumption = sum_amounts(consumption_by_scenario) - total_production = sum_amounts(production_by_scenario) - total_maintenance_cost = sum_amounts(maintenance_by_scenario, field="cost") - - total_parameters = sum( - len(items) for items in parameters_by_scenario.values() - ) - total_equipment = sum( - len(items) for items in equipment_by_scenario.values() - ) - total_maintenance_events = sum( - len(items) for items in maintenance_by_scenario.values() - ) - total_simulation_iterations = sum( - run["iterations"] for run in simulation_runs - ) - - scenario_rows: list[Dict[str, Any]] = [] - scenario_labels: list[str] = [] - scenario_capex: list[float] = [] - scenario_opex: list[float] = [] - activity_labels: list[str] = [] - activity_production: list[float] = [] - activity_consumption: list[float] = [] - - for scenario in scenarios: - scenario_id = scenario["id"] - scenario_name = scenario["name"] - param_count = len(parameters_by_scenario.get(scenario_id, [])) - equipment_count = len(equipment_by_scenario.get(scenario_id, [])) - maintenance_count = len(maintenance_by_scenario.get(scenario_id, [])) - - capex_total = sum( - float(item.get("amount", 0.0)) - for item in capex_by_scenario.get(scenario_id, []) - ) - opex_total = sum( - float(item.get("amount", 0.0)) - for item in opex_by_scenario.get(scenario_id, []) - ) - consumption_total = sum( - float(item.get("amount", 0.0)) - for item in consumption_by_scenario.get(scenario_id, []) - ) - production_total = sum( - float(item.get("amount", 0.0)) - for item in production_by_scenario.get(scenario_id, []) - ) - - run = runs_by_scenario.get(scenario_id) - summary = run["summary"] if run else generate_report([]) - iterations = run["iterations"] if run else 0 - mean_value = float(summary.get("mean", 0.0)) - - scenario_rows.append( - { - "scenario_name": scenario_name, - "parameter_count": param_count, - "parameter_display": _format_int(param_count), - "equipment_count": equipment_count, - "equipment_display": _format_int(equipment_count), - "capex_total": capex_total, - "capex_display": _format_currency(capex_total), - "opex_total": opex_total, - "opex_display": _format_currency(opex_total), - "production_total": production_total, - "production_display": _format_decimal(production_total), - "consumption_total": consumption_total, - "consumption_display": _format_decimal(consumption_total), - "maintenance_count": maintenance_count, - "maintenance_display": _format_int(maintenance_count), - "iterations": iterations, - "iterations_display": _format_int(iterations), - "simulation_mean": mean_value, - "simulation_mean_display": _format_decimal(mean_value), - } - ) - - scenario_labels.append(scenario_name) - scenario_capex.append(capex_total) - scenario_opex.append(opex_total) - - activity_labels.append(scenario_name) - activity_production.append(production_total) - activity_consumption.append(consumption_total) - - scenario_rows.sort(key=lambda row: row["scenario_name"].lower()) - - all_simulation_results = [ - {"result": float(getattr(item, "result", 0.0))} - for item in db.query(SimulationResult).all() - ] - overall_report = generate_report(all_simulation_results) - - overall_report_metrics = [ - { - "label": "Runs", - "value": _format_int(int(overall_report.get("count", 0))), - }, - { - "label": "Mean", - "value": _format_decimal(float(overall_report.get("mean", 0.0))), - }, - { - "label": "Median", - "value": _format_decimal(float(overall_report.get("median", 0.0))), - }, - { - "label": "Std Dev", - "value": _format_decimal(float(overall_report.get("std_dev", 0.0))), - }, - { - "label": "95th Percentile", - "value": _format_decimal( - float(overall_report.get("percentile_95", 0.0)) - ), - }, - { - "label": "VaR (95%)", - "value": _format_decimal( - float(overall_report.get("value_at_risk_95", 0.0)) - ), - }, - { - "label": "Expected Shortfall (95%)", - "value": _format_decimal( - float(overall_report.get("expected_shortfall_95", 0.0)) - ), - }, - ] - - recent_simulations: list[Dict[str, Any]] = [ - { - "scenario_name": run["scenario_name"], - "iterations": run["iterations"], - "iterations_display": _format_int(run["iterations"]), - "mean_display": _format_decimal( - float(run["summary"].get("mean", 0.0)) - ), - "p95_display": _format_decimal( - float(run["summary"].get("percentile_95", 0.0)) - ), - } - for run in simulation_runs - if run["iterations"] > 0 - ] - recent_simulations.sort(key=lambda item: item["iterations"], reverse=True) - recent_simulations = recent_simulations[:5] - - upcoming_maintenance: list[Dict[str, Any]] = [] - for record in ( - db.query(Maintenance) - .order_by(Maintenance.maintenance_date.asc()) - .limit(5) - .all() - ): - maintenance_date = getattr(record, "maintenance_date", None) - upcoming_maintenance.append( - { - "scenario_name": getattr( - getattr(record, "scenario", None), "name", "Unknown" - ), - "equipment_name": getattr( - getattr(record, "equipment", None), "name", "Unknown" - ), - "date_display": ( - maintenance_date.strftime("%Y-%m-%d") - if maintenance_date - else "—" - ), - "cost_display": _format_currency( - float(getattr(record, "cost", 0.0)) - ), - "description": getattr(record, "description", "") or "—", - } - ) - - cost_chart_has_data = any(value > 0 for value in scenario_capex) or any( - value > 0 for value in scenario_opex - ) - activity_chart_has_data = any( - value > 0 for value in activity_production - ) or any(value > 0 for value in activity_consumption) - - scenario_cost_chart: Dict[str, list[Any]] = { - "labels": scenario_labels, - "capex": scenario_capex, - "opex": scenario_opex, - } - scenario_activity_chart: Dict[str, list[Any]] = { - "labels": activity_labels, - "production": activity_production, - "consumption": activity_consumption, - } - - summary_metrics = [ - {"label": "Active Scenarios", "value": _format_int(len(scenarios))}, - {"label": "Parameters", "value": _format_int(total_parameters)}, - {"label": "CAPEX Total", "value": _format_currency(total_capex)}, - {"label": "OPEX Total", "value": _format_currency(total_opex)}, - {"label": "Equipment Assets", "value": _format_int(total_equipment)}, - { - "label": "Maintenance Events", - "value": _format_int(total_maintenance_events), - }, - {"label": "Consumption", "value": _format_decimal(total_consumption)}, - {"label": "Production", "value": _format_decimal(total_production)}, - { - "label": "Simulation Iterations", - "value": _format_int(total_simulation_iterations), - }, - { - "label": "Maintenance Cost", - "value": _format_currency(total_maintenance_cost), - }, - ] - - return { - "summary_metrics": summary_metrics, - "scenario_rows": scenario_rows, - "overall_report_metrics": overall_report_metrics, - "recent_simulations": recent_simulations, - "upcoming_maintenance": upcoming_maintenance, - "scenario_cost_chart": scenario_cost_chart, - "scenario_activity_chart": scenario_activity_chart, - "cost_chart_has_data": cost_chart_has_data, - "activity_chart_has_data": activity_chart_has_data, - "report_available": overall_report.get("count", 0) > 0, - } - - -@router.get("/", response_class=HTMLResponse) -async def dashboard_root(request: Request, db: Session = Depends(get_db)): - """Render the primary dashboard landing page.""" - return _render(request, "Dashboard.html", _load_dashboard(db)) - - -@router.get("/ui/dashboard", response_class=HTMLResponse) -async def dashboard(request: Request, db: Session = Depends(get_db)): - """Render the legacy dashboard route for backward compatibility.""" - return _render(request, "Dashboard.html", _load_dashboard(db)) - - -@router.get("/ui/dashboard/data", response_class=JSONResponse) -async def dashboard_data(db: Session = Depends(get_db)) -> JSONResponse: - """Expose dashboard aggregates as JSON for client-side refreshes.""" - return JSONResponse(_load_dashboard(db)) - - -@router.get("/ui/scenarios", response_class=HTMLResponse) -async def scenario_form(request: Request, db: Session = Depends(get_db)): - """Render the scenario creation form.""" - context = _load_scenarios(db) - return _render(request, "ScenarioForm.html", context) - - -@router.get("/ui/parameters", response_class=HTMLResponse) -async def parameter_form(request: Request, db: Session = Depends(get_db)): - """Render the parameter input form.""" - context: Dict[str, Any] = {} - context.update(_load_scenarios(db)) - context.update(_load_parameters(db)) - return _render(request, "ParameterInput.html", context) - - -@router.get("/ui/costs", response_class=HTMLResponse) -async def costs_view(request: Request, db: Session = Depends(get_db)): - """Render the costs view with CAPEX and OPEX data.""" - context: Dict[str, Any] = {} - context.update(_load_scenarios(db)) - context.update(_load_costs(db)) - context.update(_load_currencies(db)) - return _render(request, "costs.html", context) - - -@router.get("/ui/consumption", response_class=HTMLResponse) -async def consumption_view(request: Request, db: Session = Depends(get_db)): - """Render the consumption view with scenario consumption data.""" - context: Dict[str, Any] = {} - context.update(_load_scenarios(db)) - context.update(_load_consumption(db)) - context["unit_options"] = MEASUREMENT_UNITS - return _render(request, "consumption.html", context) - - -@router.get("/ui/production", response_class=HTMLResponse) -async def production_view(request: Request, db: Session = Depends(get_db)): - """Render the production view with scenario production data.""" - context: Dict[str, Any] = {} - context.update(_load_scenarios(db)) - context.update(_load_production(db)) - context["unit_options"] = MEASUREMENT_UNITS - return _render(request, "production.html", context) - - -@router.get("/ui/equipment", response_class=HTMLResponse) -async def equipment_view(request: Request, db: Session = Depends(get_db)): - """Render the equipment view with scenario equipment data.""" - context: Dict[str, Any] = {} - context.update(_load_scenarios(db)) - context.update(_load_equipment(db)) - return _render(request, "equipment.html", context) - - -@router.get("/ui/maintenance", response_class=HTMLResponse) -async def maintenance_view(request: Request, db: Session = Depends(get_db)): - """Render the maintenance view with scenario maintenance data.""" - context: Dict[str, Any] = {} - context.update(_load_scenarios(db)) - context.update(_load_equipment(db)) - context.update(_load_maintenance(db)) - return _render(request, "maintenance.html", context) - - -@router.get("/ui/simulations", response_class=HTMLResponse) -async def simulations_view(request: Request, db: Session = Depends(get_db)): - """Render the simulations view with scenario information and recent runs.""" - return _render(request, "simulations.html", _load_simulations(db)) - - -@router.get("/ui/reporting", response_class=HTMLResponse) -async def reporting_view(request: Request, db: Session = Depends(get_db)): - """Render the reporting view with scenario KPI summaries.""" - return _render(request, "reporting.html", _load_reporting(db)) - - -@router.get("/ui/settings", response_class=HTMLResponse) -async def settings_view(request: Request, db: Session = Depends(get_db)): - """Render the settings landing page.""" - context = _load_css_settings(db) - return _render(request, "settings.html", context) - - -@router.get("/ui/currencies", response_class=HTMLResponse) -async def currencies_view(request: Request, db: Session = Depends(get_db)): - """Render the currency administration page with full currency context.""" - context = _load_currency_settings(db) - return _render(request, "currencies.html", context) - - -@router.get("/login", response_class=HTMLResponse) -async def login_page(request: Request): - return _render(request, "login.html") - - -@router.get("/register", response_class=HTMLResponse) -async def register_page(request: Request): - return _render(request, "register.html") - - -@router.get("/profile", response_class=HTMLResponse) -async def profile_page(request: Request): - return _render(request, "profile.html") - - -@router.get("/forgot-password", response_class=HTMLResponse) -async def forgot_password_page(request: Request): - return _render(request, "forgot_password.html") - - -@router.get("/theme-settings", response_class=HTMLResponse) -async def theme_settings_page(request: Request, db: Session = Depends(get_db)): - """Render the theme settings page.""" - context = _load_css_settings(db) - return _render(request, "theme_settings.html", context) diff --git a/routes/users.py b/routes/users.py deleted file mode 100644 index 5de7092..0000000 --- a/routes/users.py +++ /dev/null @@ -1,107 +0,0 @@ -from fastapi import APIRouter, Depends, HTTPException, status -from sqlalchemy.orm import Session - -from config.database import get_db -from models.user import User -from services.security import create_access_token, get_current_user -from schemas.user import ( - PasswordReset, - PasswordResetRequest, - UserCreate, - UserInDB, - UserLogin, - UserUpdate, -) - -router = APIRouter(prefix="/users", tags=["users"]) - - -@router.post("/register", response_model=UserInDB, status_code=status.HTTP_201_CREATED) -async def register_user(user: UserCreate, db: Session = Depends(get_db)): - db_user = db.query(User).filter(User.username == user.username).first() - if db_user: - raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, - detail="Username already registered") - db_user = db.query(User).filter(User.email == user.email).first() - if db_user: - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, detail="Email already registered") - - # Get or create default role - from models.role import Role - default_role = db.query(Role).filter(Role.name == "user").first() - if not default_role: - default_role = Role(name="user") - db.add(default_role) - db.commit() - db.refresh(default_role) - - new_user = User(username=user.username, email=user.email, - role_id=default_role.id) - new_user.set_password(user.password) - db.add(new_user) - db.commit() - db.refresh(new_user) - return new_user - - -@router.post("/login") -async def login_user(user: UserLogin, db: Session = Depends(get_db)): - db_user = db.query(User).filter(User.username == user.username).first() - if not db_user or not db_user.check_password(user.password): - raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, - detail="Incorrect username or password") - access_token = create_access_token(subject=db_user.username) - return {"access_token": access_token, "token_type": "bearer"} - - -@router.get("/me") -async def read_users_me(current_user: User = Depends(get_current_user)): - return current_user - - -@router.put("/me", response_model=UserInDB) -async def update_user_me(user_update: UserUpdate, current_user: User = Depends(get_current_user), db: Session = Depends(get_db)): - if user_update.username and user_update.username != current_user.username: - existing_user = db.query(User).filter( - User.username == user_update.username).first() - if existing_user: - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, detail="Username already taken") - setattr(current_user, "username", user_update.username) - - if user_update.email and user_update.email != current_user.email: - existing_user = db.query(User).filter( - User.email == user_update.email).first() - if existing_user: - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, detail="Email already registered") - setattr(current_user, "email", user_update.email) - - if user_update.password: - current_user.set_password(user_update.password) - - db.add(current_user) - db.commit() - db.refresh(current_user) - return current_user - - -@router.post("/forgot-password") -async def forgot_password(request: PasswordResetRequest): - # In a real application, this would send an email with a reset token - return {"message": "Password reset email sent (not really)"} - - -@router.post("/reset-password") -async def reset_password(request: PasswordReset, db: Session = Depends(get_db)): - # In a real application, the token would be verified - user = db.query(User).filter(User.username == - request.token).first() # Use token as username for test - if not user: - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid token or user") - user.set_password(request.new_password) - db.add(user) - db.commit() - return {"message": "Password has been reset successfully"} diff --git a/schemas/user.py b/schemas/user.py deleted file mode 100644 index fafce5b..0000000 --- a/schemas/user.py +++ /dev/null @@ -1,41 +0,0 @@ -from pydantic import BaseModel, ConfigDict - - -class UserCreate(BaseModel): - username: str - email: str - password: str - - -class UserInDB(BaseModel): - id: int - username: str - email: str - role_id: int - - model_config = ConfigDict(from_attributes=True) - - -class UserLogin(BaseModel): - username: str - password: str - - -class UserUpdate(BaseModel): - username: str | None = None - email: str | None = None - password: str | None = None - - -class PasswordResetRequest(BaseModel): - email: str - - -class PasswordReset(BaseModel): - token: str - new_password: str - - -class Token(BaseModel): - access_token: str - token_type: str diff --git a/scripts/backfill_currency.py b/scripts/backfill_currency.py deleted file mode 100644 index 4651021..0000000 --- a/scripts/backfill_currency.py +++ /dev/null @@ -1,157 +0,0 @@ -""" -Backfill script to populate currency_id for capex and opex rows using existing currency_code. - -Usage: - python scripts/backfill_currency.py --dry-run - python scripts/backfill_currency.py --create-missing - -This script is intentionally cautious: it defaults to dry-run mode and will refuse to run -if database connection settings are missing. It supports creating missing currency rows when `--create-missing` -is provided. Always run against a development/staging database first. -""" - -from __future__ import annotations -import argparse -import importlib -import sys -from pathlib import Path - -from sqlalchemy import text, create_engine - - -PROJECT_ROOT = Path(__file__).resolve().parent.parent -if str(PROJECT_ROOT) not in sys.path: - sys.path.insert(0, str(PROJECT_ROOT)) - - -def load_database_url() -> str: - try: - db_module = importlib.import_module("config.database") - except RuntimeError as exc: - raise RuntimeError( - "Database configuration missing: set DATABASE_URL or provide granular " - "variables (DATABASE_DRIVER, DATABASE_HOST, DATABASE_PORT, DATABASE_USER, " - "DATABASE_PASSWORD, DATABASE_NAME, optional DATABASE_SCHEMA)." - ) from exc - - return getattr(db_module, "DATABASE_URL") - - -def backfill( - db_url: str, dry_run: bool = True, create_missing: bool = False -) -> None: - engine = create_engine(db_url) - with engine.begin() as conn: - # Ensure currency table exists - if db_url.startswith("sqlite:"): - conn.execute( - text( - "SELECT name FROM sqlite_master WHERE type='table' AND name='currency';" - ) - ) - else: - conn.execute(text("SELECT to_regclass('public.currency');")) - # Note: we don't strictly depend on the above - we assume migration was already applied - - # Helper: find or create currency by code - def find_currency_id(code: str): - r = conn.execute( - text("SELECT id FROM currency WHERE code = :code"), - {"code": code}, - ).fetchone() - if r: - return r[0] - if create_missing: - # insert and return id - conn.execute( - text( - "INSERT INTO currency (code, name, symbol, is_active) VALUES (:c, :n, NULL, TRUE)" - ), - {"c": code, "n": code}, - ) - r2 = conn.execute( - text("SELECT id FROM currency WHERE code = :code"), - {"code": code}, - ).fetchone() - if not r2: - raise RuntimeError( - f"Unable to determine currency ID for '{code}' after insert" - ) - return r2[0] - return None - - # Process tables capex and opex - for table in ("capex", "opex"): - # Check if currency_id column exists - try: - cols = ( - conn.execute( - text( - f"SELECT 1 FROM information_schema.columns WHERE table_name = '{table}' AND column_name = 'currency_id'" - ) - ) - if not db_url.startswith("sqlite:") - else [(1,)] - ) - except Exception: - cols = [(1,)] - - if not cols: - print(f"Skipping {table}: no currency_id column found") - continue - - # Find rows where currency_id IS NULL but currency_code exists - rows = conn.execute( - text( - f"SELECT id, currency_code FROM {table} WHERE currency_id IS NULL OR currency_id = ''" - ) - ) - changed = 0 - for r in rows: - rid = r[0] - code = (r[1] or "USD").strip().upper() - cid = find_currency_id(code) - if cid is None: - print( - f"Row {table}:{rid} has unknown currency code '{code}' and create_missing=False; skipping" - ) - continue - if dry_run: - print( - f"[DRY RUN] Would set {table}.currency_id = {cid} for row id={rid} (code={code})" - ) - else: - conn.execute( - text( - f"UPDATE {table} SET currency_id = :cid WHERE id = :rid" - ), - {"cid": cid, "rid": rid}, - ) - changed += 1 - - print(f"{table}: processed, changed={changed} (dry_run={dry_run})") - - -def main() -> None: - parser = argparse.ArgumentParser( - description="Backfill currency_id from currency_code for capex/opex tables" - ) - parser.add_argument( - "--dry-run", - action="store_true", - default=True, - help="Show actions without writing", - ) - parser.add_argument( - "--create-missing", - action="store_true", - help="Create missing currency rows in the currency table", - ) - args = parser.parse_args() - - db = load_database_url() - backfill(db, dry_run=args.dry_run, create_missing=args.create_missing) - - -if __name__ == "__main__": - main() diff --git a/scripts/check_docs_links.py b/scripts/check_docs_links.py deleted file mode 100644 index aebc1fe..0000000 --- a/scripts/check_docs_links.py +++ /dev/null @@ -1,50 +0,0 @@ -"""Simple Markdown link checker for local docs/ files. - -Checks only local file links (relative paths) and reports missing targets. - -Run from the repository root using the project's Python environment. -""" - -import re -from pathlib import Path - -ROOT = Path(__file__).resolve().parent.parent -DOCS = ROOT / "docs" - -MD_LINK_RE = re.compile(r"\[([^\]]+)\]\(([^)]+)\)") - -errors = [] - -for md in DOCS.rglob("*.md"): - text = md.read_text(encoding="utf-8") - for m in MD_LINK_RE.finditer(text): - label, target = m.groups() - # skip URLs - if ( - target.startswith("http://") - or target.startswith("https://") - or target.startswith("#") - ): - continue - # strip anchors - target_path = target.split("#")[0] - # if link is to a directory index, allow - candidate = (md.parent / target_path).resolve() - if candidate.exists(): - continue - # check common implicit index: target/ -> target/README.md or target/index.md - candidate_dir = md.parent / target_path - if candidate_dir.is_dir(): - if (candidate_dir / "README.md").exists() or ( - candidate_dir / "index.md" - ).exists(): - continue - errors.append((str(md.relative_to(ROOT)), target, label)) - -if errors: - print("Broken local links found:") - for src, tgt, label in errors: - print(f"- {src} -> {tgt} ({label})") - exit(2) - -print("No broken local links detected.") diff --git a/scripts/format_docs_md.py b/scripts/format_docs_md.py deleted file mode 100644 index 5e1e856..0000000 --- a/scripts/format_docs_md.py +++ /dev/null @@ -1,92 +0,0 @@ -"""Lightweight Markdown formatter: normalizes first-line H1, adds code-fence language hints for common shebangs, trims trailing whitespace. - -This is intentionally small and non-destructive; it touches only files under docs/ and makes safe changes. -""" - -import re -from pathlib import Path - -DOCS = Path(__file__).resolve().parents[1] / "docs" - -CODE_LANG_HINTS = { - "powershell": ("powershell",), - "bash": ("bash", "sh"), - "sql": ("sql",), - "python": ("python",), -} - - -def add_code_fence_language(match): - fence = match.group(0) - inner = match.group(1) - # If language already present, return unchanged - if fence.startswith("```") and len(fence.splitlines()[0].strip()) > 3: - return fence - # Try to infer language from the code content - code = inner.strip().splitlines()[0] if inner.strip() else "" - lang = "" - if ( - code.startswith("$") - or code.startswith("PS") - or code.lower().startswith("powershell") - ): - lang = "powershell" - elif ( - code.startswith("#") - or code.startswith("import") - or code.startswith("from") - ): - lang = "python" - elif re.match(r"^(select|insert|update|create)\b", code.strip(), re.I): - lang = "sql" - elif ( - code.startswith("git") - or code.startswith("./") - or code.startswith("sudo") - ): - lang = "bash" - if lang: - return f"```{lang}\n{inner}\n```" - return fence - - -def normalize_file(path: Path): - text = path.read_text(encoding="utf-8") - orig = text - # Trim trailing whitespace and ensure single trailing newline - text = "\n".join(line.rstrip() for line in text.splitlines()) + "\n" - # Ensure first non-empty line is H1 - lines = text.splitlines() - for i, ln in enumerate(lines): - if ln.strip(): - if not ln.startswith("#"): - lines[i] = "# " + ln - break - text = "\n".join(lines) + "\n" - # Add basic code fence languages where missing (simple heuristic) - text = re.sub(r"```\n([\s\S]*?)\n```", add_code_fence_language, text) - if text != orig: - path.write_text(text, encoding="utf-8") - return True - return False - - -def main(): - changed = [] - for p in DOCS.rglob("*.md"): - if p.is_file(): - try: - if normalize_file(p): - changed.append(str(p.relative_to(Path.cwd()))) - except Exception as e: - print(f"Failed to format {p}: {e}") - if changed: - print("Formatted files:") - for c in changed: - print(" -", c) - else: - print("No formatting changes required.") - - -if __name__ == "__main__": - main() diff --git a/scripts/migrations/000_base.sql b/scripts/migrations/000_base.sql deleted file mode 100644 index 11f9358..0000000 --- a/scripts/migrations/000_base.sql +++ /dev/null @@ -1,189 +0,0 @@ --- Baseline migration for CalMiner database schema --- Date: 2025-10-25 --- Purpose: Consolidate foundational tables and reference data - -BEGIN; - --- Currency reference table -CREATE TABLE IF NOT EXISTS currency ( - id SERIAL PRIMARY KEY, - code VARCHAR(3) NOT NULL UNIQUE, - name VARCHAR(128) NOT NULL, - symbol VARCHAR(8), - is_active BOOLEAN NOT NULL DEFAULT TRUE -); - -INSERT INTO currency (code, name, symbol, is_active) -VALUES - ('USD', 'United States Dollar', 'USD$', TRUE), - ('EUR', 'Euro', 'EUR', TRUE), - ('CLP', 'Chilean Peso', 'CLP$', TRUE), - ('RMB', 'Chinese Yuan', 'RMB', TRUE), - ('GBP', 'British Pound', 'GBP', TRUE), - ('CAD', 'Canadian Dollar', 'CAD$', TRUE), - ('AUD', 'Australian Dollar', 'AUD$', TRUE) -ON CONFLICT (code) DO UPDATE -SET name = EXCLUDED.name, - symbol = EXCLUDED.symbol, - is_active = EXCLUDED.is_active; - --- Application-level settings table -CREATE TABLE IF NOT EXISTS application_setting ( - id SERIAL PRIMARY KEY, - key VARCHAR(128) NOT NULL UNIQUE, - value TEXT NOT NULL, - value_type VARCHAR(32) NOT NULL DEFAULT 'string', - category VARCHAR(32) NOT NULL DEFAULT 'general', - description TEXT, - is_editable BOOLEAN NOT NULL DEFAULT TRUE, - created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), - updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW() -); - -CREATE UNIQUE INDEX IF NOT EXISTS ux_application_setting_key - ON application_setting (key); - -CREATE INDEX IF NOT EXISTS ix_application_setting_category - ON application_setting (category); - --- Measurement unit reference table -CREATE TABLE IF NOT EXISTS measurement_unit ( - id SERIAL PRIMARY KEY, - code VARCHAR(64) NOT NULL UNIQUE, - name VARCHAR(128) NOT NULL, - symbol VARCHAR(16), - unit_type VARCHAR(32) NOT NULL, - is_active BOOLEAN NOT NULL DEFAULT TRUE, - created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), - updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW() -); - -INSERT INTO measurement_unit (code, name, symbol, unit_type, is_active) -VALUES - ('tonnes', 'Tonnes', 't', 'mass', TRUE), - ('kilograms', 'Kilograms', 'kg', 'mass', TRUE), - ('pounds', 'Pounds', 'lb', 'mass', TRUE), - ('liters', 'Liters', 'L', 'volume', TRUE), - ('cubic_meters', 'Cubic Meters', 'm3', 'volume', TRUE), - ('kilowatt_hours', 'Kilowatt Hours', 'kWh', 'energy', TRUE) -ON CONFLICT (code) DO UPDATE -SET name = EXCLUDED.name, - symbol = EXCLUDED.symbol, - unit_type = EXCLUDED.unit_type, - is_active = EXCLUDED.is_active; - --- Consumption and production measurement metadata -ALTER TABLE consumption - ADD COLUMN IF NOT EXISTS unit_name VARCHAR(64); -ALTER TABLE consumption - ADD COLUMN IF NOT EXISTS unit_symbol VARCHAR(16); - -ALTER TABLE production_output - ADD COLUMN IF NOT EXISTS unit_name VARCHAR(64); -ALTER TABLE production_output - ADD COLUMN IF NOT EXISTS unit_symbol VARCHAR(16); - --- Currency integration for CAPEX and OPEX -ALTER TABLE capex - ADD COLUMN IF NOT EXISTS currency_id INTEGER; -ALTER TABLE opex - ADD COLUMN IF NOT EXISTS currency_id INTEGER; - -DO $$ -DECLARE - usd_id INTEGER; -BEGIN - -- Ensure currency_id columns align with legacy currency_code values when present - IF EXISTS ( - SELECT 1 FROM information_schema.columns - WHERE table_name = 'capex' AND column_name = 'currency_code' - ) THEN - UPDATE capex AS c - SET currency_id = cur.id - FROM currency AS cur - WHERE c.currency_code = cur.code - AND (c.currency_id IS DISTINCT FROM cur.id); - END IF; - - IF EXISTS ( - SELECT 1 FROM information_schema.columns - WHERE table_name = 'opex' AND column_name = 'currency_code' - ) THEN - UPDATE opex AS o - SET currency_id = cur.id - FROM currency AS cur - WHERE o.currency_code = cur.code - AND (o.currency_id IS DISTINCT FROM cur.id); - END IF; - - SELECT id INTO usd_id FROM currency WHERE code = 'USD'; - IF usd_id IS NOT NULL THEN - UPDATE capex SET currency_id = usd_id WHERE currency_id IS NULL; - UPDATE opex SET currency_id = usd_id WHERE currency_id IS NULL; - END IF; -END $$; - -ALTER TABLE capex - ALTER COLUMN currency_id SET NOT NULL; -ALTER TABLE opex - ALTER COLUMN currency_id SET NOT NULL; - -DO $$ -BEGIN - IF NOT EXISTS ( - SELECT 1 FROM information_schema.table_constraints - WHERE table_schema = current_schema() - AND table_name = 'capex' - AND constraint_name = 'fk_capex_currency' - ) THEN - ALTER TABLE capex - ADD CONSTRAINT fk_capex_currency FOREIGN KEY (currency_id) - REFERENCES currency (id) ON DELETE RESTRICT; - END IF; - - IF NOT EXISTS ( - SELECT 1 FROM information_schema.table_constraints - WHERE table_schema = current_schema() - AND table_name = 'opex' - AND constraint_name = 'fk_opex_currency' - ) THEN - ALTER TABLE opex - ADD CONSTRAINT fk_opex_currency FOREIGN KEY (currency_id) - REFERENCES currency (id) ON DELETE RESTRICT; - END IF; -END $$; - -ALTER TABLE capex - DROP COLUMN IF EXISTS currency_code; -ALTER TABLE opex - DROP COLUMN IF EXISTS currency_code; - --- Role-based access control tables -CREATE TABLE IF NOT EXISTS roles ( - id SERIAL PRIMARY KEY, - name VARCHAR(255) UNIQUE NOT NULL -); - -CREATE TABLE IF NOT EXISTS users ( - id SERIAL PRIMARY KEY, - username VARCHAR(255) UNIQUE NOT NULL, - email VARCHAR(255) UNIQUE NOT NULL, - hashed_password VARCHAR(255) NOT NULL, - role_id INTEGER NOT NULL REFERENCES roles (id) ON DELETE RESTRICT -); - -CREATE INDEX IF NOT EXISTS ix_users_username ON users (username); -CREATE INDEX IF NOT EXISTS ix_users_email ON users (email); - --- Theme settings configuration table -CREATE TABLE IF NOT EXISTS theme_settings ( - id SERIAL PRIMARY KEY, - theme_name VARCHAR(255) UNIQUE NOT NULL, - primary_color VARCHAR(7) NOT NULL, - secondary_color VARCHAR(7) NOT NULL, - accent_color VARCHAR(7) NOT NULL, - background_color VARCHAR(7) NOT NULL, - text_color VARCHAR(7) NOT NULL -); - -COMMIT; diff --git a/scripts/seed_data.py b/scripts/seed_data.py deleted file mode 100644 index b762d04..0000000 --- a/scripts/seed_data.py +++ /dev/null @@ -1,268 +0,0 @@ -"""Seed baseline data for CalMiner in an idempotent manner. - -Usage examples --------------- - -```powershell -# Use existing environment variables (or load from setup_test.env.example) -python scripts/seed_data.py --currencies --units --defaults - -# Dry-run to preview actions -python scripts/seed_data.py --currencies --dry-run -``` -""" - -from __future__ import annotations - -import argparse -import logging -from typing import Optional - -import psycopg2 -from psycopg2 import errors -from psycopg2.extras import execute_values - -from scripts.setup_database import DatabaseConfig - - -logger = logging.getLogger(__name__) - -CURRENCY_SEEDS = ( - ("USD", "United States Dollar", "USD$", True), - ("EUR", "Euro", "EUR", True), - ("CLP", "Chilean Peso", "CLP$", True), - ("RMB", "Chinese Yuan", "RMB", True), - ("GBP", "British Pound", "GBP", True), - ("CAD", "Canadian Dollar", "CAD$", True), - ("AUD", "Australian Dollar", "AUD$", True), -) - -MEASUREMENT_UNIT_SEEDS = ( - ("tonnes", "Tonnes", "t", "mass", True), - ("kilograms", "Kilograms", "kg", "mass", True), - ("pounds", "Pounds", "lb", "mass", True), - ("liters", "Liters", "L", "volume", True), - ("cubic_meters", "Cubic Meters", "m3", "volume", True), - ("kilowatt_hours", "Kilowatt Hours", "kWh", "energy", True), -) - -THEME_SETTING_SEEDS = ( - ("--color-background", "#f4f5f7", "color", - "theme", "CSS variable --color-background", True), - ("--color-surface", "#ffffff", "color", - "theme", "CSS variable --color-surface", True), - ("--color-text-primary", "#2a1f33", "color", - "theme", "CSS variable --color-text-primary", True), - ("--color-text-secondary", "#624769", "color", - "theme", "CSS variable --color-text-secondary", True), - ("--color-text-muted", "#64748b", "color", - "theme", "CSS variable --color-text-muted", True), - ("--color-text-subtle", "#94a3b8", "color", - "theme", "CSS variable --color-text-subtle", True), - ("--color-text-invert", "#ffffff", "color", - "theme", "CSS variable --color-text-invert", True), - ("--color-text-dark", "#0f172a", "color", - "theme", "CSS variable --color-text-dark", True), - ("--color-text-strong", "#111827", "color", - "theme", "CSS variable --color-text-strong", True), - ("--color-primary", "#5f320d", "color", - "theme", "CSS variable --color-primary", True), - ("--color-primary-strong", "#7e4c13", "color", - "theme", "CSS variable --color-primary-strong", True), - ("--color-primary-stronger", "#837c15", "color", - "theme", "CSS variable --color-primary-stronger", True), - ("--color-accent", "#bff838", "color", - "theme", "CSS variable --color-accent", True), - ("--color-border", "#e2e8f0", "color", - "theme", "CSS variable --color-border", True), - ("--color-border-strong", "#cbd5e1", "color", - "theme", "CSS variable --color-border-strong", True), - ("--color-highlight", "#eef2ff", "color", - "theme", "CSS variable --color-highlight", True), - ("--color-panel-shadow", "rgba(15, 23, 42, 0.08)", "color", - "theme", "CSS variable --color-panel-shadow", True), - ("--color-panel-shadow-deep", "rgba(15, 23, 42, 0.12)", "color", - "theme", "CSS variable --color-panel-shadow-deep", True), - ("--color-surface-alt", "#f8fafc", "color", - "theme", "CSS variable --color-surface-alt", True), - ("--color-success", "#047857", "color", - "theme", "CSS variable --color-success", True), - ("--color-error", "#b91c1c", "color", - "theme", "CSS variable --color-error", True), -) - - -def parse_args() -> argparse.Namespace: - parser = argparse.ArgumentParser(description="Seed baseline CalMiner data") - parser.add_argument( - "--currencies", action="store_true", help="Seed currency table" - ) - parser.add_argument("--units", action="store_true", help="Seed unit table") - parser.add_argument( - "--theme", action="store_true", help="Seed theme settings" - ) - parser.add_argument( - "--defaults", action="store_true", help="Seed default records" - ) - parser.add_argument( - "--dry-run", action="store_true", help="Print actions without executing" - ) - parser.add_argument( - "--verbose", - "-v", - action="count", - default=0, - help="Increase logging verbosity", - ) - return parser.parse_args() - - -def _configure_logging(args: argparse.Namespace) -> None: - level = logging.WARNING - (10 * min(args.verbose, 2)) - logging.basicConfig( - level=max(level, logging.INFO), format="%(levelname)s %(message)s" - ) - - -def main() -> None: - args = parse_args() - run_with_namespace(args) - - -def run_with_namespace( - args: argparse.Namespace, - *, - config: Optional[DatabaseConfig] = None, -) -> None: - if not hasattr(args, "verbose"): - args.verbose = 0 - if not hasattr(args, "dry_run"): - args.dry_run = False - - _configure_logging(args) - - currencies = bool(getattr(args, "currencies", False)) - units = bool(getattr(args, "units", False)) - theme = bool(getattr(args, "theme", False)) - defaults = bool(getattr(args, "defaults", False)) - dry_run = bool(getattr(args, "dry_run", False)) - - if not any((currencies, units, theme, defaults)): - logger.info("No seeding options provided; exiting") - return - - config = config or DatabaseConfig.from_env() - - with psycopg2.connect(config.application_dsn()) as conn: - conn.autocommit = True - with conn.cursor() as cursor: - if currencies: - _seed_currencies(cursor, dry_run=dry_run) - if units: - _seed_units(cursor, dry_run=dry_run) - if theme: - _seed_theme(cursor, dry_run=dry_run) - if defaults: - _seed_defaults(cursor, dry_run=dry_run) - - -def _seed_currencies(cursor, *, dry_run: bool) -> None: - logger.info("Seeding currency table (%d rows)", len(CURRENCY_SEEDS)) - if dry_run: - for code, name, symbol, active in CURRENCY_SEEDS: - logger.info("Dry run: would upsert currency %s (%s)", code, name) - return - - execute_values( - cursor, - """ - INSERT INTO currency (code, name, symbol, is_active) - VALUES %s - ON CONFLICT (code) DO UPDATE - SET name = EXCLUDED.name, - symbol = EXCLUDED.symbol, - is_active = EXCLUDED.is_active - """, - CURRENCY_SEEDS, - ) - logger.info("Currency seed complete") - - -def _seed_units(cursor, *, dry_run: bool) -> None: - total = len(MEASUREMENT_UNIT_SEEDS) - logger.info("Seeding measurement_unit table (%d rows)", total) - if dry_run: - for code, name, symbol, unit_type, _ in MEASUREMENT_UNIT_SEEDS: - logger.info( - "Dry run: would upsert measurement unit %s (%s - %s)", - code, - name, - unit_type, - ) - return - - try: - execute_values( - cursor, - """ - INSERT INTO measurement_unit (code, name, symbol, unit_type, is_active) - VALUES %s - ON CONFLICT (code) DO UPDATE - SET name = EXCLUDED.name, - symbol = EXCLUDED.symbol, - unit_type = EXCLUDED.unit_type, - is_active = EXCLUDED.is_active - """, - MEASUREMENT_UNIT_SEEDS, - ) - except errors.UndefinedTable: - logger.warning( - "measurement_unit table does not exist; skipping unit seeding." - ) - cursor.connection.rollback() - return - - logger.info("Measurement unit seed complete") - - -def _seed_theme(cursor, *, dry_run: bool) -> None: - logger.info("Seeding theme settings (%d rows)", len(THEME_SETTING_SEEDS)) - if dry_run: - for key, value, _, _, _, _ in THEME_SETTING_SEEDS: - logger.info( - "Dry run: would upsert theme setting %s = %s", key, value) - return - - try: - execute_values( - cursor, - """ - INSERT INTO application_setting (key, value, value_type, category, description, is_editable) - VALUES %s - ON CONFLICT (key) DO UPDATE - SET value = EXCLUDED.value, - value_type = EXCLUDED.value_type, - category = EXCLUDED.category, - description = EXCLUDED.description, - is_editable = EXCLUDED.is_editable - """, - THEME_SETTING_SEEDS, - ) - except errors.UndefinedTable: - logger.warning( - "application_setting table does not exist; skipping theme seeding." - ) - cursor.connection.rollback() - return - - logger.info("Theme settings seed complete") - - -def _seed_defaults(cursor, *, dry_run: bool) -> None: - logger.info("Seeding default records") - _seed_theme(cursor, dry_run=dry_run) - logger.info("Default records seed complete") - - -if __name__ == "__main__": - main() diff --git a/scripts/setup_database.py b/scripts/setup_database.py deleted file mode 100644 index 918d1e6..0000000 --- a/scripts/setup_database.py +++ /dev/null @@ -1,1233 +0,0 @@ -"""Utilities to bootstrap the CalMiner PostgreSQL database. - -This script is designed to be idempotent. Each step checks the existing -state before attempting to modify it so repeated executions are safe. - -Environment variables (with defaults) used when establishing connections: - -* ``DATABASE_DRIVER`` (``postgresql``) -* ``DATABASE_HOST`` (required) -* ``DATABASE_PORT`` (``5432``) -* ``DATABASE_NAME`` (required) -* ``DATABASE_USER`` (required) -* ``DATABASE_PASSWORD`` (optional, required for password auth) -* ``DATABASE_SCHEMA`` (``public``) -* ``DATABASE_ADMIN_URL`` (overrides individual admin settings) -* ``DATABASE_SUPERUSER`` (falls back to ``DATABASE_USER`` or ``postgres``) -* ``DATABASE_SUPERUSER_PASSWORD`` (falls back to ``DATABASE_PASSWORD``) -* ``DATABASE_SUPERUSER_DB`` (``postgres``) - -Set ``DATABASE_URL`` if other parts of the application rely on a single -connection string; this script will still honor the granular inputs above. -""" - -from __future__ import annotations -from config.database import Base -import argparse -import importlib -import logging -import os -import pkgutil -import sys -from dataclasses import dataclass -from pathlib import Path -from typing import Callable, Optional, cast -from urllib.parse import quote_plus, urlencode -import psycopg2 -from psycopg2 import errors -from psycopg2 import sql -from psycopg2 import extensions -from psycopg2.extensions import connection as PGConnection, parse_dsn -from dotenv import load_dotenv -from sqlalchemy import create_engine, inspect - -ROOT_DIR = Path(__file__).resolve().parents[1] -if str(ROOT_DIR) not in sys.path: - sys.path.insert(0, str(ROOT_DIR)) - - -logger = logging.getLogger(__name__) - -SCRIPTS_DIR = Path(__file__).resolve().parent -DEFAULT_MIGRATIONS_DIR = SCRIPTS_DIR / "migrations" -MIGRATIONS_TABLE = "schema_migrations" - - -@dataclass(slots=True) -class DatabaseConfig: - """Configuration required to manage the application database.""" - - driver: str - host: str - port: int - database: str - user: str - password: Optional[str] - schema: Optional[str] - - admin_user: str - admin_password: Optional[str] - admin_database: str = "postgres" - - @classmethod - def from_env( - cls, - overrides: Optional[dict[str, Optional[str]]] = None, - ) -> "DatabaseConfig": - load_dotenv() - - override_map: dict[str, Optional[str]] = dict(overrides or {}) - - def _get(name: str, default: Optional[str] = None) -> Optional[str]: - if name in override_map and override_map[name] is not None: - return override_map[name] - env_value = os.getenv(name) - if env_value is not None: - return env_value - return default - - driver = _get("DATABASE_DRIVER", "postgresql") - host = _get("DATABASE_HOST") - port_value = _get("DATABASE_PORT", "5432") - database = _get("DATABASE_NAME") - user = _get("DATABASE_USER") - password = _get("DATABASE_PASSWORD") - schema = _get("DATABASE_SCHEMA", "public") - - try: - port = int(port_value) if port_value is not None else 5432 - except ValueError as exc: - raise RuntimeError( - "Invalid DATABASE_PORT value: expected integer, got" - f" '{port_value}'" - ) from exc - - admin_url = _get("DATABASE_ADMIN_URL") - if admin_url: - admin_conninfo = parse_dsn(admin_url) - admin_user = admin_conninfo.get("user") or user or "postgres" - admin_password = admin_conninfo.get("password") - admin_database = admin_conninfo.get("dbname") or "postgres" - host = admin_conninfo.get("host") or host - port = int(admin_conninfo.get("port") or port) - else: - admin_user = _get("DATABASE_SUPERUSER", user or "postgres") - admin_password = _get("DATABASE_SUPERUSER_PASSWORD", password) - admin_database = _get("DATABASE_SUPERUSER_DB", "postgres") - - missing = [ - name - for name, value in ( - ("DATABASE_HOST", host), - ("DATABASE_NAME", database), - ("DATABASE_USER", user), - ) - if not value - ] - if missing: - raise RuntimeError( - "Missing required database configuration: " + - ", ".join(missing) - ) - - host = cast(str, host) - database = cast(str, database) - user = cast(str, user) - driver = cast(str, driver) - admin_user = cast(str, admin_user) - admin_database = cast(str, admin_database) - - return cls( - driver=driver, - host=host, - port=port, - database=database, - user=user, - password=password, - schema=schema, - admin_user=admin_user, - admin_password=admin_password, - admin_database=admin_database, - ) - - def admin_dsn(self, database: Optional[str] = None) -> str: - target_db = database or self.admin_database - return self._compose_url( - user=self.admin_user, - password=self.admin_password, - database=target_db, - schema=None, - ) - - def application_dsn(self) -> str: - """Return a SQLAlchemy URL for connecting as the application role.""" - - return self._compose_url( - user=self.user, - password=self.password, - database=self.database, - schema=self.schema, - ) - - def _compose_url( - self, - *, - user: Optional[str], - password: Optional[str], - database: str, - schema: Optional[str], - ) -> str: - auth = "" - if user: - encoded_user = quote_plus(user) - if password: - encoded_pass = quote_plus(password) - auth = f"{encoded_user}:{encoded_pass}@" - else: - auth = f"{encoded_user}@" - - host = self.host - if ":" in host and not host.startswith("["): - host = f"[{host}]" - - host_port = host - if self.port: - host_port = f"{host}:{self.port}" - - url = f"{self.driver}://{auth}{host_port}/{database}" - - params = {} - if schema and schema.strip() and schema != "public": - params["options"] = f"-csearch_path={schema}" - - if params: - url = f"{url}?{urlencode(params, quote_via=quote_plus)}" - - return url - - -class DatabaseSetup: - """Encapsulates the full setup workflow.""" - - def __init__( - self, config: DatabaseConfig, *, dry_run: bool = False - ) -> None: - self.config = config - self.dry_run = dry_run - self._models_loaded = False - self._rollback_actions: list[tuple[str, Callable[[], None]]] = [] - - def _register_rollback( - self, label: str, action: Callable[[], None] - ) -> None: - if self.dry_run: - return - self._rollback_actions.append((label, action)) - - def execute_rollbacks(self) -> None: - if not self._rollback_actions: - logger.info("No rollback actions registered; nothing to undo.") - return - - logger.warning( - "Attempting rollback of %d action(s)", len(self._rollback_actions) - ) - for label, action in reversed(self._rollback_actions): - try: - logger.warning("Rollback step: %s", label) - action() - except Exception: - logger.exception("Rollback action '%s' failed", label) - self._rollback_actions.clear() - - def clear_rollbacks(self) -> None: - self._rollback_actions.clear() - - def _describe_connection(self, user: str, database: str) -> str: - return f"{user}@{self.config.host}:{self.config.port}/{database}" - - def validate_admin_connection(self) -> None: - descriptor = self._describe_connection( - self.config.admin_user, self.config.admin_database - ) - logger.info("[CONNECT] Validating admin connection (%s)", descriptor) - try: - with self._admin_connection(self.config.admin_database) as conn: - with conn.cursor() as cursor: - cursor.execute("SELECT 1") - except psycopg2.Error as exc: - raise RuntimeError( - "Unable to connect with admin credentials. " - "Check DATABASE_ADMIN_URL or DATABASE_SUPERUSER settings." - f" Target: {descriptor}" - ) from exc - logger.info("[CONNECT] Admin connection verified (%s)", descriptor) - - def validate_application_connection(self) -> None: - descriptor = self._describe_connection( - self.config.user, self.config.database - ) - logger.info( - "[CONNECT] Validating application connection (%s)", descriptor) - try: - with self._application_connection() as conn: - with conn.cursor() as cursor: - cursor.execute("SELECT 1") - except psycopg2.Error as exc: - raise RuntimeError( - "Unable to connect using application credentials. " - "Ensure the role exists and credentials are correct. " - f"Target: {descriptor}" - ) from exc - logger.info( - "[CONNECT] Application connection verified (%s)", descriptor) - - def ensure_database(self) -> None: - """Create the target database when it does not already exist.""" - - logger.info("Ensuring database '%s' exists", self.config.database) - try: - conn = self._admin_connection(self.config.admin_database) - except RuntimeError: - logger.error( - "Could not connect to admin database '%s' while creating '%s'.", - self.config.admin_database, - self.config.database, - ) - raise - try: - conn.autocommit = True - conn.set_isolation_level(extensions.ISOLATION_LEVEL_AUTOCOMMIT) - cursor = conn.cursor() - try: - try: - cursor.execute( - "SELECT 1 FROM pg_database WHERE datname = %s", - (self.config.database,), - ) - except psycopg2.Error as exc: - message = ( - "Unable to inspect existing databases while ensuring '%s'." - " Verify admin permissions." - ) % self.config.database - logger.error(message) - raise RuntimeError(message) from exc - - exists = cursor.fetchone() is not None - if exists: - logger.info( - "Database '%s' already present", self.config.database - ) - return - - if self.dry_run: - logger.info( - "Dry run: would create database '%s'. Run without --dry-run to proceed.", - self.config.database, - ) - return - - try: - cursor.execute( - sql.SQL("CREATE DATABASE {} ENCODING 'UTF8'").format( - sql.Identifier(self.config.database) - ) - ) - except psycopg2.Error as exc: - message = ( - "Failed to create database '%s'. Rerun with --dry-run for diagnostics" - ) % self.config.database - logger.error(message) - raise RuntimeError(message) from exc - else: - rollback_label = f"drop database {self.config.database}" - self._register_rollback( - rollback_label, - lambda db=self.config.database: self._drop_database( - db), - ) - logger.info("Created database '%s'", self.config.database) - finally: - cursor.close() - finally: - conn.close() - - def ensure_role(self) -> None: - """Create the application role and assign privileges when missing.""" - - logger.info("Ensuring role '%s' exists", self.config.user) - try: - admin_conn = self._admin_connection(self.config.admin_database) - except RuntimeError: - logger.error( - "Unable to connect with admin credentials while ensuring role '%s'", - self.config.user, - ) - raise - - with admin_conn as conn: - conn.autocommit = True - with conn.cursor() as cursor: - try: - cursor.execute( - "SELECT 1 FROM pg_roles WHERE rolname = %s", - (self.config.user,), - ) - except psycopg2.Error as exc: - message = ( - "Unable to inspect existing roles while ensuring role '%s'." - " Verify admin permissions." - ) % self.config.user - logger.error(message) - raise RuntimeError(message) from exc - role_exists = cursor.fetchone() is not None - if not role_exists: - logger.info("Creating role '%s'", self.config.user) - if self.dry_run: - logger.info( - "Dry run: would create role '%s'. Run without --dry-run to apply.", - self.config.user, - ) - return - try: - if self.config.password: - cursor.execute( - sql.SQL( - "CREATE ROLE {} WITH LOGIN PASSWORD %s" - ).format(sql.Identifier(self.config.user)), - (self.config.password,), - ) - else: - cursor.execute( - sql.SQL("CREATE ROLE {} WITH LOGIN").format( - sql.Identifier(self.config.user) - ) - ) - except psycopg2.Error as exc: - message = ( - "Failed to create role '%s'. Review admin privileges and rerun." - ) % self.config.user - logger.error(message) - raise RuntimeError(message) from exc - else: - rollback_label = f"drop role {self.config.user}" - self._register_rollback( - rollback_label, - lambda role=self.config.user: self._drop_role( - role), - ) - else: - logger.info("Role '%s' already present", self.config.user) - - try: - role_conn = self._admin_connection(self.config.database) - except RuntimeError: - logger.error( - "Unable to connect to application database '%s' while granting privileges to role '%s'", - self.config.database, - self.config.user, - ) - raise - - if self.dry_run: - logger.info( - "Dry run: would grant privileges on schema/database to role '%s'.", - self.config.user, - ) - return - - with role_conn as conn: - conn.autocommit = True - with conn.cursor() as cursor: - schema_name = self.config.schema or "public" - schema_identifier = sql.Identifier(schema_name) - role_identifier = sql.Identifier(self.config.user) - - try: - cursor.execute( - sql.SQL("GRANT CONNECT ON DATABASE {} TO {}").format( - sql.Identifier(self.config.database), - role_identifier, - ) - ) - cursor.execute( - sql.SQL("GRANT USAGE ON SCHEMA {} TO {}").format( - schema_identifier, - role_identifier, - ) - ) - cursor.execute( - sql.SQL("GRANT CREATE ON SCHEMA {} TO {}").format( - schema_identifier, - role_identifier, - ) - ) - cursor.execute( - sql.SQL( - "GRANT ALL PRIVILEGES ON ALL TABLES IN SCHEMA {} TO {}" - ).format( - schema_identifier, - role_identifier, - ) - ) - cursor.execute( - sql.SQL( - "GRANT ALL PRIVILEGES ON ALL SEQUENCES IN SCHEMA {} TO {}" - ).format( - schema_identifier, - role_identifier, - ) - ) - cursor.execute( - sql.SQL( - "ALTER DEFAULT PRIVILEGES IN SCHEMA {} GRANT SELECT, INSERT, UPDATE, DELETE ON TABLES TO {}" - ).format( - schema_identifier, - role_identifier, - ) - ) - cursor.execute( - sql.SQL( - "ALTER DEFAULT PRIVILEGES IN SCHEMA {} GRANT USAGE, SELECT ON SEQUENCES TO {}" - ).format( - schema_identifier, - role_identifier, - ) - ) - except psycopg2.Error as exc: - message = ( - "Failed to grant privileges to role '%s' in schema '%s'." - " Rerun with --dry-run for more context." - ) % (self.config.user, schema_name) - logger.error(message) - raise RuntimeError(message) from exc - logger.info( - "Granted privileges on schema '%s' to role '%s'", - schema_name, - self.config.user, - ) - rollback_label = f"revoke privileges for {self.config.user}" - self._register_rollback( - rollback_label, - lambda schema=schema_name: self._revoke_role_privileges( - schema_name=schema - ), - ) - - def ensure_schema(self) -> None: - """Create the configured schema when it does not exist.""" - - schema_name = self.config.schema - if not schema_name or schema_name == "public": - logger.info("Using default schema 'public'; nothing to ensure") - return - - logger.info("Ensuring schema '%s' exists", schema_name) - with self._admin_connection(self.config.database) as conn: - conn.autocommit = True - with conn.cursor() as cursor: - cursor.execute( - sql.SQL( - "SELECT 1 FROM information_schema.schemata WHERE schema_name = %s" - ), - (schema_name,), - ) - exists = cursor.fetchone() is not None - if not exists: - if self.dry_run: - logger.info( - "Dry run: would create schema '%s'", - schema_name, - ) - else: - cursor.execute( - sql.SQL("CREATE SCHEMA {}").format( - sql.Identifier(schema_name) - ) - ) - logger.info("Created schema '%s'", schema_name) - try: - if self.dry_run: - logger.info( - "Dry run: would set schema '%s' owner to '%s'", - schema_name, - self.config.user, - ) - else: - cursor.execute( - sql.SQL("ALTER SCHEMA {} OWNER TO {}").format( - sql.Identifier(schema_name), - sql.Identifier(self.config.user), - ) - ) - except errors.UndefinedObject: - logger.warning( - "Role '%s' not found when assigning ownership to schema '%s'." - " Run --ensure-role after creating the schema.", - self.config.user, - schema_name, - ) - - def application_role_exists(self) -> bool: - try: - with self._admin_connection(self.config.admin_database) as conn: - with conn.cursor() as cursor: - try: - cursor.execute( - "SELECT 1 FROM pg_roles WHERE rolname = %s", - (self.config.user,), - ) - except psycopg2.Error as exc: - message = ( - "Unable to inspect existing roles while checking for role '%s'." - " Verify admin permissions." - ) % self.config.user - logger.error(message) - raise RuntimeError(message) from exc - return cursor.fetchone() is not None - except RuntimeError: - raise - - def _connect(self, dsn: str, descriptor: str) -> PGConnection: - try: - return psycopg2.connect(dsn) - except psycopg2.Error as exc: - raise RuntimeError( - f"Unable to establish connection. Target: {descriptor}" - ) from exc - - def _admin_connection(self, database: Optional[str] = None) -> PGConnection: - target_db = database or self.config.admin_database - dsn = self.config.admin_dsn(database) - descriptor = self._describe_connection( - self.config.admin_user, target_db - ) - return self._connect(dsn, descriptor) - - def _application_connection(self) -> PGConnection: - dsn = self.config.application_dsn() - descriptor = self._describe_connection( - self.config.user, self.config.database - ) - return self._connect(dsn, descriptor) - - def initialize_schema(self) -> None: - """Create database objects from SQLAlchemy metadata if missing.""" - - self._ensure_models_loaded() - logger.info("Ensuring SQLAlchemy metadata is reflected in database") - engine = create_engine(self.config.application_dsn(), future=True) - try: - inspector = inspect(engine) - existing_tables = set( - inspector.get_table_names(schema=self.config.schema) - ) - metadata_tables = set(Base.metadata.tables.keys()) - missing_tables = sorted(metadata_tables - existing_tables) - - if missing_tables: - logger.info("Pending tables: %s", ", ".join(missing_tables)) - else: - logger.info("All tables already exist") - - if self.dry_run: - if missing_tables: - logger.info("Dry run: skipping creation of pending tables") - return - - Base.metadata.create_all(bind=engine, checkfirst=True) - finally: - engine.dispose() - - logger.info("Schema initialization complete") - - def _ensure_models_loaded(self) -> None: - if self._models_loaded: - return - - package = importlib.import_module("models") - for module_info in pkgutil.iter_modules(package.__path__): - importlib.import_module(f"{package.__name__}.{module_info.name}") - self._models_loaded = True - - def run_migrations( - self, migrations_dir: Optional[Path | str] = None - ) -> None: - """Execute pending SQL migrations in chronological order.""" - - directory = ( - Path(migrations_dir) - if migrations_dir is not None - else DEFAULT_MIGRATIONS_DIR - ) - directory = directory.resolve() - - if not directory.exists(): - logger.warning("Migrations directory '%s' not found", directory) - return - - migration_files = sorted(directory.glob("*.sql")) - if not migration_files: - logger.info("No migration scripts found in '%s'", directory) - return - - baseline_name = "000_base.sql" - baseline_path = directory / baseline_name - - schema_name = self.config.schema or "public" - - with self._application_connection() as conn: - conn.autocommit = True - with conn.cursor() as cursor: - table_exists = self._migrations_table_exists( - cursor, schema_name - ) - if not table_exists: - if self.dry_run: - logger.info( - "Dry run: would create migration history table %s.%s", - schema_name, - MIGRATIONS_TABLE, - ) - applied: set[str] = set() - else: - self._create_migrations_table(cursor, schema_name) - logger.info( - "Created migration history table %s.%s", - schema_name, - MIGRATIONS_TABLE, - ) - applied = set() - else: - applied = self._fetch_applied_migrations( - cursor, schema_name - ) - - self._handle_baseline_migration( - cursor, schema_name, baseline_path, baseline_name, migration_files, applied - ) - - pending = [ - path for path in migration_files if path.name not in applied - ] - - if not pending: - logger.info("No pending migrations") - return - - logger.info( - "Pending migrations: %s", - ", ".join(path.name for path in pending), - ) - - if self.dry_run: - logger.info("Dry run: skipping migration execution") - return - - for path in pending: - self._apply_migration_file(cursor, schema_name, path) - - logger.info("Applied %d migrations", len(pending)) - - def _handle_baseline_migration( - self, - cursor: extensions.cursor, - schema_name: str, - baseline_path: Path, - baseline_name: str, - migration_files: list[Path], - applied: set[str], - ) -> None: - if baseline_path.exists() and baseline_name not in applied: - if self.dry_run: - logger.info( - "Dry run: baseline migration '%s' pending; would apply and mark legacy files", - baseline_name, - ) - else: - logger.info( - "[MIGRATE] Baseline migration '%s' pending; applying and marking older migrations", - baseline_name, - ) - try: - baseline_applied = self._apply_migration_file( - cursor, schema_name, baseline_path - ) - except Exception: - logger.error( - "Failed while applying baseline migration '%s'." - " Review the migration contents and rerun with --dry-run for diagnostics.", - baseline_name, - exc_info=True, - ) - raise - applied.add(baseline_applied) - self._mark_legacy_migrations_as_applied( - cursor, schema_name, migration_files, baseline_name, applied - ) - - def _mark_legacy_migrations_as_applied( - self, - cursor: extensions.cursor, - schema_name: str, - migration_files: list[Path], - baseline_name: str, - applied: set[str], - ) -> None: - legacy_files = [ - path - for path in migration_files - if path.name != baseline_name - ] - for legacy in legacy_files: - if legacy.name not in applied: - try: - cursor.execute( - sql.SQL( - "INSERT INTO {} (filename, applied_at) VALUES (%s, NOW())" - ).format( - sql.Identifier( - schema_name, - MIGRATIONS_TABLE, - ) - ), - (legacy.name,), - ) - except Exception: - logger.error( - "Unable to record legacy migration '%s' after baseline application." - " Check schema_migrations table in schema '%s' for partial state.", - legacy.name, - schema_name, - exc_info=True, - ) - raise - applied.add(legacy.name) - logger.info( - "Marked legacy migration '%s' as applied via baseline", - legacy.name, - ) - - def _apply_migration_file( - self, - cursor, - schema_name: str, - path: Path, - ) -> str: - logger.info("Applying migration '%s'", path.name) - sql_text = path.read_text(encoding="utf-8") - try: - cursor.execute(sql_text) - cursor.execute( - sql.SQL( - "INSERT INTO {} (filename, applied_at) VALUES (%s, NOW())" - ).format(sql.Identifier(schema_name, MIGRATIONS_TABLE)), - (path.name,), - ) - return path.name - except Exception: - logger.exception("Failed to apply migration '%s'", path.name) - raise - - def _migrations_table_exists(self, cursor, schema_name: str) -> bool: - cursor.execute( - """ - SELECT 1 - FROM information_schema.tables - WHERE table_schema = %s AND table_name = %s - """, - (schema_name, MIGRATIONS_TABLE), - ) - return cursor.fetchone() is not None - - def _create_migrations_table(self, cursor, schema_name: str) -> None: - cursor.execute( - sql.SQL( - "CREATE TABLE IF NOT EXISTS {} (" - "filename TEXT PRIMARY KEY," - "applied_at TIMESTAMPTZ NOT NULL DEFAULT NOW()" - ")" - ).format(sql.Identifier(schema_name, MIGRATIONS_TABLE)) - ) - - def _fetch_applied_migrations(self, cursor, schema_name: str) -> set[str]: - cursor.execute( - sql.SQL("SELECT filename FROM {} ORDER BY filename").format( - sql.Identifier(schema_name, MIGRATIONS_TABLE) - ) - ) - return {row[0] for row in cursor.fetchall()} - - def seed_baseline_data(self, *, dry_run: bool) -> None: - """Seed reference data such as currencies.""" - - from scripts import seed_data - - seed_args = argparse.Namespace( - currencies=True, - units=True, - theme=True, - defaults=False, - dry_run=dry_run, - verbose=0, - ) - try: - seed_data.run_with_namespace(seed_args, config=self.config) - except Exception: - logger.error( - "[SEED] Failed during baseline data seeding. " - "Review seed_data.py and rerun with --dry-run for diagnostics.", - exc_info=True, - ) - raise - - if dry_run: - logger.info("[SEED] Dry run: skipped seed verification") - return - - expected_currencies = { - code for code, *_ in getattr(seed_data, "CURRENCY_SEEDS", ()) - } - expected_units = { - code - for code, *_ in getattr(seed_data, "MEASUREMENT_UNIT_SEEDS", ()) - } - self._verify_seeded_data( - expected_currency_codes=expected_currencies, - expected_unit_codes=expected_units, - ) - - def _verify_seeded_data( - self, - *, - expected_currency_codes: set[str], - expected_unit_codes: set[str], - ) -> None: - if not expected_currency_codes and not expected_unit_codes: - logger.info("No seed datasets configured for verification") - return - - with self._application_connection() as conn: - with conn.cursor() as cursor: - if expected_currency_codes: - cursor.execute( - "SELECT code, is_active FROM currency WHERE code = ANY(%s)", - (list(expected_currency_codes),), - ) - rows = cursor.fetchall() - found_codes = {row[0] for row in rows} - missing_codes = sorted( - expected_currency_codes - found_codes - ) - if missing_codes: - message = ( - "Missing expected currencies after seeding: %s. " - "Run scripts/seed_data.py --currencies to restore them." - ) % ", ".join(missing_codes) - logger.error(message) - raise RuntimeError(message) - - logger.info( - "[VERIFY] Verified %d seeded currencies present", - len(found_codes), - ) - - default_status = next( - (row[1] for row in rows if row[0] == "USD"), None - ) - if default_status is False: - message = ( - "Default currency 'USD' is inactive after seeding. " - "Reactivate it or rerun the seeding command." - ) - logger.error(message) - raise RuntimeError(message) - elif default_status is None: - message = ( - "Default currency 'USD' not found after seeding. " - "Ensure baseline migration 000_base.sql ran successfully." - ) - logger.error(message) - raise RuntimeError(message) - else: - logger.info( - "[VERIFY] Verified default currency 'USD' active") - - if expected_unit_codes: - try: - cursor.execute( - "SELECT code, is_active FROM measurement_unit WHERE code = ANY(%s)", - (list(expected_unit_codes),), - ) - except errors.UndefinedTable: - conn.rollback() - message = ( - "measurement_unit table not found during seed verification. " - "Ensure baseline migration 000_base.sql has been applied." - ) - logger.error(message) - raise RuntimeError(message) - else: - rows = cursor.fetchall() - found_units = {row[0] for row in rows} - missing_units = sorted( - expected_unit_codes - found_units - ) - if missing_units: - message = ( - "Missing expected measurement units after seeding: %s. " - "Run scripts/seed_data.py --units to restore them." - ) % ", ".join(missing_units) - logger.error(message) - raise RuntimeError(message) - - inactive_units = sorted( - row[0] for row in rows if not bool(row[1]) - ) - if inactive_units: - message = ( - "Measurement units inactive after seeding: %s. " - "Reactivate them or rerun unit seeding." - ) % ", ".join(inactive_units) - logger.error(message) - raise RuntimeError(message) - - logger.info( - "Verified %d measurement units present", - len(found_units), - ) - - logger.info("Seed verification complete") - - def _drop_database(self, database: str) -> None: - logger.warning("Rollback: dropping database '%s'", database) - with self._admin_connection(self.config.admin_database) as conn: - conn.autocommit = True - with conn.cursor() as cursor: - cursor.execute( - "SELECT pg_terminate_backend(pid) FROM pg_stat_activity WHERE datname = %s", - (database,), - ) - cursor.execute( - sql.SQL("DROP DATABASE IF EXISTS {}").format( - sql.Identifier(database) - ) - ) - - def _drop_role(self, role: str) -> None: - logger.warning("Rollback: dropping role '%s'", role) - with self._admin_connection(self.config.admin_database) as conn: - conn.autocommit = True - with conn.cursor() as cursor: - cursor.execute( - sql.SQL("DROP ROLE IF EXISTS {}").format( - sql.Identifier(role) - ) - ) - - def _revoke_role_privileges(self, *, schema_name: str) -> None: - logger.warning( - "Rollback: revoking privileges on schema '%s' for role '%s'", - schema_name, - self.config.user, - ) - with self._admin_connection(self.config.database) as conn: - conn.autocommit = True - with conn.cursor() as cursor: - cursor.execute( - sql.SQL( - "REVOKE ALL PRIVILEGES ON ALL TABLES IN SCHEMA {} FROM {}" - ).format( - sql.Identifier(schema_name), - sql.Identifier(self.config.user), - ) - ) - cursor.execute( - sql.SQL( - "REVOKE ALL PRIVILEGES ON ALL SEQUENCES IN SCHEMA {} FROM {}" - ).format( - sql.Identifier(schema_name), - sql.Identifier(self.config.user), - ) - ) - cursor.execute( - sql.SQL( - "ALTER DEFAULT PRIVILEGES IN SCHEMA {} REVOKE SELECT, INSERT, UPDATE, DELETE ON TABLES FROM {}" - ).format( - sql.Identifier(schema_name), - sql.Identifier(self.config.user), - ) - ) - cursor.execute( - sql.SQL( - "ALTER DEFAULT PRIVILEGES IN SCHEMA {} REVOKE USAGE, SELECT ON SEQUENCES FROM {}" - ).format( - sql.Identifier(schema_name), - sql.Identifier(self.config.user), - ) - ) - - -def parse_args() -> argparse.Namespace: - parser = argparse.ArgumentParser(description="Bootstrap CalMiner database") - parser.add_argument( - "--ensure-database", - action="store_true", - help="Create the application database when it does not already exist.", - ) - parser.add_argument( - "--ensure-role", - action="store_true", - help="Create the application role and grant necessary privileges.", - ) - parser.add_argument( - "--ensure-schema", - action="store_true", - help="Create the configured schema if it does not exist.", - ) - parser.add_argument( - "--initialize-schema", - action="store_true", - help="Create missing tables based on SQLAlchemy models.", - ) - parser.add_argument( - "--run-migrations", - action="store_true", - help="Execute pending SQL migrations.", - ) - parser.add_argument( - "--seed-data", - action="store_true", - help="Seed baseline reference data (currencies, etc.).", - ) - parser.add_argument( - "--migrations-dir", - default=None, - help="Override the default migrations directory.", - ) - parser.add_argument("--db-driver", help="Override DATABASE_DRIVER") - parser.add_argument("--db-host", help="Override DATABASE_HOST") - parser.add_argument("--db-port", type=int, help="Override DATABASE_PORT") - parser.add_argument("--db-name", help="Override DATABASE_NAME") - parser.add_argument("--db-user", help="Override DATABASE_USER") - parser.add_argument("--db-password", help="Override DATABASE_PASSWORD") - parser.add_argument("--db-schema", help="Override DATABASE_SCHEMA") - parser.add_argument( - "--admin-url", - help="Override DATABASE_ADMIN_URL for administrative operations", - ) - parser.add_argument( - "--admin-user", help="Override DATABASE_SUPERUSER for admin ops" - ) - parser.add_argument( - "--admin-password", - help="Override DATABASE_SUPERUSER_PASSWORD for admin ops", - ) - parser.add_argument( - "--admin-db", - help="Override DATABASE_SUPERUSER_DB for admin ops", - ) - parser.add_argument( - "--dry-run", - action="store_true", - help="Log actions without applying changes.", - ) - parser.add_argument( - "--verbose", - "-v", - action="count", - default=0, - help="Increase logging verbosity", - ) - return parser.parse_args() - - -def main() -> None: - args = parse_args() - level = logging.WARNING - (10 * min(args.verbose, 2)) - logging.basicConfig( - level=max(level, logging.INFO), format="%(levelname)s %(message)s" - ) - - override_args: dict[str, Optional[str]] = { - "DATABASE_DRIVER": args.db_driver, - "DATABASE_HOST": args.db_host, - "DATABASE_NAME": args.db_name, - "DATABASE_USER": args.db_user, - "DATABASE_PASSWORD": args.db_password, - "DATABASE_SCHEMA": args.db_schema, - "DATABASE_ADMIN_URL": args.admin_url, - "DATABASE_SUPERUSER": args.admin_user, - "DATABASE_SUPERUSER_PASSWORD": args.admin_password, - "DATABASE_SUPERUSER_DB": args.admin_db, - } - if args.db_port is not None: - override_args["DATABASE_PORT"] = str(args.db_port) - - config = DatabaseConfig.from_env(overrides=override_args) - setup = DatabaseSetup(config, dry_run=args.dry_run) - - admin_tasks_requested = ( - args.ensure_database or args.ensure_role or args.ensure_schema - ) - if admin_tasks_requested: - setup.validate_admin_connection() - - app_validated = False - - def ensure_application_connection_for(operation: str) -> bool: - nonlocal app_validated - if app_validated: - return True - if setup.dry_run and not setup.application_role_exists(): - logger.info( - "Dry run: skipping %s because application role '%s' does not exist yet.", - operation, - setup.config.user, - ) - return False - setup.validate_application_connection() - app_validated = True - return True - - should_run_migrations = args.run_migrations - auto_run_migrations_reason: Optional[str] = None - if args.seed_data and not should_run_migrations: - should_run_migrations = True - auto_run_migrations_reason = "Seed data requested without explicit --run-migrations; applying migrations first." - - try: - if args.ensure_database: - setup.ensure_database() - if args.ensure_role: - setup.ensure_role() - if args.ensure_schema: - setup.ensure_schema() - - if args.initialize_schema: - if ensure_application_connection_for( - "SQLAlchemy schema initialization" - ): - setup.initialize_schema() - if should_run_migrations: - if ensure_application_connection_for("migration execution"): - if auto_run_migrations_reason: - logger.info(auto_run_migrations_reason) - migrations_path = ( - Path(args.migrations_dir) if args.migrations_dir else None - ) - setup.run_migrations(migrations_path) - if args.seed_data: - if ensure_application_connection_for("baseline data seeding"): - setup.seed_baseline_data(dry_run=args.dry_run) - except Exception: - if not setup.dry_run: - setup.execute_rollbacks() - raise - finally: - if not setup.dry_run: - setup.clear_rollbacks() - - -if __name__ == "__main__": - main() diff --git a/services/reporting.py b/services/reporting.py deleted file mode 100644 index 98387d6..0000000 --- a/services/reporting.py +++ /dev/null @@ -1,79 +0,0 @@ -from statistics import mean, median, pstdev -from typing import Any, Dict, Iterable, List, Mapping, Union, cast - - -def _extract_results(simulation_results: Iterable[object]) -> List[float]: - values: List[float] = [] - for item in simulation_results: - if not isinstance(item, Mapping): - continue - mapping_item = cast(Mapping[str, Any], item) - value = mapping_item.get("result") - if isinstance(value, (int, float)): - values.append(float(value)) - return values - - -def _percentile(values: List[float], percentile: float) -> float: - if not values: - return 0.0 - sorted_values = sorted(values) - if len(sorted_values) == 1: - return sorted_values[0] - index = (percentile / 100) * (len(sorted_values) - 1) - lower = int(index) - upper = min(lower + 1, len(sorted_values) - 1) - weight = index - lower - return sorted_values[lower] * (1 - weight) + sorted_values[upper] * weight - - -def generate_report( - simulation_results: List[Dict[str, float]], -) -> Dict[str, Union[float, int]]: - """Aggregate basic statistics for simulation outputs.""" - - values = _extract_results(simulation_results) - - if not values: - return { - "count": 0, - "mean": 0.0, - "median": 0.0, - "min": 0.0, - "max": 0.0, - "std_dev": 0.0, - "variance": 0.0, - "percentile_10": 0.0, - "percentile_90": 0.0, - "percentile_5": 0.0, - "percentile_95": 0.0, - "value_at_risk_95": 0.0, - "expected_shortfall_95": 0.0, - } - - summary: Dict[str, Union[float, int]] = { - "count": len(values), - "mean": mean(values), - "median": median(values), - "min": min(values), - "max": max(values), - "percentile_10": _percentile(values, 10), - "percentile_90": _percentile(values, 90), - "percentile_5": _percentile(values, 5), - "percentile_95": _percentile(values, 95), - } - - std_dev = pstdev(values) if len(values) > 1 else 0.0 - summary["std_dev"] = std_dev - summary["variance"] = std_dev**2 - - var_95 = summary["percentile_5"] - summary["value_at_risk_95"] = var_95 - - tail_values = [value for value in values if value <= var_95] - if tail_values: - summary["expected_shortfall_95"] = mean(tail_values) - else: - summary["expected_shortfall_95"] = var_95 - - return summary diff --git a/services/security.py b/services/security.py deleted file mode 100644 index 24782c5..0000000 --- a/services/security.py +++ /dev/null @@ -1,59 +0,0 @@ -from datetime import datetime, timedelta -from typing import Any, Union - -from fastapi import HTTPException, status, Depends -from fastapi.security import OAuth2PasswordBearer -from jose import jwt, JWTError -from passlib.context import CryptContext -from sqlalchemy.orm import Session - -from config.database import get_db - - -ACCESS_TOKEN_EXPIRE_MINUTES = 30 -SECRET_KEY = "your-secret-key" # Change this in production -ALGORITHM = "HS256" - -pwd_context = CryptContext(schemes=["pbkdf2_sha256"], deprecated="auto") - -oauth2_scheme = OAuth2PasswordBearer(tokenUrl="users/login") - - -def create_access_token( - subject: Union[str, Any], expires_delta: Union[timedelta, None] = None -) -> str: - if expires_delta: - expire = datetime.utcnow() + expires_delta - else: - expire = datetime.utcnow() + timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES) - to_encode = {"exp": expire, "sub": str(subject)} - encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM) - return encoded_jwt - - -def verify_password(plain_password: str, hashed_password: str) -> bool: - return pwd_context.verify(plain_password, hashed_password) - - -def get_password_hash(password: str) -> str: - return pwd_context.hash(password) - - -async def get_current_user(token: str = Depends(oauth2_scheme), db: Session = Depends(get_db)): - from models.user import User - credentials_exception = HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - detail="Could not validate credentials", - headers={"WWW-Authenticate": "Bearer"}, - ) - try: - payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM]) - username = payload.get("sub") - if username is None: - raise credentials_exception - except JWTError: - raise credentials_exception - user = db.query(User).filter(User.username == username).first() - if user is None: - raise credentials_exception - return user diff --git a/services/settings.py b/services/settings.py deleted file mode 100644 index 51b49ac..0000000 --- a/services/settings.py +++ /dev/null @@ -1,230 +0,0 @@ -from __future__ import annotations - -import os -import re -from typing import Dict, Mapping - -from sqlalchemy.orm import Session - -from models.application_setting import ApplicationSetting -from models.theme_setting import ThemeSetting # Import ThemeSetting model - -CSS_COLOR_CATEGORY = "theme" -CSS_COLOR_VALUE_TYPE = "color" -CSS_ENV_PREFIX = "CALMINER_THEME_" - -CSS_COLOR_DEFAULTS: Dict[str, str] = { - "--color-background": "#f4f5f7", - "--color-surface": "#ffffff", - "--color-text-primary": "#2a1f33", - "--color-text-secondary": "#624769", - "--color-text-muted": "#64748b", - "--color-text-subtle": "#94a3b8", - "--color-text-invert": "#ffffff", - "--color-text-dark": "#0f172a", - "--color-text-strong": "#111827", - "--color-primary": "#5f320d", - "--color-primary-strong": "#7e4c13", - "--color-primary-stronger": "#837c15", - "--color-accent": "#bff838", - "--color-border": "#e2e8f0", - "--color-border-strong": "#cbd5e1", - "--color-highlight": "#eef2ff", - "--color-panel-shadow": "rgba(15, 23, 42, 0.08)", - "--color-panel-shadow-deep": "rgba(15, 23, 42, 0.12)", - "--color-surface-alt": "#f8fafc", - "--color-success": "#047857", - "--color-error": "#b91c1c", -} - -_COLOR_VALUE_PATTERN = re.compile( - r"^(#([0-9a-fA-F]{3}|[0-9a-fA-F]{6}|[0-9a-fA-F]{8})|rgba?\([^)]+\)|hsla?\([^)]+\))$", - re.IGNORECASE, -) - - -def ensure_css_color_settings(db: Session) -> Dict[str, ApplicationSetting]: - """Ensure the CSS color defaults exist in the settings table.""" - - existing = ( - db.query(ApplicationSetting) - .filter(ApplicationSetting.key.in_(CSS_COLOR_DEFAULTS.keys())) - .all() - ) - by_key = {setting.key: setting for setting in existing} - - created = False - for key, default_value in CSS_COLOR_DEFAULTS.items(): - if key in by_key: - continue - setting = ApplicationSetting( - key=key, - value=default_value, - value_type=CSS_COLOR_VALUE_TYPE, - category=CSS_COLOR_CATEGORY, - description=f"CSS variable {key}", - is_editable=True, - ) - db.add(setting) - by_key[key] = setting - created = True - - if created: - db.commit() - for key, setting in by_key.items(): - db.refresh(setting) - - return by_key - - -def get_css_color_settings(db: Session) -> Dict[str, str]: - """Return CSS color variables, filling missing values with defaults.""" - - settings = ensure_css_color_settings(db) - values: Dict[str, str] = { - key: settings[key].value if key in settings else default - for key, default in CSS_COLOR_DEFAULTS.items() - } - - env_overrides = read_css_color_env_overrides(os.environ) - if env_overrides: - values.update(env_overrides) - - return values - - -def update_css_color_settings( - db: Session, updates: Mapping[str, str] -) -> Dict[str, str]: - """Persist provided CSS color overrides and return the final values.""" - - if not updates: - return get_css_color_settings(db) - - invalid_keys = sorted(set(updates.keys()) - set(CSS_COLOR_DEFAULTS.keys())) - if invalid_keys: - invalid_list = ", ".join(invalid_keys) - raise ValueError(f"Unsupported CSS variables: {invalid_list}") - - normalized: Dict[str, str] = {} - for key, value in updates.items(): - normalized[key] = _normalize_color_value(value) - - settings = ensure_css_color_settings(db) - changed = False - - for key, value in normalized.items(): - setting = settings[key] - if setting.value != value: - setting.value = value - changed = True - if setting.value_type != CSS_COLOR_VALUE_TYPE: - setting.value_type = CSS_COLOR_VALUE_TYPE - changed = True - if setting.category != CSS_COLOR_CATEGORY: - setting.category = CSS_COLOR_CATEGORY - changed = True - if not setting.is_editable: - setting.is_editable = True - changed = True - - if changed: - db.commit() - for key in normalized.keys(): - db.refresh(settings[key]) - - return get_css_color_settings(db) - - -def read_css_color_env_overrides( - env: Mapping[str, str] | None = None, -) -> Dict[str, str]: - """Return validated CSS overrides sourced from environment variables.""" - - if env is None: - env = os.environ - - overrides: Dict[str, str] = {} - for css_key in CSS_COLOR_DEFAULTS.keys(): - env_name = css_key_to_env_var(css_key) - raw_value = env.get(env_name) - if raw_value is None: - continue - overrides[css_key] = _normalize_color_value(raw_value) - - return overrides - - -def _normalize_color_value(value: str) -> str: - if not isinstance(value, str): - raise ValueError("Color value must be a string") - trimmed = value.strip() - if not trimmed: - raise ValueError("Color value cannot be empty") - if not _COLOR_VALUE_PATTERN.match(trimmed): - raise ValueError( - "Color value must be a hex code or an rgb/rgba/hsl/hsla expression" - ) - _validate_functional_color(trimmed) - return trimmed - - -def _validate_functional_color(value: str) -> None: - lowered = value.lower() - if lowered.startswith("rgb(") or lowered.startswith("hsl("): - _ensure_component_count(value, expected=3) - elif lowered.startswith("rgba(") or lowered.startswith("hsla("): - _ensure_component_count(value, expected=4) - - -def _ensure_component_count(value: str, expected: int) -> None: - if not value.endswith(")"): - raise ValueError( - "Color function expressions must end with a closing parenthesis" - ) - inner = value[value.index("(") + 1: -1] - parts = [segment.strip() for segment in inner.split(",")] - if len(parts) != expected: - raise ValueError( - "Color function expressions must provide the expected number of components" - ) - if any(not component for component in parts): - raise ValueError("Color function components cannot be empty") - - -def css_key_to_env_var(css_key: str) -> str: - sanitized = css_key.lstrip("-").replace("-", "_").upper() - return f"{CSS_ENV_PREFIX}{sanitized}" - - -def list_css_env_override_rows( - env: Mapping[str, str] | None = None, -) -> list[Dict[str, str]]: - overrides = read_css_color_env_overrides(env) - rows: list[Dict[str, str]] = [] - for css_key, value in overrides.items(): - rows.append( - { - "css_key": css_key, - "env_var": css_key_to_env_var(css_key), - "value": value, - } - ) - return rows - - -def save_theme_settings(db: Session, theme_data: dict): - theme = db.query(ThemeSetting).first() or ThemeSetting() - for key, value in theme_data.items(): - setattr(theme, key, value) - db.add(theme) - db.commit() - db.refresh(theme) - return theme - - -def get_theme_settings(db: Session): - theme = db.query(ThemeSetting).first() - if theme: - return {c.name: getattr(theme, c.name) for c in theme.__table__.columns} - return {} diff --git a/services/simulation.py b/services/simulation.py deleted file mode 100644 index 6c8ffe1..0000000 --- a/services/simulation.py +++ /dev/null @@ -1,144 +0,0 @@ -from __future__ import annotations - -from dataclasses import dataclass -from random import Random -from typing import Dict, List, Literal, Optional, Sequence - - -DEFAULT_STD_DEV_RATIO = 0.1 -DEFAULT_UNIFORM_SPAN_RATIO = 0.15 -DistributionType = Literal["normal", "uniform", "triangular"] - - -@dataclass -class SimulationParameter: - name: str - base_value: float - distribution: DistributionType - std_dev: Optional[float] = None - minimum: Optional[float] = None - maximum: Optional[float] = None - mode: Optional[float] = None - - -def _ensure_positive_span(span: float, fallback: float) -> float: - return span if span and span > 0 else fallback - - -def _compile_parameters( - parameters: Sequence[Dict[str, float]], -) -> List[SimulationParameter]: - compiled: List[SimulationParameter] = [] - for index, item in enumerate(parameters): - if "value" not in item: - raise ValueError(f"Parameter at index {index} must include 'value'") - name = str(item.get("name", f"param_{index}")) - base_value = float(item["value"]) - distribution = str(item.get("distribution", "normal")).lower() - if distribution not in {"normal", "uniform", "triangular"}: - raise ValueError( - f"Parameter '{name}' has unsupported distribution '{distribution}'" - ) - - span_default = abs(base_value) * DEFAULT_UNIFORM_SPAN_RATIO or 1.0 - - if distribution == "normal": - std_dev = item.get("std_dev") - std_dev_value = ( - float(std_dev) - if std_dev is not None - else abs(base_value) * DEFAULT_STD_DEV_RATIO or 1.0 - ) - compiled.append( - SimulationParameter( - name=name, - base_value=base_value, - distribution="normal", - std_dev=_ensure_positive_span(std_dev_value, 1.0), - ) - ) - continue - - minimum = item.get("min") - maximum = item.get("max") - if minimum is None or maximum is None: - minimum = base_value - span_default - maximum = base_value + span_default - minimum = float(minimum) - maximum = float(maximum) - if minimum >= maximum: - raise ValueError( - f"Parameter '{name}' requires 'min' < 'max' for {distribution} distribution" - ) - - if distribution == "uniform": - compiled.append( - SimulationParameter( - name=name, - base_value=base_value, - distribution="uniform", - minimum=minimum, - maximum=maximum, - ) - ) - else: # triangular - mode = item.get("mode") - if mode is None: - mode = base_value - mode_value = float(mode) - if not (minimum <= mode_value <= maximum): - raise ValueError( - f"Parameter '{name}' mode must be within min/max bounds for triangular distribution" - ) - compiled.append( - SimulationParameter( - name=name, - base_value=base_value, - distribution="triangular", - minimum=minimum, - maximum=maximum, - mode=mode_value, - ) - ) - return compiled - - -def _sample_parameter(rng: Random, param: SimulationParameter) -> float: - if param.distribution == "normal": - assert param.std_dev is not None - return rng.normalvariate(param.base_value, param.std_dev) - if param.distribution == "uniform": - assert param.minimum is not None and param.maximum is not None - return rng.uniform(param.minimum, param.maximum) - # triangular - assert ( - param.minimum is not None - and param.maximum is not None - and param.mode is not None - ) - return rng.triangular(param.minimum, param.maximum, param.mode) - - -def run_simulation( - parameters: Sequence[Dict[str, float]], - iterations: int = 1000, - seed: Optional[int] = None, -) -> List[Dict[str, float]]: - """Run a lightweight Monte Carlo simulation using configurable distributions.""" - - if iterations <= 0: - return [] - - compiled_params = _compile_parameters(parameters) - if not compiled_params: - return [] - - rng = Random(seed) - results: List[Dict[str, float]] = [] - for iteration in range(1, iterations + 1): - total = 0.0 - for param in compiled_params: - sample = _sample_parameter(rng, param) - total += sample - results.append({"iteration": iteration, "result": total}) - return results diff --git a/templates/Dashboard.html b/templates/Dashboard.html deleted file mode 100644 index ef2dfcb..0000000 --- a/templates/Dashboard.html +++ /dev/null @@ -1,94 +0,0 @@ -{% extends "base.html" %} {% block title %}Dashboard · CalMiner{% endblock %} {% -block content %} -
-
-

Operations Overview

-

- Unified insight across scenarios, costs, production, maintenance, and - simulations. -

-
-
- -
-
- - - -
-
- {% for metric in summary_metrics %} -
- {{ metric.label }} - {{ metric.value }} -
- {% endfor %} -
-

- 0 %} hidden{% endif %}> Add project inputs to populate summary metrics. -

-
- -
-
-
-
-

Scenario Cost Mix

-

CAPEX vs OPEX totals per scenario

-
-
- -

- Add CAPEX or OPEX entries to display this chart. -

-
-
-
-
-

Production vs Consumption

-

Throughput comparison by scenario

-
-
- -
-
-{% endblock %} {% block scripts %} {{ super() }} - - -{% endblock %} diff --git a/templates/ParameterInput.html b/templates/ParameterInput.html deleted file mode 100644 index 10f30b1..0000000 --- a/templates/ParameterInput.html +++ /dev/null @@ -1,51 +0,0 @@ -{% extends "base.html" %} {% block title %}Process Parameters · CalMiner{% -endblock %} {% block content %} -
-

Scenario Parameters

- {% if scenarios %} -
- - - - -
- -
- - - - - - - - - - -
ParameterValueDistributionDetails
-
- {% else %} -

- No scenarios available. Create a scenario before - adding parameters. -

- {% endif %} -
-{% endblock %} {% block scripts %} {{ super() }} - - -{% endblock %} diff --git a/templates/ScenarioForm.html b/templates/ScenarioForm.html deleted file mode 100644 index fc5e6ab..0000000 --- a/templates/ScenarioForm.html +++ /dev/null @@ -1,53 +0,0 @@ -{% extends "base.html" %} {% block title %}Scenario Management · CalMiner{% -endblock %} {% block content %} -
-

Create a New Scenario

-
- - - -
- -
- {% if scenarios %} - - - - - - - - - {% for scenario in scenarios %} - - - - - {% endfor %} - -
NameDescription
{{ scenario.name }}{{ scenario.description or "—" }}
- {% else %} -

- No scenarios yet. Create one to get started. -

- - - - - - - - - - {% endif %} -
-
-{% endblock %} {% block scripts %} {{ super() }} - -{% endblock %} diff --git a/templates/consumption.html b/templates/consumption.html deleted file mode 100644 index 63c1198..0000000 --- a/templates/consumption.html +++ /dev/null @@ -1,76 +0,0 @@ -{% extends "base.html" %} {% from "partials/components.html" import -select_field, feedback, empty_state, table_container with context %} {% block -title %}Consumption · CalMiner{% endblock %} {% block content %} -
-

Consumption Tracking

-
- {{ select_field( "Scenario filter", "consumption-scenario-filter", - options=scenarios, placeholder="Select a scenario" ) }} -
- {{ empty_state( "consumption-empty", "Choose a scenario to review its - consumption records." ) }} {% call table_container( - "consumption-table-wrapper", hidden=True, aria_label="Scenario consumption - records" ) %} - - - Amount - Description - - - - {% endcall %} -
- -
-

Add Consumption Record

- {% if scenarios %} -
- {{ select_field( "Scenario", "consumption-form-scenario", - name="scenario_id", options=scenarios, required=True, placeholder="Select a - scenario", placeholder_disabled=True ) }} - - - - - -
- {{ feedback("consumption-feedback") }} {% else %} -

- Create a scenario before adding consumption records. -

- {% endif %} -
- -{% endblock %} {% block scripts %} {{ super() }} - - -{% endblock %} diff --git a/templates/costs.html b/templates/costs.html deleted file mode 100644 index 417de97..0000000 --- a/templates/costs.html +++ /dev/null @@ -1,129 +0,0 @@ -{% extends "base.html" %} {% from "partials/components.html" import -select_field, feedback, empty_state, table_container with context %} {% block -title %}Costs · CalMiner{% endblock %} {% block content %} -
-

Cost Overview

- {% if scenarios %} -
- {{ select_field( "Scenario filter", "costs-scenario-filter", - options=scenarios, placeholder="Select a scenario" ) }} -
- {% else %} {{ empty_state( "costs-scenario-empty", "Create a scenario to - review cost information." ) }} {% endif %} {{ empty_state( "costs-empty", - "Choose a scenario to review CAPEX and OPEX details." ) }} - - -
- -
-

Add CAPEX Entry

- {% if scenarios %} -
- {{ select_field( "Scenario", "capex-form-scenario", name="scenario_id", - options=scenarios, required=True, placeholder="Select a scenario", - placeholder_disabled=True ) }} {{ select_field( "Currency", - "capex-form-currency", name="currency_code", options=currency_options, - required=True, placeholder="Select currency", placeholder_disabled=True, - value_attr="id", label_attr="name" ) }} - - - -
- {{ feedback("capex-feedback") }} {% else %} {{ empty_state( - "capex-form-empty", "Create a scenario before adding CAPEX entries." ) }} {% - endif %} -
- -
-

Add OPEX Entry

- {% if scenarios %} -
- {{ select_field( "Scenario", "opex-form-scenario", name="scenario_id", - options=scenarios, required=True, placeholder="Select a scenario", - placeholder_disabled=True ) }} {{ select_field( "Currency", - "opex-form-currency", name="currency_code", options=currency_options, - required=True, placeholder="Select currency", placeholder_disabled=True, - value_attr="id", label_attr="name" ) }} - - - -
- {{ feedback("opex-feedback") }} {% else %} {{ empty_state( "opex-form-empty", - "Create a scenario before adding OPEX entries." ) }} {% endif %} -
- -{% endblock %} {% block scripts %} {{ super() }} - - -{% endblock %} diff --git a/templates/currencies.html b/templates/currencies.html deleted file mode 100644 index 6c99515..0000000 --- a/templates/currencies.html +++ /dev/null @@ -1,131 +0,0 @@ -{% extends "base.html" %} -{% from "partials/components.html" import select_field, feedback, empty_state, table_container with context %} - -{% block title %}Currencies · CalMiner{% endblock %} - -{% block content %} -
-
-
-

Currency Overview

-

- Current availability of currencies for project inputs. -

-
-
- - {% if currency_stats %} -
-
- Total Currencies - {{ currency_stats.total }} -
-
- Active - {{ currency_stats.active }} -
-
- Inactive - {{ currency_stats.inactive }} -
-
- {% else %} {{ empty_state("currencies-overview-empty", "No currency data - available yet.") }} {% endif %} {% call table_container( - "currencies-table-container", aria_label="Configured currencies", - heading="Configured Currencies" ) %} - - - Code - Name - Symbol - Status - Actions - - - - {% endcall %} {{ empty_state( "currencies-table-empty", "No currencies - configured yet.", hidden=currencies|length > 0 ) }} -
- -
-
-
-

Manage Currencies

-

- Create new currencies or update existing configurations inline. -

-
-
- - {% set status_options = [ {"id": "true", "name": "Active"}, {"id": "false", - "name": "Inactive"} ] %} - -
- {{ select_field( "Currency to update (leave blank for new)", - "currency-form-existing", name="existing_code", options=currencies, - placeholder="Create a new currency", value_attr="code", label_attr="name" ) - }} - - - - - - - - {{ select_field( "Status", "currency-form-status", name="is_active", - options=status_options, include_blank=False ) }} - -
- - -
-
- {{ feedback("currency-form-feedback") }} -
-{% endblock %} {% block scripts %} {{ super() }} - - -{% endblock %} diff --git a/templates/equipment.html b/templates/equipment.html deleted file mode 100644 index ed02537..0000000 --- a/templates/equipment.html +++ /dev/null @@ -1,78 +0,0 @@ -{% extends "base.html" %} {% block title %}Equipment · CalMiner{% endblock %} {% -block content %} -
-

Equipment Inventory

- {% if scenarios %} -
- -
- {% else %} -

- Create a scenario to view equipment inventory. -

- {% endif %} -
- Choose a scenario to review the equipment list. -
- -
- -
-

Add Equipment

- {% if scenarios %} -
- - - - -
- - {% else %} -

- Create a scenario before managing equipment. -

- {% endif %} -
- -{% endblock %} {% block scripts %} {{ super() }} - - -{% endblock %} diff --git a/templates/maintenance.html b/templates/maintenance.html deleted file mode 100644 index 51b0449..0000000 --- a/templates/maintenance.html +++ /dev/null @@ -1,111 +0,0 @@ -{% extends "base.html" %} {% block title %}Maintenance · CalMiner{% endblock %} -{% block content %} -
-

Maintenance Schedule

- {% if scenarios %} -
- -
- {% else %} -

- Create a scenario to view maintenance entries. -

- {% endif %} -
- Choose a scenario to review upcoming or completed maintenance. -
- -
- -
-

Add Maintenance Entry

- {% if scenarios %} -
- - - - - - - -
- - {% else %} -

- Create a scenario before managing maintenance - entries. -

- {% endif %} -
- -{% endblock %} {% block scripts %} {{ super() }} - - -{% endblock %} diff --git a/templates/production.html b/templates/production.html deleted file mode 100644 index 1de5b25..0000000 --- a/templates/production.html +++ /dev/null @@ -1,97 +0,0 @@ -{% extends "base.html" %} {% block title %}Production · CalMiner{% endblock %} -{% block content %} -
-

Production Output

- {% if scenarios %} -
- -
- {% else %} -

- Create a scenario to view production output data. -

- {% endif %} -
- Choose a scenario to review its production output. -
- -
- -
-

Add Production Output

- {% if scenarios %} -
- - - - - - -
- - {% else %} -

- Create a scenario before adding production output. -

- {% endif %} -
- -{% endblock %} {% block scripts %} {{ super() }} - - -{% endblock %} diff --git a/templates/reporting.html b/templates/reporting.html deleted file mode 100644 index aa9a635..0000000 --- a/templates/reporting.html +++ /dev/null @@ -1,41 +0,0 @@ -{% extends "base.html" %} {% block title %}Reporting · CalMiner{% endblock %} {% -block content %} -
-

Scenario KPI Summary

-
- -
- - - - - -
- -{% endblock %} {% block scripts %} {{ super() }} - - -{% endblock %} diff --git a/templates/settings.html b/templates/settings.html deleted file mode 100644 index 1fcbc21..0000000 --- a/templates/settings.html +++ /dev/null @@ -1,26 +0,0 @@ -{% extends "base.html" %} {% block title %}Settings · CalMiner{% endblock %} {% -block content %} - -
-
-

Currency Management

-

- Manage available currencies, symbols, and default selections from the - Currency Management page. -

- Go to Currency Management -
- -
-{% endblock %} diff --git a/templates/simulations.html b/templates/simulations.html deleted file mode 100644 index 4e96743..0000000 --- a/templates/simulations.html +++ /dev/null @@ -1,41 +0,0 @@ -{% extends "base.html" %} {% block title %}Simulations · CalMiner{% endblock %} -{% block content %} -
-

Monte Carlo Simulations

- {% if simulation_scenarios %} -
- -
- {% else %} -

Create a scenario before running simulations.

- {% endif %} - -
-

Scenario Run History

- - - - - - - - - - - {% endblock %} {% block scripts %} {{ super() }} - - - {% endblock %} diff --git a/templates/theme_settings.html b/templates/theme_settings.html deleted file mode 100644 index 72cecf4..0000000 --- a/templates/theme_settings.html +++ /dev/null @@ -1,125 +0,0 @@ -{% extends "base.html" %} {% block title %}Theme Settings · CalMiner{% endblock -%} {% block content %} - - -
-
-
-

Theme Colors

-

- Update global CSS variables to customize CalMiner's appearance. -

-
-
-
- {% for key, value in css_variables.items() %} {% set env_meta = - css_env_override_meta.get(key) %} - - {% endfor %} - -
- - -
- - {% from "partials/components.html" import feedback with context %} {{ - feedback("theme-settings-feedback") }} -
- -
-
-
-

Environment Overrides

-

- The following CSS variables are controlled via environment variables and - take precedence over database values. -

-
-
- {% if css_env_override_rows %} -
-
ScenarioIterationsMean Result
- - - - - - - - - {% for row in css_env_override_rows %} - - - - - - {% endfor %} - -
CSS VariableEnvironment VariableValue
{{ row.css_key }}{{ row.env_var }}{{ row.value }}
-
- {% else %} -

No environment overrides configured.

- {% endif %} -
-{% endblock %} {% block scripts %} {{ super() }} - - -{% endblock %} diff --git a/tests/__init__.py b/tests/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/tests/e2e/conftest.py b/tests/e2e/conftest.py deleted file mode 100644 index 6ced399..0000000 --- a/tests/e2e/conftest.py +++ /dev/null @@ -1,170 +0,0 @@ -import os -import subprocess -import time -from typing import Dict, Generator - -import pytest - -# type: ignore[import] -from playwright.sync_api import Browser, Page, Playwright, sync_playwright - -import httpx -from sqlalchemy.engine import make_url - -# Use a different port for the test server to avoid conflicts -TEST_PORT = 8001 -BASE_URL = f"http://localhost:{TEST_PORT}" - - -@pytest.fixture(scope="session", autouse=True) -def live_server() -> Generator[str, None, None]: - """Launch a live test server in a separate process.""" - env = _prepare_database_environment(os.environ.copy()) - - process = subprocess.Popen( - [ - "uvicorn", - "main:app", - "--host", - "127.0.0.1", - f"--port={TEST_PORT}", - ], - stdout=subprocess.DEVNULL, - stderr=subprocess.DEVNULL, - env=env, - ) - - deadline = time.perf_counter() + 30 - last_error: Exception | None = None - while time.perf_counter() < deadline: - if process.poll() is not None: - raise RuntimeError("uvicorn server exited before becoming ready") - try: - response = httpx.get(BASE_URL, timeout=1.0, trust_env=False) - if response.status_code < 500: - break - except Exception as exc: # noqa: BLE001 - last_error = exc - time.sleep(0.5) - else: - process.terminate() - process.wait(timeout=5) - raise TimeoutError( - "Timed out waiting for uvicorn test server to start" - ) from last_error - - try: - yield BASE_URL - finally: - if process.poll() is None: - process.terminate() - try: - process.wait(timeout=5) - except subprocess.TimeoutExpired: - process.kill() - process.wait(timeout=5) - - -@pytest.fixture(scope="session", autouse=True) -def seed_default_currencies(live_server: str) -> None: - """Ensure a baseline set of currencies exists for UI flows.""" - - seeds = [ - {"code": "EUR", "name": "Euro", "symbol": "EUR", "is_active": True}, - { - "code": "CLP", - "name": "Chilean Peso", - "symbol": "CLP$", - "is_active": True, - }, - ] - - with httpx.Client( - base_url=live_server, timeout=5.0, trust_env=False - ) as client: - try: - response = client.get("/api/currencies/?include_inactive=true") - response.raise_for_status() - existing_codes = { - str(item.get("code")) - for item in response.json() - if isinstance(item, dict) and item.get("code") - } - except httpx.HTTPError as exc: # noqa: BLE001 - raise RuntimeError("Failed to read existing currencies") from exc - - for payload in seeds: - if payload["code"] in existing_codes: - continue - try: - create_response = client.post("/api/currencies/", json=payload) - except httpx.HTTPError as exc: # noqa: BLE001 - raise RuntimeError("Failed to seed currencies") from exc - - if create_response.status_code == 409: - continue - create_response.raise_for_status() - - -@pytest.fixture(scope="session") -def playwright_instance() -> Generator[Playwright, None, None]: - """Provide a Playwright instance for the test session.""" - with sync_playwright() as p: - yield p - - -@pytest.fixture(scope="session") -def browser( - playwright_instance: Playwright, -) -> Generator[Browser, None, None]: - """Provide a browser instance for the test session.""" - browser = playwright_instance.chromium.launch() - yield browser - browser.close() - - -@pytest.fixture() -def page(browser: Browser, live_server: str) -> Generator[Page, None, None]: - """Provide a new page for each test.""" - page = browser.new_page(base_url=live_server) - page.goto("/") - page.wait_for_load_state("networkidle") - yield page - page.close() - - -def _prepare_database_environment(env: Dict[str, str]) -> Dict[str, str]: - """Ensure granular database env vars are available for the app under test.""" - - required = ( - "DATABASE_HOST", - "DATABASE_USER", - "DATABASE_NAME", - "DATABASE_PASSWORD", - ) - if all(env.get(key) for key in required): - return env - - legacy_url = env.get("DATABASE_URL") - if not legacy_url: - return env - - url = make_url(legacy_url) - env.setdefault("DATABASE_DRIVER", url.drivername) - if url.host: - env.setdefault("DATABASE_HOST", url.host) - if url.port: - env.setdefault("DATABASE_PORT", str(url.port)) - if url.username: - env.setdefault("DATABASE_USER", url.username) - if url.password: - env.setdefault("DATABASE_PASSWORD", url.password) - if url.database: - env.setdefault("DATABASE_NAME", url.database) - - query_options = dict(url.query) if url.query else {} - options = query_options.get("options") - if isinstance(options, str) and "search_path=" in options: - env.setdefault("DATABASE_SCHEMA", options.split("search_path=")[-1]) - - return env diff --git a/tests/e2e/test_consumption.py b/tests/e2e/test_consumption.py deleted file mode 100644 index 685db93..0000000 --- a/tests/e2e/test_consumption.py +++ /dev/null @@ -1,50 +0,0 @@ -from uuid import uuid4 - -from playwright.sync_api import Page, expect - - -def test_consumption_form_loads(page: Page): - """Verify the consumption form page loads correctly.""" - page.goto("/ui/consumption") - expect(page).to_have_title("Consumption · CalMiner") - expect( - page.locator("h2:has-text('Add Consumption Record')") - ).to_be_visible() - - -def test_create_consumption_item(page: Page): - """Test creating a new consumption item through the UI.""" - # First, create a scenario to associate the consumption with. - page.goto("/ui/scenarios") - scenario_name = f"Consumption Test Scenario {uuid4()}" - page.fill("input[name='name']", scenario_name) - page.click("button[type='submit']") - with page.expect_response("**/api/scenarios/"): - pass # Wait for the scenario to be created - - # Now, navigate to the consumption page and add an item. - page.goto("/ui/consumption") - - # Create a consumption item. - consumption_desc = "Diesel for generators" - page.select_option("#consumption-form-scenario", label=scenario_name) - page.fill("textarea[name='description']", consumption_desc) - page.fill("input[name='amount']", "5000") - page.click("button[type='submit']") - - with page.expect_response("**/api/consumption/") as response_info: - pass - assert response_info.value.status == 201 - - # Verify the new item appears in the table. - page.select_option("#consumption-scenario-filter", label=scenario_name) - expect( - page.locator("#consumption-table-body tr").filter( - has_text=consumption_desc - ) - ).to_be_visible() - - # Verify the feedback message. - expect(page.locator("#consumption-feedback")).to_have_text( - "Consumption record saved." - ) diff --git a/tests/e2e/test_costs.py b/tests/e2e/test_costs.py deleted file mode 100644 index c49439a..0000000 --- a/tests/e2e/test_costs.py +++ /dev/null @@ -1,63 +0,0 @@ -from uuid import uuid4 - -from playwright.sync_api import Page, expect - - -def test_costs_form_loads(page: Page): - """Verify the costs form page loads correctly.""" - page.goto("/ui/costs") - expect(page).to_have_title("Costs · CalMiner") - expect(page.locator("h2:has-text('Add CAPEX Entry')")).to_be_visible() - - -def test_create_capex_and_opex_items(page: Page): - """Test creating new CAPEX and OPEX items through the UI.""" - # First, create a scenario to associate the costs with. - page.goto("/ui/scenarios") - scenario_name = f"Cost Test Scenario {uuid4()}" - page.fill("input[name='name']", scenario_name) - page.click("button[type='submit']") - with page.expect_response("**/api/scenarios/"): - pass # Wait for the scenario to be created - - # Now, navigate to the costs page and add CAPEX and OPEX items. - page.goto("/ui/costs") - - # Create a CAPEX item. - capex_desc = "Initial drilling equipment" - page.select_option("#capex-form-scenario", label=scenario_name) - page.fill("#capex-form-description", capex_desc) - page.fill("#capex-form-amount", "150000") - page.click("#capex-form button[type='submit']") - - with page.expect_response("**/api/costs/capex") as response_info: - pass - assert response_info.value.status == 200 - - # Create an OPEX item. - opex_desc = "Monthly fuel costs" - page.select_option("#opex-form-scenario", label=scenario_name) - page.fill("#opex-form-description", opex_desc) - page.fill("#opex-form-amount", "25000") - page.click("#opex-form button[type='submit']") - - with page.expect_response("**/api/costs/opex") as response_info: - pass - assert response_info.value.status == 200 - - # Verify the new items appear in their respective tables. - page.select_option("#costs-scenario-filter", label=scenario_name) - expect( - page.locator("#capex-table-body tr").filter(has_text=capex_desc) - ).to_be_visible() - expect( - page.locator("#opex-table-body tr").filter(has_text=opex_desc) - ).to_be_visible() - - # Verify the feedback messages. - expect(page.locator("#capex-feedback")).to_have_text( - "Entry saved successfully." - ) - expect(page.locator("#opex-feedback")).to_have_text( - "Entry saved successfully." - ) diff --git a/tests/e2e/test_currencies.py b/tests/e2e/test_currencies.py deleted file mode 100644 index 4b7f8d0..0000000 --- a/tests/e2e/test_currencies.py +++ /dev/null @@ -1,135 +0,0 @@ -import random -import string - -from playwright.sync_api import Page, expect - - -def _unique_currency_code(existing: set[str]) -> str: - """Generate a unique three-letter code not present in *existing*.""" - alphabet = string.ascii_uppercase - for _ in range(100): - candidate = "".join(random.choices(alphabet, k=3)) - if candidate not in existing and candidate != "USD": - return candidate - raise AssertionError( - "Unable to generate a unique currency code for the test run." - ) - - -def _metric_value(page: Page, element_id: str) -> int: - locator = page.locator(f"#{element_id}") - expect(locator).to_be_visible() - return int(locator.inner_text().strip()) - - -def _expect_feedback(page: Page, expected_text: str) -> None: - page.wait_for_function( - "expected => {" - " const el = document.getElementById('currency-form-feedback');" - " if (!el) return false;" - " const text = (el.textContent || '').trim();" - " return !el.classList.contains('hidden') && text === expected;" - "}", - arg=expected_text, - ) - feedback = page.locator("#currency-form-feedback") - expect(feedback).to_have_text(expected_text) - - -def test_currency_workflow_create_update_toggle(page: Page) -> None: - """Exercise create, update, and toggle flows on the currency settings page.""" - page.goto("/ui/currencies") - expect(page).to_have_title("Currencies · CalMiner") - expect(page.locator("h2:has-text('Currency Overview')")).to_be_visible() - - code_cells = page.locator("#currencies-table-body tr td:nth-child(1)") - existing_codes = { - text.strip().upper() for text in code_cells.all_inner_texts() - } - - total_before = _metric_value(page, "currency-metric-total") - active_before = _metric_value(page, "currency-metric-active") - inactive_before = _metric_value(page, "currency-metric-inactive") - - new_code = _unique_currency_code(existing_codes) - new_name = f"Test Currency {new_code}" - new_symbol = new_code[0] - - page.fill("#currency-form-code", new_code) - page.fill("#currency-form-name", new_name) - page.fill("#currency-form-symbol", new_symbol) - page.select_option("#currency-form-status", "true") - - with page.expect_response("**/api/currencies/") as create_info: - page.click("button[type='submit']") - create_response = create_info.value - assert create_response.status == 201 - - _expect_feedback(page, "Currency created successfully.") - - page.wait_for_function( - "expected => Number(document.getElementById('currency-metric-total').textContent.trim()) === expected", - arg=total_before + 1, - ) - page.wait_for_function( - "expected => Number(document.getElementById('currency-metric-active').textContent.trim()) === expected", - arg=active_before + 1, - ) - - row = page.locator("#currencies-table-body tr").filter(has_text=new_code) - expect(row).to_be_visible() - expect(row.locator("td").nth(3)).to_have_text("Active") - - # Switch to update mode using the existing currency option. - page.select_option("#currency-form-existing", new_code) - updated_name = f"{new_name} Updated" - updated_symbol = f"{new_symbol}$" - page.fill("#currency-form-name", updated_name) - page.fill("#currency-form-symbol", updated_symbol) - page.select_option("#currency-form-status", "false") - - with page.expect_response(f"**/api/currencies/{new_code}") as update_info: - page.click("button[type='submit']") - update_response = update_info.value - assert update_response.status == 200 - - _expect_feedback(page, "Currency updated successfully.") - - page.wait_for_function( - "expected => Number(document.getElementById('currency-metric-active').textContent.trim()) === expected", - arg=active_before, - ) - page.wait_for_function( - "expected => Number(document.getElementById('currency-metric-inactive').textContent.trim()) === expected", - arg=inactive_before + 1, - ) - - expect(row.locator("td").nth(1)).to_have_text(updated_name) - expect(row.locator("td").nth(2)).to_have_text(updated_symbol) - expect(row.locator("td").nth(3)).to_contain_text("Inactive") - - toggle_button = row.locator("button[data-action='toggle']") - expect(toggle_button).to_have_text("Activate") - - with page.expect_response( - f"**/api/currencies/{new_code}/activation" - ) as toggle_info: - toggle_button.click() - toggle_response = toggle_info.value - assert toggle_response.status == 200 - - page.wait_for_function( - "expected => Number(document.getElementById('currency-metric-active').textContent.trim()) === expected", - arg=active_before + 1, - ) - page.wait_for_function( - "expected => Number(document.getElementById('currency-metric-inactive').textContent.trim()) === expected", - arg=inactive_before, - ) - - _expect_feedback(page, f"Currency {new_code} activated.") - - expect(row.locator("td").nth(3)).to_contain_text("Active") - expect(row.locator("button[data-action='toggle']")).to_have_text( - "Deactivate" - ) diff --git a/tests/e2e/test_dashboard.py b/tests/e2e/test_dashboard.py deleted file mode 100644 index 198b04d..0000000 --- a/tests/e2e/test_dashboard.py +++ /dev/null @@ -1,17 +0,0 @@ -from playwright.sync_api import Page, expect - - -def test_dashboard_loads_and_has_title(page: Page): - """Verify the dashboard page loads and the title is correct.""" - expect(page).to_have_title("Dashboard · CalMiner") - - -def test_dashboard_shows_summary_metrics_panel(page: Page): - """Check that the summary metrics panel is visible.""" - expect(page.locator("h2:has-text('Operations Overview')")).to_be_visible() - - -def test_dashboard_renders_cost_chart(page: Page): - """Ensure the scenario cost chart canvas is present.""" - expect(page.locator("#cost-chart")).to_be_attached() - expect(page.locator("#cost-chart-empty")).to_be_visible() diff --git a/tests/e2e/test_equipment.py b/tests/e2e/test_equipment.py deleted file mode 100644 index f507a6e..0000000 --- a/tests/e2e/test_equipment.py +++ /dev/null @@ -1,45 +0,0 @@ -from uuid import uuid4 - -from playwright.sync_api import Page, expect - - -def test_equipment_form_loads(page: Page): - """Verify the equipment form page loads correctly.""" - page.goto("/ui/equipment") - expect(page).to_have_title("Equipment · CalMiner") - expect(page.locator("h2:has-text('Add Equipment')")).to_be_visible() - - -def test_create_equipment_item(page: Page): - """Test creating a new equipment item through the UI.""" - # First, create a scenario to associate the equipment with. - page.goto("/ui/scenarios") - scenario_name = f"Equipment Test Scenario {uuid4()}" - page.fill("input[name='name']", scenario_name) - page.click("button[type='submit']") - with page.expect_response("**/api/scenarios/"): - pass # Wait for the scenario to be created - - # Now, navigate to the equipment page and add an item. - page.goto("/ui/equipment") - - # Create an equipment item. - equipment_name = "Haul Truck HT-05" - equipment_desc = "Primary haul truck for ore transport." - page.select_option("#equipment-form-scenario", label=scenario_name) - page.fill("#equipment-form-name", equipment_name) - page.fill("#equipment-form-description", equipment_desc) - page.click("button[type='submit']") - - with page.expect_response("**/api/equipment/") as response_info: - pass - assert response_info.value.status == 200 - - # Verify the new item appears in the table. - page.select_option("#equipment-scenario-filter", label=scenario_name) - expect( - page.locator("#equipment-table-body tr").filter(has_text=equipment_name) - ).to_be_visible() - - # Verify the feedback message. - expect(page.locator("#equipment-feedback")).to_have_text("Equipment saved.") diff --git a/tests/e2e/test_maintenance.py b/tests/e2e/test_maintenance.py deleted file mode 100644 index fb9a403..0000000 --- a/tests/e2e/test_maintenance.py +++ /dev/null @@ -1,58 +0,0 @@ -from uuid import uuid4 - -from playwright.sync_api import Page, expect - - -def test_maintenance_form_loads(page: Page): - """Verify the maintenance form page loads correctly.""" - page.goto("/ui/maintenance") - expect(page).to_have_title("Maintenance · CalMiner") - expect(page.locator("h2:has-text('Add Maintenance Entry')")).to_be_visible() - - -def test_create_maintenance_item(page: Page): - """Test creating a new maintenance item through the UI.""" - # First, create a scenario and an equipment item. - page.goto("/ui/scenarios") - scenario_name = f"Maintenance Test Scenario {uuid4()}" - page.fill("input[name='name']", scenario_name) - page.click("button[type='submit']") - with page.expect_response("**/api/scenarios/"): - pass - - page.goto("/ui/equipment") - equipment_name = f"Excavator EX-12 {uuid4()}" - page.select_option("#equipment-form-scenario", label=scenario_name) - page.fill("#equipment-form-name", equipment_name) - page.click("button[type='submit']") - with page.expect_response("**/api/equipment/"): - pass - - # Now, navigate to the maintenance page and add an item. - page.goto("/ui/maintenance") - - # Create a maintenance item. - maintenance_desc = "Scheduled engine overhaul" - page.select_option("#maintenance-form-scenario", label=scenario_name) - page.select_option("#maintenance-form-equipment", label=equipment_name) - page.fill("#maintenance-form-date", "2025-12-01") - page.fill("#maintenance-form-description", maintenance_desc) - page.fill("#maintenance-form-cost", "12000") - page.click("button[type='submit']") - - with page.expect_response("**/api/maintenance/") as response_info: - pass - assert response_info.value.status == 201 - - # Verify the new item appears in the table. - page.select_option("#maintenance-scenario-filter", label=scenario_name) - expect( - page.locator("#maintenance-table-body tr").filter( - has_text=maintenance_desc - ) - ).to_be_visible() - - # Verify the feedback message. - expect(page.locator("#maintenance-feedback")).to_have_text( - "Maintenance entry saved." - ) diff --git a/tests/e2e/test_production.py b/tests/e2e/test_production.py deleted file mode 100644 index 72a63ba..0000000 --- a/tests/e2e/test_production.py +++ /dev/null @@ -1,48 +0,0 @@ -from uuid import uuid4 - -from playwright.sync_api import Page, expect - - -def test_production_form_loads(page: Page): - """Verify the production form page loads correctly.""" - page.goto("/ui/production") - expect(page).to_have_title("Production · CalMiner") - expect(page.locator("h2:has-text('Add Production Output')")).to_be_visible() - - -def test_create_production_item(page: Page): - """Test creating a new production item through the UI.""" - # First, create a scenario to associate the production with. - page.goto("/ui/scenarios") - scenario_name = f"Production Test Scenario {uuid4()}" - page.fill("input[name='name']", scenario_name) - page.click("button[type='submit']") - with page.expect_response("**/api/scenarios/"): - pass # Wait for the scenario to be created - - # Now, navigate to the production page and add an item. - page.goto("/ui/production") - - # Create a production item. - production_desc = "Ore extracted - Grade A" - page.select_option("#production-form-scenario", label=scenario_name) - page.fill("#production-form-description", production_desc) - page.fill("#production-form-amount", "1500") - page.click("button[type='submit']") - - with page.expect_response("**/api/production/") as response_info: - pass - assert response_info.value.status == 201 - - # Verify the new item appears in the table. - page.select_option("#production-scenario-filter", label=scenario_name) - expect( - page.locator("#production-table-body tr").filter( - has_text=production_desc - ) - ).to_be_visible() - - # Verify the feedback message. - expect(page.locator("#production-feedback")).to_have_text( - "Production output saved." - ) diff --git a/tests/e2e/test_reporting.py b/tests/e2e/test_reporting.py deleted file mode 100644 index 04cee12..0000000 --- a/tests/e2e/test_reporting.py +++ /dev/null @@ -1,9 +0,0 @@ -from playwright.sync_api import Page, expect - - -def test_reporting_view_loads(page: Page): - """Verify the reporting view page loads correctly.""" - page.get_by_role("link", name="Reporting").click() - expect(page).to_have_url("http://localhost:8001/ui/reporting") - expect(page).to_have_title("Reporting · CalMiner") - expect(page.locator("h2:has-text('Scenario KPI Summary')")).to_be_visible() diff --git a/tests/e2e/test_scenarios.py b/tests/e2e/test_scenarios.py deleted file mode 100644 index 04f37ea..0000000 --- a/tests/e2e/test_scenarios.py +++ /dev/null @@ -1,43 +0,0 @@ -from uuid import uuid4 - -from playwright.sync_api import Page, expect - - -def test_scenario_form_loads(page: Page): - """Verify the scenario form page loads correctly.""" - page.goto("/ui/scenarios") - expect(page).to_have_url( - "http://localhost:8001/ui/scenarios" - ) # Updated port - expect(page.locator("h2:has-text('Create a New Scenario')")).to_be_visible() - - -def test_create_new_scenario(page: Page): - """Test creating a new scenario via the UI form.""" - page.goto("/ui/scenarios") - - scenario_name = f"E2E Test Scenario {uuid4()}" - scenario_desc = "A scenario created during an end-to-end test." - - page.fill("input[name='name']", scenario_name) - page.fill("input[name='description']", scenario_desc) - - # Expect a network response from the POST request after clicking the submit button. - with page.expect_response("**/api/scenarios/") as response_info: - page.click("button[type='submit']") - - response = response_info.value - assert response.status == 200 - - # After a successful submission, the new scenario should be visible in the table. - # The table is dynamically updated, so we might need to wait for it to appear. - new_row = page.locator(f"tr:has-text('{scenario_name}')") - expect(new_row).to_be_visible() - expect(new_row.locator("td").nth(1)).to_have_text(scenario_desc) - - # Verify the feedback message. - feedback = page.locator("#feedback") - expect(feedback).to_be_visible() - expect(feedback).to_have_text( - f'Scenario "{scenario_name}" created successfully.' - ) diff --git a/tests/e2e/test_smoke.py b/tests/e2e/test_smoke.py deleted file mode 100644 index a9f0b23..0000000 --- a/tests/e2e/test_smoke.py +++ /dev/null @@ -1,85 +0,0 @@ -import pytest -from playwright.sync_api import Page, expect - -# A list of UI routes to check, with their URL, expected title, and a key heading text. -UI_ROUTES = [ - ("/", "Dashboard · CalMiner", "Operations Overview"), - ("/ui/dashboard", "Dashboard · CalMiner", "Operations Overview"), - ( - "/ui/scenarios", - "Scenario Management · CalMiner", - "Create a New Scenario", - ), - ("/ui/parameters", "Process Parameters · CalMiner", "Scenario Parameters"), - ("/ui/settings", "Settings · CalMiner", "Settings"), - ("/ui/costs", "Costs · CalMiner", "Cost Overview"), - ("/ui/consumption", "Consumption · CalMiner", "Consumption Tracking"), - ("/ui/production", "Production · CalMiner", "Production Output"), - ("/ui/equipment", "Equipment · CalMiner", "Equipment Inventory"), - ("/ui/maintenance", "Maintenance · CalMiner", "Maintenance Schedule"), - ("/ui/simulations", "Simulations · CalMiner", "Monte Carlo Simulations"), - ("/ui/reporting", "Reporting · CalMiner", "Scenario KPI Summary"), - ("/ui/currencies", "Currencies · CalMiner", "Currency Overview"), -] - - -@pytest.mark.parametrize("url, title, heading", UI_ROUTES) -def test_ui_pages_load_correctly( - page: Page, url: str, title: str, heading: str -): - """Verify that all UI pages load with the correct title and a visible heading.""" - page.goto(url) - expect(page).to_have_title(title) - # The app uses a mix of h1 and h2 for main page headings. - heading_locator = page.locator( - f"h1:has-text('{heading}'), h2:has-text('{heading}')" - ) - expect(heading_locator.first).to_be_visible() - - -def test_settings_theme_form_interaction(page: Page): - page.goto("/theme-settings") - expect(page).to_have_title("Theme Settings · CalMiner") - - env_rows = page.locator("#theme-env-overrides tbody tr") - disabled_inputs = page.locator( - "#theme-settings-form input.color-value-input[disabled]" - ) - env_row_count = env_rows.count() - disabled_count = disabled_inputs.count() - assert disabled_count == env_row_count - - color_input = page.locator( - "#theme-settings-form input[name='--color-primary']" - ) - expect(color_input).to_be_visible() - expect(color_input).to_be_enabled() - - original_value = color_input.input_value() - candidate_values = ("#114455", "#225566") - new_value = ( - candidate_values[0] - if original_value != candidate_values[0] - else candidate_values[1] - ) - - color_input.fill(new_value) - page.click("#theme-settings-form button[type='submit']") - - feedback = page.locator("#theme-settings-feedback") - expect(feedback).to_contain_text("updated successfully") - - computed_color = page.evaluate( - "() => getComputedStyle(document.documentElement).getPropertyValue('--color-primary').trim()" - ) - assert computed_color.lower() == new_value.lower() - - page.reload() - expect(color_input).to_have_value(new_value) - - color_input.fill(original_value) - page.click("#theme-settings-form button[type='submit']") - expect(feedback).to_contain_text("updated successfully") - - page.reload() - expect(color_input).to_have_value(original_value) diff --git a/tests/unit/__init__.py b/tests/unit/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py deleted file mode 100644 index 00d8401..0000000 --- a/tests/unit/conftest.py +++ /dev/null @@ -1,266 +0,0 @@ -from datetime import date -from typing import Any, Dict, Generator -from uuid import uuid4 - -import pytest -from fastapi.testclient import TestClient -from sqlalchemy import create_engine -from sqlalchemy.orm import Session, sessionmaker -from sqlalchemy.pool import StaticPool - -from config.database import Base -from main import app -from models.capex import Capex -from models.consumption import Consumption -from models.equipment import Equipment -from models.maintenance import Maintenance -from models.opex import Opex -from models.parameters import Parameter -from models.production_output import ProductionOutput -from models.scenario import Scenario -from models.simulation_result import SimulationResult - -SQLALCHEMY_TEST_URL = "sqlite:///:memory:" -engine = create_engine( - SQLALCHEMY_TEST_URL, - connect_args={"check_same_thread": False}, - poolclass=StaticPool, -) -TestingSessionLocal = sessionmaker( - autocommit=False, autoflush=False, bind=engine -) - - -@pytest.fixture(scope="session", autouse=True) -def setup_database() -> Generator[None, None, None]: - # Ensure all model metadata is registered before creating tables - from models import ( - application_setting, - capex, - consumption, - currency, - distribution, - equipment, - maintenance, - opex, - parameters, - production_output, - role, - scenario, - simulation_result, - theme_setting, - user, - ) # noqa: F401 - imported for side effects - - _ = ( - capex, - consumption, - currency, - distribution, - equipment, - maintenance, - application_setting, - opex, - parameters, - production_output, - role, - scenario, - simulation_result, - theme_setting, - user, - ) - - Base.metadata.create_all(bind=engine) - yield - Base.metadata.drop_all(bind=engine) - - -@pytest.fixture() -def db_session() -> Generator[Session, None, None]: - Base.metadata.drop_all(bind=engine) - Base.metadata.create_all(bind=engine) - session = TestingSessionLocal() - try: - yield session - finally: - session.rollback() - session.close() - - -@pytest.fixture() -def api_client(db_session: Session) -> Generator[TestClient, None, None]: - def override_get_db(): - try: - yield db_session - finally: - pass - - from routes.dependencies import get_db - - app.dependency_overrides[get_db] = override_get_db - - with TestClient(app) as client: - yield client - - app.dependency_overrides.pop(get_db, None) - - -@pytest.fixture() -def seeded_ui_data( - db_session: Session, -) -> Generator[Dict[str, Any], None, None]: - """Populate a scenario with representative related records for UI tests.""" - scenario_name = f"Scenario Alpha {uuid4()}" - scenario = Scenario(name=scenario_name, description="Seeded UI scenario") - db_session.add(scenario) - db_session.flush() - - parameter = Parameter( - scenario_id=scenario.id, - name="Ore Grade", - value=1.5, - distribution_type="normal", - distribution_parameters={"mean": 1.5, "std_dev": 0.1}, - ) - capex = Capex( - scenario_id=scenario.id, - amount=1_000_000.0, - description="Drill purchase", - currency_code="USD", - ) - opex = Opex( - scenario_id=scenario.id, - amount=250_000.0, - description="Fuel spend", - currency_code="USD", - ) - consumption = Consumption( - scenario_id=scenario.id, - amount=1_200.0, - description="Diesel (L)", - unit_name="Liters", - unit_symbol="L", - ) - production = ProductionOutput( - scenario_id=scenario.id, - amount=800.0, - description="Ore (tonnes)", - unit_name="Tonnes", - unit_symbol="t", - ) - equipment = Equipment( - scenario_id=scenario.id, - name="Excavator 42", - description="Primary loader", - ) - db_session.add_all( - [parameter, capex, opex, consumption, production, equipment] - ) - db_session.flush() - - maintenance = Maintenance( - scenario_id=scenario.id, - equipment_id=equipment.id, - maintenance_date=date(2025, 1, 15), - description="Hydraulic service", - cost=15_000.0, - ) - simulation_results = [ - SimulationResult( - scenario_id=scenario.id, - iteration=index, - result=value, - ) - for index, value in enumerate( - (950_000.0, 975_000.0, 990_000.0), start=1 - ) - ] - - db_session.add(maintenance) - db_session.add_all(simulation_results) - db_session.commit() - - try: - yield { - "scenario": scenario, - "equipment": equipment, - "simulation_results": simulation_results, - } - finally: - db_session.query(SimulationResult).filter_by( - scenario_id=scenario.id - ).delete() - db_session.query(Maintenance).filter_by( - scenario_id=scenario.id - ).delete() - db_session.query(Equipment).filter_by(id=equipment.id).delete() - db_session.query(ProductionOutput).filter_by( - scenario_id=scenario.id - ).delete() - db_session.query(Consumption).filter_by( - scenario_id=scenario.id - ).delete() - db_session.query(Opex).filter_by(scenario_id=scenario.id).delete() - db_session.query(Capex).filter_by(scenario_id=scenario.id).delete() - db_session.query(Parameter).filter_by(scenario_id=scenario.id).delete() - db_session.query(Scenario).filter_by(id=scenario.id).delete() - db_session.commit() - - -@pytest.fixture() -def invalid_request_payloads( - db_session: Session, -) -> Generator[Dict[str, Any], None, None]: - """Provide reusable invalid request bodies for exercising validation branches.""" - duplicate_name = f"Scenario Duplicate {uuid4()}" - existing = Scenario( - name=duplicate_name, - description="Existing scenario for duplicate checks", - ) - db_session.add(existing) - db_session.commit() - - payloads: Dict[str, Any] = { - "existing_scenario": existing, - "scenario_duplicate": { - "name": duplicate_name, - "description": "Second scenario should fail with duplicate name", - }, - "parameter_missing_scenario": { - "scenario_id": existing.id + 99, - "name": "Invalid Parameter", - "value": 1.0, - }, - "parameter_invalid_distribution": { - "scenario_id": existing.id, - "name": "Weird Dist", - "value": 2.5, - "distribution_type": "invalid", - }, - "simulation_unknown_scenario": { - "scenario_id": existing.id + 99, - "iterations": 10, - "parameters": [ - {"name": "grade", "value": 1.2, "distribution": "normal"} - ], - }, - "simulation_missing_parameters": { - "scenario_id": existing.id, - "iterations": 5, - "parameters": [], - }, - "reporting_non_list_payload": {"result": 10.0}, - "reporting_missing_result": [{"value": 12.0}], - "maintenance_negative_cost": { - "equipment_id": 1, - "scenario_id": existing.id, - "maintenance_date": "2025-01-15", - "cost": -500.0, - }, - } - - try: - yield payloads - finally: - db_session.query(Scenario).filter_by(id=existing.id).delete() - db_session.commit() diff --git a/tests/unit/test_auth.py b/tests/unit/test_auth.py deleted file mode 100644 index f4c0251..0000000 --- a/tests/unit/test_auth.py +++ /dev/null @@ -1,231 +0,0 @@ -from services.security import get_password_hash, verify_password - - -def test_password_hashing(): - password = "testpassword" - hashed_password = get_password_hash(password) - assert verify_password(password, hashed_password) - assert not verify_password("wrongpassword", hashed_password) - - -def test_register_user(api_client): - response = api_client.post( - "/users/register", - json={ - "username": "testuser", - "email": "test@example.com", - "password": "testpassword", - }, - ) - assert response.status_code == 201 - data = response.json() - assert data["username"] == "testuser" - assert data["email"] == "test@example.com" - assert "id" in data - assert "role_id" in data - - response = api_client.post( - "/users/register", - json={ - "username": "testuser", - "email": "another@example.com", - "password": "testpassword", - }, - ) - assert response.status_code == 400 - assert response.json() == {"detail": "Username already registered"} - - response = api_client.post( - "/users/register", - json={ - "username": "anotheruser", - "email": "test@example.com", - "password": "testpassword", - }, - ) - assert response.status_code == 400 - assert response.json() == {"detail": "Email already registered"} - - -def test_login_user(api_client): - # Register a user first - api_client.post( - "/users/register", - json={ - "username": "loginuser", - "email": "login@example.com", - "password": "loginpassword", - }, - ) - - response = api_client.post( - "/users/login", - json={"username": "loginuser", "password": "loginpassword"}, - ) - assert response.status_code == 200 - data = response.json() - assert "access_token" in data - assert data["token_type"] == "bearer" - - response = api_client.post( - "/users/login", - json={"username": "loginuser", "password": "wrongpassword"}, - ) - assert response.status_code == 401 - assert response.json() == {"detail": "Incorrect username or password"} - - response = api_client.post( - "/users/login", - json={"username": "nonexistent", "password": "password"}, - ) - assert response.status_code == 401 - assert response.json() == {"detail": "Incorrect username or password"} - - -def test_read_users_me(api_client): - # Register a user first - api_client.post( - "/users/register", - json={ - "username": "profileuser", - "email": "profile@example.com", - "password": "profilepassword", - }, - ) - # Login to get a token - login_response = api_client.post( - "/users/login", - json={"username": "profileuser", "password": "profilepassword"}, - ) - token = login_response.json()["access_token"] - - response = api_client.get( - "/users/me", headers={"Authorization": f"Bearer {token}"} - ) - assert response.status_code == 200 - data = response.json() - assert data["username"] == "profileuser" - assert data["email"] == "profile@example.com" - - -def test_update_users_me(api_client): - # Register a user first - api_client.post( - "/users/register", - json={ - "username": "updateuser", - "email": "update@example.com", - "password": "updatepassword", - }, - ) - # Login to get a token - login_response = api_client.post( - "/users/login", - json={"username": "updateuser", "password": "updatepassword"}, - ) - token = login_response.json()["access_token"] - - response = api_client.put( - "/users/me", - headers={"Authorization": f"Bearer {token}"}, - json={ - "username": "updateduser", - "email": "updated@example.com", - "password": "newpassword", - }, - ) - assert response.status_code == 200 - data = response.json() - assert data["username"] == "updateduser" - assert data["email"] == "updated@example.com" - - # Verify password change - response = api_client.post( - "/users/login", - json={"username": "updateduser", "password": "newpassword"}, - ) - assert response.status_code == 200 - token = response.json()["access_token"] - - # Test username already taken - api_client.post( - "/users/register", - json={ - "username": "anotherupdateuser", - "email": "anotherupdate@example.com", - "password": "password", - }, - ) - response = api_client.put( - "/users/me", - headers={"Authorization": f"Bearer {token}"}, - json={ - "username": "anotherupdateuser", - }, - ) - assert response.status_code == 400 - assert response.json() == {"detail": "Username already taken"} - - # Test email already registered - api_client.post( - "/users/register", - json={ - "username": "yetanotheruser", - "email": "yetanother@example.com", - "password": "password", - }, - ) - response = api_client.put( - "/users/me", - headers={"Authorization": f"Bearer {token}"}, - json={ - "email": "yetanother@example.com", - }, - ) - assert response.status_code == 400 - assert response.json() == {"detail": "Email already registered"} - - -def test_forgot_password(api_client): - response = api_client.post( - "/users/forgot-password", json={"email": "nonexistent@example.com"} - ) - assert response.status_code == 200 - assert response.json() == { - "message": "Password reset email sent (not really)"} - - -def test_reset_password(api_client): - # Register a user first - api_client.post( - "/users/register", - json={ - "username": "resetuser", - "email": "reset@example.com", - "password": "oldpassword", - }, - ) - - response = api_client.post( - "/users/reset-password", - json={ - "token": "resetuser", # Use username as token for test - "new_password": "newpassword", - }, - ) - assert response.status_code == 200 - assert response.json() == { - "message": "Password has been reset successfully"} - - # Verify password change - response = api_client.post( - "/users/login", - json={"username": "resetuser", "password": "newpassword"}, - ) - assert response.status_code == 200 - - response = api_client.post( - "/users/login", - json={"username": "resetuser", "password": "oldpassword"}, - ) - assert response.status_code == 401 diff --git a/tests/unit/test_consumption.py b/tests/unit/test_consumption.py deleted file mode 100644 index 9ea7bb3..0000000 --- a/tests/unit/test_consumption.py +++ /dev/null @@ -1,77 +0,0 @@ -from uuid import uuid4 - -import pytest -from fastapi.testclient import TestClient - - -@pytest.fixture -def client(api_client: TestClient) -> TestClient: - return api_client - - -def _create_scenario(client: TestClient) -> int: - payload = { - "name": f"Consumption Scenario {uuid4()}", - "description": "Scenario for consumption tests", - } - response = client.post("/api/scenarios/", json=payload) - assert response.status_code == 200 - return response.json()["id"] - - -def test_create_consumption(client: TestClient) -> None: - scenario_id = _create_scenario(client) - payload = { - "scenario_id": scenario_id, - "amount": 125.5, - "description": "Fuel usage baseline", - "unit_name": "Liters", - "unit_symbol": "L", - } - - response = client.post("/api/consumption/", json=payload) - assert response.status_code == 201 - body = response.json() - assert body["id"] > 0 - assert body["scenario_id"] == scenario_id - assert body["amount"] == pytest.approx(125.5) - assert body["description"] == "Fuel usage baseline" - assert body["unit_symbol"] == "L" - - -def test_list_consumption_returns_created_items(client: TestClient) -> None: - scenario_id = _create_scenario(client) - values = [50.0, 80.75] - for amount in values: - response = client.post( - "/api/consumption/", - json={ - "scenario_id": scenario_id, - "amount": amount, - "description": f"Consumption {amount}", - "unit_name": "Tonnes", - "unit_symbol": "t", - }, - ) - assert response.status_code == 201 - - list_response = client.get("/api/consumption/") - assert list_response.status_code == 200 - items = [ - item - for item in list_response.json() - if item["scenario_id"] == scenario_id - ] - assert {item["amount"] for item in items} == set(values) - - -def test_create_consumption_rejects_negative_amount(client: TestClient) -> None: - scenario_id = _create_scenario(client) - payload = { - "scenario_id": scenario_id, - "amount": -10, - "description": "Invalid negative amount", - } - - response = client.post("/api/consumption/", json=payload) - assert response.status_code == 422 diff --git a/tests/unit/test_costs.py b/tests/unit/test_costs.py deleted file mode 100644 index ae4059c..0000000 --- a/tests/unit/test_costs.py +++ /dev/null @@ -1,123 +0,0 @@ -from uuid import uuid4 - -from fastapi.testclient import TestClient - -from config.database import Base, engine -from main import app - - -def setup_module(module): - Base.metadata.drop_all(bind=engine) - Base.metadata.create_all(bind=engine) - - -def teardown_module(module): - Base.metadata.drop_all(bind=engine) - - -client = TestClient(app) - - -def _create_scenario() -> int: - payload = { - "name": f"CostScenario-{uuid4()}", - "description": "Cost tracking test scenario", - } - response = client.post("/api/scenarios/", json=payload) - assert response.status_code == 200 - return response.json()["id"] - - -def test_create_and_list_capex_and_opex(): - sid = _create_scenario() - - capex_payload = { - "scenario_id": sid, - "amount": 1000.0, - "description": "Initial capex", - "currency_code": "USD", - } - resp2 = client.post("/api/costs/capex", json=capex_payload) - assert resp2.status_code == 200 - capex = resp2.json() - assert capex["scenario_id"] == sid - assert capex["amount"] == 1000.0 - assert capex["currency_code"] == "USD" - - resp3 = client.get("/api/costs/capex") - assert resp3.status_code == 200 - data = resp3.json() - assert any( - item["amount"] == 1000.0 and item["scenario_id"] == sid for item in data - ) - - opex_payload = { - "scenario_id": sid, - "amount": 500.0, - "description": "Recurring opex", - "currency_code": "USD", - } - resp4 = client.post("/api/costs/opex", json=opex_payload) - assert resp4.status_code == 200 - opex = resp4.json() - assert opex["scenario_id"] == sid - assert opex["amount"] == 500.0 - assert opex["currency_code"] == "USD" - - resp5 = client.get("/api/costs/opex") - assert resp5.status_code == 200 - data_o = resp5.json() - assert any( - item["amount"] == 500.0 and item["scenario_id"] == sid - for item in data_o - ) - - -def test_multiple_capex_entries(): - sid = _create_scenario() - amounts = [250.0, 750.0] - for amount in amounts: - resp = client.post( - "/api/costs/capex", - json={ - "scenario_id": sid, - "amount": amount, - "description": f"Capex {amount}", - "currency_code": "EUR", - }, - ) - assert resp.status_code == 200 - - resp = client.get("/api/costs/capex") - assert resp.status_code == 200 - data = resp.json() - retrieved_amounts = [ - item["amount"] for item in data if item["scenario_id"] == sid - ] - for amount in amounts: - assert amount in retrieved_amounts - - -def test_multiple_opex_entries(): - sid = _create_scenario() - amounts = [120.0, 340.0] - for amount in amounts: - resp = client.post( - "/api/costs/opex", - json={ - "scenario_id": sid, - "amount": amount, - "description": f"Opex {amount}", - "currency_code": "CAD", - }, - ) - assert resp.status_code == 200 - - resp = client.get("/api/costs/opex") - assert resp.status_code == 200 - data = resp.json() - retrieved_amounts = [ - item["amount"] for item in data if item["scenario_id"] == sid - ] - for amount in amounts: - assert amount in retrieved_amounts diff --git a/tests/unit/test_currencies.py b/tests/unit/test_currencies.py deleted file mode 100644 index 044571e..0000000 --- a/tests/unit/test_currencies.py +++ /dev/null @@ -1,125 +0,0 @@ -from typing import Dict - -import pytest - -from models.currency import Currency - - -@pytest.fixture(autouse=True) -def _cleanup_currencies(db_session): - db_session.query(Currency).delete() - db_session.commit() - yield - db_session.query(Currency).delete() - db_session.commit() - - -def _assert_currency( - payload: Dict[str, object], - code: str, - name: str, - symbol: str | None, - is_active: bool, -) -> None: - assert payload["code"] == code - assert payload["name"] == name - assert payload["is_active"] is is_active - if symbol is None: - assert payload["symbol"] is None - else: - assert payload["symbol"] == symbol - - -def test_list_returns_default_currency(api_client, db_session): - response = api_client.get("/api/currencies/") - assert response.status_code == 200 - data = response.json() - assert any(item["code"] == "USD" for item in data) - - -def test_create_currency_success(api_client, db_session): - payload = {"code": "EUR", "name": "Euro", "symbol": "€", "is_active": True} - response = api_client.post("/api/currencies/", json=payload) - assert response.status_code == 201 - data = response.json() - _assert_currency(data, "EUR", "Euro", "€", True) - - stored = db_session.query(Currency).filter_by(code="EUR").one() - assert stored.name == "Euro" - assert stored.symbol == "€" - assert stored.is_active is True - - -def test_create_currency_conflict(api_client, db_session): - api_client.post( - "/api/currencies/", - json={ - "code": "CAD", - "name": "Canadian Dollar", - "symbol": "$", - "is_active": True, - }, - ) - duplicate = api_client.post( - "/api/currencies/", - json={ - "code": "CAD", - "name": "Canadian Dollar", - "symbol": "$", - "is_active": True, - }, - ) - assert duplicate.status_code == 409 - - -def test_update_currency_fields(api_client, db_session): - api_client.post( - "/api/currencies/", - json={ - "code": "GBP", - "name": "British Pound", - "symbol": "£", - "is_active": True, - }, - ) - - response = api_client.put( - "/api/currencies/GBP", - json={"name": "Pound Sterling", "symbol": "£", "is_active": False}, - ) - assert response.status_code == 200 - data = response.json() - _assert_currency(data, "GBP", "Pound Sterling", "£", False) - - -def test_toggle_currency_activation(api_client, db_session): - api_client.post( - "/api/currencies/", - json={ - "code": "AUD", - "name": "Australian Dollar", - "symbol": "A$", - "is_active": True, - }, - ) - - response = api_client.patch( - "/api/currencies/AUD/activation", - json={"is_active": False}, - ) - assert response.status_code == 200 - data = response.json() - _assert_currency(data, "AUD", "Australian Dollar", "A$", False) - - -def test_default_currency_cannot_be_deactivated(api_client, db_session): - api_client.get("/api/currencies/") - response = api_client.patch( - "/api/currencies/USD/activation", - json={"is_active": False}, - ) - assert response.status_code == 400 - assert ( - response.json()["detail"] - == "The default currency cannot be deactivated." - ) diff --git a/tests/unit/test_currency_workflow.py b/tests/unit/test_currency_workflow.py deleted file mode 100644 index f43809a..0000000 --- a/tests/unit/test_currency_workflow.py +++ /dev/null @@ -1,75 +0,0 @@ -from uuid import uuid4 - -import pytest - -from models.currency import Currency - - -@pytest.fixture -def seeded_currency(db_session): - currency = Currency(code="GBP", name="British Pound", symbol="GBP") - db_session.add(currency) - db_session.commit() - db_session.refresh(currency) - - try: - yield currency - finally: - db_session.delete(currency) - db_session.commit() - - -def _create_scenario(api_client): - payload = { - "name": f"CurrencyScenario-{uuid4()}", - "description": "Currency workflow scenario", - } - resp = api_client.post("/api/scenarios/", json=payload) - assert resp.status_code == 200 - return resp.json()["id"] - - -def test_create_capex_with_currency_code_and_list(api_client, seeded_currency): - sid = _create_scenario(api_client) - - payload = { - "scenario_id": sid, - "amount": 500.0, - "description": "Capex with GBP", - "currency_code": seeded_currency.code, - } - resp = api_client.post("/api/costs/capex", json=payload) - assert resp.status_code == 200 - data = resp.json() - assert ( - data.get("currency_code") == seeded_currency.code - or data.get("currency", {}).get("code") == seeded_currency.code - ) - - -def test_create_opex_with_currency_id(api_client, seeded_currency): - sid = _create_scenario(api_client) - - resp = api_client.get("/api/currencies/") - assert resp.status_code == 200 - currencies = resp.json() - assert any(c["id"] == seeded_currency.id for c in currencies) - - payload = { - "scenario_id": sid, - "amount": 120.0, - "description": "Opex with explicit id", - "currency_id": seeded_currency.id, - } - resp = api_client.post("/api/costs/opex", json=payload) - assert resp.status_code == 200 - data = resp.json() - assert data["currency_id"] == seeded_currency.id - - -def test_list_currencies_endpoint(api_client, seeded_currency): - resp = api_client.get("/api/currencies/") - assert resp.status_code == 200 - data = resp.json() - assert isinstance(data, list) - assert any(c["id"] == seeded_currency.id for c in data) diff --git a/tests/unit/test_distribution.py b/tests/unit/test_distribution.py deleted file mode 100644 index 1dbb98b..0000000 --- a/tests/unit/test_distribution.py +++ /dev/null @@ -1,71 +0,0 @@ -from uuid import uuid4 - -from fastapi.testclient import TestClient - -from config.database import Base, engine -from main import app - - -def setup_module(module): - Base.metadata.create_all(bind=engine) - - -def teardown_module(module): - Base.metadata.drop_all(bind=engine) - - -client = TestClient(app) - - -def test_create_and_list_distribution(): - dist_name = f"NormalDist-{uuid4()}" - payload = { - "name": dist_name, - "distribution_type": "normal", - "parameters": {"mu": 0, "sigma": 1}, - } - resp = client.post("/api/distributions/", json=payload) - assert resp.status_code == 200 - data = resp.json() - assert data["name"] == dist_name - - resp2 = client.get("/api/distributions/") - assert resp2.status_code == 200 - data2 = resp2.json() - assert any(d["name"] == dist_name for d in data2) - - -def test_duplicate_distribution_name_allowed(): - dist_name = f"DupDist-{uuid4()}" - payload = { - "name": dist_name, - "distribution_type": "uniform", - "parameters": {"min": 0, "max": 1}, - } - first = client.post("/api/distributions/", json=payload) - assert first.status_code == 200 - - duplicate = client.post("/api/distributions/", json=payload) - assert duplicate.status_code == 200 - - resp = client.get("/api/distributions/") - assert resp.status_code == 200 - matching = [item for item in resp.json() if item["name"] == dist_name] - assert len(matching) >= 2 - - -def test_list_distributions_returns_all(): - names = {f"ListDist-{uuid4()}" for _ in range(2)} - for name in names: - payload = { - "name": name, - "distribution_type": "triangular", - "parameters": {"min": 0, "max": 10, "mode": 5}, - } - resp = client.post("/api/distributions/", json=payload) - assert resp.status_code == 200 - - resp = client.get("/api/distributions/") - assert resp.status_code == 200 - found_names = {item["name"] for item in resp.json()} - assert names.issubset(found_names) diff --git a/tests/unit/test_equipment.py b/tests/unit/test_equipment.py deleted file mode 100644 index 2069b53..0000000 --- a/tests/unit/test_equipment.py +++ /dev/null @@ -1,77 +0,0 @@ -from uuid import uuid4 - -import pytest -from fastapi.testclient import TestClient - - -@pytest.fixture -def client(api_client: TestClient) -> TestClient: - return api_client - - -def _create_scenario(client: TestClient) -> int: - payload = { - "name": f"Equipment Scenario {uuid4()}", - "description": "Scenario for equipment tests", - } - response = client.post("/api/scenarios/", json=payload) - assert response.status_code == 200 - return response.json()["id"] - - -def test_create_equipment(client: TestClient) -> None: - scenario_id = _create_scenario(client) - payload = { - "scenario_id": scenario_id, - "name": "Excavator", - "description": "Heavy machinery", - } - - response = client.post("/api/equipment/", json=payload) - assert response.status_code == 200 - created = response.json() - assert created["id"] > 0 - assert created["scenario_id"] == scenario_id - assert created["name"] == "Excavator" - assert created["description"] == "Heavy machinery" - - -def test_list_equipment_filters_by_scenario(client: TestClient) -> None: - target_scenario = _create_scenario(client) - other_scenario = _create_scenario(client) - - for scenario_id, name in [ - (target_scenario, "Bulldozer"), - (target_scenario, "Loader"), - (other_scenario, "Conveyor"), - ]: - response = client.post( - "/api/equipment/", - json={ - "scenario_id": scenario_id, - "name": name, - "description": f"Equipment {name}", - }, - ) - assert response.status_code == 200 - - list_response = client.get("/api/equipment/") - assert list_response.status_code == 200 - items = [ - item - for item in list_response.json() - if item["scenario_id"] == target_scenario - ] - assert {item["name"] for item in items} == {"Bulldozer", "Loader"} - - -def test_create_equipment_requires_name(client: TestClient) -> None: - scenario_id = _create_scenario(client) - response = client.post( - "/api/equipment/", - json={ - "scenario_id": scenario_id, - "description": "Missing name", - }, - ) - assert response.status_code == 422 diff --git a/tests/unit/test_maintenance.py b/tests/unit/test_maintenance.py deleted file mode 100644 index 64e646c..0000000 --- a/tests/unit/test_maintenance.py +++ /dev/null @@ -1,125 +0,0 @@ -from uuid import uuid4 - -import pytest - -from fastapi.testclient import TestClient - - -@pytest.fixture -def client(api_client: TestClient) -> TestClient: - return api_client - - -def _create_scenario_and_equipment(client: TestClient): - scenario_payload = { - "name": f"Test Scenario {uuid4()}", - "description": "Scenario for maintenance tests", - } - scenario_response = client.post("/api/scenarios/", json=scenario_payload) - assert scenario_response.status_code == 200 - scenario_id = scenario_response.json()["id"] - - equipment_payload = { - "scenario_id": scenario_id, - "name": f"Test Equipment {uuid4()}", - "description": "Equipment linked to maintenance", - } - equipment_response = client.post("/api/equipment/", json=equipment_payload) - assert equipment_response.status_code == 200 - equipment_id = equipment_response.json()["id"] - return scenario_id, equipment_id - - -def _create_maintenance_payload( - equipment_id: int, scenario_id: int, description: str -): - return { - "equipment_id": equipment_id, - "scenario_id": scenario_id, - "maintenance_date": "2025-10-20", - "description": description, - "cost": 100.0, - } - - -def test_create_and_list_maintenance(client: TestClient): - scenario_id, equipment_id = _create_scenario_and_equipment(client) - payload = _create_maintenance_payload( - equipment_id, scenario_id, "Create maintenance" - ) - - response = client.post("/api/maintenance/", json=payload) - assert response.status_code == 201 - created = response.json() - assert created["equipment_id"] == equipment_id - assert created["scenario_id"] == scenario_id - assert created["description"] == "Create maintenance" - - list_response = client.get("/api/maintenance/") - assert list_response.status_code == 200 - items = list_response.json() - assert any(item["id"] == created["id"] for item in items) - - -def test_get_maintenance(client: TestClient): - scenario_id, equipment_id = _create_scenario_and_equipment(client) - payload = _create_maintenance_payload( - equipment_id, scenario_id, "Retrieve maintenance" - ) - create_response = client.post("/api/maintenance/", json=payload) - assert create_response.status_code == 201 - maintenance_id = create_response.json()["id"] - - response = client.get(f"/api/maintenance/{maintenance_id}") - assert response.status_code == 200 - data = response.json() - assert data["id"] == maintenance_id - assert data["equipment_id"] == equipment_id - assert data["description"] == "Retrieve maintenance" - - -def test_update_maintenance(client: TestClient): - scenario_id, equipment_id = _create_scenario_and_equipment(client) - create_response = client.post( - "/api/maintenance/", - json=_create_maintenance_payload( - equipment_id, scenario_id, "Maintenance before update" - ), - ) - assert create_response.status_code == 201 - maintenance_id = create_response.json()["id"] - - update_payload = { - "equipment_id": equipment_id, - "scenario_id": scenario_id, - "maintenance_date": "2025-11-01", - "description": "Maintenance after update", - "cost": 250.0, - } - - response = client.put( - f"/api/maintenance/{maintenance_id}", json=update_payload - ) - assert response.status_code == 200 - updated = response.json() - assert updated["maintenance_date"] == "2025-11-01" - assert updated["description"] == "Maintenance after update" - assert updated["cost"] == 250.0 - - -def test_delete_maintenance(client: TestClient): - scenario_id, equipment_id = _create_scenario_and_equipment(client) - create_response = client.post( - "/api/maintenance/", - json=_create_maintenance_payload( - equipment_id, scenario_id, "Delete maintenance" - ), - ) - assert create_response.status_code == 201 - maintenance_id = create_response.json()["id"] - - delete_response = client.delete(f"/api/maintenance/{maintenance_id}") - assert delete_response.status_code == 204 - - get_response = client.get(f"/api/maintenance/{maintenance_id}") - assert get_response.status_code == 404 diff --git a/tests/unit/test_parameters.py b/tests/unit/test_parameters.py deleted file mode 100644 index e1895e7..0000000 --- a/tests/unit/test_parameters.py +++ /dev/null @@ -1,126 +0,0 @@ -from typing import Any, Dict, List -from uuid import uuid4 - -from fastapi.testclient import TestClient - -from config.database import Base, engine -from main import app - - -def setup_module(module: object) -> None: - Base.metadata.create_all(bind=engine) - - -def teardown_module(module: object) -> None: - Base.metadata.drop_all(bind=engine) - - -def _create_scenario(name: str | None = None) -> int: - payload: Dict[str, Any] = { - "name": name or f"ParamScenario-{uuid4()}", - "description": "Parameter test scenario", - } - response = TestClient(app).post("/api/scenarios/", json=payload) - assert response.status_code == 200 - return response.json()["id"] - - -def _create_distribution() -> int: - payload: Dict[str, Any] = { - "name": f"NormalDist-{uuid4()}", - "distribution_type": "normal", - "parameters": {"mu": 10, "sigma": 2}, - } - response = TestClient(app).post("/api/distributions/", json=payload) - assert response.status_code == 200 - return response.json()["id"] - - -client = TestClient(app) - - -def test_create_and_list_parameter(): - scenario_id = _create_scenario() - distribution_id = _create_distribution() - parameter_payload: Dict[str, Any] = { - "scenario_id": scenario_id, - "name": f"param-{uuid4()}", - "value": 3.14, - "distribution_id": distribution_id, - } - - create_response = client.post("/api/parameters/", json=parameter_payload) - assert create_response.status_code == 200 - created = create_response.json() - assert created["scenario_id"] == scenario_id - assert created["name"] == parameter_payload["name"] - assert created["value"] == parameter_payload["value"] - assert created["distribution_id"] == distribution_id - assert created["distribution_type"] == "normal" - assert created["distribution_parameters"] == {"mu": 10, "sigma": 2} - - list_response = client.get("/api/parameters/") - assert list_response.status_code == 200 - params = list_response.json() - assert any(p["id"] == created["id"] for p in params) - - -def test_create_parameter_for_missing_scenario(): - payload: Dict[str, Any] = { - "scenario_id": 0, - "name": "invalid", - "value": 1.0, - } - response = client.post("/api/parameters/", json=payload) - assert response.status_code == 404 - assert response.json()["detail"] == "Scenario not found" - - -def test_multiple_parameters_listed(): - scenario_id = _create_scenario() - payloads: List[Dict[str, Any]] = [ - {"scenario_id": scenario_id, "name": f"alpha-{i}", "value": float(i)} - for i in range(2) - ] - - for payload in payloads: - resp = client.post("/api/parameters/", json=payload) - assert resp.status_code == 200 - - list_response = client.get("/api/parameters/") - assert list_response.status_code == 200 - names = {item["name"] for item in list_response.json()} - for payload in payloads: - assert payload["name"] in names - - -def test_parameter_inline_distribution_metadata(): - scenario_id = _create_scenario() - payload: Dict[str, Any] = { - "scenario_id": scenario_id, - "name": "inline-param", - "value": 7.5, - "distribution_type": "uniform", - "distribution_parameters": {"min": 5, "max": 10}, - } - - response = client.post("/api/parameters/", json=payload) - assert response.status_code == 200 - created = response.json() - assert created["distribution_id"] is None - assert created["distribution_type"] == "uniform" - assert created["distribution_parameters"] == {"min": 5, "max": 10} - - -def test_parameter_with_missing_distribution_reference(): - scenario_id = _create_scenario() - payload: Dict[str, Any] = { - "scenario_id": scenario_id, - "name": "missing-dist", - "value": 1.0, - "distribution_id": 9999, - } - - response = client.post("/api/parameters/", json=payload) - assert response.status_code == 404 - assert response.json()["detail"] == "Distribution not found" diff --git a/tests/unit/test_production.py b/tests/unit/test_production.py deleted file mode 100644 index 106721d..0000000 --- a/tests/unit/test_production.py +++ /dev/null @@ -1,82 +0,0 @@ -from uuid import uuid4 - -import pytest -from fastapi.testclient import TestClient - - -@pytest.fixture -def client(api_client: TestClient) -> TestClient: - return api_client - - -def _create_scenario(client: TestClient) -> int: - payload = { - "name": f"Production Scenario {uuid4()}", - "description": "Scenario for production tests", - } - response = client.post("/api/scenarios/", json=payload) - assert response.status_code == 200 - return response.json()["id"] - - -def test_create_production_record(client: TestClient) -> None: - scenario_id = _create_scenario(client) - payload: dict[str, any] = { - "scenario_id": scenario_id, - "amount": 475.25, - "description": "Daily output", - "unit_name": "Tonnes", - "unit_symbol": "t", - } - - response = client.post("/api/production/", json=payload) - assert response.status_code == 201 - created = response.json() - assert created["scenario_id"] == scenario_id - assert created["amount"] == pytest.approx(475.25) - assert created["description"] == "Daily output" - assert created["unit_symbol"] == "t" - - -def test_list_production_filters_by_scenario(client: TestClient) -> None: - target_scenario = _create_scenario(client) - other_scenario = _create_scenario(client) - - for scenario_id, amount in [ - (target_scenario, 100.0), - (target_scenario, 150.0), - (other_scenario, 200.0), - ]: - response = client.post( - "/api/production/", - json={ - "scenario_id": scenario_id, - "amount": amount, - "description": f"Output {amount}", - "unit_name": "Kilograms", - "unit_symbol": "kg", - }, - ) - assert response.status_code == 201 - - list_response = client.get("/api/production/") - assert list_response.status_code == 200 - items = [ - item - for item in list_response.json() - if item["scenario_id"] == target_scenario - ] - assert {item["amount"] for item in items} == {100.0, 150.0} - - -def test_create_production_rejects_negative_amount(client: TestClient) -> None: - scenario_id = _create_scenario(client) - response = client.post( - "/api/production/", - json={ - "scenario_id": scenario_id, - "amount": -5, - "description": "Invalid output", - }, - ) - assert response.status_code == 422 diff --git a/tests/unit/test_reporting.py b/tests/unit/test_reporting.py deleted file mode 100644 index 45adf38..0000000 --- a/tests/unit/test_reporting.py +++ /dev/null @@ -1,123 +0,0 @@ -import math -from typing import Any, Dict, List - -import pytest - -from fastapi.testclient import TestClient - -from services.reporting import generate_report - - -def test_generate_report_empty(): - report = generate_report([]) - assert report == { - "count": 0, - "mean": 0.0, - "median": 0.0, - "min": 0.0, - "max": 0.0, - "std_dev": 0.0, - "variance": 0.0, - "percentile_10": 0.0, - "percentile_90": 0.0, - "percentile_5": 0.0, - "percentile_95": 0.0, - "value_at_risk_95": 0.0, - "expected_shortfall_95": 0.0, - } - - -def test_generate_report_with_values(): - values: List[Dict[str, float]] = [ - {"iteration": 1, "result": 10.0}, - {"iteration": 2, "result": 20.0}, - {"iteration": 3, "result": 30.0}, - ] - report = generate_report(values) - assert report["count"] == 3 - assert math.isclose(float(report["mean"]), 20.0) - assert math.isclose(float(report["median"]), 20.0) - assert math.isclose(float(report["min"]), 10.0) - assert math.isclose(float(report["max"]), 30.0) - assert math.isclose(float(report["std_dev"]), 8.1649658, rel_tol=1e-6) - assert math.isclose(float(report["variance"]), 66.6666666, rel_tol=1e-6) - assert math.isclose(float(report["percentile_10"]), 12.0) - assert math.isclose(float(report["percentile_90"]), 28.0) - assert math.isclose(float(report["percentile_5"]), 11.0) - assert math.isclose(float(report["percentile_95"]), 29.0) - assert math.isclose(float(report["value_at_risk_95"]), 11.0) - assert math.isclose(float(report["expected_shortfall_95"]), 10.0) - - -def test_generate_report_single_value(): - report = generate_report( - [ - {"iteration": 1, "result": 42.0}, - ] - ) - assert report["count"] == 1 - assert report["std_dev"] == 0.0 - assert report["variance"] == 0.0 - assert report["percentile_10"] == 42.0 - assert report["expected_shortfall_95"] == 42.0 - - -def test_generate_report_ignores_invalid_entries(): - raw_values: List[Any] = [ - {"iteration": 1, "result": 10.0}, - "not-a-mapping", - {"iteration": 2}, - {"iteration": 3, "result": None}, - {"iteration": 4, "result": 20}, - ] - report = generate_report(raw_values) - assert report["count"] == 2 - assert math.isclose(float(report["mean"]), 15.0) - assert math.isclose(float(report["min"]), 10.0) - assert math.isclose(float(report["max"]), 20.0) - - -@pytest.fixture -def client(api_client: TestClient) -> TestClient: - return api_client - - -def test_reporting_endpoint_invalid_input(client: TestClient): - resp = client.post("/api/reporting/summary", json={}) - assert resp.status_code == 400 - assert resp.json()["detail"] == "Invalid input format" - - -def test_reporting_endpoint_success(client: TestClient): - input_data: List[Dict[str, float]] = [ - {"iteration": 1, "result": 10.0}, - {"iteration": 2, "result": 20.0}, - {"iteration": 3, "result": 30.0}, - ] - resp = client.post("/api/reporting/summary", json=input_data) - assert resp.status_code == 200 - data: Dict[str, Any] = resp.json() - assert data["count"] == 3 - assert math.isclose(float(data["mean"]), 20.0) - assert math.isclose(float(data["variance"]), 66.6666666, rel_tol=1e-6) - assert math.isclose(float(data["value_at_risk_95"]), 11.0) - assert math.isclose(float(data["expected_shortfall_95"]), 10.0) - - -validation_error_cases: List[tuple[List[Any], str]] = [ - (["not-a-dict"], "Entry at index 0 must be an object"), - ([{"iteration": 1}], "Entry at index 0 must include numeric 'result'"), - ( - [{"iteration": 1, "result": "bad"}], - "Entry at index 0 must include numeric 'result'", - ), -] - - -@pytest.mark.parametrize("payload,expected_detail", validation_error_cases) -def test_reporting_endpoint_validation_errors( - client: TestClient, payload: List[Any], expected_detail: str -): - resp = client.post("/api/reporting/summary", json=payload) - assert resp.status_code == 400 - assert resp.json()["detail"] == expected_detail diff --git a/tests/unit/test_router_validation.py b/tests/unit/test_router_validation.py deleted file mode 100644 index 4c81b73..0000000 --- a/tests/unit/test_router_validation.py +++ /dev/null @@ -1,94 +0,0 @@ -from typing import Any, Dict - -import pytest -from fastapi.testclient import TestClient - - -@pytest.mark.usefixtures("invalid_request_payloads") -def test_duplicate_scenario_returns_400( - api_client: TestClient, invalid_request_payloads: Dict[str, Any] -) -> None: - payload = invalid_request_payloads["scenario_duplicate"] - response = api_client.post("/api/scenarios/", json=payload) - assert response.status_code == 400 - body = response.json() - assert body["detail"] == "Scenario already exists" - - -@pytest.mark.usefixtures("invalid_request_payloads") -def test_parameter_create_missing_scenario_returns_404( - api_client: TestClient, invalid_request_payloads: Dict[str, Any] -) -> None: - payload = invalid_request_payloads["parameter_missing_scenario"] - response = api_client.post("/api/parameters/", json=payload) - assert response.status_code == 404 - assert response.json()["detail"] == "Scenario not found" - - -@pytest.mark.usefixtures("invalid_request_payloads") -def test_parameter_create_invalid_distribution_is_422( - api_client: TestClient, -) -> None: - response = api_client.post( - "/api/parameters/", - json={ - "scenario_id": 1, - "name": "Bad Dist", - "value": 2.0, - "distribution_type": "invalid", - }, - ) - assert response.status_code == 422 - errors = response.json()["detail"] - assert any("distribution_type" in err["loc"] for err in errors) - - -@pytest.mark.usefixtures("invalid_request_payloads") -def test_simulation_unknown_scenario_returns_404( - api_client: TestClient, invalid_request_payloads: Dict[str, Any] -) -> None: - payload = invalid_request_payloads["simulation_unknown_scenario"] - response = api_client.post("/api/simulations/run", json=payload) - assert response.status_code == 404 - assert response.json()["detail"] == "Scenario not found" - - -@pytest.mark.usefixtures("invalid_request_payloads") -def test_simulation_missing_parameters_returns_400( - api_client: TestClient, invalid_request_payloads: Dict[str, Any] -) -> None: - payload = invalid_request_payloads["simulation_missing_parameters"] - response = api_client.post("/api/simulations/run", json=payload) - assert response.status_code == 400 - assert response.json()["detail"] == "No parameters provided" - - -@pytest.mark.usefixtures("invalid_request_payloads") -def test_reporting_summary_rejects_non_list_payload( - api_client: TestClient, invalid_request_payloads: Dict[str, Any] -) -> None: - payload = invalid_request_payloads["reporting_non_list_payload"] - response = api_client.post("/api/reporting/summary", json=payload) - assert response.status_code == 400 - assert response.json()["detail"] == "Invalid input format" - - -@pytest.mark.usefixtures("invalid_request_payloads") -def test_reporting_summary_requires_result_field( - api_client: TestClient, invalid_request_payloads: Dict[str, Any] -) -> None: - payload = invalid_request_payloads["reporting_missing_result"] - response = api_client.post("/api/reporting/summary", json=payload) - assert response.status_code == 400 - assert "must include numeric 'result'" in response.json()["detail"] - - -@pytest.mark.usefixtures("invalid_request_payloads") -def test_maintenance_negative_cost_rejected_by_schema( - api_client: TestClient, invalid_request_payloads: Dict[str, Any] -) -> None: - payload = invalid_request_payloads["maintenance_negative_cost"] - response = api_client.post("/api/maintenance/", json=payload) - assert response.status_code == 422 - error_locations = [tuple(item["loc"]) for item in response.json()["detail"]] - assert ("body", "cost") in error_locations diff --git a/tests/unit/test_scenario.py b/tests/unit/test_scenario.py deleted file mode 100644 index fce4a28..0000000 --- a/tests/unit/test_scenario.py +++ /dev/null @@ -1,45 +0,0 @@ -from uuid import uuid4 - -from fastapi.testclient import TestClient - -from config.database import Base, engine -from main import app - - -def setup_module(module): - Base.metadata.create_all(bind=engine) - - -def teardown_module(module): - Base.metadata.drop_all(bind=engine) - - -client = TestClient(app) - - -def test_create_and_list_scenario(): - scenario_name = f"Scenario-{uuid4()}" - response = client.post( - "/api/scenarios/", - json={"name": scenario_name, "description": "Integration test"}, - ) - assert response.status_code == 200 - data = response.json() - assert data["name"] == scenario_name - - response2 = client.get("/api/scenarios/") - assert response2.status_code == 200 - data2 = response2.json() - assert any(s["name"] == scenario_name for s in data2) - - -def test_create_duplicate_scenario_rejected(): - scenario_name = f"Duplicate-{uuid4()}" - payload = {"name": scenario_name, "description": "Primary"} - - first_resp = client.post("/api/scenarios/", json=payload) - assert first_resp.status_code == 200 - - second_resp = client.post("/api/scenarios/", json=payload) - assert second_resp.status_code == 400 - assert second_resp.json()["detail"] == "Scenario already exists" diff --git a/tests/unit/test_seed_data.py b/tests/unit/test_seed_data.py deleted file mode 100644 index 87094b9..0000000 --- a/tests/unit/test_seed_data.py +++ /dev/null @@ -1,46 +0,0 @@ -import argparse -from unittest import mock - -import scripts.seed_data as seed_data -from scripts.seed_data import DatabaseConfig - - -def test_run_with_namespace_handles_missing_theme_flag_without_actions() -> None: - args = argparse.Namespace(currencies=False, units=False, defaults=False) - config = mock.create_autospec(DatabaseConfig) - config.application_dsn.return_value = "postgresql://example" - - with ( - mock.patch("scripts.seed_data._configure_logging") as configure_logging, - mock.patch("scripts.seed_data.psycopg2.connect") as connect_mock, - mock.patch.object(seed_data.logger, "info") as info_mock, - ): - seed_data.run_with_namespace(args, config=config) - - configure_logging.assert_called_once() - connect_mock.assert_not_called() - info_mock.assert_called_with("No seeding options provided; exiting") - - -def test_run_with_namespace_seeds_defaults_without_theme_flag() -> None: - args = argparse.Namespace( - currencies=False, units=False, defaults=True, dry_run=False) - config = mock.create_autospec(DatabaseConfig) - config.application_dsn.return_value = "postgresql://example" - - connection_mock = mock.MagicMock() - cursor_context = mock.MagicMock() - cursor_mock = mock.MagicMock() - connection_mock.__enter__.return_value = connection_mock - connection_mock.cursor.return_value = cursor_context - cursor_context.__enter__.return_value = cursor_mock - - with ( - mock.patch("scripts.seed_data._configure_logging"), - mock.patch("scripts.seed_data.psycopg2.connect", return_value=connection_mock) as connect_mock, - mock.patch("scripts.seed_data._seed_defaults") as seed_defaults, - ): - seed_data.run_with_namespace(args, config=config) - - connect_mock.assert_called_once_with(config.application_dsn()) - seed_defaults.assert_called_once_with(cursor_mock, dry_run=False) diff --git a/tests/unit/test_settings_routes.py b/tests/unit/test_settings_routes.py deleted file mode 100644 index 81a1aa9..0000000 --- a/tests/unit/test_settings_routes.py +++ /dev/null @@ -1,53 +0,0 @@ -import pytest -from fastapi.testclient import TestClient -from sqlalchemy.orm import Session - -from services import settings as settings_service - - -@pytest.mark.usefixtures("db_session") -def test_read_css_settings_reflects_env_overrides( - api_client: TestClient, monkeypatch: pytest.MonkeyPatch -) -> None: - env_var = settings_service.css_key_to_env_var("--color-background") - monkeypatch.setenv(env_var, "#123456") - - response = api_client.get("/api/settings/css") - assert response.status_code == 200 - body = response.json() - - assert body["variables"]["--color-background"] == "#123456" - assert body["env_overrides"]["--color-background"] == "#123456" - assert any( - source["env_var"] == env_var and source["value"] == "#123456" - for source in body["env_sources"] - ) - - -@pytest.mark.usefixtures("db_session") -def test_update_css_settings_persists_changes( - api_client: TestClient, db_session: Session -) -> None: - payload = {"variables": {"--color-primary": "#112233"}} - - response = api_client.put("/api/settings/css", json=payload) - assert response.status_code == 200 - body = response.json() - - assert body["variables"]["--color-primary"] == "#112233" - - persisted = settings_service.get_css_color_settings(db_session) - assert persisted["--color-primary"] == "#112233" - - -@pytest.mark.usefixtures("db_session") -def test_update_css_settings_invalid_value_returns_422( - api_client: TestClient, -) -> None: - response = api_client.put( - "/api/settings/css", - json={"variables": {"--color-primary": "not-a-color"}}, - ) - assert response.status_code == 422 - body = response.json() - assert "color" in body["detail"].lower() diff --git a/tests/unit/test_settings_service.py b/tests/unit/test_settings_service.py deleted file mode 100644 index 8066c06..0000000 --- a/tests/unit/test_settings_service.py +++ /dev/null @@ -1,149 +0,0 @@ -from types import SimpleNamespace -from typing import Dict - -import pytest - -from sqlalchemy.orm import Session - -from models.application_setting import ApplicationSetting -from services import settings as settings_service -from services.settings import CSS_COLOR_DEFAULTS - - -@pytest.fixture(name="clean_env") -def fixture_clean_env(monkeypatch: pytest.MonkeyPatch) -> Dict[str, str]: - """Provide an isolated environment mapping for tests.""" - - env: Dict[str, str] = {} - monkeypatch.setattr(settings_service, "os", SimpleNamespace(environ=env)) - return env - - -def test_css_key_to_env_var_formatting(): - assert ( - settings_service.css_key_to_env_var("--color-background") - == "CALMINER_THEME_COLOR_BACKGROUND" - ) - assert ( - settings_service.css_key_to_env_var("--color-primary-stronger") - == "CALMINER_THEME_COLOR_PRIMARY_STRONGER" - ) - - -@pytest.mark.parametrize( - "env_key,env_value", - [ - ("--color-background", "#ffffff"), - ("--color-primary", "rgb(10, 20, 30)"), - ("--color-accent", "rgba(1,2,3,0.5)"), - ("--color-text-secondary", "hsla(210, 40%, 40%, 1)"), - ], -) -def test_read_css_color_env_overrides_valid_values( - clean_env, env_key, env_value -): - env_var = settings_service.css_key_to_env_var(env_key) - clean_env[env_var] = env_value - - overrides = settings_service.read_css_color_env_overrides(clean_env) - assert overrides[env_key] == env_value - - -@pytest.mark.parametrize( - "invalid_value", - [ - "", # empty - "not-a-color", # arbitrary string - "#12", # short hex - "rgb(1,2)", # malformed rgb - ], -) -def test_read_css_color_env_overrides_invalid_values_raise( - clean_env, invalid_value -): - env_var = settings_service.css_key_to_env_var("--color-background") - clean_env[env_var] = invalid_value - - with pytest.raises(ValueError): - settings_service.read_css_color_env_overrides(clean_env) - - -def test_read_css_color_env_overrides_ignores_missing(clean_env): - overrides = settings_service.read_css_color_env_overrides(clean_env) - assert overrides == {} - - -def test_list_css_env_override_rows_returns_structured_data(clean_env): - clean_env[settings_service.css_key_to_env_var("--color-primary")] = ( - "#123456" - ) - rows = settings_service.list_css_env_override_rows(clean_env) - assert rows == [ - { - "css_key": "--color-primary", - "env_var": settings_service.css_key_to_env_var("--color-primary"), - "value": "#123456", - } - ] - - -def test_normalize_color_value_strips_and_validates(): - assert settings_service._normalize_color_value(" #abcdef ") == "#abcdef" - with pytest.raises(ValueError): - settings_service._normalize_color_value(123) # type: ignore[arg-type] - with pytest.raises(ValueError): - settings_service._normalize_color_value(" ") - with pytest.raises(ValueError): - settings_service._normalize_color_value("#12") - - -def test_ensure_css_color_settings_creates_defaults(db_session: Session): - settings_service.ensure_css_color_settings(db_session) - - stored = { - record.key: record.value - for record in db_session.query(ApplicationSetting).all() - } - assert set(stored.keys()) == set(CSS_COLOR_DEFAULTS.keys()) - assert stored == CSS_COLOR_DEFAULTS - - -def test_update_css_color_settings_persists_changes(db_session: Session): - settings_service.ensure_css_color_settings(db_session) - - updated = settings_service.update_css_color_settings( - db_session, - {"--color-background": "#000000", "--color-accent": "#abcdef"}, - ) - - assert updated["--color-background"] == "#000000" - assert updated["--color-accent"] == "#abcdef" - - stored = { - record.key: record.value - for record in db_session.query(ApplicationSetting).all() - } - assert stored["--color-background"] == "#000000" - assert stored["--color-accent"] == "#abcdef" - - -def test_get_css_color_settings_respects_env_overrides( - db_session: Session, clean_env: Dict[str, str] -): - settings_service.ensure_css_color_settings(db_session) - override_value = "#112233" - clean_env[settings_service.css_key_to_env_var("--color-background")] = ( - override_value - ) - - values = settings_service.get_css_color_settings(db_session) - - assert values["--color-background"] == override_value - - db_value = ( - db_session.query(ApplicationSetting) - .filter_by(key="--color-background") - .one() - .value - ) - assert db_value != override_value diff --git a/tests/unit/test_setup_database.py b/tests/unit/test_setup_database.py deleted file mode 100644 index efea513..0000000 --- a/tests/unit/test_setup_database.py +++ /dev/null @@ -1,547 +0,0 @@ -import argparse -from unittest import mock - -import psycopg2 -import pytest -from psycopg2 import errors as psycopg_errors - -import scripts.setup_database as setup_db_module - -from scripts import seed_data -from scripts.setup_database import DatabaseConfig, DatabaseSetup - - -@pytest.fixture() -def mock_config() -> DatabaseConfig: - return DatabaseConfig( - driver="postgresql", - host="localhost", - port=5432, - database="calminer_test", - user="calminer", - password="secret", - schema="public", - admin_user="postgres", - admin_password="secret", - ) - - -@pytest.fixture() -def setup_instance(mock_config: DatabaseConfig) -> DatabaseSetup: - return DatabaseSetup(mock_config, dry_run=True) - - -def test_seed_baseline_data_dry_run_skips_verification( - setup_instance: DatabaseSetup, -) -> None: - with ( - mock.patch("scripts.seed_data.run_with_namespace") as seed_run, - mock.patch.object(setup_instance, "_verify_seeded_data") as verify_mock, - ): - setup_instance.seed_baseline_data(dry_run=True) - - seed_run.assert_called_once() - namespace_arg = seed_run.call_args[0][0] - assert isinstance(namespace_arg, argparse.Namespace) - assert namespace_arg.dry_run is True - assert namespace_arg.currencies is True - assert namespace_arg.units is True - assert namespace_arg.theme is True - assert seed_run.call_args.kwargs["config"] is setup_instance.config - verify_mock.assert_not_called() - - -def test_seed_baseline_data_invokes_verification( - setup_instance: DatabaseSetup, -) -> None: - expected_currencies = {code for code, *_ in seed_data.CURRENCY_SEEDS} - expected_units = {code for code, *_ in seed_data.MEASUREMENT_UNIT_SEEDS} - - with ( - mock.patch("scripts.seed_data.run_with_namespace") as seed_run, - mock.patch.object(setup_instance, "_verify_seeded_data") as verify_mock, - ): - setup_instance.seed_baseline_data(dry_run=False) - - seed_run.assert_called_once() - namespace_arg = seed_run.call_args[0][0] - assert isinstance(namespace_arg, argparse.Namespace) - assert namespace_arg.dry_run is False - assert seed_run.call_args.kwargs["config"] is setup_instance.config - assert namespace_arg.theme is True - verify_mock.assert_called_once_with( - expected_currency_codes=expected_currencies, - expected_unit_codes=expected_units, - ) - - -def test_run_migrations_applies_baseline_when_missing( - mock_config: DatabaseConfig, tmp_path -) -> None: - setup_instance = DatabaseSetup(mock_config, dry_run=False) - - baseline = tmp_path / "000_base.sql" - baseline.write_text("SELECT 1;", encoding="utf-8") - other_migration = tmp_path / "20251022_add_other.sql" - other_migration.write_text("SELECT 2;", encoding="utf-8") - - migration_calls: list[str] = [] - - def capture_migration(cursor, schema_name: str, path): - migration_calls.append(path.name) - return path.name - - connection_mock = mock.MagicMock() - connection_mock.__enter__.return_value = connection_mock - cursor_context = mock.MagicMock() - cursor_mock = mock.MagicMock() - cursor_context.__enter__.return_value = cursor_mock - connection_mock.cursor.return_value = cursor_context - - with ( - mock.patch.object( - setup_instance, - "_application_connection", - return_value=connection_mock, - ), - mock.patch.object( - setup_instance, "_migrations_table_exists", return_value=True - ), - mock.patch.object( - setup_instance, "_fetch_applied_migrations", return_value=set() - ), - mock.patch.object( - setup_instance, - "_apply_migration_file", - side_effect=capture_migration, - ) as apply_mock, - ): - setup_instance.run_migrations(tmp_path) - - assert apply_mock.call_count == 1 - assert migration_calls == ["000_base.sql"] - legacy_marked = any( - call.args[1] == ("20251022_add_other.sql",) - for call in cursor_mock.execute.call_args_list - if len(call.args) == 2 - ) - assert legacy_marked - - -def test_run_migrations_noop_when_all_files_already_applied( - mock_config: DatabaseConfig, tmp_path -) -> None: - setup_instance = DatabaseSetup(mock_config, dry_run=False) - - baseline = tmp_path / "000_base.sql" - baseline.write_text("SELECT 1;", encoding="utf-8") - other_migration = tmp_path / "20251022_add_other.sql" - other_migration.write_text("SELECT 2;", encoding="utf-8") - - connection_mock, cursor_mock = _connection_with_cursor() - - with ( - mock.patch.object( - setup_instance, - "_application_connection", - return_value=connection_mock, - ), - mock.patch.object( - setup_instance, "_migrations_table_exists", return_value=True - ), - mock.patch.object( - setup_instance, - "_fetch_applied_migrations", - return_value={"000_base.sql", "20251022_add_other.sql"}, - ), - mock.patch.object( - setup_instance, "_apply_migration_file" - ) as apply_mock, - ): - setup_instance.run_migrations(tmp_path) - - apply_mock.assert_not_called() - cursor_mock.execute.assert_not_called() - - -def _connection_with_cursor() -> tuple[mock.MagicMock, mock.MagicMock]: - connection_mock = mock.MagicMock() - connection_mock.__enter__.return_value = connection_mock - cursor_context = mock.MagicMock() - cursor_mock = mock.MagicMock() - cursor_context.__enter__.return_value = cursor_mock - connection_mock.cursor.return_value = cursor_context - return connection_mock, cursor_mock - - -def test_verify_seeded_data_raises_when_currency_missing( - mock_config: DatabaseConfig, -) -> None: - setup_instance = DatabaseSetup(mock_config, dry_run=False) - connection_mock, cursor_mock = _connection_with_cursor() - cursor_mock.fetchall.return_value = [("USD", True)] - - with mock.patch.object( - setup_instance, "_application_connection", return_value=connection_mock - ): - with pytest.raises(RuntimeError) as exc: - setup_instance._verify_seeded_data( - expected_currency_codes={"USD", "EUR"}, - expected_unit_codes=set(), - ) - - assert "EUR" in str(exc.value) - - -def test_verify_seeded_data_raises_when_default_currency_inactive( - mock_config: DatabaseConfig, -) -> None: - setup_instance = DatabaseSetup(mock_config, dry_run=False) - connection_mock, cursor_mock = _connection_with_cursor() - cursor_mock.fetchall.return_value = [("USD", False)] - - with mock.patch.object( - setup_instance, "_application_connection", return_value=connection_mock - ): - with pytest.raises(RuntimeError) as exc: - setup_instance._verify_seeded_data( - expected_currency_codes={"USD"}, - expected_unit_codes=set(), - ) - - assert "inactive" in str(exc.value) - - -def test_verify_seeded_data_raises_when_units_missing( - mock_config: DatabaseConfig, -) -> None: - setup_instance = DatabaseSetup(mock_config, dry_run=False) - connection_mock, cursor_mock = _connection_with_cursor() - cursor_mock.fetchall.return_value = [("tonnes", True)] - - with mock.patch.object( - setup_instance, "_application_connection", return_value=connection_mock - ): - with pytest.raises(RuntimeError) as exc: - setup_instance._verify_seeded_data( - expected_currency_codes=set(), - expected_unit_codes={"tonnes", "liters"}, - ) - - assert "liters" in str(exc.value) - - -def test_verify_seeded_data_raises_when_measurement_table_missing( - mock_config: DatabaseConfig, -) -> None: - setup_instance = DatabaseSetup(mock_config, dry_run=False) - connection_mock, cursor_mock = _connection_with_cursor() - cursor_mock.execute.side_effect = psycopg_errors.UndefinedTable( - "relation does not exist" - ) - - with mock.patch.object( - setup_instance, "_application_connection", return_value=connection_mock - ): - with pytest.raises(RuntimeError) as exc: - setup_instance._verify_seeded_data( - expected_currency_codes=set(), - expected_unit_codes={"tonnes"}, - ) - - assert "measurement_unit" in str(exc.value) - connection_mock.rollback.assert_called_once() - - -def test_seed_baseline_data_rerun_uses_existing_records( - mock_config: DatabaseConfig, -) -> None: - setup_instance = DatabaseSetup(mock_config, dry_run=False) - - connection_mock, cursor_mock = _connection_with_cursor() - - currency_rows = [(code, True) for code, *_ in seed_data.CURRENCY_SEEDS] - unit_rows = [(code, True) for code, *_ in seed_data.MEASUREMENT_UNIT_SEEDS] - - cursor_mock.fetchall.side_effect = [ - currency_rows, - unit_rows, - currency_rows, - unit_rows, - ] - - with ( - mock.patch.object( - setup_instance, - "_application_connection", - return_value=connection_mock, - ), - mock.patch("scripts.seed_data.run_with_namespace") as seed_run, - ): - setup_instance.seed_baseline_data(dry_run=False) - setup_instance.seed_baseline_data(dry_run=False) - - assert seed_run.call_count == 2 - first_namespace = seed_run.call_args_list[0].args[0] - assert isinstance(first_namespace, argparse.Namespace) - assert first_namespace.dry_run is False - assert seed_run.call_args_list[0].kwargs["config"] is setup_instance.config - assert cursor_mock.execute.call_count == 4 - - -def test_ensure_database_raises_with_context( - mock_config: DatabaseConfig, -) -> None: - setup_instance = DatabaseSetup(mock_config, dry_run=False) - connection_mock = mock.MagicMock() - cursor_mock = mock.MagicMock() - cursor_mock.fetchone.return_value = None - cursor_mock.execute.side_effect = [None, psycopg2.Error("create_fail")] - connection_mock.cursor.return_value = cursor_mock - - with mock.patch.object( - setup_instance, "_admin_connection", return_value=connection_mock - ): - with pytest.raises(RuntimeError) as exc: - setup_instance.ensure_database() - - assert "Failed to create database" in str(exc.value) - - -def test_ensure_role_raises_with_context_during_creation( - mock_config: DatabaseConfig, -) -> None: - setup_instance = DatabaseSetup(mock_config, dry_run=False) - - admin_conn, admin_cursor = _connection_with_cursor() - admin_cursor.fetchone.return_value = None - admin_cursor.execute.side_effect = [None, psycopg2.Error("role_fail")] - - with mock.patch.object( - setup_instance, - "_admin_connection", - side_effect=[admin_conn], - ): - with pytest.raises(RuntimeError) as exc: - setup_instance.ensure_role() - - assert "Failed to create role" in str(exc.value) - - -def test_ensure_role_raises_with_context_during_privilege_grants( - mock_config: DatabaseConfig, -) -> None: - setup_instance = DatabaseSetup(mock_config, dry_run=False) - - admin_conn, admin_cursor = _connection_with_cursor() - admin_cursor.fetchone.return_value = (1,) - - privilege_conn, privilege_cursor = _connection_with_cursor() - privilege_cursor.execute.side_effect = [psycopg2.Error("grant_fail")] - - with mock.patch.object( - setup_instance, - "_admin_connection", - side_effect=[admin_conn, privilege_conn], - ): - with pytest.raises(RuntimeError) as exc: - setup_instance.ensure_role() - - assert "Failed to grant privileges" in str(exc.value) - - -def test_ensure_database_dry_run_skips_creation( - mock_config: DatabaseConfig, -) -> None: - setup_instance = DatabaseSetup(mock_config, dry_run=True) - - connection_mock = mock.MagicMock() - cursor_mock = mock.MagicMock() - cursor_mock.fetchone.return_value = None - connection_mock.cursor.return_value = cursor_mock - - with ( - mock.patch.object( - setup_instance, "_admin_connection", return_value=connection_mock - ), - mock.patch("scripts.setup_database.logger") as logger_mock, - ): - setup_instance.ensure_database() - - # expect only existence check, no create attempt - cursor_mock.execute.assert_called_once() - logger_mock.info.assert_any_call( - "Dry run: would create database '%s'. Run without --dry-run to proceed.", - mock_config.database, - ) - - -def test_ensure_role_dry_run_skips_creation_and_grants( - mock_config: DatabaseConfig, -) -> None: - setup_instance = DatabaseSetup(mock_config, dry_run=True) - - admin_conn, admin_cursor = _connection_with_cursor() - admin_cursor.fetchone.return_value = None - - with ( - mock.patch.object( - setup_instance, - "_admin_connection", - side_effect=[admin_conn], - ) as conn_mock, - mock.patch("scripts.setup_database.logger") as logger_mock, - ): - setup_instance.ensure_role() - - assert conn_mock.call_count == 1 - admin_cursor.execute.assert_called_once() - logger_mock.info.assert_any_call( - "Dry run: would create role '%s'. Run without --dry-run to apply.", - mock_config.user, - ) - - -def test_register_rollback_skipped_when_dry_run( - mock_config: DatabaseConfig, -) -> None: - setup_instance = DatabaseSetup(mock_config, dry_run=True) - setup_instance._register_rollback("noop", lambda: None) - assert setup_instance._rollback_actions == [] - - -def test_execute_rollbacks_runs_in_reverse_order( - mock_config: DatabaseConfig, -) -> None: - setup_instance = DatabaseSetup(mock_config, dry_run=False) - - calls: list[str] = [] - - def first_action() -> None: - calls.append("first") - - def second_action() -> None: - calls.append("second") - - setup_instance._register_rollback("first", first_action) - setup_instance._register_rollback("second", second_action) - - with mock.patch("scripts.setup_database.logger"): - setup_instance.execute_rollbacks() - - assert calls == ["second", "first"] - assert setup_instance._rollback_actions == [] - - -def test_ensure_database_registers_rollback_action( - mock_config: DatabaseConfig, -) -> None: - setup_instance = DatabaseSetup(mock_config, dry_run=False) - connection_mock = mock.MagicMock() - cursor_mock = mock.MagicMock() - cursor_mock.fetchone.return_value = None - connection_mock.cursor.return_value = cursor_mock - - with ( - mock.patch.object( - setup_instance, "_admin_connection", return_value=connection_mock - ), - mock.patch.object( - setup_instance, "_register_rollback" - ) as register_mock, - mock.patch.object(setup_instance, "_drop_database") as drop_mock, - ): - setup_instance.ensure_database() - register_mock.assert_called_once() - label, action = register_mock.call_args[0] - assert "drop database" in label - action() - drop_mock.assert_called_once_with(mock_config.database) - - -def test_ensure_role_registers_rollback_actions( - mock_config: DatabaseConfig, -) -> None: - setup_instance = DatabaseSetup(mock_config, dry_run=False) - - admin_conn, admin_cursor = _connection_with_cursor() - admin_cursor.fetchone.return_value = None - privilege_conn, privilege_cursor = _connection_with_cursor() - - with ( - mock.patch.object( - setup_instance, - "_admin_connection", - side_effect=[admin_conn, privilege_conn], - ), - mock.patch.object( - setup_instance, "_register_rollback" - ) as register_mock, - mock.patch.object(setup_instance, "_drop_role") as drop_mock, - mock.patch.object( - setup_instance, "_revoke_role_privileges" - ) as revoke_mock, - ): - setup_instance.ensure_role() - assert register_mock.call_count == 2 - drop_label, drop_action = register_mock.call_args_list[0][0] - revoke_label, revoke_action = register_mock.call_args_list[1][0] - - assert "drop role" in drop_label - assert "revoke privileges" in revoke_label - - drop_action() - drop_mock.assert_called_once_with(mock_config.user) - - revoke_action() - revoke_mock.assert_called_once() - - -def test_main_triggers_rollbacks_on_failure( - mock_config: DatabaseConfig, -) -> None: - args = argparse.Namespace( - ensure_database=True, - ensure_role=True, - ensure_schema=False, - initialize_schema=False, - run_migrations=False, - seed_data=False, - migrations_dir=None, - db_driver=None, - db_host=None, - db_port=None, - db_name=None, - db_user=None, - db_password=None, - db_schema=None, - admin_url=None, - admin_user=None, - admin_password=None, - admin_db=None, - dry_run=False, - verbose=0, - ) - - with ( - mock.patch.object(setup_db_module, "parse_args", return_value=args), - mock.patch.object( - setup_db_module.DatabaseConfig, "from_env", return_value=mock_config - ), - mock.patch.object(setup_db_module, "DatabaseSetup") as setup_cls, - ): - setup_instance = mock.MagicMock() - setup_instance.dry_run = False - setup_instance._rollback_actions = [ - ("drop role", mock.MagicMock()), - ] - setup_instance.ensure_database.side_effect = RuntimeError("boom") - setup_instance.execute_rollbacks = mock.MagicMock() - setup_instance.clear_rollbacks = mock.MagicMock() - setup_cls.return_value = setup_instance - - with pytest.raises(RuntimeError): - setup_db_module.main() - - setup_instance.execute_rollbacks.assert_called_once() - setup_instance.clear_rollbacks.assert_called_once() diff --git a/tests/unit/test_simulation.py b/tests/unit/test_simulation.py deleted file mode 100644 index b89febe..0000000 --- a/tests/unit/test_simulation.py +++ /dev/null @@ -1,232 +0,0 @@ -from math import isclose -from random import Random -from uuid import uuid4 - -import pytest -from fastapi.testclient import TestClient -from sqlalchemy.orm import Session - -from typing import Any, Dict, List - -from models.simulation_result import SimulationResult -from services.simulation import DEFAULT_UNIFORM_SPAN_RATIO, run_simulation - - -@pytest.fixture -def client(api_client: TestClient) -> TestClient: - return api_client - - -def test_run_simulation_function_generates_samples(): - params: List[Dict[str, Any]] = [ - { - "name": "grade", - "value": 1.8, - "distribution": "normal", - "std_dev": 0.2, - }, - { - "name": "recovery", - "value": 0.9, - "distribution": "uniform", - "min": 0.8, - "max": 0.95, - }, - ] - results = run_simulation(params, iterations=5, seed=123) - assert isinstance(results, list) - assert len(results) == 5 - assert results[0]["iteration"] == 1 - - -def test_run_simulation_with_zero_iterations_returns_empty(): - params: List[Dict[str, Any]] = [ - {"name": "grade", "value": 1.2, "distribution": "normal"} - ] - results = run_simulation(params, iterations=0) - assert results == [] - - -@pytest.mark.parametrize( - "parameter_payload,error_message", - [ - ( - {"name": "missing-value"}, - "Parameter at index 0 must include 'value'", - ), - ( - { - "name": "bad-dist", - "value": 1.0, - "distribution": "unsupported", - }, - "Parameter 'bad-dist' has unsupported distribution 'unsupported'", - ), - ( - { - "name": "uniform-range", - "value": 1.0, - "distribution": "uniform", - "min": 5, - "max": 5, - }, - "Parameter 'uniform-range' requires 'min' < 'max' for uniform distribution", - ), - ( - { - "name": "triangular-mode", - "value": 5.0, - "distribution": "triangular", - "min": 1, - "max": 3, - "mode": 5, - }, - "Parameter 'triangular-mode' mode must be within min/max bounds for triangular distribution", - ), - ], -) -def test_run_simulation_parameter_validation_errors( - parameter_payload: Dict[str, Any], error_message: str -) -> None: - with pytest.raises(ValueError) as exc: - run_simulation([parameter_payload]) - assert str(exc.value) == error_message - - -def test_run_simulation_normal_std_dev_fallback(): - params: List[Dict[str, Any]] = [ - { - "name": "std-dev-fallback", - "value": 10.0, - "distribution": "normal", - "std_dev": 0, - } - ] - results = run_simulation(params, iterations=3, seed=99) - assert len(results) == 3 - assert all("result" in entry for entry in results) - - -def test_run_simulation_triangular_sampling_path(): - params: List[Dict[str, Any]] = [ - {"name": "tri", "value": 10.0, "distribution": "triangular"} - ] - seed = 21 - iterations = 4 - results = run_simulation(params, iterations=iterations, seed=seed) - assert len(results) == iterations - span = 10.0 * DEFAULT_UNIFORM_SPAN_RATIO - rng = Random(seed) - expected_samples = [ - rng.triangular(10.0 - span, 10.0 + span, 10.0) - for _ in range(iterations) - ] - actual_samples = [entry["result"] for entry in results] - for actual, expected in zip(actual_samples, expected_samples): - assert isclose(actual, expected, rel_tol=1e-9) - - -def test_run_simulation_uniform_defaults_apply_bounds(): - params: List[Dict[str, Any]] = [ - {"name": "uniform-auto", "value": 200.0, "distribution": "uniform"} - ] - seed = 17 - iterations = 3 - results = run_simulation(params, iterations=iterations, seed=seed) - assert len(results) == iterations - span = 200.0 * DEFAULT_UNIFORM_SPAN_RATIO - rng = Random(seed) - expected_samples = [ - rng.uniform(200.0 - span, 200.0 + span) for _ in range(iterations) - ] - actual_samples = [entry["result"] for entry in results] - for actual, expected in zip(actual_samples, expected_samples): - assert isclose(actual, expected, rel_tol=1e-9) - - -def test_run_simulation_without_parameters_returns_empty(): - assert run_simulation([], iterations=5) == [] - - -def test_simulation_endpoint_no_params(client: TestClient): - scenario_payload: Dict[str, Any] = { - "name": f"NoParamScenario-{uuid4()}", - "description": "No parameters run", - } - scenario_resp = client.post("/api/scenarios/", json=scenario_payload) - assert scenario_resp.status_code == 200 - scenario_id = scenario_resp.json()["id"] - - resp = client.post( - "/api/simulations/run", - json={"scenario_id": scenario_id, "parameters": [], "iterations": 10}, - ) - assert resp.status_code == 400 - assert resp.json()["detail"] == "No parameters provided" - - -def test_simulation_endpoint_success(client: TestClient, db_session: Session): - scenario_payload: Dict[str, Any] = { - "name": f"SimScenario-{uuid4()}", - "description": "Simulation test", - } - scenario_resp = client.post("/api/scenarios/", json=scenario_payload) - assert scenario_resp.status_code == 200 - scenario_id = scenario_resp.json()["id"] - - params: List[Dict[str, Any]] = [ - { - "name": "param1", - "value": 2.5, - "distribution": "normal", - "std_dev": 0.5, - } - ] - payload: Dict[str, Any] = { - "scenario_id": scenario_id, - "parameters": params, - "iterations": 10, - "seed": 42, - } - - resp = client.post("/api/simulations/run", json=payload) - assert resp.status_code == 200 - data = resp.json() - assert data["scenario_id"] == scenario_id - assert len(data["results"]) == 10 - assert data["summary"]["count"] == 10 - - db_session.expire_all() - persisted = ( - db_session.query(SimulationResult) - .filter(SimulationResult.scenario_id == scenario_id) - .all() - ) - assert len(persisted) == 10 - - -def test_simulation_endpoint_uses_stored_parameters(client: TestClient): - scenario_payload: Dict[str, Any] = { - "name": f"StoredParams-{uuid4()}", - "description": "Stored parameter simulation", - } - scenario_resp = client.post("/api/scenarios/", json=scenario_payload) - assert scenario_resp.status_code == 200 - scenario_id = scenario_resp.json()["id"] - - parameter_payload: Dict[str, Any] = { - "scenario_id": scenario_id, - "name": "grade", - "value": 1.5, - } - param_resp = client.post("/api/parameters/", json=parameter_payload) - assert param_resp.status_code == 200 - - resp = client.post( - "/api/simulations/run", - json={"scenario_id": scenario_id, "iterations": 3, "seed": 7}, - ) - assert resp.status_code == 200 - data = resp.json() - assert data["summary"]["count"] == 3 - assert len(data["results"]) == 3 diff --git a/tests/unit/test_theme_settings.py b/tests/unit/test_theme_settings.py deleted file mode 100644 index e24f7ec..0000000 --- a/tests/unit/test_theme_settings.py +++ /dev/null @@ -1,56 +0,0 @@ -from sqlalchemy.orm import Session - -from services.settings import save_theme_settings, get_theme_settings - - -def test_save_theme_settings(db_session: Session): - theme_data = { - "theme_name": "dark", - "primary_color": "#000000", - "secondary_color": "#333333", - "accent_color": "#ff0000", - "background_color": "#1a1a1a", - "text_color": "#ffffff" - } - - saved_setting = save_theme_settings(db_session, theme_data) - assert str(saved_setting.theme_name) == "dark" - assert str(saved_setting.primary_color) == "#000000" - - -def test_get_theme_settings(db_session: Session): - # Create a theme setting first - theme_data = { - "theme_name": "light", - "primary_color": "#ffffff", - "secondary_color": "#cccccc", - "accent_color": "#0000ff", - "background_color": "#f0f0f0", - "text_color": "#000000" - } - save_theme_settings(db_session, theme_data) - - settings = get_theme_settings(db_session) - assert settings["theme_name"] == "light" - assert settings["primary_color"] == "#ffffff" - - -def test_theme_settings_api(api_client): - # Test API endpoint for saving theme settings - theme_data = { - "theme_name": "test_theme", - "primary_color": "#123456", - "secondary_color": "#789abc", - "accent_color": "#def012", - "background_color": "#345678", - "text_color": "#9abcde" - } - - response = api_client.post("/api/settings/theme", json=theme_data) - assert response.status_code == 200 - assert response.json()["theme"]["theme_name"] == "test_theme" - - # Test API endpoint for getting theme settings - response = api_client.get("/api/settings/theme") - assert response.status_code == 200 - assert response.json()["theme_name"] == "test_theme" diff --git a/tests/unit/test_ui_routes.py b/tests/unit/test_ui_routes.py deleted file mode 100644 index b0757d7..0000000 --- a/tests/unit/test_ui_routes.py +++ /dev/null @@ -1,179 +0,0 @@ -from typing import Any, Dict, cast - -import pytest -from fastapi.testclient import TestClient - -from models.scenario import Scenario -from services import settings as settings_service - - -def test_dashboard_route_provides_summary( - api_client: TestClient, seeded_ui_data: Dict[str, Any] -) -> None: - response = api_client.get("/ui/dashboard") - assert response.status_code == 200 - - template = getattr(response, "template", None) - assert template is not None - assert template.name == "Dashboard.html" - - context = cast(Dict[str, Any], getattr(response, "context", {})) - assert context.get("report_available") is True - - metric_labels = {item["label"] for item in context["summary_metrics"]} - assert { - "CAPEX Total", - "OPEX Total", - "Production", - "Simulation Iterations", - }.issubset(metric_labels) - - scenario = cast(Scenario, seeded_ui_data["scenario"]) - scenario_row = next( - row - for row in context["scenario_rows"] - if row["scenario_name"] == scenario.name - ) - assert scenario_row["iterations"] == 3 - assert scenario_row["simulation_mean_display"] == "971,666.67" - assert scenario_row["capex_display"] == "$1,000,000.00" - assert scenario_row["opex_display"] == "$250,000.00" - assert scenario_row["production_display"] == "800.00" - assert scenario_row["consumption_display"] == "1,200.00" - - -def test_scenarios_route_lists_seeded_scenario( - api_client: TestClient, seeded_ui_data: Dict[str, Any] -) -> None: - response = api_client.get("/ui/scenarios") - assert response.status_code == 200 - - template = getattr(response, "template", None) - assert template is not None - assert template.name == "ScenarioForm.html" - - context = cast(Dict[str, Any], getattr(response, "context", {})) - names = [item["name"] for item in context["scenarios"]] - scenario = cast(Scenario, seeded_ui_data["scenario"]) - assert scenario.name in names - - -def test_reporting_route_includes_summary( - api_client: TestClient, seeded_ui_data: Dict[str, Any] -) -> None: - response = api_client.get("/ui/reporting") - assert response.status_code == 200 - - template = getattr(response, "template", None) - assert template is not None - assert template.name == "reporting.html" - - context = cast(Dict[str, Any], getattr(response, "context", {})) - summaries = context["report_summaries"] - scenario = cast(Scenario, seeded_ui_data["scenario"]) - scenario_summary = next( - item for item in summaries if item["scenario_id"] == scenario.id - ) - assert scenario_summary["iterations"] == 3 - mean_value = float(scenario_summary["summary"]["mean"]) - assert abs(mean_value - 971_666.6666666666) < 1e-6 - - -def test_dashboard_data_endpoint_returns_aggregates( - api_client: TestClient, seeded_ui_data: Dict[str, Any] -) -> None: - response = api_client.get("/ui/dashboard/data") - assert response.status_code == 200 - - payload = response.json() - assert payload["report_available"] is True - - metric_map = { - item["label"]: item["value"] for item in payload["summary_metrics"] - } - assert metric_map["CAPEX Total"].startswith("$") - assert metric_map["Maintenance Cost"].startswith("$") - - scenario = cast(Scenario, seeded_ui_data["scenario"]) - scenario_rows = payload["scenario_rows"] - scenario_entry = next( - row for row in scenario_rows if row["scenario_name"] == scenario.name - ) - assert scenario_entry["capex_display"] == "$1,000,000.00" - assert scenario_entry["production_display"] == "800.00" - - labels = payload["scenario_cost_chart"]["labels"] - idx = labels.index(scenario.name) - assert payload["scenario_cost_chart"]["capex"][idx] == 1_000_000.0 - - activity_labels = payload["scenario_activity_chart"]["labels"] - activity_idx = activity_labels.index(scenario.name) - assert ( - payload["scenario_activity_chart"]["production"][activity_idx] == 800.0 - ) - - -@pytest.mark.parametrize( - ("path", "template_name"), - [ - ("/", "Dashboard.html"), - ("/ui/parameters", "ParameterInput.html"), - ("/ui/costs", "costs.html"), - ("/ui/consumption", "consumption.html"), - ("/ui/production", "production.html"), - ("/ui/equipment", "equipment.html"), - ("/ui/maintenance", "maintenance.html"), - ("/ui/simulations", "simulations.html"), - ], -) -def test_additional_ui_routes_render_templates( - api_client: TestClient, - seeded_ui_data: Dict[str, Any], - path: str, - template_name: str, -) -> None: - response = api_client.get(path) - assert response.status_code == 200 - - template = getattr(response, "template", None) - assert template is not None - assert template.name == template_name - - context = cast(Dict[str, Any], getattr(response, "context", {})) - assert context - - -def test_settings_route_provides_css_context( - api_client: TestClient, - monkeypatch: pytest.MonkeyPatch, -) -> None: - env_var = settings_service.css_key_to_env_var("--color-accent") - monkeypatch.setenv(env_var, "#abcdef") - - response = api_client.get("/ui/settings") - assert response.status_code == 200 - - template = getattr(response, "template", None) - assert template is not None - assert template.name == "settings.html" - - context = cast(Dict[str, Any], getattr(response, "context", {})) - assert "css_variables" in context - assert "css_defaults" in context - assert "css_env_overrides" in context - assert "css_env_override_rows" in context - assert "css_env_override_meta" in context - - assert context["css_variables"]["--color-accent"] == "#abcdef" - assert ( - context["css_defaults"]["--color-accent"] - == settings_service.CSS_COLOR_DEFAULTS["--color-accent"] - ) - assert context["css_env_overrides"]["--color-accent"] == "#abcdef" - - override_rows = context["css_env_override_rows"] - assert any(row["env_var"] == env_var for row in override_rows) - - meta = context["css_env_override_meta"]["--color-accent"] - assert meta["value"] == "#abcdef" - assert meta["env_var"] == env_var diff --git a/tests/unit/test_validation.py b/tests/unit/test_validation.py deleted file mode 100644 index 70473ae..0000000 --- a/tests/unit/test_validation.py +++ /dev/null @@ -1,28 +0,0 @@ -from uuid import uuid4 - -import pytest -from fastapi import HTTPException -from fastapi.testclient import TestClient - - -def test_validate_json_allows_valid_payload(api_client: TestClient) -> None: - payload = { - "name": f"ValidJSON-{uuid4()}", - "description": "Middleware should allow valid JSON.", - } - response = api_client.post("/api/scenarios/", json=payload) - assert response.status_code == 200 - data = response.json() - assert data["name"] == payload["name"] - - -def test_validate_json_rejects_invalid_payload(api_client: TestClient) -> None: - with pytest.raises(HTTPException) as exc_info: - api_client.post( - "/api/scenarios/", - content=b"{not valid json", - headers={"Content-Type": "application/json"}, - ) - - assert exc_info.value.status_code == 400 - assert exc_info.value.detail == "Invalid JSON payload"