from __future__ import annotations from dataclasses import dataclass from datetime import UTC, datetime from typing import Any import pytest from arbitrade.execution.fill_monitor import FillMonitorResult, OrderFillState from arbitrade.execution.recovery import PartialFillRecovery @dataclass(slots=True) class _FakeRestClient: cancel_calls: list[str] = None # type: ignore[assignment] market_calls: list[dict[str, Any]] = None # type: ignore[assignment] def __post_init__(self) -> None: self.cancel_calls = [] self.market_calls = [] async def cancel_order(self, *, order_id: str) -> dict[str, Any]: self.cancel_calls.append(order_id) return {"result": {"count": 1}} async def place_market_order(self, *, pair: str, side: str, volume: float) -> dict[str, Any]: self.market_calls.append({"pair": pair, "side": side, "volume": volume}) return {"txid": ["hedge-1"]} def _monitor_result( *, status: str, filled_volume: float | None, timed_out: bool ) -> FillMonitorResult: state = OrderFillState( order_id="order-1", status=status, filled_volume=filled_volume, avg_price=100.0, updated_at=datetime.now(UTC), source="rest_poll", ) return FillMonitorResult( order_id="order-1", timed_out=timed_out, terminal_state=None if status in {"open", "partial"} else state, last_state=state, elapsed_seconds=1.0, ) @pytest.mark.asyncio async def test_partial_fill_recovery_cancels_open_order_and_hedges_residual() -> None: client = _FakeRestClient() recovery = PartialFillRecovery(client) result = await recovery.recover_partial_fill( order_id="order-1", pair="BTC/USD", side="buy", requested_volume=10.0, fill_result=_monitor_result(status="partial", filled_volume=4.0, timed_out=True), ) assert result.canceled assert result.hedged assert client.cancel_calls == ["order-1"] assert client.market_calls == [{"pair": "BTC/USD", "side": "sell", "volume": 6.0}] assert result.hedge_volume == 6.0 assert result.reason == "canceled_partial_order" @pytest.mark.asyncio async def test_partial_fill_recovery_no_hedge_when_no_residual() -> None: client = _FakeRestClient() recovery = PartialFillRecovery(client) result = await recovery.recover_partial_fill( order_id="order-1", pair="BTC/USD", side="sell", requested_volume=5.0, fill_result=_monitor_result(status="closed", filled_volume=5.0, timed_out=False), ) assert not result.canceled assert not result.hedged assert client.cancel_calls == [] assert client.market_calls == [] def test_partial_fill_recovery_rejects_invalid_volume() -> None: client = _FakeRestClient() recovery = PartialFillRecovery(client) with pytest.raises(ValueError, match="requested_volume"): recovery._residual_volume(None, 0.0)