feat: enhance backtesting functionality with database integration and UI updates
CI / lint-test-build (push) Failing after 1m20s

This commit is contained in:
2026-06-04 18:39:17 +02:00
parent a83d231d06
commit 7728f9a8cd
5 changed files with 214 additions and 43 deletions
+45 -15
View File
@@ -10,7 +10,7 @@ from typing import cast
from urllib.parse import parse_qs from urllib.parse import parse_qs
import duckdb import duckdb
from fastapi import APIRouter, Depends, Request from fastapi import APIRouter, Depends, Request, Response
from fastapi.responses import HTMLResponse, JSONResponse, StreamingResponse from fastapi.responses import HTMLResponse, JSONResponse, StreamingResponse
from fastapi.templating import Jinja2Templates from fastapi.templating import Jinja2Templates
@@ -138,7 +138,7 @@ def _dashboard_overview(request: Request) -> dict[str, object]:
non_zero = {k: float(v) non_zero = {k: float(v)
for k, v in parsed.items() if float(v) > 0.0} for k, v in parsed.items() if float(v) > 0.0}
if non_zero: if non_zero:
balances_value = ", ".join( balances_value = "<br>".join(
f"{v:.6g} {k}" for k, v in sorted(non_zero.items()) f"{v:.6g} {k}" for k, v in sorted(non_zero.items())
) )
else: else:
@@ -720,7 +720,9 @@ def _backtesting_panel_context(
defaults: dict[str, str] | None = None, defaults: dict[str, str] | None = None,
) -> dict[str, object]: ) -> dict[str, object]:
default_values = { default_values = {
"events_path": "", "symbols": "",
"start_time": "",
"end_time": "",
"starting_balances": "USD=1000.0", "starting_balances": "USD=1000.0",
"trade_capital": "100.0", "trade_capital": "100.0",
"min_profit_threshold": "0.0005", "min_profit_threshold": "0.0005",
@@ -926,7 +928,6 @@ async def dashboard_backtesting_run(request: Request) -> HTMLResponse:
"""Submit a backtest job to the async queue. Returns panel with job list.""" """Submit a backtest job to the async queue. Returns panel with job list."""
form = _parse_form_body(await request.body()) form = _parse_form_body(await request.body())
defaults = { defaults = {
"events_path": form.get("events_path", ""),
"starting_balances": form.get("starting_balances", "USD=1000.0"), "starting_balances": form.get("starting_balances", "USD=1000.0"),
"trade_capital": form.get("trade_capital", "100.0"), "trade_capital": form.get("trade_capital", "100.0"),
"min_profit_threshold": form.get("min_profit_threshold", "0.0005"), "min_profit_threshold": form.get("min_profit_threshold", "0.0005"),
@@ -934,14 +935,13 @@ async def dashboard_backtesting_run(request: Request) -> HTMLResponse:
"custom_fee_rate": form.get("custom_fee_rate", ""), "custom_fee_rate": form.get("custom_fee_rate", ""),
"slippage_bps": form.get("slippage_bps", "4.0"), "slippage_bps": form.get("slippage_bps", "4.0"),
"execution_latency_ms": form.get("execution_latency_ms", "20.0"), "execution_latency_ms": form.get("execution_latency_ms", "20.0"),
"start_time": form.get("start_time", ""),
"end_time": form.get("end_time", ""),
"symbols": form.get("symbols", ""),
"source": form.get("source", "db"),
} }
try: try:
events_path = _resolve_workspace_path(defaults["events_path"])
if not events_path.exists() or not events_path.is_file():
raise ValueError(
"events_path must reference an existing JSONL file")
custom_fee_rate = ( custom_fee_rate = (
float(defaults["custom_fee_rate"]) float(defaults["custom_fee_rate"])
if defaults["custom_fee_rate"].strip() else None if defaults["custom_fee_rate"].strip() else None
@@ -949,8 +949,8 @@ async def dashboard_backtesting_run(request: Request) -> HTMLResponse:
fee_rate = _fee_rate_for_profile( fee_rate = _fee_rate_for_profile(
defaults["fee_profile"], custom_fee_rate, request=request) defaults["fee_profile"], custom_fee_rate, request=request)
config_dict = { config_dict: dict[str, object] = {
"events_path": _display_path(events_path), "source": defaults["source"],
"starting_balances": defaults["starting_balances"], "starting_balances": defaults["starting_balances"],
"trade_capital": float(defaults["trade_capital"]), "trade_capital": float(defaults["trade_capital"]),
"min_profit_threshold": float(defaults["min_profit_threshold"]), "min_profit_threshold": float(defaults["min_profit_threshold"]),
@@ -958,23 +958,26 @@ async def dashboard_backtesting_run(request: Request) -> HTMLResponse:
"fee_profile": defaults["fee_profile"], "fee_profile": defaults["fee_profile"],
"slippage_bps": float(defaults["slippage_bps"]), "slippage_bps": float(defaults["slippage_bps"]),
"execution_latency_ms": float(defaults["execution_latency_ms"]), "execution_latency_ms": float(defaults["execution_latency_ms"]),
"start_time": defaults["start_time"],
"end_time": defaults["end_time"],
"symbols": defaults["symbols"],
} }
store = request.app.state.store store = request.app.state.store
repo = BacktestJobRepository(store) repo = BacktestJobRepository(store)
job = repo.create_job(str(events_path), config_dict) events_label = defaults["symbols"] if defaults["symbols"] else "DB-sourced"
job = repo.create_job(events_label, config_dict)
msg_job = job.id[:8] if job.id else "unknown" msg_job = job.id[:8] if job.id else "unknown"
queue = request.app.state.backtest_queue queue = request.app.state.backtest_queue
await queue.put((job.id or "", str(events_path), config_dict)) await queue.put((job.id or "", config_dict))
_record_audit( _record_audit(
request, request,
actor="dashboard_user", actor="dashboard_user",
event_type="dashboard.backtesting.submit", event_type="dashboard.backtesting.submit",
decision="queued", decision="queued",
payload={"job_id": job.id, payload={"job_id": job.id, "source": defaults["source"]},
"events_path": _display_path(events_path)},
) )
context = _backtesting_panel_context( context = _backtesting_panel_context(
@@ -1042,6 +1045,33 @@ async def dashboard_backtesting_job_detail(request: Request, job_id: str) -> HTM
return HTMLResponse(report_html) return HTMLResponse(report_html)
@router.get("/dashboard/backtesting/job/{job_id}/export", response_class=Response)
async def dashboard_backtesting_export(request: Request, job_id: str) -> Response:
store = request.app.state.store
repo = BacktestJobRepository(store)
job = repo.get_job(job_id)
if job is None:
return Response("Job not found", status_code=404)
payload: dict[str, object] = {
"job_id": job_id,
"status": job.status,
"events_path": job.events_path,
"created_at": job.created_at.isoformat() if job.created_at else None,
}
if job.report:
payload["report"] = job.report
if job.config:
payload["config"] = job.config
return Response(
content=orjson.dumps(payload).decode("utf-8"),
media_type="application/x-jsonlines",
headers={
"Content-Disposition": f"attachment; filename=backtest_{job_id[:8]}.jsonl"},
)
@router.post("/dashboard/control/start", response_class=HTMLResponse) @router.post("/dashboard/control/start", response_class=HTMLResponse)
async def dashboard_control_start(request: Request) -> HTMLResponse: async def dashboard_control_start(request: Request) -> HTMLResponse:
controls = _dashboard_controls_state(request) controls = _dashboard_controls_state(request)
+98 -6
View File
@@ -153,7 +153,8 @@ def _parse_book_levels(raw_levels: Any) -> tuple[BookLevel, ...]:
or not isinstance(raw_level[1], int | float) or not isinstance(raw_level[1], int | float)
): ):
raise ValueError("Each level must be [price, volume]") raise ValueError("Each level must be [price, volume]")
levels.append(BookLevel(price=float(raw_level[0]), volume=float(raw_level[1]))) levels.append(BookLevel(price=float(
raw_level[0]), volume=float(raw_level[1])))
return tuple(levels) return tuple(levels)
@@ -172,7 +173,8 @@ def load_replay_events(path: Path) -> list[ReplayBookEvent]:
if not isinstance(timestamp_raw, str) or not isinstance(symbol_raw, str): if not isinstance(timestamp_raw, str) or not isinstance(symbol_raw, str):
raise ValueError("Each event must include timestamp and symbol") raise ValueError("Each event must include timestamp and symbol")
occurred_at = datetime.fromisoformat(timestamp_raw.replace("Z", "+00:00")).astimezone(UTC) occurred_at = datetime.fromisoformat(
timestamp_raw.replace("Z", "+00:00")).astimezone(UTC)
events.append( events.append(
ReplayBookEvent( ReplayBookEvent(
occurred_at=occurred_at, occurred_at=occurred_at,
@@ -185,6 +187,92 @@ def load_replay_events(path: Path) -> list[ReplayBookEvent]:
return sorted(events, key=lambda event: event.occurred_at) return sorted(events, key=lambda event: event.occurred_at)
def load_replay_events_from_db(
store: object,
*,
symbols: list[str] | None = None,
start: datetime | None = None,
end: datetime | None = None,
) -> list[ReplayBookEvent]:
"""Load replay events from market_snapshots table.
Each market_snapshots row has snapshot_at, symbol, payload (raw Kraken WS).
Payload format: {channel, symbol, data: [{bids: [{price, qty}], asks: [{price, qty}]}]}
"""
with store.connect() as conn: # type: ignore[union-attr]
query = "SELECT snapshot_at, symbol, payload FROM market_snapshots WHERE 1=1"
params: list[object] = []
if symbols:
placeholders = ",".join("?" for _ in symbols)
query += f" AND symbol IN ({placeholders})"
params.extend(symbols)
if start is not None:
query += " AND snapshot_at >= ?"
params.append(start)
if end is not None:
query += " AND snapshot_at <= ?"
params.append(end)
query += " ORDER BY snapshot_at ASC"
# type: ignore[union-attr]
rows = conn.execute(query, params).fetchall()
events: list[ReplayBookEvent] = []
for row in rows:
snapshot_at: datetime = row[0]
symbol: str = row[1]
payload_raw = row[2]
if isinstance(payload_raw, str):
payload = orjson.loads(payload_raw)
elif isinstance(payload_raw, dict):
payload = payload_raw
else:
continue
data = payload.get("data")
if not isinstance(data, list) or not data:
continue
first = data[0]
if not isinstance(first, dict):
continue
bids = _parse_kraken_book_levels(first.get("bids"))
asks = _parse_kraken_book_levels(first.get("asks"))
if bids or asks:
events.append(
ReplayBookEvent(
occurred_at=snapshot_at,
symbol=symbol,
bids=bids,
asks=asks,
)
)
return events
def _parse_kraken_book_levels(
raw_levels: object | None,
) -> tuple[BookLevel, ...]:
"""Parse Kraken WS book level format: [{price, qty}, ...]."""
if not isinstance(raw_levels, list):
return ()
levels: list[BookLevel] = []
for level in raw_levels:
if isinstance(level, dict) and "price" in level and "qty" in level:
levels.append(
BookLevel(price=float(level["price"]),
volume=float(level["qty"]))
)
return tuple(levels)
class BacktestReplayEngine: class BacktestReplayEngine:
def __init__( def __init__(
self, self,
@@ -206,7 +294,8 @@ class BacktestReplayEngine:
min_order_size_by_pair=config.min_order_size_by_pair, min_order_size_by_pair=config.min_order_size_by_pair,
) )
self._pre_trade = PreTradeValidator() self._pre_trade = PreTradeValidator()
self._trade_limits = TradeLimitsGuard(max_concurrent_trades=config.max_concurrent_trades) self._trade_limits = TradeLimitsGuard(
max_concurrent_trades=config.max_concurrent_trades)
self._simulated_rest = _SimulatedRestClient( self._simulated_rest = _SimulatedRestClient(
self._clock, self._clock,
slippage_bps=config.slippage_bps, slippage_bps=config.slippage_bps,
@@ -241,7 +330,8 @@ class BacktestReplayEngine:
trades_executed = 0 trades_executed = 0
realized_pnl = 0.0 realized_pnl = 0.0
equity = float(starting_balances.get(self._config.quote_asset.upper(), 0.0)) equity = float(starting_balances.get(
self._config.quote_asset.upper(), 0.0))
peak_equity = equity peak_equity = equity
max_drawdown = 0.0 max_drawdown = 0.0
@@ -284,7 +374,8 @@ class BacktestReplayEngine:
result = await self._sequencer.execute(opportunity) result = await self._sequencer.execute(opportunity)
self._trade_limits.close_trade(exposure) self._trade_limits.close_trade(exposure)
execution_latencies.append(self._simulated_rest.last_trade_latency_ms) execution_latencies.append(
self._simulated_rest.last_trade_latency_ms)
fill_samples.append(self._simulated_rest.last_fill_ratio) fill_samples.append(self._simulated_rest.last_fill_ratio)
if not result.success: if not result.success:
@@ -307,7 +398,8 @@ class BacktestReplayEngine:
wins = sum(1 for pnl in realized_samples if pnl > 0.0) wins = sum(1 for pnl in realized_samples if pnl > 0.0)
win_rate = (wins / len(realized_samples)) if realized_samples else None win_rate = (wins / len(realized_samples)) if realized_samples else None
fill_rate = (sum(fill_samples) / len(fill_samples)) if fill_samples else None fill_rate = (sum(fill_samples) / len(fill_samples)
) if fill_samples else None
return BacktestReport( return BacktestReport(
started_at=events[0].occurred_at if events else self._clock.now, started_at=events[0].occurred_at if events else self._clock.now,
+46 -16
View File
@@ -8,7 +8,12 @@ from pathlib import Path
import structlog import structlog
from arbitrade.backtesting.replay import BacktestConfig, BacktestReplayEngine, load_replay_events from arbitrade.backtesting.replay import (
BacktestConfig,
BacktestReplayEngine,
load_replay_events,
load_replay_events_from_db,
)
from arbitrade.detection.graph import CurrencyGraph, TriangularCycle from arbitrade.detection.graph import CurrencyGraph, TriangularCycle
from arbitrade.storage.db import DuckDBStore from arbitrade.storage.db import DuckDBStore
from arbitrade.storage.repositories import BacktestJobRepository from arbitrade.storage.repositories import BacktestJobRepository
@@ -44,26 +49,51 @@ def _parse_balances(raw: str) -> dict[str, float]:
async def run_backtest_job( async def run_backtest_job(
job_id: str, job_id: str,
events_path: str,
config_dict: dict[str, object] | None, config_dict: dict[str, object] | None,
store: DuckDBStore, store: DuckDBStore,
) -> None: ) -> None:
"""Execute a single backtest job: load events, run engine, store report to DB.""" """Execute a single backtest job: load events from DB or file, run engine, store report."""
repo = BacktestJobRepository(store) repo = BacktestJobRepository(store)
repo.update_status(job_id, "running") repo.update_status(job_id, "running")
_LOG.info("backtest_job_started", job_id=job_id, events_path=events_path) _LOG.info("backtest_job_started", job_id=job_id)
try: try:
path = Path(events_path)
if not path.is_absolute():
path = Path("data") / path
path = path.resolve()
events = load_replay_events(path)
if not events:
raise ValueError(f"No events found in {path}")
config = config_dict or {} config = config_dict or {}
events_path = str(config.get("events_path", ""))
symbols_raw = config.get("symbols")
source = str(config.get("source", "db"))
start_dt = None
end_dt = None
if source == "db":
start_str = config.get("start_time")
end_str = config.get("end_time")
if isinstance(start_str, str) and start_str:
start_dt = datetime.fromisoformat(
start_str.replace("Z", "+00:00"))
if isinstance(end_str, str) and end_str:
end_dt = datetime.fromisoformat(end_str.replace("Z", "+00:00"))
symbols: list[str] | None = None
if isinstance(symbols_raw, str) and symbols_raw.strip():
symbols = [s.strip().upper()
for s in symbols_raw.split(",") if s.strip()]
elif isinstance(symbols_raw, list):
symbols = [str(s).upper() for s in symbols_raw]
events = load_replay_events_from_db(
store, symbols=symbols, start=start_dt, end=end_dt,
)
else:
path = Path(events_path)
if not path.is_absolute():
path = Path("data") / path
path = path.resolve()
events = load_replay_events(path)
if not events:
raise ValueError("No events found for backtest")
starting_balances_raw = str(config.get( starting_balances_raw = str(config.get(
"starting_balances", "USD=1000.0")) "starting_balances", "USD=1000.0"))
starting_balances = _parse_balances(starting_balances_raw) starting_balances = _parse_balances(starting_balances_raw)
@@ -123,7 +153,7 @@ async def run_backtest_job(
async def backtest_worker( async def backtest_worker(
queue: asyncio.Queue[tuple[str, str, dict[str, object] | None] | None], queue: asyncio.Queue[tuple[str, dict[str, object] | None] | None],
store: DuckDBStore, store: DuckDBStore,
) -> None: ) -> None:
"""Worker coroutine: pull jobs from queue and execute them one at a time.""" """Worker coroutine: pull jobs from queue and execute them one at a time."""
@@ -133,9 +163,9 @@ async def backtest_worker(
if item is None: if item is None:
queue.task_done() queue.task_done()
break break
job_id, events_path, config = item job_id, config = item
try: try:
await run_backtest_job(job_id, events_path, config, store) await run_backtest_job(job_id, config, store)
except Exception: except Exception:
_LOG.exception("backtest_worker_unhandled_error", job_id=job_id) _LOG.exception("backtest_worker_unhandled_error", job_id=job_id)
finally: finally:
@@ -42,13 +42,32 @@
hx-target="#backtesting-shell" hx-target="#backtesting-shell"
hx-swap="outerHTML" hx-swap="outerHTML"
> >
<input type="hidden" name="source" value="db" />
<label class="field"> <label class="field">
<span>Replay events path (JSONL)</span> <span>Symbols (comma-separated, blank=all)</span>
<input <input
name="events_path" name="symbols"
type="text" type="text"
value="{{ events_path }}" value="{{ symbols | default('') }}"
placeholder="data/replay.jsonl" placeholder="BTC/USD,ETH/BTC"
/>
</label>
<label class="field">
<span>Start time (ISO datetime, optional)</span>
<input
name="start_time"
type="text"
value="{{ start_time | default('') }}"
placeholder="2025-01-01T00:00:00"
/>
</label>
<label class="field">
<span>End time (ISO datetime, optional)</span>
<input
name="end_time"
type="text"
value="{{ end_time | default('') }}"
placeholder="2025-01-02T00:00:00"
/> />
</label> </label>
<label class="field"> <label class="field">
@@ -6,7 +6,7 @@
</article> </article>
<article class="card"> <article class="card">
<div class="label">Balances</div> <div class="label">Balances</div>
<div class="value">{{ balances }}</div> <div class="value">{{ balances | safe }}</div>
</article> </article>
<article class="card"> <article class="card">
<div class="label">Open Trades</div> <div class="label">Open Trades</div>
@@ -51,7 +51,7 @@
class="value" class="value"
style="font-size: 1rem; font-weight: 500; word-break: break-word" style="font-size: 1rem; font-weight: 500; word-break: break-word"
> >
{{ balances }} {{ balances | safe }}
</div> </div>
<div class="meta">Total value {{ total_value }}</div> <div class="meta">Total value {{ total_value }}</div>
<div class="meta">Equity {{ equity }}</div> <div class="meta">Equity {{ equity }}</div>