From 7728f9a8cdb5016ceb6810a37fcec0f859afc103 Mon Sep 17 00:00:00 2001 From: zwitschi Date: Thu, 4 Jun 2026 18:39:17 +0200 Subject: [PATCH] feat: enhance backtesting functionality with database integration and UI updates --- src/arbitrade/api/routes.py | 60 +++++++--- src/arbitrade/backtesting/replay.py | 104 +++++++++++++++++- src/arbitrade/backtesting/runner.py | 62 ++++++++--- .../templates/partials/backtesting_panel.html | 27 ++++- .../web/templates/partials/overview.html | 4 +- 5 files changed, 214 insertions(+), 43 deletions(-) diff --git a/src/arbitrade/api/routes.py b/src/arbitrade/api/routes.py index 6a84cd7..0d79ece 100644 --- a/src/arbitrade/api/routes.py +++ b/src/arbitrade/api/routes.py @@ -10,7 +10,7 @@ from typing import cast from urllib.parse import parse_qs import duckdb -from fastapi import APIRouter, Depends, Request +from fastapi import APIRouter, Depends, Request, Response from fastapi.responses import HTMLResponse, JSONResponse, StreamingResponse from fastapi.templating import Jinja2Templates @@ -138,7 +138,7 @@ def _dashboard_overview(request: Request) -> dict[str, object]: non_zero = {k: float(v) for k, v in parsed.items() if float(v) > 0.0} if non_zero: - balances_value = ", ".join( + balances_value = "
".join( f"{v:.6g} {k}" for k, v in sorted(non_zero.items()) ) else: @@ -720,7 +720,9 @@ def _backtesting_panel_context( defaults: dict[str, str] | None = None, ) -> dict[str, object]: default_values = { - "events_path": "", + "symbols": "", + "start_time": "", + "end_time": "", "starting_balances": "USD=1000.0", "trade_capital": "100.0", "min_profit_threshold": "0.0005", @@ -926,7 +928,6 @@ async def dashboard_backtesting_run(request: Request) -> HTMLResponse: """Submit a backtest job to the async queue. Returns panel with job list.""" form = _parse_form_body(await request.body()) defaults = { - "events_path": form.get("events_path", ""), "starting_balances": form.get("starting_balances", "USD=1000.0"), "trade_capital": form.get("trade_capital", "100.0"), "min_profit_threshold": form.get("min_profit_threshold", "0.0005"), @@ -934,14 +935,13 @@ async def dashboard_backtesting_run(request: Request) -> HTMLResponse: "custom_fee_rate": form.get("custom_fee_rate", ""), "slippage_bps": form.get("slippage_bps", "4.0"), "execution_latency_ms": form.get("execution_latency_ms", "20.0"), + "start_time": form.get("start_time", ""), + "end_time": form.get("end_time", ""), + "symbols": form.get("symbols", ""), + "source": form.get("source", "db"), } try: - events_path = _resolve_workspace_path(defaults["events_path"]) - if not events_path.exists() or not events_path.is_file(): - raise ValueError( - "events_path must reference an existing JSONL file") - custom_fee_rate = ( float(defaults["custom_fee_rate"]) if defaults["custom_fee_rate"].strip() else None @@ -949,8 +949,8 @@ async def dashboard_backtesting_run(request: Request) -> HTMLResponse: fee_rate = _fee_rate_for_profile( defaults["fee_profile"], custom_fee_rate, request=request) - config_dict = { - "events_path": _display_path(events_path), + config_dict: dict[str, object] = { + "source": defaults["source"], "starting_balances": defaults["starting_balances"], "trade_capital": float(defaults["trade_capital"]), "min_profit_threshold": float(defaults["min_profit_threshold"]), @@ -958,23 +958,26 @@ async def dashboard_backtesting_run(request: Request) -> HTMLResponse: "fee_profile": defaults["fee_profile"], "slippage_bps": float(defaults["slippage_bps"]), "execution_latency_ms": float(defaults["execution_latency_ms"]), + "start_time": defaults["start_time"], + "end_time": defaults["end_time"], + "symbols": defaults["symbols"], } store = request.app.state.store repo = BacktestJobRepository(store) - job = repo.create_job(str(events_path), config_dict) + events_label = defaults["symbols"] if defaults["symbols"] else "DB-sourced" + job = repo.create_job(events_label, config_dict) msg_job = job.id[:8] if job.id else "unknown" queue = request.app.state.backtest_queue - await queue.put((job.id or "", str(events_path), config_dict)) + await queue.put((job.id or "", config_dict)) _record_audit( request, actor="dashboard_user", event_type="dashboard.backtesting.submit", decision="queued", - payload={"job_id": job.id, - "events_path": _display_path(events_path)}, + payload={"job_id": job.id, "source": defaults["source"]}, ) context = _backtesting_panel_context( @@ -1042,6 +1045,33 @@ async def dashboard_backtesting_job_detail(request: Request, job_id: str) -> HTM return HTMLResponse(report_html) +@router.get("/dashboard/backtesting/job/{job_id}/export", response_class=Response) +async def dashboard_backtesting_export(request: Request, job_id: str) -> Response: + store = request.app.state.store + repo = BacktestJobRepository(store) + job = repo.get_job(job_id) + if job is None: + return Response("Job not found", status_code=404) + + payload: dict[str, object] = { + "job_id": job_id, + "status": job.status, + "events_path": job.events_path, + "created_at": job.created_at.isoformat() if job.created_at else None, + } + if job.report: + payload["report"] = job.report + if job.config: + payload["config"] = job.config + + return Response( + content=orjson.dumps(payload).decode("utf-8"), + media_type="application/x-jsonlines", + headers={ + "Content-Disposition": f"attachment; filename=backtest_{job_id[:8]}.jsonl"}, + ) + + @router.post("/dashboard/control/start", response_class=HTMLResponse) async def dashboard_control_start(request: Request) -> HTMLResponse: controls = _dashboard_controls_state(request) diff --git a/src/arbitrade/backtesting/replay.py b/src/arbitrade/backtesting/replay.py index 4daf083..d4fdd54 100644 --- a/src/arbitrade/backtesting/replay.py +++ b/src/arbitrade/backtesting/replay.py @@ -153,7 +153,8 @@ def _parse_book_levels(raw_levels: Any) -> tuple[BookLevel, ...]: or not isinstance(raw_level[1], int | float) ): raise ValueError("Each level must be [price, volume]") - levels.append(BookLevel(price=float(raw_level[0]), volume=float(raw_level[1]))) + levels.append(BookLevel(price=float( + raw_level[0]), volume=float(raw_level[1]))) return tuple(levels) @@ -172,7 +173,8 @@ def load_replay_events(path: Path) -> list[ReplayBookEvent]: if not isinstance(timestamp_raw, str) or not isinstance(symbol_raw, str): raise ValueError("Each event must include timestamp and symbol") - occurred_at = datetime.fromisoformat(timestamp_raw.replace("Z", "+00:00")).astimezone(UTC) + occurred_at = datetime.fromisoformat( + timestamp_raw.replace("Z", "+00:00")).astimezone(UTC) events.append( ReplayBookEvent( occurred_at=occurred_at, @@ -185,6 +187,92 @@ def load_replay_events(path: Path) -> list[ReplayBookEvent]: return sorted(events, key=lambda event: event.occurred_at) +def load_replay_events_from_db( + store: object, + *, + symbols: list[str] | None = None, + start: datetime | None = None, + end: datetime | None = None, +) -> list[ReplayBookEvent]: + """Load replay events from market_snapshots table. + + Each market_snapshots row has snapshot_at, symbol, payload (raw Kraken WS). + Payload format: {channel, symbol, data: [{bids: [{price, qty}], asks: [{price, qty}]}]} + """ + with store.connect() as conn: # type: ignore[union-attr] + query = "SELECT snapshot_at, symbol, payload FROM market_snapshots WHERE 1=1" + params: list[object] = [] + + if symbols: + placeholders = ",".join("?" for _ in symbols) + query += f" AND symbol IN ({placeholders})" + params.extend(symbols) + + if start is not None: + query += " AND snapshot_at >= ?" + params.append(start) + + if end is not None: + query += " AND snapshot_at <= ?" + params.append(end) + + query += " ORDER BY snapshot_at ASC" + # type: ignore[union-attr] + rows = conn.execute(query, params).fetchall() + + events: list[ReplayBookEvent] = [] + for row in rows: + snapshot_at: datetime = row[0] + symbol: str = row[1] + payload_raw = row[2] + + if isinstance(payload_raw, str): + payload = orjson.loads(payload_raw) + elif isinstance(payload_raw, dict): + payload = payload_raw + else: + continue + + data = payload.get("data") + if not isinstance(data, list) or not data: + continue + + first = data[0] + if not isinstance(first, dict): + continue + + bids = _parse_kraken_book_levels(first.get("bids")) + asks = _parse_kraken_book_levels(first.get("asks")) + + if bids or asks: + events.append( + ReplayBookEvent( + occurred_at=snapshot_at, + symbol=symbol, + bids=bids, + asks=asks, + ) + ) + + return events + + +def _parse_kraken_book_levels( + raw_levels: object | None, +) -> tuple[BookLevel, ...]: + """Parse Kraken WS book level format: [{price, qty}, ...].""" + if not isinstance(raw_levels, list): + return () + levels: list[BookLevel] = [] + for level in raw_levels: + if isinstance(level, dict) and "price" in level and "qty" in level: + levels.append( + BookLevel(price=float(level["price"]), + volume=float(level["qty"])) + ) + return tuple(levels) + + class BacktestReplayEngine: def __init__( self, @@ -206,7 +294,8 @@ class BacktestReplayEngine: min_order_size_by_pair=config.min_order_size_by_pair, ) self._pre_trade = PreTradeValidator() - self._trade_limits = TradeLimitsGuard(max_concurrent_trades=config.max_concurrent_trades) + self._trade_limits = TradeLimitsGuard( + max_concurrent_trades=config.max_concurrent_trades) self._simulated_rest = _SimulatedRestClient( self._clock, slippage_bps=config.slippage_bps, @@ -241,7 +330,8 @@ class BacktestReplayEngine: trades_executed = 0 realized_pnl = 0.0 - equity = float(starting_balances.get(self._config.quote_asset.upper(), 0.0)) + equity = float(starting_balances.get( + self._config.quote_asset.upper(), 0.0)) peak_equity = equity max_drawdown = 0.0 @@ -284,7 +374,8 @@ class BacktestReplayEngine: result = await self._sequencer.execute(opportunity) self._trade_limits.close_trade(exposure) - execution_latencies.append(self._simulated_rest.last_trade_latency_ms) + execution_latencies.append( + self._simulated_rest.last_trade_latency_ms) fill_samples.append(self._simulated_rest.last_fill_ratio) if not result.success: @@ -307,7 +398,8 @@ class BacktestReplayEngine: wins = sum(1 for pnl in realized_samples if pnl > 0.0) win_rate = (wins / len(realized_samples)) if realized_samples else None - fill_rate = (sum(fill_samples) / len(fill_samples)) if fill_samples else None + fill_rate = (sum(fill_samples) / len(fill_samples) + ) if fill_samples else None return BacktestReport( started_at=events[0].occurred_at if events else self._clock.now, diff --git a/src/arbitrade/backtesting/runner.py b/src/arbitrade/backtesting/runner.py index 260d9d2..2dbd904 100644 --- a/src/arbitrade/backtesting/runner.py +++ b/src/arbitrade/backtesting/runner.py @@ -8,7 +8,12 @@ from pathlib import Path import structlog -from arbitrade.backtesting.replay import BacktestConfig, BacktestReplayEngine, load_replay_events +from arbitrade.backtesting.replay import ( + BacktestConfig, + BacktestReplayEngine, + load_replay_events, + load_replay_events_from_db, +) from arbitrade.detection.graph import CurrencyGraph, TriangularCycle from arbitrade.storage.db import DuckDBStore from arbitrade.storage.repositories import BacktestJobRepository @@ -44,26 +49,51 @@ def _parse_balances(raw: str) -> dict[str, float]: async def run_backtest_job( job_id: str, - events_path: str, config_dict: dict[str, object] | None, store: DuckDBStore, ) -> None: - """Execute a single backtest job: load events, run engine, store report to DB.""" + """Execute a single backtest job: load events from DB or file, run engine, store report.""" repo = BacktestJobRepository(store) repo.update_status(job_id, "running") - _LOG.info("backtest_job_started", job_id=job_id, events_path=events_path) + _LOG.info("backtest_job_started", job_id=job_id) try: - path = Path(events_path) - if not path.is_absolute(): - path = Path("data") / path - path = path.resolve() - - events = load_replay_events(path) - if not events: - raise ValueError(f"No events found in {path}") - config = config_dict or {} + events_path = str(config.get("events_path", "")) + symbols_raw = config.get("symbols") + source = str(config.get("source", "db")) + + start_dt = None + end_dt = None + if source == "db": + start_str = config.get("start_time") + end_str = config.get("end_time") + if isinstance(start_str, str) and start_str: + start_dt = datetime.fromisoformat( + start_str.replace("Z", "+00:00")) + if isinstance(end_str, str) and end_str: + end_dt = datetime.fromisoformat(end_str.replace("Z", "+00:00")) + + symbols: list[str] | None = None + if isinstance(symbols_raw, str) and symbols_raw.strip(): + symbols = [s.strip().upper() + for s in symbols_raw.split(",") if s.strip()] + elif isinstance(symbols_raw, list): + symbols = [str(s).upper() for s in symbols_raw] + + events = load_replay_events_from_db( + store, symbols=symbols, start=start_dt, end=end_dt, + ) + else: + path = Path(events_path) + if not path.is_absolute(): + path = Path("data") / path + path = path.resolve() + events = load_replay_events(path) + + if not events: + raise ValueError("No events found for backtest") + starting_balances_raw = str(config.get( "starting_balances", "USD=1000.0")) starting_balances = _parse_balances(starting_balances_raw) @@ -123,7 +153,7 @@ async def run_backtest_job( async def backtest_worker( - queue: asyncio.Queue[tuple[str, str, dict[str, object] | None] | None], + queue: asyncio.Queue[tuple[str, dict[str, object] | None] | None], store: DuckDBStore, ) -> None: """Worker coroutine: pull jobs from queue and execute them one at a time.""" @@ -133,9 +163,9 @@ async def backtest_worker( if item is None: queue.task_done() break - job_id, events_path, config = item + job_id, config = item try: - await run_backtest_job(job_id, events_path, config, store) + await run_backtest_job(job_id, config, store) except Exception: _LOG.exception("backtest_worker_unhandled_error", job_id=job_id) finally: diff --git a/src/arbitrade/web/templates/partials/backtesting_panel.html b/src/arbitrade/web/templates/partials/backtesting_panel.html index bcccd8a..18e1be4 100644 --- a/src/arbitrade/web/templates/partials/backtesting_panel.html +++ b/src/arbitrade/web/templates/partials/backtesting_panel.html @@ -42,13 +42,32 @@ hx-target="#backtesting-shell" hx-swap="outerHTML" > + + +