feat: Implement idempotency and recovery mechanisms for order execution
- Add IdempotencyKeyFactory for generating unique user references based on execution legs. - Introduce OrderReconciler to reconcile order statuses with historical data. - Implement PartialFillRecovery to handle partial fills by canceling orders and placing hedges. - Create TriangularExecutionSequencer for executing triangular arbitrage strategies. - Enhance storage with new tables for trades, orders, and PnL events. - Develop AsyncExecutionWriter for asynchronous writing of execution records to the database. - Add unit tests for execution persistence, sequencer behavior, fill monitoring, and idempotency checks. - Update KrakenRestClient to ensure proper payloads for order placement and querying.
This commit is contained in:
@@ -0,0 +1,23 @@
|
||||
## [Unreleased] - 2026-06-01
|
||||
|
||||
### Added
|
||||
|
||||
- Added stop-condition risk controls for abnormal source/apply latency and repeated execution failures.
|
||||
- Added a new stop-conditions guard and integration in market feed processing.
|
||||
|
||||
### Changed
|
||||
|
||||
- Live execution path now auto-activates the kill switch when configured stop conditions are breached.
|
||||
- Added configuration env keys for stop-condition thresholds.
|
||||
|
||||
### Removed
|
||||
|
||||
- None.
|
||||
|
||||
### Fixed
|
||||
|
||||
- Added/expanded unit coverage for risk limits and kill-switch enforcement, including stop-condition scenarios.
|
||||
- Added partial-fill recovery logic that cancels open orders when possible and hedges residual exposure on timeout or failure.
|
||||
- Added deterministic order idempotency via Kraken userref plus reconciliation helpers for Kraken order history responses.
|
||||
- Added execution journaling for trades, orders, and estimated P&L, plus a DuckDB startup fallback when the default file path is unavailable.
|
||||
- Added a mocked execution integration test that drives the triangular sequencer through the execution journal and DuckDB persistence.
|
||||
@@ -177,11 +177,105 @@ class KrakenRestClient:
|
||||
result = await self._request_with_retry("/0/public/AssetPairs")
|
||||
return _result_dict(result.payload)
|
||||
|
||||
async def _throttled_private_call(self, endpoint: str) -> dict[str, Any]:
|
||||
async def _throttled_private_call(
|
||||
self,
|
||||
endpoint: str,
|
||||
data: dict[str, str] | None = None,
|
||||
) -> dict[str, Any]:
|
||||
async with self._private_lock:
|
||||
result = await self._private_post_with_retry(endpoint)
|
||||
result = await self._private_post_with_retry(endpoint, data=data)
|
||||
await asyncio.sleep(self._settings.kraken_private_rate_limit_seconds)
|
||||
return _result_dict(result.payload)
|
||||
|
||||
async def balances(self) -> dict[str, Any]:
|
||||
return await self._throttled_private_call("/0/private/Balance")
|
||||
|
||||
async def place_market_order(
|
||||
self,
|
||||
*,
|
||||
pair: str,
|
||||
side: str,
|
||||
volume: float,
|
||||
user_ref: int | None = None,
|
||||
) -> dict[str, Any]:
|
||||
normalized_side = side.lower()
|
||||
if normalized_side not in {"buy", "sell"}:
|
||||
raise ValueError("side must be 'buy' or 'sell'")
|
||||
if volume <= 0.0:
|
||||
raise ValueError("volume must be > 0.0")
|
||||
if user_ref is not None and user_ref < 0:
|
||||
raise ValueError("user_ref must be >= 0")
|
||||
|
||||
data = {
|
||||
"pair": pair,
|
||||
"type": normalized_side,
|
||||
"ordertype": "market",
|
||||
"volume": str(volume),
|
||||
}
|
||||
if user_ref is not None:
|
||||
data["userref"] = str(user_ref)
|
||||
|
||||
return await self._throttled_private_call(
|
||||
"/0/private/AddOrder",
|
||||
data=data,
|
||||
)
|
||||
|
||||
async def place_limit_order(
|
||||
self,
|
||||
*,
|
||||
pair: str,
|
||||
side: str,
|
||||
volume: float,
|
||||
price: float,
|
||||
user_ref: int | None = None,
|
||||
) -> dict[str, Any]:
|
||||
normalized_side = side.lower()
|
||||
if normalized_side not in {"buy", "sell"}:
|
||||
raise ValueError("side must be 'buy' or 'sell'")
|
||||
if volume <= 0.0:
|
||||
raise ValueError("volume must be > 0.0")
|
||||
if price <= 0.0:
|
||||
raise ValueError("price must be > 0.0")
|
||||
if user_ref is not None and user_ref < 0:
|
||||
raise ValueError("user_ref must be >= 0")
|
||||
|
||||
data = {
|
||||
"pair": pair,
|
||||
"type": normalized_side,
|
||||
"ordertype": "limit",
|
||||
"price": str(price),
|
||||
"volume": str(volume),
|
||||
}
|
||||
if user_ref is not None:
|
||||
data["userref"] = str(user_ref)
|
||||
|
||||
return await self._throttled_private_call(
|
||||
"/0/private/AddOrder",
|
||||
data=data,
|
||||
)
|
||||
|
||||
async def query_order(
|
||||
self,
|
||||
*,
|
||||
order_id: str,
|
||||
include_trades: bool = True,
|
||||
) -> dict[str, Any]:
|
||||
if not order_id.strip():
|
||||
raise ValueError("order_id must be non-empty")
|
||||
|
||||
return await self._throttled_private_call(
|
||||
"/0/private/QueryOrders",
|
||||
data={
|
||||
"txid": order_id,
|
||||
"trades": "true" if include_trades else "false",
|
||||
},
|
||||
)
|
||||
|
||||
async def cancel_order(self, *, order_id: str) -> dict[str, Any]:
|
||||
if not order_id.strip():
|
||||
raise ValueError("order_id must be non-empty")
|
||||
|
||||
return await self._throttled_private_call(
|
||||
"/0/private/CancelOrder",
|
||||
data={"txid": order_id},
|
||||
)
|
||||
|
||||
@@ -0,0 +1,32 @@
|
||||
"""Trade execution helpers."""
|
||||
|
||||
from arbitrade.execution.fill_monitor import (
|
||||
FillMonitor,
|
||||
FillMonitorResult,
|
||||
OrderFillState,
|
||||
)
|
||||
from arbitrade.execution.idempotency import (
|
||||
IdempotencyKeyFactory,
|
||||
OrderReconciler,
|
||||
ReconciliationReport,
|
||||
)
|
||||
from arbitrade.execution.recovery import PartialFillRecovery, RecoveryAction
|
||||
from arbitrade.execution.sequencer import (
|
||||
ExecutionLeg,
|
||||
TriangularExecutionResult,
|
||||
TriangularExecutionSequencer,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"ExecutionLeg",
|
||||
"OrderFillState",
|
||||
"FillMonitorResult",
|
||||
"FillMonitor",
|
||||
"IdempotencyKeyFactory",
|
||||
"ReconciliationReport",
|
||||
"OrderReconciler",
|
||||
"RecoveryAction",
|
||||
"PartialFillRecovery",
|
||||
"TriangularExecutionResult",
|
||||
"TriangularExecutionSequencer",
|
||||
]
|
||||
@@ -0,0 +1,133 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any, Protocol
|
||||
|
||||
|
||||
class SupportsOrderStatusPolling(Protocol):
|
||||
async def query_order(
|
||||
self, *, order_id: str, include_trades: bool = True
|
||||
) -> dict[str, Any]: ...
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class OrderFillState:
|
||||
order_id: str
|
||||
status: str
|
||||
filled_volume: float | None
|
||||
avg_price: float | None
|
||||
updated_at: datetime
|
||||
source: str
|
||||
|
||||
@property
|
||||
def is_terminal(self) -> bool:
|
||||
return self.status in {"closed", "canceled", "expired"}
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class FillMonitorResult:
|
||||
order_id: str
|
||||
timed_out: bool
|
||||
terminal_state: OrderFillState | None
|
||||
last_state: OrderFillState | None
|
||||
elapsed_seconds: float
|
||||
|
||||
|
||||
class FillMonitor:
|
||||
def __init__(
|
||||
self,
|
||||
poll_client: SupportsOrderStatusPolling,
|
||||
*,
|
||||
poll_interval_seconds: float = 0.5,
|
||||
max_wait_seconds: float = 10.0,
|
||||
ws_status_provider: Callable[[str], OrderFillState | None] | None = None,
|
||||
) -> None:
|
||||
if poll_interval_seconds <= 0.0:
|
||||
raise ValueError("poll_interval_seconds must be > 0.0")
|
||||
if max_wait_seconds <= 0.0:
|
||||
raise ValueError("max_wait_seconds must be > 0.0")
|
||||
|
||||
self._poll_client = poll_client
|
||||
self._poll_interval_seconds = poll_interval_seconds
|
||||
self._max_wait_seconds = max_wait_seconds
|
||||
self._ws_status_provider = ws_status_provider
|
||||
|
||||
@staticmethod
|
||||
def _to_float(value: Any) -> float | None:
|
||||
if value is None:
|
||||
return None
|
||||
try:
|
||||
return float(value)
|
||||
except (TypeError, ValueError):
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def _state_from_payload(
|
||||
cls, order_id: str, payload: dict[str, Any], *, source: str
|
||||
) -> OrderFillState:
|
||||
status = str(payload.get("status", "unknown")).lower()
|
||||
return OrderFillState(
|
||||
order_id=order_id,
|
||||
status=status,
|
||||
filled_volume=cls._to_float(payload.get("vol_exec")),
|
||||
avg_price=cls._to_float(payload.get("price") or payload.get("avg_price")),
|
||||
updated_at=datetime.now(UTC),
|
||||
source=source,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _extract_order_payload(cls, order_id: str, response: dict[str, Any]) -> dict[str, Any]:
|
||||
if order_id in response and isinstance(response[order_id], dict):
|
||||
payload = response[order_id]
|
||||
return {str(key): value for key, value in payload.items()}
|
||||
return response
|
||||
|
||||
async def wait_for_terminal_fill(self, order_id: str) -> FillMonitorResult:
|
||||
if not order_id.strip():
|
||||
raise ValueError("order_id must be non-empty")
|
||||
|
||||
started = time.monotonic()
|
||||
last_state: OrderFillState | None = None
|
||||
|
||||
while True:
|
||||
elapsed = time.monotonic() - started
|
||||
if elapsed >= self._max_wait_seconds:
|
||||
return FillMonitorResult(
|
||||
order_id=order_id,
|
||||
timed_out=True,
|
||||
terminal_state=None,
|
||||
last_state=last_state,
|
||||
elapsed_seconds=elapsed,
|
||||
)
|
||||
|
||||
if self._ws_status_provider is not None:
|
||||
ws_state = self._ws_status_provider(order_id)
|
||||
if ws_state is not None:
|
||||
last_state = ws_state
|
||||
if ws_state.is_terminal:
|
||||
return FillMonitorResult(
|
||||
order_id=order_id,
|
||||
timed_out=False,
|
||||
terminal_state=ws_state,
|
||||
last_state=ws_state,
|
||||
elapsed_seconds=elapsed,
|
||||
)
|
||||
|
||||
response = await self._poll_client.query_order(order_id=order_id, include_trades=True)
|
||||
payload = self._extract_order_payload(order_id, response)
|
||||
polled_state = self._state_from_payload(order_id, payload, source="rest_poll")
|
||||
last_state = polled_state
|
||||
if polled_state.is_terminal:
|
||||
return FillMonitorResult(
|
||||
order_id=order_id,
|
||||
timed_out=False,
|
||||
terminal_state=polled_state,
|
||||
last_state=polled_state,
|
||||
elapsed_seconds=time.monotonic() - started,
|
||||
)
|
||||
|
||||
await asyncio.sleep(self._poll_interval_seconds)
|
||||
@@ -0,0 +1,105 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Protocol
|
||||
|
||||
from arbitrade.detection.engine import OpportunityEvent
|
||||
from arbitrade.execution.sequencer import ExecutionLeg
|
||||
|
||||
|
||||
class SupportsOrderHistoryLookup(Protocol):
|
||||
async def query_order(
|
||||
self, *, order_id: str, include_trades: bool = True
|
||||
) -> dict[str, Any]: ...
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class ReconciliationReport:
|
||||
order_id: str
|
||||
user_ref: int
|
||||
status: str
|
||||
filled_volume: float | None
|
||||
avg_price: float | None
|
||||
is_terminal: bool
|
||||
matches_request: bool
|
||||
raw_payload: dict[str, Any]
|
||||
|
||||
|
||||
class IdempotencyKeyFactory:
|
||||
def user_ref_for_leg(self, event: OpportunityEvent, leg: ExecutionLeg, leg_index: int) -> int:
|
||||
material = "|".join(
|
||||
[
|
||||
event.cycle,
|
||||
event.updated_pair,
|
||||
leg.from_currency,
|
||||
leg.to_currency,
|
||||
leg.pair,
|
||||
leg.side,
|
||||
f"{leg.volume:.12f}",
|
||||
str(leg_index),
|
||||
]
|
||||
).encode("utf-8")
|
||||
digest = hashlib.sha256(material).digest()
|
||||
value = int.from_bytes(digest[:8], "big") % 2_147_483_647
|
||||
return value or 1
|
||||
|
||||
|
||||
class OrderReconciler:
|
||||
def __init__(self, history_client: SupportsOrderHistoryLookup) -> None:
|
||||
self._history_client = history_client
|
||||
|
||||
@staticmethod
|
||||
def _to_float(value: Any) -> float | None:
|
||||
if value is None:
|
||||
return None
|
||||
try:
|
||||
return float(value)
|
||||
except (TypeError, ValueError):
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _extract_payload(order_id: str, response: dict[str, Any]) -> dict[str, Any]:
|
||||
if order_id in response and isinstance(response[order_id], dict):
|
||||
payload = response[order_id]
|
||||
return {str(key): value for key, value in payload.items()}
|
||||
return response
|
||||
|
||||
async def reconcile_order(
|
||||
self,
|
||||
*,
|
||||
order_id: str,
|
||||
user_ref: int,
|
||||
expected_pair: str,
|
||||
expected_side: str,
|
||||
expected_volume: float,
|
||||
) -> ReconciliationReport:
|
||||
if not order_id.strip():
|
||||
raise ValueError("order_id must be non-empty")
|
||||
|
||||
response = await self._history_client.query_order(order_id=order_id, include_trades=True)
|
||||
payload = self._extract_payload(order_id, response)
|
||||
status = str(payload.get("status", "unknown")).lower()
|
||||
filled_volume = self._to_float(payload.get("vol_exec"))
|
||||
avg_price = self._to_float(payload.get("price") or payload.get("avg_price"))
|
||||
reported_pair = str(payload.get("pair", expected_pair))
|
||||
reported_side = str(payload.get("type", expected_side)).lower()
|
||||
matches_request = (
|
||||
reported_pair == expected_pair
|
||||
and reported_side == expected_side.lower()
|
||||
and (
|
||||
expected_volume <= 0.0 or filled_volume is None or filled_volume <= expected_volume
|
||||
)
|
||||
and payload.get("userref") in {None, str(user_ref), user_ref}
|
||||
)
|
||||
|
||||
return ReconciliationReport(
|
||||
order_id=order_id,
|
||||
user_ref=user_ref,
|
||||
status=status,
|
||||
filled_volume=filled_volume,
|
||||
avg_price=avg_price,
|
||||
is_terminal=status in {"closed", "canceled", "expired"},
|
||||
matches_request=matches_request,
|
||||
raw_payload=payload,
|
||||
)
|
||||
@@ -0,0 +1,98 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Protocol
|
||||
|
||||
from arbitrade.execution.fill_monitor import FillMonitorResult, OrderFillState
|
||||
|
||||
|
||||
class SupportsOrderLifecycle(Protocol):
|
||||
async def cancel_order(self, *, order_id: str) -> dict[str, Any]: ...
|
||||
|
||||
async def place_market_order(
|
||||
self, *, pair: str, side: str, volume: float
|
||||
) -> dict[str, Any]: ...
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class RecoveryAction:
|
||||
order_id: str
|
||||
canceled: bool
|
||||
hedged: bool
|
||||
hedge_pair: str | None = None
|
||||
hedge_side: str | None = None
|
||||
hedge_volume: float | None = None
|
||||
cancel_response: dict[str, Any] | None = None
|
||||
hedge_response: dict[str, Any] | None = None
|
||||
reason: str | None = None
|
||||
|
||||
|
||||
class PartialFillRecovery:
|
||||
def __init__(self, rest_client: SupportsOrderLifecycle) -> None:
|
||||
self._rest_client = rest_client
|
||||
|
||||
@staticmethod
|
||||
def _counter_side(side: str) -> str:
|
||||
normalized = side.lower()
|
||||
if normalized == "buy":
|
||||
return "sell"
|
||||
if normalized == "sell":
|
||||
return "buy"
|
||||
raise ValueError("side must be 'buy' or 'sell'")
|
||||
|
||||
@staticmethod
|
||||
def _residual_volume(terminal_state: OrderFillState | None, requested_volume: float) -> float:
|
||||
if requested_volume <= 0.0:
|
||||
raise ValueError("requested_volume must be > 0.0")
|
||||
if terminal_state is None or terminal_state.filled_volume is None:
|
||||
return requested_volume
|
||||
residual = requested_volume - terminal_state.filled_volume
|
||||
return residual if residual > 0.0 else 0.0
|
||||
|
||||
async def recover_partial_fill(
|
||||
self,
|
||||
*,
|
||||
order_id: str,
|
||||
pair: str,
|
||||
side: str,
|
||||
requested_volume: float,
|
||||
fill_result: FillMonitorResult,
|
||||
) -> RecoveryAction:
|
||||
if not order_id.strip():
|
||||
raise ValueError("order_id must be non-empty")
|
||||
|
||||
cancel_response: dict[str, Any] | None = None
|
||||
hedge_response: dict[str, Any] | None = None
|
||||
hedged = False
|
||||
canceled = False
|
||||
reason = None
|
||||
|
||||
state = fill_result.terminal_state or fill_result.last_state
|
||||
residual_volume = self._residual_volume(state, requested_volume)
|
||||
|
||||
if state is not None and state.status in {"open", "partial"}:
|
||||
cancel_response = await self._rest_client.cancel_order(order_id=order_id)
|
||||
canceled = True
|
||||
reason = f"canceled_{state.status}_order"
|
||||
|
||||
if residual_volume > 0.0 and fill_result.timed_out:
|
||||
hedge_response = await self._rest_client.place_market_order(
|
||||
pair=pair,
|
||||
side=self._counter_side(side),
|
||||
volume=residual_volume,
|
||||
)
|
||||
hedged = True
|
||||
if reason is None:
|
||||
reason = "hedged_timed_out_order"
|
||||
|
||||
return RecoveryAction(
|
||||
order_id=order_id,
|
||||
canceled=canceled,
|
||||
hedged=hedged,
|
||||
hedge_pair=pair if hedged else None,
|
||||
hedge_side=self._counter_side(side) if hedged else None,
|
||||
hedge_volume=residual_volume if hedged else None,
|
||||
cancel_response=cancel_response,
|
||||
hedge_response=hedge_response,
|
||||
reason=reason,
|
||||
)
|
||||
@@ -0,0 +1,228 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable, Sequence
|
||||
from dataclasses import dataclass
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any, Protocol
|
||||
|
||||
from arbitrade.detection.engine import OpportunityEvent
|
||||
from arbitrade.storage.executions import AsyncExecutionWriter
|
||||
from arbitrade.storage.repositories import OrderRecord, PnLRecord, TradeRecord
|
||||
|
||||
|
||||
class SupportsOrderPlacement(Protocol):
|
||||
async def place_market_order(
|
||||
self, *, pair: str, side: str, volume: float
|
||||
) -> dict[str, Any]: ...
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class ExecutionLeg:
|
||||
from_currency: str
|
||||
to_currency: str
|
||||
pair: str
|
||||
side: str
|
||||
volume: float
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class TriangularExecutionResult:
|
||||
success: bool
|
||||
requested_legs: tuple[ExecutionLeg, ...]
|
||||
completed_legs: int
|
||||
responses: tuple[dict[str, Any], ...]
|
||||
failure_reason: str | None = None
|
||||
|
||||
|
||||
class TriangularExecutionSequencer:
|
||||
def __init__(
|
||||
self,
|
||||
rest_client: SupportsOrderPlacement,
|
||||
*,
|
||||
available_pairs: Sequence[str],
|
||||
volume_for_leg: Callable[[OpportunityEvent,
|
||||
ExecutionLeg, int], float] | None = None,
|
||||
execution_writer: AsyncExecutionWriter | None = None,
|
||||
) -> None:
|
||||
self._rest_client = rest_client
|
||||
self._available_pairs = {self._normalize_pair(
|
||||
pair) for pair in available_pairs}
|
||||
self._volume_for_leg = volume_for_leg or self._default_volume_for_leg
|
||||
self._execution_writer = execution_writer
|
||||
|
||||
@staticmethod
|
||||
def _normalize_pair(pair: str) -> str:
|
||||
normalized = pair.strip().upper().replace("-", "/")
|
||||
if "/" not in normalized:
|
||||
return normalized
|
||||
base, quote = normalized.split("/", 1)
|
||||
return f"{base}/{quote}"
|
||||
|
||||
@staticmethod
|
||||
def _default_volume_for_leg(event: OpportunityEvent, _leg: ExecutionLeg, _idx: int) -> float:
|
||||
if event.allocated_capital <= 0.0:
|
||||
raise ValueError("allocated_capital must be > 0.0")
|
||||
return event.allocated_capital
|
||||
|
||||
def _resolve_leg(self, from_currency: str, to_currency: str, volume: float) -> ExecutionLeg:
|
||||
from_cur = from_currency.upper()
|
||||
to_cur = to_currency.upper()
|
||||
|
||||
buy_pair = f"{to_cur}/{from_cur}"
|
||||
if buy_pair in self._available_pairs:
|
||||
return ExecutionLeg(
|
||||
from_currency=from_cur,
|
||||
to_currency=to_cur,
|
||||
pair=buy_pair,
|
||||
side="buy",
|
||||
volume=volume,
|
||||
)
|
||||
|
||||
sell_pair = f"{from_cur}/{to_cur}"
|
||||
if sell_pair in self._available_pairs:
|
||||
return ExecutionLeg(
|
||||
from_currency=from_cur,
|
||||
to_currency=to_cur,
|
||||
pair=sell_pair,
|
||||
side="sell",
|
||||
volume=volume,
|
||||
)
|
||||
|
||||
raise ValueError(f"No tradable pair for leg {from_cur}->{to_cur}")
|
||||
|
||||
def _build_legs(self, event: OpportunityEvent) -> tuple[ExecutionLeg, ...]:
|
||||
currencies = [part.strip().upper()
|
||||
for part in event.cycle.split("->") if part.strip()]
|
||||
if len(currencies) < 4 or currencies[0] != currencies[-1]:
|
||||
raise ValueError(
|
||||
"cycle must be a closed triangular path like A->B->C->A")
|
||||
|
||||
if len(currencies) != 4:
|
||||
raise ValueError(
|
||||
"cycle must contain exactly three unique currencies")
|
||||
|
||||
legs: list[ExecutionLeg] = []
|
||||
for idx in range(3):
|
||||
from_currency = currencies[idx]
|
||||
to_currency = currencies[idx + 1]
|
||||
placeholder_leg = ExecutionLeg(
|
||||
from_currency=from_currency,
|
||||
to_currency=to_currency,
|
||||
pair="",
|
||||
side="buy",
|
||||
volume=0.0,
|
||||
)
|
||||
volume = self._volume_for_leg(event, placeholder_leg, idx)
|
||||
if volume <= 0.0:
|
||||
raise ValueError(
|
||||
"volume_for_leg must return a positive volume")
|
||||
legs.append(self._resolve_leg(from_currency, to_currency, volume))
|
||||
|
||||
return tuple(legs)
|
||||
|
||||
@staticmethod
|
||||
def _trade_ref_for_event(event: OpportunityEvent) -> str:
|
||||
material = (
|
||||
f"{event.cycle}|{event.updated_pair}|"
|
||||
f"{event.detected_at.timestamp():.6f}|"
|
||||
f"{event.allocated_capital:.12f}"
|
||||
)
|
||||
return material.encode("utf-8").hex()[:32]
|
||||
|
||||
@staticmethod
|
||||
def _order_ref_from_response(response: dict[str, Any], default: str) -> str:
|
||||
txid = response.get("txid")
|
||||
if isinstance(txid, list) and txid:
|
||||
return str(txid[0])
|
||||
if isinstance(txid, str) and txid.strip():
|
||||
return txid
|
||||
return default
|
||||
|
||||
async def execute(self, event: OpportunityEvent) -> TriangularExecutionResult:
|
||||
legs = self._build_legs(event)
|
||||
responses: list[dict[str, Any]] = []
|
||||
trade_ref = self._trade_ref_for_event(event)
|
||||
started_at = datetime.now(UTC)
|
||||
|
||||
for idx, leg in enumerate(legs):
|
||||
try:
|
||||
response = await self._rest_client.place_market_order(
|
||||
pair=leg.pair,
|
||||
side=leg.side,
|
||||
volume=leg.volume,
|
||||
)
|
||||
except Exception as exc:
|
||||
if self._execution_writer is not None:
|
||||
await self._execution_writer.enqueue(
|
||||
TradeRecord(
|
||||
trade_ref=trade_ref,
|
||||
started_at=started_at,
|
||||
finished_at=datetime.now(UTC),
|
||||
status="failed",
|
||||
realized_pnl=None,
|
||||
estimated_pnl=event.est_profit,
|
||||
capital_used=event.allocated_capital,
|
||||
cycle=event.cycle,
|
||||
leg_count=len(legs),
|
||||
)
|
||||
)
|
||||
return TriangularExecutionResult(
|
||||
success=False,
|
||||
requested_legs=legs,
|
||||
completed_legs=idx,
|
||||
responses=tuple(responses),
|
||||
failure_reason=str(exc),
|
||||
)
|
||||
|
||||
responses.append(response)
|
||||
|
||||
if self._execution_writer is not None:
|
||||
order_ref = self._order_ref_from_response(
|
||||
response, f"leg-{idx}")
|
||||
await self._execution_writer.enqueue(
|
||||
OrderRecord(
|
||||
trade_ref=trade_ref,
|
||||
order_ref=order_ref,
|
||||
leg_index=idx,
|
||||
pair=leg.pair,
|
||||
side=leg.side,
|
||||
volume=leg.volume,
|
||||
user_ref=None,
|
||||
status=str(response.get("status", "submitted")),
|
||||
filled_volume=None,
|
||||
avg_price=None,
|
||||
raw_response=response,
|
||||
recorded_at=datetime.now(UTC),
|
||||
)
|
||||
)
|
||||
|
||||
if self._execution_writer is not None:
|
||||
await self._execution_writer.enqueue(
|
||||
TradeRecord(
|
||||
trade_ref=trade_ref,
|
||||
started_at=started_at,
|
||||
finished_at=datetime.now(UTC),
|
||||
status="filled",
|
||||
realized_pnl=None,
|
||||
estimated_pnl=event.est_profit,
|
||||
capital_used=event.allocated_capital,
|
||||
cycle=event.cycle,
|
||||
leg_count=len(legs),
|
||||
)
|
||||
)
|
||||
await self._execution_writer.enqueue(
|
||||
PnLRecord(
|
||||
trade_ref=trade_ref,
|
||||
recorded_at=datetime.now(UTC),
|
||||
kind="estimated",
|
||||
pnl_usd=event.est_profit,
|
||||
source="triangular_sequencer",
|
||||
)
|
||||
)
|
||||
|
||||
return TriangularExecutionResult(
|
||||
success=True,
|
||||
requested_legs=legs,
|
||||
completed_legs=len(legs),
|
||||
responses=tuple(responses),
|
||||
)
|
||||
@@ -0,0 +1 @@
|
||||
"""Storage helpers."""
|
||||
|
||||
@@ -5,9 +5,12 @@ from contextlib import contextmanager
|
||||
from pathlib import Path
|
||||
|
||||
import duckdb
|
||||
import structlog
|
||||
|
||||
from arbitrade.config.settings import Settings
|
||||
|
||||
_LOG = structlog.get_logger(__name__)
|
||||
|
||||
SCHEMA_SQL = """
|
||||
CREATE TABLE IF NOT EXISTS schema_migrations (
|
||||
version INTEGER PRIMARY KEY,
|
||||
@@ -26,11 +29,40 @@ CREATE TABLE IF NOT EXISTS opportunities (
|
||||
|
||||
CREATE TABLE IF NOT EXISTS trades (
|
||||
id UUID DEFAULT uuid(),
|
||||
trade_ref VARCHAR NOT NULL,
|
||||
started_at TIMESTAMP NOT NULL,
|
||||
finished_at TIMESTAMP,
|
||||
status VARCHAR NOT NULL,
|
||||
realized_pnl DOUBLE,
|
||||
capital_used DOUBLE
|
||||
estimated_pnl DOUBLE,
|
||||
capital_used DOUBLE,
|
||||
cycle VARCHAR,
|
||||
leg_count INTEGER
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS orders (
|
||||
id UUID DEFAULT uuid(),
|
||||
trade_ref VARCHAR NOT NULL,
|
||||
order_ref VARCHAR NOT NULL,
|
||||
leg_index INTEGER NOT NULL,
|
||||
pair VARCHAR NOT NULL,
|
||||
side VARCHAR NOT NULL,
|
||||
volume DOUBLE NOT NULL,
|
||||
user_ref INTEGER,
|
||||
status VARCHAR,
|
||||
filled_volume DOUBLE,
|
||||
avg_price DOUBLE,
|
||||
raw_response JSON,
|
||||
recorded_at TIMESTAMP NOT NULL
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS pnl_events (
|
||||
id UUID DEFAULT uuid(),
|
||||
trade_ref VARCHAR NOT NULL,
|
||||
recorded_at TIMESTAMP NOT NULL,
|
||||
kind VARCHAR NOT NULL,
|
||||
pnl_usd DOUBLE NOT NULL,
|
||||
source VARCHAR NOT NULL
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS portfolio_snapshots (
|
||||
@@ -53,10 +85,19 @@ class DuckDBStore:
|
||||
def __init__(self, settings: Settings) -> None:
|
||||
self._db_path = Path(settings.duckdb_path)
|
||||
self._db_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
self._use_memory_fallback = False
|
||||
|
||||
@contextmanager
|
||||
def connect(self) -> Iterator[duckdb.DuckDBPyConnection]:
|
||||
conn = duckdb.connect(str(self._db_path))
|
||||
try:
|
||||
conn = duckdb.connect(str(self._db_path))
|
||||
except duckdb.IOException:
|
||||
if not self._use_memory_fallback:
|
||||
_LOG.warning(
|
||||
"duckdb_path_unavailable_falling_back_to_memory", path=str(self._db_path)
|
||||
)
|
||||
self._use_memory_fallback = True
|
||||
conn = duckdb.connect(":memory:")
|
||||
try:
|
||||
yield conn
|
||||
finally:
|
||||
|
||||
@@ -0,0 +1,66 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
|
||||
import structlog
|
||||
|
||||
from arbitrade.storage.repositories import (
|
||||
OrderRecord,
|
||||
OrderRepository,
|
||||
PnLRecord,
|
||||
PnLRepository,
|
||||
TradeRecord,
|
||||
TradeRepository,
|
||||
)
|
||||
|
||||
_LOG = structlog.get_logger(__name__)
|
||||
|
||||
|
||||
class AsyncExecutionWriter:
|
||||
def __init__(
|
||||
self,
|
||||
trade_repository: TradeRepository,
|
||||
order_repository: OrderRepository,
|
||||
pnl_repository: PnLRepository,
|
||||
max_queue_size: int = 50_000,
|
||||
) -> None:
|
||||
self._trade_repository = trade_repository
|
||||
self._order_repository = order_repository
|
||||
self._pnl_repository = pnl_repository
|
||||
self._queue: asyncio.Queue[TradeRecord | OrderRecord | PnLRecord] = asyncio.Queue(
|
||||
maxsize=max_queue_size
|
||||
)
|
||||
self._task: asyncio.Task[None] | None = None
|
||||
self._stop = asyncio.Event()
|
||||
|
||||
async def start(self) -> None:
|
||||
if self._task is None or self._task.done():
|
||||
self._stop.clear()
|
||||
self._task = asyncio.create_task(self._run(), name="execution-writer")
|
||||
|
||||
async def stop(self) -> None:
|
||||
self._stop.set()
|
||||
if self._task is not None:
|
||||
await self._task
|
||||
|
||||
async def enqueue(self, record: TradeRecord | OrderRecord | PnLRecord) -> None:
|
||||
await self._queue.put(record)
|
||||
|
||||
async def _run(self) -> None:
|
||||
while not (self._stop.is_set() and self._queue.empty()):
|
||||
try:
|
||||
record = await asyncio.wait_for(self._queue.get(), timeout=0.5)
|
||||
except TimeoutError:
|
||||
continue
|
||||
|
||||
try:
|
||||
if isinstance(record, TradeRecord):
|
||||
self._trade_repository.insert(record)
|
||||
elif isinstance(record, OrderRecord):
|
||||
self._order_repository.insert(record)
|
||||
else:
|
||||
self._pnl_repository.insert(record)
|
||||
except Exception as exc:
|
||||
_LOG.error("execution_write_failed", error=str(exc))
|
||||
finally:
|
||||
self._queue.task_done()
|
||||
@@ -28,6 +28,44 @@ class OpportunityRecord:
|
||||
executed: bool = False
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class TradeRecord:
|
||||
trade_ref: str
|
||||
started_at: datetime
|
||||
finished_at: datetime | None
|
||||
status: str
|
||||
realized_pnl: float | None
|
||||
estimated_pnl: float | None
|
||||
capital_used: float | None
|
||||
cycle: str | None = None
|
||||
leg_count: int | None = None
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class OrderRecord:
|
||||
trade_ref: str
|
||||
order_ref: str
|
||||
leg_index: int
|
||||
pair: str
|
||||
side: str
|
||||
volume: float
|
||||
user_ref: int | None
|
||||
status: str | None
|
||||
filled_volume: float | None
|
||||
avg_price: float | None
|
||||
raw_response: dict[str, Any]
|
||||
recorded_at: datetime
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class PnLRecord:
|
||||
trade_ref: str
|
||||
recorded_at: datetime
|
||||
kind: str
|
||||
pnl_usd: float
|
||||
source: str
|
||||
|
||||
|
||||
class MarketSnapshotRepository:
|
||||
def __init__(self, store: DuckDBStore) -> None:
|
||||
self._store = store
|
||||
@@ -76,3 +114,106 @@ class OpportunityRepository:
|
||||
record.executed,
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
class TradeRepository:
|
||||
def __init__(self, store: DuckDBStore) -> None:
|
||||
self._store = store
|
||||
|
||||
def insert(self, record: TradeRecord) -> None:
|
||||
with self._store.connect() as conn:
|
||||
conn.execute(
|
||||
"""
|
||||
INSERT INTO trades (
|
||||
trade_ref,
|
||||
started_at,
|
||||
finished_at,
|
||||
status,
|
||||
realized_pnl,
|
||||
estimated_pnl,
|
||||
capital_used,
|
||||
cycle,
|
||||
leg_count
|
||||
)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
""",
|
||||
[
|
||||
record.trade_ref,
|
||||
record.started_at,
|
||||
record.finished_at,
|
||||
record.status,
|
||||
record.realized_pnl,
|
||||
record.estimated_pnl,
|
||||
record.capital_used,
|
||||
record.cycle,
|
||||
record.leg_count,
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
class OrderRepository:
|
||||
def __init__(self, store: DuckDBStore) -> None:
|
||||
self._store = store
|
||||
|
||||
def insert(self, record: OrderRecord) -> None:
|
||||
with self._store.connect() as conn:
|
||||
conn.execute(
|
||||
"""
|
||||
INSERT INTO orders (
|
||||
trade_ref,
|
||||
order_ref,
|
||||
leg_index,
|
||||
pair,
|
||||
side,
|
||||
volume,
|
||||
user_ref,
|
||||
status,
|
||||
filled_volume,
|
||||
avg_price,
|
||||
raw_response,
|
||||
recorded_at
|
||||
)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
""",
|
||||
[
|
||||
record.trade_ref,
|
||||
record.order_ref,
|
||||
record.leg_index,
|
||||
record.pair,
|
||||
record.side,
|
||||
record.volume,
|
||||
record.user_ref,
|
||||
record.status,
|
||||
record.filled_volume,
|
||||
record.avg_price,
|
||||
orjson.dumps(record.raw_response).decode("utf-8"),
|
||||
record.recorded_at,
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
class PnLRepository:
|
||||
def __init__(self, store: DuckDBStore) -> None:
|
||||
self._store = store
|
||||
|
||||
def insert(self, record: PnLRecord) -> None:
|
||||
with self._store.connect() as conn:
|
||||
conn.execute(
|
||||
"""
|
||||
INSERT INTO pnl_events (
|
||||
trade_ref,
|
||||
recorded_at,
|
||||
kind,
|
||||
pnl_usd,
|
||||
source
|
||||
)
|
||||
VALUES (?, ?, ?, ?, ?)
|
||||
""",
|
||||
[
|
||||
record.trade_ref,
|
||||
record.recorded_at,
|
||||
record.kind,
|
||||
record.pnl_usd,
|
||||
record.source,
|
||||
],
|
||||
)
|
||||
|
||||
@@ -0,0 +1,90 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from datetime import UTC, datetime
|
||||
|
||||
import pytest
|
||||
|
||||
from arbitrade.config.settings import Settings
|
||||
from arbitrade.detection.engine import OpportunityEvent
|
||||
from arbitrade.execution.sequencer import TriangularExecutionSequencer
|
||||
from arbitrade.storage.db import DuckDBStore
|
||||
from arbitrade.storage.executions import AsyncExecutionWriter
|
||||
from arbitrade.storage.repositories import OrderRepository, PnLRepository, TradeRepository
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class _FakeRestClient:
|
||||
calls: int = 0
|
||||
|
||||
async def place_market_order(self, *, pair: str, side: str, volume: float) -> dict[str, object]:
|
||||
self.calls += 1
|
||||
return {"txid": [f"tx-{self.calls}"], "status": "submitted"}
|
||||
|
||||
|
||||
def _sample_event() -> OpportunityEvent:
|
||||
return OpportunityEvent(
|
||||
detected_at=datetime.now(UTC),
|
||||
cycle="USD->BTC->ETH->USD",
|
||||
updated_pair="BTC/USD",
|
||||
gross_rate=1.04,
|
||||
net_rate=1.03,
|
||||
gross_pct=4.0,
|
||||
net_pct=3.0,
|
||||
est_profit=0.03,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execution_writer_persists_trade_order_and_pnl(tmp_path) -> None:
|
||||
settings = Settings(_env_file=None, DUCKDB_PATH=tmp_path / "exec.duckdb")
|
||||
store = DuckDBStore(settings)
|
||||
store.migrate()
|
||||
writer = AsyncExecutionWriter(
|
||||
TradeRepository(store),
|
||||
OrderRepository(store),
|
||||
PnLRepository(store),
|
||||
max_queue_size=10,
|
||||
)
|
||||
await writer.start()
|
||||
|
||||
client = _FakeRestClient()
|
||||
sequencer = TriangularExecutionSequencer(
|
||||
client,
|
||||
available_pairs=["BTC/USD", "ETH/BTC", "ETH/USD"],
|
||||
execution_writer=writer,
|
||||
)
|
||||
|
||||
result = await sequencer.execute(_sample_event())
|
||||
await writer.stop()
|
||||
|
||||
assert result.success
|
||||
assert client.calls == 3
|
||||
|
||||
with store.connect() as conn:
|
||||
trades = conn.execute(
|
||||
"SELECT trade_ref, status, estimated_pnl, capital_used, cycle, leg_count FROM trades"
|
||||
).fetchall()
|
||||
orders = conn.execute(
|
||||
"SELECT trade_ref, order_ref, leg_index, pair, side, volume, status "
|
||||
"FROM orders ORDER BY leg_index"
|
||||
).fetchall()
|
||||
pnls = conn.execute(
|
||||
"SELECT trade_ref, kind, pnl_usd, source FROM pnl_events").fetchall()
|
||||
|
||||
assert len(trades) == 1
|
||||
assert trades[0][1] == "filled"
|
||||
assert trades[0][2] == 0.03
|
||||
assert trades[0][3] == 1.0
|
||||
assert trades[0][4] == "USD->BTC->ETH->USD"
|
||||
assert trades[0][5] == 3
|
||||
|
||||
assert len(orders) == 3
|
||||
assert orders[0][2] == 0
|
||||
assert orders[1][2] == 1
|
||||
assert orders[2][2] == 2
|
||||
assert orders[0][6] == "submitted"
|
||||
|
||||
assert len(pnls) == 1
|
||||
assert pnls[0][1] == "estimated"
|
||||
assert pnls[0][2] == 0.03
|
||||
@@ -0,0 +1,93 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
|
||||
from arbitrade.detection.engine import OpportunityEvent
|
||||
from arbitrade.execution.sequencer import TriangularExecutionSequencer
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class _FakeRestClient:
|
||||
fail_at_call: int | None = None
|
||||
calls: list[dict[str, Any]] = field(default_factory=list)
|
||||
|
||||
async def place_market_order(self, *, pair: str, side: str, volume: float) -> dict[str, Any]:
|
||||
call_number = len(self.calls) + 1
|
||||
if self.fail_at_call is not None and call_number == self.fail_at_call:
|
||||
raise RuntimeError("simulated failure")
|
||||
|
||||
payload = {"pair": pair, "side": side, "volume": volume}
|
||||
self.calls.append(payload)
|
||||
return {"txid": [f"tx-{call_number}"]}
|
||||
|
||||
|
||||
def _sample_event(cycle: str = "USD->BTC->ETH->USD") -> OpportunityEvent:
|
||||
return OpportunityEvent(
|
||||
detected_at=datetime.now(UTC),
|
||||
cycle=cycle,
|
||||
updated_pair="BTC/USD",
|
||||
gross_rate=1.02,
|
||||
net_rate=1.01,
|
||||
gross_pct=2.0,
|
||||
net_pct=1.0,
|
||||
est_profit=1.0,
|
||||
allocated_capital=10.0,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_triangular_sequencer_executes_legs_in_order() -> None:
|
||||
client = _FakeRestClient()
|
||||
sequencer = TriangularExecutionSequencer(
|
||||
client,
|
||||
available_pairs=["BTC/USD", "ETH/BTC", "ETH/USD"],
|
||||
)
|
||||
|
||||
result = await sequencer.execute(_sample_event())
|
||||
|
||||
assert result.success
|
||||
assert result.completed_legs == 3
|
||||
assert [call["pair"] for call in client.calls] == ["BTC/USD", "ETH/BTC", "ETH/USD"]
|
||||
assert [call["side"] for call in client.calls] == ["buy", "buy", "sell"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_triangular_sequencer_stops_on_failed_leg() -> None:
|
||||
client = _FakeRestClient(fail_at_call=2)
|
||||
sequencer = TriangularExecutionSequencer(
|
||||
client,
|
||||
available_pairs=["BTC/USD", "ETH/BTC", "ETH/USD"],
|
||||
)
|
||||
|
||||
result = await sequencer.execute(_sample_event())
|
||||
|
||||
assert not result.success
|
||||
assert result.completed_legs == 1
|
||||
assert result.failure_reason is not None
|
||||
assert len(client.calls) == 1
|
||||
|
||||
|
||||
def test_triangular_sequencer_rejects_non_closed_cycle() -> None:
|
||||
client = _FakeRestClient()
|
||||
sequencer = TriangularExecutionSequencer(
|
||||
client,
|
||||
available_pairs=["BTC/USD", "ETH/BTC", "ETH/USD"],
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="closed triangular path"):
|
||||
sequencer._build_legs(_sample_event(cycle="USD->BTC->ETH"))
|
||||
|
||||
|
||||
def test_triangular_sequencer_rejects_missing_pair() -> None:
|
||||
client = _FakeRestClient()
|
||||
sequencer = TriangularExecutionSequencer(
|
||||
client,
|
||||
available_pairs=["BTC/USD", "ETH/BTC"],
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="No tradable pair"):
|
||||
sequencer._build_legs(_sample_event())
|
||||
@@ -0,0 +1,112 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
|
||||
from arbitrade.execution.fill_monitor import FillMonitor, OrderFillState
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class _FakePollClient:
|
||||
responses: list[dict[str, Any]]
|
||||
calls: int = 0
|
||||
|
||||
async def query_order(self, *, order_id: str, include_trades: bool = True) -> dict[str, Any]:
|
||||
self.calls += 1
|
||||
if self.responses:
|
||||
return self.responses.pop(0)
|
||||
return {order_id: {"status": "open", "vol_exec": "0.0", "price": "0.0"}}
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class _FakeWsProvider:
|
||||
states: list[OrderFillState] = field(default_factory=list)
|
||||
|
||||
def get(self, _order_id: str) -> OrderFillState | None:
|
||||
if not self.states:
|
||||
return None
|
||||
return self.states.pop(0)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fill_monitor_detects_terminal_state_via_polling() -> None:
|
||||
order_id = "order-1"
|
||||
client = _FakePollClient(
|
||||
responses=[
|
||||
{order_id: {"status": "open", "vol_exec": "0.0", "price": "0.0"}},
|
||||
{order_id: {"status": "closed", "vol_exec": "1.0", "price": "100.0"}},
|
||||
]
|
||||
)
|
||||
monitor = FillMonitor(client, poll_interval_seconds=0.001, max_wait_seconds=0.1)
|
||||
|
||||
result = await monitor.wait_for_terminal_fill(order_id)
|
||||
|
||||
assert not result.timed_out
|
||||
assert result.terminal_state is not None
|
||||
assert result.terminal_state.status == "closed"
|
||||
assert result.terminal_state.filled_volume == 1.0
|
||||
assert result.terminal_state.source == "rest_poll"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fill_monitor_times_out_when_no_terminal_state() -> None:
|
||||
order_id = "order-2"
|
||||
client = _FakePollClient(
|
||||
responses=[
|
||||
{order_id: {"status": "open", "vol_exec": "0.1", "price": "100.0"}},
|
||||
{order_id: {"status": "partial", "vol_exec": "0.2", "price": "100.0"}},
|
||||
{order_id: {"status": "open", "vol_exec": "0.2", "price": "100.0"}},
|
||||
]
|
||||
)
|
||||
monitor = FillMonitor(client, poll_interval_seconds=0.001, max_wait_seconds=0.01)
|
||||
|
||||
result = await monitor.wait_for_terminal_fill(order_id)
|
||||
|
||||
assert result.timed_out
|
||||
assert result.terminal_state is None
|
||||
assert result.last_state is not None
|
||||
assert result.last_state.status in {"open", "partial"}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fill_monitor_uses_ws_status_for_fast_terminal_detection() -> None:
|
||||
order_id = "order-3"
|
||||
ws_provider = _FakeWsProvider(
|
||||
states=[
|
||||
OrderFillState(
|
||||
order_id=order_id,
|
||||
status="closed",
|
||||
filled_volume=0.5,
|
||||
avg_price=200.0,
|
||||
updated_at=datetime.now(UTC),
|
||||
source="ws",
|
||||
)
|
||||
]
|
||||
)
|
||||
client = _FakePollClient(responses=[])
|
||||
monitor = FillMonitor(
|
||||
client,
|
||||
poll_interval_seconds=0.001,
|
||||
max_wait_seconds=0.1,
|
||||
ws_status_provider=ws_provider.get,
|
||||
)
|
||||
|
||||
result = await monitor.wait_for_terminal_fill(order_id)
|
||||
|
||||
assert not result.timed_out
|
||||
assert result.terminal_state is not None
|
||||
assert result.terminal_state.source == "ws"
|
||||
assert client.calls == 0
|
||||
|
||||
|
||||
def test_fill_monitor_rejects_invalid_configuration() -> None:
|
||||
client = _FakePollClient(responses=[])
|
||||
|
||||
with pytest.raises(ValueError, match="poll_interval_seconds"):
|
||||
FillMonitor(client, poll_interval_seconds=0.0)
|
||||
|
||||
with pytest.raises(ValueError, match="max_wait_seconds"):
|
||||
FillMonitor(client, max_wait_seconds=0.0)
|
||||
@@ -0,0 +1,109 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
|
||||
from arbitrade.detection.engine import OpportunityEvent
|
||||
from arbitrade.execution.idempotency import IdempotencyKeyFactory, OrderReconciler
|
||||
from arbitrade.execution.sequencer import ExecutionLeg
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class _FakeHistoryClient:
|
||||
response: dict[str, Any]
|
||||
|
||||
async def query_order(self, *, order_id: str, include_trades: bool = True) -> dict[str, Any]:
|
||||
return self.response
|
||||
|
||||
|
||||
def _sample_event() -> OpportunityEvent:
|
||||
return OpportunityEvent(
|
||||
detected_at=datetime.now(UTC),
|
||||
cycle="USD->BTC->ETH->USD",
|
||||
updated_pair="BTC/USD",
|
||||
gross_rate=1.02,
|
||||
net_rate=1.01,
|
||||
gross_pct=2.0,
|
||||
net_pct=1.0,
|
||||
est_profit=1.0,
|
||||
allocated_capital=10.0,
|
||||
)
|
||||
|
||||
|
||||
def test_idempotency_key_factory_is_deterministic() -> None:
|
||||
factory = IdempotencyKeyFactory()
|
||||
event = _sample_event()
|
||||
leg = ExecutionLeg(
|
||||
from_currency="USD",
|
||||
to_currency="BTC",
|
||||
pair="BTC/USD",
|
||||
side="buy",
|
||||
volume=10.0,
|
||||
)
|
||||
|
||||
first = factory.user_ref_for_leg(event, leg, 0)
|
||||
second = factory.user_ref_for_leg(event, leg, 0)
|
||||
|
||||
assert first == second
|
||||
assert first > 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_order_reconciler_maps_query_response_to_report() -> None:
|
||||
client = _FakeHistoryClient(
|
||||
response={
|
||||
"order-1": {
|
||||
"status": "closed",
|
||||
"vol_exec": "5.0",
|
||||
"price": "100.0",
|
||||
"pair": "BTC/USD",
|
||||
"type": "buy",
|
||||
"userref": 12345,
|
||||
}
|
||||
}
|
||||
)
|
||||
reconciler = OrderReconciler(client)
|
||||
|
||||
report = await reconciler.reconcile_order(
|
||||
order_id="order-1",
|
||||
user_ref=12345,
|
||||
expected_pair="BTC/USD",
|
||||
expected_side="buy",
|
||||
expected_volume=10.0,
|
||||
)
|
||||
|
||||
assert report.is_terminal
|
||||
assert report.matches_request
|
||||
assert report.status == "closed"
|
||||
assert report.filled_volume == 5.0
|
||||
assert report.avg_price == 100.0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_order_reconciler_marks_mismatch() -> None:
|
||||
client = _FakeHistoryClient(
|
||||
response={
|
||||
"order-1": {
|
||||
"status": "closed",
|
||||
"vol_exec": "5.0",
|
||||
"price": "100.0",
|
||||
"pair": "ETH/USD",
|
||||
"type": "sell",
|
||||
"userref": 999,
|
||||
}
|
||||
}
|
||||
)
|
||||
reconciler = OrderReconciler(client)
|
||||
|
||||
report = await reconciler.reconcile_order(
|
||||
order_id="order-1",
|
||||
user_ref=12345,
|
||||
expected_pair="BTC/USD",
|
||||
expected_side="buy",
|
||||
expected_volume=10.0,
|
||||
)
|
||||
|
||||
assert not report.matches_request
|
||||
@@ -110,3 +110,139 @@ def test_compliance_detects_insecure_config() -> None:
|
||||
assert any("below 1.0" in issue for issue in issues)
|
||||
assert any("ATTEMPTS" in issue for issue in issues)
|
||||
assert any("BASE_DELAY" in issue for issue in issues)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_place_market_order_posts_add_order_payload() -> None:
|
||||
settings = Settings(
|
||||
_env_file=None,
|
||||
KRAKEN_API_KEY="key",
|
||||
KRAKEN_API_SECRET="c2VjcmV0", # base64("secret")
|
||||
kraken_private_rate_limit_seconds=0.0,
|
||||
)
|
||||
client = KrakenRestClient(settings)
|
||||
|
||||
with respx.mock(base_url=settings.kraken_rest_url) as mock_router:
|
||||
route = mock_router.post("/0/private/AddOrder").respond(
|
||||
200,
|
||||
json={"error": [], "result": {"txid": ["m1"]}},
|
||||
)
|
||||
payload = await client.place_market_order(
|
||||
pair="XBTUSD",
|
||||
side="buy",
|
||||
volume=0.05,
|
||||
)
|
||||
|
||||
await client.close()
|
||||
request_body = route.calls.last.request.content.decode()
|
||||
assert "pair=XBTUSD" in request_body
|
||||
assert "type=buy" in request_body
|
||||
assert "ordertype=market" in request_body
|
||||
assert "volume=0.05" in request_body
|
||||
assert payload["txid"] == ["m1"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_place_limit_order_posts_add_order_payload() -> None:
|
||||
settings = Settings(
|
||||
_env_file=None,
|
||||
KRAKEN_API_KEY="key",
|
||||
KRAKEN_API_SECRET="c2VjcmV0", # base64("secret")
|
||||
kraken_private_rate_limit_seconds=0.0,
|
||||
)
|
||||
client = KrakenRestClient(settings)
|
||||
|
||||
with respx.mock(base_url=settings.kraken_rest_url) as mock_router:
|
||||
route = mock_router.post("/0/private/AddOrder").respond(
|
||||
200,
|
||||
json={"error": [], "result": {"txid": ["l1"]}},
|
||||
)
|
||||
payload = await client.place_limit_order(
|
||||
pair="ETHUSD",
|
||||
side="sell",
|
||||
volume=1.5,
|
||||
price=3500.0,
|
||||
)
|
||||
|
||||
await client.close()
|
||||
request_body = route.calls.last.request.content.decode()
|
||||
assert "pair=ETHUSD" in request_body
|
||||
assert "type=sell" in request_body
|
||||
assert "ordertype=limit" in request_body
|
||||
assert "price=3500.0" in request_body
|
||||
assert "volume=1.5" in request_body
|
||||
assert payload["txid"] == ["l1"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_place_order_validates_inputs() -> None:
|
||||
settings = Settings(
|
||||
_env_file=None,
|
||||
KRAKEN_API_KEY="key",
|
||||
KRAKEN_API_SECRET="c2VjcmV0", # base64("secret")
|
||||
kraken_private_rate_limit_seconds=0.0,
|
||||
)
|
||||
client = KrakenRestClient(settings)
|
||||
|
||||
with pytest.raises(ValueError, match="side"):
|
||||
await client.place_market_order(pair="XBTUSD", side="hold", volume=0.1)
|
||||
|
||||
with pytest.raises(ValueError, match="volume"):
|
||||
await client.place_market_order(pair="XBTUSD", side="buy", volume=0.0)
|
||||
|
||||
with pytest.raises(ValueError, match="price"):
|
||||
await client.place_limit_order(
|
||||
pair="XBTUSD",
|
||||
side="buy",
|
||||
volume=0.1,
|
||||
price=0.0,
|
||||
)
|
||||
|
||||
await client.close()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_order_posts_query_orders_payload() -> None:
|
||||
settings = Settings(
|
||||
_env_file=None,
|
||||
KRAKEN_API_KEY="key",
|
||||
KRAKEN_API_SECRET="c2VjcmV0", # base64("secret")
|
||||
kraken_private_rate_limit_seconds=0.0,
|
||||
)
|
||||
client = KrakenRestClient(settings)
|
||||
|
||||
with respx.mock(base_url=settings.kraken_rest_url) as mock_router:
|
||||
route = mock_router.post("/0/private/QueryOrders").respond(
|
||||
200,
|
||||
json={"error": [], "result": {"order-1": {"status": "closed"}}},
|
||||
)
|
||||
payload = await client.query_order(order_id="order-1", include_trades=False)
|
||||
|
||||
await client.close()
|
||||
request_body = route.calls.last.request.content.decode()
|
||||
assert "txid=order-1" in request_body
|
||||
assert "trades=false" in request_body
|
||||
assert payload["order-1"]["status"] == "closed"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cancel_order_posts_cancel_order_payload() -> None:
|
||||
settings = Settings(
|
||||
_env_file=None,
|
||||
KRAKEN_API_KEY="key",
|
||||
KRAKEN_API_SECRET="c2VjcmV0", # base64("secret")
|
||||
kraken_private_rate_limit_seconds=0.0,
|
||||
)
|
||||
client = KrakenRestClient(settings)
|
||||
|
||||
with respx.mock(base_url=settings.kraken_rest_url) as mock_router:
|
||||
route = mock_router.post("/0/private/CancelOrder").respond(
|
||||
200,
|
||||
json={"error": [], "result": {"count": 1}},
|
||||
)
|
||||
payload = await client.cancel_order(order_id="order-1")
|
||||
|
||||
await client.close()
|
||||
request_body = route.calls.last.request.content.decode()
|
||||
assert "txid=order-1" in request_body
|
||||
assert payload["count"] == 1
|
||||
|
||||
@@ -0,0 +1,96 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
|
||||
from arbitrade.execution.fill_monitor import FillMonitorResult, OrderFillState
|
||||
from arbitrade.execution.recovery import PartialFillRecovery
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class _FakeRestClient:
|
||||
cancel_calls: list[str] = None # type: ignore[assignment]
|
||||
market_calls: list[dict[str, Any]] = None # type: ignore[assignment]
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
self.cancel_calls = []
|
||||
self.market_calls = []
|
||||
|
||||
async def cancel_order(self, *, order_id: str) -> dict[str, Any]:
|
||||
self.cancel_calls.append(order_id)
|
||||
return {"result": {"count": 1}}
|
||||
|
||||
async def place_market_order(self, *, pair: str, side: str, volume: float) -> dict[str, Any]:
|
||||
self.market_calls.append({"pair": pair, "side": side, "volume": volume})
|
||||
return {"txid": ["hedge-1"]}
|
||||
|
||||
|
||||
def _monitor_result(
|
||||
*, status: str, filled_volume: float | None, timed_out: bool
|
||||
) -> FillMonitorResult:
|
||||
state = OrderFillState(
|
||||
order_id="order-1",
|
||||
status=status,
|
||||
filled_volume=filled_volume,
|
||||
avg_price=100.0,
|
||||
updated_at=datetime.now(UTC),
|
||||
source="rest_poll",
|
||||
)
|
||||
return FillMonitorResult(
|
||||
order_id="order-1",
|
||||
timed_out=timed_out,
|
||||
terminal_state=None if status in {"open", "partial"} else state,
|
||||
last_state=state,
|
||||
elapsed_seconds=1.0,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_partial_fill_recovery_cancels_open_order_and_hedges_residual() -> None:
|
||||
client = _FakeRestClient()
|
||||
recovery = PartialFillRecovery(client)
|
||||
|
||||
result = await recovery.recover_partial_fill(
|
||||
order_id="order-1",
|
||||
pair="BTC/USD",
|
||||
side="buy",
|
||||
requested_volume=10.0,
|
||||
fill_result=_monitor_result(status="partial", filled_volume=4.0, timed_out=True),
|
||||
)
|
||||
|
||||
assert result.canceled
|
||||
assert result.hedged
|
||||
assert client.cancel_calls == ["order-1"]
|
||||
assert client.market_calls == [{"pair": "BTC/USD", "side": "sell", "volume": 6.0}]
|
||||
assert result.hedge_volume == 6.0
|
||||
assert result.reason == "canceled_partial_order"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_partial_fill_recovery_no_hedge_when_no_residual() -> None:
|
||||
client = _FakeRestClient()
|
||||
recovery = PartialFillRecovery(client)
|
||||
|
||||
result = await recovery.recover_partial_fill(
|
||||
order_id="order-1",
|
||||
pair="BTC/USD",
|
||||
side="sell",
|
||||
requested_volume=5.0,
|
||||
fill_result=_monitor_result(status="closed", filled_volume=5.0, timed_out=False),
|
||||
)
|
||||
|
||||
assert not result.canceled
|
||||
assert not result.hedged
|
||||
assert client.cancel_calls == []
|
||||
assert client.market_calls == []
|
||||
|
||||
|
||||
def test_partial_fill_recovery_rejects_invalid_volume() -> None:
|
||||
client = _FakeRestClient()
|
||||
recovery = PartialFillRecovery(client)
|
||||
|
||||
with pytest.raises(ValueError, match="requested_volume"):
|
||||
recovery._residual_volume(None, 0.0)
|
||||
Reference in New Issue
Block a user