import argparse from unittest import mock import psycopg2 import pytest from psycopg2 import errors as psycopg_errors import scripts.setup_database as setup_db_module from scripts import seed_data from scripts.setup_database import DatabaseConfig, DatabaseSetup @pytest.fixture() def mock_config() -> DatabaseConfig: return DatabaseConfig( driver="postgresql", host="localhost", port=5432, database="calminer_test", user="calminer", password="secret", schema="public", admin_user="postgres", admin_password="secret", ) @pytest.fixture() def setup_instance(mock_config: DatabaseConfig) -> DatabaseSetup: return DatabaseSetup(mock_config, dry_run=True) def test_seed_baseline_data_dry_run_skips_verification( setup_instance: DatabaseSetup, ) -> None: with ( mock.patch("scripts.seed_data.run_with_namespace") as seed_run, mock.patch.object(setup_instance, "_verify_seeded_data") as verify_mock, ): setup_instance.seed_baseline_data(dry_run=True) seed_run.assert_called_once() namespace_arg = seed_run.call_args[0][0] assert isinstance(namespace_arg, argparse.Namespace) assert namespace_arg.dry_run is True assert namespace_arg.currencies is True assert namespace_arg.units is True assert seed_run.call_args.kwargs["config"] is setup_instance.config verify_mock.assert_not_called() def test_seed_baseline_data_invokes_verification( setup_instance: DatabaseSetup, ) -> None: expected_currencies = {code for code, *_ in seed_data.CURRENCY_SEEDS} expected_units = {code for code, *_ in seed_data.MEASUREMENT_UNIT_SEEDS} with ( mock.patch("scripts.seed_data.run_with_namespace") as seed_run, mock.patch.object(setup_instance, "_verify_seeded_data") as verify_mock, ): setup_instance.seed_baseline_data(dry_run=False) seed_run.assert_called_once() namespace_arg = seed_run.call_args[0][0] assert isinstance(namespace_arg, argparse.Namespace) assert namespace_arg.dry_run is False assert seed_run.call_args.kwargs["config"] is setup_instance.config verify_mock.assert_called_once_with( expected_currency_codes=expected_currencies, expected_unit_codes=expected_units, ) def test_run_migrations_applies_baseline_when_missing( mock_config: DatabaseConfig, tmp_path ) -> None: setup_instance = DatabaseSetup(mock_config, dry_run=False) baseline = tmp_path / "000_base.sql" baseline.write_text("SELECT 1;", encoding="utf-8") other_migration = tmp_path / "20251022_add_other.sql" other_migration.write_text("SELECT 2;", encoding="utf-8") migration_calls: list[str] = [] def capture_migration(cursor, schema_name: str, path): migration_calls.append(path.name) return path.name connection_mock = mock.MagicMock() connection_mock.__enter__.return_value = connection_mock cursor_context = mock.MagicMock() cursor_mock = mock.MagicMock() cursor_context.__enter__.return_value = cursor_mock connection_mock.cursor.return_value = cursor_context with ( mock.patch.object( setup_instance, "_application_connection", return_value=connection_mock, ), mock.patch.object( setup_instance, "_migrations_table_exists", return_value=True ), mock.patch.object( setup_instance, "_fetch_applied_migrations", return_value=set() ), mock.patch.object( setup_instance, "_apply_migration_file", side_effect=capture_migration, ) as apply_mock, ): setup_instance.run_migrations(tmp_path) assert apply_mock.call_count == 1 assert migration_calls == ["000_base.sql"] legacy_marked = any( call.args[1] == ("20251022_add_other.sql",) for call in cursor_mock.execute.call_args_list if len(call.args) == 2 ) assert legacy_marked def test_run_migrations_noop_when_all_files_already_applied( mock_config: DatabaseConfig, tmp_path ) -> None: setup_instance = DatabaseSetup(mock_config, dry_run=False) baseline = tmp_path / "000_base.sql" baseline.write_text("SELECT 1;", encoding="utf-8") other_migration = tmp_path / "20251022_add_other.sql" other_migration.write_text("SELECT 2;", encoding="utf-8") connection_mock, cursor_mock = _connection_with_cursor() with ( mock.patch.object( setup_instance, "_application_connection", return_value=connection_mock, ), mock.patch.object( setup_instance, "_migrations_table_exists", return_value=True ), mock.patch.object( setup_instance, "_fetch_applied_migrations", return_value={"000_base.sql", "20251022_add_other.sql"}, ), mock.patch.object( setup_instance, "_apply_migration_file" ) as apply_mock, ): setup_instance.run_migrations(tmp_path) apply_mock.assert_not_called() cursor_mock.execute.assert_not_called() def _connection_with_cursor() -> tuple[mock.MagicMock, mock.MagicMock]: connection_mock = mock.MagicMock() connection_mock.__enter__.return_value = connection_mock cursor_context = mock.MagicMock() cursor_mock = mock.MagicMock() cursor_context.__enter__.return_value = cursor_mock connection_mock.cursor.return_value = cursor_context return connection_mock, cursor_mock def test_verify_seeded_data_raises_when_currency_missing( mock_config: DatabaseConfig, ) -> None: setup_instance = DatabaseSetup(mock_config, dry_run=False) connection_mock, cursor_mock = _connection_with_cursor() cursor_mock.fetchall.return_value = [("USD", True)] with mock.patch.object( setup_instance, "_application_connection", return_value=connection_mock ): with pytest.raises(RuntimeError) as exc: setup_instance._verify_seeded_data( expected_currency_codes={"USD", "EUR"}, expected_unit_codes=set(), ) assert "EUR" in str(exc.value) def test_verify_seeded_data_raises_when_default_currency_inactive( mock_config: DatabaseConfig, ) -> None: setup_instance = DatabaseSetup(mock_config, dry_run=False) connection_mock, cursor_mock = _connection_with_cursor() cursor_mock.fetchall.return_value = [("USD", False)] with mock.patch.object( setup_instance, "_application_connection", return_value=connection_mock ): with pytest.raises(RuntimeError) as exc: setup_instance._verify_seeded_data( expected_currency_codes={"USD"}, expected_unit_codes=set(), ) assert "inactive" in str(exc.value) def test_verify_seeded_data_raises_when_units_missing( mock_config: DatabaseConfig, ) -> None: setup_instance = DatabaseSetup(mock_config, dry_run=False) connection_mock, cursor_mock = _connection_with_cursor() cursor_mock.fetchall.return_value = [("tonnes", True)] with mock.patch.object( setup_instance, "_application_connection", return_value=connection_mock ): with pytest.raises(RuntimeError) as exc: setup_instance._verify_seeded_data( expected_currency_codes=set(), expected_unit_codes={"tonnes", "liters"}, ) assert "liters" in str(exc.value) def test_verify_seeded_data_raises_when_measurement_table_missing( mock_config: DatabaseConfig, ) -> None: setup_instance = DatabaseSetup(mock_config, dry_run=False) connection_mock, cursor_mock = _connection_with_cursor() cursor_mock.execute.side_effect = psycopg_errors.UndefinedTable( "relation does not exist" ) with mock.patch.object( setup_instance, "_application_connection", return_value=connection_mock ): with pytest.raises(RuntimeError) as exc: setup_instance._verify_seeded_data( expected_currency_codes=set(), expected_unit_codes={"tonnes"}, ) assert "measurement_unit" in str(exc.value) connection_mock.rollback.assert_called_once() def test_seed_baseline_data_rerun_uses_existing_records( mock_config: DatabaseConfig, ) -> None: setup_instance = DatabaseSetup(mock_config, dry_run=False) connection_mock, cursor_mock = _connection_with_cursor() currency_rows = [(code, True) for code, *_ in seed_data.CURRENCY_SEEDS] unit_rows = [(code, True) for code, *_ in seed_data.MEASUREMENT_UNIT_SEEDS] cursor_mock.fetchall.side_effect = [ currency_rows, unit_rows, currency_rows, unit_rows, ] with ( mock.patch.object( setup_instance, "_application_connection", return_value=connection_mock, ), mock.patch("scripts.seed_data.run_with_namespace") as seed_run, ): setup_instance.seed_baseline_data(dry_run=False) setup_instance.seed_baseline_data(dry_run=False) assert seed_run.call_count == 2 first_namespace = seed_run.call_args_list[0].args[0] assert isinstance(first_namespace, argparse.Namespace) assert first_namespace.dry_run is False assert seed_run.call_args_list[0].kwargs["config"] is setup_instance.config assert cursor_mock.execute.call_count == 4 def test_ensure_database_raises_with_context( mock_config: DatabaseConfig, ) -> None: setup_instance = DatabaseSetup(mock_config, dry_run=False) connection_mock = mock.MagicMock() cursor_mock = mock.MagicMock() cursor_mock.fetchone.return_value = None cursor_mock.execute.side_effect = [None, psycopg2.Error("create_fail")] connection_mock.cursor.return_value = cursor_mock with mock.patch.object( setup_instance, "_admin_connection", return_value=connection_mock ): with pytest.raises(RuntimeError) as exc: setup_instance.ensure_database() assert "Failed to create database" in str(exc.value) def test_ensure_role_raises_with_context_during_creation( mock_config: DatabaseConfig, ) -> None: setup_instance = DatabaseSetup(mock_config, dry_run=False) admin_conn, admin_cursor = _connection_with_cursor() admin_cursor.fetchone.return_value = None admin_cursor.execute.side_effect = [None, psycopg2.Error("role_fail")] with mock.patch.object( setup_instance, "_admin_connection", side_effect=[admin_conn], ): with pytest.raises(RuntimeError) as exc: setup_instance.ensure_role() assert "Failed to create role" in str(exc.value) def test_ensure_role_raises_with_context_during_privilege_grants( mock_config: DatabaseConfig, ) -> None: setup_instance = DatabaseSetup(mock_config, dry_run=False) admin_conn, admin_cursor = _connection_with_cursor() admin_cursor.fetchone.return_value = (1,) privilege_conn, privilege_cursor = _connection_with_cursor() privilege_cursor.execute.side_effect = [psycopg2.Error("grant_fail")] with mock.patch.object( setup_instance, "_admin_connection", side_effect=[admin_conn, privilege_conn], ): with pytest.raises(RuntimeError) as exc: setup_instance.ensure_role() assert "Failed to grant privileges" in str(exc.value) def test_ensure_database_dry_run_skips_creation( mock_config: DatabaseConfig, ) -> None: setup_instance = DatabaseSetup(mock_config, dry_run=True) connection_mock = mock.MagicMock() cursor_mock = mock.MagicMock() cursor_mock.fetchone.return_value = None connection_mock.cursor.return_value = cursor_mock with ( mock.patch.object( setup_instance, "_admin_connection", return_value=connection_mock ), mock.patch("scripts.setup_database.logger") as logger_mock, ): setup_instance.ensure_database() # expect only existence check, no create attempt cursor_mock.execute.assert_called_once() logger_mock.info.assert_any_call( "Dry run: would create database '%s'. Run without --dry-run to proceed.", mock_config.database, ) def test_ensure_role_dry_run_skips_creation_and_grants( mock_config: DatabaseConfig, ) -> None: setup_instance = DatabaseSetup(mock_config, dry_run=True) admin_conn, admin_cursor = _connection_with_cursor() admin_cursor.fetchone.return_value = None with ( mock.patch.object( setup_instance, "_admin_connection", side_effect=[admin_conn], ) as conn_mock, mock.patch("scripts.setup_database.logger") as logger_mock, ): setup_instance.ensure_role() assert conn_mock.call_count == 1 admin_cursor.execute.assert_called_once() logger_mock.info.assert_any_call( "Dry run: would create role '%s'. Run without --dry-run to apply.", mock_config.user, ) def test_register_rollback_skipped_when_dry_run( mock_config: DatabaseConfig, ) -> None: setup_instance = DatabaseSetup(mock_config, dry_run=True) setup_instance._register_rollback("noop", lambda: None) assert setup_instance._rollback_actions == [] def test_execute_rollbacks_runs_in_reverse_order( mock_config: DatabaseConfig, ) -> None: setup_instance = DatabaseSetup(mock_config, dry_run=False) calls: list[str] = [] def first_action() -> None: calls.append("first") def second_action() -> None: calls.append("second") setup_instance._register_rollback("first", first_action) setup_instance._register_rollback("second", second_action) with mock.patch("scripts.setup_database.logger"): setup_instance.execute_rollbacks() assert calls == ["second", "first"] assert setup_instance._rollback_actions == [] def test_ensure_database_registers_rollback_action( mock_config: DatabaseConfig, ) -> None: setup_instance = DatabaseSetup(mock_config, dry_run=False) connection_mock = mock.MagicMock() cursor_mock = mock.MagicMock() cursor_mock.fetchone.return_value = None connection_mock.cursor.return_value = cursor_mock with ( mock.patch.object( setup_instance, "_admin_connection", return_value=connection_mock ), mock.patch.object( setup_instance, "_register_rollback" ) as register_mock, mock.patch.object(setup_instance, "_drop_database") as drop_mock, ): setup_instance.ensure_database() register_mock.assert_called_once() label, action = register_mock.call_args[0] assert "drop database" in label action() drop_mock.assert_called_once_with(mock_config.database) def test_ensure_role_registers_rollback_actions( mock_config: DatabaseConfig, ) -> None: setup_instance = DatabaseSetup(mock_config, dry_run=False) admin_conn, admin_cursor = _connection_with_cursor() admin_cursor.fetchone.return_value = None privilege_conn, privilege_cursor = _connection_with_cursor() with ( mock.patch.object( setup_instance, "_admin_connection", side_effect=[admin_conn, privilege_conn], ), mock.patch.object( setup_instance, "_register_rollback" ) as register_mock, mock.patch.object(setup_instance, "_drop_role") as drop_mock, mock.patch.object( setup_instance, "_revoke_role_privileges" ) as revoke_mock, ): setup_instance.ensure_role() assert register_mock.call_count == 2 drop_label, drop_action = register_mock.call_args_list[0][0] revoke_label, revoke_action = register_mock.call_args_list[1][0] assert "drop role" in drop_label assert "revoke privileges" in revoke_label drop_action() drop_mock.assert_called_once_with(mock_config.user) revoke_action() revoke_mock.assert_called_once() def test_main_triggers_rollbacks_on_failure( mock_config: DatabaseConfig, ) -> None: args = argparse.Namespace( ensure_database=True, ensure_role=True, ensure_schema=False, initialize_schema=False, run_migrations=False, seed_data=False, migrations_dir=None, db_driver=None, db_host=None, db_port=None, db_name=None, db_user=None, db_password=None, db_schema=None, admin_url=None, admin_user=None, admin_password=None, admin_db=None, dry_run=False, verbose=0, ) with ( mock.patch.object(setup_db_module, "parse_args", return_value=args), mock.patch.object( setup_db_module.DatabaseConfig, "from_env", return_value=mock_config ), mock.patch.object(setup_db_module, "DatabaseSetup") as setup_cls, ): setup_instance = mock.MagicMock() setup_instance.dry_run = False setup_instance._rollback_actions = [ ("drop role", mock.MagicMock()), ] setup_instance.ensure_database.side_effect = RuntimeError("boom") setup_instance.execute_rollbacks = mock.MagicMock() setup_instance.clear_rollbacks = mock.MagicMock() setup_cls.return_value = setup_instance with pytest.raises(RuntimeError): setup_db_module.main() setup_instance.execute_rollbacks.assert_called_once() setup_instance.clear_rollbacks.assert_called_once()