Add risk management features: implement loss limits, trade limits, and pre-trade validation; update settings and tests

This commit is contained in:
2026-06-01 11:16:37 +02:00
parent 9d8a8a8a45
commit 45e219d103
14 changed files with 718 additions and 20 deletions
+8
View File
@@ -16,3 +16,11 @@ KRAKEN_RETRY_BASE_DELAY_SECONDS=0.25
WS_HEARTBEAT_TIMEOUT_SECONDS=20.0 WS_HEARTBEAT_TIMEOUT_SECONDS=20.0
WS_MAX_STALENESS_SECONDS=5.0 WS_MAX_STALENESS_SECONDS=5.0
PAPER_TRADING_MODE=true PAPER_TRADING_MODE=true
TRADE_CAPITAL_USD=100.0
MAX_TRADE_CAPITAL_USD=100.0
MAX_CONCURRENT_TRADES=
MAX_EXPOSURE_PER_ASSET_USD=
QUOTE_BALANCE_ASSET=USD
MIN_ORDER_SIZE_USD=
DAILY_LOSS_LIMIT_USD=5.0
CUMULATIVE_LOSS_LIMIT_USD=10.0
+39 -9
View File
@@ -8,7 +8,12 @@ from pydantic_settings import BaseSettings, SettingsConfigDict
class Settings(BaseSettings): class Settings(BaseSettings):
model_config = SettingsConfigDict(env_file=".env", env_file_encoding="utf-8", extra="ignore") model_config = SettingsConfigDict(
env_file=".env",
env_file_encoding="utf-8",
extra="ignore",
env_ignore_empty=True,
)
app_env: str = Field(default="dev", alias="APP_ENV") app_env: str = Field(default="dev", alias="APP_ENV")
app_host: str = Field(default="0.0.0.0", alias="APP_HOST") app_host: str = Field(default="0.0.0.0", alias="APP_HOST")
@@ -17,23 +22,48 @@ class Settings(BaseSettings):
log_level: str = Field(default="INFO", alias="LOG_LEVEL") log_level: str = Field(default="INFO", alias="LOG_LEVEL")
log_json: bool = Field(default=True, alias="LOG_JSON") log_json: bool = Field(default=True, alias="LOG_JSON")
duckdb_path: Path = Field(default=Path("./data/arbitrade.duckdb"), alias="DUCKDB_PATH") duckdb_path: Path = Field(default=Path(
"./data/arbitrade.duckdb"), alias="DUCKDB_PATH")
kraken_rest_url: str = Field(default="https://api.kraken.com", alias="KRAKEN_REST_URL") kraken_rest_url: str = Field(
kraken_ws_url: str = Field(default="wss://ws.kraken.com/v2", alias="KRAKEN_WS_URL") default="https://api.kraken.com", alias="KRAKEN_REST_URL")
kraken_ws_url: str = Field(
default="wss://ws.kraken.com/v2", alias="KRAKEN_WS_URL")
kraken_private_rate_limit_seconds: float = Field( kraken_private_rate_limit_seconds: float = Field(
default=1.0, alias="KRAKEN_PRIVATE_RATE_LIMIT_SECONDS" default=1.0, alias="KRAKEN_PRIVATE_RATE_LIMIT_SECONDS"
) )
kraken_http_timeout_seconds: float = Field(default=10.0, alias="KRAKEN_HTTP_TIMEOUT_SECONDS") kraken_http_timeout_seconds: float = Field(
kraken_retry_attempts: int = Field(default=3, alias="KRAKEN_RETRY_ATTEMPTS") default=10.0, alias="KRAKEN_HTTP_TIMEOUT_SECONDS")
kraken_retry_attempts: int = Field(
default=3, alias="KRAKEN_RETRY_ATTEMPTS")
kraken_retry_base_delay_seconds: float = Field( kraken_retry_base_delay_seconds: float = Field(
default=0.25, alias="KRAKEN_RETRY_BASE_DELAY_SECONDS" default=0.25, alias="KRAKEN_RETRY_BASE_DELAY_SECONDS"
) )
kraken_api_key: str | None = Field(default=None, alias="KRAKEN_API_KEY") kraken_api_key: str | None = Field(default=None, alias="KRAKEN_API_KEY")
kraken_api_secret: str | None = Field(default=None, alias="KRAKEN_API_SECRET") kraken_api_secret: str | None = Field(
ws_heartbeat_timeout_seconds: float = Field(default=20.0, alias="WS_HEARTBEAT_TIMEOUT_SECONDS") default=None, alias="KRAKEN_API_SECRET")
ws_max_staleness_seconds: float = Field(default=5.0, alias="WS_MAX_STALENESS_SECONDS") ws_heartbeat_timeout_seconds: float = Field(
default=20.0, alias="WS_HEARTBEAT_TIMEOUT_SECONDS")
ws_max_staleness_seconds: float = Field(
default=5.0, alias="WS_MAX_STALENESS_SECONDS")
paper_trading_mode: bool = Field(default=True, alias="PAPER_TRADING_MODE") paper_trading_mode: bool = Field(default=True, alias="PAPER_TRADING_MODE")
trade_capital_usd: float = Field(default=100.0, alias="TRADE_CAPITAL_USD")
max_trade_capital_usd: float = Field(
default=100.0, alias="MAX_TRADE_CAPITAL_USD")
max_concurrent_trades: int | None = Field(
default=None, alias="MAX_CONCURRENT_TRADES")
max_exposure_per_asset_usd: float | None = Field(
default=None,
alias="MAX_EXPOSURE_PER_ASSET_USD",
)
quote_balance_asset: str = Field(
default="USD", alias="QUOTE_BALANCE_ASSET")
min_order_size_usd: float | None = Field(
default=None, alias="MIN_ORDER_SIZE_USD")
daily_loss_limit_usd: float | None = Field(
default=None, alias="DAILY_LOSS_LIMIT_USD")
cumulative_loss_limit_usd: float | None = Field(
default=None, alias="CUMULATIVE_LOSS_LIMIT_USD")
fernet_key: str | None = Field(default=None, alias="FERNET_KEY") fernet_key: str | None = Field(default=None, alias="FERNET_KEY")
+2
View File
@@ -41,6 +41,7 @@ class OpportunityEvent:
gross_pct: float gross_pct: float
net_pct: float net_pct: float
est_profit: float est_profit: float
allocated_capital: float = 1.0
@classmethod @classmethod
def from_cycle_score(cls, score: CycleScore, base_capital: float = 1.0) -> OpportunityEvent: def from_cycle_score(cls, score: CycleScore, base_capital: float = 1.0) -> OpportunityEvent:
@@ -58,6 +59,7 @@ class OpportunityEvent:
gross_pct=gross_pct, gross_pct=gross_pct,
net_pct=net_pct, net_pct=net_pct,
est_profit=est_profit, est_profit=est_profit,
allocated_capital=base_capital,
) )
+116 -3
View File
@@ -1,7 +1,8 @@
from __future__ import annotations from __future__ import annotations
import time import time
from collections.abc import Awaitable, Callable from collections.abc import Awaitable, Callable, Mapping
from dataclasses import dataclass
from datetime import UTC, datetime from datetime import UTC, datetime
import structlog import structlog
@@ -9,12 +10,21 @@ import structlog
from arbitrade.detection.engine import IncrementalCycleDetector, OpportunityEvent from arbitrade.detection.engine import IncrementalCycleDetector, OpportunityEvent
from arbitrade.exchange.kraken_ws import KrakenWsClient from arbitrade.exchange.kraken_ws import KrakenWsClient
from arbitrade.market_data.order_book import OrderBook from arbitrade.market_data.order_book import OrderBook
from arbitrade.risk.loss_limits import LossLimitGuard
from arbitrade.risk.pre_trade import PreTradeValidator
from arbitrade.risk.trade_limits import TradeLimitsGuard
from arbitrade.storage.market_snapshots import AsyncMarketSnapshotWriter, MarketSnapshot from arbitrade.storage.market_snapshots import AsyncMarketSnapshotWriter, MarketSnapshot
from arbitrade.storage.opportunities import AsyncOpportunityWriter from arbitrade.storage.opportunities import AsyncOpportunityWriter
_LOG = structlog.get_logger(__name__) _LOG = structlog.get_logger(__name__)
@dataclass(frozen=True, slots=True)
class ExecutionOutcome:
realized_pnl: float | None = None
close_trade: bool = True
class MarketDataFeed: class MarketDataFeed:
def __init__( def __init__(
self, self,
@@ -23,7 +33,17 @@ class MarketDataFeed:
detector: IncrementalCycleDetector | None = None, detector: IncrementalCycleDetector | None = None,
opportunity_writer: AsyncOpportunityWriter | None = None, opportunity_writer: AsyncOpportunityWriter | None = None,
paper_trading_mode: bool = True, paper_trading_mode: bool = True,
opportunity_executor: Callable[[OpportunityEvent], Awaitable[None]] | None = None, opportunity_executor: (
Callable[[OpportunityEvent],
Awaitable[ExecutionOutcome | float | None]] | None
) = None,
trade_capital: float = 1.0,
max_trade_capital: float | None = None,
loss_limit_guard: LossLimitGuard | None = None,
trade_limits_guard: TradeLimitsGuard | None = None,
pre_trade_validator: PreTradeValidator | None = None,
balance_provider: Callable[[], Mapping[str, float]] | None = None,
quote_balance_asset: str = "USD",
) -> None: ) -> None:
self._ws_client = ws_client self._ws_client = ws_client
self._snapshot_writer = snapshot_writer self._snapshot_writer = snapshot_writer
@@ -32,11 +52,39 @@ class MarketDataFeed:
self._opportunity_writer = opportunity_writer self._opportunity_writer = opportunity_writer
self._paper_trading_mode = paper_trading_mode self._paper_trading_mode = paper_trading_mode
self._opportunity_executor = opportunity_executor self._opportunity_executor = opportunity_executor
self._trade_capital = trade_capital
self._max_trade_capital = max_trade_capital
self._loss_limit_guard = loss_limit_guard
self._trade_limits_guard = trade_limits_guard
self._pre_trade_validator = pre_trade_validator
self._balance_provider = balance_provider
self._quote_balance_asset = quote_balance_asset.upper()
if self._trade_capital <= 0.0:
raise ValueError("trade_capital must be > 0.0")
if self._max_trade_capital is not None and self._max_trade_capital <= 0.0:
raise ValueError("max_trade_capital must be > 0.0")
@property @property
def books(self) -> dict[str, OrderBook]: def books(self) -> dict[str, OrderBook]:
return self._books return self._books
def _effective_trade_capital(self) -> float:
if self._max_trade_capital is None:
return self._trade_capital
return min(self._trade_capital, self._max_trade_capital)
@staticmethod
def _exposure_for_event(event: OpportunityEvent) -> dict[str, float]:
currencies = [part for part in event.cycle.split("->") if part]
if len(currencies) < 2:
return {}
start = currencies[0]
exposure_assets = {
currency for currency in currencies[1:] if currency != start}
return {asset: event.allocated_capital for asset in exposure_assets}
async def run(self) -> None: async def run(self) -> None:
async for message in self._ws_client.connect_stream(): async for message in self._ws_client.connect_stream():
parse_start = time.perf_counter() parse_start = time.perf_counter()
@@ -73,6 +121,7 @@ class MarketDataFeed:
opportunities = self._detector.opportunities_for_updated_pair( opportunities = self._detector.opportunities_for_updated_pair(
delta.symbol, delta.symbol,
self._books, self._books,
base_capital=self._effective_trade_capital(),
) )
_LOG.debug( _LOG.debug(
"incremental_opportunity_scores", "incremental_opportunity_scores",
@@ -111,7 +160,71 @@ class MarketDataFeed:
) )
continue continue
await self._opportunity_executor(event) if self._loss_limit_guard is not None and self._loss_limit_guard.is_halted:
_LOG.warning(
"live_trade_skipped_loss_limit_halted",
cycle=event.cycle,
updated_pair=event.updated_pair,
reason=self._loss_limit_guard.halted_reason,
)
continue
if self._pre_trade_validator is not None and self._balance_provider is not None:
required_balances = {
self._quote_balance_asset: event.allocated_capital}
balances = {
asset.upper(): amount
for asset, amount in self._balance_provider().items()
}
if not self._pre_trade_validator.validate(
balances_by_asset=balances,
required_by_asset=required_balances,
):
_LOG.warning(
"live_trade_skipped_pre_trade_validation",
cycle=event.cycle,
updated_pair=event.updated_pair,
required_by_asset=required_balances,
)
continue
exposure_by_asset = self._exposure_for_event(event)
if (
self._trade_limits_guard is not None
and not self._trade_limits_guard.is_trade_allowed(exposure_by_asset)
):
_LOG.warning(
"live_trade_skipped_trade_limits",
cycle=event.cycle,
updated_pair=event.updated_pair,
exposure_by_asset=exposure_by_asset,
)
continue
if self._trade_limits_guard is not None:
self._trade_limits_guard.open_trade(exposure_by_asset)
outcome = await self._opportunity_executor(event)
realized_pnl: float | None
close_trade = True
if isinstance(outcome, ExecutionOutcome):
realized_pnl = outcome.realized_pnl
close_trade = outcome.close_trade
else:
realized_pnl = outcome
if realized_pnl is not None and self._loss_limit_guard is not None:
self._loss_limit_guard.register_realized_pnl(
realized_pnl)
if self._loss_limit_guard.is_halted:
_LOG.warning(
"loss_limit_halt_triggered",
reason=self._loss_limit_guard.halted_reason,
cumulative_pnl=self._loss_limit_guard.cumulative_pnl,
)
if self._trade_limits_guard is not None and close_trade:
self._trade_limits_guard.close_trade(exposure_by_asset)
await self._snapshot_writer.enqueue( await self._snapshot_writer.enqueue(
MarketSnapshot( MarketSnapshot(
+7
View File
@@ -0,0 +1,7 @@
"""Risk management helpers."""
from arbitrade.risk.loss_limits import LossLimitGuard
from arbitrade.risk.pre_trade import PreTradeValidator
from arbitrade.risk.trade_limits import TradeLimitsGuard
__all__ = ["LossLimitGuard", "TradeLimitsGuard", "PreTradeValidator"]
+61
View File
@@ -0,0 +1,61 @@
from __future__ import annotations
from datetime import UTC, date, datetime
class LossLimitGuard:
def __init__(
self,
*,
daily_loss_limit: float | None = None,
cumulative_loss_limit: float | None = None,
) -> None:
self._daily_loss_limit = daily_loss_limit
self._cumulative_loss_limit = cumulative_loss_limit
if self._daily_loss_limit is not None and self._daily_loss_limit <= 0.0:
raise ValueError("daily_loss_limit must be > 0.0")
if self._cumulative_loss_limit is not None and self._cumulative_loss_limit <= 0.0:
raise ValueError("cumulative_loss_limit must be > 0.0")
self._cumulative_pnl = 0.0
self._daily_pnl: dict[date, float] = {}
self._halted_reason: str | None = None
@property
def cumulative_pnl(self) -> float:
return self._cumulative_pnl
@property
def halted_reason(self) -> str | None:
return self._halted_reason
@property
def is_halted(self) -> bool:
return self._halted_reason is not None
def daily_pnl(self, day: date) -> float:
return self._daily_pnl.get(day, 0.0)
def register_realized_pnl(self, pnl: float, *, at: datetime | None = None) -> None:
if self.is_halted:
return
timestamp = at or datetime.now(UTC)
day_key = timestamp.date()
self._cumulative_pnl += pnl
self._daily_pnl[day_key] = self._daily_pnl.get(day_key, 0.0) + pnl
if (
self._daily_loss_limit is not None
and self._daily_pnl[day_key] <= -self._daily_loss_limit
):
self._halted_reason = "daily_loss_limit_breached"
return
if (
self._cumulative_loss_limit is not None
and self._cumulative_pnl <= -self._cumulative_loss_limit
):
self._halted_reason = "cumulative_loss_limit_breached"
+43
View File
@@ -0,0 +1,43 @@
from __future__ import annotations
from collections.abc import Mapping
class PreTradeValidator:
def __init__(
self,
*,
min_order_size_by_asset: Mapping[str, float] | None = None,
) -> None:
self._min_order_size_by_asset = {
asset.upper(): float(value) for asset, value in (min_order_size_by_asset or {}).items()
}
for value in self._min_order_size_by_asset.values():
if value <= 0.0:
raise ValueError("minimum order size must be > 0.0")
def validate(
self,
*,
balances_by_asset: Mapping[str, float],
required_by_asset: Mapping[str, float],
) -> bool:
# Minimum order size checks first to fail fast on structural invalid sizes.
for asset, required in required_by_asset.items():
if required <= 0.0:
continue
min_size = self._min_order_size_by_asset.get(asset.upper())
if min_size is not None and required < min_size:
return False
# Balance checks ensure required quantity is currently available.
for asset, required in required_by_asset.items():
if required <= 0.0:
continue
available = balances_by_asset.get(asset.upper(), 0.0)
if available < required:
return False
return True
+71
View File
@@ -0,0 +1,71 @@
from __future__ import annotations
from collections.abc import Mapping
class TradeLimitsGuard:
def __init__(
self,
*,
max_concurrent_trades: int | None = None,
max_exposure_per_asset: float | None = None,
) -> None:
if max_concurrent_trades is not None and max_concurrent_trades <= 0:
raise ValueError("max_concurrent_trades must be > 0")
if max_exposure_per_asset is not None and max_exposure_per_asset <= 0.0:
raise ValueError("max_exposure_per_asset must be > 0.0")
self._max_concurrent_trades = max_concurrent_trades
self._max_exposure_per_asset = max_exposure_per_asset
self._active_trades = 0
self._asset_exposure: dict[str, float] = {}
@property
def active_trades(self) -> int:
return self._active_trades
def exposure_for_asset(self, asset: str) -> float:
return self._asset_exposure.get(asset.upper(), 0.0)
def is_trade_allowed(self, exposure_by_asset: Mapping[str, float]) -> bool:
if (
self._max_concurrent_trades is not None
and self._active_trades >= self._max_concurrent_trades
):
return False
if self._max_exposure_per_asset is None:
return True
for asset, exposure in exposure_by_asset.items():
if exposure <= 0.0:
continue
key = asset.upper()
next_exposure = self._asset_exposure.get(key, 0.0) + exposure
if next_exposure > self._max_exposure_per_asset:
return False
return True
def open_trade(self, exposure_by_asset: Mapping[str, float]) -> None:
self._active_trades += 1
for asset, exposure in exposure_by_asset.items():
if exposure <= 0.0:
continue
key = asset.upper()
self._asset_exposure[key] = self._asset_exposure.get(key, 0.0) + exposure
def close_trade(self, exposure_by_asset: Mapping[str, float]) -> None:
if self._active_trades > 0:
self._active_trades -= 1
for asset, exposure in exposure_by_asset.items():
if exposure <= 0.0:
continue
key = asset.upper()
current = self._asset_exposure.get(key, 0.0)
next_exposure = max(current - exposure, 0.0)
if next_exposure == 0.0:
self._asset_exposure.pop(key, None)
else:
self._asset_exposure[key] = next_exposure
+2 -4
View File
@@ -19,10 +19,8 @@ def _make_book_levels(
*, bids: list[tuple[float, float]], asks: list[tuple[float, float]] *, bids: list[tuple[float, float]], asks: list[tuple[float, float]]
) -> OrderBook: ) -> OrderBook:
book = OrderBook() book = OrderBook()
book.apply_bids([BookLevel(price=price, volume=volume) book.apply_bids([BookLevel(price=price, volume=volume) for price, volume in bids])
for price, volume in bids]) book.apply_asks([BookLevel(price=price, volume=volume) for price, volume in asks])
book.apply_asks([BookLevel(price=price, volume=volume)
for price, volume in asks])
return book return book
+26
View File
@@ -278,3 +278,29 @@ def test_incremental_detector_emits_structured_opportunity_event() -> None:
assert event.gross_pct == pytest.approx(4.0) assert event.gross_pct == pytest.approx(4.0)
assert event.net_pct == pytest.approx(4.0) assert event.net_pct == pytest.approx(4.0)
assert event.est_profit == pytest.approx(20.0) assert event.est_profit == pytest.approx(20.0)
def test_incremental_detector_estimated_profit_scales_with_capital() -> None:
cycle = TriangularCycle(
currencies=("USD", "BTC", "ETH"),
pairs=("BTC/USD", "ETH/BTC", "ETH/USD"),
)
detector = IncrementalCycleDetector(
CurrencyGraph.index_cycles_by_pair([cycle]),
min_profit_threshold=0.0,
)
books = {
"BTC/USD": _make_book(bid=99.9, ask=100.0),
"ETH/BTC": _make_book(bid=0.049, ask=0.05),
"ETH/USD": _make_book(bid=5.20, ask=5.21),
}
opportunities = detector.opportunities_for_updated_pair(
"ETH/BTC",
books,
base_capital=250.0,
)
assert len(opportunities) == 1
assert opportunities[0].est_profit == pytest.approx(10.0)
+49
View File
@@ -0,0 +1,49 @@
from __future__ import annotations
from datetime import UTC, datetime, timedelta
import pytest
from arbitrade.risk.loss_limits import LossLimitGuard
def test_loss_limit_guard_tracks_daily_and_cumulative_pnl() -> None:
guard = LossLimitGuard(daily_loss_limit=100.0, cumulative_loss_limit=200.0)
t0 = datetime.now(UTC)
guard.register_realized_pnl(-40.0, at=t0)
guard.register_realized_pnl(10.0, at=t0)
assert guard.cumulative_pnl == -30.0
assert guard.daily_pnl(t0.date()) == -30.0
assert not guard.is_halted
def test_loss_limit_guard_halts_on_daily_limit() -> None:
guard = LossLimitGuard(daily_loss_limit=50.0)
t0 = datetime.now(UTC)
guard.register_realized_pnl(-30.0, at=t0)
guard.register_realized_pnl(-25.0, at=t0)
assert guard.is_halted
assert guard.halted_reason == "daily_loss_limit_breached"
def test_loss_limit_guard_halts_on_cumulative_limit_across_days() -> None:
guard = LossLimitGuard(cumulative_loss_limit=60.0)
t0 = datetime.now(UTC)
guard.register_realized_pnl(-40.0, at=t0)
guard.register_realized_pnl(-25.0, at=t0 + timedelta(days=1))
assert guard.is_halted
assert guard.halted_reason == "cumulative_loss_limit_breached"
def test_loss_limit_guard_rejects_invalid_limits() -> None:
with pytest.raises(ValueError, match="daily_loss_limit"):
LossLimitGuard(daily_loss_limit=0.0)
with pytest.raises(ValueError, match="cumulative_loss_limit"):
LossLimitGuard(cumulative_loss_limit=-1.0)
+216 -4
View File
@@ -8,7 +8,10 @@ import pytest
from arbitrade.detection.engine import OpportunityEvent from arbitrade.detection.engine import OpportunityEvent
from arbitrade.exchange.models import BookDelta, BookLevel from arbitrade.exchange.models import BookDelta, BookLevel
from arbitrade.market_data.feed import MarketDataFeed from arbitrade.market_data.feed import ExecutionOutcome, MarketDataFeed
from arbitrade.risk.loss_limits import LossLimitGuard
from arbitrade.risk.pre_trade import PreTradeValidator
from arbitrade.risk.trade_limits import TradeLimitsGuard
@dataclass(slots=True) @dataclass(slots=True)
@@ -41,20 +44,47 @@ class _FakeOpportunityWriter:
class _FakeDetector: class _FakeDetector:
def __init__(self, event: OpportunityEvent) -> None: def __init__(self, event: OpportunityEvent) -> None:
self._event = event self._event = event
self.last_base_capital: float | None = None
def opportunities_for_updated_pair(self, _updated_pair: str, _books: dict[str, object]): def opportunities_for_updated_pair(
self,
_updated_pair: str,
_books: dict[str, object],
*,
base_capital: float,
):
self.last_base_capital = base_capital
return [self._event] return [self._event]
class _FakeExecutor: class _FakeExecutor:
def __init__(self) -> None: def __init__(self) -> None:
self.calls: list[OpportunityEvent] = [] self.calls: list[OpportunityEvent] = []
self.realized_pnls: list[float | None] = []
self.outcomes: list[ExecutionOutcome] = []
async def execute(self, event: OpportunityEvent) -> None: async def execute(self, event: OpportunityEvent) -> ExecutionOutcome | float | None:
self.calls.append(event) self.calls.append(event)
if self.outcomes:
return self.outcomes.pop(0)
if not self.realized_pnls:
return None
return self.realized_pnls.pop(0)
def _sample_event() -> OpportunityEvent: @dataclass(slots=True)
class _FakeWsClientTwoMessages:
delta: BookDelta
async def connect_stream(self):
yield SimpleNamespace(payload={"channel": "book", "seq": 1})
yield SimpleNamespace(payload={"channel": "book", "seq": 2})
def parse_book_delta(self, _payload: dict[str, object]) -> BookDelta:
return self.delta
def _sample_event(*, allocated_capital: float = 1.0) -> OpportunityEvent:
return OpportunityEvent( return OpportunityEvent(
detected_at=datetime.now(UTC), detected_at=datetime.now(UTC),
cycle="USD->BTC->ETH->USD", cycle="USD->BTC->ETH->USD",
@@ -64,6 +94,7 @@ def _sample_event() -> OpportunityEvent:
gross_pct=4.0, gross_pct=4.0,
net_pct=3.0, net_pct=3.0,
est_profit=0.03, est_profit=0.03,
allocated_capital=allocated_capital,
) )
@@ -110,3 +141,184 @@ async def test_market_data_feed_live_mode_executes_orders() -> None:
assert len(executor.calls) == 1 assert len(executor.calls) == 1
assert executor.calls[0].cycle == "USD->BTC->ETH->USD" assert executor.calls[0].cycle == "USD->BTC->ETH->USD"
@pytest.mark.asyncio
async def test_market_data_feed_enforces_per_trade_capital_limit() -> None:
event = _sample_event()
detector = _FakeDetector(event)
feed = MarketDataFeed(
ws_client=_FakeWsClient(_sample_delta()),
snapshot_writer=_FakeSnapshotWriter(),
detector=detector,
opportunity_writer=_FakeOpportunityWriter(),
paper_trading_mode=True,
trade_capital=250.0,
max_trade_capital=100.0,
)
await feed.run()
assert detector.last_base_capital == 100.0
@pytest.mark.asyncio
async def test_market_data_feed_auto_halts_on_daily_loss_limit() -> None:
event = _sample_event()
detector = _FakeDetector(event)
executor = _FakeExecutor()
executor.realized_pnls = [-60.0, -10.0]
loss_guard = LossLimitGuard(daily_loss_limit=50.0)
feed = MarketDataFeed(
ws_client=_FakeWsClientTwoMessages(_sample_delta()),
snapshot_writer=_FakeSnapshotWriter(),
detector=detector,
opportunity_writer=_FakeOpportunityWriter(),
paper_trading_mode=False,
opportunity_executor=executor.execute,
loss_limit_guard=loss_guard,
)
await feed.run()
assert len(executor.calls) == 1
assert loss_guard.is_halted
assert loss_guard.halted_reason == "daily_loss_limit_breached"
@pytest.mark.asyncio
async def test_market_data_feed_auto_halts_on_cumulative_loss_limit() -> None:
event = _sample_event()
detector = _FakeDetector(event)
executor = _FakeExecutor()
executor.realized_pnls = [-40.0, -15.0]
loss_guard = LossLimitGuard(cumulative_loss_limit=50.0)
feed = MarketDataFeed(
ws_client=_FakeWsClientTwoMessages(_sample_delta()),
snapshot_writer=_FakeSnapshotWriter(),
detector=detector,
opportunity_writer=_FakeOpportunityWriter(),
paper_trading_mode=False,
opportunity_executor=executor.execute,
loss_limit_guard=loss_guard,
)
await feed.run()
assert len(executor.calls) == 2
assert loss_guard.is_halted
assert loss_guard.halted_reason == "cumulative_loss_limit_breached"
@pytest.mark.asyncio
async def test_market_data_feed_enforces_max_concurrent_trades() -> None:
event = _sample_event()
detector = _FakeDetector(event)
executor = _FakeExecutor()
executor.outcomes = [ExecutionOutcome(
realized_pnl=None, close_trade=False)]
trade_guard = TradeLimitsGuard(max_concurrent_trades=1)
feed = MarketDataFeed(
ws_client=_FakeWsClientTwoMessages(_sample_delta()),
snapshot_writer=_FakeSnapshotWriter(),
detector=detector,
opportunity_writer=_FakeOpportunityWriter(),
paper_trading_mode=False,
opportunity_executor=executor.execute,
trade_limits_guard=trade_guard,
)
await feed.run()
assert len(executor.calls) == 1
assert trade_guard.active_trades == 1
@pytest.mark.asyncio
async def test_market_data_feed_enforces_per_asset_exposure_cap() -> None:
event = _sample_event(allocated_capital=100.0)
detector = _FakeDetector(event)
executor = _FakeExecutor()
trade_guard = TradeLimitsGuard(max_exposure_per_asset=50.0)
feed = MarketDataFeed(
ws_client=_FakeWsClient(_sample_delta()),
snapshot_writer=_FakeSnapshotWriter(),
detector=detector,
opportunity_writer=_FakeOpportunityWriter(),
paper_trading_mode=False,
opportunity_executor=executor.execute,
trade_limits_guard=trade_guard,
)
await feed.run()
assert len(executor.calls) == 0
@pytest.mark.asyncio
async def test_market_data_feed_blocks_when_pre_trade_balance_insufficient() -> None:
event = _sample_event(allocated_capital=100.0)
detector = _FakeDetector(event)
executor = _FakeExecutor()
validator = PreTradeValidator(min_order_size_by_asset={"USD": 50.0})
feed = MarketDataFeed(
ws_client=_FakeWsClient(_sample_delta()),
snapshot_writer=_FakeSnapshotWriter(),
detector=detector,
opportunity_writer=_FakeOpportunityWriter(),
paper_trading_mode=False,
opportunity_executor=executor.execute,
pre_trade_validator=validator,
balance_provider=lambda: {"USD": 25.0},
quote_balance_asset="USD",
)
await feed.run()
assert len(executor.calls) == 0
@pytest.mark.asyncio
async def test_market_data_feed_blocks_when_pre_trade_min_order_not_met() -> None:
event = _sample_event(allocated_capital=25.0)
detector = _FakeDetector(event)
executor = _FakeExecutor()
validator = PreTradeValidator(min_order_size_by_asset={"USD": 50.0})
feed = MarketDataFeed(
ws_client=_FakeWsClient(_sample_delta()),
snapshot_writer=_FakeSnapshotWriter(),
detector=detector,
opportunity_writer=_FakeOpportunityWriter(),
paper_trading_mode=False,
opportunity_executor=executor.execute,
pre_trade_validator=validator,
balance_provider=lambda: {"USD": 500.0},
quote_balance_asset="USD",
)
await feed.run()
assert len(executor.calls) == 0
@pytest.mark.asyncio
async def test_market_data_feed_allows_when_pre_trade_validation_passes() -> None:
event = _sample_event(allocated_capital=75.0)
detector = _FakeDetector(event)
executor = _FakeExecutor()
validator = PreTradeValidator(min_order_size_by_asset={"USD": 50.0})
feed = MarketDataFeed(
ws_client=_FakeWsClient(_sample_delta()),
snapshot_writer=_FakeSnapshotWriter(),
detector=detector,
opportunity_writer=_FakeOpportunityWriter(),
paper_trading_mode=False,
opportunity_executor=executor.execute,
pre_trade_validator=validator,
balance_provider=lambda: {"USD": 500.0},
quote_balance_asset="USD",
)
await feed.run()
assert len(executor.calls) == 1
+37
View File
@@ -0,0 +1,37 @@
from __future__ import annotations
import pytest
from arbitrade.risk.pre_trade import PreTradeValidator
def test_pre_trade_validator_accepts_when_balance_and_min_size_pass() -> None:
validator = PreTradeValidator(min_order_size_by_asset={"USD": 50.0})
assert validator.validate(
balances_by_asset={"USD": 100.0},
required_by_asset={"USD": 75.0},
)
def test_pre_trade_validator_rejects_when_balance_insufficient() -> None:
validator = PreTradeValidator(min_order_size_by_asset={"USD": 50.0})
assert not validator.validate(
balances_by_asset={"USD": 40.0},
required_by_asset={"USD": 75.0},
)
def test_pre_trade_validator_rejects_when_below_min_size() -> None:
validator = PreTradeValidator(min_order_size_by_asset={"USD": 50.0})
assert not validator.validate(
balances_by_asset={"USD": 100.0},
required_by_asset={"USD": 30.0},
)
def test_pre_trade_validator_rejects_invalid_min_order_size_config() -> None:
with pytest.raises(ValueError, match="minimum order size"):
PreTradeValidator(min_order_size_by_asset={"USD": 0.0})
+41
View File
@@ -0,0 +1,41 @@
from __future__ import annotations
import pytest
from arbitrade.risk.trade_limits import TradeLimitsGuard
def test_trade_limits_guard_blocks_when_max_concurrent_reached() -> None:
guard = TradeLimitsGuard(max_concurrent_trades=1)
guard.open_trade({"BTC": 10.0})
assert not guard.is_trade_allowed({"BTC": 1.0})
def test_trade_limits_guard_blocks_when_asset_exposure_would_breach_cap() -> None:
guard = TradeLimitsGuard(max_exposure_per_asset=100.0)
guard.open_trade({"BTC": 80.0})
assert not guard.is_trade_allowed({"BTC": 25.0})
assert guard.is_trade_allowed({"ETH": 25.0})
def test_trade_limits_guard_releases_exposure_on_close() -> None:
guard = TradeLimitsGuard(max_concurrent_trades=2, max_exposure_per_asset=100.0)
guard.open_trade({"BTC": 80.0})
guard.close_trade({"BTC": 80.0})
assert guard.active_trades == 0
assert guard.exposure_for_asset("BTC") == 0.0
assert guard.is_trade_allowed({"BTC": 100.0})
def test_trade_limits_guard_rejects_invalid_configuration() -> None:
with pytest.raises(ValueError, match="max_concurrent_trades"):
TradeLimitsGuard(max_concurrent_trades=0)
with pytest.raises(ValueError, match="max_exposure_per_asset"):
TradeLimitsGuard(max_exposure_per_asset=0.0)