Add risk management features: implement loss limits, trade limits, and pre-trade validation; update settings and tests
This commit is contained in:
@@ -19,10 +19,8 @@ def _make_book_levels(
|
||||
*, bids: list[tuple[float, float]], asks: list[tuple[float, float]]
|
||||
) -> OrderBook:
|
||||
book = OrderBook()
|
||||
book.apply_bids([BookLevel(price=price, volume=volume)
|
||||
for price, volume in bids])
|
||||
book.apply_asks([BookLevel(price=price, volume=volume)
|
||||
for price, volume in asks])
|
||||
book.apply_bids([BookLevel(price=price, volume=volume) for price, volume in bids])
|
||||
book.apply_asks([BookLevel(price=price, volume=volume) for price, volume in asks])
|
||||
return book
|
||||
|
||||
|
||||
|
||||
@@ -278,3 +278,29 @@ def test_incremental_detector_emits_structured_opportunity_event() -> None:
|
||||
assert event.gross_pct == pytest.approx(4.0)
|
||||
assert event.net_pct == pytest.approx(4.0)
|
||||
assert event.est_profit == pytest.approx(20.0)
|
||||
|
||||
|
||||
def test_incremental_detector_estimated_profit_scales_with_capital() -> None:
|
||||
cycle = TriangularCycle(
|
||||
currencies=("USD", "BTC", "ETH"),
|
||||
pairs=("BTC/USD", "ETH/BTC", "ETH/USD"),
|
||||
)
|
||||
detector = IncrementalCycleDetector(
|
||||
CurrencyGraph.index_cycles_by_pair([cycle]),
|
||||
min_profit_threshold=0.0,
|
||||
)
|
||||
|
||||
books = {
|
||||
"BTC/USD": _make_book(bid=99.9, ask=100.0),
|
||||
"ETH/BTC": _make_book(bid=0.049, ask=0.05),
|
||||
"ETH/USD": _make_book(bid=5.20, ask=5.21),
|
||||
}
|
||||
|
||||
opportunities = detector.opportunities_for_updated_pair(
|
||||
"ETH/BTC",
|
||||
books,
|
||||
base_capital=250.0,
|
||||
)
|
||||
|
||||
assert len(opportunities) == 1
|
||||
assert opportunities[0].est_profit == pytest.approx(10.0)
|
||||
|
||||
@@ -0,0 +1,49 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import UTC, datetime, timedelta
|
||||
|
||||
import pytest
|
||||
|
||||
from arbitrade.risk.loss_limits import LossLimitGuard
|
||||
|
||||
|
||||
def test_loss_limit_guard_tracks_daily_and_cumulative_pnl() -> None:
|
||||
guard = LossLimitGuard(daily_loss_limit=100.0, cumulative_loss_limit=200.0)
|
||||
t0 = datetime.now(UTC)
|
||||
|
||||
guard.register_realized_pnl(-40.0, at=t0)
|
||||
guard.register_realized_pnl(10.0, at=t0)
|
||||
|
||||
assert guard.cumulative_pnl == -30.0
|
||||
assert guard.daily_pnl(t0.date()) == -30.0
|
||||
assert not guard.is_halted
|
||||
|
||||
|
||||
def test_loss_limit_guard_halts_on_daily_limit() -> None:
|
||||
guard = LossLimitGuard(daily_loss_limit=50.0)
|
||||
t0 = datetime.now(UTC)
|
||||
|
||||
guard.register_realized_pnl(-30.0, at=t0)
|
||||
guard.register_realized_pnl(-25.0, at=t0)
|
||||
|
||||
assert guard.is_halted
|
||||
assert guard.halted_reason == "daily_loss_limit_breached"
|
||||
|
||||
|
||||
def test_loss_limit_guard_halts_on_cumulative_limit_across_days() -> None:
|
||||
guard = LossLimitGuard(cumulative_loss_limit=60.0)
|
||||
t0 = datetime.now(UTC)
|
||||
|
||||
guard.register_realized_pnl(-40.0, at=t0)
|
||||
guard.register_realized_pnl(-25.0, at=t0 + timedelta(days=1))
|
||||
|
||||
assert guard.is_halted
|
||||
assert guard.halted_reason == "cumulative_loss_limit_breached"
|
||||
|
||||
|
||||
def test_loss_limit_guard_rejects_invalid_limits() -> None:
|
||||
with pytest.raises(ValueError, match="daily_loss_limit"):
|
||||
LossLimitGuard(daily_loss_limit=0.0)
|
||||
|
||||
with pytest.raises(ValueError, match="cumulative_loss_limit"):
|
||||
LossLimitGuard(cumulative_loss_limit=-1.0)
|
||||
@@ -8,7 +8,10 @@ import pytest
|
||||
|
||||
from arbitrade.detection.engine import OpportunityEvent
|
||||
from arbitrade.exchange.models import BookDelta, BookLevel
|
||||
from arbitrade.market_data.feed import MarketDataFeed
|
||||
from arbitrade.market_data.feed import ExecutionOutcome, MarketDataFeed
|
||||
from arbitrade.risk.loss_limits import LossLimitGuard
|
||||
from arbitrade.risk.pre_trade import PreTradeValidator
|
||||
from arbitrade.risk.trade_limits import TradeLimitsGuard
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
@@ -41,20 +44,47 @@ class _FakeOpportunityWriter:
|
||||
class _FakeDetector:
|
||||
def __init__(self, event: OpportunityEvent) -> None:
|
||||
self._event = event
|
||||
self.last_base_capital: float | None = None
|
||||
|
||||
def opportunities_for_updated_pair(self, _updated_pair: str, _books: dict[str, object]):
|
||||
def opportunities_for_updated_pair(
|
||||
self,
|
||||
_updated_pair: str,
|
||||
_books: dict[str, object],
|
||||
*,
|
||||
base_capital: float,
|
||||
):
|
||||
self.last_base_capital = base_capital
|
||||
return [self._event]
|
||||
|
||||
|
||||
class _FakeExecutor:
|
||||
def __init__(self) -> None:
|
||||
self.calls: list[OpportunityEvent] = []
|
||||
self.realized_pnls: list[float | None] = []
|
||||
self.outcomes: list[ExecutionOutcome] = []
|
||||
|
||||
async def execute(self, event: OpportunityEvent) -> None:
|
||||
async def execute(self, event: OpportunityEvent) -> ExecutionOutcome | float | None:
|
||||
self.calls.append(event)
|
||||
if self.outcomes:
|
||||
return self.outcomes.pop(0)
|
||||
if not self.realized_pnls:
|
||||
return None
|
||||
return self.realized_pnls.pop(0)
|
||||
|
||||
|
||||
def _sample_event() -> OpportunityEvent:
|
||||
@dataclass(slots=True)
|
||||
class _FakeWsClientTwoMessages:
|
||||
delta: BookDelta
|
||||
|
||||
async def connect_stream(self):
|
||||
yield SimpleNamespace(payload={"channel": "book", "seq": 1})
|
||||
yield SimpleNamespace(payload={"channel": "book", "seq": 2})
|
||||
|
||||
def parse_book_delta(self, _payload: dict[str, object]) -> BookDelta:
|
||||
return self.delta
|
||||
|
||||
|
||||
def _sample_event(*, allocated_capital: float = 1.0) -> OpportunityEvent:
|
||||
return OpportunityEvent(
|
||||
detected_at=datetime.now(UTC),
|
||||
cycle="USD->BTC->ETH->USD",
|
||||
@@ -64,6 +94,7 @@ def _sample_event() -> OpportunityEvent:
|
||||
gross_pct=4.0,
|
||||
net_pct=3.0,
|
||||
est_profit=0.03,
|
||||
allocated_capital=allocated_capital,
|
||||
)
|
||||
|
||||
|
||||
@@ -110,3 +141,184 @@ async def test_market_data_feed_live_mode_executes_orders() -> None:
|
||||
|
||||
assert len(executor.calls) == 1
|
||||
assert executor.calls[0].cycle == "USD->BTC->ETH->USD"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_market_data_feed_enforces_per_trade_capital_limit() -> None:
|
||||
event = _sample_event()
|
||||
detector = _FakeDetector(event)
|
||||
feed = MarketDataFeed(
|
||||
ws_client=_FakeWsClient(_sample_delta()),
|
||||
snapshot_writer=_FakeSnapshotWriter(),
|
||||
detector=detector,
|
||||
opportunity_writer=_FakeOpportunityWriter(),
|
||||
paper_trading_mode=True,
|
||||
trade_capital=250.0,
|
||||
max_trade_capital=100.0,
|
||||
)
|
||||
|
||||
await feed.run()
|
||||
|
||||
assert detector.last_base_capital == 100.0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_market_data_feed_auto_halts_on_daily_loss_limit() -> None:
|
||||
event = _sample_event()
|
||||
detector = _FakeDetector(event)
|
||||
executor = _FakeExecutor()
|
||||
executor.realized_pnls = [-60.0, -10.0]
|
||||
loss_guard = LossLimitGuard(daily_loss_limit=50.0)
|
||||
feed = MarketDataFeed(
|
||||
ws_client=_FakeWsClientTwoMessages(_sample_delta()),
|
||||
snapshot_writer=_FakeSnapshotWriter(),
|
||||
detector=detector,
|
||||
opportunity_writer=_FakeOpportunityWriter(),
|
||||
paper_trading_mode=False,
|
||||
opportunity_executor=executor.execute,
|
||||
loss_limit_guard=loss_guard,
|
||||
)
|
||||
|
||||
await feed.run()
|
||||
|
||||
assert len(executor.calls) == 1
|
||||
assert loss_guard.is_halted
|
||||
assert loss_guard.halted_reason == "daily_loss_limit_breached"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_market_data_feed_auto_halts_on_cumulative_loss_limit() -> None:
|
||||
event = _sample_event()
|
||||
detector = _FakeDetector(event)
|
||||
executor = _FakeExecutor()
|
||||
executor.realized_pnls = [-40.0, -15.0]
|
||||
loss_guard = LossLimitGuard(cumulative_loss_limit=50.0)
|
||||
feed = MarketDataFeed(
|
||||
ws_client=_FakeWsClientTwoMessages(_sample_delta()),
|
||||
snapshot_writer=_FakeSnapshotWriter(),
|
||||
detector=detector,
|
||||
opportunity_writer=_FakeOpportunityWriter(),
|
||||
paper_trading_mode=False,
|
||||
opportunity_executor=executor.execute,
|
||||
loss_limit_guard=loss_guard,
|
||||
)
|
||||
|
||||
await feed.run()
|
||||
|
||||
assert len(executor.calls) == 2
|
||||
assert loss_guard.is_halted
|
||||
assert loss_guard.halted_reason == "cumulative_loss_limit_breached"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_market_data_feed_enforces_max_concurrent_trades() -> None:
|
||||
event = _sample_event()
|
||||
detector = _FakeDetector(event)
|
||||
executor = _FakeExecutor()
|
||||
executor.outcomes = [ExecutionOutcome(
|
||||
realized_pnl=None, close_trade=False)]
|
||||
trade_guard = TradeLimitsGuard(max_concurrent_trades=1)
|
||||
feed = MarketDataFeed(
|
||||
ws_client=_FakeWsClientTwoMessages(_sample_delta()),
|
||||
snapshot_writer=_FakeSnapshotWriter(),
|
||||
detector=detector,
|
||||
opportunity_writer=_FakeOpportunityWriter(),
|
||||
paper_trading_mode=False,
|
||||
opportunity_executor=executor.execute,
|
||||
trade_limits_guard=trade_guard,
|
||||
)
|
||||
|
||||
await feed.run()
|
||||
|
||||
assert len(executor.calls) == 1
|
||||
assert trade_guard.active_trades == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_market_data_feed_enforces_per_asset_exposure_cap() -> None:
|
||||
event = _sample_event(allocated_capital=100.0)
|
||||
detector = _FakeDetector(event)
|
||||
executor = _FakeExecutor()
|
||||
trade_guard = TradeLimitsGuard(max_exposure_per_asset=50.0)
|
||||
feed = MarketDataFeed(
|
||||
ws_client=_FakeWsClient(_sample_delta()),
|
||||
snapshot_writer=_FakeSnapshotWriter(),
|
||||
detector=detector,
|
||||
opportunity_writer=_FakeOpportunityWriter(),
|
||||
paper_trading_mode=False,
|
||||
opportunity_executor=executor.execute,
|
||||
trade_limits_guard=trade_guard,
|
||||
)
|
||||
|
||||
await feed.run()
|
||||
|
||||
assert len(executor.calls) == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_market_data_feed_blocks_when_pre_trade_balance_insufficient() -> None:
|
||||
event = _sample_event(allocated_capital=100.0)
|
||||
detector = _FakeDetector(event)
|
||||
executor = _FakeExecutor()
|
||||
validator = PreTradeValidator(min_order_size_by_asset={"USD": 50.0})
|
||||
feed = MarketDataFeed(
|
||||
ws_client=_FakeWsClient(_sample_delta()),
|
||||
snapshot_writer=_FakeSnapshotWriter(),
|
||||
detector=detector,
|
||||
opportunity_writer=_FakeOpportunityWriter(),
|
||||
paper_trading_mode=False,
|
||||
opportunity_executor=executor.execute,
|
||||
pre_trade_validator=validator,
|
||||
balance_provider=lambda: {"USD": 25.0},
|
||||
quote_balance_asset="USD",
|
||||
)
|
||||
|
||||
await feed.run()
|
||||
|
||||
assert len(executor.calls) == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_market_data_feed_blocks_when_pre_trade_min_order_not_met() -> None:
|
||||
event = _sample_event(allocated_capital=25.0)
|
||||
detector = _FakeDetector(event)
|
||||
executor = _FakeExecutor()
|
||||
validator = PreTradeValidator(min_order_size_by_asset={"USD": 50.0})
|
||||
feed = MarketDataFeed(
|
||||
ws_client=_FakeWsClient(_sample_delta()),
|
||||
snapshot_writer=_FakeSnapshotWriter(),
|
||||
detector=detector,
|
||||
opportunity_writer=_FakeOpportunityWriter(),
|
||||
paper_trading_mode=False,
|
||||
opportunity_executor=executor.execute,
|
||||
pre_trade_validator=validator,
|
||||
balance_provider=lambda: {"USD": 500.0},
|
||||
quote_balance_asset="USD",
|
||||
)
|
||||
|
||||
await feed.run()
|
||||
|
||||
assert len(executor.calls) == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_market_data_feed_allows_when_pre_trade_validation_passes() -> None:
|
||||
event = _sample_event(allocated_capital=75.0)
|
||||
detector = _FakeDetector(event)
|
||||
executor = _FakeExecutor()
|
||||
validator = PreTradeValidator(min_order_size_by_asset={"USD": 50.0})
|
||||
feed = MarketDataFeed(
|
||||
ws_client=_FakeWsClient(_sample_delta()),
|
||||
snapshot_writer=_FakeSnapshotWriter(),
|
||||
detector=detector,
|
||||
opportunity_writer=_FakeOpportunityWriter(),
|
||||
paper_trading_mode=False,
|
||||
opportunity_executor=executor.execute,
|
||||
pre_trade_validator=validator,
|
||||
balance_provider=lambda: {"USD": 500.0},
|
||||
quote_balance_asset="USD",
|
||||
)
|
||||
|
||||
await feed.run()
|
||||
|
||||
assert len(executor.calls) == 1
|
||||
|
||||
@@ -0,0 +1,37 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from arbitrade.risk.pre_trade import PreTradeValidator
|
||||
|
||||
|
||||
def test_pre_trade_validator_accepts_when_balance_and_min_size_pass() -> None:
|
||||
validator = PreTradeValidator(min_order_size_by_asset={"USD": 50.0})
|
||||
|
||||
assert validator.validate(
|
||||
balances_by_asset={"USD": 100.0},
|
||||
required_by_asset={"USD": 75.0},
|
||||
)
|
||||
|
||||
|
||||
def test_pre_trade_validator_rejects_when_balance_insufficient() -> None:
|
||||
validator = PreTradeValidator(min_order_size_by_asset={"USD": 50.0})
|
||||
|
||||
assert not validator.validate(
|
||||
balances_by_asset={"USD": 40.0},
|
||||
required_by_asset={"USD": 75.0},
|
||||
)
|
||||
|
||||
|
||||
def test_pre_trade_validator_rejects_when_below_min_size() -> None:
|
||||
validator = PreTradeValidator(min_order_size_by_asset={"USD": 50.0})
|
||||
|
||||
assert not validator.validate(
|
||||
balances_by_asset={"USD": 100.0},
|
||||
required_by_asset={"USD": 30.0},
|
||||
)
|
||||
|
||||
|
||||
def test_pre_trade_validator_rejects_invalid_min_order_size_config() -> None:
|
||||
with pytest.raises(ValueError, match="minimum order size"):
|
||||
PreTradeValidator(min_order_size_by_asset={"USD": 0.0})
|
||||
@@ -0,0 +1,41 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from arbitrade.risk.trade_limits import TradeLimitsGuard
|
||||
|
||||
|
||||
def test_trade_limits_guard_blocks_when_max_concurrent_reached() -> None:
|
||||
guard = TradeLimitsGuard(max_concurrent_trades=1)
|
||||
|
||||
guard.open_trade({"BTC": 10.0})
|
||||
|
||||
assert not guard.is_trade_allowed({"BTC": 1.0})
|
||||
|
||||
|
||||
def test_trade_limits_guard_blocks_when_asset_exposure_would_breach_cap() -> None:
|
||||
guard = TradeLimitsGuard(max_exposure_per_asset=100.0)
|
||||
|
||||
guard.open_trade({"BTC": 80.0})
|
||||
|
||||
assert not guard.is_trade_allowed({"BTC": 25.0})
|
||||
assert guard.is_trade_allowed({"ETH": 25.0})
|
||||
|
||||
|
||||
def test_trade_limits_guard_releases_exposure_on_close() -> None:
|
||||
guard = TradeLimitsGuard(max_concurrent_trades=2, max_exposure_per_asset=100.0)
|
||||
|
||||
guard.open_trade({"BTC": 80.0})
|
||||
guard.close_trade({"BTC": 80.0})
|
||||
|
||||
assert guard.active_trades == 0
|
||||
assert guard.exposure_for_asset("BTC") == 0.0
|
||||
assert guard.is_trade_allowed({"BTC": 100.0})
|
||||
|
||||
|
||||
def test_trade_limits_guard_rejects_invalid_configuration() -> None:
|
||||
with pytest.raises(ValueError, match="max_concurrent_trades"):
|
||||
TradeLimitsGuard(max_concurrent_trades=0)
|
||||
|
||||
with pytest.raises(ValueError, match="max_exposure_per_asset"):
|
||||
TradeLimitsGuard(max_exposure_per_asset=0.0)
|
||||
Reference in New Issue
Block a user