dc99f1604e
CI / lint-test-build (push) Successful in 54s
- Cleaned up multiline statements and removed unnecessary line breaks in various files. - Ensured consistent formatting in function definitions and calls across the codebase. - Updated docstrings and comments for clarity where applicable. - Removed trailing newlines in module docstrings. - Enhanced logging statements for better clarity in maintenance tasks.
149 lines
4.7 KiB
Python
149 lines
4.7 KiB
Python
from __future__ import annotations
|
|
|
|
from dataclasses import dataclass
|
|
from datetime import UTC, datetime
|
|
from unittest.mock import AsyncMock, MagicMock
|
|
|
|
import pytest
|
|
|
|
from arbitrade.api.app import create_app
|
|
from arbitrade.config.settings import Settings
|
|
from arbitrade.runtime.lifecycle import (
|
|
graceful_shutdown,
|
|
persist_runtime_snapshot,
|
|
restore_runtime_state,
|
|
)
|
|
from arbitrade.storage.repositories import RuntimeStateRecord
|
|
|
|
|
|
@dataclass(slots=True)
|
|
class _FakeWorker:
|
|
stopped: bool = False
|
|
|
|
async def stop(self) -> None:
|
|
self.stopped = True
|
|
|
|
|
|
@dataclass(slots=True)
|
|
class _FakeStartupReconciler:
|
|
called: bool = False
|
|
|
|
async def reconcile_open_trades(self) -> None:
|
|
self.called = True
|
|
|
|
|
|
def _mock_pg_store():
|
|
"""Create a PgStore-alike with an async pool returning an AsyncMock conn."""
|
|
store = MagicMock()
|
|
conn = AsyncMock()
|
|
conn.fetchrow = AsyncMock()
|
|
conn.fetch = AsyncMock(return_value=[])
|
|
conn.execute = AsyncMock(return_value=conn)
|
|
pool_cm = AsyncMock()
|
|
pool_cm.__aenter__.return_value = conn
|
|
store.pool = MagicMock()
|
|
store.pool.acquire.return_value = pool_cm
|
|
return store
|
|
|
|
|
|
@pytest.fixture
|
|
def app():
|
|
"""Create a test app with a mocked PgStore and audit repository."""
|
|
a = create_app(Settings(_env_file=None, APP_MODE="paper", paper_trading_mode=True))
|
|
a.state.store = _mock_pg_store()
|
|
a.state.runtime_state_repository.insert = AsyncMock()
|
|
a.state.runtime_state_repository.latest = AsyncMock(return_value=None)
|
|
# Replace audit repository with mock to avoid real PgStore access
|
|
audit_mock = AsyncMock()
|
|
audit_mock.insert = AsyncMock()
|
|
a.state.audit_repository = audit_mock
|
|
return a
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_persist_runtime_snapshot_writes_record(app) -> None:
|
|
app.state.dashboard_controls.is_running = True
|
|
app.state.dashboard_controls.kill_switch.deactivate()
|
|
|
|
# Mock _open_trade_count → 0, _latest_balances → None
|
|
conn = await app.state.store.pool.acquire().__aenter__()
|
|
conn.fetchrow = AsyncMock(return_value=MagicMock(**{"__getitem__": lambda s, k: 0}))
|
|
|
|
snapshot = await persist_runtime_snapshot(app, note="unit-test")
|
|
|
|
assert snapshot is not None
|
|
assert snapshot.note == "unit-test"
|
|
|
|
app.state.runtime_state_repository.latest = AsyncMock(return_value=snapshot)
|
|
latest = await app.state.runtime_state_repository.latest()
|
|
assert latest is not None
|
|
assert latest.note == "unit-test"
|
|
assert latest.is_running is True
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_restore_runtime_state_applies_snapshot(app) -> None:
|
|
seed = RuntimeStateRecord(
|
|
snapshot_at=datetime.now(UTC),
|
|
is_running=False,
|
|
kill_switch_active=True,
|
|
kill_switch_reason="manual-stop",
|
|
open_trade_count=0,
|
|
last_known_balances={"USD": 100.0},
|
|
note="seed",
|
|
)
|
|
app.state.runtime_state_repository.latest = AsyncMock(return_value=seed)
|
|
|
|
report = await restore_runtime_state(app)
|
|
|
|
assert report.restored_from_snapshot is True
|
|
assert app.state.dashboard_controls.is_running is False
|
|
assert app.state.dashboard_controls.kill_switch.is_active is True
|
|
assert app.state.dashboard_controls.kill_switch.reason == "manual-stop"
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_restore_runtime_state_enables_restart_guard_for_open_trades(app) -> None:
|
|
# Simulate 1 open trade
|
|
conn = await app.state.store.pool.acquire().__aenter__()
|
|
row = MagicMock()
|
|
row.__getitem__.return_value = 1
|
|
conn.fetchrow = AsyncMock(return_value=row)
|
|
|
|
report = await restore_runtime_state(app)
|
|
|
|
assert report.open_trades_detected == 1
|
|
assert report.restart_guard_active is True
|
|
assert app.state.dashboard_controls.is_running is False
|
|
assert app.state.dashboard_controls.kill_switch.is_active is True
|
|
assert app.state.dashboard_controls.kill_switch.reason == "recovery_open_trades_detected"
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_graceful_shutdown_drains_workers_and_persists_snapshot(app) -> None:
|
|
worker = _FakeWorker()
|
|
app.state.background_workers = [worker]
|
|
app.state.dashboard_controls.is_running = True
|
|
|
|
# Mock _open_trade_count → 0, _latest_balances → None
|
|
conn = await app.state.store.pool.acquire().__aenter__()
|
|
row = MagicMock()
|
|
row.__getitem__.return_value = 0
|
|
conn.fetchrow = AsyncMock(return_value=row)
|
|
|
|
await graceful_shutdown(app)
|
|
|
|
assert worker.stopped is True
|
|
assert app.state.dashboard_controls.is_running is False
|
|
app.state.runtime_state_repository.insert.assert_called()
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_restore_runtime_state_calls_startup_reconciler(app) -> None:
|
|
reconciler = _FakeStartupReconciler()
|
|
app.state.startup_reconciler = reconciler
|
|
|
|
await restore_runtime_state(app)
|
|
|
|
assert reconciler.called is True
|