from __future__ import annotations from collections.abc import AsyncIterator from contextlib import asynccontextmanager from dataclasses import dataclass from datetime import UTC, datetime import pytest from arbitrade.config.settings import get_settings from arbitrade.detection.engine import OpportunityEvent from arbitrade.execution.sequencer import TriangularExecutionSequencer from arbitrade.storage.executions import AsyncExecutionWriter from arbitrade.storage.pg_store import PgStore from arbitrade.storage.repositories import OrderRepository, PnLRepository, TradeRepository pytestmark = pytest.mark.integration @dataclass(slots=True) class _FakeRestClient: calls: int = 0 async def place_market_order(self, *, pair: str, side: str, volume: float) -> dict[str, object]: self.calls += 1 return {"txid": [f"tx-{self.calls}"], "status": "submitted"} def _sample_event() -> OpportunityEvent: return OpportunityEvent( detected_at=datetime.now(UTC), cycle="USD->BTC->ETH->USD", updated_pair="BTC/USD", gross_rate=1.04, net_rate=1.03, gross_pct=4.0, net_pct=3.0, est_profit=0.03, ) @asynccontextmanager async def _pg() -> AsyncIterator[PgStore]: s = get_settings() store = PgStore(s) try: await store.start() await store.migrate() yield store finally: await store.stop() @pytest.mark.asyncio async def test_execution_writer_persists_trade_order_and_pnl() -> None: async with _pg() as store: writer = AsyncExecutionWriter( TradeRepository(store), OrderRepository(store), PnLRepository(store), max_queue_size=10, ) await writer.start() client = _FakeRestClient() sequencer = TriangularExecutionSequencer( client, available_pairs=["BTC/USD", "ETH/BTC", "ETH/USD"], execution_writer=writer, ) result = await sequencer.execute(_sample_event()) await writer.stop() assert result.success assert client.calls == 3 async with store.pool.acquire() as conn: trades = await conn.fetch( "SELECT trade_ref, status, estimated_pnl, capital_used, cycle, leg_count FROM trades" ) orders = await conn.fetch( "SELECT trade_ref, order_ref, leg_index, pair, side, volume, status " "FROM orders ORDER BY leg_index" ) pnls = await conn.fetch("SELECT trade_ref, kind, pnl_usd, source FROM pnl_events") assert len(trades) == 1 assert trades[0]["status"] == "filled" assert trades[0]["estimated_pnl"] == 0.03 assert trades[0]["capital_used"] == 1.0 assert trades[0]["cycle"] == "USD->BTC->ETH->USD" assert trades[0]["leg_count"] == 3 assert len(orders) == 3 assert orders[0]["leg_index"] == 0 assert orders[1]["leg_index"] == 1 assert orders[2]["leg_index"] == 2 assert orders[0]["status"] == "submitted" assert len(pnls) == 1 assert pnls[0]["kind"] == "estimated" assert pnls[0]["pnl_usd"] == 0.03