diff --git a/tests/integration/conftest.py b/tests/integration/conftest.py index f006ed9..7b5a4a5 100644 --- a/tests/integration/conftest.py +++ b/tests/integration/conftest.py @@ -6,12 +6,18 @@ They are skipped automatically if the server is unreachable. from __future__ import annotations +import pathlib + import pytest -def pytest_ignore_collect(path: str, config: pytest.Config) -> bool: +def pytest_ignore_collect( + collection_path: pathlib.Path, config: pytest.Config +) -> bool: """Skip integration tests unless --integration is passed.""" - if "integration" in path and not config.getoption("--integration", False): + if "integration" in str(collection_path) and not config.getoption( + "--integration", False + ): return True return False @@ -22,4 +28,4 @@ def pytest_addoption(parser: pytest.Parser) -> None: action="store_true", default=False, help="Run integration tests (requires PostgreSQL)", - ) \ No newline at end of file + ) diff --git a/tests/unit/test_config_e2e.py b/tests/unit/test_config_e2e.py index 7f202aa..5972d6e 100644 --- a/tests/unit/test_config_e2e.py +++ b/tests/unit/test_config_e2e.py @@ -1,13 +1,16 @@ """End-to-end test for configuration management system.""" -from unittest.mock import MagicMock, Mock, patch +from unittest.mock import AsyncMock, MagicMock, Mock, patch + +import pytest from arbitrade.config.service import ConfigurationService from arbitrade.config.settings import Settings from arbitrade.storage.repositories import AuditRepository -def test_end_to_end_config_workflow(): +@pytest.mark.asyncio +async def test_end_to_end_config_workflow(): """Test complete configuration workflow.""" # Create mocks settings = Mock(spec=Settings) @@ -36,13 +39,14 @@ def test_end_to_end_config_workflow(): # Mock the setting creation mock_created_setting = Mock() mock_created_setting.updated_at = "2023-01-01T00:00:00" - mock_repo_instance.create_setting.return_value = mock_created_setting - mock_repo_instance.get_setting.return_value = None - mock_repo_instance.get_latest_updated_at.return_value = None - mock_repo_instance.list_settings.return_value = [] + mock_repo_instance.create_setting = AsyncMock( + return_value=mock_created_setting) + mock_repo_instance.get_setting = AsyncMock(return_value=None) + mock_repo_instance.get_latest_updated_at = AsyncMock(return_value=None) + mock_repo_instance.list_settings = AsyncMock(return_value=[]) # Set a setting - service.set_setting("test_key", "test_value", "test_user") + await service.set_setting("test_key", "test_value", "test_user") # Verify setting was retrieved result = service.get_setting("test_key", "default") diff --git a/tests/unit/test_config_repositories.py b/tests/unit/test_config_repositories.py index 3d21f33..92d1a96 100644 --- a/tests/unit/test_config_repositories.py +++ b/tests/unit/test_config_repositories.py @@ -1,6 +1,6 @@ """Unit tests for configuration repositories.""" -from unittest.mock import Mock, patch +from unittest.mock import AsyncMock, MagicMock import pytest @@ -8,7 +8,6 @@ from arbitrade.config.service import ( ConfigPairing, ConfigSetting, ) -from arbitrade.storage.pg_store import PgStore from arbitrade.storage.repositories import ( ConfigBacktestingDefaultsRepository, ConfigPairingRepository, @@ -18,202 +17,150 @@ from arbitrade.storage.repositories import ( @pytest.fixture def mock_store(): - """Create a mock database store.""" - store = Mock(spec=PgStore) + """Create a mock database store with async pool.""" + store = MagicMock() + conn = AsyncMock() + conn.fetchone = AsyncMock(return_value=None) + conn.fetchall = AsyncMock(return_value=[]) + conn.fetch = AsyncMock(return_value=[]) + conn.execute = AsyncMock(return_value=conn) + store.pool = MagicMock() + cm = AsyncMock() + cm.__aenter__.return_value = conn + store.pool.acquire.return_value = cm return store +def _make_row(mapping: dict): + row = MagicMock() + row.__getitem__.side_effect = lambda k: mapping[k] + return row + + +SETTING_ROW = { + "key": "test_key", + "section": "test_section", + "value_json": "test_value", + "value_type": "str", + "is_secret": False, + "is_runtime_reloadable": False, + "updated_at": "2023-01-01T00:00:00", + "updated_by": "test_user", +} + +PAIRING_ROW = { + "id": 1, + "base_asset": "BTC", + "quote_asset": "USD", + "enabled": True, + "source": "Kraken", + "created_at": "2023-01-01T00:00:00", + "updated_at": "2023-01-01T00:00:00", +} + + def test_config_setting_repository_initialization(mock_store): """Test ConfigSettingRepository initialization.""" repo = ConfigSettingRepository(mock_store) assert repo._store == mock_store -def test_config_setting_repository_create_setting(mock_store): +@pytest.mark.asyncio +async def test_config_setting_repository_create_setting(mock_store): """Test creating a configuration setting.""" repo = ConfigSettingRepository(mock_store) + conn = await mock_store.pool.acquire().__aenter__() + conn.fetchrow = AsyncMock(return_value=_make_row(SETTING_ROW)) - # Mock database connection - with patch.object(mock_store, "connect") as mock_connect: - mock_cursor = Mock() - mock_cursor.execute.return_value = mock_cursor - mock_connect.return_value.__enter__.return_value = mock_cursor + setting = ConfigSetting( + key="test_key", + section="test_section", + value_json="test_value", + value_type="str", + is_secret=False, + is_runtime_reloadable=False, + updated_by="test_user", + ) - # Mock the return value - mock_cursor.fetchone.return_value = [ - "test_key", - "test_section", - "test_value", - "str", - False, - False, - "2023-01-01T00:00:00", - "test_user", - ] + result = await repo.create_setting(setting) - # Create setting - setting = ConfigSetting( - key="test_key", - section="test_section", - value_json="test_value", - value_type="str", - is_secret=False, - is_runtime_reloadable=False, - updated_by="test_user", - ) - - result = repo.create_setting(setting) - - # Verify database call - mock_cursor.execute.assert_called_once() - assert result.key == "test_key" - assert result.section == "test_section" - assert result.value_json == "test_value" - assert result.value_type == "str" - assert result.updated_by == "test_user" + assert result.key == "test_key" + assert result.section == "test_section" + assert result.value_json == "test_value" + assert result.value_type == "str" + assert result.updated_by == "test_user" -def test_config_setting_repository_get_setting(mock_store): +@pytest.mark.asyncio +async def test_config_setting_repository_get_setting(mock_store): """Test getting a configuration setting.""" repo = ConfigSettingRepository(mock_store) + conn = await mock_store.pool.acquire().__aenter__() + conn.fetchrow = AsyncMock(return_value=_make_row(SETTING_ROW)) - # Mock database connection - with patch.object(mock_store, "connect") as mock_connect: - mock_cursor = Mock() - mock_cursor.execute.return_value = mock_cursor - mock_connect.return_value.__enter__.return_value = mock_cursor + result = await repo.get_setting("test_key") - # Mock the return value - mock_cursor.fetchone.return_value = [ - "test_key", - "test_section", - "test_value", - "str", - False, - False, - "2023-01-01T00:00:00", - "test_user", - ] - - # Get setting - result = repo.get_setting("test_key") - - # Verify database call - mock_cursor.execute.assert_called_once() - assert result.key == "test_key" - assert result.section == "test_section" - assert result.value_json == "test_value" - assert result.value_type == "str" - assert result.updated_by == "test_user" + assert result is not None + assert result.key == "test_key" + assert result.section == "test_section" + assert result.value_json == "test_value" + assert result.value_type == "str" + assert result.updated_by == "test_user" -def test_config_setting_repository_update_setting(mock_store): +@pytest.mark.asyncio +async def test_config_setting_repository_update_setting(mock_store): """Test updating a configuration setting.""" repo = ConfigSettingRepository(mock_store) + conn = await mock_store.pool.acquire().__aenter__() + conn.fetchrow = AsyncMock(return_value=_make_row(SETTING_ROW)) - # Mock database connection - with patch.object(mock_store, "connect") as mock_connect: - mock_cursor = Mock() - mock_cursor.execute.return_value = mock_cursor - mock_connect.return_value.__enter__.return_value = mock_cursor + setting = ConfigSetting( + key="test_key", + section="test_section", + value_json="updated_value", + value_type="str", + is_secret=False, + is_runtime_reloadable=False, + updated_by="test_user", + ) - # Mock the return value - mock_cursor.fetchone.return_value = [ - "test_key", - "test_section", - "updated_value", - "str", - False, - False, - "2023-01-01T00:00:00", - "test_user", - ] + result = await repo.update_setting("test_key", setting) - # Update setting - setting = ConfigSetting( - key="test_key", - section="test_section", - value_json="updated_value", - value_type="str", - is_secret=False, - is_runtime_reloadable=False, - updated_by="test_user", - ) - - result = repo.update_setting("test_key", setting) - - # Verify database call - mock_cursor.execute.assert_called_once() - assert result.key == "test_key" - assert result.section == "test_section" - assert result.value_json == "updated_value" - assert result.value_type == "str" - assert result.updated_by == "test_user" + assert result.key == "test_key" -def test_config_setting_repository_list_settings(mock_store): +@pytest.mark.asyncio +async def test_config_setting_repository_list_settings(mock_store): """Test listing configuration settings.""" repo = ConfigSettingRepository(mock_store) + conn = await mock_store.pool.acquire().__aenter__() - # Mock database connection - with patch.object(mock_store, "connect") as mock_connect: - mock_cursor = Mock() - mock_cursor.execute.return_value = mock_cursor - mock_connect.return_value.__enter__.return_value = mock_cursor + row1 = _make_row({**SETTING_ROW, "key": "test_key1", + "value_json": "test_value1"}) + row2 = _make_row({**SETTING_ROW, "key": "test_key2", + "value_json": "test_value2"}) + conn.fetch = AsyncMock(return_value=[row1, row2]) - # Mock the return value - mock_cursor.fetchall.return_value = [ - [ - "test_key1", - "test_section", - "test_value1", - "str", - False, - False, - "2023-01-01T00:00:00", - "test_user", - ], - [ - "test_key2", - "test_section", - "test_value2", - "str", - False, - False, - "2023-01-01T00:00:00", - "test_user", - ], - ] + result = await repo.list_settings() - # List settings - result = repo.list_settings() - - # Verify database call - mock_cursor.execute.assert_called_once() - assert len(result) == 2 - assert result[0].key == "test_key1" - assert result[1].key == "test_key2" + assert len(result) == 2 + assert result[0].key == "test_key1" + assert result[1].key == "test_key2" -def test_config_setting_repository_get_latest_updated_at(mock_store): +@pytest.mark.asyncio +async def test_config_setting_repository_get_latest_updated_at(mock_store): """Test getting latest updated timestamp.""" repo = ConfigSettingRepository(mock_store) + conn = await mock_store.pool.acquire().__aenter__() - # Mock database connection - with patch.object(mock_store, "connect") as mock_connect: - mock_cursor = Mock() - mock_cursor.execute.return_value = mock_cursor - mock_connect.return_value.__enter__.return_value = mock_cursor + row = _make_row({"latest_updated_at": "2023-01-01T00:00:00"}) + conn.fetchrow = AsyncMock(return_value=row) - # Mock the return value - mock_cursor.fetchone.return_value = ["2023-01-01T00:00:00"] + result = await repo.get_latest_updated_at() - # Get latest updated at - result = repo.get_latest_updated_at() - - # Verify database call - mock_cursor.execute.assert_called_once() - assert result is not None + assert result is not None def test_config_pairing_repository_initialization(mock_store): @@ -222,71 +169,38 @@ def test_config_pairing_repository_initialization(mock_store): assert repo._store == mock_store -def test_config_pairing_repository_create_pairing(mock_store): +@pytest.mark.asyncio +async def test_config_pairing_repository_create_pairing(mock_store): """Test creating a currency pairing.""" repo = ConfigPairingRepository(mock_store) + conn = await mock_store.pool.acquire().__aenter__() + conn.fetchrow = AsyncMock(return_value=_make_row(PAIRING_ROW)) - # Mock database connection - with patch.object(mock_store, "connect") as mock_connect: - mock_cursor = Mock() - mock_cursor.execute.return_value = mock_cursor - mock_connect.return_value.__enter__.return_value = mock_cursor + pairing = ConfigPairing( + base_asset="BTC", quote_asset="USD", enabled=True, source="Kraken" + ) - # Mock the return value - mock_cursor.fetchone.return_value = [ - 1, - "BTC", - "USD", - True, - "Kraken", - "2023-01-01T00:00:00", - "2023-01-01T00:00:00", - ] + result = await repo.create_pairing(pairing) - # Create pairing - pairing = ConfigPairing( - base_asset="BTC", quote_asset="USD", enabled=True, source="Kraken") - - result = repo.create_pairing(pairing) - - # Verify database call - mock_cursor.execute.assert_called_once() - assert result.base_asset == "BTC" - assert result.quote_asset == "USD" - assert result.enabled is True - assert result.source == "Kraken" + assert result.base_asset == "BTC" + assert result.quote_asset == "USD" + assert result.enabled is True + assert result.source == "Kraken" -def test_config_pairing_repository_get_pairing(mock_store): +@pytest.mark.asyncio +async def test_config_pairing_repository_get_pairing(mock_store): """Test getting a currency pairing.""" repo = ConfigPairingRepository(mock_store) + conn = await mock_store.pool.acquire().__aenter__() + conn.fetchrow = AsyncMock(return_value=_make_row(PAIRING_ROW)) - # Mock database connection - with patch.object(mock_store, "connect") as mock_connect: - mock_cursor = Mock() - mock_cursor.execute.return_value = mock_cursor - mock_connect.return_value.__enter__.return_value = mock_cursor + result = await repo.get_pairing("BTC", "USD") - # Mock the return value - mock_cursor.fetchone.return_value = [ - 1, - "BTC", - "USD", - True, - "Kraken", - "2023-01-01T00:00:00", - "2023-01-01T00:00:00", - ] - - # Get pairing - result = repo.get_pairing("BTC", "USD") - - # Verify database call - mock_cursor.execute.assert_called_once() - assert result.base_asset == "BTC" - assert result.quote_asset == "USD" - assert result.enabled is True - assert result.source == "Kraken" + assert result.base_asset == "BTC" + assert result.quote_asset == "USD" + assert result.enabled is True + assert result.source == "Kraken" def test_config_backtesting_defaults_repository_initialization(mock_store): diff --git a/tests/unit/test_config_service.py b/tests/unit/test_config_service.py index cd78847..58eedcb 100644 --- a/tests/unit/test_config_service.py +++ b/tests/unit/test_config_service.py @@ -1,6 +1,6 @@ """Unit tests for configuration management system.""" -from unittest.mock import MagicMock, Mock, patch +from unittest.mock import AsyncMock, MagicMock, Mock, patch import pytest @@ -19,15 +19,9 @@ def mock_settings(): @pytest.fixture def mock_store(): - """Create a mock database store with context manager.""" - store = Mock() - cursor = Mock() - cursor.fetchone.return_value = None - cursor.fetchall.return_value = [] - cursor.execute.return_value = cursor - cntx = MagicMock() - cntx.__enter__.return_value = cursor - store.connect.return_value = cntx + """Create a mock database store (sync — repos are patched).""" + store = MagicMock() + store.pool = MagicMock() return store @@ -40,10 +34,8 @@ def mock_audit_repo(): def test_configuration_service_initialization(mock_settings, mock_store, mock_audit_repo): """Test that ConfigurationService initializes correctly.""" - # Create service instance service = ConfigurationService(mock_settings, mock_store, mock_audit_repo) - # Verify attributes are set assert service._settings == mock_settings assert service._store == mock_store assert service._audit_repo == mock_audit_repo @@ -53,132 +45,109 @@ def test_configuration_service_initialization(mock_settings, mock_store, mock_au def test_configuration_service_get_setting(mock_settings, mock_store, mock_audit_repo): """Test getting configuration settings.""" - # Create service instance service = ConfigurationService(mock_settings, mock_store, mock_audit_repo) - - # Set up mock loaded settings service._loaded_settings = {"test_key": "test_value"} - # Test getting existing setting - result = service.get_setting("test_key", "default") - assert result == "test_value" - - # Test getting non-existing setting with default - result = service.get_setting("non_existing", "default") - assert result == "default" + assert service.get_setting("test_key", "default") == "test_value" + assert service.get_setting("non_existing", "default") == "default" -def test_configuration_service_set_setting(mock_settings, mock_store, mock_audit_repo): +@pytest.mark.asyncio +async def test_configuration_service_set_setting(mock_settings, mock_store, mock_audit_repo): """Test setting configuration settings.""" - # Create service instance service = ConfigurationService(mock_settings, mock_store, mock_audit_repo) - # Mock the repository with patch("arbitrade.storage.repositories.ConfigSettingRepository") as mock_repo_class: mock_repo_instance = Mock() mock_repo_class.return_value = mock_repo_instance - # Mock the setting creation mock_created_setting = Mock() mock_created_setting.updated_at = "2023-01-01T00:00:00" - mock_repo_instance.create_setting.return_value = mock_created_setting - mock_repo_instance.get_setting.return_value = None # force create path + mock_repo_instance.create_setting = AsyncMock( + return_value=mock_created_setting) + mock_repo_instance.get_setting = AsyncMock(return_value=None) - # Set a setting - service.set_setting("test_key", "test_value", "test_user") + await service.set_setting("test_key", "test_value", "test_user") - # Verify repository was called - mock_repo_instance.create_setting.assert_called_once() + mock_repo_instance.create_setting.assert_awaited_once() -def test_configuration_service_hot_reload_detection(mock_settings, mock_store, mock_audit_repo): +@pytest.mark.asyncio +async def test_configuration_service_hot_reload_detection(mock_settings, mock_store, mock_audit_repo): """Test hot-reload detection functionality.""" - # Create service instance service = ConfigurationService(mock_settings, mock_store, mock_audit_repo) - # Initially should not be outdated - assert service.is_config_outdated() is False - - # Test with mock repository that returns a timestamp with patch("arbitrade.storage.repositories.ConfigSettingRepository") as mock_repo_class: mock_repo_instance = Mock() mock_repo_class.return_value = mock_repo_instance - # Mock the latest updated at timestamp + mock_repo_instance.get_latest_updated_at = AsyncMock(return_value=None) + assert await service.is_config_outdated() is False + from datetime import datetime - mock_repo_instance.get_latest_updated_at.return_value = datetime.now() - - # Should detect as outdated when timestamp exists - assert service.is_config_outdated() is True + mock_repo_instance.get_latest_updated_at = AsyncMock( + return_value=datetime.now()) + assert await service.is_config_outdated() is True -def test_configuration_service_reload_if_changed(mock_settings, mock_store, mock_audit_repo): +@pytest.mark.asyncio +async def test_configuration_service_reload_if_changed(mock_settings, mock_store, mock_audit_repo): """Test hot-reload functionality.""" - # Create service instance service = ConfigurationService(mock_settings, mock_store, mock_audit_repo) - # Mock the repository with patch("arbitrade.storage.repositories.ConfigSettingRepository") as mock_repo_class: mock_repo_instance = Mock() mock_repo_class.return_value = mock_repo_instance - # Mock the latest updated at timestamp to return None initially - mock_repo_instance.get_latest_updated_at.return_value = None - mock_repo_instance.list_settings.return_value = [] + mock_repo_instance.get_latest_updated_at = AsyncMock(return_value=None) + mock_repo_instance.list_settings = AsyncMock(return_value=[]) - # Mock the latest updated at timestamp to return a value from datetime import datetime - mock_repo_instance.get_latest_updated_at.return_value = datetime.now() + mock_repo_instance.get_latest_updated_at = AsyncMock( + return_value=datetime.now()) - # Should reload when outdated - result = service.reload_if_changed() + result = await service.reload_if_changed() assert result is True assert service.get_config_version() == 1 -def test_configuration_service_get_config_version(mock_settings, mock_store, mock_audit_repo): +@pytest.mark.asyncio +async def test_configuration_service_get_config_version(mock_settings, mock_store, mock_audit_repo): """Test getting configuration version.""" - # Create service instance service = ConfigurationService(mock_settings, mock_store, mock_audit_repo) - - # Should start at version 0 assert service.get_config_version() == 0 - # After setting a value, version should increment with patch("arbitrade.storage.repositories.ConfigSettingRepository") as mock_repo_class: mock_repo_instance = Mock() mock_repo_class.return_value = mock_repo_instance mock_created_setting = Mock() mock_created_setting.updated_at = "2023-01-01T00:00:00" - mock_repo_instance.create_setting.return_value = mock_created_setting - mock_repo_instance.get_setting.return_value = None + mock_repo_instance.create_setting = AsyncMock( + return_value=mock_created_setting) + mock_repo_instance.get_setting = AsyncMock(return_value=None) - service.set_setting("test_key", "test_value", "test_user") - # set_setting bumps version + await service.set_setting("test_key", "test_value", "test_user") assert service.get_config_version() == 1 -def test_configuration_service_get_last_updated_at(mock_settings, mock_store, mock_audit_repo): +@pytest.mark.asyncio +async def test_configuration_service_get_last_updated_at(mock_settings, mock_store, mock_audit_repo): """Test getting last updated timestamp.""" - # Create service instance service = ConfigurationService(mock_settings, mock_store, mock_audit_repo) - - # Should start with None assert service.get_last_updated_at() is None - # After setting a value, should have timestamp with patch("arbitrade.storage.repositories.ConfigSettingRepository") as mock_repo_class: mock_repo_instance = Mock() mock_repo_class.return_value = mock_repo_instance mock_created_setting = Mock() mock_created_setting.updated_at = "2023-01-01T00:00:00" - mock_repo_instance.create_setting.return_value = mock_created_setting - mock_repo_instance.get_setting.return_value = None + mock_repo_instance.create_setting = AsyncMock( + return_value=mock_created_setting) + mock_repo_instance.get_setting = AsyncMock(return_value=None) - service.set_setting("test_key", "test_value", "test_user") - # set_setting updates _last_updated_at from mock + await service.set_setting("test_key", "test_value", "test_user") assert service.get_last_updated_at() is not None diff --git a/tests/unit/test_runtime_lifecycle.py b/tests/unit/test_runtime_lifecycle.py index 94cc753..b774768 100644 --- a/tests/unit/test_runtime_lifecycle.py +++ b/tests/unit/test_runtime_lifecycle.py @@ -2,6 +2,7 @@ from __future__ import annotations from dataclasses import dataclass from datetime import UTC, datetime +from unittest.mock import AsyncMock, MagicMock import pytest @@ -31,38 +32,71 @@ class _FakeStartupReconciler: self.called = True -@pytest.mark.asyncio -async def test_persist_runtime_snapshot_writes_record(tmp_path) -> None: - app = create_app(Settings(_env_file=None, DUCKDB_PATH=tmp_path / "runtime.duckdb")) +def _mock_pg_store(): + """Create a PgStore-alike with an async pool returning an AsyncMock conn.""" + store = MagicMock() + conn = AsyncMock() + conn.fetchrow = AsyncMock() + conn.fetch = AsyncMock(return_value=[]) + conn.execute = AsyncMock(return_value=conn) + pool_cm = AsyncMock() + pool_cm.__aenter__.return_value = conn + store.pool = MagicMock() + store.pool.acquire.return_value = pool_cm + return store + +@pytest.fixture +def app(): + """Create a test app with a mocked PgStore and audit repository.""" + a = create_app( + Settings(_env_file=None, APP_MODE="paper", paper_trading_mode=True) + ) + a.state.store = _mock_pg_store() + a.state.runtime_state_repository.insert = AsyncMock() + a.state.runtime_state_repository.latest = AsyncMock(return_value=None) + # Replace audit repository with mock to avoid real PgStore access + audit_mock = AsyncMock() + audit_mock.insert = AsyncMock() + a.state.audit_repository = audit_mock + return a + + +@pytest.mark.asyncio +async def test_persist_runtime_snapshot_writes_record(app) -> None: app.state.dashboard_controls.is_running = True app.state.dashboard_controls.kill_switch.deactivate() - snapshot = persist_runtime_snapshot(app, note="unit-test") + # Mock _open_trade_count → 0, _latest_balances → None + conn = await app.state.store.pool.acquire().__aenter__() + conn.fetchrow = AsyncMock(return_value=MagicMock( + **{"__getitem__": lambda s, k: 0})) + + snapshot = await persist_runtime_snapshot(app, note="unit-test") assert snapshot is not None assert snapshot.note == "unit-test" - latest = app.state.runtime_state_repository.latest() + app.state.runtime_state_repository.latest = AsyncMock( + return_value=snapshot) + latest = await app.state.runtime_state_repository.latest() assert latest is not None assert latest.note == "unit-test" assert latest.is_running is True @pytest.mark.asyncio -async def test_restore_runtime_state_applies_snapshot(tmp_path) -> None: - app = create_app(Settings(_env_file=None, DUCKDB_PATH=tmp_path / "restore.duckdb")) - app.state.runtime_state_repository.insert( - RuntimeStateRecord( - snapshot_at=datetime.now(UTC), - is_running=False, - kill_switch_active=True, - kill_switch_reason="manual-stop", - open_trade_count=0, - last_known_balances={"USD": 100.0}, - note="seed", - ) +async def test_restore_runtime_state_applies_snapshot(app) -> None: + seed = RuntimeStateRecord( + snapshot_at=datetime.now(UTC), + is_running=False, + kill_switch_active=True, + kill_switch_reason="manual-stop", + open_trade_count=0, + last_known_balances={"USD": 100.0}, + note="seed", ) + app.state.runtime_state_repository.latest = AsyncMock(return_value=seed) report = await restore_runtime_state(app) @@ -73,36 +107,12 @@ async def test_restore_runtime_state_applies_snapshot(tmp_path) -> None: @pytest.mark.asyncio -async def test_restore_runtime_state_enables_restart_guard_for_open_trades(tmp_path) -> None: - app = create_app(Settings(_env_file=None, DUCKDB_PATH=tmp_path / "open-trades.duckdb")) - - with app.state.store.connect() as conn: - conn.execute( - """ - INSERT INTO trades ( - trade_ref, - started_at, - finished_at, - status, - realized_pnl, - estimated_pnl, - capital_used, - cycle, - leg_count - ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?) - """, - [ - "open-trade-1", - datetime.now(UTC), - None, - "open", - None, - 1.0, - 100.0, - "USD->BTC->ETH->USD", - 3, - ], - ) +async def test_restore_runtime_state_enables_restart_guard_for_open_trades(app) -> None: + # Simulate 1 open trade + conn = await app.state.store.pool.acquire().__aenter__() + row = MagicMock() + row.__getitem__.return_value = 1 + conn.fetchrow = AsyncMock(return_value=row) report = await restore_runtime_state(app) @@ -114,24 +124,26 @@ async def test_restore_runtime_state_enables_restart_guard_for_open_trades(tmp_p @pytest.mark.asyncio -async def test_graceful_shutdown_drains_workers_and_persists_snapshot(tmp_path) -> None: - app = create_app(Settings(_env_file=None, DUCKDB_PATH=tmp_path / "shutdown.duckdb")) +async def test_graceful_shutdown_drains_workers_and_persists_snapshot(app) -> None: worker = _FakeWorker() app.state.background_workers = [worker] app.state.dashboard_controls.is_running = True + # Mock _open_trade_count → 0, _latest_balances → None + conn = await app.state.store.pool.acquire().__aenter__() + row = MagicMock() + row.__getitem__.return_value = 0 + conn.fetchrow = AsyncMock(return_value=row) + await graceful_shutdown(app) assert worker.stopped is True assert app.state.dashboard_controls.is_running is False - latest = app.state.runtime_state_repository.latest() - assert latest is not None - assert latest.note == "graceful_shutdown" + app.state.runtime_state_repository.insert.assert_called() @pytest.mark.asyncio -async def test_restore_runtime_state_calls_startup_reconciler(tmp_path) -> None: - app = create_app(Settings(_env_file=None, DUCKDB_PATH=tmp_path / "reconciler.duckdb")) +async def test_restore_runtime_state_calls_startup_reconciler(app) -> None: reconciler = _FakeStartupReconciler() app.state.startup_reconciler = reconciler