from __future__ import annotations import asyncio from typing import Any import pytest from arbitrade.risk.trade_limits import TradeLimitsGuard class _FakeAlertNotifier: def __init__(self) -> None: self.events: list[dict[str, Any]] = [] async def notify( self, *, category: str, severity: str, title: str, message: str, details: dict[str, str] | None = None, ) -> bool: self.events.append( { "category": category, "severity": severity, "title": title, "message": message, "details": details or {}, } ) return True def test_trade_limits_guard_blocks_when_max_concurrent_reached() -> None: guard = TradeLimitsGuard(max_concurrent_trades=1) guard.open_trade({"BTC": 10.0}) assert not guard.is_trade_allowed({"BTC": 1.0}) def test_trade_limits_guard_blocks_when_asset_exposure_would_breach_cap() -> None: guard = TradeLimitsGuard(max_exposure_per_asset=100.0) guard.open_trade({"BTC": 80.0}) assert not guard.is_trade_allowed({"BTC": 25.0}) assert guard.is_trade_allowed({"ETH": 25.0}) def test_trade_limits_guard_releases_exposure_on_close() -> None: guard = TradeLimitsGuard(max_concurrent_trades=2, max_exposure_per_asset=100.0) guard.open_trade({"BTC": 80.0}) guard.close_trade({"BTC": 80.0}) assert guard.active_trades == 0 assert guard.exposure_for_asset("BTC") == 0.0 assert guard.is_trade_allowed({"BTC": 100.0}) def test_trade_limits_guard_rejects_invalid_configuration() -> None: with pytest.raises(ValueError, match="max_concurrent_trades"): TradeLimitsGuard(max_concurrent_trades=0) with pytest.raises(ValueError, match="max_exposure_per_asset"): TradeLimitsGuard(max_exposure_per_asset=0.0) @pytest.mark.asyncio async def test_trade_limits_guard_emits_alert_when_rejecting_trade() -> None: notifier = _FakeAlertNotifier() guard = TradeLimitsGuard(max_concurrent_trades=1, alert_notifier=notifier) guard.open_trade({"BTC": 10.0}) allowed = guard.is_trade_allowed({"BTC": 1.0}) await asyncio.sleep(0) assert not allowed assert len(notifier.events) == 1 assert notifier.events[0]["category"] == "threshold" assert notifier.events[0]["title"] == "Concurrent trade limit reached"