diff --git a/.env.example b/.env.example index 4056b1f..042c9e5 100644 --- a/.env.example +++ b/.env.example @@ -22,5 +22,9 @@ MAX_CONCURRENT_TRADES= MAX_EXPOSURE_PER_ASSET_USD= QUOTE_BALANCE_ASSET=USD MIN_ORDER_SIZE_USD= +KILL_SWITCH_ACTIVE=false DAILY_LOSS_LIMIT_USD=5.0 CUMULATIVE_LOSS_LIMIT_USD=10.0 +MAX_SOURCE_LATENCY_MS= +MAX_APPLY_LATENCY_MS= +MAX_CONSECUTIVE_FAILURES= diff --git a/src/arbitrade/config/settings.py b/src/arbitrade/config/settings.py index b1938f0..5bd2801 100644 --- a/src/arbitrade/config/settings.py +++ b/src/arbitrade/config/settings.py @@ -22,48 +22,38 @@ class Settings(BaseSettings): log_level: str = Field(default="INFO", alias="LOG_LEVEL") log_json: bool = Field(default=True, alias="LOG_JSON") - duckdb_path: Path = Field(default=Path( - "./data/arbitrade.duckdb"), alias="DUCKDB_PATH") + duckdb_path: Path = Field(default=Path("./data/arbitrade.duckdb"), alias="DUCKDB_PATH") - kraken_rest_url: str = Field( - default="https://api.kraken.com", alias="KRAKEN_REST_URL") - kraken_ws_url: str = Field( - default="wss://ws.kraken.com/v2", alias="KRAKEN_WS_URL") + kraken_rest_url: str = Field(default="https://api.kraken.com", alias="KRAKEN_REST_URL") + kraken_ws_url: str = Field(default="wss://ws.kraken.com/v2", alias="KRAKEN_WS_URL") kraken_private_rate_limit_seconds: float = Field( default=1.0, alias="KRAKEN_PRIVATE_RATE_LIMIT_SECONDS" ) - kraken_http_timeout_seconds: float = Field( - default=10.0, alias="KRAKEN_HTTP_TIMEOUT_SECONDS") - kraken_retry_attempts: int = Field( - default=3, alias="KRAKEN_RETRY_ATTEMPTS") + kraken_http_timeout_seconds: float = Field(default=10.0, alias="KRAKEN_HTTP_TIMEOUT_SECONDS") + kraken_retry_attempts: int = Field(default=3, alias="KRAKEN_RETRY_ATTEMPTS") kraken_retry_base_delay_seconds: float = Field( default=0.25, alias="KRAKEN_RETRY_BASE_DELAY_SECONDS" ) kraken_api_key: str | None = Field(default=None, alias="KRAKEN_API_KEY") - kraken_api_secret: str | None = Field( - default=None, alias="KRAKEN_API_SECRET") - ws_heartbeat_timeout_seconds: float = Field( - default=20.0, alias="WS_HEARTBEAT_TIMEOUT_SECONDS") - ws_max_staleness_seconds: float = Field( - default=5.0, alias="WS_MAX_STALENESS_SECONDS") + kraken_api_secret: str | None = Field(default=None, alias="KRAKEN_API_SECRET") + ws_heartbeat_timeout_seconds: float = Field(default=20.0, alias="WS_HEARTBEAT_TIMEOUT_SECONDS") + ws_max_staleness_seconds: float = Field(default=5.0, alias="WS_MAX_STALENESS_SECONDS") paper_trading_mode: bool = Field(default=True, alias="PAPER_TRADING_MODE") trade_capital_usd: float = Field(default=100.0, alias="TRADE_CAPITAL_USD") - max_trade_capital_usd: float = Field( - default=100.0, alias="MAX_TRADE_CAPITAL_USD") - max_concurrent_trades: int | None = Field( - default=None, alias="MAX_CONCURRENT_TRADES") + max_trade_capital_usd: float = Field(default=100.0, alias="MAX_TRADE_CAPITAL_USD") + max_concurrent_trades: int | None = Field(default=None, alias="MAX_CONCURRENT_TRADES") max_exposure_per_asset_usd: float | None = Field( default=None, alias="MAX_EXPOSURE_PER_ASSET_USD", ) - quote_balance_asset: str = Field( - default="USD", alias="QUOTE_BALANCE_ASSET") - min_order_size_usd: float | None = Field( - default=None, alias="MIN_ORDER_SIZE_USD") - daily_loss_limit_usd: float | None = Field( - default=None, alias="DAILY_LOSS_LIMIT_USD") - cumulative_loss_limit_usd: float | None = Field( - default=None, alias="CUMULATIVE_LOSS_LIMIT_USD") + quote_balance_asset: str = Field(default="USD", alias="QUOTE_BALANCE_ASSET") + min_order_size_usd: float | None = Field(default=None, alias="MIN_ORDER_SIZE_USD") + kill_switch_active: bool = Field(default=False, alias="KILL_SWITCH_ACTIVE") + daily_loss_limit_usd: float | None = Field(default=None, alias="DAILY_LOSS_LIMIT_USD") + cumulative_loss_limit_usd: float | None = Field(default=None, alias="CUMULATIVE_LOSS_LIMIT_USD") + max_source_latency_ms: float | None = Field(default=None, alias="MAX_SOURCE_LATENCY_MS") + max_apply_latency_ms: float | None = Field(default=None, alias="MAX_APPLY_LATENCY_MS") + max_consecutive_failures: int | None = Field(default=None, alias="MAX_CONSECUTIVE_FAILURES") fernet_key: str | None = Field(default=None, alias="FERNET_KEY") diff --git a/src/arbitrade/market_data/feed.py b/src/arbitrade/market_data/feed.py index 3192b67..601bd94 100644 --- a/src/arbitrade/market_data/feed.py +++ b/src/arbitrade/market_data/feed.py @@ -10,8 +10,10 @@ import structlog from arbitrade.detection.engine import IncrementalCycleDetector, OpportunityEvent from arbitrade.exchange.kraken_ws import KrakenWsClient from arbitrade.market_data.order_book import OrderBook +from arbitrade.risk.kill_switch import KillSwitch from arbitrade.risk.loss_limits import LossLimitGuard from arbitrade.risk.pre_trade import PreTradeValidator +from arbitrade.risk.stop_conditions import StopConditionsGuard from arbitrade.risk.trade_limits import TradeLimitsGuard from arbitrade.storage.market_snapshots import AsyncMarketSnapshotWriter, MarketSnapshot from arbitrade.storage.opportunities import AsyncOpportunityWriter @@ -34,8 +36,7 @@ class MarketDataFeed: opportunity_writer: AsyncOpportunityWriter | None = None, paper_trading_mode: bool = True, opportunity_executor: ( - Callable[[OpportunityEvent], - Awaitable[ExecutionOutcome | float | None]] | None + Callable[[OpportunityEvent], Awaitable[ExecutionOutcome | float | None]] | None ) = None, trade_capital: float = 1.0, max_trade_capital: float | None = None, @@ -44,6 +45,8 @@ class MarketDataFeed: pre_trade_validator: PreTradeValidator | None = None, balance_provider: Callable[[], Mapping[str, float]] | None = None, quote_balance_asset: str = "USD", + kill_switch: KillSwitch | None = None, + stop_conditions_guard: StopConditionsGuard | None = None, ) -> None: self._ws_client = ws_client self._snapshot_writer = snapshot_writer @@ -59,6 +62,8 @@ class MarketDataFeed: self._pre_trade_validator = pre_trade_validator self._balance_provider = balance_provider self._quote_balance_asset = quote_balance_asset.upper() + self._kill_switch = kill_switch + self._stop_conditions_guard = stop_conditions_guard if self._trade_capital <= 0.0: raise ValueError("trade_capital must be > 0.0") @@ -81,8 +86,7 @@ class MarketDataFeed: return {} start = currencies[0] - exposure_assets = { - currency for currency in currencies[1:] if currency != start} + exposure_assets = {currency for currency in currencies[1:] if currency != start} return {asset: event.allocated_capital for asset in exposure_assets} async def run(self) -> None: @@ -117,6 +121,23 @@ class MarketDataFeed: source_latency_ms=source_latency_ms, ) + if self._stop_conditions_guard is not None: + self._stop_conditions_guard.observe_latency( + source_latency_ms=source_latency_ms, + apply_latency_ms=apply_latency_ms, + ) + if self._stop_conditions_guard.is_halted: + if self._kill_switch is not None and not self._kill_switch.is_active: + self._kill_switch.activate( + reason=self._stop_conditions_guard.halted_reason + or "stop_conditions_halted", + ) + _LOG.warning( + "stop_condition_halt_triggered", + reason=self._stop_conditions_guard.halted_reason, + symbol=delta.symbol, + ) + if self._detector is not None: opportunities = self._detector.opportunities_for_updated_pair( delta.symbol, @@ -160,6 +181,27 @@ class MarketDataFeed: ) continue + if self._kill_switch is not None and self._kill_switch.is_active: + _LOG.warning( + "live_trade_skipped_kill_switch", + cycle=event.cycle, + updated_pair=event.updated_pair, + reason=self._kill_switch.reason, + ) + continue + + if ( + self._stop_conditions_guard is not None + and self._stop_conditions_guard.is_halted + ): + _LOG.warning( + "live_trade_skipped_stop_condition_halt", + cycle=event.cycle, + updated_pair=event.updated_pair, + reason=self._stop_conditions_guard.halted_reason, + ) + continue + if self._loss_limit_guard is not None and self._loss_limit_guard.is_halted: _LOG.warning( "live_trade_skipped_loss_limit_halted", @@ -170,8 +212,7 @@ class MarketDataFeed: continue if self._pre_trade_validator is not None and self._balance_provider is not None: - required_balances = { - self._quote_balance_asset: event.allocated_capital} + required_balances = {self._quote_balance_asset: event.allocated_capital} balances = { asset.upper(): amount for asset, amount in self._balance_provider().items() @@ -204,7 +245,40 @@ class MarketDataFeed: if self._trade_limits_guard is not None: self._trade_limits_guard.open_trade(exposure_by_asset) - outcome = await self._opportunity_executor(event) + try: + outcome = await self._opportunity_executor(event) + except Exception: + if self._trade_limits_guard is not None: + self._trade_limits_guard.close_trade(exposure_by_asset) + + if self._stop_conditions_guard is not None: + self._stop_conditions_guard.register_failure() + if self._stop_conditions_guard.is_halted: + if ( + self._kill_switch is not None + and not self._kill_switch.is_active + ): + self._kill_switch.activate( + reason=self._stop_conditions_guard.halted_reason + or "stop_conditions_halted", + ) + _LOG.warning( + "stop_condition_halt_triggered", + reason=self._stop_conditions_guard.halted_reason, + cycle=event.cycle, + updated_pair=event.updated_pair, + ) + + _LOG.exception( + "live_trade_execution_failed", + cycle=event.cycle, + updated_pair=event.updated_pair, + ) + continue + + if self._stop_conditions_guard is not None: + self._stop_conditions_guard.register_success() + realized_pnl: float | None close_trade = True if isinstance(outcome, ExecutionOutcome): @@ -214,8 +288,7 @@ class MarketDataFeed: realized_pnl = outcome if realized_pnl is not None and self._loss_limit_guard is not None: - self._loss_limit_guard.register_realized_pnl( - realized_pnl) + self._loss_limit_guard.register_realized_pnl(realized_pnl) if self._loss_limit_guard.is_halted: _LOG.warning( "loss_limit_halt_triggered", diff --git a/src/arbitrade/risk/__init__.py b/src/arbitrade/risk/__init__.py index 512d0db..f90c40f 100644 --- a/src/arbitrade/risk/__init__.py +++ b/src/arbitrade/risk/__init__.py @@ -1,7 +1,15 @@ """Risk management helpers.""" +from arbitrade.risk.kill_switch import KillSwitch from arbitrade.risk.loss_limits import LossLimitGuard from arbitrade.risk.pre_trade import PreTradeValidator +from arbitrade.risk.stop_conditions import StopConditionsGuard from arbitrade.risk.trade_limits import TradeLimitsGuard -__all__ = ["LossLimitGuard", "TradeLimitsGuard", "PreTradeValidator"] +__all__ = [ + "LossLimitGuard", + "TradeLimitsGuard", + "PreTradeValidator", + "KillSwitch", + "StopConditionsGuard", +] diff --git a/src/arbitrade/risk/kill_switch.py b/src/arbitrade/risk/kill_switch.py new file mode 100644 index 0000000..3b3182c --- /dev/null +++ b/src/arbitrade/risk/kill_switch.py @@ -0,0 +1,23 @@ +from __future__ import annotations + + +class KillSwitch: + def __init__(self, *, active: bool = False, reason: str | None = None) -> None: + self._active = active + self._reason = reason or ("manual" if active else None) + + @property + def is_active(self) -> bool: + return self._active + + @property + def reason(self) -> str | None: + return self._reason + + def activate(self, *, reason: str = "manual") -> None: + self._active = True + self._reason = reason + + def deactivate(self) -> None: + self._active = False + self._reason = None diff --git a/src/arbitrade/risk/stop_conditions.py b/src/arbitrade/risk/stop_conditions.py new file mode 100644 index 0000000..703be77 --- /dev/null +++ b/src/arbitrade/risk/stop_conditions.py @@ -0,0 +1,72 @@ +from __future__ import annotations + + +class StopConditionsGuard: + def __init__( + self, + *, + max_source_latency_ms: float | None = None, + max_apply_latency_ms: float | None = None, + max_consecutive_failures: int | None = None, + ) -> None: + if max_source_latency_ms is not None and max_source_latency_ms <= 0.0: + raise ValueError("max_source_latency_ms must be > 0.0") + if max_apply_latency_ms is not None and max_apply_latency_ms <= 0.0: + raise ValueError("max_apply_latency_ms must be > 0.0") + if max_consecutive_failures is not None and max_consecutive_failures <= 0: + raise ValueError("max_consecutive_failures must be > 0") + + self._max_source_latency_ms = max_source_latency_ms + self._max_apply_latency_ms = max_apply_latency_ms + self._max_consecutive_failures = max_consecutive_failures + + self._consecutive_failures = 0 + self._halted_reason: str | None = None + + @property + def halted_reason(self) -> str | None: + return self._halted_reason + + @property + def is_halted(self) -> bool: + return self._halted_reason is not None + + @property + def consecutive_failures(self) -> int: + return self._consecutive_failures + + def observe_latency( + self, + *, + source_latency_ms: float | None, + apply_latency_ms: float, + ) -> None: + if self.is_halted: + return + + if ( + self._max_source_latency_ms is not None + and source_latency_ms is not None + and source_latency_ms > self._max_source_latency_ms + ): + self._halted_reason = "source_latency_limit_breached" + return + + if self._max_apply_latency_ms is not None and apply_latency_ms > self._max_apply_latency_ms: + self._halted_reason = "apply_latency_limit_breached" + + def register_failure(self) -> None: + if self.is_halted: + return + + self._consecutive_failures += 1 + if ( + self._max_consecutive_failures is not None + and self._consecutive_failures >= self._max_consecutive_failures + ): + self._halted_reason = "consecutive_failures_limit_breached" + + def register_success(self) -> None: + if self.is_halted: + return + self._consecutive_failures = 0 diff --git a/tests/unit/test_kill_switch.py b/tests/unit/test_kill_switch.py new file mode 100644 index 0000000..c00ba50 --- /dev/null +++ b/tests/unit/test_kill_switch.py @@ -0,0 +1,27 @@ +from __future__ import annotations + +from arbitrade.risk.kill_switch import KillSwitch + + +def test_kill_switch_can_activate_and_deactivate() -> None: + kill_switch = KillSwitch() + + assert not kill_switch.is_active + assert kill_switch.reason is None + + kill_switch.activate(reason="manual") + + assert kill_switch.is_active + assert kill_switch.reason == "manual" + + kill_switch.deactivate() + + assert not kill_switch.is_active + assert kill_switch.reason is None + + +def test_kill_switch_active_on_init_sets_reason() -> None: + kill_switch = KillSwitch(active=True) + + assert kill_switch.is_active + assert kill_switch.reason == "manual" diff --git a/tests/unit/test_market_data_feed.py b/tests/unit/test_market_data_feed.py index ae8c25b..4b42ecd 100644 --- a/tests/unit/test_market_data_feed.py +++ b/tests/unit/test_market_data_feed.py @@ -9,8 +9,10 @@ import pytest from arbitrade.detection.engine import OpportunityEvent from arbitrade.exchange.models import BookDelta, BookLevel from arbitrade.market_data.feed import ExecutionOutcome, MarketDataFeed +from arbitrade.risk.kill_switch import KillSwitch from arbitrade.risk.loss_limits import LossLimitGuard from arbitrade.risk.pre_trade import PreTradeValidator +from arbitrade.risk.stop_conditions import StopConditionsGuard from arbitrade.risk.trade_limits import TradeLimitsGuard @@ -72,6 +74,15 @@ class _FakeExecutor: return self.realized_pnls.pop(0) +class _FakeFailingExecutor: + def __init__(self) -> None: + self.calls: int = 0 + + async def execute(self, _event: OpportunityEvent) -> None: + self.calls += 1 + raise RuntimeError("executor failure") + + @dataclass(slots=True) class _FakeWsClientTwoMessages: delta: BookDelta @@ -215,8 +226,7 @@ async def test_market_data_feed_enforces_max_concurrent_trades() -> None: event = _sample_event() detector = _FakeDetector(event) executor = _FakeExecutor() - executor.outcomes = [ExecutionOutcome( - realized_pnl=None, close_trade=False)] + executor.outcomes = [ExecutionOutcome(realized_pnl=None, close_trade=False)] trade_guard = TradeLimitsGuard(max_concurrent_trades=1) feed = MarketDataFeed( ws_client=_FakeWsClientTwoMessages(_sample_delta()), @@ -322,3 +332,101 @@ async def test_market_data_feed_allows_when_pre_trade_validation_passes() -> Non await feed.run() assert len(executor.calls) == 1 + + +@pytest.mark.asyncio +async def test_market_data_feed_blocks_when_kill_switch_active() -> None: + event = _sample_event(allocated_capital=75.0) + detector = _FakeDetector(event) + executor = _FakeExecutor() + kill_switch = KillSwitch(active=True, reason="manual") + feed = MarketDataFeed( + ws_client=_FakeWsClient(_sample_delta()), + snapshot_writer=_FakeSnapshotWriter(), + detector=detector, + opportunity_writer=_FakeOpportunityWriter(), + paper_trading_mode=False, + opportunity_executor=executor.execute, + kill_switch=kill_switch, + ) + + await feed.run() + + assert len(executor.calls) == 0 + + +@pytest.mark.asyncio +async def test_market_data_feed_allows_when_kill_switch_inactive() -> None: + event = _sample_event(allocated_capital=75.0) + detector = _FakeDetector(event) + executor = _FakeExecutor() + kill_switch = KillSwitch(active=False) + feed = MarketDataFeed( + ws_client=_FakeWsClient(_sample_delta()), + snapshot_writer=_FakeSnapshotWriter(), + detector=detector, + opportunity_writer=_FakeOpportunityWriter(), + paper_trading_mode=False, + opportunity_executor=executor.execute, + kill_switch=kill_switch, + ) + + await feed.run() + + assert len(executor.calls) == 1 + + +@pytest.mark.asyncio +async def test_market_data_feed_halts_on_abnormal_source_latency() -> None: + event = _sample_event(allocated_capital=75.0) + detector = _FakeDetector(event) + executor = _FakeExecutor() + kill_switch = KillSwitch(active=False) + stop_guard = StopConditionsGuard(max_source_latency_ms=1.0) + delta = _sample_delta() + delta.source_timestamp_ms = 0 + feed = MarketDataFeed( + ws_client=_FakeWsClient(delta), + snapshot_writer=_FakeSnapshotWriter(), + detector=detector, + opportunity_writer=_FakeOpportunityWriter(), + paper_trading_mode=False, + opportunity_executor=executor.execute, + kill_switch=kill_switch, + stop_conditions_guard=stop_guard, + ) + + await feed.run() + + assert stop_guard.is_halted + assert stop_guard.halted_reason == "source_latency_limit_breached" + assert kill_switch.is_active + assert kill_switch.reason == "source_latency_limit_breached" + assert len(executor.calls) == 0 + + +@pytest.mark.asyncio +async def test_market_data_feed_halts_on_repeated_execution_failures() -> None: + event = _sample_event(allocated_capital=75.0) + detector = _FakeDetector(event) + executor = _FakeFailingExecutor() + kill_switch = KillSwitch(active=False) + stop_guard = StopConditionsGuard(max_consecutive_failures=2) + feed = MarketDataFeed( + ws_client=_FakeWsClientTwoMessages(_sample_delta()), + snapshot_writer=_FakeSnapshotWriter(), + detector=detector, + opportunity_writer=_FakeOpportunityWriter(), + paper_trading_mode=False, + opportunity_executor=executor.execute, + kill_switch=kill_switch, + stop_conditions_guard=stop_guard, + ) + + await feed.run() + + assert executor.calls == 2 + assert stop_guard.is_halted + assert stop_guard.halted_reason == "consecutive_failures_limit_breached" + assert kill_switch.is_active + assert kill_switch.reason == "consecutive_failures_limit_breached" diff --git a/tests/unit/test_stop_conditions.py b/tests/unit/test_stop_conditions.py new file mode 100644 index 0000000..2dd6b35 --- /dev/null +++ b/tests/unit/test_stop_conditions.py @@ -0,0 +1,57 @@ +from __future__ import annotations + +import pytest + +from arbitrade.risk.stop_conditions import StopConditionsGuard + + +def test_stop_conditions_guard_halts_on_source_latency_breach() -> None: + guard = StopConditionsGuard(max_source_latency_ms=50.0) + + guard.observe_latency(source_latency_ms=75.0, apply_latency_ms=1.0) + + assert guard.is_halted + assert guard.halted_reason == "source_latency_limit_breached" + + +def test_stop_conditions_guard_halts_on_apply_latency_breach() -> None: + guard = StopConditionsGuard(max_apply_latency_ms=2.0) + + guard.observe_latency(source_latency_ms=None, apply_latency_ms=3.5) + + assert guard.is_halted + assert guard.halted_reason == "apply_latency_limit_breached" + + +def test_stop_conditions_guard_halts_on_consecutive_failures() -> None: + guard = StopConditionsGuard(max_consecutive_failures=2) + + guard.register_failure() + assert not guard.is_halted + + guard.register_failure() + + assert guard.is_halted + assert guard.halted_reason == "consecutive_failures_limit_breached" + + +def test_stop_conditions_guard_resets_failures_after_success() -> None: + guard = StopConditionsGuard(max_consecutive_failures=3) + + guard.register_failure() + guard.register_success() + guard.register_failure() + + assert guard.consecutive_failures == 1 + assert not guard.is_halted + + +def test_stop_conditions_guard_rejects_invalid_configuration() -> None: + with pytest.raises(ValueError, match="max_source_latency_ms"): + StopConditionsGuard(max_source_latency_ms=0.0) + + with pytest.raises(ValueError, match="max_apply_latency_ms"): + StopConditionsGuard(max_apply_latency_ms=-1.0) + + with pytest.raises(ValueError, match="max_consecutive_failures"): + StopConditionsGuard(max_consecutive_failures=0)