+
+
+ {% set status_options = [ {"id": "true", "name": "Active"}, {"id": "false",
+ "name": "Inactive"} ] %}
+
+
+ {{ feedback("currency-form-feedback") }}
+
+{% endblock %} {% block scripts %} {{ super() }}
+
+
+{% endblock %}
diff --git a/templates/partials/base_footer.html b/templates/partials/base_footer.html
index 1d204e0..de97869 100644
--- a/templates/partials/base_footer.html
+++ b/templates/partials/base_footer.html
@@ -1,5 +1,8 @@
diff --git a/templates/partials/base_header.html b/templates/partials/base_header.html
index 1140a18..a8d67e2 100644
--- a/templates/partials/base_header.html
+++ b/templates/partials/base_header.html
@@ -2,6 +2,7 @@
("/", "Dashboard"),
("/ui/scenarios", "Scenarios"),
("/ui/parameters", "Parameters"),
+ ("/ui/currencies", "Currencies"),
("/ui/costs", "Costs"),
("/ui/consumption", "Consumption"),
("/ui/production", "Production"),
diff --git a/tests/e2e/conftest.py b/tests/e2e/conftest.py
index f3a9d4b..011d6c7 100644
--- a/tests/e2e/conftest.py
+++ b/tests/e2e/conftest.py
@@ -64,6 +64,40 @@ def live_server() -> Generator[str, None, None]:
process.wait(timeout=5)
+@pytest.fixture(scope="session", autouse=True)
+def seed_default_currencies(live_server: str) -> None:
+ """Ensure a baseline set of currencies exists for UI flows."""
+
+ seeds = [
+ {"code": "EUR", "name": "Euro", "symbol": "EUR", "is_active": True},
+ {"code": "CLP", "name": "Chilean Peso", "symbol": "CLP$", "is_active": True},
+ ]
+
+ with httpx.Client(base_url=live_server, timeout=5.0) as client:
+ try:
+ response = client.get("/api/currencies/?include_inactive=true")
+ response.raise_for_status()
+ existing_codes = {
+ str(item.get("code"))
+ for item in response.json()
+ if isinstance(item, dict) and item.get("code")
+ }
+ except httpx.HTTPError as exc: # noqa: BLE001
+ raise RuntimeError("Failed to read existing currencies") from exc
+
+ for payload in seeds:
+ if payload["code"] in existing_codes:
+ continue
+ try:
+ create_response = client.post("/api/currencies/", json=payload)
+ except httpx.HTTPError as exc: # noqa: BLE001
+ raise RuntimeError("Failed to seed currencies") from exc
+
+ if create_response.status_code == 409:
+ continue
+ create_response.raise_for_status()
+
+
@pytest.fixture(scope="session")
def playwright_instance() -> Generator[Playwright, None, None]:
"""Provide a Playwright instance for the test session."""
diff --git a/tests/e2e/test_currencies.py b/tests/e2e/test_currencies.py
new file mode 100644
index 0000000..b467ad1
--- /dev/null
+++ b/tests/e2e/test_currencies.py
@@ -0,0 +1,130 @@
+import random
+import string
+
+from playwright.sync_api import Page, expect
+
+
+def _unique_currency_code(existing: set[str]) -> str:
+ """Generate a unique three-letter code not present in *existing*."""
+ alphabet = string.ascii_uppercase
+ for _ in range(100):
+ candidate = "".join(random.choices(alphabet, k=3))
+ if candidate not in existing and candidate != "USD":
+ return candidate
+ raise AssertionError(
+ "Unable to generate a unique currency code for the test run.")
+
+
+def _metric_value(page: Page, element_id: str) -> int:
+ locator = page.locator(f"#{element_id}")
+ expect(locator).to_be_visible()
+ return int(locator.inner_text().strip())
+
+
+def _expect_feedback(page: Page, expected_text: str) -> None:
+ page.wait_for_function(
+ "expected => {"
+ " const el = document.getElementById('currency-form-feedback');"
+ " if (!el) return false;"
+ " const text = (el.textContent || '').trim();"
+ " return !el.classList.contains('hidden') && text === expected;"
+ "}",
+ arg=expected_text,
+ )
+ feedback = page.locator("#currency-form-feedback")
+ expect(feedback).to_have_text(expected_text)
+
+
+def test_currency_workflow_create_update_toggle(page: Page) -> None:
+ """Exercise create, update, and toggle flows on the currency settings page."""
+ page.goto("/ui/currencies")
+ expect(page).to_have_title("Currencies · CalMiner")
+ expect(page.locator("h2:has-text('Currency Overview')")).to_be_visible()
+
+ code_cells = page.locator("#currencies-table-body tr td:nth-child(1)")
+ existing_codes = {text.strip().upper()
+ for text in code_cells.all_inner_texts()}
+
+ total_before = _metric_value(page, "currency-metric-total")
+ active_before = _metric_value(page, "currency-metric-active")
+ inactive_before = _metric_value(page, "currency-metric-inactive")
+
+ new_code = _unique_currency_code(existing_codes)
+ new_name = f"Test Currency {new_code}"
+ new_symbol = new_code[0]
+
+ page.fill("#currency-form-code", new_code)
+ page.fill("#currency-form-name", new_name)
+ page.fill("#currency-form-symbol", new_symbol)
+ page.select_option("#currency-form-status", "true")
+
+ with page.expect_response("**/api/currencies/") as create_info:
+ page.click("button[type='submit']")
+ create_response = create_info.value
+ assert create_response.status == 201
+
+ _expect_feedback(page, "Currency created successfully.")
+
+ page.wait_for_function(
+ "expected => Number(document.getElementById('currency-metric-total').textContent.trim()) === expected",
+ arg=total_before + 1,
+ )
+ page.wait_for_function(
+ "expected => Number(document.getElementById('currency-metric-active').textContent.trim()) === expected",
+ arg=active_before + 1,
+ )
+
+ row = page.locator("#currencies-table-body tr").filter(has_text=new_code)
+ expect(row).to_be_visible()
+ expect(row.locator("td").nth(3)).to_have_text("Active")
+
+ # Switch to update mode using the existing currency option.
+ page.select_option("#currency-form-existing", new_code)
+ updated_name = f"{new_name} Updated"
+ updated_symbol = f"{new_symbol}$"
+ page.fill("#currency-form-name", updated_name)
+ page.fill("#currency-form-symbol", updated_symbol)
+ page.select_option("#currency-form-status", "false")
+
+ with page.expect_response(f"**/api/currencies/{new_code}") as update_info:
+ page.click("button[type='submit']")
+ update_response = update_info.value
+ assert update_response.status == 200
+
+ _expect_feedback(page, "Currency updated successfully.")
+
+ page.wait_for_function(
+ "expected => Number(document.getElementById('currency-metric-active').textContent.trim()) === expected",
+ arg=active_before,
+ )
+ page.wait_for_function(
+ "expected => Number(document.getElementById('currency-metric-inactive').textContent.trim()) === expected",
+ arg=inactive_before + 1,
+ )
+
+ expect(row.locator("td").nth(1)).to_have_text(updated_name)
+ expect(row.locator("td").nth(2)).to_have_text(updated_symbol)
+ expect(row.locator("td").nth(3)).to_contain_text("Inactive")
+
+ toggle_button = row.locator("button[data-action='toggle']")
+ expect(toggle_button).to_have_text("Activate")
+
+ with page.expect_response(f"**/api/currencies/{new_code}/activation") as toggle_info:
+ toggle_button.click()
+ toggle_response = toggle_info.value
+ assert toggle_response.status == 200
+
+ page.wait_for_function(
+ "expected => Number(document.getElementById('currency-metric-active').textContent.trim()) === expected",
+ arg=active_before + 1,
+ )
+ page.wait_for_function(
+ "expected => Number(document.getElementById('currency-metric-inactive').textContent.trim()) === expected",
+ arg=inactive_before,
+ )
+
+ _expect_feedback(page, f"Currency {new_code} activated.")
+
+ expect(row.locator("td").nth(3)).to_contain_text("Active")
+ expect(row.locator("button[data-action='toggle']")
+ ).to_have_text("Deactivate")
diff --git a/tests/e2e/test_smoke.py b/tests/e2e/test_smoke.py
index a601dcc..01c0f18 100644
--- a/tests/e2e/test_smoke.py
+++ b/tests/e2e/test_smoke.py
@@ -14,6 +14,7 @@ UI_ROUTES = [
("/ui/maintenance", "Maintenance · CalMiner", "Maintenance Schedule"),
("/ui/simulations", "Simulations · CalMiner", "Monte Carlo Simulations"),
("/ui/reporting", "Reporting · CalMiner", "Scenario KPI Summary"),
+ ("/ui/currencies", "Currencies · CalMiner", "Currency Overview"),
]
diff --git a/tests/unit/test_currencies.py b/tests/unit/test_currencies.py
new file mode 100644
index 0000000..5aa674c
--- /dev/null
+++ b/tests/unit/test_currencies.py
@@ -0,0 +1,101 @@
+from typing import Dict
+
+import pytest
+
+from models.currency import Currency
+
+
+@pytest.fixture(autouse=True)
+def _cleanup_currencies(db_session):
+ db_session.query(Currency).delete()
+ db_session.commit()
+ yield
+ db_session.query(Currency).delete()
+ db_session.commit()
+
+
+def _assert_currency(payload: Dict[str, object], code: str, name: str, symbol: str | None, is_active: bool) -> None:
+ assert payload["code"] == code
+ assert payload["name"] == name
+ assert payload["is_active"] is is_active
+ if symbol is None:
+ assert payload["symbol"] is None
+ else:
+ assert payload["symbol"] == symbol
+
+
+def test_list_returns_default_currency(api_client, db_session):
+ response = api_client.get("/api/currencies/")
+ assert response.status_code == 200
+ data = response.json()
+ assert any(item["code"] == "USD" for item in data)
+
+
+def test_create_currency_success(api_client, db_session):
+ payload = {"code": "EUR", "name": "Euro", "symbol": "€", "is_active": True}
+ response = api_client.post("/api/currencies/", json=payload)
+ assert response.status_code == 201
+ data = response.json()
+ _assert_currency(data, "EUR", "Euro", "€", True)
+
+ stored = db_session.query(Currency).filter_by(code="EUR").one()
+ assert stored.name == "Euro"
+ assert stored.symbol == "€"
+ assert stored.is_active is True
+
+
+def test_create_currency_conflict(api_client, db_session):
+ api_client.post(
+ "/api/currencies/",
+ json={"code": "CAD", "name": "Canadian Dollar",
+ "symbol": "$", "is_active": True},
+ )
+ duplicate = api_client.post(
+ "/api/currencies/",
+ json={"code": "CAD", "name": "Canadian Dollar",
+ "symbol": "$", "is_active": True},
+ )
+ assert duplicate.status_code == 409
+
+
+def test_update_currency_fields(api_client, db_session):
+ api_client.post(
+ "/api/currencies/",
+ json={"code": "GBP", "name": "British Pound",
+ "symbol": "£", "is_active": True},
+ )
+
+ response = api_client.put(
+ "/api/currencies/GBP",
+ json={"name": "Pound Sterling", "symbol": "£", "is_active": False},
+ )
+ assert response.status_code == 200
+ data = response.json()
+ _assert_currency(data, "GBP", "Pound Sterling", "£", False)
+
+
+def test_toggle_currency_activation(api_client, db_session):
+ api_client.post(
+ "/api/currencies/",
+ json={"code": "AUD", "name": "Australian Dollar",
+ "symbol": "A$", "is_active": True},
+ )
+
+ response = api_client.patch(
+ "/api/currencies/AUD/activation",
+ json={"is_active": False},
+ )
+ assert response.status_code == 200
+ data = response.json()
+ _assert_currency(data, "AUD", "Australian Dollar", "A$", False)
+
+
+def test_default_currency_cannot_be_deactivated(api_client, db_session):
+ api_client.get("/api/currencies/")
+ response = api_client.patch(
+ "/api/currencies/USD/activation",
+ json={"is_active": False},
+ )
+ assert response.status_code == 400
+ assert response.json()[
+ "detail"] == "The default currency cannot be deactivated."
diff --git a/tests/unit/test_setup_database.py b/tests/unit/test_setup_database.py
new file mode 100644
index 0000000..c67e1ab
--- /dev/null
+++ b/tests/unit/test_setup_database.py
@@ -0,0 +1,459 @@
+import argparse
+from unittest import mock
+
+import psycopg2
+import pytest
+from psycopg2 import errors as psycopg_errors
+
+import scripts.setup_database as setup_db_module
+
+from scripts import seed_data
+from scripts.setup_database import DatabaseConfig, DatabaseSetup
+
+
+@pytest.fixture()
+def mock_config() -> DatabaseConfig:
+ return DatabaseConfig(
+ driver="postgresql",
+ host="localhost",
+ port=5432,
+ database="calminer_test",
+ user="calminer",
+ password="secret",
+ schema="public",
+ admin_user="postgres",
+ admin_password="secret",
+ )
+
+
+@pytest.fixture()
+def setup_instance(mock_config: DatabaseConfig) -> DatabaseSetup:
+ return DatabaseSetup(mock_config, dry_run=True)
+
+
+def test_seed_baseline_data_dry_run_skips_verification(setup_instance: DatabaseSetup) -> None:
+ with mock.patch("scripts.seed_data.run_with_namespace") as seed_run, mock.patch.object(
+ setup_instance, "_verify_seeded_data"
+ ) as verify_mock:
+ setup_instance.seed_baseline_data(dry_run=True)
+
+ seed_run.assert_called_once()
+ namespace_arg = seed_run.call_args[0][0]
+ assert isinstance(namespace_arg, argparse.Namespace)
+ assert namespace_arg.dry_run is True
+ assert namespace_arg.currencies is True
+ assert namespace_arg.units is True
+ assert seed_run.call_args.kwargs["config"] is setup_instance.config
+ verify_mock.assert_not_called()
+
+
+def test_seed_baseline_data_invokes_verification(setup_instance: DatabaseSetup) -> None:
+ expected_currencies = {code for code, *_ in seed_data.CURRENCY_SEEDS}
+ expected_units = {code for code, *_ in seed_data.MEASUREMENT_UNIT_SEEDS}
+
+ with mock.patch("scripts.seed_data.run_with_namespace") as seed_run, mock.patch.object(
+ setup_instance, "_verify_seeded_data"
+ ) as verify_mock:
+ setup_instance.seed_baseline_data(dry_run=False)
+
+ seed_run.assert_called_once()
+ namespace_arg = seed_run.call_args[0][0]
+ assert isinstance(namespace_arg, argparse.Namespace)
+ assert namespace_arg.dry_run is False
+ assert seed_run.call_args.kwargs["config"] is setup_instance.config
+ verify_mock.assert_called_once_with(
+ expected_currency_codes=expected_currencies,
+ expected_unit_codes=expected_units,
+ )
+
+
+def test_run_migrations_applies_baseline_when_missing(mock_config: DatabaseConfig, tmp_path) -> None:
+ setup_instance = DatabaseSetup(mock_config, dry_run=False)
+
+ baseline = tmp_path / "000_base.sql"
+ baseline.write_text("SELECT 1;", encoding="utf-8")
+ other_migration = tmp_path / "20251022_add_other.sql"
+ other_migration.write_text("SELECT 2;", encoding="utf-8")
+
+ migration_calls: list[str] = []
+
+ def capture_migration(cursor, schema_name: str, path):
+ migration_calls.append(path.name)
+ return path.name
+
+ connection_mock = mock.MagicMock()
+ connection_mock.__enter__.return_value = connection_mock
+ cursor_context = mock.MagicMock()
+ cursor_mock = mock.MagicMock()
+ cursor_context.__enter__.return_value = cursor_mock
+ connection_mock.cursor.return_value = cursor_context
+
+ with mock.patch.object(
+ setup_instance, "_application_connection", return_value=connection_mock
+ ), mock.patch.object(
+ setup_instance, "_migrations_table_exists", return_value=True
+ ), mock.patch.object(
+ setup_instance, "_fetch_applied_migrations", return_value=set()
+ ), mock.patch.object(
+ setup_instance, "_apply_migration_file", side_effect=capture_migration
+ ) as apply_mock:
+ setup_instance.run_migrations(tmp_path)
+
+ assert apply_mock.call_count == 1
+ assert migration_calls == ["000_base.sql"]
+ legacy_marked = any(
+ call.args[1] == ("20251022_add_other.sql",)
+ for call in cursor_mock.execute.call_args_list
+ if len(call.args) == 2
+ )
+ assert legacy_marked
+
+
+def test_run_migrations_noop_when_all_files_already_applied(
+ mock_config: DatabaseConfig, tmp_path
+) -> None:
+ setup_instance = DatabaseSetup(mock_config, dry_run=False)
+
+ baseline = tmp_path / "000_base.sql"
+ baseline.write_text("SELECT 1;", encoding="utf-8")
+ other_migration = tmp_path / "20251022_add_other.sql"
+ other_migration.write_text("SELECT 2;", encoding="utf-8")
+
+ connection_mock, cursor_mock = _connection_with_cursor()
+
+ with mock.patch.object(
+ setup_instance, "_application_connection", return_value=connection_mock
+ ), mock.patch.object(
+ setup_instance, "_migrations_table_exists", return_value=True
+ ), mock.patch.object(
+ setup_instance,
+ "_fetch_applied_migrations",
+ return_value={"000_base.sql", "20251022_add_other.sql"},
+ ), mock.patch.object(
+ setup_instance, "_apply_migration_file"
+ ) as apply_mock:
+ setup_instance.run_migrations(tmp_path)
+
+ apply_mock.assert_not_called()
+ cursor_mock.execute.assert_not_called()
+
+
+def _connection_with_cursor() -> tuple[mock.MagicMock, mock.MagicMock]:
+ connection_mock = mock.MagicMock()
+ connection_mock.__enter__.return_value = connection_mock
+ cursor_context = mock.MagicMock()
+ cursor_mock = mock.MagicMock()
+ cursor_context.__enter__.return_value = cursor_mock
+ connection_mock.cursor.return_value = cursor_context
+ return connection_mock, cursor_mock
+
+
+def test_verify_seeded_data_raises_when_currency_missing(mock_config: DatabaseConfig) -> None:
+ setup_instance = DatabaseSetup(mock_config, dry_run=False)
+ connection_mock, cursor_mock = _connection_with_cursor()
+ cursor_mock.fetchall.return_value = [("USD", True)]
+
+ with mock.patch.object(setup_instance, "_application_connection", return_value=connection_mock):
+ with pytest.raises(RuntimeError) as exc:
+ setup_instance._verify_seeded_data(
+ expected_currency_codes={"USD", "EUR"},
+ expected_unit_codes=set(),
+ )
+
+ assert "EUR" in str(exc.value)
+
+
+def test_verify_seeded_data_raises_when_default_currency_inactive(mock_config: DatabaseConfig) -> None:
+ setup_instance = DatabaseSetup(mock_config, dry_run=False)
+ connection_mock, cursor_mock = _connection_with_cursor()
+ cursor_mock.fetchall.return_value = [("USD", False)]
+
+ with mock.patch.object(setup_instance, "_application_connection", return_value=connection_mock):
+ with pytest.raises(RuntimeError) as exc:
+ setup_instance._verify_seeded_data(
+ expected_currency_codes={"USD"},
+ expected_unit_codes=set(),
+ )
+
+ assert "inactive" in str(exc.value)
+
+
+def test_verify_seeded_data_raises_when_units_missing(mock_config: DatabaseConfig) -> None:
+ setup_instance = DatabaseSetup(mock_config, dry_run=False)
+ connection_mock, cursor_mock = _connection_with_cursor()
+ cursor_mock.fetchall.return_value = [("tonnes", True)]
+
+ with mock.patch.object(setup_instance, "_application_connection", return_value=connection_mock):
+ with pytest.raises(RuntimeError) as exc:
+ setup_instance._verify_seeded_data(
+ expected_currency_codes=set(),
+ expected_unit_codes={"tonnes", "liters"},
+ )
+
+ assert "liters" in str(exc.value)
+
+
+def test_verify_seeded_data_raises_when_measurement_table_missing(mock_config: DatabaseConfig) -> None:
+ setup_instance = DatabaseSetup(mock_config, dry_run=False)
+ connection_mock, cursor_mock = _connection_with_cursor()
+ cursor_mock.execute.side_effect = psycopg_errors.UndefinedTable("relation does not exist")
+
+ with mock.patch.object(setup_instance, "_application_connection", return_value=connection_mock):
+ with pytest.raises(RuntimeError) as exc:
+ setup_instance._verify_seeded_data(
+ expected_currency_codes=set(),
+ expected_unit_codes={"tonnes"},
+ )
+
+ assert "measurement_unit" in str(exc.value)
+ connection_mock.rollback.assert_called_once()
+
+
+def test_seed_baseline_data_rerun_uses_existing_records(
+ mock_config: DatabaseConfig,
+) -> None:
+ setup_instance = DatabaseSetup(mock_config, dry_run=False)
+
+ connection_mock, cursor_mock = _connection_with_cursor()
+
+ currency_rows = [(code, True) for code, *_ in seed_data.CURRENCY_SEEDS]
+ unit_rows = [(code, True) for code, *_ in seed_data.MEASUREMENT_UNIT_SEEDS]
+
+ cursor_mock.fetchall.side_effect = [
+ currency_rows,
+ unit_rows,
+ currency_rows,
+ unit_rows,
+ ]
+
+ with mock.patch.object(
+ setup_instance, "_application_connection", return_value=connection_mock
+ ), mock.patch("scripts.seed_data.run_with_namespace") as seed_run:
+ setup_instance.seed_baseline_data(dry_run=False)
+ setup_instance.seed_baseline_data(dry_run=False)
+
+ assert seed_run.call_count == 2
+ first_namespace = seed_run.call_args_list[0].args[0]
+ assert isinstance(first_namespace, argparse.Namespace)
+ assert first_namespace.dry_run is False
+ assert seed_run.call_args_list[0].kwargs["config"] is setup_instance.config
+ assert cursor_mock.execute.call_count == 4
+
+
+def test_ensure_database_raises_with_context(mock_config: DatabaseConfig) -> None:
+ setup_instance = DatabaseSetup(mock_config, dry_run=False)
+ connection_mock = mock.MagicMock()
+ cursor_mock = mock.MagicMock()
+ cursor_mock.fetchone.return_value = None
+ cursor_mock.execute.side_effect = [None, psycopg2.Error("create_fail")]
+ connection_mock.cursor.return_value = cursor_mock
+
+ with mock.patch.object(setup_instance, "_admin_connection", return_value=connection_mock):
+ with pytest.raises(RuntimeError) as exc:
+ setup_instance.ensure_database()
+
+ assert "Failed to create database" in str(exc.value)
+
+
+def test_ensure_role_raises_with_context_during_creation(mock_config: DatabaseConfig) -> None:
+ setup_instance = DatabaseSetup(mock_config, dry_run=False)
+
+ admin_conn, admin_cursor = _connection_with_cursor()
+ admin_cursor.fetchone.return_value = None
+ admin_cursor.execute.side_effect = [None, psycopg2.Error("role_fail")]
+
+ with mock.patch.object(
+ setup_instance,
+ "_admin_connection",
+ side_effect=[admin_conn],
+ ):
+ with pytest.raises(RuntimeError) as exc:
+ setup_instance.ensure_role()
+
+ assert "Failed to create role" in str(exc.value)
+
+
+def test_ensure_role_raises_with_context_during_privilege_grants(
+ mock_config: DatabaseConfig,
+) -> None:
+ setup_instance = DatabaseSetup(mock_config, dry_run=False)
+
+ admin_conn, admin_cursor = _connection_with_cursor()
+ admin_cursor.fetchone.return_value = (1,)
+
+ privilege_conn, privilege_cursor = _connection_with_cursor()
+ privilege_cursor.execute.side_effect = [psycopg2.Error("grant_fail")]
+
+ with mock.patch.object(
+ setup_instance,
+ "_admin_connection",
+ side_effect=[admin_conn, privilege_conn],
+ ):
+ with pytest.raises(RuntimeError) as exc:
+ setup_instance.ensure_role()
+
+ assert "Failed to grant privileges" in str(exc.value)
+
+
+def test_ensure_database_dry_run_skips_creation(mock_config: DatabaseConfig) -> None:
+ setup_instance = DatabaseSetup(mock_config, dry_run=True)
+
+ connection_mock = mock.MagicMock()
+ cursor_mock = mock.MagicMock()
+ cursor_mock.fetchone.return_value = None
+ connection_mock.cursor.return_value = cursor_mock
+
+ with mock.patch.object(setup_instance, "_admin_connection", return_value=connection_mock), mock.patch(
+ "scripts.setup_database.logger"
+ ) as logger_mock:
+ setup_instance.ensure_database()
+
+ # expect only existence check, no create attempt
+ cursor_mock.execute.assert_called_once()
+ logger_mock.info.assert_any_call(
+ "Dry run: would create database '%s'. Run without --dry-run to proceed.", mock_config.database
+ )
+
+
+def test_ensure_role_dry_run_skips_creation_and_grants(mock_config: DatabaseConfig) -> None:
+ setup_instance = DatabaseSetup(mock_config, dry_run=True)
+
+ admin_conn, admin_cursor = _connection_with_cursor()
+ admin_cursor.fetchone.return_value = None
+
+ with mock.patch.object(
+ setup_instance,
+ "_admin_connection",
+ side_effect=[admin_conn],
+ ) as conn_mock, mock.patch("scripts.setup_database.logger") as logger_mock:
+ setup_instance.ensure_role()
+
+ assert conn_mock.call_count == 1
+ admin_cursor.execute.assert_called_once()
+ logger_mock.info.assert_any_call(
+ "Dry run: would create role '%s'. Run without --dry-run to apply.", mock_config.user
+ )
+
+
+def test_register_rollback_skipped_when_dry_run(mock_config: DatabaseConfig) -> None:
+ setup_instance = DatabaseSetup(mock_config, dry_run=True)
+ setup_instance._register_rollback("noop", lambda: None)
+ assert setup_instance._rollback_actions == []
+
+
+def test_execute_rollbacks_runs_in_reverse_order(mock_config: DatabaseConfig) -> None:
+ setup_instance = DatabaseSetup(mock_config, dry_run=False)
+
+ calls: list[str] = []
+
+ def first_action() -> None:
+ calls.append("first")
+
+ def second_action() -> None:
+ calls.append("second")
+
+ setup_instance._register_rollback("first", first_action)
+ setup_instance._register_rollback("second", second_action)
+
+ with mock.patch("scripts.setup_database.logger"):
+ setup_instance.execute_rollbacks()
+
+ assert calls == ["second", "first"]
+ assert setup_instance._rollback_actions == []
+
+
+def test_ensure_database_registers_rollback_action(mock_config: DatabaseConfig) -> None:
+ setup_instance = DatabaseSetup(mock_config, dry_run=False)
+ connection_mock = mock.MagicMock()
+ cursor_mock = mock.MagicMock()
+ cursor_mock.fetchone.return_value = None
+ connection_mock.cursor.return_value = cursor_mock
+
+ with mock.patch.object(setup_instance, "_admin_connection", return_value=connection_mock), mock.patch.object(
+ setup_instance, "_register_rollback"
+ ) as register_mock, mock.patch.object(setup_instance, "_drop_database") as drop_mock:
+ setup_instance.ensure_database()
+ register_mock.assert_called_once()
+ label, action = register_mock.call_args[0]
+ assert "drop database" in label
+ action()
+ drop_mock.assert_called_once_with(mock_config.database)
+
+
+def test_ensure_role_registers_rollback_actions(mock_config: DatabaseConfig) -> None:
+ setup_instance = DatabaseSetup(mock_config, dry_run=False)
+
+ admin_conn, admin_cursor = _connection_with_cursor()
+ admin_cursor.fetchone.return_value = None
+ privilege_conn, privilege_cursor = _connection_with_cursor()
+
+ with mock.patch.object(
+ setup_instance,
+ "_admin_connection",
+ side_effect=[admin_conn, privilege_conn],
+ ), mock.patch.object(
+ setup_instance, "_register_rollback"
+ ) as register_mock, mock.patch.object(
+ setup_instance, "_drop_role"
+ ) as drop_mock, mock.patch.object(
+ setup_instance, "_revoke_role_privileges"
+ ) as revoke_mock:
+ setup_instance.ensure_role()
+ assert register_mock.call_count == 2
+ drop_label, drop_action = register_mock.call_args_list[0][0]
+ revoke_label, revoke_action = register_mock.call_args_list[1][0]
+
+ assert "drop role" in drop_label
+ assert "revoke privileges" in revoke_label
+
+ drop_action()
+ drop_mock.assert_called_once_with(mock_config.user)
+
+ revoke_action()
+ revoke_mock.assert_called_once()
+
+
+def test_main_triggers_rollbacks_on_failure(mock_config: DatabaseConfig) -> None:
+ args = argparse.Namespace(
+ ensure_database=True,
+ ensure_role=True,
+ ensure_schema=False,
+ initialize_schema=False,
+ run_migrations=False,
+ seed_data=False,
+ migrations_dir=None,
+ db_driver=None,
+ db_host=None,
+ db_port=None,
+ db_name=None,
+ db_user=None,
+ db_password=None,
+ db_schema=None,
+ admin_url=None,
+ admin_user=None,
+ admin_password=None,
+ admin_db=None,
+ dry_run=False,
+ verbose=0,
+ )
+
+ with mock.patch.object(setup_db_module, "parse_args", return_value=args), mock.patch.object(
+ setup_db_module.DatabaseConfig, "from_env", return_value=mock_config
+ ), mock.patch.object(
+ setup_db_module, "DatabaseSetup"
+ ) as setup_cls:
+ setup_instance = mock.MagicMock()
+ setup_instance.dry_run = False
+ setup_instance._rollback_actions = [
+ ("drop role", mock.MagicMock()),
+ ]
+ setup_instance.ensure_database.side_effect = RuntimeError("boom")
+ setup_instance.execute_rollbacks = mock.MagicMock()
+ setup_instance.clear_rollbacks = mock.MagicMock()
+ setup_cls.return_value = setup_instance
+
+ with pytest.raises(RuntimeError):
+ setup_db_module.main()
+
+ setup_instance.execute_rollbacks.assert_called_once()
+ setup_instance.clear_rollbacks.assert_called_once()