from __future__ import annotations from dataclasses import dataclass, field from datetime import UTC, datetime from typing import Any import pytest from arbitrade.execution.fill_monitor import FillMonitor, OrderFillState @dataclass(slots=True) class _FakePollClient: responses: list[dict[str, Any]] calls: int = 0 async def query_order(self, *, order_id: str, include_trades: bool = True) -> dict[str, Any]: self.calls += 1 if self.responses: return self.responses.pop(0) return {order_id: {"status": "open", "vol_exec": "0.0", "price": "0.0"}} @dataclass(slots=True) class _FakeWsProvider: states: list[OrderFillState] = field(default_factory=list) def get(self, _order_id: str) -> OrderFillState | None: if not self.states: return None return self.states.pop(0) @pytest.mark.asyncio async def test_fill_monitor_detects_terminal_state_via_polling() -> None: order_id = "order-1" client = _FakePollClient( responses=[ {order_id: {"status": "open", "vol_exec": "0.0", "price": "0.0"}}, {order_id: {"status": "closed", "vol_exec": "1.0", "price": "100.0"}}, ] ) monitor = FillMonitor(client, poll_interval_seconds=0.001, max_wait_seconds=0.1) result = await monitor.wait_for_terminal_fill(order_id) assert not result.timed_out assert result.terminal_state is not None assert result.terminal_state.status == "closed" assert result.terminal_state.filled_volume == 1.0 assert result.terminal_state.source == "rest_poll" @pytest.mark.asyncio async def test_fill_monitor_times_out_when_no_terminal_state() -> None: order_id = "order-2" client = _FakePollClient( responses=[ {order_id: {"status": "open", "vol_exec": "0.1", "price": "100.0"}}, {order_id: {"status": "partial", "vol_exec": "0.2", "price": "100.0"}}, {order_id: {"status": "open", "vol_exec": "0.2", "price": "100.0"}}, ] ) monitor = FillMonitor(client, poll_interval_seconds=0.001, max_wait_seconds=0.01) result = await monitor.wait_for_terminal_fill(order_id) assert result.timed_out assert result.terminal_state is None assert result.last_state is not None assert result.last_state.status in {"open", "partial"} @pytest.mark.asyncio async def test_fill_monitor_uses_ws_status_for_fast_terminal_detection() -> None: order_id = "order-3" ws_provider = _FakeWsProvider( states=[ OrderFillState( order_id=order_id, status="closed", filled_volume=0.5, avg_price=200.0, updated_at=datetime.now(UTC), source="ws", ) ] ) client = _FakePollClient(responses=[]) monitor = FillMonitor( client, poll_interval_seconds=0.001, max_wait_seconds=0.1, ws_status_provider=ws_provider.get, ) result = await monitor.wait_for_terminal_fill(order_id) assert not result.timed_out assert result.terminal_state is not None assert result.terminal_state.source == "ws" assert client.calls == 0 def test_fill_monitor_rejects_invalid_configuration() -> None: client = _FakePollClient(responses=[]) with pytest.raises(ValueError, match="poll_interval_seconds"): FillMonitor(client, poll_interval_seconds=0.0) with pytest.raises(ValueError, match="max_wait_seconds"): FillMonitor(client, max_wait_seconds=0.0)