Add risk management features: implement loss limits, trade limits, and pre-trade validation; update settings and tests
This commit is contained in:
@@ -16,3 +16,11 @@ KRAKEN_RETRY_BASE_DELAY_SECONDS=0.25
|
|||||||
WS_HEARTBEAT_TIMEOUT_SECONDS=20.0
|
WS_HEARTBEAT_TIMEOUT_SECONDS=20.0
|
||||||
WS_MAX_STALENESS_SECONDS=5.0
|
WS_MAX_STALENESS_SECONDS=5.0
|
||||||
PAPER_TRADING_MODE=true
|
PAPER_TRADING_MODE=true
|
||||||
|
TRADE_CAPITAL_USD=100.0
|
||||||
|
MAX_TRADE_CAPITAL_USD=100.0
|
||||||
|
MAX_CONCURRENT_TRADES=
|
||||||
|
MAX_EXPOSURE_PER_ASSET_USD=
|
||||||
|
QUOTE_BALANCE_ASSET=USD
|
||||||
|
MIN_ORDER_SIZE_USD=
|
||||||
|
DAILY_LOSS_LIMIT_USD=5.0
|
||||||
|
CUMULATIVE_LOSS_LIMIT_USD=10.0
|
||||||
|
|||||||
@@ -8,7 +8,12 @@ from pydantic_settings import BaseSettings, SettingsConfigDict
|
|||||||
|
|
||||||
|
|
||||||
class Settings(BaseSettings):
|
class Settings(BaseSettings):
|
||||||
model_config = SettingsConfigDict(env_file=".env", env_file_encoding="utf-8", extra="ignore")
|
model_config = SettingsConfigDict(
|
||||||
|
env_file=".env",
|
||||||
|
env_file_encoding="utf-8",
|
||||||
|
extra="ignore",
|
||||||
|
env_ignore_empty=True,
|
||||||
|
)
|
||||||
|
|
||||||
app_env: str = Field(default="dev", alias="APP_ENV")
|
app_env: str = Field(default="dev", alias="APP_ENV")
|
||||||
app_host: str = Field(default="0.0.0.0", alias="APP_HOST")
|
app_host: str = Field(default="0.0.0.0", alias="APP_HOST")
|
||||||
@@ -17,23 +22,48 @@ class Settings(BaseSettings):
|
|||||||
log_level: str = Field(default="INFO", alias="LOG_LEVEL")
|
log_level: str = Field(default="INFO", alias="LOG_LEVEL")
|
||||||
log_json: bool = Field(default=True, alias="LOG_JSON")
|
log_json: bool = Field(default=True, alias="LOG_JSON")
|
||||||
|
|
||||||
duckdb_path: Path = Field(default=Path("./data/arbitrade.duckdb"), alias="DUCKDB_PATH")
|
duckdb_path: Path = Field(default=Path(
|
||||||
|
"./data/arbitrade.duckdb"), alias="DUCKDB_PATH")
|
||||||
|
|
||||||
kraken_rest_url: str = Field(default="https://api.kraken.com", alias="KRAKEN_REST_URL")
|
kraken_rest_url: str = Field(
|
||||||
kraken_ws_url: str = Field(default="wss://ws.kraken.com/v2", alias="KRAKEN_WS_URL")
|
default="https://api.kraken.com", alias="KRAKEN_REST_URL")
|
||||||
|
kraken_ws_url: str = Field(
|
||||||
|
default="wss://ws.kraken.com/v2", alias="KRAKEN_WS_URL")
|
||||||
kraken_private_rate_limit_seconds: float = Field(
|
kraken_private_rate_limit_seconds: float = Field(
|
||||||
default=1.0, alias="KRAKEN_PRIVATE_RATE_LIMIT_SECONDS"
|
default=1.0, alias="KRAKEN_PRIVATE_RATE_LIMIT_SECONDS"
|
||||||
)
|
)
|
||||||
kraken_http_timeout_seconds: float = Field(default=10.0, alias="KRAKEN_HTTP_TIMEOUT_SECONDS")
|
kraken_http_timeout_seconds: float = Field(
|
||||||
kraken_retry_attempts: int = Field(default=3, alias="KRAKEN_RETRY_ATTEMPTS")
|
default=10.0, alias="KRAKEN_HTTP_TIMEOUT_SECONDS")
|
||||||
|
kraken_retry_attempts: int = Field(
|
||||||
|
default=3, alias="KRAKEN_RETRY_ATTEMPTS")
|
||||||
kraken_retry_base_delay_seconds: float = Field(
|
kraken_retry_base_delay_seconds: float = Field(
|
||||||
default=0.25, alias="KRAKEN_RETRY_BASE_DELAY_SECONDS"
|
default=0.25, alias="KRAKEN_RETRY_BASE_DELAY_SECONDS"
|
||||||
)
|
)
|
||||||
kraken_api_key: str | None = Field(default=None, alias="KRAKEN_API_KEY")
|
kraken_api_key: str | None = Field(default=None, alias="KRAKEN_API_KEY")
|
||||||
kraken_api_secret: str | None = Field(default=None, alias="KRAKEN_API_SECRET")
|
kraken_api_secret: str | None = Field(
|
||||||
ws_heartbeat_timeout_seconds: float = Field(default=20.0, alias="WS_HEARTBEAT_TIMEOUT_SECONDS")
|
default=None, alias="KRAKEN_API_SECRET")
|
||||||
ws_max_staleness_seconds: float = Field(default=5.0, alias="WS_MAX_STALENESS_SECONDS")
|
ws_heartbeat_timeout_seconds: float = Field(
|
||||||
|
default=20.0, alias="WS_HEARTBEAT_TIMEOUT_SECONDS")
|
||||||
|
ws_max_staleness_seconds: float = Field(
|
||||||
|
default=5.0, alias="WS_MAX_STALENESS_SECONDS")
|
||||||
paper_trading_mode: bool = Field(default=True, alias="PAPER_TRADING_MODE")
|
paper_trading_mode: bool = Field(default=True, alias="PAPER_TRADING_MODE")
|
||||||
|
trade_capital_usd: float = Field(default=100.0, alias="TRADE_CAPITAL_USD")
|
||||||
|
max_trade_capital_usd: float = Field(
|
||||||
|
default=100.0, alias="MAX_TRADE_CAPITAL_USD")
|
||||||
|
max_concurrent_trades: int | None = Field(
|
||||||
|
default=None, alias="MAX_CONCURRENT_TRADES")
|
||||||
|
max_exposure_per_asset_usd: float | None = Field(
|
||||||
|
default=None,
|
||||||
|
alias="MAX_EXPOSURE_PER_ASSET_USD",
|
||||||
|
)
|
||||||
|
quote_balance_asset: str = Field(
|
||||||
|
default="USD", alias="QUOTE_BALANCE_ASSET")
|
||||||
|
min_order_size_usd: float | None = Field(
|
||||||
|
default=None, alias="MIN_ORDER_SIZE_USD")
|
||||||
|
daily_loss_limit_usd: float | None = Field(
|
||||||
|
default=None, alias="DAILY_LOSS_LIMIT_USD")
|
||||||
|
cumulative_loss_limit_usd: float | None = Field(
|
||||||
|
default=None, alias="CUMULATIVE_LOSS_LIMIT_USD")
|
||||||
|
|
||||||
fernet_key: str | None = Field(default=None, alias="FERNET_KEY")
|
fernet_key: str | None = Field(default=None, alias="FERNET_KEY")
|
||||||
|
|
||||||
|
|||||||
@@ -41,6 +41,7 @@ class OpportunityEvent:
|
|||||||
gross_pct: float
|
gross_pct: float
|
||||||
net_pct: float
|
net_pct: float
|
||||||
est_profit: float
|
est_profit: float
|
||||||
|
allocated_capital: float = 1.0
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_cycle_score(cls, score: CycleScore, base_capital: float = 1.0) -> OpportunityEvent:
|
def from_cycle_score(cls, score: CycleScore, base_capital: float = 1.0) -> OpportunityEvent:
|
||||||
@@ -58,6 +59,7 @@ class OpportunityEvent:
|
|||||||
gross_pct=gross_pct,
|
gross_pct=gross_pct,
|
||||||
net_pct=net_pct,
|
net_pct=net_pct,
|
||||||
est_profit=est_profit,
|
est_profit=est_profit,
|
||||||
|
allocated_capital=base_capital,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,7 +1,8 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import time
|
import time
|
||||||
from collections.abc import Awaitable, Callable
|
from collections.abc import Awaitable, Callable, Mapping
|
||||||
|
from dataclasses import dataclass
|
||||||
from datetime import UTC, datetime
|
from datetime import UTC, datetime
|
||||||
|
|
||||||
import structlog
|
import structlog
|
||||||
@@ -9,12 +10,21 @@ import structlog
|
|||||||
from arbitrade.detection.engine import IncrementalCycleDetector, OpportunityEvent
|
from arbitrade.detection.engine import IncrementalCycleDetector, OpportunityEvent
|
||||||
from arbitrade.exchange.kraken_ws import KrakenWsClient
|
from arbitrade.exchange.kraken_ws import KrakenWsClient
|
||||||
from arbitrade.market_data.order_book import OrderBook
|
from arbitrade.market_data.order_book import OrderBook
|
||||||
|
from arbitrade.risk.loss_limits import LossLimitGuard
|
||||||
|
from arbitrade.risk.pre_trade import PreTradeValidator
|
||||||
|
from arbitrade.risk.trade_limits import TradeLimitsGuard
|
||||||
from arbitrade.storage.market_snapshots import AsyncMarketSnapshotWriter, MarketSnapshot
|
from arbitrade.storage.market_snapshots import AsyncMarketSnapshotWriter, MarketSnapshot
|
||||||
from arbitrade.storage.opportunities import AsyncOpportunityWriter
|
from arbitrade.storage.opportunities import AsyncOpportunityWriter
|
||||||
|
|
||||||
_LOG = structlog.get_logger(__name__)
|
_LOG = structlog.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True, slots=True)
|
||||||
|
class ExecutionOutcome:
|
||||||
|
realized_pnl: float | None = None
|
||||||
|
close_trade: bool = True
|
||||||
|
|
||||||
|
|
||||||
class MarketDataFeed:
|
class MarketDataFeed:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@@ -23,7 +33,17 @@ class MarketDataFeed:
|
|||||||
detector: IncrementalCycleDetector | None = None,
|
detector: IncrementalCycleDetector | None = None,
|
||||||
opportunity_writer: AsyncOpportunityWriter | None = None,
|
opportunity_writer: AsyncOpportunityWriter | None = None,
|
||||||
paper_trading_mode: bool = True,
|
paper_trading_mode: bool = True,
|
||||||
opportunity_executor: Callable[[OpportunityEvent], Awaitable[None]] | None = None,
|
opportunity_executor: (
|
||||||
|
Callable[[OpportunityEvent],
|
||||||
|
Awaitable[ExecutionOutcome | float | None]] | None
|
||||||
|
) = None,
|
||||||
|
trade_capital: float = 1.0,
|
||||||
|
max_trade_capital: float | None = None,
|
||||||
|
loss_limit_guard: LossLimitGuard | None = None,
|
||||||
|
trade_limits_guard: TradeLimitsGuard | None = None,
|
||||||
|
pre_trade_validator: PreTradeValidator | None = None,
|
||||||
|
balance_provider: Callable[[], Mapping[str, float]] | None = None,
|
||||||
|
quote_balance_asset: str = "USD",
|
||||||
) -> None:
|
) -> None:
|
||||||
self._ws_client = ws_client
|
self._ws_client = ws_client
|
||||||
self._snapshot_writer = snapshot_writer
|
self._snapshot_writer = snapshot_writer
|
||||||
@@ -32,11 +52,39 @@ class MarketDataFeed:
|
|||||||
self._opportunity_writer = opportunity_writer
|
self._opportunity_writer = opportunity_writer
|
||||||
self._paper_trading_mode = paper_trading_mode
|
self._paper_trading_mode = paper_trading_mode
|
||||||
self._opportunity_executor = opportunity_executor
|
self._opportunity_executor = opportunity_executor
|
||||||
|
self._trade_capital = trade_capital
|
||||||
|
self._max_trade_capital = max_trade_capital
|
||||||
|
self._loss_limit_guard = loss_limit_guard
|
||||||
|
self._trade_limits_guard = trade_limits_guard
|
||||||
|
self._pre_trade_validator = pre_trade_validator
|
||||||
|
self._balance_provider = balance_provider
|
||||||
|
self._quote_balance_asset = quote_balance_asset.upper()
|
||||||
|
|
||||||
|
if self._trade_capital <= 0.0:
|
||||||
|
raise ValueError("trade_capital must be > 0.0")
|
||||||
|
if self._max_trade_capital is not None and self._max_trade_capital <= 0.0:
|
||||||
|
raise ValueError("max_trade_capital must be > 0.0")
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def books(self) -> dict[str, OrderBook]:
|
def books(self) -> dict[str, OrderBook]:
|
||||||
return self._books
|
return self._books
|
||||||
|
|
||||||
|
def _effective_trade_capital(self) -> float:
|
||||||
|
if self._max_trade_capital is None:
|
||||||
|
return self._trade_capital
|
||||||
|
return min(self._trade_capital, self._max_trade_capital)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _exposure_for_event(event: OpportunityEvent) -> dict[str, float]:
|
||||||
|
currencies = [part for part in event.cycle.split("->") if part]
|
||||||
|
if len(currencies) < 2:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
start = currencies[0]
|
||||||
|
exposure_assets = {
|
||||||
|
currency for currency in currencies[1:] if currency != start}
|
||||||
|
return {asset: event.allocated_capital for asset in exposure_assets}
|
||||||
|
|
||||||
async def run(self) -> None:
|
async def run(self) -> None:
|
||||||
async for message in self._ws_client.connect_stream():
|
async for message in self._ws_client.connect_stream():
|
||||||
parse_start = time.perf_counter()
|
parse_start = time.perf_counter()
|
||||||
@@ -73,6 +121,7 @@ class MarketDataFeed:
|
|||||||
opportunities = self._detector.opportunities_for_updated_pair(
|
opportunities = self._detector.opportunities_for_updated_pair(
|
||||||
delta.symbol,
|
delta.symbol,
|
||||||
self._books,
|
self._books,
|
||||||
|
base_capital=self._effective_trade_capital(),
|
||||||
)
|
)
|
||||||
_LOG.debug(
|
_LOG.debug(
|
||||||
"incremental_opportunity_scores",
|
"incremental_opportunity_scores",
|
||||||
@@ -111,7 +160,71 @@ class MarketDataFeed:
|
|||||||
)
|
)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
await self._opportunity_executor(event)
|
if self._loss_limit_guard is not None and self._loss_limit_guard.is_halted:
|
||||||
|
_LOG.warning(
|
||||||
|
"live_trade_skipped_loss_limit_halted",
|
||||||
|
cycle=event.cycle,
|
||||||
|
updated_pair=event.updated_pair,
|
||||||
|
reason=self._loss_limit_guard.halted_reason,
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
|
if self._pre_trade_validator is not None and self._balance_provider is not None:
|
||||||
|
required_balances = {
|
||||||
|
self._quote_balance_asset: event.allocated_capital}
|
||||||
|
balances = {
|
||||||
|
asset.upper(): amount
|
||||||
|
for asset, amount in self._balance_provider().items()
|
||||||
|
}
|
||||||
|
if not self._pre_trade_validator.validate(
|
||||||
|
balances_by_asset=balances,
|
||||||
|
required_by_asset=required_balances,
|
||||||
|
):
|
||||||
|
_LOG.warning(
|
||||||
|
"live_trade_skipped_pre_trade_validation",
|
||||||
|
cycle=event.cycle,
|
||||||
|
updated_pair=event.updated_pair,
|
||||||
|
required_by_asset=required_balances,
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
|
exposure_by_asset = self._exposure_for_event(event)
|
||||||
|
if (
|
||||||
|
self._trade_limits_guard is not None
|
||||||
|
and not self._trade_limits_guard.is_trade_allowed(exposure_by_asset)
|
||||||
|
):
|
||||||
|
_LOG.warning(
|
||||||
|
"live_trade_skipped_trade_limits",
|
||||||
|
cycle=event.cycle,
|
||||||
|
updated_pair=event.updated_pair,
|
||||||
|
exposure_by_asset=exposure_by_asset,
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
|
if self._trade_limits_guard is not None:
|
||||||
|
self._trade_limits_guard.open_trade(exposure_by_asset)
|
||||||
|
|
||||||
|
outcome = await self._opportunity_executor(event)
|
||||||
|
realized_pnl: float | None
|
||||||
|
close_trade = True
|
||||||
|
if isinstance(outcome, ExecutionOutcome):
|
||||||
|
realized_pnl = outcome.realized_pnl
|
||||||
|
close_trade = outcome.close_trade
|
||||||
|
else:
|
||||||
|
realized_pnl = outcome
|
||||||
|
|
||||||
|
if realized_pnl is not None and self._loss_limit_guard is not None:
|
||||||
|
self._loss_limit_guard.register_realized_pnl(
|
||||||
|
realized_pnl)
|
||||||
|
if self._loss_limit_guard.is_halted:
|
||||||
|
_LOG.warning(
|
||||||
|
"loss_limit_halt_triggered",
|
||||||
|
reason=self._loss_limit_guard.halted_reason,
|
||||||
|
cumulative_pnl=self._loss_limit_guard.cumulative_pnl,
|
||||||
|
)
|
||||||
|
|
||||||
|
if self._trade_limits_guard is not None and close_trade:
|
||||||
|
self._trade_limits_guard.close_trade(exposure_by_asset)
|
||||||
|
|
||||||
await self._snapshot_writer.enqueue(
|
await self._snapshot_writer.enqueue(
|
||||||
MarketSnapshot(
|
MarketSnapshot(
|
||||||
|
|||||||
@@ -0,0 +1,7 @@
|
|||||||
|
"""Risk management helpers."""
|
||||||
|
|
||||||
|
from arbitrade.risk.loss_limits import LossLimitGuard
|
||||||
|
from arbitrade.risk.pre_trade import PreTradeValidator
|
||||||
|
from arbitrade.risk.trade_limits import TradeLimitsGuard
|
||||||
|
|
||||||
|
__all__ = ["LossLimitGuard", "TradeLimitsGuard", "PreTradeValidator"]
|
||||||
@@ -0,0 +1,61 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from datetime import UTC, date, datetime
|
||||||
|
|
||||||
|
|
||||||
|
class LossLimitGuard:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
daily_loss_limit: float | None = None,
|
||||||
|
cumulative_loss_limit: float | None = None,
|
||||||
|
) -> None:
|
||||||
|
self._daily_loss_limit = daily_loss_limit
|
||||||
|
self._cumulative_loss_limit = cumulative_loss_limit
|
||||||
|
|
||||||
|
if self._daily_loss_limit is not None and self._daily_loss_limit <= 0.0:
|
||||||
|
raise ValueError("daily_loss_limit must be > 0.0")
|
||||||
|
if self._cumulative_loss_limit is not None and self._cumulative_loss_limit <= 0.0:
|
||||||
|
raise ValueError("cumulative_loss_limit must be > 0.0")
|
||||||
|
|
||||||
|
self._cumulative_pnl = 0.0
|
||||||
|
self._daily_pnl: dict[date, float] = {}
|
||||||
|
self._halted_reason: str | None = None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def cumulative_pnl(self) -> float:
|
||||||
|
return self._cumulative_pnl
|
||||||
|
|
||||||
|
@property
|
||||||
|
def halted_reason(self) -> str | None:
|
||||||
|
return self._halted_reason
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_halted(self) -> bool:
|
||||||
|
return self._halted_reason is not None
|
||||||
|
|
||||||
|
def daily_pnl(self, day: date) -> float:
|
||||||
|
return self._daily_pnl.get(day, 0.0)
|
||||||
|
|
||||||
|
def register_realized_pnl(self, pnl: float, *, at: datetime | None = None) -> None:
|
||||||
|
if self.is_halted:
|
||||||
|
return
|
||||||
|
|
||||||
|
timestamp = at or datetime.now(UTC)
|
||||||
|
day_key = timestamp.date()
|
||||||
|
|
||||||
|
self._cumulative_pnl += pnl
|
||||||
|
self._daily_pnl[day_key] = self._daily_pnl.get(day_key, 0.0) + pnl
|
||||||
|
|
||||||
|
if (
|
||||||
|
self._daily_loss_limit is not None
|
||||||
|
and self._daily_pnl[day_key] <= -self._daily_loss_limit
|
||||||
|
):
|
||||||
|
self._halted_reason = "daily_loss_limit_breached"
|
||||||
|
return
|
||||||
|
|
||||||
|
if (
|
||||||
|
self._cumulative_loss_limit is not None
|
||||||
|
and self._cumulative_pnl <= -self._cumulative_loss_limit
|
||||||
|
):
|
||||||
|
self._halted_reason = "cumulative_loss_limit_breached"
|
||||||
@@ -0,0 +1,43 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from collections.abc import Mapping
|
||||||
|
|
||||||
|
|
||||||
|
class PreTradeValidator:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
min_order_size_by_asset: Mapping[str, float] | None = None,
|
||||||
|
) -> None:
|
||||||
|
self._min_order_size_by_asset = {
|
||||||
|
asset.upper(): float(value) for asset, value in (min_order_size_by_asset or {}).items()
|
||||||
|
}
|
||||||
|
|
||||||
|
for value in self._min_order_size_by_asset.values():
|
||||||
|
if value <= 0.0:
|
||||||
|
raise ValueError("minimum order size must be > 0.0")
|
||||||
|
|
||||||
|
def validate(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
balances_by_asset: Mapping[str, float],
|
||||||
|
required_by_asset: Mapping[str, float],
|
||||||
|
) -> bool:
|
||||||
|
# Minimum order size checks first to fail fast on structural invalid sizes.
|
||||||
|
for asset, required in required_by_asset.items():
|
||||||
|
if required <= 0.0:
|
||||||
|
continue
|
||||||
|
|
||||||
|
min_size = self._min_order_size_by_asset.get(asset.upper())
|
||||||
|
if min_size is not None and required < min_size:
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Balance checks ensure required quantity is currently available.
|
||||||
|
for asset, required in required_by_asset.items():
|
||||||
|
if required <= 0.0:
|
||||||
|
continue
|
||||||
|
available = balances_by_asset.get(asset.upper(), 0.0)
|
||||||
|
if available < required:
|
||||||
|
return False
|
||||||
|
|
||||||
|
return True
|
||||||
@@ -0,0 +1,71 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from collections.abc import Mapping
|
||||||
|
|
||||||
|
|
||||||
|
class TradeLimitsGuard:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
max_concurrent_trades: int | None = None,
|
||||||
|
max_exposure_per_asset: float | None = None,
|
||||||
|
) -> None:
|
||||||
|
if max_concurrent_trades is not None and max_concurrent_trades <= 0:
|
||||||
|
raise ValueError("max_concurrent_trades must be > 0")
|
||||||
|
if max_exposure_per_asset is not None and max_exposure_per_asset <= 0.0:
|
||||||
|
raise ValueError("max_exposure_per_asset must be > 0.0")
|
||||||
|
|
||||||
|
self._max_concurrent_trades = max_concurrent_trades
|
||||||
|
self._max_exposure_per_asset = max_exposure_per_asset
|
||||||
|
self._active_trades = 0
|
||||||
|
self._asset_exposure: dict[str, float] = {}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def active_trades(self) -> int:
|
||||||
|
return self._active_trades
|
||||||
|
|
||||||
|
def exposure_for_asset(self, asset: str) -> float:
|
||||||
|
return self._asset_exposure.get(asset.upper(), 0.0)
|
||||||
|
|
||||||
|
def is_trade_allowed(self, exposure_by_asset: Mapping[str, float]) -> bool:
|
||||||
|
if (
|
||||||
|
self._max_concurrent_trades is not None
|
||||||
|
and self._active_trades >= self._max_concurrent_trades
|
||||||
|
):
|
||||||
|
return False
|
||||||
|
|
||||||
|
if self._max_exposure_per_asset is None:
|
||||||
|
return True
|
||||||
|
|
||||||
|
for asset, exposure in exposure_by_asset.items():
|
||||||
|
if exposure <= 0.0:
|
||||||
|
continue
|
||||||
|
key = asset.upper()
|
||||||
|
next_exposure = self._asset_exposure.get(key, 0.0) + exposure
|
||||||
|
if next_exposure > self._max_exposure_per_asset:
|
||||||
|
return False
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
def open_trade(self, exposure_by_asset: Mapping[str, float]) -> None:
|
||||||
|
self._active_trades += 1
|
||||||
|
for asset, exposure in exposure_by_asset.items():
|
||||||
|
if exposure <= 0.0:
|
||||||
|
continue
|
||||||
|
key = asset.upper()
|
||||||
|
self._asset_exposure[key] = self._asset_exposure.get(key, 0.0) + exposure
|
||||||
|
|
||||||
|
def close_trade(self, exposure_by_asset: Mapping[str, float]) -> None:
|
||||||
|
if self._active_trades > 0:
|
||||||
|
self._active_trades -= 1
|
||||||
|
|
||||||
|
for asset, exposure in exposure_by_asset.items():
|
||||||
|
if exposure <= 0.0:
|
||||||
|
continue
|
||||||
|
key = asset.upper()
|
||||||
|
current = self._asset_exposure.get(key, 0.0)
|
||||||
|
next_exposure = max(current - exposure, 0.0)
|
||||||
|
if next_exposure == 0.0:
|
||||||
|
self._asset_exposure.pop(key, None)
|
||||||
|
else:
|
||||||
|
self._asset_exposure[key] = next_exposure
|
||||||
@@ -19,10 +19,8 @@ def _make_book_levels(
|
|||||||
*, bids: list[tuple[float, float]], asks: list[tuple[float, float]]
|
*, bids: list[tuple[float, float]], asks: list[tuple[float, float]]
|
||||||
) -> OrderBook:
|
) -> OrderBook:
|
||||||
book = OrderBook()
|
book = OrderBook()
|
||||||
book.apply_bids([BookLevel(price=price, volume=volume)
|
book.apply_bids([BookLevel(price=price, volume=volume) for price, volume in bids])
|
||||||
for price, volume in bids])
|
book.apply_asks([BookLevel(price=price, volume=volume) for price, volume in asks])
|
||||||
book.apply_asks([BookLevel(price=price, volume=volume)
|
|
||||||
for price, volume in asks])
|
|
||||||
return book
|
return book
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -278,3 +278,29 @@ def test_incremental_detector_emits_structured_opportunity_event() -> None:
|
|||||||
assert event.gross_pct == pytest.approx(4.0)
|
assert event.gross_pct == pytest.approx(4.0)
|
||||||
assert event.net_pct == pytest.approx(4.0)
|
assert event.net_pct == pytest.approx(4.0)
|
||||||
assert event.est_profit == pytest.approx(20.0)
|
assert event.est_profit == pytest.approx(20.0)
|
||||||
|
|
||||||
|
|
||||||
|
def test_incremental_detector_estimated_profit_scales_with_capital() -> None:
|
||||||
|
cycle = TriangularCycle(
|
||||||
|
currencies=("USD", "BTC", "ETH"),
|
||||||
|
pairs=("BTC/USD", "ETH/BTC", "ETH/USD"),
|
||||||
|
)
|
||||||
|
detector = IncrementalCycleDetector(
|
||||||
|
CurrencyGraph.index_cycles_by_pair([cycle]),
|
||||||
|
min_profit_threshold=0.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
books = {
|
||||||
|
"BTC/USD": _make_book(bid=99.9, ask=100.0),
|
||||||
|
"ETH/BTC": _make_book(bid=0.049, ask=0.05),
|
||||||
|
"ETH/USD": _make_book(bid=5.20, ask=5.21),
|
||||||
|
}
|
||||||
|
|
||||||
|
opportunities = detector.opportunities_for_updated_pair(
|
||||||
|
"ETH/BTC",
|
||||||
|
books,
|
||||||
|
base_capital=250.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(opportunities) == 1
|
||||||
|
assert opportunities[0].est_profit == pytest.approx(10.0)
|
||||||
|
|||||||
@@ -0,0 +1,49 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from datetime import UTC, datetime, timedelta
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from arbitrade.risk.loss_limits import LossLimitGuard
|
||||||
|
|
||||||
|
|
||||||
|
def test_loss_limit_guard_tracks_daily_and_cumulative_pnl() -> None:
|
||||||
|
guard = LossLimitGuard(daily_loss_limit=100.0, cumulative_loss_limit=200.0)
|
||||||
|
t0 = datetime.now(UTC)
|
||||||
|
|
||||||
|
guard.register_realized_pnl(-40.0, at=t0)
|
||||||
|
guard.register_realized_pnl(10.0, at=t0)
|
||||||
|
|
||||||
|
assert guard.cumulative_pnl == -30.0
|
||||||
|
assert guard.daily_pnl(t0.date()) == -30.0
|
||||||
|
assert not guard.is_halted
|
||||||
|
|
||||||
|
|
||||||
|
def test_loss_limit_guard_halts_on_daily_limit() -> None:
|
||||||
|
guard = LossLimitGuard(daily_loss_limit=50.0)
|
||||||
|
t0 = datetime.now(UTC)
|
||||||
|
|
||||||
|
guard.register_realized_pnl(-30.0, at=t0)
|
||||||
|
guard.register_realized_pnl(-25.0, at=t0)
|
||||||
|
|
||||||
|
assert guard.is_halted
|
||||||
|
assert guard.halted_reason == "daily_loss_limit_breached"
|
||||||
|
|
||||||
|
|
||||||
|
def test_loss_limit_guard_halts_on_cumulative_limit_across_days() -> None:
|
||||||
|
guard = LossLimitGuard(cumulative_loss_limit=60.0)
|
||||||
|
t0 = datetime.now(UTC)
|
||||||
|
|
||||||
|
guard.register_realized_pnl(-40.0, at=t0)
|
||||||
|
guard.register_realized_pnl(-25.0, at=t0 + timedelta(days=1))
|
||||||
|
|
||||||
|
assert guard.is_halted
|
||||||
|
assert guard.halted_reason == "cumulative_loss_limit_breached"
|
||||||
|
|
||||||
|
|
||||||
|
def test_loss_limit_guard_rejects_invalid_limits() -> None:
|
||||||
|
with pytest.raises(ValueError, match="daily_loss_limit"):
|
||||||
|
LossLimitGuard(daily_loss_limit=0.0)
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="cumulative_loss_limit"):
|
||||||
|
LossLimitGuard(cumulative_loss_limit=-1.0)
|
||||||
@@ -8,7 +8,10 @@ import pytest
|
|||||||
|
|
||||||
from arbitrade.detection.engine import OpportunityEvent
|
from arbitrade.detection.engine import OpportunityEvent
|
||||||
from arbitrade.exchange.models import BookDelta, BookLevel
|
from arbitrade.exchange.models import BookDelta, BookLevel
|
||||||
from arbitrade.market_data.feed import MarketDataFeed
|
from arbitrade.market_data.feed import ExecutionOutcome, MarketDataFeed
|
||||||
|
from arbitrade.risk.loss_limits import LossLimitGuard
|
||||||
|
from arbitrade.risk.pre_trade import PreTradeValidator
|
||||||
|
from arbitrade.risk.trade_limits import TradeLimitsGuard
|
||||||
|
|
||||||
|
|
||||||
@dataclass(slots=True)
|
@dataclass(slots=True)
|
||||||
@@ -41,20 +44,47 @@ class _FakeOpportunityWriter:
|
|||||||
class _FakeDetector:
|
class _FakeDetector:
|
||||||
def __init__(self, event: OpportunityEvent) -> None:
|
def __init__(self, event: OpportunityEvent) -> None:
|
||||||
self._event = event
|
self._event = event
|
||||||
|
self.last_base_capital: float | None = None
|
||||||
|
|
||||||
def opportunities_for_updated_pair(self, _updated_pair: str, _books: dict[str, object]):
|
def opportunities_for_updated_pair(
|
||||||
|
self,
|
||||||
|
_updated_pair: str,
|
||||||
|
_books: dict[str, object],
|
||||||
|
*,
|
||||||
|
base_capital: float,
|
||||||
|
):
|
||||||
|
self.last_base_capital = base_capital
|
||||||
return [self._event]
|
return [self._event]
|
||||||
|
|
||||||
|
|
||||||
class _FakeExecutor:
|
class _FakeExecutor:
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
self.calls: list[OpportunityEvent] = []
|
self.calls: list[OpportunityEvent] = []
|
||||||
|
self.realized_pnls: list[float | None] = []
|
||||||
|
self.outcomes: list[ExecutionOutcome] = []
|
||||||
|
|
||||||
async def execute(self, event: OpportunityEvent) -> None:
|
async def execute(self, event: OpportunityEvent) -> ExecutionOutcome | float | None:
|
||||||
self.calls.append(event)
|
self.calls.append(event)
|
||||||
|
if self.outcomes:
|
||||||
|
return self.outcomes.pop(0)
|
||||||
|
if not self.realized_pnls:
|
||||||
|
return None
|
||||||
|
return self.realized_pnls.pop(0)
|
||||||
|
|
||||||
|
|
||||||
def _sample_event() -> OpportunityEvent:
|
@dataclass(slots=True)
|
||||||
|
class _FakeWsClientTwoMessages:
|
||||||
|
delta: BookDelta
|
||||||
|
|
||||||
|
async def connect_stream(self):
|
||||||
|
yield SimpleNamespace(payload={"channel": "book", "seq": 1})
|
||||||
|
yield SimpleNamespace(payload={"channel": "book", "seq": 2})
|
||||||
|
|
||||||
|
def parse_book_delta(self, _payload: dict[str, object]) -> BookDelta:
|
||||||
|
return self.delta
|
||||||
|
|
||||||
|
|
||||||
|
def _sample_event(*, allocated_capital: float = 1.0) -> OpportunityEvent:
|
||||||
return OpportunityEvent(
|
return OpportunityEvent(
|
||||||
detected_at=datetime.now(UTC),
|
detected_at=datetime.now(UTC),
|
||||||
cycle="USD->BTC->ETH->USD",
|
cycle="USD->BTC->ETH->USD",
|
||||||
@@ -64,6 +94,7 @@ def _sample_event() -> OpportunityEvent:
|
|||||||
gross_pct=4.0,
|
gross_pct=4.0,
|
||||||
net_pct=3.0,
|
net_pct=3.0,
|
||||||
est_profit=0.03,
|
est_profit=0.03,
|
||||||
|
allocated_capital=allocated_capital,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -110,3 +141,184 @@ async def test_market_data_feed_live_mode_executes_orders() -> None:
|
|||||||
|
|
||||||
assert len(executor.calls) == 1
|
assert len(executor.calls) == 1
|
||||||
assert executor.calls[0].cycle == "USD->BTC->ETH->USD"
|
assert executor.calls[0].cycle == "USD->BTC->ETH->USD"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_market_data_feed_enforces_per_trade_capital_limit() -> None:
|
||||||
|
event = _sample_event()
|
||||||
|
detector = _FakeDetector(event)
|
||||||
|
feed = MarketDataFeed(
|
||||||
|
ws_client=_FakeWsClient(_sample_delta()),
|
||||||
|
snapshot_writer=_FakeSnapshotWriter(),
|
||||||
|
detector=detector,
|
||||||
|
opportunity_writer=_FakeOpportunityWriter(),
|
||||||
|
paper_trading_mode=True,
|
||||||
|
trade_capital=250.0,
|
||||||
|
max_trade_capital=100.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
await feed.run()
|
||||||
|
|
||||||
|
assert detector.last_base_capital == 100.0
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_market_data_feed_auto_halts_on_daily_loss_limit() -> None:
|
||||||
|
event = _sample_event()
|
||||||
|
detector = _FakeDetector(event)
|
||||||
|
executor = _FakeExecutor()
|
||||||
|
executor.realized_pnls = [-60.0, -10.0]
|
||||||
|
loss_guard = LossLimitGuard(daily_loss_limit=50.0)
|
||||||
|
feed = MarketDataFeed(
|
||||||
|
ws_client=_FakeWsClientTwoMessages(_sample_delta()),
|
||||||
|
snapshot_writer=_FakeSnapshotWriter(),
|
||||||
|
detector=detector,
|
||||||
|
opportunity_writer=_FakeOpportunityWriter(),
|
||||||
|
paper_trading_mode=False,
|
||||||
|
opportunity_executor=executor.execute,
|
||||||
|
loss_limit_guard=loss_guard,
|
||||||
|
)
|
||||||
|
|
||||||
|
await feed.run()
|
||||||
|
|
||||||
|
assert len(executor.calls) == 1
|
||||||
|
assert loss_guard.is_halted
|
||||||
|
assert loss_guard.halted_reason == "daily_loss_limit_breached"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_market_data_feed_auto_halts_on_cumulative_loss_limit() -> None:
|
||||||
|
event = _sample_event()
|
||||||
|
detector = _FakeDetector(event)
|
||||||
|
executor = _FakeExecutor()
|
||||||
|
executor.realized_pnls = [-40.0, -15.0]
|
||||||
|
loss_guard = LossLimitGuard(cumulative_loss_limit=50.0)
|
||||||
|
feed = MarketDataFeed(
|
||||||
|
ws_client=_FakeWsClientTwoMessages(_sample_delta()),
|
||||||
|
snapshot_writer=_FakeSnapshotWriter(),
|
||||||
|
detector=detector,
|
||||||
|
opportunity_writer=_FakeOpportunityWriter(),
|
||||||
|
paper_trading_mode=False,
|
||||||
|
opportunity_executor=executor.execute,
|
||||||
|
loss_limit_guard=loss_guard,
|
||||||
|
)
|
||||||
|
|
||||||
|
await feed.run()
|
||||||
|
|
||||||
|
assert len(executor.calls) == 2
|
||||||
|
assert loss_guard.is_halted
|
||||||
|
assert loss_guard.halted_reason == "cumulative_loss_limit_breached"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_market_data_feed_enforces_max_concurrent_trades() -> None:
|
||||||
|
event = _sample_event()
|
||||||
|
detector = _FakeDetector(event)
|
||||||
|
executor = _FakeExecutor()
|
||||||
|
executor.outcomes = [ExecutionOutcome(
|
||||||
|
realized_pnl=None, close_trade=False)]
|
||||||
|
trade_guard = TradeLimitsGuard(max_concurrent_trades=1)
|
||||||
|
feed = MarketDataFeed(
|
||||||
|
ws_client=_FakeWsClientTwoMessages(_sample_delta()),
|
||||||
|
snapshot_writer=_FakeSnapshotWriter(),
|
||||||
|
detector=detector,
|
||||||
|
opportunity_writer=_FakeOpportunityWriter(),
|
||||||
|
paper_trading_mode=False,
|
||||||
|
opportunity_executor=executor.execute,
|
||||||
|
trade_limits_guard=trade_guard,
|
||||||
|
)
|
||||||
|
|
||||||
|
await feed.run()
|
||||||
|
|
||||||
|
assert len(executor.calls) == 1
|
||||||
|
assert trade_guard.active_trades == 1
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_market_data_feed_enforces_per_asset_exposure_cap() -> None:
|
||||||
|
event = _sample_event(allocated_capital=100.0)
|
||||||
|
detector = _FakeDetector(event)
|
||||||
|
executor = _FakeExecutor()
|
||||||
|
trade_guard = TradeLimitsGuard(max_exposure_per_asset=50.0)
|
||||||
|
feed = MarketDataFeed(
|
||||||
|
ws_client=_FakeWsClient(_sample_delta()),
|
||||||
|
snapshot_writer=_FakeSnapshotWriter(),
|
||||||
|
detector=detector,
|
||||||
|
opportunity_writer=_FakeOpportunityWriter(),
|
||||||
|
paper_trading_mode=False,
|
||||||
|
opportunity_executor=executor.execute,
|
||||||
|
trade_limits_guard=trade_guard,
|
||||||
|
)
|
||||||
|
|
||||||
|
await feed.run()
|
||||||
|
|
||||||
|
assert len(executor.calls) == 0
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_market_data_feed_blocks_when_pre_trade_balance_insufficient() -> None:
|
||||||
|
event = _sample_event(allocated_capital=100.0)
|
||||||
|
detector = _FakeDetector(event)
|
||||||
|
executor = _FakeExecutor()
|
||||||
|
validator = PreTradeValidator(min_order_size_by_asset={"USD": 50.0})
|
||||||
|
feed = MarketDataFeed(
|
||||||
|
ws_client=_FakeWsClient(_sample_delta()),
|
||||||
|
snapshot_writer=_FakeSnapshotWriter(),
|
||||||
|
detector=detector,
|
||||||
|
opportunity_writer=_FakeOpportunityWriter(),
|
||||||
|
paper_trading_mode=False,
|
||||||
|
opportunity_executor=executor.execute,
|
||||||
|
pre_trade_validator=validator,
|
||||||
|
balance_provider=lambda: {"USD": 25.0},
|
||||||
|
quote_balance_asset="USD",
|
||||||
|
)
|
||||||
|
|
||||||
|
await feed.run()
|
||||||
|
|
||||||
|
assert len(executor.calls) == 0
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_market_data_feed_blocks_when_pre_trade_min_order_not_met() -> None:
|
||||||
|
event = _sample_event(allocated_capital=25.0)
|
||||||
|
detector = _FakeDetector(event)
|
||||||
|
executor = _FakeExecutor()
|
||||||
|
validator = PreTradeValidator(min_order_size_by_asset={"USD": 50.0})
|
||||||
|
feed = MarketDataFeed(
|
||||||
|
ws_client=_FakeWsClient(_sample_delta()),
|
||||||
|
snapshot_writer=_FakeSnapshotWriter(),
|
||||||
|
detector=detector,
|
||||||
|
opportunity_writer=_FakeOpportunityWriter(),
|
||||||
|
paper_trading_mode=False,
|
||||||
|
opportunity_executor=executor.execute,
|
||||||
|
pre_trade_validator=validator,
|
||||||
|
balance_provider=lambda: {"USD": 500.0},
|
||||||
|
quote_balance_asset="USD",
|
||||||
|
)
|
||||||
|
|
||||||
|
await feed.run()
|
||||||
|
|
||||||
|
assert len(executor.calls) == 0
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_market_data_feed_allows_when_pre_trade_validation_passes() -> None:
|
||||||
|
event = _sample_event(allocated_capital=75.0)
|
||||||
|
detector = _FakeDetector(event)
|
||||||
|
executor = _FakeExecutor()
|
||||||
|
validator = PreTradeValidator(min_order_size_by_asset={"USD": 50.0})
|
||||||
|
feed = MarketDataFeed(
|
||||||
|
ws_client=_FakeWsClient(_sample_delta()),
|
||||||
|
snapshot_writer=_FakeSnapshotWriter(),
|
||||||
|
detector=detector,
|
||||||
|
opportunity_writer=_FakeOpportunityWriter(),
|
||||||
|
paper_trading_mode=False,
|
||||||
|
opportunity_executor=executor.execute,
|
||||||
|
pre_trade_validator=validator,
|
||||||
|
balance_provider=lambda: {"USD": 500.0},
|
||||||
|
quote_balance_asset="USD",
|
||||||
|
)
|
||||||
|
|
||||||
|
await feed.run()
|
||||||
|
|
||||||
|
assert len(executor.calls) == 1
|
||||||
|
|||||||
@@ -0,0 +1,37 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from arbitrade.risk.pre_trade import PreTradeValidator
|
||||||
|
|
||||||
|
|
||||||
|
def test_pre_trade_validator_accepts_when_balance_and_min_size_pass() -> None:
|
||||||
|
validator = PreTradeValidator(min_order_size_by_asset={"USD": 50.0})
|
||||||
|
|
||||||
|
assert validator.validate(
|
||||||
|
balances_by_asset={"USD": 100.0},
|
||||||
|
required_by_asset={"USD": 75.0},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_pre_trade_validator_rejects_when_balance_insufficient() -> None:
|
||||||
|
validator = PreTradeValidator(min_order_size_by_asset={"USD": 50.0})
|
||||||
|
|
||||||
|
assert not validator.validate(
|
||||||
|
balances_by_asset={"USD": 40.0},
|
||||||
|
required_by_asset={"USD": 75.0},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_pre_trade_validator_rejects_when_below_min_size() -> None:
|
||||||
|
validator = PreTradeValidator(min_order_size_by_asset={"USD": 50.0})
|
||||||
|
|
||||||
|
assert not validator.validate(
|
||||||
|
balances_by_asset={"USD": 100.0},
|
||||||
|
required_by_asset={"USD": 30.0},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_pre_trade_validator_rejects_invalid_min_order_size_config() -> None:
|
||||||
|
with pytest.raises(ValueError, match="minimum order size"):
|
||||||
|
PreTradeValidator(min_order_size_by_asset={"USD": 0.0})
|
||||||
@@ -0,0 +1,41 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from arbitrade.risk.trade_limits import TradeLimitsGuard
|
||||||
|
|
||||||
|
|
||||||
|
def test_trade_limits_guard_blocks_when_max_concurrent_reached() -> None:
|
||||||
|
guard = TradeLimitsGuard(max_concurrent_trades=1)
|
||||||
|
|
||||||
|
guard.open_trade({"BTC": 10.0})
|
||||||
|
|
||||||
|
assert not guard.is_trade_allowed({"BTC": 1.0})
|
||||||
|
|
||||||
|
|
||||||
|
def test_trade_limits_guard_blocks_when_asset_exposure_would_breach_cap() -> None:
|
||||||
|
guard = TradeLimitsGuard(max_exposure_per_asset=100.0)
|
||||||
|
|
||||||
|
guard.open_trade({"BTC": 80.0})
|
||||||
|
|
||||||
|
assert not guard.is_trade_allowed({"BTC": 25.0})
|
||||||
|
assert guard.is_trade_allowed({"ETH": 25.0})
|
||||||
|
|
||||||
|
|
||||||
|
def test_trade_limits_guard_releases_exposure_on_close() -> None:
|
||||||
|
guard = TradeLimitsGuard(max_concurrent_trades=2, max_exposure_per_asset=100.0)
|
||||||
|
|
||||||
|
guard.open_trade({"BTC": 80.0})
|
||||||
|
guard.close_trade({"BTC": 80.0})
|
||||||
|
|
||||||
|
assert guard.active_trades == 0
|
||||||
|
assert guard.exposure_for_asset("BTC") == 0.0
|
||||||
|
assert guard.is_trade_allowed({"BTC": 100.0})
|
||||||
|
|
||||||
|
|
||||||
|
def test_trade_limits_guard_rejects_invalid_configuration() -> None:
|
||||||
|
with pytest.raises(ValueError, match="max_concurrent_trades"):
|
||||||
|
TradeLimitsGuard(max_concurrent_trades=0)
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="max_exposure_per_asset"):
|
||||||
|
TradeLimitsGuard(max_exposure_per_asset=0.0)
|
||||||
Reference in New Issue
Block a user