From dc99f1604e42c2a16eb70bcd59d8ead03b63abde Mon Sep 17 00:00:00 2001 From: zwitschi Date: Sun, 7 Jun 2026 21:59:09 +0200 Subject: [PATCH] Refactor code for improved readability and consistency - Cleaned up multiline statements and removed unnecessary line breaks in various files. - Ensured consistent formatting in function definitions and calls across the codebase. - Updated docstrings and comments for clarity where applicable. - Removed trailing newlines in module docstrings. - Enhanced logging statements for better clarity in maintenance tasks. --- scripts/benchmark_metrics_compute.py | 3 +- src/arbitrade/api/routes.py | 76 +++++------ src/arbitrade/config/settings.py | 114 ++++++----------- src/arbitrade/exchange/fee_service.py | 21 +-- src/arbitrade/execution/sequencer.py | 21 +-- src/arbitrade/logging/__init__.py | 2 +- src/arbitrade/logging/db_sink.py | 14 +- src/arbitrade/logging/maintenance.py | 3 +- src/arbitrade/market_data/feed.py | 15 +-- src/arbitrade/market_data/order_book.py | 3 +- src/arbitrade/metrics.py | 63 +++++---- src/arbitrade/runtime/lifecycle.py | 4 +- src/arbitrade/storage/executions.py | 3 +- src/arbitrade/storage/market_snapshots.py | 9 +- src/arbitrade/storage/opportunities.py | 6 +- src/arbitrade/storage/pg_store.py | 4 +- src/arbitrade/storage/repositories.py | 112 ++++++++++------ tests/integration/__init__.py | 2 +- tests/integration/conftest.py | 8 +- tests/integration/test_metrics.py | 70 ++++++++-- tests/integration/test_postgresql_schema.py | 134 +++++++++++++++----- tests/unit/test_config_e2e.py | 3 +- tests/unit/test_config_repositories.py | 10 +- tests/unit/test_config_service.py | 23 ++-- tests/unit/test_runtime_lifecycle.py | 10 +- 25 files changed, 409 insertions(+), 324 deletions(-) diff --git a/scripts/benchmark_metrics_compute.py b/scripts/benchmark_metrics_compute.py index e5535f6..45338a8 100644 --- a/scripts/benchmark_metrics_compute.py +++ b/scripts/benchmark_metrics_compute.py @@ -67,8 +67,7 @@ async def _seed_dataset(store: PgStore) -> None: opportunity_rows: list[tuple[object, ...]] = [] for i in range(5000): detected_at = now + timedelta(milliseconds=200 * i) - opportunity_rows.append( - (detected_at, "USD->BTC->ETH->USD", 2.5, 1.2, 0.03, bool(i % 2))) + opportunity_rows.append((detected_at, "USD->BTC->ETH->USD", 2.5, 1.2, 0.03, bool(i % 2))) order_rows: list[tuple[object, ...]] = [] for i in range(3500): diff --git a/src/arbitrade/api/routes.py b/src/arbitrade/api/routes.py index facd69b..0a95d68 100644 --- a/src/arbitrade/api/routes.py +++ b/src/arbitrade/api/routes.py @@ -36,8 +36,7 @@ public_router = APIRouter() def _resolve_templates_directory() -> str: # Support source layout, Docker runtime (/app), and installed package data. - source_layout_path = Path( - __file__).resolve().parents[3] / "web" / "templates" + source_layout_path = Path(__file__).resolve().parents[3] / "web" / "templates" if source_layout_path.is_dir(): return str(source_layout_path) @@ -46,8 +45,7 @@ def _resolve_templates_directory() -> str: return str(docker_runtime_path) try: - package_path = resources.files( - "arbitrade").joinpath("web", "templates") + package_path = resources.files("arbitrade").joinpath("web", "templates") if package_path.is_dir(): return str(package_path) except (ModuleNotFoundError, AttributeError): @@ -153,12 +151,22 @@ async def _dashboard_overview(request: Request) -> dict[str, object]: LIMIT 1 """) if acct_row is not None: - fee_tier = str( - acct_row["fee_tier"]) if acct_row["fee_tier"] is not None else "—" - maker_fee = f"{float(acct_row['maker_fee']):.4%}" if acct_row["maker_fee"] is not None else "—" - taker_fee = f"{float(acct_row['taker_fee']):.4%}" if acct_row["taker_fee"] is not None else "—" - thirty_day_volume = f"{float(acct_row['thirty_day_volume']):.2f}" if acct_row[ - "thirty_day_volume"] is not None else "—" + fee_tier = str(acct_row["fee_tier"]) if acct_row["fee_tier"] is not None else "—" + maker_fee = ( + f"{float(acct_row['maker_fee']):.4%}" + if acct_row["maker_fee"] is not None + else "—" + ) + taker_fee = ( + f"{float(acct_row['taker_fee']):.4%}" + if acct_row["taker_fee"] is not None + else "—" + ) + thirty_day_volume = ( + f"{float(acct_row['thirty_day_volume']):.2f}" + if acct_row["thirty_day_volume"] is not None + else "—" + ) except Exception: pass @@ -171,8 +179,7 @@ async def _dashboard_overview(request: Request) -> dict[str, object]: try: parsed = json.loads(balances_raw) if isinstance(parsed, dict): - non_zero = {k: float(v) - for k, v in parsed.items() if float(v) > 0.0} + non_zero = {k: float(v) for k, v in parsed.items() if float(v) > 0.0} if non_zero: balances_value = "
".join( f"{v:.6g} {k}" for k, v in sorted(non_zero.items()) @@ -192,7 +199,9 @@ async def _dashboard_overview(request: Request) -> dict[str, object]: { "trade_ref": str(r["trade_ref"]), "status": str(r["status"]), - "started_at": r["started_at"].isoformat() if isinstance(r["started_at"], datetime) else "—", + "started_at": ( + r["started_at"].isoformat() if isinstance(r["started_at"], datetime) else "—" + ), "cycle": str(r["cycle"]) if r["cycle"] is not None else "—", } for r in open_trades @@ -201,8 +210,12 @@ async def _dashboard_overview(request: Request) -> dict[str, object]: { "cycle": str(r["cycle"]), "net_pct": f"{float(r['net_pct']):.2f}%" if r["net_pct"] is not None else "—", - "est_profit": f"{float(r['est_profit']):.2f} USD" if r["est_profit"] is not None else "—", - "detected_at": r["detected_at"].isoformat() if isinstance(r["detected_at"], datetime) else "—", + "est_profit": ( + f"{float(r['est_profit']):.2f} USD" if r["est_profit"] is not None else "—" + ), + "detected_at": ( + r["detected_at"].isoformat() if isinstance(r["detected_at"], datetime) else "—" + ), } for r in latest_opportunities ] @@ -242,10 +255,8 @@ async def _dashboard_charts(request: Request) -> dict[str, object]: labels.append(row["detected_at"].isoformat()) else: labels.append(f"opportunity-{index + 1}") - np = [float(row["net_pct"]) if row["net_pct"] - is not None else 0.0 for row in cr] - ep = [float(row["est_profit"]) if row["est_profit"] - is not None else 0.0 for row in cr] + np = [float(row["net_pct"]) if row["net_pct"] is not None else 0.0 for row in cr] + ep = [float(row["est_profit"]) if row["est_profit"] is not None else 0.0 for row in cr] cycles = [str(row["cycle"]) for row in cr] return { @@ -411,8 +422,7 @@ async def _dashboard_config_context(request: Request) -> dict[str, object]: max_consecutive_failures_value = ( str(rs.max_consecutive_failures) if rs.max_consecutive_failures is not None else "" ) - strategy_stat_arb_enabled = bool( - getattr(rs, "strategy_enable_stat_arb_experiment", False)) + strategy_stat_arb_enabled = bool(getattr(rs, "strategy_enable_stat_arb_experiment", False)) return { # Runtime @@ -533,8 +543,7 @@ def _dashboard_controls(request: Request) -> dict[str, object]: alerts_last_channel_results = [ str(item) for item in cast(list[object], alert_status.get("last_channel_results", [])) ] - strategy_stat_arb_enabled = bool( - getattr(rs, "strategy_enable_stat_arb_experiment", False)) + strategy_stat_arb_enabled = bool(getattr(rs, "strategy_enable_stat_arb_experiment", False)) return { "execution_status": "running" if ctl.is_running else "stopped", @@ -813,7 +822,6 @@ async def dashboard_backtesting_page(request: Request) -> HTMLResponse: @router.get("/dashboard/fragment/backtesting", response_class=HTMLResponse) async def dashboard_backtesting_fragment(request: Request) -> HTMLResponse: - d_context = await _dashboard_config_context(request) ctx = await _backtesting_panel_context(request) ctx["flash_message"] = "" # Check if any pairings are enabled @@ -992,19 +1000,18 @@ async def dashboard_backtesting_run(request: Request) -> HTMLResponse: try: custom_fee_rate = ( - float(defaults["custom_fee_rate"] - ) if defaults["custom_fee_rate"].strip() else None + float(defaults["custom_fee_rate"]) if defaults["custom_fee_rate"].strip() else None + ) + fee_rate = await _fee_rate_for_profile( + defaults["fee_profile"], custom_fee_rate, request=request ) - fee_rate = await _fee_rate_for_profile(defaults["fee_profile"], custom_fee_rate, request=request) # Use enabled pairings from DB when none selected symbols_str = defaults["symbols"] if not symbols_str.strip(): pairing_repo = ConfigPairingRepository(request.app.state.store) enabled = await pairing_repo.list_pairings(enabled_only=True) - symbols_str = ",".join( - f"{p.base_asset}/{p.quote_asset}" for p in enabled - ) + symbols_str = ",".join(f"{p.base_asset}/{p.quote_asset}" for p in enabled) config_dict: dict[str, object] = { "source": defaults["source"], @@ -1133,8 +1140,7 @@ async def dashboard_backtesting_export(request: Request, job_id: str) -> Respons return Response( content=orjson.dumps(payload).decode("utf-8"), media_type="application/x-jsonlines", - headers={ - "Content-Disposition": f"attachment; filename=backtest_{job_id[:8]}.jsonl"}, + headers={"Content-Disposition": f"attachment; filename=backtest_{job_id[:8]}.jsonl"}, ) @@ -1383,11 +1389,9 @@ async def dashboard_api_pairings( if source: pairings = [p for p in pairings if p.source.lower() == source.lower()] if base: - pairings = [p for p in pairings if p.base_asset.lower() == - base.lower()] + pairings = [p for p in pairings if p.base_asset.lower() == base.lower()] if quote: - pairings = [p for p in pairings if p.quote_asset.lower() == - quote.lower()] + pairings = [p for p in pairings if p.quote_asset.lower() == quote.lower()] # Sort reverse = order.lower() == "desc" diff --git a/src/arbitrade/config/settings.py b/src/arbitrade/config/settings.py index bf7feec..8426a94 100644 --- a/src/arbitrade/config/settings.py +++ b/src/arbitrade/config/settings.py @@ -31,41 +31,26 @@ class Settings(BaseSettings): ) alerts_enabled: bool = Field(default=True, alias="ALERTS_ENABLED") - alert_min_severity: str = Field( - default="warning", alias="ALERT_MIN_SEVERITY") - alert_dedup_seconds: float = Field( - default=30.0, alias="ALERT_DEDUP_SECONDS") - alert_on_trade_events: bool = Field( - default=True, alias="ALERT_ON_TRADE_EVENTS") - alert_on_error_events: bool = Field( - default=True, alias="ALERT_ON_ERROR_EVENTS") - alert_on_threshold_events: bool = Field( - default=True, alias="ALERT_ON_THRESHOLD_EVENTS") - alert_on_system_events: bool = Field( - default=True, alias="ALERT_ON_SYSTEM_EVENTS") + alert_min_severity: str = Field(default="warning", alias="ALERT_MIN_SEVERITY") + alert_dedup_seconds: float = Field(default=30.0, alias="ALERT_DEDUP_SECONDS") + alert_on_trade_events: bool = Field(default=True, alias="ALERT_ON_TRADE_EVENTS") + alert_on_error_events: bool = Field(default=True, alias="ALERT_ON_ERROR_EVENTS") + alert_on_threshold_events: bool = Field(default=True, alias="ALERT_ON_THRESHOLD_EVENTS") + alert_on_system_events: bool = Field(default=True, alias="ALERT_ON_SYSTEM_EVENTS") - telegram_alerts_enabled: bool = Field( - default=False, alias="TELEGRAM_ALERTS_ENABLED") - telegram_bot_token: str | None = Field( - default=None, alias="TELEGRAM_BOT_TOKEN") - telegram_chat_id: str | None = Field( - default=None, alias="TELEGRAM_CHAT_ID") + telegram_alerts_enabled: bool = Field(default=False, alias="TELEGRAM_ALERTS_ENABLED") + telegram_bot_token: str | None = Field(default=None, alias="TELEGRAM_BOT_TOKEN") + telegram_chat_id: str | None = Field(default=None, alias="TELEGRAM_CHAT_ID") - discord_alerts_enabled: bool = Field( - default=False, alias="DISCORD_ALERTS_ENABLED") - discord_webhook_url: str | None = Field( - default=None, alias="DISCORD_WEBHOOK_URL") + discord_alerts_enabled: bool = Field(default=False, alias="DISCORD_ALERTS_ENABLED") + discord_webhook_url: str | None = Field(default=None, alias="DISCORD_WEBHOOK_URL") - email_alerts_enabled: bool = Field( - default=False, alias="EMAIL_ALERTS_ENABLED") + email_alerts_enabled: bool = Field(default=False, alias="EMAIL_ALERTS_ENABLED") email_smtp_host: str | None = Field(default=None, alias="EMAIL_SMTP_HOST") email_smtp_port: int = Field(default=587, alias="EMAIL_SMTP_PORT") - email_smtp_username: str | None = Field( - default=None, alias="EMAIL_SMTP_USERNAME") - email_smtp_password: str | None = Field( - default=None, alias="EMAIL_SMTP_PASSWORD") - email_alert_from: str | None = Field( - default=None, alias="EMAIL_ALERT_FROM") + email_smtp_username: str | None = Field(default=None, alias="EMAIL_SMTP_USERNAME") + email_smtp_password: str | None = Field(default=None, alias="EMAIL_SMTP_PASSWORD") + email_alert_from: str | None = Field(default=None, alias="EMAIL_ALERT_FROM") email_alert_to: str | None = Field(default=None, alias="EMAIL_ALERT_TO") email_smtp_use_tls: bool = Field(default=True, alias="EMAIL_SMTP_USE_TLS") @@ -78,31 +63,24 @@ class Settings(BaseSettings): pg_min_connections: int = Field(default=2, alias="PG_MIN_CONNECTIONS") pg_max_connections: int = Field(default=10, alias="PG_MAX_CONNECTIONS") - kraken_rest_url: str = Field( - default="https://api.kraken.com", alias="KRAKEN_REST_URL") - kraken_ws_url: str = Field( - default="wss://ws.kraken.com/v2", alias="KRAKEN_WS_URL") + kraken_rest_url: str = Field(default="https://api.kraken.com", alias="KRAKEN_REST_URL") + kraken_ws_url: str = Field(default="wss://ws.kraken.com/v2", alias="KRAKEN_WS_URL") kraken_private_rate_limit_seconds: float = Field( default=1.0, alias="KRAKEN_PRIVATE_RATE_LIMIT_SECONDS" ) - kraken_http_timeout_seconds: float = Field( - default=10.0, alias="KRAKEN_HTTP_TIMEOUT_SECONDS") - kraken_retry_attempts: int = Field( - default=3, alias="KRAKEN_RETRY_ATTEMPTS") + kraken_http_timeout_seconds: float = Field(default=10.0, alias="KRAKEN_HTTP_TIMEOUT_SECONDS") + kraken_retry_attempts: int = Field(default=3, alias="KRAKEN_RETRY_ATTEMPTS") kraken_retry_base_delay_seconds: float = Field( default=0.25, alias="KRAKEN_RETRY_BASE_DELAY_SECONDS" ) kraken_api_key: str | None = Field(default=None, alias="KRAKEN_API_KEY") - kraken_api_secret: str | None = Field( - default=None, alias="KRAKEN_API_SECRET") + kraken_api_secret: str | None = Field(default=None, alias="KRAKEN_API_SECRET") kraken_api_key_permissions: str = Field( default="query,trade", alias="KRAKEN_API_KEY_PERMISSIONS", ) - ws_heartbeat_timeout_seconds: float = Field( - default=20.0, alias="WS_HEARTBEAT_TIMEOUT_SECONDS") - ws_max_staleness_seconds: float = Field( - default=5.0, alias="WS_MAX_STALENESS_SECONDS") + ws_heartbeat_timeout_seconds: float = Field(default=20.0, alias="WS_HEARTBEAT_TIMEOUT_SECONDS") + ws_max_staleness_seconds: float = Field(default=5.0, alias="WS_MAX_STALENESS_SECONDS") strategy_enable_stat_arb_experiment: bool = Field( default=False, alias="STRATEGY_ENABLE_STAT_ARB_EXPERIMENT", @@ -125,29 +103,20 @@ class Settings(BaseSettings): ) paper_trading_mode: bool = Field(default=True, alias="PAPER_TRADING_MODE") trade_capital_usd: float = Field(default=100.0, alias="TRADE_CAPITAL_USD") - max_trade_capital_usd: float = Field( - default=100.0, alias="MAX_TRADE_CAPITAL_USD") - max_concurrent_trades: int | None = Field( - default=None, alias="MAX_CONCURRENT_TRADES") + max_trade_capital_usd: float = Field(default=100.0, alias="MAX_TRADE_CAPITAL_USD") + max_concurrent_trades: int | None = Field(default=None, alias="MAX_CONCURRENT_TRADES") max_exposure_per_asset_usd: float | None = Field( default=None, alias="MAX_EXPOSURE_PER_ASSET_USD", ) - quote_balance_asset: str = Field( - default="USD", alias="QUOTE_BALANCE_ASSET") - min_order_size_usd: float | None = Field( - default=None, alias="MIN_ORDER_SIZE_USD") + quote_balance_asset: str = Field(default="USD", alias="QUOTE_BALANCE_ASSET") + min_order_size_usd: float | None = Field(default=None, alias="MIN_ORDER_SIZE_USD") kill_switch_active: bool = Field(default=False, alias="KILL_SWITCH_ACTIVE") - daily_loss_limit_usd: float | None = Field( - default=None, alias="DAILY_LOSS_LIMIT_USD") - cumulative_loss_limit_usd: float | None = Field( - default=None, alias="CUMULATIVE_LOSS_LIMIT_USD") - max_source_latency_ms: float | None = Field( - default=None, alias="MAX_SOURCE_LATENCY_MS") - max_apply_latency_ms: float | None = Field( - default=None, alias="MAX_APPLY_LATENCY_MS") - max_consecutive_failures: int | None = Field( - default=None, alias="MAX_CONSECUTIVE_FAILURES") + daily_loss_limit_usd: float | None = Field(default=None, alias="DAILY_LOSS_LIMIT_USD") + cumulative_loss_limit_usd: float | None = Field(default=None, alias="CUMULATIVE_LOSS_LIMIT_USD") + max_source_latency_ms: float | None = Field(default=None, alias="MAX_SOURCE_LATENCY_MS") + max_apply_latency_ms: float | None = Field(default=None, alias="MAX_APPLY_LATENCY_MS") + max_consecutive_failures: int | None = Field(default=None, alias="MAX_CONSECUTIVE_FAILURES") fernet_key: str | None = Field(default=None, alias="FERNET_KEY") @@ -164,8 +133,7 @@ class Settings(BaseSettings): def _validate_log_level(cls, value: str) -> str: normalized = value.strip().upper() if normalized not in {"DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"}: - raise ValueError( - "LOG_LEVEL must be one of: DEBUG, INFO, WARNING, ERROR, CRITICAL") + raise ValueError("LOG_LEVEL must be one of: DEBUG, INFO, WARNING, ERROR, CRITICAL") return normalized @field_validator("alert_min_severity") @@ -173,19 +141,16 @@ class Settings(BaseSettings): def _validate_alert_severity(cls, value: str) -> str: normalized = value.strip().lower() if normalized not in {"info", "warning", "error", "critical"}: - raise ValueError( - "ALERT_MIN_SEVERITY must be one of: info, warning, error, critical") + raise ValueError("ALERT_MIN_SEVERITY must be one of: info, warning, error, critical") return normalized @model_validator(mode="after") def _validate_security_constraints(self) -> Settings: if bool(self.dashboard_auth_username) ^ bool(self.dashboard_auth_password): - raise ValueError( - "dashboard auth requires both username and password") + raise ValueError("dashboard auth requires both username and password") if bool(self.kraken_api_key) ^ bool(self.kraken_api_secret): - raise ValueError( - "Kraken API auth requires both API key and secret") + raise ValueError("Kraken API auth requires both API key and secret") permissions = { token.strip().lower() @@ -193,11 +158,9 @@ class Settings(BaseSettings): if token.strip() } if permissions and ("query" not in permissions or "trade" not in permissions): - raise ValueError( - "KRAKEN_API_KEY_PERMISSIONS must include query and trade") + raise ValueError("KRAKEN_API_KEY_PERMISSIONS must include query and trade") if "withdraw" in permissions or "withdrawals" in permissions: - raise ValueError( - "KRAKEN_API_KEY_PERMISSIONS must not include withdrawal scope") + raise ValueError("KRAKEN_API_KEY_PERMISSIONS must not include withdrawal scope") if self.alert_dedup_seconds < 0.0: raise ValueError("ALERT_DEDUP_SECONDS must be >= 0") @@ -213,8 +176,7 @@ class Settings(BaseSettings): "STRATEGY_STAT_ARB_ENTRY_ZSCORE must be greater than STRATEGY_STAT_ARB_EXIT_ZSCORE" ) if self.strategy_stat_arb_max_holding_seconds <= 0.0: - raise ValueError( - "STRATEGY_STAT_ARB_MAX_HOLDING_SECONDS must be > 0") + raise ValueError("STRATEGY_STAT_ARB_MAX_HOLDING_SECONDS must be > 0") return self diff --git a/src/arbitrade/exchange/fee_service.py b/src/arbitrade/exchange/fee_service.py index e76b093..b548c83 100644 --- a/src/arbitrade/exchange/fee_service.py +++ b/src/arbitrade/exchange/fee_service.py @@ -42,12 +42,9 @@ async def fetch_and_store_account_snapshot( _LOG.exception("trade_balance_fetch_failed") return None - fee_tier = volume_data.get("fee_tier") if isinstance( - volume_data, dict) else None - fees_dict = volume_data.get("fees") if isinstance( - volume_data, dict) else None - fees_maker = volume_data.get("fees_maker") if isinstance( - volume_data, dict) else None + fee_tier = volume_data.get("fee_tier") if isinstance(volume_data, dict) else None + fees_dict = volume_data.get("fees") if isinstance(volume_data, dict) else None + fees_maker = volume_data.get("fees_maker") if isinstance(volume_data, dict) else None currency = volume_data.get("currency") thirty_day_volume_str = volume_data.get("volume") @@ -73,8 +70,7 @@ async def fetch_and_store_account_snapshot( if currency is not None: fee_schedule["currency"] = currency - thirty_day_volume = float( - thirty_day_volume_str) if thirty_day_volume_str is not None else None + thirty_day_volume = float(thirty_day_volume_str) if thirty_day_volume_str is not None else None snapshot = KrakenAccountSnapshot( snapshot_at=datetime.now(UTC), @@ -82,8 +78,7 @@ async def fetch_and_store_account_snapshot( maker_fee=maker_fee, taker_fee=taker_fee, thirty_day_volume=thirty_day_volume, - trade_balance_raw=balance_data if isinstance( - balance_data, dict) else None, + trade_balance_raw=balance_data if isinstance(balance_data, dict) else None, fee_schedule_raw=fee_schedule if fee_schedule else None, ) @@ -107,8 +102,7 @@ async def fetch_and_store_account_snapshot( "INSERT INTO portfolio_snapshots" " (snapshot_at, balances, total_value_usd) VALUES ($1, $2, $3)", datetime.now(UTC), - orjson.dumps(wallet_balances).decode( - "utf-8") if wallet_balances else None, + orjson.dumps(wallet_balances).decode("utf-8") if wallet_balances else None, total_value, ) _LOG.info("portfolio_snapshot_stored", total_value_usd=total_value) @@ -127,8 +121,7 @@ async def run_fee_sync_loop( Runs until stop_event is set. """ - _LOG.info("fee_sync_loop_started", - interval_s=_FEE_REFRESH_INTERVAL_SECONDS) + _LOG.info("fee_sync_loop_started", interval_s=_FEE_REFRESH_INTERVAL_SECONDS) while not stop_event.is_set(): try: diff --git a/src/arbitrade/execution/sequencer.py b/src/arbitrade/execution/sequencer.py index 9d9fc2f..0dd72b2 100644 --- a/src/arbitrade/execution/sequencer.py +++ b/src/arbitrade/execution/sequencer.py @@ -47,15 +47,13 @@ class TriangularExecutionSequencer: rest_client: SupportsOrderPlacement, *, available_pairs: Sequence[str], - volume_for_leg: Callable[[OpportunityEvent, - ExecutionLeg, int], float] | None = None, + volume_for_leg: Callable[[OpportunityEvent, ExecutionLeg, int], float] | None = None, execution_writer: AsyncExecutionWriter | None = None, alert_notifier: SupportsAlerts | None = None, audit_repository: AuditRepository | None = None, ) -> None: self._rest_client = rest_client - self._available_pairs = {self._normalize_pair( - pair) for pair in available_pairs} + self._available_pairs = {self._normalize_pair(pair) for pair in available_pairs} self._volume_for_leg = volume_for_leg or self._default_volume_for_leg self._execution_writer = execution_writer self._alert_notifier = alert_notifier @@ -102,15 +100,12 @@ class TriangularExecutionSequencer: raise ValueError(f"No tradable pair for leg {from_cur}->{to_cur}") def _build_legs(self, event: OpportunityEvent) -> tuple[ExecutionLeg, ...]: - currencies = [part.strip().upper() - for part in event.cycle.split("->") if part.strip()] + currencies = [part.strip().upper() for part in event.cycle.split("->") if part.strip()] if len(currencies) < 4 or currencies[0] != currencies[-1]: - raise ValueError( - "cycle must be a closed triangular path like A->B->C->A") + raise ValueError("cycle must be a closed triangular path like A->B->C->A") if len(currencies) != 4: - raise ValueError( - "cycle must contain exactly three unique currencies") + raise ValueError("cycle must contain exactly three unique currencies") legs: list[ExecutionLeg] = [] for idx in range(3): @@ -125,8 +120,7 @@ class TriangularExecutionSequencer: ) volume = self._volume_for_leg(event, placeholder_leg, idx) if volume <= 0.0: - raise ValueError( - "volume_for_leg must return a positive volume") + raise ValueError("volume_for_leg must return a positive volume") legs.append(self._resolve_leg(from_currency, to_currency, volume)) return tuple(legs) @@ -215,8 +209,7 @@ class TriangularExecutionSequencer: responses.append(response) if self._execution_writer is not None: - order_ref = self._order_ref_from_response( - response, f"leg-{idx}") + order_ref = self._order_ref_from_response(response, f"leg-{idx}") await self._execution_writer.enqueue( OrderRecord( trade_ref=trade_ref, diff --git a/src/arbitrade/logging/__init__.py b/src/arbitrade/logging/__init__.py index d781b63..5d7da6c 100644 --- a/src/arbitrade/logging/__init__.py +++ b/src/arbitrade/logging/__init__.py @@ -1 +1 @@ -"""Logging package — DB sink, maintenance tasks.""" \ No newline at end of file +"""Logging package — DB sink, maintenance tasks.""" diff --git a/src/arbitrade/logging/db_sink.py b/src/arbitrade/logging/db_sink.py index 5d7a251..0889993 100644 --- a/src/arbitrade/logging/db_sink.py +++ b/src/arbitrade/logging/db_sink.py @@ -22,8 +22,7 @@ class DbSinkProcessor: """ def __init__(self) -> None: - self._queue: asyncio.Queue[dict[str, Any] - ] = asyncio.Queue(maxsize=2000) + self._queue: asyncio.Queue[dict[str, Any]] = asyncio.Queue(maxsize=2000) self._consumer_task: asyncio.Task[None] | None = None def __call__(self, logger: Any, method_name: str, event_dict: dict[str, Any]) -> dict[str, Any]: @@ -38,9 +37,7 @@ class DbSinkProcessor: """Start background consumer task.""" if self._consumer_task is not None and not self._consumer_task.done(): return - self._consumer_task = asyncio.create_task( - self._consume(store), name="log_db_sink" - ) + self._consumer_task = asyncio.create_task(self._consume(store), name="log_db_sink") async def stop_consumer(self) -> None: """Drain queue and cancel consumer.""" @@ -81,8 +78,7 @@ class DbSinkProcessor: level = str(event.pop("level", "info")).upper() logger = str(event.pop("logger", "root")) message = str(event.pop("event", event.pop("message", ""))) - context = {k: v for k, v in event.items( - ) if not k.startswith("_")} if event else None + context = {k: v for k, v in event.items() if not k.startswith("_")} if event else None record = LogRecord( recorded_at=recorded_at, @@ -118,8 +114,6 @@ def get_db_sink() -> DbSinkProcessor: return _db_sink -def db_sink_processor( - logger: Any, method_name: str, event_dict: dict[str, Any] -) -> dict[str, Any]: +def db_sink_processor(logger: Any, method_name: str, event_dict: dict[str, Any]) -> dict[str, Any]: """Standalone processor function wrapping the singleton.""" return _db_sink(logger, method_name, event_dict) diff --git a/src/arbitrade/logging/maintenance.py b/src/arbitrade/logging/maintenance.py index 6ef3726..1a1a89f 100644 --- a/src/arbitrade/logging/maintenance.py +++ b/src/arbitrade/logging/maintenance.py @@ -36,8 +36,7 @@ async def run_log_archive(store: PgStore, retention_days: int = _RETENTION_DAYS) repo = LogArchiveRepository(store) count = await repo.archive_before(cutoff) if count > 0: - _LOG.info("log_archive_complete", - cutoff=cutoff.isoformat(), archived=count) + _LOG.info("log_archive_complete", cutoff=cutoff.isoformat(), archived=count) return count diff --git a/src/arbitrade/market_data/feed.py b/src/arbitrade/market_data/feed.py index 9cd5f66..8fdca89 100644 --- a/src/arbitrade/market_data/feed.py +++ b/src/arbitrade/market_data/feed.py @@ -38,8 +38,7 @@ class MarketDataFeed: opportunity_writer: AsyncOpportunityWriter | None = None, paper_trading_mode: bool = True, opportunity_executor: ( - Callable[[OpportunityEvent], - Awaitable[ExecutionOutcome | float | None]] | None + Callable[[OpportunityEvent], Awaitable[ExecutionOutcome | float | None]] | None ) = None, trade_capital: float = 1.0, max_trade_capital: float | None = None, @@ -93,8 +92,7 @@ class MarketDataFeed: return {} start = currencies[0] - exposure_assets = { - currency for currency in currencies[1:] if currency != start} + exposure_assets = {currency for currency in currencies[1:] if currency != start} return {asset: event.allocated_capital for asset in exposure_assets} async def run(self) -> None: @@ -315,8 +313,7 @@ class MarketDataFeed: continue if self._pre_trade_validator is not None and self._balance_provider is not None: - required_balances = { - self._quote_balance_asset: event.allocated_capital} + required_balances = {self._quote_balance_asset: event.allocated_capital} balances = { asset.upper(): amount for asset, amount in self._balance_provider().items() @@ -384,8 +381,7 @@ class MarketDataFeed: outcome = await self._opportunity_executor(event) except Exception as exc: if self._trade_limits_guard is not None: - self._trade_limits_guard.close_trade( - exposure_by_asset) + self._trade_limits_guard.close_trade(exposure_by_asset) dispatch_alert_nowait( self._alert_notifier, @@ -451,8 +447,7 @@ class MarketDataFeed: realized_pnl = outcome if realized_pnl is not None and self._loss_limit_guard is not None: - self._loss_limit_guard.register_realized_pnl( - realized_pnl) + self._loss_limit_guard.register_realized_pnl(realized_pnl) if self._loss_limit_guard.is_halted: _LOG.warning( "loss_limit_halt_triggered", diff --git a/src/arbitrade/market_data/order_book.py b/src/arbitrade/market_data/order_book.py index a95803b..a4ba86a 100644 --- a/src/arbitrade/market_data/order_book.py +++ b/src/arbitrade/market_data/order_book.py @@ -86,8 +86,7 @@ class OrderBook: BookLevel(price=price, volume=self._bids[price]) for price in reversed(bid_keys[-depth:]) ] - asks = [BookLevel(price=price, volume=self._asks[price]) - for price in ask_keys[:depth]] + asks = [BookLevel(price=price, volume=self._asks[price]) for price in ask_keys[:depth]] return bids, asks def compute_checksum(self, depth: int = 10) -> int: diff --git a/src/arbitrade/metrics.py b/src/arbitrade/metrics.py index ca39d97..f143dff 100644 --- a/src/arbitrade/metrics.py +++ b/src/arbitrade/metrics.py @@ -51,23 +51,34 @@ class MetricsCalculator: WHERE volume > 0 AND filled_volume IS NOT NULL """) - r_pnl_usd = float( - tm["realized_pnl_usd"]) if tm and tm["realized_pnl_usd"] is not None else 0.0 - tt = int(tm["total_trades"] - ) if tm and tm["total_trades"] is not None else 0 - wt = int(tm["winning_trades"] - ) if tm and tm["winning_trades"] is not None else 0 + r_pnl_usd = ( + float(tm["realized_pnl_usd"]) if tm and tm["realized_pnl_usd"] is not None else 0.0 + ) + tt = int(tm["total_trades"]) if tm and tm["total_trades"] is not None else 0 + wt = int(tm["winning_trades"]) if tm and tm["winning_trades"] is not None else 0 wr = wt / tt if tt > 0 else None - atd = float(tm["avg_trade_duration_seconds"] - ) if tm and tm["avg_trade_duration_seconds"] is not None else None + atd = ( + float(tm["avg_trade_duration_seconds"]) + if tm and tm["avg_trade_duration_seconds"] is not None + else None + ) - oc = int(om["opportunity_count"] - ) if om is not None and om["opportunity_count"] is not None else 0 - fo = om["first_detected_at"] if om is not None and isinstance( - om["first_detected_at"], datetime) else None - lo = om["last_detected_at"] if om is not None and isinstance( - om["last_detected_at"], datetime) else None + oc = ( + int(om["opportunity_count"]) + if om is not None and om["opportunity_count"] is not None + else 0 + ) + fo = ( + om["first_detected_at"] + if om is not None and isinstance(om["first_detected_at"], datetime) + else None + ) + lo = ( + om["last_detected_at"] + if om is not None and isinstance(om["last_detected_at"], datetime) + else None + ) opportunities_per_minute: float | None if oc >= 2 and fo is not None and lo is not None: @@ -80,15 +91,23 @@ class MetricsCalculator: else: opportunities_per_minute = None - fill_rate = float( - fm["fill_rate"]) if fm and fm["fill_rate"] is not None else None + fill_rate = float(fm["fill_rate"]) if fm and fm["fill_rate"] is not None else None - lp50 = float(tm["latency_p50_seconds"] - ) if tm and tm["latency_p50_seconds"] is not None else None - lp95 = float(tm["latency_p95_seconds"] - ) if tm and tm["latency_p95_seconds"] is not None else None - lp99 = float(tm["latency_p99_seconds"] - ) if tm and tm["latency_p99_seconds"] is not None else None + lp50 = ( + float(tm["latency_p50_seconds"]) + if tm and tm["latency_p50_seconds"] is not None + else None + ) + lp95 = ( + float(tm["latency_p95_seconds"]) + if tm and tm["latency_p95_seconds"] is not None + else None + ) + lp99 = ( + float(tm["latency_p99_seconds"]) + if tm and tm["latency_p99_seconds"] is not None + else None + ) return PerformanceMetrics( realized_pnl_usd=r_pnl_usd, diff --git a/src/arbitrade/runtime/lifecycle.py b/src/arbitrade/runtime/lifecycle.py index 0f6738b..718fd6a 100644 --- a/src/arbitrade/runtime/lifecycle.py +++ b/src/arbitrade/runtime/lifecycle.py @@ -106,7 +106,9 @@ async def _run_startup_reconciler(app: FastAPI) -> None: await result -async def persist_runtime_snapshot(app: FastAPI, *, note: str | None = None) -> RuntimeStateRecord | None: +async def persist_runtime_snapshot( + app: FastAPI, *, note: str | None = None +) -> RuntimeStateRecord | None: repository = _runtime_repository(app) if repository is None: return None diff --git a/src/arbitrade/storage/executions.py b/src/arbitrade/storage/executions.py index 502ab2d..ce24832 100644 --- a/src/arbitrade/storage/executions.py +++ b/src/arbitrade/storage/executions.py @@ -36,8 +36,7 @@ class AsyncExecutionWriter: async def start(self) -> None: if self._task is None or self._task.done(): self._stop.clear() - self._task = asyncio.create_task( - self._run(), name="execution-writer") + self._task = asyncio.create_task(self._run(), name="execution-writer") async def stop(self) -> None: self._stop.set() diff --git a/src/arbitrade/storage/market_snapshots.py b/src/arbitrade/storage/market_snapshots.py index c3e4493..a08f6ac 100644 --- a/src/arbitrade/storage/market_snapshots.py +++ b/src/arbitrade/storage/market_snapshots.py @@ -24,16 +24,14 @@ class MarketSnapshot: class AsyncMarketSnapshotWriter: def __init__(self, repository: MarketSnapshotRepository, max_queue_size: int = 50_000) -> None: self._repository = repository - self._queue: asyncio.Queue[MarketSnapshot] = asyncio.Queue( - maxsize=max_queue_size) + self._queue: asyncio.Queue[MarketSnapshot] = asyncio.Queue(maxsize=max_queue_size) self._task: asyncio.Task[None] | None = None self._stop = asyncio.Event() async def start(self) -> None: if self._task is None or self._task.done(): self._stop.clear() - self._task = asyncio.create_task( - self._run(), name="market-snapshot-writer") + self._task = asyncio.create_task(self._run(), name="market-snapshot-writer") async def stop(self) -> None: self._stop.set() @@ -61,7 +59,6 @@ class AsyncMarketSnapshotWriter: ) ) except Exception as exc: - _LOG.error("market_snapshot_write_failed", - error=str(exc), symbol=item.symbol) + _LOG.error("market_snapshot_write_failed", error=str(exc), symbol=item.symbol) finally: self._queue.task_done() diff --git a/src/arbitrade/storage/opportunities.py b/src/arbitrade/storage/opportunities.py index 419520d..2fcba12 100644 --- a/src/arbitrade/storage/opportunities.py +++ b/src/arbitrade/storage/opportunities.py @@ -13,16 +13,14 @@ _LOG = structlog.get_logger(__name__) class AsyncOpportunityWriter: def __init__(self, repository: OpportunityRepository, max_queue_size: int = 50_000) -> None: self._repository = repository - self._queue: asyncio.Queue[OpportunityEvent] = asyncio.Queue( - maxsize=max_queue_size) + self._queue: asyncio.Queue[OpportunityEvent] = asyncio.Queue(maxsize=max_queue_size) self._task: asyncio.Task[None] | None = None self._stop = asyncio.Event() async def start(self) -> None: if self._task is None or self._task.done(): self._stop.clear() - self._task = asyncio.create_task( - self._run(), name="opportunity-writer") + self._task = asyncio.create_task(self._run(), name="opportunity-writer") async def stop(self) -> None: self._stop.set() diff --git a/src/arbitrade/storage/pg_store.py b/src/arbitrade/storage/pg_store.py index 58c45f7..a2b8b34 100644 --- a/src/arbitrade/storage/pg_store.py +++ b/src/arbitrade/storage/pg_store.py @@ -128,7 +128,5 @@ class PgStore: col_name = column_def.split()[0] if col_name not in existing: async with self.pool.acquire() as conn: - await conn.execute( - f"ALTER TABLE {table_name} ADD COLUMN {column_def}" - ) + await conn.execute(f"ALTER TABLE {table_name} ADD COLUMN {column_def}") _LOG.info("pg_column_added", table=table_name, column=col_name) diff --git a/src/arbitrade/storage/repositories.py b/src/arbitrade/storage/repositories.py index f91310d..6ecee68 100644 --- a/src/arbitrade/storage/repositories.py +++ b/src/arbitrade/storage/repositories.py @@ -258,11 +258,7 @@ class AuditRepository: record.actor, record.event_type, record.decision, - ( - None - if record.payload is None - else orjson.dumps(record.payload).decode("utf-8") - ), + (None if record.payload is None else orjson.dumps(record.payload).decode("utf-8")), record.correlation_id, ) @@ -294,8 +290,9 @@ class AuditRepository: event_type=str(row["event_type"]), decision=str(row["decision"]), payload=payload, - correlation_id=str( - row["correlation_id"]) if row["correlation_id"] is not None else None, + correlation_id=( + str(row["correlation_id"]) if row["correlation_id"] is not None else None + ), ) ) @@ -364,8 +361,9 @@ class RuntimeStateRepository: snapshot_at=row["snapshot_at"], is_running=bool(row["is_running"]), kill_switch_active=bool(row["kill_switch_active"]), - kill_switch_reason=str( - row["kill_switch_reason"]) if row["kill_switch_reason"] is not None else None, + kill_switch_reason=( + str(row["kill_switch_reason"]) if row["kill_switch_reason"] is not None else None + ), open_trade_count=int(row["open_trade_count"]), last_known_balances=balances, note=str(row["note"]) if row["note"] is not None else None, @@ -386,10 +384,16 @@ class ConfigSectionRepository: VALUES ($1, $2) RETURNING id, name, description, updated_at """, - section.name, section.description, + section.name, + section.description, ) if row: - return ConfigSection(id=row["id"], name=row["name"], description=row["description"], updated_at=row["updated_at"]) + return ConfigSection( + id=row["id"], + name=row["name"], + description=row["description"], + updated_at=row["updated_at"], + ) raise ValueError("Failed to create section") async def get_section(self, name: str) -> ConfigSection | None: @@ -404,7 +408,12 @@ class ConfigSectionRepository: name, ) if row: - return ConfigSection(id=row["id"], name=row["name"], description=row["description"], updated_at=row["updated_at"]) + return ConfigSection( + id=row["id"], + name=row["name"], + description=row["description"], + updated_at=row["updated_at"], + ) return None async def list_sections(self) -> list[ConfigSection]: @@ -417,7 +426,11 @@ class ConfigSectionRepository: """) return [ ConfigSection( - id=r["id"], name=r["name"], description=r["description"], updated_at=r["updated_at"]) + id=r["id"], + name=r["name"], + description=r["description"], + updated_at=r["updated_at"], + ) for r in rows ] @@ -571,7 +584,7 @@ class ConfigSettingRepository: ts = row["latest_updated_at"] if isinstance(ts, str): return datetime.fromisoformat(ts.replace("Z", "+00:00")) - return ts # type: ignore[no-any-return] + return ts # type: ignore[no-any-return] return None @@ -614,7 +627,8 @@ class ConfigPairingRepository: FROM config_pairings WHERE base_asset = $1 AND quote_asset = $2 """, - base_asset, quote_asset, + base_asset, + quote_asset, ) if row: return ConfigPairing( @@ -665,7 +679,8 @@ class ConfigPairingRepository: DELETE FROM config_pairings WHERE base_asset = $1 AND quote_asset = $2 """, - base_asset, quote_asset, + base_asset, + quote_asset, ) if result is None: return False @@ -732,7 +747,9 @@ class ConfigBacktestingDefaultsRepository: def __init__(self, store: PgStore) -> None: self._store = store - async def create_defaults(self, defaults: ConfigBacktestingDefaults) -> ConfigBacktestingDefaults: + async def create_defaults( + self, defaults: ConfigBacktestingDefaults + ) -> ConfigBacktestingDefaults: """Create new backtesting defaults.""" async with self._store.pool.acquire() as conn: balances_json = ( @@ -754,8 +771,9 @@ class ConfigBacktestingDefaultsRepository: ) if row: return ConfigBacktestingDefaults( - starting_balances=orjson.loads( - row["starting_balances"]) if row["starting_balances"] else None, + starting_balances=( + orjson.loads(row["starting_balances"]) if row["starting_balances"] else None + ), trade_capital=row["trade_capital"], min_profit_threshold=row["min_profit_threshold"], slippage_bps=row["slippage_bps"], @@ -774,8 +792,9 @@ class ConfigBacktestingDefaultsRepository: """) if row: return ConfigBacktestingDefaults( - starting_balances=orjson.loads( - row["starting_balances"]) if row["starting_balances"] else None, + starting_balances=( + orjson.loads(row["starting_balances"]) if row["starting_balances"] else None + ), trade_capital=row["trade_capital"], min_profit_threshold=row["min_profit_threshold"], slippage_bps=row["slippage_bps"], @@ -783,7 +802,9 @@ class ConfigBacktestingDefaultsRepository: ) return None - async def update_defaults(self, defaults: ConfigBacktestingDefaults) -> ConfigBacktestingDefaults: + async def update_defaults( + self, defaults: ConfigBacktestingDefaults + ) -> ConfigBacktestingDefaults: """Update the backtesting defaults.""" async with self._store.pool.acquire() as conn: starting_balances_json = ( @@ -808,8 +829,9 @@ class ConfigBacktestingDefaultsRepository: ) if row: return ConfigBacktestingDefaults( - starting_balances=orjson.loads( - row["starting_balances"]) if row["starting_balances"] else None, + starting_balances=( + orjson.loads(row["starting_balances"]) if row["starting_balances"] else None + ), trade_capital=row["trade_capital"], min_profit_threshold=row["min_profit_threshold"], slippage_bps=row["slippage_bps"], @@ -878,10 +900,12 @@ class KrakenAccountSnapshotRepository: maker_fee=row["maker_fee"], taker_fee=row["taker_fee"], thirty_day_volume=row["thirty_day_volume"], - trade_balance_raw=orjson.loads( - row["trade_balance_raw"]) if row["trade_balance_raw"] else None, - fee_schedule_raw=orjson.loads( - row["fee_schedule_raw"]) if row["fee_schedule_raw"] else None, + trade_balance_raw=( + orjson.loads(row["trade_balance_raw"]) if row["trade_balance_raw"] else None + ), + fee_schedule_raw=( + orjson.loads(row["fee_schedule_raw"]) if row["fee_schedule_raw"] else None + ), ) @@ -916,7 +940,8 @@ class BacktestJobRepository: VALUES ($1, $2) RETURNING id, status, events_path, config, created_at """, - events_path, job_config_json, + events_path, + job_config_json, ) if row is None: raise ValueError("Failed to create backtest job") @@ -933,24 +958,30 @@ class BacktestJobRepository: if status == "running": await conn.execute( "UPDATE backtest_jobs SET status = $1, started_at = CURRENT_TIMESTAMP WHERE id = $2", - status, job_id, + status, + job_id, ) elif status in ("completed", "failed"): await conn.execute( "UPDATE backtest_jobs SET status = $1, finished_at = CURRENT_TIMESTAMP, error = $2 WHERE id = $3", - status, error, job_id, + status, + error, + job_id, ) else: await conn.execute( "UPDATE backtest_jobs SET status = $1, error = $2 WHERE id = $3", - status, error, job_id, + status, + error, + job_id, ) async def store_report(self, job_id: str, report: dict[str, Any]) -> None: async with self._store.pool.acquire() as conn: await conn.execute( "UPDATE backtest_jobs SET report = $1 WHERE id = $2", - orjson.dumps(report).decode("utf-8"), job_id, + orjson.dumps(report).decode("utf-8"), + job_id, ) async def get_job(self, job_id: str) -> BacktestJobRecord | None: @@ -1082,7 +1113,9 @@ class LogRepository: ORDER BY recorded_at DESC LIMIT ${idx} OFFSET ${idx + 1} """, - *params, limit, offset, + *params, + limit, + offset, ) return [ LogRecord( @@ -1153,9 +1186,7 @@ class LogArchiveRepository: cutoff, ) # Delete originals - await conn.execute( - "DELETE FROM app_logs WHERE recorded_at < $1", cutoff - ) + await conn.execute("DELETE FROM app_logs WHERE recorded_at < $1", cutoff) if isinstance(result, str): parts = result.split() if len(parts) == 2 and parts[0] == "INSERT": @@ -1221,7 +1252,9 @@ class LogAggregationRepository: ORDER BY bucket_start DESC LIMIT $3 """, - period, level.upper(), limit, + period, + level.upper(), + limit, ) else: rows = await conn.fetch( @@ -1232,7 +1265,8 @@ class LogAggregationRepository: ORDER BY bucket_start DESC LIMIT $2 """, - period, limit, + period, + limit, ) return [ LogAggregateRecord( diff --git a/tests/integration/__init__.py b/tests/integration/__init__.py index ecab37f..0d29326 100644 --- a/tests/integration/__init__.py +++ b/tests/integration/__init__.py @@ -1 +1 @@ -"""Integration tests for PostgreSQL schema and connectivity.""" \ No newline at end of file +"""Integration tests for PostgreSQL schema and connectivity.""" diff --git a/tests/integration/conftest.py b/tests/integration/conftest.py index 7b5a4a5..385e9b1 100644 --- a/tests/integration/conftest.py +++ b/tests/integration/conftest.py @@ -11,13 +11,9 @@ import pathlib import pytest -def pytest_ignore_collect( - collection_path: pathlib.Path, config: pytest.Config -) -> bool: +def pytest_ignore_collect(collection_path: pathlib.Path, config: pytest.Config) -> bool: """Skip integration tests unless --integration is passed.""" - if "integration" in str(collection_path) and not config.getoption( - "--integration", False - ): + if "integration" in str(collection_path) and not config.getoption("--integration", False): return True return False diff --git a/tests/integration/test_metrics.py b/tests/integration/test_metrics.py index 4e7336a..99e5081 100644 --- a/tests/integration/test_metrics.py +++ b/tests/integration/test_metrics.py @@ -42,9 +42,24 @@ async def test_metrics_calculator_summarizes_execution_data() -> None: ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9), ($10, $11, $12, $13, $14, $15, $16, $17, $18) """, - "trade-1", started, finished, "filled", 12.5, 10.0, 100.0, "USD->BTC->ETH->USD", 3, - "trade-2", started_two, finished_two, "filled", - - 4.5, -2.0, 200.0, "USD->ETH->BTC->USD", 3, + "trade-1", + started, + finished, + "filled", + 12.5, + 10.0, + 100.0, + "USD->BTC->ETH->USD", + 3, + "trade-2", + started_two, + finished_two, + "filled", + -4.5, + -2.0, + 200.0, + "USD->ETH->BTC->USD", + 3, ) await conn.execute( """ @@ -53,11 +68,24 @@ async def test_metrics_calculator_summarizes_execution_data() -> None: ($7, $8, $9, $10, $11, $12), ($13, $14, $15, $16, $17, $18) """, - started, "USD->BTC->ETH->USD", 4.0, 3.0, 0.03, True, - started_two, "USD->ETH->BTC->USD", 2.0, 1.0, 0.01, False, - started_two + - timedelta( - seconds=30), "USD->BTC->ETH->USD", 5.0, 4.0, 0.04, True, + started, + "USD->BTC->ETH->USD", + 4.0, + 3.0, + 0.03, + True, + started_two, + "USD->ETH->BTC->USD", + 2.0, + 1.0, + 0.01, + False, + started_two + timedelta(seconds=30), + "USD->BTC->ETH->USD", + 5.0, + 4.0, + 0.04, + True, ) await conn.execute( """ @@ -67,8 +95,30 @@ async def test_metrics_calculator_summarizes_execution_data() -> None: ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12), ($13, $14, $15, $16, $17, $18, $19, $20, $21, $22, $23, $24) """, - "trade-1", "order-1", 0, "BTC/USD", "buy", 2.0, 101, "closed", 2.0, 100.0, "{}", started, - "trade-2", "order-2", 0, "ETH/USD", "sell", 4.0, 202, "closed", 3.0, 200.0, "{}", started_two, + "trade-1", + "order-1", + 0, + "BTC/USD", + "buy", + 2.0, + 101, + "closed", + 2.0, + 100.0, + "{}", + started, + "trade-2", + "order-2", + 0, + "ETH/USD", + "sell", + 4.0, + 202, + "closed", + 3.0, + 200.0, + "{}", + started_two, ) metrics = await MetricsCalculator(store).compute() diff --git a/tests/integration/test_postgresql_schema.py b/tests/integration/test_postgresql_schema.py index c9c7c9a..7c04f96 100644 --- a/tests/integration/test_postgresql_schema.py +++ b/tests/integration/test_postgresql_schema.py @@ -25,50 +25,116 @@ EXPECTED_TABLES: dict[str, list[str]] = { "schema_migrations": ["version", "applied_at"], "config_sections": ["id", "name", "description", "updated_at"], "config_settings": [ - "key", "section", "value_json", "value_type", "is_secret", - "is_runtime_reloadable", "updated_at", "updated_by", + "key", + "section", + "value_json", + "value_type", + "is_secret", + "is_runtime_reloadable", + "updated_at", + "updated_by", ], "config_pairings": [ - "id", "base_asset", "quote_asset", "enabled", "source", - "created_at", "updated_at", + "id", + "base_asset", + "quote_asset", + "enabled", + "source", + "created_at", + "updated_at", ], "config_backtesting_defaults": [ - "id", "starting_balances", "trade_capital", "min_profit_threshold", - "slippage_bps", "execution_latency_ms", "fee_source", + "id", + "starting_balances", + "trade_capital", + "min_profit_threshold", + "slippage_bps", + "execution_latency_ms", + "fee_source", ], "opportunities": [ - "id", "detected_at", "cycle", "gross_pct", "net_pct", - "est_profit", "executed", + "id", + "detected_at", + "cycle", + "gross_pct", + "net_pct", + "est_profit", + "executed", ], "trades": [ - "id", "trade_ref", "started_at", "finished_at", "status", - "realized_pnl", "estimated_pnl", "capital_used", "cycle", "leg_count", + "id", + "trade_ref", + "started_at", + "finished_at", + "status", + "realized_pnl", + "estimated_pnl", + "capital_used", + "cycle", + "leg_count", ], "orders": [ - "id", "trade_ref", "order_ref", "leg_index", "pair", "side", - "volume", "user_ref", "status", "filled_volume", "avg_price", - "raw_response", "recorded_at", + "id", + "trade_ref", + "order_ref", + "leg_index", + "pair", + "side", + "volume", + "user_ref", + "status", + "filled_volume", + "avg_price", + "raw_response", + "recorded_at", ], "pnl_events": [ - "id", "trade_ref", "recorded_at", "kind", "pnl_usd", "source", + "id", + "trade_ref", + "recorded_at", + "kind", + "pnl_usd", + "source", ], "portfolio_snapshots": ["snapshot_at", "balances", "total_value_usd"], "market_snapshots": ["snapshot_at", "symbol", "source", "payload", "latency_ms"], "audit_events": [ - "id", "occurred_at", "actor", "event_type", "decision", - "payload", "correlation_id", + "id", + "occurred_at", + "actor", + "event_type", + "decision", + "payload", + "correlation_id", ], "runtime_state_snapshots": [ - "snapshot_at", "is_running", "kill_switch_active", "kill_switch_reason", - "open_trade_count", "last_known_balances", "note", + "snapshot_at", + "is_running", + "kill_switch_active", + "kill_switch_reason", + "open_trade_count", + "last_known_balances", + "note", ], "kraken_account_snapshots": [ - "snapshot_at", "fee_tier", "maker_fee", "taker_fee", - "thirty_day_volume", "trade_balance_raw", "fee_schedule_raw", + "snapshot_at", + "fee_tier", + "maker_fee", + "taker_fee", + "thirty_day_volume", + "trade_balance_raw", + "fee_schedule_raw", ], "backtest_jobs": [ - "id", "status", "events_path", "config", "report", "error", - "created_at", "started_at", "finished_at", + "id", + "status", + "events_path", + "config", + "report", + "error", + "created_at", + "started_at", + "finished_at", ], } @@ -96,6 +162,7 @@ TABLES_WITH_UNIQUE_CONSTRAINTS: dict[str, list[str]] = { # ── fixtures ──────────────────────────────────────────────────────────────── + @asynccontextmanager async def _pg_lifecycle() -> AsyncIterator[PgStore]: """Connect, yield store, then disconnect.""" @@ -116,6 +183,7 @@ async def pg_fixture() -> AsyncIterator[PgStore]: # ── helpers ───────────────────────────────────────────────────────────────── + async def _get_actual_tables(store: PgStore) -> dict[str, list[str]]: """Return {table_name: [column_name, ...]} for the public schema.""" actual: dict[str, list[str]] = {} @@ -139,6 +207,7 @@ async def _table_row_count(store: PgStore, table: str) -> int: # ── tests ─────────────────────────────────────────────────────────────────── + @pytest.mark.asyncio async def test_pg_connect(pg: PgStore) -> None: """Can connect to PostgreSQL and ping the server.""" @@ -165,8 +234,7 @@ async def test_schema_migration_applies(pg: PgStore) -> None: for table in EXPECTED_TABLES: assert table in actual, ( - f"Table '{table}' missing after migration. " - f"Found tables: {sorted(actual)}" + f"Table '{table}' missing after migration. " f"Found tables: {sorted(actual)}" ) @@ -190,8 +258,7 @@ async def test_table_columns(pg: PgStore) -> None: actual_cols = actual.get(table, []) for col in expected_cols: assert col in actual_cols, ( - f"Column '{col}' missing from table '{table}'. " - f"Actual columns: {actual_cols}" + f"Column '{col}' missing from table '{table}'. " f"Actual columns: {actual_cols}" ) @@ -250,8 +317,7 @@ async def test_table_row_count_is_zero(pg: PgStore) -> None: for table in EXPECTED_TABLES: count = await _table_row_count(pg, table) assert count == 0, ( - f"Table '{table}' should be empty after migration, " - f"but has {count} rows" + f"Table '{table}' should be empty after migration, " f"but has {count} rows" ) @@ -262,13 +328,10 @@ async def test_schema_migration_version_recorded(pg: PgStore) -> None: await pg.migrate() async with pg.pool.acquire() as conn: - row = await conn.fetchrow( - "SELECT MAX(version) AS v FROM schema_migrations" - ) + row = await conn.fetchrow("SELECT MAX(version) AS v FROM schema_migrations") assert row is not None assert row["v"] == SCHEMA_VERSION, ( - f"Expected schema version {SCHEMA_VERSION}, " - f"got {row['v']}" + f"Expected schema version {SCHEMA_VERSION}, " f"got {row['v']}" ) @@ -280,7 +343,8 @@ async def test_create_and_query_row(pg: PgStore) -> None: # ConfigSections round-trip await conn.execute( "INSERT INTO config_sections (name, description) VALUES ($1, $2)", - "test_section", "A test section for integration test", + "test_section", + "A test section for integration test", ) row = await conn.fetchrow( "SELECT name, description FROM config_sections WHERE name = $1", @@ -357,4 +421,4 @@ async def test_audit_list_recent(pg: PgStore) -> None: # Verify payload serialization worked first = recent[0] if first.payload: - assert "index" in first.payload \ No newline at end of file + assert "index" in first.payload diff --git a/tests/unit/test_config_e2e.py b/tests/unit/test_config_e2e.py index 5972d6e..5c95b33 100644 --- a/tests/unit/test_config_e2e.py +++ b/tests/unit/test_config_e2e.py @@ -39,8 +39,7 @@ async def test_end_to_end_config_workflow(): # Mock the setting creation mock_created_setting = Mock() mock_created_setting.updated_at = "2023-01-01T00:00:00" - mock_repo_instance.create_setting = AsyncMock( - return_value=mock_created_setting) + mock_repo_instance.create_setting = AsyncMock(return_value=mock_created_setting) mock_repo_instance.get_setting = AsyncMock(return_value=None) mock_repo_instance.get_latest_updated_at = AsyncMock(return_value=None) mock_repo_instance.list_settings = AsyncMock(return_value=[]) diff --git a/tests/unit/test_config_repositories.py b/tests/unit/test_config_repositories.py index 92d1a96..5e22719 100644 --- a/tests/unit/test_config_repositories.py +++ b/tests/unit/test_config_repositories.py @@ -136,10 +136,8 @@ async def test_config_setting_repository_list_settings(mock_store): repo = ConfigSettingRepository(mock_store) conn = await mock_store.pool.acquire().__aenter__() - row1 = _make_row({**SETTING_ROW, "key": "test_key1", - "value_json": "test_value1"}) - row2 = _make_row({**SETTING_ROW, "key": "test_key2", - "value_json": "test_value2"}) + row1 = _make_row({**SETTING_ROW, "key": "test_key1", "value_json": "test_value1"}) + row2 = _make_row({**SETTING_ROW, "key": "test_key2", "value_json": "test_value2"}) conn.fetch = AsyncMock(return_value=[row1, row2]) result = await repo.list_settings() @@ -176,9 +174,7 @@ async def test_config_pairing_repository_create_pairing(mock_store): conn = await mock_store.pool.acquire().__aenter__() conn.fetchrow = AsyncMock(return_value=_make_row(PAIRING_ROW)) - pairing = ConfigPairing( - base_asset="BTC", quote_asset="USD", enabled=True, source="Kraken" - ) + pairing = ConfigPairing(base_asset="BTC", quote_asset="USD", enabled=True, source="Kraken") result = await repo.create_pairing(pairing) diff --git a/tests/unit/test_config_service.py b/tests/unit/test_config_service.py index 58eedcb..8bad649 100644 --- a/tests/unit/test_config_service.py +++ b/tests/unit/test_config_service.py @@ -63,8 +63,7 @@ async def test_configuration_service_set_setting(mock_settings, mock_store, mock mock_created_setting = Mock() mock_created_setting.updated_at = "2023-01-01T00:00:00" - mock_repo_instance.create_setting = AsyncMock( - return_value=mock_created_setting) + mock_repo_instance.create_setting = AsyncMock(return_value=mock_created_setting) mock_repo_instance.get_setting = AsyncMock(return_value=None) await service.set_setting("test_key", "test_value", "test_user") @@ -73,7 +72,9 @@ async def test_configuration_service_set_setting(mock_settings, mock_store, mock @pytest.mark.asyncio -async def test_configuration_service_hot_reload_detection(mock_settings, mock_store, mock_audit_repo): +async def test_configuration_service_hot_reload_detection( + mock_settings, mock_store, mock_audit_repo +): """Test hot-reload detection functionality.""" service = ConfigurationService(mock_settings, mock_store, mock_audit_repo) @@ -86,8 +87,7 @@ async def test_configuration_service_hot_reload_detection(mock_settings, mock_st from datetime import datetime - mock_repo_instance.get_latest_updated_at = AsyncMock( - return_value=datetime.now()) + mock_repo_instance.get_latest_updated_at = AsyncMock(return_value=datetime.now()) assert await service.is_config_outdated() is True @@ -105,8 +105,7 @@ async def test_configuration_service_reload_if_changed(mock_settings, mock_store from datetime import datetime - mock_repo_instance.get_latest_updated_at = AsyncMock( - return_value=datetime.now()) + mock_repo_instance.get_latest_updated_at = AsyncMock(return_value=datetime.now()) result = await service.reload_if_changed() assert result is True @@ -125,8 +124,7 @@ async def test_configuration_service_get_config_version(mock_settings, mock_stor mock_created_setting = Mock() mock_created_setting.updated_at = "2023-01-01T00:00:00" - mock_repo_instance.create_setting = AsyncMock( - return_value=mock_created_setting) + mock_repo_instance.create_setting = AsyncMock(return_value=mock_created_setting) mock_repo_instance.get_setting = AsyncMock(return_value=None) await service.set_setting("test_key", "test_value", "test_user") @@ -134,7 +132,9 @@ async def test_configuration_service_get_config_version(mock_settings, mock_stor @pytest.mark.asyncio -async def test_configuration_service_get_last_updated_at(mock_settings, mock_store, mock_audit_repo): +async def test_configuration_service_get_last_updated_at( + mock_settings, mock_store, mock_audit_repo +): """Test getting last updated timestamp.""" service = ConfigurationService(mock_settings, mock_store, mock_audit_repo) assert service.get_last_updated_at() is None @@ -145,8 +145,7 @@ async def test_configuration_service_get_last_updated_at(mock_settings, mock_sto mock_created_setting = Mock() mock_created_setting.updated_at = "2023-01-01T00:00:00" - mock_repo_instance.create_setting = AsyncMock( - return_value=mock_created_setting) + mock_repo_instance.create_setting = AsyncMock(return_value=mock_created_setting) mock_repo_instance.get_setting = AsyncMock(return_value=None) await service.set_setting("test_key", "test_value", "test_user") diff --git a/tests/unit/test_runtime_lifecycle.py b/tests/unit/test_runtime_lifecycle.py index b774768..6c945cf 100644 --- a/tests/unit/test_runtime_lifecycle.py +++ b/tests/unit/test_runtime_lifecycle.py @@ -49,9 +49,7 @@ def _mock_pg_store(): @pytest.fixture def app(): """Create a test app with a mocked PgStore and audit repository.""" - a = create_app( - Settings(_env_file=None, APP_MODE="paper", paper_trading_mode=True) - ) + a = create_app(Settings(_env_file=None, APP_MODE="paper", paper_trading_mode=True)) a.state.store = _mock_pg_store() a.state.runtime_state_repository.insert = AsyncMock() a.state.runtime_state_repository.latest = AsyncMock(return_value=None) @@ -69,16 +67,14 @@ async def test_persist_runtime_snapshot_writes_record(app) -> None: # Mock _open_trade_count → 0, _latest_balances → None conn = await app.state.store.pool.acquire().__aenter__() - conn.fetchrow = AsyncMock(return_value=MagicMock( - **{"__getitem__": lambda s, k: 0})) + conn.fetchrow = AsyncMock(return_value=MagicMock(**{"__getitem__": lambda s, k: 0})) snapshot = await persist_runtime_snapshot(app, note="unit-test") assert snapshot is not None assert snapshot.note == "unit-test" - app.state.runtime_state_repository.latest = AsyncMock( - return_value=snapshot) + app.state.runtime_state_repository.latest = AsyncMock(return_value=snapshot) latest = await app.state.runtime_state_repository.latest() assert latest is not None assert latest.note == "unit-test"