Refactor code for improved readability and consistency
CI / lint-test-build (push) Failing after 12s

- Consolidated multiline string formatting into single-line for SQL queries in multiple files.
- Adjusted argument formatting in function calls for better alignment and readability.
- Removed unnecessary line breaks and improved spacing in various sections of the codebase.
- Updated test cases to maintain consistency in formatting and improve clarity.
This commit is contained in:
2026-06-04 19:04:30 +02:00
parent 7d18bdf316
commit c8e3daeb57
21 changed files with 377 additions and 383 deletions
+9 -12
View File
@@ -19,10 +19,12 @@ def _resolve_fee_rate(fee_rate: float | None, db_path: str | None = None) -> flo
if db_path is not None: if db_path is not None:
try: try:
conn = duckdb.connect(db_path) conn = duckdb.connect(db_path)
row = conn.execute(""" row = conn.execute(
"""
SELECT maker_fee FROM kraken_account_snapshots SELECT maker_fee FROM kraken_account_snapshots
ORDER BY snapshot_at DESC LIMIT 1 ORDER BY snapshot_at DESC LIMIT 1
""").fetchone() """
).fetchone()
conn.close() conn.close()
if row is not None and row[0] is not None: if row is not None and row[0] is not None:
return float(row[0]) return float(row[0])
@@ -51,16 +53,14 @@ def _parse_balances(raw: str) -> Mapping[str, float]:
def main() -> int: def main() -> int:
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(description="Run a deterministic replay backtest.")
description="Run a deterministic replay backtest.")
parser.add_argument("--events", type=Path, required=True) parser.add_argument("--events", type=Path, required=True)
parser.add_argument("--starting-balances", type=str, default="USD=1000.0") parser.add_argument("--starting-balances", type=str, default="USD=1000.0")
parser.add_argument("--trade-capital", type=float, default=100.0) parser.add_argument("--trade-capital", type=float, default=100.0)
parser.add_argument("--fee-rate", type=float, default=None) parser.add_argument("--fee-rate", type=float, default=None)
parser.add_argument("--slippage-bps", type=float, default=4.0) parser.add_argument("--slippage-bps", type=float, default=4.0)
parser.add_argument("--execution-latency-ms", type=float, default=20.0) parser.add_argument("--execution-latency-ms", type=float, default=20.0)
parser.add_argument("--db-path", type=str, default=None, parser.add_argument("--db-path", type=str, default=None, help="DuckDB path for fee lookup")
help="DuckDB path for fee lookup")
args = parser.parse_args() args = parser.parse_args()
cycles_by_pair, available_pairs = _build_graph() cycles_by_pair, available_pairs = _build_graph()
@@ -80,18 +80,15 @@ def main() -> int:
started_at=events[0].occurred_at if events else datetime.now(UTC), started_at=events[0].occurred_at if events else datetime.now(UTC),
) )
report = asyncio.run( report = asyncio.run(
engine.run(events, starting_balances=_parse_balances( engine.run(events, starting_balances=_parse_balances(args.starting_balances))
args.starting_balances))
) )
print("Backtest report:") print("Backtest report:")
print(f"- processed_events: {report.processed_events}") print(f"- processed_events: {report.processed_events}")
print(f"- opportunities_seen: {report.opportunities_seen}") print(f"- opportunities_seen: {report.opportunities_seen}")
print(f"- trades_executed: {report.trades_executed}") print(f"- trades_executed: {report.trades_executed}")
print( print(f"- win_rate: {report.win_rate if report.win_rate is not None else 'n/a'}")
f"- win_rate: {report.win_rate if report.win_rate is not None else 'n/a'}") print(f"- fill_rate: {report.fill_rate if report.fill_rate is not None else 'n/a'}")
print(
f"- fill_rate: {report.fill_rate if report.fill_rate is not None else 'n/a'}")
print(f"- realized_pnl_usd: {report.realized_pnl_usd:.4f}") print(f"- realized_pnl_usd: {report.realized_pnl_usd:.4f}")
print(f"- max_drawdown_usd: {report.max_drawdown_usd:.4f}") print(f"- max_drawdown_usd: {report.max_drawdown_usd:.4f}")
print(f"- miss_reasons: {dict(report.miss_reasons)}") print(f"- miss_reasons: {dict(report.miss_reasons)}")
+12 -17
View File
@@ -36,8 +36,7 @@ def _parse_float_list(raw: str) -> list[float]:
def _parse_pair_universes(raw: str) -> list[tuple[str, ...]]: def _parse_pair_universes(raw: str) -> list[tuple[str, ...]]:
universes: list[tuple[str, ...]] = [] universes: list[tuple[str, ...]] = []
for chunk in raw.split(";"): for chunk in raw.split(";"):
symbols = tuple(item.strip().upper() symbols = tuple(item.strip().upper() for item in chunk.split("|") if item.strip())
for item in chunk.split("|") if item.strip())
if symbols: if symbols:
universes.append(symbols) universes.append(symbols)
if not universes: if not universes:
@@ -75,31 +74,29 @@ def _print_top_results(results: Sequence[SweepResult], *, limit: int = 5) -> Non
def main() -> int: def main() -> int:
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description="Run backtesting parameter sweep with train/test split.") description="Run backtesting parameter sweep with train/test split."
)
parser.add_argument("--events", type=Path, required=True) parser.add_argument("--events", type=Path, required=True)
parser.add_argument("--starting-balances", type=str, default="USD=1000.0") parser.add_argument("--starting-balances", type=str, default="USD=1000.0")
parser.add_argument("--theta-values", type=str, parser.add_argument("--theta-values", type=str, default="0.0003,0.0005,0.0008")
default="0.0003,0.0005,0.0008") parser.add_argument("--trade-capital-values", type=str, default="50,100,150")
parser.add_argument("--trade-capital-values",
type=str, default="50,100,150")
parser.add_argument( parser.add_argument(
"--pair-universes", "--pair-universes",
type=str, type=str,
default="BTC/USD|ETH/BTC|ETH/USD", default="BTC/USD|ETH/BTC|ETH/USD",
help="Semicolon-separated universes, each with | delimited pairs", help="Semicolon-separated universes, each with | delimited pairs",
) )
parser.add_argument("--staleness-threshold-values", parser.add_argument("--staleness-threshold-values", type=str, default="3,5,8")
type=str, default="3,5,8")
parser.add_argument("--train-ratio", type=float, default=0.7) parser.add_argument("--train-ratio", type=float, default=0.7)
parser.add_argument("--output", type=Path, parser.add_argument(
default=Path("ops/backtesting/parameter_sweep_results.json")) "--output", type=Path, default=Path("ops/backtesting/parameter_sweep_results.json")
)
parser.add_argument("--min-test-realized-pnl-usd", type=float, default=0.0) parser.add_argument("--min-test-realized-pnl-usd", type=float, default=0.0)
parser.add_argument("--min-test-win-rate", type=float, default=0.5) parser.add_argument("--min-test-win-rate", type=float, default=0.5)
parser.add_argument("--min-test-fill-rate", type=float, default=0.9) parser.add_argument("--min-test-fill-rate", type=float, default=0.9)
parser.add_argument("--max-test-drawdown-usd", type=float, default=25.0) parser.add_argument("--max-test-drawdown-usd", type=float, default=25.0)
parser.add_argument("--max-generalization-gap-ratio", parser.add_argument("--max-generalization-gap-ratio", type=float, default=0.5)
type=float, default=0.5)
args = parser.parse_args() args = parser.parse_args()
@@ -107,15 +104,13 @@ def main() -> int:
symbols = sorted({event.symbol.upper() for event in events}) symbols = sorted({event.symbol.upper() for event in events})
cycles_by_pair = _build_graph_from_symbols(symbols) cycles_by_pair = _build_graph_from_symbols(symbols)
if not cycles_by_pair: if not cycles_by_pair:
raise SystemExit( raise SystemExit("No triangular cycles found in supplied replay events")
"No triangular cycles found in supplied replay events")
grid = build_parameter_grid( grid = build_parameter_grid(
theta_values=_parse_float_list(args.theta_values), theta_values=_parse_float_list(args.theta_values),
trade_capital_values=_parse_float_list(args.trade_capital_values), trade_capital_values=_parse_float_list(args.trade_capital_values),
pair_universes=_parse_pair_universes(args.pair_universes), pair_universes=_parse_pair_universes(args.pair_universes),
staleness_threshold_values=_parse_float_list( staleness_threshold_values=_parse_float_list(args.staleness_threshold_values),
args.staleness_threshold_values),
) )
artifacts = run_parameter_search( artifacts = run_parameter_search(
+4 -2
View File
@@ -13,11 +13,13 @@ from arbitrade.storage.db import DuckDBStore
def _python_scan_compute(store: DuckDBStore) -> tuple[float, float | None, float | None]: def _python_scan_compute(store: DuckDBStore) -> tuple[float, float | None, float | None]:
with store.connect() as conn: with store.connect() as conn:
trade_rows = conn.execute(""" trade_rows = conn.execute(
"""
SELECT started_at, finished_at, realized_pnl SELECT started_at, finished_at, realized_pnl
FROM trades FROM trades
WHERE finished_at IS NOT NULL WHERE finished_at IS NOT NULL
""").fetchall() """
).fetchall()
opportunity_rows = conn.execute("SELECT detected_at FROM opportunities").fetchall() opportunity_rows = conn.execute("SELECT detected_at FROM opportunities").fetchall()
realized = sum(float(row[2]) for row in trade_rows if row[2] is not None) realized = sum(float(row[2]) for row in trade_rows if row[2] is not None)
+4 -4
View File
@@ -29,8 +29,9 @@ def create_app(settings: Settings) -> FastAPI:
db.migrate() db.migrate()
kraken_client = KrakenRestClient(settings) kraken_client = KrakenRestClient(settings)
fee_sync_stop_event = asyncio.Event() fee_sync_stop_event = asyncio.Event()
backtest_queue: asyncio.Queue[tuple[str, str, backtest_queue: asyncio.Queue[tuple[str, str, dict[str, object] | None] | None] = (
dict[str, object] | None] | None] = asyncio.Queue() asyncio.Queue()
)
@asynccontextmanager @asynccontextmanager
async def lifespan(app: FastAPI) -> AsyncIterator[None]: async def lifespan(app: FastAPI) -> AsyncIterator[None]:
@@ -75,8 +76,7 @@ def create_app(settings: Settings) -> FastAPI:
app.state.audit_repository = AuditRepository(db) app.state.audit_repository = AuditRepository(db)
app.state.runtime_state_repository = RuntimeStateRepository(db) app.state.runtime_state_repository = RuntimeStateRepository(db)
app.state.alert_notifier = build_notifier_from_settings(settings) app.state.alert_notifier = build_notifier_from_settings(settings)
app.state.configuration_service = ConfigurationService( app.state.configuration_service = ConfigurationService(settings, db, AuditRepository(db))
settings, db, AuditRepository(db))
app.state.backtest_recent_reports = [] app.state.backtest_recent_reports = []
app.state.dashboard_controls = DashboardControlState( app.state.dashboard_controls = DashboardControlState(
is_running=not settings.kill_switch_active, is_running=not settings.kill_switch_active,
+51 -42
View File
@@ -19,7 +19,12 @@ from arbitrade.api.auth import require_dashboard_auth
from arbitrade.api.control_state import DashboardControlState from arbitrade.api.control_state import DashboardControlState
from arbitrade.backtesting.replay import BacktestConfig, BacktestReplayEngine, load_replay_events from arbitrade.backtesting.replay import BacktestConfig, BacktestReplayEngine, load_replay_events
from arbitrade.detection.graph import CurrencyGraph, TriangularCycle from arbitrade.detection.graph import CurrencyGraph, TriangularCycle
from arbitrade.storage.repositories import AuditRecord, AuditRepository, BacktestJobRepository, KrakenAccountSnapshotRepository from arbitrade.storage.repositories import (
AuditRecord,
AuditRepository,
BacktestJobRepository,
KrakenAccountSnapshotRepository,
)
router = APIRouter(dependencies=[Depends(require_dashboard_auth)]) router = APIRouter(dependencies=[Depends(require_dashboard_auth)])
public_router = APIRouter() public_router = APIRouter()
@@ -27,8 +32,7 @@ public_router = APIRouter()
def _resolve_templates_directory() -> str: def _resolve_templates_directory() -> str:
# Support source layout, Docker runtime (/app), and installed package data. # Support source layout, Docker runtime (/app), and installed package data.
source_layout_path = Path( source_layout_path = Path(__file__).resolve().parents[3] / "web" / "templates"
__file__).resolve().parents[3] / "web" / "templates"
if source_layout_path.is_dir(): if source_layout_path.is_dir():
return str(source_layout_path) return str(source_layout_path)
@@ -37,8 +41,7 @@ def _resolve_templates_directory() -> str:
return str(docker_runtime_path) return str(docker_runtime_path)
try: try:
package_path = resources.files( package_path = resources.files("arbitrade").joinpath("web", "templates")
"arbitrade").joinpath("web", "templates")
if package_path.is_dir(): if package_path.is_dir():
return str(package_path) return str(package_path)
except (ModuleNotFoundError, AttributeError): except (ModuleNotFoundError, AttributeError):
@@ -101,29 +104,37 @@ def _dashboard_overview(request: Request) -> dict[str, object]:
else: else:
open_trade_filter = "LOWER(status) NOT IN ('filled', 'closed', 'cancelled', 'canceled')" open_trade_filter = "LOWER(status) NOT IN ('filled', 'closed', 'cancelled', 'canceled')"
portfolio_row = conn.execute(""" portfolio_row = conn.execute(
"""
SELECT balances, total_value_usd SELECT balances, total_value_usd
FROM portfolio_snapshots FROM portfolio_snapshots
ORDER BY snapshot_at DESC ORDER BY snapshot_at DESC
LIMIT 1 LIMIT 1
""").fetchone() """
open_trades = conn.execute(f""" ).fetchone()
open_trades = conn.execute(
f"""
SELECT {trade_ref_expr}, status, started_at, {cycle_expr} SELECT {trade_ref_expr}, status, started_at, {cycle_expr}
FROM trades FROM trades
WHERE {open_trade_filter} WHERE {open_trade_filter}
ORDER BY started_at DESC ORDER BY started_at DESC
LIMIT 5 LIMIT 5
""").fetchall() """
rpnl = conn.execute(""" ).fetchall()
rpnl = conn.execute(
"""
SELECT COALESCE(SUM(COALESCE(realized_pnl, 0)), 0) SELECT COALESCE(SUM(COALESCE(realized_pnl, 0)), 0)
FROM trades FROM trades
""").fetchone() """
latest_opportunities = conn.execute(""" ).fetchone()
latest_opportunities = conn.execute(
"""
SELECT cycle, net_pct, est_profit, detected_at SELECT cycle, net_pct, est_profit, detected_at
FROM opportunities FROM opportunities
ORDER BY detected_at DESC ORDER BY detected_at DESC
LIMIT 5 LIMIT 5
""").fetchall() """
).fetchall()
balances_value = "" balances_value = ""
total_value = "" total_value = ""
@@ -135,8 +146,7 @@ def _dashboard_overview(request: Request) -> dict[str, object]:
parsed = json.loads(balances_raw) parsed = json.loads(balances_raw)
if isinstance(parsed, dict): if isinstance(parsed, dict):
# Filter out zero balances, show non-zero as "AMT ASSET" # Filter out zero balances, show non-zero as "AMT ASSET"
non_zero = {k: float(v) non_zero = {k: float(v) for k, v in parsed.items() if float(v) > 0.0}
for k, v in parsed.items() if float(v) > 0.0}
if non_zero: if non_zero:
balances_value = "<br>".join( balances_value = "<br>".join(
f"{v:.6g} {k}" for k, v in sorted(non_zero.items()) f"{v:.6g} {k}" for k, v in sorted(non_zero.items())
@@ -154,12 +164,14 @@ def _dashboard_overview(request: Request) -> dict[str, object]:
# Query equity from kraken_account_snapshots # Query equity from kraken_account_snapshots
try: try:
equity_row = conn.execute(""" equity_row = conn.execute(
"""
SELECT trade_balance_raw SELECT trade_balance_raw
FROM kraken_account_snapshots FROM kraken_account_snapshots
ORDER BY snapshot_at DESC ORDER BY snapshot_at DESC
LIMIT 1 LIMIT 1
""").fetchone() """
).fetchone()
if equity_row is not None and equity_row[0] is not None: if equity_row is not None and equity_row[0] is not None:
tb_raw = equity_row[0] tb_raw = equity_row[0]
if isinstance(tb_raw, str): if isinstance(tb_raw, str):
@@ -195,12 +207,14 @@ def _dashboard_overview(request: Request) -> dict[str, object]:
taker_fee = "" taker_fee = ""
thirty_day_volume = "" thirty_day_volume = ""
try: try:
acct_row = conn.execute(""" acct_row = conn.execute(
"""
SELECT fee_tier, maker_fee, taker_fee, thirty_day_volume SELECT fee_tier, maker_fee, taker_fee, thirty_day_volume
FROM kraken_account_snapshots FROM kraken_account_snapshots
ORDER BY snapshot_at DESC ORDER BY snapshot_at DESC
LIMIT 1 LIMIT 1
""").fetchone() """
).fetchone()
if acct_row is not None: if acct_row is not None:
fee_tier = str(acct_row[0]) if acct_row[0] is not None else "" fee_tier = str(acct_row[0]) if acct_row[0] is not None else ""
maker_fee = f"{float(acct_row[1]):.4%}" if acct_row[1] is not None else "" maker_fee = f"{float(acct_row[1]):.4%}" if acct_row[1] is not None else ""
@@ -230,12 +244,14 @@ def _dashboard_overview(request: Request) -> dict[str, object]:
def _dashboard_charts(request: Request) -> dict[str, object]: def _dashboard_charts(request: Request) -> dict[str, object]:
store = request.app.state.store store = request.app.state.store
with store.connect() as conn: with store.connect() as conn:
opportunity_rows = conn.execute(""" opportunity_rows = conn.execute(
"""
SELECT detected_at, cycle, net_pct, est_profit SELECT detected_at, cycle, net_pct, est_profit
FROM opportunities FROM opportunities
ORDER BY detected_at DESC ORDER BY detected_at DESC
LIMIT 10 LIMIT 10
""").fetchall() """
).fetchall()
cr = list(reversed(opportunity_rows)) cr = list(reversed(opportunity_rows))
labels = [] labels = []
@@ -375,12 +391,12 @@ def _dashboard_config_context(request: Request) -> dict[str, object]:
else "" else ""
) )
max_exposure_per_asset_value = ( max_exposure_per_asset_value = (
f"{float(rs.max_exposure_per_asset_usd):.2f}" if rs.max_exposure_per_asset_usd is not None else "" f"{float(rs.max_exposure_per_asset_usd):.2f}"
if rs.max_exposure_per_asset_usd is not None
else ""
) )
daily_loss_limit = ( daily_loss_limit = (
f"{float(rs.daily_loss_limit_usd):.2f} USD" f"{float(rs.daily_loss_limit_usd):.2f} USD" if rs.daily_loss_limit_usd is not None else ""
if rs.daily_loss_limit_usd is not None
else ""
) )
daily_loss_limit_value = ( daily_loss_limit_value = (
f"{float(rs.daily_loss_limit_usd):.2f}" if rs.daily_loss_limit_usd is not None else "" f"{float(rs.daily_loss_limit_usd):.2f}" if rs.daily_loss_limit_usd is not None else ""
@@ -391,20 +407,18 @@ def _dashboard_config_context(request: Request) -> dict[str, object]:
else "" else ""
) )
cumulative_loss_limit_value = ( cumulative_loss_limit_value = (
f"{float(rs.cumulative_loss_limit_usd):.2f}" if rs.cumulative_loss_limit_usd is not None else "" f"{float(rs.cumulative_loss_limit_usd):.2f}"
if rs.cumulative_loss_limit_usd is not None
else ""
) )
max_source_latency = ( max_source_latency = (
f"{float(rs.max_source_latency_ms):.1f} ms" f"{float(rs.max_source_latency_ms):.1f} ms" if rs.max_source_latency_ms is not None else ""
if rs.max_source_latency_ms is not None
else ""
) )
max_source_latency_value = ( max_source_latency_value = (
f"{float(rs.max_source_latency_ms):.1f}" if rs.max_source_latency_ms is not None else "" f"{float(rs.max_source_latency_ms):.1f}" if rs.max_source_latency_ms is not None else ""
) )
max_apply_latency = ( max_apply_latency = (
f"{float(rs.max_apply_latency_ms):.1f} ms" f"{float(rs.max_apply_latency_ms):.1f} ms" if rs.max_apply_latency_ms is not None else ""
if rs.max_apply_latency_ms is not None
else ""
) )
max_apply_latency_value = ( max_apply_latency_value = (
f"{float(rs.max_apply_latency_ms):.1f}" if rs.max_apply_latency_ms is not None else "" f"{float(rs.max_apply_latency_ms):.1f}" if rs.max_apply_latency_ms is not None else ""
@@ -415,8 +429,7 @@ def _dashboard_config_context(request: Request) -> dict[str, object]:
max_consecutive_failures_value = ( max_consecutive_failures_value = (
str(rs.max_consecutive_failures) if rs.max_consecutive_failures is not None else "" str(rs.max_consecutive_failures) if rs.max_consecutive_failures is not None else ""
) )
strategy_stat_arb_enabled = bool( strategy_stat_arb_enabled = bool(getattr(rs, "strategy_enable_stat_arb_experiment", False))
getattr(rs, "strategy_enable_stat_arb_experiment", False))
return { return {
# Runtime # Runtime
@@ -537,8 +550,7 @@ def _dashboard_controls(request: Request) -> dict[str, object]:
alerts_last_channel_results = [ alerts_last_channel_results = [
str(item) for item in cast(list[object], alert_status.get("last_channel_results", [])) str(item) for item in cast(list[object], alert_status.get("last_channel_results", []))
] ]
strategy_stat_arb_enabled = bool( strategy_stat_arb_enabled = bool(getattr(rs, "strategy_enable_stat_arb_experiment", False))
getattr(rs, "strategy_enable_stat_arb_experiment", False))
return { return {
"execution_status": "running" if ctl.is_running else "stopped", "execution_status": "running" if ctl.is_running else "stopped",
@@ -943,11 +955,9 @@ async def dashboard_backtesting_run(request: Request) -> HTMLResponse:
try: try:
custom_fee_rate = ( custom_fee_rate = (
float(defaults["custom_fee_rate"]) float(defaults["custom_fee_rate"]) if defaults["custom_fee_rate"].strip() else None
if defaults["custom_fee_rate"].strip() else None
) )
fee_rate = _fee_rate_for_profile( fee_rate = _fee_rate_for_profile(defaults["fee_profile"], custom_fee_rate, request=request)
defaults["fee_profile"], custom_fee_rate, request=request)
config_dict: dict[str, object] = { config_dict: dict[str, object] = {
"source": defaults["source"], "source": defaults["source"],
@@ -1067,8 +1077,7 @@ async def dashboard_backtesting_export(request: Request, job_id: str) -> Respons
return Response( return Response(
content=orjson.dumps(payload).decode("utf-8"), content=orjson.dumps(payload).decode("utf-8"),
media_type="application/x-jsonlines", media_type="application/x-jsonlines",
headers={ headers={"Content-Disposition": f"attachment; filename=backtest_{job_id[:8]}.jsonl"},
"Content-Disposition": f"attachment; filename=backtest_{job_id[:8]}.jsonl"},
) )
+7 -16
View File
@@ -153,8 +153,7 @@ def _parse_book_levels(raw_levels: Any) -> tuple[BookLevel, ...]:
or not isinstance(raw_level[1], int | float) or not isinstance(raw_level[1], int | float)
): ):
raise ValueError("Each level must be [price, volume]") raise ValueError("Each level must be [price, volume]")
levels.append(BookLevel(price=float( levels.append(BookLevel(price=float(raw_level[0]), volume=float(raw_level[1])))
raw_level[0]), volume=float(raw_level[1])))
return tuple(levels) return tuple(levels)
@@ -173,8 +172,7 @@ def load_replay_events(path: Path) -> list[ReplayBookEvent]:
if not isinstance(timestamp_raw, str) or not isinstance(symbol_raw, str): if not isinstance(timestamp_raw, str) or not isinstance(symbol_raw, str):
raise ValueError("Each event must include timestamp and symbol") raise ValueError("Each event must include timestamp and symbol")
occurred_at = datetime.fromisoformat( occurred_at = datetime.fromisoformat(timestamp_raw.replace("Z", "+00:00")).astimezone(UTC)
timestamp_raw.replace("Z", "+00:00")).astimezone(UTC)
events.append( events.append(
ReplayBookEvent( ReplayBookEvent(
occurred_at=occurred_at, occurred_at=occurred_at,
@@ -266,10 +264,7 @@ def _parse_kraken_book_levels(
levels: list[BookLevel] = [] levels: list[BookLevel] = []
for level in raw_levels: for level in raw_levels:
if isinstance(level, dict) and "price" in level and "qty" in level: if isinstance(level, dict) and "price" in level and "qty" in level:
levels.append( levels.append(BookLevel(price=float(level["price"]), volume=float(level["qty"])))
BookLevel(price=float(level["price"]),
volume=float(level["qty"]))
)
return tuple(levels) return tuple(levels)
@@ -294,8 +289,7 @@ class BacktestReplayEngine:
min_order_size_by_pair=config.min_order_size_by_pair, min_order_size_by_pair=config.min_order_size_by_pair,
) )
self._pre_trade = PreTradeValidator() self._pre_trade = PreTradeValidator()
self._trade_limits = TradeLimitsGuard( self._trade_limits = TradeLimitsGuard(max_concurrent_trades=config.max_concurrent_trades)
max_concurrent_trades=config.max_concurrent_trades)
self._simulated_rest = _SimulatedRestClient( self._simulated_rest = _SimulatedRestClient(
self._clock, self._clock,
slippage_bps=config.slippage_bps, slippage_bps=config.slippage_bps,
@@ -330,8 +324,7 @@ class BacktestReplayEngine:
trades_executed = 0 trades_executed = 0
realized_pnl = 0.0 realized_pnl = 0.0
equity = float(starting_balances.get( equity = float(starting_balances.get(self._config.quote_asset.upper(), 0.0))
self._config.quote_asset.upper(), 0.0))
peak_equity = equity peak_equity = equity
max_drawdown = 0.0 max_drawdown = 0.0
@@ -374,8 +367,7 @@ class BacktestReplayEngine:
result = await self._sequencer.execute(opportunity) result = await self._sequencer.execute(opportunity)
self._trade_limits.close_trade(exposure) self._trade_limits.close_trade(exposure)
execution_latencies.append( execution_latencies.append(self._simulated_rest.last_trade_latency_ms)
self._simulated_rest.last_trade_latency_ms)
fill_samples.append(self._simulated_rest.last_fill_ratio) fill_samples.append(self._simulated_rest.last_fill_ratio)
if not result.success: if not result.success:
@@ -398,8 +390,7 @@ class BacktestReplayEngine:
wins = sum(1 for pnl in realized_samples if pnl > 0.0) wins = sum(1 for pnl in realized_samples if pnl > 0.0)
win_rate = (wins / len(realized_samples)) if realized_samples else None win_rate = (wins / len(realized_samples)) if realized_samples else None
fill_rate = (sum(fill_samples) / len(fill_samples) fill_rate = (sum(fill_samples) / len(fill_samples)) if fill_samples else None
) if fill_samples else None
return BacktestReport( return BacktestReport(
started_at=events[0].occurred_at if events else self._clock.now, started_at=events[0].occurred_at if events else self._clock.now,
+9 -11
View File
@@ -69,20 +69,21 @@ async def run_backtest_job(
start_str = config.get("start_time") start_str = config.get("start_time")
end_str = config.get("end_time") end_str = config.get("end_time")
if isinstance(start_str, str) and start_str: if isinstance(start_str, str) and start_str:
start_dt = datetime.fromisoformat( start_dt = datetime.fromisoformat(start_str.replace("Z", "+00:00"))
start_str.replace("Z", "+00:00"))
if isinstance(end_str, str) and end_str: if isinstance(end_str, str) and end_str:
end_dt = datetime.fromisoformat(end_str.replace("Z", "+00:00")) end_dt = datetime.fromisoformat(end_str.replace("Z", "+00:00"))
symbols: list[str] | None = None symbols: list[str] | None = None
if isinstance(symbols_raw, str) and symbols_raw.strip(): if isinstance(symbols_raw, str) and symbols_raw.strip():
symbols = [s.strip().upper() symbols = [s.strip().upper() for s in symbols_raw.split(",") if s.strip()]
for s in symbols_raw.split(",") if s.strip()]
elif isinstance(symbols_raw, list): elif isinstance(symbols_raw, list):
symbols = [str(s).upper() for s in symbols_raw] symbols = [str(s).upper() for s in symbols_raw]
events = load_replay_events_from_db( events = load_replay_events_from_db(
store, symbols=symbols, start=start_dt, end=end_dt, store,
symbols=symbols,
start=start_dt,
end=end_dt,
) )
else: else:
path = Path(events_path) path = Path(events_path)
@@ -94,14 +95,12 @@ async def run_backtest_job(
if not events: if not events:
raise ValueError("No events found for backtest") raise ValueError("No events found for backtest")
starting_balances_raw = str(config.get( starting_balances_raw = str(config.get("starting_balances", "USD=1000.0"))
"starting_balances", "USD=1000.0"))
starting_balances = _parse_balances(starting_balances_raw) starting_balances = _parse_balances(starting_balances_raw)
fee_rate = float(config.get("fee_rate", 0.0026)) fee_rate = float(config.get("fee_rate", 0.0026))
trade_capital = float(config.get("trade_capital", 100.0)) trade_capital = float(config.get("trade_capital", 100.0))
min_profit_threshold = float( min_profit_threshold = float(config.get("min_profit_threshold", 0.0005))
config.get("min_profit_threshold", 0.0005))
slippage_bps = float(config.get("slippage_bps", 4.0)) slippage_bps = float(config.get("slippage_bps", 4.0))
execution_latency_ms = float(config.get("execution_latency_ms", 20.0)) execution_latency_ms = float(config.get("execution_latency_ms", 20.0))
@@ -144,8 +143,7 @@ async def run_backtest_job(
repo.store_report(job_id, report_dict) repo.store_report(job_id, report_dict)
repo.update_status(job_id, "completed") repo.update_status(job_id, "completed")
_LOG.info("backtest_job_completed", job_id=job_id, _LOG.info("backtest_job_completed", job_id=job_id, pnl=report.realized_pnl_usd)
pnl=report.realized_pnl_usd)
except Exception as exc: except Exception as exc:
repo.update_status(job_id, "failed", error=str(exc)) repo.update_status(job_id, "failed", error=str(exc))
+11 -16
View File
@@ -91,16 +91,14 @@ def build_parameter_grid(
for theta in theta_values: for theta in theta_values:
for trade_capital in trade_capital_values: for trade_capital in trade_capital_values:
for pair_universe in pair_universes: for pair_universe in pair_universes:
normalized_universe = tuple( normalized_universe = tuple(sorted({pair.upper() for pair in pair_universe}))
sorted({pair.upper() for pair in pair_universe}))
for staleness_threshold in staleness_threshold_values: for staleness_threshold in staleness_threshold_values:
grid.append( grid.append(
SweepParameters( SweepParameters(
min_profit_threshold=float(theta), min_profit_threshold=float(theta),
trade_capital=float(trade_capital), trade_capital=float(trade_capital),
pair_universe=normalized_universe, pair_universe=normalized_universe,
staleness_threshold_seconds=float( staleness_threshold_seconds=float(staleness_threshold),
staleness_threshold),
) )
) )
return grid return grid
@@ -147,8 +145,9 @@ def _restrict_cycles_by_pair(
if normalized_pair not in pair_universe: if normalized_pair not in pair_universe:
continue continue
kept = [cycle for cycle in cycles if all( kept = [
pair.upper() in pair_universe for pair in cycle.pairs)] cycle for cycle in cycles if all(pair.upper() in pair_universe for pair in cycle.pairs)
]
if kept: if kept:
restricted[normalized_pair] = kept restricted[normalized_pair] = kept
return restricted return restricted
@@ -175,9 +174,7 @@ def _evaluate_promotion(
test = result.test_report test = result.test_report
if test.realized_pnl_usd < criteria.min_test_realized_pnl_usd: if test.realized_pnl_usd < criteria.min_test_realized_pnl_usd:
reasons.append( reasons.append("test_realized_pnl_below_threshold")
"test_realized_pnl_below_threshold"
)
if (test.win_rate or 0.0) < criteria.min_test_win_rate: if (test.win_rate or 0.0) < criteria.min_test_win_rate:
reasons.append("test_win_rate_below_threshold") reasons.append("test_win_rate_below_threshold")
if (test.fill_rate or 0.0) < criteria.min_test_fill_rate: if (test.fill_rate or 0.0) < criteria.min_test_fill_rate:
@@ -221,8 +218,7 @@ def run_parameter_search(
quote_asset: str = "USD", quote_asset: str = "USD",
) -> SweepArtifacts: ) -> SweepArtifacts:
criteria = promotion_criteria or PromotionCriteria() criteria = promotion_criteria or PromotionCriteria()
train_events, test_events = split_events_time_windows( train_events, test_events = split_events_time_windows(events, train_ratio=train_ratio)
events, train_ratio=train_ratio)
results: list[SweepResult] = [] results: list[SweepResult] = []
promoted: list[SweepResult] = [] promoted: list[SweepResult] = []
@@ -293,7 +289,8 @@ def run_parameter_search(
test_event_count=len(filtered_test), test_event_count=len(filtered_test),
) )
promotion_ready, promotion_reasons = _evaluate_promotion( promotion_ready, promotion_reasons = _evaluate_promotion(
result=base_result, criteria=criteria) result=base_result, criteria=criteria
)
completed_result = SweepResult( completed_result = SweepResult(
parameters=base_result.parameters, parameters=base_result.parameters,
train_report=base_result.train_report, train_report=base_result.train_report,
@@ -318,8 +315,7 @@ def run_parameter_search(
train_window: tuple[datetime, datetime] | None = None train_window: tuple[datetime, datetime] | None = None
test_window: tuple[datetime, datetime] | None = None test_window: tuple[datetime, datetime] | None = None
if train_events: if train_events:
train_window = (train_events[0].occurred_at, train_window = (train_events[0].occurred_at, train_events[-1].occurred_at)
train_events[-1].occurred_at)
if test_events: if test_events:
test_window = (test_events[0].occurred_at, test_events[-1].occurred_at) test_window = (test_events[0].occurred_at, test_events[-1].occurred_at)
@@ -392,5 +388,4 @@ def persist_sweep_results(path: Path, artifacts: SweepArtifacts) -> None:
} }
path.parent.mkdir(parents=True, exist_ok=True) path.parent.mkdir(parents=True, exist_ok=True)
path.write_bytes(orjson.dumps( path.write_bytes(orjson.dumps(payload, option=orjson.OPT_INDENT_2 | orjson.OPT_SORT_KEYS))
payload, option=orjson.OPT_INDENT_2 | orjson.OPT_SORT_KEYS))
+14 -6
View File
@@ -63,6 +63,7 @@ class ConfigurationService:
"""Load user settings from database and merge with defaults.""" """Load user settings from database and merge with defaults."""
# Import here to avoid circular imports # Import here to avoid circular imports
from arbitrade.storage.repositories import ConfigSettingRepository from arbitrade.storage.repositories import ConfigSettingRepository
setting_repo = ConfigSettingRepository(self._store) setting_repo = ConfigSettingRepository(self._store)
# Load all settings from database # Load all settings from database
@@ -91,7 +92,8 @@ class ConfigurationService:
# Track the latest update time # Track the latest update time
if db_settings: if db_settings:
latest_updated = max( latest_updated = max(
setting.updated_at for setting in db_settings if setting.updated_at) setting.updated_at for setting in db_settings if setting.updated_at
)
self._last_updated_at = latest_updated self._last_updated_at = latest_updated
# Initialize with default values from settings model # Initialize with default values from settings model
@@ -119,6 +121,7 @@ class ConfigurationService:
"""Check if configuration has been updated since last load.""" """Check if configuration has been updated since last load."""
# Import here to avoid circular imports # Import here to avoid circular imports
from arbitrade.storage.repositories import ConfigSettingRepository from arbitrade.storage.repositories import ConfigSettingRepository
setting_repo = ConfigSettingRepository(self._store) setting_repo = ConfigSettingRepository(self._store)
# Get the latest update timestamp from database # Get the latest update timestamp from database
@@ -143,6 +146,7 @@ class ConfigurationService:
"""Set a configuration setting value and persist to database.""" """Set a configuration setting value and persist to database."""
# Import here to avoid circular imports # Import here to avoid circular imports
from arbitrade.storage.repositories import ConfigSettingRepository from arbitrade.storage.repositories import ConfigSettingRepository
setting_repo = ConfigSettingRepository(self._store) setting_repo = ConfigSettingRepository(self._store)
# Convert value to JSON string and determine type # Convert value to JSON string and determine type
@@ -159,10 +163,10 @@ class ConfigurationService:
value_json = str(value).lower() value_json = str(value).lower()
value_type = "bool" value_type = "bool"
elif isinstance(value, list): elif isinstance(value, list):
value_json = orjson.dumps(value).decode('utf-8') value_json = orjson.dumps(value).decode("utf-8")
value_type = "list" value_type = "list"
elif isinstance(value, dict): elif isinstance(value, dict):
value_json = orjson.dumps(value).decode('utf-8') value_json = orjson.dumps(value).decode("utf-8")
value_type = "dict" value_type = "dict"
else: else:
value_json = str(value) value_json = str(value)
@@ -176,7 +180,7 @@ class ConfigurationService:
value_type=value_type, value_type=value_type,
is_secret=False, is_secret=False,
is_runtime_reloadable=False, is_runtime_reloadable=False,
updated_by=updated_by updated_by=updated_by,
) )
# Check if setting exists # Check if setting exists
@@ -205,17 +209,21 @@ class ConfigurationService:
def _pairing_repo(self): def _pairing_repo(self):
from arbitrade.storage.repositories import ConfigPairingRepository from arbitrade.storage.repositories import ConfigPairingRepository
return ConfigPairingRepository(self._store) return ConfigPairingRepository(self._store)
def list_pairings(self) -> list[ConfigPairing]: def list_pairings(self) -> list[ConfigPairing]:
"""List all currency pairings.""" """List all currency pairings."""
return self._pairing_repo().list_pairings() return self._pairing_repo().list_pairings()
def create_pairing(self, base_asset: str, quote_asset: str, source: str = "manual") -> ConfigPairing: def create_pairing(
self, base_asset: str, quote_asset: str, source: str = "manual"
) -> ConfigPairing:
"""Create a new currency pairing.""" """Create a new currency pairing."""
existing = self._pairing_repo().get_pairing(base_asset, quote_asset) existing = self._pairing_repo().get_pairing(base_asset, quote_asset)
if existing: if existing:
return existing return existing
pairing = ConfigPairing( pairing = ConfigPairing(
base_asset=base_asset, quote_asset=quote_asset, enabled=True, source=source) base_asset=base_asset, quote_asset=quote_asset, enabled=True, source=source
)
return self._pairing_repo().create_pairing(pairing) return self._pairing_repo().create_pairing(pairing)
+39 -78
View File
@@ -32,72 +32,49 @@ class Settings(BaseSettings):
) )
alerts_enabled: bool = Field(default=True, alias="ALERTS_ENABLED") alerts_enabled: bool = Field(default=True, alias="ALERTS_ENABLED")
alert_min_severity: str = Field( alert_min_severity: str = Field(default="warning", alias="ALERT_MIN_SEVERITY")
default="warning", alias="ALERT_MIN_SEVERITY") alert_dedup_seconds: float = Field(default=30.0, alias="ALERT_DEDUP_SECONDS")
alert_dedup_seconds: float = Field( alert_on_trade_events: bool = Field(default=True, alias="ALERT_ON_TRADE_EVENTS")
default=30.0, alias="ALERT_DEDUP_SECONDS") alert_on_error_events: bool = Field(default=True, alias="ALERT_ON_ERROR_EVENTS")
alert_on_trade_events: bool = Field( alert_on_threshold_events: bool = Field(default=True, alias="ALERT_ON_THRESHOLD_EVENTS")
default=True, alias="ALERT_ON_TRADE_EVENTS") alert_on_system_events: bool = Field(default=True, alias="ALERT_ON_SYSTEM_EVENTS")
alert_on_error_events: bool = Field(
default=True, alias="ALERT_ON_ERROR_EVENTS")
alert_on_threshold_events: bool = Field(
default=True, alias="ALERT_ON_THRESHOLD_EVENTS")
alert_on_system_events: bool = Field(
default=True, alias="ALERT_ON_SYSTEM_EVENTS")
telegram_alerts_enabled: bool = Field( telegram_alerts_enabled: bool = Field(default=False, alias="TELEGRAM_ALERTS_ENABLED")
default=False, alias="TELEGRAM_ALERTS_ENABLED") telegram_bot_token: str | None = Field(default=None, alias="TELEGRAM_BOT_TOKEN")
telegram_bot_token: str | None = Field( telegram_chat_id: str | None = Field(default=None, alias="TELEGRAM_CHAT_ID")
default=None, alias="TELEGRAM_BOT_TOKEN")
telegram_chat_id: str | None = Field(
default=None, alias="TELEGRAM_CHAT_ID")
discord_alerts_enabled: bool = Field( discord_alerts_enabled: bool = Field(default=False, alias="DISCORD_ALERTS_ENABLED")
default=False, alias="DISCORD_ALERTS_ENABLED") discord_webhook_url: str | None = Field(default=None, alias="DISCORD_WEBHOOK_URL")
discord_webhook_url: str | None = Field(
default=None, alias="DISCORD_WEBHOOK_URL")
email_alerts_enabled: bool = Field( email_alerts_enabled: bool = Field(default=False, alias="EMAIL_ALERTS_ENABLED")
default=False, alias="EMAIL_ALERTS_ENABLED")
email_smtp_host: str | None = Field(default=None, alias="EMAIL_SMTP_HOST") email_smtp_host: str | None = Field(default=None, alias="EMAIL_SMTP_HOST")
email_smtp_port: int = Field(default=587, alias="EMAIL_SMTP_PORT") email_smtp_port: int = Field(default=587, alias="EMAIL_SMTP_PORT")
email_smtp_username: str | None = Field( email_smtp_username: str | None = Field(default=None, alias="EMAIL_SMTP_USERNAME")
default=None, alias="EMAIL_SMTP_USERNAME") email_smtp_password: str | None = Field(default=None, alias="EMAIL_SMTP_PASSWORD")
email_smtp_password: str | None = Field( email_alert_from: str | None = Field(default=None, alias="EMAIL_ALERT_FROM")
default=None, alias="EMAIL_SMTP_PASSWORD")
email_alert_from: str | None = Field(
default=None, alias="EMAIL_ALERT_FROM")
email_alert_to: str | None = Field(default=None, alias="EMAIL_ALERT_TO") email_alert_to: str | None = Field(default=None, alias="EMAIL_ALERT_TO")
email_smtp_use_tls: bool = Field(default=True, alias="EMAIL_SMTP_USE_TLS") email_smtp_use_tls: bool = Field(default=True, alias="EMAIL_SMTP_USE_TLS")
duckdb_path: Path = Field(default=Path( duckdb_path: Path = Field(default=Path("./data/arbitrade.duckdb"), alias="DUCKDB_PATH")
"./data/arbitrade.duckdb"), alias="DUCKDB_PATH")
kraken_rest_url: str = Field( kraken_rest_url: str = Field(default="https://api.kraken.com", alias="KRAKEN_REST_URL")
default="https://api.kraken.com", alias="KRAKEN_REST_URL") kraken_ws_url: str = Field(default="wss://ws.kraken.com/v2", alias="KRAKEN_WS_URL")
kraken_ws_url: str = Field(
default="wss://ws.kraken.com/v2", alias="KRAKEN_WS_URL")
kraken_private_rate_limit_seconds: float = Field( kraken_private_rate_limit_seconds: float = Field(
default=1.0, alias="KRAKEN_PRIVATE_RATE_LIMIT_SECONDS" default=1.0, alias="KRAKEN_PRIVATE_RATE_LIMIT_SECONDS"
) )
kraken_http_timeout_seconds: float = Field( kraken_http_timeout_seconds: float = Field(default=10.0, alias="KRAKEN_HTTP_TIMEOUT_SECONDS")
default=10.0, alias="KRAKEN_HTTP_TIMEOUT_SECONDS") kraken_retry_attempts: int = Field(default=3, alias="KRAKEN_RETRY_ATTEMPTS")
kraken_retry_attempts: int = Field(
default=3, alias="KRAKEN_RETRY_ATTEMPTS")
kraken_retry_base_delay_seconds: float = Field( kraken_retry_base_delay_seconds: float = Field(
default=0.25, alias="KRAKEN_RETRY_BASE_DELAY_SECONDS" default=0.25, alias="KRAKEN_RETRY_BASE_DELAY_SECONDS"
) )
kraken_api_key: str | None = Field(default=None, alias="KRAKEN_API_KEY") kraken_api_key: str | None = Field(default=None, alias="KRAKEN_API_KEY")
kraken_api_secret: str | None = Field( kraken_api_secret: str | None = Field(default=None, alias="KRAKEN_API_SECRET")
default=None, alias="KRAKEN_API_SECRET")
kraken_api_key_permissions: str = Field( kraken_api_key_permissions: str = Field(
default="query,trade", default="query,trade",
alias="KRAKEN_API_KEY_PERMISSIONS", alias="KRAKEN_API_KEY_PERMISSIONS",
) )
ws_heartbeat_timeout_seconds: float = Field( ws_heartbeat_timeout_seconds: float = Field(default=20.0, alias="WS_HEARTBEAT_TIMEOUT_SECONDS")
default=20.0, alias="WS_HEARTBEAT_TIMEOUT_SECONDS") ws_max_staleness_seconds: float = Field(default=5.0, alias="WS_MAX_STALENESS_SECONDS")
ws_max_staleness_seconds: float = Field(
default=5.0, alias="WS_MAX_STALENESS_SECONDS")
strategy_enable_stat_arb_experiment: bool = Field( strategy_enable_stat_arb_experiment: bool = Field(
default=False, default=False,
alias="STRATEGY_ENABLE_STAT_ARB_EXPERIMENT", alias="STRATEGY_ENABLE_STAT_ARB_EXPERIMENT",
@@ -120,29 +97,20 @@ class Settings(BaseSettings):
) )
paper_trading_mode: bool = Field(default=True, alias="PAPER_TRADING_MODE") paper_trading_mode: bool = Field(default=True, alias="PAPER_TRADING_MODE")
trade_capital_usd: float = Field(default=100.0, alias="TRADE_CAPITAL_USD") trade_capital_usd: float = Field(default=100.0, alias="TRADE_CAPITAL_USD")
max_trade_capital_usd: float = Field( max_trade_capital_usd: float = Field(default=100.0, alias="MAX_TRADE_CAPITAL_USD")
default=100.0, alias="MAX_TRADE_CAPITAL_USD") max_concurrent_trades: int | None = Field(default=None, alias="MAX_CONCURRENT_TRADES")
max_concurrent_trades: int | None = Field(
default=None, alias="MAX_CONCURRENT_TRADES")
max_exposure_per_asset_usd: float | None = Field( max_exposure_per_asset_usd: float | None = Field(
default=None, default=None,
alias="MAX_EXPOSURE_PER_ASSET_USD", alias="MAX_EXPOSURE_PER_ASSET_USD",
) )
quote_balance_asset: str = Field( quote_balance_asset: str = Field(default="USD", alias="QUOTE_BALANCE_ASSET")
default="USD", alias="QUOTE_BALANCE_ASSET") min_order_size_usd: float | None = Field(default=None, alias="MIN_ORDER_SIZE_USD")
min_order_size_usd: float | None = Field(
default=None, alias="MIN_ORDER_SIZE_USD")
kill_switch_active: bool = Field(default=False, alias="KILL_SWITCH_ACTIVE") kill_switch_active: bool = Field(default=False, alias="KILL_SWITCH_ACTIVE")
daily_loss_limit_usd: float | None = Field( daily_loss_limit_usd: float | None = Field(default=None, alias="DAILY_LOSS_LIMIT_USD")
default=None, alias="DAILY_LOSS_LIMIT_USD") cumulative_loss_limit_usd: float | None = Field(default=None, alias="CUMULATIVE_LOSS_LIMIT_USD")
cumulative_loss_limit_usd: float | None = Field( max_source_latency_ms: float | None = Field(default=None, alias="MAX_SOURCE_LATENCY_MS")
default=None, alias="CUMULATIVE_LOSS_LIMIT_USD") max_apply_latency_ms: float | None = Field(default=None, alias="MAX_APPLY_LATENCY_MS")
max_source_latency_ms: float | None = Field( max_consecutive_failures: int | None = Field(default=None, alias="MAX_CONSECUTIVE_FAILURES")
default=None, alias="MAX_SOURCE_LATENCY_MS")
max_apply_latency_ms: float | None = Field(
default=None, alias="MAX_APPLY_LATENCY_MS")
max_consecutive_failures: int | None = Field(
default=None, alias="MAX_CONSECUTIVE_FAILURES")
fernet_key: str | None = Field(default=None, alias="FERNET_KEY") fernet_key: str | None = Field(default=None, alias="FERNET_KEY")
@@ -159,8 +127,7 @@ class Settings(BaseSettings):
def _validate_log_level(cls, value: str) -> str: def _validate_log_level(cls, value: str) -> str:
normalized = value.strip().upper() normalized = value.strip().upper()
if normalized not in {"DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"}: if normalized not in {"DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"}:
raise ValueError( raise ValueError("LOG_LEVEL must be one of: DEBUG, INFO, WARNING, ERROR, CRITICAL")
"LOG_LEVEL must be one of: DEBUG, INFO, WARNING, ERROR, CRITICAL")
return normalized return normalized
@field_validator("alert_min_severity") @field_validator("alert_min_severity")
@@ -168,19 +135,16 @@ class Settings(BaseSettings):
def _validate_alert_severity(cls, value: str) -> str: def _validate_alert_severity(cls, value: str) -> str:
normalized = value.strip().lower() normalized = value.strip().lower()
if normalized not in {"info", "warning", "error", "critical"}: if normalized not in {"info", "warning", "error", "critical"}:
raise ValueError( raise ValueError("ALERT_MIN_SEVERITY must be one of: info, warning, error, critical")
"ALERT_MIN_SEVERITY must be one of: info, warning, error, critical")
return normalized return normalized
@model_validator(mode="after") @model_validator(mode="after")
def _validate_security_constraints(self) -> Settings: def _validate_security_constraints(self) -> Settings:
if bool(self.dashboard_auth_username) ^ bool(self.dashboard_auth_password): if bool(self.dashboard_auth_username) ^ bool(self.dashboard_auth_password):
raise ValueError( raise ValueError("dashboard auth requires both username and password")
"dashboard auth requires both username and password")
if bool(self.kraken_api_key) ^ bool(self.kraken_api_secret): if bool(self.kraken_api_key) ^ bool(self.kraken_api_secret):
raise ValueError( raise ValueError("Kraken API auth requires both API key and secret")
"Kraken API auth requires both API key and secret")
permissions = { permissions = {
token.strip().lower() token.strip().lower()
@@ -188,11 +152,9 @@ class Settings(BaseSettings):
if token.strip() if token.strip()
} }
if permissions and ("query" not in permissions or "trade" not in permissions): if permissions and ("query" not in permissions or "trade" not in permissions):
raise ValueError( raise ValueError("KRAKEN_API_KEY_PERMISSIONS must include query and trade")
"KRAKEN_API_KEY_PERMISSIONS must include query and trade")
if "withdraw" in permissions or "withdrawals" in permissions: if "withdraw" in permissions or "withdrawals" in permissions:
raise ValueError( raise ValueError("KRAKEN_API_KEY_PERMISSIONS must not include withdrawal scope")
"KRAKEN_API_KEY_PERMISSIONS must not include withdrawal scope")
if self.alert_dedup_seconds < 0.0: if self.alert_dedup_seconds < 0.0:
raise ValueError("ALERT_DEDUP_SECONDS must be >= 0") raise ValueError("ALERT_DEDUP_SECONDS must be >= 0")
@@ -208,8 +170,7 @@ class Settings(BaseSettings):
"STRATEGY_STAT_ARB_ENTRY_ZSCORE must be greater than STRATEGY_STAT_ARB_EXIT_ZSCORE" "STRATEGY_STAT_ARB_ENTRY_ZSCORE must be greater than STRATEGY_STAT_ARB_EXIT_ZSCORE"
) )
if self.strategy_stat_arb_max_holding_seconds <= 0.0: if self.strategy_stat_arb_max_holding_seconds <= 0.0:
raise ValueError( raise ValueError("STRATEGY_STAT_ARB_MAX_HOLDING_SECONDS must be > 0")
"STRATEGY_STAT_ARB_MAX_HOLDING_SECONDS must be > 0")
return self return self
+1 -2
View File
@@ -92,8 +92,7 @@ def run_incremental_detection_benchmark(
def main() -> None: def main() -> None:
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(description="Benchmark incremental detection latency")
description="Benchmark incremental detection latency")
parser.add_argument("--iterations", type=int, default=50_000) parser.add_argument("--iterations", type=int, default=50_000)
parser.add_argument("--target-ms", type=float, default=1.0) parser.add_argument("--target-ms", type=float, default=1.0)
args = parser.parse_args() args = parser.parse_args()
+7 -15
View File
@@ -43,12 +43,9 @@ async def fetch_and_store_account_snapshot(
_LOG.exception("trade_balance_fetch_failed") _LOG.exception("trade_balance_fetch_failed")
return None return None
fee_tier = volume_data.get("fee_tier") if isinstance( fee_tier = volume_data.get("fee_tier") if isinstance(volume_data, dict) else None
volume_data, dict) else None fees_dict = volume_data.get("fees") if isinstance(volume_data, dict) else None
fees_dict = volume_data.get("fees") if isinstance( fees_maker = volume_data.get("fees_maker") if isinstance(volume_data, dict) else None
volume_data, dict) else None
fees_maker = volume_data.get("fees_maker") if isinstance(
volume_data, dict) else None
currency = volume_data.get("currency") currency = volume_data.get("currency")
thirty_day_volume_str = volume_data.get("volume") thirty_day_volume_str = volume_data.get("volume")
@@ -74,9 +71,7 @@ async def fetch_and_store_account_snapshot(
if currency is not None: if currency is not None:
fee_schedule["currency"] = currency fee_schedule["currency"] = currency
thirty_day_volume = ( thirty_day_volume = float(thirty_day_volume_str) if thirty_day_volume_str is not None else None
float(thirty_day_volume_str) if thirty_day_volume_str is not None else None
)
snapshot = KrakenAccountSnapshot( snapshot = KrakenAccountSnapshot(
snapshot_at=datetime.now(timezone.utc), snapshot_at=datetime.now(timezone.utc),
@@ -84,8 +79,7 @@ async def fetch_and_store_account_snapshot(
maker_fee=maker_fee, maker_fee=maker_fee,
taker_fee=taker_fee, taker_fee=taker_fee,
thirty_day_volume=thirty_day_volume, thirty_day_volume=thirty_day_volume,
trade_balance_raw=balance_data if isinstance( trade_balance_raw=balance_data if isinstance(balance_data, dict) else None,
balance_data, dict) else None,
fee_schedule_raw=fee_schedule if fee_schedule else None, fee_schedule_raw=fee_schedule if fee_schedule else None,
) )
@@ -109,8 +103,7 @@ async def fetch_and_store_account_snapshot(
"INSERT INTO portfolio_snapshots (snapshot_at, balances, total_value_usd) VALUES (?, ?, ?)", "INSERT INTO portfolio_snapshots (snapshot_at, balances, total_value_usd) VALUES (?, ?, ?)",
( (
datetime.now(timezone.utc), datetime.now(timezone.utc),
orjson.dumps(wallet_balances).decode( orjson.dumps(wallet_balances).decode("utf-8") if wallet_balances else None,
"utf-8") if wallet_balances else None,
total_value, total_value,
), ),
) )
@@ -130,8 +123,7 @@ async def run_fee_sync_loop(
Runs until stop_event is set. Runs until stop_event is set.
""" """
_LOG.info("fee_sync_loop_started", _LOG.info("fee_sync_loop_started", interval_s=_FEE_REFRESH_INTERVAL_SECONDS)
interval_s=_FEE_REFRESH_INTERVAL_SECONDS)
while not stop_event.is_set(): while not stop_event.is_set():
try: try:
+12 -6
View File
@@ -24,7 +24,8 @@ class MetricsCalculator:
def compute(self) -> PerformanceMetrics: def compute(self) -> PerformanceMetrics:
with self._store.connect() as conn: with self._store.connect() as conn:
tm = conn.execute(""" tm = conn.execute(
"""
SELECT SELECT
COALESCE(SUM(COALESCE(realized_pnl, 0)), 0) AS realized_pnl_usd, COALESCE(SUM(COALESCE(realized_pnl, 0)), 0) AS realized_pnl_usd,
COUNT(*) AS total_trades, COUNT(*) AS total_trades,
@@ -44,21 +45,26 @@ class MetricsCalculator:
) AS latency_p99_seconds ) AS latency_p99_seconds
FROM trades FROM trades
WHERE finished_at IS NOT NULL WHERE finished_at IS NOT NULL
""").fetchone() """
).fetchone()
om = conn.execute(""" om = conn.execute(
"""
SELECT SELECT
COUNT(*) AS opportunity_count, COUNT(*) AS opportunity_count,
MIN(detected_at) AS first_detected_at, MIN(detected_at) AS first_detected_at,
MAX(detected_at) AS last_detected_at MAX(detected_at) AS last_detected_at
FROM opportunities FROM opportunities
""").fetchone() """
).fetchone()
fm = conn.execute(""" fm = conn.execute(
"""
SELECT AVG(filled_volume / volume) AS fill_rate SELECT AVG(filled_volume / volume) AS fill_rate
FROM orders FROM orders
WHERE volume > 0 AND filled_volume IS NOT NULL WHERE volume > 0 AND filled_volume IS NOT NULL
""").fetchone() """
).fetchone()
r_pnl_usd = float(tm[0]) if tm and tm[0] is not None else 0.0 r_pnl_usd = float(tm[0]) if tm and tm[0] is not None else 0.0
tt = int(tm[1]) if tm and tm[1] is not None else 0 tt = int(tm[1]) if tm and tm[1] is not None else 0
+8 -4
View File
@@ -45,22 +45,26 @@ def _runtime_repository(app: FastAPI) -> RuntimeStateRepository | None:
def _open_trade_count(store: DuckDBStore) -> int: def _open_trade_count(store: DuckDBStore) -> int:
with store.connect() as conn: with store.connect() as conn:
row = conn.execute(""" row = conn.execute(
"""
SELECT COUNT(*) SELECT COUNT(*)
FROM trades FROM trades
WHERE finished_at IS NULL WHERE finished_at IS NULL
""").fetchone() """
).fetchone()
return int(row[0]) if row is not None else 0 return int(row[0]) if row is not None else 0
def _latest_balances(store: DuckDBStore) -> dict[str, Any] | None: def _latest_balances(store: DuckDBStore) -> dict[str, Any] | None:
with store.connect() as conn: with store.connect() as conn:
row = conn.execute(""" row = conn.execute(
"""
SELECT balances SELECT balances
FROM portfolio_snapshots FROM portfolio_snapshots
ORDER BY snapshot_at DESC ORDER BY snapshot_at DESC
LIMIT 1 LIMIT 1
""").fetchone() """
).fetchone()
if row is None or row[0] is None: if row is None or row[0] is None:
return None return None
+24 -27
View File
@@ -216,12 +216,14 @@ class DuckDBStore:
# Ensure schema_migrations table exists and get current version # Ensure schema_migrations table exists and get current version
if not self._table_exists(conn, "schema_migrations"): if not self._table_exists(conn, "schema_migrations"):
conn.execute(""" conn.execute(
"""
CREATE TABLE IF NOT EXISTS schema_migrations ( CREATE TABLE IF NOT EXISTS schema_migrations (
version INTEGER PRIMARY KEY, version INTEGER PRIMARY KEY,
applied_at TIMESTAMP DEFAULT current_timestamp applied_at TIMESTAMP DEFAULT current_timestamp
) )
""") """
)
# Get current schema version # Get current schema version
try: try:
@@ -236,30 +238,24 @@ class DuckDBStore:
if current_version < 1: if current_version < 1:
# Migration v1: Add missing columns to trades table # Migration v1: Add missing columns to trades table
# Note: DuckDB does not support ADD COLUMN with constraints # Note: DuckDB does not support ADD COLUMN with constraints
conn.execute( conn.execute("ALTER TABLE trades ADD COLUMN IF NOT EXISTS trade_ref VARCHAR")
"ALTER TABLE trades ADD COLUMN IF NOT EXISTS trade_ref VARCHAR") conn.execute("ALTER TABLE trades ADD COLUMN IF NOT EXISTS estimated_pnl DOUBLE")
conn.execute( conn.execute("ALTER TABLE trades ADD COLUMN IF NOT EXISTS capital_used DOUBLE")
"ALTER TABLE trades ADD COLUMN IF NOT EXISTS estimated_pnl DOUBLE") conn.execute("ALTER TABLE trades ADD COLUMN IF NOT EXISTS cycle VARCHAR")
conn.execute( conn.execute("ALTER TABLE trades ADD COLUMN IF NOT EXISTS leg_count INTEGER")
"ALTER TABLE trades ADD COLUMN IF NOT EXISTS capital_used DOUBLE") conn.execute("INSERT OR IGNORE INTO schema_migrations (version) VALUES (1)")
conn.execute(
"ALTER TABLE trades ADD COLUMN IF NOT EXISTS cycle VARCHAR")
conn.execute(
"ALTER TABLE trades ADD COLUMN IF NOT EXISTS leg_count INTEGER")
conn.execute(
"INSERT OR IGNORE INTO schema_migrations (version) VALUES (1)")
_LOG.info("migration_applied", version=1) _LOG.info("migration_applied", version=1)
if current_version < 2: if current_version < 2:
# Migration v2: Ensure config_backtesting_defaults table # Migration v2: Ensure config_backtesting_defaults table
# config_backtesting_defaults already created by SCHEMA_SQL # config_backtesting_defaults already created by SCHEMA_SQL
conn.execute( conn.execute("INSERT OR IGNORE INTO schema_migrations (version) VALUES (2)")
"INSERT OR IGNORE INTO schema_migrations (version) VALUES (2)")
_LOG.info("migration_applied", version=2) _LOG.info("migration_applied", version=2)
if current_version < 3: if current_version < 3:
# Migration v3: Add kraken_account_snapshots table # Migration v3: Add kraken_account_snapshots table
conn.execute(""" conn.execute(
"""
CREATE TABLE IF NOT EXISTS kraken_account_snapshots ( CREATE TABLE IF NOT EXISTS kraken_account_snapshots (
snapshot_at TIMESTAMP NOT NULL, snapshot_at TIMESTAMP NOT NULL,
fee_tier VARCHAR, fee_tier VARCHAR,
@@ -269,21 +265,22 @@ class DuckDBStore:
trade_balance_raw JSON, trade_balance_raw JSON,
fee_schedule_raw JSON fee_schedule_raw JSON
) )
""") """
conn.execute( )
"INSERT OR IGNORE INTO schema_migrations (version) VALUES (3)") conn.execute("INSERT OR IGNORE INTO schema_migrations (version) VALUES (3)")
_LOG.info("migration_applied", version=3) _LOG.info("migration_applied", version=3)
if current_version < 4: if current_version < 4:
# Migration v4: Add fee_source to backtesting defaults # Migration v4: Add fee_source to backtesting defaults
conn.execute( conn.execute(
"ALTER TABLE config_backtesting_defaults ADD COLUMN IF NOT EXISTS fee_source VARCHAR DEFAULT 'api'") "ALTER TABLE config_backtesting_defaults ADD COLUMN IF NOT EXISTS fee_source VARCHAR DEFAULT 'api'"
conn.execute( )
"INSERT OR IGNORE INTO schema_migrations (version) VALUES (4)") conn.execute("INSERT OR IGNORE INTO schema_migrations (version) VALUES (4)")
_LOG.info("migration_applied", version=4) _LOG.info("migration_applied", version=4)
if current_version < 5: if current_version < 5:
conn.execute(""" conn.execute(
"""
CREATE TABLE IF NOT EXISTS backtest_jobs ( CREATE TABLE IF NOT EXISTS backtest_jobs (
id UUID DEFAULT uuid(), id UUID DEFAULT uuid(),
status VARCHAR NOT NULL DEFAULT 'pending', status VARCHAR NOT NULL DEFAULT 'pending',
@@ -295,9 +292,9 @@ class DuckDBStore:
started_at TIMESTAMP, started_at TIMESTAMP,
finished_at TIMESTAMP finished_at TIMESTAMP
) )
""") """
conn.execute( )
"INSERT OR IGNORE INTO schema_migrations (version) VALUES (5)") conn.execute("INSERT OR IGNORE INTO schema_migrations (version) VALUES (5)")
_LOG.info("migration_applied", version=5) _LOG.info("migration_applied", version=5)
# Update version to current # Update version to current
+69 -51
View File
@@ -6,7 +6,12 @@ from typing import Any
import orjson import orjson
from arbitrade.config.service import ConfigBacktestingDefaults, ConfigPairing, ConfigSection, ConfigSetting from arbitrade.config.service import (
ConfigBacktestingDefaults,
ConfigPairing,
ConfigSection,
ConfigSetting,
)
from arbitrade.storage.db import DuckDBStore from arbitrade.storage.db import DuckDBStore
@@ -344,7 +349,8 @@ class RuntimeStateRepository:
def latest(self) -> RuntimeStateRecord | None: def latest(self) -> RuntimeStateRecord | None:
with self._store.connect() as conn: with self._store.connect() as conn:
row = conn.execute(""" row = conn.execute(
"""
SELECT SELECT
snapshot_at, snapshot_at,
is_running, is_running,
@@ -356,7 +362,8 @@ class RuntimeStateRepository:
FROM runtime_state_snapshots FROM runtime_state_snapshots
ORDER BY snapshot_at DESC ORDER BY snapshot_at DESC
LIMIT 1 LIMIT 1
""").fetchone() """
).fetchone()
if row is None: if row is None:
return None return None
@@ -397,12 +404,7 @@ class ConfigSectionRepository:
) )
row = cursor.fetchone() row = cursor.fetchone()
if row: if row:
return ConfigSection( return ConfigSection(id=row[0], name=row[1], description=row[2], updated_at=row[3])
id=row[0],
name=row[1],
description=row[2],
updated_at=row[3]
)
raise ValueError("Failed to create section") raise ValueError("Failed to create section")
def get_section(self, name: str) -> ConfigSection | None: def get_section(self, name: str) -> ConfigSection | None:
@@ -418,12 +420,7 @@ class ConfigSectionRepository:
) )
row = cursor.fetchone() row = cursor.fetchone()
if row: if row:
return ConfigSection( return ConfigSection(id=row[0], name=row[1], description=row[2], updated_at=row[3])
id=row[0],
name=row[1],
description=row[2],
updated_at=row[3]
)
return None return None
def list_sections(self) -> list[ConfigSection]: def list_sections(self) -> list[ConfigSection]:
@@ -437,12 +434,7 @@ class ConfigSectionRepository:
""" """
) )
return [ return [
ConfigSection( ConfigSection(id=row[0], name=row[1], description=row[2], updated_at=row[3])
id=row[0],
name=row[1],
description=row[2],
updated_at=row[3]
)
for row in cursor.fetchall() for row in cursor.fetchall()
] ]
@@ -480,7 +472,7 @@ class ConfigSettingRepository:
is_secret=bool(row[4]), is_secret=bool(row[4]),
is_runtime_reloadable=bool(row[5]), is_runtime_reloadable=bool(row[5]),
updated_at=row[6], updated_at=row[6],
updated_by=row[7] updated_by=row[7],
) )
raise ValueError("Failed to create setting") raise ValueError("Failed to create setting")
@@ -505,7 +497,7 @@ class ConfigSettingRepository:
is_secret=bool(row[4]), is_secret=bool(row[4]),
is_runtime_reloadable=bool(row[5]), is_runtime_reloadable=bool(row[5]),
updated_at=row[6], updated_at=row[6],
updated_by=row[7] updated_by=row[7],
) )
return None return None
@@ -539,7 +531,7 @@ class ConfigSettingRepository:
is_secret=bool(row[4]), is_secret=bool(row[4]),
is_runtime_reloadable=bool(row[5]), is_runtime_reloadable=bool(row[5]),
updated_at=row[6], updated_at=row[6],
updated_by=row[7] updated_by=row[7],
) )
raise ValueError("Failed to update setting") raise ValueError("Failed to update setting")
@@ -585,7 +577,7 @@ class ConfigSettingRepository:
is_secret=bool(row[4]), is_secret=bool(row[4]),
is_runtime_reloadable=bool(row[5]), is_runtime_reloadable=bool(row[5]),
updated_at=row[6], updated_at=row[6],
updated_by=row[7] updated_by=row[7],
) )
for row in cursor.fetchall() for row in cursor.fetchall()
] ]
@@ -602,7 +594,7 @@ class ConfigSettingRepository:
row = cursor.fetchone() row = cursor.fetchone()
if row and row[0]: if row and row[0]:
# Convert string timestamp to datetime # Convert string timestamp to datetime
return datetime.fromisoformat(row[0].replace('Z', '+00:00')) return datetime.fromisoformat(row[0].replace("Z", "+00:00"))
return None return None
@@ -635,7 +627,7 @@ class ConfigPairingRepository:
enabled=bool(row[3]), enabled=bool(row[3]),
source=row[4], source=row[4],
created_at=row[5], created_at=row[5],
updated_at=row[6] updated_at=row[6],
) )
raise ValueError("Failed to create pairing") raise ValueError("Failed to create pairing")
@@ -659,11 +651,13 @@ class ConfigPairingRepository:
enabled=bool(row[3]), enabled=bool(row[3]),
source=row[4], source=row[4],
created_at=row[5], created_at=row[5],
updated_at=row[6] updated_at=row[6],
) )
return None return None
def update_pairing(self, base_asset: str, quote_asset: str, pairing: ConfigPairing) -> ConfigPairing: def update_pairing(
self, base_asset: str, quote_asset: str, pairing: ConfigPairing
) -> ConfigPairing:
"""Update an existing currency pairing.""" """Update an existing currency pairing."""
with self._store.connect() as conn: with self._store.connect() as conn:
cursor = conn.execute( cursor = conn.execute(
@@ -689,7 +683,7 @@ class ConfigPairingRepository:
enabled=bool(row[3]), enabled=bool(row[3]),
source=row[4], source=row[4],
created_at=row[5], created_at=row[5],
updated_at=row[6] updated_at=row[6],
) )
raise ValueError("Failed to update pairing") raise ValueError("Failed to update pairing")
@@ -723,7 +717,7 @@ class ConfigPairingRepository:
enabled=bool(row[3]), enabled=bool(row[3]),
source=row[4], source=row[4],
created_at=row[5], created_at=row[5],
updated_at=row[6] updated_at=row[6],
) )
for row in cursor.fetchall() for row in cursor.fetchall()
] ]
@@ -743,8 +737,11 @@ class ConfigBacktestingDefaultsRepository:
RETURNING id, starting_balances, trade_capital, min_profit_threshold, slippage_bps, execution_latency_ms RETURNING id, starting_balances, trade_capital, min_profit_threshold, slippage_bps, execution_latency_ms
""", """,
( (
orjson.dumps(defaults.starting_balances).decode( (
'utf-8') if defaults.starting_balances else None, orjson.dumps(defaults.starting_balances).decode("utf-8")
if defaults.starting_balances
else None
),
defaults.trade_capital, defaults.trade_capital,
defaults.min_profit_threshold, defaults.min_profit_threshold,
defaults.slippage_bps, defaults.slippage_bps,
@@ -758,7 +755,7 @@ class ConfigBacktestingDefaultsRepository:
trade_capital=row[2], trade_capital=row[2],
min_profit_threshold=row[3], min_profit_threshold=row[3],
slippage_bps=row[4], slippage_bps=row[4],
execution_latency_ms=row[5] execution_latency_ms=row[5],
) )
raise ValueError("Failed to create backtesting defaults") raise ValueError("Failed to create backtesting defaults")
@@ -780,7 +777,7 @@ class ConfigBacktestingDefaultsRepository:
trade_capital=row[2], trade_capital=row[2],
min_profit_threshold=row[3], min_profit_threshold=row[3],
slippage_bps=row[4], slippage_bps=row[4],
execution_latency_ms=row[5] execution_latency_ms=row[5],
) )
return None return None
@@ -797,8 +794,11 @@ class ConfigBacktestingDefaultsRepository:
RETURNING id, starting_balances, trade_capital, min_profit_threshold, slippage_bps, execution_latency_ms RETURNING id, starting_balances, trade_capital, min_profit_threshold, slippage_bps, execution_latency_ms
""", """,
( (
orjson.dumps(defaults.starting_balances).decode( (
'utf-8') if defaults.starting_balances else None, orjson.dumps(defaults.starting_balances).decode("utf-8")
if defaults.starting_balances
else None
),
defaults.trade_capital, defaults.trade_capital,
defaults.min_profit_threshold, defaults.min_profit_threshold,
defaults.slippage_bps, defaults.slippage_bps,
@@ -812,7 +812,7 @@ class ConfigBacktestingDefaultsRepository:
trade_capital=row[2], trade_capital=row[2],
min_profit_threshold=row[3], min_profit_threshold=row[3],
slippage_bps=row[4], slippage_bps=row[4],
execution_latency_ms=row[5] execution_latency_ms=row[5],
) )
raise ValueError("Failed to update backtesting defaults") raise ValueError("Failed to update backtesting defaults")
@@ -847,10 +847,16 @@ class KrakenAccountSnapshotRepository:
snapshot.maker_fee, snapshot.maker_fee,
snapshot.taker_fee, snapshot.taker_fee,
snapshot.thirty_day_volume, snapshot.thirty_day_volume,
orjson.dumps(snapshot.trade_balance_raw).decode("utf-8") (
if snapshot.trade_balance_raw else None, orjson.dumps(snapshot.trade_balance_raw).decode("utf-8")
orjson.dumps(snapshot.fee_schedule_raw).decode("utf-8") if snapshot.trade_balance_raw
if snapshot.fee_schedule_raw else None, else None
),
(
orjson.dumps(snapshot.fee_schedule_raw).decode("utf-8")
if snapshot.fee_schedule_raw
else None
),
), ),
) )
@@ -895,7 +901,9 @@ class BacktestJobRepository:
def __init__(self, store: DuckDBStore) -> None: def __init__(self, store: DuckDBStore) -> None:
self._store = store self._store = store
def create_job(self, events_path: str, config: dict[str, Any] | None = None) -> BacktestJobRecord: def create_job(
self, events_path: str, config: dict[str, Any] | None = None
) -> BacktestJobRecord:
with self._store.connect() as conn: with self._store.connect() as conn:
row = conn.execute( row = conn.execute(
""" """
@@ -903,13 +911,14 @@ class BacktestJobRepository:
VALUES (?, ?) VALUES (?, ?)
RETURNING id, status, events_path, config, created_at RETURNING id, status, events_path, config, created_at
""", """,
(events_path, orjson.dumps(config).decode( (events_path, orjson.dumps(config).decode("utf-8") if config else None),
"utf-8") if config else None),
).fetchone() ).fetchone()
if row is None: if row is None:
raise ValueError("Failed to create backtest job") raise ValueError("Failed to create backtest job")
return BacktestJobRecord( return BacktestJobRecord(
id=str(row[0]), status=str(row[1]), events_path=str(row[2]), id=str(row[0]),
status=str(row[1]),
events_path=str(row[2]),
config=orjson.loads(row[3]) if row[3] else None, config=orjson.loads(row[3]) if row[3] else None,
created_at=row[4], created_at=row[4],
) )
@@ -950,11 +959,15 @@ class BacktestJobRepository:
if row is None: if row is None:
return None return None
return BacktestJobRecord( return BacktestJobRecord(
id=str(row[0]), status=str(row[1]), events_path=str(row[2]), id=str(row[0]),
status=str(row[1]),
events_path=str(row[2]),
config=orjson.loads(row[3]) if row[3] else None, config=orjson.loads(row[3]) if row[3] else None,
report=orjson.loads(row[4]) if row[4] else None, report=orjson.loads(row[4]) if row[4] else None,
error=str(row[5]) if row[5] else None, error=str(row[5]) if row[5] else None,
created_at=row[6], started_at=row[7], finished_at=row[8], created_at=row[6],
started_at=row[7],
finished_at=row[8],
) )
def list_jobs(self, limit: int = 20) -> list[BacktestJobRecord]: def list_jobs(self, limit: int = 20) -> list[BacktestJobRecord]:
@@ -967,11 +980,15 @@ class BacktestJobRepository:
).fetchall() ).fetchall()
return [ return [
BacktestJobRecord( BacktestJobRecord(
id=str(r[0]), status=str(r[1]), events_path=str(r[2]), id=str(r[0]),
status=str(r[1]),
events_path=str(r[2]),
config=orjson.loads(r[3]) if r[3] else None, config=orjson.loads(r[3]) if r[3] else None,
report=orjson.loads(r[4]) if r[4] else None, report=orjson.loads(r[4]) if r[4] else None,
error=str(r[5]) if r[5] else None, error=str(r[5]) if r[5] else None,
created_at=r[6], started_at=r[7], finished_at=r[8], created_at=r[6],
started_at=r[7],
finished_at=r[8],
) )
for r in rows for r in rows
] ]
@@ -979,6 +996,7 @@ class BacktestJobRepository:
def delete_job(self, job_id: str) -> bool: def delete_job(self, job_id: str) -> bool:
with self._store.connect() as conn: with self._store.connect() as conn:
result = conn.execute( result = conn.execute(
"DELETE FROM backtest_jobs WHERE id = ?", (job_id,), "DELETE FROM backtest_jobs WHERE id = ?",
(job_id,),
) )
return result.rowcount > 0 return result.rowcount > 0
+6 -12
View File
@@ -191,8 +191,7 @@ async def test_dashboard_page_and_fragment_and_sse(tmp_path) -> None:
assert "trade-open" in overview.text assert "trade-open" in overview.text
assert overview_stream.status_code == 200 assert overview_stream.status_code == 200
assert overview_stream.headers["content-type"].startswith( assert overview_stream.headers["content-type"].startswith("text/event-stream")
"text/event-stream")
assert "event: overview" in overview_stream.text assert "event: overview" in overview_stream.text
assert "trade-open" in overview_stream.text assert "trade-open" in overview_stream.text
@@ -262,8 +261,7 @@ async def test_dashboard_controls_update_runtime_state_and_config(tmp_path) -> N
assert app.state.settings.max_trade_capital_usd == 300.0 assert app.state.settings.max_trade_capital_usd == 300.0
assert app.state.settings.max_concurrent_trades == 4 assert app.state.settings.max_concurrent_trades == 4
assert app.state.settings.paper_trading_mode is True assert app.state.settings.paper_trading_mode is True
assert app.state.dashboard_controls.tradable_pairs == [ assert app.state.dashboard_controls.tradable_pairs == ["BTC/USD", "ETH/BTC"]
"BTC/USD", "ETH/BTC"]
assert app.state.dashboard_controls.strategy_mode == "paper" assert app.state.dashboard_controls.strategy_mode == "paper"
assert app.state.dashboard_controls.strategy_profit_threshold == 0.0025 assert app.state.dashboard_controls.strategy_profit_threshold == 0.0025
assert app.state.dashboard_controls.strategy_max_depth_levels == 7 assert app.state.dashboard_controls.strategy_max_depth_levels == 7
@@ -275,14 +273,10 @@ async def test_dashboard_controls_update_runtime_state_and_config(tmp_path) -> N
assert audit_recent.status_code == 200 assert audit_recent.status_code == 200
entries = audit_recent.json()["entries"] entries = audit_recent.json()["entries"]
assert len(entries) >= 4 assert len(entries) >= 4
assert any(entry["event_type"] == assert any(entry["event_type"] == "dashboard.control.stop" for entry in entries)
"dashboard.control.stop" for entry in entries) assert any(entry["event_type"] == "dashboard.control.start" for entry in entries)
assert any(entry["event_type"] == assert any(entry["event_type"] == "dashboard.control.kill_switch" for entry in entries)
"dashboard.control.start" for entry in entries) assert any(entry["event_type"] == "dashboard.control.config" for entry in entries)
assert any(entry["event_type"] ==
"dashboard.control.kill_switch" for entry in entries)
assert any(entry["event_type"] ==
"dashboard.control.config" for entry in entries)
async def test_dashboard_controls_emit_alerts(tmp_path) -> None: async def test_dashboard_controls_emit_alerts(tmp_path) -> None:
+1 -1
View File
@@ -24,7 +24,7 @@ def test_end_to_end_config_workflow():
assert service.get_last_updated_at() is None assert service.get_last_updated_at() is None
# Test setting a value # Test setting a value
with patch('arbitrade.config.service.ConfigSettingRepository') as mock_repo_class: with patch("arbitrade.config.service.ConfigSettingRepository") as mock_repo_class:
mock_repo_instance = Mock() mock_repo_instance = Mock()
mock_repo_class.return_value = mock_repo_instance mock_repo_class.return_value = mock_repo_instance
+74 -32
View File
@@ -6,10 +6,13 @@ from unittest.mock import Mock, patch
from arbitrade.storage.repositories import ( from arbitrade.storage.repositories import (
ConfigSettingRepository, ConfigSettingRepository,
ConfigPairingRepository, ConfigPairingRepository,
ConfigPairFeeRepository, ConfigBacktestingDefaultsRepository,
ConfigBacktestingDefaultsRepository )
from arbitrade.config.service import (
ConfigSetting,
ConfigPairing,
ConfigBacktestingDefaults,
) )
from arbitrade.config.service import ConfigSetting, ConfigPairing, ConfigPairFee, ConfigBacktestingDefaults
from arbitrade.storage.db import DuckDBStore from arbitrade.storage.db import DuckDBStore
@@ -31,13 +34,20 @@ def test_config_setting_repository_create_setting(mock_store):
repo = ConfigSettingRepository(mock_store) repo = ConfigSettingRepository(mock_store)
# Mock database connection # Mock database connection
with patch.object(mock_store, 'connect') as mock_connect: with patch.object(mock_store, "connect") as mock_connect:
mock_cursor = Mock() mock_cursor = Mock()
mock_connect.return_value.__enter__.return_value = mock_cursor mock_connect.return_value.__enter__.return_value = mock_cursor
# Mock the return value # Mock the return value
mock_cursor.fetchone.return_value = [ mock_cursor.fetchone.return_value = [
"test_key", "test_section", "test_value", "str", False, False, "2023-01-01T00:00:00", "test_user" "test_key",
"test_section",
"test_value",
"str",
False,
False,
"2023-01-01T00:00:00",
"test_user",
] ]
# Create setting # Create setting
@@ -48,7 +58,7 @@ def test_config_setting_repository_create_setting(mock_store):
value_type="str", value_type="str",
is_secret=False, is_secret=False,
is_runtime_reloadable=False, is_runtime_reloadable=False,
updated_by="test_user" updated_by="test_user",
) )
result = repo.create_setting(setting) result = repo.create_setting(setting)
@@ -67,13 +77,20 @@ def test_config_setting_repository_get_setting(mock_store):
repo = ConfigSettingRepository(mock_store) repo = ConfigSettingRepository(mock_store)
# Mock database connection # Mock database connection
with patch.object(mock_store, 'connect') as mock_connect: with patch.object(mock_store, "connect") as mock_connect:
mock_cursor = Mock() mock_cursor = Mock()
mock_connect.return_value.__enter__.return_value = mock_cursor mock_connect.return_value.__enter__.return_value = mock_cursor
# Mock the return value # Mock the return value
mock_cursor.fetchone.return_value = [ mock_cursor.fetchone.return_value = [
"test_key", "test_section", "test_value", "str", False, False, "2023-01-01T00:00:00", "test_user" "test_key",
"test_section",
"test_value",
"str",
False,
False,
"2023-01-01T00:00:00",
"test_user",
] ]
# Get setting # Get setting
@@ -93,13 +110,20 @@ def test_config_setting_repository_update_setting(mock_store):
repo = ConfigSettingRepository(mock_store) repo = ConfigSettingRepository(mock_store)
# Mock database connection # Mock database connection
with patch.object(mock_store, 'connect') as mock_connect: with patch.object(mock_store, "connect") as mock_connect:
mock_cursor = Mock() mock_cursor = Mock()
mock_connect.return_value.__enter__.return_value = mock_cursor mock_connect.return_value.__enter__.return_value = mock_cursor
# Mock the return value # Mock the return value
mock_cursor.fetchone.return_value = [ mock_cursor.fetchone.return_value = [
"test_key", "test_section", "updated_value", "str", False, False, "2023-01-01T00:00:00", "test_user" "test_key",
"test_section",
"updated_value",
"str",
False,
False,
"2023-01-01T00:00:00",
"test_user",
] ]
# Update setting # Update setting
@@ -110,7 +134,7 @@ def test_config_setting_repository_update_setting(mock_store):
value_type="str", value_type="str",
is_secret=False, is_secret=False,
is_runtime_reloadable=False, is_runtime_reloadable=False,
updated_by="test_user" updated_by="test_user",
) )
result = repo.update_setting("test_key", setting) result = repo.update_setting("test_key", setting)
@@ -129,16 +153,32 @@ def test_config_setting_repository_list_settings(mock_store):
repo = ConfigSettingRepository(mock_store) repo = ConfigSettingRepository(mock_store)
# Mock database connection # Mock database connection
with patch.object(mock_store, 'connect') as mock_connect: with patch.object(mock_store, "connect") as mock_connect:
mock_cursor = Mock() mock_cursor = Mock()
mock_connect.return_value.__enter__.return_value = mock_cursor mock_connect.return_value.__enter__.return_value = mock_cursor
# Mock the return value # Mock the return value
mock_cursor.fetchall.return_value = [ mock_cursor.fetchall.return_value = [
["test_key1", "test_section", "test_value1", "str", [
False, False, "2023-01-01T00:00:00", "test_user"], "test_key1",
["test_key2", "test_section", "test_value2", "str", "test_section",
False, False, "2023-01-01T00:00:00", "test_user"] "test_value1",
"str",
False,
False,
"2023-01-01T00:00:00",
"test_user",
],
[
"test_key2",
"test_section",
"test_value2",
"str",
False,
False,
"2023-01-01T00:00:00",
"test_user",
],
] ]
# List settings # List settings
@@ -156,7 +196,7 @@ def test_config_setting_repository_get_latest_updated_at(mock_store):
repo = ConfigSettingRepository(mock_store) repo = ConfigSettingRepository(mock_store)
# Mock database connection # Mock database connection
with patch.object(mock_store, 'connect') as mock_connect: with patch.object(mock_store, "connect") as mock_connect:
mock_cursor = Mock() mock_cursor = Mock()
mock_connect.return_value.__enter__.return_value = mock_cursor mock_connect.return_value.__enter__.return_value = mock_cursor
@@ -182,22 +222,24 @@ def test_config_pairing_repository_create_pairing(mock_store):
repo = ConfigPairingRepository(mock_store) repo = ConfigPairingRepository(mock_store)
# Mock database connection # Mock database connection
with patch.object(mock_store, 'connect') as mock_connect: with patch.object(mock_store, "connect") as mock_connect:
mock_cursor = Mock() mock_cursor = Mock()
mock_connect.return_value.__enter__.return_value = mock_cursor mock_connect.return_value.__enter__.return_value = mock_cursor
# Mock the return value # Mock the return value
mock_cursor.fetchone.return_value = [ mock_cursor.fetchone.return_value = [
1, "BTC", "USD", True, "Kraken", "2023-01-01T00:00:00", "2023-01-01T00:00:00" 1,
"BTC",
"USD",
True,
"Kraken",
"2023-01-01T00:00:00",
"2023-01-01T00:00:00",
] ]
# Create pairing # Create pairing
pairing = ConfigPairing( pairing = ConfigPairing(
base_asset="BTC", base_asset="BTC", quote_asset="USD", enabled=True, source="Kraken")
quote_asset="USD",
enabled=True,
source="Kraken"
)
result = repo.create_pairing(pairing) result = repo.create_pairing(pairing)
@@ -214,13 +256,19 @@ def test_config_pairing_repository_get_pairing(mock_store):
repo = ConfigPairingRepository(mock_store) repo = ConfigPairingRepository(mock_store)
# Mock database connection # Mock database connection
with patch.object(mock_store, 'connect') as mock_connect: with patch.object(mock_store, "connect") as mock_connect:
mock_cursor = Mock() mock_cursor = Mock()
mock_connect.return_value.__enter__.return_value = mock_cursor mock_connect.return_value.__enter__.return_value = mock_cursor
# Mock the return value # Mock the return value
mock_cursor.fetchone.return_value = [ mock_cursor.fetchone.return_value = [
1, "BTC", "USD", True, "Kraken", "2023-01-01T00:00:00", "2023-01-01T00:00:00" 1,
"BTC",
"USD",
True,
"Kraken",
"2023-01-01T00:00:00",
"2023-01-01T00:00:00",
] ]
# Get pairing # Get pairing
@@ -234,12 +282,6 @@ def test_config_pairing_repository_get_pairing(mock_store):
assert result.source == "Kraken" assert result.source == "Kraken"
def test_config_pair_fee_repository_initialization(mock_store):
"""Test ConfigPairFeeRepository initialization."""
repo = ConfigPairFeeRepository(mock_store)
assert repo._store == mock_store
def test_config_backtesting_defaults_repository_initialization(mock_store): def test_config_backtesting_defaults_repository_initialization(mock_store):
"""Test ConfigBacktestingDefaultsRepository initialization.""" """Test ConfigBacktestingDefaultsRepository initialization."""
repo = ConfigBacktestingDefaultsRepository(mock_store) repo = ConfigBacktestingDefaultsRepository(mock_store)
+14 -26
View File
@@ -31,9 +31,7 @@ def mock_audit_repo():
return audit_repo return audit_repo
def test_configuration_service_initialization( def test_configuration_service_initialization(mock_settings, mock_store, mock_audit_repo):
mock_settings, mock_store, mock_audit_repo
):
"""Test that ConfigurationService initializes correctly.""" """Test that ConfigurationService initializes correctly."""
# Create service instance # Create service instance
service = ConfigurationService(mock_settings, mock_store, mock_audit_repo) service = ConfigurationService(mock_settings, mock_store, mock_audit_repo)
@@ -46,9 +44,7 @@ def test_configuration_service_initialization(
assert isinstance(service._loaded_settings, dict) assert isinstance(service._loaded_settings, dict)
def test_configuration_service_get_setting( def test_configuration_service_get_setting(mock_settings, mock_store, mock_audit_repo):
mock_settings, mock_store, mock_audit_repo
):
"""Test getting configuration settings.""" """Test getting configuration settings."""
# Create service instance # Create service instance
service = ConfigurationService(mock_settings, mock_store, mock_audit_repo) service = ConfigurationService(mock_settings, mock_store, mock_audit_repo)
@@ -65,15 +61,13 @@ def test_configuration_service_get_setting(
assert result == "default" assert result == "default"
def test_configuration_service_set_setting( def test_configuration_service_set_setting(mock_settings, mock_store, mock_audit_repo):
mock_settings, mock_store, mock_audit_repo
):
"""Test setting configuration settings.""" """Test setting configuration settings."""
# Create service instance # Create service instance
service = ConfigurationService(mock_settings, mock_store, mock_audit_repo) service = ConfigurationService(mock_settings, mock_store, mock_audit_repo)
# Mock the repository # Mock the repository
with patch('arbitrade.config.service.ConfigSettingRepository') as mock_repo_class: with patch("arbitrade.config.service.ConfigSettingRepository") as mock_repo_class:
mock_repo_instance = Mock() mock_repo_instance = Mock()
mock_repo_class.return_value = mock_repo_instance mock_repo_class.return_value = mock_repo_instance
@@ -90,9 +84,7 @@ def test_configuration_service_set_setting(
mock_repo_instance.create_setting.assert_called_once() mock_repo_instance.create_setting.assert_called_once()
def test_configuration_service_hot_reload_detection( def test_configuration_service_hot_reload_detection(mock_settings, mock_store, mock_audit_repo):
mock_settings, mock_store, mock_audit_repo
):
"""Test hot-reload detection functionality.""" """Test hot-reload detection functionality."""
# Create service instance # Create service instance
service = ConfigurationService(mock_settings, mock_store, mock_audit_repo) service = ConfigurationService(mock_settings, mock_store, mock_audit_repo)
@@ -101,27 +93,26 @@ def test_configuration_service_hot_reload_detection(
assert service.is_config_outdated() is False assert service.is_config_outdated() is False
# Test with mock repository that returns a timestamp # Test with mock repository that returns a timestamp
with patch('arbitrade.config.service.ConfigSettingRepository') as mock_repo_class: with patch("arbitrade.config.service.ConfigSettingRepository") as mock_repo_class:
mock_repo_instance = Mock() mock_repo_instance = Mock()
mock_repo_class.return_value = mock_repo_instance mock_repo_class.return_value = mock_repo_instance
# Mock the latest updated at timestamp # Mock the latest updated at timestamp
from datetime import datetime from datetime import datetime
mock_repo_instance.get_latest_updated_at.return_value = datetime.now() mock_repo_instance.get_latest_updated_at.return_value = datetime.now()
# Should detect as outdated when timestamp exists # Should detect as outdated when timestamp exists
assert service.is_config_outdated() is True assert service.is_config_outdated() is True
def test_configuration_service_reload_if_changed( def test_configuration_service_reload_if_changed(mock_settings, mock_store, mock_audit_repo):
mock_settings, mock_store, mock_audit_repo
):
"""Test hot-reload functionality.""" """Test hot-reload functionality."""
# Create service instance # Create service instance
service = ConfigurationService(mock_settings, mock_store, mock_audit_repo) service = ConfigurationService(mock_settings, mock_store, mock_audit_repo)
# Mock the repository # Mock the repository
with patch('arbitrade.config.service.ConfigSettingRepository') as mock_repo_class: with patch("arbitrade.config.service.ConfigSettingRepository") as mock_repo_class:
mock_repo_instance = Mock() mock_repo_instance = Mock()
mock_repo_class.return_value = mock_repo_instance mock_repo_class.return_value = mock_repo_instance
@@ -135,6 +126,7 @@ def test_configuration_service_reload_if_changed(
# Mock the latest updated at timestamp to return a value # Mock the latest updated at timestamp to return a value
from datetime import datetime from datetime import datetime
mock_repo_instance.get_latest_updated_at.return_value = datetime.now() mock_repo_instance.get_latest_updated_at.return_value = datetime.now()
# Should reload when outdated # Should reload when outdated
@@ -143,9 +135,7 @@ def test_configuration_service_reload_if_changed(
assert service.get_config_version() == 1 assert service.get_config_version() == 1
def test_configuration_service_get_config_version( def test_configuration_service_get_config_version(mock_settings, mock_store, mock_audit_repo):
mock_settings, mock_store, mock_audit_repo
):
"""Test getting configuration version.""" """Test getting configuration version."""
# Create service instance # Create service instance
service = ConfigurationService(mock_settings, mock_store, mock_audit_repo) service = ConfigurationService(mock_settings, mock_store, mock_audit_repo)
@@ -154,7 +144,7 @@ def test_configuration_service_get_config_version(
assert service.get_config_version() == 0 assert service.get_config_version() == 0
# After setting a value, version should increment # After setting a value, version should increment
with patch('arbitrade.config.service.ConfigSettingRepository') as mock_repo_class: with patch("arbitrade.config.service.ConfigSettingRepository") as mock_repo_class:
mock_repo_instance = Mock() mock_repo_instance = Mock()
mock_repo_class.return_value = mock_repo_instance mock_repo_class.return_value = mock_repo_instance
@@ -166,9 +156,7 @@ def test_configuration_service_get_config_version(
assert service.get_config_version() == 1 assert service.get_config_version() == 1
def test_configuration_service_get_last_updated_at( def test_configuration_service_get_last_updated_at(mock_settings, mock_store, mock_audit_repo):
mock_settings, mock_store, mock_audit_repo
):
"""Test getting last updated timestamp.""" """Test getting last updated timestamp."""
# Create service instance # Create service instance
service = ConfigurationService(mock_settings, mock_store, mock_audit_repo) service = ConfigurationService(mock_settings, mock_store, mock_audit_repo)
@@ -177,7 +165,7 @@ def test_configuration_service_get_last_updated_at(
assert service.get_last_updated_at() is None assert service.get_last_updated_at() is None
# After setting a value, should have timestamp # After setting a value, should have timestamp
with patch('arbitrade.config.service.ConfigSettingRepository') as mock_repo_class: with patch("arbitrade.config.service.ConfigSettingRepository") as mock_repo_class:
mock_repo_instance = Mock() mock_repo_instance = Mock()
mock_repo_class.return_value = mock_repo_instance mock_repo_class.return_value = mock_repo_instance
+1 -3
View File
@@ -12,8 +12,6 @@ def test_template_directory_resolves_to_existing_location() -> None:
def test_template_exists_in_package_resources() -> None: def test_template_exists_in_package_resources() -> None:
template_path = resources.files("arbitrade").joinpath( template_path = resources.files("arbitrade").joinpath("web", "templates", "dashboard.html")
"web", "templates", "dashboard.html"
)
assert template_path.is_file() assert template_path.is_file()