diff --git a/.env.example b/.env.example index 02cfdbd..4056b1f 100644 --- a/.env.example +++ b/.env.example @@ -16,3 +16,11 @@ KRAKEN_RETRY_BASE_DELAY_SECONDS=0.25 WS_HEARTBEAT_TIMEOUT_SECONDS=20.0 WS_MAX_STALENESS_SECONDS=5.0 PAPER_TRADING_MODE=true +TRADE_CAPITAL_USD=100.0 +MAX_TRADE_CAPITAL_USD=100.0 +MAX_CONCURRENT_TRADES= +MAX_EXPOSURE_PER_ASSET_USD= +QUOTE_BALANCE_ASSET=USD +MIN_ORDER_SIZE_USD= +DAILY_LOSS_LIMIT_USD=5.0 +CUMULATIVE_LOSS_LIMIT_USD=10.0 diff --git a/src/arbitrade/config/settings.py b/src/arbitrade/config/settings.py index 8c8f225..b1938f0 100644 --- a/src/arbitrade/config/settings.py +++ b/src/arbitrade/config/settings.py @@ -8,7 +8,12 @@ from pydantic_settings import BaseSettings, SettingsConfigDict class Settings(BaseSettings): - model_config = SettingsConfigDict(env_file=".env", env_file_encoding="utf-8", extra="ignore") + model_config = SettingsConfigDict( + env_file=".env", + env_file_encoding="utf-8", + extra="ignore", + env_ignore_empty=True, + ) app_env: str = Field(default="dev", alias="APP_ENV") app_host: str = Field(default="0.0.0.0", alias="APP_HOST") @@ -17,23 +22,48 @@ class Settings(BaseSettings): log_level: str = Field(default="INFO", alias="LOG_LEVEL") log_json: bool = Field(default=True, alias="LOG_JSON") - duckdb_path: Path = Field(default=Path("./data/arbitrade.duckdb"), alias="DUCKDB_PATH") + duckdb_path: Path = Field(default=Path( + "./data/arbitrade.duckdb"), alias="DUCKDB_PATH") - kraken_rest_url: str = Field(default="https://api.kraken.com", alias="KRAKEN_REST_URL") - kraken_ws_url: str = Field(default="wss://ws.kraken.com/v2", alias="KRAKEN_WS_URL") + kraken_rest_url: str = Field( + default="https://api.kraken.com", alias="KRAKEN_REST_URL") + kraken_ws_url: str = Field( + default="wss://ws.kraken.com/v2", alias="KRAKEN_WS_URL") kraken_private_rate_limit_seconds: float = Field( default=1.0, alias="KRAKEN_PRIVATE_RATE_LIMIT_SECONDS" ) - kraken_http_timeout_seconds: float = Field(default=10.0, alias="KRAKEN_HTTP_TIMEOUT_SECONDS") - kraken_retry_attempts: int = Field(default=3, alias="KRAKEN_RETRY_ATTEMPTS") + kraken_http_timeout_seconds: float = Field( + default=10.0, alias="KRAKEN_HTTP_TIMEOUT_SECONDS") + kraken_retry_attempts: int = Field( + default=3, alias="KRAKEN_RETRY_ATTEMPTS") kraken_retry_base_delay_seconds: float = Field( default=0.25, alias="KRAKEN_RETRY_BASE_DELAY_SECONDS" ) kraken_api_key: str | None = Field(default=None, alias="KRAKEN_API_KEY") - kraken_api_secret: str | None = Field(default=None, alias="KRAKEN_API_SECRET") - ws_heartbeat_timeout_seconds: float = Field(default=20.0, alias="WS_HEARTBEAT_TIMEOUT_SECONDS") - ws_max_staleness_seconds: float = Field(default=5.0, alias="WS_MAX_STALENESS_SECONDS") + kraken_api_secret: str | None = Field( + default=None, alias="KRAKEN_API_SECRET") + ws_heartbeat_timeout_seconds: float = Field( + default=20.0, alias="WS_HEARTBEAT_TIMEOUT_SECONDS") + ws_max_staleness_seconds: float = Field( + default=5.0, alias="WS_MAX_STALENESS_SECONDS") paper_trading_mode: bool = Field(default=True, alias="PAPER_TRADING_MODE") + trade_capital_usd: float = Field(default=100.0, alias="TRADE_CAPITAL_USD") + max_trade_capital_usd: float = Field( + default=100.0, alias="MAX_TRADE_CAPITAL_USD") + max_concurrent_trades: int | None = Field( + default=None, alias="MAX_CONCURRENT_TRADES") + max_exposure_per_asset_usd: float | None = Field( + default=None, + alias="MAX_EXPOSURE_PER_ASSET_USD", + ) + quote_balance_asset: str = Field( + default="USD", alias="QUOTE_BALANCE_ASSET") + min_order_size_usd: float | None = Field( + default=None, alias="MIN_ORDER_SIZE_USD") + daily_loss_limit_usd: float | None = Field( + default=None, alias="DAILY_LOSS_LIMIT_USD") + cumulative_loss_limit_usd: float | None = Field( + default=None, alias="CUMULATIVE_LOSS_LIMIT_USD") fernet_key: str | None = Field(default=None, alias="FERNET_KEY") diff --git a/src/arbitrade/detection/engine.py b/src/arbitrade/detection/engine.py index 2f12ae1..fbed5da 100644 --- a/src/arbitrade/detection/engine.py +++ b/src/arbitrade/detection/engine.py @@ -41,6 +41,7 @@ class OpportunityEvent: gross_pct: float net_pct: float est_profit: float + allocated_capital: float = 1.0 @classmethod def from_cycle_score(cls, score: CycleScore, base_capital: float = 1.0) -> OpportunityEvent: @@ -58,6 +59,7 @@ class OpportunityEvent: gross_pct=gross_pct, net_pct=net_pct, est_profit=est_profit, + allocated_capital=base_capital, ) diff --git a/src/arbitrade/market_data/feed.py b/src/arbitrade/market_data/feed.py index 2db60a3..3192b67 100644 --- a/src/arbitrade/market_data/feed.py +++ b/src/arbitrade/market_data/feed.py @@ -1,7 +1,8 @@ from __future__ import annotations import time -from collections.abc import Awaitable, Callable +from collections.abc import Awaitable, Callable, Mapping +from dataclasses import dataclass from datetime import UTC, datetime import structlog @@ -9,12 +10,21 @@ import structlog from arbitrade.detection.engine import IncrementalCycleDetector, OpportunityEvent from arbitrade.exchange.kraken_ws import KrakenWsClient from arbitrade.market_data.order_book import OrderBook +from arbitrade.risk.loss_limits import LossLimitGuard +from arbitrade.risk.pre_trade import PreTradeValidator +from arbitrade.risk.trade_limits import TradeLimitsGuard from arbitrade.storage.market_snapshots import AsyncMarketSnapshotWriter, MarketSnapshot from arbitrade.storage.opportunities import AsyncOpportunityWriter _LOG = structlog.get_logger(__name__) +@dataclass(frozen=True, slots=True) +class ExecutionOutcome: + realized_pnl: float | None = None + close_trade: bool = True + + class MarketDataFeed: def __init__( self, @@ -23,7 +33,17 @@ class MarketDataFeed: detector: IncrementalCycleDetector | None = None, opportunity_writer: AsyncOpportunityWriter | None = None, paper_trading_mode: bool = True, - opportunity_executor: Callable[[OpportunityEvent], Awaitable[None]] | None = None, + opportunity_executor: ( + Callable[[OpportunityEvent], + Awaitable[ExecutionOutcome | float | None]] | None + ) = None, + trade_capital: float = 1.0, + max_trade_capital: float | None = None, + loss_limit_guard: LossLimitGuard | None = None, + trade_limits_guard: TradeLimitsGuard | None = None, + pre_trade_validator: PreTradeValidator | None = None, + balance_provider: Callable[[], Mapping[str, float]] | None = None, + quote_balance_asset: str = "USD", ) -> None: self._ws_client = ws_client self._snapshot_writer = snapshot_writer @@ -32,11 +52,39 @@ class MarketDataFeed: self._opportunity_writer = opportunity_writer self._paper_trading_mode = paper_trading_mode self._opportunity_executor = opportunity_executor + self._trade_capital = trade_capital + self._max_trade_capital = max_trade_capital + self._loss_limit_guard = loss_limit_guard + self._trade_limits_guard = trade_limits_guard + self._pre_trade_validator = pre_trade_validator + self._balance_provider = balance_provider + self._quote_balance_asset = quote_balance_asset.upper() + + if self._trade_capital <= 0.0: + raise ValueError("trade_capital must be > 0.0") + if self._max_trade_capital is not None and self._max_trade_capital <= 0.0: + raise ValueError("max_trade_capital must be > 0.0") @property def books(self) -> dict[str, OrderBook]: return self._books + def _effective_trade_capital(self) -> float: + if self._max_trade_capital is None: + return self._trade_capital + return min(self._trade_capital, self._max_trade_capital) + + @staticmethod + def _exposure_for_event(event: OpportunityEvent) -> dict[str, float]: + currencies = [part for part in event.cycle.split("->") if part] + if len(currencies) < 2: + return {} + + start = currencies[0] + exposure_assets = { + currency for currency in currencies[1:] if currency != start} + return {asset: event.allocated_capital for asset in exposure_assets} + async def run(self) -> None: async for message in self._ws_client.connect_stream(): parse_start = time.perf_counter() @@ -73,6 +121,7 @@ class MarketDataFeed: opportunities = self._detector.opportunities_for_updated_pair( delta.symbol, self._books, + base_capital=self._effective_trade_capital(), ) _LOG.debug( "incremental_opportunity_scores", @@ -111,7 +160,71 @@ class MarketDataFeed: ) continue - await self._opportunity_executor(event) + if self._loss_limit_guard is not None and self._loss_limit_guard.is_halted: + _LOG.warning( + "live_trade_skipped_loss_limit_halted", + cycle=event.cycle, + updated_pair=event.updated_pair, + reason=self._loss_limit_guard.halted_reason, + ) + continue + + if self._pre_trade_validator is not None and self._balance_provider is not None: + required_balances = { + self._quote_balance_asset: event.allocated_capital} + balances = { + asset.upper(): amount + for asset, amount in self._balance_provider().items() + } + if not self._pre_trade_validator.validate( + balances_by_asset=balances, + required_by_asset=required_balances, + ): + _LOG.warning( + "live_trade_skipped_pre_trade_validation", + cycle=event.cycle, + updated_pair=event.updated_pair, + required_by_asset=required_balances, + ) + continue + + exposure_by_asset = self._exposure_for_event(event) + if ( + self._trade_limits_guard is not None + and not self._trade_limits_guard.is_trade_allowed(exposure_by_asset) + ): + _LOG.warning( + "live_trade_skipped_trade_limits", + cycle=event.cycle, + updated_pair=event.updated_pair, + exposure_by_asset=exposure_by_asset, + ) + continue + + if self._trade_limits_guard is not None: + self._trade_limits_guard.open_trade(exposure_by_asset) + + outcome = await self._opportunity_executor(event) + realized_pnl: float | None + close_trade = True + if isinstance(outcome, ExecutionOutcome): + realized_pnl = outcome.realized_pnl + close_trade = outcome.close_trade + else: + realized_pnl = outcome + + if realized_pnl is not None and self._loss_limit_guard is not None: + self._loss_limit_guard.register_realized_pnl( + realized_pnl) + if self._loss_limit_guard.is_halted: + _LOG.warning( + "loss_limit_halt_triggered", + reason=self._loss_limit_guard.halted_reason, + cumulative_pnl=self._loss_limit_guard.cumulative_pnl, + ) + + if self._trade_limits_guard is not None and close_trade: + self._trade_limits_guard.close_trade(exposure_by_asset) await self._snapshot_writer.enqueue( MarketSnapshot( diff --git a/src/arbitrade/risk/__init__.py b/src/arbitrade/risk/__init__.py new file mode 100644 index 0000000..512d0db --- /dev/null +++ b/src/arbitrade/risk/__init__.py @@ -0,0 +1,7 @@ +"""Risk management helpers.""" + +from arbitrade.risk.loss_limits import LossLimitGuard +from arbitrade.risk.pre_trade import PreTradeValidator +from arbitrade.risk.trade_limits import TradeLimitsGuard + +__all__ = ["LossLimitGuard", "TradeLimitsGuard", "PreTradeValidator"] diff --git a/src/arbitrade/risk/loss_limits.py b/src/arbitrade/risk/loss_limits.py new file mode 100644 index 0000000..05f859d --- /dev/null +++ b/src/arbitrade/risk/loss_limits.py @@ -0,0 +1,61 @@ +from __future__ import annotations + +from datetime import UTC, date, datetime + + +class LossLimitGuard: + def __init__( + self, + *, + daily_loss_limit: float | None = None, + cumulative_loss_limit: float | None = None, + ) -> None: + self._daily_loss_limit = daily_loss_limit + self._cumulative_loss_limit = cumulative_loss_limit + + if self._daily_loss_limit is not None and self._daily_loss_limit <= 0.0: + raise ValueError("daily_loss_limit must be > 0.0") + if self._cumulative_loss_limit is not None and self._cumulative_loss_limit <= 0.0: + raise ValueError("cumulative_loss_limit must be > 0.0") + + self._cumulative_pnl = 0.0 + self._daily_pnl: dict[date, float] = {} + self._halted_reason: str | None = None + + @property + def cumulative_pnl(self) -> float: + return self._cumulative_pnl + + @property + def halted_reason(self) -> str | None: + return self._halted_reason + + @property + def is_halted(self) -> bool: + return self._halted_reason is not None + + def daily_pnl(self, day: date) -> float: + return self._daily_pnl.get(day, 0.0) + + def register_realized_pnl(self, pnl: float, *, at: datetime | None = None) -> None: + if self.is_halted: + return + + timestamp = at or datetime.now(UTC) + day_key = timestamp.date() + + self._cumulative_pnl += pnl + self._daily_pnl[day_key] = self._daily_pnl.get(day_key, 0.0) + pnl + + if ( + self._daily_loss_limit is not None + and self._daily_pnl[day_key] <= -self._daily_loss_limit + ): + self._halted_reason = "daily_loss_limit_breached" + return + + if ( + self._cumulative_loss_limit is not None + and self._cumulative_pnl <= -self._cumulative_loss_limit + ): + self._halted_reason = "cumulative_loss_limit_breached" diff --git a/src/arbitrade/risk/pre_trade.py b/src/arbitrade/risk/pre_trade.py new file mode 100644 index 0000000..74ae2ec --- /dev/null +++ b/src/arbitrade/risk/pre_trade.py @@ -0,0 +1,43 @@ +from __future__ import annotations + +from collections.abc import Mapping + + +class PreTradeValidator: + def __init__( + self, + *, + min_order_size_by_asset: Mapping[str, float] | None = None, + ) -> None: + self._min_order_size_by_asset = { + asset.upper(): float(value) for asset, value in (min_order_size_by_asset or {}).items() + } + + for value in self._min_order_size_by_asset.values(): + if value <= 0.0: + raise ValueError("minimum order size must be > 0.0") + + def validate( + self, + *, + balances_by_asset: Mapping[str, float], + required_by_asset: Mapping[str, float], + ) -> bool: + # Minimum order size checks first to fail fast on structural invalid sizes. + for asset, required in required_by_asset.items(): + if required <= 0.0: + continue + + min_size = self._min_order_size_by_asset.get(asset.upper()) + if min_size is not None and required < min_size: + return False + + # Balance checks ensure required quantity is currently available. + for asset, required in required_by_asset.items(): + if required <= 0.0: + continue + available = balances_by_asset.get(asset.upper(), 0.0) + if available < required: + return False + + return True diff --git a/src/arbitrade/risk/trade_limits.py b/src/arbitrade/risk/trade_limits.py new file mode 100644 index 0000000..b5bc186 --- /dev/null +++ b/src/arbitrade/risk/trade_limits.py @@ -0,0 +1,71 @@ +from __future__ import annotations + +from collections.abc import Mapping + + +class TradeLimitsGuard: + def __init__( + self, + *, + max_concurrent_trades: int | None = None, + max_exposure_per_asset: float | None = None, + ) -> None: + if max_concurrent_trades is not None and max_concurrent_trades <= 0: + raise ValueError("max_concurrent_trades must be > 0") + if max_exposure_per_asset is not None and max_exposure_per_asset <= 0.0: + raise ValueError("max_exposure_per_asset must be > 0.0") + + self._max_concurrent_trades = max_concurrent_trades + self._max_exposure_per_asset = max_exposure_per_asset + self._active_trades = 0 + self._asset_exposure: dict[str, float] = {} + + @property + def active_trades(self) -> int: + return self._active_trades + + def exposure_for_asset(self, asset: str) -> float: + return self._asset_exposure.get(asset.upper(), 0.0) + + def is_trade_allowed(self, exposure_by_asset: Mapping[str, float]) -> bool: + if ( + self._max_concurrent_trades is not None + and self._active_trades >= self._max_concurrent_trades + ): + return False + + if self._max_exposure_per_asset is None: + return True + + for asset, exposure in exposure_by_asset.items(): + if exposure <= 0.0: + continue + key = asset.upper() + next_exposure = self._asset_exposure.get(key, 0.0) + exposure + if next_exposure > self._max_exposure_per_asset: + return False + + return True + + def open_trade(self, exposure_by_asset: Mapping[str, float]) -> None: + self._active_trades += 1 + for asset, exposure in exposure_by_asset.items(): + if exposure <= 0.0: + continue + key = asset.upper() + self._asset_exposure[key] = self._asset_exposure.get(key, 0.0) + exposure + + def close_trade(self, exposure_by_asset: Mapping[str, float]) -> None: + if self._active_trades > 0: + self._active_trades -= 1 + + for asset, exposure in exposure_by_asset.items(): + if exposure <= 0.0: + continue + key = asset.upper() + current = self._asset_exposure.get(key, 0.0) + next_exposure = max(current - exposure, 0.0) + if next_exposure == 0.0: + self._asset_exposure.pop(key, None) + else: + self._asset_exposure[key] = next_exposure diff --git a/tests/unit/test_detection_synthetic_books.py b/tests/unit/test_detection_synthetic_books.py index f7f5555..cf35a3b 100644 --- a/tests/unit/test_detection_synthetic_books.py +++ b/tests/unit/test_detection_synthetic_books.py @@ -19,10 +19,8 @@ def _make_book_levels( *, bids: list[tuple[float, float]], asks: list[tuple[float, float]] ) -> OrderBook: book = OrderBook() - book.apply_bids([BookLevel(price=price, volume=volume) - for price, volume in bids]) - book.apply_asks([BookLevel(price=price, volume=volume) - for price, volume in asks]) + book.apply_bids([BookLevel(price=price, volume=volume) for price, volume in bids]) + book.apply_asks([BookLevel(price=price, volume=volume) for price, volume in asks]) return book diff --git a/tests/unit/test_incremental_detector.py b/tests/unit/test_incremental_detector.py index 204ef9f..ac97907 100644 --- a/tests/unit/test_incremental_detector.py +++ b/tests/unit/test_incremental_detector.py @@ -278,3 +278,29 @@ def test_incremental_detector_emits_structured_opportunity_event() -> None: assert event.gross_pct == pytest.approx(4.0) assert event.net_pct == pytest.approx(4.0) assert event.est_profit == pytest.approx(20.0) + + +def test_incremental_detector_estimated_profit_scales_with_capital() -> None: + cycle = TriangularCycle( + currencies=("USD", "BTC", "ETH"), + pairs=("BTC/USD", "ETH/BTC", "ETH/USD"), + ) + detector = IncrementalCycleDetector( + CurrencyGraph.index_cycles_by_pair([cycle]), + min_profit_threshold=0.0, + ) + + books = { + "BTC/USD": _make_book(bid=99.9, ask=100.0), + "ETH/BTC": _make_book(bid=0.049, ask=0.05), + "ETH/USD": _make_book(bid=5.20, ask=5.21), + } + + opportunities = detector.opportunities_for_updated_pair( + "ETH/BTC", + books, + base_capital=250.0, + ) + + assert len(opportunities) == 1 + assert opportunities[0].est_profit == pytest.approx(10.0) diff --git a/tests/unit/test_loss_limits.py b/tests/unit/test_loss_limits.py new file mode 100644 index 0000000..e7b5dc3 --- /dev/null +++ b/tests/unit/test_loss_limits.py @@ -0,0 +1,49 @@ +from __future__ import annotations + +from datetime import UTC, datetime, timedelta + +import pytest + +from arbitrade.risk.loss_limits import LossLimitGuard + + +def test_loss_limit_guard_tracks_daily_and_cumulative_pnl() -> None: + guard = LossLimitGuard(daily_loss_limit=100.0, cumulative_loss_limit=200.0) + t0 = datetime.now(UTC) + + guard.register_realized_pnl(-40.0, at=t0) + guard.register_realized_pnl(10.0, at=t0) + + assert guard.cumulative_pnl == -30.0 + assert guard.daily_pnl(t0.date()) == -30.0 + assert not guard.is_halted + + +def test_loss_limit_guard_halts_on_daily_limit() -> None: + guard = LossLimitGuard(daily_loss_limit=50.0) + t0 = datetime.now(UTC) + + guard.register_realized_pnl(-30.0, at=t0) + guard.register_realized_pnl(-25.0, at=t0) + + assert guard.is_halted + assert guard.halted_reason == "daily_loss_limit_breached" + + +def test_loss_limit_guard_halts_on_cumulative_limit_across_days() -> None: + guard = LossLimitGuard(cumulative_loss_limit=60.0) + t0 = datetime.now(UTC) + + guard.register_realized_pnl(-40.0, at=t0) + guard.register_realized_pnl(-25.0, at=t0 + timedelta(days=1)) + + assert guard.is_halted + assert guard.halted_reason == "cumulative_loss_limit_breached" + + +def test_loss_limit_guard_rejects_invalid_limits() -> None: + with pytest.raises(ValueError, match="daily_loss_limit"): + LossLimitGuard(daily_loss_limit=0.0) + + with pytest.raises(ValueError, match="cumulative_loss_limit"): + LossLimitGuard(cumulative_loss_limit=-1.0) diff --git a/tests/unit/test_market_data_feed.py b/tests/unit/test_market_data_feed.py index 43198c5..ae8c25b 100644 --- a/tests/unit/test_market_data_feed.py +++ b/tests/unit/test_market_data_feed.py @@ -8,7 +8,10 @@ import pytest from arbitrade.detection.engine import OpportunityEvent from arbitrade.exchange.models import BookDelta, BookLevel -from arbitrade.market_data.feed import MarketDataFeed +from arbitrade.market_data.feed import ExecutionOutcome, MarketDataFeed +from arbitrade.risk.loss_limits import LossLimitGuard +from arbitrade.risk.pre_trade import PreTradeValidator +from arbitrade.risk.trade_limits import TradeLimitsGuard @dataclass(slots=True) @@ -41,20 +44,47 @@ class _FakeOpportunityWriter: class _FakeDetector: def __init__(self, event: OpportunityEvent) -> None: self._event = event + self.last_base_capital: float | None = None - def opportunities_for_updated_pair(self, _updated_pair: str, _books: dict[str, object]): + def opportunities_for_updated_pair( + self, + _updated_pair: str, + _books: dict[str, object], + *, + base_capital: float, + ): + self.last_base_capital = base_capital return [self._event] class _FakeExecutor: def __init__(self) -> None: self.calls: list[OpportunityEvent] = [] + self.realized_pnls: list[float | None] = [] + self.outcomes: list[ExecutionOutcome] = [] - async def execute(self, event: OpportunityEvent) -> None: + async def execute(self, event: OpportunityEvent) -> ExecutionOutcome | float | None: self.calls.append(event) + if self.outcomes: + return self.outcomes.pop(0) + if not self.realized_pnls: + return None + return self.realized_pnls.pop(0) -def _sample_event() -> OpportunityEvent: +@dataclass(slots=True) +class _FakeWsClientTwoMessages: + delta: BookDelta + + async def connect_stream(self): + yield SimpleNamespace(payload={"channel": "book", "seq": 1}) + yield SimpleNamespace(payload={"channel": "book", "seq": 2}) + + def parse_book_delta(self, _payload: dict[str, object]) -> BookDelta: + return self.delta + + +def _sample_event(*, allocated_capital: float = 1.0) -> OpportunityEvent: return OpportunityEvent( detected_at=datetime.now(UTC), cycle="USD->BTC->ETH->USD", @@ -64,6 +94,7 @@ def _sample_event() -> OpportunityEvent: gross_pct=4.0, net_pct=3.0, est_profit=0.03, + allocated_capital=allocated_capital, ) @@ -110,3 +141,184 @@ async def test_market_data_feed_live_mode_executes_orders() -> None: assert len(executor.calls) == 1 assert executor.calls[0].cycle == "USD->BTC->ETH->USD" + + +@pytest.mark.asyncio +async def test_market_data_feed_enforces_per_trade_capital_limit() -> None: + event = _sample_event() + detector = _FakeDetector(event) + feed = MarketDataFeed( + ws_client=_FakeWsClient(_sample_delta()), + snapshot_writer=_FakeSnapshotWriter(), + detector=detector, + opportunity_writer=_FakeOpportunityWriter(), + paper_trading_mode=True, + trade_capital=250.0, + max_trade_capital=100.0, + ) + + await feed.run() + + assert detector.last_base_capital == 100.0 + + +@pytest.mark.asyncio +async def test_market_data_feed_auto_halts_on_daily_loss_limit() -> None: + event = _sample_event() + detector = _FakeDetector(event) + executor = _FakeExecutor() + executor.realized_pnls = [-60.0, -10.0] + loss_guard = LossLimitGuard(daily_loss_limit=50.0) + feed = MarketDataFeed( + ws_client=_FakeWsClientTwoMessages(_sample_delta()), + snapshot_writer=_FakeSnapshotWriter(), + detector=detector, + opportunity_writer=_FakeOpportunityWriter(), + paper_trading_mode=False, + opportunity_executor=executor.execute, + loss_limit_guard=loss_guard, + ) + + await feed.run() + + assert len(executor.calls) == 1 + assert loss_guard.is_halted + assert loss_guard.halted_reason == "daily_loss_limit_breached" + + +@pytest.mark.asyncio +async def test_market_data_feed_auto_halts_on_cumulative_loss_limit() -> None: + event = _sample_event() + detector = _FakeDetector(event) + executor = _FakeExecutor() + executor.realized_pnls = [-40.0, -15.0] + loss_guard = LossLimitGuard(cumulative_loss_limit=50.0) + feed = MarketDataFeed( + ws_client=_FakeWsClientTwoMessages(_sample_delta()), + snapshot_writer=_FakeSnapshotWriter(), + detector=detector, + opportunity_writer=_FakeOpportunityWriter(), + paper_trading_mode=False, + opportunity_executor=executor.execute, + loss_limit_guard=loss_guard, + ) + + await feed.run() + + assert len(executor.calls) == 2 + assert loss_guard.is_halted + assert loss_guard.halted_reason == "cumulative_loss_limit_breached" + + +@pytest.mark.asyncio +async def test_market_data_feed_enforces_max_concurrent_trades() -> None: + event = _sample_event() + detector = _FakeDetector(event) + executor = _FakeExecutor() + executor.outcomes = [ExecutionOutcome( + realized_pnl=None, close_trade=False)] + trade_guard = TradeLimitsGuard(max_concurrent_trades=1) + feed = MarketDataFeed( + ws_client=_FakeWsClientTwoMessages(_sample_delta()), + snapshot_writer=_FakeSnapshotWriter(), + detector=detector, + opportunity_writer=_FakeOpportunityWriter(), + paper_trading_mode=False, + opportunity_executor=executor.execute, + trade_limits_guard=trade_guard, + ) + + await feed.run() + + assert len(executor.calls) == 1 + assert trade_guard.active_trades == 1 + + +@pytest.mark.asyncio +async def test_market_data_feed_enforces_per_asset_exposure_cap() -> None: + event = _sample_event(allocated_capital=100.0) + detector = _FakeDetector(event) + executor = _FakeExecutor() + trade_guard = TradeLimitsGuard(max_exposure_per_asset=50.0) + feed = MarketDataFeed( + ws_client=_FakeWsClient(_sample_delta()), + snapshot_writer=_FakeSnapshotWriter(), + detector=detector, + opportunity_writer=_FakeOpportunityWriter(), + paper_trading_mode=False, + opportunity_executor=executor.execute, + trade_limits_guard=trade_guard, + ) + + await feed.run() + + assert len(executor.calls) == 0 + + +@pytest.mark.asyncio +async def test_market_data_feed_blocks_when_pre_trade_balance_insufficient() -> None: + event = _sample_event(allocated_capital=100.0) + detector = _FakeDetector(event) + executor = _FakeExecutor() + validator = PreTradeValidator(min_order_size_by_asset={"USD": 50.0}) + feed = MarketDataFeed( + ws_client=_FakeWsClient(_sample_delta()), + snapshot_writer=_FakeSnapshotWriter(), + detector=detector, + opportunity_writer=_FakeOpportunityWriter(), + paper_trading_mode=False, + opportunity_executor=executor.execute, + pre_trade_validator=validator, + balance_provider=lambda: {"USD": 25.0}, + quote_balance_asset="USD", + ) + + await feed.run() + + assert len(executor.calls) == 0 + + +@pytest.mark.asyncio +async def test_market_data_feed_blocks_when_pre_trade_min_order_not_met() -> None: + event = _sample_event(allocated_capital=25.0) + detector = _FakeDetector(event) + executor = _FakeExecutor() + validator = PreTradeValidator(min_order_size_by_asset={"USD": 50.0}) + feed = MarketDataFeed( + ws_client=_FakeWsClient(_sample_delta()), + snapshot_writer=_FakeSnapshotWriter(), + detector=detector, + opportunity_writer=_FakeOpportunityWriter(), + paper_trading_mode=False, + opportunity_executor=executor.execute, + pre_trade_validator=validator, + balance_provider=lambda: {"USD": 500.0}, + quote_balance_asset="USD", + ) + + await feed.run() + + assert len(executor.calls) == 0 + + +@pytest.mark.asyncio +async def test_market_data_feed_allows_when_pre_trade_validation_passes() -> None: + event = _sample_event(allocated_capital=75.0) + detector = _FakeDetector(event) + executor = _FakeExecutor() + validator = PreTradeValidator(min_order_size_by_asset={"USD": 50.0}) + feed = MarketDataFeed( + ws_client=_FakeWsClient(_sample_delta()), + snapshot_writer=_FakeSnapshotWriter(), + detector=detector, + opportunity_writer=_FakeOpportunityWriter(), + paper_trading_mode=False, + opportunity_executor=executor.execute, + pre_trade_validator=validator, + balance_provider=lambda: {"USD": 500.0}, + quote_balance_asset="USD", + ) + + await feed.run() + + assert len(executor.calls) == 1 diff --git a/tests/unit/test_pre_trade.py b/tests/unit/test_pre_trade.py new file mode 100644 index 0000000..faf16c7 --- /dev/null +++ b/tests/unit/test_pre_trade.py @@ -0,0 +1,37 @@ +from __future__ import annotations + +import pytest + +from arbitrade.risk.pre_trade import PreTradeValidator + + +def test_pre_trade_validator_accepts_when_balance_and_min_size_pass() -> None: + validator = PreTradeValidator(min_order_size_by_asset={"USD": 50.0}) + + assert validator.validate( + balances_by_asset={"USD": 100.0}, + required_by_asset={"USD": 75.0}, + ) + + +def test_pre_trade_validator_rejects_when_balance_insufficient() -> None: + validator = PreTradeValidator(min_order_size_by_asset={"USD": 50.0}) + + assert not validator.validate( + balances_by_asset={"USD": 40.0}, + required_by_asset={"USD": 75.0}, + ) + + +def test_pre_trade_validator_rejects_when_below_min_size() -> None: + validator = PreTradeValidator(min_order_size_by_asset={"USD": 50.0}) + + assert not validator.validate( + balances_by_asset={"USD": 100.0}, + required_by_asset={"USD": 30.0}, + ) + + +def test_pre_trade_validator_rejects_invalid_min_order_size_config() -> None: + with pytest.raises(ValueError, match="minimum order size"): + PreTradeValidator(min_order_size_by_asset={"USD": 0.0}) diff --git a/tests/unit/test_trade_limits.py b/tests/unit/test_trade_limits.py new file mode 100644 index 0000000..eab8366 --- /dev/null +++ b/tests/unit/test_trade_limits.py @@ -0,0 +1,41 @@ +from __future__ import annotations + +import pytest + +from arbitrade.risk.trade_limits import TradeLimitsGuard + + +def test_trade_limits_guard_blocks_when_max_concurrent_reached() -> None: + guard = TradeLimitsGuard(max_concurrent_trades=1) + + guard.open_trade({"BTC": 10.0}) + + assert not guard.is_trade_allowed({"BTC": 1.0}) + + +def test_trade_limits_guard_blocks_when_asset_exposure_would_breach_cap() -> None: + guard = TradeLimitsGuard(max_exposure_per_asset=100.0) + + guard.open_trade({"BTC": 80.0}) + + assert not guard.is_trade_allowed({"BTC": 25.0}) + assert guard.is_trade_allowed({"ETH": 25.0}) + + +def test_trade_limits_guard_releases_exposure_on_close() -> None: + guard = TradeLimitsGuard(max_concurrent_trades=2, max_exposure_per_asset=100.0) + + guard.open_trade({"BTC": 80.0}) + guard.close_trade({"BTC": 80.0}) + + assert guard.active_trades == 0 + assert guard.exposure_for_asset("BTC") == 0.0 + assert guard.is_trade_allowed({"BTC": 100.0}) + + +def test_trade_limits_guard_rejects_invalid_configuration() -> None: + with pytest.raises(ValueError, match="max_concurrent_trades"): + TradeLimitsGuard(max_concurrent_trades=0) + + with pytest.raises(ValueError, match="max_exposure_per_asset"): + TradeLimitsGuard(max_exposure_per_asset=0.0)