feat: Implement idempotency and recovery mechanisms for order execution

- Add IdempotencyKeyFactory for generating unique user references based on execution legs.
- Introduce OrderReconciler to reconcile order statuses with historical data.
- Implement PartialFillRecovery to handle partial fills by canceling orders and placing hedges.
- Create TriangularExecutionSequencer for executing triangular arbitrage strategies.
- Enhance storage with new tables for trades, orders, and PnL events.
- Develop AsyncExecutionWriter for asynchronous writing of execution records to the database.
- Add unit tests for execution persistence, sequencer behavior, fill monitoring, and idempotency checks.
- Update KrakenRestClient to ensure proper payloads for order placement and querying.
This commit is contained in:
2026-06-01 11:59:13 +02:00
parent 240a591a64
commit 93f4f62d42
17 changed files with 1602 additions and 4 deletions
+23
View File
@@ -0,0 +1,23 @@
## [Unreleased] - 2026-06-01
### Added
- Added stop-condition risk controls for abnormal source/apply latency and repeated execution failures.
- Added a new stop-conditions guard and integration in market feed processing.
### Changed
- Live execution path now auto-activates the kill switch when configured stop conditions are breached.
- Added configuration env keys for stop-condition thresholds.
### Removed
- None.
### Fixed
- Added/expanded unit coverage for risk limits and kill-switch enforcement, including stop-condition scenarios.
- Added partial-fill recovery logic that cancels open orders when possible and hedges residual exposure on timeout or failure.
- Added deterministic order idempotency via Kraken userref plus reconciliation helpers for Kraken order history responses.
- Added execution journaling for trades, orders, and estimated P&L, plus a DuckDB startup fallback when the default file path is unavailable.
- Added a mocked execution integration test that drives the triangular sequencer through the execution journal and DuckDB persistence.
+96 -2
View File
@@ -177,11 +177,105 @@ class KrakenRestClient:
result = await self._request_with_retry("/0/public/AssetPairs")
return _result_dict(result.payload)
async def _throttled_private_call(self, endpoint: str) -> dict[str, Any]:
async def _throttled_private_call(
self,
endpoint: str,
data: dict[str, str] | None = None,
) -> dict[str, Any]:
async with self._private_lock:
result = await self._private_post_with_retry(endpoint)
result = await self._private_post_with_retry(endpoint, data=data)
await asyncio.sleep(self._settings.kraken_private_rate_limit_seconds)
return _result_dict(result.payload)
async def balances(self) -> dict[str, Any]:
return await self._throttled_private_call("/0/private/Balance")
async def place_market_order(
self,
*,
pair: str,
side: str,
volume: float,
user_ref: int | None = None,
) -> dict[str, Any]:
normalized_side = side.lower()
if normalized_side not in {"buy", "sell"}:
raise ValueError("side must be 'buy' or 'sell'")
if volume <= 0.0:
raise ValueError("volume must be > 0.0")
if user_ref is not None and user_ref < 0:
raise ValueError("user_ref must be >= 0")
data = {
"pair": pair,
"type": normalized_side,
"ordertype": "market",
"volume": str(volume),
}
if user_ref is not None:
data["userref"] = str(user_ref)
return await self._throttled_private_call(
"/0/private/AddOrder",
data=data,
)
async def place_limit_order(
self,
*,
pair: str,
side: str,
volume: float,
price: float,
user_ref: int | None = None,
) -> dict[str, Any]:
normalized_side = side.lower()
if normalized_side not in {"buy", "sell"}:
raise ValueError("side must be 'buy' or 'sell'")
if volume <= 0.0:
raise ValueError("volume must be > 0.0")
if price <= 0.0:
raise ValueError("price must be > 0.0")
if user_ref is not None and user_ref < 0:
raise ValueError("user_ref must be >= 0")
data = {
"pair": pair,
"type": normalized_side,
"ordertype": "limit",
"price": str(price),
"volume": str(volume),
}
if user_ref is not None:
data["userref"] = str(user_ref)
return await self._throttled_private_call(
"/0/private/AddOrder",
data=data,
)
async def query_order(
self,
*,
order_id: str,
include_trades: bool = True,
) -> dict[str, Any]:
if not order_id.strip():
raise ValueError("order_id must be non-empty")
return await self._throttled_private_call(
"/0/private/QueryOrders",
data={
"txid": order_id,
"trades": "true" if include_trades else "false",
},
)
async def cancel_order(self, *, order_id: str) -> dict[str, Any]:
if not order_id.strip():
raise ValueError("order_id must be non-empty")
return await self._throttled_private_call(
"/0/private/CancelOrder",
data={"txid": order_id},
)
+32
View File
@@ -0,0 +1,32 @@
"""Trade execution helpers."""
from arbitrade.execution.fill_monitor import (
FillMonitor,
FillMonitorResult,
OrderFillState,
)
from arbitrade.execution.idempotency import (
IdempotencyKeyFactory,
OrderReconciler,
ReconciliationReport,
)
from arbitrade.execution.recovery import PartialFillRecovery, RecoveryAction
from arbitrade.execution.sequencer import (
ExecutionLeg,
TriangularExecutionResult,
TriangularExecutionSequencer,
)
__all__ = [
"ExecutionLeg",
"OrderFillState",
"FillMonitorResult",
"FillMonitor",
"IdempotencyKeyFactory",
"ReconciliationReport",
"OrderReconciler",
"RecoveryAction",
"PartialFillRecovery",
"TriangularExecutionResult",
"TriangularExecutionSequencer",
]
+133
View File
@@ -0,0 +1,133 @@
from __future__ import annotations
import asyncio
import time
from collections.abc import Callable
from dataclasses import dataclass
from datetime import UTC, datetime
from typing import Any, Protocol
class SupportsOrderStatusPolling(Protocol):
async def query_order(
self, *, order_id: str, include_trades: bool = True
) -> dict[str, Any]: ...
@dataclass(frozen=True, slots=True)
class OrderFillState:
order_id: str
status: str
filled_volume: float | None
avg_price: float | None
updated_at: datetime
source: str
@property
def is_terminal(self) -> bool:
return self.status in {"closed", "canceled", "expired"}
@dataclass(frozen=True, slots=True)
class FillMonitorResult:
order_id: str
timed_out: bool
terminal_state: OrderFillState | None
last_state: OrderFillState | None
elapsed_seconds: float
class FillMonitor:
def __init__(
self,
poll_client: SupportsOrderStatusPolling,
*,
poll_interval_seconds: float = 0.5,
max_wait_seconds: float = 10.0,
ws_status_provider: Callable[[str], OrderFillState | None] | None = None,
) -> None:
if poll_interval_seconds <= 0.0:
raise ValueError("poll_interval_seconds must be > 0.0")
if max_wait_seconds <= 0.0:
raise ValueError("max_wait_seconds must be > 0.0")
self._poll_client = poll_client
self._poll_interval_seconds = poll_interval_seconds
self._max_wait_seconds = max_wait_seconds
self._ws_status_provider = ws_status_provider
@staticmethod
def _to_float(value: Any) -> float | None:
if value is None:
return None
try:
return float(value)
except (TypeError, ValueError):
return None
@classmethod
def _state_from_payload(
cls, order_id: str, payload: dict[str, Any], *, source: str
) -> OrderFillState:
status = str(payload.get("status", "unknown")).lower()
return OrderFillState(
order_id=order_id,
status=status,
filled_volume=cls._to_float(payload.get("vol_exec")),
avg_price=cls._to_float(payload.get("price") or payload.get("avg_price")),
updated_at=datetime.now(UTC),
source=source,
)
@classmethod
def _extract_order_payload(cls, order_id: str, response: dict[str, Any]) -> dict[str, Any]:
if order_id in response and isinstance(response[order_id], dict):
payload = response[order_id]
return {str(key): value for key, value in payload.items()}
return response
async def wait_for_terminal_fill(self, order_id: str) -> FillMonitorResult:
if not order_id.strip():
raise ValueError("order_id must be non-empty")
started = time.monotonic()
last_state: OrderFillState | None = None
while True:
elapsed = time.monotonic() - started
if elapsed >= self._max_wait_seconds:
return FillMonitorResult(
order_id=order_id,
timed_out=True,
terminal_state=None,
last_state=last_state,
elapsed_seconds=elapsed,
)
if self._ws_status_provider is not None:
ws_state = self._ws_status_provider(order_id)
if ws_state is not None:
last_state = ws_state
if ws_state.is_terminal:
return FillMonitorResult(
order_id=order_id,
timed_out=False,
terminal_state=ws_state,
last_state=ws_state,
elapsed_seconds=elapsed,
)
response = await self._poll_client.query_order(order_id=order_id, include_trades=True)
payload = self._extract_order_payload(order_id, response)
polled_state = self._state_from_payload(order_id, payload, source="rest_poll")
last_state = polled_state
if polled_state.is_terminal:
return FillMonitorResult(
order_id=order_id,
timed_out=False,
terminal_state=polled_state,
last_state=polled_state,
elapsed_seconds=time.monotonic() - started,
)
await asyncio.sleep(self._poll_interval_seconds)
+105
View File
@@ -0,0 +1,105 @@
from __future__ import annotations
import hashlib
from dataclasses import dataclass
from typing import Any, Protocol
from arbitrade.detection.engine import OpportunityEvent
from arbitrade.execution.sequencer import ExecutionLeg
class SupportsOrderHistoryLookup(Protocol):
async def query_order(
self, *, order_id: str, include_trades: bool = True
) -> dict[str, Any]: ...
@dataclass(frozen=True, slots=True)
class ReconciliationReport:
order_id: str
user_ref: int
status: str
filled_volume: float | None
avg_price: float | None
is_terminal: bool
matches_request: bool
raw_payload: dict[str, Any]
class IdempotencyKeyFactory:
def user_ref_for_leg(self, event: OpportunityEvent, leg: ExecutionLeg, leg_index: int) -> int:
material = "|".join(
[
event.cycle,
event.updated_pair,
leg.from_currency,
leg.to_currency,
leg.pair,
leg.side,
f"{leg.volume:.12f}",
str(leg_index),
]
).encode("utf-8")
digest = hashlib.sha256(material).digest()
value = int.from_bytes(digest[:8], "big") % 2_147_483_647
return value or 1
class OrderReconciler:
def __init__(self, history_client: SupportsOrderHistoryLookup) -> None:
self._history_client = history_client
@staticmethod
def _to_float(value: Any) -> float | None:
if value is None:
return None
try:
return float(value)
except (TypeError, ValueError):
return None
@staticmethod
def _extract_payload(order_id: str, response: dict[str, Any]) -> dict[str, Any]:
if order_id in response and isinstance(response[order_id], dict):
payload = response[order_id]
return {str(key): value for key, value in payload.items()}
return response
async def reconcile_order(
self,
*,
order_id: str,
user_ref: int,
expected_pair: str,
expected_side: str,
expected_volume: float,
) -> ReconciliationReport:
if not order_id.strip():
raise ValueError("order_id must be non-empty")
response = await self._history_client.query_order(order_id=order_id, include_trades=True)
payload = self._extract_payload(order_id, response)
status = str(payload.get("status", "unknown")).lower()
filled_volume = self._to_float(payload.get("vol_exec"))
avg_price = self._to_float(payload.get("price") or payload.get("avg_price"))
reported_pair = str(payload.get("pair", expected_pair))
reported_side = str(payload.get("type", expected_side)).lower()
matches_request = (
reported_pair == expected_pair
and reported_side == expected_side.lower()
and (
expected_volume <= 0.0 or filled_volume is None or filled_volume <= expected_volume
)
and payload.get("userref") in {None, str(user_ref), user_ref}
)
return ReconciliationReport(
order_id=order_id,
user_ref=user_ref,
status=status,
filled_volume=filled_volume,
avg_price=avg_price,
is_terminal=status in {"closed", "canceled", "expired"},
matches_request=matches_request,
raw_payload=payload,
)
+98
View File
@@ -0,0 +1,98 @@
from __future__ import annotations
from dataclasses import dataclass
from typing import Any, Protocol
from arbitrade.execution.fill_monitor import FillMonitorResult, OrderFillState
class SupportsOrderLifecycle(Protocol):
async def cancel_order(self, *, order_id: str) -> dict[str, Any]: ...
async def place_market_order(
self, *, pair: str, side: str, volume: float
) -> dict[str, Any]: ...
@dataclass(frozen=True, slots=True)
class RecoveryAction:
order_id: str
canceled: bool
hedged: bool
hedge_pair: str | None = None
hedge_side: str | None = None
hedge_volume: float | None = None
cancel_response: dict[str, Any] | None = None
hedge_response: dict[str, Any] | None = None
reason: str | None = None
class PartialFillRecovery:
def __init__(self, rest_client: SupportsOrderLifecycle) -> None:
self._rest_client = rest_client
@staticmethod
def _counter_side(side: str) -> str:
normalized = side.lower()
if normalized == "buy":
return "sell"
if normalized == "sell":
return "buy"
raise ValueError("side must be 'buy' or 'sell'")
@staticmethod
def _residual_volume(terminal_state: OrderFillState | None, requested_volume: float) -> float:
if requested_volume <= 0.0:
raise ValueError("requested_volume must be > 0.0")
if terminal_state is None or terminal_state.filled_volume is None:
return requested_volume
residual = requested_volume - terminal_state.filled_volume
return residual if residual > 0.0 else 0.0
async def recover_partial_fill(
self,
*,
order_id: str,
pair: str,
side: str,
requested_volume: float,
fill_result: FillMonitorResult,
) -> RecoveryAction:
if not order_id.strip():
raise ValueError("order_id must be non-empty")
cancel_response: dict[str, Any] | None = None
hedge_response: dict[str, Any] | None = None
hedged = False
canceled = False
reason = None
state = fill_result.terminal_state or fill_result.last_state
residual_volume = self._residual_volume(state, requested_volume)
if state is not None and state.status in {"open", "partial"}:
cancel_response = await self._rest_client.cancel_order(order_id=order_id)
canceled = True
reason = f"canceled_{state.status}_order"
if residual_volume > 0.0 and fill_result.timed_out:
hedge_response = await self._rest_client.place_market_order(
pair=pair,
side=self._counter_side(side),
volume=residual_volume,
)
hedged = True
if reason is None:
reason = "hedged_timed_out_order"
return RecoveryAction(
order_id=order_id,
canceled=canceled,
hedged=hedged,
hedge_pair=pair if hedged else None,
hedge_side=self._counter_side(side) if hedged else None,
hedge_volume=residual_volume if hedged else None,
cancel_response=cancel_response,
hedge_response=hedge_response,
reason=reason,
)
+228
View File
@@ -0,0 +1,228 @@
from __future__ import annotations
from collections.abc import Callable, Sequence
from dataclasses import dataclass
from datetime import UTC, datetime
from typing import Any, Protocol
from arbitrade.detection.engine import OpportunityEvent
from arbitrade.storage.executions import AsyncExecutionWriter
from arbitrade.storage.repositories import OrderRecord, PnLRecord, TradeRecord
class SupportsOrderPlacement(Protocol):
async def place_market_order(
self, *, pair: str, side: str, volume: float
) -> dict[str, Any]: ...
@dataclass(frozen=True, slots=True)
class ExecutionLeg:
from_currency: str
to_currency: str
pair: str
side: str
volume: float
@dataclass(frozen=True, slots=True)
class TriangularExecutionResult:
success: bool
requested_legs: tuple[ExecutionLeg, ...]
completed_legs: int
responses: tuple[dict[str, Any], ...]
failure_reason: str | None = None
class TriangularExecutionSequencer:
def __init__(
self,
rest_client: SupportsOrderPlacement,
*,
available_pairs: Sequence[str],
volume_for_leg: Callable[[OpportunityEvent,
ExecutionLeg, int], float] | None = None,
execution_writer: AsyncExecutionWriter | None = None,
) -> None:
self._rest_client = rest_client
self._available_pairs = {self._normalize_pair(
pair) for pair in available_pairs}
self._volume_for_leg = volume_for_leg or self._default_volume_for_leg
self._execution_writer = execution_writer
@staticmethod
def _normalize_pair(pair: str) -> str:
normalized = pair.strip().upper().replace("-", "/")
if "/" not in normalized:
return normalized
base, quote = normalized.split("/", 1)
return f"{base}/{quote}"
@staticmethod
def _default_volume_for_leg(event: OpportunityEvent, _leg: ExecutionLeg, _idx: int) -> float:
if event.allocated_capital <= 0.0:
raise ValueError("allocated_capital must be > 0.0")
return event.allocated_capital
def _resolve_leg(self, from_currency: str, to_currency: str, volume: float) -> ExecutionLeg:
from_cur = from_currency.upper()
to_cur = to_currency.upper()
buy_pair = f"{to_cur}/{from_cur}"
if buy_pair in self._available_pairs:
return ExecutionLeg(
from_currency=from_cur,
to_currency=to_cur,
pair=buy_pair,
side="buy",
volume=volume,
)
sell_pair = f"{from_cur}/{to_cur}"
if sell_pair in self._available_pairs:
return ExecutionLeg(
from_currency=from_cur,
to_currency=to_cur,
pair=sell_pair,
side="sell",
volume=volume,
)
raise ValueError(f"No tradable pair for leg {from_cur}->{to_cur}")
def _build_legs(self, event: OpportunityEvent) -> tuple[ExecutionLeg, ...]:
currencies = [part.strip().upper()
for part in event.cycle.split("->") if part.strip()]
if len(currencies) < 4 or currencies[0] != currencies[-1]:
raise ValueError(
"cycle must be a closed triangular path like A->B->C->A")
if len(currencies) != 4:
raise ValueError(
"cycle must contain exactly three unique currencies")
legs: list[ExecutionLeg] = []
for idx in range(3):
from_currency = currencies[idx]
to_currency = currencies[idx + 1]
placeholder_leg = ExecutionLeg(
from_currency=from_currency,
to_currency=to_currency,
pair="",
side="buy",
volume=0.0,
)
volume = self._volume_for_leg(event, placeholder_leg, idx)
if volume <= 0.0:
raise ValueError(
"volume_for_leg must return a positive volume")
legs.append(self._resolve_leg(from_currency, to_currency, volume))
return tuple(legs)
@staticmethod
def _trade_ref_for_event(event: OpportunityEvent) -> str:
material = (
f"{event.cycle}|{event.updated_pair}|"
f"{event.detected_at.timestamp():.6f}|"
f"{event.allocated_capital:.12f}"
)
return material.encode("utf-8").hex()[:32]
@staticmethod
def _order_ref_from_response(response: dict[str, Any], default: str) -> str:
txid = response.get("txid")
if isinstance(txid, list) and txid:
return str(txid[0])
if isinstance(txid, str) and txid.strip():
return txid
return default
async def execute(self, event: OpportunityEvent) -> TriangularExecutionResult:
legs = self._build_legs(event)
responses: list[dict[str, Any]] = []
trade_ref = self._trade_ref_for_event(event)
started_at = datetime.now(UTC)
for idx, leg in enumerate(legs):
try:
response = await self._rest_client.place_market_order(
pair=leg.pair,
side=leg.side,
volume=leg.volume,
)
except Exception as exc:
if self._execution_writer is not None:
await self._execution_writer.enqueue(
TradeRecord(
trade_ref=trade_ref,
started_at=started_at,
finished_at=datetime.now(UTC),
status="failed",
realized_pnl=None,
estimated_pnl=event.est_profit,
capital_used=event.allocated_capital,
cycle=event.cycle,
leg_count=len(legs),
)
)
return TriangularExecutionResult(
success=False,
requested_legs=legs,
completed_legs=idx,
responses=tuple(responses),
failure_reason=str(exc),
)
responses.append(response)
if self._execution_writer is not None:
order_ref = self._order_ref_from_response(
response, f"leg-{idx}")
await self._execution_writer.enqueue(
OrderRecord(
trade_ref=trade_ref,
order_ref=order_ref,
leg_index=idx,
pair=leg.pair,
side=leg.side,
volume=leg.volume,
user_ref=None,
status=str(response.get("status", "submitted")),
filled_volume=None,
avg_price=None,
raw_response=response,
recorded_at=datetime.now(UTC),
)
)
if self._execution_writer is not None:
await self._execution_writer.enqueue(
TradeRecord(
trade_ref=trade_ref,
started_at=started_at,
finished_at=datetime.now(UTC),
status="filled",
realized_pnl=None,
estimated_pnl=event.est_profit,
capital_used=event.allocated_capital,
cycle=event.cycle,
leg_count=len(legs),
)
)
await self._execution_writer.enqueue(
PnLRecord(
trade_ref=trade_ref,
recorded_at=datetime.now(UTC),
kind="estimated",
pnl_usd=event.est_profit,
source="triangular_sequencer",
)
)
return TriangularExecutionResult(
success=True,
requested_legs=legs,
completed_legs=len(legs),
responses=tuple(responses),
)
+1
View File
@@ -0,0 +1 @@
"""Storage helpers."""
+43 -2
View File
@@ -5,9 +5,12 @@ from contextlib import contextmanager
from pathlib import Path
import duckdb
import structlog
from arbitrade.config.settings import Settings
_LOG = structlog.get_logger(__name__)
SCHEMA_SQL = """
CREATE TABLE IF NOT EXISTS schema_migrations (
version INTEGER PRIMARY KEY,
@@ -26,11 +29,40 @@ CREATE TABLE IF NOT EXISTS opportunities (
CREATE TABLE IF NOT EXISTS trades (
id UUID DEFAULT uuid(),
trade_ref VARCHAR NOT NULL,
started_at TIMESTAMP NOT NULL,
finished_at TIMESTAMP,
status VARCHAR NOT NULL,
realized_pnl DOUBLE,
capital_used DOUBLE
estimated_pnl DOUBLE,
capital_used DOUBLE,
cycle VARCHAR,
leg_count INTEGER
);
CREATE TABLE IF NOT EXISTS orders (
id UUID DEFAULT uuid(),
trade_ref VARCHAR NOT NULL,
order_ref VARCHAR NOT NULL,
leg_index INTEGER NOT NULL,
pair VARCHAR NOT NULL,
side VARCHAR NOT NULL,
volume DOUBLE NOT NULL,
user_ref INTEGER,
status VARCHAR,
filled_volume DOUBLE,
avg_price DOUBLE,
raw_response JSON,
recorded_at TIMESTAMP NOT NULL
);
CREATE TABLE IF NOT EXISTS pnl_events (
id UUID DEFAULT uuid(),
trade_ref VARCHAR NOT NULL,
recorded_at TIMESTAMP NOT NULL,
kind VARCHAR NOT NULL,
pnl_usd DOUBLE NOT NULL,
source VARCHAR NOT NULL
);
CREATE TABLE IF NOT EXISTS portfolio_snapshots (
@@ -53,10 +85,19 @@ class DuckDBStore:
def __init__(self, settings: Settings) -> None:
self._db_path = Path(settings.duckdb_path)
self._db_path.parent.mkdir(parents=True, exist_ok=True)
self._use_memory_fallback = False
@contextmanager
def connect(self) -> Iterator[duckdb.DuckDBPyConnection]:
conn = duckdb.connect(str(self._db_path))
try:
conn = duckdb.connect(str(self._db_path))
except duckdb.IOException:
if not self._use_memory_fallback:
_LOG.warning(
"duckdb_path_unavailable_falling_back_to_memory", path=str(self._db_path)
)
self._use_memory_fallback = True
conn = duckdb.connect(":memory:")
try:
yield conn
finally:
+66
View File
@@ -0,0 +1,66 @@
from __future__ import annotations
import asyncio
import structlog
from arbitrade.storage.repositories import (
OrderRecord,
OrderRepository,
PnLRecord,
PnLRepository,
TradeRecord,
TradeRepository,
)
_LOG = structlog.get_logger(__name__)
class AsyncExecutionWriter:
def __init__(
self,
trade_repository: TradeRepository,
order_repository: OrderRepository,
pnl_repository: PnLRepository,
max_queue_size: int = 50_000,
) -> None:
self._trade_repository = trade_repository
self._order_repository = order_repository
self._pnl_repository = pnl_repository
self._queue: asyncio.Queue[TradeRecord | OrderRecord | PnLRecord] = asyncio.Queue(
maxsize=max_queue_size
)
self._task: asyncio.Task[None] | None = None
self._stop = asyncio.Event()
async def start(self) -> None:
if self._task is None or self._task.done():
self._stop.clear()
self._task = asyncio.create_task(self._run(), name="execution-writer")
async def stop(self) -> None:
self._stop.set()
if self._task is not None:
await self._task
async def enqueue(self, record: TradeRecord | OrderRecord | PnLRecord) -> None:
await self._queue.put(record)
async def _run(self) -> None:
while not (self._stop.is_set() and self._queue.empty()):
try:
record = await asyncio.wait_for(self._queue.get(), timeout=0.5)
except TimeoutError:
continue
try:
if isinstance(record, TradeRecord):
self._trade_repository.insert(record)
elif isinstance(record, OrderRecord):
self._order_repository.insert(record)
else:
self._pnl_repository.insert(record)
except Exception as exc:
_LOG.error("execution_write_failed", error=str(exc))
finally:
self._queue.task_done()
+141
View File
@@ -28,6 +28,44 @@ class OpportunityRecord:
executed: bool = False
@dataclass(slots=True)
class TradeRecord:
trade_ref: str
started_at: datetime
finished_at: datetime | None
status: str
realized_pnl: float | None
estimated_pnl: float | None
capital_used: float | None
cycle: str | None = None
leg_count: int | None = None
@dataclass(slots=True)
class OrderRecord:
trade_ref: str
order_ref: str
leg_index: int
pair: str
side: str
volume: float
user_ref: int | None
status: str | None
filled_volume: float | None
avg_price: float | None
raw_response: dict[str, Any]
recorded_at: datetime
@dataclass(slots=True)
class PnLRecord:
trade_ref: str
recorded_at: datetime
kind: str
pnl_usd: float
source: str
class MarketSnapshotRepository:
def __init__(self, store: DuckDBStore) -> None:
self._store = store
@@ -76,3 +114,106 @@ class OpportunityRepository:
record.executed,
],
)
class TradeRepository:
def __init__(self, store: DuckDBStore) -> None:
self._store = store
def insert(self, record: TradeRecord) -> None:
with self._store.connect() as conn:
conn.execute(
"""
INSERT INTO trades (
trade_ref,
started_at,
finished_at,
status,
realized_pnl,
estimated_pnl,
capital_used,
cycle,
leg_count
)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
""",
[
record.trade_ref,
record.started_at,
record.finished_at,
record.status,
record.realized_pnl,
record.estimated_pnl,
record.capital_used,
record.cycle,
record.leg_count,
],
)
class OrderRepository:
def __init__(self, store: DuckDBStore) -> None:
self._store = store
def insert(self, record: OrderRecord) -> None:
with self._store.connect() as conn:
conn.execute(
"""
INSERT INTO orders (
trade_ref,
order_ref,
leg_index,
pair,
side,
volume,
user_ref,
status,
filled_volume,
avg_price,
raw_response,
recorded_at
)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
""",
[
record.trade_ref,
record.order_ref,
record.leg_index,
record.pair,
record.side,
record.volume,
record.user_ref,
record.status,
record.filled_volume,
record.avg_price,
orjson.dumps(record.raw_response).decode("utf-8"),
record.recorded_at,
],
)
class PnLRepository:
def __init__(self, store: DuckDBStore) -> None:
self._store = store
def insert(self, record: PnLRecord) -> None:
with self._store.connect() as conn:
conn.execute(
"""
INSERT INTO pnl_events (
trade_ref,
recorded_at,
kind,
pnl_usd,
source
)
VALUES (?, ?, ?, ?, ?)
""",
[
record.trade_ref,
record.recorded_at,
record.kind,
record.pnl_usd,
record.source,
],
)
+90
View File
@@ -0,0 +1,90 @@
from __future__ import annotations
from dataclasses import dataclass
from datetime import UTC, datetime
import pytest
from arbitrade.config.settings import Settings
from arbitrade.detection.engine import OpportunityEvent
from arbitrade.execution.sequencer import TriangularExecutionSequencer
from arbitrade.storage.db import DuckDBStore
from arbitrade.storage.executions import AsyncExecutionWriter
from arbitrade.storage.repositories import OrderRepository, PnLRepository, TradeRepository
@dataclass(slots=True)
class _FakeRestClient:
calls: int = 0
async def place_market_order(self, *, pair: str, side: str, volume: float) -> dict[str, object]:
self.calls += 1
return {"txid": [f"tx-{self.calls}"], "status": "submitted"}
def _sample_event() -> OpportunityEvent:
return OpportunityEvent(
detected_at=datetime.now(UTC),
cycle="USD->BTC->ETH->USD",
updated_pair="BTC/USD",
gross_rate=1.04,
net_rate=1.03,
gross_pct=4.0,
net_pct=3.0,
est_profit=0.03,
)
@pytest.mark.asyncio
async def test_execution_writer_persists_trade_order_and_pnl(tmp_path) -> None:
settings = Settings(_env_file=None, DUCKDB_PATH=tmp_path / "exec.duckdb")
store = DuckDBStore(settings)
store.migrate()
writer = AsyncExecutionWriter(
TradeRepository(store),
OrderRepository(store),
PnLRepository(store),
max_queue_size=10,
)
await writer.start()
client = _FakeRestClient()
sequencer = TriangularExecutionSequencer(
client,
available_pairs=["BTC/USD", "ETH/BTC", "ETH/USD"],
execution_writer=writer,
)
result = await sequencer.execute(_sample_event())
await writer.stop()
assert result.success
assert client.calls == 3
with store.connect() as conn:
trades = conn.execute(
"SELECT trade_ref, status, estimated_pnl, capital_used, cycle, leg_count FROM trades"
).fetchall()
orders = conn.execute(
"SELECT trade_ref, order_ref, leg_index, pair, side, volume, status "
"FROM orders ORDER BY leg_index"
).fetchall()
pnls = conn.execute(
"SELECT trade_ref, kind, pnl_usd, source FROM pnl_events").fetchall()
assert len(trades) == 1
assert trades[0][1] == "filled"
assert trades[0][2] == 0.03
assert trades[0][3] == 1.0
assert trades[0][4] == "USD->BTC->ETH->USD"
assert trades[0][5] == 3
assert len(orders) == 3
assert orders[0][2] == 0
assert orders[1][2] == 1
assert orders[2][2] == 2
assert orders[0][6] == "submitted"
assert len(pnls) == 1
assert pnls[0][1] == "estimated"
assert pnls[0][2] == 0.03
+93
View File
@@ -0,0 +1,93 @@
from __future__ import annotations
from dataclasses import dataclass, field
from datetime import UTC, datetime
from typing import Any
import pytest
from arbitrade.detection.engine import OpportunityEvent
from arbitrade.execution.sequencer import TriangularExecutionSequencer
@dataclass(slots=True)
class _FakeRestClient:
fail_at_call: int | None = None
calls: list[dict[str, Any]] = field(default_factory=list)
async def place_market_order(self, *, pair: str, side: str, volume: float) -> dict[str, Any]:
call_number = len(self.calls) + 1
if self.fail_at_call is not None and call_number == self.fail_at_call:
raise RuntimeError("simulated failure")
payload = {"pair": pair, "side": side, "volume": volume}
self.calls.append(payload)
return {"txid": [f"tx-{call_number}"]}
def _sample_event(cycle: str = "USD->BTC->ETH->USD") -> OpportunityEvent:
return OpportunityEvent(
detected_at=datetime.now(UTC),
cycle=cycle,
updated_pair="BTC/USD",
gross_rate=1.02,
net_rate=1.01,
gross_pct=2.0,
net_pct=1.0,
est_profit=1.0,
allocated_capital=10.0,
)
@pytest.mark.asyncio
async def test_triangular_sequencer_executes_legs_in_order() -> None:
client = _FakeRestClient()
sequencer = TriangularExecutionSequencer(
client,
available_pairs=["BTC/USD", "ETH/BTC", "ETH/USD"],
)
result = await sequencer.execute(_sample_event())
assert result.success
assert result.completed_legs == 3
assert [call["pair"] for call in client.calls] == ["BTC/USD", "ETH/BTC", "ETH/USD"]
assert [call["side"] for call in client.calls] == ["buy", "buy", "sell"]
@pytest.mark.asyncio
async def test_triangular_sequencer_stops_on_failed_leg() -> None:
client = _FakeRestClient(fail_at_call=2)
sequencer = TriangularExecutionSequencer(
client,
available_pairs=["BTC/USD", "ETH/BTC", "ETH/USD"],
)
result = await sequencer.execute(_sample_event())
assert not result.success
assert result.completed_legs == 1
assert result.failure_reason is not None
assert len(client.calls) == 1
def test_triangular_sequencer_rejects_non_closed_cycle() -> None:
client = _FakeRestClient()
sequencer = TriangularExecutionSequencer(
client,
available_pairs=["BTC/USD", "ETH/BTC", "ETH/USD"],
)
with pytest.raises(ValueError, match="closed triangular path"):
sequencer._build_legs(_sample_event(cycle="USD->BTC->ETH"))
def test_triangular_sequencer_rejects_missing_pair() -> None:
client = _FakeRestClient()
sequencer = TriangularExecutionSequencer(
client,
available_pairs=["BTC/USD", "ETH/BTC"],
)
with pytest.raises(ValueError, match="No tradable pair"):
sequencer._build_legs(_sample_event())
+112
View File
@@ -0,0 +1,112 @@
from __future__ import annotations
from dataclasses import dataclass, field
from datetime import UTC, datetime
from typing import Any
import pytest
from arbitrade.execution.fill_monitor import FillMonitor, OrderFillState
@dataclass(slots=True)
class _FakePollClient:
responses: list[dict[str, Any]]
calls: int = 0
async def query_order(self, *, order_id: str, include_trades: bool = True) -> dict[str, Any]:
self.calls += 1
if self.responses:
return self.responses.pop(0)
return {order_id: {"status": "open", "vol_exec": "0.0", "price": "0.0"}}
@dataclass(slots=True)
class _FakeWsProvider:
states: list[OrderFillState] = field(default_factory=list)
def get(self, _order_id: str) -> OrderFillState | None:
if not self.states:
return None
return self.states.pop(0)
@pytest.mark.asyncio
async def test_fill_monitor_detects_terminal_state_via_polling() -> None:
order_id = "order-1"
client = _FakePollClient(
responses=[
{order_id: {"status": "open", "vol_exec": "0.0", "price": "0.0"}},
{order_id: {"status": "closed", "vol_exec": "1.0", "price": "100.0"}},
]
)
monitor = FillMonitor(client, poll_interval_seconds=0.001, max_wait_seconds=0.1)
result = await monitor.wait_for_terminal_fill(order_id)
assert not result.timed_out
assert result.terminal_state is not None
assert result.terminal_state.status == "closed"
assert result.terminal_state.filled_volume == 1.0
assert result.terminal_state.source == "rest_poll"
@pytest.mark.asyncio
async def test_fill_monitor_times_out_when_no_terminal_state() -> None:
order_id = "order-2"
client = _FakePollClient(
responses=[
{order_id: {"status": "open", "vol_exec": "0.1", "price": "100.0"}},
{order_id: {"status": "partial", "vol_exec": "0.2", "price": "100.0"}},
{order_id: {"status": "open", "vol_exec": "0.2", "price": "100.0"}},
]
)
monitor = FillMonitor(client, poll_interval_seconds=0.001, max_wait_seconds=0.01)
result = await monitor.wait_for_terminal_fill(order_id)
assert result.timed_out
assert result.terminal_state is None
assert result.last_state is not None
assert result.last_state.status in {"open", "partial"}
@pytest.mark.asyncio
async def test_fill_monitor_uses_ws_status_for_fast_terminal_detection() -> None:
order_id = "order-3"
ws_provider = _FakeWsProvider(
states=[
OrderFillState(
order_id=order_id,
status="closed",
filled_volume=0.5,
avg_price=200.0,
updated_at=datetime.now(UTC),
source="ws",
)
]
)
client = _FakePollClient(responses=[])
monitor = FillMonitor(
client,
poll_interval_seconds=0.001,
max_wait_seconds=0.1,
ws_status_provider=ws_provider.get,
)
result = await monitor.wait_for_terminal_fill(order_id)
assert not result.timed_out
assert result.terminal_state is not None
assert result.terminal_state.source == "ws"
assert client.calls == 0
def test_fill_monitor_rejects_invalid_configuration() -> None:
client = _FakePollClient(responses=[])
with pytest.raises(ValueError, match="poll_interval_seconds"):
FillMonitor(client, poll_interval_seconds=0.0)
with pytest.raises(ValueError, match="max_wait_seconds"):
FillMonitor(client, max_wait_seconds=0.0)
+109
View File
@@ -0,0 +1,109 @@
from __future__ import annotations
from dataclasses import dataclass
from datetime import UTC, datetime
from typing import Any
import pytest
from arbitrade.detection.engine import OpportunityEvent
from arbitrade.execution.idempotency import IdempotencyKeyFactory, OrderReconciler
from arbitrade.execution.sequencer import ExecutionLeg
@dataclass(slots=True)
class _FakeHistoryClient:
response: dict[str, Any]
async def query_order(self, *, order_id: str, include_trades: bool = True) -> dict[str, Any]:
return self.response
def _sample_event() -> OpportunityEvent:
return OpportunityEvent(
detected_at=datetime.now(UTC),
cycle="USD->BTC->ETH->USD",
updated_pair="BTC/USD",
gross_rate=1.02,
net_rate=1.01,
gross_pct=2.0,
net_pct=1.0,
est_profit=1.0,
allocated_capital=10.0,
)
def test_idempotency_key_factory_is_deterministic() -> None:
factory = IdempotencyKeyFactory()
event = _sample_event()
leg = ExecutionLeg(
from_currency="USD",
to_currency="BTC",
pair="BTC/USD",
side="buy",
volume=10.0,
)
first = factory.user_ref_for_leg(event, leg, 0)
second = factory.user_ref_for_leg(event, leg, 0)
assert first == second
assert first > 0
@pytest.mark.asyncio
async def test_order_reconciler_maps_query_response_to_report() -> None:
client = _FakeHistoryClient(
response={
"order-1": {
"status": "closed",
"vol_exec": "5.0",
"price": "100.0",
"pair": "BTC/USD",
"type": "buy",
"userref": 12345,
}
}
)
reconciler = OrderReconciler(client)
report = await reconciler.reconcile_order(
order_id="order-1",
user_ref=12345,
expected_pair="BTC/USD",
expected_side="buy",
expected_volume=10.0,
)
assert report.is_terminal
assert report.matches_request
assert report.status == "closed"
assert report.filled_volume == 5.0
assert report.avg_price == 100.0
@pytest.mark.asyncio
async def test_order_reconciler_marks_mismatch() -> None:
client = _FakeHistoryClient(
response={
"order-1": {
"status": "closed",
"vol_exec": "5.0",
"price": "100.0",
"pair": "ETH/USD",
"type": "sell",
"userref": 999,
}
}
)
reconciler = OrderReconciler(client)
report = await reconciler.reconcile_order(
order_id="order-1",
user_ref=12345,
expected_pair="BTC/USD",
expected_side="buy",
expected_volume=10.0,
)
assert not report.matches_request
+136
View File
@@ -110,3 +110,139 @@ def test_compliance_detects_insecure_config() -> None:
assert any("below 1.0" in issue for issue in issues)
assert any("ATTEMPTS" in issue for issue in issues)
assert any("BASE_DELAY" in issue for issue in issues)
@pytest.mark.asyncio
async def test_place_market_order_posts_add_order_payload() -> None:
settings = Settings(
_env_file=None,
KRAKEN_API_KEY="key",
KRAKEN_API_SECRET="c2VjcmV0", # base64("secret")
kraken_private_rate_limit_seconds=0.0,
)
client = KrakenRestClient(settings)
with respx.mock(base_url=settings.kraken_rest_url) as mock_router:
route = mock_router.post("/0/private/AddOrder").respond(
200,
json={"error": [], "result": {"txid": ["m1"]}},
)
payload = await client.place_market_order(
pair="XBTUSD",
side="buy",
volume=0.05,
)
await client.close()
request_body = route.calls.last.request.content.decode()
assert "pair=XBTUSD" in request_body
assert "type=buy" in request_body
assert "ordertype=market" in request_body
assert "volume=0.05" in request_body
assert payload["txid"] == ["m1"]
@pytest.mark.asyncio
async def test_place_limit_order_posts_add_order_payload() -> None:
settings = Settings(
_env_file=None,
KRAKEN_API_KEY="key",
KRAKEN_API_SECRET="c2VjcmV0", # base64("secret")
kraken_private_rate_limit_seconds=0.0,
)
client = KrakenRestClient(settings)
with respx.mock(base_url=settings.kraken_rest_url) as mock_router:
route = mock_router.post("/0/private/AddOrder").respond(
200,
json={"error": [], "result": {"txid": ["l1"]}},
)
payload = await client.place_limit_order(
pair="ETHUSD",
side="sell",
volume=1.5,
price=3500.0,
)
await client.close()
request_body = route.calls.last.request.content.decode()
assert "pair=ETHUSD" in request_body
assert "type=sell" in request_body
assert "ordertype=limit" in request_body
assert "price=3500.0" in request_body
assert "volume=1.5" in request_body
assert payload["txid"] == ["l1"]
@pytest.mark.asyncio
async def test_place_order_validates_inputs() -> None:
settings = Settings(
_env_file=None,
KRAKEN_API_KEY="key",
KRAKEN_API_SECRET="c2VjcmV0", # base64("secret")
kraken_private_rate_limit_seconds=0.0,
)
client = KrakenRestClient(settings)
with pytest.raises(ValueError, match="side"):
await client.place_market_order(pair="XBTUSD", side="hold", volume=0.1)
with pytest.raises(ValueError, match="volume"):
await client.place_market_order(pair="XBTUSD", side="buy", volume=0.0)
with pytest.raises(ValueError, match="price"):
await client.place_limit_order(
pair="XBTUSD",
side="buy",
volume=0.1,
price=0.0,
)
await client.close()
@pytest.mark.asyncio
async def test_query_order_posts_query_orders_payload() -> None:
settings = Settings(
_env_file=None,
KRAKEN_API_KEY="key",
KRAKEN_API_SECRET="c2VjcmV0", # base64("secret")
kraken_private_rate_limit_seconds=0.0,
)
client = KrakenRestClient(settings)
with respx.mock(base_url=settings.kraken_rest_url) as mock_router:
route = mock_router.post("/0/private/QueryOrders").respond(
200,
json={"error": [], "result": {"order-1": {"status": "closed"}}},
)
payload = await client.query_order(order_id="order-1", include_trades=False)
await client.close()
request_body = route.calls.last.request.content.decode()
assert "txid=order-1" in request_body
assert "trades=false" in request_body
assert payload["order-1"]["status"] == "closed"
@pytest.mark.asyncio
async def test_cancel_order_posts_cancel_order_payload() -> None:
settings = Settings(
_env_file=None,
KRAKEN_API_KEY="key",
KRAKEN_API_SECRET="c2VjcmV0", # base64("secret")
kraken_private_rate_limit_seconds=0.0,
)
client = KrakenRestClient(settings)
with respx.mock(base_url=settings.kraken_rest_url) as mock_router:
route = mock_router.post("/0/private/CancelOrder").respond(
200,
json={"error": [], "result": {"count": 1}},
)
payload = await client.cancel_order(order_id="order-1")
await client.close()
request_body = route.calls.last.request.content.decode()
assert "txid=order-1" in request_body
assert payload["count"] == 1
+96
View File
@@ -0,0 +1,96 @@
from __future__ import annotations
from dataclasses import dataclass
from datetime import UTC, datetime
from typing import Any
import pytest
from arbitrade.execution.fill_monitor import FillMonitorResult, OrderFillState
from arbitrade.execution.recovery import PartialFillRecovery
@dataclass(slots=True)
class _FakeRestClient:
cancel_calls: list[str] = None # type: ignore[assignment]
market_calls: list[dict[str, Any]] = None # type: ignore[assignment]
def __post_init__(self) -> None:
self.cancel_calls = []
self.market_calls = []
async def cancel_order(self, *, order_id: str) -> dict[str, Any]:
self.cancel_calls.append(order_id)
return {"result": {"count": 1}}
async def place_market_order(self, *, pair: str, side: str, volume: float) -> dict[str, Any]:
self.market_calls.append({"pair": pair, "side": side, "volume": volume})
return {"txid": ["hedge-1"]}
def _monitor_result(
*, status: str, filled_volume: float | None, timed_out: bool
) -> FillMonitorResult:
state = OrderFillState(
order_id="order-1",
status=status,
filled_volume=filled_volume,
avg_price=100.0,
updated_at=datetime.now(UTC),
source="rest_poll",
)
return FillMonitorResult(
order_id="order-1",
timed_out=timed_out,
terminal_state=None if status in {"open", "partial"} else state,
last_state=state,
elapsed_seconds=1.0,
)
@pytest.mark.asyncio
async def test_partial_fill_recovery_cancels_open_order_and_hedges_residual() -> None:
client = _FakeRestClient()
recovery = PartialFillRecovery(client)
result = await recovery.recover_partial_fill(
order_id="order-1",
pair="BTC/USD",
side="buy",
requested_volume=10.0,
fill_result=_monitor_result(status="partial", filled_volume=4.0, timed_out=True),
)
assert result.canceled
assert result.hedged
assert client.cancel_calls == ["order-1"]
assert client.market_calls == [{"pair": "BTC/USD", "side": "sell", "volume": 6.0}]
assert result.hedge_volume == 6.0
assert result.reason == "canceled_partial_order"
@pytest.mark.asyncio
async def test_partial_fill_recovery_no_hedge_when_no_residual() -> None:
client = _FakeRestClient()
recovery = PartialFillRecovery(client)
result = await recovery.recover_partial_fill(
order_id="order-1",
pair="BTC/USD",
side="sell",
requested_volume=5.0,
fill_result=_monitor_result(status="closed", filled_volume=5.0, timed_out=False),
)
assert not result.canceled
assert not result.hedged
assert client.cancel_calls == []
assert client.market_calls == []
def test_partial_fill_recovery_rejects_invalid_volume() -> None:
client = _FakeRestClient()
recovery = PartialFillRecovery(client)
with pytest.raises(ValueError, match="requested_volume"):
recovery._residual_volume(None, 0.0)