Add risk management features: implement KillSwitch and StopConditionsGuard; update settings and tests

This commit is contained in:
2026-06-01 11:22:17 +02:00
parent 45e219d103
commit 240a591a64
9 changed files with 402 additions and 40 deletions
+4
View File
@@ -22,5 +22,9 @@ MAX_CONCURRENT_TRADES=
MAX_EXPOSURE_PER_ASSET_USD=
QUOTE_BALANCE_ASSET=USD
MIN_ORDER_SIZE_USD=
KILL_SWITCH_ACTIVE=false
DAILY_LOSS_LIMIT_USD=5.0
CUMULATIVE_LOSS_LIMIT_USD=10.0
MAX_SOURCE_LATENCY_MS=
MAX_APPLY_LATENCY_MS=
MAX_CONSECUTIVE_FAILURES=
+18 -28
View File
@@ -22,48 +22,38 @@ class Settings(BaseSettings):
log_level: str = Field(default="INFO", alias="LOG_LEVEL")
log_json: bool = Field(default=True, alias="LOG_JSON")
duckdb_path: Path = Field(default=Path(
"./data/arbitrade.duckdb"), alias="DUCKDB_PATH")
duckdb_path: Path = Field(default=Path("./data/arbitrade.duckdb"), alias="DUCKDB_PATH")
kraken_rest_url: str = Field(
default="https://api.kraken.com", alias="KRAKEN_REST_URL")
kraken_ws_url: str = Field(
default="wss://ws.kraken.com/v2", alias="KRAKEN_WS_URL")
kraken_rest_url: str = Field(default="https://api.kraken.com", alias="KRAKEN_REST_URL")
kraken_ws_url: str = Field(default="wss://ws.kraken.com/v2", alias="KRAKEN_WS_URL")
kraken_private_rate_limit_seconds: float = Field(
default=1.0, alias="KRAKEN_PRIVATE_RATE_LIMIT_SECONDS"
)
kraken_http_timeout_seconds: float = Field(
default=10.0, alias="KRAKEN_HTTP_TIMEOUT_SECONDS")
kraken_retry_attempts: int = Field(
default=3, alias="KRAKEN_RETRY_ATTEMPTS")
kraken_http_timeout_seconds: float = Field(default=10.0, alias="KRAKEN_HTTP_TIMEOUT_SECONDS")
kraken_retry_attempts: int = Field(default=3, alias="KRAKEN_RETRY_ATTEMPTS")
kraken_retry_base_delay_seconds: float = Field(
default=0.25, alias="KRAKEN_RETRY_BASE_DELAY_SECONDS"
)
kraken_api_key: str | None = Field(default=None, alias="KRAKEN_API_KEY")
kraken_api_secret: str | None = Field(
default=None, alias="KRAKEN_API_SECRET")
ws_heartbeat_timeout_seconds: float = Field(
default=20.0, alias="WS_HEARTBEAT_TIMEOUT_SECONDS")
ws_max_staleness_seconds: float = Field(
default=5.0, alias="WS_MAX_STALENESS_SECONDS")
kraken_api_secret: str | None = Field(default=None, alias="KRAKEN_API_SECRET")
ws_heartbeat_timeout_seconds: float = Field(default=20.0, alias="WS_HEARTBEAT_TIMEOUT_SECONDS")
ws_max_staleness_seconds: float = Field(default=5.0, alias="WS_MAX_STALENESS_SECONDS")
paper_trading_mode: bool = Field(default=True, alias="PAPER_TRADING_MODE")
trade_capital_usd: float = Field(default=100.0, alias="TRADE_CAPITAL_USD")
max_trade_capital_usd: float = Field(
default=100.0, alias="MAX_TRADE_CAPITAL_USD")
max_concurrent_trades: int | None = Field(
default=None, alias="MAX_CONCURRENT_TRADES")
max_trade_capital_usd: float = Field(default=100.0, alias="MAX_TRADE_CAPITAL_USD")
max_concurrent_trades: int | None = Field(default=None, alias="MAX_CONCURRENT_TRADES")
max_exposure_per_asset_usd: float | None = Field(
default=None,
alias="MAX_EXPOSURE_PER_ASSET_USD",
)
quote_balance_asset: str = Field(
default="USD", alias="QUOTE_BALANCE_ASSET")
min_order_size_usd: float | None = Field(
default=None, alias="MIN_ORDER_SIZE_USD")
daily_loss_limit_usd: float | None = Field(
default=None, alias="DAILY_LOSS_LIMIT_USD")
cumulative_loss_limit_usd: float | None = Field(
default=None, alias="CUMULATIVE_LOSS_LIMIT_USD")
quote_balance_asset: str = Field(default="USD", alias="QUOTE_BALANCE_ASSET")
min_order_size_usd: float | None = Field(default=None, alias="MIN_ORDER_SIZE_USD")
kill_switch_active: bool = Field(default=False, alias="KILL_SWITCH_ACTIVE")
daily_loss_limit_usd: float | None = Field(default=None, alias="DAILY_LOSS_LIMIT_USD")
cumulative_loss_limit_usd: float | None = Field(default=None, alias="CUMULATIVE_LOSS_LIMIT_USD")
max_source_latency_ms: float | None = Field(default=None, alias="MAX_SOURCE_LATENCY_MS")
max_apply_latency_ms: float | None = Field(default=None, alias="MAX_APPLY_LATENCY_MS")
max_consecutive_failures: int | None = Field(default=None, alias="MAX_CONSECUTIVE_FAILURES")
fernet_key: str | None = Field(default=None, alias="FERNET_KEY")
+82 -9
View File
@@ -10,8 +10,10 @@ import structlog
from arbitrade.detection.engine import IncrementalCycleDetector, OpportunityEvent
from arbitrade.exchange.kraken_ws import KrakenWsClient
from arbitrade.market_data.order_book import OrderBook
from arbitrade.risk.kill_switch import KillSwitch
from arbitrade.risk.loss_limits import LossLimitGuard
from arbitrade.risk.pre_trade import PreTradeValidator
from arbitrade.risk.stop_conditions import StopConditionsGuard
from arbitrade.risk.trade_limits import TradeLimitsGuard
from arbitrade.storage.market_snapshots import AsyncMarketSnapshotWriter, MarketSnapshot
from arbitrade.storage.opportunities import AsyncOpportunityWriter
@@ -34,8 +36,7 @@ class MarketDataFeed:
opportunity_writer: AsyncOpportunityWriter | None = None,
paper_trading_mode: bool = True,
opportunity_executor: (
Callable[[OpportunityEvent],
Awaitable[ExecutionOutcome | float | None]] | None
Callable[[OpportunityEvent], Awaitable[ExecutionOutcome | float | None]] | None
) = None,
trade_capital: float = 1.0,
max_trade_capital: float | None = None,
@@ -44,6 +45,8 @@ class MarketDataFeed:
pre_trade_validator: PreTradeValidator | None = None,
balance_provider: Callable[[], Mapping[str, float]] | None = None,
quote_balance_asset: str = "USD",
kill_switch: KillSwitch | None = None,
stop_conditions_guard: StopConditionsGuard | None = None,
) -> None:
self._ws_client = ws_client
self._snapshot_writer = snapshot_writer
@@ -59,6 +62,8 @@ class MarketDataFeed:
self._pre_trade_validator = pre_trade_validator
self._balance_provider = balance_provider
self._quote_balance_asset = quote_balance_asset.upper()
self._kill_switch = kill_switch
self._stop_conditions_guard = stop_conditions_guard
if self._trade_capital <= 0.0:
raise ValueError("trade_capital must be > 0.0")
@@ -81,8 +86,7 @@ class MarketDataFeed:
return {}
start = currencies[0]
exposure_assets = {
currency for currency in currencies[1:] if currency != start}
exposure_assets = {currency for currency in currencies[1:] if currency != start}
return {asset: event.allocated_capital for asset in exposure_assets}
async def run(self) -> None:
@@ -117,6 +121,23 @@ class MarketDataFeed:
source_latency_ms=source_latency_ms,
)
if self._stop_conditions_guard is not None:
self._stop_conditions_guard.observe_latency(
source_latency_ms=source_latency_ms,
apply_latency_ms=apply_latency_ms,
)
if self._stop_conditions_guard.is_halted:
if self._kill_switch is not None and not self._kill_switch.is_active:
self._kill_switch.activate(
reason=self._stop_conditions_guard.halted_reason
or "stop_conditions_halted",
)
_LOG.warning(
"stop_condition_halt_triggered",
reason=self._stop_conditions_guard.halted_reason,
symbol=delta.symbol,
)
if self._detector is not None:
opportunities = self._detector.opportunities_for_updated_pair(
delta.symbol,
@@ -160,6 +181,27 @@ class MarketDataFeed:
)
continue
if self._kill_switch is not None and self._kill_switch.is_active:
_LOG.warning(
"live_trade_skipped_kill_switch",
cycle=event.cycle,
updated_pair=event.updated_pair,
reason=self._kill_switch.reason,
)
continue
if (
self._stop_conditions_guard is not None
and self._stop_conditions_guard.is_halted
):
_LOG.warning(
"live_trade_skipped_stop_condition_halt",
cycle=event.cycle,
updated_pair=event.updated_pair,
reason=self._stop_conditions_guard.halted_reason,
)
continue
if self._loss_limit_guard is not None and self._loss_limit_guard.is_halted:
_LOG.warning(
"live_trade_skipped_loss_limit_halted",
@@ -170,8 +212,7 @@ class MarketDataFeed:
continue
if self._pre_trade_validator is not None and self._balance_provider is not None:
required_balances = {
self._quote_balance_asset: event.allocated_capital}
required_balances = {self._quote_balance_asset: event.allocated_capital}
balances = {
asset.upper(): amount
for asset, amount in self._balance_provider().items()
@@ -204,7 +245,40 @@ class MarketDataFeed:
if self._trade_limits_guard is not None:
self._trade_limits_guard.open_trade(exposure_by_asset)
outcome = await self._opportunity_executor(event)
try:
outcome = await self._opportunity_executor(event)
except Exception:
if self._trade_limits_guard is not None:
self._trade_limits_guard.close_trade(exposure_by_asset)
if self._stop_conditions_guard is not None:
self._stop_conditions_guard.register_failure()
if self._stop_conditions_guard.is_halted:
if (
self._kill_switch is not None
and not self._kill_switch.is_active
):
self._kill_switch.activate(
reason=self._stop_conditions_guard.halted_reason
or "stop_conditions_halted",
)
_LOG.warning(
"stop_condition_halt_triggered",
reason=self._stop_conditions_guard.halted_reason,
cycle=event.cycle,
updated_pair=event.updated_pair,
)
_LOG.exception(
"live_trade_execution_failed",
cycle=event.cycle,
updated_pair=event.updated_pair,
)
continue
if self._stop_conditions_guard is not None:
self._stop_conditions_guard.register_success()
realized_pnl: float | None
close_trade = True
if isinstance(outcome, ExecutionOutcome):
@@ -214,8 +288,7 @@ class MarketDataFeed:
realized_pnl = outcome
if realized_pnl is not None and self._loss_limit_guard is not None:
self._loss_limit_guard.register_realized_pnl(
realized_pnl)
self._loss_limit_guard.register_realized_pnl(realized_pnl)
if self._loss_limit_guard.is_halted:
_LOG.warning(
"loss_limit_halt_triggered",
+9 -1
View File
@@ -1,7 +1,15 @@
"""Risk management helpers."""
from arbitrade.risk.kill_switch import KillSwitch
from arbitrade.risk.loss_limits import LossLimitGuard
from arbitrade.risk.pre_trade import PreTradeValidator
from arbitrade.risk.stop_conditions import StopConditionsGuard
from arbitrade.risk.trade_limits import TradeLimitsGuard
__all__ = ["LossLimitGuard", "TradeLimitsGuard", "PreTradeValidator"]
__all__ = [
"LossLimitGuard",
"TradeLimitsGuard",
"PreTradeValidator",
"KillSwitch",
"StopConditionsGuard",
]
+23
View File
@@ -0,0 +1,23 @@
from __future__ import annotations
class KillSwitch:
def __init__(self, *, active: bool = False, reason: str | None = None) -> None:
self._active = active
self._reason = reason or ("manual" if active else None)
@property
def is_active(self) -> bool:
return self._active
@property
def reason(self) -> str | None:
return self._reason
def activate(self, *, reason: str = "manual") -> None:
self._active = True
self._reason = reason
def deactivate(self) -> None:
self._active = False
self._reason = None
+72
View File
@@ -0,0 +1,72 @@
from __future__ import annotations
class StopConditionsGuard:
def __init__(
self,
*,
max_source_latency_ms: float | None = None,
max_apply_latency_ms: float | None = None,
max_consecutive_failures: int | None = None,
) -> None:
if max_source_latency_ms is not None and max_source_latency_ms <= 0.0:
raise ValueError("max_source_latency_ms must be > 0.0")
if max_apply_latency_ms is not None and max_apply_latency_ms <= 0.0:
raise ValueError("max_apply_latency_ms must be > 0.0")
if max_consecutive_failures is not None and max_consecutive_failures <= 0:
raise ValueError("max_consecutive_failures must be > 0")
self._max_source_latency_ms = max_source_latency_ms
self._max_apply_latency_ms = max_apply_latency_ms
self._max_consecutive_failures = max_consecutive_failures
self._consecutive_failures = 0
self._halted_reason: str | None = None
@property
def halted_reason(self) -> str | None:
return self._halted_reason
@property
def is_halted(self) -> bool:
return self._halted_reason is not None
@property
def consecutive_failures(self) -> int:
return self._consecutive_failures
def observe_latency(
self,
*,
source_latency_ms: float | None,
apply_latency_ms: float,
) -> None:
if self.is_halted:
return
if (
self._max_source_latency_ms is not None
and source_latency_ms is not None
and source_latency_ms > self._max_source_latency_ms
):
self._halted_reason = "source_latency_limit_breached"
return
if self._max_apply_latency_ms is not None and apply_latency_ms > self._max_apply_latency_ms:
self._halted_reason = "apply_latency_limit_breached"
def register_failure(self) -> None:
if self.is_halted:
return
self._consecutive_failures += 1
if (
self._max_consecutive_failures is not None
and self._consecutive_failures >= self._max_consecutive_failures
):
self._halted_reason = "consecutive_failures_limit_breached"
def register_success(self) -> None:
if self.is_halted:
return
self._consecutive_failures = 0
+27
View File
@@ -0,0 +1,27 @@
from __future__ import annotations
from arbitrade.risk.kill_switch import KillSwitch
def test_kill_switch_can_activate_and_deactivate() -> None:
kill_switch = KillSwitch()
assert not kill_switch.is_active
assert kill_switch.reason is None
kill_switch.activate(reason="manual")
assert kill_switch.is_active
assert kill_switch.reason == "manual"
kill_switch.deactivate()
assert not kill_switch.is_active
assert kill_switch.reason is None
def test_kill_switch_active_on_init_sets_reason() -> None:
kill_switch = KillSwitch(active=True)
assert kill_switch.is_active
assert kill_switch.reason == "manual"
+110 -2
View File
@@ -9,8 +9,10 @@ import pytest
from arbitrade.detection.engine import OpportunityEvent
from arbitrade.exchange.models import BookDelta, BookLevel
from arbitrade.market_data.feed import ExecutionOutcome, MarketDataFeed
from arbitrade.risk.kill_switch import KillSwitch
from arbitrade.risk.loss_limits import LossLimitGuard
from arbitrade.risk.pre_trade import PreTradeValidator
from arbitrade.risk.stop_conditions import StopConditionsGuard
from arbitrade.risk.trade_limits import TradeLimitsGuard
@@ -72,6 +74,15 @@ class _FakeExecutor:
return self.realized_pnls.pop(0)
class _FakeFailingExecutor:
def __init__(self) -> None:
self.calls: int = 0
async def execute(self, _event: OpportunityEvent) -> None:
self.calls += 1
raise RuntimeError("executor failure")
@dataclass(slots=True)
class _FakeWsClientTwoMessages:
delta: BookDelta
@@ -215,8 +226,7 @@ async def test_market_data_feed_enforces_max_concurrent_trades() -> None:
event = _sample_event()
detector = _FakeDetector(event)
executor = _FakeExecutor()
executor.outcomes = [ExecutionOutcome(
realized_pnl=None, close_trade=False)]
executor.outcomes = [ExecutionOutcome(realized_pnl=None, close_trade=False)]
trade_guard = TradeLimitsGuard(max_concurrent_trades=1)
feed = MarketDataFeed(
ws_client=_FakeWsClientTwoMessages(_sample_delta()),
@@ -322,3 +332,101 @@ async def test_market_data_feed_allows_when_pre_trade_validation_passes() -> Non
await feed.run()
assert len(executor.calls) == 1
@pytest.mark.asyncio
async def test_market_data_feed_blocks_when_kill_switch_active() -> None:
event = _sample_event(allocated_capital=75.0)
detector = _FakeDetector(event)
executor = _FakeExecutor()
kill_switch = KillSwitch(active=True, reason="manual")
feed = MarketDataFeed(
ws_client=_FakeWsClient(_sample_delta()),
snapshot_writer=_FakeSnapshotWriter(),
detector=detector,
opportunity_writer=_FakeOpportunityWriter(),
paper_trading_mode=False,
opportunity_executor=executor.execute,
kill_switch=kill_switch,
)
await feed.run()
assert len(executor.calls) == 0
@pytest.mark.asyncio
async def test_market_data_feed_allows_when_kill_switch_inactive() -> None:
event = _sample_event(allocated_capital=75.0)
detector = _FakeDetector(event)
executor = _FakeExecutor()
kill_switch = KillSwitch(active=False)
feed = MarketDataFeed(
ws_client=_FakeWsClient(_sample_delta()),
snapshot_writer=_FakeSnapshotWriter(),
detector=detector,
opportunity_writer=_FakeOpportunityWriter(),
paper_trading_mode=False,
opportunity_executor=executor.execute,
kill_switch=kill_switch,
)
await feed.run()
assert len(executor.calls) == 1
@pytest.mark.asyncio
async def test_market_data_feed_halts_on_abnormal_source_latency() -> None:
event = _sample_event(allocated_capital=75.0)
detector = _FakeDetector(event)
executor = _FakeExecutor()
kill_switch = KillSwitch(active=False)
stop_guard = StopConditionsGuard(max_source_latency_ms=1.0)
delta = _sample_delta()
delta.source_timestamp_ms = 0
feed = MarketDataFeed(
ws_client=_FakeWsClient(delta),
snapshot_writer=_FakeSnapshotWriter(),
detector=detector,
opportunity_writer=_FakeOpportunityWriter(),
paper_trading_mode=False,
opportunity_executor=executor.execute,
kill_switch=kill_switch,
stop_conditions_guard=stop_guard,
)
await feed.run()
assert stop_guard.is_halted
assert stop_guard.halted_reason == "source_latency_limit_breached"
assert kill_switch.is_active
assert kill_switch.reason == "source_latency_limit_breached"
assert len(executor.calls) == 0
@pytest.mark.asyncio
async def test_market_data_feed_halts_on_repeated_execution_failures() -> None:
event = _sample_event(allocated_capital=75.0)
detector = _FakeDetector(event)
executor = _FakeFailingExecutor()
kill_switch = KillSwitch(active=False)
stop_guard = StopConditionsGuard(max_consecutive_failures=2)
feed = MarketDataFeed(
ws_client=_FakeWsClientTwoMessages(_sample_delta()),
snapshot_writer=_FakeSnapshotWriter(),
detector=detector,
opportunity_writer=_FakeOpportunityWriter(),
paper_trading_mode=False,
opportunity_executor=executor.execute,
kill_switch=kill_switch,
stop_conditions_guard=stop_guard,
)
await feed.run()
assert executor.calls == 2
assert stop_guard.is_halted
assert stop_guard.halted_reason == "consecutive_failures_limit_breached"
assert kill_switch.is_active
assert kill_switch.reason == "consecutive_failures_limit_breached"
+57
View File
@@ -0,0 +1,57 @@
from __future__ import annotations
import pytest
from arbitrade.risk.stop_conditions import StopConditionsGuard
def test_stop_conditions_guard_halts_on_source_latency_breach() -> None:
guard = StopConditionsGuard(max_source_latency_ms=50.0)
guard.observe_latency(source_latency_ms=75.0, apply_latency_ms=1.0)
assert guard.is_halted
assert guard.halted_reason == "source_latency_limit_breached"
def test_stop_conditions_guard_halts_on_apply_latency_breach() -> None:
guard = StopConditionsGuard(max_apply_latency_ms=2.0)
guard.observe_latency(source_latency_ms=None, apply_latency_ms=3.5)
assert guard.is_halted
assert guard.halted_reason == "apply_latency_limit_breached"
def test_stop_conditions_guard_halts_on_consecutive_failures() -> None:
guard = StopConditionsGuard(max_consecutive_failures=2)
guard.register_failure()
assert not guard.is_halted
guard.register_failure()
assert guard.is_halted
assert guard.halted_reason == "consecutive_failures_limit_breached"
def test_stop_conditions_guard_resets_failures_after_success() -> None:
guard = StopConditionsGuard(max_consecutive_failures=3)
guard.register_failure()
guard.register_success()
guard.register_failure()
assert guard.consecutive_failures == 1
assert not guard.is_halted
def test_stop_conditions_guard_rejects_invalid_configuration() -> None:
with pytest.raises(ValueError, match="max_source_latency_ms"):
StopConditionsGuard(max_source_latency_ms=0.0)
with pytest.raises(ValueError, match="max_apply_latency_ms"):
StopConditionsGuard(max_apply_latency_ms=-1.0)
with pytest.raises(ValueError, match="max_consecutive_failures"):
StopConditionsGuard(max_consecutive_failures=0)