from __future__ import annotations from dataclasses import dataclass from datetime import UTC, datetime from unittest.mock import AsyncMock, MagicMock import pytest from arbitrade.api.app import create_app from arbitrade.config.settings import Settings from arbitrade.runtime.lifecycle import ( graceful_shutdown, persist_runtime_snapshot, restore_runtime_state, ) from arbitrade.storage.repositories import RuntimeStateRecord @dataclass(slots=True) class _FakeWorker: stopped: bool = False async def stop(self) -> None: self.stopped = True @dataclass(slots=True) class _FakeStartupReconciler: called: bool = False async def reconcile_open_trades(self) -> None: self.called = True def _mock_pg_store(): """Create a PgStore-alike with an async pool returning an AsyncMock conn.""" store = MagicMock() conn = AsyncMock() conn.fetchrow = AsyncMock() conn.fetch = AsyncMock(return_value=[]) conn.execute = AsyncMock(return_value=conn) pool_cm = AsyncMock() pool_cm.__aenter__.return_value = conn store.pool = MagicMock() store.pool.acquire.return_value = pool_cm return store @pytest.fixture def app(): """Create a test app with a mocked PgStore and audit repository.""" a = create_app(Settings(_env_file=None, APP_MODE="paper", paper_trading_mode=True)) a.state.store = _mock_pg_store() a.state.runtime_state_repository.insert = AsyncMock() a.state.runtime_state_repository.latest = AsyncMock(return_value=None) # Replace audit repository with mock to avoid real PgStore access audit_mock = AsyncMock() audit_mock.insert = AsyncMock() a.state.audit_repository = audit_mock return a @pytest.mark.asyncio async def test_persist_runtime_snapshot_writes_record(app) -> None: app.state.dashboard_controls.is_running = True app.state.dashboard_controls.kill_switch.deactivate() # Mock _open_trade_count → 0, _latest_balances → None conn = await app.state.store.pool.acquire().__aenter__() conn.fetchrow = AsyncMock(return_value=MagicMock(**{"__getitem__": lambda s, k: 0})) snapshot = await persist_runtime_snapshot(app, note="unit-test") assert snapshot is not None assert snapshot.note == "unit-test" app.state.runtime_state_repository.latest = AsyncMock(return_value=snapshot) latest = await app.state.runtime_state_repository.latest() assert latest is not None assert latest.note == "unit-test" assert latest.is_running is True @pytest.mark.asyncio async def test_restore_runtime_state_applies_snapshot(app) -> None: seed = RuntimeStateRecord( snapshot_at=datetime.now(UTC), is_running=False, kill_switch_active=True, kill_switch_reason="manual-stop", open_trade_count=0, last_known_balances={"USD": 100.0}, note="seed", ) app.state.runtime_state_repository.latest = AsyncMock(return_value=seed) report = await restore_runtime_state(app) assert report.restored_from_snapshot is True assert app.state.dashboard_controls.is_running is False assert app.state.dashboard_controls.kill_switch.is_active is True assert app.state.dashboard_controls.kill_switch.reason == "manual-stop" @pytest.mark.asyncio async def test_restore_runtime_state_enables_restart_guard_for_open_trades(app) -> None: # Simulate 1 open trade conn = await app.state.store.pool.acquire().__aenter__() row = MagicMock() row.__getitem__.return_value = 1 conn.fetchrow = AsyncMock(return_value=row) report = await restore_runtime_state(app) assert report.open_trades_detected == 1 assert report.restart_guard_active is True assert app.state.dashboard_controls.is_running is False assert app.state.dashboard_controls.kill_switch.is_active is True assert app.state.dashboard_controls.kill_switch.reason == "recovery_open_trades_detected" @pytest.mark.asyncio async def test_graceful_shutdown_drains_workers_and_persists_snapshot(app) -> None: worker = _FakeWorker() app.state.background_workers = [worker] app.state.dashboard_controls.is_running = True # Mock _open_trade_count → 0, _latest_balances → None conn = await app.state.store.pool.acquire().__aenter__() row = MagicMock() row.__getitem__.return_value = 0 conn.fetchrow = AsyncMock(return_value=row) await graceful_shutdown(app) assert worker.stopped is True assert app.state.dashboard_controls.is_running is False app.state.runtime_state_repository.insert.assert_called() @pytest.mark.asyncio async def test_restore_runtime_state_calls_startup_reconciler(app) -> None: reconciler = _FakeStartupReconciler() app.state.startup_reconciler = reconciler await restore_runtime_state(app) assert reconciler.called is True