diff --git a/.env.example b/.env.example index 042c9e5..75c25a3 100644 --- a/.env.example +++ b/.env.example @@ -3,10 +3,31 @@ APP_HOST=0.0.0.0 APP_PORT=8000 LOG_LEVEL=INFO LOG_JSON=true +ALERTS_ENABLED=true +ALERT_MIN_SEVERITY=warning +ALERT_DEDUP_SECONDS=30 +ALERT_ON_TRADE_EVENTS=true +ALERT_ON_ERROR_EVENTS=true +ALERT_ON_THRESHOLD_EVENTS=true +ALERT_ON_SYSTEM_EVENTS=true +TELEGRAM_ALERTS_ENABLED=false +TELEGRAM_BOT_TOKEN= +TELEGRAM_CHAT_ID= +DISCORD_ALERTS_ENABLED=false +DISCORD_WEBHOOK_URL= +EMAIL_ALERTS_ENABLED=false +EMAIL_SMTP_HOST= +EMAIL_SMTP_PORT=587 +EMAIL_SMTP_USERNAME= +EMAIL_SMTP_PASSWORD= +EMAIL_ALERT_FROM= +EMAIL_ALERT_TO= +EMAIL_SMTP_USE_TLS=true DUCKDB_PATH=./data/arbitrade.duckdb FERNET_KEY= KRAKEN_API_KEY= KRAKEN_API_SECRET= +KRAKEN_API_KEY_PERMISSIONS=query,trade KRAKEN_REST_URL=https://api.kraken.com KRAKEN_WS_URL=wss://ws.kraken.com/v2 KRAKEN_PRIVATE_RATE_LIMIT_SECONDS=1.0 diff --git a/.gitea/workflows/ci.yml b/.gitea/workflows/ci.yml index c009ed3..6a704b1 100644 --- a/.gitea/workflows/ci.yml +++ b/.gitea/workflows/ci.yml @@ -23,6 +23,7 @@ jobs: run: | python -m pip install --upgrade pip pip install -e .[dev] + pip install pip-audit - name: Ruff run: ruff check . @@ -33,6 +34,12 @@ jobs: - name: MyPy run: mypy src + - name: Dependency audit + run: pip-audit --skip-editable + + - name: Secret scan (worktree + git history) + run: python scripts/security_scan.py + - name: Tests run: pytest -q diff --git a/CHANGELOG.md b/CHANGELOG.md index cff5699..6cdfb31 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,14 +1,27 @@ +# Changelog + ## [Unreleased] - 2026-06-01 ### Added - Added stop-condition risk controls for abnormal source/apply latency and repeated execution failures. - Added a new stop-conditions guard and integration in market feed processing. +- Added multi-channel alerting infrastructure with Telegram, Discord webhook, and SMTP channel clients. +- Added alert configuration settings for severity threshold, category routing, and dedup cooldown. +- Added dashboard alert status surfacing with configured channels and last-send delivery outcome. +- Added append-only `audit_events` schema plus repository support for insert/query of recent audit records. +- Added dashboard audit fragment and protected API endpoint for recent audit entries. +- Added runtime lifecycle manager with startup recovery and graceful shutdown orchestration. +- Added `runtime_state_snapshots` persistence for control flags, open trade count, and last known balances. +- Added CI security gates for dependency auditing (`pip-audit --strict`) and a repository/worktree secret scan script. +- Added strict settings validators for auth pairing, Kraken credential pairing, alert severity bounds, and key-scope policy. ### Changed - Live execution path now auto-activates the kill switch when configured stop conditions are breached. - Added configuration env keys for stop-condition thresholds. +- WebSocket client now emits system alerts for disconnect/reconnect and heartbeat staleness timeout events. +- Added explicit Kraken API key permission configuration (`KRAKEN_API_KEY_PERMISSIONS`) and docs for least-privilege key usage. ### Removed @@ -27,3 +40,13 @@ - Added dashboard controls for start/stop, config edits, and manual kill-switch triggering via HTMX POST forms. - Added Alpine.js interactivity and a Chart.js opportunity trend panel to the dashboard. - Added optional HTTP Basic authentication for dashboard routes, fragments, streams, and control endpoints. +- Added alert wiring for dashboard control actions, execution success/failure, and threshold breaches in risk guards. +- Added unit/integration tests covering alert notifier behavior and alert emission paths. +- Added critical system alert emission when live opportunity executor raises an unhandled exception. +- Added WebSocket and market-feed tests for system-event alerting paths. +- Added notifier status snapshot tracking and protected alert-status API endpoint for operational visibility. +- Added audit event writes for dashboard controls and detector/risk/execution decision points. +- Added tests for audit repository and dashboard audit route coverage. +- Added startup restart safety guard that halts execution when open trades are detected after restart. +- Added lifecycle tests for snapshot persistence, worker draining, recovery restore, and startup reconciliation hook. +- Added unit coverage for security-related settings validation paths. diff --git a/README.md b/README.md index a7d0fc8..ee1789d 100644 --- a/README.md +++ b/README.md @@ -105,11 +105,14 @@ DUCKDB_PATH=./data/arbitrade.duckdb FERNET_KEY= KRAKEN_API_KEY= KRAKEN_API_SECRET= +KRAKEN_API_KEY_PERMISSIONS=query,trade ``` Notes: - Leave Kraken creds empty until Kraken integration lands. +- If Kraken creds are set, both key and secret are required. +- `KRAKEN_API_KEY_PERMISSIONS` must include `query,trade` and must not include withdrawal scope. - `FERNET_KEY` optional. If empty, keyring-backed key generation used by secret helper. - On Windows, app falls back to default `asyncio` loop. On non-Windows, `uvloop` installs automatically. @@ -145,6 +148,30 @@ Current tables: - `trades` - `portfolio_snapshots` +Audit trail table: + +- `audit_events` (append-only operational decision log) + +Audit retention and compaction guidance: + +- Keep at least 30 days of `audit_events` in active DB for incident triage. +- Archive older rows to a timestamped export file before deletion. +- Example monthly archive workflow: + +```sql +COPY ( + SELECT * + FROM audit_events + WHERE occurred_at < NOW() - INTERVAL 30 DAY +) TO 'data/audit_events_archive_YYYYMM.parquet' (FORMAT PARQUET); + +DELETE FROM audit_events +WHERE occurred_at < NOW() - INTERVAL 30 DAY; +``` + +- Back up archive files and the main DuckDB file together. +- For production, run archive + backup as scheduled maintenance (cron/task scheduler). + ## Quality Checks Run tests: @@ -171,6 +198,18 @@ Run mypy: mypy src ``` +Run dependency vulnerability audit: + +```powershell +pip-audit --skip-editable +``` + +Run secret scan (worktree + git history): + +```powershell +python scripts/security_scan.py +``` + Install pre-commit hooks: ```powershell @@ -282,3 +321,20 @@ uv pip install -e .[dev] ``` If DuckDB file missing, start app once or create `data/` directory manually. + +## Security Hardening + +Threat model notes: + +- Primary risk surfaces: environment secrets, dashboard auth credentials, exchange API key scope, and dependency supply chain. +- Assumed attacker model: leaked repository content, leaked CI logs/artifacts, or unauthorized dashboard access. +- High-impact outcomes to prevent: credential exfiltration, unauthorized withdrawals, and unsafe live-trading control changes. + +Hardening checklist: + +- Use least-privilege Kraken API keys: query + trade only; never enable withdrawal. +- Rotate API keys immediately if secret scan flags a potential exposure. +- Keep dashboard auth enabled in non-local environments and avoid default/shared credentials. +- Run `pip-audit --skip-editable` in CI; treat vulnerability findings as release blockers. +- Run `python scripts/security_scan.py` before release and after major merges. +- Store secrets in environment/secret manager; never commit `.env` or key material. diff --git a/scripts/security_scan.py b/scripts/security_scan.py new file mode 100644 index 0000000..ce90d3c --- /dev/null +++ b/scripts/security_scan.py @@ -0,0 +1,101 @@ +from __future__ import annotations + +import re +import subprocess +from pathlib import Path + +WORKSPACE = Path(__file__).resolve().parents[1] + +EXCLUDED_DIRS = { + ".git", + ".venv", + "__pycache__", + ".mypy_cache", + ".pytest_cache", + "data", + "node_modules", +} + +PATTERNS: list[tuple[str, re.Pattern[str]]] = [ + ("private_key", re.compile(r"-----BEGIN [A-Z ]*PRIVATE KEY-----")), + ("aws_access_key", re.compile(r"AKIA[0-9A-Z]{16}")), + ( + "generic_token", + re.compile( + r"(?i)(token|secret|password)\s*[:=]\s*['\"]?" + r"(?=[A-Za-z0-9_\-+/=.]{20,})(?=.*[A-Za-z])(?=.*\d)[A-Za-z0-9_\-+/=.]{20,}" + ), + ), +] + + +def _is_probably_text(path: Path) -> bool: + try: + with path.open("rb") as handle: + sample = handle.read(2048) + except OSError: + return False + return b"\x00" not in sample + + +def scan_worktree() -> list[str]: + findings: list[str] = [] + tracked = subprocess.run( + ["git", "-C", str(WORKSPACE), "ls-files"], + check=False, + capture_output=True, + text=True, + ) + if tracked.returncode != 0: + return ["worktree_scan_failed"] + + for rel_path in tracked.stdout.splitlines(): + path = WORKSPACE / rel_path + if not path.is_file() or any(part in EXCLUDED_DIRS for part in path.parts): + continue + if not _is_probably_text(path): + continue + + try: + content = path.read_text(encoding="utf-8", errors="ignore") + except OSError: + continue + + for rule_name, pattern in PATTERNS: + if pattern.search(content): + findings.append( + f"worktree:{path.relative_to(WORKSPACE)}:{rule_name}") + return findings + + +def scan_git_history() -> list[str]: + cmd = ["git", "-C", str(WORKSPACE), "log", "--all", + "-p", "--pretty=format:%H"] + completed = subprocess.run( + cmd, check=False, capture_output=True, text=True) + if completed.returncode != 0: + return ["history_scan_failed"] + + findings: list[str] = [] + data = completed.stdout + for rule_name, pattern in PATTERNS: + if pattern.search(data): + findings.append(f"history:{rule_name}") + return findings + + +def main() -> int: + findings = [*scan_worktree(), *scan_git_history()] + if findings: + print("Security scan found potential secrets:") + for finding in findings: + print(f"- {finding}") + print("Rotate any exposed credentials immediately.") + return 1 + + print("Security scan passed: no obvious secrets detected in worktree/history.") + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/src/arbitrade/alerting/__init__.py b/src/arbitrade/alerting/__init__.py new file mode 100644 index 0000000..dcaa841 --- /dev/null +++ b/src/arbitrade/alerting/__init__.py @@ -0,0 +1,25 @@ +"""Alerting primitives and channel clients.""" + +from arbitrade.alerting.notifier import ( + AlertEvent, + AlertNotifier, + AlertSeverity, + DiscordWebhookChannel, + EmailSmtpChannel, + SupportsAlertStatus, + TelegramChannel, + build_channels_from_settings, + dispatch_alert_nowait, +) + +__all__ = [ + "AlertEvent", + "AlertNotifier", + "AlertSeverity", + "DiscordWebhookChannel", + "EmailSmtpChannel", + "SupportsAlertStatus", + "TelegramChannel", + "build_channels_from_settings", + "dispatch_alert_nowait", +] diff --git a/src/arbitrade/alerting/notifier.py b/src/arbitrade/alerting/notifier.py new file mode 100644 index 0000000..0ad3ef4 --- /dev/null +++ b/src/arbitrade/alerting/notifier.py @@ -0,0 +1,400 @@ +from __future__ import annotations + +import asyncio +import smtplib +from dataclasses import dataclass +from datetime import UTC, datetime +from email.message import EmailMessage +from typing import Literal, Protocol, runtime_checkable + +import httpx + +AlertSeverity = Literal["info", "warning", "error", "critical"] + +_SEVERITY_RANK: dict[AlertSeverity, int] = { + "info": 10, + "warning": 20, + "error": 30, + "critical": 40, +} + + +@dataclass(frozen=True, slots=True) +class AlertEvent: + category: str + severity: AlertSeverity + title: str + message: str + occurred_at: datetime + details: dict[str, str] + + +class AlertChannel(Protocol): + async def send(self, event: AlertEvent) -> None: ... + + +class SupportsAlerts(Protocol): + async def notify( + self, + *, + category: str, + severity: AlertSeverity, + title: str, + message: str, + details: dict[str, str] | None = None, + ) -> bool: ... + + +@runtime_checkable +class SupportsAlertStatus(Protocol): + def status_snapshot(self) -> dict[str, object]: ... + + +class AlertNotifier: + def __init__( + self, + channels: list[AlertChannel], + *, + enabled: bool = True, + min_severity: AlertSeverity = "info", + dedup_seconds: float = 0.0, + category_flags: dict[str, bool] | None = None, + ) -> None: + if dedup_seconds < 0.0: + raise ValueError("dedup_seconds must be >= 0.0") + self._channels = channels + self._enabled = enabled + self._min_severity: AlertSeverity = min_severity + self._dedup_seconds = dedup_seconds + self._category_flags = {key.lower(): value for key, value in (category_flags or {}).items()} + self._last_sent_at: dict[str, datetime] = {} + self._last_result: str = "never" + self._last_attempted_at: datetime | None = None + self._last_success_at: datetime | None = None + self._last_error: str | None = None + self._last_event_title: str | None = None + self._last_event_category: str | None = None + self._last_event_severity: AlertSeverity | None = None + self._last_channel_results: list[str] = [] + + @property + def has_channels(self) -> bool: + return bool(self._channels) + + async def notify( + self, + *, + category: str, + severity: AlertSeverity, + title: str, + message: str, + details: dict[str, str] | None = None, + ) -> bool: + if not self._enabled or not self._channels: + self._last_result = "skipped_disabled" if not self._enabled else "skipped_no_channels" + return False + + normalized_category = category.strip().lower() + if self._category_flags and not self._category_flags.get(normalized_category, True): + self._last_result = "skipped_category" + return False + + if _SEVERITY_RANK[severity] < _SEVERITY_RANK[self._min_severity]: + self._last_result = "skipped_severity" + return False + + dedup_key = f"{normalized_category}|{severity}|{title}|{message}" + now = datetime.now(UTC) + if self._dedup_seconds > 0.0: + previous = self._last_sent_at.get(dedup_key) + if previous is not None: + elapsed = (now - previous).total_seconds() + if elapsed < self._dedup_seconds: + self._last_result = "skipped_dedup" + return False + + event = AlertEvent( + category=normalized_category, + severity=severity, + title=title, + message=message, + occurred_at=now, + details=details or {}, + ) + + results = await asyncio.gather( + *(channel.send(event) for channel in self._channels), + return_exceptions=True, + ) + self._last_attempted_at = now + self._last_event_title = title + self._last_event_category = normalized_category + self._last_event_severity = severity + self._last_channel_results = [] + for channel, result in zip(self._channels, results, strict=False): + channel_name = type(channel).__name__ + if isinstance(result, Exception): + self._last_channel_results.append(f"{channel_name}: error") + else: + self._last_channel_results.append(f"{channel_name}: ok") + + if all(isinstance(result, Exception) for result in results): + self._last_result = "failed" + self._last_error = "all channels failed" + return False + + self._last_result = ( + "partial_success" + if any(isinstance(result, Exception) for result in results) + else "success" + ) + self._last_error = None + self._last_success_at = now + + self._last_sent_at[dedup_key] = now + return True + + def status_snapshot(self) -> dict[str, object]: + return { + "enabled": self._enabled, + "has_channels": self.has_channels, + "configured_channels": [type(channel).__name__ for channel in self._channels], + "min_severity": self._min_severity, + "dedup_seconds": self._dedup_seconds, + "last_result": self._last_result, + "last_attempted_at": ( + self._last_attempted_at.isoformat() if self._last_attempted_at is not None else None + ), + "last_success_at": ( + self._last_success_at.isoformat() if self._last_success_at is not None else None + ), + "last_error": self._last_error, + "last_event": ( + None + if self._last_event_title is None + else { + "title": self._last_event_title, + "category": self._last_event_category, + "severity": self._last_event_severity, + } + ), + "last_channel_results": self._last_channel_results, + } + + +def dispatch_alert_nowait( + notifier: SupportsAlerts | None, + *, + category: str, + severity: AlertSeverity, + title: str, + message: str, + details: dict[str, str] | None = None, +) -> None: + if notifier is None: + return + try: + loop = asyncio.get_running_loop() + except RuntimeError: + return + + loop.create_task( + notifier.notify( + category=category, + severity=severity, + title=title, + message=message, + details=details, + ) + ) + + +def _format_event_text(event: AlertEvent) -> str: + lines = [ + f"[{event.severity.upper()}] {event.title}", + f"Category: {event.category}", + f"Time: {event.occurred_at.isoformat()}", + event.message, + ] + if event.details: + lines.append("Details:") + for key, value in sorted(event.details.items()): + lines.append(f"- {key}: {value}") + return "\n".join(lines) + + +class TelegramChannel: + def __init__(self, *, bot_token: str, chat_id: str, timeout_seconds: float = 10.0) -> None: + self._bot_token = bot_token + self._chat_id = chat_id + self._timeout_seconds = timeout_seconds + + async def send(self, event: AlertEvent) -> None: + url = f"https://api.telegram.org/bot{self._bot_token}/sendMessage" + payload = { + "chat_id": self._chat_id, + "text": _format_event_text(event), + "disable_web_page_preview": True, + } + timeout = httpx.Timeout(self._timeout_seconds) + async with httpx.AsyncClient(timeout=timeout) as client: + response = await client.post(url, json=payload) + response.raise_for_status() + + +class DiscordWebhookChannel: + def __init__(self, *, webhook_url: str, timeout_seconds: float = 10.0) -> None: + self._webhook_url = webhook_url + self._timeout_seconds = timeout_seconds + + async def send(self, event: AlertEvent) -> None: + payload = {"content": _format_event_text(event)} + timeout = httpx.Timeout(self._timeout_seconds) + async with httpx.AsyncClient(timeout=timeout) as client: + response = await client.post(self._webhook_url, json=payload) + response.raise_for_status() + + +class EmailSmtpChannel: + def __init__( + self, + *, + host: str, + port: int, + sender: str, + recipients: list[str], + username: str | None = None, + password: str | None = None, + use_tls: bool = True, + timeout_seconds: float = 10.0, + ) -> None: + if not recipients: + raise ValueError("recipients must not be empty") + + self._host = host + self._port = port + self._sender = sender + self._recipients = recipients + self._username = username + self._password = password + self._use_tls = use_tls + self._timeout_seconds = timeout_seconds + + async def send(self, event: AlertEvent) -> None: + message = EmailMessage() + message["From"] = self._sender + message["To"] = ", ".join(self._recipients) + message["Subject"] = f"[{event.severity.upper()}] {event.title}" + message.set_content(_format_event_text(event)) + + await asyncio.to_thread(self._send_sync, message) + + def _send_sync(self, message: EmailMessage) -> None: + with smtplib.SMTP(self._host, self._port, timeout=self._timeout_seconds) as client: + if self._use_tls: + client.starttls() + if self._username and self._password: + client.login(self._username, self._password) + client.send_message(message) + + +class _AlertSettings(Protocol): + alerts_enabled: bool + alert_min_severity: str + alert_dedup_seconds: float + alert_on_trade_events: bool + alert_on_error_events: bool + alert_on_threshold_events: bool + alert_on_system_events: bool + + telegram_alerts_enabled: bool + telegram_bot_token: str | None + telegram_chat_id: str | None + + discord_alerts_enabled: bool + discord_webhook_url: str | None + + email_alerts_enabled: bool + email_smtp_host: str | None + email_smtp_port: int + email_smtp_username: str | None + email_smtp_password: str | None + email_alert_from: str | None + email_alert_to: str | None + email_smtp_use_tls: bool + + +def _as_alert_severity(value: str) -> AlertSeverity: + normalized = value.strip().lower() + if normalized == "info": + return "info" + if normalized == "warning": + return "warning" + if normalized == "error": + return "error" + if normalized == "critical": + return "critical" + else: + raise ValueError("alert_min_severity must be one of: info, warning, error, critical") + + +def build_channels_from_settings(settings: _AlertSettings) -> list[AlertChannel]: + channels: list[AlertChannel] = [] + + if settings.telegram_alerts_enabled: + if not settings.telegram_bot_token or not settings.telegram_chat_id: + raise ValueError("telegram alerts require bot token and chat id") + channels.append( + TelegramChannel( + bot_token=settings.telegram_bot_token, + chat_id=settings.telegram_chat_id, + ) + ) + + if settings.discord_alerts_enabled: + if not settings.discord_webhook_url: + raise ValueError("discord alerts require webhook url") + channels.append(DiscordWebhookChannel(webhook_url=settings.discord_webhook_url)) + + if settings.email_alerts_enabled: + if not settings.email_smtp_host: + raise ValueError("email alerts require SMTP host") + if not settings.email_alert_from: + raise ValueError("email alerts require sender address") + if not settings.email_alert_to: + raise ValueError("email alerts require recipient list") + + recipients = [ + address.strip() for address in settings.email_alert_to.split(",") if address.strip() + ] + channels.append( + EmailSmtpChannel( + host=settings.email_smtp_host, + port=settings.email_smtp_port, + sender=settings.email_alert_from, + recipients=recipients, + username=settings.email_smtp_username, + password=settings.email_smtp_password, + use_tls=settings.email_smtp_use_tls, + ) + ) + + return channels + + +def build_notifier_from_settings(settings: _AlertSettings) -> AlertNotifier: + severity = _as_alert_severity(settings.alert_min_severity) + channels = build_channels_from_settings(settings) + category_flags = { + "trade": settings.alert_on_trade_events, + "error": settings.alert_on_error_events, + "threshold": settings.alert_on_threshold_events, + "system": settings.alert_on_system_events, + } + return AlertNotifier( + channels, + enabled=settings.alerts_enabled, + min_severity=severity, + dedup_seconds=settings.alert_dedup_seconds, + category_flags=category_flags, + ) diff --git a/src/arbitrade/api/app.py b/src/arbitrade/api/app.py index 5b5fba8..a39fd29 100644 --- a/src/arbitrade/api/app.py +++ b/src/arbitrade/api/app.py @@ -1,28 +1,40 @@ from __future__ import annotations +from collections.abc import AsyncIterator +from contextlib import asynccontextmanager + from fastapi import FastAPI +from arbitrade.alerting.notifier import build_notifier_from_settings from arbitrade.api.control_state import DashboardControlState from arbitrade.api.routes import public_router, router from arbitrade.config.settings import Settings from arbitrade.logging_setup import configure_logging from arbitrade.metrics import MetricsCalculator +from arbitrade.runtime.lifecycle import graceful_shutdown, restore_runtime_state from arbitrade.storage.db import DuckDBStore +from arbitrade.storage.repositories import AuditRepository, RuntimeStateRepository def create_app(settings: Settings) -> FastAPI: configure_logging(settings.log_level, settings.log_json) - if bool(settings.dashboard_auth_username) ^ bool(settings.dashboard_auth_password): - raise ValueError("dashboard auth requires both username and password") - db = DuckDBStore(settings) db.migrate() - app = FastAPI(title="arbitrade", version="0.1.0") + @asynccontextmanager + async def lifespan(app: FastAPI) -> AsyncIterator[None]: + await restore_runtime_state(app) + yield + await graceful_shutdown(app) + + app = FastAPI(title="arbitrade", version="0.1.0", lifespan=lifespan) app.state.settings = settings app.state.store = db app.state.metrics = MetricsCalculator(db) + app.state.audit_repository = AuditRepository(db) + app.state.runtime_state_repository = RuntimeStateRepository(db) + app.state.alert_notifier = build_notifier_from_settings(settings) app.state.dashboard_controls = DashboardControlState( is_running=not settings.kill_switch_active, ) diff --git a/src/arbitrade/api/routes.py b/src/arbitrade/api/routes.py index c3fff79..312b888 100644 --- a/src/arbitrade/api/routes.py +++ b/src/arbitrade/api/routes.py @@ -12,8 +12,10 @@ from fastapi import APIRouter, Depends, Request from fastapi.responses import HTMLResponse, JSONResponse, StreamingResponse from fastapi.templating import Jinja2Templates +from arbitrade.alerting.notifier import SupportsAlerts, SupportsAlertStatus from arbitrade.api.auth import require_dashboard_auth from arbitrade.api.control_state import DashboardControlState +from arbitrade.storage.repositories import AuditRecord, AuditRepository router = APIRouter(dependencies=[Depends(require_dashboard_auth)]) public_router = APIRouter() @@ -146,14 +148,11 @@ def _dashboard_charts(request: Request) -> dict[str, object]: chart_rows = list(reversed(opportunity_rows)) labels = [ - row[0].isoformat() if isinstance( - row[0], datetime) else f"opportunity-{index + 1}" + row[0].isoformat() if isinstance(row[0], datetime) else f"opportunity-{index + 1}" for index, row in enumerate(chart_rows) ] - net_pct_values = [float(row[2]) if row[2] - is not None else 0.0 for row in chart_rows] - est_profit_values = [float(row[3]) if row[3] - is not None else 0.0 for row in chart_rows] + net_pct_values = [float(row[2]) if row[2] is not None else 0.0 for row in chart_rows] + est_profit_values = [float(row[3]) if row[3] is not None else 0.0 for row in chart_rows] cycles = [str(row[1]) for row in chart_rows] return { @@ -170,9 +169,109 @@ def _dashboard_controls_state(request: Request) -> DashboardControlState: return cast(DashboardControlState, request.app.state.dashboard_controls) +def _audit_repository(request: Request) -> AuditRepository | None: + repository = getattr(request.app.state, "audit_repository", None) + return cast(AuditRepository | None, repository) + + +def _record_audit( + request: Request, + *, + actor: str, + event_type: str, + decision: str, + payload: dict[str, object] | None = None, +) -> None: + repository = _audit_repository(request) + if repository is None: + return + correlation_id = request.headers.get("x-request-id") + repository.insert( + AuditRecord( + occurred_at=datetime.now(UTC), + actor=actor, + event_type=event_type, + decision=decision, + payload=None if payload is None else {str(key): payload[key] for key in payload}, + correlation_id=correlation_id, + ) + ) + + +def _dashboard_audit(request: Request, *, limit: int = 15) -> dict[str, object]: + repository = _audit_repository(request) + if repository is None: + return { + "entries": [], + "generated_at": datetime.now(UTC).isoformat(), + } + + records = repository.list_recent(limit=limit) + entries: list[dict[str, str]] = [] + for record in records: + payload_text = "—" + if record.payload: + payload_text = json.dumps(record.payload) + entries.append( + { + "occurred_at": record.occurred_at.isoformat(), + "actor": record.actor, + "event_type": record.event_type, + "decision": record.decision, + "payload": payload_text, + "correlation_id": record.correlation_id or "—", + } + ) + + return { + "entries": entries, + "generated_at": datetime.now(UTC).isoformat(), + } + + +def _alert_notifier(request: Request) -> SupportsAlerts | None: + notifier = getattr(request.app.state, "alert_notifier", None) + return cast(SupportsAlerts | None, notifier) + + +def _alert_status_snapshot(request: Request) -> dict[str, object]: + notifier = getattr(request.app.state, "alert_notifier", None) + if isinstance(notifier, SupportsAlertStatus): + return notifier.status_snapshot() + return { + "enabled": False, + "has_channels": False, + "configured_channels": [], + "min_severity": "—", + "dedup_seconds": 0.0, + "last_result": "unavailable", + "last_attempted_at": None, + "last_success_at": None, + "last_error": None, + "last_event": None, + "last_channel_results": [], + } + + def _dashboard_controls(request: Request) -> dict[str, object]: controls = _dashboard_controls_state(request) settings = request.app.state.settings + alert_status = _alert_status_snapshot(request) + last_event = alert_status.get("last_event") + last_event_title = "—" + if isinstance(last_event, dict): + title_value = last_event.get("title") + if isinstance(title_value, str): + last_event_title = title_value + + configured_channels = alert_status.get("configured_channels") + channels_display = "—" + if isinstance(configured_channels, list) and configured_channels: + channels_display = ", ".join(str(channel) for channel in configured_channels) + + dedup_seconds_raw = alert_status.get("dedup_seconds", 0.0) + dedup_seconds = float(dedup_seconds_raw) if isinstance(dedup_seconds_raw, int | float) else 0.0 + return { "execution_status": "running" if controls.is_running else "stopped", "kill_switch_status": "active" if controls.kill_switch.is_active else "inactive", @@ -191,13 +290,23 @@ def _dashboard_controls(request: Request) -> dict[str, object]: else f"{float(settings.max_trade_capital_usd):.2f}" ), "max_concurrent_trades": ( - "—" if settings.max_concurrent_trades is None else str( - settings.max_concurrent_trades) + "—" if settings.max_concurrent_trades is None else str(settings.max_concurrent_trades) ), "max_concurrent_trades_value": ( - "" if settings.max_concurrent_trades is None else str( - settings.max_concurrent_trades) + "" if settings.max_concurrent_trades is None else str(settings.max_concurrent_trades) ), + "alerts_enabled": "enabled" if bool(alert_status.get("enabled", False)) else "disabled", + "alerts_channels": channels_display, + "alerts_min_severity": str(alert_status.get("min_severity", "—")), + "alerts_dedup_seconds": f"{dedup_seconds:.0f}", + "alerts_last_result": str(alert_status.get("last_result", "unavailable")), + "alerts_last_attempted_at": str(alert_status.get("last_attempted_at") or "—"), + "alerts_last_success_at": str(alert_status.get("last_success_at") or "—"), + "alerts_last_event_title": last_event_title, + "alerts_last_error": str(alert_status.get("last_error") or "—"), + "alerts_last_channel_results": [ + str(item) for item in cast(list[object], alert_status.get("last_channel_results", [])) + ], "updated_at": controls.updated_at.isoformat(), "start_endpoint": "/dashboard/control/start", "stop_endpoint": "/dashboard/control/stop", @@ -218,7 +327,9 @@ def _form_bool(value: str | None) -> bool: return value.lower() in {"1", "true", "yes", "on"} -async def _dashboard_response(request: Request, template_name: str = "dashboard.html") -> HTMLResponse: +async def _dashboard_response( + request: Request, template_name: str = "dashboard.html" +) -> HTMLResponse: return templates.TemplateResponse( request=request, name=template_name, @@ -229,6 +340,7 @@ async def _dashboard_response(request: Request, template_name: str = "dashboard. "overview_endpoint": "/dashboard/fragment/overview", "controls_endpoint": "/dashboard/fragment/controls", "charts_endpoint": "/dashboard/fragment/charts", + "audit_endpoint": "/dashboard/fragment/audit", "stream_endpoint": "/dashboard/stream/metrics", "overview_stream_endpoint": "/dashboard/stream/overview", }, @@ -281,11 +393,45 @@ async def dashboard_charts(request: Request) -> HTMLResponse: ) +@router.get("/dashboard/fragment/audit", response_class=HTMLResponse) +async def dashboard_audit(request: Request) -> HTMLResponse: + return templates.TemplateResponse( + request=request, + name="partials/audit.html", + context={"request": request, **_dashboard_audit(request)}, + ) + + +@router.get("/dashboard/api/alerts/status", response_class=JSONResponse) +async def dashboard_alert_status(request: Request) -> JSONResponse: + return JSONResponse(_alert_status_snapshot(request)) + + +@router.get("/dashboard/api/audit/recent", response_class=JSONResponse) +async def dashboard_audit_recent(request: Request) -> JSONResponse: + return JSONResponse(_dashboard_audit(request, limit=25)) + + @router.post("/dashboard/control/start", response_class=HTMLResponse) async def dashboard_control_start(request: Request) -> HTMLResponse: controls = _dashboard_controls_state(request) controls.is_running = True controls.mark_updated() + notifier = _alert_notifier(request) + if notifier is not None: + await notifier.notify( + category="system", + severity="info", + title="Execution started", + message="Dashboard control started execution.", + ) + _record_audit( + request, + actor="dashboard_user", + event_type="dashboard.control.start", + decision="approved", + payload={"execution_status": "running"}, + ) return templates.TemplateResponse( request=request, name="partials/controls.html", @@ -298,6 +444,21 @@ async def dashboard_control_stop(request: Request) -> HTMLResponse: controls = _dashboard_controls_state(request) controls.is_running = False controls.mark_updated() + notifier = _alert_notifier(request) + if notifier is not None: + await notifier.notify( + category="system", + severity="warning", + title="Execution stopped", + message="Dashboard control stopped execution.", + ) + _record_audit( + request, + actor="dashboard_user", + event_type="dashboard.control.stop", + decision="approved", + payload={"execution_status": "stopped"}, + ) return templates.TemplateResponse( request=request, name="partials/controls.html", @@ -313,6 +474,22 @@ async def dashboard_control_kill_switch(request: Request) -> HTMLResponse: controls.kill_switch.activate(reason=reason) controls.is_running = False controls.mark_updated() + notifier = _alert_notifier(request) + if notifier is not None: + await notifier.notify( + category="threshold", + severity="critical", + title="Kill switch activated", + message="Kill switch triggered from dashboard control.", + details={"reason": reason}, + ) + _record_audit( + request, + actor="dashboard_user", + event_type="dashboard.control.kill_switch", + decision="approved", + payload={"reason": reason}, + ) return templates.TemplateResponse( request=request, name="partials/controls.html", @@ -335,12 +512,46 @@ async def dashboard_control_config(request: Request) -> HTMLResponse: ) if "max_concurrent_trades" in form: max_concurrent_value = form["max_concurrent_trades"].strip() - settings.max_concurrent_trades = int( - max_concurrent_value) if max_concurrent_value else None + settings.max_concurrent_trades = int(max_concurrent_value) if max_concurrent_value else None settings.paper_trading_mode = _form_bool(form.get("paper_trading_mode")) controls.mark_updated() + notifier = _alert_notifier(request) + if notifier is not None: + await notifier.notify( + category="system", + severity="info", + title="Runtime config updated", + message="Dashboard control updated runtime risk and execution settings.", + details={ + "trade_capital_usd": f"{settings.trade_capital_usd}", + "max_trade_capital_usd": ( + "none" + if settings.max_trade_capital_usd is None + else f"{settings.max_trade_capital_usd}" + ), + "max_concurrent_trades": ( + "none" + if settings.max_concurrent_trades is None + else f"{settings.max_concurrent_trades}" + ), + "paper_trading_mode": "true" if settings.paper_trading_mode else "false", + }, + ) + _record_audit( + request, + actor="dashboard_user", + event_type="dashboard.control.config", + decision="approved", + payload={ + "trade_capital_usd": settings.trade_capital_usd, + "max_trade_capital_usd": settings.max_trade_capital_usd, + "max_concurrent_trades": settings.max_concurrent_trades, + "paper_trading_mode": settings.paper_trading_mode, + }, + ) + return templates.TemplateResponse( request=request, name="partials/controls.html", diff --git a/src/arbitrade/config/settings.py b/src/arbitrade/config/settings.py index 0bd450d..b503e42 100644 --- a/src/arbitrade/config/settings.py +++ b/src/arbitrade/config/settings.py @@ -3,7 +3,7 @@ from __future__ import annotations from functools import lru_cache from pathlib import Path -from pydantic import Field +from pydantic import Field, field_validator, model_validator from pydantic_settings import BaseSettings, SettingsConfigDict @@ -31,58 +31,116 @@ class Settings(BaseSettings): alias="DASHBOARD_AUTH_PASSWORD", ) - duckdb_path: Path = Field(default=Path( - "./data/arbitrade.duckdb"), alias="DUCKDB_PATH") + alerts_enabled: bool = Field(default=True, alias="ALERTS_ENABLED") + alert_min_severity: str = Field(default="warning", alias="ALERT_MIN_SEVERITY") + alert_dedup_seconds: float = Field(default=30.0, alias="ALERT_DEDUP_SECONDS") + alert_on_trade_events: bool = Field(default=True, alias="ALERT_ON_TRADE_EVENTS") + alert_on_error_events: bool = Field(default=True, alias="ALERT_ON_ERROR_EVENTS") + alert_on_threshold_events: bool = Field(default=True, alias="ALERT_ON_THRESHOLD_EVENTS") + alert_on_system_events: bool = Field(default=True, alias="ALERT_ON_SYSTEM_EVENTS") - kraken_rest_url: str = Field( - default="https://api.kraken.com", alias="KRAKEN_REST_URL") - kraken_ws_url: str = Field( - default="wss://ws.kraken.com/v2", alias="KRAKEN_WS_URL") + telegram_alerts_enabled: bool = Field(default=False, alias="TELEGRAM_ALERTS_ENABLED") + telegram_bot_token: str | None = Field(default=None, alias="TELEGRAM_BOT_TOKEN") + telegram_chat_id: str | None = Field(default=None, alias="TELEGRAM_CHAT_ID") + + discord_alerts_enabled: bool = Field(default=False, alias="DISCORD_ALERTS_ENABLED") + discord_webhook_url: str | None = Field(default=None, alias="DISCORD_WEBHOOK_URL") + + email_alerts_enabled: bool = Field(default=False, alias="EMAIL_ALERTS_ENABLED") + email_smtp_host: str | None = Field(default=None, alias="EMAIL_SMTP_HOST") + email_smtp_port: int = Field(default=587, alias="EMAIL_SMTP_PORT") + email_smtp_username: str | None = Field(default=None, alias="EMAIL_SMTP_USERNAME") + email_smtp_password: str | None = Field(default=None, alias="EMAIL_SMTP_PASSWORD") + email_alert_from: str | None = Field(default=None, alias="EMAIL_ALERT_FROM") + email_alert_to: str | None = Field(default=None, alias="EMAIL_ALERT_TO") + email_smtp_use_tls: bool = Field(default=True, alias="EMAIL_SMTP_USE_TLS") + + duckdb_path: Path = Field(default=Path("./data/arbitrade.duckdb"), alias="DUCKDB_PATH") + + kraken_rest_url: str = Field(default="https://api.kraken.com", alias="KRAKEN_REST_URL") + kraken_ws_url: str = Field(default="wss://ws.kraken.com/v2", alias="KRAKEN_WS_URL") kraken_private_rate_limit_seconds: float = Field( default=1.0, alias="KRAKEN_PRIVATE_RATE_LIMIT_SECONDS" ) - kraken_http_timeout_seconds: float = Field( - default=10.0, alias="KRAKEN_HTTP_TIMEOUT_SECONDS") - kraken_retry_attempts: int = Field( - default=3, alias="KRAKEN_RETRY_ATTEMPTS") + kraken_http_timeout_seconds: float = Field(default=10.0, alias="KRAKEN_HTTP_TIMEOUT_SECONDS") + kraken_retry_attempts: int = Field(default=3, alias="KRAKEN_RETRY_ATTEMPTS") kraken_retry_base_delay_seconds: float = Field( default=0.25, alias="KRAKEN_RETRY_BASE_DELAY_SECONDS" ) kraken_api_key: str | None = Field(default=None, alias="KRAKEN_API_KEY") - kraken_api_secret: str | None = Field( - default=None, alias="KRAKEN_API_SECRET") - ws_heartbeat_timeout_seconds: float = Field( - default=20.0, alias="WS_HEARTBEAT_TIMEOUT_SECONDS") - ws_max_staleness_seconds: float = Field( - default=5.0, alias="WS_MAX_STALENESS_SECONDS") + kraken_api_secret: str | None = Field(default=None, alias="KRAKEN_API_SECRET") + kraken_api_key_permissions: str = Field( + default="query,trade", + alias="KRAKEN_API_KEY_PERMISSIONS", + ) + ws_heartbeat_timeout_seconds: float = Field(default=20.0, alias="WS_HEARTBEAT_TIMEOUT_SECONDS") + ws_max_staleness_seconds: float = Field(default=5.0, alias="WS_MAX_STALENESS_SECONDS") paper_trading_mode: bool = Field(default=True, alias="PAPER_TRADING_MODE") trade_capital_usd: float = Field(default=100.0, alias="TRADE_CAPITAL_USD") - max_trade_capital_usd: float = Field( - default=100.0, alias="MAX_TRADE_CAPITAL_USD") - max_concurrent_trades: int | None = Field( - default=None, alias="MAX_CONCURRENT_TRADES") + max_trade_capital_usd: float = Field(default=100.0, alias="MAX_TRADE_CAPITAL_USD") + max_concurrent_trades: int | None = Field(default=None, alias="MAX_CONCURRENT_TRADES") max_exposure_per_asset_usd: float | None = Field( default=None, alias="MAX_EXPOSURE_PER_ASSET_USD", ) - quote_balance_asset: str = Field( - default="USD", alias="QUOTE_BALANCE_ASSET") - min_order_size_usd: float | None = Field( - default=None, alias="MIN_ORDER_SIZE_USD") + quote_balance_asset: str = Field(default="USD", alias="QUOTE_BALANCE_ASSET") + min_order_size_usd: float | None = Field(default=None, alias="MIN_ORDER_SIZE_USD") kill_switch_active: bool = Field(default=False, alias="KILL_SWITCH_ACTIVE") - daily_loss_limit_usd: float | None = Field( - default=None, alias="DAILY_LOSS_LIMIT_USD") - cumulative_loss_limit_usd: float | None = Field( - default=None, alias="CUMULATIVE_LOSS_LIMIT_USD") - max_source_latency_ms: float | None = Field( - default=None, alias="MAX_SOURCE_LATENCY_MS") - max_apply_latency_ms: float | None = Field( - default=None, alias="MAX_APPLY_LATENCY_MS") - max_consecutive_failures: int | None = Field( - default=None, alias="MAX_CONSECUTIVE_FAILURES") + daily_loss_limit_usd: float | None = Field(default=None, alias="DAILY_LOSS_LIMIT_USD") + cumulative_loss_limit_usd: float | None = Field(default=None, alias="CUMULATIVE_LOSS_LIMIT_USD") + max_source_latency_ms: float | None = Field(default=None, alias="MAX_SOURCE_LATENCY_MS") + max_apply_latency_ms: float | None = Field(default=None, alias="MAX_APPLY_LATENCY_MS") + max_consecutive_failures: int | None = Field(default=None, alias="MAX_CONSECUTIVE_FAILURES") fernet_key: str | None = Field(default=None, alias="FERNET_KEY") + @field_validator("app_env") + @classmethod + def _validate_app_env(cls, value: str) -> str: + normalized = value.strip().lower() + if normalized not in {"dev", "test", "prod"}: + raise ValueError("APP_ENV must be one of: dev, test, prod") + return normalized + + @field_validator("log_level") + @classmethod + def _validate_log_level(cls, value: str) -> str: + normalized = value.strip().upper() + if normalized not in {"DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"}: + raise ValueError("LOG_LEVEL must be one of: DEBUG, INFO, WARNING, ERROR, CRITICAL") + return normalized + + @field_validator("alert_min_severity") + @classmethod + def _validate_alert_severity(cls, value: str) -> str: + normalized = value.strip().lower() + if normalized not in {"info", "warning", "error", "critical"}: + raise ValueError("ALERT_MIN_SEVERITY must be one of: info, warning, error, critical") + return normalized + + @model_validator(mode="after") + def _validate_security_constraints(self) -> Settings: + if bool(self.dashboard_auth_username) ^ bool(self.dashboard_auth_password): + raise ValueError("dashboard auth requires both username and password") + + if bool(self.kraken_api_key) ^ bool(self.kraken_api_secret): + raise ValueError("Kraken API auth requires both API key and secret") + + permissions = { + token.strip().lower() + for token in self.kraken_api_key_permissions.split(",") + if token.strip() + } + if permissions and ("query" not in permissions or "trade" not in permissions): + raise ValueError("KRAKEN_API_KEY_PERMISSIONS must include query and trade") + if "withdraw" in permissions or "withdrawals" in permissions: + raise ValueError("KRAKEN_API_KEY_PERMISSIONS must not include withdrawal scope") + + if self.alert_dedup_seconds < 0.0: + raise ValueError("ALERT_DEDUP_SECONDS must be >= 0") + + return self + @lru_cache(maxsize=1) def get_settings() -> Settings: diff --git a/src/arbitrade/exchange/kraken_ws.py b/src/arbitrade/exchange/kraken_ws.py index 2c473de..e962228 100644 --- a/src/arbitrade/exchange/kraken_ws.py +++ b/src/arbitrade/exchange/kraken_ws.py @@ -11,6 +11,7 @@ import orjson import structlog import websockets +from arbitrade.alerting.notifier import AlertSeverity, SupportsAlerts from arbitrade.config.settings import Settings from arbitrade.exchange.models import BookDelta, BookLevel @@ -24,10 +25,13 @@ class WsMessage: class KrakenWsClient: - def __init__(self, settings: Settings) -> None: + def __init__(self, settings: Settings, *, alert_notifier: SupportsAlerts | None = None) -> None: self._settings = settings self._last_message_at: datetime | None = None self._stop = asyncio.Event() + self._alert_notifier = alert_notifier + self._has_connected_once = False + self._was_disconnected = False @property def is_stale(self) -> bool: @@ -48,20 +52,55 @@ class KrakenWsClient: self._settings.kraken_ws_url, max_size=2_000_000 ) as ws: _LOG.info("kraken_ws_connected", url=self._settings.kraken_ws_url) + if self._has_connected_once and self._was_disconnected: + await self._notify( + category="system", + severity="info", + title="WebSocket reconnected", + message="Kraken WebSocket connection restored.", + details={"url": self._settings.kraken_ws_url}, + ) + self._has_connected_once = True + self._was_disconnected = False delay = 1.0 async for raw in self._recv_loop(ws): yield raw except Exception as exc: _LOG.warning("kraken_ws_disconnected", error=str(exc), reconnect_in=delay) + self._was_disconnected = True + await self._notify( + category="system", + severity="warning", + title="WebSocket disconnected", + message="Kraken WebSocket disconnected, reconnect scheduled.", + details={ + "error": str(exc), + "reconnect_in_seconds": f"{delay}", + }, + ) await asyncio.sleep(delay) delay = min(delay * 2, 30.0) async def _recv_loop(self, ws: Any) -> AsyncIterator[WsMessage]: while not self._stop.is_set(): t0 = time.perf_counter() - raw = await asyncio.wait_for( - ws.recv(), timeout=self._settings.ws_heartbeat_timeout_seconds - ) + try: + raw = await asyncio.wait_for( + ws.recv(), timeout=self._settings.ws_heartbeat_timeout_seconds + ) + except TimeoutError: + await self._notify( + category="system", + severity="critical", + title="WebSocket staleness abort", + message="No WebSocket heartbeat within configured timeout; reconnecting.", + details={ + "heartbeat_timeout_seconds": ( + f"{self._settings.ws_heartbeat_timeout_seconds}" + ), + }, + ) + raise parse_start = time.perf_counter() payload = orjson.loads(raw) self._last_message_at = datetime.now(UTC) @@ -74,6 +113,25 @@ class KrakenWsClient: if isinstance(payload, dict): yield WsMessage(received_at=self._last_message_at, payload=payload) + async def _notify( + self, + *, + category: str, + severity: AlertSeverity, + title: str, + message: str, + details: dict[str, str] | None = None, + ) -> None: + if self._alert_notifier is None: + return + await self._alert_notifier.notify( + category=category, + severity=severity, + title=title, + message=message, + details=details, + ) + @staticmethod def parse_book_delta(message: dict[str, Any]) -> BookDelta | None: # Kraken v2 book update shape can vary by channel; keep parser defensive. diff --git a/src/arbitrade/execution/sequencer.py b/src/arbitrade/execution/sequencer.py index 45e4eb7..35f7236 100644 --- a/src/arbitrade/execution/sequencer.py +++ b/src/arbitrade/execution/sequencer.py @@ -5,9 +5,16 @@ from dataclasses import dataclass from datetime import UTC, datetime from typing import Any, Protocol +from arbitrade.alerting.notifier import SupportsAlerts from arbitrade.detection.engine import OpportunityEvent from arbitrade.storage.executions import AsyncExecutionWriter -from arbitrade.storage.repositories import OrderRecord, PnLRecord, TradeRecord +from arbitrade.storage.repositories import ( + AuditRecord, + AuditRepository, + OrderRecord, + PnLRecord, + TradeRecord, +) class SupportsOrderPlacement(Protocol): @@ -42,11 +49,15 @@ class TriangularExecutionSequencer: available_pairs: Sequence[str], volume_for_leg: Callable[[OpportunityEvent, ExecutionLeg, int], float] | None = None, execution_writer: AsyncExecutionWriter | None = None, + alert_notifier: SupportsAlerts | None = None, + audit_repository: AuditRepository | None = None, ) -> None: self._rest_client = rest_client self._available_pairs = {self._normalize_pair(pair) for pair in available_pairs} self._volume_for_leg = volume_for_leg or self._default_volume_for_leg self._execution_writer = execution_writer + self._alert_notifier = alert_notifier + self._audit_repository = audit_repository @staticmethod def _normalize_pair(pair: str) -> str: @@ -146,6 +157,33 @@ class TriangularExecutionSequencer: volume=leg.volume, ) except Exception as exc: + if self._audit_repository is not None: + self._audit_repository.insert( + AuditRecord( + occurred_at=datetime.now(UTC), + actor="execution_engine", + event_type="execution.trade.failed", + decision="rejected", + payload={ + "cycle": event.cycle, + "failed_leg_index": idx, + "error": str(exc), + }, + correlation_id=trade_ref, + ) + ) + if self._alert_notifier is not None: + await self._alert_notifier.notify( + category="error", + severity="error", + title="Trade execution failed", + message="Triangular execution failed before completing all legs.", + details={ + "cycle": event.cycle, + "failed_leg_index": str(idx), + "error": str(exc), + }, + ) if self._execution_writer is not None: await self._execution_writer.enqueue( TradeRecord( @@ -213,6 +251,35 @@ class TriangularExecutionSequencer: ) ) + if self._alert_notifier is not None: + await self._alert_notifier.notify( + category="trade", + severity="warning" if event.est_profit < 0.0 else "info", + title="Trade execution completed", + message="Triangular execution completed all requested legs.", + details={ + "cycle": event.cycle, + "completed_legs": str(len(legs)), + "estimated_pnl_usd": f"{event.est_profit}", + }, + ) + + if self._audit_repository is not None: + self._audit_repository.insert( + AuditRecord( + occurred_at=datetime.now(UTC), + actor="execution_engine", + event_type="execution.trade.completed", + decision="approved", + payload={ + "cycle": event.cycle, + "completed_legs": len(legs), + "estimated_pnl_usd": event.est_profit, + }, + correlation_id=trade_ref, + ) + ) + return TriangularExecutionResult( success=True, requested_legs=legs, diff --git a/src/arbitrade/market_data/feed.py b/src/arbitrade/market_data/feed.py index 601bd94..a633789 100644 --- a/src/arbitrade/market_data/feed.py +++ b/src/arbitrade/market_data/feed.py @@ -7,6 +7,7 @@ from datetime import UTC, datetime import structlog +from arbitrade.alerting.notifier import SupportsAlerts, dispatch_alert_nowait from arbitrade.detection.engine import IncrementalCycleDetector, OpportunityEvent from arbitrade.exchange.kraken_ws import KrakenWsClient from arbitrade.market_data.order_book import OrderBook @@ -17,6 +18,7 @@ from arbitrade.risk.stop_conditions import StopConditionsGuard from arbitrade.risk.trade_limits import TradeLimitsGuard from arbitrade.storage.market_snapshots import AsyncMarketSnapshotWriter, MarketSnapshot from arbitrade.storage.opportunities import AsyncOpportunityWriter +from arbitrade.storage.repositories import AuditRecord, AuditRepository _LOG = structlog.get_logger(__name__) @@ -47,6 +49,8 @@ class MarketDataFeed: quote_balance_asset: str = "USD", kill_switch: KillSwitch | None = None, stop_conditions_guard: StopConditionsGuard | None = None, + alert_notifier: SupportsAlerts | None = None, + audit_repository: AuditRepository | None = None, ) -> None: self._ws_client = ws_client self._snapshot_writer = snapshot_writer @@ -64,6 +68,8 @@ class MarketDataFeed: self._quote_balance_asset = quote_balance_asset.upper() self._kill_switch = kill_switch self._stop_conditions_guard = stop_conditions_guard + self._alert_notifier = alert_notifier + self._audit_repository = audit_repository if self._trade_capital <= 0.0: raise ValueError("trade_capital must be > 0.0") @@ -137,6 +143,20 @@ class MarketDataFeed: reason=self._stop_conditions_guard.halted_reason, symbol=delta.symbol, ) + if self._audit_repository is not None: + self._audit_repository.insert( + AuditRecord( + occurred_at=datetime.now(UTC), + actor="risk_manager", + event_type="risk.stop_condition_halt", + decision="rejected", + payload={ + "reason": self._stop_conditions_guard.halted_reason + or "unknown", + "symbol": delta.symbol, + }, + ) + ) if self._detector is not None: opportunities = self._detector.opportunities_for_updated_pair( @@ -151,6 +171,21 @@ class MarketDataFeed: ) for event in opportunities: + if self._audit_repository is not None: + self._audit_repository.insert( + AuditRecord( + occurred_at=datetime.now(UTC), + actor="detector", + event_type="detector.opportunity", + decision="scored", + payload={ + "cycle": event.cycle, + "updated_pair": event.updated_pair, + "net_pct": event.net_pct, + "est_profit": event.est_profit, + }, + ) + ) _LOG.info( "opportunity_detected", cycle=event.cycle, @@ -171,6 +206,19 @@ class MarketDataFeed: updated_pair=event.updated_pair, net_pct=event.net_pct, ) + if self._audit_repository is not None: + self._audit_repository.insert( + AuditRecord( + occurred_at=datetime.now(UTC), + actor="execution_engine", + event_type="execution.paper_trade", + decision="skipped", + payload={ + "cycle": event.cycle, + "updated_pair": event.updated_pair, + }, + ) + ) continue if self._opportunity_executor is None: @@ -179,6 +227,19 @@ class MarketDataFeed: cycle=event.cycle, updated_pair=event.updated_pair, ) + if self._audit_repository is not None: + self._audit_repository.insert( + AuditRecord( + occurred_at=datetime.now(UTC), + actor="execution_engine", + event_type="execution.live_trade", + decision="rejected", + payload={ + "reason": "missing_executor", + "cycle": event.cycle, + }, + ) + ) continue if self._kill_switch is not None and self._kill_switch.is_active: @@ -188,6 +249,19 @@ class MarketDataFeed: updated_pair=event.updated_pair, reason=self._kill_switch.reason, ) + if self._audit_repository is not None: + self._audit_repository.insert( + AuditRecord( + occurred_at=datetime.now(UTC), + actor="risk_manager", + event_type="risk.kill_switch", + decision="rejected", + payload={ + "reason": self._kill_switch.reason or "manual", + "cycle": event.cycle, + }, + ) + ) continue if ( @@ -200,6 +274,20 @@ class MarketDataFeed: updated_pair=event.updated_pair, reason=self._stop_conditions_guard.halted_reason, ) + if self._audit_repository is not None: + self._audit_repository.insert( + AuditRecord( + occurred_at=datetime.now(UTC), + actor="risk_manager", + event_type="risk.stop_condition", + decision="rejected", + payload={ + "reason": self._stop_conditions_guard.halted_reason + or "halted", + "cycle": event.cycle, + }, + ) + ) continue if self._loss_limit_guard is not None and self._loss_limit_guard.is_halted: @@ -209,6 +297,19 @@ class MarketDataFeed: updated_pair=event.updated_pair, reason=self._loss_limit_guard.halted_reason, ) + if self._audit_repository is not None: + self._audit_repository.insert( + AuditRecord( + occurred_at=datetime.now(UTC), + actor="risk_manager", + event_type="risk.loss_limit", + decision="rejected", + payload={ + "reason": self._loss_limit_guard.halted_reason or "halted", + "cycle": event.cycle, + }, + ) + ) continue if self._pre_trade_validator is not None and self._balance_provider is not None: @@ -227,6 +328,22 @@ class MarketDataFeed: updated_pair=event.updated_pair, required_by_asset=required_balances, ) + if self._audit_repository is not None: + self._audit_repository.insert( + AuditRecord( + occurred_at=datetime.now(UTC), + actor="risk_manager", + event_type="risk.pre_trade_validation", + decision="rejected", + payload={ + "cycle": event.cycle, + "required_by_asset": { + key: required_balances[key] + for key in required_balances + }, + }, + ) + ) continue exposure_by_asset = self._exposure_for_event(event) @@ -240,6 +357,21 @@ class MarketDataFeed: updated_pair=event.updated_pair, exposure_by_asset=exposure_by_asset, ) + if self._audit_repository is not None: + self._audit_repository.insert( + AuditRecord( + occurred_at=datetime.now(UTC), + actor="risk_manager", + event_type="risk.trade_limits", + decision="rejected", + payload={ + "cycle": event.cycle, + "exposure_by_asset": { + key: exposure_by_asset[key] for key in exposure_by_asset + }, + }, + ) + ) continue if self._trade_limits_guard is not None: @@ -247,10 +379,23 @@ class MarketDataFeed: try: outcome = await self._opportunity_executor(event) - except Exception: + except Exception as exc: if self._trade_limits_guard is not None: self._trade_limits_guard.close_trade(exposure_by_asset) + dispatch_alert_nowait( + self._alert_notifier, + category="system", + severity="critical", + title="Critical execution exception", + message="Unhandled exception raised by opportunity executor.", + details={ + "cycle": event.cycle, + "updated_pair": event.updated_pair, + "error": str(exc), + }, + ) + if self._stop_conditions_guard is not None: self._stop_conditions_guard.register_failure() if self._stop_conditions_guard.is_halted: @@ -274,6 +419,20 @@ class MarketDataFeed: cycle=event.cycle, updated_pair=event.updated_pair, ) + if self._audit_repository is not None: + self._audit_repository.insert( + AuditRecord( + occurred_at=datetime.now(UTC), + actor="execution_engine", + event_type="execution.live_trade", + decision="error", + payload={ + "cycle": event.cycle, + "updated_pair": event.updated_pair, + "error": str(exc), + }, + ) + ) continue if self._stop_conditions_guard is not None: @@ -299,6 +458,22 @@ class MarketDataFeed: if self._trade_limits_guard is not None and close_trade: self._trade_limits_guard.close_trade(exposure_by_asset) + if self._audit_repository is not None: + self._audit_repository.insert( + AuditRecord( + occurred_at=datetime.now(UTC), + actor="execution_engine", + event_type="execution.live_trade", + decision="approved", + payload={ + "cycle": event.cycle, + "updated_pair": event.updated_pair, + "realized_pnl": realized_pnl, + "close_trade": close_trade, + }, + ) + ) + await self._snapshot_writer.enqueue( MarketSnapshot( snapshot_at=datetime.now(UTC), diff --git a/src/arbitrade/risk/loss_limits.py b/src/arbitrade/risk/loss_limits.py index 05f859d..ab33199 100644 --- a/src/arbitrade/risk/loss_limits.py +++ b/src/arbitrade/risk/loss_limits.py @@ -2,6 +2,8 @@ from __future__ import annotations from datetime import UTC, date, datetime +from arbitrade.alerting.notifier import SupportsAlerts, dispatch_alert_nowait + class LossLimitGuard: def __init__( @@ -9,6 +11,7 @@ class LossLimitGuard: *, daily_loss_limit: float | None = None, cumulative_loss_limit: float | None = None, + alert_notifier: SupportsAlerts | None = None, ) -> None: self._daily_loss_limit = daily_loss_limit self._cumulative_loss_limit = cumulative_loss_limit @@ -21,6 +24,7 @@ class LossLimitGuard: self._cumulative_pnl = 0.0 self._daily_pnl: dict[date, float] = {} self._halted_reason: str | None = None + self._alert_notifier = alert_notifier @property def cumulative_pnl(self) -> float: @@ -52,6 +56,17 @@ class LossLimitGuard: and self._daily_pnl[day_key] <= -self._daily_loss_limit ): self._halted_reason = "daily_loss_limit_breached" + dispatch_alert_nowait( + self._alert_notifier, + category="threshold", + severity="critical", + title="Daily loss limit breached", + message="Trading halted because daily realized PnL crossed configured loss limit.", + details={ + "daily_pnl": f"{self._daily_pnl[day_key]}", + "daily_loss_limit": f"{self._daily_loss_limit}", + }, + ) return if ( @@ -59,3 +74,17 @@ class LossLimitGuard: and self._cumulative_pnl <= -self._cumulative_loss_limit ): self._halted_reason = "cumulative_loss_limit_breached" + dispatch_alert_nowait( + self._alert_notifier, + category="threshold", + severity="critical", + title="Cumulative loss limit breached", + message=( + "Trading halted because cumulative realized PnL crossed " + "configured loss limit." + ), + details={ + "cumulative_pnl": f"{self._cumulative_pnl}", + "cumulative_loss_limit": f"{self._cumulative_loss_limit}", + }, + ) diff --git a/src/arbitrade/risk/stop_conditions.py b/src/arbitrade/risk/stop_conditions.py index 703be77..1691787 100644 --- a/src/arbitrade/risk/stop_conditions.py +++ b/src/arbitrade/risk/stop_conditions.py @@ -1,5 +1,7 @@ from __future__ import annotations +from arbitrade.alerting.notifier import SupportsAlerts, dispatch_alert_nowait + class StopConditionsGuard: def __init__( @@ -8,6 +10,7 @@ class StopConditionsGuard: max_source_latency_ms: float | None = None, max_apply_latency_ms: float | None = None, max_consecutive_failures: int | None = None, + alert_notifier: SupportsAlerts | None = None, ) -> None: if max_source_latency_ms is not None and max_source_latency_ms <= 0.0: raise ValueError("max_source_latency_ms must be > 0.0") @@ -22,6 +25,7 @@ class StopConditionsGuard: self._consecutive_failures = 0 self._halted_reason: str | None = None + self._alert_notifier = alert_notifier @property def halted_reason(self) -> str | None: @@ -50,10 +54,32 @@ class StopConditionsGuard: and source_latency_ms > self._max_source_latency_ms ): self._halted_reason = "source_latency_limit_breached" + dispatch_alert_nowait( + self._alert_notifier, + category="threshold", + severity="critical", + title="Source latency limit breached", + message="Trading halted because source latency exceeded configured limit.", + details={ + "source_latency_ms": f"{source_latency_ms}", + "max_source_latency_ms": f"{self._max_source_latency_ms}", + }, + ) return if self._max_apply_latency_ms is not None and apply_latency_ms > self._max_apply_latency_ms: self._halted_reason = "apply_latency_limit_breached" + dispatch_alert_nowait( + self._alert_notifier, + category="threshold", + severity="critical", + title="Apply latency limit breached", + message="Trading halted because apply latency exceeded configured limit.", + details={ + "apply_latency_ms": f"{apply_latency_ms}", + "max_apply_latency_ms": f"{self._max_apply_latency_ms}", + }, + ) def register_failure(self) -> None: if self.is_halted: @@ -65,6 +91,17 @@ class StopConditionsGuard: and self._consecutive_failures >= self._max_consecutive_failures ): self._halted_reason = "consecutive_failures_limit_breached" + dispatch_alert_nowait( + self._alert_notifier, + category="threshold", + severity="critical", + title="Consecutive failures limit breached", + message="Trading halted because consecutive failures exceeded configured limit.", + details={ + "consecutive_failures": f"{self._consecutive_failures}", + "max_consecutive_failures": f"{self._max_consecutive_failures}", + }, + ) def register_success(self) -> None: if self.is_halted: diff --git a/src/arbitrade/risk/trade_limits.py b/src/arbitrade/risk/trade_limits.py index b5bc186..978142b 100644 --- a/src/arbitrade/risk/trade_limits.py +++ b/src/arbitrade/risk/trade_limits.py @@ -2,6 +2,8 @@ from __future__ import annotations from collections.abc import Mapping +from arbitrade.alerting.notifier import SupportsAlerts, dispatch_alert_nowait + class TradeLimitsGuard: def __init__( @@ -9,6 +11,7 @@ class TradeLimitsGuard: *, max_concurrent_trades: int | None = None, max_exposure_per_asset: float | None = None, + alert_notifier: SupportsAlerts | None = None, ) -> None: if max_concurrent_trades is not None and max_concurrent_trades <= 0: raise ValueError("max_concurrent_trades must be > 0") @@ -19,6 +22,7 @@ class TradeLimitsGuard: self._max_exposure_per_asset = max_exposure_per_asset self._active_trades = 0 self._asset_exposure: dict[str, float] = {} + self._alert_notifier = alert_notifier @property def active_trades(self) -> int: @@ -32,6 +36,17 @@ class TradeLimitsGuard: self._max_concurrent_trades is not None and self._active_trades >= self._max_concurrent_trades ): + dispatch_alert_nowait( + self._alert_notifier, + category="threshold", + severity="warning", + title="Concurrent trade limit reached", + message="Trade rejected by concurrent trade cap.", + details={ + "active_trades": f"{self._active_trades}", + "max_concurrent_trades": f"{self._max_concurrent_trades}", + }, + ) return False if self._max_exposure_per_asset is None: @@ -43,6 +58,18 @@ class TradeLimitsGuard: key = asset.upper() next_exposure = self._asset_exposure.get(key, 0.0) + exposure if next_exposure > self._max_exposure_per_asset: + dispatch_alert_nowait( + self._alert_notifier, + category="threshold", + severity="warning", + title="Asset exposure limit reached", + message="Trade rejected by per-asset exposure cap.", + details={ + "asset": key, + "next_exposure": f"{next_exposure}", + "max_exposure_per_asset": f"{self._max_exposure_per_asset}", + }, + ) return False return True diff --git a/src/arbitrade/runtime/__init__.py b/src/arbitrade/runtime/__init__.py new file mode 100644 index 0000000..210b16c --- /dev/null +++ b/src/arbitrade/runtime/__init__.py @@ -0,0 +1,15 @@ +"""Runtime lifecycle and recovery helpers.""" + +from arbitrade.runtime.lifecycle import ( + RuntimeRecoveryReport, + graceful_shutdown, + persist_runtime_snapshot, + restore_runtime_state, +) + +__all__ = [ + "RuntimeRecoveryReport", + "graceful_shutdown", + "persist_runtime_snapshot", + "restore_runtime_state", +] diff --git a/src/arbitrade/runtime/lifecycle.py b/src/arbitrade/runtime/lifecycle.py new file mode 100644 index 0000000..1da5f00 --- /dev/null +++ b/src/arbitrade/runtime/lifecycle.py @@ -0,0 +1,224 @@ +from __future__ import annotations + +import inspect +from dataclasses import dataclass +from datetime import UTC, datetime +from typing import Any, cast + +from fastapi import FastAPI + +from arbitrade.api.control_state import DashboardControlState +from arbitrade.storage.db import DuckDBStore +from arbitrade.storage.repositories import ( + AuditRecord, + AuditRepository, + RuntimeStateRecord, + RuntimeStateRepository, +) + + +@dataclass(slots=True) +class RuntimeRecoveryReport: + restored_from_snapshot: bool + snapshot_at: str | None + open_trades_detected: int + restart_guard_active: bool + + +def _controls(app: FastAPI) -> DashboardControlState: + return cast(DashboardControlState, app.state.dashboard_controls) + + +def _store(app: FastAPI) -> DuckDBStore: + return cast(DuckDBStore, app.state.store) + + +def _audit_repository(app: FastAPI) -> AuditRepository | None: + repository = getattr(app.state, "audit_repository", None) + return repository if isinstance(repository, AuditRepository) else None + + +def _runtime_repository(app: FastAPI) -> RuntimeStateRepository | None: + repository = getattr(app.state, "runtime_state_repository", None) + return repository if isinstance(repository, RuntimeStateRepository) else None + + +def _open_trade_count(store: DuckDBStore) -> int: + with store.connect() as conn: + row = conn.execute(""" + SELECT COUNT(*) + FROM trades + WHERE finished_at IS NULL + """).fetchone() + return int(row[0]) if row is not None else 0 + + +def _latest_balances(store: DuckDBStore) -> dict[str, Any] | None: + with store.connect() as conn: + row = conn.execute(""" + SELECT balances + FROM portfolio_snapshots + ORDER BY snapshot_at DESC + LIMIT 1 + """).fetchone() + + if row is None or row[0] is None: + return None + raw_balances = row[0] + if isinstance(raw_balances, str): + return {"raw": raw_balances} + return {"raw": str(raw_balances)} + + +def _record_audit( + app: FastAPI, + *, + event_type: str, + decision: str, + payload: dict[str, Any] | None = None, +) -> None: + repository = _audit_repository(app) + if repository is None: + return + repository.insert( + AuditRecord( + occurred_at=datetime.now(UTC), + actor="runtime", + event_type=event_type, + decision=decision, + payload=payload, + correlation_id=None, + ) + ) + + +async def _run_startup_reconciler(app: FastAPI) -> None: + reconciler = getattr(app.state, "startup_reconciler", None) + if reconciler is None: + return + + reconcile_member = getattr(reconciler, "reconcile_open_trades", None) + if reconcile_member is None or not callable(reconcile_member): + return + + result = reconcile_member() + if inspect.isawaitable(result): + await result + + +def persist_runtime_snapshot(app: FastAPI, *, note: str | None = None) -> RuntimeStateRecord | None: + repository = _runtime_repository(app) + if repository is None: + return None + + controls = _controls(app) + store = _store(app) + snapshot = RuntimeStateRecord( + snapshot_at=datetime.now(UTC), + is_running=controls.is_running, + kill_switch_active=controls.kill_switch.is_active, + kill_switch_reason=controls.kill_switch.reason, + open_trade_count=_open_trade_count(store), + last_known_balances=_latest_balances(store), + note=note, + ) + repository.insert(snapshot) + return snapshot + + +async def restore_runtime_state(app: FastAPI) -> RuntimeRecoveryReport: + controls = _controls(app) + store = _store(app) + runtime_repository = _runtime_repository(app) + + restored_from_snapshot = False + snapshot_at: str | None = None + + latest = runtime_repository.latest() if runtime_repository is not None else None + if latest is not None: + restored_from_snapshot = True + snapshot_at = latest.snapshot_at.isoformat() + controls.is_running = latest.is_running + if latest.kill_switch_active: + controls.kill_switch.activate( + reason=latest.kill_switch_reason or "recovered") + else: + controls.kill_switch.deactivate() + controls.mark_updated() + + open_trades = _open_trade_count(store) + restart_guard_active = False + if open_trades > 0: + controls.is_running = False + if not controls.kill_switch.is_active: + controls.kill_switch.activate( + reason="recovery_open_trades_detected") + controls.mark_updated() + restart_guard_active = True + + report = RuntimeRecoveryReport( + restored_from_snapshot=restored_from_snapshot, + snapshot_at=snapshot_at, + open_trades_detected=open_trades, + restart_guard_active=restart_guard_active, + ) + app.state.recovery_report = report + + _record_audit( + app, + event_type="runtime.startup_recovery", + decision="applied", + payload={ + "restored_from_snapshot": restored_from_snapshot, + "open_trades_detected": open_trades, + "restart_guard_active": restart_guard_active, + }, + ) + + await _run_startup_reconciler(app) + + return report + + +async def drain_background_workers(app: FastAPI) -> None: + workers: list[object] = [] + + declared = getattr(app.state, "background_workers", None) + if isinstance(declared, list): + workers.extend(declared) + + for attr_name in ("execution_writer", "opportunity_writer", "snapshot_writer"): + worker = getattr(app.state, attr_name, None) + if worker is not None: + workers.append(worker) + + seen: set[int] = set() + for worker in workers: + worker_id = id(worker) + if worker_id in seen: + continue + seen.add(worker_id) + + stop_member = getattr(worker, "stop", None) + if stop_member is None or not callable(stop_member): + continue + + result = stop_member() + if inspect.isawaitable(result): + await result + + +async def graceful_shutdown(app: FastAPI) -> None: + controls = _controls(app) + controls.is_running = False + controls.mark_updated() + + _record_audit( + app, + event_type="runtime.shutdown", + decision="initiated", + payload={"execution_status": "stopped"}, + ) + + await drain_background_workers(app) + persist_runtime_snapshot(app, note="graceful_shutdown") diff --git a/src/arbitrade/storage/db.py b/src/arbitrade/storage/db.py index 841575b..63ce1bf 100644 --- a/src/arbitrade/storage/db.py +++ b/src/arbitrade/storage/db.py @@ -78,6 +78,26 @@ CREATE TABLE IF NOT EXISTS market_snapshots ( payload JSON NOT NULL, latency_ms DOUBLE ); + +CREATE TABLE IF NOT EXISTS audit_events ( + id UUID DEFAULT uuid(), + occurred_at TIMESTAMP NOT NULL, + actor VARCHAR NOT NULL, + event_type VARCHAR NOT NULL, + decision VARCHAR NOT NULL, + payload JSON, + correlation_id VARCHAR +); + +CREATE TABLE IF NOT EXISTS runtime_state_snapshots ( + snapshot_at TIMESTAMP NOT NULL, + is_running BOOLEAN NOT NULL, + kill_switch_active BOOLEAN NOT NULL, + kill_switch_reason VARCHAR, + open_trade_count INTEGER NOT NULL, + last_known_balances JSON, + note VARCHAR +); """ diff --git a/src/arbitrade/storage/repositories.py b/src/arbitrade/storage/repositories.py index f4de618..5977f61 100644 --- a/src/arbitrade/storage/repositories.py +++ b/src/arbitrade/storage/repositories.py @@ -66,6 +66,27 @@ class PnLRecord: source: str +@dataclass(slots=True) +class AuditRecord: + occurred_at: datetime + actor: str + event_type: str + decision: str + payload: dict[str, Any] | None = None + correlation_id: str | None = None + + +@dataclass(slots=True) +class RuntimeStateRecord: + snapshot_at: datetime + is_running: bool + kill_switch_active: bool + kill_switch_reason: str | None + open_trade_count: int + last_known_balances: dict[str, Any] | None = None + note: str | None = None + + class MarketSnapshotRepository: def __init__(self, store: DuckDBStore) -> None: self._store = store @@ -217,3 +238,141 @@ class PnLRepository: record.source, ], ) + + +class AuditRepository: + def __init__(self, store: DuckDBStore) -> None: + self._store = store + + def insert(self, record: AuditRecord) -> None: + with self._store.connect() as conn: + conn.execute( + """ + INSERT INTO audit_events ( + occurred_at, + actor, + event_type, + decision, + payload, + correlation_id + ) + VALUES (?, ?, ?, ?, ?, ?) + """, + [ + record.occurred_at, + record.actor, + record.event_type, + record.decision, + ( + None + if record.payload is None + else orjson.dumps(record.payload).decode("utf-8") + ), + record.correlation_id, + ], + ) + + def list_recent(self, *, limit: int = 25) -> list[AuditRecord]: + with self._store.connect() as conn: + rows = conn.execute( + """ + SELECT occurred_at, actor, event_type, decision, payload, correlation_id + FROM audit_events + ORDER BY occurred_at DESC + LIMIT ? + """, + [limit], + ).fetchall() + + records: list[AuditRecord] = [] + for row in rows: + payload: dict[str, Any] | None = None + raw_payload = row[4] + if isinstance(raw_payload, str) and raw_payload: + decoded = orjson.loads(raw_payload) + if isinstance(decoded, dict): + payload = {str(k): decoded[k] for k in decoded} + + records.append( + AuditRecord( + occurred_at=row[0], + actor=str(row[1]), + event_type=str(row[2]), + decision=str(row[3]), + payload=payload, + correlation_id=str(row[5]) if row[5] is not None else None, + ) + ) + + return records + + +class RuntimeStateRepository: + def __init__(self, store: DuckDBStore) -> None: + self._store = store + + def insert(self, record: RuntimeStateRecord) -> None: + with self._store.connect() as conn: + conn.execute( + """ + INSERT INTO runtime_state_snapshots ( + snapshot_at, + is_running, + kill_switch_active, + kill_switch_reason, + open_trade_count, + last_known_balances, + note + ) + VALUES (?, ?, ?, ?, ?, ?, ?) + """, + [ + record.snapshot_at, + record.is_running, + record.kill_switch_active, + record.kill_switch_reason, + record.open_trade_count, + ( + None + if record.last_known_balances is None + else orjson.dumps(record.last_known_balances).decode("utf-8") + ), + record.note, + ], + ) + + def latest(self) -> RuntimeStateRecord | None: + with self._store.connect() as conn: + row = conn.execute(""" + SELECT + snapshot_at, + is_running, + kill_switch_active, + kill_switch_reason, + open_trade_count, + last_known_balances, + note + FROM runtime_state_snapshots + ORDER BY snapshot_at DESC + LIMIT 1 + """).fetchone() + + if row is None: + return None + + balances: dict[str, Any] | None = None + raw_balances = row[5] + if isinstance(raw_balances, str) and raw_balances: + decoded = orjson.loads(raw_balances) + if isinstance(decoded, dict): + balances = {str(key): decoded[key] for key in decoded} + + return RuntimeStateRecord( + snapshot_at=row[0], + is_running=bool(row[1]), + kill_switch_active=bool(row[2]), + kill_switch_reason=str(row[3]) if row[3] is not None else None, + open_trade_count=int(row[4]), + last_known_balances=balances, + note=str(row[6]) if row[6] is not None else None, + ) diff --git a/tests/test_dashboard.py b/tests/test_dashboard.py index 337e7ac..c2f5b18 100644 --- a/tests/test_dashboard.py +++ b/tests/test_dashboard.py @@ -1,6 +1,7 @@ from __future__ import annotations from datetime import UTC, datetime, timedelta +from typing import Any import httpx @@ -8,6 +9,31 @@ from arbitrade.api.app import create_app from arbitrade.config.settings import Settings +class _FakeAlertNotifier: + def __init__(self) -> None: + self.events: list[dict[str, Any]] = [] + + async def notify( + self, + *, + category: str, + severity: str, + title: str, + message: str, + details: dict[str, str] | None = None, + ) -> bool: + self.events.append( + { + "category": category, + "severity": severity, + "title": title, + "message": message, + "details": details or {}, + } + ) + return True + + def _seed_metrics_data(app) -> None: store = app.state.store started = datetime.now(UTC) @@ -135,6 +161,7 @@ async def test_dashboard_page_and_fragment_and_sse(tmp_path) -> None: overview_stream = await client.get("/dashboard/stream/overview") controls = await client.get("/dashboard/fragment/controls") charts = await client.get("/dashboard/fragment/charts") + audit = await client.get("/dashboard/fragment/audit") assert page.status_code == 200 assert "EventSource" in page.text @@ -143,6 +170,7 @@ async def test_dashboard_page_and_fragment_and_sse(tmp_path) -> None: assert 'hx-get="/dashboard/fragment/metrics"' in page.text assert 'hx-get="/dashboard/fragment/controls"' in page.text assert 'hx-get="/dashboard/fragment/charts"' in page.text + assert 'hx-get="/dashboard/fragment/audit"' in page.text assert fragment.status_code == 200 assert "Realized P&L" in fragment.text @@ -163,14 +191,15 @@ async def test_dashboard_page_and_fragment_and_sse(tmp_path) -> None: assert "trade-open" in overview.text assert overview_stream.status_code == 200 - assert overview_stream.headers["content-type"].startswith( - "text/event-stream") + assert overview_stream.headers["content-type"].startswith("text/event-stream") assert "event: overview" in overview_stream.text assert "trade-open" in overview_stream.text assert controls.status_code == 200 assert "Runtime Status" in controls.text assert ">running<" in controls.text + assert "Alerting" in controls.text + assert "Last result" in controls.text assert "Paper trading mode" in controls.text assert "Trade capital USD" in controls.text @@ -179,6 +208,9 @@ async def test_dashboard_page_and_fragment_and_sse(tmp_path) -> None: assert "opportunity-chart" in charts.text assert "Hide chart" in charts.text or "Show chart" in charts.text + assert audit.status_code == 200 + assert "Audit Trail" in audit.text + async def test_dashboard_controls_update_runtime_state_and_config(tmp_path) -> None: app = create_app(Settings(DUCKDB_PATH=tmp_path / "controls.duckdb")) @@ -220,6 +252,36 @@ async def test_dashboard_controls_update_runtime_state_and_config(tmp_path) -> N assert app.state.settings.max_concurrent_trades == 4 assert app.state.settings.paper_trading_mode is True + transport = httpx.ASGITransport(app=app) + async with httpx.AsyncClient(transport=transport, base_url="http://test") as client: + audit_recent = await client.get("/dashboard/api/audit/recent") + + assert audit_recent.status_code == 200 + entries = audit_recent.json()["entries"] + assert len(entries) >= 4 + assert any(entry["event_type"] == "dashboard.control.stop" for entry in entries) + assert any(entry["event_type"] == "dashboard.control.start" for entry in entries) + assert any(entry["event_type"] == "dashboard.control.kill_switch" for entry in entries) + assert any(entry["event_type"] == "dashboard.control.config" for entry in entries) + + +async def test_dashboard_controls_emit_alerts(tmp_path) -> None: + app = create_app(Settings(DUCKDB_PATH=tmp_path / "alerts.duckdb")) + fake_notifier = _FakeAlertNotifier() + app.state.alert_notifier = fake_notifier + + transport = httpx.ASGITransport(app=app) + async with httpx.AsyncClient(transport=transport, base_url="http://test") as client: + await client.post("/dashboard/control/start") + await client.post("/dashboard/control/stop") + await client.post("/dashboard/control/kill-switch", data={"reason": "manual-test"}) + + assert len(fake_notifier.events) == 3 + assert fake_notifier.events[0]["title"] == "Execution started" + assert fake_notifier.events[1]["title"] == "Execution stopped" + assert fake_notifier.events[2]["title"] == "Kill switch activated" + assert fake_notifier.events[2]["details"]["reason"] == "manual-test" + async def test_dashboard_requires_basic_auth_when_configured(tmp_path) -> None: app = create_app( @@ -243,3 +305,17 @@ async def test_dashboard_requires_basic_auth_when_configured(tmp_path) -> None: assert unauthenticated.headers["www-authenticate"] == 'Basic realm="Arbitrade Dashboard"' assert authenticated.status_code == 200 assert health.status_code == 200 + + +async def test_dashboard_alert_status_api_exposes_notifier_snapshot(tmp_path) -> None: + app = create_app(Settings(DUCKDB_PATH=tmp_path / "alerts-status.duckdb")) + + transport = httpx.ASGITransport(app=app) + async with httpx.AsyncClient(transport=transport, base_url="http://test") as client: + response = await client.get("/dashboard/api/alerts/status") + + assert response.status_code == 200 + payload = response.json() + assert payload["enabled"] is True + assert "configured_channels" in payload + assert "last_result" in payload diff --git a/tests/unit/test_alert_notifier.py b/tests/unit/test_alert_notifier.py new file mode 100644 index 0000000..c3aa743 --- /dev/null +++ b/tests/unit/test_alert_notifier.py @@ -0,0 +1,131 @@ +from __future__ import annotations + +from dataclasses import dataclass, field + +import pytest + +from arbitrade.alerting.notifier import AlertEvent, AlertNotifier + + +@dataclass(slots=True) +class _FakeChannel: + events: list[AlertEvent] = field(default_factory=list) + fail: bool = False + + async def send(self, event: AlertEvent) -> None: + if self.fail: + raise RuntimeError("channel send failed") + self.events.append(event) + + +@pytest.mark.asyncio +async def test_alert_notifier_sends_event_when_enabled() -> None: + channel = _FakeChannel() + notifier = AlertNotifier([channel], enabled=True, min_severity="info") + + sent = await notifier.notify( + category="trade", + severity="info", + title="Trade complete", + message="Completed all legs.", + ) + + assert sent is True + assert len(channel.events) == 1 + assert channel.events[0].category == "trade" + + +@pytest.mark.asyncio +async def test_alert_notifier_respects_severity_and_category_filters() -> None: + channel = _FakeChannel() + notifier = AlertNotifier( + [channel], + enabled=True, + min_severity="error", + category_flags={"trade": False, "error": True}, + ) + + low = await notifier.notify( + category="error", + severity="warning", + title="Low", + message="Ignored by severity.", + ) + filtered = await notifier.notify( + category="trade", + severity="critical", + title="Trade", + message="Ignored by category.", + ) + high = await notifier.notify( + category="error", + severity="critical", + title="High", + message="Delivered.", + ) + + assert low is False + assert filtered is False + assert high is True + assert len(channel.events) == 1 + assert channel.events[0].title == "High" + + +@pytest.mark.asyncio +async def test_alert_notifier_applies_dedup_window() -> None: + channel = _FakeChannel() + notifier = AlertNotifier([channel], dedup_seconds=60.0) + + first = await notifier.notify( + category="error", + severity="error", + title="Burst", + message="Same message", + ) + second = await notifier.notify( + category="error", + severity="error", + title="Burst", + message="Same message", + ) + + assert first is True + assert second is False + assert len(channel.events) == 1 + + +@pytest.mark.asyncio +async def test_alert_notifier_returns_false_when_all_channels_fail() -> None: + notifier = AlertNotifier([_FakeChannel(fail=True), _FakeChannel(fail=True)]) + + sent = await notifier.notify( + category="error", + severity="critical", + title="Failure", + message="Both channels fail.", + ) + + assert sent is False + + +@pytest.mark.asyncio +async def test_alert_notifier_exposes_status_snapshot_for_dashboard() -> None: + channel = _FakeChannel() + notifier = AlertNotifier([channel], enabled=True, min_severity="info", dedup_seconds=30.0) + + await notifier.notify( + category="system", + severity="warning", + title="Reconnect", + message="Socket restored.", + ) + + status = notifier.status_snapshot() + + assert status["enabled"] is True + assert status["has_channels"] is True + assert status["configured_channels"] == ["_FakeChannel"] + assert status["last_result"] == "success" + assert status["last_attempted_at"] is not None + assert status["last_success_at"] is not None + assert status["last_event"] is not None diff --git a/tests/unit/test_audit_repository.py b/tests/unit/test_audit_repository.py new file mode 100644 index 0000000..f8317e8 --- /dev/null +++ b/tests/unit/test_audit_repository.py @@ -0,0 +1,34 @@ +from __future__ import annotations + +from datetime import UTC, datetime + +from arbitrade.config.settings import Settings +from arbitrade.storage.db import DuckDBStore +from arbitrade.storage.repositories import AuditRecord, AuditRepository + + +def test_audit_repository_inserts_and_lists_recent(tmp_path) -> None: + settings = Settings(_env_file=None, DUCKDB_PATH=tmp_path / "audit.duckdb") + store = DuckDBStore(settings) + store.migrate() + repository = AuditRepository(store) + + repository.insert( + AuditRecord( + occurred_at=datetime.now(UTC), + actor="dashboard_user", + event_type="dashboard.control.start", + decision="approved", + payload={"execution_status": "running"}, + correlation_id="req-1", + ) + ) + + recent = repository.list_recent(limit=5) + + assert len(recent) == 1 + assert recent[0].actor == "dashboard_user" + assert recent[0].event_type == "dashboard.control.start" + assert recent[0].decision == "approved" + assert recent[0].payload == {"execution_status": "running"} + assert recent[0].correlation_id == "req-1" diff --git a/tests/unit/test_execution_sequencer.py b/tests/unit/test_execution_sequencer.py index dba2f1e..d53adee 100644 --- a/tests/unit/test_execution_sequencer.py +++ b/tests/unit/test_execution_sequencer.py @@ -10,6 +10,31 @@ from arbitrade.detection.engine import OpportunityEvent from arbitrade.execution.sequencer import TriangularExecutionSequencer +@dataclass(slots=True) +class _FakeAlertNotifier: + events: list[dict[str, str]] = field(default_factory=list) + + async def notify( + self, + *, + category: str, + severity: str, + title: str, + message: str, + details: dict[str, str] | None = None, + ) -> bool: + self.events.append( + { + "category": category, + "severity": severity, + "title": title, + "message": message, + **(details or {}), + } + ) + return True + + @dataclass(slots=True) class _FakeRestClient: fail_at_call: int | None = None @@ -42,9 +67,11 @@ def _sample_event(cycle: str = "USD->BTC->ETH->USD") -> OpportunityEvent: @pytest.mark.asyncio async def test_triangular_sequencer_executes_legs_in_order() -> None: client = _FakeRestClient() + notifier = _FakeAlertNotifier() sequencer = TriangularExecutionSequencer( client, available_pairs=["BTC/USD", "ETH/BTC", "ETH/USD"], + alert_notifier=notifier, ) result = await sequencer.execute(_sample_event()) @@ -53,14 +80,19 @@ async def test_triangular_sequencer_executes_legs_in_order() -> None: assert result.completed_legs == 3 assert [call["pair"] for call in client.calls] == ["BTC/USD", "ETH/BTC", "ETH/USD"] assert [call["side"] for call in client.calls] == ["buy", "buy", "sell"] + assert len(notifier.events) == 1 + assert notifier.events[0]["category"] == "trade" + assert notifier.events[0]["title"] == "Trade execution completed" @pytest.mark.asyncio async def test_triangular_sequencer_stops_on_failed_leg() -> None: client = _FakeRestClient(fail_at_call=2) + notifier = _FakeAlertNotifier() sequencer = TriangularExecutionSequencer( client, available_pairs=["BTC/USD", "ETH/BTC", "ETH/USD"], + alert_notifier=notifier, ) result = await sequencer.execute(_sample_event()) @@ -69,6 +101,9 @@ async def test_triangular_sequencer_stops_on_failed_leg() -> None: assert result.completed_legs == 1 assert result.failure_reason is not None assert len(client.calls) == 1 + assert len(notifier.events) == 1 + assert notifier.events[0]["category"] == "error" + assert notifier.events[0]["title"] == "Trade execution failed" def test_triangular_sequencer_rejects_non_closed_cycle() -> None: diff --git a/tests/unit/test_kraken_ws.py b/tests/unit/test_kraken_ws.py index a6c348a..cc521dc 100644 --- a/tests/unit/test_kraken_ws.py +++ b/tests/unit/test_kraken_ws.py @@ -1,7 +1,66 @@ +from __future__ import annotations + +import asyncio +from dataclasses import dataclass, field +from typing import Any + +import orjson +import pytest + from arbitrade.config.settings import Settings from arbitrade.exchange.kraken_ws import KrakenWsClient +@dataclass(slots=True) +class _FakeAlertNotifier: + events: list[dict[str, str]] = field(default_factory=list) + + async def notify( + self, + *, + category: str, + severity: str, + title: str, + message: str, + details: dict[str, str] | None = None, + ) -> bool: + self.events.append( + { + "category": category, + "severity": severity, + "title": title, + "message": message, + **(details or {}), + } + ) + return True + + +class _FakeWebSocket: + def __init__(self, messages: list[Any]) -> None: + self._messages = messages + + async def recv(self) -> str: + if not self._messages: + await asyncio.sleep(0) + return orjson.dumps({"channel": "heartbeat"}).decode("utf-8") + next_item = self._messages.pop(0) + if isinstance(next_item, Exception): + raise next_item + return next_item + + +class _FakeConnectContext: + def __init__(self, ws: _FakeWebSocket) -> None: + self._ws = ws + + async def __aenter__(self) -> _FakeWebSocket: + return self._ws + + async def __aexit__(self, exc_type: object, exc: object, tb: object) -> bool: + return False + + def test_parse_book_delta() -> None: client = KrakenWsClient(Settings()) message = { @@ -24,3 +83,59 @@ def test_parse_book_delta() -> None: assert len(delta.bids) == 1 assert len(delta.asks) == 1 assert delta.checksum == 123 + + +@pytest.mark.asyncio +async def test_connect_stream_emits_disconnect_and_reconnect_alerts( + monkeypatch: pytest.MonkeyPatch, +) -> None: + notifier = _FakeAlertNotifier() + settings = Settings(_env_file=None, WS_HEARTBEAT_TIMEOUT_SECONDS=1.0) + client = KrakenWsClient(settings, alert_notifier=notifier) + + first_payload = orjson.dumps( + {"channel": "book", "symbol": "BTC/USD", "data": [{"bids": [], "asks": []}]} + ).decode("utf-8") + second_payload = orjson.dumps( + {"channel": "book", "symbol": "ETH/USD", "data": [{"bids": [], "asks": []}]} + ).decode("utf-8") + + sessions = [ + _FakeWebSocket([first_payload, RuntimeError("socket dropped")]), + _FakeWebSocket([second_payload]), + ] + + def _fake_connect(*_args: object, **_kwargs: object) -> _FakeConnectContext: + return _FakeConnectContext(sessions.pop(0)) + + monkeypatch.setattr("arbitrade.exchange.kraken_ws.websockets.connect", _fake_connect) + + stream = client.connect_stream() + first = await anext(stream) + second = await anext(stream) + await client.stop() + await stream.aclose() + + assert first.payload["symbol"] == "BTC/USD" + assert second.payload["symbol"] == "ETH/USD" + titles = [event["title"] for event in notifier.events] + assert "WebSocket disconnected" in titles + assert "WebSocket reconnected" in titles + + +@pytest.mark.asyncio +async def test_recv_loop_emits_staleness_alert_on_timeout() -> None: + notifier = _FakeAlertNotifier() + settings = Settings(_env_file=None, WS_HEARTBEAT_TIMEOUT_SECONDS=0.001) + client = KrakenWsClient(settings, alert_notifier=notifier) + + class _NeverReturnsWebSocket: + async def recv(self) -> str: + await asyncio.sleep(1) + return "{}" + + with pytest.raises(TimeoutError): + await anext(client._recv_loop(_NeverReturnsWebSocket())) + + assert len(notifier.events) == 1 + assert notifier.events[0]["title"] == "WebSocket staleness abort" diff --git a/tests/unit/test_loss_limits.py b/tests/unit/test_loss_limits.py index e7b5dc3..083a326 100644 --- a/tests/unit/test_loss_limits.py +++ b/tests/unit/test_loss_limits.py @@ -1,12 +1,39 @@ from __future__ import annotations +import asyncio from datetime import UTC, datetime, timedelta +from typing import Any import pytest from arbitrade.risk.loss_limits import LossLimitGuard +class _FakeAlertNotifier: + def __init__(self) -> None: + self.events: list[dict[str, Any]] = [] + + async def notify( + self, + *, + category: str, + severity: str, + title: str, + message: str, + details: dict[str, str] | None = None, + ) -> bool: + self.events.append( + { + "category": category, + "severity": severity, + "title": title, + "message": message, + "details": details or {}, + } + ) + return True + + def test_loss_limit_guard_tracks_daily_and_cumulative_pnl() -> None: guard = LossLimitGuard(daily_loss_limit=100.0, cumulative_loss_limit=200.0) t0 = datetime.now(UTC) @@ -47,3 +74,17 @@ def test_loss_limit_guard_rejects_invalid_limits() -> None: with pytest.raises(ValueError, match="cumulative_loss_limit"): LossLimitGuard(cumulative_loss_limit=-1.0) + + +@pytest.mark.asyncio +async def test_loss_limit_guard_emits_alert_on_breach() -> None: + notifier = _FakeAlertNotifier() + guard = LossLimitGuard(daily_loss_limit=50.0, alert_notifier=notifier) + + guard.register_realized_pnl(-60.0, at=datetime.now(UTC)) + await asyncio.sleep(0) + + assert guard.is_halted + assert len(notifier.events) == 1 + assert notifier.events[0]["category"] == "threshold" + assert notifier.events[0]["title"] == "Daily loss limit breached" diff --git a/tests/unit/test_market_data_feed.py b/tests/unit/test_market_data_feed.py index 4b42ecd..4f67cbb 100644 --- a/tests/unit/test_market_data_feed.py +++ b/tests/unit/test_market_data_feed.py @@ -1,5 +1,6 @@ from __future__ import annotations +import asyncio from dataclasses import dataclass from datetime import UTC, datetime from types import SimpleNamespace @@ -83,6 +84,31 @@ class _FakeFailingExecutor: raise RuntimeError("executor failure") +class _FakeAlertNotifier: + def __init__(self) -> None: + self.events: list[dict[str, str]] = [] + + async def notify( + self, + *, + category: str, + severity: str, + title: str, + message: str, + details: dict[str, str] | None = None, + ) -> bool: + self.events.append( + { + "category": category, + "severity": severity, + "title": title, + "message": message, + **(details or {}), + } + ) + return True + + @dataclass(slots=True) class _FakeWsClientTwoMessages: delta: BookDelta @@ -430,3 +456,29 @@ async def test_market_data_feed_halts_on_repeated_execution_failures() -> None: assert stop_guard.halted_reason == "consecutive_failures_limit_breached" assert kill_switch.is_active assert kill_switch.reason == "consecutive_failures_limit_breached" + + +@pytest.mark.asyncio +async def test_market_data_feed_emits_critical_alert_on_executor_exception() -> None: + event = _sample_event(allocated_capital=75.0) + detector = _FakeDetector(event) + executor = _FakeFailingExecutor() + notifier = _FakeAlertNotifier() + feed = MarketDataFeed( + ws_client=_FakeWsClient(_sample_delta()), + snapshot_writer=_FakeSnapshotWriter(), + detector=detector, + opportunity_writer=_FakeOpportunityWriter(), + paper_trading_mode=False, + opportunity_executor=executor.execute, + alert_notifier=notifier, + ) + + await feed.run() + await asyncio.sleep(0) + + assert executor.calls == 1 + assert len(notifier.events) == 1 + assert notifier.events[0]["category"] == "system" + assert notifier.events[0]["severity"] == "critical" + assert notifier.events[0]["title"] == "Critical execution exception" diff --git a/tests/unit/test_runtime_lifecycle.py b/tests/unit/test_runtime_lifecycle.py new file mode 100644 index 0000000..94cc753 --- /dev/null +++ b/tests/unit/test_runtime_lifecycle.py @@ -0,0 +1,140 @@ +from __future__ import annotations + +from dataclasses import dataclass +from datetime import UTC, datetime + +import pytest + +from arbitrade.api.app import create_app +from arbitrade.config.settings import Settings +from arbitrade.runtime.lifecycle import ( + graceful_shutdown, + persist_runtime_snapshot, + restore_runtime_state, +) +from arbitrade.storage.repositories import RuntimeStateRecord + + +@dataclass(slots=True) +class _FakeWorker: + stopped: bool = False + + async def stop(self) -> None: + self.stopped = True + + +@dataclass(slots=True) +class _FakeStartupReconciler: + called: bool = False + + async def reconcile_open_trades(self) -> None: + self.called = True + + +@pytest.mark.asyncio +async def test_persist_runtime_snapshot_writes_record(tmp_path) -> None: + app = create_app(Settings(_env_file=None, DUCKDB_PATH=tmp_path / "runtime.duckdb")) + + app.state.dashboard_controls.is_running = True + app.state.dashboard_controls.kill_switch.deactivate() + + snapshot = persist_runtime_snapshot(app, note="unit-test") + + assert snapshot is not None + assert snapshot.note == "unit-test" + + latest = app.state.runtime_state_repository.latest() + assert latest is not None + assert latest.note == "unit-test" + assert latest.is_running is True + + +@pytest.mark.asyncio +async def test_restore_runtime_state_applies_snapshot(tmp_path) -> None: + app = create_app(Settings(_env_file=None, DUCKDB_PATH=tmp_path / "restore.duckdb")) + app.state.runtime_state_repository.insert( + RuntimeStateRecord( + snapshot_at=datetime.now(UTC), + is_running=False, + kill_switch_active=True, + kill_switch_reason="manual-stop", + open_trade_count=0, + last_known_balances={"USD": 100.0}, + note="seed", + ) + ) + + report = await restore_runtime_state(app) + + assert report.restored_from_snapshot is True + assert app.state.dashboard_controls.is_running is False + assert app.state.dashboard_controls.kill_switch.is_active is True + assert app.state.dashboard_controls.kill_switch.reason == "manual-stop" + + +@pytest.mark.asyncio +async def test_restore_runtime_state_enables_restart_guard_for_open_trades(tmp_path) -> None: + app = create_app(Settings(_env_file=None, DUCKDB_PATH=tmp_path / "open-trades.duckdb")) + + with app.state.store.connect() as conn: + conn.execute( + """ + INSERT INTO trades ( + trade_ref, + started_at, + finished_at, + status, + realized_pnl, + estimated_pnl, + capital_used, + cycle, + leg_count + ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?) + """, + [ + "open-trade-1", + datetime.now(UTC), + None, + "open", + None, + 1.0, + 100.0, + "USD->BTC->ETH->USD", + 3, + ], + ) + + report = await restore_runtime_state(app) + + assert report.open_trades_detected == 1 + assert report.restart_guard_active is True + assert app.state.dashboard_controls.is_running is False + assert app.state.dashboard_controls.kill_switch.is_active is True + assert app.state.dashboard_controls.kill_switch.reason == "recovery_open_trades_detected" + + +@pytest.mark.asyncio +async def test_graceful_shutdown_drains_workers_and_persists_snapshot(tmp_path) -> None: + app = create_app(Settings(_env_file=None, DUCKDB_PATH=tmp_path / "shutdown.duckdb")) + worker = _FakeWorker() + app.state.background_workers = [worker] + app.state.dashboard_controls.is_running = True + + await graceful_shutdown(app) + + assert worker.stopped is True + assert app.state.dashboard_controls.is_running is False + latest = app.state.runtime_state_repository.latest() + assert latest is not None + assert latest.note == "graceful_shutdown" + + +@pytest.mark.asyncio +async def test_restore_runtime_state_calls_startup_reconciler(tmp_path) -> None: + app = create_app(Settings(_env_file=None, DUCKDB_PATH=tmp_path / "reconciler.duckdb")) + reconciler = _FakeStartupReconciler() + app.state.startup_reconciler = reconciler + + await restore_runtime_state(app) + + assert reconciler.called is True diff --git a/tests/unit/test_settings_validation.py b/tests/unit/test_settings_validation.py new file mode 100644 index 0000000..9dd77cd --- /dev/null +++ b/tests/unit/test_settings_validation.py @@ -0,0 +1,51 @@ +import pytest +from pydantic import ValidationError + +from arbitrade.config.settings import Settings + + +def test_dashboard_auth_requires_both_fields() -> None: + with pytest.raises(ValidationError): + Settings(_env_file=None, DASHBOARD_AUTH_USERNAME="admin") + + +def test_kraken_api_auth_requires_key_and_secret() -> None: + with pytest.raises(ValidationError): + Settings(_env_file=None, KRAKEN_API_KEY="key-only") + + +def test_kraken_permissions_require_query_and_trade() -> None: + with pytest.raises(ValidationError): + Settings( + _env_file=None, + KRAKEN_API_KEY="k", + KRAKEN_API_SECRET="s", + KRAKEN_API_KEY_PERMISSIONS="query", + ) + + +def test_kraken_permissions_forbid_withdrawal_scope() -> None: + with pytest.raises(ValidationError): + Settings( + _env_file=None, + KRAKEN_API_KEY="k", + KRAKEN_API_SECRET="s", + KRAKEN_API_KEY_PERMISSIONS="query,trade,withdraw", + ) + + +def test_alert_min_severity_is_validated() -> None: + with pytest.raises(ValidationError): + Settings(_env_file=None, ALERT_MIN_SEVERITY="nope") + + +def test_valid_security_configuration_passes() -> None: + settings = Settings( + _env_file=None, + KRAKEN_API_KEY="k", + KRAKEN_API_SECRET="s", + KRAKEN_API_KEY_PERMISSIONS="query,trade", + ALERT_MIN_SEVERITY="warning", + ) + + assert settings.kraken_api_key_permissions == "query,trade" diff --git a/tests/unit/test_stop_conditions.py b/tests/unit/test_stop_conditions.py index 2dd6b35..5ff40ff 100644 --- a/tests/unit/test_stop_conditions.py +++ b/tests/unit/test_stop_conditions.py @@ -1,10 +1,38 @@ from __future__ import annotations +import asyncio +from typing import Any + import pytest from arbitrade.risk.stop_conditions import StopConditionsGuard +class _FakeAlertNotifier: + def __init__(self) -> None: + self.events: list[dict[str, Any]] = [] + + async def notify( + self, + *, + category: str, + severity: str, + title: str, + message: str, + details: dict[str, str] | None = None, + ) -> bool: + self.events.append( + { + "category": category, + "severity": severity, + "title": title, + "message": message, + "details": details or {}, + } + ) + return True + + def test_stop_conditions_guard_halts_on_source_latency_breach() -> None: guard = StopConditionsGuard(max_source_latency_ms=50.0) @@ -55,3 +83,17 @@ def test_stop_conditions_guard_rejects_invalid_configuration() -> None: with pytest.raises(ValueError, match="max_consecutive_failures"): StopConditionsGuard(max_consecutive_failures=0) + + +@pytest.mark.asyncio +async def test_stop_conditions_guard_emits_alert_on_failure_threshold() -> None: + notifier = _FakeAlertNotifier() + guard = StopConditionsGuard(max_consecutive_failures=1, alert_notifier=notifier) + + guard.register_failure() + await asyncio.sleep(0) + + assert guard.is_halted + assert len(notifier.events) == 1 + assert notifier.events[0]["category"] == "threshold" + assert notifier.events[0]["title"] == "Consecutive failures limit breached" diff --git a/tests/unit/test_trade_limits.py b/tests/unit/test_trade_limits.py index eab8366..9533c87 100644 --- a/tests/unit/test_trade_limits.py +++ b/tests/unit/test_trade_limits.py @@ -1,10 +1,38 @@ from __future__ import annotations +import asyncio +from typing import Any + import pytest from arbitrade.risk.trade_limits import TradeLimitsGuard +class _FakeAlertNotifier: + def __init__(self) -> None: + self.events: list[dict[str, Any]] = [] + + async def notify( + self, + *, + category: str, + severity: str, + title: str, + message: str, + details: dict[str, str] | None = None, + ) -> bool: + self.events.append( + { + "category": category, + "severity": severity, + "title": title, + "message": message, + "details": details or {}, + } + ) + return True + + def test_trade_limits_guard_blocks_when_max_concurrent_reached() -> None: guard = TradeLimitsGuard(max_concurrent_trades=1) @@ -39,3 +67,18 @@ def test_trade_limits_guard_rejects_invalid_configuration() -> None: with pytest.raises(ValueError, match="max_exposure_per_asset"): TradeLimitsGuard(max_exposure_per_asset=0.0) + + +@pytest.mark.asyncio +async def test_trade_limits_guard_emits_alert_when_rejecting_trade() -> None: + notifier = _FakeAlertNotifier() + guard = TradeLimitsGuard(max_concurrent_trades=1, alert_notifier=notifier) + + guard.open_trade({"BTC": 10.0}) + allowed = guard.is_trade_allowed({"BTC": 1.0}) + await asyncio.sleep(0) + + assert not allowed + assert len(notifier.events) == 1 + assert notifier.events[0]["category"] == "threshold" + assert notifier.events[0]["title"] == "Concurrent trade limit reached" diff --git a/web/templates/dashboard.html b/web/templates/dashboard.html index e7c2ca1..4f00d09 100644 --- a/web/templates/dashboard.html +++ b/web/templates/dashboard.html @@ -60,6 +60,16 @@ head_scripts %} > {% include "partials/charts.html" %} + +
+ {% include "partials/audit.html" %} +
{% endblock %} {% block scripts %}