Refactor code for improved readability and consistency
CI / lint-test-build (push) Failing after 12s

- Consolidated multiline string formatting into single-line for SQL queries in multiple files.
- Adjusted argument formatting in function calls for better alignment and readability.
- Removed unnecessary line breaks and improved spacing in various sections of the codebase.
- Updated test cases to maintain consistency in formatting and improve clarity.
This commit is contained in:
2026-06-04 19:04:30 +02:00
parent 7d18bdf316
commit c8e3daeb57
21 changed files with 377 additions and 383 deletions
+9 -12
View File
@@ -19,10 +19,12 @@ def _resolve_fee_rate(fee_rate: float | None, db_path: str | None = None) -> flo
if db_path is not None:
try:
conn = duckdb.connect(db_path)
row = conn.execute("""
row = conn.execute(
"""
SELECT maker_fee FROM kraken_account_snapshots
ORDER BY snapshot_at DESC LIMIT 1
""").fetchone()
"""
).fetchone()
conn.close()
if row is not None and row[0] is not None:
return float(row[0])
@@ -51,16 +53,14 @@ def _parse_balances(raw: str) -> Mapping[str, float]:
def main() -> int:
parser = argparse.ArgumentParser(
description="Run a deterministic replay backtest.")
parser = argparse.ArgumentParser(description="Run a deterministic replay backtest.")
parser.add_argument("--events", type=Path, required=True)
parser.add_argument("--starting-balances", type=str, default="USD=1000.0")
parser.add_argument("--trade-capital", type=float, default=100.0)
parser.add_argument("--fee-rate", type=float, default=None)
parser.add_argument("--slippage-bps", type=float, default=4.0)
parser.add_argument("--execution-latency-ms", type=float, default=20.0)
parser.add_argument("--db-path", type=str, default=None,
help="DuckDB path for fee lookup")
parser.add_argument("--db-path", type=str, default=None, help="DuckDB path for fee lookup")
args = parser.parse_args()
cycles_by_pair, available_pairs = _build_graph()
@@ -80,18 +80,15 @@ def main() -> int:
started_at=events[0].occurred_at if events else datetime.now(UTC),
)
report = asyncio.run(
engine.run(events, starting_balances=_parse_balances(
args.starting_balances))
engine.run(events, starting_balances=_parse_balances(args.starting_balances))
)
print("Backtest report:")
print(f"- processed_events: {report.processed_events}")
print(f"- opportunities_seen: {report.opportunities_seen}")
print(f"- trades_executed: {report.trades_executed}")
print(
f"- win_rate: {report.win_rate if report.win_rate is not None else 'n/a'}")
print(
f"- fill_rate: {report.fill_rate if report.fill_rate is not None else 'n/a'}")
print(f"- win_rate: {report.win_rate if report.win_rate is not None else 'n/a'}")
print(f"- fill_rate: {report.fill_rate if report.fill_rate is not None else 'n/a'}")
print(f"- realized_pnl_usd: {report.realized_pnl_usd:.4f}")
print(f"- max_drawdown_usd: {report.max_drawdown_usd:.4f}")
print(f"- miss_reasons: {dict(report.miss_reasons)}")
+12 -17
View File
@@ -36,8 +36,7 @@ def _parse_float_list(raw: str) -> list[float]:
def _parse_pair_universes(raw: str) -> list[tuple[str, ...]]:
universes: list[tuple[str, ...]] = []
for chunk in raw.split(";"):
symbols = tuple(item.strip().upper()
for item in chunk.split("|") if item.strip())
symbols = tuple(item.strip().upper() for item in chunk.split("|") if item.strip())
if symbols:
universes.append(symbols)
if not universes:
@@ -75,31 +74,29 @@ def _print_top_results(results: Sequence[SweepResult], *, limit: int = 5) -> Non
def main() -> int:
parser = argparse.ArgumentParser(
description="Run backtesting parameter sweep with train/test split.")
description="Run backtesting parameter sweep with train/test split."
)
parser.add_argument("--events", type=Path, required=True)
parser.add_argument("--starting-balances", type=str, default="USD=1000.0")
parser.add_argument("--theta-values", type=str,
default="0.0003,0.0005,0.0008")
parser.add_argument("--trade-capital-values",
type=str, default="50,100,150")
parser.add_argument("--theta-values", type=str, default="0.0003,0.0005,0.0008")
parser.add_argument("--trade-capital-values", type=str, default="50,100,150")
parser.add_argument(
"--pair-universes",
type=str,
default="BTC/USD|ETH/BTC|ETH/USD",
help="Semicolon-separated universes, each with | delimited pairs",
)
parser.add_argument("--staleness-threshold-values",
type=str, default="3,5,8")
parser.add_argument("--staleness-threshold-values", type=str, default="3,5,8")
parser.add_argument("--train-ratio", type=float, default=0.7)
parser.add_argument("--output", type=Path,
default=Path("ops/backtesting/parameter_sweep_results.json"))
parser.add_argument(
"--output", type=Path, default=Path("ops/backtesting/parameter_sweep_results.json")
)
parser.add_argument("--min-test-realized-pnl-usd", type=float, default=0.0)
parser.add_argument("--min-test-win-rate", type=float, default=0.5)
parser.add_argument("--min-test-fill-rate", type=float, default=0.9)
parser.add_argument("--max-test-drawdown-usd", type=float, default=25.0)
parser.add_argument("--max-generalization-gap-ratio",
type=float, default=0.5)
parser.add_argument("--max-generalization-gap-ratio", type=float, default=0.5)
args = parser.parse_args()
@@ -107,15 +104,13 @@ def main() -> int:
symbols = sorted({event.symbol.upper() for event in events})
cycles_by_pair = _build_graph_from_symbols(symbols)
if not cycles_by_pair:
raise SystemExit(
"No triangular cycles found in supplied replay events")
raise SystemExit("No triangular cycles found in supplied replay events")
grid = build_parameter_grid(
theta_values=_parse_float_list(args.theta_values),
trade_capital_values=_parse_float_list(args.trade_capital_values),
pair_universes=_parse_pair_universes(args.pair_universes),
staleness_threshold_values=_parse_float_list(
args.staleness_threshold_values),
staleness_threshold_values=_parse_float_list(args.staleness_threshold_values),
)
artifacts = run_parameter_search(
+4 -2
View File
@@ -13,11 +13,13 @@ from arbitrade.storage.db import DuckDBStore
def _python_scan_compute(store: DuckDBStore) -> tuple[float, float | None, float | None]:
with store.connect() as conn:
trade_rows = conn.execute("""
trade_rows = conn.execute(
"""
SELECT started_at, finished_at, realized_pnl
FROM trades
WHERE finished_at IS NOT NULL
""").fetchall()
"""
).fetchall()
opportunity_rows = conn.execute("SELECT detected_at FROM opportunities").fetchall()
realized = sum(float(row[2]) for row in trade_rows if row[2] is not None)
+4 -4
View File
@@ -29,8 +29,9 @@ def create_app(settings: Settings) -> FastAPI:
db.migrate()
kraken_client = KrakenRestClient(settings)
fee_sync_stop_event = asyncio.Event()
backtest_queue: asyncio.Queue[tuple[str, str,
dict[str, object] | None] | None] = asyncio.Queue()
backtest_queue: asyncio.Queue[tuple[str, str, dict[str, object] | None] | None] = (
asyncio.Queue()
)
@asynccontextmanager
async def lifespan(app: FastAPI) -> AsyncIterator[None]:
@@ -75,8 +76,7 @@ def create_app(settings: Settings) -> FastAPI:
app.state.audit_repository = AuditRepository(db)
app.state.runtime_state_repository = RuntimeStateRepository(db)
app.state.alert_notifier = build_notifier_from_settings(settings)
app.state.configuration_service = ConfigurationService(
settings, db, AuditRepository(db))
app.state.configuration_service = ConfigurationService(settings, db, AuditRepository(db))
app.state.backtest_recent_reports = []
app.state.dashboard_controls = DashboardControlState(
is_running=not settings.kill_switch_active,
+51 -42
View File
@@ -19,7 +19,12 @@ from arbitrade.api.auth import require_dashboard_auth
from arbitrade.api.control_state import DashboardControlState
from arbitrade.backtesting.replay import BacktestConfig, BacktestReplayEngine, load_replay_events
from arbitrade.detection.graph import CurrencyGraph, TriangularCycle
from arbitrade.storage.repositories import AuditRecord, AuditRepository, BacktestJobRepository, KrakenAccountSnapshotRepository
from arbitrade.storage.repositories import (
AuditRecord,
AuditRepository,
BacktestJobRepository,
KrakenAccountSnapshotRepository,
)
router = APIRouter(dependencies=[Depends(require_dashboard_auth)])
public_router = APIRouter()
@@ -27,8 +32,7 @@ public_router = APIRouter()
def _resolve_templates_directory() -> str:
# Support source layout, Docker runtime (/app), and installed package data.
source_layout_path = Path(
__file__).resolve().parents[3] / "web" / "templates"
source_layout_path = Path(__file__).resolve().parents[3] / "web" / "templates"
if source_layout_path.is_dir():
return str(source_layout_path)
@@ -37,8 +41,7 @@ def _resolve_templates_directory() -> str:
return str(docker_runtime_path)
try:
package_path = resources.files(
"arbitrade").joinpath("web", "templates")
package_path = resources.files("arbitrade").joinpath("web", "templates")
if package_path.is_dir():
return str(package_path)
except (ModuleNotFoundError, AttributeError):
@@ -101,29 +104,37 @@ def _dashboard_overview(request: Request) -> dict[str, object]:
else:
open_trade_filter = "LOWER(status) NOT IN ('filled', 'closed', 'cancelled', 'canceled')"
portfolio_row = conn.execute("""
portfolio_row = conn.execute(
"""
SELECT balances, total_value_usd
FROM portfolio_snapshots
ORDER BY snapshot_at DESC
LIMIT 1
""").fetchone()
open_trades = conn.execute(f"""
"""
).fetchone()
open_trades = conn.execute(
f"""
SELECT {trade_ref_expr}, status, started_at, {cycle_expr}
FROM trades
WHERE {open_trade_filter}
ORDER BY started_at DESC
LIMIT 5
""").fetchall()
rpnl = conn.execute("""
"""
).fetchall()
rpnl = conn.execute(
"""
SELECT COALESCE(SUM(COALESCE(realized_pnl, 0)), 0)
FROM trades
""").fetchone()
latest_opportunities = conn.execute("""
"""
).fetchone()
latest_opportunities = conn.execute(
"""
SELECT cycle, net_pct, est_profit, detected_at
FROM opportunities
ORDER BY detected_at DESC
LIMIT 5
""").fetchall()
"""
).fetchall()
balances_value = ""
total_value = ""
@@ -135,8 +146,7 @@ def _dashboard_overview(request: Request) -> dict[str, object]:
parsed = json.loads(balances_raw)
if isinstance(parsed, dict):
# Filter out zero balances, show non-zero as "AMT ASSET"
non_zero = {k: float(v)
for k, v in parsed.items() if float(v) > 0.0}
non_zero = {k: float(v) for k, v in parsed.items() if float(v) > 0.0}
if non_zero:
balances_value = "<br>".join(
f"{v:.6g} {k}" for k, v in sorted(non_zero.items())
@@ -154,12 +164,14 @@ def _dashboard_overview(request: Request) -> dict[str, object]:
# Query equity from kraken_account_snapshots
try:
equity_row = conn.execute("""
equity_row = conn.execute(
"""
SELECT trade_balance_raw
FROM kraken_account_snapshots
ORDER BY snapshot_at DESC
LIMIT 1
""").fetchone()
"""
).fetchone()
if equity_row is not None and equity_row[0] is not None:
tb_raw = equity_row[0]
if isinstance(tb_raw, str):
@@ -195,12 +207,14 @@ def _dashboard_overview(request: Request) -> dict[str, object]:
taker_fee = ""
thirty_day_volume = ""
try:
acct_row = conn.execute("""
acct_row = conn.execute(
"""
SELECT fee_tier, maker_fee, taker_fee, thirty_day_volume
FROM kraken_account_snapshots
ORDER BY snapshot_at DESC
LIMIT 1
""").fetchone()
"""
).fetchone()
if acct_row is not None:
fee_tier = str(acct_row[0]) if acct_row[0] is not None else ""
maker_fee = f"{float(acct_row[1]):.4%}" if acct_row[1] is not None else ""
@@ -230,12 +244,14 @@ def _dashboard_overview(request: Request) -> dict[str, object]:
def _dashboard_charts(request: Request) -> dict[str, object]:
store = request.app.state.store
with store.connect() as conn:
opportunity_rows = conn.execute("""
opportunity_rows = conn.execute(
"""
SELECT detected_at, cycle, net_pct, est_profit
FROM opportunities
ORDER BY detected_at DESC
LIMIT 10
""").fetchall()
"""
).fetchall()
cr = list(reversed(opportunity_rows))
labels = []
@@ -375,12 +391,12 @@ def _dashboard_config_context(request: Request) -> dict[str, object]:
else ""
)
max_exposure_per_asset_value = (
f"{float(rs.max_exposure_per_asset_usd):.2f}" if rs.max_exposure_per_asset_usd is not None else ""
f"{float(rs.max_exposure_per_asset_usd):.2f}"
if rs.max_exposure_per_asset_usd is not None
else ""
)
daily_loss_limit = (
f"{float(rs.daily_loss_limit_usd):.2f} USD"
if rs.daily_loss_limit_usd is not None
else ""
f"{float(rs.daily_loss_limit_usd):.2f} USD" if rs.daily_loss_limit_usd is not None else ""
)
daily_loss_limit_value = (
f"{float(rs.daily_loss_limit_usd):.2f}" if rs.daily_loss_limit_usd is not None else ""
@@ -391,20 +407,18 @@ def _dashboard_config_context(request: Request) -> dict[str, object]:
else ""
)
cumulative_loss_limit_value = (
f"{float(rs.cumulative_loss_limit_usd):.2f}" if rs.cumulative_loss_limit_usd is not None else ""
f"{float(rs.cumulative_loss_limit_usd):.2f}"
if rs.cumulative_loss_limit_usd is not None
else ""
)
max_source_latency = (
f"{float(rs.max_source_latency_ms):.1f} ms"
if rs.max_source_latency_ms is not None
else ""
f"{float(rs.max_source_latency_ms):.1f} ms" if rs.max_source_latency_ms is not None else ""
)
max_source_latency_value = (
f"{float(rs.max_source_latency_ms):.1f}" if rs.max_source_latency_ms is not None else ""
)
max_apply_latency = (
f"{float(rs.max_apply_latency_ms):.1f} ms"
if rs.max_apply_latency_ms is not None
else ""
f"{float(rs.max_apply_latency_ms):.1f} ms" if rs.max_apply_latency_ms is not None else ""
)
max_apply_latency_value = (
f"{float(rs.max_apply_latency_ms):.1f}" if rs.max_apply_latency_ms is not None else ""
@@ -415,8 +429,7 @@ def _dashboard_config_context(request: Request) -> dict[str, object]:
max_consecutive_failures_value = (
str(rs.max_consecutive_failures) if rs.max_consecutive_failures is not None else ""
)
strategy_stat_arb_enabled = bool(
getattr(rs, "strategy_enable_stat_arb_experiment", False))
strategy_stat_arb_enabled = bool(getattr(rs, "strategy_enable_stat_arb_experiment", False))
return {
# Runtime
@@ -537,8 +550,7 @@ def _dashboard_controls(request: Request) -> dict[str, object]:
alerts_last_channel_results = [
str(item) for item in cast(list[object], alert_status.get("last_channel_results", []))
]
strategy_stat_arb_enabled = bool(
getattr(rs, "strategy_enable_stat_arb_experiment", False))
strategy_stat_arb_enabled = bool(getattr(rs, "strategy_enable_stat_arb_experiment", False))
return {
"execution_status": "running" if ctl.is_running else "stopped",
@@ -943,11 +955,9 @@ async def dashboard_backtesting_run(request: Request) -> HTMLResponse:
try:
custom_fee_rate = (
float(defaults["custom_fee_rate"])
if defaults["custom_fee_rate"].strip() else None
float(defaults["custom_fee_rate"]) if defaults["custom_fee_rate"].strip() else None
)
fee_rate = _fee_rate_for_profile(
defaults["fee_profile"], custom_fee_rate, request=request)
fee_rate = _fee_rate_for_profile(defaults["fee_profile"], custom_fee_rate, request=request)
config_dict: dict[str, object] = {
"source": defaults["source"],
@@ -1067,8 +1077,7 @@ async def dashboard_backtesting_export(request: Request, job_id: str) -> Respons
return Response(
content=orjson.dumps(payload).decode("utf-8"),
media_type="application/x-jsonlines",
headers={
"Content-Disposition": f"attachment; filename=backtest_{job_id[:8]}.jsonl"},
headers={"Content-Disposition": f"attachment; filename=backtest_{job_id[:8]}.jsonl"},
)
+7 -16
View File
@@ -153,8 +153,7 @@ def _parse_book_levels(raw_levels: Any) -> tuple[BookLevel, ...]:
or not isinstance(raw_level[1], int | float)
):
raise ValueError("Each level must be [price, volume]")
levels.append(BookLevel(price=float(
raw_level[0]), volume=float(raw_level[1])))
levels.append(BookLevel(price=float(raw_level[0]), volume=float(raw_level[1])))
return tuple(levels)
@@ -173,8 +172,7 @@ def load_replay_events(path: Path) -> list[ReplayBookEvent]:
if not isinstance(timestamp_raw, str) or not isinstance(symbol_raw, str):
raise ValueError("Each event must include timestamp and symbol")
occurred_at = datetime.fromisoformat(
timestamp_raw.replace("Z", "+00:00")).astimezone(UTC)
occurred_at = datetime.fromisoformat(timestamp_raw.replace("Z", "+00:00")).astimezone(UTC)
events.append(
ReplayBookEvent(
occurred_at=occurred_at,
@@ -266,10 +264,7 @@ def _parse_kraken_book_levels(
levels: list[BookLevel] = []
for level in raw_levels:
if isinstance(level, dict) and "price" in level and "qty" in level:
levels.append(
BookLevel(price=float(level["price"]),
volume=float(level["qty"]))
)
levels.append(BookLevel(price=float(level["price"]), volume=float(level["qty"])))
return tuple(levels)
@@ -294,8 +289,7 @@ class BacktestReplayEngine:
min_order_size_by_pair=config.min_order_size_by_pair,
)
self._pre_trade = PreTradeValidator()
self._trade_limits = TradeLimitsGuard(
max_concurrent_trades=config.max_concurrent_trades)
self._trade_limits = TradeLimitsGuard(max_concurrent_trades=config.max_concurrent_trades)
self._simulated_rest = _SimulatedRestClient(
self._clock,
slippage_bps=config.slippage_bps,
@@ -330,8 +324,7 @@ class BacktestReplayEngine:
trades_executed = 0
realized_pnl = 0.0
equity = float(starting_balances.get(
self._config.quote_asset.upper(), 0.0))
equity = float(starting_balances.get(self._config.quote_asset.upper(), 0.0))
peak_equity = equity
max_drawdown = 0.0
@@ -374,8 +367,7 @@ class BacktestReplayEngine:
result = await self._sequencer.execute(opportunity)
self._trade_limits.close_trade(exposure)
execution_latencies.append(
self._simulated_rest.last_trade_latency_ms)
execution_latencies.append(self._simulated_rest.last_trade_latency_ms)
fill_samples.append(self._simulated_rest.last_fill_ratio)
if not result.success:
@@ -398,8 +390,7 @@ class BacktestReplayEngine:
wins = sum(1 for pnl in realized_samples if pnl > 0.0)
win_rate = (wins / len(realized_samples)) if realized_samples else None
fill_rate = (sum(fill_samples) / len(fill_samples)
) if fill_samples else None
fill_rate = (sum(fill_samples) / len(fill_samples)) if fill_samples else None
return BacktestReport(
started_at=events[0].occurred_at if events else self._clock.now,
+9 -11
View File
@@ -69,20 +69,21 @@ async def run_backtest_job(
start_str = config.get("start_time")
end_str = config.get("end_time")
if isinstance(start_str, str) and start_str:
start_dt = datetime.fromisoformat(
start_str.replace("Z", "+00:00"))
start_dt = datetime.fromisoformat(start_str.replace("Z", "+00:00"))
if isinstance(end_str, str) and end_str:
end_dt = datetime.fromisoformat(end_str.replace("Z", "+00:00"))
symbols: list[str] | None = None
if isinstance(symbols_raw, str) and symbols_raw.strip():
symbols = [s.strip().upper()
for s in symbols_raw.split(",") if s.strip()]
symbols = [s.strip().upper() for s in symbols_raw.split(",") if s.strip()]
elif isinstance(symbols_raw, list):
symbols = [str(s).upper() for s in symbols_raw]
events = load_replay_events_from_db(
store, symbols=symbols, start=start_dt, end=end_dt,
store,
symbols=symbols,
start=start_dt,
end=end_dt,
)
else:
path = Path(events_path)
@@ -94,14 +95,12 @@ async def run_backtest_job(
if not events:
raise ValueError("No events found for backtest")
starting_balances_raw = str(config.get(
"starting_balances", "USD=1000.0"))
starting_balances_raw = str(config.get("starting_balances", "USD=1000.0"))
starting_balances = _parse_balances(starting_balances_raw)
fee_rate = float(config.get("fee_rate", 0.0026))
trade_capital = float(config.get("trade_capital", 100.0))
min_profit_threshold = float(
config.get("min_profit_threshold", 0.0005))
min_profit_threshold = float(config.get("min_profit_threshold", 0.0005))
slippage_bps = float(config.get("slippage_bps", 4.0))
execution_latency_ms = float(config.get("execution_latency_ms", 20.0))
@@ -144,8 +143,7 @@ async def run_backtest_job(
repo.store_report(job_id, report_dict)
repo.update_status(job_id, "completed")
_LOG.info("backtest_job_completed", job_id=job_id,
pnl=report.realized_pnl_usd)
_LOG.info("backtest_job_completed", job_id=job_id, pnl=report.realized_pnl_usd)
except Exception as exc:
repo.update_status(job_id, "failed", error=str(exc))
+11 -16
View File
@@ -91,16 +91,14 @@ def build_parameter_grid(
for theta in theta_values:
for trade_capital in trade_capital_values:
for pair_universe in pair_universes:
normalized_universe = tuple(
sorted({pair.upper() for pair in pair_universe}))
normalized_universe = tuple(sorted({pair.upper() for pair in pair_universe}))
for staleness_threshold in staleness_threshold_values:
grid.append(
SweepParameters(
min_profit_threshold=float(theta),
trade_capital=float(trade_capital),
pair_universe=normalized_universe,
staleness_threshold_seconds=float(
staleness_threshold),
staleness_threshold_seconds=float(staleness_threshold),
)
)
return grid
@@ -147,8 +145,9 @@ def _restrict_cycles_by_pair(
if normalized_pair not in pair_universe:
continue
kept = [cycle for cycle in cycles if all(
pair.upper() in pair_universe for pair in cycle.pairs)]
kept = [
cycle for cycle in cycles if all(pair.upper() in pair_universe for pair in cycle.pairs)
]
if kept:
restricted[normalized_pair] = kept
return restricted
@@ -175,9 +174,7 @@ def _evaluate_promotion(
test = result.test_report
if test.realized_pnl_usd < criteria.min_test_realized_pnl_usd:
reasons.append(
"test_realized_pnl_below_threshold"
)
reasons.append("test_realized_pnl_below_threshold")
if (test.win_rate or 0.0) < criteria.min_test_win_rate:
reasons.append("test_win_rate_below_threshold")
if (test.fill_rate or 0.0) < criteria.min_test_fill_rate:
@@ -221,8 +218,7 @@ def run_parameter_search(
quote_asset: str = "USD",
) -> SweepArtifacts:
criteria = promotion_criteria or PromotionCriteria()
train_events, test_events = split_events_time_windows(
events, train_ratio=train_ratio)
train_events, test_events = split_events_time_windows(events, train_ratio=train_ratio)
results: list[SweepResult] = []
promoted: list[SweepResult] = []
@@ -293,7 +289,8 @@ def run_parameter_search(
test_event_count=len(filtered_test),
)
promotion_ready, promotion_reasons = _evaluate_promotion(
result=base_result, criteria=criteria)
result=base_result, criteria=criteria
)
completed_result = SweepResult(
parameters=base_result.parameters,
train_report=base_result.train_report,
@@ -318,8 +315,7 @@ def run_parameter_search(
train_window: tuple[datetime, datetime] | None = None
test_window: tuple[datetime, datetime] | None = None
if train_events:
train_window = (train_events[0].occurred_at,
train_events[-1].occurred_at)
train_window = (train_events[0].occurred_at, train_events[-1].occurred_at)
if test_events:
test_window = (test_events[0].occurred_at, test_events[-1].occurred_at)
@@ -392,5 +388,4 @@ def persist_sweep_results(path: Path, artifacts: SweepArtifacts) -> None:
}
path.parent.mkdir(parents=True, exist_ok=True)
path.write_bytes(orjson.dumps(
payload, option=orjson.OPT_INDENT_2 | orjson.OPT_SORT_KEYS))
path.write_bytes(orjson.dumps(payload, option=orjson.OPT_INDENT_2 | orjson.OPT_SORT_KEYS))
+14 -6
View File
@@ -63,6 +63,7 @@ class ConfigurationService:
"""Load user settings from database and merge with defaults."""
# Import here to avoid circular imports
from arbitrade.storage.repositories import ConfigSettingRepository
setting_repo = ConfigSettingRepository(self._store)
# Load all settings from database
@@ -91,7 +92,8 @@ class ConfigurationService:
# Track the latest update time
if db_settings:
latest_updated = max(
setting.updated_at for setting in db_settings if setting.updated_at)
setting.updated_at for setting in db_settings if setting.updated_at
)
self._last_updated_at = latest_updated
# Initialize with default values from settings model
@@ -119,6 +121,7 @@ class ConfigurationService:
"""Check if configuration has been updated since last load."""
# Import here to avoid circular imports
from arbitrade.storage.repositories import ConfigSettingRepository
setting_repo = ConfigSettingRepository(self._store)
# Get the latest update timestamp from database
@@ -143,6 +146,7 @@ class ConfigurationService:
"""Set a configuration setting value and persist to database."""
# Import here to avoid circular imports
from arbitrade.storage.repositories import ConfigSettingRepository
setting_repo = ConfigSettingRepository(self._store)
# Convert value to JSON string and determine type
@@ -159,10 +163,10 @@ class ConfigurationService:
value_json = str(value).lower()
value_type = "bool"
elif isinstance(value, list):
value_json = orjson.dumps(value).decode('utf-8')
value_json = orjson.dumps(value).decode("utf-8")
value_type = "list"
elif isinstance(value, dict):
value_json = orjson.dumps(value).decode('utf-8')
value_json = orjson.dumps(value).decode("utf-8")
value_type = "dict"
else:
value_json = str(value)
@@ -176,7 +180,7 @@ class ConfigurationService:
value_type=value_type,
is_secret=False,
is_runtime_reloadable=False,
updated_by=updated_by
updated_by=updated_by,
)
# Check if setting exists
@@ -205,17 +209,21 @@ class ConfigurationService:
def _pairing_repo(self):
from arbitrade.storage.repositories import ConfigPairingRepository
return ConfigPairingRepository(self._store)
def list_pairings(self) -> list[ConfigPairing]:
"""List all currency pairings."""
return self._pairing_repo().list_pairings()
def create_pairing(self, base_asset: str, quote_asset: str, source: str = "manual") -> ConfigPairing:
def create_pairing(
self, base_asset: str, quote_asset: str, source: str = "manual"
) -> ConfigPairing:
"""Create a new currency pairing."""
existing = self._pairing_repo().get_pairing(base_asset, quote_asset)
if existing:
return existing
pairing = ConfigPairing(
base_asset=base_asset, quote_asset=quote_asset, enabled=True, source=source)
base_asset=base_asset, quote_asset=quote_asset, enabled=True, source=source
)
return self._pairing_repo().create_pairing(pairing)
+39 -78
View File
@@ -32,72 +32,49 @@ class Settings(BaseSettings):
)
alerts_enabled: bool = Field(default=True, alias="ALERTS_ENABLED")
alert_min_severity: str = Field(
default="warning", alias="ALERT_MIN_SEVERITY")
alert_dedup_seconds: float = Field(
default=30.0, alias="ALERT_DEDUP_SECONDS")
alert_on_trade_events: bool = Field(
default=True, alias="ALERT_ON_TRADE_EVENTS")
alert_on_error_events: bool = Field(
default=True, alias="ALERT_ON_ERROR_EVENTS")
alert_on_threshold_events: bool = Field(
default=True, alias="ALERT_ON_THRESHOLD_EVENTS")
alert_on_system_events: bool = Field(
default=True, alias="ALERT_ON_SYSTEM_EVENTS")
alert_min_severity: str = Field(default="warning", alias="ALERT_MIN_SEVERITY")
alert_dedup_seconds: float = Field(default=30.0, alias="ALERT_DEDUP_SECONDS")
alert_on_trade_events: bool = Field(default=True, alias="ALERT_ON_TRADE_EVENTS")
alert_on_error_events: bool = Field(default=True, alias="ALERT_ON_ERROR_EVENTS")
alert_on_threshold_events: bool = Field(default=True, alias="ALERT_ON_THRESHOLD_EVENTS")
alert_on_system_events: bool = Field(default=True, alias="ALERT_ON_SYSTEM_EVENTS")
telegram_alerts_enabled: bool = Field(
default=False, alias="TELEGRAM_ALERTS_ENABLED")
telegram_bot_token: str | None = Field(
default=None, alias="TELEGRAM_BOT_TOKEN")
telegram_chat_id: str | None = Field(
default=None, alias="TELEGRAM_CHAT_ID")
telegram_alerts_enabled: bool = Field(default=False, alias="TELEGRAM_ALERTS_ENABLED")
telegram_bot_token: str | None = Field(default=None, alias="TELEGRAM_BOT_TOKEN")
telegram_chat_id: str | None = Field(default=None, alias="TELEGRAM_CHAT_ID")
discord_alerts_enabled: bool = Field(
default=False, alias="DISCORD_ALERTS_ENABLED")
discord_webhook_url: str | None = Field(
default=None, alias="DISCORD_WEBHOOK_URL")
discord_alerts_enabled: bool = Field(default=False, alias="DISCORD_ALERTS_ENABLED")
discord_webhook_url: str | None = Field(default=None, alias="DISCORD_WEBHOOK_URL")
email_alerts_enabled: bool = Field(
default=False, alias="EMAIL_ALERTS_ENABLED")
email_alerts_enabled: bool = Field(default=False, alias="EMAIL_ALERTS_ENABLED")
email_smtp_host: str | None = Field(default=None, alias="EMAIL_SMTP_HOST")
email_smtp_port: int = Field(default=587, alias="EMAIL_SMTP_PORT")
email_smtp_username: str | None = Field(
default=None, alias="EMAIL_SMTP_USERNAME")
email_smtp_password: str | None = Field(
default=None, alias="EMAIL_SMTP_PASSWORD")
email_alert_from: str | None = Field(
default=None, alias="EMAIL_ALERT_FROM")
email_smtp_username: str | None = Field(default=None, alias="EMAIL_SMTP_USERNAME")
email_smtp_password: str | None = Field(default=None, alias="EMAIL_SMTP_PASSWORD")
email_alert_from: str | None = Field(default=None, alias="EMAIL_ALERT_FROM")
email_alert_to: str | None = Field(default=None, alias="EMAIL_ALERT_TO")
email_smtp_use_tls: bool = Field(default=True, alias="EMAIL_SMTP_USE_TLS")
duckdb_path: Path = Field(default=Path(
"./data/arbitrade.duckdb"), alias="DUCKDB_PATH")
duckdb_path: Path = Field(default=Path("./data/arbitrade.duckdb"), alias="DUCKDB_PATH")
kraken_rest_url: str = Field(
default="https://api.kraken.com", alias="KRAKEN_REST_URL")
kraken_ws_url: str = Field(
default="wss://ws.kraken.com/v2", alias="KRAKEN_WS_URL")
kraken_rest_url: str = Field(default="https://api.kraken.com", alias="KRAKEN_REST_URL")
kraken_ws_url: str = Field(default="wss://ws.kraken.com/v2", alias="KRAKEN_WS_URL")
kraken_private_rate_limit_seconds: float = Field(
default=1.0, alias="KRAKEN_PRIVATE_RATE_LIMIT_SECONDS"
)
kraken_http_timeout_seconds: float = Field(
default=10.0, alias="KRAKEN_HTTP_TIMEOUT_SECONDS")
kraken_retry_attempts: int = Field(
default=3, alias="KRAKEN_RETRY_ATTEMPTS")
kraken_http_timeout_seconds: float = Field(default=10.0, alias="KRAKEN_HTTP_TIMEOUT_SECONDS")
kraken_retry_attempts: int = Field(default=3, alias="KRAKEN_RETRY_ATTEMPTS")
kraken_retry_base_delay_seconds: float = Field(
default=0.25, alias="KRAKEN_RETRY_BASE_DELAY_SECONDS"
)
kraken_api_key: str | None = Field(default=None, alias="KRAKEN_API_KEY")
kraken_api_secret: str | None = Field(
default=None, alias="KRAKEN_API_SECRET")
kraken_api_secret: str | None = Field(default=None, alias="KRAKEN_API_SECRET")
kraken_api_key_permissions: str = Field(
default="query,trade",
alias="KRAKEN_API_KEY_PERMISSIONS",
)
ws_heartbeat_timeout_seconds: float = Field(
default=20.0, alias="WS_HEARTBEAT_TIMEOUT_SECONDS")
ws_max_staleness_seconds: float = Field(
default=5.0, alias="WS_MAX_STALENESS_SECONDS")
ws_heartbeat_timeout_seconds: float = Field(default=20.0, alias="WS_HEARTBEAT_TIMEOUT_SECONDS")
ws_max_staleness_seconds: float = Field(default=5.0, alias="WS_MAX_STALENESS_SECONDS")
strategy_enable_stat_arb_experiment: bool = Field(
default=False,
alias="STRATEGY_ENABLE_STAT_ARB_EXPERIMENT",
@@ -120,29 +97,20 @@ class Settings(BaseSettings):
)
paper_trading_mode: bool = Field(default=True, alias="PAPER_TRADING_MODE")
trade_capital_usd: float = Field(default=100.0, alias="TRADE_CAPITAL_USD")
max_trade_capital_usd: float = Field(
default=100.0, alias="MAX_TRADE_CAPITAL_USD")
max_concurrent_trades: int | None = Field(
default=None, alias="MAX_CONCURRENT_TRADES")
max_trade_capital_usd: float = Field(default=100.0, alias="MAX_TRADE_CAPITAL_USD")
max_concurrent_trades: int | None = Field(default=None, alias="MAX_CONCURRENT_TRADES")
max_exposure_per_asset_usd: float | None = Field(
default=None,
alias="MAX_EXPOSURE_PER_ASSET_USD",
)
quote_balance_asset: str = Field(
default="USD", alias="QUOTE_BALANCE_ASSET")
min_order_size_usd: float | None = Field(
default=None, alias="MIN_ORDER_SIZE_USD")
quote_balance_asset: str = Field(default="USD", alias="QUOTE_BALANCE_ASSET")
min_order_size_usd: float | None = Field(default=None, alias="MIN_ORDER_SIZE_USD")
kill_switch_active: bool = Field(default=False, alias="KILL_SWITCH_ACTIVE")
daily_loss_limit_usd: float | None = Field(
default=None, alias="DAILY_LOSS_LIMIT_USD")
cumulative_loss_limit_usd: float | None = Field(
default=None, alias="CUMULATIVE_LOSS_LIMIT_USD")
max_source_latency_ms: float | None = Field(
default=None, alias="MAX_SOURCE_LATENCY_MS")
max_apply_latency_ms: float | None = Field(
default=None, alias="MAX_APPLY_LATENCY_MS")
max_consecutive_failures: int | None = Field(
default=None, alias="MAX_CONSECUTIVE_FAILURES")
daily_loss_limit_usd: float | None = Field(default=None, alias="DAILY_LOSS_LIMIT_USD")
cumulative_loss_limit_usd: float | None = Field(default=None, alias="CUMULATIVE_LOSS_LIMIT_USD")
max_source_latency_ms: float | None = Field(default=None, alias="MAX_SOURCE_LATENCY_MS")
max_apply_latency_ms: float | None = Field(default=None, alias="MAX_APPLY_LATENCY_MS")
max_consecutive_failures: int | None = Field(default=None, alias="MAX_CONSECUTIVE_FAILURES")
fernet_key: str | None = Field(default=None, alias="FERNET_KEY")
@@ -159,8 +127,7 @@ class Settings(BaseSettings):
def _validate_log_level(cls, value: str) -> str:
normalized = value.strip().upper()
if normalized not in {"DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"}:
raise ValueError(
"LOG_LEVEL must be one of: DEBUG, INFO, WARNING, ERROR, CRITICAL")
raise ValueError("LOG_LEVEL must be one of: DEBUG, INFO, WARNING, ERROR, CRITICAL")
return normalized
@field_validator("alert_min_severity")
@@ -168,19 +135,16 @@ class Settings(BaseSettings):
def _validate_alert_severity(cls, value: str) -> str:
normalized = value.strip().lower()
if normalized not in {"info", "warning", "error", "critical"}:
raise ValueError(
"ALERT_MIN_SEVERITY must be one of: info, warning, error, critical")
raise ValueError("ALERT_MIN_SEVERITY must be one of: info, warning, error, critical")
return normalized
@model_validator(mode="after")
def _validate_security_constraints(self) -> Settings:
if bool(self.dashboard_auth_username) ^ bool(self.dashboard_auth_password):
raise ValueError(
"dashboard auth requires both username and password")
raise ValueError("dashboard auth requires both username and password")
if bool(self.kraken_api_key) ^ bool(self.kraken_api_secret):
raise ValueError(
"Kraken API auth requires both API key and secret")
raise ValueError("Kraken API auth requires both API key and secret")
permissions = {
token.strip().lower()
@@ -188,11 +152,9 @@ class Settings(BaseSettings):
if token.strip()
}
if permissions and ("query" not in permissions or "trade" not in permissions):
raise ValueError(
"KRAKEN_API_KEY_PERMISSIONS must include query and trade")
raise ValueError("KRAKEN_API_KEY_PERMISSIONS must include query and trade")
if "withdraw" in permissions or "withdrawals" in permissions:
raise ValueError(
"KRAKEN_API_KEY_PERMISSIONS must not include withdrawal scope")
raise ValueError("KRAKEN_API_KEY_PERMISSIONS must not include withdrawal scope")
if self.alert_dedup_seconds < 0.0:
raise ValueError("ALERT_DEDUP_SECONDS must be >= 0")
@@ -208,8 +170,7 @@ class Settings(BaseSettings):
"STRATEGY_STAT_ARB_ENTRY_ZSCORE must be greater than STRATEGY_STAT_ARB_EXIT_ZSCORE"
)
if self.strategy_stat_arb_max_holding_seconds <= 0.0:
raise ValueError(
"STRATEGY_STAT_ARB_MAX_HOLDING_SECONDS must be > 0")
raise ValueError("STRATEGY_STAT_ARB_MAX_HOLDING_SECONDS must be > 0")
return self
+1 -2
View File
@@ -92,8 +92,7 @@ def run_incremental_detection_benchmark(
def main() -> None:
parser = argparse.ArgumentParser(
description="Benchmark incremental detection latency")
parser = argparse.ArgumentParser(description="Benchmark incremental detection latency")
parser.add_argument("--iterations", type=int, default=50_000)
parser.add_argument("--target-ms", type=float, default=1.0)
args = parser.parse_args()
+7 -15
View File
@@ -43,12 +43,9 @@ async def fetch_and_store_account_snapshot(
_LOG.exception("trade_balance_fetch_failed")
return None
fee_tier = volume_data.get("fee_tier") if isinstance(
volume_data, dict) else None
fees_dict = volume_data.get("fees") if isinstance(
volume_data, dict) else None
fees_maker = volume_data.get("fees_maker") if isinstance(
volume_data, dict) else None
fee_tier = volume_data.get("fee_tier") if isinstance(volume_data, dict) else None
fees_dict = volume_data.get("fees") if isinstance(volume_data, dict) else None
fees_maker = volume_data.get("fees_maker") if isinstance(volume_data, dict) else None
currency = volume_data.get("currency")
thirty_day_volume_str = volume_data.get("volume")
@@ -74,9 +71,7 @@ async def fetch_and_store_account_snapshot(
if currency is not None:
fee_schedule["currency"] = currency
thirty_day_volume = (
float(thirty_day_volume_str) if thirty_day_volume_str is not None else None
)
thirty_day_volume = float(thirty_day_volume_str) if thirty_day_volume_str is not None else None
snapshot = KrakenAccountSnapshot(
snapshot_at=datetime.now(timezone.utc),
@@ -84,8 +79,7 @@ async def fetch_and_store_account_snapshot(
maker_fee=maker_fee,
taker_fee=taker_fee,
thirty_day_volume=thirty_day_volume,
trade_balance_raw=balance_data if isinstance(
balance_data, dict) else None,
trade_balance_raw=balance_data if isinstance(balance_data, dict) else None,
fee_schedule_raw=fee_schedule if fee_schedule else None,
)
@@ -109,8 +103,7 @@ async def fetch_and_store_account_snapshot(
"INSERT INTO portfolio_snapshots (snapshot_at, balances, total_value_usd) VALUES (?, ?, ?)",
(
datetime.now(timezone.utc),
orjson.dumps(wallet_balances).decode(
"utf-8") if wallet_balances else None,
orjson.dumps(wallet_balances).decode("utf-8") if wallet_balances else None,
total_value,
),
)
@@ -130,8 +123,7 @@ async def run_fee_sync_loop(
Runs until stop_event is set.
"""
_LOG.info("fee_sync_loop_started",
interval_s=_FEE_REFRESH_INTERVAL_SECONDS)
_LOG.info("fee_sync_loop_started", interval_s=_FEE_REFRESH_INTERVAL_SECONDS)
while not stop_event.is_set():
try:
+12 -6
View File
@@ -24,7 +24,8 @@ class MetricsCalculator:
def compute(self) -> PerformanceMetrics:
with self._store.connect() as conn:
tm = conn.execute("""
tm = conn.execute(
"""
SELECT
COALESCE(SUM(COALESCE(realized_pnl, 0)), 0) AS realized_pnl_usd,
COUNT(*) AS total_trades,
@@ -44,21 +45,26 @@ class MetricsCalculator:
) AS latency_p99_seconds
FROM trades
WHERE finished_at IS NOT NULL
""").fetchone()
"""
).fetchone()
om = conn.execute("""
om = conn.execute(
"""
SELECT
COUNT(*) AS opportunity_count,
MIN(detected_at) AS first_detected_at,
MAX(detected_at) AS last_detected_at
FROM opportunities
""").fetchone()
"""
).fetchone()
fm = conn.execute("""
fm = conn.execute(
"""
SELECT AVG(filled_volume / volume) AS fill_rate
FROM orders
WHERE volume > 0 AND filled_volume IS NOT NULL
""").fetchone()
"""
).fetchone()
r_pnl_usd = float(tm[0]) if tm and tm[0] is not None else 0.0
tt = int(tm[1]) if tm and tm[1] is not None else 0
+8 -4
View File
@@ -45,22 +45,26 @@ def _runtime_repository(app: FastAPI) -> RuntimeStateRepository | None:
def _open_trade_count(store: DuckDBStore) -> int:
with store.connect() as conn:
row = conn.execute("""
row = conn.execute(
"""
SELECT COUNT(*)
FROM trades
WHERE finished_at IS NULL
""").fetchone()
"""
).fetchone()
return int(row[0]) if row is not None else 0
def _latest_balances(store: DuckDBStore) -> dict[str, Any] | None:
with store.connect() as conn:
row = conn.execute("""
row = conn.execute(
"""
SELECT balances
FROM portfolio_snapshots
ORDER BY snapshot_at DESC
LIMIT 1
""").fetchone()
"""
).fetchone()
if row is None or row[0] is None:
return None
+24 -27
View File
@@ -216,12 +216,14 @@ class DuckDBStore:
# Ensure schema_migrations table exists and get current version
if not self._table_exists(conn, "schema_migrations"):
conn.execute("""
conn.execute(
"""
CREATE TABLE IF NOT EXISTS schema_migrations (
version INTEGER PRIMARY KEY,
applied_at TIMESTAMP DEFAULT current_timestamp
)
""")
"""
)
# Get current schema version
try:
@@ -236,30 +238,24 @@ class DuckDBStore:
if current_version < 1:
# Migration v1: Add missing columns to trades table
# Note: DuckDB does not support ADD COLUMN with constraints
conn.execute(
"ALTER TABLE trades ADD COLUMN IF NOT EXISTS trade_ref VARCHAR")
conn.execute(
"ALTER TABLE trades ADD COLUMN IF NOT EXISTS estimated_pnl DOUBLE")
conn.execute(
"ALTER TABLE trades ADD COLUMN IF NOT EXISTS capital_used DOUBLE")
conn.execute(
"ALTER TABLE trades ADD COLUMN IF NOT EXISTS cycle VARCHAR")
conn.execute(
"ALTER TABLE trades ADD COLUMN IF NOT EXISTS leg_count INTEGER")
conn.execute(
"INSERT OR IGNORE INTO schema_migrations (version) VALUES (1)")
conn.execute("ALTER TABLE trades ADD COLUMN IF NOT EXISTS trade_ref VARCHAR")
conn.execute("ALTER TABLE trades ADD COLUMN IF NOT EXISTS estimated_pnl DOUBLE")
conn.execute("ALTER TABLE trades ADD COLUMN IF NOT EXISTS capital_used DOUBLE")
conn.execute("ALTER TABLE trades ADD COLUMN IF NOT EXISTS cycle VARCHAR")
conn.execute("ALTER TABLE trades ADD COLUMN IF NOT EXISTS leg_count INTEGER")
conn.execute("INSERT OR IGNORE INTO schema_migrations (version) VALUES (1)")
_LOG.info("migration_applied", version=1)
if current_version < 2:
# Migration v2: Ensure config_backtesting_defaults table
# config_backtesting_defaults already created by SCHEMA_SQL
conn.execute(
"INSERT OR IGNORE INTO schema_migrations (version) VALUES (2)")
conn.execute("INSERT OR IGNORE INTO schema_migrations (version) VALUES (2)")
_LOG.info("migration_applied", version=2)
if current_version < 3:
# Migration v3: Add kraken_account_snapshots table
conn.execute("""
conn.execute(
"""
CREATE TABLE IF NOT EXISTS kraken_account_snapshots (
snapshot_at TIMESTAMP NOT NULL,
fee_tier VARCHAR,
@@ -269,21 +265,22 @@ class DuckDBStore:
trade_balance_raw JSON,
fee_schedule_raw JSON
)
""")
conn.execute(
"INSERT OR IGNORE INTO schema_migrations (version) VALUES (3)")
"""
)
conn.execute("INSERT OR IGNORE INTO schema_migrations (version) VALUES (3)")
_LOG.info("migration_applied", version=3)
if current_version < 4:
# Migration v4: Add fee_source to backtesting defaults
conn.execute(
"ALTER TABLE config_backtesting_defaults ADD COLUMN IF NOT EXISTS fee_source VARCHAR DEFAULT 'api'")
conn.execute(
"INSERT OR IGNORE INTO schema_migrations (version) VALUES (4)")
"ALTER TABLE config_backtesting_defaults ADD COLUMN IF NOT EXISTS fee_source VARCHAR DEFAULT 'api'"
)
conn.execute("INSERT OR IGNORE INTO schema_migrations (version) VALUES (4)")
_LOG.info("migration_applied", version=4)
if current_version < 5:
conn.execute("""
conn.execute(
"""
CREATE TABLE IF NOT EXISTS backtest_jobs (
id UUID DEFAULT uuid(),
status VARCHAR NOT NULL DEFAULT 'pending',
@@ -295,9 +292,9 @@ class DuckDBStore:
started_at TIMESTAMP,
finished_at TIMESTAMP
)
""")
conn.execute(
"INSERT OR IGNORE INTO schema_migrations (version) VALUES (5)")
"""
)
conn.execute("INSERT OR IGNORE INTO schema_migrations (version) VALUES (5)")
_LOG.info("migration_applied", version=5)
# Update version to current
+67 -49
View File
@@ -6,7 +6,12 @@ from typing import Any
import orjson
from arbitrade.config.service import ConfigBacktestingDefaults, ConfigPairing, ConfigSection, ConfigSetting
from arbitrade.config.service import (
ConfigBacktestingDefaults,
ConfigPairing,
ConfigSection,
ConfigSetting,
)
from arbitrade.storage.db import DuckDBStore
@@ -344,7 +349,8 @@ class RuntimeStateRepository:
def latest(self) -> RuntimeStateRecord | None:
with self._store.connect() as conn:
row = conn.execute("""
row = conn.execute(
"""
SELECT
snapshot_at,
is_running,
@@ -356,7 +362,8 @@ class RuntimeStateRepository:
FROM runtime_state_snapshots
ORDER BY snapshot_at DESC
LIMIT 1
""").fetchone()
"""
).fetchone()
if row is None:
return None
@@ -397,12 +404,7 @@ class ConfigSectionRepository:
)
row = cursor.fetchone()
if row:
return ConfigSection(
id=row[0],
name=row[1],
description=row[2],
updated_at=row[3]
)
return ConfigSection(id=row[0], name=row[1], description=row[2], updated_at=row[3])
raise ValueError("Failed to create section")
def get_section(self, name: str) -> ConfigSection | None:
@@ -418,12 +420,7 @@ class ConfigSectionRepository:
)
row = cursor.fetchone()
if row:
return ConfigSection(
id=row[0],
name=row[1],
description=row[2],
updated_at=row[3]
)
return ConfigSection(id=row[0], name=row[1], description=row[2], updated_at=row[3])
return None
def list_sections(self) -> list[ConfigSection]:
@@ -437,12 +434,7 @@ class ConfigSectionRepository:
"""
)
return [
ConfigSection(
id=row[0],
name=row[1],
description=row[2],
updated_at=row[3]
)
ConfigSection(id=row[0], name=row[1], description=row[2], updated_at=row[3])
for row in cursor.fetchall()
]
@@ -480,7 +472,7 @@ class ConfigSettingRepository:
is_secret=bool(row[4]),
is_runtime_reloadable=bool(row[5]),
updated_at=row[6],
updated_by=row[7]
updated_by=row[7],
)
raise ValueError("Failed to create setting")
@@ -505,7 +497,7 @@ class ConfigSettingRepository:
is_secret=bool(row[4]),
is_runtime_reloadable=bool(row[5]),
updated_at=row[6],
updated_by=row[7]
updated_by=row[7],
)
return None
@@ -539,7 +531,7 @@ class ConfigSettingRepository:
is_secret=bool(row[4]),
is_runtime_reloadable=bool(row[5]),
updated_at=row[6],
updated_by=row[7]
updated_by=row[7],
)
raise ValueError("Failed to update setting")
@@ -585,7 +577,7 @@ class ConfigSettingRepository:
is_secret=bool(row[4]),
is_runtime_reloadable=bool(row[5]),
updated_at=row[6],
updated_by=row[7]
updated_by=row[7],
)
for row in cursor.fetchall()
]
@@ -602,7 +594,7 @@ class ConfigSettingRepository:
row = cursor.fetchone()
if row and row[0]:
# Convert string timestamp to datetime
return datetime.fromisoformat(row[0].replace('Z', '+00:00'))
return datetime.fromisoformat(row[0].replace("Z", "+00:00"))
return None
@@ -635,7 +627,7 @@ class ConfigPairingRepository:
enabled=bool(row[3]),
source=row[4],
created_at=row[5],
updated_at=row[6]
updated_at=row[6],
)
raise ValueError("Failed to create pairing")
@@ -659,11 +651,13 @@ class ConfigPairingRepository:
enabled=bool(row[3]),
source=row[4],
created_at=row[5],
updated_at=row[6]
updated_at=row[6],
)
return None
def update_pairing(self, base_asset: str, quote_asset: str, pairing: ConfigPairing) -> ConfigPairing:
def update_pairing(
self, base_asset: str, quote_asset: str, pairing: ConfigPairing
) -> ConfigPairing:
"""Update an existing currency pairing."""
with self._store.connect() as conn:
cursor = conn.execute(
@@ -689,7 +683,7 @@ class ConfigPairingRepository:
enabled=bool(row[3]),
source=row[4],
created_at=row[5],
updated_at=row[6]
updated_at=row[6],
)
raise ValueError("Failed to update pairing")
@@ -723,7 +717,7 @@ class ConfigPairingRepository:
enabled=bool(row[3]),
source=row[4],
created_at=row[5],
updated_at=row[6]
updated_at=row[6],
)
for row in cursor.fetchall()
]
@@ -743,8 +737,11 @@ class ConfigBacktestingDefaultsRepository:
RETURNING id, starting_balances, trade_capital, min_profit_threshold, slippage_bps, execution_latency_ms
""",
(
orjson.dumps(defaults.starting_balances).decode(
'utf-8') if defaults.starting_balances else None,
(
orjson.dumps(defaults.starting_balances).decode("utf-8")
if defaults.starting_balances
else None
),
defaults.trade_capital,
defaults.min_profit_threshold,
defaults.slippage_bps,
@@ -758,7 +755,7 @@ class ConfigBacktestingDefaultsRepository:
trade_capital=row[2],
min_profit_threshold=row[3],
slippage_bps=row[4],
execution_latency_ms=row[5]
execution_latency_ms=row[5],
)
raise ValueError("Failed to create backtesting defaults")
@@ -780,7 +777,7 @@ class ConfigBacktestingDefaultsRepository:
trade_capital=row[2],
min_profit_threshold=row[3],
slippage_bps=row[4],
execution_latency_ms=row[5]
execution_latency_ms=row[5],
)
return None
@@ -797,8 +794,11 @@ class ConfigBacktestingDefaultsRepository:
RETURNING id, starting_balances, trade_capital, min_profit_threshold, slippage_bps, execution_latency_ms
""",
(
orjson.dumps(defaults.starting_balances).decode(
'utf-8') if defaults.starting_balances else None,
(
orjson.dumps(defaults.starting_balances).decode("utf-8")
if defaults.starting_balances
else None
),
defaults.trade_capital,
defaults.min_profit_threshold,
defaults.slippage_bps,
@@ -812,7 +812,7 @@ class ConfigBacktestingDefaultsRepository:
trade_capital=row[2],
min_profit_threshold=row[3],
slippage_bps=row[4],
execution_latency_ms=row[5]
execution_latency_ms=row[5],
)
raise ValueError("Failed to update backtesting defaults")
@@ -847,10 +847,16 @@ class KrakenAccountSnapshotRepository:
snapshot.maker_fee,
snapshot.taker_fee,
snapshot.thirty_day_volume,
(
orjson.dumps(snapshot.trade_balance_raw).decode("utf-8")
if snapshot.trade_balance_raw else None,
if snapshot.trade_balance_raw
else None
),
(
orjson.dumps(snapshot.fee_schedule_raw).decode("utf-8")
if snapshot.fee_schedule_raw else None,
if snapshot.fee_schedule_raw
else None
),
),
)
@@ -895,7 +901,9 @@ class BacktestJobRepository:
def __init__(self, store: DuckDBStore) -> None:
self._store = store
def create_job(self, events_path: str, config: dict[str, Any] | None = None) -> BacktestJobRecord:
def create_job(
self, events_path: str, config: dict[str, Any] | None = None
) -> BacktestJobRecord:
with self._store.connect() as conn:
row = conn.execute(
"""
@@ -903,13 +911,14 @@ class BacktestJobRepository:
VALUES (?, ?)
RETURNING id, status, events_path, config, created_at
""",
(events_path, orjson.dumps(config).decode(
"utf-8") if config else None),
(events_path, orjson.dumps(config).decode("utf-8") if config else None),
).fetchone()
if row is None:
raise ValueError("Failed to create backtest job")
return BacktestJobRecord(
id=str(row[0]), status=str(row[1]), events_path=str(row[2]),
id=str(row[0]),
status=str(row[1]),
events_path=str(row[2]),
config=orjson.loads(row[3]) if row[3] else None,
created_at=row[4],
)
@@ -950,11 +959,15 @@ class BacktestJobRepository:
if row is None:
return None
return BacktestJobRecord(
id=str(row[0]), status=str(row[1]), events_path=str(row[2]),
id=str(row[0]),
status=str(row[1]),
events_path=str(row[2]),
config=orjson.loads(row[3]) if row[3] else None,
report=orjson.loads(row[4]) if row[4] else None,
error=str(row[5]) if row[5] else None,
created_at=row[6], started_at=row[7], finished_at=row[8],
created_at=row[6],
started_at=row[7],
finished_at=row[8],
)
def list_jobs(self, limit: int = 20) -> list[BacktestJobRecord]:
@@ -967,11 +980,15 @@ class BacktestJobRepository:
).fetchall()
return [
BacktestJobRecord(
id=str(r[0]), status=str(r[1]), events_path=str(r[2]),
id=str(r[0]),
status=str(r[1]),
events_path=str(r[2]),
config=orjson.loads(r[3]) if r[3] else None,
report=orjson.loads(r[4]) if r[4] else None,
error=str(r[5]) if r[5] else None,
created_at=r[6], started_at=r[7], finished_at=r[8],
created_at=r[6],
started_at=r[7],
finished_at=r[8],
)
for r in rows
]
@@ -979,6 +996,7 @@ class BacktestJobRepository:
def delete_job(self, job_id: str) -> bool:
with self._store.connect() as conn:
result = conn.execute(
"DELETE FROM backtest_jobs WHERE id = ?", (job_id,),
"DELETE FROM backtest_jobs WHERE id = ?",
(job_id,),
)
return result.rowcount > 0
+6 -12
View File
@@ -191,8 +191,7 @@ async def test_dashboard_page_and_fragment_and_sse(tmp_path) -> None:
assert "trade-open" in overview.text
assert overview_stream.status_code == 200
assert overview_stream.headers["content-type"].startswith(
"text/event-stream")
assert overview_stream.headers["content-type"].startswith("text/event-stream")
assert "event: overview" in overview_stream.text
assert "trade-open" in overview_stream.text
@@ -262,8 +261,7 @@ async def test_dashboard_controls_update_runtime_state_and_config(tmp_path) -> N
assert app.state.settings.max_trade_capital_usd == 300.0
assert app.state.settings.max_concurrent_trades == 4
assert app.state.settings.paper_trading_mode is True
assert app.state.dashboard_controls.tradable_pairs == [
"BTC/USD", "ETH/BTC"]
assert app.state.dashboard_controls.tradable_pairs == ["BTC/USD", "ETH/BTC"]
assert app.state.dashboard_controls.strategy_mode == "paper"
assert app.state.dashboard_controls.strategy_profit_threshold == 0.0025
assert app.state.dashboard_controls.strategy_max_depth_levels == 7
@@ -275,14 +273,10 @@ async def test_dashboard_controls_update_runtime_state_and_config(tmp_path) -> N
assert audit_recent.status_code == 200
entries = audit_recent.json()["entries"]
assert len(entries) >= 4
assert any(entry["event_type"] ==
"dashboard.control.stop" for entry in entries)
assert any(entry["event_type"] ==
"dashboard.control.start" for entry in entries)
assert any(entry["event_type"] ==
"dashboard.control.kill_switch" for entry in entries)
assert any(entry["event_type"] ==
"dashboard.control.config" for entry in entries)
assert any(entry["event_type"] == "dashboard.control.stop" for entry in entries)
assert any(entry["event_type"] == "dashboard.control.start" for entry in entries)
assert any(entry["event_type"] == "dashboard.control.kill_switch" for entry in entries)
assert any(entry["event_type"] == "dashboard.control.config" for entry in entries)
async def test_dashboard_controls_emit_alerts(tmp_path) -> None:
+1 -1
View File
@@ -24,7 +24,7 @@ def test_end_to_end_config_workflow():
assert service.get_last_updated_at() is None
# Test setting a value
with patch('arbitrade.config.service.ConfigSettingRepository') as mock_repo_class:
with patch("arbitrade.config.service.ConfigSettingRepository") as mock_repo_class:
mock_repo_instance = Mock()
mock_repo_class.return_value = mock_repo_instance
+74 -32
View File
@@ -6,10 +6,13 @@ from unittest.mock import Mock, patch
from arbitrade.storage.repositories import (
ConfigSettingRepository,
ConfigPairingRepository,
ConfigPairFeeRepository,
ConfigBacktestingDefaultsRepository
ConfigBacktestingDefaultsRepository,
)
from arbitrade.config.service import (
ConfigSetting,
ConfigPairing,
ConfigBacktestingDefaults,
)
from arbitrade.config.service import ConfigSetting, ConfigPairing, ConfigPairFee, ConfigBacktestingDefaults
from arbitrade.storage.db import DuckDBStore
@@ -31,13 +34,20 @@ def test_config_setting_repository_create_setting(mock_store):
repo = ConfigSettingRepository(mock_store)
# Mock database connection
with patch.object(mock_store, 'connect') as mock_connect:
with patch.object(mock_store, "connect") as mock_connect:
mock_cursor = Mock()
mock_connect.return_value.__enter__.return_value = mock_cursor
# Mock the return value
mock_cursor.fetchone.return_value = [
"test_key", "test_section", "test_value", "str", False, False, "2023-01-01T00:00:00", "test_user"
"test_key",
"test_section",
"test_value",
"str",
False,
False,
"2023-01-01T00:00:00",
"test_user",
]
# Create setting
@@ -48,7 +58,7 @@ def test_config_setting_repository_create_setting(mock_store):
value_type="str",
is_secret=False,
is_runtime_reloadable=False,
updated_by="test_user"
updated_by="test_user",
)
result = repo.create_setting(setting)
@@ -67,13 +77,20 @@ def test_config_setting_repository_get_setting(mock_store):
repo = ConfigSettingRepository(mock_store)
# Mock database connection
with patch.object(mock_store, 'connect') as mock_connect:
with patch.object(mock_store, "connect") as mock_connect:
mock_cursor = Mock()
mock_connect.return_value.__enter__.return_value = mock_cursor
# Mock the return value
mock_cursor.fetchone.return_value = [
"test_key", "test_section", "test_value", "str", False, False, "2023-01-01T00:00:00", "test_user"
"test_key",
"test_section",
"test_value",
"str",
False,
False,
"2023-01-01T00:00:00",
"test_user",
]
# Get setting
@@ -93,13 +110,20 @@ def test_config_setting_repository_update_setting(mock_store):
repo = ConfigSettingRepository(mock_store)
# Mock database connection
with patch.object(mock_store, 'connect') as mock_connect:
with patch.object(mock_store, "connect") as mock_connect:
mock_cursor = Mock()
mock_connect.return_value.__enter__.return_value = mock_cursor
# Mock the return value
mock_cursor.fetchone.return_value = [
"test_key", "test_section", "updated_value", "str", False, False, "2023-01-01T00:00:00", "test_user"
"test_key",
"test_section",
"updated_value",
"str",
False,
False,
"2023-01-01T00:00:00",
"test_user",
]
# Update setting
@@ -110,7 +134,7 @@ def test_config_setting_repository_update_setting(mock_store):
value_type="str",
is_secret=False,
is_runtime_reloadable=False,
updated_by="test_user"
updated_by="test_user",
)
result = repo.update_setting("test_key", setting)
@@ -129,16 +153,32 @@ def test_config_setting_repository_list_settings(mock_store):
repo = ConfigSettingRepository(mock_store)
# Mock database connection
with patch.object(mock_store, 'connect') as mock_connect:
with patch.object(mock_store, "connect") as mock_connect:
mock_cursor = Mock()
mock_connect.return_value.__enter__.return_value = mock_cursor
# Mock the return value
mock_cursor.fetchall.return_value = [
["test_key1", "test_section", "test_value1", "str",
False, False, "2023-01-01T00:00:00", "test_user"],
["test_key2", "test_section", "test_value2", "str",
False, False, "2023-01-01T00:00:00", "test_user"]
[
"test_key1",
"test_section",
"test_value1",
"str",
False,
False,
"2023-01-01T00:00:00",
"test_user",
],
[
"test_key2",
"test_section",
"test_value2",
"str",
False,
False,
"2023-01-01T00:00:00",
"test_user",
],
]
# List settings
@@ -156,7 +196,7 @@ def test_config_setting_repository_get_latest_updated_at(mock_store):
repo = ConfigSettingRepository(mock_store)
# Mock database connection
with patch.object(mock_store, 'connect') as mock_connect:
with patch.object(mock_store, "connect") as mock_connect:
mock_cursor = Mock()
mock_connect.return_value.__enter__.return_value = mock_cursor
@@ -182,22 +222,24 @@ def test_config_pairing_repository_create_pairing(mock_store):
repo = ConfigPairingRepository(mock_store)
# Mock database connection
with patch.object(mock_store, 'connect') as mock_connect:
with patch.object(mock_store, "connect") as mock_connect:
mock_cursor = Mock()
mock_connect.return_value.__enter__.return_value = mock_cursor
# Mock the return value
mock_cursor.fetchone.return_value = [
1, "BTC", "USD", True, "Kraken", "2023-01-01T00:00:00", "2023-01-01T00:00:00"
1,
"BTC",
"USD",
True,
"Kraken",
"2023-01-01T00:00:00",
"2023-01-01T00:00:00",
]
# Create pairing
pairing = ConfigPairing(
base_asset="BTC",
quote_asset="USD",
enabled=True,
source="Kraken"
)
base_asset="BTC", quote_asset="USD", enabled=True, source="Kraken")
result = repo.create_pairing(pairing)
@@ -214,13 +256,19 @@ def test_config_pairing_repository_get_pairing(mock_store):
repo = ConfigPairingRepository(mock_store)
# Mock database connection
with patch.object(mock_store, 'connect') as mock_connect:
with patch.object(mock_store, "connect") as mock_connect:
mock_cursor = Mock()
mock_connect.return_value.__enter__.return_value = mock_cursor
# Mock the return value
mock_cursor.fetchone.return_value = [
1, "BTC", "USD", True, "Kraken", "2023-01-01T00:00:00", "2023-01-01T00:00:00"
1,
"BTC",
"USD",
True,
"Kraken",
"2023-01-01T00:00:00",
"2023-01-01T00:00:00",
]
# Get pairing
@@ -234,12 +282,6 @@ def test_config_pairing_repository_get_pairing(mock_store):
assert result.source == "Kraken"
def test_config_pair_fee_repository_initialization(mock_store):
"""Test ConfigPairFeeRepository initialization."""
repo = ConfigPairFeeRepository(mock_store)
assert repo._store == mock_store
def test_config_backtesting_defaults_repository_initialization(mock_store):
"""Test ConfigBacktestingDefaultsRepository initialization."""
repo = ConfigBacktestingDefaultsRepository(mock_store)
+14 -26
View File
@@ -31,9 +31,7 @@ def mock_audit_repo():
return audit_repo
def test_configuration_service_initialization(
mock_settings, mock_store, mock_audit_repo
):
def test_configuration_service_initialization(mock_settings, mock_store, mock_audit_repo):
"""Test that ConfigurationService initializes correctly."""
# Create service instance
service = ConfigurationService(mock_settings, mock_store, mock_audit_repo)
@@ -46,9 +44,7 @@ def test_configuration_service_initialization(
assert isinstance(service._loaded_settings, dict)
def test_configuration_service_get_setting(
mock_settings, mock_store, mock_audit_repo
):
def test_configuration_service_get_setting(mock_settings, mock_store, mock_audit_repo):
"""Test getting configuration settings."""
# Create service instance
service = ConfigurationService(mock_settings, mock_store, mock_audit_repo)
@@ -65,15 +61,13 @@ def test_configuration_service_get_setting(
assert result == "default"
def test_configuration_service_set_setting(
mock_settings, mock_store, mock_audit_repo
):
def test_configuration_service_set_setting(mock_settings, mock_store, mock_audit_repo):
"""Test setting configuration settings."""
# Create service instance
service = ConfigurationService(mock_settings, mock_store, mock_audit_repo)
# Mock the repository
with patch('arbitrade.config.service.ConfigSettingRepository') as mock_repo_class:
with patch("arbitrade.config.service.ConfigSettingRepository") as mock_repo_class:
mock_repo_instance = Mock()
mock_repo_class.return_value = mock_repo_instance
@@ -90,9 +84,7 @@ def test_configuration_service_set_setting(
mock_repo_instance.create_setting.assert_called_once()
def test_configuration_service_hot_reload_detection(
mock_settings, mock_store, mock_audit_repo
):
def test_configuration_service_hot_reload_detection(mock_settings, mock_store, mock_audit_repo):
"""Test hot-reload detection functionality."""
# Create service instance
service = ConfigurationService(mock_settings, mock_store, mock_audit_repo)
@@ -101,27 +93,26 @@ def test_configuration_service_hot_reload_detection(
assert service.is_config_outdated() is False
# Test with mock repository that returns a timestamp
with patch('arbitrade.config.service.ConfigSettingRepository') as mock_repo_class:
with patch("arbitrade.config.service.ConfigSettingRepository") as mock_repo_class:
mock_repo_instance = Mock()
mock_repo_class.return_value = mock_repo_instance
# Mock the latest updated at timestamp
from datetime import datetime
mock_repo_instance.get_latest_updated_at.return_value = datetime.now()
# Should detect as outdated when timestamp exists
assert service.is_config_outdated() is True
def test_configuration_service_reload_if_changed(
mock_settings, mock_store, mock_audit_repo
):
def test_configuration_service_reload_if_changed(mock_settings, mock_store, mock_audit_repo):
"""Test hot-reload functionality."""
# Create service instance
service = ConfigurationService(mock_settings, mock_store, mock_audit_repo)
# Mock the repository
with patch('arbitrade.config.service.ConfigSettingRepository') as mock_repo_class:
with patch("arbitrade.config.service.ConfigSettingRepository") as mock_repo_class:
mock_repo_instance = Mock()
mock_repo_class.return_value = mock_repo_instance
@@ -135,6 +126,7 @@ def test_configuration_service_reload_if_changed(
# Mock the latest updated at timestamp to return a value
from datetime import datetime
mock_repo_instance.get_latest_updated_at.return_value = datetime.now()
# Should reload when outdated
@@ -143,9 +135,7 @@ def test_configuration_service_reload_if_changed(
assert service.get_config_version() == 1
def test_configuration_service_get_config_version(
mock_settings, mock_store, mock_audit_repo
):
def test_configuration_service_get_config_version(mock_settings, mock_store, mock_audit_repo):
"""Test getting configuration version."""
# Create service instance
service = ConfigurationService(mock_settings, mock_store, mock_audit_repo)
@@ -154,7 +144,7 @@ def test_configuration_service_get_config_version(
assert service.get_config_version() == 0
# After setting a value, version should increment
with patch('arbitrade.config.service.ConfigSettingRepository') as mock_repo_class:
with patch("arbitrade.config.service.ConfigSettingRepository") as mock_repo_class:
mock_repo_instance = Mock()
mock_repo_class.return_value = mock_repo_instance
@@ -166,9 +156,7 @@ def test_configuration_service_get_config_version(
assert service.get_config_version() == 1
def test_configuration_service_get_last_updated_at(
mock_settings, mock_store, mock_audit_repo
):
def test_configuration_service_get_last_updated_at(mock_settings, mock_store, mock_audit_repo):
"""Test getting last updated timestamp."""
# Create service instance
service = ConfigurationService(mock_settings, mock_store, mock_audit_repo)
@@ -177,7 +165,7 @@ def test_configuration_service_get_last_updated_at(
assert service.get_last_updated_at() is None
# After setting a value, should have timestamp
with patch('arbitrade.config.service.ConfigSettingRepository') as mock_repo_class:
with patch("arbitrade.config.service.ConfigSettingRepository") as mock_repo_class:
mock_repo_instance = Mock()
mock_repo_class.return_value = mock_repo_instance
+1 -3
View File
@@ -12,8 +12,6 @@ def test_template_directory_resolves_to_existing_location() -> None:
def test_template_exists_in_package_resources() -> None:
template_path = resources.files("arbitrade").joinpath(
"web", "templates", "dashboard.html"
)
template_path = resources.files("arbitrade").joinpath("web", "templates", "dashboard.html")
assert template_path.is_file()