from __future__ import annotations from dataclasses import dataclass from datetime import UTC, datetime from types import SimpleNamespace import pytest from arbitrade.detection.engine import OpportunityEvent from arbitrade.exchange.models import BookDelta, BookLevel from arbitrade.market_data.feed import ExecutionOutcome, MarketDataFeed from arbitrade.risk.loss_limits import LossLimitGuard from arbitrade.risk.pre_trade import PreTradeValidator from arbitrade.risk.trade_limits import TradeLimitsGuard @dataclass(slots=True) class _FakeWsClient: delta: BookDelta async def connect_stream(self): yield SimpleNamespace(payload={"channel": "book"}) def parse_book_delta(self, _payload: dict[str, object]) -> BookDelta: return self.delta class _FakeSnapshotWriter: def __init__(self) -> None: self.items: list[object] = [] async def enqueue(self, snapshot: object) -> None: self.items.append(snapshot) class _FakeOpportunityWriter: def __init__(self) -> None: self.items: list[OpportunityEvent] = [] async def enqueue(self, event: OpportunityEvent) -> None: self.items.append(event) class _FakeDetector: def __init__(self, event: OpportunityEvent) -> None: self._event = event self.last_base_capital: float | None = None def opportunities_for_updated_pair( self, _updated_pair: str, _books: dict[str, object], *, base_capital: float, ): self.last_base_capital = base_capital return [self._event] class _FakeExecutor: def __init__(self) -> None: self.calls: list[OpportunityEvent] = [] self.realized_pnls: list[float | None] = [] self.outcomes: list[ExecutionOutcome] = [] async def execute(self, event: OpportunityEvent) -> ExecutionOutcome | float | None: self.calls.append(event) if self.outcomes: return self.outcomes.pop(0) if not self.realized_pnls: return None return self.realized_pnls.pop(0) @dataclass(slots=True) class _FakeWsClientTwoMessages: delta: BookDelta async def connect_stream(self): yield SimpleNamespace(payload={"channel": "book", "seq": 1}) yield SimpleNamespace(payload={"channel": "book", "seq": 2}) def parse_book_delta(self, _payload: dict[str, object]) -> BookDelta: return self.delta def _sample_event(*, allocated_capital: float = 1.0) -> OpportunityEvent: return OpportunityEvent( detected_at=datetime.now(UTC), cycle="USD->BTC->ETH->USD", updated_pair="BTC/USD", gross_rate=1.04, net_rate=1.03, gross_pct=4.0, net_pct=3.0, est_profit=0.03, allocated_capital=allocated_capital, ) def _sample_delta() -> BookDelta: return BookDelta( symbol="BTC/USD", bids=[BookLevel(price=100.0, volume=1.0)], asks=[BookLevel(price=100.5, volume=1.0)], ) @pytest.mark.asyncio async def test_market_data_feed_dry_run_does_not_execute_orders() -> None: event = _sample_event() executor = _FakeExecutor() feed = MarketDataFeed( ws_client=_FakeWsClient(_sample_delta()), snapshot_writer=_FakeSnapshotWriter(), detector=_FakeDetector(event), opportunity_writer=_FakeOpportunityWriter(), paper_trading_mode=True, opportunity_executor=executor.execute, ) await feed.run() assert executor.calls == [] @pytest.mark.asyncio async def test_market_data_feed_live_mode_executes_orders() -> None: event = _sample_event() executor = _FakeExecutor() feed = MarketDataFeed( ws_client=_FakeWsClient(_sample_delta()), snapshot_writer=_FakeSnapshotWriter(), detector=_FakeDetector(event), opportunity_writer=_FakeOpportunityWriter(), paper_trading_mode=False, opportunity_executor=executor.execute, ) await feed.run() assert len(executor.calls) == 1 assert executor.calls[0].cycle == "USD->BTC->ETH->USD" @pytest.mark.asyncio async def test_market_data_feed_enforces_per_trade_capital_limit() -> None: event = _sample_event() detector = _FakeDetector(event) feed = MarketDataFeed( ws_client=_FakeWsClient(_sample_delta()), snapshot_writer=_FakeSnapshotWriter(), detector=detector, opportunity_writer=_FakeOpportunityWriter(), paper_trading_mode=True, trade_capital=250.0, max_trade_capital=100.0, ) await feed.run() assert detector.last_base_capital == 100.0 @pytest.mark.asyncio async def test_market_data_feed_auto_halts_on_daily_loss_limit() -> None: event = _sample_event() detector = _FakeDetector(event) executor = _FakeExecutor() executor.realized_pnls = [-60.0, -10.0] loss_guard = LossLimitGuard(daily_loss_limit=50.0) feed = MarketDataFeed( ws_client=_FakeWsClientTwoMessages(_sample_delta()), snapshot_writer=_FakeSnapshotWriter(), detector=detector, opportunity_writer=_FakeOpportunityWriter(), paper_trading_mode=False, opportunity_executor=executor.execute, loss_limit_guard=loss_guard, ) await feed.run() assert len(executor.calls) == 1 assert loss_guard.is_halted assert loss_guard.halted_reason == "daily_loss_limit_breached" @pytest.mark.asyncio async def test_market_data_feed_auto_halts_on_cumulative_loss_limit() -> None: event = _sample_event() detector = _FakeDetector(event) executor = _FakeExecutor() executor.realized_pnls = [-40.0, -15.0] loss_guard = LossLimitGuard(cumulative_loss_limit=50.0) feed = MarketDataFeed( ws_client=_FakeWsClientTwoMessages(_sample_delta()), snapshot_writer=_FakeSnapshotWriter(), detector=detector, opportunity_writer=_FakeOpportunityWriter(), paper_trading_mode=False, opportunity_executor=executor.execute, loss_limit_guard=loss_guard, ) await feed.run() assert len(executor.calls) == 2 assert loss_guard.is_halted assert loss_guard.halted_reason == "cumulative_loss_limit_breached" @pytest.mark.asyncio async def test_market_data_feed_enforces_max_concurrent_trades() -> None: event = _sample_event() detector = _FakeDetector(event) executor = _FakeExecutor() executor.outcomes = [ExecutionOutcome( realized_pnl=None, close_trade=False)] trade_guard = TradeLimitsGuard(max_concurrent_trades=1) feed = MarketDataFeed( ws_client=_FakeWsClientTwoMessages(_sample_delta()), snapshot_writer=_FakeSnapshotWriter(), detector=detector, opportunity_writer=_FakeOpportunityWriter(), paper_trading_mode=False, opportunity_executor=executor.execute, trade_limits_guard=trade_guard, ) await feed.run() assert len(executor.calls) == 1 assert trade_guard.active_trades == 1 @pytest.mark.asyncio async def test_market_data_feed_enforces_per_asset_exposure_cap() -> None: event = _sample_event(allocated_capital=100.0) detector = _FakeDetector(event) executor = _FakeExecutor() trade_guard = TradeLimitsGuard(max_exposure_per_asset=50.0) feed = MarketDataFeed( ws_client=_FakeWsClient(_sample_delta()), snapshot_writer=_FakeSnapshotWriter(), detector=detector, opportunity_writer=_FakeOpportunityWriter(), paper_trading_mode=False, opportunity_executor=executor.execute, trade_limits_guard=trade_guard, ) await feed.run() assert len(executor.calls) == 0 @pytest.mark.asyncio async def test_market_data_feed_blocks_when_pre_trade_balance_insufficient() -> None: event = _sample_event(allocated_capital=100.0) detector = _FakeDetector(event) executor = _FakeExecutor() validator = PreTradeValidator(min_order_size_by_asset={"USD": 50.0}) feed = MarketDataFeed( ws_client=_FakeWsClient(_sample_delta()), snapshot_writer=_FakeSnapshotWriter(), detector=detector, opportunity_writer=_FakeOpportunityWriter(), paper_trading_mode=False, opportunity_executor=executor.execute, pre_trade_validator=validator, balance_provider=lambda: {"USD": 25.0}, quote_balance_asset="USD", ) await feed.run() assert len(executor.calls) == 0 @pytest.mark.asyncio async def test_market_data_feed_blocks_when_pre_trade_min_order_not_met() -> None: event = _sample_event(allocated_capital=25.0) detector = _FakeDetector(event) executor = _FakeExecutor() validator = PreTradeValidator(min_order_size_by_asset={"USD": 50.0}) feed = MarketDataFeed( ws_client=_FakeWsClient(_sample_delta()), snapshot_writer=_FakeSnapshotWriter(), detector=detector, opportunity_writer=_FakeOpportunityWriter(), paper_trading_mode=False, opportunity_executor=executor.execute, pre_trade_validator=validator, balance_provider=lambda: {"USD": 500.0}, quote_balance_asset="USD", ) await feed.run() assert len(executor.calls) == 0 @pytest.mark.asyncio async def test_market_data_feed_allows_when_pre_trade_validation_passes() -> None: event = _sample_event(allocated_capital=75.0) detector = _FakeDetector(event) executor = _FakeExecutor() validator = PreTradeValidator(min_order_size_by_asset={"USD": 50.0}) feed = MarketDataFeed( ws_client=_FakeWsClient(_sample_delta()), snapshot_writer=_FakeSnapshotWriter(), detector=detector, opportunity_writer=_FakeOpportunityWriter(), paper_trading_mode=False, opportunity_executor=executor.execute, pre_trade_validator=validator, balance_provider=lambda: {"USD": 500.0}, quote_balance_asset="USD", ) await feed.run() assert len(executor.calls) == 1