93f4f62d42
- Add IdempotencyKeyFactory for generating unique user references based on execution legs. - Introduce OrderReconciler to reconcile order statuses with historical data. - Implement PartialFillRecovery to handle partial fills by canceling orders and placing hedges. - Create TriangularExecutionSequencer for executing triangular arbitrage strategies. - Enhance storage with new tables for trades, orders, and PnL events. - Develop AsyncExecutionWriter for asynchronous writing of execution records to the database. - Add unit tests for execution persistence, sequencer behavior, fill monitoring, and idempotency checks. - Update KrakenRestClient to ensure proper payloads for order placement and querying.
113 lines
3.5 KiB
Python
113 lines
3.5 KiB
Python
from __future__ import annotations
|
|
|
|
from dataclasses import dataclass, field
|
|
from datetime import UTC, datetime
|
|
from typing import Any
|
|
|
|
import pytest
|
|
|
|
from arbitrade.execution.fill_monitor import FillMonitor, OrderFillState
|
|
|
|
|
|
@dataclass(slots=True)
|
|
class _FakePollClient:
|
|
responses: list[dict[str, Any]]
|
|
calls: int = 0
|
|
|
|
async def query_order(self, *, order_id: str, include_trades: bool = True) -> dict[str, Any]:
|
|
self.calls += 1
|
|
if self.responses:
|
|
return self.responses.pop(0)
|
|
return {order_id: {"status": "open", "vol_exec": "0.0", "price": "0.0"}}
|
|
|
|
|
|
@dataclass(slots=True)
|
|
class _FakeWsProvider:
|
|
states: list[OrderFillState] = field(default_factory=list)
|
|
|
|
def get(self, _order_id: str) -> OrderFillState | None:
|
|
if not self.states:
|
|
return None
|
|
return self.states.pop(0)
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_fill_monitor_detects_terminal_state_via_polling() -> None:
|
|
order_id = "order-1"
|
|
client = _FakePollClient(
|
|
responses=[
|
|
{order_id: {"status": "open", "vol_exec": "0.0", "price": "0.0"}},
|
|
{order_id: {"status": "closed", "vol_exec": "1.0", "price": "100.0"}},
|
|
]
|
|
)
|
|
monitor = FillMonitor(client, poll_interval_seconds=0.001, max_wait_seconds=0.1)
|
|
|
|
result = await monitor.wait_for_terminal_fill(order_id)
|
|
|
|
assert not result.timed_out
|
|
assert result.terminal_state is not None
|
|
assert result.terminal_state.status == "closed"
|
|
assert result.terminal_state.filled_volume == 1.0
|
|
assert result.terminal_state.source == "rest_poll"
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_fill_monitor_times_out_when_no_terminal_state() -> None:
|
|
order_id = "order-2"
|
|
client = _FakePollClient(
|
|
responses=[
|
|
{order_id: {"status": "open", "vol_exec": "0.1", "price": "100.0"}},
|
|
{order_id: {"status": "partial", "vol_exec": "0.2", "price": "100.0"}},
|
|
{order_id: {"status": "open", "vol_exec": "0.2", "price": "100.0"}},
|
|
]
|
|
)
|
|
monitor = FillMonitor(client, poll_interval_seconds=0.001, max_wait_seconds=0.01)
|
|
|
|
result = await monitor.wait_for_terminal_fill(order_id)
|
|
|
|
assert result.timed_out
|
|
assert result.terminal_state is None
|
|
assert result.last_state is not None
|
|
assert result.last_state.status in {"open", "partial"}
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_fill_monitor_uses_ws_status_for_fast_terminal_detection() -> None:
|
|
order_id = "order-3"
|
|
ws_provider = _FakeWsProvider(
|
|
states=[
|
|
OrderFillState(
|
|
order_id=order_id,
|
|
status="closed",
|
|
filled_volume=0.5,
|
|
avg_price=200.0,
|
|
updated_at=datetime.now(UTC),
|
|
source="ws",
|
|
)
|
|
]
|
|
)
|
|
client = _FakePollClient(responses=[])
|
|
monitor = FillMonitor(
|
|
client,
|
|
poll_interval_seconds=0.001,
|
|
max_wait_seconds=0.1,
|
|
ws_status_provider=ws_provider.get,
|
|
)
|
|
|
|
result = await monitor.wait_for_terminal_fill(order_id)
|
|
|
|
assert not result.timed_out
|
|
assert result.terminal_state is not None
|
|
assert result.terminal_state.source == "ws"
|
|
assert client.calls == 0
|
|
|
|
|
|
def test_fill_monitor_rejects_invalid_configuration() -> None:
|
|
client = _FakePollClient(responses=[])
|
|
|
|
with pytest.raises(ValueError, match="poll_interval_seconds"):
|
|
FillMonitor(client, poll_interval_seconds=0.0)
|
|
|
|
with pytest.raises(ValueError, match="max_wait_seconds"):
|
|
FillMonitor(client, max_wait_seconds=0.0)
|