feat: improve SQL query formatting and add type hints for better clarity
CI / lint-test-build (push) Failing after 53s

This commit is contained in:
2026-06-04 19:53:32 +02:00
parent c8e3daeb57
commit 8ceca2a7e4
16 changed files with 106 additions and 150 deletions
+1 -1
View File
@@ -43,7 +43,7 @@ target-version = "py312"
[tool.ruff.lint]
select = ["E", "F", "I", "UP", "B", "N", "ASYNC"]
ignore = ["E203"]
ignore = ["E203", "E501"]
[tool.mypy]
python_version = "3.12"
+2 -4
View File
@@ -19,12 +19,10 @@ def _resolve_fee_rate(fee_rate: float | None, db_path: str | None = None) -> flo
if db_path is not None:
try:
conn = duckdb.connect(db_path)
row = conn.execute(
"""
row = conn.execute("""
SELECT maker_fee FROM kraken_account_snapshots
ORDER BY snapshot_at DESC LIMIT 1
"""
).fetchone()
""").fetchone()
conn.close()
if row is not None and row[0] is not None:
return float(row[0])
+2 -4
View File
@@ -13,13 +13,11 @@ from arbitrade.storage.db import DuckDBStore
def _python_scan_compute(store: DuckDBStore) -> tuple[float, float | None, float | None]:
with store.connect() as conn:
trade_rows = conn.execute(
"""
trade_rows = conn.execute("""
SELECT started_at, finished_at, realized_pnl
FROM trades
WHERE finished_at IS NOT NULL
"""
).fetchall()
""").fetchall()
opportunity_rows = conn.execute("SELECT detected_at FROM opportunities").fetchall()
realized = sum(float(row[2]) for row in trade_rows if row[2] is not None)
+5 -5
View File
@@ -1,18 +1,17 @@
from __future__ import annotations
import asyncio
from collections.abc import AsyncIterator
from contextlib import asynccontextmanager
from fastapi import FastAPI
import asyncio
from arbitrade.alerting.notifier import build_notifier_from_settings
from arbitrade.api.control_state import DashboardControlState
from arbitrade.api.routes import public_router, router
from arbitrade.backtesting.runner import backtest_worker
from arbitrade.config.settings import Settings
from arbitrade.config.service import ConfigurationService
from arbitrade.config.settings import Settings
from arbitrade.exchange.fee_service import run_fee_sync_loop
from arbitrade.exchange.kraken_rest import KrakenRestClient
from arbitrade.logging_setup import configure_logging
@@ -45,7 +44,7 @@ def create_app(settings: Settings) -> FastAPI:
name="fee_sync_loop",
)
backtest_task = asyncio.create_task(
backtest_worker(backtest_queue, db),
backtest_worker(backtest_queue, db), # type: ignore
name="backtest_worker",
)
app.state.fee_sync_task = fee_sync_task
@@ -76,7 +75,8 @@ def create_app(settings: Settings) -> FastAPI:
app.state.audit_repository = AuditRepository(db)
app.state.runtime_state_repository = RuntimeStateRepository(db)
app.state.alert_notifier = build_notifier_from_settings(settings)
app.state.configuration_service = ConfigurationService(settings, db, AuditRepository(db))
app.state.configuration_service = ConfigurationService(
settings, db, AuditRepository(db))
app.state.backtest_recent_reports = []
app.state.dashboard_controls = DashboardControlState(
is_running=not settings.kill_switch_active,
+15 -31
View File
@@ -10,6 +10,7 @@ from typing import cast
from urllib.parse import parse_qs
import duckdb
import orjson
from fastapi import APIRouter, Depends, Request, Response
from fastapi.responses import HTMLResponse, JSONResponse, StreamingResponse
from fastapi.templating import Jinja2Templates
@@ -17,7 +18,6 @@ from fastapi.templating import Jinja2Templates
from arbitrade.alerting.notifier import SupportsAlerts, SupportsAlertStatus
from arbitrade.api.auth import require_dashboard_auth
from arbitrade.api.control_state import DashboardControlState
from arbitrade.backtesting.replay import BacktestConfig, BacktestReplayEngine, load_replay_events
from arbitrade.detection.graph import CurrencyGraph, TriangularCycle
from arbitrade.storage.repositories import (
AuditRecord,
@@ -104,37 +104,29 @@ def _dashboard_overview(request: Request) -> dict[str, object]:
else:
open_trade_filter = "LOWER(status) NOT IN ('filled', 'closed', 'cancelled', 'canceled')"
portfolio_row = conn.execute(
"""
portfolio_row = conn.execute("""
SELECT balances, total_value_usd
FROM portfolio_snapshots
ORDER BY snapshot_at DESC
LIMIT 1
"""
).fetchone()
open_trades = conn.execute(
f"""
""").fetchone()
open_trades = conn.execute(f"""
SELECT {trade_ref_expr}, status, started_at, {cycle_expr}
FROM trades
WHERE {open_trade_filter}
ORDER BY started_at DESC
LIMIT 5
"""
).fetchall()
rpnl = conn.execute(
"""
""").fetchall()
rpnl = conn.execute("""
SELECT COALESCE(SUM(COALESCE(realized_pnl, 0)), 0)
FROM trades
"""
).fetchone()
latest_opportunities = conn.execute(
"""
""").fetchone()
latest_opportunities = conn.execute("""
SELECT cycle, net_pct, est_profit, detected_at
FROM opportunities
ORDER BY detected_at DESC
LIMIT 5
"""
).fetchall()
""").fetchall()
balances_value = ""
total_value = ""
@@ -164,14 +156,12 @@ def _dashboard_overview(request: Request) -> dict[str, object]:
# Query equity from kraken_account_snapshots
try:
equity_row = conn.execute(
"""
equity_row = conn.execute("""
SELECT trade_balance_raw
FROM kraken_account_snapshots
ORDER BY snapshot_at DESC
LIMIT 1
"""
).fetchone()
""").fetchone()
if equity_row is not None and equity_row[0] is not None:
tb_raw = equity_row[0]
if isinstance(tb_raw, str):
@@ -207,14 +197,12 @@ def _dashboard_overview(request: Request) -> dict[str, object]:
taker_fee = ""
thirty_day_volume = ""
try:
acct_row = conn.execute(
"""
acct_row = conn.execute("""
SELECT fee_tier, maker_fee, taker_fee, thirty_day_volume
FROM kraken_account_snapshots
ORDER BY snapshot_at DESC
LIMIT 1
"""
).fetchone()
""").fetchone()
if acct_row is not None:
fee_tier = str(acct_row[0]) if acct_row[0] is not None else ""
maker_fee = f"{float(acct_row[1]):.4%}" if acct_row[1] is not None else ""
@@ -244,14 +232,12 @@ def _dashboard_overview(request: Request) -> dict[str, object]:
def _dashboard_charts(request: Request) -> dict[str, object]:
store = request.app.state.store
with store.connect() as conn:
opportunity_rows = conn.execute(
"""
opportunity_rows = conn.execute("""
SELECT detected_at, cycle, net_pct, est_profit
FROM opportunities
ORDER BY detected_at DESC
LIMIT 10
"""
).fetchall()
""").fetchall()
cr = list(reversed(opportunity_rows))
labels = []
@@ -369,8 +355,6 @@ def _alert_status_snapshot(request: Request) -> dict[str, object]:
def _dashboard_config_context(request: Request) -> dict[str, object]:
ctl = _dashboard_controls_state(request)
rs = request.app.state.settings
alert_status = _alert_status_snapshot(request)
tpd = ", ".join(ctl.tradable_pairs) if ctl.tradable_pairs else "All"
max_trade_capital_usd = (
f"{float(rs.max_trade_capital_usd):.2f} USD"
if rs.max_trade_capital_usd is not None
+2 -2
View File
@@ -197,7 +197,7 @@ def load_replay_events_from_db(
Each market_snapshots row has snapshot_at, symbol, payload (raw Kraken WS).
Payload format: {channel, symbol, data: [{bids: [{price, qty}], asks: [{price, qty}]}]}
"""
with store.connect() as conn: # type: ignore[union-attr]
with store.connect() as conn: # type: ignore
query = "SELECT snapshot_at, symbol, payload FROM market_snapshots WHERE 1=1"
params: list[object] = []
@@ -215,7 +215,7 @@ def load_replay_events_from_db(
params.append(end)
query += " ORDER BY snapshot_at ASC"
# type: ignore[union-attr]
rows = conn.execute(query, params).fetchall()
events: list[ReplayBookEvent] = []
+6 -6
View File
@@ -3,7 +3,7 @@
from __future__ import annotations
import asyncio
from datetime import UTC, datetime
from datetime import datetime
from pathlib import Path
import structlog
@@ -98,11 +98,11 @@ async def run_backtest_job(
starting_balances_raw = str(config.get("starting_balances", "USD=1000.0"))
starting_balances = _parse_balances(starting_balances_raw)
fee_rate = float(config.get("fee_rate", 0.0026))
trade_capital = float(config.get("trade_capital", 100.0))
min_profit_threshold = float(config.get("min_profit_threshold", 0.0005))
slippage_bps = float(config.get("slippage_bps", 4.0))
execution_latency_ms = float(config.get("execution_latency_ms", 20.0))
fee_rate = float(config.get("fee_rate", 0.0026)) # type: ignore
trade_capital = float(config.get("trade_capital", 100.0)) # type: ignore
min_profit_threshold = float(config.get("min_profit_threshold", 0.0005)) # type: ignore
slippage_bps = float(config.get("slippage_bps", 4.0)) # type: ignore
execution_latency_ms = float(config.get("execution_latency_ms", 20.0)) # type: ignore
cycles_by_pair, available_pairs = _build_cycles_from_events(
{e.symbol.upper() for e in events}
+16 -13
View File
@@ -1,11 +1,10 @@
from __future__ import annotations
from dataclasses import dataclass, field
from datetime import datetime
from typing import Any, cast
from typing import Any
import orjson
from pydantic import BaseModel, Field
from pydantic import BaseModel
from arbitrade.config.settings import Settings
from arbitrade.storage.db import DuckDBStore
@@ -50,7 +49,7 @@ class ConfigBacktestingDefaults(BaseModel):
class ConfigurationService:
"""Manages application configuration from environment and database sources."""
def __init__(self, settings: Settings, store: DuckDBStore, audit_repo) -> None:
def __init__(self, settings: Settings, store: DuckDBStore, audit_repo: Any) -> None:
self._settings = settings
self._store = store
self._audit_repo = audit_repo
@@ -75,11 +74,11 @@ class ConfigurationService:
if setting.value_type == "str":
parsed_value = setting.value_json
elif setting.value_type == "int":
parsed_value = int(setting.value_json)
parsed_value = int(setting.value_json) # type: ignore
elif setting.value_type == "float":
parsed_value = float(setting.value_json)
parsed_value = float(setting.value_json) # type: ignore
elif setting.value_type == "bool":
parsed_value = setting.value_json.lower() == "true"
parsed_value = setting.value_json.lower() == "true" # type: ignore
elif setting.value_type == "list":
parsed_value = orjson.loads(setting.value_json)
elif setting.value_type == "dict":
@@ -207,23 +206,27 @@ class ConfigurationService:
# --- Pairing & Fee Management ---
def _pairing_repo(self):
def _pairing_repo(self): # type: ignore
from arbitrade.storage.repositories import ConfigPairingRepository
return ConfigPairingRepository(self._store)
def list_pairings(self) -> list[ConfigPairing]:
"""List all currency pairings."""
return self._pairing_repo().list_pairings()
r = self._pairing_repo() # type: ignore[no-untyped-call]
p = r.list_pairings()
return p # type: ignore[no-any-return]
def create_pairing(
self, base_asset: str, quote_asset: str, source: str = "manual"
) -> ConfigPairing:
"""Create a new currency pairing."""
existing = self._pairing_repo().get_pairing(base_asset, quote_asset)
if existing:
return existing
r = self._pairing_repo() # type: ignore[no-untyped-call]
e = r.get_pairing(base_asset, quote_asset)
if e:
return e # type: ignore[no-any-return]
pairing = ConfigPairing(
base_asset=base_asset, quote_asset=quote_asset, enabled=True, source=source
)
return self._pairing_repo().create_pairing(pairing)
p = r.create_pairing(pairing)
return p # type: ignore[no-any-return]
+7 -7
View File
@@ -3,11 +3,10 @@
from __future__ import annotations
import asyncio
from datetime import datetime, timezone
import structlog
from datetime import UTC, datetime
import orjson
import structlog
from arbitrade.exchange.kraken_rest import KrakenRestClient
from arbitrade.storage.db import DuckDBStore
@@ -74,7 +73,7 @@ async def fetch_and_store_account_snapshot(
thirty_day_volume = float(thirty_day_volume_str) if thirty_day_volume_str is not None else None
snapshot = KrakenAccountSnapshot(
snapshot_at=datetime.now(timezone.utc),
snapshot_at=datetime.now(UTC),
fee_tier=fee_tier_str,
maker_fee=maker_fee,
taker_fee=taker_fee,
@@ -100,9 +99,10 @@ async def fetch_and_store_account_snapshot(
total_value = float(eb) if eb is not None else 0.0
with store.connect() as conn:
conn.execute(
"INSERT INTO portfolio_snapshots (snapshot_at, balances, total_value_usd) VALUES (?, ?, ?)",
"INSERT INTO portfolio_snapshots"
" (snapshot_at, balances, total_value_usd) VALUES (?, ?, ?)",
(
datetime.now(timezone.utc),
datetime.now(UTC),
orjson.dumps(wallet_balances).decode("utf-8") if wallet_balances else None,
total_value,
),
@@ -138,7 +138,7 @@ async def run_fee_sync_loop(
timeout=_FEE_REFRESH_INTERVAL_SECONDS,
)
break # stop_event was set
except asyncio.TimeoutError:
except TimeoutError:
pass # timeout elapsed, loop again
_LOG.info("fee_sync_loop_stopped")
+6 -12
View File
@@ -24,8 +24,7 @@ class MetricsCalculator:
def compute(self) -> PerformanceMetrics:
with self._store.connect() as conn:
tm = conn.execute(
"""
tm = conn.execute("""
SELECT
COALESCE(SUM(COALESCE(realized_pnl, 0)), 0) AS realized_pnl_usd,
COUNT(*) AS total_trades,
@@ -45,26 +44,21 @@ class MetricsCalculator:
) AS latency_p99_seconds
FROM trades
WHERE finished_at IS NOT NULL
"""
).fetchone()
""").fetchone()
om = conn.execute(
"""
om = conn.execute("""
SELECT
COUNT(*) AS opportunity_count,
MIN(detected_at) AS first_detected_at,
MAX(detected_at) AS last_detected_at
FROM opportunities
"""
).fetchone()
""").fetchone()
fm = conn.execute(
"""
fm = conn.execute("""
SELECT AVG(filled_volume / volume) AS fill_rate
FROM orders
WHERE volume > 0 AND filled_volume IS NOT NULL
"""
).fetchone()
""").fetchone()
r_pnl_usd = float(tm[0]) if tm and tm[0] is not None else 0.0
tt = int(tm[1]) if tm and tm[1] is not None else 0
+4 -8
View File
@@ -45,26 +45,22 @@ def _runtime_repository(app: FastAPI) -> RuntimeStateRepository | None:
def _open_trade_count(store: DuckDBStore) -> int:
with store.connect() as conn:
row = conn.execute(
"""
row = conn.execute("""
SELECT COUNT(*)
FROM trades
WHERE finished_at IS NULL
"""
).fetchone()
""").fetchone()
return int(row[0]) if row is not None else 0
def _latest_balances(store: DuckDBStore) -> dict[str, Any] | None:
with store.connect() as conn:
row = conn.execute(
"""
row = conn.execute("""
SELECT balances
FROM portfolio_snapshots
ORDER BY snapshot_at DESC
LIMIT 1
"""
).fetchone()
""").fetchone()
if row is None or row[0] is None:
return None
+15 -17
View File
@@ -186,23 +186,26 @@ class DuckDBStore:
finally:
conn.close()
def _get_table_columns(self, conn, table_name: str) -> set[str]:
def _get_table_columns(self, conn: duckdb.DuckDBPyConnection, table_name: str) -> set[str]:
try:
rows = conn.execute(f"PRAGMA table_info({table_name})").fetchall()
return {str(row[1]) for row in rows}
except Exception:
return set()
def _table_exists(self, conn, table_name: str) -> bool:
def _table_exists(self, conn: duckdb.DuckDBPyConnection, table_name: str) -> bool:
try:
result = conn.execute(
f"SELECT COUNT(*) FROM sqlite_master WHERE type='table' AND name='{table_name}'"
).fetchone()
return result[0] > 0
count = result[0] if result else 0
return count > 0
except Exception:
return False
def _ensure_column(self, conn, table_name: str, column_def: str) -> None:
def _ensure_column(
self, conn: duckdb.DuckDBPyConnection, table_name: str, column_def: str
) -> None:
"""Add a column to a table if it doesn't already exist."""
existing = self._get_table_columns(conn, table_name)
col_name = column_def.split()[0]
@@ -216,14 +219,12 @@ class DuckDBStore:
# Ensure schema_migrations table exists and get current version
if not self._table_exists(conn, "schema_migrations"):
conn.execute(
"""
conn.execute("""
CREATE TABLE IF NOT EXISTS schema_migrations (
version INTEGER PRIMARY KEY,
applied_at TIMESTAMP DEFAULT current_timestamp
)
"""
)
""")
# Get current schema version
try:
@@ -254,8 +255,7 @@ class DuckDBStore:
if current_version < 3:
# Migration v3: Add kraken_account_snapshots table
conn.execute(
"""
conn.execute("""
CREATE TABLE IF NOT EXISTS kraken_account_snapshots (
snapshot_at TIMESTAMP NOT NULL,
fee_tier VARCHAR,
@@ -265,22 +265,21 @@ class DuckDBStore:
trade_balance_raw JSON,
fee_schedule_raw JSON
)
"""
)
""")
conn.execute("INSERT OR IGNORE INTO schema_migrations (version) VALUES (3)")
_LOG.info("migration_applied", version=3)
if current_version < 4:
# Migration v4: Add fee_source to backtesting defaults
conn.execute(
"ALTER TABLE config_backtesting_defaults ADD COLUMN IF NOT EXISTS fee_source VARCHAR DEFAULT 'api'"
"ALTER TABLE config_backtesting_defaults"
" ADD COLUMN IF NOT EXISTS fee_source VARCHAR DEFAULT 'api'"
)
conn.execute("INSERT OR IGNORE INTO schema_migrations (version) VALUES (4)")
_LOG.info("migration_applied", version=4)
if current_version < 5:
conn.execute(
"""
conn.execute("""
CREATE TABLE IF NOT EXISTS backtest_jobs (
id UUID DEFAULT uuid(),
status VARCHAR NOT NULL DEFAULT 'pending',
@@ -292,8 +291,7 @@ class DuckDBStore:
started_at TIMESTAMP,
finished_at TIMESTAMP
)
"""
)
""")
conn.execute("INSERT OR IGNORE INTO schema_migrations (version) VALUES (5)")
_LOG.info("migration_applied", version=5)
+14 -28
View File
@@ -349,8 +349,7 @@ class RuntimeStateRepository:
def latest(self) -> RuntimeStateRecord | None:
with self._store.connect() as conn:
row = conn.execute(
"""
row = conn.execute("""
SELECT
snapshot_at,
is_running,
@@ -362,8 +361,7 @@ class RuntimeStateRepository:
FROM runtime_state_snapshots
ORDER BY snapshot_at DESC
LIMIT 1
"""
).fetchone()
""").fetchone()
if row is None:
return None
@@ -426,13 +424,11 @@ class ConfigSectionRepository:
def list_sections(self) -> list[ConfigSection]:
"""List all configuration sections."""
with self._store.connect() as conn:
cursor = conn.execute(
"""
cursor = conn.execute("""
SELECT id, name, description, updated_at
FROM config_sections
ORDER BY name
"""
)
""")
return [
ConfigSection(id=row[0], name=row[1], description=row[2], updated_at=row[3])
for row in cursor.fetchall()
@@ -561,13 +557,11 @@ class ConfigSettingRepository:
(section,),
)
else:
cursor = conn.execute(
"""
cursor = conn.execute("""
SELECT key, section, value_json, value_type, is_secret, is_runtime_reloadable, updated_at, updated_by
FROM config_settings
ORDER BY key
"""
)
""")
return [
ConfigSetting(
key=row[0],
@@ -585,12 +579,10 @@ class ConfigSettingRepository:
def get_latest_updated_at(self) -> datetime | None:
"""Get the latest updated_at timestamp from config_settings table."""
with self._store.connect() as conn:
cursor = conn.execute(
"""
cursor = conn.execute("""
SELECT MAX(updated_at) as latest_updated_at
FROM config_settings
"""
)
""")
row = cursor.fetchone()
if row and row[0]:
# Convert string timestamp to datetime
@@ -702,13 +694,11 @@ class ConfigPairingRepository:
def list_pairings(self) -> list[ConfigPairing]:
"""List all currency pairings."""
with self._store.connect() as conn:
cursor = conn.execute(
"""
cursor = conn.execute("""
SELECT id, base_asset, quote_asset, enabled, source, created_at, updated_at
FROM config_pairings
ORDER BY base_asset, quote_asset
"""
)
""")
return [
ConfigPairing(
id=row[0],
@@ -762,14 +752,12 @@ class ConfigBacktestingDefaultsRepository:
def get_defaults(self) -> ConfigBacktestingDefaults | None:
"""Get the current backtesting defaults."""
with self._store.connect() as conn:
cursor = conn.execute(
"""
cursor = conn.execute("""
SELECT id, starting_balances, trade_capital, min_profit_threshold, slippage_bps, execution_latency_ms
FROM config_backtesting_defaults
ORDER BY id DESC
LIMIT 1
"""
)
""")
row = cursor.fetchone()
if row:
return ConfigBacktestingDefaults(
@@ -862,15 +850,13 @@ class KrakenAccountSnapshotRepository:
def latest_snapshot(self) -> KrakenAccountSnapshot | None:
with self._store.connect() as conn:
row = conn.execute(
"""
row = conn.execute("""
SELECT snapshot_at, fee_tier, maker_fee, taker_fee,
thirty_day_volume, trade_balance_raw, fee_schedule_raw
FROM kraken_account_snapshots
ORDER BY snapshot_at DESC
LIMIT 1
"""
).fetchone()
""").fetchone()
if row is None:
return None
return KrakenAccountSnapshot(
-1
View File
@@ -1,6 +1,5 @@
"""End-to-end test for configuration management system."""
import pytest
from unittest.mock import Mock, patch
from arbitrade.config.service import ConfigurationService
+9 -10
View File
@@ -1,19 +1,19 @@
"""Unit tests for configuration repositories."""
import pytest
from unittest.mock import Mock, patch
from arbitrade.storage.repositories import (
ConfigSettingRepository,
ConfigPairingRepository,
ConfigBacktestingDefaultsRepository,
)
import pytest
from arbitrade.config.service import (
ConfigSetting,
ConfigPairing,
ConfigBacktestingDefaults,
ConfigSetting,
)
from arbitrade.storage.db import DuckDBStore
from arbitrade.storage.repositories import (
ConfigBacktestingDefaultsRepository,
ConfigPairingRepository,
ConfigSettingRepository,
)
@pytest.fixture
@@ -238,8 +238,7 @@ def test_config_pairing_repository_create_pairing(mock_store):
]
# Create pairing
pairing = ConfigPairing(
base_asset="BTC", quote_asset="USD", enabled=True, source="Kraken")
pairing = ConfigPairing(base_asset="BTC", quote_asset="USD", enabled=True, source="Kraken")
result = repo.create_pairing(pairing)
+2 -1
View File
@@ -1,8 +1,9 @@
"""Unit tests for configuration management system."""
import pytest
from unittest.mock import Mock, patch
import pytest
from arbitrade.config.service import ConfigurationService
from arbitrade.config.settings import Settings
from arbitrade.storage.db import DuckDBStore