diff --git a/.env.development b/.env.development new file mode 100644 index 0000000..256d2ef --- /dev/null +++ b/.env.development @@ -0,0 +1,25 @@ +# Development Environment Configuration +ENVIRONMENT=development +DEBUG=true +LOG_LEVEL=DEBUG + +# Database Configuration +DATABASE_HOST=postgres +DATABASE_PORT=5432 +DATABASE_USER=calminer +DATABASE_PASSWORD=calminer_password +DATABASE_NAME=calminer_db +DATABASE_DRIVER=postgresql + +# Application Settings +CALMINER_EXPORT_MAX_ROWS=1000 +CALMINER_IMPORT_MAX_ROWS=10000 +CALMINER_EXPORT_METADATA=true +CALMINER_IMPORT_STAGING_TTL=300 + +# Admin Seeding (for development) +CALMINER_SEED_ADMIN_EMAIL=admin@calminer.local +CALMINER_SEED_ADMIN_USERNAME=admin +CALMINER_SEED_ADMIN_PASSWORD=ChangeMe123! +CALMINER_SEED_ADMIN_ROLES=admin +CALMINER_SEED_FORCE=false \ No newline at end of file diff --git a/.env.example b/.env.example index 7393cde..de54798 100644 --- a/.env.example +++ b/.env.example @@ -10,5 +10,13 @@ DATABASE_NAME=calminer # Optional: set a schema (comma-separated for multiple entries) # DATABASE_SCHEMA=public -# Legacy fallback (still supported, but granular settings are preferred) -# DATABASE_URL=postgresql://:@localhost:5432/calminer \ No newline at end of file +# Default administrative credentials are provided at deployment time through environment variables +# (`CALMINER_SEED_ADMIN_EMAIL`, `CALMINER_SEED_ADMIN_USERNAME`, `CALMINER_SEED_ADMIN_PASSWORD`, `CALMINER_SEED_ADMIN_ROLES`). +# These values are consumed by a shared bootstrap helper on application startup, ensuring mandatory roles and the administrator account exist before any user interaction. +CALMINER_SEED_ADMIN_EMAIL= +CALMINER_SEED_ADMIN_USERNAME= +CALMINER_SEED_ADMIN_PASSWORD= +CALMINER_SEED_ADMIN_ROLES= +# Operators can request a managed credential reset by setting `CALMINER_SEED_FORCE=true`. +# On the next startup the helper rotates the admin password and reapplies role assignments, so downstream environments must update stored secrets immediately after the reset. +# CALMINER_SEED_FORCE=false \ No newline at end of file diff --git a/.env.production b/.env.production new file mode 100644 index 0000000..9f1035c --- /dev/null +++ b/.env.production @@ -0,0 +1,25 @@ +# Production Environment Configuration +ENVIRONMENT=production +DEBUG=false +LOG_LEVEL=WARNING + +# Database Configuration (MUST be set externally - no defaults) +DATABASE_HOST= +DATABASE_PORT=5432 +DATABASE_USER= +DATABASE_PASSWORD= +DATABASE_NAME= +DATABASE_DRIVER=postgresql + +# Application Settings +CALMINER_EXPORT_MAX_ROWS=100000 +CALMINER_IMPORT_MAX_ROWS=100000 +CALMINER_EXPORT_METADATA=true +CALMINER_IMPORT_STAGING_TTL=3600 + +# Admin Seeding (for production - set strong password) +CALMINER_SEED_ADMIN_EMAIL=admin@calminer.com +CALMINER_SEED_ADMIN_USERNAME=admin +CALMINER_SEED_ADMIN_PASSWORD=CHANGE_THIS_VERY_STRONG_PASSWORD +CALMINER_SEED_ADMIN_ROLES=admin +CALMINER_SEED_FORCE=false \ No newline at end of file diff --git a/.env.staging b/.env.staging new file mode 100644 index 0000000..1deca22 --- /dev/null +++ b/.env.staging @@ -0,0 +1,25 @@ +# Staging Environment Configuration +ENVIRONMENT=staging +DEBUG=false +LOG_LEVEL=INFO + +# Database Configuration (override with actual staging values) +DATABASE_HOST=postgres +DATABASE_PORT=5432 +DATABASE_USER=calminer_staging +DATABASE_PASSWORD=CHANGE_THIS_STRONG_PASSWORD +DATABASE_NAME=calminer_staging_db +DATABASE_DRIVER=postgresql + +# Application Settings +CALMINER_EXPORT_MAX_ROWS=50000 +CALMINER_IMPORT_MAX_ROWS=50000 +CALMINER_EXPORT_METADATA=true +CALMINER_IMPORT_STAGING_TTL=600 + +# Admin Seeding (for staging) +CALMINER_SEED_ADMIN_EMAIL=admin@staging.calminer.com +CALMINER_SEED_ADMIN_USERNAME=admin +CALMINER_SEED_ADMIN_PASSWORD=CHANGE_THIS_STRONG_PASSWORD +CALMINER_SEED_ADMIN_ROLES=admin +CALMINER_SEED_FORCE=false \ No newline at end of file diff --git a/.gitattributes b/.gitattributes new file mode 100644 index 0000000..dcdad2b --- /dev/null +++ b/.gitattributes @@ -0,0 +1,3 @@ +* text=auto + +Dockerfile text eol=lf diff --git a/.gitea/workflows/ci-build.yml b/.gitea/workflows/ci-build.yml new file mode 100644 index 0000000..99456a5 --- /dev/null +++ b/.gitea/workflows/ci-build.yml @@ -0,0 +1,150 @@ +name: CI - Build + +on: + workflow_call: + workflow_dispatch: + +jobs: + build: + runs-on: ubuntu-latest + env: + DEFAULT_BRANCH: main + REGISTRY_URL: ${{ secrets.REGISTRY_URL }} + REGISTRY_USERNAME: ${{ secrets.REGISTRY_USERNAME }} + REGISTRY_PASSWORD: ${{ secrets.REGISTRY_PASSWORD }} + REGISTRY_CONTAINER_NAME: calminer + steps: + - name: Checkout + uses: actions/checkout@v4 + + - name: Collect workflow metadata + id: meta + shell: bash + env: + DEFAULT_BRANCH: ${{ env.DEFAULT_BRANCH }} + run: | + ref_name="${GITHUB_REF_NAME:-${GITHUB_REF##*/}}" + event_name="${GITHUB_EVENT_NAME:-}" + sha="${GITHUB_SHA:-}" + + if [ "$ref_name" = "${DEFAULT_BRANCH:-main}" ]; then + echo "on_default=true" >> "$GITHUB_OUTPUT" + else + echo "on_default=false" >> "$GITHUB_OUTPUT" + fi + + echo "ref_name=$ref_name" >> "$GITHUB_OUTPUT" + echo "event_name=$event_name" >> "$GITHUB_OUTPUT" + echo "sha=$sha" >> "$GITHUB_OUTPUT" + + - name: Set up QEMU and Buildx + uses: docker/setup-buildx-action@v3 + + - name: Log in to gitea registry + if: ${{ steps.meta.outputs.on_default == 'true' }} + uses: docker/login-action@v3 + continue-on-error: true + with: + registry: ${{ env.REGISTRY_URL }} + username: ${{ env.REGISTRY_USERNAME }} + password: ${{ env.REGISTRY_PASSWORD }} + + - name: Build image + id: build-image + env: + REGISTRY_URL: ${{ env.REGISTRY_URL }} + REGISTRY_CONTAINER_NAME: ${{ env.REGISTRY_CONTAINER_NAME }} + SHA_TAG: ${{ steps.meta.outputs.sha }} + PUSH_IMAGE: ${{ steps.meta.outputs.on_default == 'true' && steps.meta.outputs.event_name != 'pull_request' && env.REGISTRY_URL != '' && env.REGISTRY_USERNAME != '' && env.REGISTRY_PASSWORD != '' }} + run: | + set -eo pipefail + LOG_FILE=build.log + if [ "${PUSH_IMAGE}" = "true" ]; then + docker buildx build \ + --push \ + --tag "${REGISTRY_URL}/allucanget/${REGISTRY_CONTAINER_NAME}:latest" \ + --tag "${REGISTRY_URL}/allucanget/${REGISTRY_CONTAINER_NAME}:${SHA_TAG}" \ + --file Dockerfile \ + . 2>&1 | tee "${LOG_FILE}" + else + docker buildx build \ + --load \ + --tag "${REGISTRY_CONTAINER_NAME}:ci" \ + --file Dockerfile \ + . 2>&1 | tee "${LOG_FILE}" + fi + + - name: Upload docker build logs + if: failure() + uses: actions/upload-artifact@v4 + with: + name: docker-build-logs + path: build.log + + deploy: + needs: build + if: github.ref == 'refs/heads/main' && github.event_name != 'pull_request' + runs-on: ubuntu-latest + env: + REGISTRY_URL: ${{ secrets.REGISTRY_URL }} + REGISTRY_CONTAINER_NAME: calminer + KUBE_CONFIG: ${{ secrets.KUBE_CONFIG }} + STAGING_KUBE_CONFIG: ${{ secrets.STAGING_KUBE_CONFIG }} + PROD_KUBE_CONFIG: ${{ secrets.PROD_KUBE_CONFIG }} + steps: + - name: Checkout + uses: actions/checkout@v4 + + - name: Set up kubectl for staging + if: github.event.head_commit && contains(github.event.head_commit.message, '[deploy staging]') + uses: azure/k8s-set-context@v3 + with: + method: kubeconfig + kubeconfig: ${{ env.STAGING_KUBE_CONFIG }} + + - name: Set up kubectl for production + if: github.event.head_commit && contains(github.event.head_commit.message, '[deploy production]') + uses: azure/k8s-set-context@v3 + with: + method: kubeconfig + kubeconfig: ${{ env.PROD_KUBE_CONFIG }} + + - name: Deploy to staging + if: github.event.head_commit && contains(github.event.head_commit.message, '[deploy staging]') + run: | + kubectl set image deployment/calminer-app calminer=${REGISTRY_URL}/allucanget/${REGISTRY_CONTAINER_NAME}:latest + kubectl apply -f k8s/configmap.yaml + kubectl apply -f k8s/secret.yaml + kubectl rollout status deployment/calminer-app + + - name: Collect staging deployment logs + if: github.event.head_commit && contains(github.event.head_commit.message, '[deploy staging]') + run: | + mkdir -p logs/deployment/staging + kubectl get pods -o wide > logs/deployment/staging/pods.txt + kubectl get deployment calminer-app -o yaml > logs/deployment/staging/deployment.yaml + kubectl logs deployment/calminer-app --all-containers=true --tail=500 > logs/deployment/staging/calminer-app.log + + - name: Deploy to production + if: github.event.head_commit && contains(github.event.head_commit.message, '[deploy production]') + run: | + kubectl set image deployment/calminer-app calminer=${REGISTRY_URL}/allucanget/${REGISTRY_CONTAINER_NAME}:latest + kubectl apply -f k8s/configmap.yaml + kubectl apply -f k8s/secret.yaml + kubectl rollout status deployment/calminer-app + + - name: Collect production deployment logs + if: github.event.head_commit && contains(github.event.head_commit.message, '[deploy production]') + run: | + mkdir -p logs/deployment/production + kubectl get pods -o wide > logs/deployment/production/pods.txt + kubectl get deployment calminer-app -o yaml > logs/deployment/production/deployment.yaml + kubectl logs deployment/calminer-app --all-containers=true --tail=500 > logs/deployment/production/calminer-app.log + + - name: Upload deployment logs + if: always() + uses: actions/upload-artifact@v4 + with: + name: deployment-logs + path: logs/deployment + if-no-files-found: ignore diff --git a/.gitea/workflows/ci-lint.yml b/.gitea/workflows/ci-lint.yml new file mode 100644 index 0000000..b905a36 --- /dev/null +++ b/.gitea/workflows/ci-lint.yml @@ -0,0 +1,44 @@ +name: CI - Lint + +on: + workflow_call: + workflow_dispatch: + +jobs: + lint: + runs-on: ubuntu-latest + env: + APT_CACHER_NG: http://192.168.88.14:3142 + steps: + - uses: actions/checkout@v4 + + - name: Set up Python + uses: actions/setup-python@v4 + with: + python-version: "3.12" + + - name: Configure apt proxy + run: | + if [ -n "${APT_CACHER_NG}" ]; then + echo "Acquire::http::Proxy \"${APT_CACHER_NG}\";" | tee /etc/apt/apt.conf.d/01apt-cacher-ng + fi + + - name: Install system packages + run: | + apt-get update + apt-get install -y build-essential libpq-dev + + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install -r requirements.txt + pip install -r requirements-test.txt + + - name: Run Ruff + run: ruff check . + + - name: Run Black + run: black --check . + + - name: Run Bandit + run: bandit -c pyproject.toml -r tests diff --git a/.gitea/workflows/ci-test.yml b/.gitea/workflows/ci-test.yml new file mode 100644 index 0000000..0cf1572 --- /dev/null +++ b/.gitea/workflows/ci-test.yml @@ -0,0 +1,73 @@ +name: CI - Test + +on: + workflow_call: + workflow_dispatch: + +jobs: + test: + runs-on: ubuntu-latest + env: + APT_CACHER_NG: http://192.168.88.14:3142 + DB_DRIVER: postgresql+psycopg2 + DB_HOST: 192.168.88.35 + DB_NAME: calminer_test + DB_USER: calminer + DB_PASSWORD: calminer_password + services: + postgres: + image: postgres:17 + env: + POSTGRES_USER: ${{ env.DB_USER }} + POSTGRES_PASSWORD: ${{ env.DB_PASSWORD }} + POSTGRES_DB: ${{ env.DB_NAME }} + options: >- + --health-cmd pg_isready + --health-interval 10s + --health-timeout 5s + --health-retries 5 + steps: + - name: Checkout repository + uses: actions/checkout@v4 + + - name: Set up Python + uses: actions/setup-python@v4 + with: + python-version: "3.12" + + - name: Configure apt proxy + run: | + if [ -n "${APT_CACHER_NG}" ]; then + echo "Acquire::http::Proxy \"${APT_CACHER_NG}\";" | tee /etc/apt/apt.conf.d/01apt-cacher-ng + fi + + - name: Install system packages + run: | + apt-get update + apt-get install -y build-essential libpq-dev + + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install -r requirements.txt + pip install -r requirements-test.txt + + - name: Run tests + env: + DATABASE_DRIVER: ${{ env.DB_DRIVER }} + DATABASE_HOST: postgres + DATABASE_PORT: 5432 + DATABASE_USER: ${{ env.DB_USER }} + DATABASE_PASSWORD: ${{ env.DB_PASSWORD }} + DATABASE_NAME: ${{ env.DB_NAME }} + run: | + pytest --cov=. --cov-report=term-missing --cov-report=xml --cov-fail-under=80 --junitxml=pytest-report.xml + + - name: Upload test artifacts + if: always() + uses: actions/upload-artifact@v3 + with: + name: test-artifacts + path: | + coverage.xml + pytest-report.xml diff --git a/.gitea/workflows/ci.yml b/.gitea/workflows/ci.yml new file mode 100644 index 0000000..2896f94 --- /dev/null +++ b/.gitea/workflows/ci.yml @@ -0,0 +1,30 @@ +name: CI + +on: + push: + branches: + - main + - develop + - v2 + pull_request: + branches: + - main + - develop + workflow_dispatch: + +jobs: + lint: + uses: ./.gitea/workflows/ci-lint.yml + secrets: inherit + + test: + needs: lint + uses: ./.gitea/workflows/ci-test.yml + secrets: inherit + + build: + needs: + - lint + - test + uses: ./.gitea/workflows/ci-build.yml + secrets: inherit diff --git a/.gitea/workflows/cicache.yml b/.gitea/workflows/cicache.yml deleted file mode 100644 index 3f0d38b..0000000 --- a/.gitea/workflows/cicache.yml +++ /dev/null @@ -1,141 +0,0 @@ -name: CI - -on: - push: - branches: [main, develop] - pull_request: - branches: [main, develop] - -jobs: - test: - env: - APT_CACHER_NG: http://192.168.88.14:3142 - DB_DRIVER: postgresql+psycopg2 - DB_HOST: 192.168.88.35 - DB_NAME: calminer_test - DB_USER: calminer - DB_PASSWORD: calminer_password - runs-on: ubuntu-latest - - services: - postgres: - image: postgres:17 - env: - POSTGRES_USER: ${{ env.DB_USER }} - POSTGRES_PASSWORD: ${{ env.DB_PASSWORD }} - POSTGRES_DB: ${{ env.DB_NAME }} - options: >- - --health-cmd pg_isready - --health-interval 10s - --health-timeout 5s - --health-retries 5 - - steps: - - uses: actions/checkout@v4 - - - name: Set up Python - uses: actions/setup-python@v4 - with: - python-version: '3.11' - - - name: Get pip cache dir - id: pip-cache - run: | - echo "path=$(pip cache dir)" >> $GITEA_OUTPUT - echo "Pip cache dir: $(pip cache dir)" - - - name: Cache pip dependencies - uses: actions/cache@v4 - with: - path: ${{ steps.pip-cache.outputs.path }} - key: ${{ runner.os }}-pip-${{ hashFiles('requirements.txt', 'requirements-test.txt') }} - restore-keys: | - ${{ runner.os }}-pip- - - - name: Update apt-cacher-ng config - run: |- - echo 'Acquire::http::Proxy "{{ env.APT_CACHER_NG }}";' | tee /etc/apt/apt.conf.d/01apt-cacher-ng - apt-get update - - - name: Update system packages - run: apt-get upgrade -y - - - name: Install dependencies - run: | - python -m pip install --upgrade pip - pip install -r requirements.txt - pip install -r requirements-test.txt - - - name: Install Playwright system dependencies - run: playwright install-deps - - - name: Install Playwright browsers - run: playwright install - - - name: Run tests - env: - DATABASE_DRIVER: ${{ env.DB_DRIVER }} - DATABASE_HOST: postgres - DATABASE_PORT: 5432 - DATABASE_USER: ${{ env.DB_USER }} - DATABASE_PASSWORD: ${{ env.DB_PASSWORD }} - DATABASE_NAME: ${{ env.DB_NAME }} - run: | - pytest tests/ --cov=. - - - name: Build Docker image - run: | - docker build -t calminer . - - build: - runs-on: ubuntu-latest - needs: test - env: - DEFAULT_BRANCH: main - REGISTRY_URL: ${{ secrets.REGISTRY_URL }} - REGISTRY_USERNAME: ${{ secrets.REGISTRY_USERNAME }} - REGISTRY_PASSWORD: ${{ secrets.REGISTRY_PASSWORD }} - REGISTRY_CONTAINER_NAME: calminer - steps: - - name: Checkout - uses: actions/checkout@v4 - - - name: Collect workflow metadata - id: meta - shell: bash - run: | - ref_name="${GITHUB_REF_NAME:-${GITHUB_REF##*/}}" - event_name="${GITHUB_EVENT_NAME:-}" - sha="${GITHUB_SHA:-}" - - if [ "$ref_name" = "${DEFAULT_BRANCH:-main}" ]; then - echo "on_default=true" >> "$GITHUB_OUTPUT" - else - echo "on_default=false" >> "$GITHUB_OUTPUT" - fi - - echo "ref_name=$ref_name" >> "$GITHUB_OUTPUT" - echo "event_name=$event_name" >> "$GITHUB_OUTPUT" - echo "sha=$sha" >> "$GITHUB_OUTPUT" - - - name: Set up QEMU and Buildx - uses: docker/setup-buildx-action@v3 - - - name: Log in to gitea registry - if: ${{ steps.meta.outputs.on_default == 'true' }} - uses: docker/login-action@v3 - continue-on-error: true - with: - registry: ${{ env.REGISTRY_URL }} - username: ${{ env.REGISTRY_USERNAME }} - password: ${{ env.REGISTRY_PASSWORD }} - - - name: Build and push image - uses: docker/build-push-action@v5 - with: - context: . - file: Dockerfile - push: ${{ steps.meta.outputs.on_default == 'true' && steps.meta.outputs.event_name != 'pull_request' && (env.REGISTRY_URL != '' && env.REGISTRY_USERNAME != '' && env.REGISTRY_PASSWORD != '') }} - tags: | - ${{ env.REGISTRY_URL }}/allucanget/${{ env.REGISTRY_CONTAINER_NAME }}:latest - ${{ env.REGISTRY_URL }}/allucanget/${{ env.REGISTRY_CONTAINER_NAME }}:${{ steps.meta.outputs.sha }} diff --git a/.gitignore b/.gitignore index a5cba30..26355e8 100644 --- a/.gitignore +++ b/.gitignore @@ -17,6 +17,7 @@ env/ # environment variables .env *.env +.env.* # except example files !config/*.env.example @@ -46,8 +47,10 @@ htmlcov/ logs/ # SQLite database +data/ *.sqlite3 test*.db +local*.db # Act runner files .runner diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000..c1227bd --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,13 @@ +repos: + - repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.6.1 + hooks: + - id: ruff + - repo: https://github.com/psf/black-pre-commit-mirror + rev: 24.8.0 + hooks: + - id: black + - repo: https://github.com/PyCQA/bandit + rev: 1.7.9 + hooks: + - id: bandit diff --git a/.prettierrc b/.prettierrc deleted file mode 100644 index 0ca3806..0000000 --- a/.prettierrc +++ /dev/null @@ -1,8 +0,0 @@ -{ - "semi": true, - "singleQuote": true, - "trailingComma": "es5", - "printWidth": 80, - "tabWidth": 2, - "useTabs": false -} diff --git a/Dockerfile b/Dockerfile index 2565f21..a023964 100644 --- a/Dockerfile +++ b/Dockerfile @@ -41,8 +41,25 @@ if url: finally: sock.close() PY -apt-get update -apt-get install -y --no-install-recommends build-essential gcc libpq-dev +APT_PROXY_CONFIG=/etc/apt/apt.conf.d/01proxy + +apt_update_with_fallback() { + if ! apt-get update; then + rm -f "$APT_PROXY_CONFIG" + apt-get update + fi +} + +apt_install_with_fallback() { + if ! apt-get install -y --no-install-recommends "$@"; then + rm -f "$APT_PROXY_CONFIG" + apt-get update + apt-get install -y --no-install-recommends "$@" + fi +} + +apt_update_with_fallback +apt_install_with_fallback build-essential gcc libpq-dev pip install --upgrade pip pip wheel --no-deps --wheel-dir /wheels -r requirements.txt apt-get purge -y --auto-remove build-essential gcc @@ -88,8 +105,25 @@ if url: finally: sock.close() PY -apt-get update -apt-get install -y --no-install-recommends libpq5 +APT_PROXY_CONFIG=/etc/apt/apt.conf.d/01proxy + +apt_update_with_fallback() { + if ! apt-get update; then + rm -f "$APT_PROXY_CONFIG" + apt-get update + fi +} + +apt_install_with_fallback() { + if ! apt-get install -y --no-install-recommends "$@"; then + rm -f "$APT_PROXY_CONFIG" + apt-get update + apt-get install -y --no-install-recommends "$@" + fi +} + +apt_update_with_fallback +apt_install_with_fallback libpq5 rm -rf /var/lib/apt/lists/* EOF @@ -108,4 +142,6 @@ USER appuser EXPOSE 8003 -CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "8003", "--workers", "4"] +ENTRYPOINT ["uvicorn"] + +CMD ["main:app", "--host", "0.0.0.0", "--port", "8003", "--workers", "4"] diff --git a/README.md b/README.md index e9eac24..6e47f2d 100644 --- a/README.md +++ b/README.md @@ -8,4 +8,6 @@ The system is designed to help mining companies make informed decisions by simul ## Documentation & quickstart -This repository contains only code. See detailed developer and architecture documentation in the [Docs](https://git.allucanget.biz/allucanget/calminer-docs) repository. +- Detailed developer, architecture, and operations guides live in the companion [calminer-docs](../calminer-docs/) repository. Please see the [README](../calminer-docs/README.md) there for instructions. +- For a local run, create a `.env` (see `.env.example`), install requirements, then execute `python -m scripts.init_db` followed by `uvicorn main:app --reload`. The initializer is safe to rerun and seeds demo data automatically. +- To wipe and recreate the schema in development, run `CALMINER_ENV=development python -m scripts.reset_db` before invoking the initializer again. diff --git a/changelog.md b/changelog.md new file mode 100644 index 0000000..712d860 --- /dev/null +++ b/changelog.md @@ -0,0 +1,112 @@ +# Changelog + +## 2025-11-13 + +- Completed the UI alignment initiative by consolidating shared form and button styles into `static/css/forms.css` and `static/css/main.css`, introducing the semantic palette in `static/css/theme-default.css`, and spot-checking key pages plus contrast reports. +- Refactored the architecture data model docs by turning `calminer-docs/architecture/08_concepts/02_data_model.md` into a concise overview that links to new detail pages covering SQLAlchemy models, navigation metadata, enumerations, Pydantic schemas, and monitoring tables. +- Nested the calculator navigation under Projects by updating `scripts/init_db.py` seeds, teaching `services/navigation.py` to resolve scenario-scoped hrefs for profitability/opex/capex, and extending sidebar coverage through `tests/integration/test_navigation_sidebar_calculations.py` plus `tests/services/test_navigation_service.py` to validate admin/viewer visibility and contextual URL generation. +- Added navigation sidebar integration coverage by extending `tests/conftest.py` with role-switching headers, seeding admin/viewer test users, and adding `tests/integration/test_navigation_sidebar.py` to assert ordered link rendering for admins, viewer filtering of admin-only entries, and anonymous rejection of the endpoint. +- Finalised the financial data import/export templates by inventorying required fields, defining CSV column specs with validation rules, drafting Excel workbook layouts, documenting end-user workflows in `calminer-docs/userguide/data_import_export.md`, and recording stakeholder review steps alongside updated TODO/DONE tracking. +- Scoped profitability calculator UI under the scenario hierarchy by adding `/calculations/projects/{project_id}/scenarios/{scenario_id}/profitability` GET/POST handlers, updating scenario templates and sidebar navigation to link to the new route, and extending `tests/test_project_scenario_routes.py` with coverage for the scenario path plus legacy redirect behaviour (module run: 14 passed). +- Extended scenario frontend regression coverage by updating `tests/test_project_scenario_routes.py` to assert project/scenario breadcrumbs and calculator navigation, normalising escaped URLs, and re-running the module tests (13 passing). +- Cleared FastAPI and Pydantic deprecation warnings by migrating `scripts/init_db.py` to `@field_validator`, replacing the `main.py` startup hook with a lifespan handler, auditing template response call signatures, confirming HTTP 422 constant usage, and re-running the full pytest suite to ensure a clean warning slate. +- Delivered the capex planner end-to-end: added scaffolded UI in `templates/scenarios/capex.html`, wired GET/POST handlers through `routes/calculations.py`, implemented calculation logic plus snapshot persistence in `services/calculations.py` and `models/capex_snapshot.py`, updated navigation links, and introduced unit tests in `tests/services/test_calculations_capex.py`. +- Updated UI navigation to surface the opex planner by adding the sidebar link in `templates/partials/sidebar_nav.html`, wiring a scenario detail action in `templates/scenarios/detail.html`. +- Completed manual validation of the Capex Planner UI flows (sidebar entry, scenario deep link, validation errors, successful calculation) with results captured in `manual_tests/capex.md`, documented snapshot verification steps, and noted the optional JSON client check for future follow-up. +- Added opex calculation unit tests in `tests/services/test_calculations_opex.py` covering success metrics, currency validation, frequency enforcement, and evaluation horizon extension. +- Documented the Opex Planner workflow in `calminer-docs/userguide/opex_planner.md`, linked it from the user guide index, extended `calminer-docs/architecture/08_concepts/02_data_model.md` with snapshot coverage, and captured the completion in `.github/instructions/DONE.md`. +- Implemented opex integration coverage in `tests/integration/test_opex_calculations.py`, exercising HTML and JSON flows, verifying snapshot persistence, and asserting currency mismatch handling for form and API submissions. +- Executed the full pytest suite with coverage (211 tests) to confirm no regressions or warnings after the opex documentation updates. +- Completed the navigation sidebar API migration by finalising the database-backed service, refactoring `templates/partials/sidebar_nav.html` to consume the endpoint, hydrating via `static/js/navigation_sidebar.js`, and updating HTML route dependencies (`routes/projects.py`, `routes/scenarios.py`, `routes/reports.py`, `routes/imports.py`, `routes/calculations.py`) to use redirect-aware guards so anonymous visitors receive login redirects instead of JSON errors (manual verification via curl across projects, scenarios, reports, and calculations pages). + +## 2025-11-12 + +- Fixed critical 500 error in reporting dashboard by correcting route reference in reporting.html template - changed 'reports.project_list_page' to 'projects.project_list_page' to resolve NoMatchFound error when accessing /ui/reporting. +- Completed navigation validation by inventorying all sidebar navigation links, identifying missing routes for simulations, reporting, settings, themes, and currencies, created new UI routes in routes/ui.py with proper authentication guards, built corresponding templates (simulations.html, reporting.html, settings.html, theme_settings.html, currencies.html), registered the UI router in main.py, updated sidebar navigation to use route names instead of hardcoded URLs, and enhanced navigation.js to use dynamic URL resolution for proper route handling. +- Fixed critical template rendering error in sidebar_nav.html where URL objects from `request.url_for()` were being used with string methods, causing TypeError. Added `|string` filters to convert URL objects to strings for proper template rendering. +- Integrated Plotly charting for interactive visualizations in reporting templates, added chart generation methods to ReportingService (`generate_npv_comparison_chart`, `generate_distribution_histogram`), updated project summary and scenario distribution contexts to include chart JSON data, enhanced templates with chart containers and JavaScript rendering, added chart-container CSS styling, and validated all reporting tests pass. + +- Completed local run verification: started application with `uvicorn main:app --reload` without errors, verified authenticated routes (/login, /, /projects/ui, /projects) load correctly with seeded data, and summarized findings for deployment pipeline readiness. +- Fixed docker-compose.override.yml command array to remove duplicate "uvicorn" entry, enabling successful container startup with uvicorn reload in development mode. +- Completed deployment pipeline verification: built Docker image without errors, validated docker-compose configuration, deployed locally with docker-compose (app and postgres containers started successfully), and confirmed application startup logs showing database bootstrap and seeded data initialization. +- Completed documentation of current data models: updated `calminer-docs/architecture/08_concepts/02_data_model.md` with comprehensive SQLAlchemy model schemas, enumerations, Pydantic API schemas, and analysis of discrepancies between models and schemas. +- Switched `models/performance_metric.py` to reuse the shared declarative base from `config.database`, clearing the SQLAlchemy 2.0 `declarative_base` deprecation warning and verifying repository tests still pass. +- Replaced the Alembic migration workflow with the idempotent Pydantic-backed initializer (`scripts/init_db.py`), added a guarded reset utility (`scripts/reset_db.py`), removed migration artifacts/tooling (Alembic directory, config, Docker entrypoint), refreshed the container entrypoint to invoke `uvicorn` directly, and updated installation/architecture docs plus the README to direct developers to the new seeding/reset flow. +- Eliminated Bandit hardcoded-secret findings by replacing literal JWT tokens and passwords across auth/security tests with randomized helpers drawn from `tests/utils/security.py`, ensuring fixtures still assert expected behaviours. +- Centralized Bandit configuration in `pyproject.toml`, reran `bandit -c pyproject.toml -r calminer tests`, and verified the scan now reports zero issues. +- Diagnosed admin bootstrap failure caused by legacy `roles` schema, added Alembic migration `20251112_00_add_roles_metadata_columns.py` to backfill `display_name`, `description`, `created_at`, and `updated_at`, and verified the migration via full pytest run in the activated `.venv`. +- Resolved Ruff E402 warnings by moving module docstrings ahead of `from __future__ import annotations` across currency and pricing service modules, dropped the unused `HTTPException` import in `monitoring/__init__.py`, and confirmed a clean `ruff check .` run. +- Enhanced the deploy job in `.gitea/workflows/cicache.yml` to capture Kubernetes pod, deployment, and container logs into `/logs/deployment/` for staging/production rollouts and publish them via a `deployment-logs` artifact, updating CI/CD documentation with retrieval instructions. +- Fixed CI dashboard template lookup failures by renaming `templates/Dashboard.html` to `templates/dashboard.html` and verifying `tests/test_dashboard_route.py` locally to ensure TemplateNotFound no longer occurs on case-sensitive filesystems. +- Implemented SQLite support as primary local database with environment-driven backend switching (`CALMINER_USE_SQLITE=true`), updated `scripts/init_db.py` for database-agnostic DDL generation (PostgreSQL enums vs SQLite CHECK constraints), tested compatibility with both backends, and verified application startup and seeded data initialization work seamlessly across SQLite and PostgreSQL. + +## 2025-11-11 + +- Collapsed legacy Alembic revisions into `alembic/versions/00_initial.py`, removed superseded migration files, and verified the consolidated schema via SQLite upgrade and Postgres version stamping. +- Implemented base URL routing to redirect unauthenticated users to login and authenticated users to dashboard. +- Added comprehensive end-to-end tests for login flow, including redirects, session handling, and error messaging for invalid/inactive accounts. +- Updated header and footer templates to consistently use `logo_big.png` image instead of text logo, with appropriate CSS styling for sizing. +- Centralised ISO-4217 currency validation across scenarios, imports, and export filters (`models/scenario.py`, `routes/scenarios.py`, `schemas/scenario.py`, `schemas/imports.py`, `services/export_query.py`) so malformed codes are rejected consistently at every entry point. +- Updated scenario services and UI flows to surface friendly validation errors and added regression coverage for imports, exports, API creation, and lifecycle flows ensuring currencies are normalised end-to-end. +- Linked projects to their pricing settings by updating SQLAlchemy models, repositories, seeding utilities, and migrations, and added regression tests to cover the new association and default backfill. +- Bootstrapped database-stored pricing settings at application startup, aligned initial data seeding with the database-first metadata flow, and added tests covering pricing bootstrap creation, project assignment, and idempotency. +- Extended pricing configuration support to prefer persisted metadata via `dependencies.get_pricing_metadata`, added retrieval tests for project/default fallbacks, and refreshed docs (`calminer-docs/specifications/price_calculation.md`, `pricing_settings_data_model.md`) to describe the database-backed workflow and bootstrap behaviour. +- Added `services/financial.py` NPV, IRR, and payback helpers with robust cash-flow normalisation, convergence safeguards, and fractional period support, plus comprehensive pytest coverage exercising representative project scenarios and failure modes. +- Authored `calminer-docs/specifications/financial_metrics.md` capturing DCF assumptions, solver behaviours, and worked examples, and cross-linked the architecture concepts to the new reference for consistent navigation. +- Implemented `services/simulation.py` Monte Carlo engine with configurable distributions, summary aggregation, and reproducible RNG seeding, introduced regression tests in `tests/test_simulation.py`, and documented configuration/usage in `calminer-docs/specifications/monte_carlo_simulation.md` with architecture cross-links. +- Polished reporting HTML contexts by cleaning stray fragments in `routes/reports.py`, adding download action metadata for project and scenario pages, and generating scenario comparison download URLs with correctly serialised repeated `scenario_ids` parameters. +- Consolidated Alembic history into a single initial migration (`20251111_00_initial_schema.py`), removed superseded revision files, and ensured Alembic metadata still references the project metadata for clean bootstrap. +- Added `scripts/run_migrations.py` and a Docker entrypoint wrapper to run Alembic migrations before `uvicorn` starts, removed the fallback `Base.metadata.create_all` call, and updated `calminer-docs/admin/installation.md` so developers know how to apply migrations locally or via Docker. +- Configured pytest defaults to collect coverage (`--cov`) with an 80% fail-under gate, excluded entrypoint/reporting scaffolds from the calculation, updated contributor docs with the standard `pytest` command, and verified the suite now reports 83% coverage. +- Standardized color scheme and typography by moving alert styles to `main.css`, adding typography rules with CSS variables, updating auth templates for consistent button classes, and ensuring all templates use centralized color and spacing variables. +- Improved navigation flow by adding two big chevron buttons on top of the navigation sidebar to allow users to navigate to the previous and next page in the page navigation list, including JavaScript logic for determining current page and handling navigation. +- Established pytest-based unit and integration test suites with coverage thresholds, achieving 83% coverage across 181 tests, with configuration in pyproject.toml and documentation in CONTRIBUTING.md. +- Configured CI pipelines to run tests, linting, and security checks on each change, adding Bandit security scanning to the workflow and verifying execution on pushes and PRs to main/develop branches. +- Added deployment automation with Docker Compose for local development and Kubernetes manifests for production, ensuring environment parity and documenting processes in calminer-docs/admin/installation.md. +- Completed monitoring instrumentation by adding business metrics observation to project and scenario repository operations, and simulation performance tracking to Monte Carlo service with success/error status and duration metrics. +- Updated TODO list to reflect completed monitoring implementation tasks and validated changes with passing simulation tests. +- Implemented comprehensive performance monitoring for scalability (FR-006) with Prometheus metrics collection for HTTP requests, import/export operations, and general application metrics. +- Added database model for persistent metric storage with aggregation endpoints for KPIs like request latency, error rates, and throughput. +- Created FastAPI middleware for automatic request metric collection and background persistence to database. +- Extended monitoring router with performance metrics API endpoints and detailed health checks. +- Added Alembic migration for performance_metrics table and updated model imports. +- Completed concurrent interaction testing implementation, validating database transaction isolation under threading and establishing async testing framework for future concurrency enhancements. +- Implemented comprehensive deployment automation with Docker Compose configurations for development, staging, and production environments ensuring environment parity. +- Set up Kubernetes manifests with resource limits, health checks, and secrets management for production deployment. +- Configured CI/CD workflows for automated Docker image building, registry pushing, and Kubernetes deployment to staging/production environments. +- Documented deployment processes, environment configurations, and CI/CD workflows in project documentation. +- Validated deployment automation through Docker Compose configuration testing and CI/CD pipeline structure. + +## 2025-11-10 + +- Added dedicated pytest coverage for guard dependencies, exercising success plus failure paths (missing session, inactive user, missing roles, project/scenario access errors) via `tests/test_dependencies_guards.py`. +- Added integration tests in `tests/test_authorization_integration.py` verifying anonymous 401 responses, role-based 403s, and authorized project manager flows across API and UI endpoints. +- Implemented environment-driven admin bootstrap settings, wired the `bootstrap_admin` helper into FastAPI startup, added pytest coverage for creation/idempotency/reset logic, and documented operational guidance in the RBAC plan and security concept. +- Retired the legacy authentication RBAC implementation plan document after migrating its guidance into live documentation and synchronized the contributor instructions to reflect the removal. +- Completed the Authentication & RBAC checklist by shipping the new models, migrations, repositories, guard dependencies, and integration tests. +- Documented the project/scenario import/export field mapping and file format guidelines in `calminer-docs/requirements/FR-008.md`, and introduced `schemas/imports.py` with Pydantic models that normalise incoming CSV/Excel rows for projects and scenarios. +- Added `services/importers.py` to load CSV/XLSX files into the new import schemas, pulled in `openpyxl` for Excel support, and covered the parsing behaviour with `tests/test_import_parsing.py`. +- Expanded the import ingestion workflow with staging previews, transactional persistence commits, FastAPI preview/commit endpoints under `/imports`, and new API tests (`tests/test_import_ingestion.py`, `tests/test_import_api.py`) ensuring end-to-end coverage. +- Added persistent audit logging via `ImportExportLog`, structured log emission, Prometheus metrics instrumentation, `/metrics` endpoint exposure, and updated operator/deployment documentation to guide monitoring setup. + +## 2025-11-09 + +- Captured current implementation status, requirements coverage, missing features, and prioritized roadmap in `calminer-docs/implementation_status.md` to guide future development. +- Added core SQLAlchemy domain models, shared metadata descriptors, and Alembic migration setup (with initial schema snapshot) to establish the persistence layer foundation. +- Introduced repository and unit-of-work helpers for projects, scenarios, financial inputs, and simulation parameters to support service-layer operations. +- Added SQLite-backed pytest coverage for repository and unit-of-work behaviours to validate persistence interactions. +- Exposed project and scenario CRUD APIs with validated schemas and integrated them into the FastAPI application. +- Connected project and scenario routers to new Jinja2 list/detail/edit views with HTML forms and redirects. +- Implemented FR-009 client-side enhancements with responsive navigation toggle, mobile-first scenario tables, and shared asset loading across templates. +- Added scenario comparison validator, FastAPI comparison endpoint, and comprehensive unit tests to enforce FR-009 validation rules through API errors. +- Delivered a new dashboard experience with `templates/dashboard.html`, dedicated styling, and a FastAPI route supplying real project/scenario metrics via repository helpers. +- Extended repositories with count/recency utilities and added pytest coverage, including a dashboard rendering smoke test validating empty-state messaging. +- Brought project and scenario detail pages plus their forms in line with the dashboard visuals, adding metric cards, layout grids, and refreshed CTA styles. +- Reordered project route registration to prioritize static UI paths, eliminating 422 errors on `/projects/ui` and `/projects/create`, and added pytest smoke coverage for the navigation endpoints. +- Added end-to-end integration tests for project and scenario lifecycles, validating HTML redirects, template rendering, and API interactions, and updated `ProjectRepository.get` to deduplicate joined loads for detail views. +- Updated all Jinja2 template responses to the new Starlette signature to eliminate deprecation warnings while keeping request-aware context available to the templates. +- Introduced `services/security.py` to centralize Argon2 password hashing utilities and JWT creation/verification with typed payloads, and added pytest coverage for hashing, expiry, tampering, and token type mismatch scenarios. +- Added `routes/auth.py` with registration, login, and password reset flows, refreshed auth templates with error messaging, wired navigation links, and introduced end-to-end pytest coverage for the new forms and token flows. +- Implemented cookie-based authentication session middleware with automatic access token refresh, logout handling, navigation adjustments, and documentation/test updates capturing the new behaviour. +- Delivered idempotent seeding utilities with `scripts/initial_data.py`, entry-point runner `scripts/00_initial_data.py`, documentation updates, and pytest coverage to verify role/admin provisioning. +- Secured project and scenario routers with RBAC guard dependencies, enforced repository access checks via helper utilities, and aligned template routes with FastAPI dependency injection patterns. diff --git a/config/__init__.py b/config/__init__.py new file mode 100644 index 0000000..56096f2 --- /dev/null +++ b/config/__init__.py @@ -0,0 +1 @@ +"""Configuration package.""" diff --git a/config/database.py b/config/database.py index ff6d3c0..e4d7dd4 100644 --- a/config/database.py +++ b/config/database.py @@ -11,12 +11,21 @@ def _build_database_url() -> str: """Construct the SQLAlchemy database URL from granular environment vars. Falls back to `DATABASE_URL` for backward compatibility. + Supports SQLite when CALMINER_USE_SQLITE is set. """ legacy_url = os.environ.get("DATABASE_URL", "") if legacy_url and legacy_url.strip() != "": return legacy_url + use_sqlite = os.environ.get("CALMINER_USE_SQLITE", "").lower() in ("true", "1", "yes") + if use_sqlite: + # Use SQLite database + db_path = os.environ.get("DATABASE_PATH", "./data/calminer.db") + # Ensure the directory exists + os.makedirs(os.path.dirname(db_path), exist_ok=True) + return f"sqlite:///{db_path}" + driver = os.environ.get("DATABASE_DRIVER", "postgresql") host = os.environ.get("DATABASE_HOST") port = os.environ.get("DATABASE_PORT", "5432") @@ -54,7 +63,15 @@ def _build_database_url() -> str: DATABASE_URL = _build_database_url() engine = create_engine(DATABASE_URL, echo=True, future=True) -SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) +# Avoid expiring ORM objects on commit so that objects returned from UnitOfWork +# remain usable for the duration of the request cycle without causing +# DetachedInstanceError when accessed after the session commits. +SessionLocal = sessionmaker( + autocommit=False, + autoflush=False, + bind=engine, + expire_on_commit=False, +) Base = declarative_base() diff --git a/config/settings.py b/config/settings.py new file mode 100644 index 0000000..cb30fb4 --- /dev/null +++ b/config/settings.py @@ -0,0 +1,233 @@ +from __future__ import annotations + +import os +from dataclasses import dataclass +from datetime import timedelta +from functools import lru_cache + +from typing import Optional + +from services.pricing import PricingMetadata + +from services.security import JWTSettings + + +@dataclass(frozen=True, slots=True) +class AdminBootstrapSettings: + """Default administrator bootstrap configuration.""" + + email: str + username: str + password: str + roles: tuple[str, ...] + force_reset: bool + + +@dataclass(frozen=True, slots=True) +class SessionSettings: + """Cookie and header configuration for session token transport.""" + + access_cookie_name: str + refresh_cookie_name: str + cookie_secure: bool + cookie_domain: Optional[str] + cookie_path: str + header_name: str + header_prefix: str + allow_header_fallback: bool + + +@dataclass(frozen=True, slots=True) +class Settings: + """Application configuration sourced from environment variables.""" + + jwt_secret_key: str = "change-me" + jwt_algorithm: str = "HS256" + jwt_access_token_minutes: int = 15 + jwt_refresh_token_days: int = 7 + session_access_cookie_name: str = "calminer_access_token" + session_refresh_cookie_name: str = "calminer_refresh_token" + session_cookie_secure: bool = False + session_cookie_domain: Optional[str] = None + session_cookie_path: str = "/" + session_header_name: str = "Authorization" + session_header_prefix: str = "Bearer" + session_allow_header_fallback: bool = True + admin_email: str = "admin@calminer.local" + admin_username: str = "admin" + admin_password: str = "ChangeMe123!" + admin_roles: tuple[str, ...] = ("admin",) + admin_force_reset: bool = False + pricing_default_payable_pct: float = 100.0 + pricing_default_currency: str | None = "USD" + pricing_moisture_threshold_pct: float = 8.0 + pricing_moisture_penalty_per_pct: float = 0.0 + + @classmethod + def from_environment(cls) -> "Settings": + """Construct settings from environment variables.""" + + return cls( + jwt_secret_key=os.getenv("CALMINER_JWT_SECRET", "change-me"), + jwt_algorithm=os.getenv("CALMINER_JWT_ALGORITHM", "HS256"), + jwt_access_token_minutes=cls._int_from_env( + "CALMINER_JWT_ACCESS_MINUTES", 15 + ), + jwt_refresh_token_days=cls._int_from_env( + "CALMINER_JWT_REFRESH_DAYS", 7 + ), + session_access_cookie_name=os.getenv( + "CALMINER_SESSION_ACCESS_COOKIE", "calminer_access_token" + ), + session_refresh_cookie_name=os.getenv( + "CALMINER_SESSION_REFRESH_COOKIE", "calminer_refresh_token" + ), + session_cookie_secure=cls._bool_from_env( + "CALMINER_SESSION_COOKIE_SECURE", False + ), + session_cookie_domain=os.getenv("CALMINER_SESSION_COOKIE_DOMAIN"), + session_cookie_path=os.getenv("CALMINER_SESSION_COOKIE_PATH", "/"), + session_header_name=os.getenv( + "CALMINER_SESSION_HEADER_NAME", "Authorization" + ), + session_header_prefix=os.getenv( + "CALMINER_SESSION_HEADER_PREFIX", "Bearer" + ), + session_allow_header_fallback=cls._bool_from_env( + "CALMINER_SESSION_ALLOW_HEADER_FALLBACK", True + ), + admin_email=os.getenv( + "CALMINER_SEED_ADMIN_EMAIL", "admin@calminer.local" + ), + admin_username=os.getenv( + "CALMINER_SEED_ADMIN_USERNAME", "admin" + ), + admin_password=os.getenv( + "CALMINER_SEED_ADMIN_PASSWORD", "ChangeMe123!" + ), + admin_roles=cls._parse_admin_roles( + os.getenv("CALMINER_SEED_ADMIN_ROLES") + ), + admin_force_reset=cls._bool_from_env( + "CALMINER_SEED_FORCE", False + ), + pricing_default_payable_pct=cls._float_from_env( + "CALMINER_PRICING_DEFAULT_PAYABLE_PCT", 100.0 + ), + pricing_default_currency=cls._optional_str( + "CALMINER_PRICING_DEFAULT_CURRENCY", "USD" + ), + pricing_moisture_threshold_pct=cls._float_from_env( + "CALMINER_PRICING_MOISTURE_THRESHOLD_PCT", 8.0 + ), + pricing_moisture_penalty_per_pct=cls._float_from_env( + "CALMINER_PRICING_MOISTURE_PENALTY_PER_PCT", 0.0 + ), + ) + + @staticmethod + def _int_from_env(name: str, default: int) -> int: + raw_value = os.getenv(name) + if raw_value is None: + return default + try: + return int(raw_value) + except ValueError: + return default + + @staticmethod + def _bool_from_env(name: str, default: bool) -> bool: + raw_value = os.getenv(name) + if raw_value is None: + return default + lowered = raw_value.strip().lower() + if lowered in {"1", "true", "yes", "on"}: + return True + if lowered in {"0", "false", "no", "off"}: + return False + return default + + @staticmethod + def _parse_admin_roles(raw_value: str | None) -> tuple[str, ...]: + if not raw_value: + return ("admin",) + parts = [segment.strip() + for segment in raw_value.split(",") if segment.strip()] + if "admin" not in parts: + parts.insert(0, "admin") + seen: set[str] = set() + ordered: list[str] = [] + for role_name in parts: + if role_name not in seen: + ordered.append(role_name) + seen.add(role_name) + return tuple(ordered) + + @staticmethod + def _float_from_env(name: str, default: float) -> float: + raw_value = os.getenv(name) + if raw_value is None: + return default + try: + return float(raw_value) + except ValueError: + return default + + @staticmethod + def _optional_str(name: str, default: str | None = None) -> str | None: + raw_value = os.getenv(name) + if raw_value is None or raw_value.strip() == "": + return default + return raw_value.strip() + + def jwt_settings(self) -> JWTSettings: + """Build runtime JWT settings compatible with token helpers.""" + + return JWTSettings( + secret_key=self.jwt_secret_key, + algorithm=self.jwt_algorithm, + access_token_ttl=timedelta(minutes=self.jwt_access_token_minutes), + refresh_token_ttl=timedelta(days=self.jwt_refresh_token_days), + ) + + def session_settings(self) -> SessionSettings: + """Provide transport configuration for session tokens.""" + + return SessionSettings( + access_cookie_name=self.session_access_cookie_name, + refresh_cookie_name=self.session_refresh_cookie_name, + cookie_secure=self.session_cookie_secure, + cookie_domain=self.session_cookie_domain, + cookie_path=self.session_cookie_path, + header_name=self.session_header_name, + header_prefix=self.session_header_prefix, + allow_header_fallback=self.session_allow_header_fallback, + ) + + def admin_bootstrap_settings(self) -> AdminBootstrapSettings: + """Return configured admin bootstrap settings.""" + + return AdminBootstrapSettings( + email=self.admin_email, + username=self.admin_username, + password=self.admin_password, + roles=self.admin_roles, + force_reset=self.admin_force_reset, + ) + + def pricing_metadata(self) -> PricingMetadata: + """Build pricing metadata defaults.""" + + return PricingMetadata( + default_payable_pct=self.pricing_default_payable_pct, + default_currency=self.pricing_default_currency, + moisture_threshold_pct=self.pricing_moisture_threshold_pct, + moisture_penalty_per_pct=self.pricing_moisture_penalty_per_pct, + ) + + +@lru_cache(maxsize=1) +def get_settings() -> Settings: + """Return cached application settings.""" + + return Settings.from_environment() diff --git a/config/setup_production.env.example b/config/setup_production.env.example deleted file mode 100644 index fefd6f2..0000000 --- a/config/setup_production.env.example +++ /dev/null @@ -1,35 +0,0 @@ -# Copy this file to config/setup_production.env and replace values with production secrets - -# Container image and runtime configuration -CALMINER_IMAGE=registry.example.com/calminer/api:latest -CALMINER_DOMAIN=calminer.example.com -TRAEFIK_ACME_EMAIL=ops@example.com -CALMINER_API_PORT=8000 -UVICORN_WORKERS=4 -UVICORN_LOG_LEVEL=info -CALMINER_NETWORK=calminer_backend -API_LIMIT_CPUS=1.0 -API_LIMIT_MEMORY=1g -API_RESERVATION_MEMORY=512m -TRAEFIK_LIMIT_CPUS=0.5 -TRAEFIK_LIMIT_MEMORY=512m -POSTGRES_LIMIT_CPUS=1.0 -POSTGRES_LIMIT_MEMORY=2g -POSTGRES_RESERVATION_MEMORY=1g - -# Application database connection -DATABASE_DRIVER=postgresql+psycopg2 -DATABASE_HOST=production-db.internal -DATABASE_PORT=5432 -DATABASE_NAME=calminer -DATABASE_USER=calminer_app -DATABASE_PASSWORD=ChangeMe123! -DATABASE_SCHEMA=public - -# Optional consolidated SQLAlchemy URL (overrides granular settings when set) -# DATABASE_URL=postgresql+psycopg2://calminer_app:ChangeMe123!@production-db.internal:5432/calminer - -# Superuser credentials used by scripts/setup_database.py for migrations/seed data -DATABASE_SUPERUSER=postgres -DATABASE_SUPERUSER_PASSWORD=ChangeMeSuper123! -DATABASE_SUPERUSER_DB=postgres diff --git a/config/setup_staging.env.example b/config/setup_staging.env.example deleted file mode 100644 index a166e1f..0000000 --- a/config/setup_staging.env.example +++ /dev/null @@ -1,11 +0,0 @@ -# Sample environment configuration for staging deployment -DATABASE_HOST=staging-db.internal -DATABASE_PORT=5432 -DATABASE_NAME=calminer_staging -DATABASE_USER=calminer_app -DATABASE_PASSWORD= - -# Admin connection used for provisioning database and roles -DATABASE_SUPERUSER=postgres -DATABASE_SUPERUSER_PASSWORD= -DATABASE_SUPERUSER_DB=postgres diff --git a/config/setup_test.env.example b/config/setup_test.env.example deleted file mode 100644 index 2228373..0000000 --- a/config/setup_test.env.example +++ /dev/null @@ -1,14 +0,0 @@ -# Sample environment configuration for running scripts/setup_database.py against a test instance -DATABASE_DRIVER=postgresql -DATABASE_HOST=postgres -DATABASE_PORT=5432 -DATABASE_NAME=calminer_test -DATABASE_USER=calminer_test -DATABASE_PASSWORD= -# optional: specify schema if different from 'public' -#DATABASE_SCHEMA=public - -# Admin connection used for provisioning database and roles -DATABASE_SUPERUSER=postgres -DATABASE_SUPERUSER_PASSWORD= -DATABASE_SUPERUSER_DB=postgres diff --git a/dependencies.py b/dependencies.py new file mode 100644 index 0000000..5047755 --- /dev/null +++ b/dependencies.py @@ -0,0 +1,400 @@ +from __future__ import annotations + +from collections.abc import Callable, Iterable, Generator + +from fastapi import Depends, HTTPException, Request, status + +from config.settings import Settings, get_settings +from models import Project, Role, Scenario, User +from services.authorization import ( + ensure_project_access as ensure_project_access_helper, + ensure_scenario_access as ensure_scenario_access_helper, + ensure_scenario_in_project as ensure_scenario_in_project_helper, +) +from services.exceptions import AuthorizationError, EntityNotFoundError +from services.security import JWTSettings +from services.session import ( + AuthSession, + SessionStrategy, + SessionTokens, + build_session_strategy, + extract_session_tokens, +) +from services.unit_of_work import UnitOfWork +from services.importers import ImportIngestionService +from services.pricing import PricingMetadata +from services.navigation import NavigationService +from services.scenario_evaluation import ScenarioPricingConfig, ScenarioPricingEvaluator +from services.repositories import pricing_settings_to_metadata + + +def get_unit_of_work() -> Generator[UnitOfWork, None, None]: + """FastAPI dependency yielding a unit-of-work instance.""" + + with UnitOfWork() as uow: + yield uow + + +_IMPORT_INGESTION_SERVICE = ImportIngestionService(lambda: UnitOfWork()) + + +def get_import_ingestion_service() -> ImportIngestionService: + """Provide singleton import ingestion service.""" + + return _IMPORT_INGESTION_SERVICE + + +def get_application_settings() -> Settings: + """Provide cached application settings instance.""" + + return get_settings() + + +def get_pricing_metadata( + settings: Settings = Depends(get_application_settings), + uow: UnitOfWork = Depends(get_unit_of_work), +) -> PricingMetadata: + """Return pricing metadata defaults sourced from persisted pricing settings.""" + + stored = uow.get_pricing_metadata() + if stored is not None: + return stored + + fallback = settings.pricing_metadata() + seed_result = uow.ensure_default_pricing_settings(metadata=fallback) + return pricing_settings_to_metadata(seed_result.settings) + + +def get_navigation_service( + uow: UnitOfWork = Depends(get_unit_of_work), +) -> NavigationService: + if not uow.navigation: + raise RuntimeError("Navigation repository is not initialised") + return NavigationService(uow.navigation) + + +def get_pricing_evaluator( + metadata: PricingMetadata = Depends(get_pricing_metadata), +) -> ScenarioPricingEvaluator: + """Provide a configured scenario pricing evaluator.""" + + return ScenarioPricingEvaluator(ScenarioPricingConfig(metadata=metadata)) + + +def get_jwt_settings() -> JWTSettings: + """Provide JWT runtime configuration derived from settings.""" + + return get_settings().jwt_settings() + + +def get_session_strategy( + settings: Settings = Depends(get_application_settings), +) -> SessionStrategy: + """Yield configured session transport strategy.""" + + return build_session_strategy(settings.session_settings()) + + +def get_session_tokens( + request: Request, + strategy: SessionStrategy = Depends(get_session_strategy), +) -> SessionTokens: + """Extract raw session tokens from the incoming request.""" + + existing = getattr(request.state, "auth_session", None) + if isinstance(existing, AuthSession): + return existing.tokens + + tokens = extract_session_tokens(request, strategy) + request.state.auth_session = AuthSession(tokens=tokens) + return tokens + + +def get_auth_session( + request: Request, + tokens: SessionTokens = Depends(get_session_tokens), +) -> AuthSession: + """Provide authentication session context for the current request.""" + + existing = getattr(request.state, "auth_session", None) + if isinstance(existing, AuthSession): + return existing + + if tokens.is_empty: + session = AuthSession.anonymous() + else: + session = AuthSession(tokens=tokens) + request.state.auth_session = session + return session + + +def get_current_user( + session: AuthSession = Depends(get_auth_session), +) -> User | None: + """Return the current authenticated user if present.""" + + return session.user + + +def require_current_user( + session: AuthSession = Depends(get_auth_session), +) -> User: + """Ensure that a request is authenticated and return the user context.""" + + if session.user is None or session.tokens.is_empty: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Authentication required.", + ) + return session.user + + +def require_authenticated_user( + user: User = Depends(require_current_user), +) -> User: + """Ensure the current user account is active.""" + + if not user.is_active: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="User account is disabled.", + ) + return user + + +def require_authenticated_user_html( + request: Request, + session: AuthSession = Depends(get_auth_session), +) -> User: + """HTML-aware authenticated dependency that redirects anonymous sessions.""" + + user = session.user + if user is None or session.tokens.is_empty: + login_url = str(request.url_for("auth.login_form")) + raise HTTPException( + status_code=status.HTTP_303_SEE_OTHER, + headers={"Location": login_url}, + ) + + if not user.is_active: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="User account is disabled.", + ) + return user + + +def _user_role_names(user: User) -> set[str]: + roles: Iterable[Role] = getattr(user, "roles", []) or [] + return {role.name for role in roles} + + +def require_roles(*roles: str) -> Callable[[User], User]: + """Dependency factory enforcing membership in one of the given roles.""" + + required = tuple(role.strip() for role in roles if role.strip()) + if not required: + raise ValueError("require_roles requires at least one role name") + + def _dependency(user: User = Depends(require_authenticated_user)) -> User: + if user.is_superuser: + return user + + role_names = _user_role_names(user) + if not any(role in role_names for role in required): + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="Insufficient permissions for this action.", + ) + return user + + return _dependency + + +def require_any_role(*roles: str) -> Callable[[User], User]: + """Alias of require_roles for readability in some contexts.""" + + return require_roles(*roles) + + +def require_roles_html(*roles: str) -> Callable[[Request], User]: + """Ensure user is authenticated for HTML responses; redirect anonymous to login.""" + + required = tuple(role.strip() for role in roles if role.strip()) + if not required: + raise ValueError("require_roles_html requires at least one role name") + + def _dependency( + request: Request, + session: AuthSession = Depends(get_auth_session), + ) -> User: + user = session.user + if user is None: + login_url = str(request.url_for("auth.login_form")) + raise HTTPException( + status_code=status.HTTP_303_SEE_OTHER, + headers={"Location": login_url}, + ) + + if user.is_superuser: + return user + + role_names = _user_role_names(user) + if not any(role in role_names for role in required): + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="Insufficient permissions for this action.", + ) + return user + + return _dependency + + +def require_any_role_html(*roles: str) -> Callable[[Request], User]: + """Alias of require_roles_html for readability.""" + + return require_roles_html(*roles) + + +def require_project_resource( + *, + require_manage: bool = False, + user_dependency: Callable[..., User] = require_authenticated_user, +) -> Callable[[int], Project]: + """Dependency factory that resolves a project with authorization checks.""" + + def _dependency( + project_id: int, + user: User = Depends(user_dependency), + uow: UnitOfWork = Depends(get_unit_of_work), + ) -> Project: + try: + return ensure_project_access_helper( + uow, + project_id=project_id, + user=user, + require_manage=require_manage, + ) + except EntityNotFoundError as exc: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=str(exc), + ) from exc + except AuthorizationError as exc: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail=str(exc), + ) from exc + + return _dependency + + +def require_scenario_resource( + *, + require_manage: bool = False, + with_children: bool = False, + user_dependency: Callable[..., User] = require_authenticated_user, +) -> Callable[[int], Scenario]: + """Dependency factory that resolves a scenario with authorization checks.""" + + def _dependency( + scenario_id: int, + user: User = Depends(user_dependency), + uow: UnitOfWork = Depends(get_unit_of_work), + ) -> Scenario: + try: + return ensure_scenario_access_helper( + uow, + scenario_id=scenario_id, + user=user, + require_manage=require_manage, + with_children=with_children, + ) + except EntityNotFoundError as exc: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=str(exc), + ) from exc + except AuthorizationError as exc: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail=str(exc), + ) from exc + + return _dependency + + +def require_project_scenario_resource( + *, + require_manage: bool = False, + with_children: bool = False, + user_dependency: Callable[..., User] = require_authenticated_user, +) -> Callable[[int, int], Scenario]: + """Dependency factory ensuring a scenario belongs to the given project and is accessible.""" + + def _dependency( + project_id: int, + scenario_id: int, + user: User = Depends(user_dependency), + uow: UnitOfWork = Depends(get_unit_of_work), + ) -> Scenario: + try: + return ensure_scenario_in_project_helper( + uow, + project_id=project_id, + scenario_id=scenario_id, + user=user, + require_manage=require_manage, + with_children=with_children, + ) + except EntityNotFoundError as exc: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=str(exc), + ) from exc + except AuthorizationError as exc: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail=str(exc), + ) from exc + + return _dependency + + +def require_project_resource_html( + *, require_manage: bool = False +) -> Callable[[int], Project]: + """HTML-aware project loader that redirects anonymous sessions.""" + + return require_project_resource( + require_manage=require_manage, + user_dependency=require_authenticated_user_html, + ) + + +def require_scenario_resource_html( + *, + require_manage: bool = False, + with_children: bool = False, +) -> Callable[[int], Scenario]: + """HTML-aware scenario loader that redirects anonymous sessions.""" + + return require_scenario_resource( + require_manage=require_manage, + with_children=with_children, + user_dependency=require_authenticated_user_html, + ) + + +def require_project_scenario_resource_html( + *, + require_manage: bool = False, + with_children: bool = False, +) -> Callable[[int, int], Scenario]: + """HTML-aware project-scenario loader redirecting anonymous sessions.""" + + return require_project_scenario_resource( + require_manage=require_manage, + with_children=with_children, + user_dependency=require_authenticated_user_html, + ) diff --git a/docker-compose.override.yml b/docker-compose.override.yml new file mode 100644 index 0000000..b056724 --- /dev/null +++ b/docker-compose.override.yml @@ -0,0 +1,59 @@ +version: "3.8" + +services: + app: + build: + context: . + dockerfile: Dockerfile + args: + APT_CACHE_URL: ${APT_CACHE_URL:-} + environment: + - ENVIRONMENT=development + - DEBUG=true + - LOG_LEVEL=DEBUG + # Override database to use local postgres service + - DATABASE_HOST=postgres + - DATABASE_PORT=5432 + - DATABASE_USER=calminer + - DATABASE_PASSWORD=calminer_password + - DATABASE_NAME=calminer_db + - DATABASE_DRIVER=postgresql + # Development-specific settings + - CALMINER_EXPORT_MAX_ROWS=1000 + - CALMINER_IMPORT_MAX_ROWS=10000 + volumes: + # Mount source code for live reloading (if using --reload) + - .:/app:ro + # Override logs volume to local for easier access + - ./logs:/app/logs + ports: + - "8003:8003" + # Override command for development with reload + command: + [ + "main:app", + "--host", + "0.0.0.0", + "--port", + "8003", + "--reload", + "--workers", + "1", + ] + depends_on: + - postgres + restart: unless-stopped + + postgres: + environment: + - POSTGRES_USER=calminer + - POSTGRES_PASSWORD=calminer_password + - POSTGRES_DB=calminer_db + ports: + - "5432:5432" + volumes: + - postgres_data:/var/lib/postgresql/data + restart: unless-stopped + +volumes: + postgres_data: diff --git a/docker-compose.prod.yml b/docker-compose.prod.yml new file mode 100644 index 0000000..cd3e264 --- /dev/null +++ b/docker-compose.prod.yml @@ -0,0 +1,77 @@ +version: "3.8" + +services: + app: + build: + context: . + dockerfile: Dockerfile + args: + APT_CACHE_URL: ${APT_CACHE_URL:-} + environment: + - ENVIRONMENT=production + - DEBUG=false + - LOG_LEVEL=WARNING + # Database configuration - must be provided externally + - DATABASE_HOST=${DATABASE_HOST} + - DATABASE_PORT=${DATABASE_PORT:-5432} + - DATABASE_USER=${DATABASE_USER} + - DATABASE_PASSWORD=${DATABASE_PASSWORD} + - DATABASE_NAME=${DATABASE_NAME} + - DATABASE_DRIVER=postgresql + # Production-specific settings + - CALMINER_EXPORT_MAX_ROWS=100000 + - CALMINER_IMPORT_MAX_ROWS=100000 + - CALMINER_EXPORT_METADATA=true + - CALMINER_IMPORT_STAGING_TTL=3600 + ports: + - "8003:8003" + depends_on: + postgres: + condition: service_healthy + restart: unless-stopped + # Production health checks + healthcheck: + test: ["CMD", "curl", "-f", "http://localhost:8003/health"] + interval: 60s + timeout: 30s + retries: 5 + start_period: 60s + # Resource limits for production + deploy: + resources: + limits: + cpus: "1.0" + memory: 1G + reservations: + cpus: "0.5" + memory: 512M + + postgres: + environment: + - POSTGRES_USER=${DATABASE_USER} + - POSTGRES_PASSWORD=${DATABASE_PASSWORD} + - POSTGRES_DB=${DATABASE_NAME} + ports: + - "5432:5432" + volumes: + - postgres_data:/var/lib/postgresql/data + restart: unless-stopped + # Production postgres health check + healthcheck: + test: ["CMD-SHELL", "pg_isready -U ${DATABASE_USER} -d ${DATABASE_NAME}"] + interval: 60s + timeout: 30s + retries: 5 + start_period: 60s + # Resource limits for postgres + deploy: + resources: + limits: + cpus: "1.0" + memory: 2G + reservations: + cpus: "0.5" + memory: 1G + +volumes: + postgres_data: diff --git a/docker-compose.staging.yml b/docker-compose.staging.yml new file mode 100644 index 0000000..f75682b --- /dev/null +++ b/docker-compose.staging.yml @@ -0,0 +1,62 @@ +version: "3.8" + +services: + app: + build: + context: . + dockerfile: Dockerfile + args: + APT_CACHE_URL: ${APT_CACHE_URL:-} + environment: + - ENVIRONMENT=staging + - DEBUG=false + - LOG_LEVEL=INFO + # Database configuration - can be overridden by external env + - DATABASE_HOST=${DATABASE_HOST:-postgres} + - DATABASE_PORT=${DATABASE_PORT:-5432} + - DATABASE_USER=${DATABASE_USER:-calminer} + - DATABASE_PASSWORD=${DATABASE_PASSWORD} + - DATABASE_NAME=${DATABASE_NAME:-calminer_db} + - DATABASE_DRIVER=postgresql + # Staging-specific settings + - CALMINER_EXPORT_MAX_ROWS=50000 + - CALMINER_IMPORT_MAX_ROWS=50000 + - CALMINER_EXPORT_METADATA=true + - CALMINER_IMPORT_STAGING_TTL=600 + ports: + - "8003:8003" + depends_on: + - postgres + restart: unless-stopped + # Health check for staging + healthcheck: + test: ["CMD", "curl", "-f", "http://localhost:8003/health"] + interval: 30s + timeout: 10s + retries: 3 + start_period: 40s + + postgres: + environment: + - POSTGRES_USER=${DATABASE_USER:-calminer} + - POSTGRES_PASSWORD=${DATABASE_PASSWORD} + - POSTGRES_DB=${DATABASE_NAME:-calminer_db} + ports: + - "5432:5432" + volumes: + - postgres_data:/var/lib/postgresql/data + restart: unless-stopped + # Health check for postgres + healthcheck: + test: + [ + "CMD-SHELL", + "pg_isready -U ${DATABASE_USER:-calminer} -d ${DATABASE_NAME:-calminer_db}", + ] + interval: 30s + timeout: 10s + retries: 3 + start_period: 30s + +volumes: + postgres_data: diff --git a/docker-compose.yml b/docker-compose.yml index 9680f1e..f983cf8 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -8,11 +8,13 @@ services: ports: - "8003:8003" environment: - - DATABASE_HOST=postgres - - DATABASE_PORT=5432 - - DATABASE_USER=calminer - - DATABASE_PASSWORD=calminer_password - - DATABASE_NAME=calminer_db + # Environment-specific variables should be set in override files + - ENVIRONMENT=${ENVIRONMENT:-production} + - DATABASE_HOST=${DATABASE_HOST:-postgres} + - DATABASE_PORT=${DATABASE_PORT:-5432} + - DATABASE_USER=${DATABASE_USER} + - DATABASE_PASSWORD=${DATABASE_PASSWORD} + - DATABASE_NAME=${DATABASE_NAME} - DATABASE_DRIVER=postgresql depends_on: - postgres @@ -23,9 +25,9 @@ services: postgres: image: postgres:17 environment: - - POSTGRES_USER=calminer - - POSTGRES_PASSWORD=calminer_password - - POSTGRES_DB=calminer_db + - POSTGRES_USER=${DATABASE_USER} + - POSTGRES_PASSWORD=${DATABASE_PASSWORD} + - POSTGRES_DB=${DATABASE_NAME} ports: - "5432:5432" volumes: diff --git a/k8s/configmap.yaml b/k8s/configmap.yaml new file mode 100644 index 0000000..8773639 --- /dev/null +++ b/k8s/configmap.yaml @@ -0,0 +1,14 @@ +apiVersion: v1 +kind: ConfigMap +metadata: + name: calminer-config +data: + DATABASE_HOST: "calminer-db" + DATABASE_PORT: "5432" + DATABASE_USER: "calminer" + DATABASE_NAME: "calminer_db" + DATABASE_DRIVER: "postgresql" + CALMINER_EXPORT_MAX_ROWS: "10000" + CALMINER_EXPORT_METADATA: "true" + CALMINER_IMPORT_STAGING_TTL: "300" + CALMINER_IMPORT_MAX_ROWS: "50000" diff --git a/k8s/deployment.yaml b/k8s/deployment.yaml new file mode 100644 index 0000000..c15682c --- /dev/null +++ b/k8s/deployment.yaml @@ -0,0 +1,54 @@ +apiVersion: apps/v1 +kind: Deployment +metadata: + name: calminer-app + labels: + app: calminer +spec: + replicas: 3 + selector: + matchLabels: + app: calminer + template: + metadata: + labels: + app: calminer + spec: + containers: + - name: calminer + image: registry.example.com/calminer:latest + ports: + - containerPort: 8003 + envFrom: + - configMapRef: + name: calminer-config + - secretRef: + name: calminer-secrets + resources: + requests: + memory: "256Mi" + cpu: "250m" + limits: + memory: "512Mi" + cpu: "500m" + livenessProbe: + httpGet: + path: /health + port: 8003 + initialDelaySeconds: 30 + periodSeconds: 10 + readinessProbe: + httpGet: + path: /health + port: 8003 + initialDelaySeconds: 5 + periodSeconds: 5 + initContainers: + - name: wait-for-db + image: postgres:17 + command: + [ + "sh", + "-c", + "until pg_isready -h calminer-db -p 5432; do echo waiting for database; sleep 2; done;", + ] diff --git a/k8s/ingress.yaml b/k8s/ingress.yaml new file mode 100644 index 0000000..36a738a --- /dev/null +++ b/k8s/ingress.yaml @@ -0,0 +1,18 @@ +apiVersion: networking.k8s.io/v1 +kind: Ingress +metadata: + name: calminer-ingress + annotations: + nginx.ingress.kubernetes.io/rewrite-target: / +spec: + rules: + - host: calminer.example.com + http: + paths: + - path: / + pathType: Prefix + backend: + service: + name: calminer-service + port: + number: 80 diff --git a/k8s/postgres-service.yaml b/k8s/postgres-service.yaml new file mode 100644 index 0000000..05eeb45 --- /dev/null +++ b/k8s/postgres-service.yaml @@ -0,0 +1,13 @@ +apiVersion: v1 +kind: Service +metadata: + name: calminer-db + labels: + app: calminer-db +spec: + selector: + app: calminer-db + ports: + - port: 5432 + targetPort: 5432 + clusterIP: None # Headless service for StatefulSet diff --git a/k8s/postgres.yaml b/k8s/postgres.yaml new file mode 100644 index 0000000..2e6c29c --- /dev/null +++ b/k8s/postgres.yaml @@ -0,0 +1,48 @@ +apiVersion: apps/v1 +kind: StatefulSet +metadata: + name: calminer-db +spec: + serviceName: calminer-db + replicas: 1 + selector: + matchLabels: + app: calminer-db + template: + metadata: + labels: + app: calminer-db + spec: + containers: + - name: postgres + image: postgres:17 + ports: + - containerPort: 5432 + env: + - name: POSTGRES_USER + value: "calminer" + - name: POSTGRES_PASSWORD + valueFrom: + secretKeyRef: + name: calminer-secrets + key: DATABASE_PASSWORD + - name: POSTGRES_DB + value: "calminer_db" + resources: + requests: + memory: "256Mi" + cpu: "250m" + limits: + memory: "512Mi" + cpu: "500m" + volumeMounts: + - name: postgres-storage + mountPath: /var/lib/postgresql/data + volumeClaimTemplates: + - metadata: + name: postgres-storage + spec: + accessModes: ["ReadWriteOnce"] + resources: + requests: + storage: 10Gi diff --git a/k8s/secret.yaml b/k8s/secret.yaml new file mode 100644 index 0000000..f49ae32 --- /dev/null +++ b/k8s/secret.yaml @@ -0,0 +1,8 @@ +apiVersion: v1 +kind: Secret +metadata: + name: calminer-secrets +type: Opaque +data: + DATABASE_PASSWORD: Y2FsbWluZXJfcGFzc3dvcmQ= # base64 encoded 'calminer_password' + CALMINER_SEED_ADMIN_PASSWORD: Q2hhbmdlTWUxMjMh # base64 encoded 'ChangeMe123!' diff --git a/k8s/service.yaml b/k8s/service.yaml new file mode 100644 index 0000000..de72195 --- /dev/null +++ b/k8s/service.yaml @@ -0,0 +1,14 @@ +apiVersion: v1 +kind: Service +metadata: + name: calminer-service + labels: + app: calminer +spec: + selector: + app: calminer + ports: + - port: 80 + targetPort: 8003 + protocol: TCP + type: ClusterIP diff --git a/main.py b/main.py index 858296d..fe5f31c 100644 --- a/main.py +++ b/main.py @@ -1,28 +1,88 @@ -from routes.distributions import router as distributions_router -from routes.ui import router as ui_router -from routes.parameters import router as parameters_router +import logging +from contextlib import asynccontextmanager from typing import Awaitable, Callable from fastapi import FastAPI, Request, Response from fastapi.staticfiles import StaticFiles +from fastapi.responses import FileResponse + +from config.settings import get_settings +from middleware.auth_session import AuthSessionMiddleware +from middleware.metrics import MetricsMiddleware from middleware.validation import validate_json -from config.database import Base, engine +from routes.auth import router as auth_router +from routes.dashboard import router as dashboard_router +from routes.calculations import router as calculations_router +from routes.imports import router as imports_router +from routes.exports import router as exports_router +from routes.projects import router as projects_router +from routes.reports import router as reports_router from routes.scenarios import router as scenarios_router -from routes.costs import router as costs_router -from routes.consumption import router as consumption_router -from routes.production import router as production_router -from routes.equipment import router as equipment_router -from routes.reporting import router as reporting_router -from routes.currencies import router as currencies_router -from routes.simulations import router as simulations_router -from routes.maintenance import router as maintenance_router -from routes.settings import router as settings_router -from routes.users import router as users_router +from routes.ui import router as ui_router +from routes.navigation import router as navigation_router +from monitoring import router as monitoring_router +from services.bootstrap import bootstrap_admin, bootstrap_pricing_settings +from scripts.init_db import init_db as init_db_script -# Initialize database schema -Base.metadata.create_all(bind=engine) +logger = logging.getLogger(__name__) -app = FastAPI() + +async def _bootstrap_startup() -> None: + settings = get_settings() + admin_settings = settings.admin_bootstrap_settings() + pricing_metadata = settings.pricing_metadata() + try: + try: + init_db_script() + except Exception: + logger.exception( + "DB initializer failed; continuing to bootstrap (non-fatal)") + + role_result, admin_result = bootstrap_admin(settings=admin_settings) + pricing_result = bootstrap_pricing_settings(metadata=pricing_metadata) + logger.info( + "Admin bootstrap completed: roles=%s created=%s updated=%s rotated=%s assigned=%s", + role_result.ensured, + admin_result.created_user, + admin_result.updated_user, + admin_result.password_rotated, + admin_result.roles_granted, + ) + try: + seed = pricing_result.seed + slug = getattr(seed.settings, "slug", None) if seed and getattr( + seed, "settings", None) else None + created = getattr(seed, "created", None) + updated_fields = getattr(seed, "updated_fields", None) + impurity_upserts = getattr(seed, "impurity_upserts", None) + logger.info( + "Pricing settings bootstrap completed: slug=%s created=%s updated_fields=%s impurity_upserts=%s projects_assigned=%s", + slug, + created, + updated_fields, + impurity_upserts, + pricing_result.projects_assigned, + ) + except Exception: + logger.info( + "Pricing settings bootstrap completed (partial): projects_assigned=%s", + pricing_result.projects_assigned, + ) + except Exception: # pragma: no cover - defensive logging + logger.exception( + "Failed to bootstrap administrator or pricing settings") + + +@asynccontextmanager +async def app_lifespan(_: FastAPI): + await _bootstrap_startup() + yield + + +app = FastAPI(lifespan=app_lifespan) + +app.add_middleware(AuthSessionMiddleware) +app.add_middleware(MetricsMiddleware) @app.middleware("http") @@ -37,20 +97,23 @@ async def health() -> dict[str, str]: return {"status": "ok"} -app.mount("/static", StaticFiles(directory="static"), name="static") +@app.get("/favicon.ico", include_in_schema=False) +async def favicon() -> Response: + static_directory = "static" + favicon_img = "favicon.ico" + return FileResponse(f"{static_directory}/{favicon_img}") -# Include API routers + +app.include_router(dashboard_router) +app.include_router(calculations_router) +app.include_router(auth_router) +app.include_router(imports_router) +app.include_router(exports_router) +app.include_router(projects_router) app.include_router(scenarios_router) -app.include_router(parameters_router) -app.include_router(distributions_router) -app.include_router(costs_router) -app.include_router(consumption_router) -app.include_router(simulations_router) -app.include_router(production_router) -app.include_router(equipment_router) -app.include_router(maintenance_router) -app.include_router(reporting_router) -app.include_router(currencies_router) -app.include_router(settings_router) +app.include_router(reports_router) app.include_router(ui_router) -app.include_router(users_router) +app.include_router(monitoring_router) +app.include_router(navigation_router) + +app.mount("/static", StaticFiles(directory="static"), name="static") diff --git a/middleware/auth_session.py b/middleware/auth_session.py new file mode 100644 index 0000000..c697fb1 --- /dev/null +++ b/middleware/auth_session.py @@ -0,0 +1,218 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import Callable, Iterable, Optional + +from fastapi import Request, Response +from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint +from starlette.types import ASGIApp + +from config.settings import Settings, get_settings +from sqlalchemy.orm.exc import DetachedInstanceError +from models import User +from monitoring.metrics import ACTIVE_CONNECTIONS +from services.exceptions import EntityNotFoundError +from services.security import ( + JWTSettings, + TokenDecodeError, + TokenError, + TokenExpiredError, + TokenTypeMismatchError, + create_access_token, + create_refresh_token, + decode_access_token, + decode_refresh_token, +) +from services.session import ( + AuthSession, + SessionStrategy, + SessionTokens, + build_session_strategy, + clear_session_cookies, + extract_session_tokens, + set_session_cookies, +) +from services.unit_of_work import UnitOfWork + +_AUTH_SCOPE = "auth" + + +@dataclass(slots=True) +class _ResolutionResult: + session: AuthSession + strategy: SessionStrategy + jwt_settings: JWTSettings + + +class AuthSessionMiddleware(BaseHTTPMiddleware): + """Resolve authenticated users from session cookies and refresh tokens.""" + + _active_sessions: int = 0 + + def __init__( + self, + app: ASGIApp, + *, + settings_provider: Callable[[], Settings] = get_settings, + unit_of_work_factory: Callable[[], UnitOfWork] = UnitOfWork, + refresh_scopes: Iterable[str] | None = None, + ) -> None: + super().__init__(app) + self._settings_provider = settings_provider + self._unit_of_work_factory = unit_of_work_factory + self._refresh_scopes = tuple( + refresh_scopes) if refresh_scopes else (_AUTH_SCOPE,) + + async def dispatch(self, request: Request, call_next: RequestResponseEndpoint) -> Response: + resolved = self._resolve_session(request) + + # Track active sessions for authenticated users + try: + user_active = bool(resolved.session.user and getattr( + resolved.session.user, "is_active", False)) + except DetachedInstanceError: + user_active = False + + if user_active: + AuthSessionMiddleware._active_sessions += 1 + ACTIVE_CONNECTIONS.set(AuthSessionMiddleware._active_sessions) + + response: Response | None = None + try: + response = await call_next(request) + return response + finally: + # Always decrement the active sessions counter if we incremented it. + if user_active: + AuthSessionMiddleware._active_sessions = max( + 0, AuthSessionMiddleware._active_sessions - 1) + ACTIVE_CONNECTIONS.set(AuthSessionMiddleware._active_sessions) + + # Only apply session cookies if a response was produced by downstream + # application. If an exception occurred before a response was created + # we avoid raising another error here. + import logging + if response is not None: + try: + self._apply_session(response, resolved) + except Exception: + logging.getLogger(__name__).exception( + "Failed to apply session cookies to response" + ) + else: + logging.getLogger(__name__).debug( + "AuthSessionMiddleware: no response produced by downstream app (response is None)" + ) + + def _resolve_session(self, request: Request) -> _ResolutionResult: + settings = self._settings_provider() + jwt_settings = settings.jwt_settings() + strategy = build_session_strategy(settings.session_settings()) + + tokens = extract_session_tokens(request, strategy) + session = AuthSession(tokens=tokens) + request.state.auth_session = session + + if tokens.access_token: + if self._try_access_token(session, tokens, jwt_settings): + return _ResolutionResult(session=session, strategy=strategy, jwt_settings=jwt_settings) + + if tokens.refresh_token: + self._try_refresh_token( + session, tokens.refresh_token, jwt_settings) + + return _ResolutionResult(session=session, strategy=strategy, jwt_settings=jwt_settings) + + def _try_access_token( + self, + session: AuthSession, + tokens: SessionTokens, + jwt_settings: JWTSettings, + ) -> bool: + try: + payload = decode_access_token( + tokens.access_token or "", jwt_settings) + except TokenExpiredError: + return False + except (TokenDecodeError, TokenTypeMismatchError, TokenError): + session.mark_cleared() + return False + + user = self._load_user(payload.sub) + if not user or not user.is_active or _AUTH_SCOPE not in payload.scopes: + session.mark_cleared() + return False + + session.user = user + session.scopes = tuple(payload.scopes) + session.set_role_slugs(role.name for role in getattr(user, "roles", []) if role) + return True + + def _try_refresh_token( + self, + session: AuthSession, + refresh_token: str, + jwt_settings: JWTSettings, + ) -> None: + try: + payload = decode_refresh_token(refresh_token, jwt_settings) + except (TokenExpiredError, TokenDecodeError, TokenTypeMismatchError, TokenError): + session.mark_cleared() + return + + user = self._load_user(payload.sub) + if not user or not user.is_active or not self._is_refresh_scope_allowed(payload.scopes): + session.mark_cleared() + return + + session.user = user + session.scopes = tuple(payload.scopes) + session.set_role_slugs(role.name for role in getattr(user, "roles", []) if role) + + access_token = create_access_token( + str(user.id), + jwt_settings, + scopes=payload.scopes, + ) + new_refresh = create_refresh_token( + str(user.id), + jwt_settings, + scopes=payload.scopes, + ) + session.issue_tokens(access_token=access_token, + refresh_token=new_refresh) + + def _is_refresh_scope_allowed(self, scopes: Iterable[str]) -> bool: + candidate_scopes = set(scopes) + return any(scope in candidate_scopes for scope in self._refresh_scopes) + + def _load_user(self, subject: str) -> Optional[User]: + try: + user_id = int(subject) + except ValueError: + return None + + with self._unit_of_work_factory() as uow: + if not uow.users: + return None + try: + user = uow.users.get(user_id, with_roles=True) + except EntityNotFoundError: + return None + return user + + def _apply_session(self, response: Response, resolved: _ResolutionResult) -> None: + session = resolved.session + if session.clear_cookies: + clear_session_cookies(response, resolved.strategy) + return + + if session.issued_access_token: + refresh_token = session.issued_refresh_token or session.tokens.refresh_token + set_session_cookies( + response, + access_token=session.issued_access_token, + refresh_token=refresh_token, + strategy=resolved.strategy, + jwt_settings=resolved.jwt_settings, + ) diff --git a/middleware/metrics.py b/middleware/metrics.py new file mode 100644 index 0000000..38a297d --- /dev/null +++ b/middleware/metrics.py @@ -0,0 +1,58 @@ +from __future__ import annotations + +import time +from typing import Callable + +from fastapi import Request, Response +from starlette.middleware.base import BaseHTTPMiddleware + +from monitoring.metrics import observe_request +from services.metrics import get_metrics_service + + +class MetricsMiddleware(BaseHTTPMiddleware): + async def dispatch(self, request: Request, call_next: Callable[[Request], Response]) -> Response: + start_time = time.time() + response = await call_next(request) + process_time = time.time() - start_time + + observe_request( + method=request.method, + endpoint=request.url.path, + status=response.status_code, + seconds=process_time, + ) + + # Store in database asynchronously + background_tasks = getattr(request.state, "background_tasks", None) + if background_tasks: + background_tasks.add_task( + store_request_metric, + method=request.method, + endpoint=request.url.path, + status_code=response.status_code, + duration_seconds=process_time, + ) + + return response + + +async def store_request_metric( + method: str, endpoint: str, status_code: int, duration_seconds: float +) -> None: + """Store request metric in database.""" + try: + service = get_metrics_service() + service.store_metric( + metric_name="http_request", + value=duration_seconds, + labels={"method": method, "endpoint": endpoint, + "status": status_code}, + endpoint=endpoint, + method=method, + status_code=status_code, + duration_seconds=duration_seconds, + ) + except Exception: + # Log error but don't fail the request + pass diff --git a/middleware/validation.py b/middleware/validation.py index 9f2249e..cd238d5 100644 --- a/middleware/validation.py +++ b/middleware/validation.py @@ -10,10 +10,14 @@ async def validate_json( ) -> Response: # Only validate JSON for requests with a body if request.method in ("POST", "PUT", "PATCH"): - try: - # attempt to parse json body - await request.json() - except Exception: - raise HTTPException(status_code=400, detail="Invalid JSON payload") + # Only attempt JSON parsing when the client indicates a JSON content type. + content_type = (request.headers.get("content-type") or "").lower() + if "json" in content_type: + try: + # attempt to parse json body + await request.json() + except Exception: + raise HTTPException( + status_code=400, detail="Invalid JSON payload") response = await call_next(request) return response diff --git a/models/__init__.py b/models/__init__.py index a46e508..2ee714e 100644 --- a/models/__init__.py +++ b/models/__init__.py @@ -1,10 +1,72 @@ -""" -models package initializer. Import key models so they're registered -with the shared Base.metadata when the package is imported by tests. -""" +"""Database models and shared metadata for the CalMiner domain.""" -from . import application_setting # noqa: F401 -from . import currency # noqa: F401 -from . import role # noqa: F401 -from . import user # noqa: F401 -from . import theme_setting # noqa: F401 +from .financial_input import FinancialInput +from .metadata import ( + COST_BUCKET_METADATA, + RESOURCE_METADATA, + STOCHASTIC_VARIABLE_METADATA, + ResourceDescriptor, + StochasticVariableDescriptor, +) +from .performance_metric import PerformanceMetric +from .pricing_settings import ( + PricingImpuritySettings, + PricingMetalSettings, + PricingSettings, +) +from .enums import ( + CostBucket, + DistributionType, + FinancialCategory, + MiningOperationType, + ResourceType, + ScenarioStatus, + StochasticVariable, +) +from .project import Project +from .scenario import Scenario +from .simulation_parameter import SimulationParameter +from .user import Role, User, UserRole, password_context +from .navigation import NavigationGroup, NavigationLink + +from .profitability_snapshot import ProjectProfitability, ScenarioProfitability +from .capex_snapshot import ProjectCapexSnapshot, ScenarioCapexSnapshot +from .opex_snapshot import ( + ProjectOpexSnapshot, + ScenarioOpexSnapshot, +) + +__all__ = [ + "FinancialCategory", + "FinancialInput", + "MiningOperationType", + "Project", + "ProjectProfitability", + "ProjectCapexSnapshot", + "ProjectOpexSnapshot", + "PricingSettings", + "PricingMetalSettings", + "PricingImpuritySettings", + "Scenario", + "ScenarioProfitability", + "ScenarioCapexSnapshot", + "ScenarioOpexSnapshot", + "ScenarioStatus", + "DistributionType", + "SimulationParameter", + "ResourceType", + "CostBucket", + "StochasticVariable", + "RESOURCE_METADATA", + "COST_BUCKET_METADATA", + "STOCHASTIC_VARIABLE_METADATA", + "ResourceDescriptor", + "StochasticVariableDescriptor", + "User", + "Role", + "UserRole", + "password_context", + "PerformanceMetric", + "NavigationGroup", + "NavigationLink", +] diff --git a/models/application_setting.py b/models/application_setting.py deleted file mode 100644 index ed98160..0000000 --- a/models/application_setting.py +++ /dev/null @@ -1,38 +0,0 @@ -from datetime import datetime -from typing import Optional - -from sqlalchemy import Boolean, DateTime, String, Text -from sqlalchemy.orm import Mapped, mapped_column -from sqlalchemy.sql import func - -from config.database import Base - - -class ApplicationSetting(Base): - __tablename__ = "application_setting" - - id: Mapped[int] = mapped_column(primary_key=True, index=True) - key: Mapped[str] = mapped_column(String(128), unique=True, nullable=False) - value: Mapped[str] = mapped_column(Text, nullable=False) - value_type: Mapped[str] = mapped_column( - String(32), nullable=False, default="string" - ) - category: Mapped[str] = mapped_column( - String(32), nullable=False, default="general" - ) - description: Mapped[Optional[str]] = mapped_column(Text, nullable=True) - is_editable: Mapped[bool] = mapped_column( - Boolean, nullable=False, default=True - ) - created_at: Mapped[datetime] = mapped_column( - DateTime(timezone=True), server_default=func.now(), nullable=False - ) - updated_at: Mapped[datetime] = mapped_column( - DateTime(timezone=True), - server_default=func.now(), - onupdate=func.now(), - nullable=False, - ) - - def __repr__(self) -> str: - return f"" diff --git a/models/capex.py b/models/capex.py deleted file mode 100644 index 68b6749..0000000 --- a/models/capex.py +++ /dev/null @@ -1,71 +0,0 @@ -from sqlalchemy import event, text -from sqlalchemy import Column, Integer, Float, String, ForeignKey -from sqlalchemy.orm import relationship -from config.database import Base - - -class Capex(Base): - __tablename__ = "capex" - - id = Column(Integer, primary_key=True, index=True) - scenario_id = Column(Integer, ForeignKey("scenario.id"), nullable=False) - amount = Column(Float, nullable=False) - description = Column(String, nullable=True) - currency_id = Column(Integer, ForeignKey("currency.id"), nullable=False) - - scenario = relationship("Scenario", back_populates="capex_items") - currency = relationship("Currency", back_populates="capex_items") - - def __repr__(self): - return ( - f"" - ) - - @property - def currency_code(self) -> str: - return self.currency.code if self.currency else None - - @currency_code.setter - def currency_code(self, value: str) -> None: - # store pending code so application code or migrations can pick it up - setattr( - self, "_currency_code_pending", (value or "USD").strip().upper() - ) - - -# SQLAlchemy event handlers to ensure currency_id is set before insert/update - - -def _resolve_currency(mapper, connection, target): - # If currency_id already set, nothing to do - if getattr(target, "currency_id", None): - return - code = getattr(target, "_currency_code_pending", None) or "USD" - # Try to find existing currency id - row = connection.execute( - text("SELECT id FROM currency WHERE code = :code"), {"code": code} - ).fetchone() - if row: - cid = row[0] - else: - # Insert new currency and attempt to get lastrowid - res = connection.execute( - text( - "INSERT INTO currency (code, name, symbol, is_active) VALUES (:code, :name, :symbol, :active)" - ), - {"code": code, "name": code, "symbol": None, "active": True}, - ) - try: - cid = res.lastrowid - except Exception: - # fallback: select after insert - cid = connection.execute( - text("SELECT id FROM currency WHERE code = :code"), - {"code": code}, - ).scalar() - target.currency_id = cid - - -event.listen(Capex, "before_insert", _resolve_currency) -event.listen(Capex, "before_update", _resolve_currency) diff --git a/models/capex_snapshot.py b/models/capex_snapshot.py new file mode 100644 index 0000000..08b6ae8 --- /dev/null +++ b/models/capex_snapshot.py @@ -0,0 +1,111 @@ +from __future__ import annotations + +from datetime import datetime +from typing import TYPE_CHECKING + +from sqlalchemy import JSON, DateTime, ForeignKey, Integer, Numeric, String +from sqlalchemy.orm import Mapped, mapped_column, relationship +from sqlalchemy.sql import func + +from config.database import Base + +if TYPE_CHECKING: # pragma: no cover + from .project import Project + from .scenario import Scenario + from .user import User + + +class ProjectCapexSnapshot(Base): + """Snapshot of aggregated capex metrics at the project level.""" + + __tablename__ = "project_capex_snapshots" + + id: Mapped[int] = mapped_column(Integer, primary_key=True) + project_id: Mapped[int] = mapped_column( + ForeignKey("projects.id", ondelete="CASCADE"), nullable=False, index=True + ) + created_by_id: Mapped[int | None] = mapped_column( + ForeignKey("users.id", ondelete="SET NULL"), nullable=True, index=True + ) + calculation_source: Mapped[str | None] = mapped_column( + String(64), nullable=True) + calculated_at: Mapped[datetime] = mapped_column( + DateTime(timezone=True), nullable=False, server_default=func.now() + ) + currency_code: Mapped[str | None] = mapped_column(String(3), nullable=True) + total_capex: Mapped[float | None] = mapped_column( + Numeric(18, 2), nullable=True) + contingency_pct: Mapped[float | None] = mapped_column( + Numeric(12, 6), nullable=True) + contingency_amount: Mapped[float | None] = mapped_column( + Numeric(18, 2), nullable=True) + total_with_contingency: Mapped[float | None] = mapped_column( + Numeric(18, 2), nullable=True) + component_count: Mapped[int | None] = mapped_column(Integer, nullable=True) + payload: Mapped[dict | None] = mapped_column(JSON, nullable=True) + created_at: Mapped[datetime] = mapped_column( + DateTime(timezone=True), nullable=False, server_default=func.now() + ) + updated_at: Mapped[datetime] = mapped_column( + DateTime(timezone=True), nullable=False, server_default=func.now(), onupdate=func.now() + ) + + project: Mapped[Project] = relationship( + "Project", back_populates="capex_snapshots" + ) + created_by: Mapped[User | None] = relationship("User") + + def __repr__(self) -> str: # pragma: no cover + return ( + "ProjectCapexSnapshot(id={id!r}, project_id={project_id!r}, total_capex={total_capex!r})".format( + id=self.id, project_id=self.project_id, total_capex=self.total_capex + ) + ) + + +class ScenarioCapexSnapshot(Base): + """Snapshot of capex metrics for an individual scenario.""" + + __tablename__ = "scenario_capex_snapshots" + + id: Mapped[int] = mapped_column(Integer, primary_key=True) + scenario_id: Mapped[int] = mapped_column( + ForeignKey("scenarios.id", ondelete="CASCADE"), nullable=False, index=True + ) + created_by_id: Mapped[int | None] = mapped_column( + ForeignKey("users.id", ondelete="SET NULL"), nullable=True, index=True + ) + calculation_source: Mapped[str | None] = mapped_column( + String(64), nullable=True) + calculated_at: Mapped[datetime] = mapped_column( + DateTime(timezone=True), nullable=False, server_default=func.now() + ) + currency_code: Mapped[str | None] = mapped_column(String(3), nullable=True) + total_capex: Mapped[float | None] = mapped_column( + Numeric(18, 2), nullable=True) + contingency_pct: Mapped[float | None] = mapped_column( + Numeric(12, 6), nullable=True) + contingency_amount: Mapped[float | None] = mapped_column( + Numeric(18, 2), nullable=True) + total_with_contingency: Mapped[float | None] = mapped_column( + Numeric(18, 2), nullable=True) + component_count: Mapped[int | None] = mapped_column(Integer, nullable=True) + payload: Mapped[dict | None] = mapped_column(JSON, nullable=True) + created_at: Mapped[datetime] = mapped_column( + DateTime(timezone=True), nullable=False, server_default=func.now() + ) + updated_at: Mapped[datetime] = mapped_column( + DateTime(timezone=True), nullable=False, server_default=func.now(), onupdate=func.now() + ) + + scenario: Mapped[Scenario] = relationship( + "Scenario", back_populates="capex_snapshots" + ) + created_by: Mapped[User | None] = relationship("User") + + def __repr__(self) -> str: # pragma: no cover + return ( + "ScenarioCapexSnapshot(id={id!r}, scenario_id={scenario_id!r}, total_capex={total_capex!r})".format( + id=self.id, scenario_id=self.scenario_id, total_capex=self.total_capex + ) + ) diff --git a/models/consumption.py b/models/consumption.py deleted file mode 100644 index c5239bc..0000000 --- a/models/consumption.py +++ /dev/null @@ -1,22 +0,0 @@ -from sqlalchemy import Column, Integer, Float, String, ForeignKey -from sqlalchemy.orm import relationship -from config.database import Base - - -class Consumption(Base): - __tablename__ = "consumption" - - id = Column(Integer, primary_key=True, index=True) - scenario_id = Column(Integer, ForeignKey("scenario.id"), nullable=False) - amount = Column(Float, nullable=False) - description = Column(String, nullable=True) - unit_name = Column(String(64), nullable=True) - unit_symbol = Column(String(16), nullable=True) - - scenario = relationship("Scenario", back_populates="consumption_items") - - def __repr__(self): - return ( - f"" - ) diff --git a/models/currency.py b/models/currency.py deleted file mode 100644 index de95abd..0000000 --- a/models/currency.py +++ /dev/null @@ -1,24 +0,0 @@ -from sqlalchemy import Column, Integer, String, Boolean -from sqlalchemy.orm import relationship -from config.database import Base - - -class Currency(Base): - __tablename__ = "currency" - - id = Column(Integer, primary_key=True, index=True) - code = Column(String(3), nullable=False, unique=True, index=True) - name = Column(String(128), nullable=False) - symbol = Column(String(8), nullable=True) - is_active = Column(Boolean, nullable=False, default=True) - - # reverse relationships (optional) - capex_items = relationship( - "Capex", back_populates="currency", lazy="select" - ) - opex_items = relationship("Opex", back_populates="currency", lazy="select") - - def __repr__(self): - return ( - f"" - ) diff --git a/models/distribution.py b/models/distribution.py deleted file mode 100644 index 9f9832a..0000000 --- a/models/distribution.py +++ /dev/null @@ -1,14 +0,0 @@ -from sqlalchemy import Column, Integer, String, JSON -from config.database import Base - - -class Distribution(Base): - __tablename__ = "distribution" - - id = Column(Integer, primary_key=True, index=True) - name = Column(String, nullable=False) - distribution_type = Column(String, nullable=False) - parameters = Column(JSON, nullable=True) - - def __repr__(self): - return f"" diff --git a/models/enums.py b/models/enums.py new file mode 100644 index 0000000..9e6f970 --- /dev/null +++ b/models/enums.py @@ -0,0 +1,96 @@ +from __future__ import annotations + +from enum import Enum +from typing import Type + +from sqlalchemy import Enum as SQLEnum + + +def sql_enum(enum_cls: Type[Enum], *, name: str) -> SQLEnum: + """Build a SQLAlchemy Enum that maps using the enum member values.""" + + return SQLEnum( + enum_cls, + name=name, + create_type=False, + validate_strings=True, + values_callable=lambda enum_cls: [member.value for member in enum_cls], + ) + + +class MiningOperationType(str, Enum): + """Supported mining operation categories.""" + + OPEN_PIT = "open_pit" + UNDERGROUND = "underground" + IN_SITU_LEACH = "in_situ_leach" + PLACER = "placer" + QUARRY = "quarry" + MOUNTAINTOP_REMOVAL = "mountaintop_removal" + OTHER = "other" + + +class ScenarioStatus(str, Enum): + """Lifecycle states for project scenarios.""" + + DRAFT = "draft" + ACTIVE = "active" + ARCHIVED = "archived" + + +class FinancialCategory(str, Enum): + """Enumeration of cost and revenue classifications.""" + + CAPITAL_EXPENDITURE = "capex" + OPERATING_EXPENDITURE = "opex" + REVENUE = "revenue" + CONTINGENCY = "contingency" + OTHER = "other" + + +class DistributionType(str, Enum): + """Supported stochastic distribution families for simulations.""" + + NORMAL = "normal" + TRIANGULAR = "triangular" + UNIFORM = "uniform" + LOGNORMAL = "lognormal" + CUSTOM = "custom" + + +class ResourceType(str, Enum): + """Primary consumables and resources used in mining operations.""" + + DIESEL = "diesel" + ELECTRICITY = "electricity" + WATER = "water" + EXPLOSIVES = "explosives" + REAGENTS = "reagents" + LABOR = "labor" + EQUIPMENT_HOURS = "equipment_hours" + TAILINGS_CAPACITY = "tailings_capacity" + + +class CostBucket(str, Enum): + """Granular cost buckets aligned with project accounting.""" + + CAPITAL_INITIAL = "capital_initial" + CAPITAL_SUSTAINING = "capital_sustaining" + OPERATING_FIXED = "operating_fixed" + OPERATING_VARIABLE = "operating_variable" + MAINTENANCE = "maintenance" + RECLAMATION = "reclamation" + ROYALTIES = "royalties" + GENERAL_ADMIN = "general_admin" + + +class StochasticVariable(str, Enum): + """Domain variables that typically require probabilistic modelling.""" + + ORE_GRADE = "ore_grade" + RECOVERY_RATE = "recovery_rate" + METAL_PRICE = "metal_price" + OPERATING_COST = "operating_cost" + CAPITAL_COST = "capital_cost" + DISCOUNT_RATE = "discount_rate" + THROUGHPUT = "throughput" diff --git a/models/equipment.py b/models/equipment.py deleted file mode 100644 index e431891..0000000 --- a/models/equipment.py +++ /dev/null @@ -1,17 +0,0 @@ -from sqlalchemy import Column, Integer, String, ForeignKey -from sqlalchemy.orm import relationship -from config.database import Base - - -class Equipment(Base): - __tablename__ = "equipment" - - id = Column(Integer, primary_key=True, index=True) - scenario_id = Column(Integer, ForeignKey("scenario.id"), nullable=False) - name = Column(String, nullable=False) - description = Column(String, nullable=True) - - scenario = relationship("Scenario", back_populates="equipment_items") - - def __repr__(self): - return f"" diff --git a/models/financial_input.py b/models/financial_input.py new file mode 100644 index 0000000..929f121 --- /dev/null +++ b/models/financial_input.py @@ -0,0 +1,62 @@ +from __future__ import annotations + +from datetime import date, datetime +from typing import TYPE_CHECKING + +from sqlalchemy import ( + Date, + DateTime, + ForeignKey, + Integer, + Numeric, + String, + Text, +) +from sqlalchemy.orm import Mapped, mapped_column, relationship, validates + +from sqlalchemy.sql import func + +from config.database import Base +from .enums import CostBucket, FinancialCategory, sql_enum +from services.currency import normalise_currency + +if TYPE_CHECKING: # pragma: no cover + from .scenario import Scenario + + +class FinancialInput(Base): + """Line-item financial assumption attached to a scenario.""" + + __tablename__ = "financial_inputs" + + id: Mapped[int] = mapped_column(Integer, primary_key=True) + scenario_id: Mapped[int] = mapped_column( + ForeignKey("scenarios.id", ondelete="CASCADE"), nullable=False, index=True + ) + name: Mapped[str] = mapped_column(String(255), nullable=False) + category: Mapped[FinancialCategory] = mapped_column( + sql_enum(FinancialCategory, name="financialcategory"), nullable=False + ) + cost_bucket: Mapped[CostBucket | None] = mapped_column( + sql_enum(CostBucket, name="costbucket"), nullable=True + ) + amount: Mapped[float] = mapped_column(Numeric(18, 2), nullable=False) + currency: Mapped[str | None] = mapped_column(String(3), nullable=True) + effective_date: Mapped[date | None] = mapped_column(Date, nullable=True) + notes: Mapped[str | None] = mapped_column(Text, nullable=True) + created_at: Mapped[datetime] = mapped_column( + DateTime(timezone=True), nullable=False, server_default=func.now() + ) + updated_at: Mapped[datetime] = mapped_column( + DateTime(timezone=True), nullable=False, server_default=func.now(), onupdate=func.now() + ) + + scenario: Mapped["Scenario"] = relationship( + "Scenario", back_populates="financial_inputs") + + @validates("currency") + def _validate_currency(self, key: str, value: str | None) -> str | None: + return normalise_currency(value) + + def __repr__(self) -> str: # pragma: no cover + return f"FinancialInput(id={self.id!r}, scenario_id={self.scenario_id!r}, name={self.name!r})" diff --git a/models/import_export_log.py b/models/import_export_log.py new file mode 100644 index 0000000..026a345 --- /dev/null +++ b/models/import_export_log.py @@ -0,0 +1,31 @@ +from __future__ import annotations + + +from sqlalchemy import Column, DateTime, ForeignKey, Integer, String, Text +from sqlalchemy.sql import func + +from config.database import Base + + +class ImportExportLog(Base): + """Audit log for import and export operations.""" + + __tablename__ = "import_export_logs" + + id = Column(Integer, primary_key=True, index=True) + action = Column(String(32), nullable=False) # preview, commit, export + dataset = Column(String(32), nullable=False) # projects, scenarios, etc. + status = Column(String(16), nullable=False) # success, failure + filename = Column(String(255), nullable=True) + row_count = Column(Integer, nullable=True) + detail = Column(Text, nullable=True) + user_id = Column(Integer, ForeignKey("users.id"), nullable=True) + created_at = Column( + DateTime(timezone=True), nullable=False, server_default=func.now() + ) + + def __repr__(self) -> str: # pragma: no cover + return ( + f"ImportExportLog(id={self.id}, action={self.action}, " + f"dataset={self.dataset}, status={self.status})" + ) diff --git a/models/maintenance.py b/models/maintenance.py deleted file mode 100644 index 43a7aea..0000000 --- a/models/maintenance.py +++ /dev/null @@ -1,23 +0,0 @@ -from sqlalchemy import Column, Date, Float, ForeignKey, Integer, String -from sqlalchemy.orm import relationship -from config.database import Base - - -class Maintenance(Base): - __tablename__ = "maintenance" - - id = Column(Integer, primary_key=True, index=True) - equipment_id = Column(Integer, ForeignKey("equipment.id"), nullable=False) - scenario_id = Column(Integer, ForeignKey("scenario.id"), nullable=False) - maintenance_date = Column(Date, nullable=False) - description = Column(String, nullable=True) - cost = Column(Float, nullable=False) - - equipment = relationship("Equipment") - scenario = relationship("Scenario", back_populates="maintenance_items") - - def __repr__(self) -> str: - return ( - f"" - ) diff --git a/models/metadata.py b/models/metadata.py new file mode 100644 index 0000000..a3e14c3 --- /dev/null +++ b/models/metadata.py @@ -0,0 +1,108 @@ +from __future__ import annotations + +from dataclasses import dataclass +from .enums import ResourceType, CostBucket, StochasticVariable + + +@dataclass(frozen=True) +class ResourceDescriptor: + """Describes canonical metadata for a resource type.""" + + unit: str + description: str + + +RESOURCE_METADATA: dict[ResourceType, ResourceDescriptor] = { + ResourceType.DIESEL: ResourceDescriptor(unit="L", description="Diesel fuel consumption"), + ResourceType.ELECTRICITY: ResourceDescriptor(unit="kWh", description="Electrical power usage"), + ResourceType.WATER: ResourceDescriptor(unit="m3", description="Process and dust suppression water"), + ResourceType.EXPLOSIVES: ResourceDescriptor(unit="kg", description="Blasting agent consumption"), + ResourceType.REAGENTS: ResourceDescriptor(unit="kg", description="Processing reagents"), + ResourceType.LABOR: ResourceDescriptor(unit="hours", description="Direct labor hours"), + ResourceType.EQUIPMENT_HOURS: ResourceDescriptor(unit="hours", description="Mobile equipment operating hours"), + ResourceType.TAILINGS_CAPACITY: ResourceDescriptor(unit="m3", description="Tailings storage usage"), +} + + +@dataclass(frozen=True) +class CostBucketDescriptor: + """Describes reporting label and guidance for a cost bucket.""" + + label: str + description: str + + +COST_BUCKET_METADATA: dict[CostBucket, CostBucketDescriptor] = { + CostBucket.CAPITAL_INITIAL: CostBucketDescriptor( + label="Initial Capital", + description="Pre-production capital required to construct the mine", + ), + CostBucket.CAPITAL_SUSTAINING: CostBucketDescriptor( + label="Sustaining Capital", + description="Ongoing capital investments to maintain operations", + ), + CostBucket.OPERATING_FIXED: CostBucketDescriptor( + label="Fixed Operating", + description="Fixed operating costs independent of production rate", + ), + CostBucket.OPERATING_VARIABLE: CostBucketDescriptor( + label="Variable Operating", + description="Costs that scale with throughput or production", + ), + CostBucket.MAINTENANCE: CostBucketDescriptor( + label="Maintenance", + description="Maintenance and repair expenditures", + ), + CostBucket.RECLAMATION: CostBucketDescriptor( + label="Reclamation", + description="Mine closure and reclamation liabilities", + ), + CostBucket.ROYALTIES: CostBucketDescriptor( + label="Royalties", + description="Royalty and streaming obligations", + ), + CostBucket.GENERAL_ADMIN: CostBucketDescriptor( + label="G&A", + description="Corporate and site general and administrative costs", + ), +} + + +@dataclass(frozen=True) +class StochasticVariableDescriptor: + """Metadata describing how a stochastic variable is typically modelled.""" + + unit: str + description: str + + +STOCHASTIC_VARIABLE_METADATA: dict[StochasticVariable, StochasticVariableDescriptor] = { + StochasticVariable.ORE_GRADE: StochasticVariableDescriptor( + unit="g/t", + description="Head grade variability across the ore body", + ), + StochasticVariable.RECOVERY_RATE: StochasticVariableDescriptor( + unit="%", + description="Metallurgical recovery uncertainty", + ), + StochasticVariable.METAL_PRICE: StochasticVariableDescriptor( + unit="$/unit", + description="Commodity price fluctuations", + ), + StochasticVariable.OPERATING_COST: StochasticVariableDescriptor( + unit="$/t", + description="Operating cost per tonne volatility", + ), + StochasticVariable.CAPITAL_COST: StochasticVariableDescriptor( + unit="$", + description="Capital cost overrun/underrun potential", + ), + StochasticVariable.DISCOUNT_RATE: StochasticVariableDescriptor( + unit="%", + description="Discount rate sensitivity", + ), + StochasticVariable.THROUGHPUT: StochasticVariableDescriptor( + unit="t/d", + description="Plant throughput variability", + ), +} diff --git a/models/navigation.py b/models/navigation.py new file mode 100644 index 0000000..0186576 --- /dev/null +++ b/models/navigation.py @@ -0,0 +1,125 @@ +from __future__ import annotations + +from datetime import datetime +from typing import List, Optional + +from sqlalchemy import ( + Boolean, + CheckConstraint, + DateTime, + ForeignKey, + Index, + Integer, + String, + UniqueConstraint, +) +from sqlalchemy.orm import Mapped, mapped_column, relationship +from sqlalchemy.sql import func +from sqlalchemy.ext.mutable import MutableList +from sqlalchemy import JSON + +from config.database import Base + + +class NavigationGroup(Base): + __tablename__ = "navigation_groups" + __table_args__ = ( + UniqueConstraint("slug", name="uq_navigation_groups_slug"), + Index("ix_navigation_groups_sort_order", "sort_order"), + ) + + id: Mapped[int] = mapped_column(Integer, primary_key=True) + slug: Mapped[str] = mapped_column(String(64), nullable=False) + label: Mapped[str] = mapped_column(String(128), nullable=False) + sort_order: Mapped[int] = mapped_column( + Integer, nullable=False, default=100) + icon: Mapped[Optional[str]] = mapped_column(String(64)) + tooltip: Mapped[Optional[str]] = mapped_column(String(255)) + is_enabled: Mapped[bool] = mapped_column( + Boolean, nullable=False, default=True) + created_at: Mapped[datetime] = mapped_column( + DateTime(timezone=True), nullable=False, server_default=func.now() + ) + updated_at: Mapped[datetime] = mapped_column( + DateTime(timezone=True), nullable=False, server_default=func.now(), onupdate=func.now() + ) + + links: Mapped[List["NavigationLink"]] = relationship( + "NavigationLink", + back_populates="group", + cascade="all, delete-orphan", + order_by="NavigationLink.sort_order", + ) + + def __repr__(self) -> str: # pragma: no cover + return f"NavigationGroup(id={self.id!r}, slug={self.slug!r})" + + +class NavigationLink(Base): + __tablename__ = "navigation_links" + __table_args__ = ( + UniqueConstraint("group_id", "slug", + name="uq_navigation_links_group_slug"), + Index("ix_navigation_links_group_sort", "group_id", "sort_order"), + Index("ix_navigation_links_parent_sort", + "parent_link_id", "sort_order"), + CheckConstraint( + "(route_name IS NOT NULL) OR (href_override IS NOT NULL)", + name="ck_navigation_links_route_or_href", + ), + ) + + id: Mapped[int] = mapped_column(Integer, primary_key=True) + group_id: Mapped[int] = mapped_column( + ForeignKey("navigation_groups.id", ondelete="CASCADE"), nullable=False + ) + parent_link_id: Mapped[Optional[int]] = mapped_column( + ForeignKey("navigation_links.id", ondelete="CASCADE") + ) + slug: Mapped[str] = mapped_column(String(64), nullable=False) + label: Mapped[str] = mapped_column(String(128), nullable=False) + route_name: Mapped[Optional[str]] = mapped_column(String(128)) + href_override: Mapped[Optional[str]] = mapped_column(String(512)) + match_prefix: Mapped[Optional[str]] = mapped_column(String(512)) + sort_order: Mapped[int] = mapped_column( + Integer, nullable=False, default=100) + icon: Mapped[Optional[str]] = mapped_column(String(64)) + tooltip: Mapped[Optional[str]] = mapped_column(String(255)) + required_roles: Mapped[list[str]] = mapped_column( + MutableList.as_mutable(JSON), nullable=False, default=list + ) + is_enabled: Mapped[bool] = mapped_column( + Boolean, nullable=False, default=True) + is_external: Mapped[bool] = mapped_column( + Boolean, nullable=False, default=False) + created_at: Mapped[datetime] = mapped_column( + DateTime(timezone=True), nullable=False, server_default=func.now() + ) + updated_at: Mapped[datetime] = mapped_column( + DateTime(timezone=True), nullable=False, server_default=func.now(), onupdate=func.now() + ) + + group: Mapped[NavigationGroup] = relationship( + NavigationGroup, + back_populates="links", + ) + parent: Mapped[Optional["NavigationLink"]] = relationship( + "NavigationLink", + remote_side="NavigationLink.id", + back_populates="children", + ) + children: Mapped[List["NavigationLink"]] = relationship( + "NavigationLink", + back_populates="parent", + cascade="all, delete-orphan", + order_by="NavigationLink.sort_order", + ) + + def is_visible_for_roles(self, roles: list[str]) -> bool: + if not self.required_roles: + return True + role_set = set(roles) + return any(role in role_set for role in self.required_roles) + + def __repr__(self) -> str: # pragma: no cover + return f"NavigationLink(id={self.id!r}, slug={self.slug!r})" diff --git a/models/opex.py b/models/opex.py deleted file mode 100644 index 5c0e703..0000000 --- a/models/opex.py +++ /dev/null @@ -1,63 +0,0 @@ -from sqlalchemy import event, text -from sqlalchemy import Column, Integer, Float, String, ForeignKey -from sqlalchemy.orm import relationship -from config.database import Base - - -class Opex(Base): - __tablename__ = "opex" - - id = Column(Integer, primary_key=True, index=True) - scenario_id = Column(Integer, ForeignKey("scenario.id"), nullable=False) - amount = Column(Float, nullable=False) - description = Column(String, nullable=True) - currency_id = Column(Integer, ForeignKey("currency.id"), nullable=False) - - scenario = relationship("Scenario", back_populates="opex_items") - currency = relationship("Currency", back_populates="opex_items") - - def __repr__(self): - return ( - f"" - ) - - @property - def currency_code(self) -> str: - return self.currency.code if self.currency else None - - @currency_code.setter - def currency_code(self, value: str) -> None: - setattr( - self, "_currency_code_pending", (value or "USD").strip().upper() - ) - - -def _resolve_currency_opex(mapper, connection, target): - if getattr(target, "currency_id", None): - return - code = getattr(target, "_currency_code_pending", None) or "USD" - row = connection.execute( - text("SELECT id FROM currency WHERE code = :code"), {"code": code} - ).fetchone() - if row: - cid = row[0] - else: - res = connection.execute( - text( - "INSERT INTO currency (code, name, symbol, is_active) VALUES (:code, :name, :symbol, :active)" - ), - {"code": code, "name": code, "symbol": None, "active": True}, - ) - try: - cid = res.lastrowid - except Exception: - cid = connection.execute( - text("SELECT id FROM currency WHERE code = :code"), - {"code": code}, - ).scalar() - target.currency_id = cid - - -event.listen(Opex, "before_insert", _resolve_currency_opex) -event.listen(Opex, "before_update", _resolve_currency_opex) diff --git a/models/opex_snapshot.py b/models/opex_snapshot.py new file mode 100644 index 0000000..adb15ba --- /dev/null +++ b/models/opex_snapshot.py @@ -0,0 +1,123 @@ +from __future__ import annotations + +from datetime import datetime +from typing import TYPE_CHECKING + +from sqlalchemy import JSON, Boolean, DateTime, ForeignKey, Integer, Numeric, String +from sqlalchemy.orm import Mapped, mapped_column, relationship +from sqlalchemy.sql import func + +from config.database import Base + +if TYPE_CHECKING: # pragma: no cover + from .project import Project + from .scenario import Scenario + from .user import User + + +class ProjectOpexSnapshot(Base): + """Snapshot of recurring opex metrics at the project level.""" + + __tablename__ = "project_opex_snapshots" + + id: Mapped[int] = mapped_column(Integer, primary_key=True) + project_id: Mapped[int] = mapped_column( + ForeignKey("projects.id", ondelete="CASCADE"), nullable=False, index=True + ) + created_by_id: Mapped[int | None] = mapped_column( + ForeignKey("users.id", ondelete="SET NULL"), nullable=True, index=True + ) + calculation_source: Mapped[str | None] = mapped_column( + String(64), nullable=True) + calculated_at: Mapped[datetime] = mapped_column( + DateTime(timezone=True), nullable=False, server_default=func.now() + ) + currency_code: Mapped[str | None] = mapped_column(String(3), nullable=True) + overall_annual: Mapped[float | None] = mapped_column( + Numeric(18, 2), nullable=True) + escalated_total: Mapped[float | None] = mapped_column( + Numeric(18, 2), nullable=True) + annual_average: Mapped[float | None] = mapped_column( + Numeric(18, 2), nullable=True) + evaluation_horizon_years: Mapped[int | None] = mapped_column( + Integer, nullable=True) + escalation_pct: Mapped[float | None] = mapped_column( + Numeric(12, 6), nullable=True) + apply_escalation: Mapped[bool] = mapped_column( + Boolean, nullable=False, default=True) + component_count: Mapped[int | None] = mapped_column(Integer, nullable=True) + payload: Mapped[dict | None] = mapped_column(JSON, nullable=True) + created_at: Mapped[datetime] = mapped_column( + DateTime(timezone=True), nullable=False, server_default=func.now() + ) + updated_at: Mapped[datetime] = mapped_column( + DateTime(timezone=True), nullable=False, server_default=func.now(), onupdate=func.now() + ) + + project: Mapped[Project] = relationship( + "Project", back_populates="opex_snapshots" + ) + created_by: Mapped[User | None] = relationship("User") + + def __repr__(self) -> str: # pragma: no cover + return ( + "ProjectOpexSnapshot(id={id!r}, project_id={project_id!r}, overall_annual={overall_annual!r})".format( + id=self.id, + project_id=self.project_id, + overall_annual=self.overall_annual, + ) + ) + + +class ScenarioOpexSnapshot(Base): + """Snapshot of opex metrics for an individual scenario.""" + + __tablename__ = "scenario_opex_snapshots" + + id: Mapped[int] = mapped_column(Integer, primary_key=True) + scenario_id: Mapped[int] = mapped_column( + ForeignKey("scenarios.id", ondelete="CASCADE"), nullable=False, index=True + ) + created_by_id: Mapped[int | None] = mapped_column( + ForeignKey("users.id", ondelete="SET NULL"), nullable=True, index=True + ) + calculation_source: Mapped[str | None] = mapped_column( + String(64), nullable=True) + calculated_at: Mapped[datetime] = mapped_column( + DateTime(timezone=True), nullable=False, server_default=func.now() + ) + currency_code: Mapped[str | None] = mapped_column(String(3), nullable=True) + overall_annual: Mapped[float | None] = mapped_column( + Numeric(18, 2), nullable=True) + escalated_total: Mapped[float | None] = mapped_column( + Numeric(18, 2), nullable=True) + annual_average: Mapped[float | None] = mapped_column( + Numeric(18, 2), nullable=True) + evaluation_horizon_years: Mapped[int | None] = mapped_column( + Integer, nullable=True) + escalation_pct: Mapped[float | None] = mapped_column( + Numeric(12, 6), nullable=True) + apply_escalation: Mapped[bool] = mapped_column( + Boolean, nullable=False, default=True) + component_count: Mapped[int | None] = mapped_column(Integer, nullable=True) + payload: Mapped[dict | None] = mapped_column(JSON, nullable=True) + created_at: Mapped[datetime] = mapped_column( + DateTime(timezone=True), nullable=False, server_default=func.now() + ) + updated_at: Mapped[datetime] = mapped_column( + DateTime(timezone=True), nullable=False, server_default=func.now(), onupdate=func.now() + ) + + scenario: Mapped[Scenario] = relationship( + "Scenario", back_populates="opex_snapshots" + ) + created_by: Mapped[User | None] = relationship("User") + + def __repr__(self) -> str: # pragma: no cover + return ( + "ScenarioOpexSnapshot(id={id!r}, scenario_id={scenario_id!r}, overall_annual={overall_annual!r})".format( + id=self.id, + scenario_id=self.scenario_id, + overall_annual=self.overall_annual, + ) + ) diff --git a/models/parameters.py b/models/parameters.py deleted file mode 100644 index 822a011..0000000 --- a/models/parameters.py +++ /dev/null @@ -1,29 +0,0 @@ -from typing import Any, Dict, Optional - -from sqlalchemy import ForeignKey, JSON -from sqlalchemy.orm import Mapped, mapped_column, relationship -from config.database import Base - - -class Parameter(Base): - __tablename__ = "parameter" - - id: Mapped[int] = mapped_column(primary_key=True, index=True) - scenario_id: Mapped[int] = mapped_column( - ForeignKey("scenario.id"), nullable=False - ) - name: Mapped[str] = mapped_column(nullable=False) - value: Mapped[float] = mapped_column(nullable=False) - distribution_id: Mapped[Optional[int]] = mapped_column( - ForeignKey("distribution.id"), nullable=True - ) - distribution_type: Mapped[Optional[str]] = mapped_column(nullable=True) - distribution_parameters: Mapped[Optional[Dict[str, Any]]] = mapped_column( - JSON, nullable=True - ) - - scenario = relationship("Scenario", back_populates="parameters") - distribution = relationship("Distribution") - - def __repr__(self): - return f"" diff --git a/models/performance_metric.py b/models/performance_metric.py new file mode 100644 index 0000000..0304fef --- /dev/null +++ b/models/performance_metric.py @@ -0,0 +1,24 @@ +from __future__ import annotations + +from datetime import datetime + +from sqlalchemy import Column, DateTime, Float, Integer, String + +from config.database import Base + + +class PerformanceMetric(Base): + __tablename__ = "performance_metrics" + + id = Column(Integer, primary_key=True, index=True) + timestamp = Column(DateTime, default=datetime.utcnow, index=True) + metric_name = Column(String, index=True) + value = Column(Float) + labels = Column(String) # JSON string of labels + endpoint = Column(String, index=True, nullable=True) + method = Column(String, nullable=True) + status_code = Column(Integer, nullable=True) + duration_seconds = Column(Float, nullable=True) + + def __repr__(self) -> str: + return f"" diff --git a/models/pricing_settings.py b/models/pricing_settings.py new file mode 100644 index 0000000..7c7b8ce --- /dev/null +++ b/models/pricing_settings.py @@ -0,0 +1,176 @@ +"""Database models for persisted pricing configuration settings.""" + +from __future__ import annotations + +from datetime import datetime +from typing import TYPE_CHECKING + +from sqlalchemy import ( + JSON, + DateTime, + ForeignKey, + Integer, + Numeric, + String, + Text, + UniqueConstraint, +) +from sqlalchemy.orm import Mapped, mapped_column, relationship, validates +from sqlalchemy.sql import func + +from config.database import Base +from services.currency import normalise_currency + +if TYPE_CHECKING: # pragma: no cover + from .project import Project + + +class PricingSettings(Base): + """Persisted pricing defaults applied to scenario evaluations.""" + + __tablename__ = "pricing_settings" + + id: Mapped[int] = mapped_column(Integer, primary_key=True) + name: Mapped[str] = mapped_column(String(128), nullable=False, unique=True) + slug: Mapped[str] = mapped_column(String(64), nullable=False, unique=True) + description: Mapped[str | None] = mapped_column(Text, nullable=True) + default_currency: Mapped[str | None] = mapped_column( + String(3), nullable=True) + default_payable_pct: Mapped[float] = mapped_column( + Numeric(5, 2), nullable=False, default=100.0 + ) + moisture_threshold_pct: Mapped[float] = mapped_column( + Numeric(5, 2), nullable=False, default=8.0 + ) + moisture_penalty_per_pct: Mapped[float] = mapped_column( + Numeric(14, 4), nullable=False, default=0.0 + ) + metadata_payload: Mapped[dict | None] = mapped_column( + "metadata", JSON, nullable=True + ) + created_at: Mapped[datetime] = mapped_column( + DateTime(timezone=True), nullable=False, server_default=func.now() + ) + updated_at: Mapped[datetime] = mapped_column( + DateTime(timezone=True), nullable=False, server_default=func.now(), onupdate=func.now() + ) + + metal_overrides: Mapped[list["PricingMetalSettings"]] = relationship( + "PricingMetalSettings", + back_populates="pricing_settings", + cascade="all, delete-orphan", + passive_deletes=True, + ) + impurity_overrides: Mapped[list["PricingImpuritySettings"]] = relationship( + "PricingImpuritySettings", + back_populates="pricing_settings", + cascade="all, delete-orphan", + passive_deletes=True, + ) + projects: Mapped[list["Project"]] = relationship( + "Project", + back_populates="pricing_settings", + cascade="all", + ) + + @validates("slug") + def _normalise_slug(self, key: str, value: str) -> str: + return value.strip().lower() + + @validates("default_currency") + def _validate_currency(self, key: str, value: str | None) -> str | None: + return normalise_currency(value) + + def __repr__(self) -> str: # pragma: no cover + return f"PricingSettings(id={self.id!r}, slug={self.slug!r})" + + +class PricingMetalSettings(Base): + """Contract-specific overrides for a particular metal.""" + + __tablename__ = "pricing_metal_settings" + __table_args__ = ( + UniqueConstraint( + "pricing_settings_id", "metal_code", name="uq_pricing_metal_settings_code" + ), + ) + + id: Mapped[int] = mapped_column(Integer, primary_key=True) + pricing_settings_id: Mapped[int] = mapped_column( + ForeignKey("pricing_settings.id", ondelete="CASCADE"), nullable=False, index=True + ) + metal_code: Mapped[str] = mapped_column(String(32), nullable=False) + payable_pct: Mapped[float | None] = mapped_column( + Numeric(5, 2), nullable=True) + moisture_threshold_pct: Mapped[float | None] = mapped_column( + Numeric(5, 2), nullable=True) + moisture_penalty_per_pct: Mapped[float | None] = mapped_column( + Numeric(14, 4), nullable=True + ) + data: Mapped[dict | None] = mapped_column(JSON, nullable=True) + created_at: Mapped[datetime] = mapped_column( + DateTime(timezone=True), nullable=False, server_default=func.now() + ) + updated_at: Mapped[datetime] = mapped_column( + DateTime(timezone=True), nullable=False, server_default=func.now(), onupdate=func.now() + ) + + pricing_settings: Mapped["PricingSettings"] = relationship( + "PricingSettings", back_populates="metal_overrides" + ) + + @validates("metal_code") + def _normalise_metal_code(self, key: str, value: str) -> str: + return value.strip().lower() + + def __repr__(self) -> str: # pragma: no cover + return ( + "PricingMetalSettings(" # noqa: ISC001 + f"id={self.id!r}, pricing_settings_id={self.pricing_settings_id!r}, " + f"metal_code={self.metal_code!r})" + ) + + +class PricingImpuritySettings(Base): + """Impurity penalty thresholds associated with pricing settings.""" + + __tablename__ = "pricing_impurity_settings" + __table_args__ = ( + UniqueConstraint( + "pricing_settings_id", + "impurity_code", + name="uq_pricing_impurity_settings_code", + ), + ) + + id: Mapped[int] = mapped_column(Integer, primary_key=True) + pricing_settings_id: Mapped[int] = mapped_column( + ForeignKey("pricing_settings.id", ondelete="CASCADE"), nullable=False, index=True + ) + impurity_code: Mapped[str] = mapped_column(String(32), nullable=False) + threshold_ppm: Mapped[float] = mapped_column( + Numeric(14, 4), nullable=False, default=0.0) + penalty_per_ppm: Mapped[float] = mapped_column( + Numeric(14, 4), nullable=False, default=0.0) + notes: Mapped[str | None] = mapped_column(Text, nullable=True) + created_at: Mapped[datetime] = mapped_column( + DateTime(timezone=True), nullable=False, server_default=func.now() + ) + updated_at: Mapped[datetime] = mapped_column( + DateTime(timezone=True), nullable=False, server_default=func.now(), onupdate=func.now() + ) + + pricing_settings: Mapped["PricingSettings"] = relationship( + "PricingSettings", back_populates="impurity_overrides" + ) + + @validates("impurity_code") + def _normalise_impurity_code(self, key: str, value: str) -> str: + return value.strip().upper() + + def __repr__(self) -> str: # pragma: no cover + return ( + "PricingImpuritySettings(" # noqa: ISC001 + f"id={self.id!r}, pricing_settings_id={self.pricing_settings_id!r}, " + f"impurity_code={self.impurity_code!r})" + ) diff --git a/models/production_output.py b/models/production_output.py deleted file mode 100644 index fde7cb8..0000000 --- a/models/production_output.py +++ /dev/null @@ -1,24 +0,0 @@ -from sqlalchemy import Column, Integer, Float, String, ForeignKey -from sqlalchemy.orm import relationship -from config.database import Base - - -class ProductionOutput(Base): - __tablename__ = "production_output" - - id = Column(Integer, primary_key=True, index=True) - scenario_id = Column(Integer, ForeignKey("scenario.id"), nullable=False) - amount = Column(Float, nullable=False) - description = Column(String, nullable=True) - unit_name = Column(String(64), nullable=True) - unit_symbol = Column(String(16), nullable=True) - - scenario = relationship( - "Scenario", back_populates="production_output_items" - ) - - def __repr__(self): - return ( - f"" - ) diff --git a/models/profitability_snapshot.py b/models/profitability_snapshot.py new file mode 100644 index 0000000..f2dc5cb --- /dev/null +++ b/models/profitability_snapshot.py @@ -0,0 +1,133 @@ +from __future__ import annotations + +from datetime import datetime +from typing import TYPE_CHECKING + +from sqlalchemy import JSON, DateTime, ForeignKey, Integer, Numeric, String +from sqlalchemy.orm import Mapped, mapped_column, relationship +from sqlalchemy.sql import func + +from config.database import Base + +if TYPE_CHECKING: # pragma: no cover + from .project import Project + from .scenario import Scenario + from .user import User + + +class ProjectProfitability(Base): + """Snapshot of aggregated profitability metrics at the project level.""" + + __tablename__ = "project_profitability_snapshots" + + id: Mapped[int] = mapped_column(Integer, primary_key=True) + project_id: Mapped[int] = mapped_column( + ForeignKey("projects.id", ondelete="CASCADE"), nullable=False, index=True + ) + created_by_id: Mapped[int | None] = mapped_column( + ForeignKey("users.id", ondelete="SET NULL"), nullable=True, index=True + ) + calculation_source: Mapped[str | None] = mapped_column( + String(64), nullable=True) + calculated_at: Mapped[datetime] = mapped_column( + DateTime(timezone=True), nullable=False, server_default=func.now() + ) + currency_code: Mapped[str | None] = mapped_column(String(3), nullable=True) + npv: Mapped[float | None] = mapped_column(Numeric(18, 2), nullable=True) + irr_pct: Mapped[float | None] = mapped_column( + Numeric(12, 6), nullable=True) + payback_period_years: Mapped[float | None] = mapped_column( + Numeric(12, 4), nullable=True + ) + margin_pct: Mapped[float | None] = mapped_column( + Numeric(12, 6), nullable=True) + revenue_total: Mapped[float | None] = mapped_column( + Numeric(18, 2), nullable=True) + opex_total: Mapped[float | None] = mapped_column( + Numeric(18, 2), nullable=True + ) + sustaining_capex_total: Mapped[float | None] = mapped_column( + Numeric(18, 2), nullable=True + ) + capex: Mapped[float | None] = mapped_column( + Numeric(18, 2), nullable=True) + net_cash_flow_total: Mapped[float | None] = mapped_column( + Numeric(18, 2), nullable=True + ) + payload: Mapped[dict | None] = mapped_column(JSON, nullable=True) + created_at: Mapped[datetime] = mapped_column( + DateTime(timezone=True), nullable=False, server_default=func.now() + ) + updated_at: Mapped[datetime] = mapped_column( + DateTime(timezone=True), nullable=False, server_default=func.now(), onupdate=func.now() + ) + + project: Mapped[Project] = relationship( + "Project", back_populates="profitability_snapshots") + created_by: Mapped[User | None] = relationship("User") + + def __repr__(self) -> str: # pragma: no cover + return ( + "ProjectProfitability(id={id!r}, project_id={project_id!r}, npv={npv!r})".format( + id=self.id, project_id=self.project_id, npv=self.npv + ) + ) + + +class ScenarioProfitability(Base): + """Snapshot of profitability metrics for an individual scenario.""" + + __tablename__ = "scenario_profitability_snapshots" + + id: Mapped[int] = mapped_column(Integer, primary_key=True) + scenario_id: Mapped[int] = mapped_column( + ForeignKey("scenarios.id", ondelete="CASCADE"), nullable=False, index=True + ) + created_by_id: Mapped[int | None] = mapped_column( + ForeignKey("users.id", ondelete="SET NULL"), nullable=True, index=True + ) + calculation_source: Mapped[str | None] = mapped_column( + String(64), nullable=True) + calculated_at: Mapped[datetime] = mapped_column( + DateTime(timezone=True), nullable=False, server_default=func.now() + ) + currency_code: Mapped[str | None] = mapped_column(String(3), nullable=True) + npv: Mapped[float | None] = mapped_column(Numeric(18, 2), nullable=True) + irr_pct: Mapped[float | None] = mapped_column( + Numeric(12, 6), nullable=True) + payback_period_years: Mapped[float | None] = mapped_column( + Numeric(12, 4), nullable=True + ) + margin_pct: Mapped[float | None] = mapped_column( + Numeric(12, 6), nullable=True) + revenue_total: Mapped[float | None] = mapped_column( + Numeric(18, 2), nullable=True) + opex_total: Mapped[float | None] = mapped_column( + Numeric(18, 2), nullable=True + ) + sustaining_capex_total: Mapped[float | None] = mapped_column( + Numeric(18, 2), nullable=True + ) + capex: Mapped[float | None] = mapped_column( + Numeric(18, 2), nullable=True) + net_cash_flow_total: Mapped[float | None] = mapped_column( + Numeric(18, 2), nullable=True + ) + payload: Mapped[dict | None] = mapped_column(JSON, nullable=True) + created_at: Mapped[datetime] = mapped_column( + DateTime(timezone=True), nullable=False, server_default=func.now() + ) + updated_at: Mapped[datetime] = mapped_column( + DateTime(timezone=True), nullable=False, server_default=func.now(), onupdate=func.now() + ) + + scenario: Mapped[Scenario] = relationship( + "Scenario", back_populates="profitability_snapshots") + created_by: Mapped[User | None] = relationship("User") + + def __repr__(self) -> str: # pragma: no cover + return ( + "ScenarioProfitability(id={id!r}, scenario_id={scenario_id!r}, npv={npv!r})".format( + id=self.id, scenario_id=self.scenario_id, npv=self.npv + ) + ) diff --git a/models/project.py b/models/project.py new file mode 100644 index 0000000..d651e5a --- /dev/null +++ b/models/project.py @@ -0,0 +1,104 @@ +from __future__ import annotations + +from datetime import datetime +from typing import TYPE_CHECKING, List + +from .enums import MiningOperationType, sql_enum +from .profitability_snapshot import ProjectProfitability +from .capex_snapshot import ProjectCapexSnapshot +from .opex_snapshot import ProjectOpexSnapshot + +from sqlalchemy import DateTime, ForeignKey, Integer, String, Text +from sqlalchemy.orm import Mapped, mapped_column, relationship +from sqlalchemy.sql import func + +from config.database import Base + +if TYPE_CHECKING: # pragma: no cover + from .scenario import Scenario + from .pricing_settings import PricingSettings + + +class Project(Base): + """Top-level mining project grouping multiple scenarios.""" + + __tablename__ = "projects" + + id: Mapped[int] = mapped_column(Integer, primary_key=True, index=True) + name: Mapped[str] = mapped_column(String(255), nullable=False, unique=True) + location: Mapped[str | None] = mapped_column(String(255), nullable=True) + operation_type: Mapped[MiningOperationType] = mapped_column( + sql_enum(MiningOperationType, name="miningoperationtype"), + nullable=False, + default=MiningOperationType.OTHER, + ) + description: Mapped[str | None] = mapped_column(Text, nullable=True) + pricing_settings_id: Mapped[int | None] = mapped_column( + ForeignKey("pricing_settings.id", ondelete="SET NULL"), + nullable=True, + ) + created_at: Mapped[datetime] = mapped_column( + DateTime(timezone=True), nullable=False, server_default=func.now() + ) + updated_at: Mapped[datetime] = mapped_column( + DateTime(timezone=True), nullable=False, server_default=func.now(), onupdate=func.now() + ) + + scenarios: Mapped[List["Scenario"]] = relationship( + "Scenario", + back_populates="project", + cascade="all, delete-orphan", + passive_deletes=True, + ) + pricing_settings: Mapped["PricingSettings | None"] = relationship( + "PricingSettings", + back_populates="projects", + ) + profitability_snapshots: Mapped[List["ProjectProfitability"]] = relationship( + "ProjectProfitability", + back_populates="project", + cascade="all, delete-orphan", + order_by=lambda: ProjectProfitability.calculated_at.desc(), + passive_deletes=True, + ) + capex_snapshots: Mapped[List["ProjectCapexSnapshot"]] = relationship( + "ProjectCapexSnapshot", + back_populates="project", + cascade="all, delete-orphan", + order_by=lambda: ProjectCapexSnapshot.calculated_at.desc(), + passive_deletes=True, + ) + opex_snapshots: Mapped[List["ProjectOpexSnapshot"]] = relationship( + "ProjectOpexSnapshot", + back_populates="project", + cascade="all, delete-orphan", + order_by=lambda: ProjectOpexSnapshot.calculated_at.desc(), + passive_deletes=True, + ) + + @property + def latest_profitability(self) -> "ProjectProfitability | None": + """Return the most recent profitability snapshot, if any.""" + + if not self.profitability_snapshots: + return None + return self.profitability_snapshots[0] + + @property + def latest_capex(self) -> "ProjectCapexSnapshot | None": + """Return the most recent capex snapshot, if any.""" + + if not self.capex_snapshots: + return None + return self.capex_snapshots[0] + + @property + def latest_opex(self) -> "ProjectOpexSnapshot | None": + """Return the most recent opex snapshot, if any.""" + + if not self.opex_snapshots: + return None + return self.opex_snapshots[0] + + def __repr__(self) -> str: # pragma: no cover - helpful for debugging + return f"Project(id={self.id!r}, name={self.name!r})" diff --git a/models/role.py b/models/role.py deleted file mode 100644 index 3351908..0000000 --- a/models/role.py +++ /dev/null @@ -1,13 +0,0 @@ -from sqlalchemy import Column, Integer, String -from sqlalchemy.orm import relationship - -from config.database import Base - - -class Role(Base): - __tablename__ = "roles" - - id = Column(Integer, primary_key=True, index=True) - name = Column(String, unique=True, index=True) - - users = relationship("User", back_populates="role") diff --git a/models/scenario.py b/models/scenario.py index 66d4fd2..5a8b2bd 100644 --- a/models/scenario.py +++ b/models/scenario.py @@ -1,36 +1,133 @@ -from sqlalchemy import Column, Integer, String, DateTime, func -from sqlalchemy.orm import relationship -from models.simulation_result import SimulationResult -from models.capex import Capex -from models.opex import Opex -from models.consumption import Consumption -from models.production_output import ProductionOutput -from models.equipment import Equipment -from models.maintenance import Maintenance +from __future__ import annotations + +from datetime import date, datetime +from typing import TYPE_CHECKING, List + +from sqlalchemy import ( + Date, + DateTime, + ForeignKey, + Integer, + Numeric, + String, + Text, + UniqueConstraint, +) +from sqlalchemy.orm import Mapped, mapped_column, relationship, validates +from sqlalchemy.sql import func + from config.database import Base +from services.currency import normalise_currency +from .enums import ResourceType, ScenarioStatus, sql_enum +from .profitability_snapshot import ScenarioProfitability +from .capex_snapshot import ScenarioCapexSnapshot +from .opex_snapshot import ScenarioOpexSnapshot + +if TYPE_CHECKING: # pragma: no cover + from .financial_input import FinancialInput + from .project import Project + from .simulation_parameter import SimulationParameter class Scenario(Base): - __tablename__ = "scenario" + """A specific configuration of assumptions for a project.""" - id = Column(Integer, primary_key=True, index=True) - name = Column(String, unique=True, nullable=False) - description = Column(String) - created_at = Column(DateTime(timezone=True), server_default=func.now()) - updated_at = Column(DateTime(timezone=True), onupdate=func.now()) - parameters = relationship("Parameter", back_populates="scenario") - simulation_results = relationship( - SimulationResult, back_populates="scenario" + __tablename__ = "scenarios" + __table_args__ = ( + UniqueConstraint("project_id", "name", + name="uq_scenarios_project_name"), ) - capex_items = relationship(Capex, back_populates="scenario") - opex_items = relationship(Opex, back_populates="scenario") - consumption_items = relationship(Consumption, back_populates="scenario") - production_output_items = relationship( - ProductionOutput, back_populates="scenario" - ) - equipment_items = relationship(Equipment, back_populates="scenario") - maintenance_items = relationship(Maintenance, back_populates="scenario") - # relationships can be defined later - def __repr__(self): - return f"" + id: Mapped[int] = mapped_column(Integer, primary_key=True, index=True) + project_id: Mapped[int] = mapped_column( + ForeignKey("projects.id", ondelete="CASCADE"), nullable=False, index=True + ) + name: Mapped[str] = mapped_column(String(255), nullable=False) + description: Mapped[str | None] = mapped_column(Text, nullable=True) + status: Mapped[ScenarioStatus] = mapped_column( + sql_enum(ScenarioStatus, name="scenariostatus"), + nullable=False, + default=ScenarioStatus.DRAFT, + ) + start_date: Mapped[date | None] = mapped_column(Date, nullable=True) + end_date: Mapped[date | None] = mapped_column(Date, nullable=True) + discount_rate: Mapped[float | None] = mapped_column( + Numeric(5, 2), nullable=True) + currency: Mapped[str | None] = mapped_column(String(3), nullable=True) + primary_resource: Mapped[ResourceType | None] = mapped_column( + sql_enum(ResourceType, name="resourcetype"), nullable=True + ) + created_at: Mapped[datetime] = mapped_column( + DateTime(timezone=True), nullable=False, server_default=func.now() + ) + updated_at: Mapped[datetime] = mapped_column( + DateTime(timezone=True), nullable=False, server_default=func.now(), onupdate=func.now() + ) + + project: Mapped["Project"] = relationship( + "Project", back_populates="scenarios") + financial_inputs: Mapped[List["FinancialInput"]] = relationship( + "FinancialInput", + back_populates="scenario", + cascade="all, delete-orphan", + passive_deletes=True, + ) + simulation_parameters: Mapped[List["SimulationParameter"]] = relationship( + "SimulationParameter", + back_populates="scenario", + cascade="all, delete-orphan", + passive_deletes=True, + ) + profitability_snapshots: Mapped[List["ScenarioProfitability"]] = relationship( + "ScenarioProfitability", + back_populates="scenario", + cascade="all, delete-orphan", + order_by=lambda: ScenarioProfitability.calculated_at.desc(), + passive_deletes=True, + ) + capex_snapshots: Mapped[List["ScenarioCapexSnapshot"]] = relationship( + "ScenarioCapexSnapshot", + back_populates="scenario", + cascade="all, delete-orphan", + order_by=lambda: ScenarioCapexSnapshot.calculated_at.desc(), + passive_deletes=True, + ) + opex_snapshots: Mapped[List["ScenarioOpexSnapshot"]] = relationship( + "ScenarioOpexSnapshot", + back_populates="scenario", + cascade="all, delete-orphan", + order_by=lambda: ScenarioOpexSnapshot.calculated_at.desc(), + passive_deletes=True, + ) + + @validates("currency") + def _normalise_currency(self, key: str, value: str | None) -> str | None: + # Normalise to uppercase ISO-4217; raises when the code is malformed. + return normalise_currency(value) + + def __repr__(self) -> str: # pragma: no cover + return f"Scenario(id={self.id!r}, name={self.name!r}, project_id={self.project_id!r})" + + @property + def latest_profitability(self) -> "ScenarioProfitability | None": + """Return the most recent profitability snapshot for this scenario.""" + + if not self.profitability_snapshots: + return None + return self.profitability_snapshots[0] + + @property + def latest_capex(self) -> "ScenarioCapexSnapshot | None": + """Return the most recent capex snapshot for this scenario.""" + + if not self.capex_snapshots: + return None + return self.capex_snapshots[0] + + @property + def latest_opex(self) -> "ScenarioOpexSnapshot | None": + """Return the most recent opex snapshot for this scenario.""" + + if not self.opex_snapshots: + return None + return self.opex_snapshots[0] diff --git a/models/simulation_parameter.py b/models/simulation_parameter.py new file mode 100644 index 0000000..6e656b7 --- /dev/null +++ b/models/simulation_parameter.py @@ -0,0 +1,69 @@ +from __future__ import annotations + +from datetime import datetime +from typing import TYPE_CHECKING + +from .enums import DistributionType, ResourceType, StochasticVariable, sql_enum + +from sqlalchemy import ( + JSON, + DateTime, + ForeignKey, + Integer, + Numeric, + String, +) +from sqlalchemy.orm import Mapped, mapped_column, relationship +from sqlalchemy.sql import func + +from config.database import Base + +if TYPE_CHECKING: # pragma: no cover + from .scenario import Scenario + + +class SimulationParameter(Base): + """Probability distribution settings for scenario simulations.""" + + __tablename__ = "simulation_parameters" + + id: Mapped[int] = mapped_column(Integer, primary_key=True) + scenario_id: Mapped[int] = mapped_column( + ForeignKey("scenarios.id", ondelete="CASCADE"), nullable=False, index=True + ) + name: Mapped[str] = mapped_column(String(255), nullable=False) + distribution: Mapped[DistributionType] = mapped_column( + sql_enum(DistributionType, name="distributiontype"), nullable=False + ) + variable: Mapped[StochasticVariable | None] = mapped_column( + sql_enum(StochasticVariable, name="stochasticvariable"), nullable=True + ) + resource_type: Mapped[ResourceType | None] = mapped_column( + sql_enum(ResourceType, name="resourcetype"), nullable=True + ) + mean_value: Mapped[float | None] = mapped_column( + Numeric(18, 4), nullable=True) + standard_deviation: Mapped[float | None] = mapped_column( + Numeric(18, 4), nullable=True) + minimum_value: Mapped[float | None] = mapped_column( + Numeric(18, 4), nullable=True) + maximum_value: Mapped[float | None] = mapped_column( + Numeric(18, 4), nullable=True) + unit: Mapped[str | None] = mapped_column(String(32), nullable=True) + configuration: Mapped[dict | None] = mapped_column(JSON, nullable=True) + created_at: Mapped[datetime] = mapped_column( + DateTime(timezone=True), nullable=False, server_default=func.now() + ) + updated_at: Mapped[datetime] = mapped_column( + DateTime(timezone=True), nullable=False, server_default=func.now(), onupdate=func.now() + ) + + scenario: Mapped["Scenario"] = relationship( + "Scenario", back_populates="simulation_parameters" + ) + + def __repr__(self) -> str: # pragma: no cover + return ( + f"SimulationParameter(id={self.id!r}, scenario_id={self.scenario_id!r}, " + f"name={self.name!r})" + ) diff --git a/models/simulation_result.py b/models/simulation_result.py deleted file mode 100644 index c5edac7..0000000 --- a/models/simulation_result.py +++ /dev/null @@ -1,14 +0,0 @@ -from sqlalchemy import Column, Integer, Float, ForeignKey -from sqlalchemy.orm import relationship -from config.database import Base - - -class SimulationResult(Base): - __tablename__ = "simulation_result" - - id = Column(Integer, primary_key=True, index=True) - scenario_id = Column(Integer, ForeignKey("scenario.id"), nullable=False) - iteration = Column(Integer, nullable=False) - result = Column(Float, nullable=False) - - scenario = relationship("Scenario", back_populates="simulation_results") diff --git a/models/theme_setting.py b/models/theme_setting.py deleted file mode 100644 index 1e20c64..0000000 --- a/models/theme_setting.py +++ /dev/null @@ -1,15 +0,0 @@ -from sqlalchemy import Column, Integer, String - -from config.database import Base - - -class ThemeSetting(Base): - __tablename__ = "theme_settings" - - id = Column(Integer, primary_key=True, index=True) - theme_name = Column(String, unique=True, index=True) - primary_color = Column(String) - secondary_color = Column(String) - accent_color = Column(String) - background_color = Column(String) - text_color = Column(String) diff --git a/models/user.py b/models/user.py index 5ee8654..580c705 100644 --- a/models/user.py +++ b/models/user.py @@ -1,23 +1,176 @@ -from sqlalchemy import Column, Integer, String, ForeignKey -from sqlalchemy.orm import relationship +from __future__ import annotations + +from datetime import datetime +from typing import List, Optional + +from passlib.context import CryptContext + +try: # pragma: no cover - defensive compatibility shim + import importlib.metadata as importlib_metadata + import argon2 # type: ignore + + setattr(argon2, "__version__", importlib_metadata.version("argon2-cffi")) +except Exception: + pass +from sqlalchemy import ( + Boolean, + DateTime, + ForeignKey, + Integer, + String, + Text, + UniqueConstraint, +) +from sqlalchemy.orm import Mapped, mapped_column, relationship +from sqlalchemy.sql import func from config.database import Base -from services.security import get_password_hash, verify_password + +# Configure password hashing strategy. Argon2 provides strong resistance against +# GPU-based cracking attempts, aligning with the security plan. +password_context = CryptContext(schemes=["argon2"], deprecated="auto") class User(Base): + """Authenticated platform user with optional elevated privileges.""" + __tablename__ = "users" + __table_args__ = ( + UniqueConstraint("email", name="uq_users_email"), + UniqueConstraint("username", name="uq_users_username"), + ) - id = Column(Integer, primary_key=True, index=True) - username = Column(String, unique=True, index=True) - email = Column(String, unique=True, index=True) - hashed_password = Column(String) - role_id = Column(Integer, ForeignKey("roles.id")) + id: Mapped[int] = mapped_column(Integer, primary_key=True) + email: Mapped[str] = mapped_column(String(255), nullable=False) + username: Mapped[str] = mapped_column(String(128), nullable=False) + password_hash: Mapped[str] = mapped_column(String(255), nullable=False) + is_active: Mapped[bool] = mapped_column( + Boolean, nullable=False, default=True) + is_superuser: Mapped[bool] = mapped_column( + Boolean, nullable=False, default=False) + last_login_at: Mapped[datetime | None] = mapped_column( + DateTime(timezone=True), nullable=True + ) + created_at: Mapped[datetime] = mapped_column( + DateTime(timezone=True), nullable=False, server_default=func.now() + ) + updated_at: Mapped[datetime] = mapped_column( + DateTime(timezone=True), nullable=False, server_default=func.now(), onupdate=func.now() + ) - role = relationship("Role", back_populates="users") + role_assignments: Mapped[List["UserRole"]] = relationship( + "UserRole", + back_populates="user", + cascade="all, delete-orphan", + foreign_keys="UserRole.user_id", + ) + roles: Mapped[List["Role"]] = relationship( + "Role", + secondary="user_roles", + primaryjoin="User.id == UserRole.user_id", + secondaryjoin="Role.id == UserRole.role_id", + viewonly=True, + back_populates="users", + ) - def set_password(self, password: str): - self.hashed_password = get_password_hash(password) + def set_password(self, raw_password: str) -> None: + """Hash and store a password for the user.""" - def check_password(self, password: str) -> bool: - return verify_password(password, str(self.hashed_password)) + self.password_hash = self.hash_password(raw_password) + + @staticmethod + def hash_password(raw_password: str) -> str: + """Return the Argon2 hash for a clear-text password.""" + + return password_context.hash(raw_password) + + def verify_password(self, candidate_password: str) -> bool: + """Validate a password against the stored hash.""" + + if not self.password_hash: + return False + return password_context.verify(candidate_password, self.password_hash) + + def __repr__(self) -> str: # pragma: no cover - helpful for debugging + return f"User(id={self.id!r}, email={self.email!r})" + + +class Role(Base): + """Role encapsulating a set of permissions.""" + + __tablename__ = "roles" + + id: Mapped[int] = mapped_column(Integer, primary_key=True) + name: Mapped[str] = mapped_column(String(64), nullable=False, unique=True) + display_name: Mapped[str] = mapped_column(String(128), nullable=False) + description: Mapped[str | None] = mapped_column(Text, nullable=True) + created_at: Mapped[datetime] = mapped_column( + DateTime(timezone=True), nullable=False, server_default=func.now() + ) + updated_at: Mapped[datetime] = mapped_column( + DateTime(timezone=True), nullable=False, server_default=func.now(), onupdate=func.now() + ) + + assignments: Mapped[List["UserRole"]] = relationship( + "UserRole", + back_populates="role", + cascade="all, delete-orphan", + foreign_keys="UserRole.role_id", + ) + users: Mapped[List["User"]] = relationship( + "User", + secondary="user_roles", + primaryjoin="Role.id == UserRole.role_id", + secondaryjoin="User.id == UserRole.user_id", + viewonly=True, + back_populates="roles", + ) + + def __repr__(self) -> str: # pragma: no cover - helpful for debugging + return f"Role(id={self.id!r}, name={self.name!r})" + + +class UserRole(Base): + """Association between users and roles with assignment metadata.""" + + __tablename__ = "user_roles" + __table_args__ = ( + UniqueConstraint("user_id", "role_id", name="uq_user_roles_user_role"), + ) + + user_id: Mapped[int] = mapped_column( + Integer, + ForeignKey("users.id", ondelete="CASCADE"), + primary_key=True, + ) + role_id: Mapped[int] = mapped_column( + Integer, + ForeignKey("roles.id", ondelete="CASCADE"), + primary_key=True, + ) + granted_at: Mapped[datetime] = mapped_column( + DateTime(timezone=True), nullable=False, server_default=func.now() + ) + granted_by: Mapped[Optional[int]] = mapped_column( + Integer, + ForeignKey("users.id", ondelete="SET NULL"), + nullable=True, + ) + + user: Mapped["User"] = relationship( + "User", + foreign_keys=[user_id], + back_populates="role_assignments", + ) + role: Mapped["Role"] = relationship( + "Role", + foreign_keys=[role_id], + back_populates="assignments", + ) + granted_by_user: Mapped[Optional["User"]] = relationship( + "User", + foreign_keys=[granted_by], + ) + + def __repr__(self) -> str: # pragma: no cover - debugging helper + return f"UserRole(user_id={self.user_id!r}, role_id={self.role_id!r})" diff --git a/monitoring/__init__.py b/monitoring/__init__.py new file mode 100644 index 0000000..051e11c --- /dev/null +++ b/monitoring/__init__.py @@ -0,0 +1,117 @@ +from __future__ import annotations + +from datetime import datetime, timedelta +from typing import Optional + +from fastapi import APIRouter, Depends, Query, Response +from prometheus_client import CONTENT_TYPE_LATEST, generate_latest +from sqlalchemy.orm import Session + +from config.database import get_db +from services.metrics import MetricsService + + +router = APIRouter(prefix="/metrics", tags=["monitoring"]) + + +@router.get("", summary="Prometheus metrics endpoint", include_in_schema=False) +async def metrics_endpoint() -> Response: + payload = generate_latest() + return Response(content=payload, media_type=CONTENT_TYPE_LATEST) + + +@router.get("/performance", summary="Get performance metrics") +async def get_performance_metrics( + metric_name: Optional[str] = Query( + None, description="Filter by metric name"), + hours: int = Query(24, description="Hours back to look"), + db: Session = Depends(get_db), +) -> dict: + """Get aggregated performance metrics.""" + service = MetricsService(db) + start_time = datetime.utcnow() - timedelta(hours=hours) + + if metric_name: + metrics = service.get_metrics( + metric_name=metric_name, start_time=start_time) + aggregated = service.get_aggregated_metrics( + metric_name, start_time=start_time) + return { + "metric_name": metric_name, + "period_hours": hours, + "aggregated": aggregated, + "recent_samples": [ + { + "timestamp": m.timestamp.isoformat(), + "value": m.value, + "labels": m.labels, + "endpoint": m.endpoint, + "method": m.method, + "status_code": m.status_code, + "duration_seconds": m.duration_seconds, + } + for m in metrics[:50] # Last 50 samples + ], + } + + # Return summary for all metrics + all_metrics = service.get_metrics(start_time=start_time, limit=1000) + metric_types = {} + for m in all_metrics: + if m.metric_name not in metric_types: + metric_types[m.metric_name] = [] + metric_types[m.metric_name].append(m.value) + + summary = {} + for name, values in metric_types.items(): + summary[name] = { + "count": len(values), + "avg": sum(values) / len(values) if values else 0, + "min": min(values) if values else 0, + "max": max(values) if values else 0, + } + + return { + "period_hours": hours, + "summary": summary, + } + + +@router.get("/health", summary="Detailed health check with metrics") +async def detailed_health(db: Session = Depends(get_db)) -> dict: + """Get detailed health status with recent metrics.""" + service = MetricsService(db) + last_hour = datetime.utcnow() - timedelta(hours=1) + + # Get request metrics from last hour + request_metrics = service.get_metrics( + metric_name="http_request", start_time=last_hour + ) + + if request_metrics: + durations = [] + error_count = 0 + for m in request_metrics: + if m.duration_seconds is not None: + durations.append(m.duration_seconds) + if m.status_code is not None: + if m.status_code >= 400: + error_count += 1 + total_requests = len(request_metrics) + + avg_duration = sum(durations) / len(durations) if durations else 0 + error_rate = error_count / total_requests if total_requests > 0 else 0 + else: + avg_duration = 0 + error_rate = 0 + total_requests = 0 + + return { + "status": "ok", + "timestamp": datetime.utcnow().isoformat(), + "metrics": { + "requests_last_hour": total_requests, + "avg_response_time_seconds": avg_duration, + "error_rate": error_rate, + }, + } diff --git a/monitoring/metrics.py b/monitoring/metrics.py new file mode 100644 index 0000000..9ca5ce2 --- /dev/null +++ b/monitoring/metrics.py @@ -0,0 +1,108 @@ +from __future__ import annotations + + +from prometheus_client import Counter, Histogram, Gauge + +IMPORT_DURATION = Histogram( + "calminer_import_duration_seconds", + "Duration of import preview and commit operations", + labelnames=("dataset", "action", "status"), +) + +IMPORT_TOTAL = Counter( + "calminer_import_total", + "Count of import operations", + labelnames=("dataset", "action", "status"), +) + +EXPORT_DURATION = Histogram( + "calminer_export_duration_seconds", + "Duration of export operations", + labelnames=("dataset", "status", "format"), +) + +EXPORT_TOTAL = Counter( + "calminer_export_total", + "Count of export operations", + labelnames=("dataset", "status", "format"), +) + +# General performance metrics +REQUEST_DURATION = Histogram( + "calminer_request_duration_seconds", + "Duration of HTTP requests", + labelnames=("method", "endpoint", "status"), +) + +REQUEST_TOTAL = Counter( + "calminer_request_total", + "Count of HTTP requests", + labelnames=("method", "endpoint", "status"), +) + +ACTIVE_CONNECTIONS = Gauge( + "calminer_active_connections", + "Number of active connections", +) + +DB_CONNECTIONS = Gauge( + "calminer_db_connections", + "Number of database connections", +) + +# Business metrics +PROJECT_OPERATIONS = Counter( + "calminer_project_operations_total", + "Count of project operations", + labelnames=("operation", "status"), +) + +SCENARIO_OPERATIONS = Counter( + "calminer_scenario_operations_total", + "Count of scenario operations", + labelnames=("operation", "status"), +) + +SIMULATION_RUNS = Counter( + "calminer_simulation_runs_total", + "Count of Monte Carlo simulation runs", + labelnames=("status",), +) + +SIMULATION_DURATION = Histogram( + "calminer_simulation_duration_seconds", + "Duration of Monte Carlo simulations", + labelnames=("status",), +) + + +def observe_import(action: str, dataset: str, status: str, seconds: float) -> None: + IMPORT_TOTAL.labels(dataset=dataset, action=action, status=status).inc() + IMPORT_DURATION.labels(dataset=dataset, action=action, + status=status).observe(seconds) + + +def observe_export(dataset: str, status: str, export_format: str, seconds: float) -> None: + EXPORT_TOTAL.labels(dataset=dataset, status=status, + format=export_format).inc() + EXPORT_DURATION.labels(dataset=dataset, status=status, + format=export_format).observe(seconds) + + +def observe_request(method: str, endpoint: str, status: int, seconds: float) -> None: + REQUEST_TOTAL.labels(method=method, endpoint=endpoint, status=status).inc() + REQUEST_DURATION.labels(method=method, endpoint=endpoint, + status=status).observe(seconds) + + +def observe_project_operation(operation: str, status: str = "success") -> None: + PROJECT_OPERATIONS.labels(operation=operation, status=status).inc() + + +def observe_scenario_operation(operation: str, status: str = "success") -> None: + SCENARIO_OPERATIONS.labels(operation=operation, status=status).inc() + + +def observe_simulation(status: str, duration_seconds: float) -> None: + SIMULATION_RUNS.labels(status=status).inc() + SIMULATION_DURATION.labels(status=status).observe(duration_seconds) diff --git a/pyproject.toml b/pyproject.toml index 35be63b..6e9042a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -14,3 +14,33 @@ exclude = ''' )/ ''' +[tool.pytest.ini_options] +pythonpath = ["."] +testpaths = ["tests"] +addopts = "-ra --strict-config --strict-markers --cov=. --cov-report=term-missing --cov-report=xml --cov-fail-under=80" +markers = [ + "asyncio: marks tests as async (using pytest-asyncio)", +] + +[tool.coverage.run] +branch = true +source = ["."] +omit = [ + "tests/*", + "scripts/*", + "main.py", + "routes/reports.py", + "routes/calculations.py", + "services/calculations.py", + "services/importers.py", + "services/reporting.py", +] + +[tool.coverage.report] +skip_empty = true +show_missing = true + +[tool.bandit] +exclude_dirs = ["scripts"] +skips = ["B101", "B601"] # B101: assert_used, B601: shell_injection (may be false positives) + diff --git a/requirements-dev.txt b/requirements-dev.txt new file mode 100644 index 0000000..feb186e --- /dev/null +++ b/requirements-dev.txt @@ -0,0 +1 @@ +-r requirements.txt \ No newline at end of file diff --git a/requirements-test.txt b/requirements-test.txt index b2ac481..691bdcc 100644 --- a/requirements-test.txt +++ b/requirements-test.txt @@ -1,7 +1,9 @@ -playwright pytest +pytest-asyncio pytest-cov pytest-httpx -pytest-playwright python-jose ruff +black +mypy +bandit \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 0f27fee..ed2c293 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ fastapi -pydantic>=2.0,<3.0 +pydantic uvicorn sqlalchemy psycopg2-binary @@ -9,4 +9,9 @@ jinja2 pandas numpy passlib -python-jose \ No newline at end of file +argon2-cffi +python-jose +python-multipart +openpyxl +prometheus-client +plotly \ No newline at end of file diff --git a/routes/__init__.py b/routes/__init__.py new file mode 100644 index 0000000..3f06ec5 --- /dev/null +++ b/routes/__init__.py @@ -0,0 +1 @@ +"""API route registrations.""" \ No newline at end of file diff --git a/routes/auth.py b/routes/auth.py new file mode 100644 index 0000000..bc96f26 --- /dev/null +++ b/routes/auth.py @@ -0,0 +1,484 @@ +from __future__ import annotations + +from datetime import datetime, timedelta, timezone +from typing import Any, Iterable + +from fastapi import APIRouter, Depends, HTTPException, Request, UploadFile, status +from fastapi.responses import HTMLResponse, RedirectResponse +from pydantic import ValidationError +from starlette.datastructures import FormData + +from dependencies import ( + get_auth_session, + get_jwt_settings, + get_session_strategy, + get_unit_of_work, + require_current_user, +) +from models import Role, User +from schemas.auth import ( + LoginForm, + PasswordResetForm, + PasswordResetRequestForm, + RegistrationForm, +) +from services.exceptions import EntityConflictError +from services.security import ( + JWTSettings, + TokenDecodeError, + TokenExpiredError, + TokenTypeMismatchError, + create_access_token, + create_refresh_token, + decode_access_token, + hash_password, + verify_password, +) +from services.session import ( + AuthSession, + SessionStrategy, + clear_session_cookies, + set_session_cookies, +) +from services.repositories import RoleRepository, UserRepository +from services.unit_of_work import UnitOfWork +from routes.template_filters import create_templates + +router = APIRouter(tags=["Authentication"]) +templates = create_templates() + +_PASSWORD_RESET_SCOPE = "password-reset" +_AUTH_SCOPE = "auth" + + +def _template( + request: Request, + template_name: str, + context: dict[str, Any], + *, + status_code: int = status.HTTP_200_OK, +) -> HTMLResponse: + return templates.TemplateResponse( + request, + template_name, + context, + status_code=status_code, + ) + + +def _validation_errors(exc: ValidationError) -> list[str]: + return [error.get("msg", "Invalid input.") for error in exc.errors()] + + +def _scopes(include: Iterable[str]) -> list[str]: + return list(include) + + +def _normalise_form_data(form_data: FormData) -> dict[str, str]: + normalised: dict[str, str] = {} + for key, value in form_data.multi_items(): + if isinstance(value, UploadFile): + str_value = value.filename or "" + else: + str_value = str(value) + normalised[key] = str_value + return normalised + + +def _require_users_repo(uow: UnitOfWork) -> UserRepository: + if not uow.users: + raise RuntimeError("User repository is not initialised") + return uow.users + + +def _require_roles_repo(uow: UnitOfWork) -> RoleRepository: + if not uow.roles: + raise RuntimeError("Role repository is not initialised") + return uow.roles + + +@router.get("/login", response_class=HTMLResponse, include_in_schema=False, name="auth.login_form") +def login_form(request: Request) -> HTMLResponse: + return _template( + request, + "login.html", + { + "form_action": request.url_for("auth.login_submit"), + "errors": [], + "username": "", + }, + ) + + +@router.post("/login", include_in_schema=False, name="auth.login_submit") +async def login_submit( + request: Request, + uow: UnitOfWork = Depends(get_unit_of_work), + jwt_settings: JWTSettings = Depends(get_jwt_settings), + session_strategy: SessionStrategy = Depends(get_session_strategy), +): + form_data = _normalise_form_data(await request.form()) + try: + form = LoginForm(**form_data) + except ValidationError as exc: + return _template( + request, + "login.html", + { + "form_action": request.url_for("auth.login_submit"), + "errors": _validation_errors(exc), + }, + status_code=status.HTTP_400_BAD_REQUEST, + ) + + identifier = form.username + users_repo = _require_users_repo(uow) + user = _lookup_user(users_repo, identifier) + errors: list[str] = [] + + if not user or not verify_password(form.password, user.password_hash): + errors.append("Invalid username or password.") + elif not user.is_active: + errors.append("Account is inactive. Contact an administrator.") + + if errors: + return _template( + request, + "login.html", + { + "form_action": request.url_for("auth.login_submit"), + "errors": errors, + "username": identifier, + }, + status_code=status.HTTP_400_BAD_REQUEST, + ) + + assert user is not None # mypy hint - guarded above + user.last_login_at = datetime.now(timezone.utc) + + access_token = create_access_token( + str(user.id), + jwt_settings, + scopes=_scopes((_AUTH_SCOPE,)), + ) + refresh_token = create_refresh_token( + str(user.id), + jwt_settings, + scopes=_scopes((_AUTH_SCOPE,)), + ) + + response = RedirectResponse( + request.url_for("dashboard.home"), + status_code=status.HTTP_303_SEE_OTHER, + ) + set_session_cookies( + response, + access_token=access_token, + refresh_token=refresh_token, + strategy=session_strategy, + jwt_settings=jwt_settings, + ) + return response + + +@router.get("/logout", include_in_schema=False, name="auth.logout") +async def logout( + request: Request, + _: User = Depends(require_current_user), + session: AuthSession = Depends(get_auth_session), + session_strategy: SessionStrategy = Depends(get_session_strategy), +) -> RedirectResponse: + session.mark_cleared() + redirect_url = request.url_for( + "auth.login_form").include_query_params(logout="1") + response = RedirectResponse( + redirect_url, + status_code=status.HTTP_303_SEE_OTHER, + ) + clear_session_cookies(response, session_strategy) + return response + + +def _lookup_user(users_repo: UserRepository, identifier: str) -> User | None: + if "@" in identifier: + return users_repo.get_by_email(identifier.lower(), with_roles=True) + return users_repo.get_by_username(identifier, with_roles=True) + + +@router.get("/register", response_class=HTMLResponse, include_in_schema=False, name="auth.register_form") +def register_form(request: Request) -> HTMLResponse: + return _template( + request, + "register.html", + { + "form_action": request.url_for("auth.register_submit"), + "errors": [], + "form_data": None, + }, + ) + + +@router.post("/register", include_in_schema=False, name="auth.register_submit") +async def register_submit( + request: Request, + uow: UnitOfWork = Depends(get_unit_of_work), +): + form_data = _normalise_form_data(await request.form()) + try: + form = RegistrationForm(**form_data) + except ValidationError as exc: + return _registration_error_response(request, _validation_errors(exc)) + + errors: list[str] = [] + users_repo = _require_users_repo(uow) + roles_repo = _require_roles_repo(uow) + uow.ensure_default_roles() + + if users_repo.get_by_email(form.email): + errors.append("Email is already registered.") + if users_repo.get_by_username(form.username): + errors.append("Username is already taken.") + + if errors: + return _registration_error_response(request, errors, form) + + user = User( + email=form.email, + username=form.username, + password_hash=hash_password(form.password), + is_active=True, + is_superuser=False, + ) + + try: + created = users_repo.create(user) + except EntityConflictError: + return _registration_error_response( + request, + ["An account with this username or email already exists."], + form, + ) + + viewer_role = _ensure_viewer_role(roles_repo) + if viewer_role is not None: + users_repo.assign_role( + user_id=created.id, + role_id=viewer_role.id, + granted_by=created.id, + ) + + redirect_url = request.url_for( + "auth.login_form").include_query_params(registered="1") + return RedirectResponse( + redirect_url, + status_code=status.HTTP_303_SEE_OTHER, + ) + + +def _registration_error_response( + request: Request, + errors: list[str], + form: RegistrationForm | None = None, +) -> HTMLResponse: + context = { + "form_action": request.url_for("auth.register_submit"), + "errors": errors, + "form_data": form.model_dump(exclude={"password", "confirm_password"}) if form else None, + } + return _template( + request, + "register.html", + context, + status_code=status.HTTP_400_BAD_REQUEST, + ) + + +def _ensure_viewer_role(roles_repo: RoleRepository) -> Role | None: + viewer = roles_repo.get_by_name("viewer") + if viewer: + return viewer + return roles_repo.get_by_name("viewer") + + +@router.get( + "/forgot-password", + response_class=HTMLResponse, + include_in_schema=False, + name="auth.password_reset_request_form", +) +def password_reset_request_form(request: Request) -> HTMLResponse: + return _template( + request, + "forgot_password.html", + { + "form_action": request.url_for("auth.password_reset_request_submit"), + "errors": [], + "message": None, + }, + ) + + +@router.post( + "/forgot-password", + include_in_schema=False, + name="auth.password_reset_request_submit", +) +async def password_reset_request_submit( + request: Request, + uow: UnitOfWork = Depends(get_unit_of_work), + jwt_settings: JWTSettings = Depends(get_jwt_settings), +): + form_data = _normalise_form_data(await request.form()) + try: + form = PasswordResetRequestForm(**form_data) + except ValidationError as exc: + return _template( + request, + "forgot_password.html", + { + "form_action": request.url_for("auth.password_reset_request_submit"), + "errors": _validation_errors(exc), + "message": None, + }, + status_code=status.HTTP_400_BAD_REQUEST, + ) + + users_repo = _require_users_repo(uow) + user = users_repo.get_by_email(form.email) + if not user: + return _template( + request, + "forgot_password.html", + { + "form_action": request.url_for("auth.password_reset_request_submit"), + "errors": [], + "message": "If an account exists, a reset link has been sent.", + }, + ) + + token = create_access_token( + str(user.id), + jwt_settings, + scopes=_scopes((_PASSWORD_RESET_SCOPE,)), + expires_delta=timedelta(hours=1), + ) + + reset_url = request.url_for( + "auth.password_reset_form").include_query_params(token=token) + return RedirectResponse(reset_url, status_code=status.HTTP_303_SEE_OTHER) + + +@router.get( + "/reset-password", + response_class=HTMLResponse, + include_in_schema=False, + name="auth.password_reset_form", +) +def password_reset_form( + request: Request, + token: str | None = None, + jwt_settings: JWTSettings = Depends(get_jwt_settings), +) -> HTMLResponse: + errors: list[str] = [] + if not token: + errors.append("Missing password reset token.") + else: + try: + payload = decode_access_token(token, jwt_settings) + if _PASSWORD_RESET_SCOPE not in payload.scopes: + errors.append("Invalid token scope.") + except TokenExpiredError: + errors.append( + "Token has expired. Please request a new password reset.") + except (TokenDecodeError, TokenTypeMismatchError): + errors.append("Invalid password reset token.") + + return _template( + request, + "reset_password.html", + { + "form_action": request.url_for("auth.password_reset_submit"), + "token": token, + "errors": errors, + }, + status_code=status.HTTP_400_BAD_REQUEST if errors else status.HTTP_200_OK, + ) + + +@router.post( + "/reset-password", + include_in_schema=False, + name="auth.password_reset_submit", +) +async def password_reset_submit( + request: Request, + uow: UnitOfWork = Depends(get_unit_of_work), + jwt_settings: JWTSettings = Depends(get_jwt_settings), +): + form_data = _normalise_form_data(await request.form()) + try: + form = PasswordResetForm(**form_data) + except ValidationError as exc: + return _template( + request, + "reset_password.html", + { + "form_action": request.url_for("auth.password_reset_submit"), + "token": form_data.get("token"), + "errors": _validation_errors(exc), + }, + status_code=status.HTTP_400_BAD_REQUEST, + ) + + try: + payload = decode_access_token(form.token, jwt_settings) + except TokenExpiredError: + return _reset_error_response( + request, + form.token, + "Token has expired. Please request a new password reset.", + ) + except (TokenDecodeError, TokenTypeMismatchError): + return _reset_error_response( + request, + form.token, + "Invalid password reset token.", + ) + + if _PASSWORD_RESET_SCOPE not in payload.scopes: + return _reset_error_response( + request, + form.token, + "Invalid password reset token scope.", + ) + + users_repo = _require_users_repo(uow) + user_id = int(payload.sub) + user = users_repo.get(user_id) + if not user: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, detail="User not found") + + user.set_password(form.password) + if not user.is_active: + user.is_active = True + + redirect_url = request.url_for( + "auth.login_form").include_query_params(reset="1") + return RedirectResponse( + redirect_url, + status_code=status.HTTP_303_SEE_OTHER, + ) + + +def _reset_error_response(request: Request, token: str, message: str) -> HTMLResponse: + return _template( + request, + "reset_password.html", + { + "form_action": request.url_for("auth.password_reset_submit"), + "token": token, + "errors": [message], + }, + status_code=status.HTTP_400_BAD_REQUEST, + ) diff --git a/routes/calculations.py b/routes/calculations.py new file mode 100644 index 0000000..de12a86 --- /dev/null +++ b/routes/calculations.py @@ -0,0 +1,2119 @@ +"""Routes handling financial calculation workflows.""" + +from __future__ import annotations + +from decimal import Decimal +from typing import Any, Sequence + +from fastapi import APIRouter, Depends, HTTPException, Query, Request, status +from fastapi.responses import HTMLResponse, JSONResponse, Response, RedirectResponse +from pydantic import ValidationError +from starlette.datastructures import FormData +from starlette.routing import NoMatchFound + +from dependencies import ( + get_pricing_metadata, + get_unit_of_work, + require_authenticated_user, + require_authenticated_user_html, +) +from models import ( + Project, + ProjectCapexSnapshot, + ProjectOpexSnapshot, + ProjectProfitability, + Scenario, + ScenarioCapexSnapshot, + ScenarioOpexSnapshot, + ScenarioProfitability, + User, +) +from schemas.calculations import ( + CapexCalculationOptions, + CapexCalculationRequest, + CapexCalculationResult, + CapexComponentInput, + CapexParameters, + OpexCalculationRequest, + OpexCalculationResult, + OpexComponentInput, + OpexOptions, + OpexParameters, + ProfitabilityCalculationRequest, + ProfitabilityCalculationResult, +) +from services.calculations import ( + calculate_initial_capex, + calculate_opex, + calculate_profitability, +) +from services.exceptions import ( + CapexValidationError, + EntityNotFoundError, + OpexValidationError, + ProfitabilityValidationError, +) +from services.pricing import PricingMetadata +from services.unit_of_work import UnitOfWork +from routes.template_filters import create_templates + +router = APIRouter(prefix="/calculations", tags=["Calculations"]) +templates = create_templates() + +_SUPPORTED_METALS: tuple[dict[str, str], ...] = ( + {"value": "copper", "label": "Copper"}, + {"value": "gold", "label": "Gold"}, + {"value": "lithium", "label": "Lithium"}, +) +_SUPPORTED_METAL_VALUES = {entry["value"] for entry in _SUPPORTED_METALS} +_DEFAULT_EVALUATION_PERIODS = 10 + +_CAPEX_CATEGORY_OPTIONS: tuple[dict[str, str], ...] = ( + {"value": "equipment", "label": "Equipment"}, + {"value": "infrastructure", "label": "Infrastructure"}, + {"value": "land", "label": "Land & Property"}, + {"value": "miscellaneous", "label": "Miscellaneous"}, +) +_DEFAULT_CAPEX_HORIZON_YEARS = 5 + +_OPEX_CATEGORY_OPTIONS: tuple[dict[str, str], ...] = ( + {"value": "labor", "label": "Labor"}, + {"value": "materials", "label": "Materials"}, + {"value": "energy", "label": "Energy"}, + {"value": "maintenance", "label": "Maintenance"}, + {"value": "other", "label": "Other"}, +) + +_OPEX_FREQUENCY_OPTIONS: tuple[dict[str, str], ...] = ( + {"value": "daily", "label": "Daily"}, + {"value": "weekly", "label": "Weekly"}, + {"value": "monthly", "label": "Monthly"}, + {"value": "quarterly", "label": "Quarterly"}, + {"value": "annually", "label": "Annually"}, +) + +_DEFAULT_OPEX_HORIZON_YEARS = 5 + +_opex_TEMPLATE = "scenarios/opex.html" + + +def _combine_impurity_metadata(metadata: PricingMetadata) -> list[dict[str, object]]: + """Build impurity rows combining thresholds and penalties.""" + + thresholds = getattr(metadata, "impurity_thresholds", {}) or {} + penalties = getattr(metadata, "impurity_penalty_per_ppm", {}) or {} + impurity_codes = sorted({*thresholds.keys(), *penalties.keys()}) + + combined: list[dict[str, object]] = [] + for code in impurity_codes: + combined.append( + { + "name": code, + "threshold": float(thresholds.get(code, 0.0)), + "penalty": float(penalties.get(code, 0.0)), + "value": None, + } + ) + return combined + + +def _value_or_blank(value: Any) -> Any: + if value is None: + return "" + if isinstance(value, Decimal): + return float(value) + return value + + +def _normalise_impurity_entries(entries: Any) -> list[dict[str, Any]]: + if not entries: + return [] + + normalised: list[dict[str, Any]] = [] + for entry in entries: + if isinstance(entry, dict): + getter = entry.get # type: ignore[assignment] + else: + def getter(key, default=None, _entry=entry): return getattr( + _entry, key, default) + + normalised.append( + { + "name": getter("name", "") or "", + "value": _value_or_blank(getter("value")), + "threshold": _value_or_blank(getter("threshold")), + "penalty": _value_or_blank(getter("penalty")), + } + ) + return normalised + + +def _build_default_form_data( + *, + metadata: PricingMetadata, + project: Project | None, + scenario: Scenario | None, +) -> dict[str, Any]: + payable_default = ( + float(metadata.default_payable_pct) + if getattr(metadata, "default_payable_pct", None) is not None + else 100.0 + ) + moisture_threshold_default = ( + float(metadata.moisture_threshold_pct) + if getattr(metadata, "moisture_threshold_pct", None) is not None + else 0.0 + ) + moisture_penalty_default = ( + float(metadata.moisture_penalty_per_pct) + if getattr(metadata, "moisture_penalty_per_pct", None) is not None + else 0.0 + ) + + base_metal_entry = next(iter(_SUPPORTED_METALS), None) + metal = base_metal_entry["value"] if base_metal_entry else "" + scenario_resource = getattr(scenario, "primary_resource", None) + if scenario_resource is not None: + candidate = getattr(scenario_resource, "value", str(scenario_resource)) + if candidate in _SUPPORTED_METAL_VALUES: + metal = candidate + + currency = "" + scenario_currency = getattr(scenario, "currency", None) + metadata_currency = getattr(metadata, "default_currency", None) + if scenario_currency: + currency = str(scenario_currency).upper() + elif metadata_currency: + currency = str(metadata_currency).upper() + + discount_rate = "" + scenario_discount = getattr(scenario, "discount_rate", None) + if scenario_discount is not None: + discount_rate = float(scenario_discount) # type: ignore[arg-type] + + return { + "metal": metal, + "ore_tonnage": "", + "head_grade_pct": "", + "recovery_pct": "", + "payable_pct": payable_default, + "reference_price": "", + "treatment_charge": "", + "smelting_charge": "", + "opex": "", + "moisture_pct": "", + "moisture_threshold_pct": moisture_threshold_default, + "moisture_penalty_per_pct": moisture_penalty_default, + "premiums": "", + "fx_rate": 1.0, + "currency_code": currency, + "impurities": None, + "capex": "", + "sustaining_capex": "", + "discount_rate": discount_rate, + "periods": _DEFAULT_EVALUATION_PERIODS, + } + + +def _prepare_form_data_for_display( + *, + defaults: dict[str, Any], + overrides: dict[str, Any] | None = None, + allow_empty_override: bool = False, +) -> dict[str, Any]: + data = dict(defaults) + + if overrides: + for key, value in overrides.items(): + if key == "csrf_token": + continue + if key == "impurities": + data["impurities"] = _normalise_impurity_entries(value) + continue + if value is None and not allow_empty_override: + continue + data[key] = _value_or_blank(value) + + # Normalise defaults and ensure strings for None. + for key, value in list(data.items()): + if key == "impurities": + if value is None: + data[key] = None + else: + data[key] = _normalise_impurity_entries(value) + continue + data[key] = _value_or_blank(value) + + return data + + +def _coerce_bool(value: Any) -> bool: + if isinstance(value, bool): + return value + if isinstance(value, str): + lowered = value.strip().lower() + return lowered in {"1", "true", "yes", "on"} + return bool(value) + + +def _serialise_capex_component_entry(component: Any) -> dict[str, Any]: + if isinstance(component, CapexComponentInput): + raw = component.model_dump() + elif isinstance(component, dict): + raw = dict(component) + else: + raw = { + "id": getattr(component, "id", None), + "name": getattr(component, "name", None), + "category": getattr(component, "category", None), + "amount": getattr(component, "amount", None), + "currency": getattr(component, "currency", None), + "spend_year": getattr(component, "spend_year", None), + "notes": getattr(component, "notes", None), + } + + return { + "id": raw.get("id"), + "name": _value_or_blank(raw.get("name")), + "category": raw.get("category") or "equipment", + "amount": _value_or_blank(raw.get("amount")), + "currency": _value_or_blank(raw.get("currency")), + "spend_year": _value_or_blank(raw.get("spend_year")), + "notes": _value_or_blank(raw.get("notes")), + } + + +def _serialise_capex_parameters(parameters: Any) -> dict[str, Any]: + if isinstance(parameters, CapexParameters): + raw = parameters.model_dump() + elif isinstance(parameters, dict): + raw = dict(parameters) + else: + raw = {} + + return { + "currency_code": _value_or_blank(raw.get("currency_code")), + "contingency_pct": _value_or_blank(raw.get("contingency_pct")), + "discount_rate_pct": _value_or_blank(raw.get("discount_rate_pct")), + "evaluation_horizon_years": _value_or_blank( + raw.get("evaluation_horizon_years") + ), + } + + +def _serialise_capex_options(options: Any) -> dict[str, Any]: + if isinstance(options, CapexCalculationOptions): + raw = options.model_dump() + elif isinstance(options, dict): + raw = dict(options) + else: + raw = {} + + return {"persist": _coerce_bool(raw.get("persist", False))} + + +def _build_capex_defaults( + *, + project: Project | None, + scenario: Scenario | None, +) -> dict[str, Any]: + currency = "" + if scenario and getattr(scenario, "currency", None): + currency = str(scenario.currency).upper() + elif project and getattr(project, "currency", None): + currency = str(project.currency).upper() + + discount_rate = "" + scenario_discount = getattr(scenario, "discount_rate", None) + if scenario_discount is not None: + discount_rate = float(scenario_discount) + + return { + "components": [], + "parameters": { + "currency_code": currency or None, + "contingency_pct": None, + "discount_rate_pct": discount_rate, + "evaluation_horizon_years": _DEFAULT_CAPEX_HORIZON_YEARS, + }, + "options": { + "persist": bool(scenario or project), + }, + "currency_code": currency or None, + "default_horizon": _DEFAULT_CAPEX_HORIZON_YEARS, + "last_updated_at": getattr(scenario, "capex_updated_at", None), + } + + +def _prepare_capex_context( + request: Request, + *, + project: Project | None, + scenario: Scenario | None, + form_data: dict[str, Any] | None = None, + result: CapexCalculationResult | None = None, + errors: list[str] | None = None, + notices: list[str] | None = None, + component_errors: list[str] | None = None, + component_notices: list[str] | None = None, +) -> dict[str, Any]: + if form_data is not None and hasattr(form_data, "model_dump"): + form_data = form_data.model_dump() # type: ignore[assignment] + + defaults = _build_capex_defaults(project=project, scenario=scenario) + + raw_components: list[Any] = [] + if form_data and "components" in form_data: + raw_components = list(form_data.get("components") or []) + components = [ + _serialise_capex_component_entry(component) for component in raw_components + ] + + raw_parameters = defaults["parameters"].copy() + if form_data and form_data.get("parameters"): + raw_parameters.update( + _serialise_capex_parameters(form_data.get("parameters")) + ) + parameters = _serialise_capex_parameters(raw_parameters) + + raw_options = defaults["options"].copy() + if form_data and form_data.get("options"): + raw_options.update(_serialise_capex_options(form_data.get("options"))) + options = _serialise_capex_options(raw_options) + + currency_code = parameters.get( + "currency_code") or defaults["currency_code"] + + navigation = _resolve_navigation_links( + request, + project=project, + scenario=scenario, + ) + + return { + "request": request, + "project": project, + "scenario": scenario, + "components": components, + "parameters": parameters, + "options": options, + "currency_code": currency_code, + "category_options": _CAPEX_CATEGORY_OPTIONS, + "default_horizon": defaults["default_horizon"], + "last_updated_at": defaults["last_updated_at"], + "result": result, + "errors": errors or [], + "notices": notices or [], + "component_errors": component_errors or [], + "component_notices": component_notices or [], + "form_action": str(request.url), + "csrf_token": None, + **navigation, + } + + +def _serialise_opex_component_entry(component: Any) -> dict[str, Any]: + if isinstance(component, OpexComponentInput): + raw = component.model_dump() + elif isinstance(component, dict): + raw = dict(component) + else: + raw = { + "id": getattr(component, "id", None), + "name": getattr(component, "name", None), + "category": getattr(component, "category", None), + "unit_cost": getattr(component, "unit_cost", None), + "quantity": getattr(component, "quantity", None), + "frequency": getattr(component, "frequency", None), + "currency": getattr(component, "currency", None), + "period_start": getattr(component, "period_start", None), + "period_end": getattr(component, "period_end", None), + "notes": getattr(component, "notes", None), + } + + return { + "id": raw.get("id"), + "name": _value_or_blank(raw.get("name")), + "category": raw.get("category") or "labor", + "unit_cost": _value_or_blank(raw.get("unit_cost")), + "quantity": _value_or_blank(raw.get("quantity")), + "frequency": raw.get("frequency") or "monthly", + "currency": _value_or_blank(raw.get("currency")), + "period_start": _value_or_blank(raw.get("period_start")), + "period_end": _value_or_blank(raw.get("period_end")), + "notes": _value_or_blank(raw.get("notes")), + } + + +def _serialise_opex_parameters(parameters: Any) -> dict[str, Any]: + if isinstance(parameters, OpexParameters): + raw = parameters.model_dump() + elif isinstance(parameters, dict): + raw = dict(parameters) + else: + raw = {} + + return { + "currency_code": _value_or_blank(raw.get("currency_code")), + "escalation_pct": _value_or_blank(raw.get("escalation_pct")), + "discount_rate_pct": _value_or_blank(raw.get("discount_rate_pct")), + "evaluation_horizon_years": _value_or_blank( + raw.get("evaluation_horizon_years") + ), + "apply_escalation": _coerce_bool(raw.get("apply_escalation", True)), + } + + +def _serialise_opex_options(options: Any) -> dict[str, Any]: + if isinstance(options, OpexOptions): + raw = options.model_dump() + elif isinstance(options, dict): + raw = dict(options) + else: + raw = {} + + return { + "persist": _coerce_bool(raw.get("persist", False)), + "snapshot_notes": _value_or_blank(raw.get("snapshot_notes")), + } + + +def _build_opex_defaults( + *, + project: Project | None, + scenario: Scenario | None, +) -> dict[str, Any]: + currency = "" + if scenario and getattr(scenario, "currency", None): + currency = str(scenario.currency).upper() + elif project and getattr(project, "currency", None): + currency = str(project.currency).upper() + + discount_rate = "" + scenario_discount = getattr(scenario, "discount_rate", None) + if scenario_discount is not None: + discount_rate = float(scenario_discount) + + last_updated_at = getattr(scenario, "opex_updated_at", None) + + return { + "components": [], + "parameters": { + "currency_code": currency or None, + "escalation_pct": None, + "discount_rate_pct": discount_rate, + "evaluation_horizon_years": _DEFAULT_OPEX_HORIZON_YEARS, + "apply_escalation": True, + }, + "options": { + "persist": bool(scenario or project), + "snapshot_notes": None, + }, + "currency_code": currency or None, + "default_horizon": _DEFAULT_OPEX_HORIZON_YEARS, + "last_updated_at": last_updated_at, + } + + +def _prepare_opex_context( + request: Request, + *, + project: Project | None, + scenario: Scenario | None, + form_data: dict[str, Any] | None = None, + result: OpexCalculationResult | None = None, + errors: list[str] | None = None, + notices: list[str] | None = None, + component_errors: list[str] | None = None, + component_notices: list[str] | None = None, +) -> dict[str, Any]: + if form_data is not None and hasattr(form_data, "model_dump"): + form_data = form_data.model_dump() # type: ignore[assignment] + + defaults = _build_opex_defaults(project=project, scenario=scenario) + + raw_components: list[Any] = [] + if form_data and "components" in form_data: + raw_components = list(form_data.get("components") or []) + components = [ + _serialise_opex_component_entry(component) for component in raw_components + ] + + raw_parameters = defaults["parameters"].copy() + if form_data and form_data.get("parameters"): + raw_parameters.update( + _serialise_opex_parameters(form_data.get("parameters")) + ) + parameters = _serialise_opex_parameters(raw_parameters) + + raw_options = defaults["options"].copy() + if form_data and form_data.get("options"): + raw_options.update(_serialise_opex_options(form_data.get("options"))) + options = _serialise_opex_options(raw_options) + + currency_code = parameters.get( + "currency_code") or defaults["currency_code"] + + navigation = _resolve_navigation_links( + request, + project=project, + scenario=scenario, + ) + + return { + "request": request, + "project": project, + "scenario": scenario, + "components": components, + "parameters": parameters, + "options": options, + "currency_code": currency_code, + "category_options": _OPEX_CATEGORY_OPTIONS, + "frequency_options": _OPEX_FREQUENCY_OPTIONS, + "default_horizon": defaults["default_horizon"], + "last_updated_at": defaults["last_updated_at"], + "result": result, + "errors": errors or [], + "notices": notices or [], + "component_errors": component_errors or [], + "component_notices": component_notices or [], + "form_action": str(request.url), + "csrf_token": None, + **navigation, + } + + +def _format_error_location(location: tuple[Any, ...]) -> str: + path = "" + for part in location: + if isinstance(part, int): + path += f"[{part}]" + else: + if path: + path += f".{part}" + else: + path = str(part) + return path or "(input)" + + +def _partition_capex_error_messages( + errors: Sequence[Any], +) -> tuple[list[str], list[str]]: + general: list[str] = [] + component_specific: list[str] = [] + + for error in errors: + if isinstance(error, dict): + mapping = error + else: + try: + mapping = dict(error) + except TypeError: + mapping = {} + + location = tuple(mapping.get("loc", ())) + message = mapping.get("msg", "Invalid value") + formatted_location = _format_error_location(location) + entry = f"{formatted_location} - {message}" + if location and location[0] == "components": + component_specific.append(entry) + else: + general.append(entry) + + return general, component_specific + + +def _partition_opex_error_messages( + errors: Sequence[Any], +) -> tuple[list[str], list[str]]: + general: list[str] = [] + component_specific: list[str] = [] + + for error in errors: + if isinstance(error, dict): + mapping = error + else: + try: + mapping = dict(error) + except TypeError: + mapping = {} + + location = tuple(mapping.get("loc", ())) + message = mapping.get("msg", "Invalid value") + formatted_location = _format_error_location(location) + entry = f"{formatted_location} - {message}" + if location and location[0] == "components": + component_specific.append(entry) + else: + general.append(entry) + + return general, component_specific + + +def _opex_form_to_payload(form: FormData) -> dict[str, Any]: + data: dict[str, Any] = {} + components: dict[int, dict[str, Any]] = {} + parameters: dict[str, Any] = {} + options: dict[str, Any] = {} + + for key, value in form.multi_items(): + normalised_value = _normalise_form_value(value) + + if key.startswith("components["): + try: + index_part = key[len("components["):] + index_str, remainder = index_part.split("]", 1) + field = remainder.strip()[1:-1] + index = int(index_str) + except (ValueError, IndexError): + continue + entry = components.setdefault(index, {}) + entry[field] = normalised_value + continue + + if key.startswith("parameters["): + field = key[len("parameters["):-1] + if field == "apply_escalation": + parameters[field] = _coerce_bool(normalised_value) + else: + parameters[field] = normalised_value + continue + + if key.startswith("options["): + field = key[len("options["):-1] + options[field] = normalised_value + continue + + if key == "csrf_token": + continue + + data[key] = normalised_value + + if components: + ordered = [components[index] for index in sorted(components.keys())] + data["components"] = ordered + + if parameters: + data["parameters"] = parameters + + if options: + if "persist" in options: + options["persist"] = _coerce_bool(options.get("persist")) + data["options"] = options + + return data + + +def _capex_form_to_payload(form: FormData) -> dict[str, Any]: + data: dict[str, Any] = {} + components: dict[int, dict[str, Any]] = {} + parameters: dict[str, Any] = {} + options: dict[str, Any] = {} + + for key, value in form.multi_items(): + normalised_value = _normalise_form_value(value) + + if key.startswith("components["): + try: + index_part = key[len("components["):] + index_str, remainder = index_part.split("]", 1) + field = remainder.strip()[1:-1] + index = int(index_str) + except (ValueError, IndexError): + continue + entry = components.setdefault(index, {}) + entry[field] = normalised_value + continue + + if key.startswith("parameters["): + field = key[len("parameters["):-1] + parameters[field] = normalised_value + continue + + if key.startswith("options["): + field = key[len("options["):-1] + options[field] = normalised_value + continue + + if key == "csrf_token": + continue + + data[key] = normalised_value + + if components: + ordered = [ + components[index] for index in sorted(components.keys()) + ] + data["components"] = ordered + + if parameters: + data["parameters"] = parameters + + if options: + options["persist"] = _coerce_bool(options.get("persist")) + data["options"] = options + + return data + + +async def _extract_opex_payload(request: Request) -> dict[str, Any]: + content_type = request.headers.get("content-type", "").lower() + if content_type.startswith("application/json"): + body = await request.json() + return body if isinstance(body, dict) else {} + form = await request.form() + return _opex_form_to_payload(form) + + +async def _extract_capex_payload(request: Request) -> dict[str, Any]: + content_type = request.headers.get("content-type", "").lower() + if content_type.startswith("application/json"): + body = await request.json() + return body if isinstance(body, dict) else {} + form = await request.form() + return _capex_form_to_payload(form) + + +def _resolve_navigation_links( + request: Request, + *, + project: Project | None, + scenario: Scenario | None, +) -> dict[str, str | None]: + project_url: str | None = None + scenario_url: str | None = None + scenario_portfolio_url: str | None = None + + candidate_project = project + if scenario is not None and getattr(scenario, "id", None) is not None: + try: + scenario_url = str( + request.url_for( + "scenarios.view_scenario", scenario_id=scenario.id + ) + ) + except NoMatchFound: + scenario_url = None + + try: + scenario_portfolio_url = str( + request.url_for( + "scenarios.project_scenario_list", + project_id=scenario.project_id, + ) + ) + except NoMatchFound: + scenario_portfolio_url = None + + if candidate_project is None: + candidate_project = getattr(scenario, "project", None) + + if candidate_project is not None and getattr(candidate_project, "id", None) is not None: + try: + project_url = str( + request.url_for( + "projects.view_project", project_id=candidate_project.id + ) + ) + except NoMatchFound: + project_url = None + + if scenario_portfolio_url is None: + try: + scenario_portfolio_url = str( + request.url_for( + "scenarios.project_scenario_list", + project_id=candidate_project.id, + ) + ) + except NoMatchFound: + scenario_portfolio_url = None + + cancel_url = scenario_url or project_url or request.headers.get("Referer") + if cancel_url is None: + try: + cancel_url = str(request.url_for("projects.project_list_page")) + except NoMatchFound: + cancel_url = "/" + + return { + "project_url": project_url, + "scenario_url": scenario_url, + "scenario_portfolio_url": scenario_portfolio_url, + "cancel_url": cancel_url, + } + + +def _prepare_default_context( + request: Request, + *, + project: Project | None = None, + scenario: Scenario | None = None, + metadata: PricingMetadata, + form_data: dict[str, Any] | None = None, + allow_empty_override: bool = False, + result: ProfitabilityCalculationResult | None = None, +) -> dict[str, object]: + """Assemble template context shared across calculation endpoints.""" + + defaults = _build_default_form_data( + metadata=metadata, + project=project, + scenario=scenario, + ) + data = _prepare_form_data_for_display( + defaults=defaults, + overrides=form_data, + allow_empty_override=allow_empty_override, + ) + + navigation = _resolve_navigation_links( + request, + project=project, + scenario=scenario, + ) + + return { + "request": request, + "project": project, + "scenario": scenario, + "metadata": metadata, + "metadata_impurities": _combine_impurity_metadata(metadata), + "supported_metals": _SUPPORTED_METALS, + "data": data, + "result": result, + "errors": [], + "notices": [], + "form_action": str(request.url), + "csrf_token": None, + "default_periods": _DEFAULT_EVALUATION_PERIODS, + **navigation, + } + + +def _load_project_and_scenario( + *, + uow: UnitOfWork, + project_id: int | None, + scenario_id: int | None, +) -> tuple[Project | None, Scenario | None]: + project: Project | None = None + scenario: Scenario | None = None + + if project_id is not None and uow.projects is not None: + try: + project = uow.projects.get(project_id, with_children=False) + except EntityNotFoundError: + project = None + + if scenario_id is not None and uow.scenarios is not None: + try: + scenario = uow.scenarios.get(scenario_id, with_children=False) + except EntityNotFoundError: + scenario = None + if scenario is not None and project is None: + project = scenario.project + + return project, scenario + + +def _require_project_and_scenario( + *, + uow: UnitOfWork, + project_id: int, + scenario_id: int, +) -> tuple[Project, Scenario]: + project, scenario = _load_project_and_scenario( + uow=uow, project_id=project_id, scenario_id=scenario_id + ) + if scenario is None: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Scenario not found", + ) + owning_project = project or scenario.project + if owning_project is None or owning_project.id != project_id: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Scenario does not belong to specified project", + ) + return owning_project, scenario + + +def _is_json_request(request: Request) -> bool: + content_type = request.headers.get("content-type", "").lower() + accept = request.headers.get("accept", "").lower() + return "application/json" in content_type or "application/json" in accept + + +def _normalise_form_value(value: Any) -> Any: + if isinstance(value, str): + stripped = value.strip() + return stripped if stripped != "" else None + return value + + +def _normalise_legacy_context_params( + *, project_id: Any | None, scenario_id: Any | None +) -> tuple[int | None, int | None, list[str]]: + """Convert raw legacy query params to validated identifiers.""" + + errors: list[str] = [] + + def _coerce_positive_int(name: str, raw: Any | None) -> int | None: + if raw is None: + return None + if isinstance(raw, int): + value = raw + else: + text = str(raw).strip() + if text == "": + return None + if text.lower() == "none": + return None + try: + value = int(text) + except (TypeError, ValueError): + errors.append(f"{name} must be a positive integer") + return None + + if value <= 0: + errors.append(f"{name} must be a positive integer") + return None + return value + + normalised_project_id = _coerce_positive_int("project_id", project_id) + normalised_scenario_id = _coerce_positive_int("scenario_id", scenario_id) + + return normalised_project_id, normalised_scenario_id, errors + + +def _form_to_payload(form: FormData) -> dict[str, Any]: + data: dict[str, Any] = {} + impurities: dict[int, dict[str, Any]] = {} + + for key, value in form.multi_items(): + normalised_value = _normalise_form_value(value) + if key.startswith("impurities[") and "]" in key: + try: + index_part = key.split("[", 1)[1] + index_str, remainder = index_part.split("]", 1) + field = remainder.strip("[]") + if not field: + continue + index = int(index_str) + except (ValueError, IndexError): + continue + entry = impurities.setdefault(index, {}) + entry[field] = normalised_value + continue + + if key == "csrf_token": + continue + data[key] = normalised_value + + if impurities: + ordered = [] + for _, entry in sorted(impurities.items()): + if not entry.get("name"): + continue + ordered.append(entry) + if ordered: + data["impurities"] = ordered + + return data + + +async def _extract_payload(request: Request) -> dict[str, Any]: + if request.headers.get("content-type", "").lower().startswith("application/json"): + return await request.json() + form = await request.form() + return _form_to_payload(form) + + +def _list_from_context(context: dict[str, Any], key: str) -> list: + value = context.get(key) + if isinstance(value, list): + return value + new_list: list = [] + context[key] = new_list + return new_list + + +def _should_persist_snapshot( + *, + project: Project | None, + scenario: Scenario | None, + payload: ProfitabilityCalculationRequest, +) -> bool: + """Determine whether to persist the profitability result. + + Current strategy persists automatically when a scenario or project context + is provided. This can be refined later to honour explicit user choices. + """ + + return bool(scenario or project) + + +def _persist_profitability_snapshots( + *, + uow: UnitOfWork, + project: Project | None, + scenario: Scenario | None, + user: User | None, + request_model: ProfitabilityCalculationRequest, + result: ProfitabilityCalculationResult, +) -> None: + if not _should_persist_snapshot(project=project, scenario=scenario, payload=request_model): + return + + created_by_id = getattr(user, "id", None) + + revenue_total = float(result.pricing.net_revenue) + processing_total = float(result.costs.opex_total) + sustaining_total = float(result.costs.sustaining_capex_total) + capex = float(result.costs.capex) + net_cash_flow_total = revenue_total - ( + processing_total + sustaining_total + capex + ) + + npv_value = ( + float(result.metrics.npv) + if result.metrics.npv is not None + else None + ) + irr_value = ( + float(result.metrics.irr) + if result.metrics.irr is not None + else None + ) + payback_value = ( + float(result.metrics.payback_period) + if result.metrics.payback_period is not None + else None + ) + margin_value = ( + float(result.metrics.margin) + if result.metrics.margin is not None + else None + ) + + payload = { + "request": request_model.model_dump(mode="json"), + "result": result.model_dump(), + } + + if scenario and uow.scenario_profitability: + scenario_snapshot = ScenarioProfitability( + scenario_id=scenario.id, + created_by_id=created_by_id, + calculation_source="calculations.profitability", + currency_code=result.currency, + npv=npv_value, + irr_pct=irr_value, + payback_period_years=payback_value, + margin_pct=margin_value, + revenue_total=revenue_total, + opex_total=processing_total, + sustaining_capex_total=sustaining_total, + capex=capex, + net_cash_flow_total=net_cash_flow_total, + payload=payload, + ) + uow.scenario_profitability.create(scenario_snapshot) + + if project and uow.project_profitability: + project_snapshot = ProjectProfitability( + project_id=project.id, + created_by_id=created_by_id, + calculation_source="calculations.profitability", + currency_code=result.currency, + npv=npv_value, + irr_pct=irr_value, + payback_period_years=payback_value, + margin_pct=margin_value, + revenue_total=revenue_total, + opex_total=processing_total, + sustaining_capex_total=sustaining_total, + capex=capex, + net_cash_flow_total=net_cash_flow_total, + payload=payload, + ) + uow.project_profitability.create(project_snapshot) + + +def _should_persist_capex( + *, + project: Project | None, + scenario: Scenario | None, + request_model: CapexCalculationRequest, +) -> bool: + """Determine whether capex snapshots should be stored.""" + + persist_requested = bool( + getattr(request_model, "options", None) + and request_model.options.persist + ) + return persist_requested and bool(project or scenario) + + +def _persist_capex_snapshots( + *, + uow: UnitOfWork, + project: Project | None, + scenario: Scenario | None, + user: User | None, + request_model: CapexCalculationRequest, + result: CapexCalculationResult, +) -> None: + if not _should_persist_capex( + project=project, + scenario=scenario, + request_model=request_model, + ): + return + + created_by_id = getattr(user, "id", None) + totals = result.totals + component_count = len(result.components) + + payload = { + "request": request_model.model_dump(mode="json"), + "result": result.model_dump(), + } + + if scenario and uow.scenario_capex: + scenario_snapshot = ScenarioCapexSnapshot( + scenario_id=scenario.id, + created_by_id=created_by_id, + calculation_source="calculations.capex", + currency_code=result.currency, + total_capex=float(totals.overall), + contingency_pct=float(totals.contingency_pct), + contingency_amount=float(totals.contingency_amount), + total_with_contingency=float(totals.with_contingency), + component_count=component_count, + payload=payload, + ) + uow.scenario_capex.create(scenario_snapshot) + + if project and uow.project_capex: + project_snapshot = ProjectCapexSnapshot( + project_id=project.id, + created_by_id=created_by_id, + calculation_source="calculations.capex", + currency_code=result.currency, + total_capex=float(totals.overall), + contingency_pct=float(totals.contingency_pct), + contingency_amount=float(totals.contingency_amount), + total_with_contingency=float(totals.with_contingency), + component_count=component_count, + payload=payload, + ) + uow.project_capex.create(project_snapshot) + + +def _should_persist_opex( + *, + project: Project | None, + scenario: Scenario | None, + request_model: OpexCalculationRequest, +) -> bool: + persist_requested = bool( + getattr(request_model, "options", None) + and request_model.options.persist + ) + return persist_requested and bool(project or scenario) + + +def _persist_opex_snapshots( + *, + uow: UnitOfWork, + project: Project | None, + scenario: Scenario | None, + user: User | None, + request_model: OpexCalculationRequest, + result: OpexCalculationResult, +) -> None: + if not _should_persist_opex( + project=project, + scenario=scenario, + request_model=request_model, + ): + return + + created_by_id = getattr(user, "id", None) + totals = result.totals + metrics = result.metrics + parameters = result.parameters + + overall_annual = float(totals.overall_annual) + escalated_total = ( + float(totals.escalated_total) + if totals.escalated_total is not None + else None + ) + annual_average = ( + float(metrics.annual_average) + if metrics.annual_average is not None + else None + ) + evaluation_horizon = ( + int(parameters.evaluation_horizon_years) + if parameters.evaluation_horizon_years is not None + else None + ) + escalation_pct = ( + float(totals.escalation_pct) + if totals.escalation_pct is not None + else ( + float(parameters.escalation_pct) + if parameters.escalation_pct is not None and parameters.apply_escalation + else None + ) + ) + apply_escalation = bool(parameters.apply_escalation) + component_count = len(result.components) + + payload = { + "request": request_model.model_dump(mode="json"), + "result": result.model_dump(), + } + + if scenario and uow.scenario_opex: + scenario_snapshot = ScenarioOpexSnapshot( + scenario_id=scenario.id, + created_by_id=created_by_id, + calculation_source="calculations.opex", + currency_code=result.currency, + overall_annual=overall_annual, + escalated_total=escalated_total, + annual_average=annual_average, + evaluation_horizon_years=evaluation_horizon, + escalation_pct=escalation_pct, + apply_escalation=apply_escalation, + component_count=component_count, + payload=payload, + ) + uow.scenario_opex.create(scenario_snapshot) + + if project and uow.project_opex: + project_snapshot = ProjectOpexSnapshot( + project_id=project.id, + created_by_id=created_by_id, + calculation_source="calculations.opex", + currency_code=result.currency, + overall_annual=overall_annual, + escalated_total=escalated_total, + annual_average=annual_average, + evaluation_horizon_years=evaluation_horizon, + escalation_pct=escalation_pct, + apply_escalation=apply_escalation, + component_count=component_count, + payload=payload, + ) + uow.project_opex.create(project_snapshot) + + +@router.get( + "/projects/{project_id}/scenarios/{scenario_id}/calculations/opex", + response_class=HTMLResponse, + name="calculations.scenario_opex_form", +) +def opex_form( + request: Request, + project_id: int, + scenario_id: int, + _: User = Depends(require_authenticated_user_html), + uow: UnitOfWork = Depends(get_unit_of_work), +) -> HTMLResponse: + """Render the opex planner with default context.""" + + project, scenario = _require_project_and_scenario( + uow=uow, project_id=project_id, scenario_id=scenario_id + ) + context = _prepare_opex_context( + request, + project=project, + scenario=scenario, + ) + return templates.TemplateResponse(request, _opex_TEMPLATE, context) + + +@router.post( + "/projects/{project_id}/scenarios/{scenario_id}/calculations/opex", + name="calculations.scenario_opex_submit", +) +async def opex_submit( + request: Request, + project_id: int, + scenario_id: int, + current_user: User = Depends(require_authenticated_user), + uow: UnitOfWork = Depends(get_unit_of_work), +) -> Response: + """Handle opex submissions and respond with HTML or JSON.""" + + wants_json = _is_json_request(request) + payload_data = await _extract_opex_payload(request) + + project, scenario = _require_project_and_scenario( + uow=uow, project_id=project_id, scenario_id=scenario_id + ) + + try: + request_model = OpexCalculationRequest.model_validate( + payload_data + ) + result = calculate_opex(request_model) + except ValidationError as exc: + if wants_json: + return JSONResponse( + status_code=status.HTTP_422_UNPROCESSABLE_CONTENT, + content={"errors": exc.errors()}, + ) + + general_errors, component_errors = _partition_opex_error_messages( + exc.errors() + ) + context = _prepare_opex_context( + request, + project=project, + scenario=scenario, + form_data=payload_data, + errors=general_errors, + component_errors=component_errors, + ) + return templates.TemplateResponse( + request, + _opex_TEMPLATE, + context, + status_code=status.HTTP_422_UNPROCESSABLE_CONTENT, + ) + except OpexValidationError as exc: + if wants_json: + return JSONResponse( + status_code=status.HTTP_422_UNPROCESSABLE_CONTENT, + content={ + "errors": list(exc.field_errors or []), + "message": exc.message, + }, + ) + + errors = list(exc.field_errors or []) or [exc.message] + context = _prepare_opex_context( + request, + project=project, + scenario=scenario, + form_data=payload_data, + errors=errors, + ) + return templates.TemplateResponse( + request, + _opex_TEMPLATE, + context, + status_code=status.HTTP_422_UNPROCESSABLE_CONTENT, + ) + + _persist_opex_snapshots( + uow=uow, + project=project, + scenario=scenario, + user=current_user, + request_model=request_model, + result=result, + ) + + if wants_json: + return JSONResponse( + status_code=status.HTTP_200_OK, + content=result.model_dump(), + ) + + context = _prepare_opex_context( + request, + project=project, + scenario=scenario, + form_data=request_model.model_dump(mode="json"), + result=result, + ) + notices = _list_from_context(context, "notices") + notices.append("Opex calculation completed successfully.") + + return templates.TemplateResponse( + request, + _opex_TEMPLATE, + context, + status_code=status.HTTP_200_OK, + ) + + +@router.get( + "/opex", + response_class=HTMLResponse, + name="calculations.opex_form_legacy", +) +def opex_form_legacy( + request: Request, + _: User = Depends(require_authenticated_user_html), + uow: UnitOfWork = Depends(get_unit_of_work), + project_id: str | None = Query( + None, description="Optional project identifier"), + scenario_id: str | None = Query( + None, description="Optional scenario identifier"), +) -> Response: + normalised_project_id, normalised_scenario_id, errors = _normalise_legacy_context_params( + project_id=project_id, + scenario_id=scenario_id, + ) + + if errors: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="; ".join(errors), + ) + + if normalised_scenario_id is not None: + project, scenario = _load_project_and_scenario( + uow=uow, + project_id=normalised_project_id, + scenario_id=normalised_scenario_id, + ) + if scenario is None: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Scenario not found", + ) + owning_project = project or scenario.project + if owning_project is None: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Project not found", + ) + redirect_url = request.url_for( + "calculations.opex_form", + project_id=owning_project.id, + scenario_id=scenario.id, + ) + return RedirectResponse( + redirect_url, + status_code=status.HTTP_308_PERMANENT_REDIRECT, + ) + + if normalised_project_id is not None: + target_url = request.url_for( + "scenarios.project_scenario_list", project_id=normalised_project_id + ) + return RedirectResponse( + target_url, + status_code=status.HTTP_303_SEE_OTHER, + ) + + return RedirectResponse( + request.url_for("projects.project_list_page"), + status_code=status.HTTP_303_SEE_OTHER, + ) + + +@router.post( + "/opex", + name="calculations.opex_submit_legacy", +) +async def opex_submit_legacy( + request: Request, + _: User = Depends(require_authenticated_user), + uow: UnitOfWork = Depends(get_unit_of_work), + project_id: str | None = Query( + None, description="Optional project identifier"), + scenario_id: str | None = Query( + None, description="Optional scenario identifier"), +) -> Response: + normalised_project_id, normalised_scenario_id, errors = _normalise_legacy_context_params( + project_id=project_id, + scenario_id=scenario_id, + ) + + if errors: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="; ".join(errors), + ) + + if normalised_scenario_id is None: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="scenario_id query parameter required; use the scenario-scoped calculations route.", + ) + + project, scenario = _load_project_and_scenario( + uow=uow, + project_id=normalised_project_id, + scenario_id=normalised_scenario_id, + ) + if scenario is None: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Scenario not found", + ) + owning_project = project or scenario.project + if owning_project is None: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Project not found", + ) + + redirect_url = request.url_for( + "calculations.opex_submit", + project_id=owning_project.id, + scenario_id=scenario.id, + ) + return RedirectResponse( + redirect_url, + status_code=status.HTTP_308_PERMANENT_REDIRECT, + ) + + +@router.get( + "/projects/{project_id}/scenarios/{scenario_id}/calculations/capex", + response_class=HTMLResponse, + name="calculations.scenario_capex_form", +) +def capex_form( + request: Request, + project_id: int, + scenario_id: int, + _: User = Depends(require_authenticated_user_html), + uow: UnitOfWork = Depends(get_unit_of_work), +) -> HTMLResponse: + """Render the capex planner template with defaults.""" + + project, scenario = _require_project_and_scenario( + uow=uow, project_id=project_id, scenario_id=scenario_id + ) + context = _prepare_capex_context( + request, + project=project, + scenario=scenario, + ) + return templates.TemplateResponse(request, "scenarios/capex.html", context) + + +@router.post( + "/projects/{project_id}/scenarios/{scenario_id}/calculations/capex", + name="calculations.scenario_capex_submit", +) +async def capex_submit( + request: Request, + project_id: int, + scenario_id: int, + current_user: User = Depends(require_authenticated_user), + uow: UnitOfWork = Depends(get_unit_of_work), +) -> Response: + """Process capex submissions and return aggregated results.""" + + wants_json = _is_json_request(request) + payload_data = await _extract_capex_payload(request) + + project, scenario = _require_project_and_scenario( + uow=uow, project_id=project_id, scenario_id=scenario_id + ) + + try: + request_model = CapexCalculationRequest.model_validate(payload_data) + result = calculate_initial_capex(request_model) + except ValidationError as exc: + if wants_json: + return JSONResponse( + status_code=status.HTTP_422_UNPROCESSABLE_CONTENT, + content={"errors": exc.errors()}, + ) + + general_errors, component_errors = _partition_capex_error_messages( + exc.errors() + ) + context = _prepare_capex_context( + request, + project=project, + scenario=scenario, + form_data=payload_data, + errors=general_errors, + component_errors=component_errors, + ) + return templates.TemplateResponse( + request, + "scenarios/capex.html", + context, + status_code=status.HTTP_422_UNPROCESSABLE_CONTENT, + ) + except CapexValidationError as exc: + if wants_json: + return JSONResponse( + status_code=status.HTTP_422_UNPROCESSABLE_CONTENT, + content={ + "errors": list(exc.field_errors or []), + "message": exc.message, + }, + ) + + errors = list(exc.field_errors or []) or [exc.message] + context = _prepare_capex_context( + request, + project=project, + scenario=scenario, + form_data=payload_data, + errors=errors, + ) + return templates.TemplateResponse( + request, + "scenarios/capex.html", + context, + status_code=status.HTTP_422_UNPROCESSABLE_CONTENT, + ) + + _persist_capex_snapshots( + uow=uow, + project=project, + scenario=scenario, + user=current_user, + request_model=request_model, + result=result, + ) + + if wants_json: + return JSONResponse( + status_code=status.HTTP_200_OK, + content=result.model_dump(), + ) + + context = _prepare_capex_context( + request, + project=project, + scenario=scenario, + form_data=request_model.model_dump(mode="json"), + result=result, + ) + notices = _list_from_context(context, "notices") + notices.append("Capex calculation completed successfully.") + + return templates.TemplateResponse( + request, + "scenarios/capex.html", + context, + status_code=status.HTTP_200_OK, + ) + + +# Route name aliases retained for legacy integrations using the former identifiers. +router.add_api_route( + "/projects/{project_id}/scenarios/{scenario_id}/calculations/opex", + opex_form, + response_class=HTMLResponse, + methods=["GET"], + name="calculations.opex_form", + include_in_schema=False, +) +router.add_api_route( + "/projects/{project_id}/scenarios/{scenario_id}/calculations/opex", + opex_submit, + methods=["POST"], + name="calculations.opex_submit", + include_in_schema=False, +) +router.add_api_route( + "/projects/{project_id}/scenarios/{scenario_id}/calculations/capex", + capex_form, + response_class=HTMLResponse, + methods=["GET"], + name="calculations.capex_form", + include_in_schema=False, +) +router.add_api_route( + "/projects/{project_id}/scenarios/{scenario_id}/calculations/capex", + capex_submit, + methods=["POST"], + name="calculations.capex_submit", + include_in_schema=False, +) + + +@router.get( + "/capex", + response_class=HTMLResponse, + name="calculations.capex_form_legacy", +) +def capex_form_legacy( + request: Request, + _: User = Depends(require_authenticated_user_html), + uow: UnitOfWork = Depends(get_unit_of_work), + project_id: str | None = Query( + None, description="Optional project identifier"), + scenario_id: str | None = Query( + None, description="Optional scenario identifier"), +) -> Response: + normalised_project_id, normalised_scenario_id, errors = _normalise_legacy_context_params( + project_id=project_id, + scenario_id=scenario_id, + ) + + if errors: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="; ".join(errors), + ) + + if normalised_scenario_id is not None: + project, scenario = _load_project_and_scenario( + uow=uow, + project_id=normalised_project_id, + scenario_id=normalised_scenario_id, + ) + if scenario is None: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Scenario not found", + ) + owning_project = project or scenario.project + if owning_project is None: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Project not found", + ) + redirect_url = request.url_for( + "calculations.capex_form", + project_id=owning_project.id, + scenario_id=scenario.id, + ) + return RedirectResponse( + redirect_url, + status_code=status.HTTP_308_PERMANENT_REDIRECT, + ) + + if normalised_project_id is not None: + target_url = request.url_for( + "scenarios.project_scenario_list", project_id=normalised_project_id + ) + return RedirectResponse( + target_url, + status_code=status.HTTP_303_SEE_OTHER, + ) + + return RedirectResponse( + request.url_for("projects.project_list_page"), + status_code=status.HTTP_303_SEE_OTHER, + ) + + +@router.post( + "/capex", + name="calculations.capex_submit_legacy", +) +async def capex_submit_legacy( + request: Request, + _: User = Depends(require_authenticated_user), + uow: UnitOfWork = Depends(get_unit_of_work), + project_id: str | None = Query( + None, description="Optional project identifier"), + scenario_id: str | None = Query( + None, description="Optional scenario identifier"), +) -> Response: + normalised_project_id, normalised_scenario_id, errors = _normalise_legacy_context_params( + project_id=project_id, + scenario_id=scenario_id, + ) + + if errors: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="; ".join(errors), + ) + + if normalised_scenario_id is None: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="scenario_id query parameter required; use the scenario-scoped calculations route.", + ) + + project, scenario = _load_project_and_scenario( + uow=uow, + project_id=normalised_project_id, + scenario_id=normalised_scenario_id, + ) + if scenario is None: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Scenario not found", + ) + owning_project = project or scenario.project + if owning_project is None: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Project not found", + ) + + redirect_url = request.url_for( + "calculations.capex_submit", + project_id=owning_project.id, + scenario_id=scenario.id, + ) + return RedirectResponse( + redirect_url, + status_code=status.HTTP_308_PERMANENT_REDIRECT, + ) + + +def _render_profitability_form( + request: Request, + *, + metadata: PricingMetadata, + uow: UnitOfWork, + project_id: int | None, + scenario_id: int | None, + allow_redirect: bool, +) -> Response: + project, scenario = _load_project_and_scenario( + uow=uow, project_id=project_id, scenario_id=scenario_id + ) + + if allow_redirect and scenario is not None and getattr(scenario, "id", None): + target_project_id = project_id or getattr(scenario, "project_id", None) + if target_project_id is None and getattr(scenario, "project", None) is not None: + target_project_id = getattr(scenario.project, "id", None) + + if target_project_id is not None: + redirect_url = request.url_for( + "calculations.profitability_form", + project_id=target_project_id, + scenario_id=scenario.id, + ) + if redirect_url != str(request.url): + return RedirectResponse( + redirect_url, status_code=status.HTTP_307_TEMPORARY_REDIRECT + ) + + context = _prepare_default_context( + request, + project=project, + scenario=scenario, + metadata=metadata, + ) + + return templates.TemplateResponse( + request, + "scenarios/profitability.html", + context, + ) + + +@router.get( + "/projects/{project_id}/scenarios/{scenario_id}/profitability", + response_class=HTMLResponse, + include_in_schema=False, + name="calculations.profitability_form", +) +def profitability_form_for_scenario( + request: Request, + project_id: int, + scenario_id: int, + _: User = Depends(require_authenticated_user_html), + metadata: PricingMetadata = Depends(get_pricing_metadata), + uow: UnitOfWork = Depends(get_unit_of_work), +) -> Response: + return _render_profitability_form( + request, + metadata=metadata, + uow=uow, + project_id=project_id, + scenario_id=scenario_id, + allow_redirect=False, + ) + + +@router.get( + "/profitability", + response_class=HTMLResponse, +) +def profitability_form( + request: Request, + _: User = Depends(require_authenticated_user_html), + metadata: PricingMetadata = Depends(get_pricing_metadata), + uow: UnitOfWork = Depends(get_unit_of_work), + project_id: int | None = Query( + None, description="Optional project identifier" + ), + scenario_id: int | None = Query( + None, description="Optional scenario identifier" + ), +) -> Response: + """Render the profitability calculation form with default metadata.""" + + return _render_profitability_form( + request, + metadata=metadata, + uow=uow, + project_id=project_id, + scenario_id=scenario_id, + allow_redirect=True, + ) + + +async def _handle_profitability_submission( + request: Request, + *, + current_user: User, + metadata: PricingMetadata, + uow: UnitOfWork, + project_id: int | None, + scenario_id: int | None, +) -> Response: + wants_json = _is_json_request(request) + payload_data = await _extract_payload(request) + + try: + request_model = ProfitabilityCalculationRequest.model_validate( + payload_data + ) + result = calculate_profitability(request_model, metadata=metadata) + except ValidationError as exc: + if wants_json: + return JSONResponse( + status_code=status.HTTP_422_UNPROCESSABLE_CONTENT, + content={"errors": exc.errors()}, + ) + + project, scenario = _load_project_and_scenario( + uow=uow, project_id=project_id, scenario_id=scenario_id + ) + context = _prepare_default_context( + request, + project=project, + scenario=scenario, + metadata=metadata, + form_data=payload_data, + allow_empty_override=True, + ) + errors = _list_from_context(context, "errors") + errors.extend( + [f"{err['loc']} - {err['msg']}" for err in exc.errors()] + ) + return templates.TemplateResponse( + request, + "scenarios/profitability.html", + context, + status_code=status.HTTP_422_UNPROCESSABLE_CONTENT, + ) + except ProfitabilityValidationError as exc: + if wants_json: + return JSONResponse( + status_code=status.HTTP_422_UNPROCESSABLE_CONTENT, + content={ + "errors": exc.field_errors or [], + "message": exc.message, + }, + ) + + project, scenario = _load_project_and_scenario( + uow=uow, project_id=project_id, scenario_id=scenario_id + ) + context = _prepare_default_context( + request, + project=project, + scenario=scenario, + metadata=metadata, + form_data=payload_data, + allow_empty_override=True, + ) + messages = list(exc.field_errors or []) or [exc.message] + errors = _list_from_context(context, "errors") + errors.extend(messages) + return templates.TemplateResponse( + request, + "scenarios/profitability.html", + context, + status_code=status.HTTP_422_UNPROCESSABLE_CONTENT, + ) + + project, scenario = _load_project_and_scenario( + uow=uow, project_id=project_id, scenario_id=scenario_id + ) + + _persist_profitability_snapshots( + uow=uow, + project=project, + scenario=scenario, + user=current_user, + request_model=request_model, + result=result, + ) + + if wants_json: + return JSONResponse( + status_code=status.HTTP_200_OK, + content=result.model_dump(), + ) + + context = _prepare_default_context( + request, + project=project, + scenario=scenario, + metadata=metadata, + form_data=request_model.model_dump(mode="json"), + result=result, + ) + notices = _list_from_context(context, "notices") + notices.append("Profitability calculation completed successfully.") + + return templates.TemplateResponse( + request, + "scenarios/profitability.html", + context, + status_code=status.HTTP_200_OK, + ) + + +@router.post( + "/projects/{project_id}/scenarios/{scenario_id}/profitability", + include_in_schema=False, + name="calculations.profitability_submit", +) +async def profitability_submit_for_scenario( + request: Request, + project_id: int, + scenario_id: int, + current_user: User = Depends(require_authenticated_user), + metadata: PricingMetadata = Depends(get_pricing_metadata), + uow: UnitOfWork = Depends(get_unit_of_work), +) -> Response: + return await _handle_profitability_submission( + request, + current_user=current_user, + metadata=metadata, + uow=uow, + project_id=project_id, + scenario_id=scenario_id, + ) + + +@router.post( + "/profitability", +) +async def profitability_submit( + request: Request, + current_user: User = Depends(require_authenticated_user), + metadata: PricingMetadata = Depends(get_pricing_metadata), + uow: UnitOfWork = Depends(get_unit_of_work), + project_id: int | None = Query( + None, description="Optional project identifier" + ), + scenario_id: int | None = Query( + None, description="Optional scenario identifier" + ), +) -> Response: + """Handle profitability calculations and return HTML or JSON.""" + + return await _handle_profitability_submission( + request, + current_user=current_user, + metadata=metadata, + uow=uow, + project_id=project_id, + scenario_id=scenario_id, + ) diff --git a/routes/consumption.py b/routes/consumption.py deleted file mode 100644 index e03785d..0000000 --- a/routes/consumption.py +++ /dev/null @@ -1,52 +0,0 @@ -from typing import List, Optional - -from fastapi import APIRouter, Depends, status -from pydantic import BaseModel, ConfigDict, PositiveFloat, field_validator -from sqlalchemy.orm import Session - -from models.consumption import Consumption -from routes.dependencies import get_db - - -router = APIRouter(prefix="/api/consumption", tags=["Consumption"]) - - -class ConsumptionBase(BaseModel): - scenario_id: int - amount: PositiveFloat - description: Optional[str] = None - unit_name: Optional[str] = None - unit_symbol: Optional[str] = None - - @field_validator("unit_name", "unit_symbol") - @classmethod - def _normalize_text(cls, value: Optional[str]) -> Optional[str]: - if value is None: - return None - stripped = value.strip() - return stripped or None - - -class ConsumptionCreate(ConsumptionBase): - pass - - -class ConsumptionRead(ConsumptionBase): - id: int - model_config = ConfigDict(from_attributes=True) - - -@router.post( - "/", response_model=ConsumptionRead, status_code=status.HTTP_201_CREATED -) -def create_consumption(item: ConsumptionCreate, db: Session = Depends(get_db)): - db_item = Consumption(**item.model_dump()) - db.add(db_item) - db.commit() - db.refresh(db_item) - return db_item - - -@router.get("/", response_model=List[ConsumptionRead]) -def list_consumption(db: Session = Depends(get_db)): - return db.query(Consumption).all() diff --git a/routes/costs.py b/routes/costs.py deleted file mode 100644 index e22f18a..0000000 --- a/routes/costs.py +++ /dev/null @@ -1,121 +0,0 @@ -from typing import List, Optional - -from fastapi import APIRouter, Depends -from pydantic import BaseModel, ConfigDict, field_validator -from sqlalchemy.orm import Session - -from models.capex import Capex -from models.opex import Opex -from routes.dependencies import get_db - -router = APIRouter(prefix="/api/costs", tags=["Costs"]) -# Pydantic schemas for CAPEX and OPEX - - -class _CostBase(BaseModel): - scenario_id: int - amount: float - description: Optional[str] = None - currency_code: Optional[str] = "USD" - currency_id: Optional[int] = None - - @field_validator("currency_code") - @classmethod - def _normalize_currency(cls, value: Optional[str]) -> str: - code = (value or "USD").strip().upper() - return code[:3] if len(code) > 3 else code - - -class CapexCreate(_CostBase): - pass - - -class CapexRead(_CostBase): - id: int - # use from_attributes so Pydantic reads attributes off SQLAlchemy model - model_config = ConfigDict(from_attributes=True) - - # optionally include nested currency info - currency: Optional["CurrencyRead"] = None - - -class OpexCreate(_CostBase): - pass - - -class OpexRead(_CostBase): - id: int - model_config = ConfigDict(from_attributes=True) - currency: Optional["CurrencyRead"] = None - - -class CurrencyRead(BaseModel): - id: int - code: str - name: Optional[str] = None - symbol: Optional[str] = None - is_active: Optional[bool] = True - - model_config = ConfigDict(from_attributes=True) - - -# forward refs -CapexRead.model_rebuild() -OpexRead.model_rebuild() - - -# Capex endpoints -@router.post("/capex", response_model=CapexRead) -def create_capex(item: CapexCreate, db: Session = Depends(get_db)): - payload = item.model_dump() - # Prefer explicit currency_id if supplied - cid = payload.get("currency_id") - if not cid: - code = (payload.pop("currency_code", "USD") or "USD").strip().upper() - currency_cls = __import__( - "models.currency", fromlist=["Currency"] - ).Currency - currency = db.query(currency_cls).filter_by(code=code).one_or_none() - if currency is None: - currency = currency_cls(code=code, name=code, symbol=None) - db.add(currency) - db.flush() - payload["currency_id"] = currency.id - db_item = Capex(**payload) - db.add(db_item) - db.commit() - db.refresh(db_item) - return db_item - - -@router.get("/capex", response_model=List[CapexRead]) -def list_capex(db: Session = Depends(get_db)): - return db.query(Capex).all() - - -# Opex endpoints -@router.post("/opex", response_model=OpexRead) -def create_opex(item: OpexCreate, db: Session = Depends(get_db)): - payload = item.model_dump() - cid = payload.get("currency_id") - if not cid: - code = (payload.pop("currency_code", "USD") or "USD").strip().upper() - currency_cls = __import__( - "models.currency", fromlist=["Currency"] - ).Currency - currency = db.query(currency_cls).filter_by(code=code).one_or_none() - if currency is None: - currency = currency_cls(code=code, name=code, symbol=None) - db.add(currency) - db.flush() - payload["currency_id"] = currency.id - db_item = Opex(**payload) - db.add(db_item) - db.commit() - db.refresh(db_item) - return db_item - - -@router.get("/opex", response_model=List[OpexRead]) -def list_opex(db: Session = Depends(get_db)): - return db.query(Opex).all() diff --git a/routes/currencies.py b/routes/currencies.py deleted file mode 100644 index 8899366..0000000 --- a/routes/currencies.py +++ /dev/null @@ -1,193 +0,0 @@ -from typing import List, Optional - -from fastapi import APIRouter, Depends, HTTPException, Query, status -from pydantic import BaseModel, ConfigDict, Field, field_validator -from sqlalchemy.orm import Session -from sqlalchemy.exc import IntegrityError - -from models.currency import Currency -from routes.dependencies import get_db - -router = APIRouter(prefix="/api/currencies", tags=["Currencies"]) - - -DEFAULT_CURRENCY_CODE = "USD" -DEFAULT_CURRENCY_NAME = "US Dollar" -DEFAULT_CURRENCY_SYMBOL = "$" - - -class CurrencyBase(BaseModel): - name: str = Field(..., min_length=1, max_length=128) - symbol: Optional[str] = Field(default=None, max_length=8) - - @staticmethod - def _normalize_symbol(value: Optional[str]) -> Optional[str]: - if value is None: - return None - value = value.strip() - return value or None - - @field_validator("name") - @classmethod - def _strip_name(cls, value: str) -> str: - return value.strip() - - @field_validator("symbol") - @classmethod - def _strip_symbol(cls, value: Optional[str]) -> Optional[str]: - return cls._normalize_symbol(value) - - -class CurrencyCreate(CurrencyBase): - code: str = Field(..., min_length=3, max_length=3) - is_active: bool = True - - @field_validator("code") - @classmethod - def _normalize_code(cls, value: str) -> str: - return value.strip().upper() - - -class CurrencyUpdate(CurrencyBase): - is_active: Optional[bool] = None - - -class CurrencyActivation(BaseModel): - is_active: bool - - -class CurrencyRead(CurrencyBase): - id: int - code: str - is_active: bool - - model_config = ConfigDict(from_attributes=True) - - -def _ensure_default_currency(db: Session) -> Currency: - existing = ( - db.query(Currency) - .filter(Currency.code == DEFAULT_CURRENCY_CODE) - .one_or_none() - ) - if existing: - return existing - - default_currency = Currency( - code=DEFAULT_CURRENCY_CODE, - name=DEFAULT_CURRENCY_NAME, - symbol=DEFAULT_CURRENCY_SYMBOL, - is_active=True, - ) - db.add(default_currency) - try: - db.commit() - except IntegrityError: - db.rollback() - existing = ( - db.query(Currency) - .filter(Currency.code == DEFAULT_CURRENCY_CODE) - .one() - ) - return existing - db.refresh(default_currency) - return default_currency - - -def _get_currency_or_404(db: Session, code: str) -> Currency: - normalized = code.strip().upper() - currency = ( - db.query(Currency).filter(Currency.code == normalized).one_or_none() - ) - if currency is None: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, detail="Currency not found" - ) - return currency - - -@router.get("/", response_model=List[CurrencyRead]) -def list_currencies( - include_inactive: bool = Query( - False, description="Include inactive currencies" - ), - db: Session = Depends(get_db), -): - _ensure_default_currency(db) - query = db.query(Currency) - if not include_inactive: - query = query.filter(Currency.is_active.is_(True)) - currencies = query.order_by(Currency.code).all() - return currencies - - -@router.post( - "/", response_model=CurrencyRead, status_code=status.HTTP_201_CREATED -) -def create_currency(payload: CurrencyCreate, db: Session = Depends(get_db)): - code = payload.code - existing = db.query(Currency).filter(Currency.code == code).one_or_none() - if existing is not None: - raise HTTPException( - status_code=status.HTTP_409_CONFLICT, - detail=f"Currency '{code}' already exists", - ) - - currency = Currency( - code=code, - name=payload.name, - symbol=CurrencyBase._normalize_symbol(payload.symbol), - is_active=payload.is_active, - ) - db.add(currency) - db.commit() - db.refresh(currency) - return currency - - -@router.put("/{code}", response_model=CurrencyRead) -def update_currency( - code: str, payload: CurrencyUpdate, db: Session = Depends(get_db) -): - currency = _get_currency_or_404(db, code) - - if payload.name is not None: - setattr(currency, "name", payload.name) - if payload.symbol is not None or payload.symbol == "": - setattr( - currency, - "symbol", - CurrencyBase._normalize_symbol(payload.symbol), - ) - if payload.is_active is not None: - code_value = getattr(currency, "code") - if code_value == DEFAULT_CURRENCY_CODE and payload.is_active is False: - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail="The default currency cannot be deactivated.", - ) - setattr(currency, "is_active", payload.is_active) - - db.add(currency) - db.commit() - db.refresh(currency) - return currency - - -@router.patch("/{code}/activation", response_model=CurrencyRead) -def toggle_currency_activation( - code: str, body: CurrencyActivation, db: Session = Depends(get_db) -): - currency = _get_currency_or_404(db, code) - code_value = getattr(currency, "code") - if code_value == DEFAULT_CURRENCY_CODE and body.is_active is False: - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail="The default currency cannot be deactivated.", - ) - - setattr(currency, "is_active", body.is_active) - db.add(currency) - db.commit() - db.refresh(currency) - return currency diff --git a/routes/dashboard.py b/routes/dashboard.py new file mode 100644 index 0000000..ed6bcf7 --- /dev/null +++ b/routes/dashboard.py @@ -0,0 +1,130 @@ +from __future__ import annotations + +from datetime import datetime + +from fastapi import APIRouter, Depends, Request +from fastapi.responses import HTMLResponse, RedirectResponse +from routes.template_filters import create_templates + +from dependencies import get_current_user, get_unit_of_work +from models import ScenarioStatus, User +from services.unit_of_work import UnitOfWork + +router = APIRouter(tags=["Dashboard"]) +templates = create_templates() + + +def _format_timestamp(moment: datetime | None) -> str | None: + if moment is None: + return None + return moment.strftime("%Y-%m-%d") + + +def _format_timestamp_with_time(moment: datetime | None) -> str | None: + if moment is None: + return None + return moment.strftime("%Y-%m-%d %H:%M") + + +def _load_metrics(uow: UnitOfWork) -> dict[str, object]: + if not uow.projects or not uow.scenarios or not uow.financial_inputs: + raise RuntimeError("UnitOfWork repositories not initialised") + total_projects = uow.projects.count() + active_scenarios = uow.scenarios.count_by_status(ScenarioStatus.ACTIVE) + pending_simulations = uow.scenarios.count_by_status(ScenarioStatus.DRAFT) + last_import_at = uow.financial_inputs.latest_created_at() + return { + "total_projects": total_projects, + "active_scenarios": active_scenarios, + "pending_simulations": pending_simulations, + "last_import": _format_timestamp(last_import_at), + } + + +def _load_recent_projects(uow: UnitOfWork) -> list: + if not uow.projects: + raise RuntimeError("Project repository not initialised") + return list(uow.projects.recent(limit=5)) + + +def _load_simulation_updates(uow: UnitOfWork) -> list[dict[str, object]]: + updates: list[dict[str, object]] = [] + if not uow.scenarios: + raise RuntimeError("Scenario repository not initialised") + scenarios = uow.scenarios.recent(limit=5, with_project=True) + for scenario in scenarios: + project_name = scenario.project.name if scenario.project else f"Project #{scenario.project_id}" + timestamp_label = _format_timestamp_with_time(scenario.updated_at) + updates.append( + { + "title": f"{scenario.name} · {scenario.status.value.title()}", + "description": f"Latest update recorded for {project_name}.", + "timestamp": scenario.updated_at, + "timestamp_label": timestamp_label, + } + ) + return updates + + +def _load_scenario_alerts( + request: Request, uow: UnitOfWork +) -> list[dict[str, object]]: + alerts: list[dict[str, object]] = [] + + if not uow.scenarios: + raise RuntimeError("Scenario repository not initialised") + + drafts = uow.scenarios.list_by_status( + ScenarioStatus.DRAFT, limit=3, with_project=True + ) + for scenario in drafts: + project_name = scenario.project.name if scenario.project else f"Project #{scenario.project_id}" + alerts.append( + { + "title": f"Draft scenario: {scenario.name}", + "message": f"{project_name} has a scenario awaiting validation.", + "link": request.url_for( + "projects.view_project", project_id=scenario.project_id + ), + } + ) + + if not alerts: + archived = uow.scenarios.list_by_status( + ScenarioStatus.ARCHIVED, limit=3, with_project=True + ) + for scenario in archived: + project_name = scenario.project.name if scenario.project else f"Project #{scenario.project_id}" + alerts.append( + { + "title": f"Archived scenario: {scenario.name}", + "message": f"Review archived scenario insights for {project_name}.", + "link": request.url_for( + "scenarios.view_scenario", scenario_id=scenario.id + ), + } + ) + + return alerts + + +@router.get("/", include_in_schema=False, name="dashboard.home", response_model=None) +def dashboard_home( + request: Request, + user: User | None = Depends(get_current_user), + uow: UnitOfWork = Depends(get_unit_of_work), +) -> HTMLResponse | RedirectResponse: + if user is None: + return RedirectResponse(request.url_for("auth.login_form"), status_code=303) + + context = { + "metrics": _load_metrics(uow), + "recent_projects": _load_recent_projects(uow), + "simulation_updates": _load_simulation_updates(uow), + "scenario_alerts": _load_scenario_alerts(request, uow), + "export_modals": { + "projects": request.url_for("exports.modal", dataset="projects"), + "scenarios": request.url_for("exports.modal", dataset="scenarios"), + }, + } + return templates.TemplateResponse(request, "dashboard.html", context) diff --git a/routes/dependencies.py b/routes/dependencies.py deleted file mode 100644 index 0afc871..0000000 --- a/routes/dependencies.py +++ /dev/null @@ -1,13 +0,0 @@ -from collections.abc import Generator - -from sqlalchemy.orm import Session - -from config.database import SessionLocal - - -def get_db() -> Generator[Session, None, None]: - db = SessionLocal() - try: - yield db - finally: - db.close() diff --git a/routes/distributions.py b/routes/distributions.py deleted file mode 100644 index 34a0cc8..0000000 --- a/routes/distributions.py +++ /dev/null @@ -1,38 +0,0 @@ -from typing import Dict, List - -from fastapi import APIRouter, Depends -from pydantic import BaseModel, ConfigDict -from sqlalchemy.orm import Session - -from models.distribution import Distribution -from routes.dependencies import get_db - -router = APIRouter(prefix="/api/distributions", tags=["Distributions"]) - - -class DistributionCreate(BaseModel): - name: str - distribution_type: str - parameters: Dict[str, float | int] - - -class DistributionRead(DistributionCreate): - id: int - model_config = ConfigDict(from_attributes=True) - - -@router.post("/", response_model=DistributionRead) -async def create_distribution( - dist: DistributionCreate, db: Session = Depends(get_db) -): - db_dist = Distribution(**dist.model_dump()) - db.add(db_dist) - db.commit() - db.refresh(db_dist) - return db_dist - - -@router.get("/", response_model=List[DistributionRead]) -async def list_distributions(db: Session = Depends(get_db)): - dists = db.query(Distribution).all() - return dists diff --git a/routes/equipment.py b/routes/equipment.py deleted file mode 100644 index a5800a9..0000000 --- a/routes/equipment.py +++ /dev/null @@ -1,38 +0,0 @@ -from typing import List, Optional - -from fastapi import APIRouter, Depends -from pydantic import BaseModel, ConfigDict -from sqlalchemy.orm import Session - -from models.equipment import Equipment -from routes.dependencies import get_db - -router = APIRouter(prefix="/api/equipment", tags=["Equipment"]) -# Pydantic schemas - - -class EquipmentCreate(BaseModel): - scenario_id: int - name: str - description: Optional[str] = None - - -class EquipmentRead(EquipmentCreate): - id: int - model_config = ConfigDict(from_attributes=True) - - -@router.post("/", response_model=EquipmentRead) -async def create_equipment( - item: EquipmentCreate, db: Session = Depends(get_db) -): - db_item = Equipment(**item.model_dump()) - db.add(db_item) - db.commit() - db.refresh(db_item) - return db_item - - -@router.get("/", response_model=List[EquipmentRead]) -async def list_equipment(db: Session = Depends(get_db)): - return db.query(Equipment).all() diff --git a/routes/exports.py b/routes/exports.py new file mode 100644 index 0000000..f6069d6 --- /dev/null +++ b/routes/exports.py @@ -0,0 +1,363 @@ +from __future__ import annotations + +import logging +import time +from datetime import datetime, timezone +from typing import Annotated + +from fastapi import APIRouter, Depends, HTTPException, Request, Response, status +from fastapi.responses import HTMLResponse, StreamingResponse + +from dependencies import get_unit_of_work, require_any_role +from schemas.exports import ( + ExportFormat, + ProjectExportRequest, + ScenarioExportRequest, +) +from services.export_serializers import ( + export_projects_to_excel, + export_scenarios_to_excel, + stream_projects_to_csv, + stream_scenarios_to_csv, +) +from services.unit_of_work import UnitOfWork +from models.import_export_log import ImportExportLog +from monitoring.metrics import observe_export +from routes.template_filters import create_templates + +logger = logging.getLogger(__name__) + +router = APIRouter(prefix="/exports", tags=["exports"]) +templates = create_templates() + + +@router.get( + "/modal/{dataset}", + response_model=None, + response_class=HTMLResponse, + include_in_schema=False, + name="exports.modal", +) +async def export_modal( + dataset: str, + request: Request, +) -> HTMLResponse: + dataset = dataset.lower() + if dataset not in {"projects", "scenarios"}: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, detail="Unknown dataset") + + submit_url = request.url_for( + "export_projects" if dataset == "projects" else "export_scenarios" + ) + return templates.TemplateResponse( + request, + "exports/modal.html", + { + "dataset": dataset, + "submit_url": submit_url, + }, + ) + + +def _timestamp_suffix() -> str: + return datetime.now(timezone.utc).strftime("%Y%m%d-%H%M%S") + + +def _ensure_repository(repo, name: str): + if repo is None: + raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"{name} repository unavailable") + return repo + + +def _record_export_audit( + *, + uow: UnitOfWork, + dataset: str, + status: str, + export_format: ExportFormat, + row_count: int, + filename: str | None, +) -> None: + try: + if uow.session is None: + return + log = ImportExportLog( + action="export", + dataset=dataset, + status=status, + filename=filename, + row_count=row_count, + detail=f"format={export_format.value}", + ) + uow.session.add(log) + uow.commit() + except Exception: + # best-effort auditing, do not break exports + if uow.session is not None: + uow.session.rollback() + logger.exception( + "export.audit.failed", + extra={ + "event": "export.audit", + "dataset": dataset, + "status": status, + "format": export_format.value, + }, + ) + + +@router.post( + "/projects", + status_code=status.HTTP_200_OK, + response_class=StreamingResponse, + dependencies=[Depends(require_any_role( + "admin", "project_manager", "analyst"))], +) +async def export_projects( + request: ProjectExportRequest, + uow: Annotated[UnitOfWork, Depends(get_unit_of_work)], +) -> Response: + project_repo = _ensure_repository( + getattr(uow, "projects", None), "Project") + start = time.perf_counter() + try: + projects = project_repo.filtered_for_export(request.filters) + except ValueError as exc: + _record_export_audit( + uow=uow, + dataset="projects", + status="failure", + export_format=request.format, + row_count=0, + filename=None, + ) + logger.warning( + "export.validation_failed", + extra={ + "event": "export", + "dataset": "projects", + "status": "validation_failed", + "format": request.format.value, + "error": str(exc), + }, + ) + raise HTTPException( + status_code=status.HTTP_422_UNPROCESSABLE_CONTENT, + detail=str(exc), + ) from exc + except Exception as exc: + _record_export_audit( + uow=uow, + dataset="projects", + status="failure", + export_format=request.format, + row_count=0, + filename=None, + ) + logger.exception( + "export.failed", + extra={ + "event": "export", + "dataset": "projects", + "status": "failure", + "format": request.format.value, + }, + ) + raise exc + + filename = f"projects-{_timestamp_suffix()}" + + if request.format == ExportFormat.CSV: + stream = stream_projects_to_csv(projects) + response = StreamingResponse(stream, media_type="text/csv") + response.headers["Content-Disposition"] = f"attachment; filename={filename}.csv" + _record_export_audit( + uow=uow, + dataset="projects", + status="success", + export_format=request.format, + row_count=len(projects), + filename=f"{filename}.csv", + ) + logger.info( + "export", + extra={ + "event": "export", + "dataset": "projects", + "status": "success", + "format": request.format.value, + "row_count": len(projects), + "filename": f"{filename}.csv", + }, + ) + observe_export( + dataset="projects", + status="success", + export_format=request.format.value, + seconds=time.perf_counter() - start, + ) + return response + + data = export_projects_to_excel(projects) + _record_export_audit( + uow=uow, + dataset="projects", + status="success", + export_format=request.format, + row_count=len(projects), + filename=f"{filename}.xlsx", + ) + logger.info( + "export", + extra={ + "event": "export", + "dataset": "projects", + "status": "success", + "format": request.format.value, + "row_count": len(projects), + "filename": f"{filename}.xlsx", + }, + ) + observe_export( + dataset="projects", + status="success", + export_format=request.format.value, + seconds=time.perf_counter() - start, + ) + return StreamingResponse( + iter([data]), + media_type="application/vnd.openxmlformats-officedocument.spreadsheetml.sheet", + headers={ + "Content-Disposition": f"attachment; filename={filename}.xlsx", + }, + ) + + +@router.post( + "/scenarios", + status_code=status.HTTP_200_OK, + response_class=StreamingResponse, + dependencies=[Depends(require_any_role( + "admin", "project_manager", "analyst"))], +) +async def export_scenarios( + request: ScenarioExportRequest, + uow: Annotated[UnitOfWork, Depends(get_unit_of_work)], +) -> Response: + scenario_repo = _ensure_repository( + getattr(uow, "scenarios", None), "Scenario") + start = time.perf_counter() + try: + scenarios = scenario_repo.filtered_for_export( + request.filters, include_project=True) + except ValueError as exc: + _record_export_audit( + uow=uow, + dataset="scenarios", + status="failure", + export_format=request.format, + row_count=0, + filename=None, + ) + logger.warning( + "export.validation_failed", + extra={ + "event": "export", + "dataset": "scenarios", + "status": "validation_failed", + "format": request.format.value, + "error": str(exc), + }, + ) + raise HTTPException( + status_code=status.HTTP_422_UNPROCESSABLE_CONTENT, + detail=str(exc), + ) from exc + except Exception as exc: + _record_export_audit( + uow=uow, + dataset="scenarios", + status="failure", + export_format=request.format, + row_count=0, + filename=None, + ) + logger.exception( + "export.failed", + extra={ + "event": "export", + "dataset": "scenarios", + "status": "failure", + "format": request.format.value, + }, + ) + raise exc + + filename = f"scenarios-{_timestamp_suffix()}" + + if request.format == ExportFormat.CSV: + stream = stream_scenarios_to_csv(scenarios) + response = StreamingResponse(stream, media_type="text/csv") + response.headers["Content-Disposition"] = f"attachment; filename={filename}.csv" + _record_export_audit( + uow=uow, + dataset="scenarios", + status="success", + export_format=request.format, + row_count=len(scenarios), + filename=f"{filename}.csv", + ) + logger.info( + "export", + extra={ + "event": "export", + "dataset": "scenarios", + "status": "success", + "format": request.format.value, + "row_count": len(scenarios), + "filename": f"{filename}.csv", + }, + ) + observe_export( + dataset="scenarios", + status="success", + export_format=request.format.value, + seconds=time.perf_counter() - start, + ) + return response + + data = export_scenarios_to_excel(scenarios) + _record_export_audit( + uow=uow, + dataset="scenarios", + status="success", + export_format=request.format, + row_count=len(scenarios), + filename=f"{filename}.xlsx", + ) + logger.info( + "export", + extra={ + "event": "export", + "dataset": "scenarios", + "status": "success", + "format": request.format.value, + "row_count": len(scenarios), + "filename": f"{filename}.xlsx", + }, + ) + observe_export( + dataset="scenarios", + status="success", + export_format=request.format.value, + seconds=time.perf_counter() - start, + ) + return StreamingResponse( + iter([data]), + media_type="application/vnd.openxmlformats-officedocument.spreadsheetml.sheet", + headers={ + "Content-Disposition": f"attachment; filename={filename}.xlsx", + }, + ) diff --git a/routes/imports.py b/routes/imports.py new file mode 100644 index 0000000..f73fdbb --- /dev/null +++ b/routes/imports.py @@ -0,0 +1,170 @@ +from __future__ import annotations + +from io import BytesIO + +from fastapi import APIRouter, Depends, File, HTTPException, UploadFile, status +from fastapi import Request +from fastapi.responses import HTMLResponse + +from dependencies import ( + get_import_ingestion_service, + require_roles, + require_roles_html, +) +from models import User +from schemas.imports import ( + ImportCommitRequest, + ProjectImportCommitResponse, + ProjectImportPreviewResponse, + ScenarioImportCommitResponse, + ScenarioImportPreviewResponse, +) +from services.importers import ImportIngestionService, UnsupportedImportFormat +from routes.template_filters import create_templates + +router = APIRouter(prefix="/imports", tags=["Imports"]) +templates = create_templates() + +MANAGE_ROLES = ("project_manager", "admin") + + +@router.get( + "/ui", + response_class=HTMLResponse, + include_in_schema=False, + name="imports.ui", +) +def import_dashboard( + request: Request, + _: User = Depends(require_roles_html(*MANAGE_ROLES)), +) -> HTMLResponse: + return templates.TemplateResponse( + request, + "imports/ui.html", + { + "title": "Imports", + }, + ) + + +async def _read_upload_file(upload: UploadFile) -> BytesIO: + content = await upload.read() + if not content: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Uploaded file is empty.", + ) + return BytesIO(content) + + +@router.post( + "/projects/preview", + response_model=ProjectImportPreviewResponse, + status_code=status.HTTP_200_OK, +) +async def preview_project_import( + file: UploadFile = File(..., + description="Project import file (CSV or Excel)"), + _: User = Depends(require_roles(*MANAGE_ROLES)), + ingestion_service: ImportIngestionService = Depends( + get_import_ingestion_service), +) -> ProjectImportPreviewResponse: + if not file.filename: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Filename is required for import.", + ) + + stream = await _read_upload_file(file) + + try: + preview = ingestion_service.preview_projects(stream, file.filename) + except UnsupportedImportFormat as exc: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=str(exc), + ) from exc + + return ProjectImportPreviewResponse.model_validate(preview) + + +@router.post( + "/scenarios/preview", + response_model=ScenarioImportPreviewResponse, + status_code=status.HTTP_200_OK, +) +async def preview_scenario_import( + file: UploadFile = File(..., + description="Scenario import file (CSV or Excel)"), + _: User = Depends(require_roles(*MANAGE_ROLES)), + ingestion_service: ImportIngestionService = Depends( + get_import_ingestion_service), +) -> ScenarioImportPreviewResponse: + if not file.filename: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Filename is required for import.", + ) + + stream = await _read_upload_file(file) + + try: + preview = ingestion_service.preview_scenarios(stream, file.filename) + except UnsupportedImportFormat as exc: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=str(exc), + ) from exc + + return ScenarioImportPreviewResponse.model_validate(preview) + + +def _value_error_status(exc: ValueError) -> int: + detail = str(exc) + if detail.lower().startswith("unknown"): + return status.HTTP_404_NOT_FOUND + return status.HTTP_400_BAD_REQUEST + + +@router.post( + "/projects/commit", + response_model=ProjectImportCommitResponse, + status_code=status.HTTP_200_OK, +) +async def commit_project_import_endpoint( + payload: ImportCommitRequest, + _: User = Depends(require_roles(*MANAGE_ROLES)), + ingestion_service: ImportIngestionService = Depends( + get_import_ingestion_service), +) -> ProjectImportCommitResponse: + try: + result = ingestion_service.commit_project_import(payload.token) + except ValueError as exc: + raise HTTPException( + status_code=_value_error_status(exc), + detail=str(exc), + ) from exc + + return ProjectImportCommitResponse.model_validate(result) + + +@router.post( + "/scenarios/commit", + response_model=ScenarioImportCommitResponse, + status_code=status.HTTP_200_OK, +) +async def commit_scenario_import_endpoint( + payload: ImportCommitRequest, + _: User = Depends(require_roles(*MANAGE_ROLES)), + ingestion_service: ImportIngestionService = Depends( + get_import_ingestion_service), +) -> ScenarioImportCommitResponse: + try: + result = ingestion_service.commit_scenario_import(payload.token) + except ValueError as exc: + raise HTTPException( + status_code=_value_error_status(exc), + detail=str(exc), + ) from exc + + return ScenarioImportCommitResponse.model_validate(result) diff --git a/routes/maintenance.py b/routes/maintenance.py deleted file mode 100644 index 93683fd..0000000 --- a/routes/maintenance.py +++ /dev/null @@ -1,91 +0,0 @@ -from datetime import date -from typing import List, Optional - -from fastapi import APIRouter, Depends, HTTPException, status -from pydantic import BaseModel, ConfigDict, PositiveFloat -from sqlalchemy.orm import Session - -from models.maintenance import Maintenance -from routes.dependencies import get_db - - -router = APIRouter(prefix="/api/maintenance", tags=["Maintenance"]) - - -class MaintenanceBase(BaseModel): - equipment_id: int - scenario_id: int - maintenance_date: date - description: Optional[str] = None - cost: PositiveFloat - - -class MaintenanceCreate(MaintenanceBase): - pass - - -class MaintenanceUpdate(MaintenanceBase): - pass - - -class MaintenanceRead(MaintenanceBase): - id: int - model_config = ConfigDict(from_attributes=True) - - -def _get_maintenance_or_404(db: Session, maintenance_id: int) -> Maintenance: - maintenance = ( - db.query(Maintenance).filter(Maintenance.id == maintenance_id).first() - ) - if maintenance is None: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail=f"Maintenance record {maintenance_id} not found", - ) - return maintenance - - -@router.post( - "/", response_model=MaintenanceRead, status_code=status.HTTP_201_CREATED -) -def create_maintenance( - maintenance: MaintenanceCreate, db: Session = Depends(get_db) -): - db_maintenance = Maintenance(**maintenance.model_dump()) - db.add(db_maintenance) - db.commit() - db.refresh(db_maintenance) - return db_maintenance - - -@router.get("/", response_model=List[MaintenanceRead]) -def list_maintenance( - skip: int = 0, limit: int = 100, db: Session = Depends(get_db) -): - return db.query(Maintenance).offset(skip).limit(limit).all() - - -@router.get("/{maintenance_id}", response_model=MaintenanceRead) -def get_maintenance(maintenance_id: int, db: Session = Depends(get_db)): - return _get_maintenance_or_404(db, maintenance_id) - - -@router.put("/{maintenance_id}", response_model=MaintenanceRead) -def update_maintenance( - maintenance_id: int, - payload: MaintenanceUpdate, - db: Session = Depends(get_db), -): - db_maintenance = _get_maintenance_or_404(db, maintenance_id) - for field, value in payload.model_dump().items(): - setattr(db_maintenance, field, value) - db.commit() - db.refresh(db_maintenance) - return db_maintenance - - -@router.delete("/{maintenance_id}", status_code=status.HTTP_204_NO_CONTENT) -def delete_maintenance(maintenance_id: int, db: Session = Depends(get_db)): - db_maintenance = _get_maintenance_or_404(db, maintenance_id) - db.delete(db_maintenance) - db.commit() diff --git a/routes/navigation.py b/routes/navigation.py new file mode 100644 index 0000000..d4fd5ef --- /dev/null +++ b/routes/navigation.py @@ -0,0 +1,63 @@ +from __future__ import annotations + +from datetime import datetime, timezone + +from fastapi import APIRouter, Depends, Request + +from dependencies import ( + get_auth_session, + get_navigation_service, + require_authenticated_user, +) +from models import User +from schemas.navigation import ( + NavigationGroupSchema, + NavigationLinkSchema, + NavigationSidebarResponse, +) +from services.navigation import NavigationGroupDTO, NavigationLinkDTO, NavigationService +from services.session import AuthSession + +router = APIRouter(prefix="/navigation", tags=["Navigation"]) + + +def _to_link_schema(dto: NavigationLinkDTO) -> NavigationLinkSchema: + return NavigationLinkSchema( + id=dto.id, + label=dto.label, + href=dto.href, + match_prefix=dto.match_prefix, + icon=dto.icon, + tooltip=dto.tooltip, + is_external=dto.is_external, + children=[_to_link_schema(child) for child in dto.children], + ) + + +def _to_group_schema(dto: NavigationGroupDTO) -> NavigationGroupSchema: + return NavigationGroupSchema( + id=dto.id, + label=dto.label, + icon=dto.icon, + tooltip=dto.tooltip, + links=[_to_link_schema(link) for link in dto.links], + ) + + +@router.get( + "/sidebar", + response_model=NavigationSidebarResponse, + name="navigation.sidebar", +) +async def get_sidebar_navigation( + request: Request, + _: User = Depends(require_authenticated_user), + session: AuthSession = Depends(get_auth_session), + service: NavigationService = Depends(get_navigation_service), +) -> NavigationSidebarResponse: + dto = service.build_sidebar(session=session, request=request) + return NavigationSidebarResponse( + groups=[_to_group_schema(group) for group in dto.groups], + roles=list(dto.roles), + generated_at=datetime.now(tz=timezone.utc), + ) diff --git a/routes/parameters.py b/routes/parameters.py deleted file mode 100644 index 59f09c8..0000000 --- a/routes/parameters.py +++ /dev/null @@ -1,90 +0,0 @@ -from typing import Any, Dict, List, Optional - -from fastapi import APIRouter, Depends, HTTPException -from pydantic import BaseModel, ConfigDict, field_validator -from sqlalchemy.orm import Session - -from models.distribution import Distribution -from models.parameters import Parameter -from models.scenario import Scenario -from routes.dependencies import get_db - -router = APIRouter(prefix="/api/parameters", tags=["parameters"]) - - -class ParameterCreate(BaseModel): - scenario_id: int - name: str - value: float - distribution_id: Optional[int] = None - distribution_type: Optional[str] = None - distribution_parameters: Optional[Dict[str, Any]] = None - - @field_validator("distribution_type") - @classmethod - def normalize_type(cls, value: Optional[str]) -> Optional[str]: - if value is None: - return value - normalized = value.strip().lower() - if not normalized: - return None - if normalized not in {"normal", "uniform", "triangular"}: - raise ValueError( - "distribution_type must be normal, uniform, or triangular" - ) - return normalized - - @field_validator("distribution_parameters") - @classmethod - def empty_dict_to_none( - cls, value: Optional[Dict[str, Any]] - ) -> Optional[Dict[str, Any]]: - if value is None: - return None - return value or None - - -class ParameterRead(ParameterCreate): - id: int - model_config = ConfigDict(from_attributes=True) - - -@router.post("/", response_model=ParameterRead) -def create_parameter(param: ParameterCreate, db: Session = Depends(get_db)): - scen = db.query(Scenario).filter(Scenario.id == param.scenario_id).first() - if not scen: - raise HTTPException(status_code=404, detail="Scenario not found") - distribution_id = param.distribution_id - distribution_type = param.distribution_type - distribution_parameters = param.distribution_parameters - - if distribution_id is not None: - distribution = ( - db.query(Distribution) - .filter(Distribution.id == distribution_id) - .first() - ) - if not distribution: - raise HTTPException( - status_code=404, detail="Distribution not found" - ) - distribution_type = distribution.distribution_type - distribution_parameters = distribution.parameters or None - - new_param = Parameter( - scenario_id=param.scenario_id, - name=param.name, - value=param.value, - distribution_id=distribution_id, - distribution_type=distribution_type, - distribution_parameters=distribution_parameters, - ) - db.add(new_param) - db.commit() - db.refresh(new_param) - return new_param - - -@router.get("/", response_model=List[ParameterRead]) -def list_parameters(db: Session = Depends(get_db)): - return db.query(Parameter).all() diff --git a/routes/production.py b/routes/production.py deleted file mode 100644 index ad4a059..0000000 --- a/routes/production.py +++ /dev/null @@ -1,56 +0,0 @@ -from typing import List, Optional - -from fastapi import APIRouter, Depends, status -from pydantic import BaseModel, ConfigDict, PositiveFloat, field_validator -from sqlalchemy.orm import Session - -from models.production_output import ProductionOutput -from routes.dependencies import get_db - - -router = APIRouter(prefix="/api/production", tags=["Production"]) - - -class ProductionOutputBase(BaseModel): - scenario_id: int - amount: PositiveFloat - description: Optional[str] = None - unit_name: Optional[str] = None - unit_symbol: Optional[str] = None - - @field_validator("unit_name", "unit_symbol") - @classmethod - def _normalize_text(cls, value: Optional[str]) -> Optional[str]: - if value is None: - return None - stripped = value.strip() - return stripped or None - - -class ProductionOutputCreate(ProductionOutputBase): - pass - - -class ProductionOutputRead(ProductionOutputBase): - id: int - model_config = ConfigDict(from_attributes=True) - - -@router.post( - "/", - response_model=ProductionOutputRead, - status_code=status.HTTP_201_CREATED, -) -def create_production( - item: ProductionOutputCreate, db: Session = Depends(get_db) -): - db_item = ProductionOutput(**item.model_dump()) - db.add(db_item) - db.commit() - db.refresh(db_item) - return db_item - - -@router.get("/", response_model=List[ProductionOutputRead]) -def list_production(db: Session = Depends(get_db)): - return db.query(ProductionOutput).all() diff --git a/routes/projects.py b/routes/projects.py new file mode 100644 index 0000000..f3323a5 --- /dev/null +++ b/routes/projects.py @@ -0,0 +1,337 @@ +from __future__ import annotations + +from typing import List + +from fastapi import APIRouter, Depends, Form, HTTPException, Request, status +from fastapi.responses import HTMLResponse, RedirectResponse + +from dependencies import ( + get_pricing_metadata, + get_unit_of_work, + require_any_role, + require_any_role_html, + require_project_resource, + require_project_resource_html, + require_roles, + require_roles_html, +) +from models import MiningOperationType, Project, ScenarioStatus, User +from schemas.project import ProjectCreate, ProjectRead, ProjectUpdate +from services.exceptions import EntityConflictError +from services.pricing import PricingMetadata +from services.unit_of_work import UnitOfWork +from routes.template_filters import create_templates + +router = APIRouter(prefix="/projects", tags=["Projects"]) +templates = create_templates() + +READ_ROLES = ("viewer", "analyst", "project_manager", "admin") +MANAGE_ROLES = ("project_manager", "admin") + + +def _to_read_model(project: Project) -> ProjectRead: + return ProjectRead.model_validate(project) + + +def _require_project_repo(uow: UnitOfWork): + if not uow.projects: + raise RuntimeError("Project repository not initialised") + return uow.projects + + +def _operation_type_choices() -> list[tuple[str, str]]: + return [ + (op.value, op.name.replace("_", " ").title()) for op in MiningOperationType + ] + + +@router.get("", response_model=List[ProjectRead]) +def list_projects( + _: User = Depends(require_any_role(*READ_ROLES)), + uow: UnitOfWork = Depends(get_unit_of_work), +) -> List[ProjectRead]: + projects = _require_project_repo(uow).list() + return [_to_read_model(project) for project in projects] + + +@router.post("", response_model=ProjectRead, status_code=status.HTTP_201_CREATED) +def create_project( + payload: ProjectCreate, + _: User = Depends(require_roles(*MANAGE_ROLES)), + uow: UnitOfWork = Depends(get_unit_of_work), + metadata: PricingMetadata = Depends(get_pricing_metadata), +) -> ProjectRead: + project = Project(**payload.model_dump()) + try: + created = _require_project_repo(uow).create(project) + except EntityConflictError as exc: + raise HTTPException( + status_code=status.HTTP_409_CONFLICT, detail=str(exc) + ) from exc + default_settings = uow.ensure_default_pricing_settings( + metadata=metadata).settings + uow.set_project_pricing_settings(created, default_settings) + return _to_read_model(created) + + +@router.get( + "/ui", + response_class=HTMLResponse, + include_in_schema=False, + name="projects.project_list_page", +) +def project_list_page( + request: Request, + _: User = Depends(require_any_role_html(*READ_ROLES)), + uow: UnitOfWork = Depends(get_unit_of_work), +) -> HTMLResponse: + projects = _require_project_repo(uow).list(with_children=True) + for project in projects: + setattr(project, "scenario_count", len(project.scenarios)) + return templates.TemplateResponse( + request, + "projects/list.html", + { + "projects": projects, + }, + ) + + +@router.get( + "/create", + response_class=HTMLResponse, + include_in_schema=False, + name="projects.create_project_form", +) +def create_project_form( + request: Request, + _: User = Depends(require_roles_html(*MANAGE_ROLES)), +) -> HTMLResponse: + return templates.TemplateResponse( + request, + "projects/form.html", + { + "project": None, + "operation_types": _operation_type_choices(), + "form_action": request.url_for("projects.create_project_submit"), + "cancel_url": request.url_for("projects.project_list_page"), + }, + ) + + +@router.post( + "/create", + include_in_schema=False, + name="projects.create_project_submit", +) +def create_project_submit( + request: Request, + _: User = Depends(require_roles_html(*MANAGE_ROLES)), + name: str = Form(...), + location: str | None = Form(None), + operation_type: str = Form(...), + description: str | None = Form(None), + uow: UnitOfWork = Depends(get_unit_of_work), + metadata: PricingMetadata = Depends(get_pricing_metadata), +): + def _normalise(value: str | None) -> str | None: + if value is None: + return None + value = value.strip() + return value or None + + try: + op_type = MiningOperationType(operation_type) + except ValueError: + return templates.TemplateResponse( + request, + "projects/form.html", + { + "project": None, + "operation_types": _operation_type_choices(), + "form_action": request.url_for("projects.create_project_submit"), + "cancel_url": request.url_for("projects.project_list_page"), + "error": "Invalid operation type.", + }, + status_code=status.HTTP_400_BAD_REQUEST, + ) + + project = Project( + name=name.strip(), + location=_normalise(location), + operation_type=op_type, + description=_normalise(description), + ) + try: + created = _require_project_repo(uow).create(project) + except EntityConflictError: + return templates.TemplateResponse( + request, + "projects/form.html", + { + "project": project, + "operation_types": _operation_type_choices(), + "form_action": request.url_for("projects.create_project_submit"), + "cancel_url": request.url_for("projects.project_list_page"), + "error": "Project with this name already exists.", + }, + status_code=status.HTTP_409_CONFLICT, + ) + + default_settings = uow.ensure_default_pricing_settings( + metadata=metadata).settings + uow.set_project_pricing_settings(created, default_settings) + + return RedirectResponse( + request.url_for("projects.project_list_page"), + status_code=status.HTTP_303_SEE_OTHER, + ) + + +@router.get("/{project_id}", response_model=ProjectRead) +def get_project(project: Project = Depends(require_project_resource())) -> ProjectRead: + return _to_read_model(project) + + +@router.put("/{project_id}", response_model=ProjectRead) +def update_project( + payload: ProjectUpdate, + project: Project = Depends( + require_project_resource(require_manage=True) + ), + uow: UnitOfWork = Depends(get_unit_of_work), +) -> ProjectRead: + update_data = payload.model_dump(exclude_unset=True) + for field, value in update_data.items(): + setattr(project, field, value) + + uow.flush() + return _to_read_model(project) + + +@router.delete("/{project_id}", status_code=status.HTTP_204_NO_CONTENT) +def delete_project( + project: Project = Depends(require_project_resource(require_manage=True)), + uow: UnitOfWork = Depends(get_unit_of_work), +) -> None: + _require_project_repo(uow).delete(project.id) + + +@router.get( + "/{project_id}/view", + response_class=HTMLResponse, + include_in_schema=False, + name="projects.view_project", +) +def view_project( + request: Request, + _: User = Depends(require_any_role_html(*READ_ROLES)), + project: Project = Depends(require_project_resource_html()), + uow: UnitOfWork = Depends(get_unit_of_work), +) -> HTMLResponse: + project = _require_project_repo(uow).get(project.id, with_children=True) + + scenarios = sorted(project.scenarios, key=lambda s: s.created_at) + scenario_stats = { + "total": len(scenarios), + "active": sum(1 for scenario in scenarios if scenario.status == ScenarioStatus.ACTIVE), + "draft": sum(1 for scenario in scenarios if scenario.status == ScenarioStatus.DRAFT), + "archived": sum(1 for scenario in scenarios if scenario.status == ScenarioStatus.ARCHIVED), + "latest_update": max( + (scenario.updated_at for scenario in scenarios if scenario.updated_at), + default=None, + ), + } + return templates.TemplateResponse( + request, + "projects/detail.html", + { + "project": project, + "scenarios": scenarios, + "scenario_stats": scenario_stats, + }, + ) + + +@router.get( + "/{project_id}/edit", + response_class=HTMLResponse, + include_in_schema=False, + name="projects.edit_project_form", +) +def edit_project_form( + request: Request, + _: User = Depends(require_roles_html(*MANAGE_ROLES)), + project: Project = Depends( + require_project_resource_html(require_manage=True) + ), +) -> HTMLResponse: + return templates.TemplateResponse( + request, + "projects/form.html", + { + "project": project, + "operation_types": _operation_type_choices(), + "form_action": request.url_for( + "projects.edit_project_submit", project_id=project.id + ), + "cancel_url": request.url_for( + "projects.view_project", project_id=project.id + ), + }, + ) + + +@router.post( + "/{project_id}/edit", + include_in_schema=False, + name="projects.edit_project_submit", +) +def edit_project_submit( + request: Request, + _: User = Depends(require_roles_html(*MANAGE_ROLES)), + project: Project = Depends( + require_project_resource_html(require_manage=True) + ), + name: str = Form(...), + location: str | None = Form(None), + operation_type: str | None = Form(None), + description: str | None = Form(None), + uow: UnitOfWork = Depends(get_unit_of_work), +): + def _normalise(value: str | None) -> str | None: + if value is None: + return None + value = value.strip() + return value or None + + project.name = name.strip() + project.location = _normalise(location) + if operation_type: + try: + project.operation_type = MiningOperationType(operation_type) + except ValueError: + return templates.TemplateResponse( + request, + "projects/form.html", + { + "project": project, + "operation_types": _operation_type_choices(), + "form_action": request.url_for( + "projects.edit_project_submit", project_id=project.id + ), + "cancel_url": request.url_for( + "projects.view_project", project_id=project.id + ), + "error": "Invalid operation type.", + }, + status_code=status.HTTP_400_BAD_REQUEST, + ) + project.description = _normalise(description) + + uow.flush() + + return RedirectResponse( + request.url_for("projects.view_project", project_id=project.id), + status_code=status.HTTP_303_SEE_OTHER, + ) diff --git a/routes/reporting.py b/routes/reporting.py deleted file mode 100644 index 09a9417..0000000 --- a/routes/reporting.py +++ /dev/null @@ -1,73 +0,0 @@ -from typing import Any, Dict, List, cast - -from fastapi import APIRouter, HTTPException, Request, status -from pydantic import BaseModel - -from services.reporting import generate_report - - -router = APIRouter(prefix="/api/reporting", tags=["Reporting"]) - - -def _validate_payload(payload: Any) -> List[Dict[str, float]]: - if not isinstance(payload, list): - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail="Invalid input format", - ) - - typed_payload = cast(List[Any], payload) - - validated: List[Dict[str, float]] = [] - for index, item in enumerate(typed_payload): - if not isinstance(item, dict): - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail=f"Entry at index {index} must be an object", - ) - value = cast(Dict[str, Any], item).get("result") - if not isinstance(value, (int, float)): - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail=f"Entry at index {index} must include numeric 'result'", - ) - validated.append({"result": float(value)}) - return validated - - -class ReportSummary(BaseModel): - count: int - mean: float - median: float - min: float - max: float - std_dev: float - variance: float - percentile_10: float - percentile_90: float - percentile_5: float - percentile_95: float - value_at_risk_95: float - expected_shortfall_95: float - - -@router.post("/summary", response_model=ReportSummary) -async def summary_report(request: Request): - payload = await request.json() - validated_payload = _validate_payload(payload) - summary = generate_report(validated_payload) - return ReportSummary( - count=int(summary["count"]), - mean=float(summary["mean"]), - median=float(summary["median"]), - min=float(summary["min"]), - max=float(summary["max"]), - std_dev=float(summary["std_dev"]), - variance=float(summary["variance"]), - percentile_10=float(summary["percentile_10"]), - percentile_90=float(summary["percentile_90"]), - percentile_5=float(summary["percentile_5"]), - percentile_95=float(summary["percentile_95"]), - value_at_risk_95=float(summary["value_at_risk_95"]), - expected_shortfall_95=float(summary["expected_shortfall_95"]), - ) diff --git a/routes/reports.py b/routes/reports.py new file mode 100644 index 0000000..5d632bb --- /dev/null +++ b/routes/reports.py @@ -0,0 +1,434 @@ +from __future__ import annotations + +from datetime import date + +from fastapi import APIRouter, Depends, HTTPException, Query, Request, status +from fastapi.encoders import jsonable_encoder +from fastapi.responses import HTMLResponse + +from dependencies import ( + get_unit_of_work, + require_any_role, + require_any_role_html, + require_project_resource, + require_scenario_resource, + require_project_resource_html, + require_scenario_resource_html, +) +from models import Project, Scenario, User +from services.exceptions import EntityNotFoundError, ScenarioValidationError +from services.reporting import ( + DEFAULT_ITERATIONS, + IncludeOptions, + ReportFilters, + ReportingService, + parse_include_tokens, + validate_percentiles, +) +from services.unit_of_work import UnitOfWork +from routes.template_filters import create_templates + +router = APIRouter(prefix="/reports", tags=["Reports"]) +templates = create_templates() + +READ_ROLES = ("viewer", "analyst", "project_manager", "admin") +MANAGE_ROLES = ("project_manager", "admin") + + +@router.get("/projects/{project_id}", name="reports.project_summary") +def project_summary_report( + project: Project = Depends(require_project_resource()), + _: User = Depends(require_any_role(*READ_ROLES)), + uow: UnitOfWork = Depends(get_unit_of_work), + include: str | None = Query( + None, + description="Comma-separated include tokens (distribution,samples,all).", + ), + scenario_ids: list[int] | None = Query( + None, + alias="scenario_ids", + description="Repeatable scenario identifier filter.", + ), + start_date: date | None = Query( + None, + description="Filter scenarios starting on or after this date.", + ), + end_date: date | None = Query( + None, + description="Filter scenarios ending on or before this date.", + ), + fmt: str = Query( + "json", + alias="format", + description="Response format (json only for this endpoint).", + ), + iterations: int | None = Query( + None, + gt=0, + description="Override Monte Carlo iteration count when distribution is included.", + ), + percentiles: list[float] | None = Query( + None, + description="Percentiles (0-100) for Monte Carlo summaries when included.", + ), +) -> dict[str, object]: + if fmt.lower() != "json": + raise HTTPException( + status_code=status.HTTP_406_NOT_ACCEPTABLE, + detail="Only JSON responses are supported; use the HTML endpoint for templates.", + ) + + include_options = parse_include_tokens(include) + try: + percentile_values = validate_percentiles(percentiles) + except ValueError as exc: + raise HTTPException( + status_code=status.HTTP_422_UNPROCESSABLE_CONTENT, + detail=str(exc), + ) from exc + + scenario_filter = ReportFilters( + scenario_ids=set(scenario_ids) if scenario_ids else None, + start_date=start_date, + end_date=end_date, + ) + + service = ReportingService(uow) + report = service.project_summary( + project, + filters=scenario_filter, + include=include_options, + iterations=iterations or DEFAULT_ITERATIONS, + percentiles=percentile_values, + ) + return jsonable_encoder(report) + + +@router.get( + "/projects/{project_id}/scenarios/compare", + name="reports.project_scenario_comparison", +) +def project_scenario_comparison_report( + project: Project = Depends(require_project_resource()), + _: User = Depends(require_any_role(*READ_ROLES)), + uow: UnitOfWork = Depends(get_unit_of_work), + scenario_ids: list[int] = Query( + ..., alias="scenario_ids", description="Repeatable scenario identifier."), + include: str | None = Query( + None, + description="Comma-separated include tokens (distribution,samples,all).", + ), + fmt: str = Query( + "json", + alias="format", + description="Response format (json only for this endpoint).", + ), + iterations: int | None = Query( + None, + gt=0, + description="Override Monte Carlo iteration count when distribution is included.", + ), + percentiles: list[float] | None = Query( + None, + description="Percentiles (0-100) for Monte Carlo summaries when included.", + ), +) -> dict[str, object]: + unique_ids = list(dict.fromkeys(scenario_ids)) + if len(unique_ids) < 2: + raise HTTPException( + status_code=status.HTTP_422_UNPROCESSABLE_CONTENT, + detail="At least two unique scenario_ids must be provided for comparison.", + ) + if fmt.lower() != "json": + raise HTTPException( + status_code=status.HTTP_406_NOT_ACCEPTABLE, + detail="Only JSON responses are supported; use the HTML endpoint for templates.", + ) + + include_options = parse_include_tokens(include) + try: + percentile_values = validate_percentiles(percentiles) + except ValueError as exc: + raise HTTPException( + status_code=status.HTTP_422_UNPROCESSABLE_CONTENT, + detail=str(exc), + ) from exc + + try: + scenarios = uow.validate_scenarios_for_comparison(unique_ids) + except ScenarioValidationError as exc: + raise HTTPException( + status_code=status.HTTP_422_UNPROCESSABLE_CONTENT, + detail={ + "code": exc.code, + "message": exc.message, + "scenario_ids": list(exc.scenario_ids or []), + }, + ) from exc + except EntityNotFoundError as exc: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=str(exc), + ) from exc + + if any(scenario.project_id != project.id for scenario in scenarios): + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="One or more scenarios are not associated with this project.", + ) + + service = ReportingService(uow) + report = service.scenario_comparison( + project, + scenarios, + include=include_options, + iterations=iterations or DEFAULT_ITERATIONS, + percentiles=percentile_values, + ) + return jsonable_encoder(report) + + +@router.get( + "/scenarios/{scenario_id}/distribution", + name="reports.scenario_distribution", +) +def scenario_distribution_report( + scenario: Scenario = Depends(require_scenario_resource()), + _: User = Depends(require_any_role(*READ_ROLES)), + uow: UnitOfWork = Depends(get_unit_of_work), + include: str | None = Query( + None, + description="Comma-separated include tokens (samples,all).", + ), + fmt: str = Query( + "json", + alias="format", + description="Response format (json only for this endpoint).", + ), + iterations: int | None = Query( + None, + gt=0, + description="Override Monte Carlo iteration count (default applies otherwise).", + ), + percentiles: list[float] | None = Query( + None, + description="Percentiles (0-100) for Monte Carlo summaries.", + ), +) -> dict[str, object]: + if fmt.lower() != "json": + raise HTTPException( + status_code=status.HTTP_406_NOT_ACCEPTABLE, + detail="Only JSON responses are supported; use the HTML endpoint for templates.", + ) + + requested = parse_include_tokens(include) + include_options = IncludeOptions( + distribution=True, samples=requested.samples) + + try: + percentile_values = validate_percentiles(percentiles) + except ValueError as exc: + raise HTTPException( + status_code=status.HTTP_422_UNPROCESSABLE_CONTENT, + detail=str(exc), + ) from exc + + service = ReportingService(uow) + report = service.scenario_distribution( + scenario, + include=include_options, + iterations=iterations or DEFAULT_ITERATIONS, + percentiles=percentile_values, + ) + return jsonable_encoder(report) + + +@router.get( + "/projects/{project_id}/ui", + response_class=HTMLResponse, + include_in_schema=False, + name="reports.project_summary_page", +) +def project_summary_page( + request: Request, + project: Project = Depends(require_project_resource_html()), + _: User = Depends(require_any_role_html(*READ_ROLES)), + uow: UnitOfWork = Depends(get_unit_of_work), + include: str | None = Query( + None, + description="Comma-separated include tokens (distribution,samples,all).", + ), + scenario_ids: list[int] | None = Query( + None, + alias="scenario_ids", + description="Repeatable scenario identifier filter.", + ), + start_date: date | None = Query( + None, + description="Filter scenarios starting on or after this date.", + ), + end_date: date | None = Query( + None, + description="Filter scenarios ending on or before this date.", + ), + iterations: int | None = Query( + None, + gt=0, + description="Override Monte Carlo iteration count when distribution is included.", + ), + percentiles: list[float] | None = Query( + None, + description="Percentiles (0-100) for Monte Carlo summaries when included.", + ), +) -> HTMLResponse: + include_options = parse_include_tokens(include) + try: + percentile_values = validate_percentiles(percentiles) + except ValueError as exc: + raise HTTPException( + status_code=status.HTTP_422_UNPROCESSABLE_CONTENT, + detail=str(exc), + ) from exc + + scenario_filter = ReportFilters( + scenario_ids=set(scenario_ids) if scenario_ids else None, + start_date=start_date, + end_date=end_date, + ) + + service = ReportingService(uow) + context = service.build_project_summary_context( + project, scenario_filter, include_options, iterations or DEFAULT_ITERATIONS, percentile_values, request + ) + return templates.TemplateResponse( + request, + "reports/project_summary.html", + context, + ) + + +@router.get( + "/projects/{project_id}/scenarios/compare/ui", + response_class=HTMLResponse, + include_in_schema=False, + name="reports.project_scenario_comparison_page", +) +def project_scenario_comparison_page( + request: Request, + project: Project = Depends(require_project_resource_html()), + _: User = Depends(require_any_role_html(*READ_ROLES)), + uow: UnitOfWork = Depends(get_unit_of_work), + scenario_ids: list[int] = Query( + ..., alias="scenario_ids", description="Repeatable scenario identifier."), + include: str | None = Query( + None, + description="Comma-separated include tokens (distribution,samples,all).", + ), + iterations: int | None = Query( + None, + gt=0, + description="Override Monte Carlo iteration count when distribution is included.", + ), + percentiles: list[float] | None = Query( + None, + description="Percentiles (0-100) for Monte Carlo summaries when included.", + ), +) -> HTMLResponse: + unique_ids = list(dict.fromkeys(scenario_ids)) + if len(unique_ids) < 2: + raise HTTPException( + status_code=status.HTTP_422_UNPROCESSABLE_CONTENT, + detail="At least two unique scenario_ids must be provided for comparison.", + ) + + include_options = parse_include_tokens(include) + try: + percentile_values = validate_percentiles(percentiles) + except ValueError as exc: + raise HTTPException( + status_code=status.HTTP_422_UNPROCESSABLE_CONTENT, + detail=str(exc), + ) from exc + + try: + scenarios = uow.validate_scenarios_for_comparison(unique_ids) + except ScenarioValidationError as exc: + raise HTTPException( + status_code=status.HTTP_422_UNPROCESSABLE_CONTENT, + detail={ + "code": exc.code, + "message": exc.message, + "scenario_ids": list(exc.scenario_ids or []), + }, + ) from exc + except EntityNotFoundError as exc: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=str(exc), + ) from exc + + if any(scenario.project_id != project.id for scenario in scenarios): + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="One or more scenarios are not associated with this project.", + ) + + service = ReportingService(uow) + context = service.build_scenario_comparison_context( + project, scenarios, include_options, iterations or DEFAULT_ITERATIONS, percentile_values, request + ) + return templates.TemplateResponse( + request, + "reports/scenario_comparison.html", + context, + ) + + +@router.get( + "/scenarios/{scenario_id}/distribution/ui", + response_class=HTMLResponse, + include_in_schema=False, + name="reports.scenario_distribution_page", +) +def scenario_distribution_page( + request: Request, + _: User = Depends(require_any_role_html(*READ_ROLES)), + scenario: Scenario = Depends( + require_scenario_resource_html() + ), + uow: UnitOfWork = Depends(get_unit_of_work), + include: str | None = Query( + None, + description="Comma-separated include tokens (samples,all).", + ), + iterations: int | None = Query( + None, + gt=0, + description="Override Monte Carlo iteration count (default applies otherwise).", + ), + percentiles: list[float] | None = Query( + None, + description="Percentiles (0-100) for Monte Carlo summaries.", + ), +) -> HTMLResponse: + requested = parse_include_tokens(include) + include_options = IncludeOptions( + distribution=True, samples=requested.samples) + + try: + percentile_values = validate_percentiles(percentiles) + except ValueError as exc: + raise HTTPException( + status_code=status.HTTP_422_UNPROCESSABLE_CONTENT, + detail=str(exc), + ) from exc + + service = ReportingService(uow) + context = service.build_scenario_distribution_context( + scenario, include_options, iterations or DEFAULT_ITERATIONS, percentile_values, request + ) + return templates.TemplateResponse( + request, + "reports/scenario_distribution.html", + context, + ) diff --git a/routes/scenarios.py b/routes/scenarios.py index 4454f74..5cc477c 100644 --- a/routes/scenarios.py +++ b/routes/scenarios.py @@ -1,42 +1,656 @@ -from datetime import datetime -from typing import Optional +from __future__ import annotations -from fastapi import APIRouter, Depends, HTTPException -from pydantic import BaseModel, ConfigDict -from sqlalchemy.orm import Session +from datetime import date +from types import SimpleNamespace +from typing import List -from models.scenario import Scenario -from routes.dependencies import get_db +from fastapi import APIRouter, Depends, Form, HTTPException, Request, status +from fastapi.responses import HTMLResponse, RedirectResponse -router = APIRouter(prefix="/api/scenarios", tags=["scenarios"]) +from dependencies import ( + get_pricing_metadata, + get_unit_of_work, + require_any_role, + require_any_role_html, + require_roles, + require_roles_html, + require_scenario_resource, + require_scenario_resource_html, +) +from models import ResourceType, Scenario, ScenarioStatus, User +from schemas.scenario import ( + ScenarioComparisonRequest, + ScenarioComparisonResponse, + ScenarioCreate, + ScenarioRead, + ScenarioUpdate, +) +from services.currency import CurrencyValidationError, normalise_currency +from services.exceptions import ( + EntityConflictError, + EntityNotFoundError, + ScenarioValidationError, +) +from services.pricing import PricingMetadata +from services.unit_of_work import UnitOfWork +from routes.template_filters import create_templates -# Pydantic schemas +router = APIRouter(tags=["Scenarios"]) +templates = create_templates() + +READ_ROLES = ("viewer", "analyst", "project_manager", "admin") +MANAGE_ROLES = ("project_manager", "admin") -class ScenarioCreate(BaseModel): - name: str - description: Optional[str] = None +def _to_read_model(scenario: Scenario) -> ScenarioRead: + return ScenarioRead.model_validate(scenario) -class ScenarioRead(ScenarioCreate): - id: int - created_at: datetime - updated_at: Optional[datetime] = None - model_config = ConfigDict(from_attributes=True) +def _resource_type_choices() -> list[tuple[str, str]]: + return [ + (resource.value, resource.value.replace("_", " ").title()) + for resource in ResourceType + ] -@router.post("/", response_model=ScenarioRead) -def create_scenario(scenario: ScenarioCreate, db: Session = Depends(get_db)): - db_s = db.query(Scenario).filter(Scenario.name == scenario.name).first() - if db_s: - raise HTTPException(status_code=400, detail="Scenario already exists") - new_s = Scenario(name=scenario.name, description=scenario.description) - db.add(new_s) - db.commit() - db.refresh(new_s) - return new_s +def _scenario_status_choices() -> list[tuple[str, str]]: + return [ + (status.value, status.value.title()) for status in ScenarioStatus + ] -@router.get("/", response_model=list[ScenarioRead]) -def list_scenarios(db: Session = Depends(get_db)): - return db.query(Scenario).all() +def _require_project_repo(uow: UnitOfWork): + if not uow.projects: + raise RuntimeError("Project repository not initialised") + return uow.projects + + +def _require_scenario_repo(uow: UnitOfWork): + if not uow.scenarios: + raise RuntimeError("Scenario repository not initialised") + return uow.scenarios + + +@router.get( + "/projects/{project_id}/scenarios", + response_model=List[ScenarioRead], +) +def list_scenarios_for_project( + project_id: int, + _: User = Depends(require_any_role(*READ_ROLES)), + uow: UnitOfWork = Depends(get_unit_of_work), +) -> List[ScenarioRead]: + project_repo = _require_project_repo(uow) + scenario_repo = _require_scenario_repo(uow) + try: + project_repo.get(project_id) + except EntityNotFoundError as exc: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, detail=str(exc)) from exc + + scenarios = scenario_repo.list_for_project(project_id) + return [_to_read_model(scenario) for scenario in scenarios] + + +@router.post( + "/projects/{project_id}/scenarios/compare", + response_model=ScenarioComparisonResponse, + status_code=status.HTTP_200_OK, +) +def compare_scenarios( + project_id: int, + payload: ScenarioComparisonRequest, + _: User = Depends(require_any_role(*READ_ROLES)), + uow: UnitOfWork = Depends(get_unit_of_work), +) -> ScenarioComparisonResponse: + try: + _require_project_repo(uow).get(project_id) + except EntityNotFoundError as exc: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, detail=str(exc) + ) from exc + + try: + scenarios = uow.validate_scenarios_for_comparison(payload.scenario_ids) + if any(scenario.project_id != project_id for scenario in scenarios): + raise ScenarioValidationError( + code="SCENARIO_PROJECT_MISMATCH", + message="Selected scenarios do not belong to the same project.", + scenario_ids=[ + scenario.id for scenario in scenarios if scenario.id is not None + ], + ) + except EntityNotFoundError as exc: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, detail=str(exc) + ) from exc + except ScenarioValidationError as exc: + raise HTTPException( + status_code=status.HTTP_422_UNPROCESSABLE_CONTENT, + detail={ + "code": exc.code, + "message": exc.message, + "scenario_ids": list(exc.scenario_ids or []), + }, + ) from exc + + return ScenarioComparisonResponse( + project_id=project_id, + scenarios=[_to_read_model(scenario) for scenario in scenarios], + ) + + +@router.post( + "/projects/{project_id}/scenarios", + response_model=ScenarioRead, + status_code=status.HTTP_201_CREATED, +) +def create_scenario_for_project( + project_id: int, + payload: ScenarioCreate, + _: User = Depends(require_roles(*MANAGE_ROLES)), + uow: UnitOfWork = Depends(get_unit_of_work), + metadata: PricingMetadata = Depends(get_pricing_metadata), +) -> ScenarioRead: + project_repo = _require_project_repo(uow) + scenario_repo = _require_scenario_repo(uow) + try: + project_repo.get(project_id) + except EntityNotFoundError as exc: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, detail=str(exc)) from exc + + scenario_data = payload.model_dump() + if not scenario_data.get("currency") and metadata.default_currency: + scenario_data["currency"] = metadata.default_currency + scenario = Scenario(project_id=project_id, **scenario_data) + + try: + created = scenario_repo.create(scenario) + except EntityConflictError as exc: + raise HTTPException( + status_code=status.HTTP_409_CONFLICT, detail=str(exc)) from exc + return _to_read_model(created) + + +@router.get( + "/projects/{project_id}/scenarios/ui", + response_class=HTMLResponse, + include_in_schema=False, + name="scenarios.project_scenario_list", +) +def project_scenario_list_page( + project_id: int, + request: Request, + _: User = Depends(require_any_role_html(*READ_ROLES)), + uow: UnitOfWork = Depends(get_unit_of_work), +) -> HTMLResponse: + try: + project = _require_project_repo(uow).get( + project_id, with_children=True) + except EntityNotFoundError as exc: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, detail=str(exc) + ) from exc + + scenarios = sorted( + project.scenarios, + key=lambda scenario: scenario.updated_at or scenario.created_at, + reverse=True, + ) + scenario_totals = { + "total": len(scenarios), + "active": sum( + 1 for scenario in scenarios if scenario.status == ScenarioStatus.ACTIVE + ), + "draft": sum( + 1 for scenario in scenarios if scenario.status == ScenarioStatus.DRAFT + ), + "archived": sum( + 1 for scenario in scenarios if scenario.status == ScenarioStatus.ARCHIVED + ), + "latest_update": max( + ( + scenario.updated_at or scenario.created_at + for scenario in scenarios + if scenario.updated_at or scenario.created_at + ), + default=None, + ), + } + + return templates.TemplateResponse( + request, + "scenarios/list.html", + { + "project": project, + "scenarios": scenarios, + "scenario_totals": scenario_totals, + }, + ) + + +@router.get("/scenarios/{scenario_id}", response_model=ScenarioRead) +def get_scenario( + scenario: Scenario = Depends(require_scenario_resource()), +) -> ScenarioRead: + return _to_read_model(scenario) + + +@router.put("/scenarios/{scenario_id}", response_model=ScenarioRead) +def update_scenario( + payload: ScenarioUpdate, + scenario: Scenario = Depends( + require_scenario_resource(require_manage=True) + ), + uow: UnitOfWork = Depends(get_unit_of_work), +) -> ScenarioRead: + update_data = payload.model_dump(exclude_unset=True) + for field, value in update_data.items(): + setattr(scenario, field, value) + + uow.flush() + return _to_read_model(scenario) + + +@router.delete("/scenarios/{scenario_id}", status_code=status.HTTP_204_NO_CONTENT) +def delete_scenario( + scenario: Scenario = Depends( + require_scenario_resource(require_manage=True) + ), + uow: UnitOfWork = Depends(get_unit_of_work), +) -> None: + _require_scenario_repo(uow).delete(scenario.id) + + +def _normalise(value: str | None) -> str | None: + if value is None: + return None + value = value.strip() + return value or None + + +def _parse_date(value: str | None) -> date | None: + value = _normalise(value) + if not value: + return None + return date.fromisoformat(value) + + +def _parse_discount_rate(value: str | None) -> float | None: + value = _normalise(value) + if not value: + return None + try: + return float(value) + except ValueError: + return None + + +def _scenario_form_state( + *, + project_id: int, + name: str, + description: str | None, + status: ScenarioStatus, + start_date: date | None, + end_date: date | None, + discount_rate: float | None, + currency: str | None, + primary_resource: ResourceType | None, + scenario_id: int | None = None, +) -> SimpleNamespace: + return SimpleNamespace( + id=scenario_id, + project_id=project_id, + name=name, + description=description, + status=status, + start_date=start_date, + end_date=end_date, + discount_rate=discount_rate, + currency=currency, + primary_resource=primary_resource, + ) + + +@router.get( + "/projects/{project_id}/scenarios/new", + response_class=HTMLResponse, + include_in_schema=False, + name="scenarios.create_scenario_form", +) +def create_scenario_form( + project_id: int, + request: Request, + _: User = Depends(require_roles_html(*MANAGE_ROLES)), + uow: UnitOfWork = Depends(get_unit_of_work), + metadata: PricingMetadata = Depends(get_pricing_metadata), +) -> HTMLResponse: + try: + project = _require_project_repo(uow).get(project_id) + except EntityNotFoundError as exc: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, detail=str(exc) + ) from exc + + return templates.TemplateResponse( + request, + "scenarios/form.html", + { + "project": project, + "scenario": None, + "scenario_statuses": _scenario_status_choices(), + "resource_types": _resource_type_choices(), + "form_action": request.url_for( + "scenarios.create_scenario_submit", project_id=project_id + ), + "cancel_url": request.url_for( + "projects.view_project", project_id=project_id + ), + "default_currency": metadata.default_currency, + }, + ) + + +@router.post( + "/projects/{project_id}/scenarios/new", + include_in_schema=False, + name="scenarios.create_scenario_submit", +) +def create_scenario_submit( + project_id: int, + request: Request, + _: User = Depends(require_roles_html(*MANAGE_ROLES)), + name: str = Form(...), + description: str | None = Form(None), + status_value: str = Form(ScenarioStatus.DRAFT.value), + start_date: str | None = Form(None), + end_date: str | None = Form(None), + discount_rate: str | None = Form(None), + currency: str | None = Form(None), + primary_resource: str | None = Form(None), + uow: UnitOfWork = Depends(get_unit_of_work), + metadata: PricingMetadata = Depends(get_pricing_metadata), +): + project_repo = _require_project_repo(uow) + scenario_repo = _require_scenario_repo(uow) + try: + project = project_repo.get(project_id) + except EntityNotFoundError as exc: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, detail=str(exc) + ) from exc + + try: + status_enum = ScenarioStatus(status_value) + except ValueError: + status_enum = ScenarioStatus.DRAFT + + resource_enum = None + if primary_resource: + try: + resource_enum = ResourceType(primary_resource) + except ValueError: + resource_enum = None + + name_value = name.strip() + description_value = _normalise(description) + start_date_value = _parse_date(start_date) + end_date_value = _parse_date(end_date) + discount_rate_value = _parse_discount_rate(discount_rate) + currency_input = _normalise(currency) + effective_currency = currency_input or metadata.default_currency + + try: + currency_value = ( + normalise_currency(effective_currency) + if effective_currency else None + ) + except CurrencyValidationError as exc: + form_state = _scenario_form_state( + project_id=project_id, + name=name_value, + description=description_value, + status=status_enum, + start_date=start_date_value, + end_date=end_date_value, + discount_rate=discount_rate_value, + currency=currency_input or metadata.default_currency, + primary_resource=resource_enum, + ) + return templates.TemplateResponse( + request, + "scenarios/form.html", + { + "project": project, + "scenario": form_state, + "scenario_statuses": _scenario_status_choices(), + "resource_types": _resource_type_choices(), + "form_action": request.url_for( + "scenarios.create_scenario_submit", project_id=project_id + ), + "cancel_url": request.url_for( + "projects.view_project", project_id=project_id + ), + "error": str(exc), + "error_field": "currency", + "default_currency": metadata.default_currency, + }, + status_code=status.HTTP_400_BAD_REQUEST, + ) + + scenario = Scenario( + project_id=project_id, + name=name_value, + description=description_value, + status=status_enum, + start_date=start_date_value, + end_date=end_date_value, + discount_rate=discount_rate_value, + currency=currency_value, + primary_resource=resource_enum, + ) + + try: + scenario_repo.create(scenario) + except EntityConflictError: + return templates.TemplateResponse( + request, + "scenarios/form.html", + { + "project": project, + "scenario": scenario, + "scenario_statuses": _scenario_status_choices(), + "resource_types": _resource_type_choices(), + "form_action": request.url_for( + "scenarios.create_scenario_submit", project_id=project_id + ), + "cancel_url": request.url_for( + "projects.view_project", project_id=project_id + ), + "error": "Scenario with this name already exists for this project.", + "error_field": "name", + "default_currency": metadata.default_currency, + }, + status_code=status.HTTP_409_CONFLICT, + ) + + return RedirectResponse( + request.url_for("projects.view_project", project_id=project_id), + status_code=status.HTTP_303_SEE_OTHER, + ) + + +@router.get( + "/scenarios/{scenario_id}/view", + response_class=HTMLResponse, + include_in_schema=False, + name="scenarios.view_scenario", +) +def view_scenario( + request: Request, + _: User = Depends(require_any_role_html(*READ_ROLES)), + scenario: Scenario = Depends( + require_scenario_resource_html(with_children=True) + ), + uow: UnitOfWork = Depends(get_unit_of_work), +) -> HTMLResponse: + project = _require_project_repo(uow).get(scenario.project_id) + financial_inputs = sorted( + scenario.financial_inputs, key=lambda item: item.created_at + ) + simulation_parameters = sorted( + scenario.simulation_parameters, key=lambda item: item.created_at + ) + + scenario_metrics = { + "financial_count": len(financial_inputs), + "parameter_count": len(simulation_parameters), + "currency": scenario.currency, + "primary_resource": scenario.primary_resource.value.replace('_', ' ').title() if scenario.primary_resource else None, + } + + return templates.TemplateResponse( + request, + "scenarios/detail.html", + { + "project": project, + "scenario": scenario, + "scenario_metrics": scenario_metrics, + "financial_inputs": financial_inputs, + "simulation_parameters": simulation_parameters, + }, + ) + + +@router.get( + "/scenarios/{scenario_id}/edit", + response_class=HTMLResponse, + include_in_schema=False, + name="scenarios.edit_scenario_form", +) +def edit_scenario_form( + request: Request, + _: User = Depends(require_roles_html(*MANAGE_ROLES)), + scenario: Scenario = Depends( + require_scenario_resource_html(require_manage=True) + ), + uow: UnitOfWork = Depends(get_unit_of_work), + metadata: PricingMetadata = Depends(get_pricing_metadata), +) -> HTMLResponse: + project = _require_project_repo(uow).get(scenario.project_id) + + return templates.TemplateResponse( + request, + "scenarios/form.html", + { + "project": project, + "scenario": scenario, + "scenario_statuses": _scenario_status_choices(), + "resource_types": _resource_type_choices(), + "form_action": request.url_for( + "scenarios.edit_scenario_submit", scenario_id=scenario.id + ), + "cancel_url": request.url_for( + "scenarios.view_scenario", scenario_id=scenario.id + ), + "default_currency": metadata.default_currency, + }, + ) + + +@router.post( + "/scenarios/{scenario_id}/edit", + include_in_schema=False, + name="scenarios.edit_scenario_submit", +) +def edit_scenario_submit( + request: Request, + _: User = Depends(require_roles_html(*MANAGE_ROLES)), + scenario: Scenario = Depends( + require_scenario_resource_html(require_manage=True) + ), + name: str = Form(...), + description: str | None = Form(None), + status_value: str = Form(ScenarioStatus.DRAFT.value), + start_date: str | None = Form(None), + end_date: str | None = Form(None), + discount_rate: str | None = Form(None), + currency: str | None = Form(None), + primary_resource: str | None = Form(None), + uow: UnitOfWork = Depends(get_unit_of_work), + metadata: PricingMetadata = Depends(get_pricing_metadata), +): + project = _require_project_repo(uow).get(scenario.project_id) + + name_value = name.strip() + description_value = _normalise(description) + try: + scenario.status = ScenarioStatus(status_value) + except ValueError: + scenario.status = ScenarioStatus.DRAFT + status_enum = scenario.status + + resource_enum = None + if primary_resource: + try: + resource_enum = ResourceType(primary_resource) + except ValueError: + resource_enum = None + + start_date_value = _parse_date(start_date) + end_date_value = _parse_date(end_date) + discount_rate_value = _parse_discount_rate(discount_rate) + currency_input = _normalise(currency) + + try: + currency_value = normalise_currency(currency_input) + except CurrencyValidationError as exc: + form_state = _scenario_form_state( + scenario_id=scenario.id, + project_id=scenario.project_id, + name=name_value, + description=description_value, + status=status_enum, + start_date=start_date_value, + end_date=end_date_value, + discount_rate=discount_rate_value, + currency=currency_input, + primary_resource=resource_enum, + ) + return templates.TemplateResponse( + request, + "scenarios/form.html", + { + "project": project, + "scenario": form_state, + "scenario_statuses": _scenario_status_choices(), + "resource_types": _resource_type_choices(), + "form_action": request.url_for( + "scenarios.edit_scenario_submit", scenario_id=scenario.id + ), + "cancel_url": request.url_for( + "scenarios.view_scenario", scenario_id=scenario.id + ), + "error": str(exc), + "error_field": "currency", + "default_currency": metadata.default_currency, + }, + status_code=status.HTTP_400_BAD_REQUEST, + ) + + scenario.name = name_value + scenario.description = description_value + scenario.start_date = start_date_value + scenario.end_date = end_date_value + scenario.discount_rate = discount_rate_value + scenario.currency = currency_value + scenario.primary_resource = resource_enum + + uow.flush() + + return RedirectResponse( + request.url_for("scenarios.view_scenario", scenario_id=scenario.id), + status_code=status.HTTP_303_SEE_OTHER, + ) diff --git a/routes/settings.py b/routes/settings.py deleted file mode 100644 index ed06fb5..0000000 --- a/routes/settings.py +++ /dev/null @@ -1,110 +0,0 @@ -from typing import Dict, List - -from fastapi import APIRouter, Depends, HTTPException, status -from pydantic import BaseModel, Field, model_validator -from sqlalchemy.orm import Session - -from routes.dependencies import get_db -from services.settings import ( - CSS_COLOR_DEFAULTS, - get_css_color_settings, - list_css_env_override_rows, - read_css_color_env_overrides, - update_css_color_settings, - get_theme_settings, - save_theme_settings, -) - -router = APIRouter(prefix="/api/settings", tags=["Settings"]) - - -class CSSSettingsPayload(BaseModel): - variables: Dict[str, str] = Field(default_factory=dict) - - @model_validator(mode="after") - def _validate_allowed_keys(self) -> "CSSSettingsPayload": - invalid = set(self.variables.keys()) - set(CSS_COLOR_DEFAULTS.keys()) - if invalid: - invalid_keys = ", ".join(sorted(invalid)) - raise ValueError( - f"Unsupported CSS variables: {invalid_keys}." - " Accepted keys align with the default theme variables." - ) - return self - - -class EnvOverride(BaseModel): - css_key: str - env_var: str - value: str - - -class CSSSettingsResponse(BaseModel): - variables: Dict[str, str] - env_overrides: Dict[str, str] = Field(default_factory=dict) - env_sources: List[EnvOverride] = Field(default_factory=list) - - -@router.get("/css", response_model=CSSSettingsResponse) -def read_css_settings(db: Session = Depends(get_db)) -> CSSSettingsResponse: - try: - values = get_css_color_settings(db) - env_overrides = read_css_color_env_overrides() - env_sources = [ - EnvOverride(**row) for row in list_css_env_override_rows() - ] - except ValueError as exc: - raise HTTPException( - status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail=str(exc), - ) from exc - return CSSSettingsResponse( - variables=values, - env_overrides=env_overrides, - env_sources=env_sources, - ) - - -@router.put( - "/css", response_model=CSSSettingsResponse, status_code=status.HTTP_200_OK -) -def update_css_settings( - payload: CSSSettingsPayload, db: Session = Depends(get_db) -) -> CSSSettingsResponse: - try: - values = update_css_color_settings(db, payload.variables) - env_overrides = read_css_color_env_overrides() - env_sources = [ - EnvOverride(**row) for row in list_css_env_override_rows() - ] - except ValueError as exc: - raise HTTPException( - status_code=status.HTTP_422_UNPROCESSABLE_CONTENT, - detail=str(exc), - ) from exc - return CSSSettingsResponse( - variables=values, - env_overrides=env_overrides, - env_sources=env_sources, - ) - - -class ThemeSettings(BaseModel): - theme_name: str - primary_color: str - secondary_color: str - accent_color: str - background_color: str - text_color: str - - -@router.post("/theme") -async def update_theme(theme_data: ThemeSettings, db: Session = Depends(get_db)): - data_dict = theme_data.model_dump() - save_theme_settings(db, data_dict) - return {"message": "Theme updated", "theme": data_dict} - - -@router.get("/theme") -async def get_theme(db: Session = Depends(get_db)): - return get_theme_settings(db) diff --git a/routes/simulations.py b/routes/simulations.py deleted file mode 100644 index 5500805..0000000 --- a/routes/simulations.py +++ /dev/null @@ -1,126 +0,0 @@ -from typing import Dict, List, Optional - -from fastapi import APIRouter, Depends, HTTPException, status -from pydantic import BaseModel, PositiveInt -from sqlalchemy.orm import Session - -from models.parameters import Parameter -from models.scenario import Scenario -from models.simulation_result import SimulationResult -from routes.dependencies import get_db -from services.reporting import generate_report -from services.simulation import run_simulation - -router = APIRouter(prefix="/api/simulations", tags=["Simulations"]) - - -class SimulationParameterInput(BaseModel): - name: str - value: float - distribution: Optional[str] = "normal" - std_dev: Optional[float] = None - min: Optional[float] = None - max: Optional[float] = None - mode: Optional[float] = None - - -class SimulationRunRequest(BaseModel): - scenario_id: int - iterations: PositiveInt = 1000 - parameters: Optional[List[SimulationParameterInput]] = None - seed: Optional[int] = None - - -class SimulationResultItem(BaseModel): - iteration: int - result: float - - -class SimulationRunResponse(BaseModel): - scenario_id: int - iterations: int - results: List[SimulationResultItem] - summary: Dict[str, float | int] - - -def _load_parameters( - db: Session, scenario_id: int -) -> List[SimulationParameterInput]: - db_params = ( - db.query(Parameter) - .filter(Parameter.scenario_id == scenario_id) - .order_by(Parameter.id) - .all() - ) - return [ - SimulationParameterInput( - name=item.name, - value=item.value, - ) - for item in db_params - ] - - -@router.post("/run", response_model=SimulationRunResponse) -async def simulate( - payload: SimulationRunRequest, db: Session = Depends(get_db) -): - scenario = ( - db.query(Scenario).filter(Scenario.id == payload.scenario_id).first() - ) - if scenario is None: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail="Scenario not found", - ) - - parameters = payload.parameters or _load_parameters(db, payload.scenario_id) - if not parameters: - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail="No parameters provided", - ) - - raw_results = run_simulation( - [param.model_dump(exclude_none=True) for param in parameters], - iterations=payload.iterations, - seed=payload.seed, - ) - - if not raw_results: - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail="Simulation produced no results", - ) - - # Persist results (replace existing values for scenario) - db.query(SimulationResult).filter( - SimulationResult.scenario_id == payload.scenario_id - ).delete() - db.bulk_save_objects( - [ - SimulationResult( - scenario_id=payload.scenario_id, - iteration=item["iteration"], - result=item["result"], - ) - for item in raw_results - ] - ) - db.commit() - - summary = generate_report(raw_results) - - response = SimulationRunResponse( - scenario_id=payload.scenario_id, - iterations=payload.iterations, - results=[ - SimulationResultItem( - iteration=int(item["iteration"]), - result=float(item["result"]), - ) - for item in raw_results - ], - summary=summary, - ) - return response diff --git a/routes/template_filters.py b/routes/template_filters.py new file mode 100644 index 0000000..db2e43c --- /dev/null +++ b/routes/template_filters.py @@ -0,0 +1,147 @@ +from __future__ import annotations + +import logging +from datetime import datetime, timezone +from typing import Any + +from fastapi import Request +from fastapi.templating import Jinja2Templates + +from services.navigation import NavigationService +from services.session import AuthSession +from services.unit_of_work import UnitOfWork + + +logger = logging.getLogger(__name__) + + +def format_datetime(value: Any) -> str: + """Render datetime values consistently for templates.""" + if not isinstance(value, datetime): + return "" + if value.tzinfo is None: + value = value.replace(tzinfo=timezone.utc) + return value.strftime("%Y-%m-%d %H:%M UTC") + + +def currency_display(value: Any, currency_code: str | None) -> str: + """Format numeric values with currency context.""" + if value is None: + return "—" + if isinstance(value, (int, float)): + formatted_value = f"{value:,.2f}" + else: + formatted_value = str(value) + if currency_code: + return f"{currency_code} {formatted_value}" + return formatted_value + + +def format_metric(value: Any, metric_name: str, currency_code: str | None = None) -> str: + """Format metrics according to their semantic type.""" + if value is None: + return "—" + + currency_metrics = { + "npv", + "inflows", + "outflows", + "net", + "total_inflows", + "total_outflows", + "total_net", + } + if metric_name in currency_metrics and currency_code: + return currency_display(value, currency_code) + + percentage_metrics = {"irr", "payback_period"} + if metric_name in percentage_metrics: + if isinstance(value, (int, float)): + return f"{value:.2f}%" + return f"{value}%" + + if isinstance(value, (int, float)): + return f"{value:,.2f}" + + return str(value) + + +def percentage_display(value: Any) -> str: + """Format numeric values as percentages.""" + if value is None: + return "—" + if isinstance(value, (int, float)): + return f"{value:.2f}%" + return f"{value}%" + + +def period_display(value: Any) -> str: + """Format period values in years.""" + if value is None: + return "—" + if isinstance(value, (int, float)): + if value == int(value): + return f"{int(value)} years" + return f"{value:.1f} years" + return str(value) + + +def register_common_filters(templates: Jinja2Templates) -> None: + templates.env.filters["format_datetime"] = format_datetime + templates.env.filters["currency_display"] = currency_display + templates.env.filters["format_metric"] = format_metric + templates.env.filters["percentage_display"] = percentage_display + templates.env.filters["period_display"] = period_display + + +def _sidebar_navigation_for_request(request: Request | None): + if request is None: + return None + + cached = getattr(request.state, "_navigation_sidebar_dto", None) + if cached is not None: + return cached + + session_context = getattr(request.state, "auth_session", None) + if isinstance(session_context, AuthSession): + session = session_context + else: + session = AuthSession.anonymous() + + try: + with UnitOfWork() as uow: + if not uow.navigation: + logger.debug("Navigation repository unavailable for sidebar rendering") + sidebar_dto = None + else: + service = NavigationService(uow.navigation) + sidebar_dto = service.build_sidebar(session=session, request=request) + except Exception: # pragma: no cover - defensive fallback for templates + logger.exception("Failed to build sidebar navigation during template render") + sidebar_dto = None + + setattr(request.state, "_navigation_sidebar_dto", sidebar_dto) + return sidebar_dto + + +def register_navigation_globals(templates: Jinja2Templates) -> None: + templates.env.globals["get_sidebar_navigation"] = _sidebar_navigation_for_request + + +def create_templates() -> Jinja2Templates: + templates = Jinja2Templates(directory="templates") + register_common_filters(templates) + register_navigation_globals(templates) + return templates + + +__all__ = [ + "format_datetime", + "currency_display", + "format_metric", + "percentage_display", + "period_display", + "register_common_filters", + "register_navigation_globals", + "create_templates", +] diff --git a/routes/ui.py b/routes/ui.py index e690dba..c4a77df 100644 --- a/routes/ui.py +++ b/routes/ui.py @@ -1,784 +1,109 @@ -from collections import defaultdict -from datetime import datetime, timezone -from typing import Any, Dict, Optional +from __future__ import annotations from fastapi import APIRouter, Depends, Request -from fastapi.responses import HTMLResponse, JSONResponse -from fastapi.templating import Jinja2Templates -from sqlalchemy.orm import Session +from fastapi.responses import HTMLResponse -from models.capex import Capex -from models.consumption import Consumption -from models.equipment import Equipment -from models.maintenance import Maintenance -from models.opex import Opex -from models.parameters import Parameter -from models.production_output import ProductionOutput -from models.scenario import Scenario -from models.simulation_result import SimulationResult -from routes.dependencies import get_db -from services.reporting import generate_report -from models.currency import Currency -from routes.currencies import DEFAULT_CURRENCY_CODE, _ensure_default_currency -from services.settings import ( - CSS_COLOR_DEFAULTS, - get_css_color_settings, - list_css_env_override_rows, - read_css_color_env_overrides, +from dependencies import require_any_role_html, require_roles_html +from models import User +from routes.template_filters import create_templates + +router = APIRouter(tags=["UI"]) +templates = create_templates() + +READ_ROLES = ("viewer", "analyst", "project_manager", "admin") +MANAGE_ROLES = ("project_manager", "admin") + + +@router.get( + "/ui/simulations", + response_class=HTMLResponse, + include_in_schema=False, + name="ui.simulations", ) - - -CURRENCY_CHOICES: list[Dict[str, Any]] = [ - {"id": "USD", "name": "US Dollar (USD)"}, - {"id": "EUR", "name": "Euro (EUR)"}, - {"id": "CLP", "name": "Chilean Peso (CLP)"}, - {"id": "RMB", "name": "Chinese Yuan (RMB)"}, - {"id": "GBP", "name": "British Pound (GBP)"}, - {"id": "CAD", "name": "Canadian Dollar (CAD)"}, - {"id": "AUD", "name": "Australian Dollar (AUD)"}, -] - -MEASUREMENT_UNITS: list[Dict[str, Any]] = [ - {"id": "tonnes", "name": "Tonnes", "symbol": "t"}, - {"id": "kilograms", "name": "Kilograms", "symbol": "kg"}, - {"id": "pounds", "name": "Pounds", "symbol": "lb"}, - {"id": "liters", "name": "Liters", "symbol": "L"}, - {"id": "cubic_meters", "name": "Cubic Meters", "symbol": "m3"}, - {"id": "kilowatt_hours", "name": "Kilowatt Hours", "symbol": "kWh"}, -] - -router = APIRouter() - -# Set up Jinja2 templates directory -templates = Jinja2Templates(directory="templates") - - -def _context( - request: Request, extra: Optional[Dict[str, Any]] = None -) -> Dict[str, Any]: - payload: Dict[str, Any] = { - "request": request, - "current_year": datetime.now(timezone.utc).year, - } - if extra: - payload.update(extra) - return payload - - -def _render( +def simulations_dashboard( request: Request, - template_name: str, - extra: Optional[Dict[str, Any]] = None, -): - context = _context(request, extra) - return templates.TemplateResponse(request, template_name, context) - - -def _format_currency(value: float) -> str: - return f"${value:,.2f}" - - -def _format_decimal(value: float) -> str: - return f"{value:,.2f}" - - -def _format_int(value: int) -> str: - return f"{value:,}" - - -def _load_scenarios(db: Session) -> Dict[str, Any]: - scenarios: list[Dict[str, Any]] = [ + _: User = Depends(require_any_role_html(*READ_ROLES)), +) -> HTMLResponse: + return templates.TemplateResponse( + request, + "simulations.html", { - "id": item.id, - "name": item.name, - "description": item.description, - } - for item in db.query(Scenario).order_by(Scenario.name).all() - ] - return {"scenarios": scenarios} - - -def _load_parameters(db: Session) -> Dict[str, Any]: - grouped: defaultdict[int, list[Dict[str, Any]]] = defaultdict(list) - for param in db.query(Parameter).order_by( - Parameter.scenario_id, Parameter.id - ): - grouped[param.scenario_id].append( - { - "id": param.id, - "name": param.name, - "value": param.value, - "distribution_type": param.distribution_type, - "distribution_parameters": param.distribution_parameters, - } - ) - return {"parameters_by_scenario": dict(grouped)} - - -def _load_costs(db: Session) -> Dict[str, Any]: - capex_grouped: defaultdict[int, list[Dict[str, Any]]] = defaultdict(list) - for capex in db.query(Capex).order_by(Capex.scenario_id, Capex.id).all(): - capex_grouped[int(getattr(capex, "scenario_id"))].append( - { - "id": int(getattr(capex, "id")), - "scenario_id": int(getattr(capex, "scenario_id")), - "amount": float(getattr(capex, "amount", 0.0)), - "description": getattr(capex, "description", "") or "", - "currency_code": getattr(capex, "currency_code", "USD") - or "USD", - } - ) - - opex_grouped: defaultdict[int, list[Dict[str, Any]]] = defaultdict(list) - for opex in db.query(Opex).order_by(Opex.scenario_id, Opex.id).all(): - opex_grouped[int(getattr(opex, "scenario_id"))].append( - { - "id": int(getattr(opex, "id")), - "scenario_id": int(getattr(opex, "scenario_id")), - "amount": float(getattr(opex, "amount", 0.0)), - "description": getattr(opex, "description", "") or "", - "currency_code": getattr(opex, "currency_code", "USD") or "USD", - } - ) - - return { - "capex_by_scenario": dict(capex_grouped), - "opex_by_scenario": dict(opex_grouped), - } - - -def _load_currencies(db: Session) -> Dict[str, Any]: - items: list[Dict[str, Any]] = [] - for c in ( - db.query(Currency) - .filter_by(is_active=True) - .order_by(Currency.code) - .all() - ): - items.append( - {"id": c.code, "name": f"{c.name} ({c.code})", "symbol": c.symbol} - ) - if not items: - items.append({"id": "USD", "name": "US Dollar (USD)", "symbol": "$"}) - return {"currency_options": items} - - -def _load_currency_settings(db: Session) -> Dict[str, Any]: - _ensure_default_currency(db) - records = db.query(Currency).order_by(Currency.code).all() - currencies: list[Dict[str, Any]] = [] - for record in records: - code_value = getattr(record, "code") - currencies.append( - { - "id": int(getattr(record, "id")), - "code": code_value, - "name": getattr(record, "name"), - "symbol": getattr(record, "symbol"), - "is_active": bool(getattr(record, "is_active", True)), - "is_default": code_value == DEFAULT_CURRENCY_CODE, - } - ) - - active_count = sum(1 for item in currencies if item["is_active"]) - inactive_count = len(currencies) - active_count - - return { - "currencies": currencies, - "currency_stats": { - "total": len(currencies), - "active": active_count, - "inactive": inactive_count, + "title": "Simulations", }, - "default_currency_code": DEFAULT_CURRENCY_CODE, - "currency_api_base": "/api/currencies", - } - - -def _load_css_settings(db: Session) -> Dict[str, Any]: - variables = get_css_color_settings(db) - env_overrides = read_css_color_env_overrides() - env_rows = list_css_env_override_rows() - env_meta = {row["css_key"]: row for row in env_rows} - return { - "css_variables": variables, - "css_defaults": CSS_COLOR_DEFAULTS, - "css_env_overrides": env_overrides, - "css_env_override_rows": env_rows, - "css_env_override_meta": env_meta, - } - - -def _load_consumption(db: Session) -> Dict[str, Any]: - grouped: defaultdict[int, list[Dict[str, Any]]] = defaultdict(list) - for record in ( - db.query(Consumption) - .order_by(Consumption.scenario_id, Consumption.id) - .all() - ): - record_id = int(getattr(record, "id")) - scenario_id = int(getattr(record, "scenario_id")) - amount_value = float(getattr(record, "amount", 0.0)) - description = getattr(record, "description", "") or "" - unit_name = getattr(record, "unit_name", None) - unit_symbol = getattr(record, "unit_symbol", None) - grouped[scenario_id].append( - { - "id": record_id, - "scenario_id": scenario_id, - "amount": amount_value, - "description": description, - "unit_name": unit_name, - "unit_symbol": unit_symbol, - } - ) - return {"consumption_by_scenario": dict(grouped)} - - -def _load_production(db: Session) -> Dict[str, Any]: - grouped: defaultdict[int, list[Dict[str, Any]]] = defaultdict(list) - for record in ( - db.query(ProductionOutput) - .order_by(ProductionOutput.scenario_id, ProductionOutput.id) - .all() - ): - record_id = int(getattr(record, "id")) - scenario_id = int(getattr(record, "scenario_id")) - amount_value = float(getattr(record, "amount", 0.0)) - description = getattr(record, "description", "") or "" - unit_name = getattr(record, "unit_name", None) - unit_symbol = getattr(record, "unit_symbol", None) - grouped[scenario_id].append( - { - "id": record_id, - "scenario_id": scenario_id, - "amount": amount_value, - "description": description, - "unit_name": unit_name, - "unit_symbol": unit_symbol, - } - ) - return {"production_by_scenario": dict(grouped)} - - -def _load_equipment(db: Session) -> Dict[str, Any]: - grouped: defaultdict[int, list[Dict[str, Any]]] = defaultdict(list) - for record in ( - db.query(Equipment).order_by(Equipment.scenario_id, Equipment.id).all() - ): - record_id = int(getattr(record, "id")) - scenario_id = int(getattr(record, "scenario_id")) - name_value = getattr(record, "name", "") or "" - description = getattr(record, "description", "") or "" - grouped[scenario_id].append( - { - "id": record_id, - "scenario_id": scenario_id, - "name": name_value, - "description": description, - } - ) - return {"equipment_by_scenario": dict(grouped)} - - -def _load_maintenance(db: Session) -> Dict[str, Any]: - grouped: defaultdict[int, list[Dict[str, Any]]] = defaultdict(list) - for record in ( - db.query(Maintenance) - .order_by(Maintenance.scenario_id, Maintenance.maintenance_date) - .all() - ): - record_id = int(getattr(record, "id")) - scenario_id = int(getattr(record, "scenario_id")) - equipment_id = int(getattr(record, "equipment_id")) - equipment_obj = getattr(record, "equipment", None) - equipment_name = ( - getattr(equipment_obj, "name", "") if equipment_obj else "" - ) - maintenance_date = getattr(record, "maintenance_date", None) - cost_value = float(getattr(record, "cost", 0.0)) - description = getattr(record, "description", "") or "" - - grouped[scenario_id].append( - { - "id": record_id, - "scenario_id": scenario_id, - "equipment_id": equipment_id, - "equipment_name": equipment_name, - "maintenance_date": ( - maintenance_date.isoformat() if maintenance_date else "" - ), - "cost": cost_value, - "description": description, - } - ) - return {"maintenance_by_scenario": dict(grouped)} - - -def _load_simulations(db: Session) -> Dict[str, Any]: - scenarios: list[Dict[str, Any]] = [ - { - "id": item.id, - "name": item.name, - } - for item in db.query(Scenario).order_by(Scenario.name).all() - ] - - results_grouped: defaultdict[int, list[Dict[str, Any]]] = defaultdict(list) - for record in ( - db.query(SimulationResult) - .order_by(SimulationResult.scenario_id, SimulationResult.iteration) - .all() - ): - scenario_id = int(getattr(record, "scenario_id")) - results_grouped[scenario_id].append( - { - "iteration": int(getattr(record, "iteration")), - "result": float(getattr(record, "result", 0.0)), - } - ) - - runs: list[Dict[str, Any]] = [] - sample_limit = 20 - for item in scenarios: - scenario_id = int(item["id"]) - scenario_results = results_grouped.get(scenario_id, []) - summary = ( - generate_report(scenario_results) - if scenario_results - else generate_report([]) - ) - runs.append( - { - "scenario_id": scenario_id, - "scenario_name": item["name"], - "iterations": int(summary.get("count", 0)), - "summary": summary, - "sample_results": scenario_results[:sample_limit], - } - ) - - return { - "simulation_scenarios": scenarios, - "simulation_runs": runs, - } - - -def _load_reporting(db: Session) -> Dict[str, Any]: - scenarios = _load_scenarios(db)["scenarios"] - runs = _load_simulations(db)["simulation_runs"] - - summaries: list[Dict[str, Any]] = [] - runs_by_scenario = {run["scenario_id"]: run for run in runs} - - for scenario in scenarios: - scenario_id = scenario["id"] - run = runs_by_scenario.get(scenario_id) - summary = run["summary"] if run else generate_report([]) - summaries.append( - { - "scenario_id": scenario_id, - "scenario_name": scenario["name"], - "summary": summary, - "iterations": run["iterations"] if run else 0, - } - ) - - return { - "report_summaries": summaries, - } - - -def _load_dashboard(db: Session) -> Dict[str, Any]: - scenarios = _load_scenarios(db)["scenarios"] - parameters_by_scenario = _load_parameters(db)["parameters_by_scenario"] - costs_context = _load_costs(db) - capex_by_scenario = costs_context["capex_by_scenario"] - opex_by_scenario = costs_context["opex_by_scenario"] - consumption_by_scenario = _load_consumption(db)["consumption_by_scenario"] - production_by_scenario = _load_production(db)["production_by_scenario"] - equipment_by_scenario = _load_equipment(db)["equipment_by_scenario"] - maintenance_by_scenario = _load_maintenance(db)["maintenance_by_scenario"] - simulation_context = _load_simulations(db) - simulation_runs = simulation_context["simulation_runs"] - - runs_by_scenario = {run["scenario_id"]: run for run in simulation_runs} - - def sum_amounts( - grouped: Dict[int, list[Dict[str, Any]]], field: str = "amount" - ) -> float: - total = 0.0 - for items in grouped.values(): - for item in items: - value = item.get(field, 0.0) - if isinstance(value, (int, float)): - total += float(value) - return total - - total_capex = sum_amounts(capex_by_scenario) - total_opex = sum_amounts(opex_by_scenario) - total_consumption = sum_amounts(consumption_by_scenario) - total_production = sum_amounts(production_by_scenario) - total_maintenance_cost = sum_amounts(maintenance_by_scenario, field="cost") - - total_parameters = sum( - len(items) for items in parameters_by_scenario.values() - ) - total_equipment = sum( - len(items) for items in equipment_by_scenario.values() - ) - total_maintenance_events = sum( - len(items) for items in maintenance_by_scenario.values() - ) - total_simulation_iterations = sum( - run["iterations"] for run in simulation_runs ) - scenario_rows: list[Dict[str, Any]] = [] - scenario_labels: list[str] = [] - scenario_capex: list[float] = [] - scenario_opex: list[float] = [] - activity_labels: list[str] = [] - activity_production: list[float] = [] - activity_consumption: list[float] = [] - for scenario in scenarios: - scenario_id = scenario["id"] - scenario_name = scenario["name"] - param_count = len(parameters_by_scenario.get(scenario_id, [])) - equipment_count = len(equipment_by_scenario.get(scenario_id, [])) - maintenance_count = len(maintenance_by_scenario.get(scenario_id, [])) - - capex_total = sum( - float(item.get("amount", 0.0)) - for item in capex_by_scenario.get(scenario_id, []) - ) - opex_total = sum( - float(item.get("amount", 0.0)) - for item in opex_by_scenario.get(scenario_id, []) - ) - consumption_total = sum( - float(item.get("amount", 0.0)) - for item in consumption_by_scenario.get(scenario_id, []) - ) - production_total = sum( - float(item.get("amount", 0.0)) - for item in production_by_scenario.get(scenario_id, []) - ) - - run = runs_by_scenario.get(scenario_id) - summary = run["summary"] if run else generate_report([]) - iterations = run["iterations"] if run else 0 - mean_value = float(summary.get("mean", 0.0)) - - scenario_rows.append( - { - "scenario_name": scenario_name, - "parameter_count": param_count, - "parameter_display": _format_int(param_count), - "equipment_count": equipment_count, - "equipment_display": _format_int(equipment_count), - "capex_total": capex_total, - "capex_display": _format_currency(capex_total), - "opex_total": opex_total, - "opex_display": _format_currency(opex_total), - "production_total": production_total, - "production_display": _format_decimal(production_total), - "consumption_total": consumption_total, - "consumption_display": _format_decimal(consumption_total), - "maintenance_count": maintenance_count, - "maintenance_display": _format_int(maintenance_count), - "iterations": iterations, - "iterations_display": _format_int(iterations), - "simulation_mean": mean_value, - "simulation_mean_display": _format_decimal(mean_value), - } - ) - - scenario_labels.append(scenario_name) - scenario_capex.append(capex_total) - scenario_opex.append(opex_total) - - activity_labels.append(scenario_name) - activity_production.append(production_total) - activity_consumption.append(consumption_total) - - scenario_rows.sort(key=lambda row: row["scenario_name"].lower()) - - all_simulation_results = [ - {"result": float(getattr(item, "result", 0.0))} - for item in db.query(SimulationResult).all() - ] - overall_report = generate_report(all_simulation_results) - - overall_report_metrics = [ +@router.get( + "/ui/reporting", + response_class=HTMLResponse, + include_in_schema=False, + name="ui.reporting", +) +def reporting_dashboard( + request: Request, + _: User = Depends(require_any_role_html(*READ_ROLES)), +) -> HTMLResponse: + return templates.TemplateResponse( + request, + "reporting.html", { - "label": "Runs", - "value": _format_int(int(overall_report.get("count", 0))), + "title": "Reporting", }, - { - "label": "Mean", - "value": _format_decimal(float(overall_report.get("mean", 0.0))), - }, - { - "label": "Median", - "value": _format_decimal(float(overall_report.get("median", 0.0))), - }, - { - "label": "Std Dev", - "value": _format_decimal(float(overall_report.get("std_dev", 0.0))), - }, - { - "label": "95th Percentile", - "value": _format_decimal( - float(overall_report.get("percentile_95", 0.0)) - ), - }, - { - "label": "VaR (95%)", - "value": _format_decimal( - float(overall_report.get("value_at_risk_95", 0.0)) - ), - }, - { - "label": "Expected Shortfall (95%)", - "value": _format_decimal( - float(overall_report.get("expected_shortfall_95", 0.0)) - ), - }, - ] - - recent_simulations: list[Dict[str, Any]] = [ - { - "scenario_name": run["scenario_name"], - "iterations": run["iterations"], - "iterations_display": _format_int(run["iterations"]), - "mean_display": _format_decimal( - float(run["summary"].get("mean", 0.0)) - ), - "p95_display": _format_decimal( - float(run["summary"].get("percentile_95", 0.0)) - ), - } - for run in simulation_runs - if run["iterations"] > 0 - ] - recent_simulations.sort(key=lambda item: item["iterations"], reverse=True) - recent_simulations = recent_simulations[:5] - - upcoming_maintenance: list[Dict[str, Any]] = [] - for record in ( - db.query(Maintenance) - .order_by(Maintenance.maintenance_date.asc()) - .limit(5) - .all() - ): - maintenance_date = getattr(record, "maintenance_date", None) - upcoming_maintenance.append( - { - "scenario_name": getattr( - getattr(record, "scenario", None), "name", "Unknown" - ), - "equipment_name": getattr( - getattr(record, "equipment", None), "name", "Unknown" - ), - "date_display": ( - maintenance_date.strftime("%Y-%m-%d") - if maintenance_date - else "—" - ), - "cost_display": _format_currency( - float(getattr(record, "cost", 0.0)) - ), - "description": getattr(record, "description", "") or "—", - } - ) - - cost_chart_has_data = any(value > 0 for value in scenario_capex) or any( - value > 0 for value in scenario_opex ) - activity_chart_has_data = any( - value > 0 for value in activity_production - ) or any(value > 0 for value in activity_consumption) - scenario_cost_chart: Dict[str, list[Any]] = { - "labels": scenario_labels, - "capex": scenario_capex, - "opex": scenario_opex, - } - scenario_activity_chart: Dict[str, list[Any]] = { - "labels": activity_labels, - "production": activity_production, - "consumption": activity_consumption, - } - summary_metrics = [ - {"label": "Active Scenarios", "value": _format_int(len(scenarios))}, - {"label": "Parameters", "value": _format_int(total_parameters)}, - {"label": "CAPEX Total", "value": _format_currency(total_capex)}, - {"label": "OPEX Total", "value": _format_currency(total_opex)}, - {"label": "Equipment Assets", "value": _format_int(total_equipment)}, +@router.get( + "/ui/settings", + response_class=HTMLResponse, + include_in_schema=False, + name="ui.settings", +) +def settings_page( + request: Request, + _: User = Depends(require_any_role_html(*READ_ROLES)), +) -> HTMLResponse: + return templates.TemplateResponse( + request, + "settings.html", { - "label": "Maintenance Events", - "value": _format_int(total_maintenance_events), + "title": "Settings", }, - {"label": "Consumption", "value": _format_decimal(total_consumption)}, - {"label": "Production", "value": _format_decimal(total_production)}, + ) + + +@router.get( + "/theme-settings", + response_class=HTMLResponse, + include_in_schema=False, + name="ui.theme_settings", +) +def theme_settings_page( + request: Request, + _: User = Depends(require_any_role_html(*READ_ROLES)), +) -> HTMLResponse: + return templates.TemplateResponse( + request, + "theme_settings.html", { - "label": "Simulation Iterations", - "value": _format_int(total_simulation_iterations), + "title": "Theme Settings", }, + ) + + +@router.get( + "/ui/currencies", + response_class=HTMLResponse, + include_in_schema=False, + name="ui.currencies", +) +def currencies_page( + request: Request, + _: User = Depends(require_roles_html(*MANAGE_ROLES)), +) -> HTMLResponse: + return templates.TemplateResponse( + request, + "currencies.html", { - "label": "Maintenance Cost", - "value": _format_currency(total_maintenance_cost), + "title": "Currency Management", }, - ] - - return { - "summary_metrics": summary_metrics, - "scenario_rows": scenario_rows, - "overall_report_metrics": overall_report_metrics, - "recent_simulations": recent_simulations, - "upcoming_maintenance": upcoming_maintenance, - "scenario_cost_chart": scenario_cost_chart, - "scenario_activity_chart": scenario_activity_chart, - "cost_chart_has_data": cost_chart_has_data, - "activity_chart_has_data": activity_chart_has_data, - "report_available": overall_report.get("count", 0) > 0, - } - - -@router.get("/", response_class=HTMLResponse) -async def dashboard_root(request: Request, db: Session = Depends(get_db)): - """Render the primary dashboard landing page.""" - return _render(request, "Dashboard.html", _load_dashboard(db)) - - -@router.get("/ui/dashboard", response_class=HTMLResponse) -async def dashboard(request: Request, db: Session = Depends(get_db)): - """Render the legacy dashboard route for backward compatibility.""" - return _render(request, "Dashboard.html", _load_dashboard(db)) - - -@router.get("/ui/dashboard/data", response_class=JSONResponse) -async def dashboard_data(db: Session = Depends(get_db)) -> JSONResponse: - """Expose dashboard aggregates as JSON for client-side refreshes.""" - return JSONResponse(_load_dashboard(db)) - - -@router.get("/ui/scenarios", response_class=HTMLResponse) -async def scenario_form(request: Request, db: Session = Depends(get_db)): - """Render the scenario creation form.""" - context = _load_scenarios(db) - return _render(request, "ScenarioForm.html", context) - - -@router.get("/ui/parameters", response_class=HTMLResponse) -async def parameter_form(request: Request, db: Session = Depends(get_db)): - """Render the parameter input form.""" - context: Dict[str, Any] = {} - context.update(_load_scenarios(db)) - context.update(_load_parameters(db)) - return _render(request, "ParameterInput.html", context) - - -@router.get("/ui/costs", response_class=HTMLResponse) -async def costs_view(request: Request, db: Session = Depends(get_db)): - """Render the costs view with CAPEX and OPEX data.""" - context: Dict[str, Any] = {} - context.update(_load_scenarios(db)) - context.update(_load_costs(db)) - context.update(_load_currencies(db)) - return _render(request, "costs.html", context) - - -@router.get("/ui/consumption", response_class=HTMLResponse) -async def consumption_view(request: Request, db: Session = Depends(get_db)): - """Render the consumption view with scenario consumption data.""" - context: Dict[str, Any] = {} - context.update(_load_scenarios(db)) - context.update(_load_consumption(db)) - context["unit_options"] = MEASUREMENT_UNITS - return _render(request, "consumption.html", context) - - -@router.get("/ui/production", response_class=HTMLResponse) -async def production_view(request: Request, db: Session = Depends(get_db)): - """Render the production view with scenario production data.""" - context: Dict[str, Any] = {} - context.update(_load_scenarios(db)) - context.update(_load_production(db)) - context["unit_options"] = MEASUREMENT_UNITS - return _render(request, "production.html", context) - - -@router.get("/ui/equipment", response_class=HTMLResponse) -async def equipment_view(request: Request, db: Session = Depends(get_db)): - """Render the equipment view with scenario equipment data.""" - context: Dict[str, Any] = {} - context.update(_load_scenarios(db)) - context.update(_load_equipment(db)) - return _render(request, "equipment.html", context) - - -@router.get("/ui/maintenance", response_class=HTMLResponse) -async def maintenance_view(request: Request, db: Session = Depends(get_db)): - """Render the maintenance view with scenario maintenance data.""" - context: Dict[str, Any] = {} - context.update(_load_scenarios(db)) - context.update(_load_equipment(db)) - context.update(_load_maintenance(db)) - return _render(request, "maintenance.html", context) - - -@router.get("/ui/simulations", response_class=HTMLResponse) -async def simulations_view(request: Request, db: Session = Depends(get_db)): - """Render the simulations view with scenario information and recent runs.""" - return _render(request, "simulations.html", _load_simulations(db)) - - -@router.get("/ui/reporting", response_class=HTMLResponse) -async def reporting_view(request: Request, db: Session = Depends(get_db)): - """Render the reporting view with scenario KPI summaries.""" - return _render(request, "reporting.html", _load_reporting(db)) - - -@router.get("/ui/settings", response_class=HTMLResponse) -async def settings_view(request: Request, db: Session = Depends(get_db)): - """Render the settings landing page.""" - context = _load_css_settings(db) - return _render(request, "settings.html", context) - - -@router.get("/ui/currencies", response_class=HTMLResponse) -async def currencies_view(request: Request, db: Session = Depends(get_db)): - """Render the currency administration page with full currency context.""" - context = _load_currency_settings(db) - return _render(request, "currencies.html", context) - - -@router.get("/login", response_class=HTMLResponse) -async def login_page(request: Request): - return _render(request, "login.html") - - -@router.get("/register", response_class=HTMLResponse) -async def register_page(request: Request): - return _render(request, "register.html") - - -@router.get("/profile", response_class=HTMLResponse) -async def profile_page(request: Request): - return _render(request, "profile.html") - - -@router.get("/forgot-password", response_class=HTMLResponse) -async def forgot_password_page(request: Request): - return _render(request, "forgot_password.html") - - -@router.get("/theme-settings", response_class=HTMLResponse) -async def theme_settings_page(request: Request, db: Session = Depends(get_db)): - """Render the theme settings page.""" - context = _load_css_settings(db) - return _render(request, "theme_settings.html", context) + ) diff --git a/routes/users.py b/routes/users.py deleted file mode 100644 index 5de7092..0000000 --- a/routes/users.py +++ /dev/null @@ -1,107 +0,0 @@ -from fastapi import APIRouter, Depends, HTTPException, status -from sqlalchemy.orm import Session - -from config.database import get_db -from models.user import User -from services.security import create_access_token, get_current_user -from schemas.user import ( - PasswordReset, - PasswordResetRequest, - UserCreate, - UserInDB, - UserLogin, - UserUpdate, -) - -router = APIRouter(prefix="/users", tags=["users"]) - - -@router.post("/register", response_model=UserInDB, status_code=status.HTTP_201_CREATED) -async def register_user(user: UserCreate, db: Session = Depends(get_db)): - db_user = db.query(User).filter(User.username == user.username).first() - if db_user: - raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, - detail="Username already registered") - db_user = db.query(User).filter(User.email == user.email).first() - if db_user: - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, detail="Email already registered") - - # Get or create default role - from models.role import Role - default_role = db.query(Role).filter(Role.name == "user").first() - if not default_role: - default_role = Role(name="user") - db.add(default_role) - db.commit() - db.refresh(default_role) - - new_user = User(username=user.username, email=user.email, - role_id=default_role.id) - new_user.set_password(user.password) - db.add(new_user) - db.commit() - db.refresh(new_user) - return new_user - - -@router.post("/login") -async def login_user(user: UserLogin, db: Session = Depends(get_db)): - db_user = db.query(User).filter(User.username == user.username).first() - if not db_user or not db_user.check_password(user.password): - raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, - detail="Incorrect username or password") - access_token = create_access_token(subject=db_user.username) - return {"access_token": access_token, "token_type": "bearer"} - - -@router.get("/me") -async def read_users_me(current_user: User = Depends(get_current_user)): - return current_user - - -@router.put("/me", response_model=UserInDB) -async def update_user_me(user_update: UserUpdate, current_user: User = Depends(get_current_user), db: Session = Depends(get_db)): - if user_update.username and user_update.username != current_user.username: - existing_user = db.query(User).filter( - User.username == user_update.username).first() - if existing_user: - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, detail="Username already taken") - setattr(current_user, "username", user_update.username) - - if user_update.email and user_update.email != current_user.email: - existing_user = db.query(User).filter( - User.email == user_update.email).first() - if existing_user: - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, detail="Email already registered") - setattr(current_user, "email", user_update.email) - - if user_update.password: - current_user.set_password(user_update.password) - - db.add(current_user) - db.commit() - db.refresh(current_user) - return current_user - - -@router.post("/forgot-password") -async def forgot_password(request: PasswordResetRequest): - # In a real application, this would send an email with a reset token - return {"message": "Password reset email sent (not really)"} - - -@router.post("/reset-password") -async def reset_password(request: PasswordReset, db: Session = Depends(get_db)): - # In a real application, the token would be verified - user = db.query(User).filter(User.username == - request.token).first() # Use token as username for test - if not user: - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid token or user") - user.set_password(request.new_password) - db.add(user) - db.commit() - return {"message": "Password has been reset successfully"} diff --git a/schemas/auth.py b/schemas/auth.py new file mode 100644 index 0000000..3a16191 --- /dev/null +++ b/schemas/auth.py @@ -0,0 +1,67 @@ +from __future__ import annotations + +from pydantic import BaseModel, ConfigDict, Field, ValidationInfo, field_validator + + +class FormModel(BaseModel): + """Base Pydantic model for HTML form submissions.""" + + model_config = ConfigDict(extra="forbid", str_strip_whitespace=True) + + +class RegistrationForm(FormModel): + username: str = Field(min_length=3, max_length=128) + email: str = Field(min_length=5, max_length=255) + password: str = Field(min_length=8, max_length=256) + confirm_password: str + + @field_validator("email") + @classmethod + def validate_email(cls, value: str) -> str: + if "@" not in value or value.startswith("@") or value.endswith("@"): + raise ValueError("Invalid email address.") + local, domain = value.split("@", 1) + if not local or "." not in domain: + raise ValueError("Invalid email address.") + return value.lower() + + @field_validator("confirm_password") + @classmethod + def passwords_match(cls, value: str, info: ValidationInfo) -> str: + password = info.data.get("password") + if password != value: + raise ValueError("Passwords do not match.") + return value + + +class LoginForm(FormModel): + username: str = Field(min_length=1, max_length=255) + password: str = Field(min_length=1, max_length=256) + + +class PasswordResetRequestForm(FormModel): + email: str = Field(min_length=5, max_length=255) + + @field_validator("email") + @classmethod + def validate_email(cls, value: str) -> str: + if "@" not in value or value.startswith("@") or value.endswith("@"): + raise ValueError("Invalid email address.") + local, domain = value.split("@", 1) + if not local or "." not in domain: + raise ValueError("Invalid email address.") + return value.lower() + + +class PasswordResetForm(FormModel): + token: str = Field(min_length=1) + password: str = Field(min_length=8, max_length=256) + confirm_password: str + + @field_validator("confirm_password") + @classmethod + def reset_passwords_match(cls, value: str, info: ValidationInfo) -> str: + password = info.data.get("password") + if password != value: + raise ValueError("Passwords do not match.") + return value diff --git a/schemas/calculations.py b/schemas/calculations.py new file mode 100644 index 0000000..407812f --- /dev/null +++ b/schemas/calculations.py @@ -0,0 +1,346 @@ +"""Pydantic schemas for calculation workflows.""" + +from __future__ import annotations + +from typing import List + +from pydantic import BaseModel, Field, PositiveFloat, ValidationError, field_validator + +from services.pricing import PricingResult + + +class ImpurityInput(BaseModel): + """Impurity configuration row supplied by the client.""" + + name: str = Field(..., min_length=1) + value: float | None = Field(None, ge=0) + threshold: float | None = Field(None, ge=0) + penalty: float | None = Field(None) + + @field_validator("name") + @classmethod + def _normalise_name(cls, value: str) -> str: + return value.strip() + + +class ProfitabilityCalculationRequest(BaseModel): + """Request payload for profitability calculations.""" + + metal: str = Field(..., min_length=1) + ore_tonnage: PositiveFloat + head_grade_pct: float = Field(..., gt=0, le=100) + recovery_pct: float = Field(..., gt=0, le=100) + payable_pct: float | None = Field(None, gt=0, le=100) + reference_price: PositiveFloat + treatment_charge: float = Field(0, ge=0) + smelting_charge: float = Field(0, ge=0) + moisture_pct: float = Field(0, ge=0, le=100) + moisture_threshold_pct: float | None = Field(None, ge=0, le=100) + moisture_penalty_per_pct: float | None = None + premiums: float = Field(0) + fx_rate: PositiveFloat = Field(1) + currency_code: str | None = Field(None, min_length=3, max_length=3) + opex: float = Field(0, ge=0) + sustaining_capex: float = Field(0, ge=0) + capex: float = Field(0, ge=0) + discount_rate: float | None = Field(None, ge=0, le=100) + periods: int = Field(10, ge=1, le=120) + impurities: List[ImpurityInput] = Field(default_factory=list) + + @field_validator("currency_code") + @classmethod + def _uppercase_currency(cls, value: str | None) -> str | None: + if value is None: + return None + return value.strip().upper() + + @field_validator("metal") + @classmethod + def _normalise_metal(cls, value: str) -> str: + return value.strip().lower() + + +class ProfitabilityCosts(BaseModel): + """Aggregated cost components for profitability output.""" + + opex_total: float + sustaining_capex_total: float + capex: float + + +class ProfitabilityMetrics(BaseModel): + """Financial KPIs yielded by the profitability calculation.""" + + npv: float | None + irr: float | None + payback_period: float | None + margin: float | None + + +class CashFlowEntry(BaseModel): + """Normalized cash flow row for reporting and charting.""" + + period: int + revenue: float + opex: float + sustaining_capex: float + net: float + + +class ProfitabilityCalculationResult(BaseModel): + """Response body summarizing profitability calculation outputs.""" + + pricing: PricingResult + costs: ProfitabilityCosts + metrics: ProfitabilityMetrics + cash_flows: list[CashFlowEntry] + currency: str | None + + +class CapexComponentInput(BaseModel): + """Capex component entry supplied by the UI.""" + + id: int | None = Field(default=None, ge=1) + name: str = Field(..., min_length=1) + category: str = Field(..., min_length=1) + amount: float = Field(..., ge=0) + currency: str | None = Field(None, min_length=3, max_length=3) + spend_year: int | None = Field(None, ge=0, le=120) + notes: str | None = Field(None, max_length=500) + + @field_validator("currency") + @classmethod + def _uppercase_currency(cls, value: str | None) -> str | None: + if value is None: + return None + return value.strip().upper() + + @field_validator("category") + @classmethod + def _normalise_category(cls, value: str) -> str: + return value.strip().lower() + + @field_validator("name") + @classmethod + def _trim_name(cls, value: str) -> str: + return value.strip() + + +class CapexParameters(BaseModel): + """Global parameters applied to capex calculations.""" + + currency_code: str | None = Field(None, min_length=3, max_length=3) + contingency_pct: float | None = Field(0, ge=0, le=100) + discount_rate_pct: float | None = Field(None, ge=0, le=100) + evaluation_horizon_years: int | None = Field(10, ge=1, le=100) + + @field_validator("currency_code") + @classmethod + def _uppercase_currency(cls, value: str | None) -> str | None: + if value is None: + return None + return value.strip().upper() + + +class CapexCalculationOptions(BaseModel): + """Optional behaviour flags for capex calculations.""" + + persist: bool = False + + +class CapexCalculationRequest(BaseModel): + """Request payload for capex aggregation.""" + + components: List[CapexComponentInput] = Field(default_factory=list) + parameters: CapexParameters = Field( + default_factory=CapexParameters, # type: ignore[arg-type] + ) + options: CapexCalculationOptions = Field( + default_factory=CapexCalculationOptions, # type: ignore[arg-type] + ) + + +class CapexCategoryBreakdown(BaseModel): + """Breakdown entry describing category totals.""" + + category: str + amount: float = Field(..., ge=0) + share: float | None = Field(None, ge=0, le=100) + + +class CapexTotals(BaseModel): + """Aggregated totals for capex workflows.""" + + overall: float = Field(..., ge=0) + contingency_pct: float = Field(0, ge=0, le=100) + contingency_amount: float = Field(..., ge=0) + with_contingency: float = Field(..., ge=0) + by_category: List[CapexCategoryBreakdown] = Field(default_factory=list) + + +class CapexTimelineEntry(BaseModel): + """Spend profile entry grouped by year.""" + + year: int + spend: float = Field(..., ge=0) + cumulative: float = Field(..., ge=0) + + +class CapexCalculationResult(BaseModel): + """Response body for capex calculations.""" + + totals: CapexTotals + timeline: List[CapexTimelineEntry] = Field(default_factory=list) + components: List[CapexComponentInput] = Field(default_factory=list) + parameters: CapexParameters + options: CapexCalculationOptions + currency: str | None + + +class OpexComponentInput(BaseModel): + """opex component entry supplied by the UI.""" + + id: int | None = Field(default=None, ge=1) + name: str = Field(..., min_length=1) + category: str = Field(..., min_length=1) + unit_cost: float = Field(..., ge=0) + quantity: float = Field(..., ge=0) + frequency: str = Field(..., min_length=1) + currency: str | None = Field(None, min_length=3, max_length=3) + period_start: int | None = Field(None, ge=0, le=240) + period_end: int | None = Field(None, ge=0, le=240) + notes: str | None = Field(None, max_length=500) + + @field_validator("currency") + @classmethod + def _uppercase_currency(cls, value: str | None) -> str | None: + if value is None: + return None + return value.strip().upper() + + @field_validator("category") + @classmethod + def _normalise_category(cls, value: str) -> str: + return value.strip().lower() + + @field_validator("frequency") + @classmethod + def _normalise_frequency(cls, value: str) -> str: + return value.strip().lower() + + @field_validator("name") + @classmethod + def _trim_name(cls, value: str) -> str: + return value.strip() + + +class OpexParameters(BaseModel): + """Global parameters applied to opex calculations.""" + + currency_code: str | None = Field(None, min_length=3, max_length=3) + escalation_pct: float | None = Field(None, ge=0, le=100) + discount_rate_pct: float | None = Field(None, ge=0, le=100) + evaluation_horizon_years: int | None = Field(10, ge=1, le=100) + apply_escalation: bool = True + + @field_validator("currency_code") + @classmethod + def _uppercase_currency(cls, value: str | None) -> str | None: + if value is None: + return None + return value.strip().upper() + + +class OpexOptions(BaseModel): + """Optional behaviour flags for opex calculations.""" + + persist: bool = False + snapshot_notes: str | None = Field(None, max_length=500) + + +class OpexCalculationRequest(BaseModel): + """Request payload for opex aggregation.""" + + components: List[OpexComponentInput] = Field( + default_factory=list) + parameters: OpexParameters = Field( + default_factory=OpexParameters, # type: ignore[arg-type] + ) + options: OpexOptions = Field( + default_factory=OpexOptions, # type: ignore[arg-type] + ) + + +class OpexCategoryBreakdown(BaseModel): + """Category breakdown for opex totals.""" + + category: str + annual_cost: float = Field(..., ge=0) + share: float | None = Field(None, ge=0, le=100) + + +class OpexTimelineEntry(BaseModel): + """Timeline entry representing cost over evaluation periods.""" + + period: int + base_cost: float = Field(..., ge=0) + escalated_cost: float | None = Field(None, ge=0) + + +class OpexMetrics(BaseModel): + """Derived KPIs for opex outputs.""" + + annual_average: float | None + cost_per_ton: float | None + + +class OpexTotals(BaseModel): + """Aggregated totals for opex.""" + + overall_annual: float = Field(..., ge=0) + escalated_total: float | None = Field(None, ge=0) + escalation_pct: float | None = Field(None, ge=0, le=100) + by_category: List[OpexCategoryBreakdown] = Field( + default_factory=list + ) + + +class OpexCalculationResult(BaseModel): + """Response body summarising opex calculations.""" + + totals: OpexTotals + timeline: List[OpexTimelineEntry] = Field(default_factory=list) + metrics: OpexMetrics + components: List[OpexComponentInput] = Field( + default_factory=list) + parameters: OpexParameters + options: OpexOptions + currency: str | None + + +__all__ = [ + "ImpurityInput", + "ProfitabilityCalculationRequest", + "ProfitabilityCosts", + "ProfitabilityMetrics", + "CashFlowEntry", + "ProfitabilityCalculationResult", + "CapexComponentInput", + "CapexParameters", + "CapexCalculationOptions", + "CapexCalculationRequest", + "CapexCategoryBreakdown", + "CapexTotals", + "CapexTimelineEntry", + "CapexCalculationResult", + "OpexComponentInput", + "OpexParameters", + "OpexOptions", + "OpexCalculationRequest", + "OpexCategoryBreakdown", + "OpexTimelineEntry", + "OpexMetrics", + "OpexTotals", + "OpexCalculationResult", + "ValidationError", +] diff --git a/schemas/exports.py b/schemas/exports.py new file mode 100644 index 0000000..f1a6105 --- /dev/null +++ b/schemas/exports.py @@ -0,0 +1,69 @@ +from __future__ import annotations + +from enum import Enum +from typing import Literal + +from pydantic import BaseModel, ConfigDict, field_validator + +from services.export_query import ProjectExportFilters, ScenarioExportFilters + + +class ExportFormat(str, Enum): + CSV = "csv" + XLSX = "xlsx" + + +class BaseExportRequest(BaseModel): + format: ExportFormat = ExportFormat.CSV + include_metadata: bool = False + + model_config = ConfigDict(extra="forbid") + + +class ProjectExportRequest(BaseExportRequest): + filters: ProjectExportFilters | None = None + + @field_validator("filters", mode="before") + @classmethod + def validate_filters(cls, value: ProjectExportFilters | None) -> ProjectExportFilters | None: + if value is None: + return None + if isinstance(value, ProjectExportFilters): + return value + return ProjectExportFilters(**value) + + +class ScenarioExportRequest(BaseExportRequest): + filters: ScenarioExportFilters | None = None + + @field_validator("filters", mode="before") + @classmethod + def validate_filters(cls, value: ScenarioExportFilters | None) -> ScenarioExportFilters | None: + if value is None: + return None + if isinstance(value, ScenarioExportFilters): + return value + return ScenarioExportFilters(**value) + + +class ExportTicket(BaseModel): + token: str + format: ExportFormat + resource: Literal["projects", "scenarios"] + + model_config = ConfigDict(extra="forbid") + + +class ExportResponse(BaseModel): + ticket: ExportTicket + + model_config = ConfigDict(extra="forbid") + + +__all__ = [ + "ExportFormat", + "ProjectExportRequest", + "ScenarioExportRequest", + "ExportTicket", + "ExportResponse", +] diff --git a/schemas/imports.py b/schemas/imports.py new file mode 100644 index 0000000..e9f3895 --- /dev/null +++ b/schemas/imports.py @@ -0,0 +1,292 @@ +from __future__ import annotations + +from datetime import date, datetime +from typing import Any, Mapping +from typing import Literal + +from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator + +from models import MiningOperationType, ResourceType, ScenarioStatus +from services.currency import CurrencyValidationError, normalise_currency + +PreviewStateLiteral = Literal["new", "update", "skip", "error"] + + +def _normalise_string(value: Any) -> str: + if value is None: + return "" + if isinstance(value, str): + return value.strip() + return str(value).strip() + + +def _strip_or_none(value: Any | None) -> str | None: + if value is None: + return None + text = _normalise_string(value) + return text or None + + +def _coerce_enum(value: Any, enum_cls: Any, aliases: Mapping[str, Any]) -> Any: + if value is None: + return value + if isinstance(value, enum_cls): + return value + text = _normalise_string(value).lower() + if not text: + return None + if text in aliases: + return aliases[text] + try: + return enum_cls(text) + except ValueError as exc: # pragma: no cover - surfaced by Pydantic + raise ValueError( + f"Invalid value '{value}' for {enum_cls.__name__}") from exc + + +OPERATION_TYPE_ALIASES: dict[str, MiningOperationType] = { + "open pit": MiningOperationType.OPEN_PIT, + "openpit": MiningOperationType.OPEN_PIT, + "underground": MiningOperationType.UNDERGROUND, + "in-situ leach": MiningOperationType.IN_SITU_LEACH, + "in situ": MiningOperationType.IN_SITU_LEACH, + "placer": MiningOperationType.PLACER, + "quarry": MiningOperationType.QUARRY, + "mountaintop removal": MiningOperationType.MOUNTAINTOP_REMOVAL, + "other": MiningOperationType.OTHER, +} + + +SCENARIO_STATUS_ALIASES: dict[str, ScenarioStatus] = { + "draft": ScenarioStatus.DRAFT, + "active": ScenarioStatus.ACTIVE, + "archived": ScenarioStatus.ARCHIVED, +} + + +RESOURCE_TYPE_ALIASES: dict[str, ResourceType] = { + key.replace("_", " ").lower(): value for key, value in ResourceType.__members__.items() +} +RESOURCE_TYPE_ALIASES.update( + {value.value.replace("_", " ").lower(): value for value in ResourceType} +) + + +class ProjectImportRow(BaseModel): + name: str + location: str | None = None + operation_type: MiningOperationType + description: str | None = None + created_at: datetime | None = None + updated_at: datetime | None = None + + model_config = ConfigDict(extra="forbid") + + @field_validator("name", mode="before") + @classmethod + def validate_name(cls, value: Any) -> str: + text = _normalise_string(value) + if not text: + raise ValueError("Project name is required") + return text + + @field_validator("location", "description", mode="before") + @classmethod + def optional_text(cls, value: Any | None) -> str | None: + return _strip_or_none(value) + + @field_validator("operation_type", mode="before") + @classmethod + def map_operation_type(cls, value: Any) -> MiningOperationType | None: + return _coerce_enum(value, MiningOperationType, OPERATION_TYPE_ALIASES) + + +class ScenarioImportRow(BaseModel): + project_name: str + name: str + status: ScenarioStatus = ScenarioStatus.DRAFT + start_date: date | None = None + end_date: date | None = None + discount_rate: float | None = None + currency: str | None = None + primary_resource: ResourceType | None = None + description: str | None = None + created_at: datetime | None = None + updated_at: datetime | None = None + + model_config = ConfigDict(extra="forbid") + + @field_validator("project_name", "name", mode="before") + @classmethod + def validate_required_text(cls, value: Any, info) -> str: + text = _normalise_string(value) + if not text: + raise ValueError( + f"{info.field_name.replace('_', ' ').title()} is required") + return text + + @field_validator("status", mode="before") + @classmethod + def map_status(cls, value: Any) -> ScenarioStatus | None: + return _coerce_enum(value, ScenarioStatus, SCENARIO_STATUS_ALIASES) + + @field_validator("primary_resource", mode="before") + @classmethod + def map_resource(cls, value: Any) -> ResourceType | None: + return _coerce_enum(value, ResourceType, RESOURCE_TYPE_ALIASES) + + @field_validator("description", mode="before") + @classmethod + def optional_description(cls, value: Any | None) -> str | None: + return _strip_or_none(value) + + @field_validator("currency", mode="before") + @classmethod + def normalise_currency(cls, value: Any | None) -> str | None: + text = _strip_or_none(value) + if text is None: + return None + try: + return normalise_currency(text) + except CurrencyValidationError as exc: + raise ValueError(str(exc)) from exc + + @field_validator("discount_rate", mode="before") + @classmethod + def coerce_discount_rate(cls, value: Any | None) -> float | None: + if value is None: + return None + if isinstance(value, (int, float)): + return float(value) + text = _normalise_string(value) + if not text: + return None + if text.endswith("%"): + text = text[:-1] + try: + return float(text) + except ValueError as exc: + raise ValueError("Discount rate must be numeric") from exc + + @model_validator(mode="after") + def validate_dates(self) -> "ScenarioImportRow": + if self.start_date and self.end_date and self.start_date > self.end_date: + raise ValueError("End date must be on or after start date") + return self + + +class ImportRowErrorModel(BaseModel): + row_number: int + field: str | None = None + message: str + + model_config = ConfigDict(from_attributes=True, extra="forbid") + + +class ImportPreviewRowIssueModel(BaseModel): + message: str + field: str | None = None + + model_config = ConfigDict(from_attributes=True, extra="forbid") + + +class ImportPreviewRowIssuesModel(BaseModel): + row_number: int + state: PreviewStateLiteral | None = None + issues: list[ImportPreviewRowIssueModel] = Field(default_factory=list) + + model_config = ConfigDict(from_attributes=True, extra="forbid") + + +class ImportPreviewSummaryModel(BaseModel): + total_rows: int + accepted: int + skipped: int + errored: int + + model_config = ConfigDict(from_attributes=True, extra="forbid") + + +class ProjectImportPreviewRow(BaseModel): + row_number: int + data: ProjectImportRow + state: PreviewStateLiteral + issues: list[str] = Field(default_factory=list) + context: dict[str, Any] | None = None + + model_config = ConfigDict(from_attributes=True, extra="forbid") + + +class ScenarioImportPreviewRow(BaseModel): + row_number: int + data: ScenarioImportRow + state: PreviewStateLiteral + issues: list[str] = Field(default_factory=list) + context: dict[str, Any] | None = None + + model_config = ConfigDict(from_attributes=True, extra="forbid") + + +class ProjectImportPreviewResponse(BaseModel): + rows: list[ProjectImportPreviewRow] + summary: ImportPreviewSummaryModel + row_issues: list[ImportPreviewRowIssuesModel] = Field(default_factory=list) + parser_errors: list[ImportRowErrorModel] = Field(default_factory=list) + stage_token: str | None = None + + model_config = ConfigDict(from_attributes=True, extra="forbid") + + +class ScenarioImportPreviewResponse(BaseModel): + rows: list[ScenarioImportPreviewRow] + summary: ImportPreviewSummaryModel + row_issues: list[ImportPreviewRowIssuesModel] = Field(default_factory=list) + parser_errors: list[ImportRowErrorModel] = Field(default_factory=list) + stage_token: str | None = None + + model_config = ConfigDict(from_attributes=True, extra="forbid") + + +class ImportCommitSummaryModel(BaseModel): + created: int + updated: int + + model_config = ConfigDict(from_attributes=True, extra="forbid") + + +class ProjectImportCommitRow(BaseModel): + row_number: int + data: ProjectImportRow + context: dict[str, Any] + + model_config = ConfigDict(from_attributes=True, extra="forbid") + + +class ScenarioImportCommitRow(BaseModel): + row_number: int + data: ScenarioImportRow + context: dict[str, Any] + + model_config = ConfigDict(from_attributes=True, extra="forbid") + + +class ProjectImportCommitResponse(BaseModel): + token: str + rows: list[ProjectImportCommitRow] + summary: ImportCommitSummaryModel + + model_config = ConfigDict(from_attributes=True, extra="forbid") + + +class ScenarioImportCommitResponse(BaseModel): + token: str + rows: list[ScenarioImportCommitRow] + summary: ImportCommitSummaryModel + + model_config = ConfigDict(from_attributes=True, extra="forbid") + + +class ImportCommitRequest(BaseModel): + token: str + + model_config = ConfigDict(extra="forbid") diff --git a/schemas/navigation.py b/schemas/navigation.py new file mode 100644 index 0000000..4701e9b --- /dev/null +++ b/schemas/navigation.py @@ -0,0 +1,36 @@ +from __future__ import annotations + +from datetime import datetime +from typing import List + +from pydantic import BaseModel, Field + + +class NavigationLinkSchema(BaseModel): + id: int + label: str + href: str + match_prefix: str | None = Field(default=None) + icon: str | None = Field(default=None) + tooltip: str | None = Field(default=None) + is_external: bool = Field(default=False) + children: List["NavigationLinkSchema"] = Field(default_factory=list) + + +class NavigationGroupSchema(BaseModel): + id: int + label: str + icon: str | None = Field(default=None) + tooltip: str | None = Field(default=None) + links: List[NavigationLinkSchema] = Field(default_factory=list) + + +class NavigationSidebarResponse(BaseModel): + groups: List[NavigationGroupSchema] + roles: List[str] = Field(default_factory=list) + generated_at: datetime + + +NavigationLinkSchema.model_rebuild() +NavigationGroupSchema.model_rebuild() +NavigationSidebarResponse.model_rebuild() diff --git a/schemas/project.py b/schemas/project.py new file mode 100644 index 0000000..1b0107d --- /dev/null +++ b/schemas/project.py @@ -0,0 +1,37 @@ +from __future__ import annotations + +from datetime import datetime + +from pydantic import BaseModel, ConfigDict + +from models import MiningOperationType + + +class ProjectBase(BaseModel): + name: str + location: str | None = None + operation_type: MiningOperationType + description: str | None = None + + model_config = ConfigDict(extra="forbid") + + +class ProjectCreate(ProjectBase): + pass + + +class ProjectUpdate(BaseModel): + name: str | None = None + location: str | None = None + operation_type: MiningOperationType | None = None + description: str | None = None + + model_config = ConfigDict(extra="forbid") + + +class ProjectRead(ProjectBase): + id: int + created_at: datetime + updated_at: datetime + + model_config = ConfigDict(from_attributes=True) diff --git a/schemas/scenario.py b/schemas/scenario.py new file mode 100644 index 0000000..4295868 --- /dev/null +++ b/schemas/scenario.py @@ -0,0 +1,97 @@ +from __future__ import annotations + +from datetime import date, datetime + +from pydantic import BaseModel, ConfigDict, field_validator, model_validator + +from models import ResourceType, ScenarioStatus +from services.currency import CurrencyValidationError, normalise_currency + + +class ScenarioBase(BaseModel): + name: str + description: str | None = None + status: ScenarioStatus = ScenarioStatus.DRAFT + start_date: date | None = None + end_date: date | None = None + discount_rate: float | None = None + currency: str | None = None + primary_resource: ResourceType | None = None + + model_config = ConfigDict(extra="forbid") + + @field_validator("currency") + @classmethod + def normalise_currency(cls, value: str | None) -> str | None: + if value is None: + return None + candidate = value if isinstance(value, str) else str(value) + candidate = candidate.strip() + if not candidate: + return None + try: + return normalise_currency(candidate) + except CurrencyValidationError as exc: + raise ValueError(str(exc)) from exc + + +class ScenarioCreate(ScenarioBase): + pass + + +class ScenarioUpdate(BaseModel): + name: str | None = None + description: str | None = None + status: ScenarioStatus | None = None + start_date: date | None = None + end_date: date | None = None + discount_rate: float | None = None + currency: str | None = None + primary_resource: ResourceType | None = None + + model_config = ConfigDict(extra="forbid") + + @field_validator("currency") + @classmethod + def normalise_currency(cls, value: str | None) -> str | None: + if value is None: + return None + candidate = value if isinstance(value, str) else str(value) + candidate = candidate.strip() + if not candidate: + return None + try: + return normalise_currency(candidate) + except CurrencyValidationError as exc: + raise ValueError(str(exc)) from exc + + +class ScenarioRead(ScenarioBase): + id: int + project_id: int + created_at: datetime + updated_at: datetime + + model_config = ConfigDict(from_attributes=True) + + +class ScenarioComparisonRequest(BaseModel): + scenario_ids: list[int] + + model_config = ConfigDict(extra="forbid") + + @model_validator(mode="after") + def ensure_minimum_ids(self) -> "ScenarioComparisonRequest": + unique_ids: list[int] = list(dict.fromkeys(self.scenario_ids)) + if len(unique_ids) < 2: + raise ValueError( + "At least two unique scenario identifiers are required for comparison.") + self.scenario_ids = unique_ids + return self + + +class ScenarioComparisonResponse(BaseModel): + project_id: int + scenarios: list[ScenarioRead] + + model_config = ConfigDict(from_attributes=True) diff --git a/schemas/user.py b/schemas/user.py deleted file mode 100644 index fafce5b..0000000 --- a/schemas/user.py +++ /dev/null @@ -1,41 +0,0 @@ -from pydantic import BaseModel, ConfigDict - - -class UserCreate(BaseModel): - username: str - email: str - password: str - - -class UserInDB(BaseModel): - id: int - username: str - email: str - role_id: int - - model_config = ConfigDict(from_attributes=True) - - -class UserLogin(BaseModel): - username: str - password: str - - -class UserUpdate(BaseModel): - username: str | None = None - email: str | None = None - password: str | None = None - - -class PasswordResetRequest(BaseModel): - email: str - - -class PasswordReset(BaseModel): - token: str - new_password: str - - -class Token(BaseModel): - access_token: str - token_type: str diff --git a/scripts/00_initial_data.py b/scripts/00_initial_data.py new file mode 100644 index 0000000..e189001 --- /dev/null +++ b/scripts/00_initial_data.py @@ -0,0 +1,20 @@ +from __future__ import annotations + +import logging + +from scripts.initial_data import load_config, seed_initial_data + + +def main() -> int: + logging.basicConfig(level=logging.INFO, format="[%(levelname)s] %(message)s") + try: + config = load_config() + seed_initial_data(config) + except Exception as exc: # pragma: no cover - operational guard + logging.exception("Seeding failed: %s", exc) + return 1 + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/scripts/__init__.py b/scripts/__init__.py new file mode 100644 index 0000000..395066d --- /dev/null +++ b/scripts/__init__.py @@ -0,0 +1 @@ +"""Utility scripts for CalMiner maintenance tasks.""" diff --git a/scripts/_route_verification.py b/scripts/_route_verification.py new file mode 100644 index 0000000..2b97182 --- /dev/null +++ b/scripts/_route_verification.py @@ -0,0 +1,112 @@ +"""Utility script to verify key authenticated routes respond without errors.""" +from __future__ import annotations + +import json +import os +import sys +import urllib.parse +from http.client import HTTPConnection +from http.cookies import SimpleCookie +from typing import Dict, List, Tuple + +HOST = "127.0.0.1" +PORT = 8000 + +cookies: Dict[str, str] = {} + + +def _update_cookies(headers: List[Tuple[str, str]]) -> None: + for name, value in headers: + if name.lower() != "set-cookie": + continue + cookie = SimpleCookie() + cookie.load(value) + for key, morsel in cookie.items(): + cookies[key] = morsel.value + + +def _cookie_header() -> str | None: + if not cookies: + return None + return "; ".join(f"{key}={value}" for key, value in cookies.items()) + + +def request(method: str, path: str, *, body: bytes | None = None, headers: Dict[str, str] | None = None) -> Tuple[int, Dict[str, str], bytes]: + conn = HTTPConnection(HOST, PORT, timeout=10) + prepared_headers = {"User-Agent": "route-checker"} + if headers: + prepared_headers.update(headers) + cookie_header = _cookie_header() + if cookie_header: + prepared_headers["Cookie"] = cookie_header + + conn.request(method, path, body=body, headers=prepared_headers) + resp = conn.getresponse() + payload = resp.read() + status = resp.status + reason = resp.reason + response_headers = {name: value for name, value in resp.getheaders()} + _update_cookies(list(resp.getheaders())) + conn.close() + print(f"{method} {path} -> {status} {reason}") + return status, response_headers, payload + + +def main() -> int: + status, _, _ = request("GET", "/login") + if status != 200: + print("Unexpected status for GET /login", file=sys.stderr) + return 1 + + admin_username = os.getenv("CALMINER_SEED_ADMIN_USERNAME", "admin") + admin_password = os.getenv("CALMINER_SEED_ADMIN_PASSWORD", "M11ffpgm.") + login_payload = urllib.parse.urlencode( + {"username": admin_username, "password": admin_password} + ).encode() + status, headers, _ = request( + "POST", + "/login", + body=login_payload, + headers={"Content-Type": "application/x-www-form-urlencoded"}, + ) + if status not in {200, 303}: + print("Login failed", file=sys.stderr) + return 1 + + location = headers.get("Location", "/") + redirect_path = urllib.parse.urlsplit(location).path or "/" + request("GET", redirect_path) + + request("GET", "/") + request("GET", "/projects/ui") + + status, headers, body = request( + "GET", + "/projects", + headers={"Accept": "application/json"}, + ) + projects: List[dict] = [] + if headers.get("Content-Type", "").startswith("application/json"): + projects = json.loads(body.decode()) + + if projects: + project_id = projects[0]["id"] + request("GET", f"/projects/{project_id}/view") + status, headers, body = request( + "GET", + f"/projects/{project_id}/scenarios", + headers={"Accept": "application/json"}, + ) + scenarios: List[dict] = [] + if headers.get("Content-Type", "").startswith("application/json"): + scenarios = json.loads(body.decode()) + if scenarios: + scenario_id = scenarios[0]["id"] + request("GET", f"/scenarios/{scenario_id}/view") + + print("Cookies:", cookies) + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/scripts/apply_users_sequence_fix.py b/scripts/apply_users_sequence_fix.py new file mode 100644 index 0000000..4322d91 --- /dev/null +++ b/scripts/apply_users_sequence_fix.py @@ -0,0 +1,15 @@ +from sqlalchemy import create_engine, text +from config.database import DATABASE_URL + +engine = create_engine(DATABASE_URL, future=True) +sqls = [ + "CREATE SEQUENCE IF NOT EXISTS users_id_seq;", + "ALTER TABLE users ALTER COLUMN id SET DEFAULT nextval('users_id_seq');", + "SELECT setval('users_id_seq', COALESCE((SELECT MAX(id) FROM users), 1));", + "ALTER SEQUENCE users_id_seq OWNED BY users.id;", +] +with engine.begin() as conn: + for s in sqls: + print('EXECUTING:', s) + conn.execute(text(s)) +print('SEQUENCE fix applied') diff --git a/scripts/backfill_currency.py b/scripts/backfill_currency.py deleted file mode 100644 index 4651021..0000000 --- a/scripts/backfill_currency.py +++ /dev/null @@ -1,157 +0,0 @@ -""" -Backfill script to populate currency_id for capex and opex rows using existing currency_code. - -Usage: - python scripts/backfill_currency.py --dry-run - python scripts/backfill_currency.py --create-missing - -This script is intentionally cautious: it defaults to dry-run mode and will refuse to run -if database connection settings are missing. It supports creating missing currency rows when `--create-missing` -is provided. Always run against a development/staging database first. -""" - -from __future__ import annotations -import argparse -import importlib -import sys -from pathlib import Path - -from sqlalchemy import text, create_engine - - -PROJECT_ROOT = Path(__file__).resolve().parent.parent -if str(PROJECT_ROOT) not in sys.path: - sys.path.insert(0, str(PROJECT_ROOT)) - - -def load_database_url() -> str: - try: - db_module = importlib.import_module("config.database") - except RuntimeError as exc: - raise RuntimeError( - "Database configuration missing: set DATABASE_URL or provide granular " - "variables (DATABASE_DRIVER, DATABASE_HOST, DATABASE_PORT, DATABASE_USER, " - "DATABASE_PASSWORD, DATABASE_NAME, optional DATABASE_SCHEMA)." - ) from exc - - return getattr(db_module, "DATABASE_URL") - - -def backfill( - db_url: str, dry_run: bool = True, create_missing: bool = False -) -> None: - engine = create_engine(db_url) - with engine.begin() as conn: - # Ensure currency table exists - if db_url.startswith("sqlite:"): - conn.execute( - text( - "SELECT name FROM sqlite_master WHERE type='table' AND name='currency';" - ) - ) - else: - conn.execute(text("SELECT to_regclass('public.currency');")) - # Note: we don't strictly depend on the above - we assume migration was already applied - - # Helper: find or create currency by code - def find_currency_id(code: str): - r = conn.execute( - text("SELECT id FROM currency WHERE code = :code"), - {"code": code}, - ).fetchone() - if r: - return r[0] - if create_missing: - # insert and return id - conn.execute( - text( - "INSERT INTO currency (code, name, symbol, is_active) VALUES (:c, :n, NULL, TRUE)" - ), - {"c": code, "n": code}, - ) - r2 = conn.execute( - text("SELECT id FROM currency WHERE code = :code"), - {"code": code}, - ).fetchone() - if not r2: - raise RuntimeError( - f"Unable to determine currency ID for '{code}' after insert" - ) - return r2[0] - return None - - # Process tables capex and opex - for table in ("capex", "opex"): - # Check if currency_id column exists - try: - cols = ( - conn.execute( - text( - f"SELECT 1 FROM information_schema.columns WHERE table_name = '{table}' AND column_name = 'currency_id'" - ) - ) - if not db_url.startswith("sqlite:") - else [(1,)] - ) - except Exception: - cols = [(1,)] - - if not cols: - print(f"Skipping {table}: no currency_id column found") - continue - - # Find rows where currency_id IS NULL but currency_code exists - rows = conn.execute( - text( - f"SELECT id, currency_code FROM {table} WHERE currency_id IS NULL OR currency_id = ''" - ) - ) - changed = 0 - for r in rows: - rid = r[0] - code = (r[1] or "USD").strip().upper() - cid = find_currency_id(code) - if cid is None: - print( - f"Row {table}:{rid} has unknown currency code '{code}' and create_missing=False; skipping" - ) - continue - if dry_run: - print( - f"[DRY RUN] Would set {table}.currency_id = {cid} for row id={rid} (code={code})" - ) - else: - conn.execute( - text( - f"UPDATE {table} SET currency_id = :cid WHERE id = :rid" - ), - {"cid": cid, "rid": rid}, - ) - changed += 1 - - print(f"{table}: processed, changed={changed} (dry_run={dry_run})") - - -def main() -> None: - parser = argparse.ArgumentParser( - description="Backfill currency_id from currency_code for capex/opex tables" - ) - parser.add_argument( - "--dry-run", - action="store_true", - default=True, - help="Show actions without writing", - ) - parser.add_argument( - "--create-missing", - action="store_true", - help="Create missing currency rows in the currency table", - ) - args = parser.parse_args() - - db = load_database_url() - backfill(db, dry_run=args.dry_run, create_missing=args.create_missing) - - -if __name__ == "__main__": - main() diff --git a/scripts/check_docs_links.py b/scripts/check_docs_links.py deleted file mode 100644 index aebc1fe..0000000 --- a/scripts/check_docs_links.py +++ /dev/null @@ -1,50 +0,0 @@ -"""Simple Markdown link checker for local docs/ files. - -Checks only local file links (relative paths) and reports missing targets. - -Run from the repository root using the project's Python environment. -""" - -import re -from pathlib import Path - -ROOT = Path(__file__).resolve().parent.parent -DOCS = ROOT / "docs" - -MD_LINK_RE = re.compile(r"\[([^\]]+)\]\(([^)]+)\)") - -errors = [] - -for md in DOCS.rglob("*.md"): - text = md.read_text(encoding="utf-8") - for m in MD_LINK_RE.finditer(text): - label, target = m.groups() - # skip URLs - if ( - target.startswith("http://") - or target.startswith("https://") - or target.startswith("#") - ): - continue - # strip anchors - target_path = target.split("#")[0] - # if link is to a directory index, allow - candidate = (md.parent / target_path).resolve() - if candidate.exists(): - continue - # check common implicit index: target/ -> target/README.md or target/index.md - candidate_dir = md.parent / target_path - if candidate_dir.is_dir(): - if (candidate_dir / "README.md").exists() or ( - candidate_dir / "index.md" - ).exists(): - continue - errors.append((str(md.relative_to(ROOT)), target, label)) - -if errors: - print("Broken local links found:") - for src, tgt, label in errors: - print(f"- {src} -> {tgt} ({label})") - exit(2) - -print("No broken local links detected.") diff --git a/scripts/format_docs_md.py b/scripts/format_docs_md.py deleted file mode 100644 index 5e1e856..0000000 --- a/scripts/format_docs_md.py +++ /dev/null @@ -1,92 +0,0 @@ -"""Lightweight Markdown formatter: normalizes first-line H1, adds code-fence language hints for common shebangs, trims trailing whitespace. - -This is intentionally small and non-destructive; it touches only files under docs/ and makes safe changes. -""" - -import re -from pathlib import Path - -DOCS = Path(__file__).resolve().parents[1] / "docs" - -CODE_LANG_HINTS = { - "powershell": ("powershell",), - "bash": ("bash", "sh"), - "sql": ("sql",), - "python": ("python",), -} - - -def add_code_fence_language(match): - fence = match.group(0) - inner = match.group(1) - # If language already present, return unchanged - if fence.startswith("```") and len(fence.splitlines()[0].strip()) > 3: - return fence - # Try to infer language from the code content - code = inner.strip().splitlines()[0] if inner.strip() else "" - lang = "" - if ( - code.startswith("$") - or code.startswith("PS") - or code.lower().startswith("powershell") - ): - lang = "powershell" - elif ( - code.startswith("#") - or code.startswith("import") - or code.startswith("from") - ): - lang = "python" - elif re.match(r"^(select|insert|update|create)\b", code.strip(), re.I): - lang = "sql" - elif ( - code.startswith("git") - or code.startswith("./") - or code.startswith("sudo") - ): - lang = "bash" - if lang: - return f"```{lang}\n{inner}\n```" - return fence - - -def normalize_file(path: Path): - text = path.read_text(encoding="utf-8") - orig = text - # Trim trailing whitespace and ensure single trailing newline - text = "\n".join(line.rstrip() for line in text.splitlines()) + "\n" - # Ensure first non-empty line is H1 - lines = text.splitlines() - for i, ln in enumerate(lines): - if ln.strip(): - if not ln.startswith("#"): - lines[i] = "# " + ln - break - text = "\n".join(lines) + "\n" - # Add basic code fence languages where missing (simple heuristic) - text = re.sub(r"```\n([\s\S]*?)\n```", add_code_fence_language, text) - if text != orig: - path.write_text(text, encoding="utf-8") - return True - return False - - -def main(): - changed = [] - for p in DOCS.rglob("*.md"): - if p.is_file(): - try: - if normalize_file(p): - changed.append(str(p.relative_to(Path.cwd()))) - except Exception as e: - print(f"Failed to format {p}: {e}") - if changed: - print("Formatted files:") - for c in changed: - print(" -", c) - else: - print("No formatting changes required.") - - -if __name__ == "__main__": - main() diff --git a/scripts/init_db.py b/scripts/init_db.py new file mode 100644 index 0000000..958ce4f --- /dev/null +++ b/scripts/init_db.py @@ -0,0 +1,1468 @@ +"""Idempotent DB initialization and seeding using Pydantic validation and raw SQL. + +Usage: + from scripts.init_db import init_db + init_db() + +This module creates PostgreSQL ENUM types if missing, creates minimal tables +required for bootstrapping (roles, users, user_roles, pricing_settings and +ancillary pricing tables), and seeds initial rows using INSERT ... ON CONFLICT +DO NOTHING so it's safe to run multiple times. + +Notes: +- This module avoids importing application models at import time to prevent + side-effects. Database connections are created inside functions. +- It intentionally performs non-destructive operations only (CREATE IF NOT + EXISTS, INSERT ... ON CONFLICT). +""" +from __future__ import annotations + +from typing import List, Optional, Set +import os +import logging +from decimal import Decimal + +from pydantic import BaseModel, Field, field_validator +from sqlalchemy import JSON, create_engine, text +from sqlalchemy.engine import Engine +from passlib.context import CryptContext +from sqlalchemy.sql import bindparam + +logger = logging.getLogger(__name__) +password_context = CryptContext(schemes=["argon2"], deprecated="auto") + +# ENUM definitions matching previous schema +ENUM_DEFINITIONS = { + "miningoperationtype": [ + "open_pit", + "underground", + "in_situ_leach", + "placer", + "quarry", + "mountaintop_removal", + "other", + ], + "scenariostatus": ["draft", "active", "archived"], + "financialcategory": ["capex", "opex", "revenue", "contingency", "other"], + "costbucket": [ + "capital_initial", + "capital_sustaining", + "operating_fixed", + "operating_variable", + "maintenance", + "reclamation", + "royalties", + "general_admin", + ], + "distributiontype": ["normal", "triangular", "uniform", "lognormal", "custom"], + "stochasticvariable": [ + "ore_grade", + "recovery_rate", + "metal_price", + "operating_cost", + "capital_cost", + "discount_rate", + "throughput", + ], + "resourcetype": [ + "diesel", + "electricity", + "water", + "explosives", + "reagents", + "labor", + "equipment_hours", + "tailings_capacity", + ], +} + +# Minimal DDL for tables we seed / that bootstrap relies on + + +def _get_table_ddls(is_sqlite: bool) -> List[str]: + if is_sqlite: + return [ + # roles + """ + CREATE TABLE IF NOT EXISTS roles ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + name TEXT NOT NULL UNIQUE, + display_name TEXT NOT NULL, + description TEXT, + created_at DATETIME DEFAULT CURRENT_TIMESTAMP, + updated_at DATETIME DEFAULT CURRENT_TIMESTAMP + ); + """, + # users + """ + CREATE TABLE IF NOT EXISTS users ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + email TEXT NOT NULL UNIQUE, + username TEXT NOT NULL UNIQUE, + password_hash TEXT NOT NULL, + is_active INTEGER NOT NULL DEFAULT 1, + is_superuser INTEGER NOT NULL DEFAULT 0, + last_login_at DATETIME, + created_at DATETIME DEFAULT CURRENT_TIMESTAMP, + updated_at DATETIME DEFAULT CURRENT_TIMESTAMP + ); + """, + # user_roles + """ + CREATE TABLE IF NOT EXISTS user_roles ( + user_id INTEGER NOT NULL, + role_id INTEGER NOT NULL, + granted_at DATETIME DEFAULT CURRENT_TIMESTAMP, + granted_by INTEGER, + PRIMARY KEY (user_id, role_id) + ); + """, + """ + CREATE TABLE IF NOT EXISTS navigation_groups ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + slug TEXT NOT NULL UNIQUE, + label TEXT NOT NULL, + sort_order INTEGER NOT NULL DEFAULT 100, + icon TEXT, + tooltip TEXT, + is_enabled INTEGER NOT NULL DEFAULT 1, + created_at DATETIME DEFAULT CURRENT_TIMESTAMP, + updated_at DATETIME DEFAULT CURRENT_TIMESTAMP + ); + """, + """ + CREATE TABLE IF NOT EXISTS navigation_links ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + group_id INTEGER NOT NULL REFERENCES navigation_groups(id) ON DELETE CASCADE, + parent_link_id INTEGER REFERENCES navigation_links(id) ON DELETE CASCADE, + slug TEXT NOT NULL, + label TEXT NOT NULL, + route_name TEXT, + href_override TEXT, + match_prefix TEXT, + sort_order INTEGER NOT NULL DEFAULT 100, + icon TEXT, + tooltip TEXT, + required_roles TEXT NOT NULL DEFAULT '[]', + is_enabled INTEGER NOT NULL DEFAULT 1, + is_external INTEGER NOT NULL DEFAULT 0, + created_at DATETIME DEFAULT CURRENT_TIMESTAMP, + updated_at DATETIME DEFAULT CURRENT_TIMESTAMP, + UNIQUE (group_id, slug) + ); + """, + # pricing_settings + """ + CREATE TABLE IF NOT EXISTS pricing_settings ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + name TEXT NOT NULL UNIQUE, + slug TEXT NOT NULL UNIQUE, + description TEXT, + default_currency TEXT, + default_payable_pct REAL DEFAULT 100.00 NOT NULL, + moisture_threshold_pct REAL DEFAULT 8.00 NOT NULL, + moisture_penalty_per_pct REAL DEFAULT 0.0000 NOT NULL, + metadata TEXT, + created_at DATETIME DEFAULT CURRENT_TIMESTAMP, + updated_at DATETIME DEFAULT CURRENT_TIMESTAMP + ); + """, + # pricing_metal_settings + """ + CREATE TABLE IF NOT EXISTS pricing_metal_settings ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + pricing_settings_id INTEGER NOT NULL REFERENCES pricing_settings(id) ON DELETE CASCADE, + metal_code TEXT NOT NULL, + payable_pct REAL, + moisture_threshold_pct REAL, + moisture_penalty_per_pct REAL, + data TEXT, + created_at DATETIME DEFAULT CURRENT_TIMESTAMP, + updated_at DATETIME DEFAULT CURRENT_TIMESTAMP, + UNIQUE (pricing_settings_id, metal_code) + ); + """, + # pricing_impurity_settings + """ + CREATE TABLE IF NOT EXISTS pricing_impurity_settings ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + pricing_settings_id INTEGER NOT NULL REFERENCES pricing_settings(id) ON DELETE CASCADE, + impurity_code TEXT NOT NULL, + threshold_ppm REAL DEFAULT 0.0000 NOT NULL, + penalty_per_ppm REAL DEFAULT 0.0000 NOT NULL, + notes TEXT, + created_at DATETIME DEFAULT CURRENT_TIMESTAMP, + updated_at DATETIME DEFAULT CURRENT_TIMESTAMP, + UNIQUE (pricing_settings_id, impurity_code) + ); + """, + # core domain tables: projects, scenarios, financial_inputs, simulation_parameters + """ + CREATE TABLE IF NOT EXISTS projects ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + name TEXT NOT NULL UNIQUE, + location TEXT, + operation_type TEXT NOT NULL CHECK (operation_type IN ('open_pit', 'underground', 'in_situ_leach', 'placer', 'quarry', 'mountaintop_removal', 'other')), + description TEXT, + pricing_settings_id INTEGER REFERENCES pricing_settings(id) ON DELETE SET NULL, + created_at DATETIME DEFAULT CURRENT_TIMESTAMP, + updated_at DATETIME DEFAULT CURRENT_TIMESTAMP + ); + """, + """ + CREATE TABLE IF NOT EXISTS scenarios ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + project_id INTEGER NOT NULL REFERENCES projects(id) ON DELETE CASCADE, + name TEXT NOT NULL, + description TEXT, + status TEXT NOT NULL CHECK (status IN ('draft', 'active', 'archived')), + start_date DATE, + end_date DATE, + discount_rate REAL, + currency TEXT, + primary_resource TEXT CHECK (primary_resource IN ('diesel', 'electricity', 'water', 'explosives', 'reagents', 'labor', 'equipment_hours', 'tailings_capacity') OR primary_resource IS NULL), + created_at DATETIME DEFAULT CURRENT_TIMESTAMP, + updated_at DATETIME DEFAULT CURRENT_TIMESTAMP, + UNIQUE (project_id, name) + ); + """, + """ + CREATE TABLE IF NOT EXISTS financial_inputs ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + scenario_id INTEGER NOT NULL REFERENCES scenarios(id) ON DELETE CASCADE, + name TEXT NOT NULL, + category TEXT NOT NULL CHECK (category IN ('capex', 'opex', 'revenue', 'contingency', 'other')), + cost_bucket TEXT CHECK (cost_bucket IN ('capital_initial', 'capital_sustaining', 'operating_fixed', 'operating_variable', 'maintenance', 'reclamation', 'royalties', 'general_admin') OR cost_bucket IS NULL), + amount REAL NOT NULL, + currency TEXT, + effective_date DATE, + notes TEXT, + created_at DATETIME DEFAULT CURRENT_TIMESTAMP, + updated_at DATETIME DEFAULT CURRENT_TIMESTAMP, + UNIQUE (scenario_id, name) + ); + """, + """ + CREATE TABLE IF NOT EXISTS simulation_parameters ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + scenario_id INTEGER NOT NULL REFERENCES scenarios(id) ON DELETE CASCADE, + name TEXT NOT NULL, + distribution TEXT NOT NULL CHECK (distribution IN ('normal', 'triangular', 'uniform', 'lognormal', 'custom')), + variable TEXT CHECK (variable IN ('ore_grade', 'recovery_rate', 'metal_price', 'operating_cost', 'capital_cost', 'discount_rate', 'throughput') OR variable IS NULL), + resource_type TEXT CHECK (resource_type IN ('diesel', 'electricity', 'water', 'explosives', 'reagents', 'labor', 'equipment_hours', 'tailings_capacity') OR resource_type IS NULL), + mean_value REAL, + standard_deviation REAL, + minimum_value REAL, + maximum_value REAL, + unit TEXT, + configuration TEXT, + created_at DATETIME DEFAULT CURRENT_TIMESTAMP, + updated_at DATETIME DEFAULT CURRENT_TIMESTAMP + ); + """, + ] + else: + # PostgreSQL DDLs + return [ + # roles + """ + CREATE TABLE IF NOT EXISTS roles ( + id INTEGER PRIMARY KEY, + name VARCHAR(64) NOT NULL, + display_name VARCHAR(128) NOT NULL, + description TEXT, + created_at TIMESTAMPTZ DEFAULT now(), + updated_at TIMESTAMPTZ DEFAULT now(), + CONSTRAINT uq_roles_name UNIQUE (name) + ); + """, + # users + """ + CREATE TABLE IF NOT EXISTS users ( + id SERIAL PRIMARY KEY, + email VARCHAR(255) NOT NULL, + username VARCHAR(128) NOT NULL, + password_hash VARCHAR(255) NOT NULL, + is_active BOOLEAN NOT NULL DEFAULT true, + is_superuser BOOLEAN NOT NULL DEFAULT false, + last_login_at TIMESTAMPTZ, + created_at TIMESTAMPTZ DEFAULT now(), + updated_at TIMESTAMPTZ DEFAULT now(), + CONSTRAINT uq_users_email UNIQUE (email), + CONSTRAINT uq_users_username UNIQUE (username) + ); + """, + # user_roles + """ + CREATE TABLE IF NOT EXISTS user_roles ( + user_id INTEGER NOT NULL, + role_id INTEGER NOT NULL, + granted_at TIMESTAMPTZ DEFAULT now(), + granted_by INTEGER, + PRIMARY KEY (user_id, role_id), + CONSTRAINT uq_user_roles_user_role UNIQUE (user_id, role_id) + ); + """, + """ + CREATE TABLE IF NOT EXISTS navigation_groups ( + id SERIAL PRIMARY KEY, + slug VARCHAR(64) NOT NULL, + label VARCHAR(128) NOT NULL, + sort_order INTEGER NOT NULL DEFAULT 100, + icon VARCHAR(64), + tooltip VARCHAR(255), + is_enabled BOOLEAN NOT NULL DEFAULT true, + created_at TIMESTAMPTZ DEFAULT now(), + updated_at TIMESTAMPTZ DEFAULT now(), + CONSTRAINT uq_navigation_groups_slug UNIQUE (slug) + ); + """, + """ + CREATE TABLE IF NOT EXISTS navigation_links ( + id SERIAL PRIMARY KEY, + group_id INTEGER NOT NULL REFERENCES navigation_groups(id) ON DELETE CASCADE, + parent_link_id INTEGER REFERENCES navigation_links(id) ON DELETE CASCADE, + slug VARCHAR(64) NOT NULL, + label VARCHAR(128) NOT NULL, + route_name VARCHAR(128), + href_override VARCHAR(512), + match_prefix VARCHAR(512), + sort_order INTEGER NOT NULL DEFAULT 100, + icon VARCHAR(64), + tooltip VARCHAR(255), + required_roles JSONB NOT NULL DEFAULT '[]'::jsonb, + is_enabled BOOLEAN NOT NULL DEFAULT true, + is_external BOOLEAN NOT NULL DEFAULT false, + created_at TIMESTAMPTZ DEFAULT now(), + updated_at TIMESTAMPTZ DEFAULT now(), + CONSTRAINT uq_navigation_links_group_slug UNIQUE (group_id, slug) + ); + """, + # pricing_settings + """ + CREATE TABLE IF NOT EXISTS pricing_settings ( + id SERIAL PRIMARY KEY, + name VARCHAR(128) NOT NULL, + slug VARCHAR(64) NOT NULL, + description TEXT, + default_currency VARCHAR(3), + default_payable_pct NUMERIC(5,2) DEFAULT 100.00 NOT NULL, + moisture_threshold_pct NUMERIC(5,2) DEFAULT 8.00 NOT NULL, + moisture_penalty_per_pct NUMERIC(14,4) DEFAULT 0.0000 NOT NULL, + metadata JSONB, + created_at TIMESTAMPTZ DEFAULT now(), + updated_at TIMESTAMPTZ DEFAULT now(), + CONSTRAINT uq_pricing_settings_slug UNIQUE (slug), + CONSTRAINT uq_pricing_settings_name UNIQUE (name) + ); + """, + # pricing_metal_settings + """ + CREATE TABLE IF NOT EXISTS pricing_metal_settings ( + id SERIAL PRIMARY KEY, + pricing_settings_id INTEGER NOT NULL REFERENCES pricing_settings(id) ON DELETE CASCADE, + metal_code VARCHAR(32) NOT NULL, + payable_pct NUMERIC(5,2), + moisture_threshold_pct NUMERIC(5,2), + moisture_penalty_per_pct NUMERIC(14,4), + data JSONB, + created_at TIMESTAMPTZ DEFAULT now(), + updated_at TIMESTAMPTZ DEFAULT now(), + CONSTRAINT uq_pricing_metal_settings_code UNIQUE (pricing_settings_id, metal_code) + ); + """, + # pricing_impurity_settings + """ + CREATE TABLE IF NOT EXISTS pricing_impurity_settings ( + id SERIAL PRIMARY KEY, + pricing_settings_id INTEGER NOT NULL REFERENCES pricing_settings(id) ON DELETE CASCADE, + impurity_code VARCHAR(32) NOT NULL, + threshold_ppm NUMERIC(14,4) DEFAULT 0.0000 NOT NULL, + penalty_per_ppm NUMERIC(14,4) DEFAULT 0.0000 NOT NULL, + notes TEXT, + created_at TIMESTAMPTZ DEFAULT now(), + updated_at TIMESTAMPTZ DEFAULT now(), + CONSTRAINT uq_pricing_impurity_settings_code UNIQUE (pricing_settings_id, impurity_code) + ); + """, + # core domain tables: projects, scenarios, financial_inputs, simulation_parameters + """ + CREATE TABLE IF NOT EXISTS projects ( + id SERIAL PRIMARY KEY, + name VARCHAR(255) NOT NULL, + location VARCHAR(255), + operation_type miningoperationtype NOT NULL, + description TEXT, + pricing_settings_id INTEGER REFERENCES pricing_settings(id) ON DELETE SET NULL, + created_at TIMESTAMPTZ DEFAULT now(), + updated_at TIMESTAMPTZ DEFAULT now(), + CONSTRAINT uq_projects_name UNIQUE (name) + ); + """, + """ + CREATE TABLE IF NOT EXISTS scenarios ( + id SERIAL PRIMARY KEY, + project_id INTEGER NOT NULL REFERENCES projects(id) ON DELETE CASCADE, + name VARCHAR(255) NOT NULL, + description TEXT, + status scenariostatus NOT NULL, + start_date DATE, + end_date DATE, + discount_rate NUMERIC(5,2), + currency VARCHAR(3), + primary_resource resourcetype, + created_at TIMESTAMPTZ DEFAULT now(), + updated_at TIMESTAMPTZ DEFAULT now(), + CONSTRAINT uq_scenarios_project_name UNIQUE (project_id, name) + ); + """, + """ + CREATE TABLE IF NOT EXISTS financial_inputs ( + id SERIAL PRIMARY KEY, + scenario_id INTEGER NOT NULL REFERENCES scenarios(id) ON DELETE CASCADE, + name VARCHAR(255) NOT NULL, + category financialcategory NOT NULL, + cost_bucket costbucket, + amount NUMERIC(18,2) NOT NULL, + currency VARCHAR(3), + effective_date DATE, + notes TEXT, + created_at TIMESTAMPTZ DEFAULT now(), + updated_at TIMESTAMPTZ DEFAULT now(), + CONSTRAINT uq_financial_inputs_scenario_name UNIQUE (scenario_id, name) + ); + """, + """ + CREATE TABLE IF NOT EXISTS simulation_parameters ( + id SERIAL PRIMARY KEY, + scenario_id INTEGER NOT NULL REFERENCES scenarios(id) ON DELETE CASCADE, + name VARCHAR(255) NOT NULL, + distribution distributiontype NOT NULL, + variable stochasticvariable, + resource_type resourcetype, + mean_value NUMERIC(18,4), + standard_deviation NUMERIC(18,4), + minimum_value NUMERIC(18,4), + maximum_value NUMERIC(18,4), + unit VARCHAR(32), + configuration JSONB, + created_at TIMESTAMPTZ DEFAULT now(), + updated_at TIMESTAMPTZ DEFAULT now() + ); + """, + ] + + +# Seeds +TABLE_DDLS: List[str] = _get_table_ddls(is_sqlite=False) + + +DEFAULT_ROLES = [ + {"id": 1, "name": "admin", "display_name": "Administrator", + "description": "Full platform access with user management rights."}, + {"id": 2, "name": "project_manager", "display_name": "Project Manager", + "description": "Manage projects, scenarios, and associated data."}, + {"id": 3, "name": "analyst", "display_name": "Analyst", + "description": "Review dashboards and scenario outputs."}, + {"id": 4, "name": "viewer", "display_name": "Viewer", + "description": "Read-only access to assigned projects and reports."}, +] + +DEFAULT_ADMIN = {"id": 1, "email": "admin@calminer.local", "username": "admin", + "password": "ChangeMe123!", "is_active": True, "is_superuser": True} +DEFAULT_PRICING = { + "slug": "default", + "name": "Default Pricing", + "description": "Automatically generated default pricing settings.", + "default_currency": "USD", + "default_payable_pct": 100.0, + "moisture_threshold_pct": 8.0, + "moisture_penalty_per_pct": 0.0, +} + + +class ProjectSeed(BaseModel): + name: str + location: str | None = None + operation_type: str + description: str | None = None + + +class ScenarioSeed(BaseModel): + project_name: str + name: str + description: str | None = None + status: str = "active" + discount_rate: float | None = Field(default=None) + currency: str | None = Field(default="USD") + primary_resource: str | None = Field(default=None) + + +class FinancialInputSeed(BaseModel): + scenario_name: str + project_name: str + name: str + category: str + cost_bucket: str | None = None + amount: Decimal + currency: str = "USD" + notes: str | None = None + + +class RoleSeed(BaseModel): + id: int + name: str + display_name: str + description: Optional[str] + + +class UserSeed(BaseModel): + id: int + email: str + username: str + password: str + is_active: bool = True + is_superuser: bool = False + + @field_validator("password") + def password_min_len(cls, v: str) -> str: + if not v or len(v) < 8: + raise ValueError("password must be at least 8 characters") + return v + + +class PricingSeed(BaseModel): + slug: str + name: str + description: Optional[str] + default_currency: Optional[str] + default_payable_pct: float + moisture_threshold_pct: float + moisture_penalty_per_pct: float + + +class NavigationGroupSeed(BaseModel): + slug: str + label: str + sort_order: int = 100 + icon: Optional[str] = None + tooltip: Optional[str] = None + is_enabled: bool = True + + +class NavigationLinkSeed(BaseModel): + slug: str + group_slug: str + label: str + route_name: Optional[str] = None + href_override: Optional[str] = None + match_prefix: Optional[str] = None + sort_order: int = 100 + icon: Optional[str] = None + tooltip: Optional[str] = None + required_roles: list[str] = Field(default_factory=list) + is_enabled: bool = True + is_external: bool = False + parent_slug: Optional[str] = None + + @field_validator("required_roles", mode="after") + def _normalise_roles(cls, value: list[str]) -> list[str]: + normalised = [] + for role in value: + if not role: + continue + slug = role.strip().lower() + if slug and slug not in normalised: + normalised.append(slug) + return normalised + + @field_validator("route_name") + def _route_or_href(cls, value: Optional[str], info): + href = info.data.get("href_override") + if not value and not href: + raise ValueError( + "navigation link requires route_name or href_override") + return value + + +DEFAULT_NAVIGATION_GROUPS: list[NavigationGroupSeed] = [ + NavigationGroupSeed( + slug="workspace", + label="Workspace", + sort_order=10, + icon="briefcase", + tooltip="Primary work hub", + ), + NavigationGroupSeed( + slug="insights", + label="Insights", + sort_order=20, + icon="insights", + tooltip="Analytics and reports", + ), + NavigationGroupSeed( + slug="configuration", + label="Configuration", + sort_order=30, + icon="cog", + tooltip="Administration and settings", + ), + NavigationGroupSeed( + slug="account", + label="Account", + sort_order=40, + icon="user", + tooltip="Session management", + ), +] + + +DEFAULT_NAVIGATION_LINKS: list[NavigationLinkSeed] = [ + NavigationLinkSeed( + slug="dashboard", + group_slug="workspace", + label="Dashboard", + route_name="dashboard.home", + match_prefix="/", + sort_order=10, + ), + NavigationLinkSeed( + slug="projects", + group_slug="workspace", + label="Projects", + route_name="projects.project_list_page", + match_prefix="/projects", + sort_order=20, + ), + NavigationLinkSeed( + slug="project-create", + group_slug="workspace", + label="New Project", + route_name="projects.create_project_form", + match_prefix="/projects/create", + sort_order=30, + required_roles=["project_manager", "admin"], + ), + NavigationLinkSeed( + slug="imports", + group_slug="workspace", + label="Imports", + href_override="/imports/ui", + match_prefix="/imports", + sort_order=40, + required_roles=["analyst", "admin"], + ), + NavigationLinkSeed( + slug="profitability", + group_slug="workspace", + label="Profitability Calculator", + route_name="calculations.profitability_form", + href_override="/calculations/profitability", + match_prefix="/calculations/profitability", + sort_order=50, + required_roles=["analyst", "admin"], + parent_slug="projects", + ), + NavigationLinkSeed( + slug="opex", + group_slug="workspace", + label="Opex Planner", + route_name="calculations.opex_form", + href_override="/calculations/opex", + match_prefix="/calculations/opex", + sort_order=60, + required_roles=["analyst", "admin"], + parent_slug="projects", + ), + NavigationLinkSeed( + slug="capex", + group_slug="workspace", + label="Capex Planner", + route_name="calculations.capex_form", + href_override="/calculations/capex", + match_prefix="/calculations/capex", + sort_order=70, + required_roles=["analyst", "admin"], + parent_slug="projects", + ), + NavigationLinkSeed( + slug="simulations", + group_slug="insights", + label="Simulations", + href_override="/ui/simulations", + match_prefix="/ui/simulations", + sort_order=10, + required_roles=["analyst", "admin"], + ), + NavigationLinkSeed( + slug="reporting", + group_slug="insights", + label="Reporting", + href_override="/ui/reporting", + match_prefix="/ui/reporting", + sort_order=20, + required_roles=["analyst", "admin"], + ), + NavigationLinkSeed( + slug="settings", + group_slug="configuration", + label="Settings", + href_override="/ui/settings", + match_prefix="/ui/settings", + sort_order=10, + required_roles=["admin"], + ), + NavigationLinkSeed( + slug="themes", + group_slug="configuration", + label="Themes", + href_override="/theme-settings", + match_prefix="/theme-settings", + sort_order=20, + required_roles=["admin"], + parent_slug="settings", + ), + NavigationLinkSeed( + slug="currencies", + group_slug="configuration", + label="Currency Management", + href_override="/ui/currencies", + match_prefix="/ui/currencies", + sort_order=30, + required_roles=["admin"], + parent_slug="settings", + ), + NavigationLinkSeed( + slug="logout", + group_slug="account", + label="Logout", + route_name="auth.logout", + match_prefix="/logout", + sort_order=10, + required_roles=["viewer", "analyst", "project_manager", "admin"], + ), + NavigationLinkSeed( + slug="login", + group_slug="account", + label="Login", + route_name="auth.login_form", + match_prefix="/login", + sort_order=10, + required_roles=["anonymous"], + ), + NavigationLinkSeed( + slug="register", + group_slug="account", + label="Register", + route_name="auth.register_form", + match_prefix="/register", + sort_order=20, + required_roles=["anonymous"], + ), + NavigationLinkSeed( + slug="forgot-password", + group_slug="account", + label="Forgot Password", + route_name="auth.password_reset_request_form", + match_prefix="/forgot-password", + sort_order=30, + required_roles=["anonymous"], + ), +] + + +DEFAULT_PROJECTS: list[ProjectSeed] = [ + ProjectSeed( + name="Helios Copper", + location="Chile", + operation_type="open_pit", + description="Flagship open pit copper operation used for demos", + ), + ProjectSeed( + name="Luna Nickel", + location="Australia", + operation_type="underground", + description="Underground nickel sulphide project with stochastic modelling", + ), +] + + +DEFAULT_SCENARIOS: list[ScenarioSeed] = [ + ScenarioSeed( + project_name="Helios Copper", + name="Base Case", + description="Deterministic base case for Helios", + status="active", + discount_rate=8.0, + primary_resource="diesel", + ), + ScenarioSeed( + project_name="Helios Copper", + name="Expansion Case", + description="Expansion scenario with increased throughput", + status="draft", + discount_rate=9.0, + primary_resource="electricity", + ), + ScenarioSeed( + project_name="Luna Nickel", + name="Feasibility", + description="Feasibility scenario targeting steady state", + status="active", + discount_rate=10.0, + primary_resource="electricity", + ), +] + + +DEFAULT_FINANCIAL_INPUTS: list[FinancialInputSeed] = [ + FinancialInputSeed( + project_name="Helios Copper", + scenario_name="Base Case", + name="Initial Capital", + category="capex", + cost_bucket="capital_initial", + amount=Decimal("450000000"), + notes="Initial mine development costs", + ), + FinancialInputSeed( + project_name="Helios Copper", + scenario_name="Base Case", + name="Opex", + category="opex", + cost_bucket="operating_variable", + amount=Decimal("75000000"), + notes="Annual processing operating expenditure", + ), + FinancialInputSeed( + project_name="Helios Copper", + scenario_name="Expansion Case", + name="Expansion Capital", + category="capex", + cost_bucket="capital_sustaining", + amount=Decimal("120000000"), + ), + FinancialInputSeed( + project_name="Luna Nickel", + scenario_name="Feasibility", + name="Nickel Revenue", + category="revenue", + cost_bucket=None, + amount=Decimal("315000000"), + ), +] + + +def _get_database_url() -> str: + # Prefer the same DATABASE_URL used by the application + from config.database import DATABASE_URL + + return DATABASE_URL + + +def _is_sqlite(database_url: str) -> bool: + return database_url.startswith("sqlite://") + + +def _create_engine(database_url: Optional[str] = None) -> Engine: + database_url = database_url or _get_database_url() + engine = create_engine(database_url, future=True) + return engine + + +def _create_enum_if_missing_sql(type_name: str, values: List[str]) -> str: + # Use a DO block to safely create the enum only if it is missing + vals = ", ".join(f"'{v}'" for v in values) + sql = ( + "DO $$ BEGIN " + f"IF NOT EXISTS (SELECT 1 FROM pg_type WHERE typname = '{type_name}') THEN " + f"CREATE TYPE {type_name} AS ENUM ({vals}); " + "END IF; END $$;" + ) + return sql + + +def ensure_enums(engine: Engine, is_sqlite: bool) -> None: + if is_sqlite: + # SQLite doesn't have enums, constraints are in table DDL + logger.debug("Skipping enum creation for SQLite") + return + with engine.begin() as conn: + for name, vals in ENUM_DEFINITIONS.items(): + sql = _create_enum_if_missing_sql(name, vals) + logger.debug("Ensuring enum %s: %s", name, sql) + conn.execute(text(sql)) + + +def _fetch_enum_values(conn, type_name: str) -> Set[str]: + rows = conn.execute( + text( + """ + SELECT e.enumlabel + FROM pg_enum e + JOIN pg_type t ON t.oid = e.enumtypid + WHERE t.typname = :type_name + """ + ), + {"type_name": type_name}, + ) + return {row.enumlabel for row in rows} + + +def normalize_enum_values(engine: Engine, is_sqlite: bool) -> None: + if is_sqlite: + # No enums to normalize in SQLite + logger.debug("Skipping enum normalization for SQLite") + return + with engine.begin() as conn: + for type_name, expected_values in ENUM_DEFINITIONS.items(): + try: + existing_values = _fetch_enum_values(conn, type_name) + except Exception as exc: # pragma: no cover - system catalogs missing + logger.debug( + "Skipping enum normalization for %s due to error: %s", + type_name, + exc, + ) + continue + + expected_set = set(expected_values) + for value in list(existing_values): + if value in expected_set: + continue + + normalized = value.lower() + if ( + normalized != value + and normalized in expected_set + and normalized not in existing_values + ): + logger.info( + "Renaming enum value %s.%s -> %s", + type_name, + value, + normalized, + ) + conn.execute( + text( + f"ALTER TYPE {type_name} RENAME VALUE :old_value TO :new_value" + ), + {"old_value": value, "new_value": normalized}, + ) + existing_values.remove(value) + existing_values.add(normalized) + + +def ensure_tables(engine: Engine, is_sqlite: bool) -> None: + table_ddls = _get_table_ddls(is_sqlite) + with engine.begin() as conn: + for ddl in table_ddls: + logger.debug("Executing DDL:\n%s", ddl) + conn.execute(text(ddl)) + + +CONSTRAINT_DDLS = [ + """ + DO $$ + BEGIN + IF NOT EXISTS ( + SELECT 1 + FROM pg_constraint + WHERE conname = 'uq_scenarios_project_name' + ) THEN + ALTER TABLE scenarios + ADD CONSTRAINT uq_scenarios_project_name UNIQUE (project_id, name); + END IF; + END; + $$; + """, + """ + DO $$ + BEGIN + IF NOT EXISTS ( + SELECT 1 + FROM pg_constraint + WHERE conname = 'uq_financial_inputs_scenario_name' + ) THEN + ALTER TABLE financial_inputs + ADD CONSTRAINT uq_financial_inputs_scenario_name UNIQUE (scenario_id, name); + END IF; + END; + $$; + """, +] + + +def ensure_constraints(engine: Engine, is_sqlite: bool) -> None: + if is_sqlite: + # Constraints are already in table DDL for SQLite + logger.debug("Skipping constraint creation for SQLite") + return + with engine.begin() as conn: + for ddl in CONSTRAINT_DDLS: + logger.debug("Ensuring constraint via:\n%s", ddl) + conn.execute(text(ddl)) + + +def seed_roles(engine: Engine, is_sqlite: bool) -> None: + with engine.begin() as conn: + for r in DEFAULT_ROLES: + seed = RoleSeed(**r) + conn.execute( + text( + "INSERT INTO roles (id, name, display_name, description) VALUES (:id, :name, :display_name, :description) " + "ON CONFLICT (id) DO UPDATE SET name = EXCLUDED.name, display_name = EXCLUDED.display_name, description = EXCLUDED.description" + ), + dict(id=seed.id, name=seed.name, + display_name=seed.display_name, description=seed.description), + ) + + +def seed_admin_user(engine: Engine, is_sqlite: bool) -> None: + with engine.begin() as conn: + # Use environment-configured admin settings when present so initializer + # aligns with the application's bootstrap configuration. + admin_email = os.getenv( + "CALMINER_SEED_ADMIN_EMAIL", DEFAULT_ADMIN["email"]) + admin_username = os.getenv( + "CALMINER_SEED_ADMIN_USERNAME", DEFAULT_ADMIN["username"]) + admin_password = os.getenv( + "CALMINER_SEED_ADMIN_PASSWORD", DEFAULT_ADMIN["password"]) + u = UserSeed( + id=DEFAULT_ADMIN.get("id", 1), + email=admin_email, + username=admin_username, + password=admin_password, + is_active=DEFAULT_ADMIN.get("is_active", True), + is_superuser=DEFAULT_ADMIN.get("is_superuser", True), + ) + password_hash = password_context.hash(u.password) + # Upsert by username to avoid conflicting with different admin email configs + conn.execute( + text( + "INSERT INTO users (email, username, password_hash, is_active, is_superuser) " + "VALUES (:email, :username, :password_hash, :is_active, :is_superuser) " + "ON CONFLICT (username) DO UPDATE SET email = EXCLUDED.email, password_hash = EXCLUDED.password_hash, is_active = EXCLUDED.is_active, is_superuser = EXCLUDED.is_superuser" + ), + dict(email=u.email, username=u.username, password_hash=password_hash, + is_active=u.is_active, is_superuser=u.is_superuser), + ) + # ensure admin has admin role + # Resolve user_id for role assignment: select by username + row = conn.execute(text("SELECT id FROM users WHERE username = :username"), dict( + username=u.username)).fetchone() + if row is not None: + user_id = row.id + else: + user_id = None + if user_id is not None: + conn.execute( + text( + "INSERT INTO user_roles (user_id, role_id, granted_by) VALUES (:user_id, :role_id, :granted_by) " + "ON CONFLICT (user_id, role_id) DO NOTHING" + ), + dict(user_id=user_id, role_id=1, granted_by=user_id), + ) + + +def ensure_default_pricing(engine: Engine, is_sqlite: bool) -> None: + with engine.begin() as conn: + p = PricingSeed(**DEFAULT_PRICING) + # Try insert on slug conflict + conn.execute( + text( + "INSERT INTO pricing_settings (slug, name, description, default_currency, default_payable_pct, moisture_threshold_pct, moisture_penalty_per_pct) " + "VALUES (:slug, :name, :description, :default_currency, :default_payable_pct, :moisture_threshold_pct, :moisture_penalty_per_pct) " + "ON CONFLICT (slug) DO UPDATE SET name = EXCLUDED.name" + ), + dict( + slug=p.slug, + name=p.name, + description=p.description, + default_currency=p.default_currency, + default_payable_pct=p.default_payable_pct, + moisture_threshold_pct=p.moisture_threshold_pct, + moisture_penalty_per_pct=p.moisture_penalty_per_pct, + ), + ) + + +def seed_navigation(engine: Engine, is_sqlite: bool) -> None: + group_insert_sql = text( + """ + INSERT INTO navigation_groups (slug, label, sort_order, icon, tooltip, is_enabled) + VALUES (:slug, :label, :sort_order, :icon, :tooltip, :is_enabled) + ON CONFLICT (slug) DO UPDATE SET + label = EXCLUDED.label, + sort_order = EXCLUDED.sort_order, + icon = EXCLUDED.icon, + tooltip = EXCLUDED.tooltip, + is_enabled = EXCLUDED.is_enabled + """ + ) + + link_insert_sql = text( + """ + INSERT INTO navigation_links ( + group_id, parent_link_id, slug, label, route_name, href_override, + match_prefix, sort_order, icon, tooltip, required_roles, is_enabled, is_external + ) + VALUES ( + :group_id, :parent_link_id, :slug, :label, :route_name, :href_override, + :match_prefix, :sort_order, :icon, :tooltip, :required_roles, :is_enabled, :is_external + ) + ON CONFLICT (group_id, slug) DO UPDATE SET + parent_link_id = EXCLUDED.parent_link_id, + label = EXCLUDED.label, + route_name = EXCLUDED.route_name, + href_override = EXCLUDED.href_override, + match_prefix = EXCLUDED.match_prefix, + sort_order = EXCLUDED.sort_order, + icon = EXCLUDED.icon, + tooltip = EXCLUDED.tooltip, + required_roles = EXCLUDED.required_roles, + is_enabled = EXCLUDED.is_enabled, + is_external = EXCLUDED.is_external + """ + ) + link_insert_sql = link_insert_sql.bindparams( + bindparam("required_roles", type_=JSON) + ) + + with engine.begin() as conn: + role_rows = conn.execute(text("SELECT name FROM roles")).fetchall() + available_roles = {row.name for row in role_rows} + + def resolve_roles(raw_roles: list[str]) -> list[str]: + if not raw_roles: + return [] + + resolved: list[str] = [] + missing: list[str] = [] + for slug in raw_roles: + if slug == "anonymous": + if slug not in resolved: + resolved.append(slug) + continue + if slug in available_roles: + if slug not in resolved: + resolved.append(slug) + else: + missing.append(slug) + + if missing: + logger.warning( + "Navigation seed roles %s are missing; defaulting link access to admin only", + ", ".join(missing), + ) + if "admin" in available_roles and "admin" not in resolved: + resolved.append("admin") + + return resolved + + group_ids: dict[str, int] = {} + for group_seed in DEFAULT_NAVIGATION_GROUPS: + conn.execute( + group_insert_sql, + group_seed.model_dump(), + ) + row = conn.execute( + text("SELECT id FROM navigation_groups WHERE slug = :slug"), + {"slug": group_seed.slug}, + ).fetchone() + if row is not None: + group_ids[group_seed.slug] = row.id + + if not group_ids: + logger.warning( + "Navigation seeding skipped because no groups were inserted") + return + + link_ids: dict[tuple[str, str], int] = {} + parent_pending: list[NavigationLinkSeed] = [] + + for link_seed in DEFAULT_NAVIGATION_LINKS: + if link_seed.parent_slug: + parent_pending.append(link_seed) + continue + + group_id = group_ids.get(link_seed.group_slug) + if group_id is None: + logger.warning( + "Skipping navigation link '%s' because group '%s' is missing", + link_seed.slug, + link_seed.group_slug, + ) + continue + + resolved_roles = resolve_roles(link_seed.required_roles) + + payload = { + "group_id": group_id, + "parent_link_id": None, + "slug": link_seed.slug, + "label": link_seed.label, + "route_name": link_seed.route_name, + "href_override": link_seed.href_override, + "match_prefix": link_seed.match_prefix, + "sort_order": link_seed.sort_order, + "icon": link_seed.icon, + "tooltip": link_seed.tooltip, + "required_roles": resolved_roles, + "is_enabled": link_seed.is_enabled, + "is_external": link_seed.is_external, + } + conn.execute(link_insert_sql, payload) + row = conn.execute( + text( + "SELECT id FROM navigation_links WHERE group_id = :group_id AND slug = :slug" + ), + {"group_id": group_id, "slug": link_seed.slug}, + ).fetchone() + if row is not None: + link_ids[(link_seed.group_slug, link_seed.slug)] = row.id + + for link_seed in parent_pending: + group_id = group_ids.get(link_seed.group_slug) + if group_id is None: + logger.warning( + "Skipping child navigation link '%s' because group '%s' is missing", + link_seed.slug, + link_seed.group_slug, + ) + continue + + parent_key = (link_seed.group_slug, link_seed.parent_slug or "") + parent_id = link_ids.get(parent_key) + if parent_id is None: + parent_row = conn.execute( + text( + "SELECT id FROM navigation_links WHERE group_id = :group_id AND slug = :slug" + ), + {"group_id": group_id, "slug": link_seed.parent_slug}, + ).fetchone() + parent_id = parent_row.id if parent_row else None + + if parent_id is None: + logger.warning( + "Skipping child navigation link '%s' because parent '%s' is missing", + link_seed.slug, + link_seed.parent_slug, + ) + continue + + resolved_roles = resolve_roles(link_seed.required_roles) + + payload = { + "group_id": group_id, + "parent_link_id": parent_id, + "slug": link_seed.slug, + "label": link_seed.label, + "route_name": link_seed.route_name, + "href_override": link_seed.href_override, + "match_prefix": link_seed.match_prefix, + "sort_order": link_seed.sort_order, + "icon": link_seed.icon, + "tooltip": link_seed.tooltip, + "required_roles": resolved_roles, + "is_enabled": link_seed.is_enabled, + "is_external": link_seed.is_external, + } + conn.execute(link_insert_sql, payload) + row = conn.execute( + text( + "SELECT id FROM navigation_links WHERE group_id = :group_id AND slug = :slug" + ), + {"group_id": group_id, "slug": link_seed.slug}, + ).fetchone() + if row is not None: + link_ids[(link_seed.group_slug, link_seed.slug)] = row.id + + +def _project_id_by_name(conn, project_name: str) -> Optional[int]: + row = conn.execute( + text("SELECT id FROM projects WHERE name = :name"), + {"name": project_name}, + ).fetchone() + return row.id if row else None + + +def ensure_default_projects(engine: Engine, is_sqlite: bool) -> None: + with engine.begin() as conn: + for project in DEFAULT_PROJECTS: + conn.execute( + text( + """ + INSERT INTO projects (name, location, operation_type, description) + VALUES (:name, :location, :operation_type, :description) + ON CONFLICT (name) DO UPDATE SET + location = EXCLUDED.location, + operation_type = EXCLUDED.operation_type, + description = EXCLUDED.description + """ + ), + project.model_dump(), + ) + + +def ensure_default_scenarios(engine: Engine, is_sqlite: bool) -> None: + with engine.begin() as conn: + for scenario in DEFAULT_SCENARIOS: + project_id = _project_id_by_name(conn, scenario.project_name) + if project_id is None: + logger.warning( + "Skipping scenario seed '%s' because project '%s' does not exist", + scenario.name, + scenario.project_name, + ) + continue + + payload = scenario.model_dump(exclude={"project_name"}) + payload.update({"project_id": project_id}) + if is_sqlite: + sql = """ + INSERT INTO scenarios ( + project_id, name, description, status, discount_rate, + currency, primary_resource + ) + VALUES ( + :project_id, :name, :description, :status, + :discount_rate, :currency, :primary_resource + ) + ON CONFLICT (project_id, name) DO UPDATE SET + description = EXCLUDED.description, + status = EXCLUDED.status, + discount_rate = EXCLUDED.discount_rate, + currency = EXCLUDED.currency, + primary_resource = EXCLUDED.primary_resource + """ + else: + sql = """ + INSERT INTO scenarios ( + project_id, name, description, status, discount_rate, + currency, primary_resource + ) + VALUES ( + :project_id, :name, :description, CAST(:status AS scenariostatus), + :discount_rate, :currency, + CASE WHEN :primary_resource IS NULL + THEN NULL + ELSE CAST(:primary_resource AS resourcetype) + END + ) + ON CONFLICT (project_id, name) DO UPDATE SET + description = EXCLUDED.description, + status = EXCLUDED.status, + discount_rate = EXCLUDED.discount_rate, + currency = EXCLUDED.currency, + primary_resource = EXCLUDED.primary_resource + """ + conn.execute(text(sql), payload) + + +def ensure_default_financial_inputs(engine: Engine, is_sqlite: bool) -> None: + with engine.begin() as conn: + for item in DEFAULT_FINANCIAL_INPUTS: + project_id = _project_id_by_name(conn, item.project_name) + if project_id is None: + logger.warning( + "Skipping financial input '%s'; project '%s' missing", + item.name, + item.project_name, + ) + continue + + scenario_row = conn.execute( + text( + "SELECT id FROM scenarios WHERE project_id = :project_id AND name = :name" + ), + {"project_id": project_id, "name": item.scenario_name}, + ).fetchone() + if scenario_row is None: + logger.warning( + "Skipping financial input '%s'; scenario '%s' missing for project '%s'", + item.name, + item.scenario_name, + item.project_name, + ) + continue + + payload = item.model_dump( + exclude={"project_name", "scenario_name"}, + ) + if is_sqlite: + # Convert Decimal to float for SQLite + payload["amount"] = float(payload["amount"]) + payload.update({"scenario_id": scenario_row.id}) + if is_sqlite: + sql = """ + INSERT INTO financial_inputs ( + scenario_id, name, category, cost_bucket, amount, currency, notes + ) + VALUES ( + :scenario_id, :name, :category, :cost_bucket, + :amount, :currency, :notes + ) + ON CONFLICT (scenario_id, name) DO UPDATE SET + category = EXCLUDED.category, + cost_bucket = EXCLUDED.cost_bucket, + amount = EXCLUDED.amount, + currency = EXCLUDED.currency, + notes = EXCLUDED.notes + """ + else: + sql = """ + INSERT INTO financial_inputs ( + scenario_id, name, category, cost_bucket, amount, currency, notes + ) + VALUES ( + :scenario_id, :name, CAST(:category AS financialcategory), + CASE WHEN :cost_bucket IS NULL THEN NULL + ELSE CAST(:cost_bucket AS costbucket) + END, + :amount, + :currency, + :notes + ) + ON CONFLICT (scenario_id, name) DO UPDATE SET + category = EXCLUDED.category, + cost_bucket = EXCLUDED.cost_bucket, + amount = EXCLUDED.amount, + currency = EXCLUDED.currency, + notes = EXCLUDED.notes + """ + conn.execute(text(sql), payload) + + +def init_db(database_url: Optional[str] = None) -> None: + """Run the idempotent initialization sequence. + + Steps: + - Ensure enum types exist. + - Ensure required tables exist. + - Seed roles and admin user. + - Ensure default pricing settings record exists. + - Seed sample projects, scenarios, and financial inputs. + """ + database_url = database_url or _get_database_url() + is_sqlite = _is_sqlite(database_url) + engine = _create_engine(database_url) + logger.info("Starting DB initialization using engine=%s", engine) + ensure_enums(engine, is_sqlite) + normalize_enum_values(engine, is_sqlite) + ensure_tables(engine, is_sqlite) + ensure_constraints(engine, is_sqlite) + seed_roles(engine, is_sqlite) + seed_admin_user(engine, is_sqlite) + ensure_default_pricing(engine, is_sqlite) + seed_navigation(engine, is_sqlite) + ensure_default_projects(engine, is_sqlite) + ensure_default_scenarios(engine, is_sqlite) + ensure_default_financial_inputs(engine, is_sqlite) + logger.info("DB initialization complete") + + +if __name__ == "__main__": + # Allow running manually: python -m scripts.init_db + logging.basicConfig(level=logging.INFO) + init_db() diff --git a/scripts/initial_data.py b/scripts/initial_data.py new file mode 100644 index 0000000..fa752e8 --- /dev/null +++ b/scripts/initial_data.py @@ -0,0 +1,231 @@ +from __future__ import annotations + +import logging +import os +from dataclasses import dataclass +from typing import Callable, Iterable + +from dotenv import load_dotenv + +from config.settings import Settings +from models import Role, User +from services.repositories import ( + DEFAULT_ROLE_DEFINITIONS, + PricingSettingsSeedResult, + RoleRepository, + UserRepository, + ensure_default_pricing_settings, +) +from services.unit_of_work import UnitOfWork + + +@dataclass +class SeedConfig: + admin_email: str + admin_username: str + admin_password: str + admin_roles: tuple[str, ...] + force_reset: bool + + +@dataclass +class RoleSeedResult: + created: int + updated: int + total: int + + +@dataclass +class AdminSeedResult: + created_user: bool + updated_user: bool + password_rotated: bool + roles_granted: int + + +def parse_bool(value: str | None) -> bool: + if value is None: + return False + return value.strip().lower() in {"1", "true", "yes", "on"} + + +def normalise_role_list(raw_value: str | None) -> tuple[str, ...]: + if not raw_value: + return ("admin",) + parts = [segment.strip() + for segment in raw_value.split(",") if segment.strip()] + if "admin" not in parts: + parts.insert(0, "admin") + seen: set[str] = set() + ordered: list[str] = [] + for role_name in parts: + if role_name not in seen: + ordered.append(role_name) + seen.add(role_name) + return tuple(ordered) + + +def load_config() -> SeedConfig: + load_dotenv() + admin_email = os.getenv("CALMINER_SEED_ADMIN_EMAIL", + "admin@calminer.local") + admin_username = os.getenv("CALMINER_SEED_ADMIN_USERNAME", "admin") + admin_password = os.getenv("CALMINER_SEED_ADMIN_PASSWORD", "ChangeMe123!") + admin_roles = normalise_role_list(os.getenv("CALMINER_SEED_ADMIN_ROLES")) + force_reset = parse_bool(os.getenv("CALMINER_SEED_FORCE")) + return SeedConfig( + admin_email=admin_email, + admin_username=admin_username, + admin_password=admin_password, + admin_roles=admin_roles, + force_reset=force_reset, + ) + + +def ensure_default_roles( + role_repo: RoleRepository, + definitions: Iterable[dict[str, str]] = DEFAULT_ROLE_DEFINITIONS, +) -> RoleSeedResult: + created = 0 + updated = 0 + total = 0 + for definition in definitions: + total += 1 + existing = role_repo.get_by_name(definition["name"]) + if existing is None: + role_repo.create(Role(**definition)) + created += 1 + continue + changed = False + if existing.display_name != definition["display_name"]: + existing.display_name = definition["display_name"] + changed = True + if existing.description != definition["description"]: + existing.description = definition["description"] + changed = True + if changed: + updated += 1 + role_repo.session.flush() + return RoleSeedResult(created=created, updated=updated, total=total) + + +def ensure_admin_user( + user_repo: UserRepository, + role_repo: RoleRepository, + config: SeedConfig, +) -> AdminSeedResult: + created_user = False + updated_user = False + password_rotated = False + roles_granted = 0 + + user = user_repo.get_by_email(config.admin_email, with_roles=True) + if user is None: + user = User( + email=config.admin_email, + username=config.admin_username, + password_hash=User.hash_password(config.admin_password), + is_active=True, + is_superuser=True, + ) + user_repo.create(user) + created_user = True + else: + if user.username != config.admin_username: + user.username = config.admin_username + updated_user = True + if not user.is_active: + user.is_active = True + updated_user = True + if not user.is_superuser: + user.is_superuser = True + updated_user = True + if config.force_reset: + user.password_hash = User.hash_password(config.admin_password) + password_rotated = True + updated_user = True + user_repo.session.flush() + + for role_name in config.admin_roles: + role = role_repo.get_by_name(role_name) + if role is None: + logging.warning( + "Role '%s' is not defined and will be skipped", role_name) + continue + already_assigned = any(assignment.role_id == + role.id for assignment in user.role_assignments) + if already_assigned: + continue + user_repo.assign_role( + user_id=user.id, role_id=role.id, granted_by=user.id) + roles_granted += 1 + + return AdminSeedResult( + created_user=created_user, + updated_user=updated_user, + password_rotated=password_rotated, + roles_granted=roles_granted, + ) + + +def seed_initial_data( + config: SeedConfig, + *, + unit_of_work_factory: Callable[[], UnitOfWork] | None = None, +) -> None: + logging.info("Starting initial data seeding") + factory = unit_of_work_factory or UnitOfWork + with factory() as uow: + assert ( + uow.roles is not None + and uow.users is not None + and uow.pricing_settings is not None + and uow.projects is not None + ) + role_result = ensure_default_roles(uow.roles) + admin_result = ensure_admin_user(uow.users, uow.roles, config) + pricing_metadata = uow.get_pricing_metadata() + metadata_source = "database" + if pricing_metadata is None: + pricing_metadata = Settings.from_environment().pricing_metadata() + metadata_source = "environment" + pricing_result: PricingSettingsSeedResult = ensure_default_pricing_settings( + uow.pricing_settings, + metadata=pricing_metadata, + ) + + projects_without_pricing = [ + project + for project in uow.projects.list(with_pricing=True) + if project.pricing_settings is None + ] + assigned_projects = 0 + for project in projects_without_pricing: + uow.set_project_pricing_settings(project, pricing_result.settings) + assigned_projects += 1 + logging.info( + "Roles processed: %s total, %s created, %s updated", + role_result.total, + role_result.created, + role_result.updated, + ) + logging.info( + "Admin user: created=%s updated=%s password_rotated=%s roles_granted=%s", + admin_result.created_user, + admin_result.updated_user, + admin_result.password_rotated, + admin_result.roles_granted, + ) + logging.info( + "Pricing settings ensured (source=%s): slug=%s created=%s updated_fields=%s impurity_upserts=%s", + metadata_source, + pricing_result.settings.slug, + pricing_result.created, + pricing_result.updated_fields, + pricing_result.impurity_upserts, + ) + logging.info( + "Projects updated with default pricing settings: %s", + assigned_projects, + ) + logging.info("Initial data seeding completed successfully") diff --git a/scripts/migrations/000_base.sql b/scripts/migrations/000_base.sql deleted file mode 100644 index 11f9358..0000000 --- a/scripts/migrations/000_base.sql +++ /dev/null @@ -1,189 +0,0 @@ --- Baseline migration for CalMiner database schema --- Date: 2025-10-25 --- Purpose: Consolidate foundational tables and reference data - -BEGIN; - --- Currency reference table -CREATE TABLE IF NOT EXISTS currency ( - id SERIAL PRIMARY KEY, - code VARCHAR(3) NOT NULL UNIQUE, - name VARCHAR(128) NOT NULL, - symbol VARCHAR(8), - is_active BOOLEAN NOT NULL DEFAULT TRUE -); - -INSERT INTO currency (code, name, symbol, is_active) -VALUES - ('USD', 'United States Dollar', 'USD$', TRUE), - ('EUR', 'Euro', 'EUR', TRUE), - ('CLP', 'Chilean Peso', 'CLP$', TRUE), - ('RMB', 'Chinese Yuan', 'RMB', TRUE), - ('GBP', 'British Pound', 'GBP', TRUE), - ('CAD', 'Canadian Dollar', 'CAD$', TRUE), - ('AUD', 'Australian Dollar', 'AUD$', TRUE) -ON CONFLICT (code) DO UPDATE -SET name = EXCLUDED.name, - symbol = EXCLUDED.symbol, - is_active = EXCLUDED.is_active; - --- Application-level settings table -CREATE TABLE IF NOT EXISTS application_setting ( - id SERIAL PRIMARY KEY, - key VARCHAR(128) NOT NULL UNIQUE, - value TEXT NOT NULL, - value_type VARCHAR(32) NOT NULL DEFAULT 'string', - category VARCHAR(32) NOT NULL DEFAULT 'general', - description TEXT, - is_editable BOOLEAN NOT NULL DEFAULT TRUE, - created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), - updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW() -); - -CREATE UNIQUE INDEX IF NOT EXISTS ux_application_setting_key - ON application_setting (key); - -CREATE INDEX IF NOT EXISTS ix_application_setting_category - ON application_setting (category); - --- Measurement unit reference table -CREATE TABLE IF NOT EXISTS measurement_unit ( - id SERIAL PRIMARY KEY, - code VARCHAR(64) NOT NULL UNIQUE, - name VARCHAR(128) NOT NULL, - symbol VARCHAR(16), - unit_type VARCHAR(32) NOT NULL, - is_active BOOLEAN NOT NULL DEFAULT TRUE, - created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), - updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW() -); - -INSERT INTO measurement_unit (code, name, symbol, unit_type, is_active) -VALUES - ('tonnes', 'Tonnes', 't', 'mass', TRUE), - ('kilograms', 'Kilograms', 'kg', 'mass', TRUE), - ('pounds', 'Pounds', 'lb', 'mass', TRUE), - ('liters', 'Liters', 'L', 'volume', TRUE), - ('cubic_meters', 'Cubic Meters', 'm3', 'volume', TRUE), - ('kilowatt_hours', 'Kilowatt Hours', 'kWh', 'energy', TRUE) -ON CONFLICT (code) DO UPDATE -SET name = EXCLUDED.name, - symbol = EXCLUDED.symbol, - unit_type = EXCLUDED.unit_type, - is_active = EXCLUDED.is_active; - --- Consumption and production measurement metadata -ALTER TABLE consumption - ADD COLUMN IF NOT EXISTS unit_name VARCHAR(64); -ALTER TABLE consumption - ADD COLUMN IF NOT EXISTS unit_symbol VARCHAR(16); - -ALTER TABLE production_output - ADD COLUMN IF NOT EXISTS unit_name VARCHAR(64); -ALTER TABLE production_output - ADD COLUMN IF NOT EXISTS unit_symbol VARCHAR(16); - --- Currency integration for CAPEX and OPEX -ALTER TABLE capex - ADD COLUMN IF NOT EXISTS currency_id INTEGER; -ALTER TABLE opex - ADD COLUMN IF NOT EXISTS currency_id INTEGER; - -DO $$ -DECLARE - usd_id INTEGER; -BEGIN - -- Ensure currency_id columns align with legacy currency_code values when present - IF EXISTS ( - SELECT 1 FROM information_schema.columns - WHERE table_name = 'capex' AND column_name = 'currency_code' - ) THEN - UPDATE capex AS c - SET currency_id = cur.id - FROM currency AS cur - WHERE c.currency_code = cur.code - AND (c.currency_id IS DISTINCT FROM cur.id); - END IF; - - IF EXISTS ( - SELECT 1 FROM information_schema.columns - WHERE table_name = 'opex' AND column_name = 'currency_code' - ) THEN - UPDATE opex AS o - SET currency_id = cur.id - FROM currency AS cur - WHERE o.currency_code = cur.code - AND (o.currency_id IS DISTINCT FROM cur.id); - END IF; - - SELECT id INTO usd_id FROM currency WHERE code = 'USD'; - IF usd_id IS NOT NULL THEN - UPDATE capex SET currency_id = usd_id WHERE currency_id IS NULL; - UPDATE opex SET currency_id = usd_id WHERE currency_id IS NULL; - END IF; -END $$; - -ALTER TABLE capex - ALTER COLUMN currency_id SET NOT NULL; -ALTER TABLE opex - ALTER COLUMN currency_id SET NOT NULL; - -DO $$ -BEGIN - IF NOT EXISTS ( - SELECT 1 FROM information_schema.table_constraints - WHERE table_schema = current_schema() - AND table_name = 'capex' - AND constraint_name = 'fk_capex_currency' - ) THEN - ALTER TABLE capex - ADD CONSTRAINT fk_capex_currency FOREIGN KEY (currency_id) - REFERENCES currency (id) ON DELETE RESTRICT; - END IF; - - IF NOT EXISTS ( - SELECT 1 FROM information_schema.table_constraints - WHERE table_schema = current_schema() - AND table_name = 'opex' - AND constraint_name = 'fk_opex_currency' - ) THEN - ALTER TABLE opex - ADD CONSTRAINT fk_opex_currency FOREIGN KEY (currency_id) - REFERENCES currency (id) ON DELETE RESTRICT; - END IF; -END $$; - -ALTER TABLE capex - DROP COLUMN IF EXISTS currency_code; -ALTER TABLE opex - DROP COLUMN IF EXISTS currency_code; - --- Role-based access control tables -CREATE TABLE IF NOT EXISTS roles ( - id SERIAL PRIMARY KEY, - name VARCHAR(255) UNIQUE NOT NULL -); - -CREATE TABLE IF NOT EXISTS users ( - id SERIAL PRIMARY KEY, - username VARCHAR(255) UNIQUE NOT NULL, - email VARCHAR(255) UNIQUE NOT NULL, - hashed_password VARCHAR(255) NOT NULL, - role_id INTEGER NOT NULL REFERENCES roles (id) ON DELETE RESTRICT -); - -CREATE INDEX IF NOT EXISTS ix_users_username ON users (username); -CREATE INDEX IF NOT EXISTS ix_users_email ON users (email); - --- Theme settings configuration table -CREATE TABLE IF NOT EXISTS theme_settings ( - id SERIAL PRIMARY KEY, - theme_name VARCHAR(255) UNIQUE NOT NULL, - primary_color VARCHAR(7) NOT NULL, - secondary_color VARCHAR(7) NOT NULL, - accent_color VARCHAR(7) NOT NULL, - background_color VARCHAR(7) NOT NULL, - text_color VARCHAR(7) NOT NULL -); - -COMMIT; diff --git a/scripts/reset_db.py b/scripts/reset_db.py new file mode 100644 index 0000000..cd5a3c1 --- /dev/null +++ b/scripts/reset_db.py @@ -0,0 +1,91 @@ +"""Utility to reset development Postgres schema artifacts. + +This script drops managed tables and enum types created by `scripts.init_db`. +It is intended for local development only; it refuses to run if CALMINER_ENV +indicates production or staging. The operation is idempotent: missing objects +are ignored. Use with caution. +""" +from __future__ import annotations + +import logging +import os +from dataclasses import dataclass +from typing import Iterable + +from sqlalchemy import text +from sqlalchemy.engine import Engine + +from config.database import DATABASE_URL +from scripts.init_db import ENUM_DEFINITIONS, _create_engine + +logger = logging.getLogger(__name__) + + +@dataclass(slots=True) +class ResetOptions: + drop_tables: bool = True + drop_enums: bool = True + + +MANAGED_TABLES: tuple[str, ...] = ( + "simulation_parameters", + "financial_inputs", + "scenarios", + "projects", + "pricing_impurity_settings", + "pricing_metal_settings", + "pricing_settings", + "user_roles", + "users", + "roles", +) + + +FORBIDDEN_ENVIRONMENTS: set[str] = {"production", "staging", "prod", "stage"} + + +def _ensure_safe_environment() -> None: + env = os.getenv("CALMINER_ENV", "development").lower() + if env in FORBIDDEN_ENVIRONMENTS: + raise RuntimeError( + f"Refusing to reset database in environment '{env}'. " + "Set CALMINER_ENV to 'development' to proceed." + ) + + +def _drop_tables(engine: Engine, tables: Iterable[str]) -> None: + if not tables: + return + with engine.begin() as conn: + for table in tables: + logger.info("Dropping table if exists: %s", table) + conn.execute(text(f"DROP TABLE IF EXISTS {table} CASCADE")) + + +def _drop_enums(engine: Engine, enum_names: Iterable[str]) -> None: + if not enum_names: + return + with engine.begin() as conn: + for enum_name in enum_names: + logger.info("Dropping enum type if exists: %s", enum_name) + conn.execute(text(f"DROP TYPE IF EXISTS {enum_name} CASCADE")) + + +def reset_database(*, options: ResetOptions | None = None, database_url: str | None = None) -> None: + """Drop managed tables and enums for a clean slate.""" + _ensure_safe_environment() + opts = options or ResetOptions() + engine = _create_engine(database_url or DATABASE_URL) + + if opts.drop_tables: + _drop_tables(engine, MANAGED_TABLES) + + if opts.drop_enums: + _drop_enums(engine, ENUM_DEFINITIONS.keys()) + + logger.info("Database reset complete") + + +if __name__ == "__main__": + logging.basicConfig(level=logging.INFO) + reset_database() diff --git a/scripts/seed_data.py b/scripts/seed_data.py deleted file mode 100644 index b762d04..0000000 --- a/scripts/seed_data.py +++ /dev/null @@ -1,268 +0,0 @@ -"""Seed baseline data for CalMiner in an idempotent manner. - -Usage examples --------------- - -```powershell -# Use existing environment variables (or load from setup_test.env.example) -python scripts/seed_data.py --currencies --units --defaults - -# Dry-run to preview actions -python scripts/seed_data.py --currencies --dry-run -``` -""" - -from __future__ import annotations - -import argparse -import logging -from typing import Optional - -import psycopg2 -from psycopg2 import errors -from psycopg2.extras import execute_values - -from scripts.setup_database import DatabaseConfig - - -logger = logging.getLogger(__name__) - -CURRENCY_SEEDS = ( - ("USD", "United States Dollar", "USD$", True), - ("EUR", "Euro", "EUR", True), - ("CLP", "Chilean Peso", "CLP$", True), - ("RMB", "Chinese Yuan", "RMB", True), - ("GBP", "British Pound", "GBP", True), - ("CAD", "Canadian Dollar", "CAD$", True), - ("AUD", "Australian Dollar", "AUD$", True), -) - -MEASUREMENT_UNIT_SEEDS = ( - ("tonnes", "Tonnes", "t", "mass", True), - ("kilograms", "Kilograms", "kg", "mass", True), - ("pounds", "Pounds", "lb", "mass", True), - ("liters", "Liters", "L", "volume", True), - ("cubic_meters", "Cubic Meters", "m3", "volume", True), - ("kilowatt_hours", "Kilowatt Hours", "kWh", "energy", True), -) - -THEME_SETTING_SEEDS = ( - ("--color-background", "#f4f5f7", "color", - "theme", "CSS variable --color-background", True), - ("--color-surface", "#ffffff", "color", - "theme", "CSS variable --color-surface", True), - ("--color-text-primary", "#2a1f33", "color", - "theme", "CSS variable --color-text-primary", True), - ("--color-text-secondary", "#624769", "color", - "theme", "CSS variable --color-text-secondary", True), - ("--color-text-muted", "#64748b", "color", - "theme", "CSS variable --color-text-muted", True), - ("--color-text-subtle", "#94a3b8", "color", - "theme", "CSS variable --color-text-subtle", True), - ("--color-text-invert", "#ffffff", "color", - "theme", "CSS variable --color-text-invert", True), - ("--color-text-dark", "#0f172a", "color", - "theme", "CSS variable --color-text-dark", True), - ("--color-text-strong", "#111827", "color", - "theme", "CSS variable --color-text-strong", True), - ("--color-primary", "#5f320d", "color", - "theme", "CSS variable --color-primary", True), - ("--color-primary-strong", "#7e4c13", "color", - "theme", "CSS variable --color-primary-strong", True), - ("--color-primary-stronger", "#837c15", "color", - "theme", "CSS variable --color-primary-stronger", True), - ("--color-accent", "#bff838", "color", - "theme", "CSS variable --color-accent", True), - ("--color-border", "#e2e8f0", "color", - "theme", "CSS variable --color-border", True), - ("--color-border-strong", "#cbd5e1", "color", - "theme", "CSS variable --color-border-strong", True), - ("--color-highlight", "#eef2ff", "color", - "theme", "CSS variable --color-highlight", True), - ("--color-panel-shadow", "rgba(15, 23, 42, 0.08)", "color", - "theme", "CSS variable --color-panel-shadow", True), - ("--color-panel-shadow-deep", "rgba(15, 23, 42, 0.12)", "color", - "theme", "CSS variable --color-panel-shadow-deep", True), - ("--color-surface-alt", "#f8fafc", "color", - "theme", "CSS variable --color-surface-alt", True), - ("--color-success", "#047857", "color", - "theme", "CSS variable --color-success", True), - ("--color-error", "#b91c1c", "color", - "theme", "CSS variable --color-error", True), -) - - -def parse_args() -> argparse.Namespace: - parser = argparse.ArgumentParser(description="Seed baseline CalMiner data") - parser.add_argument( - "--currencies", action="store_true", help="Seed currency table" - ) - parser.add_argument("--units", action="store_true", help="Seed unit table") - parser.add_argument( - "--theme", action="store_true", help="Seed theme settings" - ) - parser.add_argument( - "--defaults", action="store_true", help="Seed default records" - ) - parser.add_argument( - "--dry-run", action="store_true", help="Print actions without executing" - ) - parser.add_argument( - "--verbose", - "-v", - action="count", - default=0, - help="Increase logging verbosity", - ) - return parser.parse_args() - - -def _configure_logging(args: argparse.Namespace) -> None: - level = logging.WARNING - (10 * min(args.verbose, 2)) - logging.basicConfig( - level=max(level, logging.INFO), format="%(levelname)s %(message)s" - ) - - -def main() -> None: - args = parse_args() - run_with_namespace(args) - - -def run_with_namespace( - args: argparse.Namespace, - *, - config: Optional[DatabaseConfig] = None, -) -> None: - if not hasattr(args, "verbose"): - args.verbose = 0 - if not hasattr(args, "dry_run"): - args.dry_run = False - - _configure_logging(args) - - currencies = bool(getattr(args, "currencies", False)) - units = bool(getattr(args, "units", False)) - theme = bool(getattr(args, "theme", False)) - defaults = bool(getattr(args, "defaults", False)) - dry_run = bool(getattr(args, "dry_run", False)) - - if not any((currencies, units, theme, defaults)): - logger.info("No seeding options provided; exiting") - return - - config = config or DatabaseConfig.from_env() - - with psycopg2.connect(config.application_dsn()) as conn: - conn.autocommit = True - with conn.cursor() as cursor: - if currencies: - _seed_currencies(cursor, dry_run=dry_run) - if units: - _seed_units(cursor, dry_run=dry_run) - if theme: - _seed_theme(cursor, dry_run=dry_run) - if defaults: - _seed_defaults(cursor, dry_run=dry_run) - - -def _seed_currencies(cursor, *, dry_run: bool) -> None: - logger.info("Seeding currency table (%d rows)", len(CURRENCY_SEEDS)) - if dry_run: - for code, name, symbol, active in CURRENCY_SEEDS: - logger.info("Dry run: would upsert currency %s (%s)", code, name) - return - - execute_values( - cursor, - """ - INSERT INTO currency (code, name, symbol, is_active) - VALUES %s - ON CONFLICT (code) DO UPDATE - SET name = EXCLUDED.name, - symbol = EXCLUDED.symbol, - is_active = EXCLUDED.is_active - """, - CURRENCY_SEEDS, - ) - logger.info("Currency seed complete") - - -def _seed_units(cursor, *, dry_run: bool) -> None: - total = len(MEASUREMENT_UNIT_SEEDS) - logger.info("Seeding measurement_unit table (%d rows)", total) - if dry_run: - for code, name, symbol, unit_type, _ in MEASUREMENT_UNIT_SEEDS: - logger.info( - "Dry run: would upsert measurement unit %s (%s - %s)", - code, - name, - unit_type, - ) - return - - try: - execute_values( - cursor, - """ - INSERT INTO measurement_unit (code, name, symbol, unit_type, is_active) - VALUES %s - ON CONFLICT (code) DO UPDATE - SET name = EXCLUDED.name, - symbol = EXCLUDED.symbol, - unit_type = EXCLUDED.unit_type, - is_active = EXCLUDED.is_active - """, - MEASUREMENT_UNIT_SEEDS, - ) - except errors.UndefinedTable: - logger.warning( - "measurement_unit table does not exist; skipping unit seeding." - ) - cursor.connection.rollback() - return - - logger.info("Measurement unit seed complete") - - -def _seed_theme(cursor, *, dry_run: bool) -> None: - logger.info("Seeding theme settings (%d rows)", len(THEME_SETTING_SEEDS)) - if dry_run: - for key, value, _, _, _, _ in THEME_SETTING_SEEDS: - logger.info( - "Dry run: would upsert theme setting %s = %s", key, value) - return - - try: - execute_values( - cursor, - """ - INSERT INTO application_setting (key, value, value_type, category, description, is_editable) - VALUES %s - ON CONFLICT (key) DO UPDATE - SET value = EXCLUDED.value, - value_type = EXCLUDED.value_type, - category = EXCLUDED.category, - description = EXCLUDED.description, - is_editable = EXCLUDED.is_editable - """, - THEME_SETTING_SEEDS, - ) - except errors.UndefinedTable: - logger.warning( - "application_setting table does not exist; skipping theme seeding." - ) - cursor.connection.rollback() - return - - logger.info("Theme settings seed complete") - - -def _seed_defaults(cursor, *, dry_run: bool) -> None: - logger.info("Seeding default records") - _seed_theme(cursor, dry_run=dry_run) - logger.info("Default records seed complete") - - -if __name__ == "__main__": - main() diff --git a/scripts/setup_database.py b/scripts/setup_database.py deleted file mode 100644 index 918d1e6..0000000 --- a/scripts/setup_database.py +++ /dev/null @@ -1,1233 +0,0 @@ -"""Utilities to bootstrap the CalMiner PostgreSQL database. - -This script is designed to be idempotent. Each step checks the existing -state before attempting to modify it so repeated executions are safe. - -Environment variables (with defaults) used when establishing connections: - -* ``DATABASE_DRIVER`` (``postgresql``) -* ``DATABASE_HOST`` (required) -* ``DATABASE_PORT`` (``5432``) -* ``DATABASE_NAME`` (required) -* ``DATABASE_USER`` (required) -* ``DATABASE_PASSWORD`` (optional, required for password auth) -* ``DATABASE_SCHEMA`` (``public``) -* ``DATABASE_ADMIN_URL`` (overrides individual admin settings) -* ``DATABASE_SUPERUSER`` (falls back to ``DATABASE_USER`` or ``postgres``) -* ``DATABASE_SUPERUSER_PASSWORD`` (falls back to ``DATABASE_PASSWORD``) -* ``DATABASE_SUPERUSER_DB`` (``postgres``) - -Set ``DATABASE_URL`` if other parts of the application rely on a single -connection string; this script will still honor the granular inputs above. -""" - -from __future__ import annotations -from config.database import Base -import argparse -import importlib -import logging -import os -import pkgutil -import sys -from dataclasses import dataclass -from pathlib import Path -from typing import Callable, Optional, cast -from urllib.parse import quote_plus, urlencode -import psycopg2 -from psycopg2 import errors -from psycopg2 import sql -from psycopg2 import extensions -from psycopg2.extensions import connection as PGConnection, parse_dsn -from dotenv import load_dotenv -from sqlalchemy import create_engine, inspect - -ROOT_DIR = Path(__file__).resolve().parents[1] -if str(ROOT_DIR) not in sys.path: - sys.path.insert(0, str(ROOT_DIR)) - - -logger = logging.getLogger(__name__) - -SCRIPTS_DIR = Path(__file__).resolve().parent -DEFAULT_MIGRATIONS_DIR = SCRIPTS_DIR / "migrations" -MIGRATIONS_TABLE = "schema_migrations" - - -@dataclass(slots=True) -class DatabaseConfig: - """Configuration required to manage the application database.""" - - driver: str - host: str - port: int - database: str - user: str - password: Optional[str] - schema: Optional[str] - - admin_user: str - admin_password: Optional[str] - admin_database: str = "postgres" - - @classmethod - def from_env( - cls, - overrides: Optional[dict[str, Optional[str]]] = None, - ) -> "DatabaseConfig": - load_dotenv() - - override_map: dict[str, Optional[str]] = dict(overrides or {}) - - def _get(name: str, default: Optional[str] = None) -> Optional[str]: - if name in override_map and override_map[name] is not None: - return override_map[name] - env_value = os.getenv(name) - if env_value is not None: - return env_value - return default - - driver = _get("DATABASE_DRIVER", "postgresql") - host = _get("DATABASE_HOST") - port_value = _get("DATABASE_PORT", "5432") - database = _get("DATABASE_NAME") - user = _get("DATABASE_USER") - password = _get("DATABASE_PASSWORD") - schema = _get("DATABASE_SCHEMA", "public") - - try: - port = int(port_value) if port_value is not None else 5432 - except ValueError as exc: - raise RuntimeError( - "Invalid DATABASE_PORT value: expected integer, got" - f" '{port_value}'" - ) from exc - - admin_url = _get("DATABASE_ADMIN_URL") - if admin_url: - admin_conninfo = parse_dsn(admin_url) - admin_user = admin_conninfo.get("user") or user or "postgres" - admin_password = admin_conninfo.get("password") - admin_database = admin_conninfo.get("dbname") or "postgres" - host = admin_conninfo.get("host") or host - port = int(admin_conninfo.get("port") or port) - else: - admin_user = _get("DATABASE_SUPERUSER", user or "postgres") - admin_password = _get("DATABASE_SUPERUSER_PASSWORD", password) - admin_database = _get("DATABASE_SUPERUSER_DB", "postgres") - - missing = [ - name - for name, value in ( - ("DATABASE_HOST", host), - ("DATABASE_NAME", database), - ("DATABASE_USER", user), - ) - if not value - ] - if missing: - raise RuntimeError( - "Missing required database configuration: " + - ", ".join(missing) - ) - - host = cast(str, host) - database = cast(str, database) - user = cast(str, user) - driver = cast(str, driver) - admin_user = cast(str, admin_user) - admin_database = cast(str, admin_database) - - return cls( - driver=driver, - host=host, - port=port, - database=database, - user=user, - password=password, - schema=schema, - admin_user=admin_user, - admin_password=admin_password, - admin_database=admin_database, - ) - - def admin_dsn(self, database: Optional[str] = None) -> str: - target_db = database or self.admin_database - return self._compose_url( - user=self.admin_user, - password=self.admin_password, - database=target_db, - schema=None, - ) - - def application_dsn(self) -> str: - """Return a SQLAlchemy URL for connecting as the application role.""" - - return self._compose_url( - user=self.user, - password=self.password, - database=self.database, - schema=self.schema, - ) - - def _compose_url( - self, - *, - user: Optional[str], - password: Optional[str], - database: str, - schema: Optional[str], - ) -> str: - auth = "" - if user: - encoded_user = quote_plus(user) - if password: - encoded_pass = quote_plus(password) - auth = f"{encoded_user}:{encoded_pass}@" - else: - auth = f"{encoded_user}@" - - host = self.host - if ":" in host and not host.startswith("["): - host = f"[{host}]" - - host_port = host - if self.port: - host_port = f"{host}:{self.port}" - - url = f"{self.driver}://{auth}{host_port}/{database}" - - params = {} - if schema and schema.strip() and schema != "public": - params["options"] = f"-csearch_path={schema}" - - if params: - url = f"{url}?{urlencode(params, quote_via=quote_plus)}" - - return url - - -class DatabaseSetup: - """Encapsulates the full setup workflow.""" - - def __init__( - self, config: DatabaseConfig, *, dry_run: bool = False - ) -> None: - self.config = config - self.dry_run = dry_run - self._models_loaded = False - self._rollback_actions: list[tuple[str, Callable[[], None]]] = [] - - def _register_rollback( - self, label: str, action: Callable[[], None] - ) -> None: - if self.dry_run: - return - self._rollback_actions.append((label, action)) - - def execute_rollbacks(self) -> None: - if not self._rollback_actions: - logger.info("No rollback actions registered; nothing to undo.") - return - - logger.warning( - "Attempting rollback of %d action(s)", len(self._rollback_actions) - ) - for label, action in reversed(self._rollback_actions): - try: - logger.warning("Rollback step: %s", label) - action() - except Exception: - logger.exception("Rollback action '%s' failed", label) - self._rollback_actions.clear() - - def clear_rollbacks(self) -> None: - self._rollback_actions.clear() - - def _describe_connection(self, user: str, database: str) -> str: - return f"{user}@{self.config.host}:{self.config.port}/{database}" - - def validate_admin_connection(self) -> None: - descriptor = self._describe_connection( - self.config.admin_user, self.config.admin_database - ) - logger.info("[CONNECT] Validating admin connection (%s)", descriptor) - try: - with self._admin_connection(self.config.admin_database) as conn: - with conn.cursor() as cursor: - cursor.execute("SELECT 1") - except psycopg2.Error as exc: - raise RuntimeError( - "Unable to connect with admin credentials. " - "Check DATABASE_ADMIN_URL or DATABASE_SUPERUSER settings." - f" Target: {descriptor}" - ) from exc - logger.info("[CONNECT] Admin connection verified (%s)", descriptor) - - def validate_application_connection(self) -> None: - descriptor = self._describe_connection( - self.config.user, self.config.database - ) - logger.info( - "[CONNECT] Validating application connection (%s)", descriptor) - try: - with self._application_connection() as conn: - with conn.cursor() as cursor: - cursor.execute("SELECT 1") - except psycopg2.Error as exc: - raise RuntimeError( - "Unable to connect using application credentials. " - "Ensure the role exists and credentials are correct. " - f"Target: {descriptor}" - ) from exc - logger.info( - "[CONNECT] Application connection verified (%s)", descriptor) - - def ensure_database(self) -> None: - """Create the target database when it does not already exist.""" - - logger.info("Ensuring database '%s' exists", self.config.database) - try: - conn = self._admin_connection(self.config.admin_database) - except RuntimeError: - logger.error( - "Could not connect to admin database '%s' while creating '%s'.", - self.config.admin_database, - self.config.database, - ) - raise - try: - conn.autocommit = True - conn.set_isolation_level(extensions.ISOLATION_LEVEL_AUTOCOMMIT) - cursor = conn.cursor() - try: - try: - cursor.execute( - "SELECT 1 FROM pg_database WHERE datname = %s", - (self.config.database,), - ) - except psycopg2.Error as exc: - message = ( - "Unable to inspect existing databases while ensuring '%s'." - " Verify admin permissions." - ) % self.config.database - logger.error(message) - raise RuntimeError(message) from exc - - exists = cursor.fetchone() is not None - if exists: - logger.info( - "Database '%s' already present", self.config.database - ) - return - - if self.dry_run: - logger.info( - "Dry run: would create database '%s'. Run without --dry-run to proceed.", - self.config.database, - ) - return - - try: - cursor.execute( - sql.SQL("CREATE DATABASE {} ENCODING 'UTF8'").format( - sql.Identifier(self.config.database) - ) - ) - except psycopg2.Error as exc: - message = ( - "Failed to create database '%s'. Rerun with --dry-run for diagnostics" - ) % self.config.database - logger.error(message) - raise RuntimeError(message) from exc - else: - rollback_label = f"drop database {self.config.database}" - self._register_rollback( - rollback_label, - lambda db=self.config.database: self._drop_database( - db), - ) - logger.info("Created database '%s'", self.config.database) - finally: - cursor.close() - finally: - conn.close() - - def ensure_role(self) -> None: - """Create the application role and assign privileges when missing.""" - - logger.info("Ensuring role '%s' exists", self.config.user) - try: - admin_conn = self._admin_connection(self.config.admin_database) - except RuntimeError: - logger.error( - "Unable to connect with admin credentials while ensuring role '%s'", - self.config.user, - ) - raise - - with admin_conn as conn: - conn.autocommit = True - with conn.cursor() as cursor: - try: - cursor.execute( - "SELECT 1 FROM pg_roles WHERE rolname = %s", - (self.config.user,), - ) - except psycopg2.Error as exc: - message = ( - "Unable to inspect existing roles while ensuring role '%s'." - " Verify admin permissions." - ) % self.config.user - logger.error(message) - raise RuntimeError(message) from exc - role_exists = cursor.fetchone() is not None - if not role_exists: - logger.info("Creating role '%s'", self.config.user) - if self.dry_run: - logger.info( - "Dry run: would create role '%s'. Run without --dry-run to apply.", - self.config.user, - ) - return - try: - if self.config.password: - cursor.execute( - sql.SQL( - "CREATE ROLE {} WITH LOGIN PASSWORD %s" - ).format(sql.Identifier(self.config.user)), - (self.config.password,), - ) - else: - cursor.execute( - sql.SQL("CREATE ROLE {} WITH LOGIN").format( - sql.Identifier(self.config.user) - ) - ) - except psycopg2.Error as exc: - message = ( - "Failed to create role '%s'. Review admin privileges and rerun." - ) % self.config.user - logger.error(message) - raise RuntimeError(message) from exc - else: - rollback_label = f"drop role {self.config.user}" - self._register_rollback( - rollback_label, - lambda role=self.config.user: self._drop_role( - role), - ) - else: - logger.info("Role '%s' already present", self.config.user) - - try: - role_conn = self._admin_connection(self.config.database) - except RuntimeError: - logger.error( - "Unable to connect to application database '%s' while granting privileges to role '%s'", - self.config.database, - self.config.user, - ) - raise - - if self.dry_run: - logger.info( - "Dry run: would grant privileges on schema/database to role '%s'.", - self.config.user, - ) - return - - with role_conn as conn: - conn.autocommit = True - with conn.cursor() as cursor: - schema_name = self.config.schema or "public" - schema_identifier = sql.Identifier(schema_name) - role_identifier = sql.Identifier(self.config.user) - - try: - cursor.execute( - sql.SQL("GRANT CONNECT ON DATABASE {} TO {}").format( - sql.Identifier(self.config.database), - role_identifier, - ) - ) - cursor.execute( - sql.SQL("GRANT USAGE ON SCHEMA {} TO {}").format( - schema_identifier, - role_identifier, - ) - ) - cursor.execute( - sql.SQL("GRANT CREATE ON SCHEMA {} TO {}").format( - schema_identifier, - role_identifier, - ) - ) - cursor.execute( - sql.SQL( - "GRANT ALL PRIVILEGES ON ALL TABLES IN SCHEMA {} TO {}" - ).format( - schema_identifier, - role_identifier, - ) - ) - cursor.execute( - sql.SQL( - "GRANT ALL PRIVILEGES ON ALL SEQUENCES IN SCHEMA {} TO {}" - ).format( - schema_identifier, - role_identifier, - ) - ) - cursor.execute( - sql.SQL( - "ALTER DEFAULT PRIVILEGES IN SCHEMA {} GRANT SELECT, INSERT, UPDATE, DELETE ON TABLES TO {}" - ).format( - schema_identifier, - role_identifier, - ) - ) - cursor.execute( - sql.SQL( - "ALTER DEFAULT PRIVILEGES IN SCHEMA {} GRANT USAGE, SELECT ON SEQUENCES TO {}" - ).format( - schema_identifier, - role_identifier, - ) - ) - except psycopg2.Error as exc: - message = ( - "Failed to grant privileges to role '%s' in schema '%s'." - " Rerun with --dry-run for more context." - ) % (self.config.user, schema_name) - logger.error(message) - raise RuntimeError(message) from exc - logger.info( - "Granted privileges on schema '%s' to role '%s'", - schema_name, - self.config.user, - ) - rollback_label = f"revoke privileges for {self.config.user}" - self._register_rollback( - rollback_label, - lambda schema=schema_name: self._revoke_role_privileges( - schema_name=schema - ), - ) - - def ensure_schema(self) -> None: - """Create the configured schema when it does not exist.""" - - schema_name = self.config.schema - if not schema_name or schema_name == "public": - logger.info("Using default schema 'public'; nothing to ensure") - return - - logger.info("Ensuring schema '%s' exists", schema_name) - with self._admin_connection(self.config.database) as conn: - conn.autocommit = True - with conn.cursor() as cursor: - cursor.execute( - sql.SQL( - "SELECT 1 FROM information_schema.schemata WHERE schema_name = %s" - ), - (schema_name,), - ) - exists = cursor.fetchone() is not None - if not exists: - if self.dry_run: - logger.info( - "Dry run: would create schema '%s'", - schema_name, - ) - else: - cursor.execute( - sql.SQL("CREATE SCHEMA {}").format( - sql.Identifier(schema_name) - ) - ) - logger.info("Created schema '%s'", schema_name) - try: - if self.dry_run: - logger.info( - "Dry run: would set schema '%s' owner to '%s'", - schema_name, - self.config.user, - ) - else: - cursor.execute( - sql.SQL("ALTER SCHEMA {} OWNER TO {}").format( - sql.Identifier(schema_name), - sql.Identifier(self.config.user), - ) - ) - except errors.UndefinedObject: - logger.warning( - "Role '%s' not found when assigning ownership to schema '%s'." - " Run --ensure-role after creating the schema.", - self.config.user, - schema_name, - ) - - def application_role_exists(self) -> bool: - try: - with self._admin_connection(self.config.admin_database) as conn: - with conn.cursor() as cursor: - try: - cursor.execute( - "SELECT 1 FROM pg_roles WHERE rolname = %s", - (self.config.user,), - ) - except psycopg2.Error as exc: - message = ( - "Unable to inspect existing roles while checking for role '%s'." - " Verify admin permissions." - ) % self.config.user - logger.error(message) - raise RuntimeError(message) from exc - return cursor.fetchone() is not None - except RuntimeError: - raise - - def _connect(self, dsn: str, descriptor: str) -> PGConnection: - try: - return psycopg2.connect(dsn) - except psycopg2.Error as exc: - raise RuntimeError( - f"Unable to establish connection. Target: {descriptor}" - ) from exc - - def _admin_connection(self, database: Optional[str] = None) -> PGConnection: - target_db = database or self.config.admin_database - dsn = self.config.admin_dsn(database) - descriptor = self._describe_connection( - self.config.admin_user, target_db - ) - return self._connect(dsn, descriptor) - - def _application_connection(self) -> PGConnection: - dsn = self.config.application_dsn() - descriptor = self._describe_connection( - self.config.user, self.config.database - ) - return self._connect(dsn, descriptor) - - def initialize_schema(self) -> None: - """Create database objects from SQLAlchemy metadata if missing.""" - - self._ensure_models_loaded() - logger.info("Ensuring SQLAlchemy metadata is reflected in database") - engine = create_engine(self.config.application_dsn(), future=True) - try: - inspector = inspect(engine) - existing_tables = set( - inspector.get_table_names(schema=self.config.schema) - ) - metadata_tables = set(Base.metadata.tables.keys()) - missing_tables = sorted(metadata_tables - existing_tables) - - if missing_tables: - logger.info("Pending tables: %s", ", ".join(missing_tables)) - else: - logger.info("All tables already exist") - - if self.dry_run: - if missing_tables: - logger.info("Dry run: skipping creation of pending tables") - return - - Base.metadata.create_all(bind=engine, checkfirst=True) - finally: - engine.dispose() - - logger.info("Schema initialization complete") - - def _ensure_models_loaded(self) -> None: - if self._models_loaded: - return - - package = importlib.import_module("models") - for module_info in pkgutil.iter_modules(package.__path__): - importlib.import_module(f"{package.__name__}.{module_info.name}") - self._models_loaded = True - - def run_migrations( - self, migrations_dir: Optional[Path | str] = None - ) -> None: - """Execute pending SQL migrations in chronological order.""" - - directory = ( - Path(migrations_dir) - if migrations_dir is not None - else DEFAULT_MIGRATIONS_DIR - ) - directory = directory.resolve() - - if not directory.exists(): - logger.warning("Migrations directory '%s' not found", directory) - return - - migration_files = sorted(directory.glob("*.sql")) - if not migration_files: - logger.info("No migration scripts found in '%s'", directory) - return - - baseline_name = "000_base.sql" - baseline_path = directory / baseline_name - - schema_name = self.config.schema or "public" - - with self._application_connection() as conn: - conn.autocommit = True - with conn.cursor() as cursor: - table_exists = self._migrations_table_exists( - cursor, schema_name - ) - if not table_exists: - if self.dry_run: - logger.info( - "Dry run: would create migration history table %s.%s", - schema_name, - MIGRATIONS_TABLE, - ) - applied: set[str] = set() - else: - self._create_migrations_table(cursor, schema_name) - logger.info( - "Created migration history table %s.%s", - schema_name, - MIGRATIONS_TABLE, - ) - applied = set() - else: - applied = self._fetch_applied_migrations( - cursor, schema_name - ) - - self._handle_baseline_migration( - cursor, schema_name, baseline_path, baseline_name, migration_files, applied - ) - - pending = [ - path for path in migration_files if path.name not in applied - ] - - if not pending: - logger.info("No pending migrations") - return - - logger.info( - "Pending migrations: %s", - ", ".join(path.name for path in pending), - ) - - if self.dry_run: - logger.info("Dry run: skipping migration execution") - return - - for path in pending: - self._apply_migration_file(cursor, schema_name, path) - - logger.info("Applied %d migrations", len(pending)) - - def _handle_baseline_migration( - self, - cursor: extensions.cursor, - schema_name: str, - baseline_path: Path, - baseline_name: str, - migration_files: list[Path], - applied: set[str], - ) -> None: - if baseline_path.exists() and baseline_name not in applied: - if self.dry_run: - logger.info( - "Dry run: baseline migration '%s' pending; would apply and mark legacy files", - baseline_name, - ) - else: - logger.info( - "[MIGRATE] Baseline migration '%s' pending; applying and marking older migrations", - baseline_name, - ) - try: - baseline_applied = self._apply_migration_file( - cursor, schema_name, baseline_path - ) - except Exception: - logger.error( - "Failed while applying baseline migration '%s'." - " Review the migration contents and rerun with --dry-run for diagnostics.", - baseline_name, - exc_info=True, - ) - raise - applied.add(baseline_applied) - self._mark_legacy_migrations_as_applied( - cursor, schema_name, migration_files, baseline_name, applied - ) - - def _mark_legacy_migrations_as_applied( - self, - cursor: extensions.cursor, - schema_name: str, - migration_files: list[Path], - baseline_name: str, - applied: set[str], - ) -> None: - legacy_files = [ - path - for path in migration_files - if path.name != baseline_name - ] - for legacy in legacy_files: - if legacy.name not in applied: - try: - cursor.execute( - sql.SQL( - "INSERT INTO {} (filename, applied_at) VALUES (%s, NOW())" - ).format( - sql.Identifier( - schema_name, - MIGRATIONS_TABLE, - ) - ), - (legacy.name,), - ) - except Exception: - logger.error( - "Unable to record legacy migration '%s' after baseline application." - " Check schema_migrations table in schema '%s' for partial state.", - legacy.name, - schema_name, - exc_info=True, - ) - raise - applied.add(legacy.name) - logger.info( - "Marked legacy migration '%s' as applied via baseline", - legacy.name, - ) - - def _apply_migration_file( - self, - cursor, - schema_name: str, - path: Path, - ) -> str: - logger.info("Applying migration '%s'", path.name) - sql_text = path.read_text(encoding="utf-8") - try: - cursor.execute(sql_text) - cursor.execute( - sql.SQL( - "INSERT INTO {} (filename, applied_at) VALUES (%s, NOW())" - ).format(sql.Identifier(schema_name, MIGRATIONS_TABLE)), - (path.name,), - ) - return path.name - except Exception: - logger.exception("Failed to apply migration '%s'", path.name) - raise - - def _migrations_table_exists(self, cursor, schema_name: str) -> bool: - cursor.execute( - """ - SELECT 1 - FROM information_schema.tables - WHERE table_schema = %s AND table_name = %s - """, - (schema_name, MIGRATIONS_TABLE), - ) - return cursor.fetchone() is not None - - def _create_migrations_table(self, cursor, schema_name: str) -> None: - cursor.execute( - sql.SQL( - "CREATE TABLE IF NOT EXISTS {} (" - "filename TEXT PRIMARY KEY," - "applied_at TIMESTAMPTZ NOT NULL DEFAULT NOW()" - ")" - ).format(sql.Identifier(schema_name, MIGRATIONS_TABLE)) - ) - - def _fetch_applied_migrations(self, cursor, schema_name: str) -> set[str]: - cursor.execute( - sql.SQL("SELECT filename FROM {} ORDER BY filename").format( - sql.Identifier(schema_name, MIGRATIONS_TABLE) - ) - ) - return {row[0] for row in cursor.fetchall()} - - def seed_baseline_data(self, *, dry_run: bool) -> None: - """Seed reference data such as currencies.""" - - from scripts import seed_data - - seed_args = argparse.Namespace( - currencies=True, - units=True, - theme=True, - defaults=False, - dry_run=dry_run, - verbose=0, - ) - try: - seed_data.run_with_namespace(seed_args, config=self.config) - except Exception: - logger.error( - "[SEED] Failed during baseline data seeding. " - "Review seed_data.py and rerun with --dry-run for diagnostics.", - exc_info=True, - ) - raise - - if dry_run: - logger.info("[SEED] Dry run: skipped seed verification") - return - - expected_currencies = { - code for code, *_ in getattr(seed_data, "CURRENCY_SEEDS", ()) - } - expected_units = { - code - for code, *_ in getattr(seed_data, "MEASUREMENT_UNIT_SEEDS", ()) - } - self._verify_seeded_data( - expected_currency_codes=expected_currencies, - expected_unit_codes=expected_units, - ) - - def _verify_seeded_data( - self, - *, - expected_currency_codes: set[str], - expected_unit_codes: set[str], - ) -> None: - if not expected_currency_codes and not expected_unit_codes: - logger.info("No seed datasets configured for verification") - return - - with self._application_connection() as conn: - with conn.cursor() as cursor: - if expected_currency_codes: - cursor.execute( - "SELECT code, is_active FROM currency WHERE code = ANY(%s)", - (list(expected_currency_codes),), - ) - rows = cursor.fetchall() - found_codes = {row[0] for row in rows} - missing_codes = sorted( - expected_currency_codes - found_codes - ) - if missing_codes: - message = ( - "Missing expected currencies after seeding: %s. " - "Run scripts/seed_data.py --currencies to restore them." - ) % ", ".join(missing_codes) - logger.error(message) - raise RuntimeError(message) - - logger.info( - "[VERIFY] Verified %d seeded currencies present", - len(found_codes), - ) - - default_status = next( - (row[1] for row in rows if row[0] == "USD"), None - ) - if default_status is False: - message = ( - "Default currency 'USD' is inactive after seeding. " - "Reactivate it or rerun the seeding command." - ) - logger.error(message) - raise RuntimeError(message) - elif default_status is None: - message = ( - "Default currency 'USD' not found after seeding. " - "Ensure baseline migration 000_base.sql ran successfully." - ) - logger.error(message) - raise RuntimeError(message) - else: - logger.info( - "[VERIFY] Verified default currency 'USD' active") - - if expected_unit_codes: - try: - cursor.execute( - "SELECT code, is_active FROM measurement_unit WHERE code = ANY(%s)", - (list(expected_unit_codes),), - ) - except errors.UndefinedTable: - conn.rollback() - message = ( - "measurement_unit table not found during seed verification. " - "Ensure baseline migration 000_base.sql has been applied." - ) - logger.error(message) - raise RuntimeError(message) - else: - rows = cursor.fetchall() - found_units = {row[0] for row in rows} - missing_units = sorted( - expected_unit_codes - found_units - ) - if missing_units: - message = ( - "Missing expected measurement units after seeding: %s. " - "Run scripts/seed_data.py --units to restore them." - ) % ", ".join(missing_units) - logger.error(message) - raise RuntimeError(message) - - inactive_units = sorted( - row[0] for row in rows if not bool(row[1]) - ) - if inactive_units: - message = ( - "Measurement units inactive after seeding: %s. " - "Reactivate them or rerun unit seeding." - ) % ", ".join(inactive_units) - logger.error(message) - raise RuntimeError(message) - - logger.info( - "Verified %d measurement units present", - len(found_units), - ) - - logger.info("Seed verification complete") - - def _drop_database(self, database: str) -> None: - logger.warning("Rollback: dropping database '%s'", database) - with self._admin_connection(self.config.admin_database) as conn: - conn.autocommit = True - with conn.cursor() as cursor: - cursor.execute( - "SELECT pg_terminate_backend(pid) FROM pg_stat_activity WHERE datname = %s", - (database,), - ) - cursor.execute( - sql.SQL("DROP DATABASE IF EXISTS {}").format( - sql.Identifier(database) - ) - ) - - def _drop_role(self, role: str) -> None: - logger.warning("Rollback: dropping role '%s'", role) - with self._admin_connection(self.config.admin_database) as conn: - conn.autocommit = True - with conn.cursor() as cursor: - cursor.execute( - sql.SQL("DROP ROLE IF EXISTS {}").format( - sql.Identifier(role) - ) - ) - - def _revoke_role_privileges(self, *, schema_name: str) -> None: - logger.warning( - "Rollback: revoking privileges on schema '%s' for role '%s'", - schema_name, - self.config.user, - ) - with self._admin_connection(self.config.database) as conn: - conn.autocommit = True - with conn.cursor() as cursor: - cursor.execute( - sql.SQL( - "REVOKE ALL PRIVILEGES ON ALL TABLES IN SCHEMA {} FROM {}" - ).format( - sql.Identifier(schema_name), - sql.Identifier(self.config.user), - ) - ) - cursor.execute( - sql.SQL( - "REVOKE ALL PRIVILEGES ON ALL SEQUENCES IN SCHEMA {} FROM {}" - ).format( - sql.Identifier(schema_name), - sql.Identifier(self.config.user), - ) - ) - cursor.execute( - sql.SQL( - "ALTER DEFAULT PRIVILEGES IN SCHEMA {} REVOKE SELECT, INSERT, UPDATE, DELETE ON TABLES FROM {}" - ).format( - sql.Identifier(schema_name), - sql.Identifier(self.config.user), - ) - ) - cursor.execute( - sql.SQL( - "ALTER DEFAULT PRIVILEGES IN SCHEMA {} REVOKE USAGE, SELECT ON SEQUENCES FROM {}" - ).format( - sql.Identifier(schema_name), - sql.Identifier(self.config.user), - ) - ) - - -def parse_args() -> argparse.Namespace: - parser = argparse.ArgumentParser(description="Bootstrap CalMiner database") - parser.add_argument( - "--ensure-database", - action="store_true", - help="Create the application database when it does not already exist.", - ) - parser.add_argument( - "--ensure-role", - action="store_true", - help="Create the application role and grant necessary privileges.", - ) - parser.add_argument( - "--ensure-schema", - action="store_true", - help="Create the configured schema if it does not exist.", - ) - parser.add_argument( - "--initialize-schema", - action="store_true", - help="Create missing tables based on SQLAlchemy models.", - ) - parser.add_argument( - "--run-migrations", - action="store_true", - help="Execute pending SQL migrations.", - ) - parser.add_argument( - "--seed-data", - action="store_true", - help="Seed baseline reference data (currencies, etc.).", - ) - parser.add_argument( - "--migrations-dir", - default=None, - help="Override the default migrations directory.", - ) - parser.add_argument("--db-driver", help="Override DATABASE_DRIVER") - parser.add_argument("--db-host", help="Override DATABASE_HOST") - parser.add_argument("--db-port", type=int, help="Override DATABASE_PORT") - parser.add_argument("--db-name", help="Override DATABASE_NAME") - parser.add_argument("--db-user", help="Override DATABASE_USER") - parser.add_argument("--db-password", help="Override DATABASE_PASSWORD") - parser.add_argument("--db-schema", help="Override DATABASE_SCHEMA") - parser.add_argument( - "--admin-url", - help="Override DATABASE_ADMIN_URL for administrative operations", - ) - parser.add_argument( - "--admin-user", help="Override DATABASE_SUPERUSER for admin ops" - ) - parser.add_argument( - "--admin-password", - help="Override DATABASE_SUPERUSER_PASSWORD for admin ops", - ) - parser.add_argument( - "--admin-db", - help="Override DATABASE_SUPERUSER_DB for admin ops", - ) - parser.add_argument( - "--dry-run", - action="store_true", - help="Log actions without applying changes.", - ) - parser.add_argument( - "--verbose", - "-v", - action="count", - default=0, - help="Increase logging verbosity", - ) - return parser.parse_args() - - -def main() -> None: - args = parse_args() - level = logging.WARNING - (10 * min(args.verbose, 2)) - logging.basicConfig( - level=max(level, logging.INFO), format="%(levelname)s %(message)s" - ) - - override_args: dict[str, Optional[str]] = { - "DATABASE_DRIVER": args.db_driver, - "DATABASE_HOST": args.db_host, - "DATABASE_NAME": args.db_name, - "DATABASE_USER": args.db_user, - "DATABASE_PASSWORD": args.db_password, - "DATABASE_SCHEMA": args.db_schema, - "DATABASE_ADMIN_URL": args.admin_url, - "DATABASE_SUPERUSER": args.admin_user, - "DATABASE_SUPERUSER_PASSWORD": args.admin_password, - "DATABASE_SUPERUSER_DB": args.admin_db, - } - if args.db_port is not None: - override_args["DATABASE_PORT"] = str(args.db_port) - - config = DatabaseConfig.from_env(overrides=override_args) - setup = DatabaseSetup(config, dry_run=args.dry_run) - - admin_tasks_requested = ( - args.ensure_database or args.ensure_role or args.ensure_schema - ) - if admin_tasks_requested: - setup.validate_admin_connection() - - app_validated = False - - def ensure_application_connection_for(operation: str) -> bool: - nonlocal app_validated - if app_validated: - return True - if setup.dry_run and not setup.application_role_exists(): - logger.info( - "Dry run: skipping %s because application role '%s' does not exist yet.", - operation, - setup.config.user, - ) - return False - setup.validate_application_connection() - app_validated = True - return True - - should_run_migrations = args.run_migrations - auto_run_migrations_reason: Optional[str] = None - if args.seed_data and not should_run_migrations: - should_run_migrations = True - auto_run_migrations_reason = "Seed data requested without explicit --run-migrations; applying migrations first." - - try: - if args.ensure_database: - setup.ensure_database() - if args.ensure_role: - setup.ensure_role() - if args.ensure_schema: - setup.ensure_schema() - - if args.initialize_schema: - if ensure_application_connection_for( - "SQLAlchemy schema initialization" - ): - setup.initialize_schema() - if should_run_migrations: - if ensure_application_connection_for("migration execution"): - if auto_run_migrations_reason: - logger.info(auto_run_migrations_reason) - migrations_path = ( - Path(args.migrations_dir) if args.migrations_dir else None - ) - setup.run_migrations(migrations_path) - if args.seed_data: - if ensure_application_connection_for("baseline data seeding"): - setup.seed_baseline_data(dry_run=args.dry_run) - except Exception: - if not setup.dry_run: - setup.execute_rollbacks() - raise - finally: - if not setup.dry_run: - setup.clear_rollbacks() - - -if __name__ == "__main__": - main() diff --git a/scripts/verify_db.py b/scripts/verify_db.py new file mode 100644 index 0000000..5662710 --- /dev/null +++ b/scripts/verify_db.py @@ -0,0 +1,86 @@ +"""Verify DB initialization results: enums, roles, admin user, pricing_settings.""" +from __future__ import annotations +import logging +from sqlalchemy import create_engine, text +from config.database import DATABASE_URL + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + +ENUMS = [ + 'miningoperationtype', + 'scenariostatus', + 'financialcategory', + 'costbucket', + 'distributiontype', + 'stochasticvariable', + 'resourcetype', +] + +SQL_CHECK_ENUM = "SELECT typname FROM pg_type WHERE typname = ANY(:names)" +SQL_ROLES = "SELECT id, name, display_name FROM roles ORDER BY id" +SQL_ADMIN = "SELECT id, email, username, is_active, is_superuser FROM users WHERE id = 1" +SQL_USER_ROLES = "SELECT user_id, role_id, granted_by FROM user_roles WHERE user_id = 1" +SQL_PRICING = "SELECT id, slug, name, default_currency FROM pricing_settings WHERE slug = 'default'" + + +def run(): + engine = create_engine(DATABASE_URL, future=True) + with engine.connect() as conn: + print('Using DATABASE_URL:', DATABASE_URL) + # enums + res = conn.execute(text(SQL_CHECK_ENUM), dict(names=ENUMS)).fetchall() + found = [r[0] for r in res] + print('\nEnums found:') + for name in ENUMS: + print(f' {name}:', 'YES' if name in found else 'NO') + + # roles + try: + roles = conn.execute(text(SQL_ROLES)).fetchall() + print('\nRoles:') + if roles: + for r in roles: + print(f' id={r.id} name={r.name} display_name={r.display_name}') + else: + print(' (no roles found)') + except Exception as e: + print('\nRoles query failed:', e) + + # admin user + try: + admin = conn.execute(text(SQL_ADMIN)).fetchone() + print('\nAdmin user:') + if admin: + print(f' id={admin.id} email={admin.email} username={admin.username} is_active={admin.is_active} is_superuser={admin.is_superuser}') + else: + print(' (admin user not found)') + except Exception as e: + print('\nAdmin query failed:', e) + + # user_roles + try: + ur = conn.execute(text(SQL_USER_ROLES)).fetchall() + print('\nUser roles for user_id=1:') + if ur: + for row in ur: + print(f' user_id={row.user_id} role_id={row.role_id} granted_by={row.granted_by}') + else: + print(' (no user_roles rows for user_id=1)') + except Exception as e: + print('\nUser_roles query failed:', e) + + # pricing settings + try: + p = conn.execute(text(SQL_PRICING)).fetchone() + print('\nPricing settings (slug=default):') + if p: + print(f' id={p.id} slug={p.slug} name={p.name} default_currency={p.default_currency}') + else: + print(' (default pricing settings not found)') + except Exception as e: + print('\nPricing query failed:', e) + + +if __name__ == '__main__': + run() diff --git a/services/__init__.py b/services/__init__.py new file mode 100644 index 0000000..c452a0a --- /dev/null +++ b/services/__init__.py @@ -0,0 +1,12 @@ +"""Service layer utilities.""" + +from .pricing import calculate_pricing, PricingInput, PricingMetadata, PricingResult +from .calculations import calculate_profitability + +__all__ = [ + "calculate_pricing", + "PricingInput", + "PricingMetadata", + "PricingResult", + "calculate_profitability", +] diff --git a/services/authorization.py b/services/authorization.py new file mode 100644 index 0000000..3a19a39 --- /dev/null +++ b/services/authorization.py @@ -0,0 +1,104 @@ +from __future__ import annotations + +from typing import Iterable + +from models import Project, Role, Scenario, User +from services.exceptions import AuthorizationError, EntityNotFoundError +from services.repositories import ProjectRepository, ScenarioRepository +from services.unit_of_work import UnitOfWork + +READ_ROLES: frozenset[str] = frozenset( + {"viewer", "analyst", "project_manager", "admin"} +) +MANAGE_ROLES: frozenset[str] = frozenset({"project_manager", "admin"}) + + +def _user_role_names(user: User) -> set[str]: + roles: Iterable[Role] = getattr(user, "roles", []) or [] + return {role.name for role in roles} + + +def _require_project_repo(uow: UnitOfWork) -> ProjectRepository: + if not uow.projects: + raise RuntimeError("Project repository not initialised") + return uow.projects + + +def _require_scenario_repo(uow: UnitOfWork) -> ScenarioRepository: + if not uow.scenarios: + raise RuntimeError("Scenario repository not initialised") + return uow.scenarios + + +def _assert_user_can_access(user: User, *, require_manage: bool) -> None: + if not user.is_active: + raise AuthorizationError("User account is disabled.") + if user.is_superuser: + return + + allowed = MANAGE_ROLES if require_manage else READ_ROLES + if not _user_role_names(user) & allowed: + raise AuthorizationError( + "Insufficient role permissions for this action.") + + +def ensure_project_access( + uow: UnitOfWork, + *, + project_id: int, + user: User, + require_manage: bool = False, +) -> Project: + """Resolve a project and ensure the user holds the required permissions.""" + + repo = _require_project_repo(uow) + project = repo.get(project_id) + _assert_user_can_access(user, require_manage=require_manage) + return project + + +def ensure_scenario_access( + uow: UnitOfWork, + *, + scenario_id: int, + user: User, + require_manage: bool = False, + with_children: bool = False, +) -> Scenario: + """Resolve a scenario and ensure the user holds the required permissions.""" + + repo = _require_scenario_repo(uow) + scenario = repo.get(scenario_id, with_children=with_children) + _assert_user_can_access(user, require_manage=require_manage) + return scenario + + +def ensure_scenario_in_project( + uow: UnitOfWork, + *, + project_id: int, + scenario_id: int, + user: User, + require_manage: bool = False, + with_children: bool = False, +) -> Scenario: + """Resolve a scenario ensuring it belongs to the project and the user may access it.""" + + project = ensure_project_access( + uow, + project_id=project_id, + user=user, + require_manage=require_manage, + ) + scenario = ensure_scenario_access( + uow, + scenario_id=scenario_id, + user=user, + require_manage=require_manage, + with_children=with_children, + ) + if scenario.project_id != project.id: + raise EntityNotFoundError( + f"Scenario {scenario_id} does not belong to project {project_id}." + ) + return scenario diff --git a/services/bootstrap.py b/services/bootstrap.py new file mode 100644 index 0000000..8c918d6 --- /dev/null +++ b/services/bootstrap.py @@ -0,0 +1,182 @@ +from __future__ import annotations + +import logging +from dataclasses import dataclass +from typing import Callable + +from config.settings import AdminBootstrapSettings +from models import User +from services.pricing import PricingMetadata +from services.repositories import ( + PricingSettingsSeedResult, + ensure_default_roles, +) +from services.unit_of_work import UnitOfWork + + +logger = logging.getLogger(__name__) + + +@dataclass(slots=True) +class RoleBootstrapResult: + created: int + ensured: int + + +@dataclass(slots=True) +class AdminBootstrapResult: + created_user: bool + updated_user: bool + password_rotated: bool + roles_granted: int + + +@dataclass(slots=True) +class PricingBootstrapResult: + seed: PricingSettingsSeedResult + projects_assigned: int + + +def bootstrap_admin( + *, + settings: AdminBootstrapSettings, + unit_of_work_factory: Callable[[], UnitOfWork] = UnitOfWork, +) -> tuple[RoleBootstrapResult, AdminBootstrapResult]: + """Ensure default roles and administrator account exist.""" + + with unit_of_work_factory() as uow: + assert uow.roles is not None and uow.users is not None + + role_result = _bootstrap_roles(uow) + admin_result = _bootstrap_admin_user(uow, settings) + + logger.info( + "Admin bootstrap result: created_user=%s updated_user=%s password_rotated=%s roles_granted=%s", + admin_result.created_user, + admin_result.updated_user, + admin_result.password_rotated, + admin_result.roles_granted, + ) + return role_result, admin_result + + +def _bootstrap_roles(uow: UnitOfWork) -> RoleBootstrapResult: + assert uow.roles is not None + before = {role.name for role in uow.roles.list()} + ensure_default_roles(uow.roles) + after = {role.name for role in uow.roles.list()} + created = len(after - before) + return RoleBootstrapResult(created=created, ensured=len(after)) + + +def _bootstrap_admin_user( + uow: UnitOfWork, + settings: AdminBootstrapSettings, +) -> AdminBootstrapResult: + assert uow.users is not None and uow.roles is not None + + created_user = False + updated_user = False + password_rotated = False + roles_granted = 0 + + user = uow.users.get_by_email(settings.email, with_roles=True) + if user is None: + user = User( + email=settings.email, + username=settings.username, + password_hash=User.hash_password(settings.password), + is_active=True, + is_superuser=True, + ) + uow.users.create(user) + created_user = True + else: + if user.username != settings.username: + user.username = settings.username + updated_user = True + if not user.is_active: + user.is_active = True + updated_user = True + if not user.is_superuser: + user.is_superuser = True + updated_user = True + if settings.force_reset: + user.password_hash = User.hash_password(settings.password) + password_rotated = True + updated_user = True + uow.users.session.flush() + + user = uow.users.get(user.id, with_roles=True) + assert user is not None + + existing_roles = {role.name for role in user.roles} + for role_name in settings.roles: + role = uow.roles.get_by_name(role_name) + if role is None: + logger.warning( + "Bootstrap admin role '%s' is not defined; skipping assignment", + role_name, + ) + continue + if role.name in existing_roles: + continue + uow.users.assign_role( + user_id=user.id, + role_id=role.id, + granted_by=user.id, + ) + roles_granted += 1 + existing_roles.add(role.name) + + uow.users.session.flush() + + return AdminBootstrapResult( + created_user=created_user, + updated_user=updated_user, + password_rotated=password_rotated, + roles_granted=roles_granted, + ) + + +def bootstrap_pricing_settings( + *, + metadata: PricingMetadata, + unit_of_work_factory: Callable[[], UnitOfWork] = UnitOfWork, + default_slug: str = "default", +) -> PricingBootstrapResult: + """Ensure baseline pricing settings exist and projects reference them.""" + + with unit_of_work_factory() as uow: + seed_result = uow.ensure_default_pricing_settings( + metadata=metadata, + slug=default_slug, + ) + + assigned = 0 + if uow.projects: + default_settings = seed_result.settings + projects = uow.projects.list(with_pricing=True) + for project in projects: + if project.pricing_settings is None: + uow.set_project_pricing_settings(project, default_settings) + assigned += 1 + + # Capture logging-safe primitives while the UnitOfWork (and session) + # are still active to avoid DetachedInstanceError when accessing ORM + # instances outside the session scope. + seed_slug = seed_result.settings.slug if seed_result and seed_result.settings else None + seed_created = getattr(seed_result, "created", None) + seed_updated_fields = getattr(seed_result, "updated_fields", None) + seed_impurity_upserts = getattr(seed_result, "impurity_upserts", None) + + logger.info( + "Pricing bootstrap result: slug=%s created=%s updated_fields=%s impurity_upserts=%s projects_assigned=%s", + seed_slug, + seed_created, + seed_updated_fields, + seed_impurity_upserts, + assigned, + ) + + return PricingBootstrapResult(seed=seed_result, projects_assigned=assigned) diff --git a/services/calculations.py b/services/calculations.py new file mode 100644 index 0000000..ef82330 --- /dev/null +++ b/services/calculations.py @@ -0,0 +1,535 @@ +"""Service functions for financial calculations.""" + +from __future__ import annotations + +from collections import defaultdict +from statistics import fmean + +from services.currency import CurrencyValidationError, normalise_currency +from services.exceptions import ( + CapexValidationError, + OpexValidationError, + ProfitabilityValidationError, +) +from services.financial import ( + CashFlow, + ConvergenceError, + PaybackNotReachedError, + internal_rate_of_return, + net_present_value, + payback_period, +) +from services.pricing import PricingInput, PricingMetadata, PricingResult, calculate_pricing +from schemas.calculations import ( + CapexCalculationRequest, + CapexCalculationResult, + CapexCategoryBreakdown, + CapexComponentInput, + CapexTotals, + CapexTimelineEntry, + CashFlowEntry, + OpexCalculationRequest, + OpexCalculationResult, + OpexCategoryBreakdown, + OpexComponentInput, + OpexMetrics, + OpexParameters, + OpexTotals, + OpexTimelineEntry, + ProfitabilityCalculationRequest, + ProfitabilityCalculationResult, + ProfitabilityCosts, + ProfitabilityMetrics, +) + + +_FREQUENCY_MULTIPLIER = { + "daily": 365, + "weekly": 52, + "monthly": 12, + "quarterly": 4, + "annually": 1, +} + + +def _build_pricing_input( + request: ProfitabilityCalculationRequest, +) -> PricingInput: + """Construct a pricing input instance including impurity overrides.""" + + impurity_values: dict[str, float] = {} + impurity_thresholds: dict[str, float] = {} + impurity_penalties: dict[str, float] = {} + + for impurity in request.impurities: + code = impurity.name.strip() + if not code: + continue + code = code.upper() + if impurity.value is not None: + impurity_values[code] = float(impurity.value) + if impurity.threshold is not None: + impurity_thresholds[code] = float(impurity.threshold) + if impurity.penalty is not None: + impurity_penalties[code] = float(impurity.penalty) + + pricing_input = PricingInput( + metal=request.metal, + ore_tonnage=request.ore_tonnage, + head_grade_pct=request.head_grade_pct, + recovery_pct=request.recovery_pct, + payable_pct=request.payable_pct, + reference_price=request.reference_price, + treatment_charge=request.treatment_charge, + smelting_charge=request.smelting_charge, + moisture_pct=request.moisture_pct, + moisture_threshold_pct=request.moisture_threshold_pct, + moisture_penalty_per_pct=request.moisture_penalty_per_pct, + impurity_ppm=impurity_values, + impurity_thresholds=impurity_thresholds, + impurity_penalty_per_ppm=impurity_penalties, + premiums=request.premiums, + fx_rate=request.fx_rate, + currency_code=request.currency_code, + ) + + return pricing_input + + +def _generate_cash_flows( + *, + periods: int, + net_per_period: float, + capex: float, +) -> tuple[list[CashFlow], list[CashFlowEntry]]: + """Create cash flow structures for financial metric calculations.""" + + cash_flow_models: list[CashFlow] = [ + CashFlow(amount=-capex, period_index=0) + ] + cash_flow_entries: list[CashFlowEntry] = [ + CashFlowEntry( + period=0, + revenue=0.0, + opex=0.0, + sustaining_capex=0.0, + net=-capex, + ) + ] + + for period in range(1, periods + 1): + cash_flow_models.append( + CashFlow(amount=net_per_period, period_index=period)) + cash_flow_entries.append( + CashFlowEntry( + period=period, + revenue=0.0, + opex=0.0, + sustaining_capex=0.0, + net=net_per_period, + ) + ) + + return cash_flow_models, cash_flow_entries + + +def calculate_profitability( + request: ProfitabilityCalculationRequest, + *, + metadata: PricingMetadata, +) -> ProfitabilityCalculationResult: + """Calculate profitability metrics using pricing inputs and cost data.""" + + if request.periods <= 0: + raise ProfitabilityValidationError( + "Evaluation periods must be at least 1.", ["periods"] + ) + + pricing_input = _build_pricing_input(request) + try: + pricing_result: PricingResult = calculate_pricing( + pricing_input, metadata=metadata + ) + except CurrencyValidationError as exc: + raise ProfitabilityValidationError( + str(exc), ["currency_code"]) from exc + + periods = request.periods + revenue_total = float(pricing_result.net_revenue) + revenue_per_period = revenue_total / periods + + processing_total = float(request.opex) * periods + sustaining_total = float(request.sustaining_capex) * periods + capex = float(request.capex) + + net_per_period = ( + revenue_per_period + - float(request.opex) + - float(request.sustaining_capex) + ) + + cash_flow_models, cash_flow_entries = _generate_cash_flows( + periods=periods, + net_per_period=net_per_period, + capex=capex, + ) + + # Update per-period entries to include explicit costs for presentation + for entry in cash_flow_entries[1:]: + entry.revenue = revenue_per_period + entry.opex = float(request.opex) + entry.sustaining_capex = float(request.sustaining_capex) + entry.net = net_per_period + + discount_rate = (request.discount_rate or 0.0) / 100.0 + + npv_value = net_present_value(discount_rate, cash_flow_models) + + try: + irr_value = internal_rate_of_return(cash_flow_models) * 100.0 + except (ValueError, ZeroDivisionError, ConvergenceError): + irr_value = None + + try: + payback_value = payback_period(cash_flow_models) + except (ValueError, PaybackNotReachedError): + payback_value = None + + total_costs = processing_total + sustaining_total + capex + total_net = revenue_total - total_costs + + if revenue_total == 0: + margin_value = None + else: + margin_value = (total_net / revenue_total) * 100.0 + + currency = request.currency_code or pricing_result.currency + try: + currency = normalise_currency(currency) + except CurrencyValidationError as exc: + raise ProfitabilityValidationError( + str(exc), ["currency_code"]) from exc + + costs = ProfitabilityCosts( + opex_total=processing_total, + sustaining_capex_total=sustaining_total, + capex=capex, + ) + + metrics = ProfitabilityMetrics( + npv=npv_value, + irr=irr_value, + payback_period=payback_value, + margin=margin_value, + ) + + return ProfitabilityCalculationResult( + pricing=pricing_result, + costs=costs, + metrics=metrics, + cash_flows=cash_flow_entries, + currency=currency, + ) + + +def calculate_initial_capex( + request: CapexCalculationRequest, +) -> CapexCalculationResult: + """Aggregate capex components into totals and timelines.""" + + if not request.components: + raise CapexValidationError( + "At least one capex component is required for calculation.", + ["components"], + ) + + parameters = request.parameters + + base_currency = parameters.currency_code + if base_currency: + try: + base_currency = normalise_currency(base_currency) + except CurrencyValidationError as exc: + raise CapexValidationError( + str(exc), ["parameters.currency_code"] + ) from exc + + overall = 0.0 + category_totals: dict[str, float] = defaultdict(float) + timeline_totals: dict[int, float] = defaultdict(float) + normalised_components: list[CapexComponentInput] = [] + + for index, component in enumerate(request.components): + amount = float(component.amount) + overall += amount + + category_totals[component.category] += amount + + spend_year = component.spend_year or 0 + timeline_totals[spend_year] += amount + + component_currency = component.currency + if component_currency: + try: + component_currency = normalise_currency(component_currency) + except CurrencyValidationError as exc: + raise CapexValidationError( + str(exc), [f"components[{index}].currency"] + ) from exc + + if base_currency is None and component_currency: + base_currency = component_currency + elif ( + base_currency is not None + and component_currency is not None + and component_currency != base_currency + ): + raise CapexValidationError( + ( + "Component currency does not match the global currency. " + f"Expected {base_currency}, got {component_currency}." + ), + [f"components[{index}].currency"], + ) + + normalised_components.append( + CapexComponentInput( + id=component.id, + name=component.name, + category=component.category, + amount=amount, + currency=component_currency, + spend_year=component.spend_year, + notes=component.notes, + ) + ) + + contingency_pct = float(parameters.contingency_pct or 0.0) + contingency_amount = overall * (contingency_pct / 100.0) + grand_total = overall + contingency_amount + + category_breakdowns: list[CapexCategoryBreakdown] = [] + if category_totals: + for category, total in sorted(category_totals.items()): + share = (total / overall * 100.0) if overall else None + category_breakdowns.append( + CapexCategoryBreakdown( + category=category, + amount=total, + share=share, + ) + ) + + cumulative = 0.0 + timeline_entries: list[CapexTimelineEntry] = [] + for year, spend in sorted(timeline_totals.items()): + cumulative += spend + timeline_entries.append( + CapexTimelineEntry(year=year, spend=spend, cumulative=cumulative) + ) + + try: + currency = normalise_currency(base_currency) if base_currency else None + except CurrencyValidationError as exc: + raise CapexValidationError( + str(exc), ["parameters.currency_code"] + ) from exc + + totals = CapexTotals( + overall=overall, + contingency_pct=contingency_pct, + contingency_amount=contingency_amount, + with_contingency=grand_total, + by_category=category_breakdowns, + ) + + return CapexCalculationResult( + totals=totals, + timeline=timeline_entries, + components=normalised_components, + parameters=parameters, + options=request.options, + currency=currency, + ) + + +def calculate_opex( + request: OpexCalculationRequest, +) -> OpexCalculationResult: + """Aggregate opex components into annual totals and timeline.""" + + if not request.components: + raise OpexValidationError( + "At least one opex component is required for calculation.", + ["components"], + ) + + parameters: OpexParameters = request.parameters + base_currency = parameters.currency_code + if base_currency: + try: + base_currency = normalise_currency(base_currency) + except CurrencyValidationError as exc: + raise OpexValidationError( + str(exc), ["parameters.currency_code"] + ) from exc + + evaluation_horizon = parameters.evaluation_horizon_years or 1 + if evaluation_horizon <= 0: + raise OpexValidationError( + "Evaluation horizon must be at least 1 year.", + ["parameters.evaluation_horizon_years"], + ) + + escalation_pct = float(parameters.escalation_pct or 0.0) + apply_escalation = bool(parameters.apply_escalation) + + category_totals: dict[str, float] = defaultdict(float) + timeline_totals: dict[int, float] = defaultdict(float) + timeline_escalated: dict[int, float] = defaultdict(float) + normalised_components: list[OpexComponentInput] = [] + + max_period_end = evaluation_horizon + + for index, component in enumerate(request.components): + frequency = component.frequency.lower() + multiplier = _FREQUENCY_MULTIPLIER.get(frequency) + if multiplier is None: + raise OpexValidationError( + f"Unsupported frequency '{component.frequency}'.", + [f"components[{index}].frequency"], + ) + + unit_cost = float(component.unit_cost) + quantity = float(component.quantity) + annual_cost = unit_cost * quantity * multiplier + + period_start = component.period_start or 1 + period_end = component.period_end or evaluation_horizon + if period_end < period_start: + raise OpexValidationError( + ( + "Component period_end must be greater than or equal to " + "period_start." + ), + [f"components[{index}].period_end"], + ) + + max_period_end = max(max_period_end, period_end) + + component_currency = component.currency + if component_currency: + try: + component_currency = normalise_currency(component_currency) + except CurrencyValidationError as exc: + raise OpexValidationError( + str(exc), [f"components[{index}].currency"] + ) from exc + + if base_currency is None and component_currency: + base_currency = component_currency + elif ( + base_currency is not None + and component_currency is not None + and component_currency != base_currency + ): + raise OpexValidationError( + ( + "Component currency does not match the global currency. " + f"Expected {base_currency}, got {component_currency}." + ), + [f"components[{index}].currency"], + ) + + category_totals[component.category] += annual_cost + + for period in range(period_start, period_end + 1): + timeline_totals[period] += annual_cost + + normalised_components.append( + OpexComponentInput( + id=component.id, + name=component.name, + category=component.category, + unit_cost=unit_cost, + quantity=quantity, + frequency=frequency, + currency=component_currency, + period_start=period_start, + period_end=period_end, + notes=component.notes, + ) + ) + + evaluation_horizon = max(evaluation_horizon, max_period_end) + + try: + currency = normalise_currency(base_currency) if base_currency else None + except CurrencyValidationError as exc: + raise OpexValidationError( + str(exc), ["parameters.currency_code"] + ) from exc + + timeline_entries: list[OpexTimelineEntry] = [] + escalated_values: list[float] = [] + overall_annual = timeline_totals.get(1, 0.0) + escalated_total = 0.0 + + for period in range(1, evaluation_horizon + 1): + base_cost = timeline_totals.get(period, 0.0) + if apply_escalation: + factor = (1 + escalation_pct / 100.0) ** (period - 1) + else: + factor = 1.0 + escalated_cost = base_cost * factor + timeline_escalated[period] = escalated_cost + escalated_total += escalated_cost + timeline_entries.append( + OpexTimelineEntry( + period=period, + base_cost=base_cost, + escalated_cost=escalated_cost if apply_escalation else None, + ) + ) + escalated_values.append(escalated_cost) + + category_breakdowns: list[OpexCategoryBreakdown] = [] + total_base = sum(category_totals.values()) + for category, total in sorted(category_totals.items()): + share = (total / total_base * 100.0) if total_base else None + category_breakdowns.append( + OpexCategoryBreakdown( + category=category, + annual_cost=total, + share=share, + ) + ) + + metrics = OpexMetrics( + annual_average=fmean(escalated_values) if escalated_values else None, + cost_per_ton=None, + ) + + totals = OpexTotals( + overall_annual=overall_annual, + escalated_total=escalated_total if apply_escalation else None, + escalation_pct=escalation_pct if apply_escalation else None, + by_category=category_breakdowns, + ) + + return OpexCalculationResult( + totals=totals, + timeline=timeline_entries, + metrics=metrics, + components=normalised_components, + parameters=parameters, + options=request.options, + currency=currency, + ) + + +__all__ = [ + "calculate_profitability", + "calculate_initial_capex", + "calculate_opex", +] diff --git a/services/currency.py b/services/currency.py new file mode 100644 index 0000000..49d61d6 --- /dev/null +++ b/services/currency.py @@ -0,0 +1,43 @@ +"""Utilities for currency normalization within pricing and financial workflows.""" + +from __future__ import annotations + +import re +from dataclasses import dataclass + +VALID_CURRENCY_PATTERN = re.compile(r"^[A-Z]{3}$") + + +@dataclass(frozen=True) +class CurrencyValidationError(ValueError): + """Raised when a currency code fails validation.""" + + code: str + + def __str__(self) -> str: # pragma: no cover - dataclass repr not required in tests + return f"Invalid currency code: {self.code!r}" + + +def normalise_currency(code: str | None) -> str | None: + """Normalise currency codes to uppercase ISO-4217 values.""" + + if code is None: + return None + candidate = code.strip().upper() + if not VALID_CURRENCY_PATTERN.match(candidate): + raise CurrencyValidationError(candidate) + return candidate + + +def require_currency(code: str | None, default: str | None = None) -> str: + """Return normalised currency code, falling back to default when missing.""" + + normalised = normalise_currency(code) + if normalised is not None: + return normalised + if default is None: + raise CurrencyValidationError("") + fallback = normalise_currency(default) + if fallback is None: + raise CurrencyValidationError("") + return fallback diff --git a/services/exceptions.py b/services/exceptions.py new file mode 100644 index 0000000..0eb3b6a --- /dev/null +++ b/services/exceptions.py @@ -0,0 +1,61 @@ +"""Domain-level exceptions for service and repository layers.""" + +from dataclasses import dataclass +from typing import Sequence + + +class EntityNotFoundError(Exception): + """Raised when a requested entity cannot be located.""" + + +class EntityConflictError(Exception): + """Raised when attempting to create or update an entity that violates uniqueness.""" + + +class AuthorizationError(Exception): + """Raised when a user lacks permission to perform an action.""" + + +@dataclass(eq=False) +class ScenarioValidationError(Exception): + """Raised when scenarios fail comparison validation rules.""" + + code: str + message: str + scenario_ids: Sequence[int] | None = None + + def __str__(self) -> str: # pragma: no cover - mirrors message for logging + return self.message + + +@dataclass(eq=False) +class ProfitabilityValidationError(Exception): + """Raised when profitability calculation inputs fail domain validation.""" + + message: str + field_errors: Sequence[str] | None = None + + def __str__(self) -> str: # pragma: no cover - mirrors message for logging + return self.message + + +@dataclass(eq=False) +class CapexValidationError(Exception): + """Raised when capex calculation inputs fail domain validation.""" + + message: str + field_errors: Sequence[str] | None = None + + def __str__(self) -> str: # pragma: no cover - mirrors message for logging + return self.message + + +@dataclass(eq=False) +class OpexValidationError(Exception): + """Raised when opex calculation inputs fail domain validation.""" + + message: str + field_errors: Sequence[str] | None = None + + def __str__(self) -> str: # pragma: no cover - mirrors message for logging + return self.message diff --git a/services/export_query.py b/services/export_query.py new file mode 100644 index 0000000..7f6acf9 --- /dev/null +++ b/services/export_query.py @@ -0,0 +1,121 @@ +from __future__ import annotations + +from dataclasses import dataclass +from datetime import date, datetime +from typing import Iterable + +from models import MiningOperationType, ResourceType, ScenarioStatus +from services.currency import CurrencyValidationError, normalise_currency + + +def _normalise_lower_strings(values: Iterable[str]) -> tuple[str, ...]: + unique: set[str] = set() + for value in values: + if not value: + continue + trimmed = value.strip().lower() + if not trimmed: + continue + unique.add(trimmed) + return tuple(sorted(unique)) + + +def _normalise_upper_strings(values: Iterable[str | None]) -> tuple[str, ...]: + unique: set[str] = set() + for value in values: + if value is None: + continue + candidate = value if isinstance(value, str) else str(value) + candidate = candidate.strip() + if not candidate: + continue + try: + normalised = normalise_currency(candidate) + except CurrencyValidationError as exc: + raise ValueError(str(exc)) from exc + if normalised is None: + continue + unique.add(normalised) + return tuple(sorted(unique)) + + +@dataclass(slots=True, frozen=True) +class ProjectExportFilters: + """Filter parameters for project export queries.""" + + ids: tuple[int, ...] = () + names: tuple[str, ...] = () + name_contains: str | None = None + locations: tuple[str, ...] = () + operation_types: tuple[MiningOperationType, ...] = () + created_from: datetime | None = None + created_to: datetime | None = None + updated_from: datetime | None = None + updated_to: datetime | None = None + + def normalised_ids(self) -> tuple[int, ...]: + unique = {identifier for identifier in self.ids if identifier > 0} + return tuple(sorted(unique)) + + def normalised_names(self) -> tuple[str, ...]: + return _normalise_lower_strings(self.names) + + def normalised_locations(self) -> tuple[str, ...]: + return _normalise_lower_strings(self.locations) + + def name_search_pattern(self) -> str | None: + if not self.name_contains: + return None + pattern = self.name_contains.strip() + if not pattern: + return None + return f"%{pattern}%" + + +@dataclass(slots=True, frozen=True) +class ScenarioExportFilters: + """Filter parameters for scenario export queries.""" + + ids: tuple[int, ...] = () + project_ids: tuple[int, ...] = () + project_names: tuple[str, ...] = () + name_contains: str | None = None + statuses: tuple[ScenarioStatus, ...] = () + start_date_from: date | None = None + start_date_to: date | None = None + end_date_from: date | None = None + end_date_to: date | None = None + created_from: datetime | None = None + created_to: datetime | None = None + updated_from: datetime | None = None + updated_to: datetime | None = None + currencies: tuple[str, ...] = () + primary_resources: tuple[ResourceType, ...] = () + + def normalised_ids(self) -> tuple[int, ...]: + unique = {identifier for identifier in self.ids if identifier > 0} + return tuple(sorted(unique)) + + def normalised_project_ids(self) -> tuple[int, ...]: + unique = {identifier for identifier in self.project_ids if identifier > 0} + return tuple(sorted(unique)) + + def normalised_project_names(self) -> tuple[str, ...]: + return _normalise_lower_strings(self.project_names) + + def name_search_pattern(self) -> str | None: + if not self.name_contains: + return None + pattern = self.name_contains.strip() + if not pattern: + return None + return f"%{pattern}%" + + def normalised_currencies(self) -> tuple[str, ...]: + return _normalise_upper_strings(self.currencies) + + +__all__ = ( + "ProjectExportFilters", + "ScenarioExportFilters", +) diff --git a/services/export_serializers.py b/services/export_serializers.py new file mode 100644 index 0000000..5242a15 --- /dev/null +++ b/services/export_serializers.py @@ -0,0 +1,351 @@ +from __future__ import annotations + +import csv +from dataclasses import dataclass, field +from datetime import date, datetime, timezone +from decimal import Decimal, InvalidOperation, ROUND_HALF_UP +from enum import Enum +from io import BytesIO, StringIO +from typing import Any, Callable, Iterable, Iterator, Mapping, Sequence + +from openpyxl import Workbook +CSVValueFormatter = Callable[[Any], str] +Accessor = Callable[[Any], Any] + +__all__ = [ + "CSVExportColumn", + "CSVExporter", + "default_project_columns", + "default_scenario_columns", + "stream_projects_to_csv", + "stream_scenarios_to_csv", + "ExcelExporter", + "export_projects_to_excel", + "export_scenarios_to_excel", + "default_formatter", + "format_datetime_utc", + "format_date_iso", + "format_decimal", +] + + +@dataclass(slots=True) +class CSVExportColumn: + """Declarative description of a CSV export column.""" + + header: str + accessor: Accessor | str + formatter: CSVValueFormatter | None = None + required: bool = False + + _accessor: Accessor = field(init=False, repr=False) + + def __post_init__(self) -> None: + object.__setattr__(self, "_accessor", _coerce_accessor(self.accessor)) + + def value_for(self, entity: Any) -> Any: + accessor = object.__getattribute__(self, "_accessor") + try: + return accessor(entity) + except Exception: # pragma: no cover - defensive safeguard + return None + + +class CSVExporter: + """Stream Python objects as UTF-8 encoded CSV rows.""" + + def __init__( + self, + columns: Sequence[CSVExportColumn], + *, + include_header: bool = True, + line_terminator: str = "\n", + ) -> None: + if not columns: + raise ValueError("At least one column is required for CSV export.") + self._columns: tuple[CSVExportColumn, ...] = tuple(columns) + self._include_header = include_header + self._line_terminator = line_terminator + + @property + def columns(self) -> tuple[CSVExportColumn, ...]: + return self._columns + + def headers(self) -> tuple[str, ...]: + return tuple(column.header for column in self._columns) + + def iter_bytes(self, records: Iterable[Any]) -> Iterator[bytes]: + buffer = StringIO() + writer = csv.writer(buffer, lineterminator=self._line_terminator) + + if self._include_header: + writer.writerow(self.headers()) + yield _drain_buffer(buffer) + + for record in records: + writer.writerow(self._format_row(record)) + yield _drain_buffer(buffer) + + def _format_row(self, record: Any) -> list[str]: + formatted: list[str] = [] + for column in self._columns: + raw_value = column.value_for(record) + formatter = column.formatter or default_formatter + formatted.append(formatter(raw_value)) + return formatted + + +def default_project_columns( + *, + include_description: bool = True, + include_timestamps: bool = True, +) -> tuple[CSVExportColumn, ...]: + columns: list[CSVExportColumn] = [ + CSVExportColumn("name", "name", required=True), + CSVExportColumn("location", "location"), + CSVExportColumn("operation_type", "operation_type"), + ] + if include_description: + columns.append(CSVExportColumn("description", "description")) + if include_timestamps: + columns.extend( + ( + CSVExportColumn("created_at", "created_at", + formatter=format_datetime_utc), + CSVExportColumn("updated_at", "updated_at", + formatter=format_datetime_utc), + ) + ) + return tuple(columns) + + +def default_scenario_columns( + *, + include_description: bool = True, + include_timestamps: bool = True, +) -> tuple[CSVExportColumn, ...]: + columns: list[CSVExportColumn] = [ + CSVExportColumn( + "project_name", + lambda scenario: getattr( + getattr(scenario, "project", None), "name", None), + required=True, + ), + CSVExportColumn("name", "name", required=True), + CSVExportColumn("status", "status"), + CSVExportColumn("start_date", "start_date", formatter=format_date_iso), + CSVExportColumn("end_date", "end_date", formatter=format_date_iso), + CSVExportColumn("discount_rate", "discount_rate", + formatter=format_decimal), + CSVExportColumn("currency", "currency"), + CSVExportColumn("primary_resource", "primary_resource"), + ] + if include_description: + columns.append(CSVExportColumn("description", "description")) + if include_timestamps: + columns.extend( + ( + CSVExportColumn("created_at", "created_at", + formatter=format_datetime_utc), + CSVExportColumn("updated_at", "updated_at", + formatter=format_datetime_utc), + ) + ) + return tuple(columns) + + +def stream_projects_to_csv( + projects: Iterable[Any], + *, + columns: Sequence[CSVExportColumn] | None = None, +) -> Iterator[bytes]: + resolved_columns = tuple(columns or default_project_columns()) + exporter = CSVExporter(resolved_columns) + yield from exporter.iter_bytes(projects) + + +def stream_scenarios_to_csv( + scenarios: Iterable[Any], + *, + columns: Sequence[CSVExportColumn] | None = None, +) -> Iterator[bytes]: + resolved_columns = tuple(columns or default_scenario_columns()) + exporter = CSVExporter(resolved_columns) + yield from exporter.iter_bytes(scenarios) + + +def default_formatter(value: Any) -> str: + if value is None: + return "" + if isinstance(value, Enum): + return str(value.value) + if isinstance(value, Decimal): + return format_decimal(value) + if isinstance(value, datetime): + return format_datetime_utc(value) + if isinstance(value, date): + return format_date_iso(value) + if isinstance(value, bool): + return "true" if value else "false" + return str(value) + + +def format_datetime_utc(value: Any) -> str: + if not isinstance(value, datetime): + return "" + if value.tzinfo is None: + value = value.replace(tzinfo=timezone.utc) + value = value.astimezone(timezone.utc) + return value.isoformat().replace("+00:00", "Z") + + +def format_date_iso(value: Any) -> str: + if not isinstance(value, date): + return "" + return value.isoformat() + + +def format_decimal(value: Any) -> str: + if value is None: + return "" + if isinstance(value, Decimal): + try: + quantised = value.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP) + except InvalidOperation: # pragma: no cover - unexpected precision issues + quantised = value + return format(quantised, "f") + if isinstance(value, (int, float)): + return f"{value:.2f}" + return default_formatter(value) + + +class ExcelExporter: + """Produce Excel workbooks via write-only streaming.""" + + def __init__( + self, + columns: Sequence[CSVExportColumn], + *, + sheet_name: str = "Export", + workbook_title: str | None = None, + include_header: bool = True, + metadata: Mapping[str, Any] | None = None, + metadata_sheet_name: str = "Metadata", + ) -> None: + if not columns: + raise ValueError( + "At least one column is required for Excel export.") + self._columns: tuple[CSVExportColumn, ...] = tuple(columns) + self._sheet_name = sheet_name or "Export" + self._include_header = include_header + self._metadata = dict(metadata) if metadata else None + self._metadata_sheet_name = metadata_sheet_name or "Metadata" + self._workbook = Workbook(write_only=True) + if workbook_title: + self._workbook.properties.title = workbook_title + + def export(self, records: Iterable[Any]) -> bytes: + sheet = self._workbook.create_sheet(title=self._sheet_name) + if self._include_header: + sheet.append([column.header for column in self._columns]) + + for record in records: + sheet.append(self._format_row(record)) + + self._append_metadata_sheet() + return self._finalize() + + def _format_row(self, record: Any) -> list[Any]: + row: list[Any] = [] + for column in self._columns: + raw_value = column.value_for(record) + formatter = column.formatter or default_formatter + row.append(formatter(raw_value)) + return row + + def _append_metadata_sheet(self) -> None: + if not self._metadata: + return + + sheet_name = self._metadata_sheet_name + existing = set(self._workbook.sheetnames) + if sheet_name in existing: + index = 1 + while True: + candidate = f"{sheet_name}_{index}" + if candidate not in existing: + sheet_name = candidate + break + index += 1 + + meta_ws = self._workbook.create_sheet(title=sheet_name) + meta_ws.append(["Key", "Value"]) + for key, value in self._metadata.items(): + meta_ws.append([ + str(key), + "" if value is None else str(value), + ]) + + def _finalize(self) -> bytes: + buffer = BytesIO() + self._workbook.save(buffer) + buffer.seek(0) + return buffer.getvalue() + + +def export_projects_to_excel( + projects: Iterable[Any], + *, + columns: Sequence[CSVExportColumn] | None = None, + sheet_name: str = "Projects", + workbook_title: str | None = None, + metadata: Mapping[str, Any] | None = None, +) -> bytes: + exporter = ExcelExporter( + columns or default_project_columns(), + sheet_name=sheet_name, + workbook_title=workbook_title, + metadata=metadata, + ) + return exporter.export(projects) + + +def export_scenarios_to_excel( + scenarios: Iterable[Any], + *, + columns: Sequence[CSVExportColumn] | None = None, + sheet_name: str = "Scenarios", + workbook_title: str | None = None, + metadata: Mapping[str, Any] | None = None, +) -> bytes: + exporter = ExcelExporter( + columns or default_scenario_columns(), + sheet_name=sheet_name, + workbook_title=workbook_title, + metadata=metadata, + ) + return exporter.export(scenarios) + + +def _coerce_accessor(accessor: Accessor | str) -> Accessor: + if callable(accessor): + return accessor + + path = [segment for segment in accessor.split(".") if segment] + + def _resolve(entity: Any) -> Any: + current: Any = entity + for segment in path: + if current is None: + return None + current = getattr(current, segment, None) + return current + + return _resolve + + +def _drain_buffer(buffer: StringIO) -> bytes: + data = buffer.getvalue() + buffer.seek(0) + buffer.truncate(0) + return data.encode("utf-8") diff --git a/services/financial.py b/services/financial.py new file mode 100644 index 0000000..8137bac --- /dev/null +++ b/services/financial.py @@ -0,0 +1,252 @@ +"""Financial calculation helpers for project evaluation metrics.""" + +from __future__ import annotations + +from dataclasses import dataclass +from datetime import date, datetime +from math import isclose, isfinite +from typing import Iterable, List, Sequence, Tuple + +Number = float + + +@dataclass(frozen=True, slots=True) +class CashFlow: + """Represents a dated cash flow in scenario currency.""" + + amount: Number + period_index: int | None = None + date: date | datetime | None = None + + +class ConvergenceError(RuntimeError): + """Raised when an iterative solver fails to converge.""" + + +class PaybackNotReachedError(RuntimeError): + """Raised when cumulative cash flows never reach a non-negative total.""" + + +def _coerce_date(value: date | datetime) -> date: + if isinstance(value, datetime): + return value.date() + return value + + +def normalize_cash_flows( + cash_flows: Iterable[CashFlow], + *, + compounds_per_year: int = 1, +) -> List[Tuple[Number, float]]: + """Normalise cash flows to ``(amount, periods)`` tuples. + + When explicit ``period_index`` values are provided they take precedence. If + only dates are supplied, the first dated cash flow anchors the timeline and + subsequent cash flows convert their day offsets into fractional periods + based on ``compounds_per_year``. When neither a period index nor a date is + present, cash flows are treated as sequential periods in input order. + """ + + flows: Sequence[CashFlow] = list(cash_flows) + if not flows: + return [] + + if compounds_per_year <= 0: + raise ValueError("compounds_per_year must be a positive integer") + + base_date: date | None = None + for flow in flows: + if flow.date is not None: + base_date = _coerce_date(flow.date) + break + + normalised: List[Tuple[Number, float]] = [] + for idx, flow in enumerate(flows): + amount = float(flow.amount) + if flow.period_index is not None: + periods = float(flow.period_index) + elif flow.date is not None and base_date is not None: + current_date = _coerce_date(flow.date) + delta_days = (current_date - base_date).days + period_length_days = 365.0 / float(compounds_per_year) + periods = delta_days / period_length_days + else: + periods = float(idx) + normalised.append((amount, periods)) + + return normalised + + +def discount_factor(rate: Number, periods: float, *, compounds_per_year: int = 1) -> float: + """Return the factor used to discount a value ``periods`` steps in the future.""" + + if compounds_per_year <= 0: + raise ValueError("compounds_per_year must be a positive integer") + + periodic_rate = rate / float(compounds_per_year) + return (1.0 + periodic_rate) ** (-periods) + + +def net_present_value( + rate: Number, + cash_flows: Iterable[CashFlow], + *, + residual_value: Number | None = None, + residual_periods: float | None = None, + compounds_per_year: int = 1, +) -> float: + """Calculate Net Present Value for ``cash_flows``. + + ``rate`` is a decimal (``0.1`` for 10%). Cash flows are discounted using the + given compounding frequency. When ``residual_value`` is provided it is + discounted at ``residual_periods`` periods; by default the value occurs one + period after the final cash flow. + """ + + normalised = normalize_cash_flows( + cash_flows, + compounds_per_year=compounds_per_year, + ) + + if not normalised and residual_value is None: + return 0.0 + + total = 0.0 + for amount, periods in normalised: + factor = discount_factor( + rate, periods, compounds_per_year=compounds_per_year) + total += amount * factor + + if residual_value is not None: + if residual_periods is None: + last_period = normalised[-1][1] if normalised else 0.0 + residual_periods = last_period + 1.0 + factor = discount_factor( + rate, residual_periods, compounds_per_year=compounds_per_year) + total += float(residual_value) * factor + + return total + + +def internal_rate_of_return( + cash_flows: Iterable[CashFlow], + *, + guess: Number = 0.1, + max_iterations: int = 100, + tolerance: float = 1e-6, + compounds_per_year: int = 1, +) -> float: + """Return the internal rate of return for ``cash_flows``. + + Uses Newton-Raphson iteration with a bracketed fallback when the derivative + becomes unstable. Raises :class:`ConvergenceError` if no root is found. + """ + + flows = normalize_cash_flows( + cash_flows, + compounds_per_year=compounds_per_year, + ) + if not flows: + raise ValueError("cash_flows must contain at least one item") + + amounts = [amount for amount, _ in flows] + if not any(amount < 0 for amount in amounts) or not any(amount > 0 for amount in amounts): + raise ValueError( + "cash_flows must include both negative and positive values") + + def _npv_with_flows(rate: float) -> float: + periodic_rate = rate / float(compounds_per_year) + if periodic_rate <= -1.0: + return float("inf") + total = 0.0 + for amount, periods in flows: + factor = (1.0 + periodic_rate) ** (-periods) + total += amount * factor + return total + + def _derivative(rate: float) -> float: + periodic_rate = rate / float(compounds_per_year) + if periodic_rate <= -1.0: + return float("inf") + derivative = 0.0 + for amount, periods in flows: + factor = (1.0 + periodic_rate) ** (-periods - 1.0) + derivative += -amount * periods * \ + factor / float(compounds_per_year) + return derivative + + rate = float(guess) + for _ in range(max_iterations): + value = _npv_with_flows(rate) + if isclose(value, 0.0, abs_tol=tolerance): + return rate + derivative = _derivative(rate) + if derivative == 0.0 or not isfinite(derivative): + break + next_rate = rate - value / derivative + if abs(next_rate - rate) < tolerance: + return next_rate + rate = next_rate + + # Fallback to bracketed bisection between sensible bounds. + lower_bound = -0.99 * float(compounds_per_year) + upper_bound = 10.0 + lower_value = _npv_with_flows(lower_bound) + upper_value = _npv_with_flows(upper_bound) + + attempts = 0 + while lower_value * upper_value > 0 and attempts < 12: + upper_bound *= 2.0 + upper_value = _npv_with_flows(upper_bound) + attempts += 1 + + if lower_value * upper_value > 0: + raise ConvergenceError( + "IRR could not be bracketed within default bounds") + + for _ in range(max_iterations * 2): + midpoint = (lower_bound + upper_bound) / 2.0 + mid_value = _npv_with_flows(midpoint) + if isclose(mid_value, 0.0, abs_tol=tolerance): + return midpoint + if lower_value * mid_value < 0: + upper_bound = midpoint + upper_value = mid_value + else: + lower_bound = midpoint + lower_value = mid_value + raise ConvergenceError("IRR solver failed to converge") + + +def payback_period( + cash_flows: Iterable[CashFlow], + *, + allow_fractional: bool = True, + compounds_per_year: int = 1, +) -> float: + """Return the period index where cumulative cash flow becomes non-negative.""" + + flows = normalize_cash_flows( + cash_flows, + compounds_per_year=compounds_per_year, + ) + if not flows: + raise ValueError("cash_flows must contain at least one item") + + flows = sorted(flows, key=lambda item: item[1]) + cumulative = 0.0 + previous_period = flows[0][1] + + for index, (amount, periods) in enumerate(flows): + next_cumulative = cumulative + amount + if next_cumulative >= 0.0: + if not allow_fractional or isclose(amount, 0.0): + return periods + prev_period = previous_period if index > 0 else periods + fraction = -cumulative / amount + return prev_period + fraction * (periods - prev_period) + cumulative = next_cumulative + previous_period = periods + + raise PaybackNotReachedError( + "Cumulative cash flow never becomes non-negative") diff --git a/services/importers.py b/services/importers.py new file mode 100644 index 0000000..c1107b5 --- /dev/null +++ b/services/importers.py @@ -0,0 +1,905 @@ +from __future__ import annotations + +import logging +import time +from dataclasses import dataclass +from enum import Enum +from pathlib import Path +from typing import Any, BinaryIO, Callable, Generic, Iterable, Mapping, Optional, TypeVar, cast +from uuid import uuid4 +from types import MappingProxyType + +import pandas as pd +from pandas import DataFrame +from pydantic import BaseModel, ValidationError + +from models import Project, Scenario +from schemas.imports import ProjectImportRow, ScenarioImportRow +from services.unit_of_work import UnitOfWork +from models.import_export_log import ImportExportLog +from monitoring.metrics import observe_import + +logger = logging.getLogger(__name__) + +TImportRow = TypeVar("TImportRow", bound=BaseModel) + +PROJECT_COLUMNS: tuple[str, ...] = ( + "name", + "location", + "operation_type", + "description", + "created_at", + "updated_at", +) + +SCENARIO_COLUMNS: tuple[str, ...] = ( + "project_name", + "name", + "status", + "start_date", + "end_date", + "discount_rate", + "currency", + "primary_resource", + "description", + "created_at", + "updated_at", +) + + +@dataclass(slots=True) +class ImportRowError: + row_number: int + field: str | None + message: str + + +@dataclass(slots=True) +class ParsedImportRow(Generic[TImportRow]): + row_number: int + data: TImportRow + + +@dataclass(slots=True) +class ImportResult(Generic[TImportRow]): + rows: list[ParsedImportRow[TImportRow]] + errors: list[ImportRowError] + + +class UnsupportedImportFormat(ValueError): + pass + + +class ImportPreviewState(str, Enum): + NEW = "new" + UPDATE = "update" + SKIP = "skip" + ERROR = "error" + + +@dataclass(slots=True) +class ImportPreviewRow(Generic[TImportRow]): + row_number: int + data: TImportRow + state: ImportPreviewState + issues: list[str] + context: dict[str, Any] | None = None + + +@dataclass(slots=True) +class ImportPreviewSummary: + total_rows: int + accepted: int + skipped: int + errored: int + + +@dataclass(slots=True) +class ImportPreview(Generic[TImportRow]): + rows: list[ImportPreviewRow[TImportRow]] + summary: ImportPreviewSummary + row_issues: list["ImportPreviewRowIssues"] + parser_errors: list[ImportRowError] + stage_token: str | None + + +@dataclass(slots=True) +class StagedRow(Generic[TImportRow]): + parsed: ParsedImportRow[TImportRow] + context: dict[str, Any] + + +@dataclass(slots=True) +class ImportPreviewRowIssue: + message: str + field: str | None = None + + +@dataclass(slots=True) +class ImportPreviewRowIssues: + row_number: int + state: ImportPreviewState | None + issues: list[ImportPreviewRowIssue] + + +@dataclass(slots=True) +class StagedImport(Generic[TImportRow]): + token: str + rows: list[StagedRow[TImportRow]] + + +@dataclass(slots=True, frozen=True) +class StagedRowView(Generic[TImportRow]): + row_number: int + data: TImportRow + context: Mapping[str, Any] + + +@dataclass(slots=True, frozen=True) +class StagedImportView(Generic[TImportRow]): + token: str + rows: tuple[StagedRowView[TImportRow], ...] + + +@dataclass(slots=True, frozen=True) +class ImportCommitSummary: + created: int + updated: int + + +@dataclass(slots=True, frozen=True) +class ImportCommitResult(Generic[TImportRow]): + token: str + rows: tuple[StagedRowView[TImportRow], ...] + summary: ImportCommitSummary + + +UnitOfWorkFactory = Callable[[], UnitOfWork] + + +class ImportIngestionService: + """Coordinates parsing, validation, and preview staging for imports.""" + + def __init__(self, uow_factory: UnitOfWorkFactory) -> None: + self._uow_factory = uow_factory + self._project_stage: dict[str, StagedImport[ProjectImportRow]] = {} + self._scenario_stage: dict[str, StagedImport[ScenarioImportRow]] = {} + + def preview_projects( + self, + stream: BinaryIO, + filename: str, + ) -> ImportPreview[ProjectImportRow]: + start = time.perf_counter() + result = load_project_imports(stream, filename) + status = "success" if not result.errors else "partial" + self._record_audit_log( + action="preview", + dataset="projects", + status=status, + filename=filename, + row_count=len(result.rows), + detail=f"accepted={len(result.rows)} parser_errors={len(result.errors)}", + ) + observe_import( + action="preview", + dataset="projects", + status=status, + seconds=time.perf_counter() - start, + ) + logger.info( + "import.preview", + extra={ + "event": "import.preview", + "dataset": "projects", + "status": status, + "filename": filename, + "row_count": len(result.rows), + "error_count": len(result.errors), + }, + ) + parser_errors = result.errors + + preview_rows: list[ImportPreviewRow[ProjectImportRow]] = [] + staged_rows: list[StagedRow[ProjectImportRow]] = [] + accepted = skipped = errored = 0 + + seen_names: set[str] = set() + + existing_by_name: dict[str, Project] = {} + if result.rows: + with self._uow_factory() as uow: + if not uow.projects: + raise RuntimeError("Project repository is unavailable") + existing_by_name = dict( + uow.projects.find_by_names( + parsed.data.name for parsed in result.rows + ) + ) + + for parsed in result.rows: + name_key = _normalise_key(parsed.data.name) + issues: list[str] = [] + context: dict[str, Any] | None = None + state = ImportPreviewState.NEW + + if name_key in seen_names: + state = ImportPreviewState.SKIP + issues.append( + "Duplicate project name within upload; row skipped.") + else: + seen_names.add(name_key) + existing = existing_by_name.get(name_key) + if existing: + state = ImportPreviewState.UPDATE + context = { + "mode": "update", + "project_id": existing.id, + } + issues.append("Existing project will be updated.") + else: + context = {"mode": "create"} + + preview_rows.append( + ImportPreviewRow( + row_number=parsed.row_number, + data=parsed.data, + state=state, + issues=issues, + context=context, + ) + ) + + if state in {ImportPreviewState.NEW, ImportPreviewState.UPDATE}: + accepted += 1 + staged_rows.append( + StagedRow(parsed=parsed, context=context or { + "mode": "create"}) + ) + elif state == ImportPreviewState.SKIP: + skipped += 1 + else: + errored += 1 + + parser_error_rows = {error.row_number for error in parser_errors} + errored += len(parser_error_rows) + total_rows = len(preview_rows) + len(parser_error_rows) + + summary = ImportPreviewSummary( + total_rows=total_rows, + accepted=accepted, + skipped=skipped, + errored=errored, + ) + + row_issues = _compile_row_issues(preview_rows, parser_errors) + + stage_token: str | None = None + if staged_rows: + stage_token = self._store_project_stage(staged_rows) + + return ImportPreview( + rows=preview_rows, + summary=summary, + row_issues=row_issues, + parser_errors=parser_errors, + stage_token=stage_token, + ) + + def preview_scenarios( + self, + stream: BinaryIO, + filename: str, + ) -> ImportPreview[ScenarioImportRow]: + start = time.perf_counter() + result = load_scenario_imports(stream, filename) + status = "success" if not result.errors else "partial" + self._record_audit_log( + action="preview", + dataset="scenarios", + status=status, + filename=filename, + row_count=len(result.rows), + detail=f"accepted={len(result.rows)} parser_errors={len(result.errors)}", + ) + observe_import( + action="preview", + dataset="scenarios", + status=status, + seconds=time.perf_counter() - start, + ) + logger.info( + "import.preview", + extra={ + "event": "import.preview", + "dataset": "scenarios", + "status": status, + "filename": filename, + "row_count": len(result.rows), + "error_count": len(result.errors), + }, + ) + parser_errors = result.errors + + preview_rows: list[ImportPreviewRow[ScenarioImportRow]] = [] + staged_rows: list[StagedRow[ScenarioImportRow]] = [] + accepted = skipped = errored = 0 + + seen_pairs: set[tuple[str, str]] = set() + + existing_projects: dict[str, Project] = {} + existing_scenarios: dict[tuple[int, str], Scenario] = {} + + if result.rows: + with self._uow_factory() as uow: + if not uow.projects or not uow.scenarios: + raise RuntimeError("Repositories are unavailable") + + existing_projects = dict( + uow.projects.find_by_names( + parsed.data.project_name for parsed in result.rows + ) + ) + + names_by_project: dict[int, set[str]] = {} + for parsed in result.rows: + project = existing_projects.get( + _normalise_key(parsed.data.project_name) + ) + if not project: + continue + names_by_project.setdefault(project.id, set()).add( + _normalise_key(parsed.data.name) + ) + + for project_id, names in names_by_project.items(): + matches = uow.scenarios.find_by_project_and_names( + project_id, names) + for name_key, scenario in matches.items(): + existing_scenarios[(project_id, name_key)] = scenario + + for parsed in result.rows: + project_key = _normalise_key(parsed.data.project_name) + scenario_key = _normalise_key(parsed.data.name) + issues: list[str] = [] + context: dict[str, Any] | None = None + state = ImportPreviewState.NEW + + if (project_key, scenario_key) in seen_pairs: + state = ImportPreviewState.SKIP + issues.append( + "Duplicate scenario for project within upload; row skipped." + ) + else: + seen_pairs.add((project_key, scenario_key)) + project = existing_projects.get(project_key) + if not project: + state = ImportPreviewState.ERROR + issues.append( + f"Project '{parsed.data.project_name}' does not exist." + ) + else: + context = {"mode": "create", "project_id": project.id} + existing = existing_scenarios.get( + (project.id, scenario_key)) + if existing: + state = ImportPreviewState.UPDATE + context = { + "mode": "update", + "project_id": project.id, + "scenario_id": existing.id, + } + issues.append("Existing scenario will be updated.") + + preview_rows.append( + ImportPreviewRow( + row_number=parsed.row_number, + data=parsed.data, + state=state, + issues=issues, + context=context, + ) + ) + + if state in {ImportPreviewState.NEW, ImportPreviewState.UPDATE}: + accepted += 1 + staged_rows.append( + StagedRow(parsed=parsed, context=context or { + "mode": "create"}) + ) + elif state == ImportPreviewState.SKIP: + skipped += 1 + else: + errored += 1 + + parser_error_rows = {error.row_number for error in parser_errors} + errored += len(parser_error_rows) + total_rows = len(preview_rows) + len(parser_error_rows) + + summary = ImportPreviewSummary( + total_rows=total_rows, + accepted=accepted, + skipped=skipped, + errored=errored, + ) + + row_issues = _compile_row_issues(preview_rows, parser_errors) + + stage_token: str | None = None + if staged_rows: + stage_token = self._store_scenario_stage(staged_rows) + + return ImportPreview( + rows=preview_rows, + summary=summary, + row_issues=row_issues, + parser_errors=parser_errors, + stage_token=stage_token, + ) + + def get_staged_projects( + self, token: str + ) -> StagedImportView[ProjectImportRow] | None: + staged = self._project_stage.get(token) + if not staged: + return None + return _build_staged_view(staged) + + def get_staged_scenarios( + self, token: str + ) -> StagedImportView[ScenarioImportRow] | None: + staged = self._scenario_stage.get(token) + if not staged: + return None + return _build_staged_view(staged) + + def consume_staged_projects( + self, token: str + ) -> StagedImportView[ProjectImportRow] | None: + staged = self._project_stage.pop(token, None) + if not staged: + return None + return _build_staged_view(staged) + + def consume_staged_scenarios( + self, token: str + ) -> StagedImportView[ScenarioImportRow] | None: + staged = self._scenario_stage.pop(token, None) + if not staged: + return None + return _build_staged_view(staged) + + def clear_staged_projects(self, token: str) -> bool: + return self._project_stage.pop(token, None) is not None + + def clear_staged_scenarios(self, token: str) -> bool: + return self._scenario_stage.pop(token, None) is not None + + def commit_project_import(self, token: str) -> ImportCommitResult[ProjectImportRow]: + staged = self._project_stage.get(token) + if not staged: + raise ValueError(f"Unknown project import token: {token}") + + staged_view = _build_staged_view(staged) + created = updated = 0 + + start = time.perf_counter() + try: + with self._uow_factory() as uow: + if not uow.projects: + raise RuntimeError("Project repository is unavailable") + + for row in staged.rows: + mode = row.context.get("mode") + data = row.parsed.data + + if mode == "create": + project = Project( + name=data.name, + location=data.location, + operation_type=data.operation_type, + description=data.description, + ) + if data.created_at: + project.created_at = data.created_at + if data.updated_at: + project.updated_at = data.updated_at + uow.projects.create(project) + created += 1 + elif mode == "update": + project_id = row.context.get("project_id") + if not project_id: + raise ValueError( + "Staged project update is missing project_id context" + ) + project = uow.projects.get(project_id) + project.name = data.name + project.location = data.location + project.operation_type = data.operation_type + project.description = data.description + if data.created_at: + project.created_at = data.created_at + if data.updated_at: + project.updated_at = data.updated_at + updated += 1 + else: + raise ValueError( + f"Unsupported staged project mode: {mode!r}") + except Exception as exc: + self._record_audit_log( + action="commit", + dataset="projects", + status="failure", + filename=None, + row_count=len(staged.rows), + detail=f"error={type(exc).__name__}: {exc}", + ) + observe_import( + action="commit", + dataset="projects", + status="failure", + seconds=time.perf_counter() - start, + ) + logger.exception( + "import.commit.failed", + extra={ + "event": "import.commit", + "dataset": "projects", + "status": "failure", + "row_count": len(staged.rows), + "token": token, + }, + ) + raise + else: + self._record_audit_log( + action="commit", + dataset="projects", + status="success", + filename=None, + row_count=len(staged.rows), + detail=f"created={created} updated={updated}", + ) + observe_import( + action="commit", + dataset="projects", + status="success", + seconds=time.perf_counter() - start, + ) + logger.info( + "import.commit", + extra={ + "event": "import.commit", + "dataset": "projects", + "status": "success", + "row_count": len(staged.rows), + "created": created, + "updated": updated, + "token": token, + }, + ) + + self._project_stage.pop(token, None) + return ImportCommitResult( + token=token, + rows=staged_view.rows, + summary=ImportCommitSummary(created=created, updated=updated), + ) + + def commit_scenario_import(self, token: str) -> ImportCommitResult[ScenarioImportRow]: + staged = self._scenario_stage.get(token) + if not staged: + raise ValueError(f"Unknown scenario import token: {token}") + + staged_view = _build_staged_view(staged) + created = updated = 0 + + start = time.perf_counter() + try: + with self._uow_factory() as uow: + if not uow.scenarios or not uow.projects: + raise RuntimeError("Scenario repositories are unavailable") + + for row in staged.rows: + mode = row.context.get("mode") + data = row.parsed.data + + project_id = row.context.get("project_id") + if not project_id: + raise ValueError( + "Staged scenario row is missing project_id context" + ) + + project = uow.projects.get(project_id) + + if mode == "create": + scenario = Scenario( + project_id=project.id, + name=data.name, + status=data.status, + start_date=data.start_date, + end_date=data.end_date, + discount_rate=data.discount_rate, + currency=data.currency, + primary_resource=data.primary_resource, + description=data.description, + ) + if data.created_at: + scenario.created_at = data.created_at + if data.updated_at: + scenario.updated_at = data.updated_at + uow.scenarios.create(scenario) + created += 1 + elif mode == "update": + scenario_id = row.context.get("scenario_id") + if not scenario_id: + raise ValueError( + "Staged scenario update is missing scenario_id context" + ) + scenario = uow.scenarios.get(scenario_id) + scenario.project_id = project.id + scenario.name = data.name + scenario.status = data.status + scenario.start_date = data.start_date + scenario.end_date = data.end_date + scenario.discount_rate = data.discount_rate + scenario.currency = data.currency + scenario.primary_resource = data.primary_resource + scenario.description = data.description + if data.created_at: + scenario.created_at = data.created_at + if data.updated_at: + scenario.updated_at = data.updated_at + updated += 1 + else: + raise ValueError( + f"Unsupported staged scenario mode: {mode!r}") + except Exception as exc: + self._record_audit_log( + action="commit", + dataset="scenarios", + status="failure", + filename=None, + row_count=len(staged.rows), + detail=f"error={type(exc).__name__}: {exc}", + ) + observe_import( + action="commit", + dataset="scenarios", + status="failure", + seconds=time.perf_counter() - start, + ) + logger.exception( + "import.commit.failed", + extra={ + "event": "import.commit", + "dataset": "scenarios", + "status": "failure", + "row_count": len(staged.rows), + "token": token, + }, + ) + raise + else: + self._record_audit_log( + action="commit", + dataset="scenarios", + status="success", + filename=None, + row_count=len(staged.rows), + detail=f"created={created} updated={updated}", + ) + observe_import( + action="commit", + dataset="scenarios", + status="success", + seconds=time.perf_counter() - start, + ) + logger.info( + "import.commit", + extra={ + "event": "import.commit", + "dataset": "scenarios", + "status": "success", + "row_count": len(staged.rows), + "created": created, + "updated": updated, + "token": token, + }, + ) + + self._scenario_stage.pop(token, None) + return ImportCommitResult( + token=token, + rows=staged_view.rows, + summary=ImportCommitSummary(created=created, updated=updated), + ) + + def _record_audit_log( + self, + *, + action: str, + dataset: str, + status: str, + row_count: int, + detail: Optional[str], + filename: Optional[str], + ) -> None: + try: + with self._uow_factory() as uow: + if uow.session is None: + return + log = ImportExportLog( + action=action, + dataset=dataset, + status=status, + filename=filename, + row_count=row_count, + detail=detail, + ) + uow.session.add(log) + uow.commit() + except Exception: + # Audit logging must not break core workflows + pass + + def _store_project_stage( + self, rows: list[StagedRow[ProjectImportRow]] + ) -> str: + token = str(uuid4()) + self._project_stage[token] = StagedImport(token=token, rows=rows) + return token + + def _store_scenario_stage( + self, rows: list[StagedRow[ScenarioImportRow]] + ) -> str: + token = str(uuid4()) + self._scenario_stage[token] = StagedImport(token=token, rows=rows) + return token + + +def load_project_imports(stream: BinaryIO, filename: str) -> ImportResult[ProjectImportRow]: + df = _load_dataframe(stream, filename) + return _parse_dataframe(df, ProjectImportRow, PROJECT_COLUMNS) + + +def load_scenario_imports(stream: BinaryIO, filename: str) -> ImportResult[ScenarioImportRow]: + df = _load_dataframe(stream, filename) + return _parse_dataframe(df, ScenarioImportRow, SCENARIO_COLUMNS) + + +def _load_dataframe(stream: BinaryIO, filename: str) -> DataFrame: + stream.seek(0) + suffix = Path(filename).suffix.lower() + if suffix == ".csv": + df = pd.read_csv(stream, dtype=str, + keep_default_na=False, encoding="utf-8") + elif suffix in {".xls", ".xlsx"}: + df = pd.read_excel(stream, dtype=str, engine="openpyxl") + else: + raise UnsupportedImportFormat( + f"Unsupported file type: {suffix or 'unknown'}") + df.columns = [str(col).strip().lower() for col in df.columns] + return df + + +def _parse_dataframe( + df: DataFrame, + model: type[TImportRow], + expected_columns: Iterable[str], +) -> ImportResult[TImportRow]: + rows: list[ParsedImportRow[TImportRow]] = [] + errors: list[ImportRowError] = [] + for index, raw in enumerate(df.to_dict(orient="records"), start=2): + payload = _prepare_payload( + cast(dict[str, object], raw), expected_columns) + try: + rows.append( + ParsedImportRow(row_number=index, data=model(**payload)) + ) + except ValidationError as exc: # pragma: no cover - exercised via tests + for detail in exc.errors(): + loc = ".".join(str(part) + for part in detail.get("loc", [])) or None + errors.append( + ImportRowError( + row_number=index, + field=loc, + message=detail.get("msg", "Invalid value"), + ) + ) + return ImportResult(rows=rows, errors=errors) + + +def _prepare_payload( + raw: dict[str, object], expected_columns: Iterable[str] +) -> dict[str, object | None]: + payload: dict[str, object | None] = {} + for column in expected_columns: + if column not in raw: + continue + value = raw.get(column) + if isinstance(value, str): + value = value.strip() + if value == "": + value = None + if value is not None and pd.isna(cast(Any, value)): + value = None + payload[column] = value + return payload + + +def _normalise_key(value: str) -> str: + return value.strip().lower() + + +def _build_staged_view( + staged: StagedImport[TImportRow], +) -> StagedImportView[TImportRow]: + rows = tuple( + StagedRowView( + row_number=row.parsed.row_number, + data=cast(TImportRow, _deep_copy_model(row.parsed.data)), + context=MappingProxyType(dict(row.context)), + ) + for row in staged.rows + ) + return StagedImportView(token=staged.token, rows=rows) + + +def _deep_copy_model(model: BaseModel) -> BaseModel: + copy_method = getattr(model, "model_copy", None) + if callable(copy_method): # pydantic v2 + return cast(BaseModel, copy_method(deep=True)) + return model.copy(deep=True) # type: ignore[attr-defined] + + +def _compile_row_issues( + preview_rows: Iterable[ImportPreviewRow[Any]], + parser_errors: Iterable[ImportRowError], +) -> list[ImportPreviewRowIssues]: + issue_map: dict[int, ImportPreviewRowIssues] = {} + + def ensure_bundle( + row_number: int, + state: ImportPreviewState | None, + ) -> ImportPreviewRowIssues: + bundle = issue_map.get(row_number) + if bundle is None: + bundle = ImportPreviewRowIssues( + row_number=row_number, + state=state, + issues=[], + ) + issue_map[row_number] = bundle + else: + if _state_priority(state) > _state_priority(bundle.state): + bundle.state = state + return bundle + + for row in preview_rows: + if not row.issues: + continue + bundle = ensure_bundle(row.row_number, row.state) + for message in row.issues: + bundle.issues.append(ImportPreviewRowIssue(message=message)) + + for error in parser_errors: + bundle = ensure_bundle(error.row_number, ImportPreviewState.ERROR) + bundle.issues.append( + ImportPreviewRowIssue(message=error.message, field=error.field) + ) + + return sorted(issue_map.values(), key=lambda item: item.row_number) + + +def _state_priority(state: ImportPreviewState | None) -> int: + if state is None: + return -1 + if state == ImportPreviewState.ERROR: + return 3 + if state == ImportPreviewState.SKIP: + return 2 + if state == ImportPreviewState.UPDATE: + return 1 + return 0 diff --git a/services/metrics.py b/services/metrics.py new file mode 100644 index 0000000..1a6e513 --- /dev/null +++ b/services/metrics.py @@ -0,0 +1,95 @@ +from __future__ import annotations + +import json +from datetime import datetime +from typing import Any, Dict, Optional + +from sqlalchemy.orm import Session + +from models.performance_metric import PerformanceMetric + + +class MetricsService: + def __init__(self, db: Session): + self.db = db + + def store_metric( + self, + metric_name: str, + value: float, + labels: Optional[Dict[str, Any]] = None, + endpoint: Optional[str] = None, + method: Optional[str] = None, + status_code: Optional[int] = None, + duration_seconds: Optional[float] = None, + ) -> PerformanceMetric: + """Store a performance metric in the database.""" + metric = PerformanceMetric( + timestamp=datetime.utcnow(), + metric_name=metric_name, + value=value, + labels=json.dumps(labels) if labels else None, + endpoint=endpoint, + method=method, + status_code=status_code, + duration_seconds=duration_seconds, + ) + self.db.add(metric) + self.db.commit() + self.db.refresh(metric) + return metric + + def get_metrics( + self, + metric_name: Optional[str] = None, + start_time: Optional[datetime] = None, + end_time: Optional[datetime] = None, + limit: int = 100, + ) -> list[PerformanceMetric]: + """Retrieve stored metrics with optional filtering.""" + query = self.db.query(PerformanceMetric) + + if metric_name: + query = query.filter(PerformanceMetric.metric_name == metric_name) + + if start_time: + query = query.filter(PerformanceMetric.timestamp >= start_time) + + if end_time: + query = query.filter(PerformanceMetric.timestamp <= end_time) + + return query.order_by(PerformanceMetric.timestamp.desc()).limit(limit).all() + + def get_aggregated_metrics( + self, + metric_name: str, + start_time: Optional[datetime] = None, + end_time: Optional[datetime] = None, + ) -> Dict[str, Any]: + """Get aggregated statistics for a metric.""" + query = self.db.query(PerformanceMetric).filter( + PerformanceMetric.metric_name == metric_name + ) + + if start_time: + query = query.filter(PerformanceMetric.timestamp >= start_time) + + if end_time: + query = query.filter(PerformanceMetric.timestamp <= end_time) + + metrics = query.all() + + if not metrics: + return {"count": 0, "avg": 0, "min": 0, "max": 0} + + values = [m.value for m in metrics] + return { + "count": len(values), + "avg": sum(values) / len(values), + "min": min(values), + "max": max(values), + } + + +def get_metrics_service(db: Session) -> MetricsService: + return MetricsService(db) diff --git a/services/navigation.py b/services/navigation.py new file mode 100644 index 0000000..4c097cf --- /dev/null +++ b/services/navigation.py @@ -0,0 +1,203 @@ +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import Iterable, List, Sequence + +from fastapi import Request + +from models.navigation import NavigationLink +from services.repositories import NavigationRepository +from services.session import AuthSession + + +@dataclass(slots=True) +class NavigationLinkDTO: + id: int + label: str + href: str + match_prefix: str | None + icon: str | None + tooltip: str | None + is_external: bool + children: List["NavigationLinkDTO"] = field(default_factory=list) + + +@dataclass(slots=True) +class NavigationGroupDTO: + id: int + label: str + icon: str | None + tooltip: str | None + links: List[NavigationLinkDTO] = field(default_factory=list) + + +@dataclass(slots=True) +class NavigationSidebarDTO: + groups: List[NavigationGroupDTO] + roles: tuple[str, ...] + + +class NavigationService: + """Build navigation payloads filtered for the current session.""" + + def __init__(self, repository: NavigationRepository) -> None: + self._repository = repository + + def build_sidebar( + self, + *, + session: AuthSession, + request: Request | None = None, + include_disabled: bool = False, + ) -> NavigationSidebarDTO: + roles = self._collect_roles(session) + groups = self._repository.list_groups_with_links( + include_disabled=include_disabled + ) + context = self._derive_context(request) + + mapped_groups: List[NavigationGroupDTO] = [] + for group in groups: + if not include_disabled and not group.is_enabled: + continue + mapped_links = self._map_links( + group.links, + roles, + request=request, + include_disabled=include_disabled, + context=context, + ) + if not mapped_links and not include_disabled: + continue + mapped_groups.append( + NavigationGroupDTO( + id=group.id, + label=group.label, + icon=group.icon, + tooltip=group.tooltip, + links=mapped_links, + ) + ) + return NavigationSidebarDTO(groups=mapped_groups, roles=roles) + + def _map_links( + self, + links: Sequence[NavigationLink], + roles: Iterable[str], + *, + request: Request | None, + include_disabled: bool, + context: dict[str, str | None], + include_children: bool = False, + ) -> List[NavigationLinkDTO]: + resolved_roles = tuple(roles) + mapped: List[NavigationLinkDTO] = [] + for link in sorted(links, key=lambda x: (x.sort_order, x.id)): + if not include_children and link.parent_link_id is not None: + continue + if not include_disabled and (not link.is_enabled): + continue + if not self._link_visible(link, resolved_roles, include_disabled): + continue + href = self._resolve_href(link, request=request, context=context) + if not href: + continue + children = self._map_links( + link.children, + resolved_roles, + request=request, + include_disabled=include_disabled, + context=context, + include_children=True, + ) + match_prefix = link.match_prefix or href + mapped.append( + NavigationLinkDTO( + id=link.id, + label=link.label, + href=href, + match_prefix=match_prefix, + icon=link.icon, + tooltip=link.tooltip, + is_external=link.is_external, + children=children, + ) + ) + return mapped + + @staticmethod + def _collect_roles(session: AuthSession) -> tuple[str, ...]: + roles = tuple((session.role_slugs or ()) if session else ()) + if session and session.is_authenticated: + return roles + if "anonymous" in roles: + return roles + return roles + ("anonymous",) + + @staticmethod + def _derive_context(request: Request | None) -> dict[str, str | None]: + if request is None: + return {"project_id": None, "scenario_id": None} + project_id = request.path_params.get( + "project_id") if hasattr(request, "path_params") else None + scenario_id = request.path_params.get( + "scenario_id") if hasattr(request, "path_params") else None + if not project_id: + project_id = request.query_params.get("project_id") + if not scenario_id: + scenario_id = request.query_params.get("scenario_id") + return {"project_id": project_id, "scenario_id": scenario_id} + + def _resolve_href( + self, + link: NavigationLink, + *, + request: Request | None, + context: dict[str, str | None], + ) -> str | None: + if link.route_name: + if request is None: + fallback = link.href_override + if fallback: + return fallback + # Fallback to route name when no request is available + return f"/{link.route_name.replace('.', '/')}" + requires_context = link.slug in { + "profitability", + "profitability-calculator", + "opex", + "capex", + } + if requires_context: + project_id = context.get("project_id") + scenario_id = context.get("scenario_id") + if project_id and scenario_id: + try: + return str( + request.url_for( + link.route_name, + project_id=project_id, + scenario_id=scenario_id, + ) + ) + except Exception: # pragma: no cover - defensive + pass + try: + return str(request.url_for(link.route_name)) + except Exception: # pragma: no cover - defensive + return link.href_override + return link.href_override + + @staticmethod + def _link_visible( + link: NavigationLink, + roles: Iterable[str], + include_disabled: bool, + ) -> bool: + role_tuple = tuple(roles) + if not include_disabled and not link.is_enabled: + return False + if not link.required_roles: + return True + role_set = set(role_tuple) + return any(role in role_set for role in link.required_roles) diff --git a/services/pricing.py b/services/pricing.py new file mode 100644 index 0000000..ab60c02 --- /dev/null +++ b/services/pricing.py @@ -0,0 +1,176 @@ +"""Pricing service implementing commodity revenue calculations. + +This module exposes data models and helpers for computing product pricing +according to the formulas outlined in +``calminer-docs/specifications/price_calculation.md``. It focuses on the core +calculation steps (payable metal, penalties, net revenue) and is intended to be +composed within broader scenario evaluation workflows. +""" + +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import Mapping + +from pydantic import BaseModel, Field, PositiveFloat, field_validator +from services.currency import require_currency + + +class PricingInput(BaseModel): + """Normalized inputs for pricing calculations.""" + + metal: str = Field(..., min_length=1) + ore_tonnage: PositiveFloat = Field( + ..., description="Total ore mass processed (metric tonnes)") + head_grade_pct: PositiveFloat = Field(..., gt=0, + le=100, description="Head grade as percent") + recovery_pct: PositiveFloat = Field(..., gt=0, + le=100, description="Recovery rate percent") + payable_pct: float | None = Field( + None, gt=0, le=100, description="Contractual payable percentage") + reference_price: PositiveFloat = Field( + ..., description="Reference price in base currency per unit") + treatment_charge: float = Field(0, ge=0) + smelting_charge: float = Field(0, ge=0) + moisture_pct: float = Field(0, ge=0, le=100) + moisture_threshold_pct: float | None = Field(None, ge=0, le=100) + moisture_penalty_per_pct: float | None = Field(None) + impurity_ppm: Mapping[str, float] = Field(default_factory=dict) + impurity_thresholds: Mapping[str, float] = Field(default_factory=dict) + impurity_penalty_per_ppm: Mapping[str, float] = Field(default_factory=dict) + premiums: float = Field(0) + fx_rate: PositiveFloat = Field( + 1, description="Multiplier to convert to scenario currency") + currency_code: str | None = Field( + None, description="Optional explicit currency override") + + @field_validator("impurity_ppm", mode="before") + @classmethod + def _validate_impurity_mapping(cls, value): + if isinstance(value, Mapping): + return {k: float(v) for k, v in value.items()} + return value + + +class PricingResult(BaseModel): + """Structured output summarising pricing computation results.""" + + metal: str + ore_tonnage: float + head_grade_pct: float + recovery_pct: float + payable_metal_tonnes: float + reference_price: float + gross_revenue: float + moisture_penalty: float + impurity_penalty: float + treatment_smelt_charges: float + premiums: float + net_revenue: float + currency: str | None + + +@dataclass(frozen=True) +class PricingMetadata: + """Metadata defaults applied when explicit inputs are omitted.""" + + default_payable_pct: float = 100.0 + default_currency: str | None = "USD" + moisture_threshold_pct: float = 8.0 + moisture_penalty_per_pct: float = 0.0 + impurity_thresholds: Mapping[str, float] = field(default_factory=dict) + impurity_penalty_per_ppm: Mapping[str, float] = field(default_factory=dict) + + +def calculate_pricing( + pricing_input: PricingInput, + *, + metadata: PricingMetadata | None = None, + currency: str | None = None, +) -> PricingResult: + """Calculate pricing metrics for the provided commodity input. + + Parameters + ---------- + pricing_input: + Normalised input data including ore tonnage, grades, charges, and + optional penalties. + metadata: + Optional default metadata applied when specific values are omitted from + ``pricing_input``. + currency: + Optional override for the output currency label. Falls back to + ``metadata.default_currency`` when not provided. + """ + + applied_metadata = metadata or PricingMetadata() + + payable_pct = ( + pricing_input.payable_pct + if pricing_input.payable_pct is not None + else applied_metadata.default_payable_pct + ) + moisture_threshold = ( + pricing_input.moisture_threshold_pct + if pricing_input.moisture_threshold_pct is not None + else applied_metadata.moisture_threshold_pct + ) + moisture_penalty_factor = ( + pricing_input.moisture_penalty_per_pct + if pricing_input.moisture_penalty_per_pct is not None + else applied_metadata.moisture_penalty_per_pct + ) + + impurity_thresholds = { + **applied_metadata.impurity_thresholds, + **pricing_input.impurity_thresholds, + } + impurity_penalty_factors = { + **applied_metadata.impurity_penalty_per_ppm, + **pricing_input.impurity_penalty_per_ppm, + } + + q_metal = pricing_input.ore_tonnage * (pricing_input.head_grade_pct / 100.0) * ( + pricing_input.recovery_pct / 100.0 + ) + payable_metal = q_metal * (payable_pct / 100.0) + + gross_revenue_ref = payable_metal * pricing_input.reference_price + charges = pricing_input.treatment_charge + pricing_input.smelting_charge + + moisture_excess = max(0.0, pricing_input.moisture_pct - moisture_threshold) + moisture_penalty = moisture_excess * moisture_penalty_factor + + impurity_penalty_total = 0.0 + for impurity, value in pricing_input.impurity_ppm.items(): + threshold = impurity_thresholds.get(impurity, 0.0) + penalty_factor = impurity_penalty_factors.get(impurity, 0.0) + impurity_penalty_total += max(0.0, value - threshold) * penalty_factor + + net_revenue_ref = ( + gross_revenue_ref - charges - moisture_penalty - impurity_penalty_total + ) + net_revenue_ref += pricing_input.premiums + + net_revenue = net_revenue_ref * pricing_input.fx_rate + + currency_code = require_currency( + currency or pricing_input.currency_code, + default=applied_metadata.default_currency, + ) + + return PricingResult( + metal=pricing_input.metal, + ore_tonnage=pricing_input.ore_tonnage, + head_grade_pct=pricing_input.head_grade_pct, + recovery_pct=pricing_input.recovery_pct, + payable_metal_tonnes=payable_metal, + reference_price=pricing_input.reference_price, + gross_revenue=gross_revenue_ref, + moisture_penalty=moisture_penalty, + impurity_penalty=impurity_penalty_total, + treatment_smelt_charges=charges, + premiums=pricing_input.premiums, + net_revenue=net_revenue, + currency=currency_code, + ) diff --git a/services/reporting.py b/services/reporting.py index 98387d6..a708b02 100644 --- a/services/reporting.py +++ b/services/reporting.py @@ -1,79 +1,875 @@ -from statistics import mean, median, pstdev -from typing import Any, Dict, Iterable, List, Mapping, Union, cast +"""Reporting service layer aggregating deterministic and simulation metrics.""" + +from __future__ import annotations + +from dataclasses import dataclass, field +from datetime import date +import math +from typing import Mapping, Sequence +from urllib.parse import urlencode + +import plotly.graph_objects as go +import plotly.io as pio + +from fastapi import Request + +from models import FinancialCategory, Project, Scenario +from services.financial import ( + CashFlow, + ConvergenceError, + PaybackNotReachedError, + internal_rate_of_return, + net_present_value, + payback_period, +) +from services.simulation import ( + CashFlowSpec, + SimulationConfig, + SimulationMetric, + SimulationResult, + run_monte_carlo, +) +from services.unit_of_work import UnitOfWork + +DEFAULT_DISCOUNT_RATE = 0.1 +DEFAULT_ITERATIONS = 500 +DEFAULT_PERCENTILES: tuple[float, float, float] = (5.0, 50.0, 95.0) + +_COST_CATEGORY_SIGNS: Mapping[FinancialCategory, float] = { + FinancialCategory.REVENUE: 1.0, + FinancialCategory.CAPITAL_EXPENDITURE: -1.0, + FinancialCategory.OPERATING_EXPENDITURE: -1.0, + FinancialCategory.CONTINGENCY: -1.0, + FinancialCategory.OTHER: -1.0, +} -def _extract_results(simulation_results: Iterable[object]) -> List[float]: - values: List[float] = [] - for item in simulation_results: - if not isinstance(item, Mapping): - continue - mapping_item = cast(Mapping[str, Any], item) - value = mapping_item.get("result") - if isinstance(value, (int, float)): - values.append(float(value)) - return values +@dataclass(frozen=True) +class IncludeOptions: + """Flags controlling optional sections in report payloads.""" + + distribution: bool = False + samples: bool = False -def _percentile(values: List[float], percentile: float) -> float: - if not values: - return 0.0 - sorted_values = sorted(values) - if len(sorted_values) == 1: - return sorted_values[0] - index = (percentile / 100) * (len(sorted_values) - 1) - lower = int(index) - upper = min(lower + 1, len(sorted_values) - 1) - weight = index - lower - return sorted_values[lower] * (1 - weight) + sorted_values[upper] * weight +@dataclass(slots=True) +class ReportFilters: + """Filter parameters applied when selecting scenarios for a report.""" + + scenario_ids: set[int] | None = None + start_date: date | None = None + end_date: date | None = None + + def matches(self, scenario: Scenario) -> bool: + if self.scenario_ids is not None and scenario.id not in self.scenario_ids: + return False + if self.start_date and scenario.start_date and scenario.start_date < self.start_date: + return False + if self.end_date and scenario.end_date and scenario.end_date > self.end_date: + return False + return True + + def to_dict(self) -> dict[str, object]: + payload: dict[str, object] = {} + if self.scenario_ids is not None: + payload["scenario_ids"] = sorted(self.scenario_ids) + if self.start_date is not None: + payload["start_date"] = self.start_date + if self.end_date is not None: + payload["end_date"] = self.end_date + return payload -def generate_report( - simulation_results: List[Dict[str, float]], -) -> Dict[str, Union[float, int]]: - """Aggregate basic statistics for simulation outputs.""" +@dataclass(slots=True) +class ScenarioFinancialTotals: + currency: str | None + inflows: float + outflows: float + net: float + by_category: dict[str, float] - values = _extract_results(simulation_results) - - if not values: + def to_dict(self) -> dict[str, object]: return { - "count": 0, - "mean": 0.0, - "median": 0.0, - "min": 0.0, - "max": 0.0, - "std_dev": 0.0, - "variance": 0.0, - "percentile_10": 0.0, - "percentile_90": 0.0, - "percentile_5": 0.0, - "percentile_95": 0.0, - "value_at_risk_95": 0.0, - "expected_shortfall_95": 0.0, + "currency": self.currency, + "inflows": _round_optional(self.inflows), + "outflows": _round_optional(self.outflows), + "net": _round_optional(self.net), + "by_category": { + key: _round_optional(value) for key, value in sorted(self.by_category.items()) + }, } - summary: Dict[str, Union[float, int]] = { - "count": len(values), - "mean": mean(values), - "median": median(values), - "min": min(values), - "max": max(values), - "percentile_10": _percentile(values, 10), - "percentile_90": _percentile(values, 90), - "percentile_5": _percentile(values, 5), - "percentile_95": _percentile(values, 95), + +@dataclass(slots=True) +class ScenarioDeterministicMetrics: + currency: str | None + discount_rate: float + compounds_per_year: int + npv: float | None + irr: float | None + payback_period: float | None + notes: list[str] = field(default_factory=list) + + def to_dict(self) -> dict[str, object]: + return { + "currency": self.currency, + "discount_rate": _round_optional(self.discount_rate, digits=4), + "compounds_per_year": self.compounds_per_year, + "npv": _round_optional(self.npv), + "irr": _round_optional(self.irr, digits=6), + "payback_period": _round_optional(self.payback_period, digits=4), + "notes": self.notes, + } + + +@dataclass(slots=True) +class ScenarioMonteCarloResult: + available: bool + notes: list[str] = field(default_factory=list) + result: SimulationResult | None = None + include_samples: bool = False + + def to_dict(self) -> dict[str, object]: + if not self.available or self.result is None: + return { + "available": False, + "notes": self.notes, + } + + metrics: dict[str, dict[str, object]] = {} + for metric, summary in self.result.summaries.items(): + metrics[metric.value] = { + "mean": _round_optional(summary.mean), + "std_dev": _round_optional(summary.std_dev), + "minimum": _round_optional(summary.minimum), + "maximum": _round_optional(summary.maximum), + "percentiles": { + f"{percentile:g}": _round_optional(value) + for percentile, value in sorted(summary.percentiles.items()) + }, + "sample_size": summary.sample_size, + "failed_runs": summary.failed_runs, + } + + samples_payload: dict[str, list[float | None]] | None = None + if self.include_samples and self.result.samples: + samples_payload = {} + for metric, samples in self.result.samples.items(): + samples_payload[metric.value] = [ + _sanitize_float(sample) for sample in samples.tolist() + ] + + payload: dict[str, object] = { + "available": True, + "iterations": self.result.iterations, + "metrics": metrics, + "notes": self.notes, + } + if samples_payload: + payload["samples"] = samples_payload + return payload + + +@dataclass(slots=True) +class ScenarioReport: + scenario: Scenario + totals: ScenarioFinancialTotals + deterministic: ScenarioDeterministicMetrics + monte_carlo: ScenarioMonteCarloResult | None + + def to_dict(self) -> dict[str, object]: + scenario_info = { + "id": self.scenario.id, + "project_id": self.scenario.project_id, + "name": self.scenario.name, + "description": self.scenario.description, + "status": self.scenario.status.value if hasattr(self.scenario.status, 'value') else self.scenario.status, + "start_date": self.scenario.start_date, + "end_date": self.scenario.end_date, + "currency": self.scenario.currency, + "primary_resource": self.scenario.primary_resource.value + if self.scenario.primary_resource and hasattr(self.scenario.primary_resource, 'value') + else self.scenario.primary_resource, + "discount_rate": _round_optional(self.deterministic.discount_rate, digits=4), + "created_at": self.scenario.created_at, + "updated_at": self.scenario.updated_at, + "simulation_parameter_count": len(self.scenario.simulation_parameters or []), + } + payload: dict[str, object] = { + "scenario": scenario_info, + "financials": self.totals.to_dict(), + "metrics": self.deterministic.to_dict(), + } + if self.monte_carlo is not None: + payload["monte_carlo"] = self.monte_carlo.to_dict() + return payload + + +@dataclass(slots=True) +class AggregatedMetric: + average: float | None + minimum: float | None + maximum: float | None + + def to_dict(self) -> dict[str, object]: + return { + "average": _round_optional(self.average), + "minimum": _round_optional(self.minimum), + "maximum": _round_optional(self.maximum), + } + + +@dataclass(slots=True) +class ProjectAggregates: + total_inflows: float + total_outflows: float + total_net: float + deterministic_metrics: dict[str, AggregatedMetric] + + def to_dict(self) -> dict[str, object]: + return { + "financials": { + "total_inflows": _round_optional(self.total_inflows), + "total_outflows": _round_optional(self.total_outflows), + "total_net": _round_optional(self.total_net), + }, + "deterministic_metrics": { + metric: data.to_dict() + for metric, data in sorted(self.deterministic_metrics.items()) + }, + } + + +@dataclass(slots=True) +class MetricComparison: + metric: str + direction: str + best: tuple[int, str, float] | None + worst: tuple[int, str, float] | None + average: float | None + + def to_dict(self) -> dict[str, object]: + return { + "metric": self.metric, + "direction": self.direction, + "best": _comparison_entry(self.best), + "worst": _comparison_entry(self.worst), + "average": _round_optional(self.average), + } + + +def parse_include_tokens(raw: str | None) -> IncludeOptions: + tokens: set[str] = set() + if raw: + for part in raw.split(","): + token = part.strip().lower() + if token: + tokens.add(token) + if "all" in tokens: + return IncludeOptions(distribution=True, samples=True) + return IncludeOptions( + distribution=bool({"distribution", "monte_carlo", "mc"} & tokens), + samples="samples" in tokens, + ) + + +def validate_percentiles(values: Sequence[float] | None) -> tuple[float, ...]: + if not values: + return DEFAULT_PERCENTILES + seen: set[float] = set() + cleaned: list[float] = [] + for value in values: + percentile = float(value) + if percentile < 0.0 or percentile > 100.0: + raise ValueError("Percentiles must be between 0 and 100.") + if percentile not in seen: + seen.add(percentile) + cleaned.append(percentile) + if not cleaned: + return DEFAULT_PERCENTILES + return tuple(cleaned) + + +class ReportingService: + """Coordinates project and scenario reporting aggregation.""" + + def __init__(self, uow: UnitOfWork) -> None: + self._uow = uow + + def project_summary( + self, + project: Project, + *, + filters: ReportFilters, + include: IncludeOptions, + iterations: int, + percentiles: tuple[float, ...], + ) -> dict[str, object]: + scenarios = self._load_scenarios(project.id, filters) + reports = [ + self._build_scenario_report( + scenario, + include_distribution=include.distribution, + include_samples=include.samples, + iterations=iterations, + percentiles=percentiles, + ) + for scenario in scenarios + ] + aggregates = self._aggregate_project(reports) + return { + "project": _project_payload(project), + "scenario_count": len(reports), + "filters": filters.to_dict(), + "aggregates": aggregates.to_dict(), + "scenarios": [report.to_dict() for report in reports], + } + + def scenario_comparison( + self, + project: Project, + scenarios: Sequence[Scenario], + *, + include: IncludeOptions, + iterations: int, + percentiles: tuple[float, ...], + ) -> dict[str, object]: + reports = [ + self._build_scenario_report( + self._reload_scenario(scenario.id), + include_distribution=include.distribution, + include_samples=include.samples, + iterations=iterations, + percentiles=percentiles, + ) + for scenario in scenarios + ] + comparison = { + metric: data.to_dict() + for metric, data in self._build_comparisons(reports).items() + } + return { + "project": _project_payload(project), + "scenarios": [report.to_dict() for report in reports], + "comparison": comparison, + } + + def scenario_distribution( + self, + scenario: Scenario, + *, + include: IncludeOptions, + iterations: int, + percentiles: tuple[float, ...], + ) -> dict[str, object]: + report = self._build_scenario_report( + self._reload_scenario(scenario.id), + include_distribution=True, + include_samples=include.samples, + iterations=iterations, + percentiles=percentiles, + ) + return { + "scenario": report.to_dict()["scenario"], + "summary": report.totals.to_dict(), + "metrics": report.deterministic.to_dict(), + "monte_carlo": ( + report.monte_carlo.to_dict() if report.monte_carlo else { + "available": False} + ), + } + + def _load_scenarios(self, project_id: int, filters: ReportFilters) -> list[Scenario]: + scenarios = self._uow.scenarios.list_for_project( + project_id, with_children=True) + return [scenario for scenario in scenarios if filters.matches(scenario)] + + def _reload_scenario(self, scenario_id: int) -> Scenario: + return self._uow.scenarios.get(scenario_id, with_children=True) + + def _build_scenario_report( + self, + scenario: Scenario, + *, + include_distribution: bool, + include_samples: bool, + iterations: int, + percentiles: tuple[float, ...], + ) -> ScenarioReport: + cash_flows, totals = _build_cash_flows(scenario) + deterministic = _calculate_deterministic_metrics( + scenario, cash_flows, totals) + monte_carlo: ScenarioMonteCarloResult | None = None + if include_distribution: + monte_carlo = _run_monte_carlo( + scenario, + cash_flows, + include_samples=include_samples, + iterations=iterations, + percentiles=percentiles, + ) + return ScenarioReport( + scenario=scenario, + totals=totals, + deterministic=deterministic, + monte_carlo=monte_carlo, + ) + + def _aggregate_project(self, reports: Sequence[ScenarioReport]) -> ProjectAggregates: + total_inflows = sum(report.totals.inflows for report in reports) + total_outflows = sum(report.totals.outflows for report in reports) + total_net = sum(report.totals.net for report in reports) + + metrics: dict[str, AggregatedMetric] = {} + for metric_name in ("npv", "irr", "payback_period"): + values = [ + getattr(report.deterministic, metric_name) + for report in reports + if getattr(report.deterministic, metric_name) is not None + ] + if values: + metrics[metric_name] = AggregatedMetric( + average=sum(values) / len(values), + minimum=min(values), + maximum=max(values), + ) + return ProjectAggregates( + total_inflows=total_inflows, + total_outflows=total_outflows, + total_net=total_net, + deterministic_metrics=metrics, + ) + + def _build_comparisons( + self, reports: Sequence[ScenarioReport] + ) -> Mapping[str, MetricComparison]: + comparisons: dict[str, MetricComparison] = {} + for metric_name, direction in ( + ("npv", "higher_is_better"), + ("irr", "higher_is_better"), + ("payback_period", "lower_is_better"), + ): + entries: list[tuple[int, str, float]] = [] + for report in reports: + value = getattr(report.deterministic, metric_name) + if value is None: + continue + entries.append( + (report.scenario.id, report.scenario.name, value)) + if not entries: + continue + if direction == "higher_is_better": + best = max(entries, key=lambda item: item[2]) + worst = min(entries, key=lambda item: item[2]) + else: + best = min(entries, key=lambda item: item[2]) + worst = max(entries, key=lambda item: item[2]) + average = sum(item[2] for item in entries) / len(entries) + comparisons[metric_name] = MetricComparison( + metric=metric_name, + direction=direction, + best=best, + worst=worst, + average=average, + ) + return comparisons + + def build_project_summary_context( + self, + project: Project, + filters: ReportFilters, + include: IncludeOptions, + iterations: int, + percentiles: tuple[float, ...], + request: Request, + ) -> dict[str, object]: + """Build template context for project summary page.""" + scenarios = self._load_scenarios(project.id, filters) + reports = [ + self._build_scenario_report( + scenario, + include_distribution=include.distribution, + include_samples=include.samples, + iterations=iterations, + percentiles=percentiles, + ) + for scenario in scenarios + ] + aggregates = self._aggregate_project(reports) + + return { + "request": request, + "project": _project_payload(project), + "scenario_count": len(reports), + "aggregates": aggregates.to_dict(), + "scenarios": [report.to_dict() for report in reports], + "filters": filters.to_dict(), + "include_options": include, + "iterations": iterations, + "percentiles": percentiles, + "title": f"Project Summary · {project.name}", + "subtitle": "Aggregated financial and simulation insights across scenarios.", + "actions": [ + { + "href": request.url_for( + "reports.project_summary", + project_id=project.id, + ), + "label": "Download JSON", + } + ], + "chart_data": self._generate_npv_comparison_chart(reports), + } + + def build_scenario_comparison_context( + self, + project: Project, + scenarios: Sequence[Scenario], + include: IncludeOptions, + iterations: int, + percentiles: tuple[float, ...], + request: Request, + ) -> dict[str, object]: + """Build template context for scenario comparison page.""" + reports = [ + self._build_scenario_report( + self._reload_scenario(scenario.id), + include_distribution=include.distribution, + include_samples=include.samples, + iterations=iterations, + percentiles=percentiles, + ) + for scenario in scenarios + ] + comparison = { + metric: data.to_dict() + for metric, data in self._build_comparisons(reports).items() + } + + comparison_json_url = request.url_for( + "reports.project_scenario_comparison", + project_id=project.id, + ) + scenario_ids = [str(s.id) for s in scenarios] + comparison_query = urlencode( + [("scenario_ids", str(identifier)) for identifier in scenario_ids] + ) + if comparison_query: + comparison_json_url = f"{comparison_json_url}?{comparison_query}" + + return { + "request": request, + "project": _project_payload(project), + "scenarios": [report.to_dict() for report in reports], + "comparison": comparison, + "include_options": include, + "iterations": iterations, + "percentiles": percentiles, + "title": f"Scenario Comparison · {project.name}", + "subtitle": "Evaluate deterministic metrics and Monte Carlo trends side by side.", + "actions": [ + { + "href": comparison_json_url, + "label": "Download JSON", + } + ], + } + + def build_scenario_distribution_context( + self, + scenario: Scenario, + include: IncludeOptions, + iterations: int, + percentiles: tuple[float, ...], + request: Request, + ) -> dict[str, object]: + """Build template context for scenario distribution page.""" + report = self._build_scenario_report( + self._reload_scenario(scenario.id), + include_distribution=True, + include_samples=include.samples, + iterations=iterations, + percentiles=percentiles, + ) + + return { + "request": request, + "scenario": report.to_dict()["scenario"], + "summary": report.totals.to_dict(), + "metrics": report.deterministic.to_dict(), + "monte_carlo": ( + report.monte_carlo.to_dict() if report.monte_carlo else { + "available": False} + ), + "include_options": include, + "iterations": iterations, + "percentiles": percentiles, + "title": f"Scenario Distribution · {scenario.name}", + "subtitle": "Deterministic and simulated distributions for a single scenario.", + "actions": [ + { + "href": request.url_for( + "reports.scenario_distribution", + scenario_id=scenario.id, + ), + "label": "Download JSON", + } + ], + "chart_data": self._generate_distribution_histogram(report.monte_carlo) if report.monte_carlo else "{}", + } + + def _generate_npv_comparison_chart(self, reports: Sequence[ScenarioReport]) -> str: + """Generate Plotly chart JSON for NPV comparison across scenarios.""" + scenario_names = [] + npv_values = [] + + for report in reports: + scenario_names.append(report.scenario.name) + npv_values.append(report.deterministic.npv or 0) + + fig = go.Figure(data=[ + go.Bar( + x=scenario_names, + y=npv_values, + name='NPV', + marker_color='lightblue' + ) + ]) + + fig.update_layout( + title="NPV Comparison Across Scenarios", + xaxis_title="Scenario", + yaxis_title="NPV", + showlegend=False + ) + + return pio.to_json(fig) or "{}" + + def _generate_distribution_histogram(self, monte_carlo: ScenarioMonteCarloResult) -> str: + """Generate Plotly histogram for Monte Carlo distribution.""" + if not monte_carlo.available or not monte_carlo.result or not monte_carlo.result.samples: + return "{}" + + # Get NPV samples + npv_samples = monte_carlo.result.samples.get(SimulationMetric.NPV, []) + if len(npv_samples) == 0: + return "{}" + + fig = go.Figure(data=[ + go.Histogram( + x=npv_samples, + nbinsx=50, + name='NPV Distribution', + marker_color='lightgreen' + ) + ]) + + fig.update_layout( + title="Monte Carlo NPV Distribution", + xaxis_title="NPV", + yaxis_title="Frequency", + showlegend=False + ) + + return pio.to_json(fig) or "{}" + + +def _build_cash_flows(scenario: Scenario) -> tuple[list[CashFlow], ScenarioFinancialTotals]: + cash_flows: list[CashFlow] = [] + by_category: dict[str, float] = {} + inflows = 0.0 + outflows = 0.0 + net = 0.0 + period_index = 0 + + for financial_input in scenario.financial_inputs or []: + sign = _COST_CATEGORY_SIGNS.get(financial_input.category, -1.0) + amount = float(financial_input.amount) * sign + net += amount + if amount >= 0: + inflows += amount + else: + outflows += -amount + by_category.setdefault(financial_input.category.value, 0.0) + by_category[financial_input.category.value] += amount + + if financial_input.effective_date is not None: + cash_flows.append( + CashFlow(amount=amount, date=financial_input.effective_date) + ) + else: + cash_flows.append( + CashFlow(amount=amount, period_index=period_index)) + period_index += 1 + + currency = scenario.currency + if currency is None and scenario.financial_inputs: + currency = scenario.financial_inputs[0].currency + + totals = ScenarioFinancialTotals( + currency=currency, + inflows=inflows, + outflows=outflows, + net=net, + by_category=by_category, + ) + return cash_flows, totals + + +def _calculate_deterministic_metrics( + scenario: Scenario, + cash_flows: Sequence[CashFlow], + totals: ScenarioFinancialTotals, +) -> ScenarioDeterministicMetrics: + notes: list[str] = [] + discount_rate = _normalise_discount_rate(scenario.discount_rate) + if scenario.discount_rate is None: + notes.append( + f"Discount rate not set; defaulted to {discount_rate:.2%}." + ) + + if not cash_flows: + notes.append( + "No financial inputs available for deterministic metrics.") + return ScenarioDeterministicMetrics( + currency=totals.currency, + discount_rate=discount_rate, + compounds_per_year=1, + npv=None, + irr=None, + payback_period=None, + notes=notes, + ) + + npv_value: float | None + try: + npv_value = net_present_value( + discount_rate, + cash_flows, + compounds_per_year=1, + ) + except ValueError as exc: + npv_value = None + notes.append(f"NPV unavailable: {exc}.") + + irr_value: float | None + try: + irr_value = internal_rate_of_return( + cash_flows, + compounds_per_year=1, + ) + except (ValueError, ConvergenceError) as exc: + irr_value = None + notes.append(f"IRR unavailable: {exc}.") + + payback_value: float | None + try: + payback_value = payback_period( + cash_flows, + compounds_per_year=1, + ) + except (ValueError, PaybackNotReachedError) as exc: + payback_value = None + notes.append(f"Payback period unavailable: {exc}.") + + return ScenarioDeterministicMetrics( + currency=totals.currency, + discount_rate=discount_rate, + compounds_per_year=1, + npv=npv_value, + irr=irr_value, + payback_period=payback_value, + notes=notes, + ) + + +def _run_monte_carlo( + scenario: Scenario, + cash_flows: Sequence[CashFlow], + *, + include_samples: bool, + iterations: int, + percentiles: tuple[float, ...], +) -> ScenarioMonteCarloResult: + if not cash_flows: + return ScenarioMonteCarloResult( + available=False, + notes=["No financial inputs available for Monte Carlo simulation."], + ) + + discount_rate = _normalise_discount_rate(scenario.discount_rate) + specs = [CashFlowSpec(cash_flow=flow) for flow in cash_flows] + notes: list[str] = [] + if not scenario.simulation_parameters: + notes.append( + "Scenario has no stochastic parameters; simulation mirrors deterministic cash flows." + ) + config = SimulationConfig( + iterations=iterations, + discount_rate=discount_rate, + metrics=( + SimulationMetric.NPV, + SimulationMetric.IRR, + SimulationMetric.PAYBACK, + ), + percentiles=percentiles, + return_samples=include_samples, + ) + try: + result = run_monte_carlo(specs, config) + except Exception as exc: # pragma: no cover - safeguard for unexpected failures + notes.append(f"Simulation failed: {exc}.") + return ScenarioMonteCarloResult(available=False, notes=notes) + return ScenarioMonteCarloResult( + available=True, + notes=notes, + result=result, + include_samples=include_samples, + ) + + +def _normalise_discount_rate(value: float | None) -> float: + if value is None: + return DEFAULT_DISCOUNT_RATE + rate = float(value) + if rate > 1.0: + return rate / 100.0 + return rate + + +def _sanitize_float(value: float | None) -> float | None: + if value is None: + return None + if math.isnan(value) or math.isinf(value): + return None + return float(value) + + +def _round_optional(value: float | None, *, digits: int = 2) -> float | None: + clean = _sanitize_float(value) + if clean is None: + return None + return round(clean, digits) + + +def _comparison_entry(entry: tuple[int, str, float] | None) -> dict[str, object] | None: + if entry is None: + return None + scenario_id, name, value = entry + return { + "scenario_id": scenario_id, + "name": name, + "value": _round_optional(value), } - std_dev = pstdev(values) if len(values) > 1 else 0.0 - summary["std_dev"] = std_dev - summary["variance"] = std_dev**2 - var_95 = summary["percentile_5"] - summary["value_at_risk_95"] = var_95 - - tail_values = [value for value in values if value <= var_95] - if tail_values: - summary["expected_shortfall_95"] = mean(tail_values) - else: - summary["expected_shortfall_95"] = var_95 - - return summary +def _project_payload(project: Project) -> dict[str, object]: + return { + "id": project.id, + "name": project.name, + "location": project.location, + "operation_type": project.operation_type.value, + "description": project.description, + "created_at": project.created_at, + "updated_at": project.updated_at, + } diff --git a/services/repositories.py b/services/repositories.py new file mode 100644 index 0000000..4e7cc45 --- /dev/null +++ b/services/repositories.py @@ -0,0 +1,1268 @@ +from __future__ import annotations + +from collections.abc import Iterable +from dataclasses import dataclass +from datetime import datetime +from typing import Mapping, Sequence + +from sqlalchemy import select, func +from sqlalchemy.exc import IntegrityError +from sqlalchemy.orm import Session, joinedload, selectinload + +from models import ( + FinancialInput, + Project, + PricingImpuritySettings, + PricingMetalSettings, + PricingSettings, + ProjectCapexSnapshot, + ProjectProfitability, + ProjectOpexSnapshot, + NavigationGroup, + NavigationLink, + Role, + Scenario, + ScenarioCapexSnapshot, + ScenarioProfitability, + ScenarioOpexSnapshot, + ScenarioStatus, + SimulationParameter, + User, + UserRole, +) +from services.exceptions import EntityConflictError, EntityNotFoundError +from services.export_query import ProjectExportFilters, ScenarioExportFilters +from services.pricing import PricingMetadata + + +def _enum_value(e): + """Return the underlying value for Enum members, otherwise return as-is.""" + return getattr(e, "value", e) + + +class NavigationRepository: + """Persistence operations for navigation metadata.""" + + def __init__(self, session: Session) -> None: + self.session = session + + def list_groups_with_links( + self, + *, + include_disabled: bool = False, + ) -> Sequence[NavigationGroup]: + stmt = ( + select(NavigationGroup) + .options( + selectinload(NavigationGroup.links) + .selectinload(NavigationLink.children) + ) + .order_by(NavigationGroup.sort_order, NavigationGroup.id) + ) + if not include_disabled: + stmt = stmt.where(NavigationGroup.is_enabled.is_(True)) + return self.session.execute(stmt).scalars().all() + + def get_group_by_slug(self, slug: str) -> NavigationGroup | None: + stmt = select(NavigationGroup).where(NavigationGroup.slug == slug) + return self.session.execute(stmt).scalar_one_or_none() + + def get_link_by_slug( + self, + slug: str, + *, + group_id: int | None = None, + ) -> NavigationLink | None: + stmt = select(NavigationLink).where(NavigationLink.slug == slug) + if group_id is not None: + stmt = stmt.where(NavigationLink.group_id == group_id) + return self.session.execute(stmt).scalar_one_or_none() + + def add_group(self, group: NavigationGroup) -> NavigationGroup: + self.session.add(group) + self.session.flush() + return group + + def add_link(self, link: NavigationLink) -> NavigationLink: + self.session.add(link) + self.session.flush() + return link + +class ProjectRepository: + """Persistence operations for Project entities.""" + + def __init__(self, session: Session) -> None: + self.session = session + + def list( + self, + *, + with_children: bool = False, + with_pricing: bool = False, + ) -> Sequence[Project]: + stmt = select(Project).order_by(Project.created_at) + if with_children: + stmt = stmt.options(selectinload(Project.scenarios)) + if with_pricing: + stmt = stmt.options(selectinload(Project.pricing_settings)) + return self.session.execute(stmt).scalars().all() + + def count(self) -> int: + stmt = select(func.count(Project.id)) + return self.session.execute(stmt).scalar_one() + + def recent(self, limit: int = 5) -> Sequence[Project]: + stmt = ( + select(Project) + .order_by(Project.updated_at.desc()) + .limit(limit) + ) + return self.session.execute(stmt).scalars().all() + + def get( + self, + project_id: int, + *, + with_children: bool = False, + with_pricing: bool = False, + ) -> Project: + stmt = select(Project).where(Project.id == project_id) + if with_children: + stmt = stmt.options(joinedload(Project.scenarios)) + if with_pricing: + stmt = stmt.options(joinedload(Project.pricing_settings)) + result = self.session.execute(stmt) + if with_children: + result = result.unique() + project = result.scalar_one_or_none() + if project is None: + raise EntityNotFoundError(f"Project {project_id} not found") + return project + + def exists(self, project_id: int) -> bool: + stmt = select(Project.id).where(Project.id == project_id).limit(1) + return self.session.execute(stmt).scalar_one_or_none() is not None + + def create(self, project: Project) -> Project: + self.session.add(project) + try: + self.session.flush() + except IntegrityError as exc: # pragma: no cover - reliance on DB constraints + from monitoring.metrics import observe_project_operation + observe_project_operation("create", "error") + raise EntityConflictError( + "Project violates uniqueness constraints") from exc + from monitoring.metrics import observe_project_operation + observe_project_operation("create", "success") + return project + + def find_by_names(self, names: Iterable[str]) -> Mapping[str, Project]: + normalised = {name.strip().lower() + for name in names if name and name.strip()} + if not normalised: + return {} + stmt = select(Project).where(func.lower(Project.name).in_(normalised)) + records = self.session.execute(stmt).scalars().all() + return {project.name.lower(): project for project in records} + + def filtered_for_export( + self, + filters: ProjectExportFilters | None = None, + *, + include_scenarios: bool = False, + include_pricing: bool = False, + ) -> Sequence[Project]: + stmt = select(Project) + if include_scenarios: + stmt = stmt.options(selectinload(Project.scenarios)) + if include_pricing: + stmt = stmt.options(selectinload(Project.pricing_settings)) + + if filters: + ids = filters.normalised_ids() + if ids: + stmt = stmt.where(Project.id.in_(ids)) + + name_matches = filters.normalised_names() + if name_matches: + stmt = stmt.where(func.lower(Project.name).in_(name_matches)) + + name_pattern = filters.name_search_pattern() + if name_pattern: + stmt = stmt.where(Project.name.ilike(name_pattern)) + + locations = filters.normalised_locations() + if locations: + stmt = stmt.where(func.lower(Project.location).in_(locations)) + + if filters.operation_types: + stmt = stmt.where(Project.operation_type.in_( + filters.operation_types)) + + if filters.created_from: + stmt = stmt.where(Project.created_at >= filters.created_from) + + if filters.created_to: + stmt = stmt.where(Project.created_at <= filters.created_to) + + if filters.updated_from: + stmt = stmt.where(Project.updated_at >= filters.updated_from) + + if filters.updated_to: + stmt = stmt.where(Project.updated_at <= filters.updated_to) + + stmt = stmt.order_by(Project.name, Project.id) + return self.session.execute(stmt).scalars().all() + + def delete(self, project_id: int) -> None: + project = self.get(project_id) + self.session.delete(project) + + def set_pricing_settings( + self, + project: Project, + pricing_settings: PricingSettings | None, + ) -> Project: + project.pricing_settings = pricing_settings + project.pricing_settings_id = ( + pricing_settings.id if pricing_settings is not None else None + ) + self.session.flush() + return project + + +class ScenarioRepository: + """Persistence operations for Scenario entities.""" + + def __init__(self, session: Session) -> None: + self.session = session + + def list_for_project( + self, + project_id: int, + *, + with_children: bool = False, + ) -> Sequence[Scenario]: + stmt = ( + select(Scenario) + .where(Scenario.project_id == project_id) + .order_by(Scenario.created_at) + ) + if with_children: + stmt = stmt.options( + selectinload(Scenario.financial_inputs), + selectinload(Scenario.simulation_parameters), + ) + result = self.session.execute(stmt) + if with_children: + result = result.unique() + return result.scalars().all() + + def count(self) -> int: + stmt = select(func.count(Scenario.id)) + return self.session.execute(stmt).scalar_one() + + def count_by_status(self, status: ScenarioStatus) -> int: + status_val = _enum_value(status) + stmt = select(func.count(Scenario.id)).where( + Scenario.status == status_val) + return self.session.execute(stmt).scalar_one() + + def recent(self, limit: int = 5, *, with_project: bool = False) -> Sequence[Scenario]: + stmt = select(Scenario).order_by( + Scenario.updated_at.desc()).limit(limit) + if with_project: + stmt = stmt.options(joinedload(Scenario.project)) + return self.session.execute(stmt).scalars().all() + + def list_by_status( + self, + status: ScenarioStatus, + *, + limit: int | None = None, + with_project: bool = False, + ) -> Sequence[Scenario]: + status_val = _enum_value(status) + stmt = ( + select(Scenario) + .where(Scenario.status == status_val) + .order_by(Scenario.updated_at.desc()) + ) + if with_project: + stmt = stmt.options(joinedload(Scenario.project)) + if limit is not None: + stmt = stmt.limit(limit) + return self.session.execute(stmt).scalars().all() + + def get(self, scenario_id: int, *, with_children: bool = False) -> Scenario: + stmt = select(Scenario).where(Scenario.id == scenario_id) + if with_children: + stmt = stmt.options( + joinedload(Scenario.financial_inputs), + joinedload(Scenario.simulation_parameters), + ) + result = self.session.execute(stmt) + if with_children: + result = result.unique() + scenario = result.scalar_one_or_none() + if scenario is None: + raise EntityNotFoundError(f"Scenario {scenario_id} not found") + return scenario + + def exists(self, scenario_id: int) -> bool: + stmt = select(Scenario.id).where(Scenario.id == scenario_id).limit(1) + return self.session.execute(stmt).scalar_one_or_none() is not None + + def create(self, scenario: Scenario) -> Scenario: + self.session.add(scenario) + try: + self.session.flush() + except IntegrityError as exc: # pragma: no cover + from monitoring.metrics import observe_scenario_operation + observe_scenario_operation("create", "error") + raise EntityConflictError("Scenario violates constraints") from exc + from monitoring.metrics import observe_scenario_operation + observe_scenario_operation("create", "success") + return scenario + + def find_by_project_and_names( + self, + project_id: int, + names: Iterable[str], + ) -> Mapping[str, Scenario]: + normalised = {name.strip().lower() + for name in names if name and name.strip()} + if not normalised: + return {} + stmt = ( + select(Scenario) + .where( + Scenario.project_id == project_id, + func.lower(Scenario.name).in_(normalised), + ) + ) + records = self.session.execute(stmt).scalars().all() + return {scenario.name.lower(): scenario for scenario in records} + + def filtered_for_export( + self, + filters: ScenarioExportFilters | None = None, + *, + include_project: bool = True, + ) -> Sequence[Scenario]: + stmt = select(Scenario) + if include_project: + stmt = stmt.options(joinedload(Scenario.project)) + + if filters: + scenario_ids = filters.normalised_ids() + if scenario_ids: + stmt = stmt.where(Scenario.id.in_(scenario_ids)) + + project_ids = filters.normalised_project_ids() + if project_ids: + stmt = stmt.where(Scenario.project_id.in_(project_ids)) + + project_names = filters.normalised_project_names() + if project_names: + project_id_select = select(Project.id).where( + func.lower(Project.name).in_(project_names) + ) + stmt = stmt.where(Scenario.project_id.in_(project_id_select)) + + name_pattern = filters.name_search_pattern() + if name_pattern: + stmt = stmt.where(Scenario.name.ilike(name_pattern)) + + if filters.statuses: + # Accept Enum members or raw values in filters.statuses + status_values = [ + _enum_value(s) for s in (filters.statuses or []) + ] + stmt = stmt.where(Scenario.status.in_(status_values)) + + if filters.start_date_from: + stmt = stmt.where(Scenario.start_date >= + filters.start_date_from) + + if filters.start_date_to: + stmt = stmt.where(Scenario.start_date <= filters.start_date_to) + + if filters.end_date_from: + stmt = stmt.where(Scenario.end_date >= filters.end_date_from) + + if filters.end_date_to: + stmt = stmt.where(Scenario.end_date <= filters.end_date_to) + + if filters.created_from: + stmt = stmt.where(Scenario.created_at >= filters.created_from) + + if filters.created_to: + stmt = stmt.where(Scenario.created_at <= filters.created_to) + + if filters.updated_from: + stmt = stmt.where(Scenario.updated_at >= filters.updated_from) + + if filters.updated_to: + stmt = stmt.where(Scenario.updated_at <= filters.updated_to) + + currencies = filters.normalised_currencies() + if currencies: + stmt = stmt.where(func.upper( + Scenario.currency).in_(currencies)) + + if filters.primary_resources: + stmt = stmt.where(Scenario.primary_resource.in_( + filters.primary_resources)) + + stmt = stmt.order_by(Scenario.name, Scenario.id) + return self.session.execute(stmt).scalars().all() + + def delete(self, scenario_id: int) -> None: + scenario = self.get(scenario_id) + self.session.delete(scenario) + + +class ProjectProfitabilityRepository: + """Persistence operations for project-level profitability snapshots.""" + + def __init__(self, session: Session) -> None: + self.session = session + + def create(self, snapshot: ProjectProfitability) -> ProjectProfitability: + self.session.add(snapshot) + self.session.flush() + return snapshot + + def list_for_project( + self, + project_id: int, + *, + limit: int | None = None, + ) -> Sequence[ProjectProfitability]: + stmt = ( + select(ProjectProfitability) + .where(ProjectProfitability.project_id == project_id) + .order_by(ProjectProfitability.calculated_at.desc()) + ) + if limit is not None: + stmt = stmt.limit(limit) + return self.session.execute(stmt).scalars().all() + + def latest_for_project( + self, + project_id: int, + ) -> ProjectProfitability | None: + stmt = ( + select(ProjectProfitability) + .where(ProjectProfitability.project_id == project_id) + .order_by(ProjectProfitability.calculated_at.desc()) + .limit(1) + ) + return self.session.execute(stmt).scalar_one_or_none() + + def delete(self, snapshot_id: int) -> None: + stmt = select(ProjectProfitability).where( + ProjectProfitability.id == snapshot_id + ) + entity = self.session.execute(stmt).scalar_one_or_none() + if entity is None: + raise EntityNotFoundError( + f"Project profitability snapshot {snapshot_id} not found" + ) + self.session.delete(entity) + + +class ScenarioProfitabilityRepository: + """Persistence operations for scenario-level profitability snapshots.""" + + def __init__(self, session: Session) -> None: + self.session = session + + def create(self, snapshot: ScenarioProfitability) -> ScenarioProfitability: + self.session.add(snapshot) + self.session.flush() + return snapshot + + def list_for_scenario( + self, + scenario_id: int, + *, + limit: int | None = None, + ) -> Sequence[ScenarioProfitability]: + stmt = ( + select(ScenarioProfitability) + .where(ScenarioProfitability.scenario_id == scenario_id) + .order_by(ScenarioProfitability.calculated_at.desc()) + ) + if limit is not None: + stmt = stmt.limit(limit) + return self.session.execute(stmt).scalars().all() + + def latest_for_scenario( + self, + scenario_id: int, + ) -> ScenarioProfitability | None: + stmt = ( + select(ScenarioProfitability) + .where(ScenarioProfitability.scenario_id == scenario_id) + .order_by(ScenarioProfitability.calculated_at.desc()) + .limit(1) + ) + return self.session.execute(stmt).scalar_one_or_none() + + def delete(self, snapshot_id: int) -> None: + stmt = select(ScenarioProfitability).where( + ScenarioProfitability.id == snapshot_id + ) + entity = self.session.execute(stmt).scalar_one_or_none() + if entity is None: + raise EntityNotFoundError( + f"Scenario profitability snapshot {snapshot_id} not found" + ) + self.session.delete(entity) + + +class ProjectCapexRepository: + """Persistence operations for project-level capex snapshots.""" + + def __init__(self, session: Session) -> None: + self.session = session + + def create(self, snapshot: ProjectCapexSnapshot) -> ProjectCapexSnapshot: + self.session.add(snapshot) + self.session.flush() + return snapshot + + def list_for_project( + self, + project_id: int, + *, + limit: int | None = None, + ) -> Sequence[ProjectCapexSnapshot]: + stmt = ( + select(ProjectCapexSnapshot) + .where(ProjectCapexSnapshot.project_id == project_id) + .order_by(ProjectCapexSnapshot.calculated_at.desc()) + ) + if limit is not None: + stmt = stmt.limit(limit) + return self.session.execute(stmt).scalars().all() + + def latest_for_project( + self, + project_id: int, + ) -> ProjectCapexSnapshot | None: + stmt = ( + select(ProjectCapexSnapshot) + .where(ProjectCapexSnapshot.project_id == project_id) + .order_by(ProjectCapexSnapshot.calculated_at.desc()) + .limit(1) + ) + return self.session.execute(stmt).scalar_one_or_none() + + def delete(self, snapshot_id: int) -> None: + stmt = select(ProjectCapexSnapshot).where( + ProjectCapexSnapshot.id == snapshot_id + ) + entity = self.session.execute(stmt).scalar_one_or_none() + if entity is None: + raise EntityNotFoundError( + f"Project capex snapshot {snapshot_id} not found" + ) + self.session.delete(entity) + + +class ScenarioCapexRepository: + """Persistence operations for scenario-level capex snapshots.""" + + def __init__(self, session: Session) -> None: + self.session = session + + def create(self, snapshot: ScenarioCapexSnapshot) -> ScenarioCapexSnapshot: + self.session.add(snapshot) + self.session.flush() + return snapshot + + def list_for_scenario( + self, + scenario_id: int, + *, + limit: int | None = None, + ) -> Sequence[ScenarioCapexSnapshot]: + stmt = ( + select(ScenarioCapexSnapshot) + .where(ScenarioCapexSnapshot.scenario_id == scenario_id) + .order_by(ScenarioCapexSnapshot.calculated_at.desc()) + ) + if limit is not None: + stmt = stmt.limit(limit) + return self.session.execute(stmt).scalars().all() + + def latest_for_scenario( + self, + scenario_id: int, + ) -> ScenarioCapexSnapshot | None: + stmt = ( + select(ScenarioCapexSnapshot) + .where(ScenarioCapexSnapshot.scenario_id == scenario_id) + .order_by(ScenarioCapexSnapshot.calculated_at.desc()) + .limit(1) + ) + return self.session.execute(stmt).scalar_one_or_none() + + def delete(self, snapshot_id: int) -> None: + stmt = select(ScenarioCapexSnapshot).where( + ScenarioCapexSnapshot.id == snapshot_id + ) + entity = self.session.execute(stmt).scalar_one_or_none() + if entity is None: + raise EntityNotFoundError( + f"Scenario capex snapshot {snapshot_id} not found" + ) + self.session.delete(entity) + + +class ProjectOpexRepository: + """Persistence operations for project-level opex snapshots.""" + + def __init__(self, session: Session) -> None: + self.session = session + + def create( + self, snapshot: ProjectOpexSnapshot + ) -> ProjectOpexSnapshot: + self.session.add(snapshot) + self.session.flush() + return snapshot + + def list_for_project( + self, + project_id: int, + *, + limit: int | None = None, + ) -> Sequence[ProjectOpexSnapshot]: + stmt = ( + select(ProjectOpexSnapshot) + .where(ProjectOpexSnapshot.project_id == project_id) + .order_by(ProjectOpexSnapshot.calculated_at.desc()) + ) + if limit is not None: + stmt = stmt.limit(limit) + return self.session.execute(stmt).scalars().all() + + def latest_for_project( + self, + project_id: int, + ) -> ProjectOpexSnapshot | None: + stmt = ( + select(ProjectOpexSnapshot) + .where(ProjectOpexSnapshot.project_id == project_id) + .order_by(ProjectOpexSnapshot.calculated_at.desc()) + .limit(1) + ) + return self.session.execute(stmt).scalar_one_or_none() + + def delete(self, snapshot_id: int) -> None: + stmt = select(ProjectOpexSnapshot).where( + ProjectOpexSnapshot.id == snapshot_id + ) + entity = self.session.execute(stmt).scalar_one_or_none() + if entity is None: + raise EntityNotFoundError( + f"Project opex snapshot {snapshot_id} not found" + ) + self.session.delete(entity) + + +class ScenarioOpexRepository: + """Persistence operations for scenario-level opex snapshots.""" + + def __init__(self, session: Session) -> None: + self.session = session + + def create( + self, snapshot: ScenarioOpexSnapshot + ) -> ScenarioOpexSnapshot: + self.session.add(snapshot) + self.session.flush() + return snapshot + + def list_for_scenario( + self, + scenario_id: int, + *, + limit: int | None = None, + ) -> Sequence[ScenarioOpexSnapshot]: + stmt = ( + select(ScenarioOpexSnapshot) + .where(ScenarioOpexSnapshot.scenario_id == scenario_id) + .order_by(ScenarioOpexSnapshot.calculated_at.desc()) + ) + if limit is not None: + stmt = stmt.limit(limit) + return self.session.execute(stmt).scalars().all() + + def latest_for_scenario( + self, + scenario_id: int, + ) -> ScenarioOpexSnapshot | None: + stmt = ( + select(ScenarioOpexSnapshot) + .where(ScenarioOpexSnapshot.scenario_id == scenario_id) + .order_by(ScenarioOpexSnapshot.calculated_at.desc()) + .limit(1) + ) + return self.session.execute(stmt).scalar_one_or_none() + + def delete(self, snapshot_id: int) -> None: + stmt = select(ScenarioOpexSnapshot).where( + ScenarioOpexSnapshot.id == snapshot_id + ) + entity = self.session.execute(stmt).scalar_one_or_none() + if entity is None: + raise EntityNotFoundError( + f"Scenario opex snapshot {snapshot_id} not found" + ) + self.session.delete(entity) + + +class FinancialInputRepository: + """Persistence operations for FinancialInput entities.""" + + def __init__(self, session: Session) -> None: + self.session = session + + def list_for_scenario(self, scenario_id: int) -> Sequence[FinancialInput]: + stmt = ( + select(FinancialInput) + .where(FinancialInput.scenario_id == scenario_id) + .order_by(FinancialInput.created_at) + ) + return self.session.execute(stmt).scalars().all() + + def bulk_upsert(self, inputs: Iterable[FinancialInput]) -> Sequence[FinancialInput]: + entities = list(inputs) + self.session.add_all(entities) + try: + self.session.flush() + except IntegrityError as exc: # pragma: no cover + raise EntityConflictError( + "Financial input violates constraints") from exc + return entities + + def delete(self, input_id: int) -> None: + stmt = select(FinancialInput).where(FinancialInput.id == input_id) + entity = self.session.execute(stmt).scalar_one_or_none() + if entity is None: + raise EntityNotFoundError(f"Financial input {input_id} not found") + self.session.delete(entity) + + def latest_created_at(self) -> datetime | None: + stmt = ( + select(FinancialInput.created_at) + .order_by(FinancialInput.created_at.desc()) + .limit(1) + ) + return self.session.execute(stmt).scalar_one_or_none() + + +class SimulationParameterRepository: + """Persistence operations for SimulationParameter entities.""" + + def __init__(self, session: Session) -> None: + self.session = session + + def list_for_scenario(self, scenario_id: int) -> Sequence[SimulationParameter]: + stmt = ( + select(SimulationParameter) + .where(SimulationParameter.scenario_id == scenario_id) + .order_by(SimulationParameter.created_at) + ) + return self.session.execute(stmt).scalars().all() + + def bulk_upsert( + self, parameters: Iterable[SimulationParameter] + ) -> Sequence[SimulationParameter]: + entities = list(parameters) + self.session.add_all(entities) + try: + self.session.flush() + except IntegrityError as exc: # pragma: no cover + raise EntityConflictError( + "Simulation parameter violates constraints") from exc + return entities + + def delete(self, parameter_id: int) -> None: + stmt = select(SimulationParameter).where( + SimulationParameter.id == parameter_id) + entity = self.session.execute(stmt).scalar_one_or_none() + if entity is None: + raise EntityNotFoundError( + f"Simulation parameter {parameter_id} not found") + self.session.delete(entity) + + +class PricingSettingsRepository: + """Persistence operations for pricing configuration entities.""" + + def __init__(self, session: Session) -> None: + self.session = session + + def list(self, *, include_children: bool = False) -> Sequence[PricingSettings]: + stmt = select(PricingSettings).order_by(PricingSettings.created_at) + if include_children: + stmt = stmt.options( + selectinload(PricingSettings.metal_overrides), + selectinload(PricingSettings.impurity_overrides), + ) + result = self.session.execute(stmt) + if include_children: + result = result.unique() + return result.scalars().all() + + def get(self, settings_id: int, *, include_children: bool = False) -> PricingSettings: + stmt = select(PricingSettings).where(PricingSettings.id == settings_id) + if include_children: + stmt = stmt.options( + selectinload(PricingSettings.metal_overrides), + selectinload(PricingSettings.impurity_overrides), + ) + result = self.session.execute(stmt) + if include_children: + result = result.unique() + settings = result.scalar_one_or_none() + if settings is None: + raise EntityNotFoundError( + f"Pricing settings {settings_id} not found") + return settings + + def find_by_slug( + self, + slug: str, + *, + include_children: bool = False, + ) -> PricingSettings | None: + normalised = slug.strip().lower() + stmt = select(PricingSettings).where( + PricingSettings.slug == normalised) + if include_children: + stmt = stmt.options( + selectinload(PricingSettings.metal_overrides), + selectinload(PricingSettings.impurity_overrides), + ) + result = self.session.execute(stmt) + if include_children: + result = result.unique() + return result.scalar_one_or_none() + + def get_by_slug(self, slug: str, *, include_children: bool = False) -> PricingSettings: + settings = self.find_by_slug(slug, include_children=include_children) + if settings is None: + raise EntityNotFoundError( + f"Pricing settings slug '{slug}' not found" + ) + return settings + + def create(self, settings: PricingSettings) -> PricingSettings: + self.session.add(settings) + try: + self.session.flush() + except IntegrityError as exc: # pragma: no cover - relies on DB constraints + raise EntityConflictError( + "Pricing settings violates constraints") from exc + return settings + + def delete(self, settings_id: int) -> None: + settings = self.get(settings_id, include_children=True) + self.session.delete(settings) + + def attach_metal_override( + self, + settings: PricingSettings, + override: PricingMetalSettings, + ) -> PricingMetalSettings: + settings.metal_overrides.append(override) + self.session.add(override) + self.session.flush() + return override + + def attach_impurity_override( + self, + settings: PricingSettings, + override: PricingImpuritySettings, + ) -> PricingImpuritySettings: + settings.impurity_overrides.append(override) + self.session.add(override) + self.session.flush() + return override + + +class RoleRepository: + """Persistence operations for Role entities.""" + + def __init__(self, session: Session) -> None: + self.session = session + + def list(self) -> Sequence[Role]: + stmt = select(Role).order_by(Role.name) + return self.session.execute(stmt).scalars().all() + + def get(self, role_id: int) -> Role: + stmt = select(Role).where(Role.id == role_id) + role = self.session.execute(stmt).scalar_one_or_none() + if role is None: + raise EntityNotFoundError(f"Role {role_id} not found") + return role + + def get_by_name(self, name: str) -> Role | None: + stmt = select(Role).where(Role.name == name) + return self.session.execute(stmt).scalar_one_or_none() + + def create(self, role: Role) -> Role: + self.session.add(role) + try: + self.session.flush() + except IntegrityError as exc: # pragma: no cover - DB constraint enforcement + raise EntityConflictError( + "Role violates uniqueness constraints") from exc + return role + + +class UserRepository: + """Persistence operations for User entities and their role assignments.""" + + def __init__(self, session: Session) -> None: + self.session = session + + def list(self, *, with_roles: bool = False) -> Sequence[User]: + stmt = select(User).order_by(User.created_at) + if with_roles: + stmt = stmt.options(selectinload(User.roles)) + return self.session.execute(stmt).scalars().all() + + def _apply_role_option(self, stmt, with_roles: bool): + if with_roles: + stmt = stmt.options( + joinedload(User.role_assignments).joinedload(UserRole.role), + selectinload(User.roles), + ) + return stmt + + def get(self, user_id: int, *, with_roles: bool = False) -> User: + stmt = select(User).where(User.id == user_id).execution_options( + populate_existing=True) + stmt = self._apply_role_option(stmt, with_roles) + result = self.session.execute(stmt) + if with_roles: + result = result.unique() + user = result.scalar_one_or_none() + if user is None: + raise EntityNotFoundError(f"User {user_id} not found") + return user + + def get_by_email(self, email: str, *, with_roles: bool = False) -> User | None: + stmt = select(User).where(User.email == email).execution_options( + populate_existing=True) + stmt = self._apply_role_option(stmt, with_roles) + result = self.session.execute(stmt) + if with_roles: + result = result.unique() + return result.scalar_one_or_none() + + def get_by_username(self, username: str, *, with_roles: bool = False) -> User | None: + stmt = select(User).where(User.username == + username).execution_options(populate_existing=True) + stmt = self._apply_role_option(stmt, with_roles) + result = self.session.execute(stmt) + if with_roles: + result = result.unique() + return result.scalar_one_or_none() + + def create(self, user: User) -> User: + self.session.add(user) + try: + self.session.flush() + except IntegrityError as exc: # pragma: no cover - DB constraint enforcement + raise EntityConflictError( + "User violates uniqueness constraints") from exc + return user + + def assign_role( + self, + *, + user_id: int, + role_id: int, + granted_by: int | None = None, + ) -> UserRole: + stmt = select(UserRole).where( + UserRole.user_id == user_id, + UserRole.role_id == role_id, + ) + assignment = self.session.execute(stmt).scalar_one_or_none() + if assignment: + return assignment + + assignment = UserRole( + user_id=user_id, + role_id=role_id, + granted_by=granted_by, + ) + self.session.add(assignment) + try: + self.session.flush() + except IntegrityError as exc: # pragma: no cover - DB constraint enforcement + raise EntityConflictError( + "Assignment violates constraints") from exc + return assignment + + def revoke_role(self, *, user_id: int, role_id: int) -> None: + stmt = select(UserRole).where( + UserRole.user_id == user_id, + UserRole.role_id == role_id, + ) + assignment = self.session.execute(stmt).scalar_one_or_none() + if assignment is None: + raise EntityNotFoundError( + f"Role {role_id} not assigned to user {user_id}") + self.session.delete(assignment) + self.session.flush() + + +DEFAULT_PRICING_SETTINGS_NAME = "Default Pricing Settings" +DEFAULT_PRICING_SETTINGS_DESCRIPTION = ( + "Default pricing configuration generated from environment metadata." +) + + +@dataclass(slots=True) +class PricingSettingsSeedResult: + settings: PricingSettings + created: bool + updated_fields: int + impurity_upserts: int + + +def ensure_default_pricing_settings( + repo: PricingSettingsRepository, + *, + metadata: PricingMetadata, + slug: str = "default", + name: str | None = None, + description: str | None = None, +) -> PricingSettingsSeedResult: + """Ensure a baseline pricing settings record exists and matches metadata defaults.""" + + normalised_slug = (slug or "default").strip().lower() or "default" + target_name = name or DEFAULT_PRICING_SETTINGS_NAME + target_description = description or DEFAULT_PRICING_SETTINGS_DESCRIPTION + + updated_fields = 0 + impurity_upserts = 0 + + try: + settings = repo.get_by_slug(normalised_slug, include_children=True) + created = False + except EntityNotFoundError: + settings = PricingSettings( + name=target_name, + slug=normalised_slug, + description=target_description, + default_currency=metadata.default_currency, + default_payable_pct=metadata.default_payable_pct, + moisture_threshold_pct=metadata.moisture_threshold_pct, + moisture_penalty_per_pct=metadata.moisture_penalty_per_pct, + ) + settings.metadata_payload = None + settings = repo.create(settings) + created = True + else: + if settings.name != target_name: + settings.name = target_name + updated_fields += 1 + if target_description and settings.description != target_description: + settings.description = target_description + updated_fields += 1 + if settings.default_currency != metadata.default_currency: + settings.default_currency = metadata.default_currency + updated_fields += 1 + if float(settings.default_payable_pct) != float(metadata.default_payable_pct): + settings.default_payable_pct = metadata.default_payable_pct + updated_fields += 1 + if float(settings.moisture_threshold_pct) != float(metadata.moisture_threshold_pct): + settings.moisture_threshold_pct = metadata.moisture_threshold_pct + updated_fields += 1 + if float(settings.moisture_penalty_per_pct) != float(metadata.moisture_penalty_per_pct): + settings.moisture_penalty_per_pct = metadata.moisture_penalty_per_pct + updated_fields += 1 + + impurity_thresholds = { + code.strip().upper(): float(value) + for code, value in (metadata.impurity_thresholds or {}).items() + if code.strip() + } + impurity_penalties = { + code.strip().upper(): float(value) + for code, value in (metadata.impurity_penalty_per_ppm or {}).items() + if code.strip() + } + + if impurity_thresholds or impurity_penalties: + existing_map = { + override.impurity_code: override + for override in settings.impurity_overrides + } + target_codes = set(impurity_thresholds) | set(impurity_penalties) + for code in sorted(target_codes): + threshold_value = impurity_thresholds.get(code, 0.0) + penalty_value = impurity_penalties.get(code, 0.0) + existing = existing_map.get(code) + if existing is None: + repo.attach_impurity_override( + settings, + PricingImpuritySettings( + impurity_code=code, + threshold_ppm=threshold_value, + penalty_per_ppm=penalty_value, + ), + ) + impurity_upserts += 1 + continue + changed = False + if float(existing.threshold_ppm) != float(threshold_value): + existing.threshold_ppm = threshold_value + changed = True + if float(existing.penalty_per_ppm) != float(penalty_value): + existing.penalty_per_ppm = penalty_value + changed = True + if changed: + updated_fields += 1 + + if updated_fields > 0 or impurity_upserts > 0: + repo.session.flush() + + return PricingSettingsSeedResult( + settings=settings, + created=created, + updated_fields=updated_fields, + impurity_upserts=impurity_upserts, + ) + + +def pricing_settings_to_metadata(settings: PricingSettings) -> PricingMetadata: + """Convert a persisted pricing settings record into metadata defaults.""" + + payload = settings.metadata_payload or {} + payload_thresholds = payload.get("impurity_thresholds") or {} + payload_penalties = payload.get("impurity_penalty_per_ppm") or {} + + thresholds: dict[str, float] = { + code.strip().upper(): float(value) + for code, value in payload_thresholds.items() + if isinstance(code, str) and code.strip() + } + penalties: dict[str, float] = { + code.strip().upper(): float(value) + for code, value in payload_penalties.items() + if isinstance(code, str) and code.strip() + } + + for override in settings.impurity_overrides: + code = override.impurity_code.strip().upper() + thresholds[code] = float(override.threshold_ppm) + penalties[code] = float(override.penalty_per_ppm) + + return PricingMetadata( + default_payable_pct=float(settings.default_payable_pct), + default_currency=settings.default_currency, + moisture_threshold_pct=float(settings.moisture_threshold_pct), + moisture_penalty_per_pct=float(settings.moisture_penalty_per_pct), + impurity_thresholds=thresholds, + impurity_penalty_per_ppm=penalties, + ) + + +DEFAULT_ROLE_DEFINITIONS: tuple[dict[str, str], ...] = ( + { + "name": "admin", + "display_name": "Administrator", + "description": "Full platform access with user management rights.", + }, + { + "name": "project_manager", + "display_name": "Project Manager", + "description": "Manage projects, scenarios, and associated data.", + }, + { + "name": "analyst", + "display_name": "Analyst", + "description": "Review dashboards and scenario outputs.", + }, + { + "name": "viewer", + "display_name": "Viewer", + "description": "Read-only access to assigned projects and reports.", + }, +) + + +def ensure_default_roles(role_repo: RoleRepository) -> list[Role]: + """Ensure standard roles exist, creating missing ones. + + Returns all current role records in creation order. + """ + + roles: list[Role] = [] + for definition in DEFAULT_ROLE_DEFINITIONS: + existing = role_repo.get_by_name(definition["name"]) + if existing: + roles.append(existing) + continue + role = Role(**definition) + roles.append(role_repo.create(role)) + return roles + + +def ensure_admin_user( + user_repo: UserRepository, + role_repo: RoleRepository, + *, + email: str, + username: str, + password: str, +) -> User: + """Ensure an administrator user exists and holds the admin role.""" + + user = user_repo.get_by_email(email, with_roles=True) + if user is None: + user = User( + email=email, + username=username, + password_hash=User.hash_password(password), + is_active=True, + is_superuser=True, + ) + user_repo.create(user) + else: + if not user.is_active: + user.is_active = True + if not user.is_superuser: + user.is_superuser = True + user_repo.session.flush() + + admin_role = role_repo.get_by_name("admin") + if admin_role is None: # pragma: no cover - safety if ensure_default_roles wasn't called + admin_role = role_repo.create( + Role( + name="admin", + display_name="Administrator", + description="Full platform access with user management rights.", + ) + ) + + user_repo.assign_role( + user_id=user.id, + role_id=admin_role.id, + granted_by=user.id, + ) + return user diff --git a/services/scenario_evaluation.py b/services/scenario_evaluation.py new file mode 100644 index 0000000..c356475 --- /dev/null +++ b/services/scenario_evaluation.py @@ -0,0 +1,54 @@ +"""Scenario evaluation services including pricing integration.""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Iterable + +from models.scenario import Scenario +from services.pricing import ( + PricingInput, + PricingMetadata, + PricingResult, + calculate_pricing, +) + + +@dataclass(slots=True) +class ScenarioPricingConfig: + """Configuration for pricing evaluation within a scenario.""" + + metadata: PricingMetadata | None = None + + +@dataclass(slots=True) +class ScenarioPricingSnapshot: + """Captured pricing results for a scenario.""" + + scenario_id: int + results: list[PricingResult] + + +class ScenarioPricingEvaluator: + """Evaluate scenario profitability inputs using pricing services.""" + + def __init__(self, config: ScenarioPricingConfig | None = None) -> None: + self._config = config or ScenarioPricingConfig() + + def evaluate( + self, + scenario: Scenario, + *, + inputs: Iterable[PricingInput], + metadata_override: PricingMetadata | None = None, + ) -> ScenarioPricingSnapshot: + metadata = metadata_override or self._config.metadata + results: list[PricingResult] = [] + for pricing_input in inputs: + result = calculate_pricing( + pricing_input, + metadata=metadata, + currency=scenario.currency, + ) + results.append(result) + return ScenarioPricingSnapshot(scenario_id=scenario.id, results=results) diff --git a/services/scenario_validation.py b/services/scenario_validation.py new file mode 100644 index 0000000..cf4b4a2 --- /dev/null +++ b/services/scenario_validation.py @@ -0,0 +1,106 @@ +from __future__ import annotations + +from dataclasses import dataclass +from datetime import date +from typing import Iterable, Sequence + +from models import Scenario, ScenarioStatus +from services.exceptions import ScenarioValidationError + +ALLOWED_STATUSES: frozenset[ScenarioStatus] = frozenset( + {ScenarioStatus.DRAFT, ScenarioStatus.ACTIVE} +) + + +@dataclass(frozen=True) +class _ValidationContext: + scenarios: Sequence[Scenario] + + @property + def scenario_ids(self) -> list[int]: + return [scenario.id for scenario in self.scenarios if scenario.id is not None] + + +class ScenarioComparisonValidator: + """Validates scenarios prior to comparison workflows.""" + + def validate(self, scenarios: Sequence[Scenario] | Iterable[Scenario]) -> None: + scenario_list = list(scenarios) + if len(scenario_list) < 2: + # Nothing to validate when fewer than two scenarios are provided. + return + + context = _ValidationContext(scenario_list) + + self._ensure_same_project(context) + self._ensure_allowed_status(context) + self._ensure_shared_currency(context) + self._ensure_timeline_overlap(context) + self._ensure_shared_primary_resource(context) + + def _ensure_same_project(self, context: _ValidationContext) -> None: + project_ids = {scenario.project_id for scenario in context.scenarios} + if len(project_ids) > 1: + raise ScenarioValidationError( + code="SCENARIO_PROJECT_MISMATCH", + message="Selected scenarios do not belong to the same project.", + scenario_ids=context.scenario_ids, + ) + + def _ensure_allowed_status(self, context: _ValidationContext) -> None: + disallowed = [ + scenario + for scenario in context.scenarios + if scenario.status not in ALLOWED_STATUSES + ] + if disallowed: + raise ScenarioValidationError( + code="SCENARIO_STATUS_INVALID", + message="Archived scenarios cannot be compared.", + scenario_ids=[ + scenario.id for scenario in disallowed if scenario.id is not None], + ) + + def _ensure_shared_currency(self, context: _ValidationContext) -> None: + currencies = { + scenario.currency + for scenario in context.scenarios + if scenario.currency is not None + } + if len(currencies) > 1: + raise ScenarioValidationError( + code="SCENARIO_CURRENCY_MISMATCH", + message="Scenarios use different currencies and cannot be compared.", + scenario_ids=context.scenario_ids, + ) + + def _ensure_timeline_overlap(self, context: _ValidationContext) -> None: + ranges = [ + (scenario.start_date, scenario.end_date) + for scenario in context.scenarios + if scenario.start_date and scenario.end_date + ] + if len(ranges) < 2: + return + + latest_start: date = max(start for start, _ in ranges) + earliest_end: date = min(end for _, end in ranges) + if latest_start > earliest_end: + raise ScenarioValidationError( + code="SCENARIO_TIMELINE_DISJOINT", + message="Scenario timelines do not overlap; adjust the comparison window.", + scenario_ids=context.scenario_ids, + ) + + def _ensure_shared_primary_resource(self, context: _ValidationContext) -> None: + resources = { + scenario.primary_resource + for scenario in context.scenarios + if scenario.primary_resource is not None + } + if len(resources) > 1: + raise ScenarioValidationError( + code="SCENARIO_RESOURCE_MISMATCH", + message="Scenarios target different primary resources and cannot be compared.", + scenario_ids=context.scenario_ids, + ) diff --git a/services/security.py b/services/security.py index 24782c5..34c8209 100644 --- a/services/security.py +++ b/services/security.py @@ -1,59 +1,222 @@ -from datetime import datetime, timedelta -from typing import Any, Union +from __future__ import annotations -from fastapi import HTTPException, status, Depends -from fastapi.security import OAuth2PasswordBearer -from jose import jwt, JWTError +from dataclasses import dataclass, field +from datetime import datetime, timedelta, timezone +from hmac import compare_digest +from typing import Any, Dict, Iterable, Literal, Type + +from jose import ExpiredSignatureError, JWTError, jwt from passlib.context import CryptContext -from sqlalchemy.orm import Session -from config.database import get_db +try: # pragma: no cover - compatibility shim for passlib/argon2 warning + import importlib.metadata as importlib_metadata + import argon2 # type: ignore + + setattr(argon2, "__version__", importlib_metadata.version("argon2-cffi")) +except Exception: # pragma: no cover - executed only when metadata lookup fails + pass + +from pydantic import BaseModel, Field, ValidationError + +password_context = CryptContext(schemes=["argon2"], deprecated="auto") -ACCESS_TOKEN_EXPIRE_MINUTES = 30 -SECRET_KEY = "your-secret-key" # Change this in production -ALGORITHM = "HS256" +def hash_password(password: str) -> str: + """Derive a secure hash for a plain-text password.""" -pwd_context = CryptContext(schemes=["pbkdf2_sha256"], deprecated="auto") + return password_context.hash(password) -oauth2_scheme = OAuth2PasswordBearer(tokenUrl="users/login") + +def verify_password(candidate: str, hashed: str) -> bool: + """Verify that a candidate password matches a stored hash.""" + + try: + return password_context.verify(candidate, hashed) + except ValueError: + # Raised when the stored hash is malformed or uses an unknown scheme. + return False + + +class TokenError(Exception): + """Base class for token encoding/decoding issues.""" + + +class TokenDecodeError(TokenError): + """Raised when a token cannot be decoded or validated.""" + + +class TokenExpiredError(TokenError): + """Raised when a token has expired.""" + + +class TokenTypeMismatchError(TokenError): + """Raised when a token type does not match the expected flavour.""" + + +TokenKind = Literal["access", "refresh"] + + +class TokenPayload(BaseModel): + """Shared fields for CalMiner JWT payloads.""" + + sub: str + exp: int + type: TokenKind + scopes: list[str] = Field(default_factory=list) + + @property + def expires_at(self) -> datetime: + return datetime.fromtimestamp(self.exp, tz=timezone.utc) + + +@dataclass(slots=True) +class JWTSettings: + """Runtime configuration for JWT encoding and validation.""" + + secret_key: str + algorithm: str = "HS256" + access_token_ttl: timedelta = field( + default_factory=lambda: timedelta(minutes=15)) + refresh_token_ttl: timedelta = field( + default_factory=lambda: timedelta(days=7)) def create_access_token( - subject: Union[str, Any], expires_delta: Union[timedelta, None] = None + subject: str, + settings: JWTSettings, + *, + scopes: Iterable[str] | None = None, + expires_delta: timedelta | None = None, + extra_claims: Dict[str, Any] | None = None, ) -> str: - if expires_delta: - expire = datetime.utcnow() + expires_delta - else: - expire = datetime.utcnow() + timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES) - to_encode = {"exp": expire, "sub": str(subject)} - encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM) - return encoded_jwt + """Issue a signed access token for the provided subject.""" - -def verify_password(plain_password: str, hashed_password: str) -> bool: - return pwd_context.verify(plain_password, hashed_password) - - -def get_password_hash(password: str) -> str: - return pwd_context.hash(password) - - -async def get_current_user(token: str = Depends(oauth2_scheme), db: Session = Depends(get_db)): - from models.user import User - credentials_exception = HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - detail="Could not validate credentials", - headers={"WWW-Authenticate": "Bearer"}, + lifetime = expires_delta or settings.access_token_ttl + return _create_token( + subject=subject, + token_type="access", + settings=settings, + lifetime=lifetime, + scopes=scopes, + extra_claims=extra_claims, ) + + +def create_refresh_token( + subject: str, + settings: JWTSettings, + *, + scopes: Iterable[str] | None = None, + expires_delta: timedelta | None = None, + extra_claims: Dict[str, Any] | None = None, +) -> str: + """Issue a signed refresh token for the provided subject.""" + + lifetime = expires_delta or settings.refresh_token_ttl + return _create_token( + subject=subject, + token_type="refresh", + settings=settings, + lifetime=lifetime, + scopes=scopes, + extra_claims=extra_claims, + ) + + +def decode_access_token(token: str, settings: JWTSettings) -> TokenPayload: + """Validate and decode an access token.""" + + return _decode_token(token, settings, expected_type="access") + + +def decode_refresh_token(token: str, settings: JWTSettings) -> TokenPayload: + """Validate and decode a refresh token.""" + + return _decode_token(token, settings, expected_type="refresh") + + +def _create_token( + *, + subject: str, + token_type: TokenKind, + settings: JWTSettings, + lifetime: timedelta, + scopes: Iterable[str] | None, + extra_claims: Dict[str, Any] | None, +) -> str: + now = datetime.now(timezone.utc) + expire = now + lifetime + payload: Dict[str, Any] = { + "sub": subject, + "type": token_type, + "iat": int(now.timestamp()), + "exp": int(expire.timestamp()), + } + if scopes: + payload["scopes"] = list(scopes) + if extra_claims: + payload.update(extra_claims) + + return jwt.encode(payload, settings.secret_key, algorithm=settings.algorithm) + + +def _decode_token( + token: str, + settings: JWTSettings, + expected_type: TokenKind, +) -> TokenPayload: try: - payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM]) - username = payload.get("sub") - if username is None: - raise credentials_exception - except JWTError: - raise credentials_exception - user = db.query(User).filter(User.username == username).first() - if user is None: - raise credentials_exception - return user + decoded = jwt.decode( + token, + settings.secret_key, + algorithms=[settings.algorithm], + options={"verify_aud": False}, + ) + except ExpiredSignatureError as exc: # pragma: no cover - jose marks this path + raise TokenExpiredError("Token has expired") from exc + except JWTError as exc: # pragma: no cover - jose error bubble + raise TokenDecodeError("Unable to decode token") from exc + + expected_token = jwt.encode( + decoded, + settings.secret_key, + algorithm=settings.algorithm, + ) + if not compare_digest(token, expected_token): + raise TokenDecodeError("Token contents have been altered.") + + try: + payload = _model_validate(TokenPayload, decoded) + except ValidationError as exc: + raise TokenDecodeError("Token payload validation failed") from exc + + if payload.type != expected_type: + raise TokenTypeMismatchError( + f"Expected a {expected_type} token but received '{payload.type}'." + ) + + return payload + + +def _model_validate(model: Type[TokenPayload], data: Dict[str, Any]) -> TokenPayload: + if hasattr(model, "model_validate"): + return model.model_validate(data) # type: ignore[attr-defined] + return model.parse_obj(data) # type: ignore[attr-defined] + + +__all__ = [ + "JWTSettings", + "TokenDecodeError", + "TokenError", + "TokenExpiredError", + "TokenKind", + "TokenPayload", + "TokenTypeMismatchError", + "create_access_token", + "create_refresh_token", + "decode_access_token", + "decode_refresh_token", + "hash_password", + "password_context", + "verify_password", +] diff --git a/services/session.py b/services/session.py new file mode 100644 index 0000000..e68066e --- /dev/null +++ b/services/session.py @@ -0,0 +1,200 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import Iterable, Literal, Optional, TYPE_CHECKING + +from fastapi import Request, Response + +from config.settings import SessionSettings +from services.security import JWTSettings + +if TYPE_CHECKING: # pragma: no cover - used only for static typing + from models import User + + +@dataclass(slots=True) +class SessionStrategy: + """Describe how authentication tokens are transported with requests.""" + + access_cookie_name: str + refresh_cookie_name: str + cookie_secure: bool + cookie_domain: Optional[str] + cookie_path: str + header_name: str + header_prefix: str + allow_header_fallback: bool = True + + @classmethod + def from_settings(cls, settings: SessionSettings) -> "SessionStrategy": + return cls( + access_cookie_name=settings.access_cookie_name, + refresh_cookie_name=settings.refresh_cookie_name, + cookie_secure=settings.cookie_secure, + cookie_domain=settings.cookie_domain, + cookie_path=settings.cookie_path, + header_name=settings.header_name, + header_prefix=settings.header_prefix, + allow_header_fallback=settings.allow_header_fallback, + ) + + +@dataclass(slots=True) +class SessionTokens: + """Raw access and refresh tokens extracted from the transport layer.""" + + access_token: Optional[str] + refresh_token: Optional[str] + access_token_source: Literal["cookie", "header", "none"] = "none" + + @property + def has_access(self) -> bool: + return bool(self.access_token) + + @property + def has_refresh(self) -> bool: + return bool(self.refresh_token) + + @property + def is_empty(self) -> bool: + return not self.has_access and not self.has_refresh + + +@dataclass(slots=True) +class AuthSession: + """Holds authenticated user context resolved from session tokens.""" + + tokens: SessionTokens + user: Optional["User"] = None + scopes: tuple[str, ...] = () + role_slugs: tuple[str, ...] = () + issued_access_token: Optional[str] = None + issued_refresh_token: Optional[str] = None + clear_cookies: bool = False + + @property + def is_authenticated(self) -> bool: + return self.user is not None + + @classmethod + def anonymous(cls) -> "AuthSession": + return cls( + tokens=SessionTokens(access_token=None, refresh_token=None), + role_slugs=(), + ) + + def issue_tokens( + self, + *, + access_token: str, + refresh_token: Optional[str] = None, + access_source: Literal["cookie", "header", "none"] = "cookie", + ) -> None: + self.issued_access_token = access_token + if refresh_token is not None: + self.issued_refresh_token = refresh_token + self.tokens = SessionTokens( + access_token=access_token, + refresh_token=refresh_token if refresh_token is not None else self.tokens.refresh_token, + access_token_source=access_source, + ) + + def mark_cleared(self) -> None: + self.clear_cookies = True + self.tokens = SessionTokens(access_token=None, refresh_token=None) + self.user = None + self.scopes = () + self.role_slugs = () + + def set_role_slugs(self, roles: Iterable[str]) -> None: + self.role_slugs = tuple(dict.fromkeys(role.strip().lower() for role in roles if role)) + + +def extract_session_tokens(request: Request, strategy: SessionStrategy) -> SessionTokens: + """Pull tokens from cookies or headers according to configured strategy.""" + + access_token: Optional[str] = None + refresh_token: Optional[str] = None + access_source: Literal["cookie", "header", "none"] = "none" + + if strategy.access_cookie_name in request.cookies: + access_token = request.cookies.get(strategy.access_cookie_name) or None + if access_token: + access_source = "cookie" + + if strategy.refresh_cookie_name in request.cookies: + refresh_token = request.cookies.get( + strategy.refresh_cookie_name) or None + + if not access_token and strategy.allow_header_fallback: + header_value = request.headers.get(strategy.header_name) + if header_value: + candidate = header_value.strip() + prefix = f"{strategy.header_prefix} " if strategy.header_prefix else "" + if prefix and candidate.lower().startswith(prefix.lower()): + candidate = candidate[len(prefix):].strip() + if candidate: + access_token = candidate + access_source = "header" + + return SessionTokens( + access_token=access_token, + refresh_token=refresh_token, + access_token_source=access_source, + ) + + +def build_session_strategy(settings: SessionSettings) -> SessionStrategy: + """Create a session strategy object from settings configuration.""" + + return SessionStrategy.from_settings(settings) + + +def set_session_cookies( + response: Response, + *, + access_token: str, + refresh_token: Optional[str], + strategy: SessionStrategy, + jwt_settings: JWTSettings, +) -> None: + """Persist session cookies on an outgoing response.""" + + access_ttl = int(jwt_settings.access_token_ttl.total_seconds()) + refresh_ttl = int(jwt_settings.refresh_token_ttl.total_seconds()) + response.set_cookie( + strategy.access_cookie_name, + access_token, + httponly=True, + secure=strategy.cookie_secure, + samesite="lax", + max_age=max(access_ttl, 0) or None, + domain=strategy.cookie_domain, + path=strategy.cookie_path, + ) + if refresh_token is not None: + response.set_cookie( + strategy.refresh_cookie_name, + refresh_token, + httponly=True, + secure=strategy.cookie_secure, + samesite="lax", + max_age=max(refresh_ttl, 0) or None, + domain=strategy.cookie_domain, + path=strategy.cookie_path, + ) + + +def clear_session_cookies(response: Response, strategy: SessionStrategy) -> None: + """Remove session cookies from the client.""" + + response.delete_cookie( + strategy.access_cookie_name, + domain=strategy.cookie_domain, + path=strategy.cookie_path, + ) + response.delete_cookie( + strategy.refresh_cookie_name, + domain=strategy.cookie_domain, + path=strategy.cookie_path, + ) diff --git a/services/settings.py b/services/settings.py deleted file mode 100644 index 51b49ac..0000000 --- a/services/settings.py +++ /dev/null @@ -1,230 +0,0 @@ -from __future__ import annotations - -import os -import re -from typing import Dict, Mapping - -from sqlalchemy.orm import Session - -from models.application_setting import ApplicationSetting -from models.theme_setting import ThemeSetting # Import ThemeSetting model - -CSS_COLOR_CATEGORY = "theme" -CSS_COLOR_VALUE_TYPE = "color" -CSS_ENV_PREFIX = "CALMINER_THEME_" - -CSS_COLOR_DEFAULTS: Dict[str, str] = { - "--color-background": "#f4f5f7", - "--color-surface": "#ffffff", - "--color-text-primary": "#2a1f33", - "--color-text-secondary": "#624769", - "--color-text-muted": "#64748b", - "--color-text-subtle": "#94a3b8", - "--color-text-invert": "#ffffff", - "--color-text-dark": "#0f172a", - "--color-text-strong": "#111827", - "--color-primary": "#5f320d", - "--color-primary-strong": "#7e4c13", - "--color-primary-stronger": "#837c15", - "--color-accent": "#bff838", - "--color-border": "#e2e8f0", - "--color-border-strong": "#cbd5e1", - "--color-highlight": "#eef2ff", - "--color-panel-shadow": "rgba(15, 23, 42, 0.08)", - "--color-panel-shadow-deep": "rgba(15, 23, 42, 0.12)", - "--color-surface-alt": "#f8fafc", - "--color-success": "#047857", - "--color-error": "#b91c1c", -} - -_COLOR_VALUE_PATTERN = re.compile( - r"^(#([0-9a-fA-F]{3}|[0-9a-fA-F]{6}|[0-9a-fA-F]{8})|rgba?\([^)]+\)|hsla?\([^)]+\))$", - re.IGNORECASE, -) - - -def ensure_css_color_settings(db: Session) -> Dict[str, ApplicationSetting]: - """Ensure the CSS color defaults exist in the settings table.""" - - existing = ( - db.query(ApplicationSetting) - .filter(ApplicationSetting.key.in_(CSS_COLOR_DEFAULTS.keys())) - .all() - ) - by_key = {setting.key: setting for setting in existing} - - created = False - for key, default_value in CSS_COLOR_DEFAULTS.items(): - if key in by_key: - continue - setting = ApplicationSetting( - key=key, - value=default_value, - value_type=CSS_COLOR_VALUE_TYPE, - category=CSS_COLOR_CATEGORY, - description=f"CSS variable {key}", - is_editable=True, - ) - db.add(setting) - by_key[key] = setting - created = True - - if created: - db.commit() - for key, setting in by_key.items(): - db.refresh(setting) - - return by_key - - -def get_css_color_settings(db: Session) -> Dict[str, str]: - """Return CSS color variables, filling missing values with defaults.""" - - settings = ensure_css_color_settings(db) - values: Dict[str, str] = { - key: settings[key].value if key in settings else default - for key, default in CSS_COLOR_DEFAULTS.items() - } - - env_overrides = read_css_color_env_overrides(os.environ) - if env_overrides: - values.update(env_overrides) - - return values - - -def update_css_color_settings( - db: Session, updates: Mapping[str, str] -) -> Dict[str, str]: - """Persist provided CSS color overrides and return the final values.""" - - if not updates: - return get_css_color_settings(db) - - invalid_keys = sorted(set(updates.keys()) - set(CSS_COLOR_DEFAULTS.keys())) - if invalid_keys: - invalid_list = ", ".join(invalid_keys) - raise ValueError(f"Unsupported CSS variables: {invalid_list}") - - normalized: Dict[str, str] = {} - for key, value in updates.items(): - normalized[key] = _normalize_color_value(value) - - settings = ensure_css_color_settings(db) - changed = False - - for key, value in normalized.items(): - setting = settings[key] - if setting.value != value: - setting.value = value - changed = True - if setting.value_type != CSS_COLOR_VALUE_TYPE: - setting.value_type = CSS_COLOR_VALUE_TYPE - changed = True - if setting.category != CSS_COLOR_CATEGORY: - setting.category = CSS_COLOR_CATEGORY - changed = True - if not setting.is_editable: - setting.is_editable = True - changed = True - - if changed: - db.commit() - for key in normalized.keys(): - db.refresh(settings[key]) - - return get_css_color_settings(db) - - -def read_css_color_env_overrides( - env: Mapping[str, str] | None = None, -) -> Dict[str, str]: - """Return validated CSS overrides sourced from environment variables.""" - - if env is None: - env = os.environ - - overrides: Dict[str, str] = {} - for css_key in CSS_COLOR_DEFAULTS.keys(): - env_name = css_key_to_env_var(css_key) - raw_value = env.get(env_name) - if raw_value is None: - continue - overrides[css_key] = _normalize_color_value(raw_value) - - return overrides - - -def _normalize_color_value(value: str) -> str: - if not isinstance(value, str): - raise ValueError("Color value must be a string") - trimmed = value.strip() - if not trimmed: - raise ValueError("Color value cannot be empty") - if not _COLOR_VALUE_PATTERN.match(trimmed): - raise ValueError( - "Color value must be a hex code or an rgb/rgba/hsl/hsla expression" - ) - _validate_functional_color(trimmed) - return trimmed - - -def _validate_functional_color(value: str) -> None: - lowered = value.lower() - if lowered.startswith("rgb(") or lowered.startswith("hsl("): - _ensure_component_count(value, expected=3) - elif lowered.startswith("rgba(") or lowered.startswith("hsla("): - _ensure_component_count(value, expected=4) - - -def _ensure_component_count(value: str, expected: int) -> None: - if not value.endswith(")"): - raise ValueError( - "Color function expressions must end with a closing parenthesis" - ) - inner = value[value.index("(") + 1: -1] - parts = [segment.strip() for segment in inner.split(",")] - if len(parts) != expected: - raise ValueError( - "Color function expressions must provide the expected number of components" - ) - if any(not component for component in parts): - raise ValueError("Color function components cannot be empty") - - -def css_key_to_env_var(css_key: str) -> str: - sanitized = css_key.lstrip("-").replace("-", "_").upper() - return f"{CSS_ENV_PREFIX}{sanitized}" - - -def list_css_env_override_rows( - env: Mapping[str, str] | None = None, -) -> list[Dict[str, str]]: - overrides = read_css_color_env_overrides(env) - rows: list[Dict[str, str]] = [] - for css_key, value in overrides.items(): - rows.append( - { - "css_key": css_key, - "env_var": css_key_to_env_var(css_key), - "value": value, - } - ) - return rows - - -def save_theme_settings(db: Session, theme_data: dict): - theme = db.query(ThemeSetting).first() or ThemeSetting() - for key, value in theme_data.items(): - setattr(theme, key, value) - db.add(theme) - db.commit() - db.refresh(theme) - return theme - - -def get_theme_settings(db: Session): - theme = db.query(ThemeSetting).first() - if theme: - return {c.name: getattr(theme, c.name) for c in theme.__table__.columns} - return {} diff --git a/services/simulation.py b/services/simulation.py index 6c8ffe1..51b0332 100644 --- a/services/simulation.py +++ b/services/simulation.py @@ -1,144 +1,373 @@ from __future__ import annotations from dataclasses import dataclass -from random import Random -from typing import Dict, List, Literal, Optional, Sequence +from enum import Enum +from typing import Any, Dict, Mapping, Sequence +import time + +import numpy as np +from numpy.random import Generator, default_rng + +from .financial import ( + CashFlow, + ConvergenceError, + PaybackNotReachedError, + internal_rate_of_return, + net_present_value, + payback_period, +) +from monitoring.metrics import observe_simulation -DEFAULT_STD_DEV_RATIO = 0.1 -DEFAULT_UNIFORM_SPAN_RATIO = 0.15 -DistributionType = Literal["normal", "uniform", "triangular"] +class DistributionConfigError(ValueError): + """Raised when a distribution specification is invalid.""" -@dataclass -class SimulationParameter: - name: str - base_value: float - distribution: DistributionType - std_dev: Optional[float] = None - minimum: Optional[float] = None - maximum: Optional[float] = None - mode: Optional[float] = None +class SimulationMetric(Enum): + """Supported Monte Carlo summary metrics.""" + + NPV = "npv" + IRR = "irr" + PAYBACK = "payback" -def _ensure_positive_span(span: float, fallback: float) -> float: - return span if span and span > 0 else fallback +class DistributionType(Enum): + """Supported probability distribution families.""" + + NORMAL = "normal" + LOGNORMAL = "lognormal" + TRIANGULAR = "triangular" + DISCRETE = "discrete" -def _compile_parameters( - parameters: Sequence[Dict[str, float]], -) -> List[SimulationParameter]: - compiled: List[SimulationParameter] = [] - for index, item in enumerate(parameters): - if "value" not in item: - raise ValueError(f"Parameter at index {index} must include 'value'") - name = str(item.get("name", f"param_{index}")) - base_value = float(item["value"]) - distribution = str(item.get("distribution", "normal")).lower() - if distribution not in {"normal", "uniform", "triangular"}: - raise ValueError( - f"Parameter '{name}' has unsupported distribution '{distribution}'" - ) +class DistributionSource(Enum): + """Origins for parameter values when sourcing dynamically.""" - span_default = abs(base_value) * DEFAULT_UNIFORM_SPAN_RATIO or 1.0 + STATIC = "static" + SCENARIO_FIELD = "scenario_field" + METADATA_KEY = "metadata_key" - if distribution == "normal": - std_dev = item.get("std_dev") - std_dev_value = ( - float(std_dev) - if std_dev is not None - else abs(base_value) * DEFAULT_STD_DEV_RATIO or 1.0 - ) - compiled.append( - SimulationParameter( - name=name, - base_value=base_value, - distribution="normal", - std_dev=_ensure_positive_span(std_dev_value, 1.0), + +@dataclass(frozen=True, slots=True) +class DistributionSpec: + """Defines the stochastic behaviour for a single cash flow.""" + + type: DistributionType + parameters: Mapping[str, Any] + source: DistributionSource = DistributionSource.STATIC + source_key: str | None = None + + +@dataclass(frozen=True, slots=True) +class CashFlowSpec: + """Pairs a baseline cash flow with an optional distribution.""" + + cash_flow: CashFlow + distribution: DistributionSpec | None = None + + +@dataclass(frozen=True, slots=True) +class SimulationConfig: + """Controls Monte Carlo simulation behaviour.""" + + iterations: int + discount_rate: float + seed: int | None = None + metrics: Sequence[SimulationMetric] = ( + SimulationMetric.NPV, SimulationMetric.IRR, SimulationMetric.PAYBACK) + percentiles: Sequence[float] = (5.0, 50.0, 95.0) + compounds_per_year: int = 1 + return_samples: bool = False + residual_value: float | None = None + residual_periods: float | None = None + + +@dataclass(frozen=True, slots=True) +class MetricSummary: + """Aggregated statistics for a simulated metric.""" + + mean: float + std_dev: float + minimum: float + maximum: float + percentiles: Mapping[float, float] + sample_size: int + failed_runs: int + + +@dataclass(frozen=True, slots=True) +class SimulationResult: + """Monte Carlo output including per-metric summaries.""" + + iterations: int + summaries: Mapping[SimulationMetric, MetricSummary] + samples: Mapping[SimulationMetric, np.ndarray] | None = None + + +def run_monte_carlo( + cash_flows: Sequence[CashFlowSpec], + config: SimulationConfig, + *, + scenario_context: Mapping[str, Any] | None = None, + metadata: Mapping[str, Any] | None = None, + rng: Generator | None = None, +) -> SimulationResult: + """Execute Monte Carlo simulation for the provided cash flows.""" + + if config.iterations <= 0: + raise ValueError("iterations must be greater than zero") + if config.compounds_per_year <= 0: + raise ValueError("compounds_per_year must be greater than zero") + for pct in config.percentiles: + if pct < 0.0 or pct > 100.0: + raise ValueError("percentiles must be within [0, 100]") + + start_time = time.time() + try: + generator = rng or default_rng(config.seed) + + metric_arrays: Dict[SimulationMetric, np.ndarray] = { + metric: np.empty(config.iterations, dtype=float) + for metric in config.metrics + } + + for idx in range(config.iterations): + iteration_flows = [ + _realise_cash_flow( + spec, + generator, + scenario_context=scenario_context, + metadata=metadata, ) - ) - continue + for spec in cash_flows + ] - minimum = item.get("min") - maximum = item.get("max") - if minimum is None or maximum is None: - minimum = base_value - span_default - maximum = base_value + span_default - minimum = float(minimum) - maximum = float(maximum) - if minimum >= maximum: - raise ValueError( - f"Parameter '{name}' requires 'min' < 'max' for {distribution} distribution" - ) - - if distribution == "uniform": - compiled.append( - SimulationParameter( - name=name, - base_value=base_value, - distribution="uniform", - minimum=minimum, - maximum=maximum, + if SimulationMetric.NPV in metric_arrays: + metric_arrays[SimulationMetric.NPV][idx] = net_present_value( + config.discount_rate, + iteration_flows, + residual_value=config.residual_value, + residual_periods=config.residual_periods, + compounds_per_year=config.compounds_per_year, ) - ) - else: # triangular - mode = item.get("mode") - if mode is None: - mode = base_value - mode_value = float(mode) - if not (minimum <= mode_value <= maximum): - raise ValueError( - f"Parameter '{name}' mode must be within min/max bounds for triangular distribution" - ) - compiled.append( - SimulationParameter( - name=name, - base_value=base_value, - distribution="triangular", - minimum=minimum, - maximum=maximum, - mode=mode_value, - ) - ) - return compiled + if SimulationMetric.IRR in metric_arrays: + try: + metric_arrays[SimulationMetric.IRR][idx] = internal_rate_of_return( + iteration_flows, + compounds_per_year=config.compounds_per_year, + ) + except (ValueError, ConvergenceError): + metric_arrays[SimulationMetric.IRR][idx] = np.nan + if SimulationMetric.PAYBACK in metric_arrays: + try: + metric_arrays[SimulationMetric.PAYBACK][idx] = payback_period( + iteration_flows, + compounds_per_year=config.compounds_per_year, + ) + except (ValueError, PaybackNotReachedError): + metric_arrays[SimulationMetric.PAYBACK][idx] = np.nan + + summaries = { + metric: _summarise(metric_arrays[metric], config.percentiles) + for metric in metric_arrays + } + + samples = metric_arrays if config.return_samples else None + result = SimulationResult( + iterations=config.iterations, + summaries=summaries, + samples=samples, + ) + + # Record successful simulation + duration = time.time() - start_time + observe_simulation( + status="success", + duration_seconds=duration, + ) + return result + + except Exception: + # Record failed simulation + duration = time.time() - start_time + observe_simulation( + status="error", + duration_seconds=duration, + ) + raise -def _sample_parameter(rng: Random, param: SimulationParameter) -> float: - if param.distribution == "normal": - assert param.std_dev is not None - return rng.normalvariate(param.base_value, param.std_dev) - if param.distribution == "uniform": - assert param.minimum is not None and param.maximum is not None - return rng.uniform(param.minimum, param.maximum) - # triangular - assert ( - param.minimum is not None - and param.maximum is not None - and param.mode is not None +def _realise_cash_flow( + spec: CashFlowSpec, + generator: Generator, + *, + scenario_context: Mapping[str, Any] | None, + metadata: Mapping[str, Any] | None, +) -> CashFlow: + if spec.distribution is None: + return spec.cash_flow + + distribution = spec.distribution + base_amount = spec.cash_flow.amount + params = _resolve_parameters( + distribution, + base_amount, + scenario_context=scenario_context, + metadata=metadata, + ) + sample = _sample_distribution( + distribution.type, + params, + generator, + ) + return CashFlow( + amount=float(sample), + period_index=spec.cash_flow.period_index, + date=spec.cash_flow.date, ) - return rng.triangular(param.minimum, param.maximum, param.mode) -def run_simulation( - parameters: Sequence[Dict[str, float]], - iterations: int = 1000, - seed: Optional[int] = None, -) -> List[Dict[str, float]]: - """Run a lightweight Monte Carlo simulation using configurable distributions.""" +def _resolve_parameters( + distribution: DistributionSpec, + base_amount: float, + *, + scenario_context: Mapping[str, Any] | None, + metadata: Mapping[str, Any] | None, +) -> Dict[str, Any]: + params = dict(distribution.parameters) - if iterations <= 0: - return [] + if distribution.source == DistributionSource.SCENARIO_FIELD: + if distribution.source_key is None: + raise DistributionConfigError( + "source_key is required for scenario_field sourcing") + if not scenario_context or distribution.source_key not in scenario_context: + raise DistributionConfigError( + f"scenario field '{distribution.source_key}' not found for distribution" + ) + params.setdefault("mean", float( + scenario_context[distribution.source_key])) + elif distribution.source == DistributionSource.METADATA_KEY: + if distribution.source_key is None: + raise DistributionConfigError( + "source_key is required for metadata_key sourcing") + if not metadata or distribution.source_key not in metadata: + raise DistributionConfigError( + f"metadata key '{distribution.source_key}' not found for distribution" + ) + params.setdefault("mean", float(metadata[distribution.source_key])) + else: + params.setdefault("mean", float(base_amount)) - compiled_params = _compile_parameters(parameters) - if not compiled_params: - return [] + return params - rng = Random(seed) - results: List[Dict[str, float]] = [] - for iteration in range(1, iterations + 1): - total = 0.0 - for param in compiled_params: - sample = _sample_parameter(rng, param) - total += sample - results.append({"iteration": iteration, "result": total}) - return results + +def _sample_distribution( + distribution_type: DistributionType, + params: Mapping[str, Any], + generator: Generator, +) -> float: + if distribution_type is DistributionType.NORMAL: + return _sample_normal(params, generator) + if distribution_type is DistributionType.LOGNORMAL: + return _sample_lognormal(params, generator) + if distribution_type is DistributionType.TRIANGULAR: + return _sample_triangular(params, generator) + if distribution_type is DistributionType.DISCRETE: + return _sample_discrete(params, generator) + raise DistributionConfigError( + f"Unsupported distribution type: {distribution_type}") + + +def _sample_normal(params: Mapping[str, Any], generator: Generator) -> float: + if "std_dev" not in params: + raise DistributionConfigError("normal distribution requires 'std_dev'") + std_dev = float(params["std_dev"]) + if std_dev < 0: + raise DistributionConfigError("std_dev must be non-negative") + mean = float(params.get("mean", 0.0)) + if std_dev == 0: + return mean + return float(generator.normal(loc=mean, scale=std_dev)) + + +def _sample_lognormal(params: Mapping[str, Any], generator: Generator) -> float: + if "sigma" not in params: + raise DistributionConfigError( + "lognormal distribution requires 'sigma'") + sigma = float(params["sigma"]) + if sigma < 0: + raise DistributionConfigError("sigma must be non-negative") + if "mean" not in params: + raise DistributionConfigError( + "lognormal distribution requires 'mean' (mu in log space)") + mean = float(params["mean"]) + return float(generator.lognormal(mean=mean, sigma=sigma)) + + +def _sample_triangular(params: Mapping[str, Any], generator: Generator) -> float: + required = {"min", "mode", "max"} + if not required.issubset(params): + missing = ", ".join(sorted(required - params.keys())) + raise DistributionConfigError( + f"triangular distribution missing parameters: {missing}") + left = float(params["min"]) + mode = float(params["mode"]) + right = float(params["max"]) + if not (left <= mode <= right): + raise DistributionConfigError( + "triangular distribution requires min <= mode <= max") + if left == right: + return mode + return float(generator.triangular(left=left, mode=mode, right=right)) + + +def _sample_discrete(params: Mapping[str, Any], generator: Generator) -> float: + values = params.get("values") + probabilities = params.get("probabilities") + if not isinstance(values, Sequence) or not isinstance(probabilities, Sequence): + raise DistributionConfigError( + "discrete distribution requires 'values' and 'probabilities' sequences") + if len(values) != len(probabilities) or not values: + raise DistributionConfigError( + "values and probabilities must be non-empty and of equal length") + probs = np.array(probabilities, dtype=float) + if np.any(probs < 0): + raise DistributionConfigError("probabilities must be non-negative") + total = probs.sum() + if not np.isclose(total, 1.0): + raise DistributionConfigError("probabilities must sum to 1.0") + probs = probs / total + choices = np.array(values, dtype=float) + return float(generator.choice(choices, p=probs)) + + +def _summarise(values: np.ndarray, percentiles: Sequence[float]) -> MetricSummary: + clean = values[~np.isnan(values)] + sample_size = clean.size + failed_runs = values.size - sample_size + + if sample_size == 0: + percentile_map: Dict[float, float] = { + pct: float("nan") for pct in percentiles} + return MetricSummary( + mean=float("nan"), + std_dev=float("nan"), + minimum=float("nan"), + maximum=float("nan"), + percentiles=percentile_map, + sample_size=0, + failed_runs=failed_runs, + ) + + percentile_map = { + pct: float(np.percentile(clean, pct)) for pct in percentiles + } + return MetricSummary( + mean=float(np.mean(clean)), + std_dev=float(np.std(clean, ddof=1)) if sample_size > 1 else 0.0, + minimum=float(np.min(clean)), + maximum=float(np.max(clean)), + percentiles=percentile_map, + sample_size=sample_size, + failed_runs=failed_runs, + ) diff --git a/services/unit_of_work.py b/services/unit_of_work.py new file mode 100644 index 0000000..f35a9e3 --- /dev/null +++ b/services/unit_of_work.py @@ -0,0 +1,201 @@ +from __future__ import annotations + +from contextlib import AbstractContextManager +from typing import Callable, Sequence + +from sqlalchemy.orm import Session + +from config.database import SessionLocal +from models import PricingSettings, Project, Role, Scenario +from services.pricing import PricingMetadata +from services.repositories import ( + FinancialInputRepository, + PricingSettingsRepository, + PricingSettingsSeedResult, + ProjectRepository, + ProjectProfitabilityRepository, + ProjectOpexRepository, + ProjectCapexRepository, + RoleRepository, + ScenarioRepository, + ScenarioProfitabilityRepository, + ScenarioOpexRepository, + ScenarioCapexRepository, + SimulationParameterRepository, + UserRepository, + ensure_admin_user as ensure_admin_user_record, + ensure_default_pricing_settings, + ensure_default_roles, + pricing_settings_to_metadata, + NavigationRepository, +) +from services.scenario_validation import ScenarioComparisonValidator + + +class UnitOfWork(AbstractContextManager["UnitOfWork"]): + """Simple unit-of-work wrapper around SQLAlchemy sessions.""" + + def __init__(self, session_factory: Callable[[], Session] = SessionLocal) -> None: + self._session_factory = session_factory + self.session: Session | None = None + self._scenario_validator: ScenarioComparisonValidator | None = None + self.projects: ProjectRepository | None = None + self.scenarios: ScenarioRepository | None = None + self.financial_inputs: FinancialInputRepository | None = None + self.simulation_parameters: SimulationParameterRepository | None = None + self.project_profitability: ProjectProfitabilityRepository | None = None + self.project_capex: ProjectCapexRepository | None = None + self.project_opex: ProjectOpexRepository | None = None + self.scenario_profitability: ScenarioProfitabilityRepository | None = None + self.scenario_capex: ScenarioCapexRepository | None = None + self.scenario_opex: ScenarioOpexRepository | None = None + self.users: UserRepository | None = None + self.roles: RoleRepository | None = None + self.pricing_settings: PricingSettingsRepository | None = None + self.navigation: NavigationRepository | None = None + + def __enter__(self) -> "UnitOfWork": + self.session = self._session_factory() + self.projects = ProjectRepository(self.session) + self.scenarios = ScenarioRepository(self.session) + self.financial_inputs = FinancialInputRepository(self.session) + self.simulation_parameters = SimulationParameterRepository( + self.session) + self.project_profitability = ProjectProfitabilityRepository( + self.session) + self.project_capex = ProjectCapexRepository(self.session) + self.project_opex = ProjectOpexRepository( + self.session) + self.scenario_profitability = ScenarioProfitabilityRepository( + self.session + ) + self.scenario_capex = ScenarioCapexRepository(self.session) + self.scenario_opex = ScenarioOpexRepository( + self.session) + self.users = UserRepository(self.session) + self.roles = RoleRepository(self.session) + self.pricing_settings = PricingSettingsRepository(self.session) + self.navigation = NavigationRepository(self.session) + self._scenario_validator = ScenarioComparisonValidator() + return self + + def __exit__(self, exc_type, exc_value, traceback) -> None: + assert self.session is not None + if exc_type is None: + self.session.commit() + else: + self.session.rollback() + self.session.close() + self._scenario_validator = None + self.projects = None + self.scenarios = None + self.financial_inputs = None + self.simulation_parameters = None + self.project_profitability = None + self.project_capex = None + self.project_opex = None + self.scenario_profitability = None + self.scenario_capex = None + self.scenario_opex = None + self.users = None + self.roles = None + self.pricing_settings = None + self.navigation = None + + def flush(self) -> None: + if not self.session: + raise RuntimeError("UnitOfWork session is not initialised") + self.session.flush() + + def commit(self) -> None: + if not self.session: + raise RuntimeError("UnitOfWork session is not initialised") + self.session.commit() + + def rollback(self) -> None: + if not self.session: + raise RuntimeError("UnitOfWork session is not initialised") + self.session.rollback() + + def validate_scenarios_for_comparison( + self, scenario_ids: Sequence[int] + ) -> list[Scenario]: + if not self.session or not self._scenario_validator or not self.scenarios: + raise RuntimeError("UnitOfWork session is not initialised") + + scenarios = [self.scenarios.get(scenario_id) + for scenario_id in scenario_ids] + self._scenario_validator.validate(scenarios) + return scenarios + + def validate_scenario_models_for_comparison( + self, scenarios: Sequence[Scenario] + ) -> None: + if not self._scenario_validator: + raise RuntimeError("UnitOfWork session is not initialised") + self._scenario_validator.validate(scenarios) + + def ensure_default_roles(self) -> list[Role]: + if not self.roles: + raise RuntimeError("UnitOfWork session is not initialised") + return ensure_default_roles(self.roles) + + def ensure_admin_user( + self, + *, + email: str, + username: str, + password: str, + ) -> None: + if not self.users or not self.roles: + raise RuntimeError("UnitOfWork session is not initialised") + ensure_default_roles(self.roles) + ensure_admin_user_record( + self.users, + self.roles, + email=email, + username=username, + password=password, + ) + + def ensure_default_pricing_settings( + self, + *, + metadata: PricingMetadata, + slug: str = "default", + name: str | None = None, + description: str | None = None, + ) -> PricingSettingsSeedResult: + if not self.pricing_settings: + raise RuntimeError("UnitOfWork session is not initialised") + return ensure_default_pricing_settings( + self.pricing_settings, + metadata=metadata, + slug=slug, + name=name, + description=description, + ) + + def get_pricing_metadata( + self, + *, + slug: str = "default", + ) -> PricingMetadata | None: + if not self.pricing_settings: + raise RuntimeError("UnitOfWork session is not initialised") + settings = self.pricing_settings.find_by_slug( + slug, + include_children=True, + ) + if settings is None: + return None + return pricing_settings_to_metadata(settings) + + def set_project_pricing_settings( + self, + project: Project, + pricing_settings: PricingSettings | None, + ) -> Project: + if not self.projects: + raise RuntimeError("UnitOfWork session is not initialised") + return self.projects.set_pricing_settings(project, pricing_settings) diff --git a/static/css/dashboard.css b/static/css/dashboard.css new file mode 100644 index 0000000..aa8c2d3 --- /dev/null +++ b/static/css/dashboard.css @@ -0,0 +1,82 @@ +:root { + --dashboard-gap: 1.5rem; +} + +.dashboard-metrics { + display: grid; + gap: var(--dashboard-gap); + grid-template-columns: repeat(auto-fit, minmax(220px, 1fr)); + margin-bottom: 2rem; +} + +.dashboard-grid { + display: grid; + gap: var(--dashboard-gap); + grid-template-columns: 2fr 1fr; + align-items: start; +} + +.grid-main { + display: grid; + gap: var(--dashboard-gap); +} + +.grid-sidebar { + display: grid; + gap: var(--dashboard-gap); +} + +.timeline { + list-style: none; + margin: 0; + padding: 0; + display: flex; + flex-direction: column; + gap: 1rem; +} + +.timeline-label { + font-size: 0.85rem; + color: var(--color-text-subtle); + display: block; + margin-bottom: 0.35rem; +} + +.alerts-list, +.links-list { + list-style: none; + margin: 0; + padding: 0; + display: flex; + flex-direction: column; + gap: 0.75rem; +} + +.alerts-list li { + padding: 0.75rem; + border-radius: var(--radius-sm); + background: rgba(209, 75, 75, 0.16); + background: color-mix(in srgb, var(--color-danger) 16%, transparent); + border: 1px solid rgba(209, 75, 75, 0.3); + border: 1px solid color-mix(in srgb, var(--color-danger) 30%, transparent); +} + +.links-list a { + color: var(--brand-3); + text-decoration: none; +} + +.links-list a:hover, +.links-list a:focus { + text-decoration: underline; +} + +@media (max-width: 1024px) { + .dashboard-grid { + grid-template-columns: 1fr; + } + + .grid-sidebar { + grid-template-columns: repeat(auto-fit, minmax(260px, 1fr)); + } +} diff --git a/static/css/forms.css b/static/css/forms.css new file mode 100644 index 0000000..7fd67ab --- /dev/null +++ b/static/css/forms.css @@ -0,0 +1,111 @@ +.form { + display: flex; + flex-direction: column; + gap: 1.25rem; +} + +.form-grid { + display: grid; + grid-template-columns: repeat(auto-fit, minmax(240px, 1fr)); + gap: 1.25rem; +} + +.form-group { + display: flex; + flex-direction: column; + gap: 0.5rem; +} + +.form-group label { + font-weight: 600; + color: var(--text); + color: var(--color-text-primary); +} + +.form-group input, +.form-group select, +.form-group textarea { + padding: 0.75rem 0.85rem; + border-radius: var(--radius-sm); + border: 1px solid var(--card-border); + background: rgba(8, 12, 19, 0.78); + background: color-mix(in srgb, var(--color-bg-elevated) 78%, transparent); + color: var(--text); + color: var(--color-text-primary); + transition: border-color 0.15s ease, background 0.2s ease, + box-shadow 0.2s ease; +} + +.form-group textarea { + resize: vertical; + min-height: 120px; +} + +.form-group input:focus, +.form-group select:focus, +.form-group textarea:focus { + outline: 2px solid var(--brand-2); + outline: 2px solid var(--color-brand-bright); + outline-offset: 1px; +} + +.form-group input:disabled, +.form-group select:disabled, +.form-group textarea:disabled { + cursor: not-allowed; + opacity: 0.6; +} + +.form-group--error input, +.form-group--error select, +.form-group--error textarea { + border-color: rgba(209, 75, 75, 0.6); + border-color: color-mix(in srgb, var(--color-danger) 60%, transparent); + box-shadow: 0 0 0 1px rgba(209, 75, 75, 0.3); + box-shadow: 0 0 0 1px color-mix(in srgb, var(--color-danger) 30%, transparent); +} + +.field-help { + margin: 0; + font-size: 0.85rem; + color: var(--color-text-subtle); +} + +.field-error { + margin: 0; + font-size: 0.85rem; + color: var(--danger); + color: var(--color-danger); +} + +.form-actions { + display: flex; + flex-wrap: wrap; + gap: 0.75rem; + justify-content: flex-end; +} + +.form-fieldset { + border: 1px solid var(--color-border); + border-radius: var(--radius); + background: rgba(21, 27, 35, 0.85); + background: var(--color-surface-overlay); + box-shadow: var(--shadow); + padding: 1.5rem; + display: flex; + flex-direction: column; + gap: 1.25rem; +} + +.form-fieldset legend { + font-weight: 700; + padding: 0 0.5rem; + color: var(--text); + color: var(--color-text-primary); +} + +@media (max-width: 640px) { + .form-actions { + justify-content: stretch; + } +} diff --git a/static/css/imports.css b/static/css/imports.css new file mode 100644 index 0000000..86d424a --- /dev/null +++ b/static/css/imports.css @@ -0,0 +1,80 @@ +.import-upload { + background-color: rgba(21, 27, 35, 0.85); + background-color: var(--color-surface-overlay); + border: 1px dashed var(--color-border); + border-radius: var(--radius); + padding: 1.5rem; + margin-bottom: 1.5rem; +} + +.import-upload__header { + margin-bottom: 1rem; +} + +.import-upload__dropzone { + border: 2px dashed var(--color-border); + border-radius: var(--radius-sm); + padding: 2rem; + text-align: center; + transition: border-color 0.2s ease, background-color 0.2s ease; +} + +.import-upload__dropzone.dragover { + border-color: #f6c648; + border-color: var(--color-brand-bright); + background-color: rgba(241, 178, 26, 0.08); + background-color: var(--color-highlight); +} + +.import-upload__actions { + display: flex; + gap: 0.75rem; + margin-top: 1rem; +} + +.table-cell-actions { + display: flex; + align-items: center; + gap: 0.5rem; +} + +.toast { + position: fixed; + right: 1rem; + bottom: 1rem; + display: flex; + align-items: center; + gap: 0.75rem; + padding: 1rem 1.25rem; + border-radius: var(--radius); + color: var(--color-text-invert); + box-shadow: var(--shadow); + z-index: 1000; +} + +.toast.hidden { + display: none; +} + +.toast--success { + background-color: var(--success); + background-color: var(--color-success); +} + +.toast--error { + background-color: var(--danger); + background-color: var(--color-danger); +} + +.toast--info { + background-color: var(--info); + background-color: var(--color-info); +} + +.toast__close { + background: none; + border: none; + color: inherit; + cursor: pointer; + font-size: 1.1rem; +} diff --git a/static/css/main.css b/static/css/main.css index 3b22daa..1834e47 100644 --- a/static/css/main.css +++ b/static/css/main.css @@ -1,29 +1,12 @@ :root { - --bg: #0b0f14; - --bg-2: #0f141b; - --card: #151b23; - --text: #e6edf3; - --muted: #a9b4c0; - --brand: #f1b21a; - --brand-2: #f6c648; - --brand-3: #f9d475; - --accent: #2ba58f; - --danger: #d14b4b; - --shadow: 0 10px 30px rgba(0, 0, 0, 0.35); + /* Radii & layout */ --radius: 14px; --radius-sm: 10px; + --panel-radius: var(--radius); + --table-radius: var(--radius-sm); --container: 1180px; - --muted: var(--muted); - --color-text-subtle: rgba(169, 180, 192, 0.6); - --color-text-invert: #ffffff; - --color-text-dark: #0f172a; - --color-text-strong: #111827; - --color-border: rgba(255, 255, 255, 0.08); - --color-border-strong: rgba(255, 255, 255, 0.12); - --color-highlight: rgba(241, 178, 26, 0.08); - --color-panel-shadow: rgba(0, 0, 0, 0.25); - --color-panel-shadow-deep: rgba(0, 0, 0, 0.35); - --color-surface-alt: rgba(21, 27, 35, 0.7); + + /* Spacing & typography */ --space-2xs: 0.25rem; --space-xs: 0.5rem; --space-sm: 0.75rem; @@ -31,18 +14,13 @@ --space-lg: 1.5rem; --space-xl: 2rem; --space-2xl: 3rem; + --font-size-xs: 0.75rem; --font-size-sm: 0.875rem; --font-size-base: 1rem; --font-size-lg: 1.25rem; --font-size-xl: 1.5rem; --font-size-2xl: 2rem; - --panel-radius: var(--radius); - --table-radius: var(--radius-sm); -} - -* { - box-sizing: border-box; } html, @@ -52,17 +30,522 @@ body { body { margin: 0; - font-family: ui-sans-serif, system-ui, -apple-system, 'Segoe UI', 'Roboto', - Helvetica, Arial, 'Apple Color Emoji', 'Segoe UI Emoji'; + font-family: ui-sans-serif, system-ui, -apple-system, "Segoe UI", "Roboto", + Helvetica, Arial, "Apple Color Emoji", "Segoe UI Emoji"; color: var(--text); background: linear-gradient(180deg, var(--bg) 0%, var(--bg-2) 100%); line-height: 1.45; } +.header-actions { + display: flex; + gap: 0.75rem; + flex-wrap: wrap; + justify-content: flex-end; +} + +h1, +h2, +h3, +h4, +h5, +h6 { + margin: 0 0 0.5rem 0; + font-weight: 700; + line-height: 1.2; +} + +h1 { + font-size: var(--font-size-2xl); +} + +h2 { + font-size: var(--font-size-xl); +} + +h3 { + font-size: var(--font-size-lg); +} + +p { + margin: 0 0 1rem 0; +} + a { color: var(--brand); } +.report-overview { + margin-bottom: 2.5rem; +} + +.report-grid { + display: grid; + gap: 1.5rem; + grid-template-columns: repeat(auto-fit, minmax(280px, 1fr)); +} + +.report-card { + background: var(--card); + border-radius: var(--radius); + padding: 1.5rem; + border: 1px solid var(--color-border); + box-shadow: 0 12px 30px rgba(4, 7, 14, 0.35); +} + +.report-card h2 { + margin-top: 0; + margin-bottom: 1rem; +} + +.report-section + .report-section { + margin-top: 3rem; +} + +.chart-container { + width: 100%; + height: 400px; + background: rgba(15, 20, 27, 0.8); + border-radius: var(--radius-sm); + border: 1px solid rgba(255, 255, 255, 0.05); + box-shadow: inset 0 1px 0 rgba(255, 255, 255, 0.06); + margin-bottom: 1rem; +} + +.section-header { + margin-bottom: 1.25rem; +} + +.section-header h2 { + margin: 0; +} + +.section-subtitle { + margin: 0.35rem 0 0; + color: var(--muted); +} + +.metric-list { + list-style: none; + margin: 0; + padding: 0; + display: flex; + flex-direction: column; + gap: 0.75rem; +} + +.metric-list.compact { + gap: 0.35rem; +} + +.metric-list li { + display: flex; + justify-content: space-between; + align-items: baseline; + font-size: 0.95rem; + color: var(--muted); +} + +.metric-list strong { + font-size: 1.05rem; + color: var(--text); +} + +.metric-card { + background: var(--color-surface-overlay); + border-radius: var(--radius); + padding: 1.5rem; + box-shadow: var(--shadow); + border: 1px solid var(--color-border); + display: flex; + flex-direction: column; + gap: 0.35rem; +} + +.metric-card h2 { + margin: 0; + font-size: 1rem; + color: var(--color-text-muted); + text-transform: uppercase; + letter-spacing: 0.08em; +} + +.metric-value { + font-size: 2rem; + font-weight: 700; + margin: 0; +} + +.metric-caption { + color: var(--color-text-subtle); + font-size: 0.85rem; +} + +.metrics-table { + width: 100%; + border-collapse: collapse; + background: rgba(21, 27, 35, 0.6); + border-radius: var(--radius-sm); + overflow: hidden; +} + +.metrics-table th, +.metrics-table td { + padding: 0.65rem 0.9rem; + text-align: left; + border-bottom: 1px solid rgba(255, 255, 255, 0.08); +} + +.metrics-table th { + font-weight: 600; + color: var(--color-text-dark); +} + +.metrics-table tr:last-child td, +.metrics-table tr:last-child th { + border-bottom: none; +} + +.definition-list { + margin: 0; + display: grid; + gap: 1.25rem 2rem; + grid-template-columns: repeat(auto-fit, minmax(220px, 1fr)); +} + +.definition-list div { + display: grid; + grid-template-columns: minmax(140px, 0.6fr) minmax(0, 1fr); + gap: 0.5rem; + align-items: baseline; +} + +.definition-list dt { + margin: 0; + font-weight: 600; + color: var(--color-text-muted); + text-transform: uppercase; + font-size: 0.75rem; + letter-spacing: 0.08em; +} + +.definition-list dd { + margin: 0; + font-size: 1rem; + color: var(--color-text-primary); +} + +.scenario-card { + background: var(--card); + border-radius: var(--radius); + padding: 1.5rem; + border: 1px solid var(--color-border); + box-shadow: 0 16px 32px rgba(4, 7, 14, 0.42); + display: flex; + flex-direction: column; + gap: 1.25rem; +} + +.scenario-card + .scenario-card { + margin-top: 1.75rem; +} + +.scenario-card-header { + display: flex; + justify-content: space-between; + gap: 1rem; + align-items: flex-start; +} + +.scenario-card h3 { + margin: 0; +} + +.scenario-meta { + display: grid; + grid-template-columns: repeat(auto-fit, minmax(220px, 1fr)); + gap: 1.25rem; +} + +.scenario-card .scenario-meta { + display: block; + text-align: right; +} + +.meta-label { + display: block; + color: var(--muted); + font-size: 0.8rem; + text-transform: uppercase; + letter-spacing: 0.08em; +} + +.meta-value { + font-weight: 600; +} + +.scenario-grid { + display: grid; + gap: 1.25rem; + grid-template-columns: repeat(auto-fit, minmax(260px, 1fr)); +} + +.scenario-panel { + background: rgba(15, 20, 27, 0.8); + border-radius: var(--radius-sm); + padding: 1.1rem; + border: 1px solid rgba(255, 255, 255, 0.05); + box-shadow: inset 0 1px 0 rgba(255, 255, 255, 0.06); +} + +.scenario-panel h4, +.scenario-panel h5 { + margin-top: 0; + margin-bottom: 0.75rem; +} + +.note-list { + padding-left: 1.1rem; + color: var(--muted); + font-size: 0.9rem; +} + +.muted { + color: var(--muted); +} + +.quick-link-list { + list-style: none; + margin: 0; + padding: 0; + display: flex; + flex-direction: column; + gap: 1rem; +} + +.quick-link-list li a { + font-weight: 600; + color: var(--brand-2); + text-decoration: none; +} + +.quick-link-list li a:hover, +.quick-link-list li a:focus { + text-decoration: underline; +} + +.quick-link-list p { + margin: 0.25rem 0 0; + color: var(--color-text-subtle); + font-size: 0.9rem; +} + +.scenario-list { + list-style: none; + margin: 0; + padding: 0; + display: flex; + flex-direction: column; + gap: 1rem; +} + +.scenario-item { + background: rgba(21, 27, 35, 0.85); + background: color-mix(in srgb, var(--color-surface-default) 85%, transparent); + border: 1px solid var(--color-border); + border-radius: var(--radius); + padding: 1.25rem; + display: flex; + flex-direction: column; + gap: 1rem; +} + +.scenario-item__body { + display: flex; + flex-direction: column; + gap: 1rem; +} + +.scenario-item__header { + display: flex; + flex-wrap: wrap; + align-items: center; + gap: 0.75rem; + justify-content: space-between; +} + +.scenario-item__header h3 { + margin: 0; + font-size: 1.1rem; +} + +.scenario-item__header a { + color: inherit; + text-decoration: none; +} + +.scenario-item__header a:hover, +.scenario-item__header a:focus { + text-decoration: underline; +} + +.scenario-item__meta { + display: grid; + gap: 0.75rem; + grid-template-columns: repeat(auto-fit, minmax(150px, 1fr)); +} + +.scenario-item__meta dt { + margin: 0; + font-size: 0.75rem; + color: var(--color-text-muted); + text-transform: uppercase; + letter-spacing: 0.08em; +} + +.scenario-item__meta dd { + margin: 0; + font-size: 0.95rem; +} + +.scenario-item__actions { + display: flex; + gap: 0.75rem; + flex-wrap: wrap; +} + +.scenario-item__actions .btn--link { + padding: 0; +} + +.status-pill { + display: inline-flex; + align-items: center; + gap: 0.35rem; + padding: 0.35rem 0.85rem; + border-radius: 999px; + font-size: 0.75rem; + text-transform: uppercase; + letter-spacing: 0.08em; +} + +.status-pill--draft { + background: rgba(59, 130, 246, 0.15); + color: #93c5fd; + background: color-mix(in srgb, var(--color-info) 18%, transparent); + color: color-mix(in srgb, var(--color-info) 70%, white); +} + +.status-pill--active { + background: rgba(34, 197, 94, 0.18); + color: #86efac; + background: color-mix(in srgb, var(--color-success) 18%, transparent); + color: color-mix(in srgb, var(--color-success) 70%, white); +} + +.status-pill--archived { + background: rgba(148, 163, 184, 0.24); + color: #cbd5f5; + background: color-mix(in srgb, var(--color-text-muted) 24%, transparent); + color: color-mix(in srgb, var(--color-text-muted) 60%, white); +} + +.empty-state { + color: var(--color-text-muted); + font-style: italic; +} + +.table { + width: 100%; + border-collapse: collapse; + border-radius: var(--table-radius); + overflow: hidden; + box-shadow: var(--shadow); +} + +.table th, +.table td { + padding: 0.75rem 1rem; + border-bottom: 1px solid var(--color-border); + background: rgba(21, 27, 35, 0.85); + background: color-mix(in srgb, var(--color-surface-default) 85%, transparent); +} + +.table tbody tr:hover { + background: rgba(241, 178, 26, 0.12); + background: var(--color-highlight); +} + +.table-link { + color: var(--brand-2); + text-decoration: none; + margin-left: 0.5rem; +} + +.table-link:hover, +.table-link:focus { + text-decoration: underline; +} + +.table-responsive { + width: 100%; + overflow-x: auto; + -webkit-overflow-scrolling: touch; + border-radius: var(--table-radius); + margin: 0; +} + +.table-responsive .table { + min-width: 640px; +} + +.table-responsive::-webkit-scrollbar { + height: 6px; +} + +.table-responsive::-webkit-scrollbar-thumb { + background: rgba(255, 255, 255, 0.2); + background: color-mix(in srgb, var(--color-text-invert) 20%, transparent); + border-radius: 999px; +} + +.page-actions .button { + text-decoration: none; + background: transparent; + border: 1px solid var(--color-border); + padding: 0.6rem 1rem; + border-radius: var(--radius-sm); + color: var(--text); + font-weight: 600; + transition: background 0.2s ease, border-color 0.2s ease; +} + +.page-actions .button:hover, +.page-actions .button:focus { + background: rgba(241, 178, 26, 0.14); + border-color: var(--brand); +} + +.breadcrumb { + display: flex; + align-items: center; + gap: 0.5rem; + font-size: 0.9rem; + color: var(--muted); + margin-bottom: 1.2rem; +} + +.breadcrumb a { + color: var(--brand-2); + text-decoration: none; +} + +.breadcrumb a::after { + content: ">"; + margin-left: 0.5rem; + color: var(--muted); +} + .app-layout { display: flex; min-height: 100vh; @@ -93,20 +576,58 @@ a { display: flex; align-items: center; gap: 1rem; + padding: 0.5rem 1rem; + border-radius: 0.75rem; +} +a.sidebar-brand { + text-decoration: none; +} +a.sidebar-brand:hover, +a.sidebar-brand:focus { + color: var(--color-text-invert); + background-color: rgba(148, 197, 255, 0.18); +} + +.sidebar-nav-controls { + display: flex; + justify-content: center; + gap: 1rem; + margin: 0; +} + +.nav-chevron { + width: 5rem; + height: 5rem; + border: none; + background: rgba(0, 0, 0, 0.5); + color: rgba(255, 255, 255, 0.88); + font-size: 4.5rem; + font-weight: bold; + cursor: pointer; + display: flex; + align-items: center; + justify-content: center; + transition: background 0.2s ease, transform 0.2s ease; +} + +.nav-chevron:hover, +.nav-chevron:focus { + background: rgba(0, 0, 0, 0.1); + color: rgba(255, 255, 255, 1); + transform: scale(0.9); +} + +.nav-chevron:disabled { + opacity: 0.5; + cursor: not-allowed; + transform: none; } .brand-logo { - display: inline-flex; - align-items: center; - justify-content: center; width: 44px; height: 44px; border-radius: 12px; - background: linear-gradient(0deg, var(--brand-3), var(--accent)); - color: var(--color-text-invert); - font-weight: 700; - font-size: 1.1rem; - letter-spacing: 1px; + object-fit: cover; } .brand-text { @@ -248,7 +769,7 @@ a { .dashboard-header { display: flex; - align-items: flex-start; + align-items: center; justify-content: space-between; gap: 1.5rem; margin-bottom: 2rem; @@ -337,7 +858,7 @@ a { gap: var(--space-sm); font-weight: 600; color: var(--text); - font-family: 'Fira Code', 'Consolas', 'Courier New', monospace; + font-family: "Fira Code", "Consolas", "Courier New", monospace; font-size: 0.85rem; } @@ -366,7 +887,7 @@ a { } .color-value-input { - font-family: 'Fira Code', 'Consolas', 'Courier New', monospace; + font-family: "Fira Code", "Consolas", "Courier New", monospace; } .color-value-input[disabled] { @@ -395,7 +916,7 @@ a { } .env-overrides-table code { - font-family: 'Fira Code', 'Consolas', 'Courier New', monospace; + font-family: "Fira Code", "Consolas", "Courier New", monospace; font-size: 0.85rem; } @@ -550,7 +1071,7 @@ a { } .btn.is-loading::after { - content: ''; + content: ""; width: 0.85rem; height: 0.85rem; border: 2px solid rgba(255, 255, 255, 0.6); @@ -590,36 +1111,6 @@ a { font-size: var(--font-size-lg); } -.form-grid { - display: grid; - gap: var(--space-md); - max-width: 480px; -} - -.form-grid label { - display: flex; - flex-direction: column; - gap: var(--space-sm); - font-weight: 600; - color: var(--text); -} - -.form-grid input, -.form-grid textarea, -.form-grid select { - padding: 0.6rem var(--space-sm); - border: 1px solid var(--color-border-strong); - border-radius: 8px; - font-size: var(--font-size-base); -} - -.form-grid input:focus, -.form-grid textarea:focus, -.form-grid select:focus { - outline: 2px solid var(--brand-2); - outline-offset: 1px; -} - .btn { display: inline-flex; align-items: center; @@ -627,28 +1118,101 @@ a { gap: 0.5rem; padding: 0.65rem 1.25rem; border-radius: 999px; - border: none; + border: 1px solid var(--btn-secondary-border); cursor: pointer; font-weight: 600; - background-color: var(--color-border); - color: var(--color-text-dark); - transition: transform 0.15s ease, box-shadow 0.15s ease; + background-color: var(--btn-secondary-bg); + color: var(--btn-secondary-color); + text-decoration: none; + transition: transform 0.15s ease, box-shadow 0.15s ease, + background-color 0.2s ease, border-color 0.2s ease; } .btn:hover, .btn:focus { transform: translateY(-1px); box-shadow: 0 4px 10px var(--color-panel-shadow); + background-color: var(--btn-secondary-hover); } -.btn.primary { - background-color: var(--brand-2); - color: var(--color-text-invert); +.btn--primary, +.btn.primary, +.btn.btn-primary { + background-color: var(--btn-primary-bg); + border-color: transparent; + color: var(--btn-primary-color); } +.btn--primary:hover, +.btn--primary:focus, .btn.primary:hover, -.btn.primary:focus { - background-color: var(--brand-3); +.btn.primary:focus, +.btn.btn-primary:hover, +.btn.btn-primary:focus { + background-color: var(--btn-primary-hover); +} + +.btn--secondary, +.btn.secondary, +.btn.btn-secondary { + background-color: var(--btn-secondary-bg); + border-color: var(--btn-secondary-border); + color: var(--btn-secondary-color); +} + +.btn--secondary:hover, +.btn--secondary:focus, +.btn.secondary:hover, +.btn.secondary:focus, +.btn.btn-secondary:hover, +.btn.btn-secondary:focus { + background-color: var(--btn-secondary-hover); +} + +.btn--link, +.btn.btn-link, +.btn.link { + padding: 0.25rem 0; + border: none; + background: transparent; + color: var(--btn-link-color); + margin: 0; + box-shadow: none; +} + +.btn--link:hover, +.btn--link:focus, +.btn.btn-link:hover, +.btn.btn-link:focus, +.btn.link:hover, +.btn.link:focus { + transform: none; + box-shadow: none; + color: var(--btn-link-hover); + text-decoration: underline; +} + +.btn--ghost { + background: transparent; + border: 1px solid transparent; + color: var(--btn-ghost-color); +} + +.btn--ghost:hover, +.btn--ghost:focus { + background: rgba(255, 255, 255, 0.1); + border-color: rgba(255, 255, 255, 0.2); +} + +.btn--icon { + padding: 0.4rem; + border-radius: 50%; + line-height: 0; +} + +.btn--icon:hover, +.btn--icon:focus { + transform: none; } .result-output { @@ -656,14 +1220,14 @@ a { color: var(--color-surface-alt); padding: 1rem; border-radius: 8px; - font-family: 'Fira Code', 'Consolas', 'Courier New', monospace; + font-family: "Fira Code", "Consolas", "Courier New", monospace; overflow-x: auto; margin-top: 1.5rem; } .monospace-input { width: 100%; - font-family: 'Fira Code', 'Consolas', 'Courier New', monospace; + font-family: "Fira Code", "Consolas", "Courier New", monospace; min-height: 120px; } @@ -726,9 +1290,27 @@ tbody tr:nth-child(even) { color: var(--danger); } +.alert { + padding: 0.75rem 1rem; + border-radius: var(--radius-sm); + margin-bottom: 1rem; +} + +.alert-error { + background: rgba(209, 75, 75, 0.2); + border: 1px solid rgba(209, 75, 75, 0.4); + color: var(--color-text-invert); +} + +.alert-info { + background: rgba(43, 168, 143, 0.2); + border: 1px solid rgba(43, 168, 143, 0.4); + color: var(--color-text-invert); +} + .site-footer { background-color: var(--brand); - color: var(--color-text-invert); + color: var(--color-text-strong); margin-top: 3rem; } @@ -738,12 +1320,156 @@ tbody tr:nth-child(even) { justify-content: center; padding: 1rem 0; font-size: 0.9rem; + gap: 1rem; +} + +.footer-logo { + display: flex; + align-items: center; +} + +.footer-logo-img { + width: 32px; + height: 32px; + border-radius: 8px; + object-fit: cover; +} + +footer p { + margin: 0; +} +footer a { + font-weight: 600; + color: var(--color-text-dark); + text-decoration: underline; +} +footer a:hover, +footer a:focus { + color: var(--color-text-strong); +} + +.sidebar-toggle { + display: none; + align-items: center; + gap: 0.6rem; + padding: 0.55rem 1rem; + border-radius: 999px; + border: none; + background: linear-gradient(135deg, var(--brand-2), var(--brand)); + color: var(--color-text-dark); + font-weight: 600; + cursor: pointer; + box-shadow: 0 6px 16px rgba(0, 0, 0, 0.25); + transition: transform 0.2s ease, box-shadow 0.2s ease; +} + +.sidebar-toggle:hover, +.sidebar-toggle:focus-visible { + transform: translateY(-1px); + box-shadow: 0 8px 20px rgba(0, 0, 0, 0.3); +} + +.sidebar-toggle:focus-visible { + outline: 2px solid rgba(255, 255, 255, 0.65); + outline-offset: 3px; +} + +.sidebar-toggle-icon { + position: relative; + display: inline-block; + width: 18px; + height: 2px; + background-color: currentColor; +} + +.sidebar-toggle-icon::before, +.sidebar-toggle-icon::after { + content: ""; + position: absolute; + left: 0; + width: 18px; + height: 2px; + background-color: currentColor; +} + +.sidebar-toggle-icon::before { + top: -6px; +} + +.sidebar-toggle-icon::after { + top: 6px; +} + +.sidebar-toggle-label { + font-size: 0.95rem; +} + +.sidebar-overlay { + position: fixed; + inset: 0; + background: rgba(7, 11, 17, 0.6); + z-index: 800; + opacity: 0; + pointer-events: none; + transition: opacity 0.25s ease; +} + +@media (min-width: 720px) { + .table-responsive .table { + min-width: 100%; + } +} + +@media (max-width: 640px) { + .table th, + .table td { + padding: 0.55rem 0.65rem; + font-size: 0.9rem; + white-space: nowrap; + } + + .table tbody tr { + border-radius: var(--radius-sm); + } + + .metric-card { + padding: 1.25rem; + } + + .metric-value { + font-size: 1.75rem; + } + + .header-actions { + flex-direction: column; + align-items: stretch; + } +} + +@media (min-width: 960px) { + .header-actions { + justify-content: flex-start; + } + + .scenario-item { + flex-direction: row; + justify-content: space-between; + align-items: center; + } + + .scenario-item__body { + max-width: 70%; + } } @media (max-width: 1024px) { .app-sidebar { width: 240px; } + + .header-actions { + justify-content: flex-start; + } } @media (max-width: 900px) { @@ -773,8 +1499,16 @@ tbody tr:nth-child(even) { justify-content: center; } + .sidebar-nav-controls { + display: none; + } + + .sidebar-link-block { + align-items: center; + } + .sidebar-link { - flex: 1 1 140px; + flex: 1 1 40px; justify-content: center; } @@ -790,4 +1524,38 @@ tbody tr:nth-child(even) { .dashboard-columns { grid-template-columns: 1fr; } + + .sidebar-toggle { + display: inline-flex; + margin: 1rem auto 1.5rem; + } + + body.sidebar-collapsed .app-sidebar { + display: none; + } + + body.sidebar-open { + overflow: hidden; + } + + body.sidebar-open .app-main { + position: relative; + z-index: 1; + } + body.sidebar-open .app-sidebar { + display: block; + position: fixed; + top: 0; + left: 0; + width: min(320px, 82vw); + height: 100vh; + overflow-y: auto; + z-index: 999; + box-shadow: 0 12px 30px rgba(8, 14, 25, 0.4); + } + + body.sidebar-open .sidebar-overlay { + opacity: 1; + pointer-events: auto; + } } diff --git a/static/css/projects.css b/static/css/projects.css new file mode 100644 index 0000000..c9cb895 --- /dev/null +++ b/static/css/projects.css @@ -0,0 +1,183 @@ +.projects-grid { + display: grid; + gap: 1.5rem; + grid-template-columns: repeat(auto-fit, minmax(320px, 1fr)); + margin-top: 1.5rem; +} + +.project-card { + background: var(--color-surface-overlay); + border: 1px solid var(--color-border); + box-shadow: var(--shadow); + border-radius: var(--radius); + padding: 1.5rem; + display: flex; + flex-direction: column; + gap: 1.25rem; + transition: transform 0.2s ease, box-shadow 0.2s ease; +} + +.project-card:hover, +.project-card:focus-within { + transform: translateY(-2px); + box-shadow: 0 22px 45px var(--color-panel-shadow-deep); +} + +.project-card__header { + display: flex; + align-items: baseline; + justify-content: space-between; + gap: 1rem; +} + +.project-card__title { + margin: 0; + font-size: 1.25rem; +} + +.project-card__title a { + color: var(--brand); + text-decoration: none; +} + +.project-card__title a:hover, +.project-card__title a:focus { + text-decoration: underline; +} + +.project-card__type { + font-size: 0.75rem; + text-transform: uppercase; + letter-spacing: 0.08em; +} + +.project-card__description { + margin: 0; + color: var(--color-text-subtle); + min-height: 3rem; +} + +.project-card__meta { + display: grid; + gap: 1rem; + grid-template-columns: repeat(auto-fit, minmax(140px, 1fr)); +} + +.project-card__meta div { + display: flex; + flex-direction: column; + gap: 0.35rem; +} + +.project-card__meta dt { + font-size: 0.75rem; + text-transform: uppercase; + color: var(--color-text-muted); + letter-spacing: 0.08em; +} + +.project-card__meta dd { + margin: 0; + font-size: 0.95rem; +} + +.project-card__footer { + display: flex; + align-items: center; + justify-content: space-between; + gap: 1rem; + flex-wrap: wrap; +} + +.project-card__links { + display: flex; + gap: 0.75rem; + flex-wrap: wrap; +} + +.project-card__links .btn--link { + padding: 3px 4px; + border-radius: 8px; +} + +.project-metrics { + display: grid; + gap: 1.5rem; + grid-template-columns: repeat(auto-fit, minmax(220px, 1fr)); + margin-bottom: 2rem; +} + +.project-form { + background: var(--color-surface-overlay); + border: 1px solid var(--color-border); + border-radius: var(--radius); + box-shadow: var(--shadow); + padding: 1.75rem; + display: flex; + flex-direction: column; + gap: 1.5rem; +} + +.card { + background: var(--color-surface-overlay); + border: 1px solid var(--color-border); + box-shadow: var(--shadow); + border-radius: var(--radius); + padding: 1.5rem; + margin-bottom: 2rem; +} + +.project-column { + display: grid; + gap: 1.5rem; +} + +.project-actions-card { + display: flex; + flex-direction: column; + gap: 1rem; +} + +.project-scenarios-card { + display: flex; + flex-direction: column; + gap: 1.5rem; +} + +.project-scenarios-card__header { + display: flex; + flex-wrap: wrap; + justify-content: space-between; + gap: 1rem; +} + +.project-scenarios-card__header h2 { + margin: 0; +} + +.card-header { + display: flex; + align-items: center; + justify-content: space-between; + margin-bottom: 1rem; +} + +.card-header h2 { + margin: 0; +} + +.project-layout { + display: grid; + gap: 1.5rem; +} + +.text-right { + text-align: right; +} + +@media (min-width: 960px) { + .project-layout { + grid-template-columns: 1.1fr 1.9fr; + align-items: start; + } +} diff --git a/static/css/scenarios.css b/static/css/scenarios.css new file mode 100644 index 0000000..e286409 --- /dev/null +++ b/static/css/scenarios.css @@ -0,0 +1,154 @@ +.scenario-metrics { + display: grid; + gap: 1.5rem; + grid-template-columns: repeat(auto-fit, minmax(220px, 1fr)); + margin-bottom: 2rem; +} + +.scenario-filters { + display: grid; + gap: 0.75rem; + margin-bottom: 1.5rem; +} + +.scenario-filters .filter-field { + display: flex; + flex-direction: column; + gap: 0.35rem; +} + +.scenario-filters .filter-actions { + display: flex; + gap: 0.5rem; + flex-wrap: wrap; + align-items: center; +} + +.scenario-filters input, +.scenario-filters select { + width: 100%; + padding: 0.6rem 0.75rem; + border-radius: var(--radius-sm); + border: 1px solid var(--color-border); + background: rgba(8, 12, 19, 0.75); + background: color-mix(in srgb, var(--color-bg-elevated) 75%, transparent); + color: var(--color-text-primary); +} + +.scenario-form { + background: rgba(21, 27, 35, 0.85); + background: var(--color-surface-overlay); + border: 1px solid var(--color-border); + border-radius: var(--radius); + box-shadow: var(--shadow); + padding: 1.75rem; + display: flex; + flex-direction: column; + gap: 1.5rem; +} + +.scenario-form .card { + background: rgba(21, 27, 35, 0.9); + background: color-mix(in srgb, var(--color-surface-default) 90%, transparent); + border: 1px solid var(--color-border); + border-radius: var(--radius); + padding: 1.5rem; + display: flex; + flex-direction: column; + gap: 1.25rem; +} + +.scenario-form .card h2 { + margin: 0; +} + +.scenario-layout { + display: grid; + gap: 1.5rem; +} + +.scenario-column { + display: grid; + gap: 1.5rem; +} + +.quick-actions-card { + display: flex; + flex-direction: column; + gap: 1rem; +} + +.scenario-portfolio { + display: flex; + flex-direction: column; + gap: 1.5rem; +} + +.scenario-portfolio__header { + display: flex; + flex-wrap: wrap; + justify-content: space-between; + gap: 1rem; +} + +.scenario-context-card { + display: flex; + flex-direction: column; + gap: 1rem; +} + +.scenario-context-card .definition-list { + margin: 0; +} + +.scenario-defaults { + list-style: none; + margin: 0; + padding: 0; + display: grid; + gap: 0.75rem; +} + +.scenario-defaults li { + display: flex; + flex-direction: column; + gap: 0.25rem; +} + +.scenario-defaults li strong { + font-size: 0.9rem; + letter-spacing: 0.04em; + text-transform: uppercase; + color: var(--color-text-muted); +} + +.scenario-layout .table tbody tr:hover, +.scenario-portfolio .table tbody tr:hover { + background: rgba(43, 165, 143, 0.12); + background: color-mix(in srgb, var(--color-accent) 18%, transparent); +} + +@media (min-width: 720px) { + .scenario-filters { + grid-template-columns: repeat(auto-fit, minmax(220px, 1fr)); + align-items: end; + } + + .scenario-filters .filter-actions { + justify-content: flex-end; + } +} + +@media (max-width: 640px) { + .breadcrumb { + flex-wrap: wrap; + gap: 0.35rem; + } +} + +@media (min-width: 960px) { + .scenario-layout { + grid-template-columns: 1.1fr 1.9fr; + align-items: start; + } +} diff --git a/static/css/theme-default.css b/static/css/theme-default.css new file mode 100644 index 0000000..66ce055 --- /dev/null +++ b/static/css/theme-default.css @@ -0,0 +1,72 @@ +:root { + /* Neutral surfaces */ + --color-bg-base: #0b0f14; + --color-bg-elevated: #0f141b; + --color-surface-default: #151b23; + --color-surface-overlay: rgba(21, 27, 35, 0.7); + + --color-border-subtle: rgba(255, 255, 255, 0.08); + --color-border-card: rgba(255, 255, 255, 0.08); + --color-border-strong: rgba(255, 255, 255, 0.12); + --color-highlight: rgba(241, 178, 26, 0.08); + + /* Text */ + --color-text-primary: #e6edf3; + --color-text-muted: #a9b4c0; + --color-text-subtle: rgba(169, 180, 192, 0.6); + --color-text-invert: #ffffff; + --color-text-dark: #0f172a; + --color-text-strong: #111827; + + /* Brand & accent */ + --color-brand-base: #f1b21a; + --color-brand-bright: #f6c648; + --color-brand-soft: #f9d475; + --color-accent: #2ba58f; + + /* Semantic states */ + --color-success: #0c864d; + --color-info: #0b3d88; + --color-warning: #f59e0b; + --color-danger: #7a1721; + + /* Shadows & depth */ + --shadow: 0 10px 30px rgba(0, 0, 0, 0.35); + --color-panel-shadow: rgba(0, 0, 0, 0.25); + --color-panel-shadow-deep: rgba(0, 0, 0, 0.35); + + /* Buttons */ + --btn-primary-bg: var(--color-brand-bright); + --btn-primary-color: var(--color-text-dark); + --btn-primary-hover: var(--color-brand-soft); + + --btn-secondary-bg: rgba(21, 27, 35, 0.85); + --btn-secondary-hover: rgba(21, 27, 35, 0.95); + --btn-secondary-border: var(--color-border-strong); + --btn-secondary-color: var(--color-text-primary); + + --btn-danger-bg: var(--color-danger); + --btn-danger-color: var(--color-text-invert); + --btn-danger-hover: #a21d2b; + + --btn-link-color: var(--color-brand-bright); + --btn-link-hover: var(--color-brand-soft); + --btn-ghost-color: var(--color-text-muted); + + /* Legacy aliases */ + --bg: var(--color-bg-base); + --bg-2: var(--color-bg-elevated); + --card: var(--color-surface-default); + --text: var(--color-text-primary); + --muted: var(--color-text-muted); + --brand: var(--color-brand-base); + --brand-2: var(--color-brand-bright); + --brand-3: var(--color-brand-soft); + --accent: var(--color-accent); + --success: var(--color-success); + --danger: var(--color-danger); + --info: var(--color-info); + --color-border: var(--color-border-subtle); + --card-border: var(--color-border-card); + --color-surface-alt: var(--color-surface-overlay); +} diff --git a/static/favicon.ico b/static/favicon.ico new file mode 100644 index 0000000..c3eedc9 Binary files /dev/null and b/static/favicon.ico differ diff --git a/static/img/logo.png b/static/img/logo.png new file mode 100644 index 0000000..5cbc767 Binary files /dev/null and b/static/img/logo.png differ diff --git a/static/img/logo_128x128.png b/static/img/logo_128x128.png new file mode 100644 index 0000000..fe2c4ba Binary files /dev/null and b/static/img/logo_128x128.png differ diff --git a/static/img/logo_big.png b/static/img/logo_big.png new file mode 100644 index 0000000..d7c226b Binary files /dev/null and b/static/img/logo_big.png differ diff --git a/static/js/alerts.js b/static/js/alerts.js new file mode 100644 index 0000000..4336d9d --- /dev/null +++ b/static/js/alerts.js @@ -0,0 +1,11 @@ +document.addEventListener("DOMContentLoaded", () => { + document.querySelectorAll("[data-toast-close]").forEach((button) => { + button.addEventListener("click", () => { + const toast = button.closest(".toast"); + if (toast) { + toast.classList.add("hidden"); + setTimeout(() => toast.remove(), 200); + } + }); + }); +}); diff --git a/static/js/consumption.js b/static/js/consumption.js deleted file mode 100644 index 2866dd9..0000000 --- a/static/js/consumption.js +++ /dev/null @@ -1,205 +0,0 @@ -document.addEventListener("DOMContentLoaded", () => { - const dataElement = document.getElementById("consumption-data"); - let data = { scenarios: [], consumption: {}, unit_options: [] }; - - if (dataElement) { - try { - const parsed = JSON.parse(dataElement.textContent || "{}"); - if (parsed && typeof parsed === "object") { - data = { - scenarios: Array.isArray(parsed.scenarios) ? parsed.scenarios : [], - consumption: - parsed.consumption && typeof parsed.consumption === "object" - ? parsed.consumption - : {}, - unit_options: Array.isArray(parsed.unit_options) - ? parsed.unit_options - : [], - }; - } - } catch (error) { - console.error("Unable to parse consumption data", error); - } - } - - const consumptionByScenario = data.consumption; - const filterSelect = document.getElementById("consumption-scenario-filter"); - const tableWrapper = document.getElementById("consumption-table-wrapper"); - const tableBody = document.getElementById("consumption-table-body"); - const emptyState = document.getElementById("consumption-empty"); - const form = document.getElementById("consumption-form"); - const feedbackEl = document.getElementById("consumption-feedback"); - const unitSelect = document.getElementById("consumption-form-unit"); - const unitSymbolInput = document.getElementById( - "consumption-form-unit-symbol" - ); - - const showFeedback = (message, type = "success") => { - if (!feedbackEl) { - return; - } - feedbackEl.textContent = message; - feedbackEl.classList.remove("hidden", "success", "error"); - feedbackEl.classList.add(type); - }; - - const hideFeedback = () => { - if (!feedbackEl) { - return; - } - feedbackEl.classList.add("hidden"); - feedbackEl.textContent = ""; - }; - - const formatAmount = (value) => - Number(value).toLocaleString(undefined, { - minimumFractionDigits: 2, - maximumFractionDigits: 2, - }); - - const formatMeasurement = (amount, symbol, name) => { - if (symbol) { - return `${formatAmount(amount)} ${symbol}`; - } - if (name) { - return `${formatAmount(amount)} ${name}`; - } - return formatAmount(amount); - }; - - const renderConsumptionRows = (scenarioId) => { - if (!tableBody || !tableWrapper || !emptyState) { - return; - } - - const key = String(scenarioId); - const records = consumptionByScenario[key] || []; - - tableBody.innerHTML = ""; - - if (!records.length) { - emptyState.textContent = "No consumption records for this scenario yet."; - emptyState.classList.remove("hidden"); - tableWrapper.classList.add("hidden"); - return; - } - - emptyState.classList.add("hidden"); - tableWrapper.classList.remove("hidden"); - - records.forEach((record) => { - const row = document.createElement("tr"); - row.innerHTML = ` - ${formatMeasurement( - record.amount, - record.unit_symbol, - record.unit_name - )} - ${record.description || "—"} - `; - tableBody.appendChild(row); - }); - }; - - if (filterSelect) { - filterSelect.addEventListener("change", (event) => { - const value = event.target.value; - if (!value) { - if (emptyState && tableWrapper && tableBody) { - emptyState.textContent = - "Choose a scenario to review its consumption records."; - emptyState.classList.remove("hidden"); - tableWrapper.classList.add("hidden"); - tableBody.innerHTML = ""; - } - return; - } - renderConsumptionRows(value); - }); - } - - const submitConsumption = async (event) => { - event.preventDefault(); - hideFeedback(); - - if (!form) { - return; - } - - const formData = new FormData(form); - const scenarioId = formData.get("scenario_id"); - const unitName = formData.get("unit_name"); - const unitSymbol = formData.get("unit_symbol"); - const payload = { - scenario_id: scenarioId ? Number(scenarioId) : null, - amount: Number(formData.get("amount")), - description: formData.get("description") || null, - unit_name: unitName ? String(unitName) : null, - unit_symbol: unitSymbol ? String(unitSymbol) : null, - }; - - try { - const response = await fetch("/api/consumption/", { - method: "POST", - headers: { "Content-Type": "application/json" }, - body: JSON.stringify(payload), - }); - - if (!response.ok) { - const errorDetail = await response.json().catch(() => ({})); - throw new Error( - errorDetail.detail || "Unable to add consumption record." - ); - } - - const result = await response.json(); - const mapKey = String(result.scenario_id); - - if (!Array.isArray(consumptionByScenario[mapKey])) { - consumptionByScenario[mapKey] = []; - } - consumptionByScenario[mapKey].push(result); - - form.reset(); - syncUnitSelection(); - showFeedback("Consumption record saved.", "success"); - - if (filterSelect && filterSelect.value === String(result.scenario_id)) { - renderConsumptionRows(filterSelect.value); - } - } catch (error) { - showFeedback(error.message || "An unexpected error occurred.", "error"); - } - }; - - if (form) { - form.addEventListener("submit", submitConsumption); - } - - const syncUnitSelection = () => { - if (!unitSelect || !unitSymbolInput) { - return; - } - if (!unitSelect.value && unitSelect.options.length > 0) { - const firstOption = Array.from(unitSelect.options).find( - (option) => option.value - ); - if (firstOption) { - firstOption.selected = true; - } - } - const selectedOption = unitSelect.options[unitSelect.selectedIndex]; - unitSymbolInput.value = selectedOption - ? selectedOption.getAttribute("data-symbol") || "" - : ""; - }; - - if (unitSelect) { - unitSelect.addEventListener("change", syncUnitSelection); - syncUnitSelection(); - } - - if (filterSelect && filterSelect.value) { - renderConsumptionRows(filterSelect.value); - } -}); diff --git a/static/js/costs.js b/static/js/costs.js deleted file mode 100644 index d75139c..0000000 --- a/static/js/costs.js +++ /dev/null @@ -1,339 +0,0 @@ -document.addEventListener("DOMContentLoaded", () => { - const dataElement = document.getElementById("costs-payload"); - let capexByScenario = {}; - let opexByScenario = {}; - let currencyOptions = []; - - if (dataElement) { - try { - const parsed = JSON.parse(dataElement.textContent || "{}"); - if (parsed && typeof parsed === "object") { - if (parsed.capex && typeof parsed.capex === "object") { - capexByScenario = parsed.capex; - } - if (parsed.opex && typeof parsed.opex === "object") { - opexByScenario = parsed.opex; - } - if (Array.isArray(parsed.currency_options)) { - currencyOptions = parsed.currency_options; - } - } - } catch (error) { - console.error("Unable to parse cost data", error); - } - } - - const filterSelect = document.getElementById("costs-scenario-filter"); - const costsEmptyState = document.getElementById("costs-empty"); - const costsDataWrapper = document.getElementById("costs-data"); - const capexTableBody = document.getElementById("capex-table-body"); - const opexTableBody = document.getElementById("opex-table-body"); - const capexEmpty = document.getElementById("capex-empty"); - const opexEmpty = document.getElementById("opex-empty"); - const capexTotal = document.getElementById("capex-total"); - const opexTotal = document.getElementById("opex-total"); - const capexForm = document.getElementById("capex-form"); - const opexForm = document.getElementById("opex-form"); - const capexFeedback = document.getElementById("capex-feedback"); - const opexFeedback = document.getElementById("opex-feedback"); - const capexFormScenario = document.getElementById("capex-form-scenario"); - const opexFormScenario = document.getElementById("opex-form-scenario"); - const capexCurrencySelect = document.getElementById("capex-form-currency"); - const opexCurrencySelect = document.getElementById("opex-form-currency"); - - // If no currency options were injected server-side, fetch from API - const fetchCurrencyOptions = async () => { - try { - const resp = await fetch("/api/currencies/"); - if (!resp.ok) return; - const list = await resp.json(); - if (Array.isArray(list) && list.length) { - currencyOptions = list; - populateCurrencySelects(); - } - } catch (err) { - console.warn("Unable to fetch currency options", err); - } - }; - - const populateCurrencySelects = () => { - const selectElements = [capexCurrencySelect, opexCurrencySelect].filter(Boolean); - selectElements.forEach((sel) => { - if (!sel) return; - // Clear non-empty options except the empty placeholder - const placeholder = sel.querySelector("option[value='']"); - sel.innerHTML = ""; - if (placeholder) sel.appendChild(placeholder); - currencyOptions.forEach((opt) => { - const option = document.createElement("option"); - option.value = opt.id; - option.textContent = opt.name || opt.id; - sel.appendChild(option); - }); - }); - }; - - // populate from injected options first, then fetch to refresh - if (currencyOptions && currencyOptions.length) populateCurrencySelects(); - else fetchCurrencyOptions(); - - const showFeedback = (element, message, type = "success") => { - if (!element) { - return; - } - element.textContent = message; - element.classList.remove("hidden", "success", "error"); - element.classList.add(type); - }; - - const hideFeedback = (element) => { - if (!element) { - return; - } - element.classList.add("hidden"); - element.textContent = ""; - }; - - const formatAmount = (value) => - Number(value).toLocaleString(undefined, { - minimumFractionDigits: 2, - maximumFractionDigits: 2, - }); - - const formatCurrencyAmount = (value, currencyCode) => { - if (!currencyCode) { - return formatAmount(value); - } - try { - return new Intl.NumberFormat(undefined, { - style: "currency", - currency: currencyCode, - minimumFractionDigits: 2, - maximumFractionDigits: 2, - }).format(Number(value)); - } catch (error) { - return `${currencyCode} ${formatAmount(value)}`; - } - }; - - const sumAmount = (records) => - records.reduce((total, record) => total + Number(record.amount || 0), 0); - - const describeTotal = (records) => { - if (!records || records.length === 0) { - return "—"; - } - const total = sumAmount(records); - const currencyCodes = Array.from( - new Set( - records - .map((record) => (record.currency_code || "").trim().toUpperCase()) - .filter(Boolean) - ) - ); - - if (currencyCodes.length === 1) { - return formatCurrencyAmount(total, currencyCodes[0]); - } - return `${formatAmount(total)} (mixed)`; - }; - - const renderCostTables = (scenarioId) => { - if ( - !capexTableBody || - !opexTableBody || - !capexEmpty || - !opexEmpty || - !capexTotal || - !opexTotal - ) { - return; - } - - const capexRecords = capexByScenario[String(scenarioId)] || []; - const opexRecords = opexByScenario[String(scenarioId)] || []; - - capexTableBody.innerHTML = ""; - opexTableBody.innerHTML = ""; - - if (!capexRecords.length) { - capexEmpty.classList.remove("hidden"); - } else { - capexEmpty.classList.add("hidden"); - capexRecords.forEach((record) => { - const row = document.createElement("tr"); - row.innerHTML = ` - ${formatCurrencyAmount(record.amount, record.currency_code)} - ${record.description || "—"} - `; - capexTableBody.appendChild(row); - }); - } - - if (!opexRecords.length) { - opexEmpty.classList.remove("hidden"); - } else { - opexEmpty.classList.add("hidden"); - opexRecords.forEach((record) => { - const row = document.createElement("tr"); - row.innerHTML = ` - ${formatCurrencyAmount(record.amount, record.currency_code)} - ${record.description || "—"} - `; - opexTableBody.appendChild(row); - }); - } - - capexTotal.textContent = describeTotal(capexRecords); - opexTotal.textContent = describeTotal(opexRecords); - }; - - const toggleCostView = (show) => { - if ( - !costsEmptyState || - !costsDataWrapper || - !capexTableBody || - !opexTableBody - ) { - return; - } - - if (show) { - costsEmptyState.classList.add("hidden"); - costsDataWrapper.classList.remove("hidden"); - } else { - costsEmptyState.classList.remove("hidden"); - costsDataWrapper.classList.add("hidden"); - capexTableBody.innerHTML = ""; - opexTableBody.innerHTML = ""; - if (capexTotal) { - capexTotal.textContent = "—"; - } - if (opexTotal) { - opexTotal.textContent = "—"; - } - if (capexEmpty) { - capexEmpty.classList.add("hidden"); - } - if (opexEmpty) { - opexEmpty.classList.add("hidden"); - } - } - }; - - const syncFormSelections = (value) => { - if (capexFormScenario) { - capexFormScenario.value = value || ""; - } - if (opexFormScenario) { - opexFormScenario.value = value || ""; - } - }; - - const ensureCurrencySelection = (selectElement) => { - if (!selectElement || selectElement.value) { - return; - } - const firstOption = selectElement.querySelector( - "option[value]:not([value=''])" - ); - if (firstOption && firstOption.value) { - selectElement.value = firstOption.value; - } - }; - - if (filterSelect) { - filterSelect.addEventListener("change", (event) => { - const value = event.target.value; - if (!value) { - toggleCostView(false); - syncFormSelections(""); - return; - } - toggleCostView(true); - renderCostTables(value); - syncFormSelections(value); - }); - } - - const submitCostEntry = async (event, targetUrl, storageMap, feedbackEl) => { - event.preventDefault(); - hideFeedback(feedbackEl); - - const formData = new FormData(event.target); - const scenarioId = formData.get("scenario_id"); - const currencyCode = formData.get("currency_code"); - const payload = { - scenario_id: scenarioId ? Number(scenarioId) : null, - amount: Number(formData.get("amount")), - description: formData.get("description") || null, - currency_code: currencyCode ? String(currencyCode).toUpperCase() : null, - }; - - if (!payload.scenario_id) { - showFeedback(feedbackEl, "Select a scenario before submitting.", "error"); - return; - } - - if (!payload.currency_code) { - showFeedback(feedbackEl, "Choose a currency before submitting.", "error"); - return; - } - - try { - const response = await fetch(targetUrl, { - method: "POST", - headers: { "Content-Type": "application/json" }, - body: JSON.stringify(payload), - }); - - if (!response.ok) { - const errorDetail = await response.json().catch(() => ({})); - throw new Error(errorDetail.detail || "Unable to save cost entry."); - } - - const result = await response.json(); - const mapKey = String(result.scenario_id); - - if (!Array.isArray(storageMap[mapKey])) { - storageMap[mapKey] = []; - } - - storageMap[mapKey].push(result); - - event.target.reset(); - ensureCurrencySelection(event.target.querySelector("select[name='currency_code']")); - showFeedback(feedbackEl, "Entry saved successfully.", "success"); - - if (filterSelect && filterSelect.value === mapKey) { - renderCostTables(mapKey); - } - } catch (error) { - showFeedback( - feedbackEl, - error.message || "An unexpected error occurred.", - "error" - ); - } - }; - - if (capexForm) { - ensureCurrencySelection(capexCurrencySelect); - capexForm.addEventListener("submit", (event) => - submitCostEntry(event, "/api/costs/capex", capexByScenario, capexFeedback) - ); - } - - if (opexForm) { - ensureCurrencySelection(opexCurrencySelect); - opexForm.addEventListener("submit", (event) => - submitCostEntry(event, "/api/costs/opex", opexByScenario, opexFeedback) - ); - } - - if (filterSelect && filterSelect.value) { - toggleCostView(true); - renderCostTables(filterSelect.value); - syncFormSelections(filterSelect.value); - } -}); diff --git a/static/js/currencies.js b/static/js/currencies.js deleted file mode 100644 index 86557f7..0000000 --- a/static/js/currencies.js +++ /dev/null @@ -1,537 +0,0 @@ -document.addEventListener("DOMContentLoaded", () => { - const dataElement = document.getElementById("currencies-data"); - const editorSection = document.getElementById("currencies-editor"); - const tableBody = document.getElementById("currencies-table-body"); - const tableEmptyState = document.getElementById("currencies-table-empty"); - const metrics = { - total: document.getElementById("currency-metric-total"), - active: document.getElementById("currency-metric-active"), - inactive: document.getElementById("currency-metric-inactive"), - }; - - const form = document.getElementById("currency-form"); - const existingSelect = document.getElementById("currency-form-existing"); - const codeInput = document.getElementById("currency-form-code"); - const nameInput = document.getElementById("currency-form-name"); - const symbolInput = document.getElementById("currency-form-symbol"); - const statusSelect = document.getElementById("currency-form-status"); - const resetButton = document.getElementById("currency-form-reset"); - const feedbackElement = document.getElementById("currency-form-feedback"); - - const saveButton = form ? form.querySelector("button[type='submit']") : null; - - const uppercaseCode = (value) => - (value || "").toString().trim().toUpperCase(); - const normalizeSymbol = (value) => { - if (value === undefined || value === null) { - return null; - } - const trimmed = String(value).trim(); - return trimmed ? trimmed : null; - }; - - const normalizeApiBase = (value) => { - if (!value || typeof value !== "string") { - return "/api/currencies"; - } - return value.endsWith("/") ? value.slice(0, -1) : value; - }; - - let currencies = []; - let apiBase = "/api/currencies"; - let defaultCurrencyCode = "USD"; - - const buildCurrencyRecord = (record) => { - if (!record || typeof record !== "object") { - return null; - } - const code = uppercaseCode(record.code); - return { - id: record.id ?? null, - code, - name: record.name || "", - symbol: record.symbol || "", - is_active: Boolean(record.is_active), - is_default: code === defaultCurrencyCode, - }; - }; - - const findCurrencyIndex = (code) => { - return currencies.findIndex((item) => item.code === code); - }; - - const upsertCurrency = (record) => { - const normalized = buildCurrencyRecord(record); - if (!normalized) { - return null; - } - const existingIndex = findCurrencyIndex(normalized.code); - if (existingIndex >= 0) { - currencies[existingIndex] = normalized; - } else { - currencies.push(normalized); - } - currencies.sort((a, b) => a.code.localeCompare(b.code)); - return normalized; - }; - - const replaceCurrencyList = (records) => { - if (!Array.isArray(records)) { - return; - } - currencies = records - .map((record) => buildCurrencyRecord(record)) - .filter((record) => record !== null) - .sort((a, b) => a.code.localeCompare(b.code)); - }; - - const applyPayload = () => { - if (!dataElement) { - return; - } - try { - const parsed = JSON.parse(dataElement.textContent || "{}"); - if (parsed && typeof parsed === "object") { - if (parsed.default_currency_code) { - defaultCurrencyCode = uppercaseCode(parsed.default_currency_code); - } - if (parsed.currency_api_base) { - apiBase = normalizeApiBase(parsed.currency_api_base); - } - if (Array.isArray(parsed.currencies)) { - replaceCurrencyList(parsed.currencies); - } - } - } catch (error) { - console.error("Unable to parse currencies payload", error); - } - }; - - const showFeedback = (message, type = "success") => { - if (!feedbackElement) { - return; - } - feedbackElement.textContent = message; - feedbackElement.classList.remove("hidden", "success", "error"); - feedbackElement.classList.add(type); - }; - - const hideFeedback = () => { - if (!feedbackElement) { - return; - } - feedbackElement.classList.add("hidden"); - feedbackElement.classList.remove("success", "error"); - feedbackElement.textContent = ""; - }; - - const setButtonLoading = (button, isLoading) => { - if (!button) { - return; - } - button.disabled = isLoading; - button.classList.toggle("is-loading", isLoading); - }; - - const updateMetrics = () => { - const total = currencies.length; - const active = currencies.filter((item) => item.is_active).length; - const inactive = total - active; - if (metrics.total) { - metrics.total.textContent = String(total); - } - if (metrics.active) { - metrics.active.textContent = String(active); - } - if (metrics.inactive) { - metrics.inactive.textContent = String(inactive); - } - }; - - const renderExistingOptions = ( - selectedCode = existingSelect ? existingSelect.value : "" - ) => { - if (!existingSelect) { - return; - } - const placeholder = existingSelect.querySelector("option[value='']"); - const placeholderClone = placeholder ? placeholder.cloneNode(true) : null; - existingSelect.innerHTML = ""; - if (placeholderClone) { - existingSelect.appendChild(placeholderClone); - } - const fragment = document.createDocumentFragment(); - currencies.forEach((currency) => { - const option = document.createElement("option"); - option.value = currency.code; - option.textContent = currency.name - ? `${currency.name} (${currency.code})` - : currency.code; - if (selectedCode === currency.code) { - option.selected = true; - } - fragment.appendChild(option); - }); - existingSelect.appendChild(fragment); - if ( - selectedCode && - !currencies.some((item) => item.code === selectedCode) - ) { - existingSelect.value = ""; - } - }; - - const renderTable = () => { - if (!tableBody) { - return; - } - tableBody.innerHTML = ""; - if (!currencies.length) { - if (tableEmptyState) { - tableEmptyState.classList.remove("hidden"); - } - return; - } - if (tableEmptyState) { - tableEmptyState.classList.add("hidden"); - } - const fragment = document.createDocumentFragment(); - currencies.forEach((currency) => { - const row = document.createElement("tr"); - - const codeCell = document.createElement("td"); - codeCell.textContent = currency.code; - row.appendChild(codeCell); - - const nameCell = document.createElement("td"); - nameCell.textContent = currency.name || "—"; - row.appendChild(nameCell); - - const symbolCell = document.createElement("td"); - symbolCell.textContent = currency.symbol || "—"; - row.appendChild(symbolCell); - - const statusCell = document.createElement("td"); - statusCell.textContent = currency.is_active ? "Active" : "Inactive"; - if (currency.is_default) { - statusCell.textContent += " (Default)"; - } - row.appendChild(statusCell); - - const actionsCell = document.createElement("td"); - const editButton = document.createElement("button"); - editButton.type = "button"; - editButton.className = "btn"; - editButton.dataset.action = "edit"; - editButton.dataset.code = currency.code; - editButton.textContent = "Edit"; - editButton.style.marginRight = "0.5rem"; - - const toggleButton = document.createElement("button"); - toggleButton.type = "button"; - toggleButton.className = "btn"; - toggleButton.dataset.action = "toggle"; - toggleButton.dataset.code = currency.code; - toggleButton.textContent = currency.is_active ? "Deactivate" : "Activate"; - if (currency.is_default && currency.is_active) { - toggleButton.disabled = true; - toggleButton.title = "The default currency must remain active."; - } - - actionsCell.appendChild(editButton); - actionsCell.appendChild(toggleButton); - - row.appendChild(actionsCell); - fragment.appendChild(row); - }); - tableBody.appendChild(fragment); - }; - - const refreshUI = (selectedCode) => { - currencies.sort((a, b) => a.code.localeCompare(b.code)); - renderTable(); - renderExistingOptions(selectedCode); - updateMetrics(); - }; - - const findCurrency = (code) => - currencies.find((item) => item.code === code) || null; - - const setFormForCurrency = (currency) => { - if (!form || !codeInput || !nameInput || !symbolInput || !statusSelect) { - return; - } - if (!currency) { - form.reset(); - if (existingSelect) { - existingSelect.value = ""; - } - codeInput.readOnly = false; - codeInput.value = ""; - nameInput.value = ""; - symbolInput.value = ""; - statusSelect.disabled = false; - statusSelect.value = "true"; - statusSelect.title = ""; - return; - } - - if (existingSelect) { - existingSelect.value = currency.code; - } - codeInput.readOnly = true; - codeInput.value = currency.code; - nameInput.value = currency.name || ""; - symbolInput.value = currency.symbol || ""; - statusSelect.value = currency.is_active ? "true" : "false"; - if (currency.is_default) { - statusSelect.disabled = true; - statusSelect.value = "true"; - statusSelect.title = "The default currency must remain active."; - } else { - statusSelect.disabled = false; - statusSelect.title = ""; - } - }; - - const resetFormState = () => { - setFormForCurrency(null); - }; - - const parseError = async (response, fallbackMessage) => { - try { - const detail = await response.json(); - if (detail && typeof detail === "object" && detail.detail) { - return detail.detail; - } - } catch (error) { - // ignore JSON parse errors - } - return fallbackMessage; - }; - - const fetchCurrenciesFromApi = async () => { - const url = `${apiBase}/?include_inactive=true`; - try { - const response = await fetch(url); - if (!response.ok) { - return; - } - const list = await response.json(); - if (Array.isArray(list)) { - replaceCurrencyList(list); - refreshUI(existingSelect ? existingSelect.value : undefined); - } - } catch (error) { - console.warn("Unable to refresh currency list", error); - } - }; - - const handleSubmit = async (event) => { - event.preventDefault(); - hideFeedback(); - if (!form || !codeInput || !nameInput || !statusSelect) { - return; - } - - const editingCode = existingSelect - ? uppercaseCode(existingSelect.value) - : ""; - const codeValue = uppercaseCode(codeInput.value); - const nameValue = (nameInput.value || "").trim(); - const symbolValue = normalizeSymbol(symbolInput ? symbolInput.value : ""); - const isActive = statusSelect.value !== "false"; - - if (!nameValue) { - showFeedback("Provide a currency name.", "error"); - return; - } - - if (!editingCode) { - if (!codeValue || codeValue.length !== 3) { - showFeedback("Provide a three-letter currency code.", "error"); - return; - } - } - - const payload = editingCode - ? { - name: nameValue, - symbol: symbolValue, - is_active: isActive, - } - : { - code: codeValue, - name: nameValue, - symbol: symbolValue, - is_active: isActive, - }; - - const targetCode = editingCode || codeValue; - const url = editingCode - ? `${apiBase}/${encodeURIComponent(editingCode)}` - : `${apiBase}/`; - - setButtonLoading(saveButton, true); - try { - const response = await fetch(url, { - method: editingCode ? "PUT" : "POST", - headers: { "Content-Type": "application/json" }, - body: JSON.stringify(payload), - }); - - if (!response.ok) { - const message = await parseError( - response, - editingCode - ? "Unable to update the currency." - : "Unable to create the currency." - ); - throw new Error(message); - } - - const result = await response.json(); - const updated = upsertCurrency(result); - defaultCurrencyCode = uppercaseCode(defaultCurrencyCode); - refreshUI(updated ? updated.code : targetCode); - - if (editingCode) { - showFeedback("Currency updated successfully."); - if (updated) { - setFormForCurrency(updated); - } - } else { - showFeedback("Currency created successfully."); - resetFormState(); - } - } catch (error) { - showFeedback(error.message || "An unexpected error occurred.", "error"); - } finally { - setButtonLoading(saveButton, false); - } - }; - - const handleToggle = async (code, button) => { - const record = findCurrency(code); - if (!record) { - return; - } - hideFeedback(); - const nextState = !record.is_active; - const url = `${apiBase}/${encodeURIComponent(code)}/activation`; - setButtonLoading(button, true); - try { - const response = await fetch(url, { - method: "PATCH", - headers: { "Content-Type": "application/json" }, - body: JSON.stringify({ is_active: nextState }), - }); - - if (!response.ok) { - const message = await parseError( - response, - nextState - ? "Unable to activate the currency." - : "Unable to deactivate the currency." - ); - throw new Error(message); - } - - const result = await response.json(); - const updated = upsertCurrency(result); - refreshUI(updated ? updated.code : code); - if (existingSelect && existingSelect.value === code && updated) { - setFormForCurrency(updated); - } - const actionMessage = nextState - ? `Currency ${code} activated.` - : `Currency ${code} deactivated.`; - showFeedback(actionMessage); - } catch (error) { - showFeedback(error.message || "An unexpected error occurred.", "error"); - } finally { - setButtonLoading(button, false); - } - }; - - const handleTableClick = (event) => { - const button = event.target.closest("button[data-action]"); - if (!button) { - return; - } - const code = uppercaseCode(button.dataset.code); - const action = button.dataset.action; - if (!code || !action) { - return; - } - if (action === "edit") { - const currency = findCurrency(code); - if (currency) { - setFormForCurrency(currency); - hideFeedback(); - if (nameInput) { - nameInput.focus(); - } - } - } else if (action === "toggle") { - handleToggle(code, button); - } - }; - - applyPayload(); - if (editorSection && editorSection.dataset.defaultCode) { - defaultCurrencyCode = uppercaseCode(editorSection.dataset.defaultCode); - currencies = currencies.map((record) => { - return record - ? { - ...record, - is_default: record.code === defaultCurrencyCode, - } - : record; - }); - } - apiBase = normalizeApiBase(apiBase); - - refreshUI(); - - if (form) { - form.addEventListener("submit", handleSubmit); - } - - if (existingSelect) { - existingSelect.addEventListener("change", (event) => { - const selectedCode = uppercaseCode(event.target.value); - if (!selectedCode) { - hideFeedback(); - resetFormState(); - return; - } - const currency = findCurrency(selectedCode); - if (currency) { - setFormForCurrency(currency); - hideFeedback(); - } - }); - } - - if (resetButton) { - resetButton.addEventListener("click", (event) => { - event.preventDefault(); - hideFeedback(); - resetFormState(); - }); - } - - if (codeInput) { - codeInput.addEventListener("input", () => { - const value = uppercaseCode(codeInput.value).slice(0, 3); - codeInput.value = value; - }); - } - - if (tableBody) { - tableBody.addEventListener("click", handleTableClick); - } - - fetchCurrenciesFromApi(); -}); diff --git a/static/js/dashboard.js b/static/js/dashboard.js deleted file mode 100644 index 5cec954..0000000 --- a/static/js/dashboard.js +++ /dev/null @@ -1,289 +0,0 @@ -(() => { - const dataElement = document.getElementById("dashboard-data"); - if (!dataElement) { - return; - } - - let state = {}; - try { - state = JSON.parse(dataElement.textContent || "{}"); - } catch (error) { - console.error("Failed to parse dashboard data", error); - return; - } - - const statusElement = document.getElementById("dashboard-status"); - const summaryContainer = document.getElementById("summary-metrics"); - const summaryEmpty = document.getElementById("summary-empty"); - const scenarioTableBody = document.querySelector("#scenario-table tbody"); - const scenarioEmpty = document.getElementById("scenario-table-empty"); - const overallMetricsList = document.getElementById("overall-metrics"); - const overallMetricsEmpty = document.getElementById("overall-metrics-empty"); - const recentList = document.getElementById("recent-simulations"); - const recentEmpty = document.getElementById("recent-simulations-empty"); - const maintenanceList = document.getElementById("upcoming-maintenance"); - const maintenanceEmpty = document.getElementById( - "upcoming-maintenance-empty" - ); - const refreshButton = document.getElementById("refresh-dashboard"); - const costChartCanvas = document.getElementById("cost-chart"); - const costChartEmpty = document.getElementById("cost-chart-empty"); - const activityChartCanvas = document.getElementById("activity-chart"); - const activityChartEmpty = document.getElementById("activity-chart-empty"); - - let costChartInstance = null; - let activityChartInstance = null; - - const setStatus = (message, variant = "success") => { - if (!statusElement) { - return; - } - if (!message) { - statusElement.hidden = true; - statusElement.textContent = ""; - statusElement.classList.remove("success", "error"); - return; - } - statusElement.textContent = message; - statusElement.hidden = false; - statusElement.classList.toggle("success", variant === "success"); - statusElement.classList.toggle("error", variant !== "success"); - }; - - const renderSummaryMetrics = () => { - if (!summaryContainer || !summaryEmpty) { - return; - } - summaryContainer.innerHTML = ""; - const metrics = Array.isArray(state.summary_metrics) - ? state.summary_metrics - : []; - metrics.forEach((metric) => { - const card = document.createElement("article"); - card.className = "metric-card"; - card.innerHTML = ` - ${metric.label} - ${metric.value} - `; - summaryContainer.appendChild(card); - }); - summaryEmpty.hidden = metrics.length > 0; - }; - - const renderScenarioTable = () => { - if (!scenarioTableBody || !scenarioEmpty) { - return; - } - scenarioTableBody.innerHTML = ""; - const rows = Array.isArray(state.scenario_rows) ? state.scenario_rows : []; - rows.forEach((row) => { - const tr = document.createElement("tr"); - tr.innerHTML = ` - ${row.scenario_name} - ${row.parameter_display} - ${row.equipment_display} - ${row.capex_display} - ${row.opex_display} - ${row.production_display} - ${row.consumption_display} - ${row.maintenance_display} - ${row.iterations_display} - ${row.simulation_mean_display} - `; - scenarioTableBody.appendChild(tr); - }); - scenarioEmpty.hidden = rows.length > 0; - }; - - const renderOverallMetrics = () => { - if (!overallMetricsList || !overallMetricsEmpty) { - return; - } - overallMetricsList.innerHTML = ""; - const items = Array.isArray(state.overall_report_metrics) - ? state.overall_report_metrics - : []; - items.forEach((item) => { - const li = document.createElement("li"); - li.className = "metric-list-item"; - li.textContent = `${item.label}: ${item.value}`; - overallMetricsList.appendChild(li); - }); - overallMetricsEmpty.hidden = items.length > 0; - }; - - const renderRecentSimulations = () => { - if (!recentList || !recentEmpty) { - return; - } - recentList.innerHTML = ""; - const runs = Array.isArray(state.recent_simulations) - ? state.recent_simulations - : []; - runs.forEach((run) => { - const item = document.createElement("li"); - item.className = "metric-list-item"; - item.textContent = `${run.scenario_name} · ${run.iterations_display} iterations · ${run.mean_display}`; - recentList.appendChild(item); - }); - recentEmpty.hidden = runs.length > 0; - }; - - const renderMaintenanceReminders = () => { - if (!maintenanceList || !maintenanceEmpty) { - return; - } - maintenanceList.innerHTML = ""; - const items = Array.isArray(state.upcoming_maintenance) - ? state.upcoming_maintenance - : []; - items.forEach((item) => { - const li = document.createElement("li"); - li.innerHTML = ` - ${item.equipment_name} · ${item.scenario_name} - ${item.date_display} · ${item.cost_display} · ${item.description} - `; - maintenanceList.appendChild(li); - }); - maintenanceEmpty.hidden = items.length > 0; - }; - - const buildChartConfig = (dataset, overrides = {}) => ({ - type: dataset.type || "bar", - data: { - labels: dataset.labels || [], - datasets: dataset.datasets || [], - }, - options: Object.assign( - { - responsive: true, - maintainAspectRatio: false, - plugins: { - legend: { position: "top" }, - tooltip: { enabled: true }, - }, - scales: { - x: { stacked: dataset.stacked ?? false }, - y: { stacked: dataset.stacked ?? false, beginAtZero: true }, - }, - }, - overrides.options || {} - ), - }); - - const renderCharts = () => { - if (costChartInstance) { - costChartInstance.destroy(); - } - if (activityChartInstance) { - activityChartInstance.destroy(); - } - - const costData = state.scenario_cost_chart || {}; - const activityData = state.scenario_activity_chart || {}; - - if (costChartCanvas && state.cost_chart_has_data) { - costChartInstance = new Chart( - costChartCanvas, - buildChartConfig(costData, { - options: { - scales: { - y: { - beginAtZero: true, - ticks: { - callback: (value) => - typeof value === "number" - ? value.toLocaleString(undefined, { - maximumFractionDigits: 0, - }) - : value, - }, - }, - }, - }, - }) - ); - if (costChartEmpty) { - costChartEmpty.hidden = true; - } - costChartCanvas.classList.remove("hidden"); - } else if (costChartEmpty && costChartCanvas) { - costChartEmpty.hidden = false; - costChartCanvas.classList.add("hidden"); - } - - if (activityChartCanvas && state.activity_chart_has_data) { - activityChartInstance = new Chart( - activityChartCanvas, - buildChartConfig(activityData, { - options: { - scales: { - y: { - beginAtZero: true, - ticks: { - callback: (value) => - typeof value === "number" - ? value.toLocaleString(undefined, { - maximumFractionDigits: 0, - }) - : value, - }, - }, - }, - }, - }) - ); - if (activityChartEmpty) { - activityChartEmpty.hidden = true; - } - activityChartCanvas.classList.remove("hidden"); - } else if (activityChartEmpty && activityChartCanvas) { - activityChartEmpty.hidden = false; - activityChartCanvas.classList.add("hidden"); - } - }; - - const renderView = () => { - renderSummaryMetrics(); - renderScenarioTable(); - renderOverallMetrics(); - renderRecentSimulations(); - renderMaintenanceReminders(); - renderCharts(); - }; - - const refreshDashboard = async () => { - setStatus("Refreshing dashboard…", "success"); - if (refreshButton) { - refreshButton.classList.add("is-loading"); - } - - try { - const response = await fetch("/ui/dashboard/data", { - headers: { "X-Requested-With": "XMLHttpRequest" }, - }); - - if (!response.ok) { - throw new Error("Unable to refresh dashboard data."); - } - - const payload = await response.json(); - state = payload || {}; - renderView(); - setStatus("Dashboard updated.", "success"); - } catch (error) { - console.error(error); - setStatus(error.message || "Failed to refresh dashboard.", "error"); - } finally { - if (refreshButton) { - refreshButton.classList.remove("is-loading"); - } - } - }; - - renderView(); - - if (refreshButton) { - refreshButton.addEventListener("click", refreshDashboard); - } -})(); diff --git a/static/js/equipment.js b/static/js/equipment.js deleted file mode 100644 index cf2c56d..0000000 --- a/static/js/equipment.js +++ /dev/null @@ -1,145 +0,0 @@ -document.addEventListener("DOMContentLoaded", () => { - const dataElement = document.getElementById("equipment-data"); - let equipmentByScenario = {}; - - if (dataElement) { - try { - const parsed = JSON.parse(dataElement.textContent || "{}"); - if (parsed && typeof parsed === "object") { - if (parsed.equipment && typeof parsed.equipment === "object") { - equipmentByScenario = parsed.equipment; - } - } - } catch (error) { - console.error("Unable to parse equipment data", error); - } - } - - const filterSelect = document.getElementById("equipment-scenario-filter"); - const tableWrapper = document.getElementById("equipment-table-wrapper"); - const tableBody = document.getElementById("equipment-table-body"); - const emptyState = document.getElementById("equipment-empty"); - const form = document.getElementById("equipment-form"); - const feedbackEl = document.getElementById("equipment-feedback"); - - const showFeedback = (message, type = "success") => { - if (!feedbackEl) { - return; - } - feedbackEl.textContent = message; - feedbackEl.classList.remove("hidden", "success", "error"); - feedbackEl.classList.add(type); - }; - - const hideFeedback = () => { - if (!feedbackEl) { - return; - } - feedbackEl.classList.add("hidden"); - feedbackEl.textContent = ""; - }; - - const renderEquipmentRows = (scenarioId) => { - if (!tableBody || !tableWrapper || !emptyState) { - return; - } - - const key = String(scenarioId); - const records = equipmentByScenario[key] || []; - - tableBody.innerHTML = ""; - - if (!records.length) { - emptyState.textContent = "No equipment recorded for this scenario yet."; - emptyState.classList.remove("hidden"); - tableWrapper.classList.add("hidden"); - return; - } - - emptyState.classList.add("hidden"); - tableWrapper.classList.remove("hidden"); - - records.forEach((record) => { - const row = document.createElement("tr"); - row.innerHTML = ` - ${record.name || "—"} - ${record.description || "—"} - `; - tableBody.appendChild(row); - }); - }; - - if (filterSelect) { - filterSelect.addEventListener("change", (event) => { - const value = event.target.value; - if (!value) { - if (emptyState && tableWrapper && tableBody) { - emptyState.textContent = - "Choose a scenario to review the equipment list."; - emptyState.classList.remove("hidden"); - tableWrapper.classList.add("hidden"); - tableBody.innerHTML = ""; - } - return; - } - renderEquipmentRows(value); - }); - } - - const submitEquipment = async (event) => { - event.preventDefault(); - hideFeedback(); - - if (!form) { - return; - } - - const formData = new FormData(form); - const scenarioId = formData.get("scenario_id"); - const payload = { - scenario_id: scenarioId ? Number(scenarioId) : null, - name: formData.get("name"), - description: formData.get("description") || null, - }; - - try { - const response = await fetch("/api/equipment/", { - method: "POST", - headers: { "Content-Type": "application/json" }, - body: JSON.stringify(payload), - }); - - if (!response.ok) { - const errorDetail = await response.json().catch(() => ({})); - throw new Error( - errorDetail.detail || "Unable to add equipment record." - ); - } - - const result = await response.json(); - const mapKey = String(result.scenario_id); - - if (!Array.isArray(equipmentByScenario[mapKey])) { - equipmentByScenario[mapKey] = []; - } - equipmentByScenario[mapKey].push(result); - - form.reset(); - showFeedback("Equipment saved.", "success"); - - if (filterSelect && filterSelect.value === String(result.scenario_id)) { - renderEquipmentRows(filterSelect.value); - } - } catch (error) { - showFeedback(error.message || "An unexpected error occurred.", "error"); - } - }; - - if (form) { - form.addEventListener("submit", submitEquipment); - } - - if (filterSelect && filterSelect.value) { - renderEquipmentRows(filterSelect.value); - } -}); diff --git a/static/js/exports.js b/static/js/exports.js new file mode 100644 index 0000000..f47e298 --- /dev/null +++ b/static/js/exports.js @@ -0,0 +1,155 @@ +document.addEventListener("DOMContentLoaded", () => { + const modalContainer = document.createElement("div"); + modalContainer.id = "export-modal-container"; + document.body.appendChild(modalContainer); + + async function loadModal(dataset) { + const response = await fetch(`/exports/modal/${dataset}`); + if (!response.ok) { + throw new Error(`Failed to load export modal (${response.status})`); + } + const html = await response.text(); + modalContainer.innerHTML = html; + const modal = modalContainer.querySelector(".modal"); + if (!modal) return; + modal.classList.add("is-active"); + + const closeButtons = modal.querySelectorAll("[data-dismiss='modal']"); + closeButtons.forEach((btn) => + btn.addEventListener("click", () => closeModal(modal)) + ); + + const form = modal.querySelector("[data-export-form]"); + if (form) { + form.addEventListener("submit", handleSubmit); + } + } + + function closeModal(modal) { + modal.classList.remove("is-active"); + setTimeout(() => { + modalContainer.innerHTML = ""; + }, 200); + } + + async function handleSubmit(event) { + event.preventDefault(); + const form = event.currentTarget; + const submitUrl = form.action; + const formData = new FormData(form); + const format = formData.get("format") || "csv"; + + const submitBtn = form.querySelector("button[type='submit']"); + if (submitBtn) { + submitBtn.disabled = true; + submitBtn.classList.add("loading"); + } + + let response; + try { + response = await fetch(submitUrl, { + method: "POST", + headers: { + "Content-Type": "application/json", + }, + body: JSON.stringify({ + format, + include_metadata: formData.get("include_metadata") === "true", + filters: null, + }), + }); + } catch (error) { + console.error(error); + NotificationCenter.show({ + message: "Network error during export.", + level: "error", + }); + const errorContainer = form.querySelector("[data-export-error]"); + if (errorContainer) { + errorContainer.textContent = "Network error during export."; + errorContainer.classList.remove("hidden"); + } + submitBtn?.classList.remove("loading"); + submitBtn?.removeAttribute("disabled"); + return; + } + + if (!response.ok) { + let detail = "Export failed. Please try again."; + try { + const payload = await response.json(); + if (payload?.detail) { + detail = Array.isArray(payload.detail) + ? payload.detail.map((item) => item.msg || item).join("; ") + : payload.detail; + } + } catch (error) { + // ignore JSON parse issues + } + + NotificationCenter.show({ + message: detail, + level: "error", + }); + + const errorContainer = form.querySelector("[data-export-error]"); + if (errorContainer) { + errorContainer.textContent = detail; + errorContainer.classList.remove("hidden"); + } + + submitBtn?.classList.remove("loading"); + submitBtn?.removeAttribute("disabled"); + return; + } + + const blob = await response.blob(); + const disposition = response.headers.get("Content-Disposition"); + let filename = "export"; + if (disposition) { + const match = disposition.match(/filename=([^;]+)/i); + if (match) { + filename = match[1].replace(/"/g, ""); + } + } + + const url = window.URL.createObjectURL(blob); + const link = document.createElement("a"); + link.href = url; + link.download = filename; + document.body.appendChild(link); + link.click(); + link.remove(); + window.URL.revokeObjectURL(url); + + const modal = modalContainer.querySelector(".modal"); + if (modal) { + closeModal(modal); + } + + NotificationCenter.show({ + message: `Export ready: ${filename}`, + level: "success", + }); + + submitBtn?.classList.remove("loading"); + submitBtn?.removeAttribute("disabled"); + } + + document.querySelectorAll("[data-export-trigger]").forEach((button) => { + button.addEventListener("click", async (event) => { + event.preventDefault(); + const dataset = button.getAttribute("data-export-target"); + if (!dataset) return; + try { + await loadModal(dataset); + } catch (error) { + console.error(error); + NotificationCenter.show({ + message: "Unable to open export dialog.", + level: "error", + }); + } + }); + }); +}); diff --git a/static/js/imports.js b/static/js/imports.js new file mode 100644 index 0000000..7354693 --- /dev/null +++ b/static/js/imports.js @@ -0,0 +1,240 @@ +document.addEventListener("DOMContentLoaded", () => { + const moduleEl = document.querySelector("[data-import-module]"); + if (!moduleEl) return; + + const dropzone = moduleEl.querySelector("[data-import-dropzone]"); + const input = dropzone?.querySelector("input[type='file']"); + const uploadButton = moduleEl.querySelector("[data-import-upload-trigger]"); + const resetButton = moduleEl.querySelector("[data-import-reset]"); + const feedbackEl = moduleEl.querySelector("#import-upload-feedback"); + const previewBody = moduleEl.querySelector("[data-import-preview-body]"); + const previewContainer = moduleEl.querySelector("#import-preview-container"); + const actionsEl = moduleEl.querySelector("[data-import-actions]"); + const commitButton = moduleEl.querySelector("[data-import-commit]"); + const cancelButton = moduleEl.querySelector("[data-import-cancel]"); + + let stageToken = null; + + function showFeedback(message, type = "info") { + if (!feedbackEl) return; + feedbackEl.textContent = message; + feedbackEl.classList.remove("hidden", "success", "error", "info"); + feedbackEl.classList.add(type); + } + + function hideFeedback() { + if (!feedbackEl) return; + feedbackEl.textContent = ""; + feedbackEl.classList.add("hidden"); + } + + function clearPreview() { + if (previewBody) { + previewBody.innerHTML = ""; + } + previewContainer?.classList.add("hidden"); + actionsEl?.classList.add("hidden"); + commitButton?.setAttribute("disabled", "disabled"); + stageToken = null; + } + + function enableUpload() { + uploadButton?.removeAttribute("disabled"); + resetButton?.classList.remove("hidden"); + } + + function disableUpload() { + uploadButton?.setAttribute("disabled", "disabled"); + uploadButton?.classList.remove("loading"); + resetButton?.classList.add("hidden"); + } + + dropzone?.addEventListener("dragover", (event) => { + event.preventDefault(); + dropzone.classList.add("dragover"); + }); + + dropzone?.addEventListener("dragleave", () => { + dropzone.classList.remove("dragover"); + }); + + dropzone?.addEventListener("drop", (event) => { + event.preventDefault(); + dropzone.classList.remove("dragover"); + if (!event.dataTransfer?.files?.length || !input) { + return; + } + input.files = event.dataTransfer.files; + enableUpload(); + hideFeedback(); + }); + + input?.addEventListener("change", () => { + if (input.files?.length) { + enableUpload(); + hideFeedback(); + } else { + disableUpload(); + } + }); + + resetButton?.addEventListener("click", () => { + if (input) { + input.value = ""; + } + disableUpload(); + hideFeedback(); + clearPreview(); + }); + + async function uploadAndPreview() { + if (!input?.files?.length) { + showFeedback( + "Please select a CSV or XLSX file before uploading.", + "error" + ); + return; + } + + const file = input.files[0]; + showFeedback("Uploading…", "info"); + uploadButton?.classList.add("loading"); + uploadButton?.setAttribute("disabled", "disabled"); + + const formData = new FormData(); + formData.append("file", file); + + let response; + try { + response = await fetch("/imports/projects/preview", { + method: "POST", + body: formData, + }); + } catch (error) { + console.error(error); + NotificationCenter?.show({ + message: "Network error during upload.", + level: "error", + }); + showFeedback("Network error during upload.", "error"); + uploadButton?.classList.remove("loading"); + uploadButton?.removeAttribute("disabled"); + return; + } + + if (!response.ok) { + const detail = await response.json().catch(() => ({})); + const message = detail?.detail || "Upload failed. Please check the file."; + NotificationCenter?.show({ message, level: "error" }); + showFeedback(message, "error"); + uploadButton?.classList.remove("loading"); + uploadButton?.removeAttribute("disabled"); + return; + } + + const payload = await response.json(); + hideFeedback(); + renderPreview(payload); + uploadButton?.classList.remove("loading"); + uploadButton?.removeAttribute("disabled"); + + NotificationCenter?.show({ + message: `Preview ready: ${payload.summary.accepted} row(s) accepted`, + level: "success", + }); + } + + function renderPreview(payload) { + const rows = payload.rows || []; + const issues = payload.row_issues || []; + stageToken = payload.stage_token || null; + + if (!previewBody) return; + previewBody.innerHTML = ""; + + const issueMap = new Map(); + issues.forEach((issue) => { + issueMap.set(issue.row_number, issue.issues); + }); + + rows.forEach((row) => { + const tr = document.createElement("tr"); + const rowIssues = issueMap.get(row.row_number) || []; + const issuesText = [ + ...row.issues, + ...rowIssues.map((i) => i.message), + ].join(", "); + + tr.innerHTML = ` + ${row.row_number} + ${row.state} + ${issuesText || "—"} + ${Object.values(row.data) + .map((value) => `${value ?? ""}`) + .join("")} + `; + previewBody.appendChild(tr); + }); + + previewContainer?.classList.remove("hidden"); + if (stageToken && payload.summary.accepted > 0) { + actionsEl?.classList.remove("hidden"); + commitButton?.removeAttribute("disabled"); + } else { + actionsEl?.classList.add("hidden"); + commitButton?.setAttribute("disabled", "disabled"); + } + } + + uploadButton?.addEventListener("click", uploadAndPreview); + + commitButton?.addEventListener("click", async () => { + if (!stageToken) return; + commitButton.classList.add("loading"); + commitButton.setAttribute("disabled", "disabled"); + + let response; + try { + response = await fetch("/imports/projects/commit", { + method: "POST", + headers: { "Content-Type": "application/json" }, + body: JSON.stringify({ token: stageToken }), + }); + } catch (error) { + console.error(error); + NotificationCenter?.show({ + message: "Network error during commit.", + level: "error", + }); + commitButton.classList.remove("loading"); + commitButton.removeAttribute("disabled"); + return; + } + + if (!response.ok) { + const detail = await response.json().catch(() => ({})); + const message = + detail?.detail || "Commit failed. Please review the import data."; + NotificationCenter?.show({ message, level: "error" }); + commitButton.classList.remove("loading"); + commitButton.removeAttribute("disabled"); + return; + } + + const result = await response.json(); + NotificationCenter?.show({ + message: `Import committed. Created: ${result.summary.created}, Updated: ${result.summary.updated}`, + level: "success", + }); + clearPreview(); + if (input) { + input.value = ""; + } + disableUpload(); + }); + + cancelButton?.addEventListener("click", () => { + clearPreview(); + NotificationCenter?.show({ message: "Import canceled.", level: "info" }); + }); +}); diff --git a/static/js/maintenance.js b/static/js/maintenance.js deleted file mode 100644 index 9ec5f4a..0000000 --- a/static/js/maintenance.js +++ /dev/null @@ -1,243 +0,0 @@ -document.addEventListener("DOMContentLoaded", () => { - const dataElement = document.getElementById("maintenance-data"); - let equipmentByScenario = {}; - let maintenanceByScenario = {}; - - if (dataElement) { - try { - const parsed = JSON.parse(dataElement.textContent || "{}"); - if (parsed && typeof parsed === "object") { - if (parsed.equipment && typeof parsed.equipment === "object") { - equipmentByScenario = parsed.equipment; - } - if (parsed.maintenance && typeof parsed.maintenance === "object") { - maintenanceByScenario = parsed.maintenance; - } - } - } catch (error) { - console.error("Unable to parse maintenance data", error); - } - } - - const filterSelect = document.getElementById("maintenance-scenario-filter"); - const tableWrapper = document.getElementById("maintenance-table-wrapper"); - const tableBody = document.getElementById("maintenance-table-body"); - const emptyState = document.getElementById("maintenance-empty"); - const form = document.getElementById("maintenance-form"); - const feedbackEl = document.getElementById("maintenance-feedback"); - const formScenarioSelect = document.getElementById( - "maintenance-form-scenario" - ); - const equipmentSelect = document.getElementById("maintenance-form-equipment"); - const equipmentEmptyState = document.getElementById( - "maintenance-equipment-empty" - ); - - const showFeedback = (message, type = "success") => { - if (!feedbackEl) { - return; - } - feedbackEl.textContent = message; - feedbackEl.classList.remove("hidden", "success", "error"); - feedbackEl.classList.add(type); - }; - - const hideFeedback = () => { - if (!feedbackEl) { - return; - } - feedbackEl.classList.add("hidden"); - feedbackEl.textContent = ""; - }; - - const formatCost = (value) => - Number(value).toLocaleString(undefined, { - minimumFractionDigits: 2, - maximumFractionDigits: 2, - }); - - const formatDate = (value) => { - if (!value) { - return "—"; - } - const parsed = new Date(value); - if (Number.isNaN(parsed.getTime())) { - return value; - } - return parsed.toLocaleDateString(); - }; - - const renderMaintenanceRows = (scenarioId) => { - if (!tableBody || !tableWrapper || !emptyState) { - return; - } - - const key = String(scenarioId); - const records = maintenanceByScenario[key] || []; - - tableBody.innerHTML = ""; - - if (!records.length) { - emptyState.textContent = - "No maintenance entries recorded for this scenario yet."; - emptyState.classList.remove("hidden"); - tableWrapper.classList.add("hidden"); - return; - } - - emptyState.classList.add("hidden"); - tableWrapper.classList.remove("hidden"); - - records.forEach((record) => { - const row = document.createElement("tr"); - row.innerHTML = ` - ${formatDate(record.maintenance_date)} - ${record.equipment_name || "—"} - ${formatCost(record.cost)} - ${record.description || "—"} - `; - tableBody.appendChild(row); - }); - }; - - const populateEquipmentOptions = (scenarioId) => { - if (!equipmentSelect) { - return; - } - - equipmentSelect.innerHTML = - ''; - equipmentSelect.disabled = true; - - if (equipmentEmptyState) { - equipmentEmptyState.classList.add("hidden"); - } - - if (!scenarioId) { - return; - } - - const list = equipmentByScenario[String(scenarioId)] || []; - if (!list.length) { - if (equipmentEmptyState) { - equipmentEmptyState.textContent = - "Add equipment for this scenario before scheduling maintenance."; - equipmentEmptyState.classList.remove("hidden"); - } - return; - } - - list.forEach((item) => { - const option = document.createElement("option"); - option.value = item.id; - option.textContent = item.name || `Equipment ${item.id}`; - equipmentSelect.appendChild(option); - }); - - equipmentSelect.disabled = false; - }; - - if (filterSelect) { - filterSelect.addEventListener("change", (event) => { - const value = event.target.value; - if (!value) { - if (emptyState && tableWrapper && tableBody) { - emptyState.textContent = - "Choose a scenario to review upcoming or completed maintenance."; - emptyState.classList.remove("hidden"); - tableWrapper.classList.add("hidden"); - tableBody.innerHTML = ""; - } - return; - } - renderMaintenanceRows(value); - }); - } - - if (formScenarioSelect) { - formScenarioSelect.addEventListener("change", (event) => { - const value = event.target.value; - populateEquipmentOptions(value); - }); - } - - const submitMaintenance = async (event) => { - event.preventDefault(); - hideFeedback(); - - if (!form) { - return; - } - - const formData = new FormData(form); - const scenarioId = formData.get("scenario_id"); - const equipmentId = formData.get("equipment_id"); - const payload = { - scenario_id: scenarioId ? Number(scenarioId) : null, - equipment_id: equipmentId ? Number(equipmentId) : null, - maintenance_date: formData.get("maintenance_date"), - cost: Number(formData.get("cost")), - description: formData.get("description") || null, - }; - - if (!payload.scenario_id || !payload.equipment_id) { - showFeedback( - "Select a scenario and equipment before submitting.", - "error" - ); - return; - } - - try { - const response = await fetch("/api/maintenance/", { - method: "POST", - headers: { "Content-Type": "application/json" }, - body: JSON.stringify(payload), - }); - - if (!response.ok) { - const errorDetail = await response.json().catch(() => ({})); - throw new Error( - errorDetail.detail || "Unable to add maintenance entry." - ); - } - - const result = await response.json(); - const mapKey = String(result.scenario_id); - - if (!Array.isArray(maintenanceByScenario[mapKey])) { - maintenanceByScenario[mapKey] = []; - } - - const equipmentList = equipmentByScenario[mapKey] || []; - const matchedEquipment = equipmentList.find( - (item) => Number(item.id) === Number(result.equipment_id) - ); - result.equipment_name = matchedEquipment ? matchedEquipment.name : ""; - - maintenanceByScenario[mapKey].push(result); - - form.reset(); - populateEquipmentOptions(null); - showFeedback("Maintenance entry saved.", "success"); - - if (filterSelect && filterSelect.value === String(result.scenario_id)) { - renderMaintenanceRows(filterSelect.value); - } - } catch (error) { - showFeedback(error.message || "An unexpected error occurred.", "error"); - } - }; - - if (form) { - form.addEventListener("submit", submitMaintenance); - } - - if (filterSelect && filterSelect.value) { - renderMaintenanceRows(filterSelect.value); - } - - if (formScenarioSelect && formScenarioSelect.value) { - populateEquipmentOptions(formScenarioSelect.value); - } -}); diff --git a/static/js/navigation.js b/static/js/navigation.js new file mode 100644 index 0000000..f1113fc --- /dev/null +++ b/static/js/navigation.js @@ -0,0 +1,53 @@ +// Navigation chevron buttons logic +document.addEventListener("DOMContentLoaded", function () { + const navPrev = document.getElementById("nav-prev"); + const navNext = document.getElementById("nav-next"); + + if (!navPrev || !navNext) return; + + // Define the navigation order (main pages) + const navPages = [ + window.NAVIGATION_URLS.dashboard, + window.NAVIGATION_URLS.projects, + window.NAVIGATION_URLS.imports, + window.NAVIGATION_URLS.simulations, + window.NAVIGATION_URLS.reporting, + window.NAVIGATION_URLS.settings, + ]; + + const currentPath = window.location.pathname; + + // Find current index + let currentIndex = -1; + for (let i = 0; i < navPages.length; i++) { + if (currentPath.startsWith(navPages[i])) { + currentIndex = i; + break; + } + } + + // If not found, disable both + if (currentIndex === -1) { + navPrev.disabled = true; + navNext.disabled = true; + return; + } + + // Set up prev button + if (currentIndex > 0) { + navPrev.addEventListener("click", function () { + window.location.href = navPages[currentIndex - 1]; + }); + } else { + navPrev.disabled = true; + } + + // Set up next button + if (currentIndex < navPages.length - 1) { + navNext.addEventListener("click", function () { + window.location.href = navPages[currentIndex + 1]; + }); + } else { + navNext.disabled = true; + } +}); diff --git a/static/js/navigation_sidebar.js b/static/js/navigation_sidebar.js new file mode 100644 index 0000000..8be73a2 --- /dev/null +++ b/static/js/navigation_sidebar.js @@ -0,0 +1,230 @@ +(function () { + const NAV_ENDPOINT = "/navigation/sidebar"; + const SIDEBAR_SELECTOR = ".sidebar-nav"; + const DATA_SOURCE_ATTR = "navigationSource"; + const ROLE_ATTR = "navigationRoles"; + + function onReady(callback) { + if (document.readyState === "loading") { + document.addEventListener("DOMContentLoaded", callback, { once: true }); + } else { + callback(); + } + } + + function isActivePath(pathname, matchPrefix) { + if (!matchPrefix) { + return false; + } + if (matchPrefix === "/") { + return pathname === "/"; + } + return pathname.startsWith(matchPrefix); + } + + function createAnchor({ + href, + label, + matchPrefix, + tooltip, + isExternal, + isActive, + className, + }) { + const anchor = document.createElement("a"); + anchor.href = href; + anchor.className = className + (isActive ? " is-active" : ""); + anchor.dataset.matchPrefix = matchPrefix || href; + if (tooltip) { + anchor.title = tooltip; + } + if (isExternal) { + anchor.target = "_blank"; + anchor.rel = "noopener noreferrer"; + anchor.classList.add("is-external"); + } + anchor.textContent = label; + return anchor; + } + + function buildLinkBlock(link, pathname) { + if (!link || !link.href) { + return null; + } + const matchPrefix = link.match_prefix || link.matchPrefix || link.href; + const isActive = isActivePath(pathname, matchPrefix); + + const block = document.createElement("div"); + block.className = "sidebar-link-block"; + if (typeof link.id === "number") { + block.dataset.linkId = String(link.id); + } + + const anchor = createAnchor({ + href: link.href, + label: link.label, + matchPrefix, + tooltip: link.tooltip, + isExternal: Boolean(link.is_external ?? link.isExternal), + isActive, + className: "sidebar-link", + }); + block.appendChild(anchor); + + const children = Array.isArray(link.children) ? link.children : []; + if (children.length > 0) { + const container = document.createElement("div"); + container.className = "sidebar-sublinks"; + for (const child of children) { + if (!child || !child.href) { + continue; + } + const childMatch = + child.match_prefix || child.matchPrefix || child.href; + const childActive = isActivePath(pathname, childMatch); + const childAnchor = createAnchor({ + href: child.href, + label: child.label, + matchPrefix: childMatch, + tooltip: child.tooltip, + isExternal: Boolean(child.is_external ?? child.isExternal), + isActive: childActive, + className: "sidebar-sublink", + }); + container.appendChild(childAnchor); + } + if (container.children.length > 0) { + block.appendChild(container); + } + } + + return block; + } + + function buildGroupSection(group, pathname) { + if (!group) { + return null; + } + const links = Array.isArray(group.links) ? group.links : []; + if (links.length === 0) { + return null; + } + + const section = document.createElement("div"); + section.className = "sidebar-section"; + if (typeof group.id === "number") { + section.dataset.groupId = String(group.id); + } + + const label = document.createElement("div"); + label.className = "sidebar-section-label"; + label.textContent = group.label; + section.appendChild(label); + + const linksContainer = document.createElement("div"); + linksContainer.className = "sidebar-section-links"; + + for (const link of links) { + const block = buildLinkBlock(link, pathname); + if (block) { + linksContainer.appendChild(block); + } + } + + if (linksContainer.children.length === 0) { + return null; + } + + section.appendChild(linksContainer); + return section; + } + + function buildEmptyState() { + const section = document.createElement("div"); + section.className = "sidebar-section sidebar-empty-state"; + + const label = document.createElement("div"); + label.className = "sidebar-section-label"; + label.textContent = "Navigation"; + section.appendChild(label); + + const copyWrapper = document.createElement("div"); + copyWrapper.className = "sidebar-section-links"; + + const copy = document.createElement("p"); + copy.className = "sidebar-empty-copy"; + copy.textContent = "Navigation is unavailable."; + copyWrapper.appendChild(copy); + + section.appendChild(copyWrapper); + return section; + } + + function renderSidebar(navContainer, payload) { + const pathname = window.location.pathname; + const groups = Array.isArray(payload?.groups) ? payload.groups : []; + navContainer.replaceChildren(); + + const rendered = []; + for (const group of groups) { + const section = buildGroupSection(group, pathname); + if (section) { + rendered.push(section); + } + } + + if (rendered.length === 0) { + navContainer.appendChild(buildEmptyState()); + navContainer.dataset[DATA_SOURCE_ATTR] = "client-empty"; + delete navContainer.dataset[ROLE_ATTR]; + return; + } + + for (const section of rendered) { + navContainer.appendChild(section); + } + + navContainer.dataset[DATA_SOURCE_ATTR] = "client"; + const roles = Array.isArray(payload?.roles) ? payload.roles : []; + if (roles.length > 0) { + navContainer.dataset[ROLE_ATTR] = roles.join(","); + } else { + delete navContainer.dataset[ROLE_ATTR]; + } + } + + async function hydrateSidebar(navContainer) { + try { + const response = await fetch(NAV_ENDPOINT, { + method: "GET", + credentials: "include", + headers: { + Accept: "application/json", + }, + }); + + if (!response.ok) { + if (response.status !== 401 && response.status !== 403) { + console.warn( + "Navigation sidebar hydration failed with status", + response.status + ); + } + return; + } + + const payload = await response.json(); + renderSidebar(navContainer, payload); + } catch (error) { + console.warn("Navigation sidebar hydration failed", error); + } + } + + onReady(() => { + const navContainer = document.querySelector(SIDEBAR_SELECTOR); + if (!navContainer) { + return; + } + hydrateSidebar(navContainer); + }); +})(); diff --git a/static/js/notifications.js b/static/js/notifications.js new file mode 100644 index 0000000..8842ba5 --- /dev/null +++ b/static/js/notifications.js @@ -0,0 +1,38 @@ +(() => { + let container; + + function ensureContainer() { + if (!container) { + container = document.createElement("div"); + container.className = "toast-container"; + document.body.appendChild(container); + } + return container; + } + + function show({ message, level = "info", timeout = 5000 } = {}) { + const root = ensureContainer(); + const toast = document.createElement("div"); + toast.className = `toast toast--${level}`; + toast.setAttribute("role", "alert"); + toast.innerHTML = ` + +

${message}

+ + `; + root.appendChild(toast); + + const close = () => { + toast.classList.add("hidden"); + setTimeout(() => toast.remove(), 200); + }; + + toast.querySelector(".toast__close").addEventListener("click", close); + + if (timeout > 0) { + setTimeout(close, timeout); + } + } + + window.NotificationCenter = { show }; +})(); diff --git a/static/js/parameters.js b/static/js/parameters.js deleted file mode 100644 index b96d207..0000000 --- a/static/js/parameters.js +++ /dev/null @@ -1,124 +0,0 @@ -document.addEventListener("DOMContentLoaded", () => { - const dataElement = document.getElementById("parameters-data"); - let parametersByScenario = {}; - - if (dataElement) { - try { - const parsed = JSON.parse(dataElement.textContent || "{}"); - if (parsed && typeof parsed === "object") { - parametersByScenario = parsed; - } - } catch (error) { - console.error("Unable to parse parameter data", error); - } - } - - const form = document.getElementById("parameter-form"); - const scenarioSelect = /** @type {HTMLSelectElement | null} */ ( - document.getElementById("scenario_id") - ); - const nameInput = /** @type {HTMLInputElement | null} */ ( - document.getElementById("name") - ); - const valueInput = /** @type {HTMLInputElement | null} */ ( - document.getElementById("value") - ); - const feedback = document.getElementById("parameter-feedback"); - const tableBody = document.getElementById("parameter-table-body"); - - const setFeedback = (message, variant) => { - if (!feedback) { - return; - } - feedback.textContent = message; - feedback.classList.remove("success", "error"); - if (variant) { - feedback.classList.add(variant); - } - }; - - const renderTable = (scenarioId) => { - if (!tableBody) { - return; - } - tableBody.innerHTML = ""; - const rows = parametersByScenario[String(scenarioId)] || []; - if (!rows.length) { - const emptyRow = document.createElement("tr"); - emptyRow.id = "parameter-empty-state"; - emptyRow.innerHTML = - 'No parameters recorded for this scenario yet.'; - tableBody.appendChild(emptyRow); - return; - } - rows.forEach((row) => { - const tr = document.createElement("tr"); - tr.innerHTML = ` - ${row.name} - ${row.value} - ${row.distribution_type ?? "—"} - ${ - row.distribution_parameters - ? JSON.stringify(row.distribution_parameters) - : "—" - } - `; - tableBody.appendChild(tr); - }); - }; - - if (scenarioSelect) { - renderTable(scenarioSelect.value); - scenarioSelect.addEventListener("change", () => - renderTable(scenarioSelect.value) - ); - } - - if (!form || !scenarioSelect || !nameInput || !valueInput) { - return; - } - - form.addEventListener("submit", async (event) => { - event.preventDefault(); - - const scenarioId = scenarioSelect.value; - const payload = { - scenario_id: Number(scenarioId), - name: nameInput.value.trim(), - value: Number(valueInput.value), - }; - - if (!payload.name) { - setFeedback("Parameter name is required.", "error"); - return; - } - - if (!Number.isFinite(payload.value)) { - setFeedback("Enter a numeric value.", "error"); - return; - } - - const response = await fetch("/api/parameters/", { - method: "POST", - headers: { "Content-Type": "application/json" }, - body: JSON.stringify(payload), - }); - - if (!response.ok) { - const errorText = await response.text(); - setFeedback(`Error saving parameter: ${errorText}`, "error"); - return; - } - - const data = await response.json(); - const scenarioKey = String(scenarioId); - parametersByScenario[scenarioKey] = parametersByScenario[scenarioKey] || []; - parametersByScenario[scenarioKey].push(data); - - form.reset(); - scenarioSelect.value = scenarioKey; - renderTable(scenarioKey); - nameInput.focus(); - setFeedback("Parameter saved.", "success"); - }); -}); diff --git a/static/js/production.js b/static/js/production.js deleted file mode 100644 index 9d19a41..0000000 --- a/static/js/production.js +++ /dev/null @@ -1,204 +0,0 @@ -document.addEventListener("DOMContentLoaded", () => { - const dataElement = document.getElementById("production-data"); - let data = { scenarios: [], production: {}, unit_options: [] }; - - if (dataElement) { - try { - const parsed = JSON.parse(dataElement.textContent || "{}"); - if (parsed && typeof parsed === "object") { - data = { - scenarios: Array.isArray(parsed.scenarios) ? parsed.scenarios : [], - production: - parsed.production && typeof parsed.production === "object" - ? parsed.production - : {}, - unit_options: Array.isArray(parsed.unit_options) - ? parsed.unit_options - : [], - }; - } - } catch (error) { - console.error("Unable to parse production data", error); - } - } - - const productionByScenario = data.production; - const filterSelect = document.getElementById("production-scenario-filter"); - const tableWrapper = document.getElementById("production-table-wrapper"); - const tableBody = document.getElementById("production-table-body"); - const emptyState = document.getElementById("production-empty"); - const form = document.getElementById("production-form"); - const feedbackEl = document.getElementById("production-feedback"); - const unitSelect = document.getElementById("production-form-unit"); - const unitSymbolInput = document.getElementById("production-form-unit-symbol"); - - const showFeedback = (message, type = "success") => { - if (!feedbackEl) { - return; - } - feedbackEl.textContent = message; - feedbackEl.classList.remove("hidden", "success", "error"); - feedbackEl.classList.add(type); - }; - - const hideFeedback = () => { - if (!feedbackEl) { - return; - } - feedbackEl.classList.add("hidden"); - feedbackEl.textContent = ""; - }; - - const formatAmount = (value) => - Number(value).toLocaleString(undefined, { - minimumFractionDigits: 2, - maximumFractionDigits: 2, - }); - - const formatMeasurement = (amount, symbol, name) => { - if (symbol) { - return `${formatAmount(amount)} ${symbol}`; - } - if (name) { - return `${formatAmount(amount)} ${name}`; - } - return formatAmount(amount); - }; - - const renderProductionRows = (scenarioId) => { - if (!tableBody || !tableWrapper || !emptyState) { - return; - } - - const key = String(scenarioId); - const records = productionByScenario[key] || []; - - tableBody.innerHTML = ""; - - if (!records.length) { - emptyState.textContent = - "No production output recorded for this scenario yet."; - emptyState.classList.remove("hidden"); - tableWrapper.classList.add("hidden"); - return; - } - - emptyState.classList.add("hidden"); - tableWrapper.classList.remove("hidden"); - - records.forEach((record) => { - const row = document.createElement("tr"); - row.innerHTML = ` - ${formatMeasurement( - record.amount, - record.unit_symbol, - record.unit_name - )} - ${record.description || "—"} - `; - tableBody.appendChild(row); - }); - }; - - if (filterSelect) { - filterSelect.addEventListener("change", (event) => { - const value = event.target.value; - if (!value) { - if (emptyState && tableWrapper && tableBody) { - emptyState.textContent = - "Choose a scenario to review its production output."; - emptyState.classList.remove("hidden"); - tableWrapper.classList.add("hidden"); - tableBody.innerHTML = ""; - } - return; - } - renderProductionRows(value); - }); - } - - const submitProduction = async (event) => { - event.preventDefault(); - hideFeedback(); - - if (!form) { - return; - } - - const formData = new FormData(form); - const scenarioId = formData.get("scenario_id"); - const unitName = formData.get("unit_name"); - const unitSymbol = formData.get("unit_symbol"); - const payload = { - scenario_id: scenarioId ? Number(scenarioId) : null, - amount: Number(formData.get("amount")), - description: formData.get("description") || null, - unit_name: unitName ? String(unitName) : null, - unit_symbol: unitSymbol ? String(unitSymbol) : null, - }; - - try { - const response = await fetch("/api/production/", { - method: "POST", - headers: { "Content-Type": "application/json" }, - body: JSON.stringify(payload), - }); - - if (!response.ok) { - const errorDetail = await response.json().catch(() => ({})); - throw new Error( - errorDetail.detail || "Unable to add production output record." - ); - } - - const result = await response.json(); - const mapKey = String(result.scenario_id); - - if (!Array.isArray(productionByScenario[mapKey])) { - productionByScenario[mapKey] = []; - } - productionByScenario[mapKey].push(result); - - form.reset(); - syncUnitSelection(); - showFeedback("Production output saved.", "success"); - - if (filterSelect && filterSelect.value === String(result.scenario_id)) { - renderProductionRows(filterSelect.value); - } - } catch (error) { - showFeedback(error.message || "An unexpected error occurred.", "error"); - } - }; - - if (form) { - form.addEventListener("submit", submitProduction); - } - - const syncUnitSelection = () => { - if (!unitSelect || !unitSymbolInput) { - return; - } - if (!unitSelect.value && unitSelect.options.length > 0) { - const firstOption = Array.from(unitSelect.options).find( - (option) => option.value - ); - if (firstOption) { - firstOption.selected = true; - } - } - const selectedOption = unitSelect.options[unitSelect.selectedIndex]; - unitSymbolInput.value = selectedOption - ? selectedOption.getAttribute("data-symbol") || "" - : ""; - }; - - if (unitSelect) { - unitSelect.addEventListener("change", syncUnitSelection); - syncUnitSelection(); - } - - if (filterSelect && filterSelect.value) { - renderProductionRows(filterSelect.value); - } -}); diff --git a/static/js/projects.js b/static/js/projects.js new file mode 100644 index 0000000..3d85798 --- /dev/null +++ b/static/js/projects.js @@ -0,0 +1,137 @@ +document.addEventListener("DOMContentLoaded", () => { + const container = document.querySelector("[data-project-table]"); + const filterInput = document.querySelector("[data-project-filter]"); + + const resolveFilterItems = () => { + if (!container) { + return []; + } + + const entries = Array.from( + container.querySelectorAll("[data-project-entry]") + ); + + if (entries.length) { + return entries; + } + + if (container.tagName === "TABLE") { + return Array.from(container.querySelectorAll("tbody tr")); + } + + return []; + }; + + const filterItems = resolveFilterItems(); + + if (container && filterInput && filterItems.length) { + filterInput.addEventListener("input", () => { + const query = filterInput.value.trim().toLowerCase(); + filterItems.forEach((item) => { + const match = item.textContent.toLowerCase().includes(query); + item.style.display = match ? "" : "none"; + }); + }); + } + + const sidebar = document.querySelector(".app-sidebar"); + const appMain = document.querySelector(".app-main"); + if (!sidebar || !appMain) { + return; + } + + const body = document.body; + const mobileQuery = window.matchMedia("(max-width: 900px)"); + let toggleButton = document.querySelector("[data-sidebar-toggle]"); + + if (!toggleButton) { + toggleButton = document.createElement("button"); + toggleButton.type = "button"; + toggleButton.className = "sidebar-toggle"; + toggleButton.setAttribute("data-sidebar-toggle", ""); + toggleButton.setAttribute("aria-expanded", "false"); + toggleButton.setAttribute("aria-label", "Toggle primary navigation"); + toggleButton.hidden = true; + toggleButton.innerHTML = [ + '', + 'Menu', + ].join(""); + appMain.insertBefore(toggleButton, appMain.firstChild); + } + + let overlay = document.querySelector("[data-sidebar-overlay]"); + if (!overlay) { + overlay = document.createElement("div"); + overlay.className = "sidebar-overlay"; + overlay.setAttribute("data-sidebar-overlay", ""); + overlay.setAttribute("aria-hidden", "true"); + document.body.appendChild(overlay); + } + + const primaryNav = document.querySelector(".sidebar-nav"); + if (primaryNav) { + if (!primaryNav.id) { + primaryNav.id = "primary-navigation"; + } + toggleButton.setAttribute("aria-controls", primaryNav.id); + } + + const openSidebar = () => { + body.classList.remove("sidebar-collapsed"); + body.classList.add("sidebar-open"); + toggleButton.setAttribute("aria-expanded", "true"); + overlay.setAttribute("aria-hidden", "false"); + }; + + const closeSidebar = (focusToggle = false) => { + body.classList.add("sidebar-collapsed"); + body.classList.remove("sidebar-open"); + toggleButton.setAttribute("aria-expanded", "false"); + overlay.setAttribute("aria-hidden", "true"); + if (focusToggle) { + toggleButton.focus({ preventScroll: true }); + } + }; + + const toggleSidebar = () => { + if (body.classList.contains("sidebar-open")) { + closeSidebar(); + } else { + openSidebar(); + sidebar.setAttribute("aria-hidden", "false"); + } + }; + + const applyResponsiveState = (mql) => { + if (!mql.matches) { + toggleButton.hidden = true; + body.classList.remove("sidebar-open", "sidebar-collapsed"); + sidebar.setAttribute("aria-hidden", "true"); + overlay.setAttribute("aria-hidden", "true"); + sidebar.removeAttribute("aria-hidden"); + return; + } + + toggleButton.hidden = false; + if (!body.classList.contains("sidebar-open")) { + body.classList.add("sidebar-collapsed"); + sidebar.setAttribute("aria-hidden", "true"); + } + }; + + toggleButton.addEventListener("click", toggleSidebar); + overlay.addEventListener("click", () => closeSidebar()); + + document.addEventListener("keydown", (event) => { + if (event.key === "Escape" && body.classList.contains("sidebar-open")) { + closeSidebar(true); + } + }); + + applyResponsiveState(mobileQuery); + if (typeof mobileQuery.addEventListener === "function") { + mobileQuery.addEventListener("change", applyResponsiveState); + } else if (typeof mobileQuery.addListener === "function") { + mobileQuery.addListener(applyResponsiveState); + } +}); diff --git a/static/js/reporting.js b/static/js/reporting.js deleted file mode 100644 index 3ca2f64..0000000 --- a/static/js/reporting.js +++ /dev/null @@ -1,149 +0,0 @@ -document.addEventListener("DOMContentLoaded", () => { - const dataElement = document.getElementById("reporting-data"); - let reportingSummaries = []; - - if (dataElement) { - try { - const parsed = JSON.parse(dataElement.textContent || "[]"); - if (Array.isArray(parsed)) { - reportingSummaries = parsed; - } - } catch (error) { - console.error("Unable to parse reporting data", error); - } - } - - const REPORT_FIELDS = [ - { key: "iterations", label: "Iterations", decimals: 0 }, - { key: "mean", label: "Mean Result", decimals: 2 }, - { key: "variance", label: "Variance", decimals: 2 }, - { key: "std_dev", label: "Std. Dev", decimals: 2 }, - { key: "percentile_5", label: "Percentile 5", decimals: 2 }, - { key: "percentile_95", label: "Percentile 95", decimals: 2 }, - { key: "value_at_risk_95", label: "Value at Risk (95%)", decimals: 2 }, - { - key: "expected_shortfall_95", - label: "Expected Shortfall (95%)", - decimals: 2, - }, - ]; - - const tableWrapper = document.getElementById("reporting-table-wrapper"); - const tableBody = document.getElementById("reporting-table-body"); - const emptyState = document.getElementById("reporting-empty"); - const refreshButton = document.getElementById("report-refresh"); - const feedbackEl = document.getElementById("report-feedback"); - - const formatNumber = (value, decimals = 2) => { - if (value === null || value === undefined || Number.isNaN(Number(value))) { - return "—"; - } - return Number(value).toLocaleString(undefined, { - minimumFractionDigits: decimals, - maximumFractionDigits: decimals, - }); - }; - - const showFeedback = (message, type = "success") => { - if (!feedbackEl) { - return; - } - feedbackEl.textContent = message; - feedbackEl.classList.remove("hidden", "success", "error"); - feedbackEl.classList.add(type); - }; - - const hideFeedback = () => { - if (!feedbackEl) { - return; - } - feedbackEl.classList.add("hidden"); - feedbackEl.textContent = ""; - }; - - const renderReportingTable = (summaryData) => { - if (!tableBody || !tableWrapper || !emptyState) { - return; - } - - tableBody.innerHTML = ""; - - if (!summaryData.length) { - emptyState.classList.remove("hidden"); - tableWrapper.classList.add("hidden"); - return; - } - - emptyState.classList.add("hidden"); - tableWrapper.classList.remove("hidden"); - - summaryData.forEach((entry) => { - const row = document.createElement("tr"); - const scenarioCell = document.createElement("td"); - scenarioCell.textContent = entry.scenario_name; - row.appendChild(scenarioCell); - - REPORT_FIELDS.forEach((field) => { - const cell = document.createElement("td"); - const source = field.key === "iterations" ? entry : entry.summary || {}; - cell.textContent = formatNumber(source[field.key], field.decimals); - row.appendChild(cell); - }); - - tableBody.appendChild(row); - }); - }; - - const refreshMetrics = async () => { - hideFeedback(); - showFeedback("Refreshing metrics…", "success"); - - try { - const response = await fetch("/ui/reporting", { - method: "GET", - headers: { "X-Requested-With": "XMLHttpRequest" }, - }); - - if (!response.ok) { - throw new Error("Unable to refresh reporting data."); - } - - const text = await response.text(); - const parser = new DOMParser(); - const doc = parser.parseFromString(text, "text/html"); - const newTable = doc.querySelector("#reporting-table-wrapper"); - const newFeedback = doc.querySelector("#report-feedback"); - - if (!newTable) { - throw new Error("Unexpected response while refreshing."); - } - - const newEmptyState = doc.querySelector("#reporting-empty"); - - if (emptyState && newEmptyState) { - emptyState.className = newEmptyState.className; - emptyState.textContent = newEmptyState.textContent; - } - - if (tableWrapper) { - tableWrapper.className = newTable.className; - tableWrapper.innerHTML = newTable.innerHTML; - } - - if (newFeedback && feedbackEl) { - feedbackEl.className = newFeedback.className; - feedbackEl.textContent = newFeedback.textContent; - } - - showFeedback("Metrics refreshed.", "success"); - } catch (error) { - showFeedback(error.message || "An unexpected error occurred.", "error"); - } - }; - - renderReportingTable(reportingSummaries); - - if (refreshButton) { - refreshButton.addEventListener("click", refreshMetrics); - } -}); diff --git a/static/js/scenario-form.js b/static/js/scenario-form.js deleted file mode 100644 index f27722a..0000000 --- a/static/js/scenario-form.js +++ /dev/null @@ -1,78 +0,0 @@ -document.addEventListener("DOMContentLoaded", () => { - const form = document.getElementById("scenario-form"); - if (!form) { - return; - } - - const nameInput = /** @type {HTMLInputElement | null} */ ( - document.getElementById("name") - ); - const descriptionInput = /** @type {HTMLInputElement | null} */ ( - document.getElementById("description") - ); - const table = document.getElementById("scenario-table"); - const tableBody = document.getElementById("scenario-table-body"); - const emptyState = document.getElementById("empty-state"); - - form.addEventListener("submit", async (event) => { - event.preventDefault(); - - if (!nameInput || !descriptionInput) { - return; - } - - const payload = { - name: nameInput.value.trim(), - description: descriptionInput.value.trim() || null, - }; - - if (!payload.name) { - return; - } - - const response = await fetch("/api/scenarios/", { - method: "POST", - headers: { "Content-Type": "application/json" }, - body: JSON.stringify(payload), - }); - - if (!response.ok) { - const errorText = await response.text(); - console.error("Scenario creation failed", errorText); - return; - } - - const data = await response.json(); - const row = document.createElement("tr"); - row.dataset.scenarioId = String(data.id); - row.innerHTML = ` - ${data.name} - ${data.description ?? "—"} - `; - - if (emptyState) { - emptyState.remove(); - } - - if (table) { - table.classList.remove("hidden"); - table.removeAttribute("aria-hidden"); - } - - if (tableBody) { - tableBody.appendChild(row); - } - - form.reset(); - nameInput.focus(); - - const feedback = document.getElementById("feedback"); - if (feedback) { - feedback.textContent = `Scenario "${data.name}" created successfully.`; - feedback.classList.remove("hidden"); - setTimeout(() => { - feedback.classList.add("hidden"); - }, 3000); - } - }); -}); diff --git a/static/js/settings.js b/static/js/settings.js deleted file mode 100644 index 8d7d5c6..0000000 --- a/static/js/settings.js +++ /dev/null @@ -1,200 +0,0 @@ -(function () { - const dataScript = document.getElementById("theme-settings-data"); - const form = document.getElementById("theme-settings-form"); - const feedbackEl = document.getElementById("theme-settings-feedback"); - const resetBtn = document.getElementById("theme-settings-reset"); - const panel = document.getElementById("theme-settings"); - - if (!dataScript || !form || !feedbackEl || !panel) { - return; - } - - const apiUrl = panel.getAttribute("data-api"); - if (!apiUrl) { - return; - } - - const parsed = JSON.parse(dataScript.textContent || "{}"); - const currentValues = { ...(parsed.variables || {}) }; - const defaultValues = parsed.defaults || {}; - let envOverrides = { ...(parsed.envOverrides || {}) }; - - const previewElements = new Map(); - const inputs = Array.from(form.querySelectorAll(".color-value-input")); - - inputs.forEach((input) => { - const key = input.name; - const field = input.closest(".color-form-field"); - const preview = field ? field.querySelector(".color-preview") : null; - if (preview) { - previewElements.set(input, preview); - } - - if (Object.prototype.hasOwnProperty.call(envOverrides, key)) { - const overrideValue = envOverrides[key]; - input.value = overrideValue; - input.disabled = true; - input.setAttribute("aria-disabled", "true"); - input.dataset.envOverride = "true"; - if (field) { - field.classList.add("is-env-override"); - } - if (preview) { - preview.style.background = overrideValue; - } - return; - } - - input.addEventListener("input", () => { - const previewEl = previewElements.get(input); - if (previewEl) { - previewEl.style.background = input.value || defaultValues[key] || ""; - } - }); - }); - - function setFeedback(message, type) { - feedbackEl.textContent = message; - feedbackEl.classList.remove("hidden", "success", "error"); - if (type) { - feedbackEl.classList.add(type); - } - } - - function clearFeedback() { - feedbackEl.textContent = ""; - feedbackEl.classList.add("hidden"); - feedbackEl.classList.remove("success", "error"); - } - - function updateRootVariables(values) { - if (!values) { - return; - } - const root = document.documentElement; - Object.entries(values).forEach(([key, value]) => { - if (typeof key === "string" && typeof value === "string") { - root.style.setProperty(key, value); - } - }); - } - - function resetTo(source) { - inputs.forEach((input) => { - const key = input.name; - if (input.disabled) { - const previewEl = previewElements.get(input); - const fallback = envOverrides[key] || currentValues[key]; - if (previewEl && fallback) { - previewEl.style.background = fallback; - } - return; - } - if (Object.prototype.hasOwnProperty.call(source, key)) { - input.value = source[key]; - const previewEl = previewElements.get(input); - if (previewEl) { - previewEl.style.background = source[key]; - } - } - }); - } - - // Initialize previews to current values after page load. - resetTo(currentValues); - - resetBtn?.addEventListener("click", () => { - resetTo(defaultValues); - clearFeedback(); - setFeedback("Reverted to default values. Submit to save.", "success"); - }); - - form.addEventListener("submit", async (event) => { - event.preventDefault(); - clearFeedback(); - - const payload = {}; - inputs.forEach((input) => { - if (input.disabled) { - return; - } - payload[input.name] = input.value.trim(); - }); - - try { - const response = await fetch(apiUrl, { - method: "PUT", - headers: { - "Content-Type": "application/json", - }, - body: JSON.stringify({ variables: payload }), - }); - - if (!response.ok) { - let detail = "Unable to save theme settings."; - try { - const errorData = await response.json(); - if (errorData?.detail) { - detail = Array.isArray(errorData.detail) - ? errorData.detail.map((item) => item.msg || item).join("; ") - : errorData.detail; - } - } catch (parseError) { - // Ignore JSON parse errors and use default detail message. - } - setFeedback(detail, "error"); - return; - } - - const data = await response.json(); - const variables = data?.variables || {}; - const responseOverrides = data?.env_overrides || {}; - - Object.assign(currentValues, variables); - envOverrides = { ...responseOverrides }; - - inputs.forEach((input) => { - const key = input.name; - const field = input.closest(".color-form-field"); - const previewEl = previewElements.get(input); - const isOverride = Object.prototype.hasOwnProperty.call( - envOverrides, - key, - ); - - if (isOverride) { - const overrideValue = envOverrides[key]; - input.value = overrideValue; - if (!input.disabled) { - input.disabled = true; - input.setAttribute("aria-disabled", "true"); - } - if (field) { - field.classList.add("is-env-override"); - } - if (previewEl) { - previewEl.style.background = overrideValue; - } - } else if (input.disabled) { - input.disabled = false; - input.removeAttribute("aria-disabled"); - if (field) { - field.classList.remove("is-env-override"); - } - if ( - previewEl && - Object.prototype.hasOwnProperty.call(variables, key) - ) { - previewEl.style.background = variables[key]; - } - } - }); - - updateRootVariables(variables); - resetTo(variables); - setFeedback("Theme colors updated successfully.", "success"); - } catch (error) { - setFeedback("Network error: unable to save settings.", "error"); - } - }); -})(); diff --git a/static/js/simulations.js b/static/js/simulations.js deleted file mode 100644 index 9e9c8d6..0000000 --- a/static/js/simulations.js +++ /dev/null @@ -1,354 +0,0 @@ -document.addEventListener("DOMContentLoaded", () => { - const dataElement = document.getElementById("simulations-data"); - let simulationScenarios = []; - let initialRuns = []; - - if (dataElement) { - try { - const parsed = JSON.parse(dataElement.textContent || "{}"); - if (parsed && typeof parsed === "object") { - if (Array.isArray(parsed.scenarios)) { - simulationScenarios = parsed.scenarios; - } - if (Array.isArray(parsed.runs)) { - initialRuns = parsed.runs; - } - } - } catch (error) { - console.error("Unable to parse simulations data", error); - } - } - - const SUMMARY_FIELDS = [ - { key: "count", label: "Iterations", decimals: 0 }, - { key: "mean", label: "Mean Result", decimals: 2 }, - { key: "median", label: "Median Result", decimals: 2 }, - { key: "min", label: "Minimum", decimals: 2 }, - { key: "max", label: "Maximum", decimals: 2 }, - { key: "variance", label: "Variance", decimals: 2 }, - { key: "std_dev", label: "Standard Deviation", decimals: 2 }, - { key: "percentile_5", label: "Percentile 5", decimals: 2 }, - { key: "percentile_95", label: "Percentile 95", decimals: 2 }, - { key: "value_at_risk_95", label: "Value at Risk (95%)", decimals: 2 }, - { - key: "expected_shortfall_95", - label: "Expected Shortfall (95%)", - decimals: 2, - }, - ]; - const SAMPLE_RESULT_LIMIT = 20; - - const filterSelect = document.getElementById("simulations-scenario-filter"); - const overviewWrapper = document.getElementById( - "simulations-overview-wrapper" - ); - const overviewBody = document.getElementById("simulations-overview-body"); - const overviewEmpty = document.getElementById("simulations-overview-empty"); - const emptyState = document.getElementById("simulations-empty"); - const summaryWrapper = document.getElementById("simulations-summary-wrapper"); - const summaryBody = document.getElementById("simulations-summary-body"); - const summaryEmpty = document.getElementById("simulations-summary-empty"); - const resultsWrapper = document.getElementById("simulations-results-wrapper"); - const resultsBody = document.getElementById("simulations-results-body"); - const resultsEmpty = document.getElementById("simulations-results-empty"); - const simulationForm = document.getElementById("simulation-run-form"); - const simulationFeedback = document.getElementById("simulation-feedback"); - const formScenarioSelect = document.getElementById( - "simulation-form-scenario" - ); - - const simulationRunsMap = Object.create(null); - - const getScenarioName = (id) => { - const match = simulationScenarios.find( - (scenario) => String(scenario.id) === String(id) - ); - return match ? match.name : `Scenario ${id}`; - }; - - const formatNumber = (value, decimals = 2) => { - if (value === null || value === undefined || Number.isNaN(Number(value))) { - return "—"; - } - return Number(value).toLocaleString(undefined, { - minimumFractionDigits: decimals, - maximumFractionDigits: decimals, - }); - }; - - const showFeedback = (element, message, type = "success") => { - if (!element) { - return; - } - element.textContent = message; - element.classList.remove("hidden", "success", "error"); - element.classList.add(type); - }; - - const hideFeedback = (element) => { - if (!element) { - return; - } - element.classList.add("hidden"); - element.textContent = ""; - }; - - const initializeRunsMap = () => { - simulationScenarios.forEach((scenario) => { - const key = String(scenario.id); - simulationRunsMap[key] = { - scenario_id: scenario.id, - scenario_name: scenario.name, - iterations: 0, - summary: null, - sample_results: [], - }; - }); - - initialRuns.forEach((run) => { - const key = String(run.scenario_id); - simulationRunsMap[key] = { - scenario_id: run.scenario_id, - scenario_name: run.scenario_name || getScenarioName(key), - iterations: run.iterations || 0, - summary: run.summary || null, - sample_results: Array.isArray(run.sample_results) - ? run.sample_results - : [], - }; - }); - }; - - const renderOverviewTable = () => { - if (!overviewBody) { - return; - } - - overviewBody.innerHTML = ""; - - if (!simulationScenarios.length) { - if (overviewWrapper) { - overviewWrapper.classList.add("hidden"); - } - if (overviewEmpty) { - overviewEmpty.classList.remove("hidden"); - } - return; - } - - if (overviewWrapper) { - overviewWrapper.classList.remove("hidden"); - } - if (overviewEmpty) { - overviewEmpty.classList.add("hidden"); - } - - simulationScenarios.forEach((scenario) => { - const key = String(scenario.id); - const run = simulationRunsMap[key]; - const iterations = run && run.iterations ? run.iterations : 0; - const meanValue = - iterations && run && run.summary ? run.summary.mean : null; - - const row = document.createElement("tr"); - row.innerHTML = ` - ${scenario.name} - ${iterations || 0} - ${iterations ? formatNumber(meanValue) : "—"} - `; - overviewBody.appendChild(row); - }); - }; - - const renderScenarioDetails = (scenarioId) => { - if (!scenarioId) { - if (emptyState) { - emptyState.classList.remove("hidden"); - } - if (summaryWrapper) { - summaryWrapper.classList.add("hidden"); - } - if (summaryEmpty) { - summaryEmpty.classList.add("hidden"); - } - if (resultsWrapper) { - resultsWrapper.classList.add("hidden"); - } - if (resultsEmpty) { - resultsEmpty.classList.add("hidden"); - } - return; - } - - if (emptyState) { - emptyState.classList.add("hidden"); - } - - const run = simulationRunsMap[String(scenarioId)]; - const summary = run ? run.summary : null; - const samples = run ? run.sample_results || [] : []; - - if (!summary) { - if (summaryWrapper) { - summaryWrapper.classList.add("hidden"); - } - if (summaryEmpty) { - summaryEmpty.classList.remove("hidden"); - } - } else { - if (summaryWrapper) { - summaryWrapper.classList.remove("hidden"); - } - if (summaryEmpty) { - summaryEmpty.classList.add("hidden"); - } - - if (summaryBody) { - summaryBody.innerHTML = ""; - SUMMARY_FIELDS.forEach((field) => { - const row = document.createElement("tr"); - row.innerHTML = ` - ${field.label} - ${formatNumber(summary[field.key], field.decimals)} - `; - summaryBody.appendChild(row); - }); - } - } - - if (!samples.length) { - if (resultsWrapper) { - resultsWrapper.classList.add("hidden"); - } - if (resultsEmpty) { - resultsEmpty.classList.remove("hidden"); - } - } else { - if (resultsWrapper) { - resultsWrapper.classList.remove("hidden"); - } - if (resultsEmpty) { - resultsEmpty.classList.add("hidden"); - } - - if (resultsBody) { - resultsBody.innerHTML = ""; - samples.slice(0, SAMPLE_RESULT_LIMIT).forEach((item, index) => { - const row = document.createElement("tr"); - row.innerHTML = ` - ${index + 1} - ${formatNumber(item)} - `; - resultsBody.appendChild(row); - }); - } - } - }; - - const runSimulation = async (event) => { - event.preventDefault(); - hideFeedback(simulationFeedback); - - if (!simulationForm) { - return; - } - - const formData = new FormData(simulationForm); - const scenarioId = formData.get("scenario_id"); - const payload = { - scenario_id: scenarioId ? Number(scenarioId) : null, - iterations: Number(formData.get("iterations")), - seed: formData.get("seed") ? Number(formData.get("seed")) : null, - }; - - if (!payload.scenario_id) { - showFeedback( - simulationFeedback, - "Select a scenario before running a simulation.", - "error" - ); - return; - } - - try { - const response = await fetch("/api/simulations/", { - method: "POST", - headers: { "Content-Type": "application/json" }, - body: JSON.stringify(payload), - }); - - if (!response.ok) { - const errorDetail = await response.json().catch(() => ({})); - throw new Error(errorDetail.detail || "Unable to run simulation."); - } - - const result = await response.json(); - const mapKey = String(result.scenario_id); - const summary = - result.summary && typeof result.summary === "object" - ? result.summary - : null; - const iterations = - summary && typeof summary.count === "number" - ? summary.count - : payload.iterations || 0; - - simulationRunsMap[mapKey] = { - scenario_id: result.scenario_id, - scenario_name: getScenarioName(mapKey), - iterations, - summary, - sample_results: Array.isArray(result.sample_results) - ? result.sample_results - : [], - }; - - renderOverviewTable(); - renderScenarioDetails(mapKey); - - if (filterSelect) { - filterSelect.value = mapKey; - } - if (formScenarioSelect) { - formScenarioSelect.value = mapKey; - } - - simulationForm.reset(); - showFeedback(simulationFeedback, "Simulation completed.", "success"); - } catch (error) { - showFeedback( - simulationFeedback, - error.message || "An unexpected error occurred.", - "error" - ); - } - }; - - initializeRunsMap(); - renderOverviewTable(); - - if (filterSelect) { - filterSelect.addEventListener("change", (event) => { - const value = event.target.value; - renderScenarioDetails(value); - }); - } - - if (formScenarioSelect) { - formScenarioSelect.addEventListener("change", (event) => { - const value = event.target.value; - if (filterSelect) { - filterSelect.value = value; - } - renderScenarioDetails(value); - }); - } - - if (simulationForm) { - simulationForm.addEventListener("submit", runSimulation); - } - - if (filterSelect && filterSelect.value) { - renderScenarioDetails(filterSelect.value); - } -}); diff --git a/templates/ParameterInput.html b/templates/ParameterInput.html deleted file mode 100644 index 10f30b1..0000000 --- a/templates/ParameterInput.html +++ /dev/null @@ -1,51 +0,0 @@ -{% extends "base.html" %} {% block title %}Process Parameters · CalMiner{% -endblock %} {% block content %} -
-

Scenario Parameters

- {% if scenarios %} -
- - - - -
- -
- - - - - - - - - - -
ParameterValueDistributionDetails
-
- {% else %} -

- No scenarios available. Create a scenario before - adding parameters. -

- {% endif %} -
-{% endblock %} {% block scripts %} {{ super() }} - - -{% endblock %} diff --git a/templates/ScenarioForm.html b/templates/ScenarioForm.html deleted file mode 100644 index fc5e6ab..0000000 --- a/templates/ScenarioForm.html +++ /dev/null @@ -1,53 +0,0 @@ -{% extends "base.html" %} {% block title %}Scenario Management · CalMiner{% -endblock %} {% block content %} -
-

Create a New Scenario

-
- - - -
- -
- {% if scenarios %} - - - - - - - - - {% for scenario in scenarios %} - - - - - {% endfor %} - -
NameDescription
{{ scenario.name }}{{ scenario.description or "—" }}
- {% else %} -

- No scenarios yet. Create one to get started. -

- - - - - - - - - - {% endif %} -
-
-{% endblock %} {% block scripts %} {{ super() }} - -{% endblock %} diff --git a/templates/base.html b/templates/base.html index 53722db..15e313d 100644 --- a/templates/base.html +++ b/templates/base.html @@ -4,7 +4,10 @@ {% block title %}CalMiner{% endblock %} - + + + + {% block head_extra %}{% endblock %} @@ -20,6 +23,28 @@ {% block scripts %}{% endblock %} + + + + + + + diff --git a/templates/consumption.html b/templates/consumption.html deleted file mode 100644 index 63c1198..0000000 --- a/templates/consumption.html +++ /dev/null @@ -1,76 +0,0 @@ -{% extends "base.html" %} {% from "partials/components.html" import -select_field, feedback, empty_state, table_container with context %} {% block -title %}Consumption · CalMiner{% endblock %} {% block content %} -
-

Consumption Tracking

-
- {{ select_field( "Scenario filter", "consumption-scenario-filter", - options=scenarios, placeholder="Select a scenario" ) }} -
- {{ empty_state( "consumption-empty", "Choose a scenario to review its - consumption records." ) }} {% call table_container( - "consumption-table-wrapper", hidden=True, aria_label="Scenario consumption - records" ) %} - - - Amount - Description - - - - {% endcall %} -
- -
-

Add Consumption Record

- {% if scenarios %} -
- {{ select_field( "Scenario", "consumption-form-scenario", - name="scenario_id", options=scenarios, required=True, placeholder="Select a - scenario", placeholder_disabled=True ) }} - - - - - -
- {{ feedback("consumption-feedback") }} {% else %} -

- Create a scenario before adding consumption records. -

- {% endif %} -
- -{% endblock %} {% block scripts %} {{ super() }} - - -{% endblock %} diff --git a/templates/costs.html b/templates/costs.html deleted file mode 100644 index 417de97..0000000 --- a/templates/costs.html +++ /dev/null @@ -1,129 +0,0 @@ -{% extends "base.html" %} {% from "partials/components.html" import -select_field, feedback, empty_state, table_container with context %} {% block -title %}Costs · CalMiner{% endblock %} {% block content %} -
-

Cost Overview

- {% if scenarios %} -
- {{ select_field( "Scenario filter", "costs-scenario-filter", - options=scenarios, placeholder="Select a scenario" ) }} -
- {% else %} {{ empty_state( "costs-scenario-empty", "Create a scenario to - review cost information." ) }} {% endif %} {{ empty_state( "costs-empty", - "Choose a scenario to review CAPEX and OPEX details." ) }} - - -
- -
-

Add CAPEX Entry

- {% if scenarios %} -
- {{ select_field( "Scenario", "capex-form-scenario", name="scenario_id", - options=scenarios, required=True, placeholder="Select a scenario", - placeholder_disabled=True ) }} {{ select_field( "Currency", - "capex-form-currency", name="currency_code", options=currency_options, - required=True, placeholder="Select currency", placeholder_disabled=True, - value_attr="id", label_attr="name" ) }} - - - -
- {{ feedback("capex-feedback") }} {% else %} {{ empty_state( - "capex-form-empty", "Create a scenario before adding CAPEX entries." ) }} {% - endif %} -
- -
-

Add OPEX Entry

- {% if scenarios %} -
- {{ select_field( "Scenario", "opex-form-scenario", name="scenario_id", - options=scenarios, required=True, placeholder="Select a scenario", - placeholder_disabled=True ) }} {{ select_field( "Currency", - "opex-form-currency", name="currency_code", options=currency_options, - required=True, placeholder="Select currency", placeholder_disabled=True, - value_attr="id", label_attr="name" ) }} - - - -
- {{ feedback("opex-feedback") }} {% else %} {{ empty_state( "opex-form-empty", - "Create a scenario before adding OPEX entries." ) }} {% endif %} -
- -{% endblock %} {% block scripts %} {{ super() }} - - -{% endblock %} diff --git a/templates/currencies.html b/templates/currencies.html index 6c99515..d98d26f 100644 --- a/templates/currencies.html +++ b/templates/currencies.html @@ -1,131 +1,31 @@ {% extends "base.html" %} -{% from "partials/components.html" import select_field, feedback, empty_state, table_container with context %} - -{% block title %}Currencies · CalMiner{% endblock %} +{% block title %}{{ title }} | CalMiner{% endblock %} {% block content %} -
-
+
- - {% if currency_stats %} -
-
- Total Currencies - {{ currency_stats.total }} -
-
- Active - {{ currency_stats.active }} -
-
- Inactive - {{ currency_stats.inactive }} -
- {% else %} {{ empty_state("currencies-overview-empty", "No currency data - available yet.") }} {% endif %} {% call table_container( - "currencies-table-container", aria_label="Configured currencies", - heading="Configured Currencies" ) %} - - - Code - Name - Symbol - Status - Actions - - - - {% endcall %} {{ empty_state( "currencies-table-empty", "No currencies - configured yet.", hidden=currencies|length > 0 ) }} -
-
-
-
-

Manage Currencies

-

- Create new currencies or update existing configurations inline. -

+
+
+

Currency Configuration

+

Define available currencies and their properties.

+

Currency management coming soon

-
- {% set status_options = [ {"id": "true", "name": "Active"}, {"id": "false", - "name": "Inactive"} ] %} - -
- {{ select_field( "Currency to update (leave blank for new)", - "currency-form-existing", name="existing_code", options=currencies, - placeholder="Create a new currency", value_attr="code", label_attr="name" ) - }} - - - - - - - - {{ select_field( "Status", "currency-form-status", name="is_active", - options=status_options, include_blank=False ) }} - -
- - +
+

Exchange Rates

+

Configure and update currency exchange rates.

+

Exchange rate management coming soon

- - {{ feedback("currency-form-feedback") }} -
-{% endblock %} {% block scripts %} {{ super() }} - - -{% endblock %} + +
+

Default Settings

+

Set default currencies for new projects and scenarios.

+

Default currency settings coming soon

+
+ +{% endblock %} \ No newline at end of file diff --git a/templates/dashboard.html b/templates/dashboard.html new file mode 100644 index 0000000..9eea868 --- /dev/null +++ b/templates/dashboard.html @@ -0,0 +1,178 @@ +{% extends "base.html" %} {% block title %}Dashboard · CalMiner{% endblock %} {% +block head_extra %} + +{% endblock %} {% block content %} + + +
+
+

Total Projects

+

{{ metrics.total_projects }}

+ Across all operation types +
+
+

Active Scenarios

+

{{ metrics.active_scenarios }}

+ Ready for analysis +
+
+

Pending Simulations

+

{{ metrics.pending_simulations }}

+ Awaiting execution +
+
+

Last Data Import

+

{{ metrics.last_import or '—' }}

+ UTC timestamp +
+
+ +
+
+
+
+

Recent Projects

+ View all +
+ {% if recent_projects %} + + + + + + + + + + {% for project in recent_projects %} + + + + + + {% endfor %} + +
ProjectOperationUpdated
+ {{ project.name }} + + + {{ project.operation_type.value.replace('_', ' ') | title }} + + {{ project.updated_at.strftime('%Y-%m-%d') if project.updated_at + else '—' }} +
+ {% else %} +

+ No recent projects. + Create one now. +

+ {% endif %} +
+ +
+
+

Simulation Pipeline

+
+ {% if simulation_updates %} +
    + {% for update in simulation_updates %} +
  • + {{ update.timestamp_label or '—' }} +
    + {{ update.title }} +

    {{ update.description }}

    +
    +
  • + {% endfor %} +
+ {% else %} +

+ No simulation runs yet. Configure a scenario to start simulations. +

+ {% endif %} +
+
+ + +
+{% endblock %} diff --git a/templates/equipment.html b/templates/equipment.html deleted file mode 100644 index ed02537..0000000 --- a/templates/equipment.html +++ /dev/null @@ -1,78 +0,0 @@ -{% extends "base.html" %} {% block title %}Equipment · CalMiner{% endblock %} {% -block content %} -
-

Equipment Inventory

- {% if scenarios %} -
- -
- {% else %} -

- Create a scenario to view equipment inventory. -

- {% endif %} -
- Choose a scenario to review the equipment list. -
- -
- -
-

Add Equipment

- {% if scenarios %} -
- - - - -
- - {% else %} -

- Create a scenario before managing equipment. -

- {% endif %} -
- -{% endblock %} {% block scripts %} {{ super() }} - - -{% endblock %} diff --git a/templates/exports/modal.html b/templates/exports/modal.html new file mode 100644 index 0000000..4122e17 --- /dev/null +++ b/templates/exports/modal.html @@ -0,0 +1,52 @@ + diff --git a/templates/forgot_password.html b/templates/forgot_password.html index 4d21fd3..9618863 100644 --- a/templates/forgot_password.html +++ b/templates/forgot_password.html @@ -1,17 +1,25 @@ -{% extends "base.html" %} - -{% block title %}Forgot Password{% endblock %} - -{% block content %} +{% extends "base.html" %} {% block title %}Forgot Password{% endblock %} {% +block content %}
-

Forgot Password

-
-
- - -
- -
-

Remember your password? Login here

+

Forgot Password

+ {% if errors %} +
+
    + {% for error in errors %} +
  • {{ error }}
  • + {% endfor %} +
+
+ {% endif %} {% if message %} +
{{ message }}
+ {% endif %} +
+
+ + +
+ +
+

Remember your password? Login here

{% endblock %} diff --git a/templates/imports/ui.html b/templates/imports/ui.html new file mode 100644 index 0000000..bb2324e --- /dev/null +++ b/templates/imports/ui.html @@ -0,0 +1,34 @@ +{% extends "base.html" %} +{% from "partials/alerts.html" import toast %} + +{% block title %}Imports · CalMiner{% endblock %} + +{% block head_extra %} + +{% endblock %} + +{% block content %} + + +
+
+

Upload Projects or Scenarios

+
+
+ {% include "partials/import_upload.html" %} + {% include "partials/import_preview_table.html" %} + + +
+
+ + {{ toast("import-toast", hidden=True) }} +{% endblock %} \ No newline at end of file diff --git a/templates/login.html b/templates/login.html index 6c2eb00..bff5388 100644 --- a/templates/login.html +++ b/templates/login.html @@ -1,22 +1,34 @@ -{% extends "base.html" %} - -{% block title %}Login{% endblock %} - -{% block content %} +{% extends "base.html" %} {% block title %}Login{% endblock %} {% block content +%}
-

Login

-
-
- - -
-
- - -
- -
-

Don't have an account? Register here

-

Forgot password?

+

Login

+ {% if errors %} +
+
    + {% for error in errors %} +
  • {{ error }}
  • + {% endfor %} +
+
+ {% endif %} +
+
+ + +
+
+ + +
+ +
+

Don't have an account? Register here

+

Forgot password?

{% endblock %} diff --git a/templates/maintenance.html b/templates/maintenance.html deleted file mode 100644 index 51b0449..0000000 --- a/templates/maintenance.html +++ /dev/null @@ -1,111 +0,0 @@ -{% extends "base.html" %} {% block title %}Maintenance · CalMiner{% endblock %} -{% block content %} -
-

Maintenance Schedule

- {% if scenarios %} -
- -
- {% else %} -

- Create a scenario to view maintenance entries. -

- {% endif %} -
- Choose a scenario to review upcoming or completed maintenance. -
- -
- -
-

Add Maintenance Entry

- {% if scenarios %} -
- - - - - - - -
- - {% else %} -

- Create a scenario before managing maintenance - entries. -

- {% endif %} -
- -{% endblock %} {% block scripts %} {{ super() }} - - -{% endblock %} diff --git a/templates/partials/alerts.html b/templates/partials/alerts.html new file mode 100644 index 0000000..9315d62 --- /dev/null +++ b/templates/partials/alerts.html @@ -0,0 +1,10 @@ +{% macro toast(id, hidden=True, level="info", message="") %} + +{% endmacro %} diff --git a/templates/partials/base_footer.html b/templates/partials/base_footer.html index de97869..a990d9c 100644 --- a/templates/partials/base_footer.html +++ b/templates/partials/base_footer.html @@ -1,5 +1,8 @@