From c8e3daeb571ad0a960f133bcf74d765a5b4c2a95 Mon Sep 17 00:00:00 2001 From: zwitschi Date: Thu, 4 Jun 2026 19:04:30 +0200 Subject: [PATCH] Refactor code for improved readability and consistency - Consolidated multiline string formatting into single-line for SQL queries in multiple files. - Adjusted argument formatting in function calls for better alignment and readability. - Removed unnecessary line breaks and improved spacing in various sections of the codebase. - Updated test cases to maintain consistency in formatting and improve clarity. --- scripts/backtest_replay.py | 21 ++--- scripts/backtest_sweep.py | 29 +++--- scripts/benchmark_metrics_compute.py | 6 +- src/arbitrade/api/app.py | 8 +- src/arbitrade/api/routes.py | 93 ++++++++++--------- src/arbitrade/backtesting/replay.py | 23 ++--- src/arbitrade/backtesting/runner.py | 20 ++--- src/arbitrade/backtesting/sweep.py | 27 +++--- src/arbitrade/config/service.py | 20 +++-- src/arbitrade/config/settings.py | 117 ++++++++---------------- src/arbitrade/detection/benchmark.py | 3 +- src/arbitrade/exchange/fee_service.py | 22 ++--- src/arbitrade/metrics.py | 18 ++-- src/arbitrade/runtime/lifecycle.py | 12 ++- src/arbitrade/storage/db.py | 51 +++++------ src/arbitrade/storage/repositories.py | 120 ++++++++++++++----------- tests/test_dashboard.py | 18 ++-- tests/unit/test_config_e2e.py | 2 +- tests/unit/test_config_repositories.py | 106 +++++++++++++++------- tests/unit/test_config_service.py | 40 +++------ tests/unit/test_template_resolution.py | 4 +- 21 files changed, 377 insertions(+), 383 deletions(-) diff --git a/scripts/backtest_replay.py b/scripts/backtest_replay.py index e0c704f..7acddfd 100644 --- a/scripts/backtest_replay.py +++ b/scripts/backtest_replay.py @@ -19,10 +19,12 @@ def _resolve_fee_rate(fee_rate: float | None, db_path: str | None = None) -> flo if db_path is not None: try: conn = duckdb.connect(db_path) - row = conn.execute(""" + row = conn.execute( + """ SELECT maker_fee FROM kraken_account_snapshots ORDER BY snapshot_at DESC LIMIT 1 - """).fetchone() + """ + ).fetchone() conn.close() if row is not None and row[0] is not None: return float(row[0]) @@ -51,16 +53,14 @@ def _parse_balances(raw: str) -> Mapping[str, float]: def main() -> int: - parser = argparse.ArgumentParser( - description="Run a deterministic replay backtest.") + parser = argparse.ArgumentParser(description="Run a deterministic replay backtest.") parser.add_argument("--events", type=Path, required=True) parser.add_argument("--starting-balances", type=str, default="USD=1000.0") parser.add_argument("--trade-capital", type=float, default=100.0) parser.add_argument("--fee-rate", type=float, default=None) parser.add_argument("--slippage-bps", type=float, default=4.0) parser.add_argument("--execution-latency-ms", type=float, default=20.0) - parser.add_argument("--db-path", type=str, default=None, - help="DuckDB path for fee lookup") + parser.add_argument("--db-path", type=str, default=None, help="DuckDB path for fee lookup") args = parser.parse_args() cycles_by_pair, available_pairs = _build_graph() @@ -80,18 +80,15 @@ def main() -> int: started_at=events[0].occurred_at if events else datetime.now(UTC), ) report = asyncio.run( - engine.run(events, starting_balances=_parse_balances( - args.starting_balances)) + engine.run(events, starting_balances=_parse_balances(args.starting_balances)) ) print("Backtest report:") print(f"- processed_events: {report.processed_events}") print(f"- opportunities_seen: {report.opportunities_seen}") print(f"- trades_executed: {report.trades_executed}") - print( - f"- win_rate: {report.win_rate if report.win_rate is not None else 'n/a'}") - print( - f"- fill_rate: {report.fill_rate if report.fill_rate is not None else 'n/a'}") + print(f"- win_rate: {report.win_rate if report.win_rate is not None else 'n/a'}") + print(f"- fill_rate: {report.fill_rate if report.fill_rate is not None else 'n/a'}") print(f"- realized_pnl_usd: {report.realized_pnl_usd:.4f}") print(f"- max_drawdown_usd: {report.max_drawdown_usd:.4f}") print(f"- miss_reasons: {dict(report.miss_reasons)}") diff --git a/scripts/backtest_sweep.py b/scripts/backtest_sweep.py index e7376a8..1e96b51 100644 --- a/scripts/backtest_sweep.py +++ b/scripts/backtest_sweep.py @@ -36,8 +36,7 @@ def _parse_float_list(raw: str) -> list[float]: def _parse_pair_universes(raw: str) -> list[tuple[str, ...]]: universes: list[tuple[str, ...]] = [] for chunk in raw.split(";"): - symbols = tuple(item.strip().upper() - for item in chunk.split("|") if item.strip()) + symbols = tuple(item.strip().upper() for item in chunk.split("|") if item.strip()) if symbols: universes.append(symbols) if not universes: @@ -75,31 +74,29 @@ def _print_top_results(results: Sequence[SweepResult], *, limit: int = 5) -> Non def main() -> int: parser = argparse.ArgumentParser( - description="Run backtesting parameter sweep with train/test split.") + description="Run backtesting parameter sweep with train/test split." + ) parser.add_argument("--events", type=Path, required=True) parser.add_argument("--starting-balances", type=str, default="USD=1000.0") - parser.add_argument("--theta-values", type=str, - default="0.0003,0.0005,0.0008") - parser.add_argument("--trade-capital-values", - type=str, default="50,100,150") + parser.add_argument("--theta-values", type=str, default="0.0003,0.0005,0.0008") + parser.add_argument("--trade-capital-values", type=str, default="50,100,150") parser.add_argument( "--pair-universes", type=str, default="BTC/USD|ETH/BTC|ETH/USD", help="Semicolon-separated universes, each with | delimited pairs", ) - parser.add_argument("--staleness-threshold-values", - type=str, default="3,5,8") + parser.add_argument("--staleness-threshold-values", type=str, default="3,5,8") parser.add_argument("--train-ratio", type=float, default=0.7) - parser.add_argument("--output", type=Path, - default=Path("ops/backtesting/parameter_sweep_results.json")) + parser.add_argument( + "--output", type=Path, default=Path("ops/backtesting/parameter_sweep_results.json") + ) parser.add_argument("--min-test-realized-pnl-usd", type=float, default=0.0) parser.add_argument("--min-test-win-rate", type=float, default=0.5) parser.add_argument("--min-test-fill-rate", type=float, default=0.9) parser.add_argument("--max-test-drawdown-usd", type=float, default=25.0) - parser.add_argument("--max-generalization-gap-ratio", - type=float, default=0.5) + parser.add_argument("--max-generalization-gap-ratio", type=float, default=0.5) args = parser.parse_args() @@ -107,15 +104,13 @@ def main() -> int: symbols = sorted({event.symbol.upper() for event in events}) cycles_by_pair = _build_graph_from_symbols(symbols) if not cycles_by_pair: - raise SystemExit( - "No triangular cycles found in supplied replay events") + raise SystemExit("No triangular cycles found in supplied replay events") grid = build_parameter_grid( theta_values=_parse_float_list(args.theta_values), trade_capital_values=_parse_float_list(args.trade_capital_values), pair_universes=_parse_pair_universes(args.pair_universes), - staleness_threshold_values=_parse_float_list( - args.staleness_threshold_values), + staleness_threshold_values=_parse_float_list(args.staleness_threshold_values), ) artifacts = run_parameter_search( diff --git a/scripts/benchmark_metrics_compute.py b/scripts/benchmark_metrics_compute.py index 6486b07..92ac1e7 100644 --- a/scripts/benchmark_metrics_compute.py +++ b/scripts/benchmark_metrics_compute.py @@ -13,11 +13,13 @@ from arbitrade.storage.db import DuckDBStore def _python_scan_compute(store: DuckDBStore) -> tuple[float, float | None, float | None]: with store.connect() as conn: - trade_rows = conn.execute(""" + trade_rows = conn.execute( + """ SELECT started_at, finished_at, realized_pnl FROM trades WHERE finished_at IS NOT NULL - """).fetchall() + """ + ).fetchall() opportunity_rows = conn.execute("SELECT detected_at FROM opportunities").fetchall() realized = sum(float(row[2]) for row in trade_rows if row[2] is not None) diff --git a/src/arbitrade/api/app.py b/src/arbitrade/api/app.py index 0737048..603908a 100644 --- a/src/arbitrade/api/app.py +++ b/src/arbitrade/api/app.py @@ -29,8 +29,9 @@ def create_app(settings: Settings) -> FastAPI: db.migrate() kraken_client = KrakenRestClient(settings) fee_sync_stop_event = asyncio.Event() - backtest_queue: asyncio.Queue[tuple[str, str, - dict[str, object] | None] | None] = asyncio.Queue() + backtest_queue: asyncio.Queue[tuple[str, str, dict[str, object] | None] | None] = ( + asyncio.Queue() + ) @asynccontextmanager async def lifespan(app: FastAPI) -> AsyncIterator[None]: @@ -75,8 +76,7 @@ def create_app(settings: Settings) -> FastAPI: app.state.audit_repository = AuditRepository(db) app.state.runtime_state_repository = RuntimeStateRepository(db) app.state.alert_notifier = build_notifier_from_settings(settings) - app.state.configuration_service = ConfigurationService( - settings, db, AuditRepository(db)) + app.state.configuration_service = ConfigurationService(settings, db, AuditRepository(db)) app.state.backtest_recent_reports = [] app.state.dashboard_controls = DashboardControlState( is_running=not settings.kill_switch_active, diff --git a/src/arbitrade/api/routes.py b/src/arbitrade/api/routes.py index 0d79ece..8dbe8b9 100644 --- a/src/arbitrade/api/routes.py +++ b/src/arbitrade/api/routes.py @@ -19,7 +19,12 @@ from arbitrade.api.auth import require_dashboard_auth from arbitrade.api.control_state import DashboardControlState from arbitrade.backtesting.replay import BacktestConfig, BacktestReplayEngine, load_replay_events from arbitrade.detection.graph import CurrencyGraph, TriangularCycle -from arbitrade.storage.repositories import AuditRecord, AuditRepository, BacktestJobRepository, KrakenAccountSnapshotRepository +from arbitrade.storage.repositories import ( + AuditRecord, + AuditRepository, + BacktestJobRepository, + KrakenAccountSnapshotRepository, +) router = APIRouter(dependencies=[Depends(require_dashboard_auth)]) public_router = APIRouter() @@ -27,8 +32,7 @@ public_router = APIRouter() def _resolve_templates_directory() -> str: # Support source layout, Docker runtime (/app), and installed package data. - source_layout_path = Path( - __file__).resolve().parents[3] / "web" / "templates" + source_layout_path = Path(__file__).resolve().parents[3] / "web" / "templates" if source_layout_path.is_dir(): return str(source_layout_path) @@ -37,8 +41,7 @@ def _resolve_templates_directory() -> str: return str(docker_runtime_path) try: - package_path = resources.files( - "arbitrade").joinpath("web", "templates") + package_path = resources.files("arbitrade").joinpath("web", "templates") if package_path.is_dir(): return str(package_path) except (ModuleNotFoundError, AttributeError): @@ -101,29 +104,37 @@ def _dashboard_overview(request: Request) -> dict[str, object]: else: open_trade_filter = "LOWER(status) NOT IN ('filled', 'closed', 'cancelled', 'canceled')" - portfolio_row = conn.execute(""" + portfolio_row = conn.execute( + """ SELECT balances, total_value_usd FROM portfolio_snapshots ORDER BY snapshot_at DESC LIMIT 1 - """).fetchone() - open_trades = conn.execute(f""" + """ + ).fetchone() + open_trades = conn.execute( + f""" SELECT {trade_ref_expr}, status, started_at, {cycle_expr} FROM trades WHERE {open_trade_filter} ORDER BY started_at DESC LIMIT 5 - """).fetchall() - rpnl = conn.execute(""" + """ + ).fetchall() + rpnl = conn.execute( + """ SELECT COALESCE(SUM(COALESCE(realized_pnl, 0)), 0) FROM trades - """).fetchone() - latest_opportunities = conn.execute(""" + """ + ).fetchone() + latest_opportunities = conn.execute( + """ SELECT cycle, net_pct, est_profit, detected_at FROM opportunities ORDER BY detected_at DESC LIMIT 5 - """).fetchall() + """ + ).fetchall() balances_value = "—" total_value = "—" @@ -135,8 +146,7 @@ def _dashboard_overview(request: Request) -> dict[str, object]: parsed = json.loads(balances_raw) if isinstance(parsed, dict): # Filter out zero balances, show non-zero as "AMT ASSET" - non_zero = {k: float(v) - for k, v in parsed.items() if float(v) > 0.0} + non_zero = {k: float(v) for k, v in parsed.items() if float(v) > 0.0} if non_zero: balances_value = "
".join( f"{v:.6g} {k}" for k, v in sorted(non_zero.items()) @@ -154,12 +164,14 @@ def _dashboard_overview(request: Request) -> dict[str, object]: # Query equity from kraken_account_snapshots try: - equity_row = conn.execute(""" + equity_row = conn.execute( + """ SELECT trade_balance_raw FROM kraken_account_snapshots ORDER BY snapshot_at DESC LIMIT 1 - """).fetchone() + """ + ).fetchone() if equity_row is not None and equity_row[0] is not None: tb_raw = equity_row[0] if isinstance(tb_raw, str): @@ -195,12 +207,14 @@ def _dashboard_overview(request: Request) -> dict[str, object]: taker_fee = "—" thirty_day_volume = "—" try: - acct_row = conn.execute(""" + acct_row = conn.execute( + """ SELECT fee_tier, maker_fee, taker_fee, thirty_day_volume FROM kraken_account_snapshots ORDER BY snapshot_at DESC LIMIT 1 - """).fetchone() + """ + ).fetchone() if acct_row is not None: fee_tier = str(acct_row[0]) if acct_row[0] is not None else "—" maker_fee = f"{float(acct_row[1]):.4%}" if acct_row[1] is not None else "—" @@ -230,12 +244,14 @@ def _dashboard_overview(request: Request) -> dict[str, object]: def _dashboard_charts(request: Request) -> dict[str, object]: store = request.app.state.store with store.connect() as conn: - opportunity_rows = conn.execute(""" + opportunity_rows = conn.execute( + """ SELECT detected_at, cycle, net_pct, est_profit FROM opportunities ORDER BY detected_at DESC LIMIT 10 - """).fetchall() + """ + ).fetchall() cr = list(reversed(opportunity_rows)) labels = [] @@ -375,12 +391,12 @@ def _dashboard_config_context(request: Request) -> dict[str, object]: else "—" ) max_exposure_per_asset_value = ( - f"{float(rs.max_exposure_per_asset_usd):.2f}" if rs.max_exposure_per_asset_usd is not None else "" + f"{float(rs.max_exposure_per_asset_usd):.2f}" + if rs.max_exposure_per_asset_usd is not None + else "" ) daily_loss_limit = ( - f"{float(rs.daily_loss_limit_usd):.2f} USD" - if rs.daily_loss_limit_usd is not None - else "—" + f"{float(rs.daily_loss_limit_usd):.2f} USD" if rs.daily_loss_limit_usd is not None else "—" ) daily_loss_limit_value = ( f"{float(rs.daily_loss_limit_usd):.2f}" if rs.daily_loss_limit_usd is not None else "" @@ -391,20 +407,18 @@ def _dashboard_config_context(request: Request) -> dict[str, object]: else "—" ) cumulative_loss_limit_value = ( - f"{float(rs.cumulative_loss_limit_usd):.2f}" if rs.cumulative_loss_limit_usd is not None else "" + f"{float(rs.cumulative_loss_limit_usd):.2f}" + if rs.cumulative_loss_limit_usd is not None + else "" ) max_source_latency = ( - f"{float(rs.max_source_latency_ms):.1f} ms" - if rs.max_source_latency_ms is not None - else "—" + f"{float(rs.max_source_latency_ms):.1f} ms" if rs.max_source_latency_ms is not None else "—" ) max_source_latency_value = ( f"{float(rs.max_source_latency_ms):.1f}" if rs.max_source_latency_ms is not None else "" ) max_apply_latency = ( - f"{float(rs.max_apply_latency_ms):.1f} ms" - if rs.max_apply_latency_ms is not None - else "—" + f"{float(rs.max_apply_latency_ms):.1f} ms" if rs.max_apply_latency_ms is not None else "—" ) max_apply_latency_value = ( f"{float(rs.max_apply_latency_ms):.1f}" if rs.max_apply_latency_ms is not None else "" @@ -415,8 +429,7 @@ def _dashboard_config_context(request: Request) -> dict[str, object]: max_consecutive_failures_value = ( str(rs.max_consecutive_failures) if rs.max_consecutive_failures is not None else "" ) - strategy_stat_arb_enabled = bool( - getattr(rs, "strategy_enable_stat_arb_experiment", False)) + strategy_stat_arb_enabled = bool(getattr(rs, "strategy_enable_stat_arb_experiment", False)) return { # Runtime @@ -537,8 +550,7 @@ def _dashboard_controls(request: Request) -> dict[str, object]: alerts_last_channel_results = [ str(item) for item in cast(list[object], alert_status.get("last_channel_results", [])) ] - strategy_stat_arb_enabled = bool( - getattr(rs, "strategy_enable_stat_arb_experiment", False)) + strategy_stat_arb_enabled = bool(getattr(rs, "strategy_enable_stat_arb_experiment", False)) return { "execution_status": "running" if ctl.is_running else "stopped", @@ -943,11 +955,9 @@ async def dashboard_backtesting_run(request: Request) -> HTMLResponse: try: custom_fee_rate = ( - float(defaults["custom_fee_rate"]) - if defaults["custom_fee_rate"].strip() else None + float(defaults["custom_fee_rate"]) if defaults["custom_fee_rate"].strip() else None ) - fee_rate = _fee_rate_for_profile( - defaults["fee_profile"], custom_fee_rate, request=request) + fee_rate = _fee_rate_for_profile(defaults["fee_profile"], custom_fee_rate, request=request) config_dict: dict[str, object] = { "source": defaults["source"], @@ -1067,8 +1077,7 @@ async def dashboard_backtesting_export(request: Request, job_id: str) -> Respons return Response( content=orjson.dumps(payload).decode("utf-8"), media_type="application/x-jsonlines", - headers={ - "Content-Disposition": f"attachment; filename=backtest_{job_id[:8]}.jsonl"}, + headers={"Content-Disposition": f"attachment; filename=backtest_{job_id[:8]}.jsonl"}, ) diff --git a/src/arbitrade/backtesting/replay.py b/src/arbitrade/backtesting/replay.py index d4fdd54..6aa2883 100644 --- a/src/arbitrade/backtesting/replay.py +++ b/src/arbitrade/backtesting/replay.py @@ -153,8 +153,7 @@ def _parse_book_levels(raw_levels: Any) -> tuple[BookLevel, ...]: or not isinstance(raw_level[1], int | float) ): raise ValueError("Each level must be [price, volume]") - levels.append(BookLevel(price=float( - raw_level[0]), volume=float(raw_level[1]))) + levels.append(BookLevel(price=float(raw_level[0]), volume=float(raw_level[1]))) return tuple(levels) @@ -173,8 +172,7 @@ def load_replay_events(path: Path) -> list[ReplayBookEvent]: if not isinstance(timestamp_raw, str) or not isinstance(symbol_raw, str): raise ValueError("Each event must include timestamp and symbol") - occurred_at = datetime.fromisoformat( - timestamp_raw.replace("Z", "+00:00")).astimezone(UTC) + occurred_at = datetime.fromisoformat(timestamp_raw.replace("Z", "+00:00")).astimezone(UTC) events.append( ReplayBookEvent( occurred_at=occurred_at, @@ -266,10 +264,7 @@ def _parse_kraken_book_levels( levels: list[BookLevel] = [] for level in raw_levels: if isinstance(level, dict) and "price" in level and "qty" in level: - levels.append( - BookLevel(price=float(level["price"]), - volume=float(level["qty"])) - ) + levels.append(BookLevel(price=float(level["price"]), volume=float(level["qty"]))) return tuple(levels) @@ -294,8 +289,7 @@ class BacktestReplayEngine: min_order_size_by_pair=config.min_order_size_by_pair, ) self._pre_trade = PreTradeValidator() - self._trade_limits = TradeLimitsGuard( - max_concurrent_trades=config.max_concurrent_trades) + self._trade_limits = TradeLimitsGuard(max_concurrent_trades=config.max_concurrent_trades) self._simulated_rest = _SimulatedRestClient( self._clock, slippage_bps=config.slippage_bps, @@ -330,8 +324,7 @@ class BacktestReplayEngine: trades_executed = 0 realized_pnl = 0.0 - equity = float(starting_balances.get( - self._config.quote_asset.upper(), 0.0)) + equity = float(starting_balances.get(self._config.quote_asset.upper(), 0.0)) peak_equity = equity max_drawdown = 0.0 @@ -374,8 +367,7 @@ class BacktestReplayEngine: result = await self._sequencer.execute(opportunity) self._trade_limits.close_trade(exposure) - execution_latencies.append( - self._simulated_rest.last_trade_latency_ms) + execution_latencies.append(self._simulated_rest.last_trade_latency_ms) fill_samples.append(self._simulated_rest.last_fill_ratio) if not result.success: @@ -398,8 +390,7 @@ class BacktestReplayEngine: wins = sum(1 for pnl in realized_samples if pnl > 0.0) win_rate = (wins / len(realized_samples)) if realized_samples else None - fill_rate = (sum(fill_samples) / len(fill_samples) - ) if fill_samples else None + fill_rate = (sum(fill_samples) / len(fill_samples)) if fill_samples else None return BacktestReport( started_at=events[0].occurred_at if events else self._clock.now, diff --git a/src/arbitrade/backtesting/runner.py b/src/arbitrade/backtesting/runner.py index 2dbd904..3d1be7d 100644 --- a/src/arbitrade/backtesting/runner.py +++ b/src/arbitrade/backtesting/runner.py @@ -69,20 +69,21 @@ async def run_backtest_job( start_str = config.get("start_time") end_str = config.get("end_time") if isinstance(start_str, str) and start_str: - start_dt = datetime.fromisoformat( - start_str.replace("Z", "+00:00")) + start_dt = datetime.fromisoformat(start_str.replace("Z", "+00:00")) if isinstance(end_str, str) and end_str: end_dt = datetime.fromisoformat(end_str.replace("Z", "+00:00")) symbols: list[str] | None = None if isinstance(symbols_raw, str) and symbols_raw.strip(): - symbols = [s.strip().upper() - for s in symbols_raw.split(",") if s.strip()] + symbols = [s.strip().upper() for s in symbols_raw.split(",") if s.strip()] elif isinstance(symbols_raw, list): symbols = [str(s).upper() for s in symbols_raw] events = load_replay_events_from_db( - store, symbols=symbols, start=start_dt, end=end_dt, + store, + symbols=symbols, + start=start_dt, + end=end_dt, ) else: path = Path(events_path) @@ -94,14 +95,12 @@ async def run_backtest_job( if not events: raise ValueError("No events found for backtest") - starting_balances_raw = str(config.get( - "starting_balances", "USD=1000.0")) + starting_balances_raw = str(config.get("starting_balances", "USD=1000.0")) starting_balances = _parse_balances(starting_balances_raw) fee_rate = float(config.get("fee_rate", 0.0026)) trade_capital = float(config.get("trade_capital", 100.0)) - min_profit_threshold = float( - config.get("min_profit_threshold", 0.0005)) + min_profit_threshold = float(config.get("min_profit_threshold", 0.0005)) slippage_bps = float(config.get("slippage_bps", 4.0)) execution_latency_ms = float(config.get("execution_latency_ms", 20.0)) @@ -144,8 +143,7 @@ async def run_backtest_job( repo.store_report(job_id, report_dict) repo.update_status(job_id, "completed") - _LOG.info("backtest_job_completed", job_id=job_id, - pnl=report.realized_pnl_usd) + _LOG.info("backtest_job_completed", job_id=job_id, pnl=report.realized_pnl_usd) except Exception as exc: repo.update_status(job_id, "failed", error=str(exc)) diff --git a/src/arbitrade/backtesting/sweep.py b/src/arbitrade/backtesting/sweep.py index 67c44a7..2b3fc33 100644 --- a/src/arbitrade/backtesting/sweep.py +++ b/src/arbitrade/backtesting/sweep.py @@ -91,16 +91,14 @@ def build_parameter_grid( for theta in theta_values: for trade_capital in trade_capital_values: for pair_universe in pair_universes: - normalized_universe = tuple( - sorted({pair.upper() for pair in pair_universe})) + normalized_universe = tuple(sorted({pair.upper() for pair in pair_universe})) for staleness_threshold in staleness_threshold_values: grid.append( SweepParameters( min_profit_threshold=float(theta), trade_capital=float(trade_capital), pair_universe=normalized_universe, - staleness_threshold_seconds=float( - staleness_threshold), + staleness_threshold_seconds=float(staleness_threshold), ) ) return grid @@ -147,8 +145,9 @@ def _restrict_cycles_by_pair( if normalized_pair not in pair_universe: continue - kept = [cycle for cycle in cycles if all( - pair.upper() in pair_universe for pair in cycle.pairs)] + kept = [ + cycle for cycle in cycles if all(pair.upper() in pair_universe for pair in cycle.pairs) + ] if kept: restricted[normalized_pair] = kept return restricted @@ -175,9 +174,7 @@ def _evaluate_promotion( test = result.test_report if test.realized_pnl_usd < criteria.min_test_realized_pnl_usd: - reasons.append( - "test_realized_pnl_below_threshold" - ) + reasons.append("test_realized_pnl_below_threshold") if (test.win_rate or 0.0) < criteria.min_test_win_rate: reasons.append("test_win_rate_below_threshold") if (test.fill_rate or 0.0) < criteria.min_test_fill_rate: @@ -221,8 +218,7 @@ def run_parameter_search( quote_asset: str = "USD", ) -> SweepArtifacts: criteria = promotion_criteria or PromotionCriteria() - train_events, test_events = split_events_time_windows( - events, train_ratio=train_ratio) + train_events, test_events = split_events_time_windows(events, train_ratio=train_ratio) results: list[SweepResult] = [] promoted: list[SweepResult] = [] @@ -293,7 +289,8 @@ def run_parameter_search( test_event_count=len(filtered_test), ) promotion_ready, promotion_reasons = _evaluate_promotion( - result=base_result, criteria=criteria) + result=base_result, criteria=criteria + ) completed_result = SweepResult( parameters=base_result.parameters, train_report=base_result.train_report, @@ -318,8 +315,7 @@ def run_parameter_search( train_window: tuple[datetime, datetime] | None = None test_window: tuple[datetime, datetime] | None = None if train_events: - train_window = (train_events[0].occurred_at, - train_events[-1].occurred_at) + train_window = (train_events[0].occurred_at, train_events[-1].occurred_at) if test_events: test_window = (test_events[0].occurred_at, test_events[-1].occurred_at) @@ -392,5 +388,4 @@ def persist_sweep_results(path: Path, artifacts: SweepArtifacts) -> None: } path.parent.mkdir(parents=True, exist_ok=True) - path.write_bytes(orjson.dumps( - payload, option=orjson.OPT_INDENT_2 | orjson.OPT_SORT_KEYS)) + path.write_bytes(orjson.dumps(payload, option=orjson.OPT_INDENT_2 | orjson.OPT_SORT_KEYS)) diff --git a/src/arbitrade/config/service.py b/src/arbitrade/config/service.py index 0f778a6..66babe4 100644 --- a/src/arbitrade/config/service.py +++ b/src/arbitrade/config/service.py @@ -63,6 +63,7 @@ class ConfigurationService: """Load user settings from database and merge with defaults.""" # Import here to avoid circular imports from arbitrade.storage.repositories import ConfigSettingRepository + setting_repo = ConfigSettingRepository(self._store) # Load all settings from database @@ -91,7 +92,8 @@ class ConfigurationService: # Track the latest update time if db_settings: latest_updated = max( - setting.updated_at for setting in db_settings if setting.updated_at) + setting.updated_at for setting in db_settings if setting.updated_at + ) self._last_updated_at = latest_updated # Initialize with default values from settings model @@ -119,6 +121,7 @@ class ConfigurationService: """Check if configuration has been updated since last load.""" # Import here to avoid circular imports from arbitrade.storage.repositories import ConfigSettingRepository + setting_repo = ConfigSettingRepository(self._store) # Get the latest update timestamp from database @@ -143,6 +146,7 @@ class ConfigurationService: """Set a configuration setting value and persist to database.""" # Import here to avoid circular imports from arbitrade.storage.repositories import ConfigSettingRepository + setting_repo = ConfigSettingRepository(self._store) # Convert value to JSON string and determine type @@ -159,10 +163,10 @@ class ConfigurationService: value_json = str(value).lower() value_type = "bool" elif isinstance(value, list): - value_json = orjson.dumps(value).decode('utf-8') + value_json = orjson.dumps(value).decode("utf-8") value_type = "list" elif isinstance(value, dict): - value_json = orjson.dumps(value).decode('utf-8') + value_json = orjson.dumps(value).decode("utf-8") value_type = "dict" else: value_json = str(value) @@ -176,7 +180,7 @@ class ConfigurationService: value_type=value_type, is_secret=False, is_runtime_reloadable=False, - updated_by=updated_by + updated_by=updated_by, ) # Check if setting exists @@ -205,17 +209,21 @@ class ConfigurationService: def _pairing_repo(self): from arbitrade.storage.repositories import ConfigPairingRepository + return ConfigPairingRepository(self._store) def list_pairings(self) -> list[ConfigPairing]: """List all currency pairings.""" return self._pairing_repo().list_pairings() - def create_pairing(self, base_asset: str, quote_asset: str, source: str = "manual") -> ConfigPairing: + def create_pairing( + self, base_asset: str, quote_asset: str, source: str = "manual" + ) -> ConfigPairing: """Create a new currency pairing.""" existing = self._pairing_repo().get_pairing(base_asset, quote_asset) if existing: return existing pairing = ConfigPairing( - base_asset=base_asset, quote_asset=quote_asset, enabled=True, source=source) + base_asset=base_asset, quote_asset=quote_asset, enabled=True, source=source + ) return self._pairing_repo().create_pairing(pairing) diff --git a/src/arbitrade/config/settings.py b/src/arbitrade/config/settings.py index 11ac3f0..d0d9c2e 100644 --- a/src/arbitrade/config/settings.py +++ b/src/arbitrade/config/settings.py @@ -32,72 +32,49 @@ class Settings(BaseSettings): ) alerts_enabled: bool = Field(default=True, alias="ALERTS_ENABLED") - alert_min_severity: str = Field( - default="warning", alias="ALERT_MIN_SEVERITY") - alert_dedup_seconds: float = Field( - default=30.0, alias="ALERT_DEDUP_SECONDS") - alert_on_trade_events: bool = Field( - default=True, alias="ALERT_ON_TRADE_EVENTS") - alert_on_error_events: bool = Field( - default=True, alias="ALERT_ON_ERROR_EVENTS") - alert_on_threshold_events: bool = Field( - default=True, alias="ALERT_ON_THRESHOLD_EVENTS") - alert_on_system_events: bool = Field( - default=True, alias="ALERT_ON_SYSTEM_EVENTS") + alert_min_severity: str = Field(default="warning", alias="ALERT_MIN_SEVERITY") + alert_dedup_seconds: float = Field(default=30.0, alias="ALERT_DEDUP_SECONDS") + alert_on_trade_events: bool = Field(default=True, alias="ALERT_ON_TRADE_EVENTS") + alert_on_error_events: bool = Field(default=True, alias="ALERT_ON_ERROR_EVENTS") + alert_on_threshold_events: bool = Field(default=True, alias="ALERT_ON_THRESHOLD_EVENTS") + alert_on_system_events: bool = Field(default=True, alias="ALERT_ON_SYSTEM_EVENTS") - telegram_alerts_enabled: bool = Field( - default=False, alias="TELEGRAM_ALERTS_ENABLED") - telegram_bot_token: str | None = Field( - default=None, alias="TELEGRAM_BOT_TOKEN") - telegram_chat_id: str | None = Field( - default=None, alias="TELEGRAM_CHAT_ID") + telegram_alerts_enabled: bool = Field(default=False, alias="TELEGRAM_ALERTS_ENABLED") + telegram_bot_token: str | None = Field(default=None, alias="TELEGRAM_BOT_TOKEN") + telegram_chat_id: str | None = Field(default=None, alias="TELEGRAM_CHAT_ID") - discord_alerts_enabled: bool = Field( - default=False, alias="DISCORD_ALERTS_ENABLED") - discord_webhook_url: str | None = Field( - default=None, alias="DISCORD_WEBHOOK_URL") + discord_alerts_enabled: bool = Field(default=False, alias="DISCORD_ALERTS_ENABLED") + discord_webhook_url: str | None = Field(default=None, alias="DISCORD_WEBHOOK_URL") - email_alerts_enabled: bool = Field( - default=False, alias="EMAIL_ALERTS_ENABLED") + email_alerts_enabled: bool = Field(default=False, alias="EMAIL_ALERTS_ENABLED") email_smtp_host: str | None = Field(default=None, alias="EMAIL_SMTP_HOST") email_smtp_port: int = Field(default=587, alias="EMAIL_SMTP_PORT") - email_smtp_username: str | None = Field( - default=None, alias="EMAIL_SMTP_USERNAME") - email_smtp_password: str | None = Field( - default=None, alias="EMAIL_SMTP_PASSWORD") - email_alert_from: str | None = Field( - default=None, alias="EMAIL_ALERT_FROM") + email_smtp_username: str | None = Field(default=None, alias="EMAIL_SMTP_USERNAME") + email_smtp_password: str | None = Field(default=None, alias="EMAIL_SMTP_PASSWORD") + email_alert_from: str | None = Field(default=None, alias="EMAIL_ALERT_FROM") email_alert_to: str | None = Field(default=None, alias="EMAIL_ALERT_TO") email_smtp_use_tls: bool = Field(default=True, alias="EMAIL_SMTP_USE_TLS") - duckdb_path: Path = Field(default=Path( - "./data/arbitrade.duckdb"), alias="DUCKDB_PATH") + duckdb_path: Path = Field(default=Path("./data/arbitrade.duckdb"), alias="DUCKDB_PATH") - kraken_rest_url: str = Field( - default="https://api.kraken.com", alias="KRAKEN_REST_URL") - kraken_ws_url: str = Field( - default="wss://ws.kraken.com/v2", alias="KRAKEN_WS_URL") + kraken_rest_url: str = Field(default="https://api.kraken.com", alias="KRAKEN_REST_URL") + kraken_ws_url: str = Field(default="wss://ws.kraken.com/v2", alias="KRAKEN_WS_URL") kraken_private_rate_limit_seconds: float = Field( default=1.0, alias="KRAKEN_PRIVATE_RATE_LIMIT_SECONDS" ) - kraken_http_timeout_seconds: float = Field( - default=10.0, alias="KRAKEN_HTTP_TIMEOUT_SECONDS") - kraken_retry_attempts: int = Field( - default=3, alias="KRAKEN_RETRY_ATTEMPTS") + kraken_http_timeout_seconds: float = Field(default=10.0, alias="KRAKEN_HTTP_TIMEOUT_SECONDS") + kraken_retry_attempts: int = Field(default=3, alias="KRAKEN_RETRY_ATTEMPTS") kraken_retry_base_delay_seconds: float = Field( default=0.25, alias="KRAKEN_RETRY_BASE_DELAY_SECONDS" ) kraken_api_key: str | None = Field(default=None, alias="KRAKEN_API_KEY") - kraken_api_secret: str | None = Field( - default=None, alias="KRAKEN_API_SECRET") + kraken_api_secret: str | None = Field(default=None, alias="KRAKEN_API_SECRET") kraken_api_key_permissions: str = Field( default="query,trade", alias="KRAKEN_API_KEY_PERMISSIONS", ) - ws_heartbeat_timeout_seconds: float = Field( - default=20.0, alias="WS_HEARTBEAT_TIMEOUT_SECONDS") - ws_max_staleness_seconds: float = Field( - default=5.0, alias="WS_MAX_STALENESS_SECONDS") + ws_heartbeat_timeout_seconds: float = Field(default=20.0, alias="WS_HEARTBEAT_TIMEOUT_SECONDS") + ws_max_staleness_seconds: float = Field(default=5.0, alias="WS_MAX_STALENESS_SECONDS") strategy_enable_stat_arb_experiment: bool = Field( default=False, alias="STRATEGY_ENABLE_STAT_ARB_EXPERIMENT", @@ -120,29 +97,20 @@ class Settings(BaseSettings): ) paper_trading_mode: bool = Field(default=True, alias="PAPER_TRADING_MODE") trade_capital_usd: float = Field(default=100.0, alias="TRADE_CAPITAL_USD") - max_trade_capital_usd: float = Field( - default=100.0, alias="MAX_TRADE_CAPITAL_USD") - max_concurrent_trades: int | None = Field( - default=None, alias="MAX_CONCURRENT_TRADES") + max_trade_capital_usd: float = Field(default=100.0, alias="MAX_TRADE_CAPITAL_USD") + max_concurrent_trades: int | None = Field(default=None, alias="MAX_CONCURRENT_TRADES") max_exposure_per_asset_usd: float | None = Field( default=None, alias="MAX_EXPOSURE_PER_ASSET_USD", ) - quote_balance_asset: str = Field( - default="USD", alias="QUOTE_BALANCE_ASSET") - min_order_size_usd: float | None = Field( - default=None, alias="MIN_ORDER_SIZE_USD") + quote_balance_asset: str = Field(default="USD", alias="QUOTE_BALANCE_ASSET") + min_order_size_usd: float | None = Field(default=None, alias="MIN_ORDER_SIZE_USD") kill_switch_active: bool = Field(default=False, alias="KILL_SWITCH_ACTIVE") - daily_loss_limit_usd: float | None = Field( - default=None, alias="DAILY_LOSS_LIMIT_USD") - cumulative_loss_limit_usd: float | None = Field( - default=None, alias="CUMULATIVE_LOSS_LIMIT_USD") - max_source_latency_ms: float | None = Field( - default=None, alias="MAX_SOURCE_LATENCY_MS") - max_apply_latency_ms: float | None = Field( - default=None, alias="MAX_APPLY_LATENCY_MS") - max_consecutive_failures: int | None = Field( - default=None, alias="MAX_CONSECUTIVE_FAILURES") + daily_loss_limit_usd: float | None = Field(default=None, alias="DAILY_LOSS_LIMIT_USD") + cumulative_loss_limit_usd: float | None = Field(default=None, alias="CUMULATIVE_LOSS_LIMIT_USD") + max_source_latency_ms: float | None = Field(default=None, alias="MAX_SOURCE_LATENCY_MS") + max_apply_latency_ms: float | None = Field(default=None, alias="MAX_APPLY_LATENCY_MS") + max_consecutive_failures: int | None = Field(default=None, alias="MAX_CONSECUTIVE_FAILURES") fernet_key: str | None = Field(default=None, alias="FERNET_KEY") @@ -159,8 +127,7 @@ class Settings(BaseSettings): def _validate_log_level(cls, value: str) -> str: normalized = value.strip().upper() if normalized not in {"DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"}: - raise ValueError( - "LOG_LEVEL must be one of: DEBUG, INFO, WARNING, ERROR, CRITICAL") + raise ValueError("LOG_LEVEL must be one of: DEBUG, INFO, WARNING, ERROR, CRITICAL") return normalized @field_validator("alert_min_severity") @@ -168,19 +135,16 @@ class Settings(BaseSettings): def _validate_alert_severity(cls, value: str) -> str: normalized = value.strip().lower() if normalized not in {"info", "warning", "error", "critical"}: - raise ValueError( - "ALERT_MIN_SEVERITY must be one of: info, warning, error, critical") + raise ValueError("ALERT_MIN_SEVERITY must be one of: info, warning, error, critical") return normalized @model_validator(mode="after") def _validate_security_constraints(self) -> Settings: if bool(self.dashboard_auth_username) ^ bool(self.dashboard_auth_password): - raise ValueError( - "dashboard auth requires both username and password") + raise ValueError("dashboard auth requires both username and password") if bool(self.kraken_api_key) ^ bool(self.kraken_api_secret): - raise ValueError( - "Kraken API auth requires both API key and secret") + raise ValueError("Kraken API auth requires both API key and secret") permissions = { token.strip().lower() @@ -188,11 +152,9 @@ class Settings(BaseSettings): if token.strip() } if permissions and ("query" not in permissions or "trade" not in permissions): - raise ValueError( - "KRAKEN_API_KEY_PERMISSIONS must include query and trade") + raise ValueError("KRAKEN_API_KEY_PERMISSIONS must include query and trade") if "withdraw" in permissions or "withdrawals" in permissions: - raise ValueError( - "KRAKEN_API_KEY_PERMISSIONS must not include withdrawal scope") + raise ValueError("KRAKEN_API_KEY_PERMISSIONS must not include withdrawal scope") if self.alert_dedup_seconds < 0.0: raise ValueError("ALERT_DEDUP_SECONDS must be >= 0") @@ -208,8 +170,7 @@ class Settings(BaseSettings): "STRATEGY_STAT_ARB_ENTRY_ZSCORE must be greater than STRATEGY_STAT_ARB_EXIT_ZSCORE" ) if self.strategy_stat_arb_max_holding_seconds <= 0.0: - raise ValueError( - "STRATEGY_STAT_ARB_MAX_HOLDING_SECONDS must be > 0") + raise ValueError("STRATEGY_STAT_ARB_MAX_HOLDING_SECONDS must be > 0") return self diff --git a/src/arbitrade/detection/benchmark.py b/src/arbitrade/detection/benchmark.py index bfd053c..2573364 100644 --- a/src/arbitrade/detection/benchmark.py +++ b/src/arbitrade/detection/benchmark.py @@ -92,8 +92,7 @@ def run_incremental_detection_benchmark( def main() -> None: - parser = argparse.ArgumentParser( - description="Benchmark incremental detection latency") + parser = argparse.ArgumentParser(description="Benchmark incremental detection latency") parser.add_argument("--iterations", type=int, default=50_000) parser.add_argument("--target-ms", type=float, default=1.0) args = parser.parse_args() diff --git a/src/arbitrade/exchange/fee_service.py b/src/arbitrade/exchange/fee_service.py index c9289ab..aa2ef20 100644 --- a/src/arbitrade/exchange/fee_service.py +++ b/src/arbitrade/exchange/fee_service.py @@ -43,12 +43,9 @@ async def fetch_and_store_account_snapshot( _LOG.exception("trade_balance_fetch_failed") return None - fee_tier = volume_data.get("fee_tier") if isinstance( - volume_data, dict) else None - fees_dict = volume_data.get("fees") if isinstance( - volume_data, dict) else None - fees_maker = volume_data.get("fees_maker") if isinstance( - volume_data, dict) else None + fee_tier = volume_data.get("fee_tier") if isinstance(volume_data, dict) else None + fees_dict = volume_data.get("fees") if isinstance(volume_data, dict) else None + fees_maker = volume_data.get("fees_maker") if isinstance(volume_data, dict) else None currency = volume_data.get("currency") thirty_day_volume_str = volume_data.get("volume") @@ -74,9 +71,7 @@ async def fetch_and_store_account_snapshot( if currency is not None: fee_schedule["currency"] = currency - thirty_day_volume = ( - float(thirty_day_volume_str) if thirty_day_volume_str is not None else None - ) + thirty_day_volume = float(thirty_day_volume_str) if thirty_day_volume_str is not None else None snapshot = KrakenAccountSnapshot( snapshot_at=datetime.now(timezone.utc), @@ -84,8 +79,7 @@ async def fetch_and_store_account_snapshot( maker_fee=maker_fee, taker_fee=taker_fee, thirty_day_volume=thirty_day_volume, - trade_balance_raw=balance_data if isinstance( - balance_data, dict) else None, + trade_balance_raw=balance_data if isinstance(balance_data, dict) else None, fee_schedule_raw=fee_schedule if fee_schedule else None, ) @@ -109,8 +103,7 @@ async def fetch_and_store_account_snapshot( "INSERT INTO portfolio_snapshots (snapshot_at, balances, total_value_usd) VALUES (?, ?, ?)", ( datetime.now(timezone.utc), - orjson.dumps(wallet_balances).decode( - "utf-8") if wallet_balances else None, + orjson.dumps(wallet_balances).decode("utf-8") if wallet_balances else None, total_value, ), ) @@ -130,8 +123,7 @@ async def run_fee_sync_loop( Runs until stop_event is set. """ - _LOG.info("fee_sync_loop_started", - interval_s=_FEE_REFRESH_INTERVAL_SECONDS) + _LOG.info("fee_sync_loop_started", interval_s=_FEE_REFRESH_INTERVAL_SECONDS) while not stop_event.is_set(): try: diff --git a/src/arbitrade/metrics.py b/src/arbitrade/metrics.py index aadaf32..995ef56 100644 --- a/src/arbitrade/metrics.py +++ b/src/arbitrade/metrics.py @@ -24,7 +24,8 @@ class MetricsCalculator: def compute(self) -> PerformanceMetrics: with self._store.connect() as conn: - tm = conn.execute(""" + tm = conn.execute( + """ SELECT COALESCE(SUM(COALESCE(realized_pnl, 0)), 0) AS realized_pnl_usd, COUNT(*) AS total_trades, @@ -44,21 +45,26 @@ class MetricsCalculator: ) AS latency_p99_seconds FROM trades WHERE finished_at IS NOT NULL - """).fetchone() + """ + ).fetchone() - om = conn.execute(""" + om = conn.execute( + """ SELECT COUNT(*) AS opportunity_count, MIN(detected_at) AS first_detected_at, MAX(detected_at) AS last_detected_at FROM opportunities - """).fetchone() + """ + ).fetchone() - fm = conn.execute(""" + fm = conn.execute( + """ SELECT AVG(filled_volume / volume) AS fill_rate FROM orders WHERE volume > 0 AND filled_volume IS NOT NULL - """).fetchone() + """ + ).fetchone() r_pnl_usd = float(tm[0]) if tm and tm[0] is not None else 0.0 tt = int(tm[1]) if tm and tm[1] is not None else 0 diff --git a/src/arbitrade/runtime/lifecycle.py b/src/arbitrade/runtime/lifecycle.py index c00a0e0..021c277 100644 --- a/src/arbitrade/runtime/lifecycle.py +++ b/src/arbitrade/runtime/lifecycle.py @@ -45,22 +45,26 @@ def _runtime_repository(app: FastAPI) -> RuntimeStateRepository | None: def _open_trade_count(store: DuckDBStore) -> int: with store.connect() as conn: - row = conn.execute(""" + row = conn.execute( + """ SELECT COUNT(*) FROM trades WHERE finished_at IS NULL - """).fetchone() + """ + ).fetchone() return int(row[0]) if row is not None else 0 def _latest_balances(store: DuckDBStore) -> dict[str, Any] | None: with store.connect() as conn: - row = conn.execute(""" + row = conn.execute( + """ SELECT balances FROM portfolio_snapshots ORDER BY snapshot_at DESC LIMIT 1 - """).fetchone() + """ + ).fetchone() if row is None or row[0] is None: return None diff --git a/src/arbitrade/storage/db.py b/src/arbitrade/storage/db.py index e33afa7..2d28ab6 100644 --- a/src/arbitrade/storage/db.py +++ b/src/arbitrade/storage/db.py @@ -216,12 +216,14 @@ class DuckDBStore: # Ensure schema_migrations table exists and get current version if not self._table_exists(conn, "schema_migrations"): - conn.execute(""" + conn.execute( + """ CREATE TABLE IF NOT EXISTS schema_migrations ( version INTEGER PRIMARY KEY, applied_at TIMESTAMP DEFAULT current_timestamp ) - """) + """ + ) # Get current schema version try: @@ -236,30 +238,24 @@ class DuckDBStore: if current_version < 1: # Migration v1: Add missing columns to trades table # Note: DuckDB does not support ADD COLUMN with constraints - conn.execute( - "ALTER TABLE trades ADD COLUMN IF NOT EXISTS trade_ref VARCHAR") - conn.execute( - "ALTER TABLE trades ADD COLUMN IF NOT EXISTS estimated_pnl DOUBLE") - conn.execute( - "ALTER TABLE trades ADD COLUMN IF NOT EXISTS capital_used DOUBLE") - conn.execute( - "ALTER TABLE trades ADD COLUMN IF NOT EXISTS cycle VARCHAR") - conn.execute( - "ALTER TABLE trades ADD COLUMN IF NOT EXISTS leg_count INTEGER") - conn.execute( - "INSERT OR IGNORE INTO schema_migrations (version) VALUES (1)") + conn.execute("ALTER TABLE trades ADD COLUMN IF NOT EXISTS trade_ref VARCHAR") + conn.execute("ALTER TABLE trades ADD COLUMN IF NOT EXISTS estimated_pnl DOUBLE") + conn.execute("ALTER TABLE trades ADD COLUMN IF NOT EXISTS capital_used DOUBLE") + conn.execute("ALTER TABLE trades ADD COLUMN IF NOT EXISTS cycle VARCHAR") + conn.execute("ALTER TABLE trades ADD COLUMN IF NOT EXISTS leg_count INTEGER") + conn.execute("INSERT OR IGNORE INTO schema_migrations (version) VALUES (1)") _LOG.info("migration_applied", version=1) if current_version < 2: # Migration v2: Ensure config_backtesting_defaults table # config_backtesting_defaults already created by SCHEMA_SQL - conn.execute( - "INSERT OR IGNORE INTO schema_migrations (version) VALUES (2)") + conn.execute("INSERT OR IGNORE INTO schema_migrations (version) VALUES (2)") _LOG.info("migration_applied", version=2) if current_version < 3: # Migration v3: Add kraken_account_snapshots table - conn.execute(""" + conn.execute( + """ CREATE TABLE IF NOT EXISTS kraken_account_snapshots ( snapshot_at TIMESTAMP NOT NULL, fee_tier VARCHAR, @@ -269,21 +265,22 @@ class DuckDBStore: trade_balance_raw JSON, fee_schedule_raw JSON ) - """) - conn.execute( - "INSERT OR IGNORE INTO schema_migrations (version) VALUES (3)") + """ + ) + conn.execute("INSERT OR IGNORE INTO schema_migrations (version) VALUES (3)") _LOG.info("migration_applied", version=3) if current_version < 4: # Migration v4: Add fee_source to backtesting defaults conn.execute( - "ALTER TABLE config_backtesting_defaults ADD COLUMN IF NOT EXISTS fee_source VARCHAR DEFAULT 'api'") - conn.execute( - "INSERT OR IGNORE INTO schema_migrations (version) VALUES (4)") + "ALTER TABLE config_backtesting_defaults ADD COLUMN IF NOT EXISTS fee_source VARCHAR DEFAULT 'api'" + ) + conn.execute("INSERT OR IGNORE INTO schema_migrations (version) VALUES (4)") _LOG.info("migration_applied", version=4) if current_version < 5: - conn.execute(""" + conn.execute( + """ CREATE TABLE IF NOT EXISTS backtest_jobs ( id UUID DEFAULT uuid(), status VARCHAR NOT NULL DEFAULT 'pending', @@ -295,9 +292,9 @@ class DuckDBStore: started_at TIMESTAMP, finished_at TIMESTAMP ) - """) - conn.execute( - "INSERT OR IGNORE INTO schema_migrations (version) VALUES (5)") + """ + ) + conn.execute("INSERT OR IGNORE INTO schema_migrations (version) VALUES (5)") _LOG.info("migration_applied", version=5) # Update version to current diff --git a/src/arbitrade/storage/repositories.py b/src/arbitrade/storage/repositories.py index daaa98c..418f83d 100644 --- a/src/arbitrade/storage/repositories.py +++ b/src/arbitrade/storage/repositories.py @@ -6,7 +6,12 @@ from typing import Any import orjson -from arbitrade.config.service import ConfigBacktestingDefaults, ConfigPairing, ConfigSection, ConfigSetting +from arbitrade.config.service import ( + ConfigBacktestingDefaults, + ConfigPairing, + ConfigSection, + ConfigSetting, +) from arbitrade.storage.db import DuckDBStore @@ -344,7 +349,8 @@ class RuntimeStateRepository: def latest(self) -> RuntimeStateRecord | None: with self._store.connect() as conn: - row = conn.execute(""" + row = conn.execute( + """ SELECT snapshot_at, is_running, @@ -356,7 +362,8 @@ class RuntimeStateRepository: FROM runtime_state_snapshots ORDER BY snapshot_at DESC LIMIT 1 - """).fetchone() + """ + ).fetchone() if row is None: return None @@ -397,12 +404,7 @@ class ConfigSectionRepository: ) row = cursor.fetchone() if row: - return ConfigSection( - id=row[0], - name=row[1], - description=row[2], - updated_at=row[3] - ) + return ConfigSection(id=row[0], name=row[1], description=row[2], updated_at=row[3]) raise ValueError("Failed to create section") def get_section(self, name: str) -> ConfigSection | None: @@ -418,12 +420,7 @@ class ConfigSectionRepository: ) row = cursor.fetchone() if row: - return ConfigSection( - id=row[0], - name=row[1], - description=row[2], - updated_at=row[3] - ) + return ConfigSection(id=row[0], name=row[1], description=row[2], updated_at=row[3]) return None def list_sections(self) -> list[ConfigSection]: @@ -437,12 +434,7 @@ class ConfigSectionRepository: """ ) return [ - ConfigSection( - id=row[0], - name=row[1], - description=row[2], - updated_at=row[3] - ) + ConfigSection(id=row[0], name=row[1], description=row[2], updated_at=row[3]) for row in cursor.fetchall() ] @@ -480,7 +472,7 @@ class ConfigSettingRepository: is_secret=bool(row[4]), is_runtime_reloadable=bool(row[5]), updated_at=row[6], - updated_by=row[7] + updated_by=row[7], ) raise ValueError("Failed to create setting") @@ -505,7 +497,7 @@ class ConfigSettingRepository: is_secret=bool(row[4]), is_runtime_reloadable=bool(row[5]), updated_at=row[6], - updated_by=row[7] + updated_by=row[7], ) return None @@ -539,7 +531,7 @@ class ConfigSettingRepository: is_secret=bool(row[4]), is_runtime_reloadable=bool(row[5]), updated_at=row[6], - updated_by=row[7] + updated_by=row[7], ) raise ValueError("Failed to update setting") @@ -585,7 +577,7 @@ class ConfigSettingRepository: is_secret=bool(row[4]), is_runtime_reloadable=bool(row[5]), updated_at=row[6], - updated_by=row[7] + updated_by=row[7], ) for row in cursor.fetchall() ] @@ -602,7 +594,7 @@ class ConfigSettingRepository: row = cursor.fetchone() if row and row[0]: # Convert string timestamp to datetime - return datetime.fromisoformat(row[0].replace('Z', '+00:00')) + return datetime.fromisoformat(row[0].replace("Z", "+00:00")) return None @@ -635,7 +627,7 @@ class ConfigPairingRepository: enabled=bool(row[3]), source=row[4], created_at=row[5], - updated_at=row[6] + updated_at=row[6], ) raise ValueError("Failed to create pairing") @@ -659,11 +651,13 @@ class ConfigPairingRepository: enabled=bool(row[3]), source=row[4], created_at=row[5], - updated_at=row[6] + updated_at=row[6], ) return None - def update_pairing(self, base_asset: str, quote_asset: str, pairing: ConfigPairing) -> ConfigPairing: + def update_pairing( + self, base_asset: str, quote_asset: str, pairing: ConfigPairing + ) -> ConfigPairing: """Update an existing currency pairing.""" with self._store.connect() as conn: cursor = conn.execute( @@ -689,7 +683,7 @@ class ConfigPairingRepository: enabled=bool(row[3]), source=row[4], created_at=row[5], - updated_at=row[6] + updated_at=row[6], ) raise ValueError("Failed to update pairing") @@ -723,7 +717,7 @@ class ConfigPairingRepository: enabled=bool(row[3]), source=row[4], created_at=row[5], - updated_at=row[6] + updated_at=row[6], ) for row in cursor.fetchall() ] @@ -743,8 +737,11 @@ class ConfigBacktestingDefaultsRepository: RETURNING id, starting_balances, trade_capital, min_profit_threshold, slippage_bps, execution_latency_ms """, ( - orjson.dumps(defaults.starting_balances).decode( - 'utf-8') if defaults.starting_balances else None, + ( + orjson.dumps(defaults.starting_balances).decode("utf-8") + if defaults.starting_balances + else None + ), defaults.trade_capital, defaults.min_profit_threshold, defaults.slippage_bps, @@ -758,7 +755,7 @@ class ConfigBacktestingDefaultsRepository: trade_capital=row[2], min_profit_threshold=row[3], slippage_bps=row[4], - execution_latency_ms=row[5] + execution_latency_ms=row[5], ) raise ValueError("Failed to create backtesting defaults") @@ -780,7 +777,7 @@ class ConfigBacktestingDefaultsRepository: trade_capital=row[2], min_profit_threshold=row[3], slippage_bps=row[4], - execution_latency_ms=row[5] + execution_latency_ms=row[5], ) return None @@ -797,8 +794,11 @@ class ConfigBacktestingDefaultsRepository: RETURNING id, starting_balances, trade_capital, min_profit_threshold, slippage_bps, execution_latency_ms """, ( - orjson.dumps(defaults.starting_balances).decode( - 'utf-8') if defaults.starting_balances else None, + ( + orjson.dumps(defaults.starting_balances).decode("utf-8") + if defaults.starting_balances + else None + ), defaults.trade_capital, defaults.min_profit_threshold, defaults.slippage_bps, @@ -812,7 +812,7 @@ class ConfigBacktestingDefaultsRepository: trade_capital=row[2], min_profit_threshold=row[3], slippage_bps=row[4], - execution_latency_ms=row[5] + execution_latency_ms=row[5], ) raise ValueError("Failed to update backtesting defaults") @@ -847,10 +847,16 @@ class KrakenAccountSnapshotRepository: snapshot.maker_fee, snapshot.taker_fee, snapshot.thirty_day_volume, - orjson.dumps(snapshot.trade_balance_raw).decode("utf-8") - if snapshot.trade_balance_raw else None, - orjson.dumps(snapshot.fee_schedule_raw).decode("utf-8") - if snapshot.fee_schedule_raw else None, + ( + orjson.dumps(snapshot.trade_balance_raw).decode("utf-8") + if snapshot.trade_balance_raw + else None + ), + ( + orjson.dumps(snapshot.fee_schedule_raw).decode("utf-8") + if snapshot.fee_schedule_raw + else None + ), ), ) @@ -895,7 +901,9 @@ class BacktestJobRepository: def __init__(self, store: DuckDBStore) -> None: self._store = store - def create_job(self, events_path: str, config: dict[str, Any] | None = None) -> BacktestJobRecord: + def create_job( + self, events_path: str, config: dict[str, Any] | None = None + ) -> BacktestJobRecord: with self._store.connect() as conn: row = conn.execute( """ @@ -903,13 +911,14 @@ class BacktestJobRepository: VALUES (?, ?) RETURNING id, status, events_path, config, created_at """, - (events_path, orjson.dumps(config).decode( - "utf-8") if config else None), + (events_path, orjson.dumps(config).decode("utf-8") if config else None), ).fetchone() if row is None: raise ValueError("Failed to create backtest job") return BacktestJobRecord( - id=str(row[0]), status=str(row[1]), events_path=str(row[2]), + id=str(row[0]), + status=str(row[1]), + events_path=str(row[2]), config=orjson.loads(row[3]) if row[3] else None, created_at=row[4], ) @@ -950,11 +959,15 @@ class BacktestJobRepository: if row is None: return None return BacktestJobRecord( - id=str(row[0]), status=str(row[1]), events_path=str(row[2]), + id=str(row[0]), + status=str(row[1]), + events_path=str(row[2]), config=orjson.loads(row[3]) if row[3] else None, report=orjson.loads(row[4]) if row[4] else None, error=str(row[5]) if row[5] else None, - created_at=row[6], started_at=row[7], finished_at=row[8], + created_at=row[6], + started_at=row[7], + finished_at=row[8], ) def list_jobs(self, limit: int = 20) -> list[BacktestJobRecord]: @@ -967,11 +980,15 @@ class BacktestJobRepository: ).fetchall() return [ BacktestJobRecord( - id=str(r[0]), status=str(r[1]), events_path=str(r[2]), + id=str(r[0]), + status=str(r[1]), + events_path=str(r[2]), config=orjson.loads(r[3]) if r[3] else None, report=orjson.loads(r[4]) if r[4] else None, error=str(r[5]) if r[5] else None, - created_at=r[6], started_at=r[7], finished_at=r[8], + created_at=r[6], + started_at=r[7], + finished_at=r[8], ) for r in rows ] @@ -979,6 +996,7 @@ class BacktestJobRepository: def delete_job(self, job_id: str) -> bool: with self._store.connect() as conn: result = conn.execute( - "DELETE FROM backtest_jobs WHERE id = ?", (job_id,), + "DELETE FROM backtest_jobs WHERE id = ?", + (job_id,), ) return result.rowcount > 0 diff --git a/tests/test_dashboard.py b/tests/test_dashboard.py index e023590..8ee2623 100644 --- a/tests/test_dashboard.py +++ b/tests/test_dashboard.py @@ -191,8 +191,7 @@ async def test_dashboard_page_and_fragment_and_sse(tmp_path) -> None: assert "trade-open" in overview.text assert overview_stream.status_code == 200 - assert overview_stream.headers["content-type"].startswith( - "text/event-stream") + assert overview_stream.headers["content-type"].startswith("text/event-stream") assert "event: overview" in overview_stream.text assert "trade-open" in overview_stream.text @@ -262,8 +261,7 @@ async def test_dashboard_controls_update_runtime_state_and_config(tmp_path) -> N assert app.state.settings.max_trade_capital_usd == 300.0 assert app.state.settings.max_concurrent_trades == 4 assert app.state.settings.paper_trading_mode is True - assert app.state.dashboard_controls.tradable_pairs == [ - "BTC/USD", "ETH/BTC"] + assert app.state.dashboard_controls.tradable_pairs == ["BTC/USD", "ETH/BTC"] assert app.state.dashboard_controls.strategy_mode == "paper" assert app.state.dashboard_controls.strategy_profit_threshold == 0.0025 assert app.state.dashboard_controls.strategy_max_depth_levels == 7 @@ -275,14 +273,10 @@ async def test_dashboard_controls_update_runtime_state_and_config(tmp_path) -> N assert audit_recent.status_code == 200 entries = audit_recent.json()["entries"] assert len(entries) >= 4 - assert any(entry["event_type"] == - "dashboard.control.stop" for entry in entries) - assert any(entry["event_type"] == - "dashboard.control.start" for entry in entries) - assert any(entry["event_type"] == - "dashboard.control.kill_switch" for entry in entries) - assert any(entry["event_type"] == - "dashboard.control.config" for entry in entries) + assert any(entry["event_type"] == "dashboard.control.stop" for entry in entries) + assert any(entry["event_type"] == "dashboard.control.start" for entry in entries) + assert any(entry["event_type"] == "dashboard.control.kill_switch" for entry in entries) + assert any(entry["event_type"] == "dashboard.control.config" for entry in entries) async def test_dashboard_controls_emit_alerts(tmp_path) -> None: diff --git a/tests/unit/test_config_e2e.py b/tests/unit/test_config_e2e.py index ab3eccb..f8354fd 100644 --- a/tests/unit/test_config_e2e.py +++ b/tests/unit/test_config_e2e.py @@ -24,7 +24,7 @@ def test_end_to_end_config_workflow(): assert service.get_last_updated_at() is None # Test setting a value - with patch('arbitrade.config.service.ConfigSettingRepository') as mock_repo_class: + with patch("arbitrade.config.service.ConfigSettingRepository") as mock_repo_class: mock_repo_instance = Mock() mock_repo_class.return_value = mock_repo_instance diff --git a/tests/unit/test_config_repositories.py b/tests/unit/test_config_repositories.py index 6b39f32..eae5388 100644 --- a/tests/unit/test_config_repositories.py +++ b/tests/unit/test_config_repositories.py @@ -6,10 +6,13 @@ from unittest.mock import Mock, patch from arbitrade.storage.repositories import ( ConfigSettingRepository, ConfigPairingRepository, - ConfigPairFeeRepository, - ConfigBacktestingDefaultsRepository + ConfigBacktestingDefaultsRepository, +) +from arbitrade.config.service import ( + ConfigSetting, + ConfigPairing, + ConfigBacktestingDefaults, ) -from arbitrade.config.service import ConfigSetting, ConfigPairing, ConfigPairFee, ConfigBacktestingDefaults from arbitrade.storage.db import DuckDBStore @@ -31,13 +34,20 @@ def test_config_setting_repository_create_setting(mock_store): repo = ConfigSettingRepository(mock_store) # Mock database connection - with patch.object(mock_store, 'connect') as mock_connect: + with patch.object(mock_store, "connect") as mock_connect: mock_cursor = Mock() mock_connect.return_value.__enter__.return_value = mock_cursor # Mock the return value mock_cursor.fetchone.return_value = [ - "test_key", "test_section", "test_value", "str", False, False, "2023-01-01T00:00:00", "test_user" + "test_key", + "test_section", + "test_value", + "str", + False, + False, + "2023-01-01T00:00:00", + "test_user", ] # Create setting @@ -48,7 +58,7 @@ def test_config_setting_repository_create_setting(mock_store): value_type="str", is_secret=False, is_runtime_reloadable=False, - updated_by="test_user" + updated_by="test_user", ) result = repo.create_setting(setting) @@ -67,13 +77,20 @@ def test_config_setting_repository_get_setting(mock_store): repo = ConfigSettingRepository(mock_store) # Mock database connection - with patch.object(mock_store, 'connect') as mock_connect: + with patch.object(mock_store, "connect") as mock_connect: mock_cursor = Mock() mock_connect.return_value.__enter__.return_value = mock_cursor # Mock the return value mock_cursor.fetchone.return_value = [ - "test_key", "test_section", "test_value", "str", False, False, "2023-01-01T00:00:00", "test_user" + "test_key", + "test_section", + "test_value", + "str", + False, + False, + "2023-01-01T00:00:00", + "test_user", ] # Get setting @@ -93,13 +110,20 @@ def test_config_setting_repository_update_setting(mock_store): repo = ConfigSettingRepository(mock_store) # Mock database connection - with patch.object(mock_store, 'connect') as mock_connect: + with patch.object(mock_store, "connect") as mock_connect: mock_cursor = Mock() mock_connect.return_value.__enter__.return_value = mock_cursor # Mock the return value mock_cursor.fetchone.return_value = [ - "test_key", "test_section", "updated_value", "str", False, False, "2023-01-01T00:00:00", "test_user" + "test_key", + "test_section", + "updated_value", + "str", + False, + False, + "2023-01-01T00:00:00", + "test_user", ] # Update setting @@ -110,7 +134,7 @@ def test_config_setting_repository_update_setting(mock_store): value_type="str", is_secret=False, is_runtime_reloadable=False, - updated_by="test_user" + updated_by="test_user", ) result = repo.update_setting("test_key", setting) @@ -129,16 +153,32 @@ def test_config_setting_repository_list_settings(mock_store): repo = ConfigSettingRepository(mock_store) # Mock database connection - with patch.object(mock_store, 'connect') as mock_connect: + with patch.object(mock_store, "connect") as mock_connect: mock_cursor = Mock() mock_connect.return_value.__enter__.return_value = mock_cursor # Mock the return value mock_cursor.fetchall.return_value = [ - ["test_key1", "test_section", "test_value1", "str", - False, False, "2023-01-01T00:00:00", "test_user"], - ["test_key2", "test_section", "test_value2", "str", - False, False, "2023-01-01T00:00:00", "test_user"] + [ + "test_key1", + "test_section", + "test_value1", + "str", + False, + False, + "2023-01-01T00:00:00", + "test_user", + ], + [ + "test_key2", + "test_section", + "test_value2", + "str", + False, + False, + "2023-01-01T00:00:00", + "test_user", + ], ] # List settings @@ -156,7 +196,7 @@ def test_config_setting_repository_get_latest_updated_at(mock_store): repo = ConfigSettingRepository(mock_store) # Mock database connection - with patch.object(mock_store, 'connect') as mock_connect: + with patch.object(mock_store, "connect") as mock_connect: mock_cursor = Mock() mock_connect.return_value.__enter__.return_value = mock_cursor @@ -182,22 +222,24 @@ def test_config_pairing_repository_create_pairing(mock_store): repo = ConfigPairingRepository(mock_store) # Mock database connection - with patch.object(mock_store, 'connect') as mock_connect: + with patch.object(mock_store, "connect") as mock_connect: mock_cursor = Mock() mock_connect.return_value.__enter__.return_value = mock_cursor # Mock the return value mock_cursor.fetchone.return_value = [ - 1, "BTC", "USD", True, "Kraken", "2023-01-01T00:00:00", "2023-01-01T00:00:00" + 1, + "BTC", + "USD", + True, + "Kraken", + "2023-01-01T00:00:00", + "2023-01-01T00:00:00", ] # Create pairing pairing = ConfigPairing( - base_asset="BTC", - quote_asset="USD", - enabled=True, - source="Kraken" - ) + base_asset="BTC", quote_asset="USD", enabled=True, source="Kraken") result = repo.create_pairing(pairing) @@ -214,13 +256,19 @@ def test_config_pairing_repository_get_pairing(mock_store): repo = ConfigPairingRepository(mock_store) # Mock database connection - with patch.object(mock_store, 'connect') as mock_connect: + with patch.object(mock_store, "connect") as mock_connect: mock_cursor = Mock() mock_connect.return_value.__enter__.return_value = mock_cursor # Mock the return value mock_cursor.fetchone.return_value = [ - 1, "BTC", "USD", True, "Kraken", "2023-01-01T00:00:00", "2023-01-01T00:00:00" + 1, + "BTC", + "USD", + True, + "Kraken", + "2023-01-01T00:00:00", + "2023-01-01T00:00:00", ] # Get pairing @@ -234,12 +282,6 @@ def test_config_pairing_repository_get_pairing(mock_store): assert result.source == "Kraken" -def test_config_pair_fee_repository_initialization(mock_store): - """Test ConfigPairFeeRepository initialization.""" - repo = ConfigPairFeeRepository(mock_store) - assert repo._store == mock_store - - def test_config_backtesting_defaults_repository_initialization(mock_store): """Test ConfigBacktestingDefaultsRepository initialization.""" repo = ConfigBacktestingDefaultsRepository(mock_store) diff --git a/tests/unit/test_config_service.py b/tests/unit/test_config_service.py index 14eb66e..6fb1c8f 100644 --- a/tests/unit/test_config_service.py +++ b/tests/unit/test_config_service.py @@ -31,9 +31,7 @@ def mock_audit_repo(): return audit_repo -def test_configuration_service_initialization( - mock_settings, mock_store, mock_audit_repo -): +def test_configuration_service_initialization(mock_settings, mock_store, mock_audit_repo): """Test that ConfigurationService initializes correctly.""" # Create service instance service = ConfigurationService(mock_settings, mock_store, mock_audit_repo) @@ -46,9 +44,7 @@ def test_configuration_service_initialization( assert isinstance(service._loaded_settings, dict) -def test_configuration_service_get_setting( - mock_settings, mock_store, mock_audit_repo -): +def test_configuration_service_get_setting(mock_settings, mock_store, mock_audit_repo): """Test getting configuration settings.""" # Create service instance service = ConfigurationService(mock_settings, mock_store, mock_audit_repo) @@ -65,15 +61,13 @@ def test_configuration_service_get_setting( assert result == "default" -def test_configuration_service_set_setting( - mock_settings, mock_store, mock_audit_repo -): +def test_configuration_service_set_setting(mock_settings, mock_store, mock_audit_repo): """Test setting configuration settings.""" # Create service instance service = ConfigurationService(mock_settings, mock_store, mock_audit_repo) # Mock the repository - with patch('arbitrade.config.service.ConfigSettingRepository') as mock_repo_class: + with patch("arbitrade.config.service.ConfigSettingRepository") as mock_repo_class: mock_repo_instance = Mock() mock_repo_class.return_value = mock_repo_instance @@ -90,9 +84,7 @@ def test_configuration_service_set_setting( mock_repo_instance.create_setting.assert_called_once() -def test_configuration_service_hot_reload_detection( - mock_settings, mock_store, mock_audit_repo -): +def test_configuration_service_hot_reload_detection(mock_settings, mock_store, mock_audit_repo): """Test hot-reload detection functionality.""" # Create service instance service = ConfigurationService(mock_settings, mock_store, mock_audit_repo) @@ -101,27 +93,26 @@ def test_configuration_service_hot_reload_detection( assert service.is_config_outdated() is False # Test with mock repository that returns a timestamp - with patch('arbitrade.config.service.ConfigSettingRepository') as mock_repo_class: + with patch("arbitrade.config.service.ConfigSettingRepository") as mock_repo_class: mock_repo_instance = Mock() mock_repo_class.return_value = mock_repo_instance # Mock the latest updated at timestamp from datetime import datetime + mock_repo_instance.get_latest_updated_at.return_value = datetime.now() # Should detect as outdated when timestamp exists assert service.is_config_outdated() is True -def test_configuration_service_reload_if_changed( - mock_settings, mock_store, mock_audit_repo -): +def test_configuration_service_reload_if_changed(mock_settings, mock_store, mock_audit_repo): """Test hot-reload functionality.""" # Create service instance service = ConfigurationService(mock_settings, mock_store, mock_audit_repo) # Mock the repository - with patch('arbitrade.config.service.ConfigSettingRepository') as mock_repo_class: + with patch("arbitrade.config.service.ConfigSettingRepository") as mock_repo_class: mock_repo_instance = Mock() mock_repo_class.return_value = mock_repo_instance @@ -135,6 +126,7 @@ def test_configuration_service_reload_if_changed( # Mock the latest updated at timestamp to return a value from datetime import datetime + mock_repo_instance.get_latest_updated_at.return_value = datetime.now() # Should reload when outdated @@ -143,9 +135,7 @@ def test_configuration_service_reload_if_changed( assert service.get_config_version() == 1 -def test_configuration_service_get_config_version( - mock_settings, mock_store, mock_audit_repo -): +def test_configuration_service_get_config_version(mock_settings, mock_store, mock_audit_repo): """Test getting configuration version.""" # Create service instance service = ConfigurationService(mock_settings, mock_store, mock_audit_repo) @@ -154,7 +144,7 @@ def test_configuration_service_get_config_version( assert service.get_config_version() == 0 # After setting a value, version should increment - with patch('arbitrade.config.service.ConfigSettingRepository') as mock_repo_class: + with patch("arbitrade.config.service.ConfigSettingRepository") as mock_repo_class: mock_repo_instance = Mock() mock_repo_class.return_value = mock_repo_instance @@ -166,9 +156,7 @@ def test_configuration_service_get_config_version( assert service.get_config_version() == 1 -def test_configuration_service_get_last_updated_at( - mock_settings, mock_store, mock_audit_repo -): +def test_configuration_service_get_last_updated_at(mock_settings, mock_store, mock_audit_repo): """Test getting last updated timestamp.""" # Create service instance service = ConfigurationService(mock_settings, mock_store, mock_audit_repo) @@ -177,7 +165,7 @@ def test_configuration_service_get_last_updated_at( assert service.get_last_updated_at() is None # After setting a value, should have timestamp - with patch('arbitrade.config.service.ConfigSettingRepository') as mock_repo_class: + with patch("arbitrade.config.service.ConfigSettingRepository") as mock_repo_class: mock_repo_instance = Mock() mock_repo_class.return_value = mock_repo_instance diff --git a/tests/unit/test_template_resolution.py b/tests/unit/test_template_resolution.py index 26cca3f..568b3a0 100644 --- a/tests/unit/test_template_resolution.py +++ b/tests/unit/test_template_resolution.py @@ -12,8 +12,6 @@ def test_template_directory_resolves_to_existing_location() -> None: def test_template_exists_in_package_resources() -> None: - template_path = resources.files("arbitrade").joinpath( - "web", "templates", "dashboard.html" - ) + template_path = resources.files("arbitrade").joinpath("web", "templates", "dashboard.html") assert template_path.is_file()