Refactor code for improved readability and consistency
CI / lint-test-build (push) Successful in 54s

- Cleaned up multiline statements and removed unnecessary line breaks in various files.
- Ensured consistent formatting in function definitions and calls across the codebase.
- Updated docstrings and comments for clarity where applicable.
- Removed trailing newlines in module docstrings.
- Enhanced logging statements for better clarity in maintenance tasks.
This commit is contained in:
2026-06-07 21:59:09 +02:00
parent f221464daa
commit dc99f1604e
25 changed files with 409 additions and 324 deletions
+1 -2
View File
@@ -67,8 +67,7 @@ async def _seed_dataset(store: PgStore) -> None:
opportunity_rows: list[tuple[object, ...]] = [] opportunity_rows: list[tuple[object, ...]] = []
for i in range(5000): for i in range(5000):
detected_at = now + timedelta(milliseconds=200 * i) detected_at = now + timedelta(milliseconds=200 * i)
opportunity_rows.append( opportunity_rows.append((detected_at, "USD->BTC->ETH->USD", 2.5, 1.2, 0.03, bool(i % 2)))
(detected_at, "USD->BTC->ETH->USD", 2.5, 1.2, 0.03, bool(i % 2)))
order_rows: list[tuple[object, ...]] = [] order_rows: list[tuple[object, ...]] = []
for i in range(3500): for i in range(3500):
+40 -36
View File
@@ -36,8 +36,7 @@ public_router = APIRouter()
def _resolve_templates_directory() -> str: def _resolve_templates_directory() -> str:
# Support source layout, Docker runtime (/app), and installed package data. # Support source layout, Docker runtime (/app), and installed package data.
source_layout_path = Path( source_layout_path = Path(__file__).resolve().parents[3] / "web" / "templates"
__file__).resolve().parents[3] / "web" / "templates"
if source_layout_path.is_dir(): if source_layout_path.is_dir():
return str(source_layout_path) return str(source_layout_path)
@@ -46,8 +45,7 @@ def _resolve_templates_directory() -> str:
return str(docker_runtime_path) return str(docker_runtime_path)
try: try:
package_path = resources.files( package_path = resources.files("arbitrade").joinpath("web", "templates")
"arbitrade").joinpath("web", "templates")
if package_path.is_dir(): if package_path.is_dir():
return str(package_path) return str(package_path)
except (ModuleNotFoundError, AttributeError): except (ModuleNotFoundError, AttributeError):
@@ -153,12 +151,22 @@ async def _dashboard_overview(request: Request) -> dict[str, object]:
LIMIT 1 LIMIT 1
""") """)
if acct_row is not None: if acct_row is not None:
fee_tier = str( fee_tier = str(acct_row["fee_tier"]) if acct_row["fee_tier"] is not None else ""
acct_row["fee_tier"]) if acct_row["fee_tier"] is not None else "" maker_fee = (
maker_fee = f"{float(acct_row['maker_fee']):.4%}" if acct_row["maker_fee"] is not None else "" f"{float(acct_row['maker_fee']):.4%}"
taker_fee = f"{float(acct_row['taker_fee']):.4%}" if acct_row["taker_fee"] is not None else "" if acct_row["maker_fee"] is not None
thirty_day_volume = f"{float(acct_row['thirty_day_volume']):.2f}" if acct_row[ else ""
"thirty_day_volume"] is not None else "" )
taker_fee = (
f"{float(acct_row['taker_fee']):.4%}"
if acct_row["taker_fee"] is not None
else ""
)
thirty_day_volume = (
f"{float(acct_row['thirty_day_volume']):.2f}"
if acct_row["thirty_day_volume"] is not None
else ""
)
except Exception: except Exception:
pass pass
@@ -171,8 +179,7 @@ async def _dashboard_overview(request: Request) -> dict[str, object]:
try: try:
parsed = json.loads(balances_raw) parsed = json.loads(balances_raw)
if isinstance(parsed, dict): if isinstance(parsed, dict):
non_zero = {k: float(v) non_zero = {k: float(v) for k, v in parsed.items() if float(v) > 0.0}
for k, v in parsed.items() if float(v) > 0.0}
if non_zero: if non_zero:
balances_value = "<br>".join( balances_value = "<br>".join(
f"{v:.6g} {k}" for k, v in sorted(non_zero.items()) f"{v:.6g} {k}" for k, v in sorted(non_zero.items())
@@ -192,7 +199,9 @@ async def _dashboard_overview(request: Request) -> dict[str, object]:
{ {
"trade_ref": str(r["trade_ref"]), "trade_ref": str(r["trade_ref"]),
"status": str(r["status"]), "status": str(r["status"]),
"started_at": r["started_at"].isoformat() if isinstance(r["started_at"], datetime) else "", "started_at": (
r["started_at"].isoformat() if isinstance(r["started_at"], datetime) else ""
),
"cycle": str(r["cycle"]) if r["cycle"] is not None else "", "cycle": str(r["cycle"]) if r["cycle"] is not None else "",
} }
for r in open_trades for r in open_trades
@@ -201,8 +210,12 @@ async def _dashboard_overview(request: Request) -> dict[str, object]:
{ {
"cycle": str(r["cycle"]), "cycle": str(r["cycle"]),
"net_pct": f"{float(r['net_pct']):.2f}%" if r["net_pct"] is not None else "", "net_pct": f"{float(r['net_pct']):.2f}%" if r["net_pct"] is not None else "",
"est_profit": f"{float(r['est_profit']):.2f} USD" if r["est_profit"] is not None else "", "est_profit": (
"detected_at": r["detected_at"].isoformat() if isinstance(r["detected_at"], datetime) else "", f"{float(r['est_profit']):.2f} USD" if r["est_profit"] is not None else ""
),
"detected_at": (
r["detected_at"].isoformat() if isinstance(r["detected_at"], datetime) else ""
),
} }
for r in latest_opportunities for r in latest_opportunities
] ]
@@ -242,10 +255,8 @@ async def _dashboard_charts(request: Request) -> dict[str, object]:
labels.append(row["detected_at"].isoformat()) labels.append(row["detected_at"].isoformat())
else: else:
labels.append(f"opportunity-{index + 1}") labels.append(f"opportunity-{index + 1}")
np = [float(row["net_pct"]) if row["net_pct"] np = [float(row["net_pct"]) if row["net_pct"] is not None else 0.0 for row in cr]
is not None else 0.0 for row in cr] ep = [float(row["est_profit"]) if row["est_profit"] is not None else 0.0 for row in cr]
ep = [float(row["est_profit"]) if row["est_profit"]
is not None else 0.0 for row in cr]
cycles = [str(row["cycle"]) for row in cr] cycles = [str(row["cycle"]) for row in cr]
return { return {
@@ -411,8 +422,7 @@ async def _dashboard_config_context(request: Request) -> dict[str, object]:
max_consecutive_failures_value = ( max_consecutive_failures_value = (
str(rs.max_consecutive_failures) if rs.max_consecutive_failures is not None else "" str(rs.max_consecutive_failures) if rs.max_consecutive_failures is not None else ""
) )
strategy_stat_arb_enabled = bool( strategy_stat_arb_enabled = bool(getattr(rs, "strategy_enable_stat_arb_experiment", False))
getattr(rs, "strategy_enable_stat_arb_experiment", False))
return { return {
# Runtime # Runtime
@@ -533,8 +543,7 @@ def _dashboard_controls(request: Request) -> dict[str, object]:
alerts_last_channel_results = [ alerts_last_channel_results = [
str(item) for item in cast(list[object], alert_status.get("last_channel_results", [])) str(item) for item in cast(list[object], alert_status.get("last_channel_results", []))
] ]
strategy_stat_arb_enabled = bool( strategy_stat_arb_enabled = bool(getattr(rs, "strategy_enable_stat_arb_experiment", False))
getattr(rs, "strategy_enable_stat_arb_experiment", False))
return { return {
"execution_status": "running" if ctl.is_running else "stopped", "execution_status": "running" if ctl.is_running else "stopped",
@@ -813,7 +822,6 @@ async def dashboard_backtesting_page(request: Request) -> HTMLResponse:
@router.get("/dashboard/fragment/backtesting", response_class=HTMLResponse) @router.get("/dashboard/fragment/backtesting", response_class=HTMLResponse)
async def dashboard_backtesting_fragment(request: Request) -> HTMLResponse: async def dashboard_backtesting_fragment(request: Request) -> HTMLResponse:
d_context = await _dashboard_config_context(request)
ctx = await _backtesting_panel_context(request) ctx = await _backtesting_panel_context(request)
ctx["flash_message"] = "" ctx["flash_message"] = ""
# Check if any pairings are enabled # Check if any pairings are enabled
@@ -992,19 +1000,18 @@ async def dashboard_backtesting_run(request: Request) -> HTMLResponse:
try: try:
custom_fee_rate = ( custom_fee_rate = (
float(defaults["custom_fee_rate"] float(defaults["custom_fee_rate"]) if defaults["custom_fee_rate"].strip() else None
) if defaults["custom_fee_rate"].strip() else None )
fee_rate = await _fee_rate_for_profile(
defaults["fee_profile"], custom_fee_rate, request=request
) )
fee_rate = await _fee_rate_for_profile(defaults["fee_profile"], custom_fee_rate, request=request)
# Use enabled pairings from DB when none selected # Use enabled pairings from DB when none selected
symbols_str = defaults["symbols"] symbols_str = defaults["symbols"]
if not symbols_str.strip(): if not symbols_str.strip():
pairing_repo = ConfigPairingRepository(request.app.state.store) pairing_repo = ConfigPairingRepository(request.app.state.store)
enabled = await pairing_repo.list_pairings(enabled_only=True) enabled = await pairing_repo.list_pairings(enabled_only=True)
symbols_str = ",".join( symbols_str = ",".join(f"{p.base_asset}/{p.quote_asset}" for p in enabled)
f"{p.base_asset}/{p.quote_asset}" for p in enabled
)
config_dict: dict[str, object] = { config_dict: dict[str, object] = {
"source": defaults["source"], "source": defaults["source"],
@@ -1133,8 +1140,7 @@ async def dashboard_backtesting_export(request: Request, job_id: str) -> Respons
return Response( return Response(
content=orjson.dumps(payload).decode("utf-8"), content=orjson.dumps(payload).decode("utf-8"),
media_type="application/x-jsonlines", media_type="application/x-jsonlines",
headers={ headers={"Content-Disposition": f"attachment; filename=backtest_{job_id[:8]}.jsonl"},
"Content-Disposition": f"attachment; filename=backtest_{job_id[:8]}.jsonl"},
) )
@@ -1383,11 +1389,9 @@ async def dashboard_api_pairings(
if source: if source:
pairings = [p for p in pairings if p.source.lower() == source.lower()] pairings = [p for p in pairings if p.source.lower() == source.lower()]
if base: if base:
pairings = [p for p in pairings if p.base_asset.lower() == pairings = [p for p in pairings if p.base_asset.lower() == base.lower()]
base.lower()]
if quote: if quote:
pairings = [p for p in pairings if p.quote_asset.lower() == pairings = [p for p in pairings if p.quote_asset.lower() == quote.lower()]
quote.lower()]
# Sort # Sort
reverse = order.lower() == "desc" reverse = order.lower() == "desc"
+38 -76
View File
@@ -31,41 +31,26 @@ class Settings(BaseSettings):
) )
alerts_enabled: bool = Field(default=True, alias="ALERTS_ENABLED") alerts_enabled: bool = Field(default=True, alias="ALERTS_ENABLED")
alert_min_severity: str = Field( alert_min_severity: str = Field(default="warning", alias="ALERT_MIN_SEVERITY")
default="warning", alias="ALERT_MIN_SEVERITY") alert_dedup_seconds: float = Field(default=30.0, alias="ALERT_DEDUP_SECONDS")
alert_dedup_seconds: float = Field( alert_on_trade_events: bool = Field(default=True, alias="ALERT_ON_TRADE_EVENTS")
default=30.0, alias="ALERT_DEDUP_SECONDS") alert_on_error_events: bool = Field(default=True, alias="ALERT_ON_ERROR_EVENTS")
alert_on_trade_events: bool = Field( alert_on_threshold_events: bool = Field(default=True, alias="ALERT_ON_THRESHOLD_EVENTS")
default=True, alias="ALERT_ON_TRADE_EVENTS") alert_on_system_events: bool = Field(default=True, alias="ALERT_ON_SYSTEM_EVENTS")
alert_on_error_events: bool = Field(
default=True, alias="ALERT_ON_ERROR_EVENTS")
alert_on_threshold_events: bool = Field(
default=True, alias="ALERT_ON_THRESHOLD_EVENTS")
alert_on_system_events: bool = Field(
default=True, alias="ALERT_ON_SYSTEM_EVENTS")
telegram_alerts_enabled: bool = Field( telegram_alerts_enabled: bool = Field(default=False, alias="TELEGRAM_ALERTS_ENABLED")
default=False, alias="TELEGRAM_ALERTS_ENABLED") telegram_bot_token: str | None = Field(default=None, alias="TELEGRAM_BOT_TOKEN")
telegram_bot_token: str | None = Field( telegram_chat_id: str | None = Field(default=None, alias="TELEGRAM_CHAT_ID")
default=None, alias="TELEGRAM_BOT_TOKEN")
telegram_chat_id: str | None = Field(
default=None, alias="TELEGRAM_CHAT_ID")
discord_alerts_enabled: bool = Field( discord_alerts_enabled: bool = Field(default=False, alias="DISCORD_ALERTS_ENABLED")
default=False, alias="DISCORD_ALERTS_ENABLED") discord_webhook_url: str | None = Field(default=None, alias="DISCORD_WEBHOOK_URL")
discord_webhook_url: str | None = Field(
default=None, alias="DISCORD_WEBHOOK_URL")
email_alerts_enabled: bool = Field( email_alerts_enabled: bool = Field(default=False, alias="EMAIL_ALERTS_ENABLED")
default=False, alias="EMAIL_ALERTS_ENABLED")
email_smtp_host: str | None = Field(default=None, alias="EMAIL_SMTP_HOST") email_smtp_host: str | None = Field(default=None, alias="EMAIL_SMTP_HOST")
email_smtp_port: int = Field(default=587, alias="EMAIL_SMTP_PORT") email_smtp_port: int = Field(default=587, alias="EMAIL_SMTP_PORT")
email_smtp_username: str | None = Field( email_smtp_username: str | None = Field(default=None, alias="EMAIL_SMTP_USERNAME")
default=None, alias="EMAIL_SMTP_USERNAME") email_smtp_password: str | None = Field(default=None, alias="EMAIL_SMTP_PASSWORD")
email_smtp_password: str | None = Field( email_alert_from: str | None = Field(default=None, alias="EMAIL_ALERT_FROM")
default=None, alias="EMAIL_SMTP_PASSWORD")
email_alert_from: str | None = Field(
default=None, alias="EMAIL_ALERT_FROM")
email_alert_to: str | None = Field(default=None, alias="EMAIL_ALERT_TO") email_alert_to: str | None = Field(default=None, alias="EMAIL_ALERT_TO")
email_smtp_use_tls: bool = Field(default=True, alias="EMAIL_SMTP_USE_TLS") email_smtp_use_tls: bool = Field(default=True, alias="EMAIL_SMTP_USE_TLS")
@@ -78,31 +63,24 @@ class Settings(BaseSettings):
pg_min_connections: int = Field(default=2, alias="PG_MIN_CONNECTIONS") pg_min_connections: int = Field(default=2, alias="PG_MIN_CONNECTIONS")
pg_max_connections: int = Field(default=10, alias="PG_MAX_CONNECTIONS") pg_max_connections: int = Field(default=10, alias="PG_MAX_CONNECTIONS")
kraken_rest_url: str = Field( kraken_rest_url: str = Field(default="https://api.kraken.com", alias="KRAKEN_REST_URL")
default="https://api.kraken.com", alias="KRAKEN_REST_URL") kraken_ws_url: str = Field(default="wss://ws.kraken.com/v2", alias="KRAKEN_WS_URL")
kraken_ws_url: str = Field(
default="wss://ws.kraken.com/v2", alias="KRAKEN_WS_URL")
kraken_private_rate_limit_seconds: float = Field( kraken_private_rate_limit_seconds: float = Field(
default=1.0, alias="KRAKEN_PRIVATE_RATE_LIMIT_SECONDS" default=1.0, alias="KRAKEN_PRIVATE_RATE_LIMIT_SECONDS"
) )
kraken_http_timeout_seconds: float = Field( kraken_http_timeout_seconds: float = Field(default=10.0, alias="KRAKEN_HTTP_TIMEOUT_SECONDS")
default=10.0, alias="KRAKEN_HTTP_TIMEOUT_SECONDS") kraken_retry_attempts: int = Field(default=3, alias="KRAKEN_RETRY_ATTEMPTS")
kraken_retry_attempts: int = Field(
default=3, alias="KRAKEN_RETRY_ATTEMPTS")
kraken_retry_base_delay_seconds: float = Field( kraken_retry_base_delay_seconds: float = Field(
default=0.25, alias="KRAKEN_RETRY_BASE_DELAY_SECONDS" default=0.25, alias="KRAKEN_RETRY_BASE_DELAY_SECONDS"
) )
kraken_api_key: str | None = Field(default=None, alias="KRAKEN_API_KEY") kraken_api_key: str | None = Field(default=None, alias="KRAKEN_API_KEY")
kraken_api_secret: str | None = Field( kraken_api_secret: str | None = Field(default=None, alias="KRAKEN_API_SECRET")
default=None, alias="KRAKEN_API_SECRET")
kraken_api_key_permissions: str = Field( kraken_api_key_permissions: str = Field(
default="query,trade", default="query,trade",
alias="KRAKEN_API_KEY_PERMISSIONS", alias="KRAKEN_API_KEY_PERMISSIONS",
) )
ws_heartbeat_timeout_seconds: float = Field( ws_heartbeat_timeout_seconds: float = Field(default=20.0, alias="WS_HEARTBEAT_TIMEOUT_SECONDS")
default=20.0, alias="WS_HEARTBEAT_TIMEOUT_SECONDS") ws_max_staleness_seconds: float = Field(default=5.0, alias="WS_MAX_STALENESS_SECONDS")
ws_max_staleness_seconds: float = Field(
default=5.0, alias="WS_MAX_STALENESS_SECONDS")
strategy_enable_stat_arb_experiment: bool = Field( strategy_enable_stat_arb_experiment: bool = Field(
default=False, default=False,
alias="STRATEGY_ENABLE_STAT_ARB_EXPERIMENT", alias="STRATEGY_ENABLE_STAT_ARB_EXPERIMENT",
@@ -125,29 +103,20 @@ class Settings(BaseSettings):
) )
paper_trading_mode: bool = Field(default=True, alias="PAPER_TRADING_MODE") paper_trading_mode: bool = Field(default=True, alias="PAPER_TRADING_MODE")
trade_capital_usd: float = Field(default=100.0, alias="TRADE_CAPITAL_USD") trade_capital_usd: float = Field(default=100.0, alias="TRADE_CAPITAL_USD")
max_trade_capital_usd: float = Field( max_trade_capital_usd: float = Field(default=100.0, alias="MAX_TRADE_CAPITAL_USD")
default=100.0, alias="MAX_TRADE_CAPITAL_USD") max_concurrent_trades: int | None = Field(default=None, alias="MAX_CONCURRENT_TRADES")
max_concurrent_trades: int | None = Field(
default=None, alias="MAX_CONCURRENT_TRADES")
max_exposure_per_asset_usd: float | None = Field( max_exposure_per_asset_usd: float | None = Field(
default=None, default=None,
alias="MAX_EXPOSURE_PER_ASSET_USD", alias="MAX_EXPOSURE_PER_ASSET_USD",
) )
quote_balance_asset: str = Field( quote_balance_asset: str = Field(default="USD", alias="QUOTE_BALANCE_ASSET")
default="USD", alias="QUOTE_BALANCE_ASSET") min_order_size_usd: float | None = Field(default=None, alias="MIN_ORDER_SIZE_USD")
min_order_size_usd: float | None = Field(
default=None, alias="MIN_ORDER_SIZE_USD")
kill_switch_active: bool = Field(default=False, alias="KILL_SWITCH_ACTIVE") kill_switch_active: bool = Field(default=False, alias="KILL_SWITCH_ACTIVE")
daily_loss_limit_usd: float | None = Field( daily_loss_limit_usd: float | None = Field(default=None, alias="DAILY_LOSS_LIMIT_USD")
default=None, alias="DAILY_LOSS_LIMIT_USD") cumulative_loss_limit_usd: float | None = Field(default=None, alias="CUMULATIVE_LOSS_LIMIT_USD")
cumulative_loss_limit_usd: float | None = Field( max_source_latency_ms: float | None = Field(default=None, alias="MAX_SOURCE_LATENCY_MS")
default=None, alias="CUMULATIVE_LOSS_LIMIT_USD") max_apply_latency_ms: float | None = Field(default=None, alias="MAX_APPLY_LATENCY_MS")
max_source_latency_ms: float | None = Field( max_consecutive_failures: int | None = Field(default=None, alias="MAX_CONSECUTIVE_FAILURES")
default=None, alias="MAX_SOURCE_LATENCY_MS")
max_apply_latency_ms: float | None = Field(
default=None, alias="MAX_APPLY_LATENCY_MS")
max_consecutive_failures: int | None = Field(
default=None, alias="MAX_CONSECUTIVE_FAILURES")
fernet_key: str | None = Field(default=None, alias="FERNET_KEY") fernet_key: str | None = Field(default=None, alias="FERNET_KEY")
@@ -164,8 +133,7 @@ class Settings(BaseSettings):
def _validate_log_level(cls, value: str) -> str: def _validate_log_level(cls, value: str) -> str:
normalized = value.strip().upper() normalized = value.strip().upper()
if normalized not in {"DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"}: if normalized not in {"DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"}:
raise ValueError( raise ValueError("LOG_LEVEL must be one of: DEBUG, INFO, WARNING, ERROR, CRITICAL")
"LOG_LEVEL must be one of: DEBUG, INFO, WARNING, ERROR, CRITICAL")
return normalized return normalized
@field_validator("alert_min_severity") @field_validator("alert_min_severity")
@@ -173,19 +141,16 @@ class Settings(BaseSettings):
def _validate_alert_severity(cls, value: str) -> str: def _validate_alert_severity(cls, value: str) -> str:
normalized = value.strip().lower() normalized = value.strip().lower()
if normalized not in {"info", "warning", "error", "critical"}: if normalized not in {"info", "warning", "error", "critical"}:
raise ValueError( raise ValueError("ALERT_MIN_SEVERITY must be one of: info, warning, error, critical")
"ALERT_MIN_SEVERITY must be one of: info, warning, error, critical")
return normalized return normalized
@model_validator(mode="after") @model_validator(mode="after")
def _validate_security_constraints(self) -> Settings: def _validate_security_constraints(self) -> Settings:
if bool(self.dashboard_auth_username) ^ bool(self.dashboard_auth_password): if bool(self.dashboard_auth_username) ^ bool(self.dashboard_auth_password):
raise ValueError( raise ValueError("dashboard auth requires both username and password")
"dashboard auth requires both username and password")
if bool(self.kraken_api_key) ^ bool(self.kraken_api_secret): if bool(self.kraken_api_key) ^ bool(self.kraken_api_secret):
raise ValueError( raise ValueError("Kraken API auth requires both API key and secret")
"Kraken API auth requires both API key and secret")
permissions = { permissions = {
token.strip().lower() token.strip().lower()
@@ -193,11 +158,9 @@ class Settings(BaseSettings):
if token.strip() if token.strip()
} }
if permissions and ("query" not in permissions or "trade" not in permissions): if permissions and ("query" not in permissions or "trade" not in permissions):
raise ValueError( raise ValueError("KRAKEN_API_KEY_PERMISSIONS must include query and trade")
"KRAKEN_API_KEY_PERMISSIONS must include query and trade")
if "withdraw" in permissions or "withdrawals" in permissions: if "withdraw" in permissions or "withdrawals" in permissions:
raise ValueError( raise ValueError("KRAKEN_API_KEY_PERMISSIONS must not include withdrawal scope")
"KRAKEN_API_KEY_PERMISSIONS must not include withdrawal scope")
if self.alert_dedup_seconds < 0.0: if self.alert_dedup_seconds < 0.0:
raise ValueError("ALERT_DEDUP_SECONDS must be >= 0") raise ValueError("ALERT_DEDUP_SECONDS must be >= 0")
@@ -213,8 +176,7 @@ class Settings(BaseSettings):
"STRATEGY_STAT_ARB_ENTRY_ZSCORE must be greater than STRATEGY_STAT_ARB_EXIT_ZSCORE" "STRATEGY_STAT_ARB_ENTRY_ZSCORE must be greater than STRATEGY_STAT_ARB_EXIT_ZSCORE"
) )
if self.strategy_stat_arb_max_holding_seconds <= 0.0: if self.strategy_stat_arb_max_holding_seconds <= 0.0:
raise ValueError( raise ValueError("STRATEGY_STAT_ARB_MAX_HOLDING_SECONDS must be > 0")
"STRATEGY_STAT_ARB_MAX_HOLDING_SECONDS must be > 0")
return self return self
+7 -14
View File
@@ -42,12 +42,9 @@ async def fetch_and_store_account_snapshot(
_LOG.exception("trade_balance_fetch_failed") _LOG.exception("trade_balance_fetch_failed")
return None return None
fee_tier = volume_data.get("fee_tier") if isinstance( fee_tier = volume_data.get("fee_tier") if isinstance(volume_data, dict) else None
volume_data, dict) else None fees_dict = volume_data.get("fees") if isinstance(volume_data, dict) else None
fees_dict = volume_data.get("fees") if isinstance( fees_maker = volume_data.get("fees_maker") if isinstance(volume_data, dict) else None
volume_data, dict) else None
fees_maker = volume_data.get("fees_maker") if isinstance(
volume_data, dict) else None
currency = volume_data.get("currency") currency = volume_data.get("currency")
thirty_day_volume_str = volume_data.get("volume") thirty_day_volume_str = volume_data.get("volume")
@@ -73,8 +70,7 @@ async def fetch_and_store_account_snapshot(
if currency is not None: if currency is not None:
fee_schedule["currency"] = currency fee_schedule["currency"] = currency
thirty_day_volume = float( thirty_day_volume = float(thirty_day_volume_str) if thirty_day_volume_str is not None else None
thirty_day_volume_str) if thirty_day_volume_str is not None else None
snapshot = KrakenAccountSnapshot( snapshot = KrakenAccountSnapshot(
snapshot_at=datetime.now(UTC), snapshot_at=datetime.now(UTC),
@@ -82,8 +78,7 @@ async def fetch_and_store_account_snapshot(
maker_fee=maker_fee, maker_fee=maker_fee,
taker_fee=taker_fee, taker_fee=taker_fee,
thirty_day_volume=thirty_day_volume, thirty_day_volume=thirty_day_volume,
trade_balance_raw=balance_data if isinstance( trade_balance_raw=balance_data if isinstance(balance_data, dict) else None,
balance_data, dict) else None,
fee_schedule_raw=fee_schedule if fee_schedule else None, fee_schedule_raw=fee_schedule if fee_schedule else None,
) )
@@ -107,8 +102,7 @@ async def fetch_and_store_account_snapshot(
"INSERT INTO portfolio_snapshots" "INSERT INTO portfolio_snapshots"
" (snapshot_at, balances, total_value_usd) VALUES ($1, $2, $3)", " (snapshot_at, balances, total_value_usd) VALUES ($1, $2, $3)",
datetime.now(UTC), datetime.now(UTC),
orjson.dumps(wallet_balances).decode( orjson.dumps(wallet_balances).decode("utf-8") if wallet_balances else None,
"utf-8") if wallet_balances else None,
total_value, total_value,
) )
_LOG.info("portfolio_snapshot_stored", total_value_usd=total_value) _LOG.info("portfolio_snapshot_stored", total_value_usd=total_value)
@@ -127,8 +121,7 @@ async def run_fee_sync_loop(
Runs until stop_event is set. Runs until stop_event is set.
""" """
_LOG.info("fee_sync_loop_started", _LOG.info("fee_sync_loop_started", interval_s=_FEE_REFRESH_INTERVAL_SECONDS)
interval_s=_FEE_REFRESH_INTERVAL_SECONDS)
while not stop_event.is_set(): while not stop_event.is_set():
try: try:
+7 -14
View File
@@ -47,15 +47,13 @@ class TriangularExecutionSequencer:
rest_client: SupportsOrderPlacement, rest_client: SupportsOrderPlacement,
*, *,
available_pairs: Sequence[str], available_pairs: Sequence[str],
volume_for_leg: Callable[[OpportunityEvent, volume_for_leg: Callable[[OpportunityEvent, ExecutionLeg, int], float] | None = None,
ExecutionLeg, int], float] | None = None,
execution_writer: AsyncExecutionWriter | None = None, execution_writer: AsyncExecutionWriter | None = None,
alert_notifier: SupportsAlerts | None = None, alert_notifier: SupportsAlerts | None = None,
audit_repository: AuditRepository | None = None, audit_repository: AuditRepository | None = None,
) -> None: ) -> None:
self._rest_client = rest_client self._rest_client = rest_client
self._available_pairs = {self._normalize_pair( self._available_pairs = {self._normalize_pair(pair) for pair in available_pairs}
pair) for pair in available_pairs}
self._volume_for_leg = volume_for_leg or self._default_volume_for_leg self._volume_for_leg = volume_for_leg or self._default_volume_for_leg
self._execution_writer = execution_writer self._execution_writer = execution_writer
self._alert_notifier = alert_notifier self._alert_notifier = alert_notifier
@@ -102,15 +100,12 @@ class TriangularExecutionSequencer:
raise ValueError(f"No tradable pair for leg {from_cur}->{to_cur}") raise ValueError(f"No tradable pair for leg {from_cur}->{to_cur}")
def _build_legs(self, event: OpportunityEvent) -> tuple[ExecutionLeg, ...]: def _build_legs(self, event: OpportunityEvent) -> tuple[ExecutionLeg, ...]:
currencies = [part.strip().upper() currencies = [part.strip().upper() for part in event.cycle.split("->") if part.strip()]
for part in event.cycle.split("->") if part.strip()]
if len(currencies) < 4 or currencies[0] != currencies[-1]: if len(currencies) < 4 or currencies[0] != currencies[-1]:
raise ValueError( raise ValueError("cycle must be a closed triangular path like A->B->C->A")
"cycle must be a closed triangular path like A->B->C->A")
if len(currencies) != 4: if len(currencies) != 4:
raise ValueError( raise ValueError("cycle must contain exactly three unique currencies")
"cycle must contain exactly three unique currencies")
legs: list[ExecutionLeg] = [] legs: list[ExecutionLeg] = []
for idx in range(3): for idx in range(3):
@@ -125,8 +120,7 @@ class TriangularExecutionSequencer:
) )
volume = self._volume_for_leg(event, placeholder_leg, idx) volume = self._volume_for_leg(event, placeholder_leg, idx)
if volume <= 0.0: if volume <= 0.0:
raise ValueError( raise ValueError("volume_for_leg must return a positive volume")
"volume_for_leg must return a positive volume")
legs.append(self._resolve_leg(from_currency, to_currency, volume)) legs.append(self._resolve_leg(from_currency, to_currency, volume))
return tuple(legs) return tuple(legs)
@@ -215,8 +209,7 @@ class TriangularExecutionSequencer:
responses.append(response) responses.append(response)
if self._execution_writer is not None: if self._execution_writer is not None:
order_ref = self._order_ref_from_response( order_ref = self._order_ref_from_response(response, f"leg-{idx}")
response, f"leg-{idx}")
await self._execution_writer.enqueue( await self._execution_writer.enqueue(
OrderRecord( OrderRecord(
trade_ref=trade_ref, trade_ref=trade_ref,
+1 -1
View File
@@ -1 +1 @@
"""Logging package — DB sink, maintenance tasks.""" """Logging package — DB sink, maintenance tasks."""
+4 -10
View File
@@ -22,8 +22,7 @@ class DbSinkProcessor:
""" """
def __init__(self) -> None: def __init__(self) -> None:
self._queue: asyncio.Queue[dict[str, Any] self._queue: asyncio.Queue[dict[str, Any]] = asyncio.Queue(maxsize=2000)
] = asyncio.Queue(maxsize=2000)
self._consumer_task: asyncio.Task[None] | None = None self._consumer_task: asyncio.Task[None] | None = None
def __call__(self, logger: Any, method_name: str, event_dict: dict[str, Any]) -> dict[str, Any]: def __call__(self, logger: Any, method_name: str, event_dict: dict[str, Any]) -> dict[str, Any]:
@@ -38,9 +37,7 @@ class DbSinkProcessor:
"""Start background consumer task.""" """Start background consumer task."""
if self._consumer_task is not None and not self._consumer_task.done(): if self._consumer_task is not None and not self._consumer_task.done():
return return
self._consumer_task = asyncio.create_task( self._consumer_task = asyncio.create_task(self._consume(store), name="log_db_sink")
self._consume(store), name="log_db_sink"
)
async def stop_consumer(self) -> None: async def stop_consumer(self) -> None:
"""Drain queue and cancel consumer.""" """Drain queue and cancel consumer."""
@@ -81,8 +78,7 @@ class DbSinkProcessor:
level = str(event.pop("level", "info")).upper() level = str(event.pop("level", "info")).upper()
logger = str(event.pop("logger", "root")) logger = str(event.pop("logger", "root"))
message = str(event.pop("event", event.pop("message", ""))) message = str(event.pop("event", event.pop("message", "")))
context = {k: v for k, v in event.items( context = {k: v for k, v in event.items() if not k.startswith("_")} if event else None
) if not k.startswith("_")} if event else None
record = LogRecord( record = LogRecord(
recorded_at=recorded_at, recorded_at=recorded_at,
@@ -118,8 +114,6 @@ def get_db_sink() -> DbSinkProcessor:
return _db_sink return _db_sink
def db_sink_processor( def db_sink_processor(logger: Any, method_name: str, event_dict: dict[str, Any]) -> dict[str, Any]:
logger: Any, method_name: str, event_dict: dict[str, Any]
) -> dict[str, Any]:
"""Standalone processor function wrapping the singleton.""" """Standalone processor function wrapping the singleton."""
return _db_sink(logger, method_name, event_dict) return _db_sink(logger, method_name, event_dict)
+1 -2
View File
@@ -36,8 +36,7 @@ async def run_log_archive(store: PgStore, retention_days: int = _RETENTION_DAYS)
repo = LogArchiveRepository(store) repo = LogArchiveRepository(store)
count = await repo.archive_before(cutoff) count = await repo.archive_before(cutoff)
if count > 0: if count > 0:
_LOG.info("log_archive_complete", _LOG.info("log_archive_complete", cutoff=cutoff.isoformat(), archived=count)
cutoff=cutoff.isoformat(), archived=count)
return count return count
+5 -10
View File
@@ -38,8 +38,7 @@ class MarketDataFeed:
opportunity_writer: AsyncOpportunityWriter | None = None, opportunity_writer: AsyncOpportunityWriter | None = None,
paper_trading_mode: bool = True, paper_trading_mode: bool = True,
opportunity_executor: ( opportunity_executor: (
Callable[[OpportunityEvent], Callable[[OpportunityEvent], Awaitable[ExecutionOutcome | float | None]] | None
Awaitable[ExecutionOutcome | float | None]] | None
) = None, ) = None,
trade_capital: float = 1.0, trade_capital: float = 1.0,
max_trade_capital: float | None = None, max_trade_capital: float | None = None,
@@ -93,8 +92,7 @@ class MarketDataFeed:
return {} return {}
start = currencies[0] start = currencies[0]
exposure_assets = { exposure_assets = {currency for currency in currencies[1:] if currency != start}
currency for currency in currencies[1:] if currency != start}
return {asset: event.allocated_capital for asset in exposure_assets} return {asset: event.allocated_capital for asset in exposure_assets}
async def run(self) -> None: async def run(self) -> None:
@@ -315,8 +313,7 @@ class MarketDataFeed:
continue continue
if self._pre_trade_validator is not None and self._balance_provider is not None: if self._pre_trade_validator is not None and self._balance_provider is not None:
required_balances = { required_balances = {self._quote_balance_asset: event.allocated_capital}
self._quote_balance_asset: event.allocated_capital}
balances = { balances = {
asset.upper(): amount asset.upper(): amount
for asset, amount in self._balance_provider().items() for asset, amount in self._balance_provider().items()
@@ -384,8 +381,7 @@ class MarketDataFeed:
outcome = await self._opportunity_executor(event) outcome = await self._opportunity_executor(event)
except Exception as exc: except Exception as exc:
if self._trade_limits_guard is not None: if self._trade_limits_guard is not None:
self._trade_limits_guard.close_trade( self._trade_limits_guard.close_trade(exposure_by_asset)
exposure_by_asset)
dispatch_alert_nowait( dispatch_alert_nowait(
self._alert_notifier, self._alert_notifier,
@@ -451,8 +447,7 @@ class MarketDataFeed:
realized_pnl = outcome realized_pnl = outcome
if realized_pnl is not None and self._loss_limit_guard is not None: if realized_pnl is not None and self._loss_limit_guard is not None:
self._loss_limit_guard.register_realized_pnl( self._loss_limit_guard.register_realized_pnl(realized_pnl)
realized_pnl)
if self._loss_limit_guard.is_halted: if self._loss_limit_guard.is_halted:
_LOG.warning( _LOG.warning(
"loss_limit_halt_triggered", "loss_limit_halt_triggered",
+1 -2
View File
@@ -86,8 +86,7 @@ class OrderBook:
BookLevel(price=price, volume=self._bids[price]) BookLevel(price=price, volume=self._bids[price])
for price in reversed(bid_keys[-depth:]) for price in reversed(bid_keys[-depth:])
] ]
asks = [BookLevel(price=price, volume=self._asks[price]) asks = [BookLevel(price=price, volume=self._asks[price]) for price in ask_keys[:depth]]
for price in ask_keys[:depth]]
return bids, asks return bids, asks
def compute_checksum(self, depth: int = 10) -> int: def compute_checksum(self, depth: int = 10) -> int:
+41 -22
View File
@@ -51,23 +51,34 @@ class MetricsCalculator:
WHERE volume > 0 AND filled_volume IS NOT NULL WHERE volume > 0 AND filled_volume IS NOT NULL
""") """)
r_pnl_usd = float( r_pnl_usd = (
tm["realized_pnl_usd"]) if tm and tm["realized_pnl_usd"] is not None else 0.0 float(tm["realized_pnl_usd"]) if tm and tm["realized_pnl_usd"] is not None else 0.0
tt = int(tm["total_trades"] )
) if tm and tm["total_trades"] is not None else 0 tt = int(tm["total_trades"]) if tm and tm["total_trades"] is not None else 0
wt = int(tm["winning_trades"] wt = int(tm["winning_trades"]) if tm and tm["winning_trades"] is not None else 0
) if tm and tm["winning_trades"] is not None else 0
wr = wt / tt if tt > 0 else None wr = wt / tt if tt > 0 else None
atd = float(tm["avg_trade_duration_seconds"] atd = (
) if tm and tm["avg_trade_duration_seconds"] is not None else None float(tm["avg_trade_duration_seconds"])
if tm and tm["avg_trade_duration_seconds"] is not None
else None
)
oc = int(om["opportunity_count"] oc = (
) if om is not None and om["opportunity_count"] is not None else 0 int(om["opportunity_count"])
fo = om["first_detected_at"] if om is not None and isinstance( if om is not None and om["opportunity_count"] is not None
om["first_detected_at"], datetime) else None else 0
lo = om["last_detected_at"] if om is not None and isinstance( )
om["last_detected_at"], datetime) else None fo = (
om["first_detected_at"]
if om is not None and isinstance(om["first_detected_at"], datetime)
else None
)
lo = (
om["last_detected_at"]
if om is not None and isinstance(om["last_detected_at"], datetime)
else None
)
opportunities_per_minute: float | None opportunities_per_minute: float | None
if oc >= 2 and fo is not None and lo is not None: if oc >= 2 and fo is not None and lo is not None:
@@ -80,15 +91,23 @@ class MetricsCalculator:
else: else:
opportunities_per_minute = None opportunities_per_minute = None
fill_rate = float( fill_rate = float(fm["fill_rate"]) if fm and fm["fill_rate"] is not None else None
fm["fill_rate"]) if fm and fm["fill_rate"] is not None else None
lp50 = float(tm["latency_p50_seconds"] lp50 = (
) if tm and tm["latency_p50_seconds"] is not None else None float(tm["latency_p50_seconds"])
lp95 = float(tm["latency_p95_seconds"] if tm and tm["latency_p50_seconds"] is not None
) if tm and tm["latency_p95_seconds"] is not None else None else None
lp99 = float(tm["latency_p99_seconds"] )
) if tm and tm["latency_p99_seconds"] is not None else None lp95 = (
float(tm["latency_p95_seconds"])
if tm and tm["latency_p95_seconds"] is not None
else None
)
lp99 = (
float(tm["latency_p99_seconds"])
if tm and tm["latency_p99_seconds"] is not None
else None
)
return PerformanceMetrics( return PerformanceMetrics(
realized_pnl_usd=r_pnl_usd, realized_pnl_usd=r_pnl_usd,
+3 -1
View File
@@ -106,7 +106,9 @@ async def _run_startup_reconciler(app: FastAPI) -> None:
await result await result
async def persist_runtime_snapshot(app: FastAPI, *, note: str | None = None) -> RuntimeStateRecord | None: async def persist_runtime_snapshot(
app: FastAPI, *, note: str | None = None
) -> RuntimeStateRecord | None:
repository = _runtime_repository(app) repository = _runtime_repository(app)
if repository is None: if repository is None:
return None return None
+1 -2
View File
@@ -36,8 +36,7 @@ class AsyncExecutionWriter:
async def start(self) -> None: async def start(self) -> None:
if self._task is None or self._task.done(): if self._task is None or self._task.done():
self._stop.clear() self._stop.clear()
self._task = asyncio.create_task( self._task = asyncio.create_task(self._run(), name="execution-writer")
self._run(), name="execution-writer")
async def stop(self) -> None: async def stop(self) -> None:
self._stop.set() self._stop.set()
+3 -6
View File
@@ -24,16 +24,14 @@ class MarketSnapshot:
class AsyncMarketSnapshotWriter: class AsyncMarketSnapshotWriter:
def __init__(self, repository: MarketSnapshotRepository, max_queue_size: int = 50_000) -> None: def __init__(self, repository: MarketSnapshotRepository, max_queue_size: int = 50_000) -> None:
self._repository = repository self._repository = repository
self._queue: asyncio.Queue[MarketSnapshot] = asyncio.Queue( self._queue: asyncio.Queue[MarketSnapshot] = asyncio.Queue(maxsize=max_queue_size)
maxsize=max_queue_size)
self._task: asyncio.Task[None] | None = None self._task: asyncio.Task[None] | None = None
self._stop = asyncio.Event() self._stop = asyncio.Event()
async def start(self) -> None: async def start(self) -> None:
if self._task is None or self._task.done(): if self._task is None or self._task.done():
self._stop.clear() self._stop.clear()
self._task = asyncio.create_task( self._task = asyncio.create_task(self._run(), name="market-snapshot-writer")
self._run(), name="market-snapshot-writer")
async def stop(self) -> None: async def stop(self) -> None:
self._stop.set() self._stop.set()
@@ -61,7 +59,6 @@ class AsyncMarketSnapshotWriter:
) )
) )
except Exception as exc: except Exception as exc:
_LOG.error("market_snapshot_write_failed", _LOG.error("market_snapshot_write_failed", error=str(exc), symbol=item.symbol)
error=str(exc), symbol=item.symbol)
finally: finally:
self._queue.task_done() self._queue.task_done()
+2 -4
View File
@@ -13,16 +13,14 @@ _LOG = structlog.get_logger(__name__)
class AsyncOpportunityWriter: class AsyncOpportunityWriter:
def __init__(self, repository: OpportunityRepository, max_queue_size: int = 50_000) -> None: def __init__(self, repository: OpportunityRepository, max_queue_size: int = 50_000) -> None:
self._repository = repository self._repository = repository
self._queue: asyncio.Queue[OpportunityEvent] = asyncio.Queue( self._queue: asyncio.Queue[OpportunityEvent] = asyncio.Queue(maxsize=max_queue_size)
maxsize=max_queue_size)
self._task: asyncio.Task[None] | None = None self._task: asyncio.Task[None] | None = None
self._stop = asyncio.Event() self._stop = asyncio.Event()
async def start(self) -> None: async def start(self) -> None:
if self._task is None or self._task.done(): if self._task is None or self._task.done():
self._stop.clear() self._stop.clear()
self._task = asyncio.create_task( self._task = asyncio.create_task(self._run(), name="opportunity-writer")
self._run(), name="opportunity-writer")
async def stop(self) -> None: async def stop(self) -> None:
self._stop.set() self._stop.set()
+1 -3
View File
@@ -128,7 +128,5 @@ class PgStore:
col_name = column_def.split()[0] col_name = column_def.split()[0]
if col_name not in existing: if col_name not in existing:
async with self.pool.acquire() as conn: async with self.pool.acquire() as conn:
await conn.execute( await conn.execute(f"ALTER TABLE {table_name} ADD COLUMN {column_def}")
f"ALTER TABLE {table_name} ADD COLUMN {column_def}"
)
_LOG.info("pg_column_added", table=table_name, column=col_name) _LOG.info("pg_column_added", table=table_name, column=col_name)
+73 -39
View File
@@ -258,11 +258,7 @@ class AuditRepository:
record.actor, record.actor,
record.event_type, record.event_type,
record.decision, record.decision,
( (None if record.payload is None else orjson.dumps(record.payload).decode("utf-8")),
None
if record.payload is None
else orjson.dumps(record.payload).decode("utf-8")
),
record.correlation_id, record.correlation_id,
) )
@@ -294,8 +290,9 @@ class AuditRepository:
event_type=str(row["event_type"]), event_type=str(row["event_type"]),
decision=str(row["decision"]), decision=str(row["decision"]),
payload=payload, payload=payload,
correlation_id=str( correlation_id=(
row["correlation_id"]) if row["correlation_id"] is not None else None, str(row["correlation_id"]) if row["correlation_id"] is not None else None
),
) )
) )
@@ -364,8 +361,9 @@ class RuntimeStateRepository:
snapshot_at=row["snapshot_at"], snapshot_at=row["snapshot_at"],
is_running=bool(row["is_running"]), is_running=bool(row["is_running"]),
kill_switch_active=bool(row["kill_switch_active"]), kill_switch_active=bool(row["kill_switch_active"]),
kill_switch_reason=str( kill_switch_reason=(
row["kill_switch_reason"]) if row["kill_switch_reason"] is not None else None, str(row["kill_switch_reason"]) if row["kill_switch_reason"] is not None else None
),
open_trade_count=int(row["open_trade_count"]), open_trade_count=int(row["open_trade_count"]),
last_known_balances=balances, last_known_balances=balances,
note=str(row["note"]) if row["note"] is not None else None, note=str(row["note"]) if row["note"] is not None else None,
@@ -386,10 +384,16 @@ class ConfigSectionRepository:
VALUES ($1, $2) VALUES ($1, $2)
RETURNING id, name, description, updated_at RETURNING id, name, description, updated_at
""", """,
section.name, section.description, section.name,
section.description,
) )
if row: if row:
return ConfigSection(id=row["id"], name=row["name"], description=row["description"], updated_at=row["updated_at"]) return ConfigSection(
id=row["id"],
name=row["name"],
description=row["description"],
updated_at=row["updated_at"],
)
raise ValueError("Failed to create section") raise ValueError("Failed to create section")
async def get_section(self, name: str) -> ConfigSection | None: async def get_section(self, name: str) -> ConfigSection | None:
@@ -404,7 +408,12 @@ class ConfigSectionRepository:
name, name,
) )
if row: if row:
return ConfigSection(id=row["id"], name=row["name"], description=row["description"], updated_at=row["updated_at"]) return ConfigSection(
id=row["id"],
name=row["name"],
description=row["description"],
updated_at=row["updated_at"],
)
return None return None
async def list_sections(self) -> list[ConfigSection]: async def list_sections(self) -> list[ConfigSection]:
@@ -417,7 +426,11 @@ class ConfigSectionRepository:
""") """)
return [ return [
ConfigSection( ConfigSection(
id=r["id"], name=r["name"], description=r["description"], updated_at=r["updated_at"]) id=r["id"],
name=r["name"],
description=r["description"],
updated_at=r["updated_at"],
)
for r in rows for r in rows
] ]
@@ -571,7 +584,7 @@ class ConfigSettingRepository:
ts = row["latest_updated_at"] ts = row["latest_updated_at"]
if isinstance(ts, str): if isinstance(ts, str):
return datetime.fromisoformat(ts.replace("Z", "+00:00")) return datetime.fromisoformat(ts.replace("Z", "+00:00"))
return ts # type: ignore[no-any-return] return ts # type: ignore[no-any-return]
return None return None
@@ -614,7 +627,8 @@ class ConfigPairingRepository:
FROM config_pairings FROM config_pairings
WHERE base_asset = $1 AND quote_asset = $2 WHERE base_asset = $1 AND quote_asset = $2
""", """,
base_asset, quote_asset, base_asset,
quote_asset,
) )
if row: if row:
return ConfigPairing( return ConfigPairing(
@@ -665,7 +679,8 @@ class ConfigPairingRepository:
DELETE FROM config_pairings DELETE FROM config_pairings
WHERE base_asset = $1 AND quote_asset = $2 WHERE base_asset = $1 AND quote_asset = $2
""", """,
base_asset, quote_asset, base_asset,
quote_asset,
) )
if result is None: if result is None:
return False return False
@@ -732,7 +747,9 @@ class ConfigBacktestingDefaultsRepository:
def __init__(self, store: PgStore) -> None: def __init__(self, store: PgStore) -> None:
self._store = store self._store = store
async def create_defaults(self, defaults: ConfigBacktestingDefaults) -> ConfigBacktestingDefaults: async def create_defaults(
self, defaults: ConfigBacktestingDefaults
) -> ConfigBacktestingDefaults:
"""Create new backtesting defaults.""" """Create new backtesting defaults."""
async with self._store.pool.acquire() as conn: async with self._store.pool.acquire() as conn:
balances_json = ( balances_json = (
@@ -754,8 +771,9 @@ class ConfigBacktestingDefaultsRepository:
) )
if row: if row:
return ConfigBacktestingDefaults( return ConfigBacktestingDefaults(
starting_balances=orjson.loads( starting_balances=(
row["starting_balances"]) if row["starting_balances"] else None, orjson.loads(row["starting_balances"]) if row["starting_balances"] else None
),
trade_capital=row["trade_capital"], trade_capital=row["trade_capital"],
min_profit_threshold=row["min_profit_threshold"], min_profit_threshold=row["min_profit_threshold"],
slippage_bps=row["slippage_bps"], slippage_bps=row["slippage_bps"],
@@ -774,8 +792,9 @@ class ConfigBacktestingDefaultsRepository:
""") """)
if row: if row:
return ConfigBacktestingDefaults( return ConfigBacktestingDefaults(
starting_balances=orjson.loads( starting_balances=(
row["starting_balances"]) if row["starting_balances"] else None, orjson.loads(row["starting_balances"]) if row["starting_balances"] else None
),
trade_capital=row["trade_capital"], trade_capital=row["trade_capital"],
min_profit_threshold=row["min_profit_threshold"], min_profit_threshold=row["min_profit_threshold"],
slippage_bps=row["slippage_bps"], slippage_bps=row["slippage_bps"],
@@ -783,7 +802,9 @@ class ConfigBacktestingDefaultsRepository:
) )
return None return None
async def update_defaults(self, defaults: ConfigBacktestingDefaults) -> ConfigBacktestingDefaults: async def update_defaults(
self, defaults: ConfigBacktestingDefaults
) -> ConfigBacktestingDefaults:
"""Update the backtesting defaults.""" """Update the backtesting defaults."""
async with self._store.pool.acquire() as conn: async with self._store.pool.acquire() as conn:
starting_balances_json = ( starting_balances_json = (
@@ -808,8 +829,9 @@ class ConfigBacktestingDefaultsRepository:
) )
if row: if row:
return ConfigBacktestingDefaults( return ConfigBacktestingDefaults(
starting_balances=orjson.loads( starting_balances=(
row["starting_balances"]) if row["starting_balances"] else None, orjson.loads(row["starting_balances"]) if row["starting_balances"] else None
),
trade_capital=row["trade_capital"], trade_capital=row["trade_capital"],
min_profit_threshold=row["min_profit_threshold"], min_profit_threshold=row["min_profit_threshold"],
slippage_bps=row["slippage_bps"], slippage_bps=row["slippage_bps"],
@@ -878,10 +900,12 @@ class KrakenAccountSnapshotRepository:
maker_fee=row["maker_fee"], maker_fee=row["maker_fee"],
taker_fee=row["taker_fee"], taker_fee=row["taker_fee"],
thirty_day_volume=row["thirty_day_volume"], thirty_day_volume=row["thirty_day_volume"],
trade_balance_raw=orjson.loads( trade_balance_raw=(
row["trade_balance_raw"]) if row["trade_balance_raw"] else None, orjson.loads(row["trade_balance_raw"]) if row["trade_balance_raw"] else None
fee_schedule_raw=orjson.loads( ),
row["fee_schedule_raw"]) if row["fee_schedule_raw"] else None, fee_schedule_raw=(
orjson.loads(row["fee_schedule_raw"]) if row["fee_schedule_raw"] else None
),
) )
@@ -916,7 +940,8 @@ class BacktestJobRepository:
VALUES ($1, $2) VALUES ($1, $2)
RETURNING id, status, events_path, config, created_at RETURNING id, status, events_path, config, created_at
""", """,
events_path, job_config_json, events_path,
job_config_json,
) )
if row is None: if row is None:
raise ValueError("Failed to create backtest job") raise ValueError("Failed to create backtest job")
@@ -933,24 +958,30 @@ class BacktestJobRepository:
if status == "running": if status == "running":
await conn.execute( await conn.execute(
"UPDATE backtest_jobs SET status = $1, started_at = CURRENT_TIMESTAMP WHERE id = $2", "UPDATE backtest_jobs SET status = $1, started_at = CURRENT_TIMESTAMP WHERE id = $2",
status, job_id, status,
job_id,
) )
elif status in ("completed", "failed"): elif status in ("completed", "failed"):
await conn.execute( await conn.execute(
"UPDATE backtest_jobs SET status = $1, finished_at = CURRENT_TIMESTAMP, error = $2 WHERE id = $3", "UPDATE backtest_jobs SET status = $1, finished_at = CURRENT_TIMESTAMP, error = $2 WHERE id = $3",
status, error, job_id, status,
error,
job_id,
) )
else: else:
await conn.execute( await conn.execute(
"UPDATE backtest_jobs SET status = $1, error = $2 WHERE id = $3", "UPDATE backtest_jobs SET status = $1, error = $2 WHERE id = $3",
status, error, job_id, status,
error,
job_id,
) )
async def store_report(self, job_id: str, report: dict[str, Any]) -> None: async def store_report(self, job_id: str, report: dict[str, Any]) -> None:
async with self._store.pool.acquire() as conn: async with self._store.pool.acquire() as conn:
await conn.execute( await conn.execute(
"UPDATE backtest_jobs SET report = $1 WHERE id = $2", "UPDATE backtest_jobs SET report = $1 WHERE id = $2",
orjson.dumps(report).decode("utf-8"), job_id, orjson.dumps(report).decode("utf-8"),
job_id,
) )
async def get_job(self, job_id: str) -> BacktestJobRecord | None: async def get_job(self, job_id: str) -> BacktestJobRecord | None:
@@ -1082,7 +1113,9 @@ class LogRepository:
ORDER BY recorded_at DESC ORDER BY recorded_at DESC
LIMIT ${idx} OFFSET ${idx + 1} LIMIT ${idx} OFFSET ${idx + 1}
""", """,
*params, limit, offset, *params,
limit,
offset,
) )
return [ return [
LogRecord( LogRecord(
@@ -1153,9 +1186,7 @@ class LogArchiveRepository:
cutoff, cutoff,
) )
# Delete originals # Delete originals
await conn.execute( await conn.execute("DELETE FROM app_logs WHERE recorded_at < $1", cutoff)
"DELETE FROM app_logs WHERE recorded_at < $1", cutoff
)
if isinstance(result, str): if isinstance(result, str):
parts = result.split() parts = result.split()
if len(parts) == 2 and parts[0] == "INSERT": if len(parts) == 2 and parts[0] == "INSERT":
@@ -1221,7 +1252,9 @@ class LogAggregationRepository:
ORDER BY bucket_start DESC ORDER BY bucket_start DESC
LIMIT $3 LIMIT $3
""", """,
period, level.upper(), limit, period,
level.upper(),
limit,
) )
else: else:
rows = await conn.fetch( rows = await conn.fetch(
@@ -1232,7 +1265,8 @@ class LogAggregationRepository:
ORDER BY bucket_start DESC ORDER BY bucket_start DESC
LIMIT $2 LIMIT $2
""", """,
period, limit, period,
limit,
) )
return [ return [
LogAggregateRecord( LogAggregateRecord(
+1 -1
View File
@@ -1 +1 @@
"""Integration tests for PostgreSQL schema and connectivity.""" """Integration tests for PostgreSQL schema and connectivity."""
+2 -6
View File
@@ -11,13 +11,9 @@ import pathlib
import pytest import pytest
def pytest_ignore_collect( def pytest_ignore_collect(collection_path: pathlib.Path, config: pytest.Config) -> bool:
collection_path: pathlib.Path, config: pytest.Config
) -> bool:
"""Skip integration tests unless --integration is passed.""" """Skip integration tests unless --integration is passed."""
if "integration" in str(collection_path) and not config.getoption( if "integration" in str(collection_path) and not config.getoption("--integration", False):
"--integration", False
):
return True return True
return False return False
+60 -10
View File
@@ -42,9 +42,24 @@ async def test_metrics_calculator_summarizes_execution_data() -> None:
) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9), ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9),
($10, $11, $12, $13, $14, $15, $16, $17, $18) ($10, $11, $12, $13, $14, $15, $16, $17, $18)
""", """,
"trade-1", started, finished, "filled", 12.5, 10.0, 100.0, "USD->BTC->ETH->USD", 3, "trade-1",
"trade-2", started_two, finished_two, "filled", - started,
4.5, -2.0, 200.0, "USD->ETH->BTC->USD", 3, finished,
"filled",
12.5,
10.0,
100.0,
"USD->BTC->ETH->USD",
3,
"trade-2",
started_two,
finished_two,
"filled",
-4.5,
-2.0,
200.0,
"USD->ETH->BTC->USD",
3,
) )
await conn.execute( await conn.execute(
""" """
@@ -53,11 +68,24 @@ async def test_metrics_calculator_summarizes_execution_data() -> None:
($7, $8, $9, $10, $11, $12), ($7, $8, $9, $10, $11, $12),
($13, $14, $15, $16, $17, $18) ($13, $14, $15, $16, $17, $18)
""", """,
started, "USD->BTC->ETH->USD", 4.0, 3.0, 0.03, True, started,
started_two, "USD->ETH->BTC->USD", 2.0, 1.0, 0.01, False, "USD->BTC->ETH->USD",
started_two + 4.0,
timedelta( 3.0,
seconds=30), "USD->BTC->ETH->USD", 5.0, 4.0, 0.04, True, 0.03,
True,
started_two,
"USD->ETH->BTC->USD",
2.0,
1.0,
0.01,
False,
started_two + timedelta(seconds=30),
"USD->BTC->ETH->USD",
5.0,
4.0,
0.04,
True,
) )
await conn.execute( await conn.execute(
""" """
@@ -67,8 +95,30 @@ async def test_metrics_calculator_summarizes_execution_data() -> None:
) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12), ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12),
($13, $14, $15, $16, $17, $18, $19, $20, $21, $22, $23, $24) ($13, $14, $15, $16, $17, $18, $19, $20, $21, $22, $23, $24)
""", """,
"trade-1", "order-1", 0, "BTC/USD", "buy", 2.0, 101, "closed", 2.0, 100.0, "{}", started, "trade-1",
"trade-2", "order-2", 0, "ETH/USD", "sell", 4.0, 202, "closed", 3.0, 200.0, "{}", started_two, "order-1",
0,
"BTC/USD",
"buy",
2.0,
101,
"closed",
2.0,
100.0,
"{}",
started,
"trade-2",
"order-2",
0,
"ETH/USD",
"sell",
4.0,
202,
"closed",
3.0,
200.0,
"{}",
started_two,
) )
metrics = await MetricsCalculator(store).compute() metrics = await MetricsCalculator(store).compute()
+99 -35
View File
@@ -25,50 +25,116 @@ EXPECTED_TABLES: dict[str, list[str]] = {
"schema_migrations": ["version", "applied_at"], "schema_migrations": ["version", "applied_at"],
"config_sections": ["id", "name", "description", "updated_at"], "config_sections": ["id", "name", "description", "updated_at"],
"config_settings": [ "config_settings": [
"key", "section", "value_json", "value_type", "is_secret", "key",
"is_runtime_reloadable", "updated_at", "updated_by", "section",
"value_json",
"value_type",
"is_secret",
"is_runtime_reloadable",
"updated_at",
"updated_by",
], ],
"config_pairings": [ "config_pairings": [
"id", "base_asset", "quote_asset", "enabled", "source", "id",
"created_at", "updated_at", "base_asset",
"quote_asset",
"enabled",
"source",
"created_at",
"updated_at",
], ],
"config_backtesting_defaults": [ "config_backtesting_defaults": [
"id", "starting_balances", "trade_capital", "min_profit_threshold", "id",
"slippage_bps", "execution_latency_ms", "fee_source", "starting_balances",
"trade_capital",
"min_profit_threshold",
"slippage_bps",
"execution_latency_ms",
"fee_source",
], ],
"opportunities": [ "opportunities": [
"id", "detected_at", "cycle", "gross_pct", "net_pct", "id",
"est_profit", "executed", "detected_at",
"cycle",
"gross_pct",
"net_pct",
"est_profit",
"executed",
], ],
"trades": [ "trades": [
"id", "trade_ref", "started_at", "finished_at", "status", "id",
"realized_pnl", "estimated_pnl", "capital_used", "cycle", "leg_count", "trade_ref",
"started_at",
"finished_at",
"status",
"realized_pnl",
"estimated_pnl",
"capital_used",
"cycle",
"leg_count",
], ],
"orders": [ "orders": [
"id", "trade_ref", "order_ref", "leg_index", "pair", "side", "id",
"volume", "user_ref", "status", "filled_volume", "avg_price", "trade_ref",
"raw_response", "recorded_at", "order_ref",
"leg_index",
"pair",
"side",
"volume",
"user_ref",
"status",
"filled_volume",
"avg_price",
"raw_response",
"recorded_at",
], ],
"pnl_events": [ "pnl_events": [
"id", "trade_ref", "recorded_at", "kind", "pnl_usd", "source", "id",
"trade_ref",
"recorded_at",
"kind",
"pnl_usd",
"source",
], ],
"portfolio_snapshots": ["snapshot_at", "balances", "total_value_usd"], "portfolio_snapshots": ["snapshot_at", "balances", "total_value_usd"],
"market_snapshots": ["snapshot_at", "symbol", "source", "payload", "latency_ms"], "market_snapshots": ["snapshot_at", "symbol", "source", "payload", "latency_ms"],
"audit_events": [ "audit_events": [
"id", "occurred_at", "actor", "event_type", "decision", "id",
"payload", "correlation_id", "occurred_at",
"actor",
"event_type",
"decision",
"payload",
"correlation_id",
], ],
"runtime_state_snapshots": [ "runtime_state_snapshots": [
"snapshot_at", "is_running", "kill_switch_active", "kill_switch_reason", "snapshot_at",
"open_trade_count", "last_known_balances", "note", "is_running",
"kill_switch_active",
"kill_switch_reason",
"open_trade_count",
"last_known_balances",
"note",
], ],
"kraken_account_snapshots": [ "kraken_account_snapshots": [
"snapshot_at", "fee_tier", "maker_fee", "taker_fee", "snapshot_at",
"thirty_day_volume", "trade_balance_raw", "fee_schedule_raw", "fee_tier",
"maker_fee",
"taker_fee",
"thirty_day_volume",
"trade_balance_raw",
"fee_schedule_raw",
], ],
"backtest_jobs": [ "backtest_jobs": [
"id", "status", "events_path", "config", "report", "error", "id",
"created_at", "started_at", "finished_at", "status",
"events_path",
"config",
"report",
"error",
"created_at",
"started_at",
"finished_at",
], ],
} }
@@ -96,6 +162,7 @@ TABLES_WITH_UNIQUE_CONSTRAINTS: dict[str, list[str]] = {
# ── fixtures ──────────────────────────────────────────────────────────────── # ── fixtures ────────────────────────────────────────────────────────────────
@asynccontextmanager @asynccontextmanager
async def _pg_lifecycle() -> AsyncIterator[PgStore]: async def _pg_lifecycle() -> AsyncIterator[PgStore]:
"""Connect, yield store, then disconnect.""" """Connect, yield store, then disconnect."""
@@ -116,6 +183,7 @@ async def pg_fixture() -> AsyncIterator[PgStore]:
# ── helpers ───────────────────────────────────────────────────────────────── # ── helpers ─────────────────────────────────────────────────────────────────
async def _get_actual_tables(store: PgStore) -> dict[str, list[str]]: async def _get_actual_tables(store: PgStore) -> dict[str, list[str]]:
"""Return {table_name: [column_name, ...]} for the public schema.""" """Return {table_name: [column_name, ...]} for the public schema."""
actual: dict[str, list[str]] = {} actual: dict[str, list[str]] = {}
@@ -139,6 +207,7 @@ async def _table_row_count(store: PgStore, table: str) -> int:
# ── tests ─────────────────────────────────────────────────────────────────── # ── tests ───────────────────────────────────────────────────────────────────
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_pg_connect(pg: PgStore) -> None: async def test_pg_connect(pg: PgStore) -> None:
"""Can connect to PostgreSQL and ping the server.""" """Can connect to PostgreSQL and ping the server."""
@@ -165,8 +234,7 @@ async def test_schema_migration_applies(pg: PgStore) -> None:
for table in EXPECTED_TABLES: for table in EXPECTED_TABLES:
assert table in actual, ( assert table in actual, (
f"Table '{table}' missing after migration. " f"Table '{table}' missing after migration. " f"Found tables: {sorted(actual)}"
f"Found tables: {sorted(actual)}"
) )
@@ -190,8 +258,7 @@ async def test_table_columns(pg: PgStore) -> None:
actual_cols = actual.get(table, []) actual_cols = actual.get(table, [])
for col in expected_cols: for col in expected_cols:
assert col in actual_cols, ( assert col in actual_cols, (
f"Column '{col}' missing from table '{table}'. " f"Column '{col}' missing from table '{table}'. " f"Actual columns: {actual_cols}"
f"Actual columns: {actual_cols}"
) )
@@ -250,8 +317,7 @@ async def test_table_row_count_is_zero(pg: PgStore) -> None:
for table in EXPECTED_TABLES: for table in EXPECTED_TABLES:
count = await _table_row_count(pg, table) count = await _table_row_count(pg, table)
assert count == 0, ( assert count == 0, (
f"Table '{table}' should be empty after migration, " f"Table '{table}' should be empty after migration, " f"but has {count} rows"
f"but has {count} rows"
) )
@@ -262,13 +328,10 @@ async def test_schema_migration_version_recorded(pg: PgStore) -> None:
await pg.migrate() await pg.migrate()
async with pg.pool.acquire() as conn: async with pg.pool.acquire() as conn:
row = await conn.fetchrow( row = await conn.fetchrow("SELECT MAX(version) AS v FROM schema_migrations")
"SELECT MAX(version) AS v FROM schema_migrations"
)
assert row is not None assert row is not None
assert row["v"] == SCHEMA_VERSION, ( assert row["v"] == SCHEMA_VERSION, (
f"Expected schema version {SCHEMA_VERSION}, " f"Expected schema version {SCHEMA_VERSION}, " f"got {row['v']}"
f"got {row['v']}"
) )
@@ -280,7 +343,8 @@ async def test_create_and_query_row(pg: PgStore) -> None:
# ConfigSections round-trip # ConfigSections round-trip
await conn.execute( await conn.execute(
"INSERT INTO config_sections (name, description) VALUES ($1, $2)", "INSERT INTO config_sections (name, description) VALUES ($1, $2)",
"test_section", "A test section for integration test", "test_section",
"A test section for integration test",
) )
row = await conn.fetchrow( row = await conn.fetchrow(
"SELECT name, description FROM config_sections WHERE name = $1", "SELECT name, description FROM config_sections WHERE name = $1",
@@ -357,4 +421,4 @@ async def test_audit_list_recent(pg: PgStore) -> None:
# Verify payload serialization worked # Verify payload serialization worked
first = recent[0] first = recent[0]
if first.payload: if first.payload:
assert "index" in first.payload assert "index" in first.payload
+1 -2
View File
@@ -39,8 +39,7 @@ async def test_end_to_end_config_workflow():
# Mock the setting creation # Mock the setting creation
mock_created_setting = Mock() mock_created_setting = Mock()
mock_created_setting.updated_at = "2023-01-01T00:00:00" mock_created_setting.updated_at = "2023-01-01T00:00:00"
mock_repo_instance.create_setting = AsyncMock( mock_repo_instance.create_setting = AsyncMock(return_value=mock_created_setting)
return_value=mock_created_setting)
mock_repo_instance.get_setting = AsyncMock(return_value=None) mock_repo_instance.get_setting = AsyncMock(return_value=None)
mock_repo_instance.get_latest_updated_at = AsyncMock(return_value=None) mock_repo_instance.get_latest_updated_at = AsyncMock(return_value=None)
mock_repo_instance.list_settings = AsyncMock(return_value=[]) mock_repo_instance.list_settings = AsyncMock(return_value=[])
+3 -7
View File
@@ -136,10 +136,8 @@ async def test_config_setting_repository_list_settings(mock_store):
repo = ConfigSettingRepository(mock_store) repo = ConfigSettingRepository(mock_store)
conn = await mock_store.pool.acquire().__aenter__() conn = await mock_store.pool.acquire().__aenter__()
row1 = _make_row({**SETTING_ROW, "key": "test_key1", row1 = _make_row({**SETTING_ROW, "key": "test_key1", "value_json": "test_value1"})
"value_json": "test_value1"}) row2 = _make_row({**SETTING_ROW, "key": "test_key2", "value_json": "test_value2"})
row2 = _make_row({**SETTING_ROW, "key": "test_key2",
"value_json": "test_value2"})
conn.fetch = AsyncMock(return_value=[row1, row2]) conn.fetch = AsyncMock(return_value=[row1, row2])
result = await repo.list_settings() result = await repo.list_settings()
@@ -176,9 +174,7 @@ async def test_config_pairing_repository_create_pairing(mock_store):
conn = await mock_store.pool.acquire().__aenter__() conn = await mock_store.pool.acquire().__aenter__()
conn.fetchrow = AsyncMock(return_value=_make_row(PAIRING_ROW)) conn.fetchrow = AsyncMock(return_value=_make_row(PAIRING_ROW))
pairing = ConfigPairing( pairing = ConfigPairing(base_asset="BTC", quote_asset="USD", enabled=True, source="Kraken")
base_asset="BTC", quote_asset="USD", enabled=True, source="Kraken"
)
result = await repo.create_pairing(pairing) result = await repo.create_pairing(pairing)
+11 -12
View File
@@ -63,8 +63,7 @@ async def test_configuration_service_set_setting(mock_settings, mock_store, mock
mock_created_setting = Mock() mock_created_setting = Mock()
mock_created_setting.updated_at = "2023-01-01T00:00:00" mock_created_setting.updated_at = "2023-01-01T00:00:00"
mock_repo_instance.create_setting = AsyncMock( mock_repo_instance.create_setting = AsyncMock(return_value=mock_created_setting)
return_value=mock_created_setting)
mock_repo_instance.get_setting = AsyncMock(return_value=None) mock_repo_instance.get_setting = AsyncMock(return_value=None)
await service.set_setting("test_key", "test_value", "test_user") await service.set_setting("test_key", "test_value", "test_user")
@@ -73,7 +72,9 @@ async def test_configuration_service_set_setting(mock_settings, mock_store, mock
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_configuration_service_hot_reload_detection(mock_settings, mock_store, mock_audit_repo): async def test_configuration_service_hot_reload_detection(
mock_settings, mock_store, mock_audit_repo
):
"""Test hot-reload detection functionality.""" """Test hot-reload detection functionality."""
service = ConfigurationService(mock_settings, mock_store, mock_audit_repo) service = ConfigurationService(mock_settings, mock_store, mock_audit_repo)
@@ -86,8 +87,7 @@ async def test_configuration_service_hot_reload_detection(mock_settings, mock_st
from datetime import datetime from datetime import datetime
mock_repo_instance.get_latest_updated_at = AsyncMock( mock_repo_instance.get_latest_updated_at = AsyncMock(return_value=datetime.now())
return_value=datetime.now())
assert await service.is_config_outdated() is True assert await service.is_config_outdated() is True
@@ -105,8 +105,7 @@ async def test_configuration_service_reload_if_changed(mock_settings, mock_store
from datetime import datetime from datetime import datetime
mock_repo_instance.get_latest_updated_at = AsyncMock( mock_repo_instance.get_latest_updated_at = AsyncMock(return_value=datetime.now())
return_value=datetime.now())
result = await service.reload_if_changed() result = await service.reload_if_changed()
assert result is True assert result is True
@@ -125,8 +124,7 @@ async def test_configuration_service_get_config_version(mock_settings, mock_stor
mock_created_setting = Mock() mock_created_setting = Mock()
mock_created_setting.updated_at = "2023-01-01T00:00:00" mock_created_setting.updated_at = "2023-01-01T00:00:00"
mock_repo_instance.create_setting = AsyncMock( mock_repo_instance.create_setting = AsyncMock(return_value=mock_created_setting)
return_value=mock_created_setting)
mock_repo_instance.get_setting = AsyncMock(return_value=None) mock_repo_instance.get_setting = AsyncMock(return_value=None)
await service.set_setting("test_key", "test_value", "test_user") await service.set_setting("test_key", "test_value", "test_user")
@@ -134,7 +132,9 @@ async def test_configuration_service_get_config_version(mock_settings, mock_stor
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_configuration_service_get_last_updated_at(mock_settings, mock_store, mock_audit_repo): async def test_configuration_service_get_last_updated_at(
mock_settings, mock_store, mock_audit_repo
):
"""Test getting last updated timestamp.""" """Test getting last updated timestamp."""
service = ConfigurationService(mock_settings, mock_store, mock_audit_repo) service = ConfigurationService(mock_settings, mock_store, mock_audit_repo)
assert service.get_last_updated_at() is None assert service.get_last_updated_at() is None
@@ -145,8 +145,7 @@ async def test_configuration_service_get_last_updated_at(mock_settings, mock_sto
mock_created_setting = Mock() mock_created_setting = Mock()
mock_created_setting.updated_at = "2023-01-01T00:00:00" mock_created_setting.updated_at = "2023-01-01T00:00:00"
mock_repo_instance.create_setting = AsyncMock( mock_repo_instance.create_setting = AsyncMock(return_value=mock_created_setting)
return_value=mock_created_setting)
mock_repo_instance.get_setting = AsyncMock(return_value=None) mock_repo_instance.get_setting = AsyncMock(return_value=None)
await service.set_setting("test_key", "test_value", "test_user") await service.set_setting("test_key", "test_value", "test_user")
+3 -7
View File
@@ -49,9 +49,7 @@ def _mock_pg_store():
@pytest.fixture @pytest.fixture
def app(): def app():
"""Create a test app with a mocked PgStore and audit repository.""" """Create a test app with a mocked PgStore and audit repository."""
a = create_app( a = create_app(Settings(_env_file=None, APP_MODE="paper", paper_trading_mode=True))
Settings(_env_file=None, APP_MODE="paper", paper_trading_mode=True)
)
a.state.store = _mock_pg_store() a.state.store = _mock_pg_store()
a.state.runtime_state_repository.insert = AsyncMock() a.state.runtime_state_repository.insert = AsyncMock()
a.state.runtime_state_repository.latest = AsyncMock(return_value=None) a.state.runtime_state_repository.latest = AsyncMock(return_value=None)
@@ -69,16 +67,14 @@ async def test_persist_runtime_snapshot_writes_record(app) -> None:
# Mock _open_trade_count → 0, _latest_balances → None # Mock _open_trade_count → 0, _latest_balances → None
conn = await app.state.store.pool.acquire().__aenter__() conn = await app.state.store.pool.acquire().__aenter__()
conn.fetchrow = AsyncMock(return_value=MagicMock( conn.fetchrow = AsyncMock(return_value=MagicMock(**{"__getitem__": lambda s, k: 0}))
**{"__getitem__": lambda s, k: 0}))
snapshot = await persist_runtime_snapshot(app, note="unit-test") snapshot = await persist_runtime_snapshot(app, note="unit-test")
assert snapshot is not None assert snapshot is not None
assert snapshot.note == "unit-test" assert snapshot.note == "unit-test"
app.state.runtime_state_repository.latest = AsyncMock( app.state.runtime_state_repository.latest = AsyncMock(return_value=snapshot)
return_value=snapshot)
latest = await app.state.runtime_state_repository.latest() latest = await app.state.runtime_state_repository.latest()
assert latest is not None assert latest is not None
assert latest.note == "unit-test" assert latest.note == "unit-test"