Add risk management features: implement loss limits, trade limits, and pre-trade validation; update settings and tests

This commit is contained in:
2026-06-01 11:16:37 +02:00
parent 9d8a8a8a45
commit 45e219d103
14 changed files with 718 additions and 20 deletions
+216 -4
View File
@@ -8,7 +8,10 @@ import pytest
from arbitrade.detection.engine import OpportunityEvent
from arbitrade.exchange.models import BookDelta, BookLevel
from arbitrade.market_data.feed import MarketDataFeed
from arbitrade.market_data.feed import ExecutionOutcome, MarketDataFeed
from arbitrade.risk.loss_limits import LossLimitGuard
from arbitrade.risk.pre_trade import PreTradeValidator
from arbitrade.risk.trade_limits import TradeLimitsGuard
@dataclass(slots=True)
@@ -41,20 +44,47 @@ class _FakeOpportunityWriter:
class _FakeDetector:
def __init__(self, event: OpportunityEvent) -> None:
self._event = event
self.last_base_capital: float | None = None
def opportunities_for_updated_pair(self, _updated_pair: str, _books: dict[str, object]):
def opportunities_for_updated_pair(
self,
_updated_pair: str,
_books: dict[str, object],
*,
base_capital: float,
):
self.last_base_capital = base_capital
return [self._event]
class _FakeExecutor:
def __init__(self) -> None:
self.calls: list[OpportunityEvent] = []
self.realized_pnls: list[float | None] = []
self.outcomes: list[ExecutionOutcome] = []
async def execute(self, event: OpportunityEvent) -> None:
async def execute(self, event: OpportunityEvent) -> ExecutionOutcome | float | None:
self.calls.append(event)
if self.outcomes:
return self.outcomes.pop(0)
if not self.realized_pnls:
return None
return self.realized_pnls.pop(0)
def _sample_event() -> OpportunityEvent:
@dataclass(slots=True)
class _FakeWsClientTwoMessages:
delta: BookDelta
async def connect_stream(self):
yield SimpleNamespace(payload={"channel": "book", "seq": 1})
yield SimpleNamespace(payload={"channel": "book", "seq": 2})
def parse_book_delta(self, _payload: dict[str, object]) -> BookDelta:
return self.delta
def _sample_event(*, allocated_capital: float = 1.0) -> OpportunityEvent:
return OpportunityEvent(
detected_at=datetime.now(UTC),
cycle="USD->BTC->ETH->USD",
@@ -64,6 +94,7 @@ def _sample_event() -> OpportunityEvent:
gross_pct=4.0,
net_pct=3.0,
est_profit=0.03,
allocated_capital=allocated_capital,
)
@@ -110,3 +141,184 @@ async def test_market_data_feed_live_mode_executes_orders() -> None:
assert len(executor.calls) == 1
assert executor.calls[0].cycle == "USD->BTC->ETH->USD"
@pytest.mark.asyncio
async def test_market_data_feed_enforces_per_trade_capital_limit() -> None:
event = _sample_event()
detector = _FakeDetector(event)
feed = MarketDataFeed(
ws_client=_FakeWsClient(_sample_delta()),
snapshot_writer=_FakeSnapshotWriter(),
detector=detector,
opportunity_writer=_FakeOpportunityWriter(),
paper_trading_mode=True,
trade_capital=250.0,
max_trade_capital=100.0,
)
await feed.run()
assert detector.last_base_capital == 100.0
@pytest.mark.asyncio
async def test_market_data_feed_auto_halts_on_daily_loss_limit() -> None:
event = _sample_event()
detector = _FakeDetector(event)
executor = _FakeExecutor()
executor.realized_pnls = [-60.0, -10.0]
loss_guard = LossLimitGuard(daily_loss_limit=50.0)
feed = MarketDataFeed(
ws_client=_FakeWsClientTwoMessages(_sample_delta()),
snapshot_writer=_FakeSnapshotWriter(),
detector=detector,
opportunity_writer=_FakeOpportunityWriter(),
paper_trading_mode=False,
opportunity_executor=executor.execute,
loss_limit_guard=loss_guard,
)
await feed.run()
assert len(executor.calls) == 1
assert loss_guard.is_halted
assert loss_guard.halted_reason == "daily_loss_limit_breached"
@pytest.mark.asyncio
async def test_market_data_feed_auto_halts_on_cumulative_loss_limit() -> None:
event = _sample_event()
detector = _FakeDetector(event)
executor = _FakeExecutor()
executor.realized_pnls = [-40.0, -15.0]
loss_guard = LossLimitGuard(cumulative_loss_limit=50.0)
feed = MarketDataFeed(
ws_client=_FakeWsClientTwoMessages(_sample_delta()),
snapshot_writer=_FakeSnapshotWriter(),
detector=detector,
opportunity_writer=_FakeOpportunityWriter(),
paper_trading_mode=False,
opportunity_executor=executor.execute,
loss_limit_guard=loss_guard,
)
await feed.run()
assert len(executor.calls) == 2
assert loss_guard.is_halted
assert loss_guard.halted_reason == "cumulative_loss_limit_breached"
@pytest.mark.asyncio
async def test_market_data_feed_enforces_max_concurrent_trades() -> None:
event = _sample_event()
detector = _FakeDetector(event)
executor = _FakeExecutor()
executor.outcomes = [ExecutionOutcome(
realized_pnl=None, close_trade=False)]
trade_guard = TradeLimitsGuard(max_concurrent_trades=1)
feed = MarketDataFeed(
ws_client=_FakeWsClientTwoMessages(_sample_delta()),
snapshot_writer=_FakeSnapshotWriter(),
detector=detector,
opportunity_writer=_FakeOpportunityWriter(),
paper_trading_mode=False,
opportunity_executor=executor.execute,
trade_limits_guard=trade_guard,
)
await feed.run()
assert len(executor.calls) == 1
assert trade_guard.active_trades == 1
@pytest.mark.asyncio
async def test_market_data_feed_enforces_per_asset_exposure_cap() -> None:
event = _sample_event(allocated_capital=100.0)
detector = _FakeDetector(event)
executor = _FakeExecutor()
trade_guard = TradeLimitsGuard(max_exposure_per_asset=50.0)
feed = MarketDataFeed(
ws_client=_FakeWsClient(_sample_delta()),
snapshot_writer=_FakeSnapshotWriter(),
detector=detector,
opportunity_writer=_FakeOpportunityWriter(),
paper_trading_mode=False,
opportunity_executor=executor.execute,
trade_limits_guard=trade_guard,
)
await feed.run()
assert len(executor.calls) == 0
@pytest.mark.asyncio
async def test_market_data_feed_blocks_when_pre_trade_balance_insufficient() -> None:
event = _sample_event(allocated_capital=100.0)
detector = _FakeDetector(event)
executor = _FakeExecutor()
validator = PreTradeValidator(min_order_size_by_asset={"USD": 50.0})
feed = MarketDataFeed(
ws_client=_FakeWsClient(_sample_delta()),
snapshot_writer=_FakeSnapshotWriter(),
detector=detector,
opportunity_writer=_FakeOpportunityWriter(),
paper_trading_mode=False,
opportunity_executor=executor.execute,
pre_trade_validator=validator,
balance_provider=lambda: {"USD": 25.0},
quote_balance_asset="USD",
)
await feed.run()
assert len(executor.calls) == 0
@pytest.mark.asyncio
async def test_market_data_feed_blocks_when_pre_trade_min_order_not_met() -> None:
event = _sample_event(allocated_capital=25.0)
detector = _FakeDetector(event)
executor = _FakeExecutor()
validator = PreTradeValidator(min_order_size_by_asset={"USD": 50.0})
feed = MarketDataFeed(
ws_client=_FakeWsClient(_sample_delta()),
snapshot_writer=_FakeSnapshotWriter(),
detector=detector,
opportunity_writer=_FakeOpportunityWriter(),
paper_trading_mode=False,
opportunity_executor=executor.execute,
pre_trade_validator=validator,
balance_provider=lambda: {"USD": 500.0},
quote_balance_asset="USD",
)
await feed.run()
assert len(executor.calls) == 0
@pytest.mark.asyncio
async def test_market_data_feed_allows_when_pre_trade_validation_passes() -> None:
event = _sample_event(allocated_capital=75.0)
detector = _FakeDetector(event)
executor = _FakeExecutor()
validator = PreTradeValidator(min_order_size_by_asset={"USD": 50.0})
feed = MarketDataFeed(
ws_client=_FakeWsClient(_sample_delta()),
snapshot_writer=_FakeSnapshotWriter(),
detector=detector,
opportunity_writer=_FakeOpportunityWriter(),
paper_trading_mode=False,
opportunity_executor=executor.execute,
pre_trade_validator=validator,
balance_provider=lambda: {"USD": 500.0},
quote_balance_asset="USD",
)
await feed.run()
assert len(executor.calls) == 1