from __future__ import annotations from dataclasses import dataclass, field import pytest from backend.scripts import stations_load def test_parse_station_entries_returns_models() -> None: entries = [ { "name": "Central", "latitude": 52.52, "longitude": 13.405, "osm_id": "123", "code": "BER", "elevation_m": 34.5, "is_active": True, } ] parsed = stations_load._parse_station_entries(entries) assert parsed[0].name == "Central" assert parsed[0].latitude == 52.52 assert parsed[0].osm_id == "123" def test_parse_station_entries_invalid_raises_value_error() -> None: entries = [ { "latitude": 52.52, "longitude": 13.405, "is_active": True, } ] with pytest.raises(ValueError): stations_load._parse_station_entries(entries) @dataclass class DummySession: committed: bool = False rolled_back: bool = False closed: bool = False def __enter__(self) -> "DummySession": return self def __exit__(self, exc_type, exc, traceback) -> None: self.closed = True def commit(self) -> None: self.committed = True def rollback(self) -> None: self.rolled_back = True @dataclass class DummyRepository: session: DummySession created: list = field(default_factory=list) def create(self, data) -> None: # pragma: no cover - simple delegation self.created.append(data) class DummySessionFactory: def __call__(self) -> DummySession: return DummySession() def test_load_stations_commits_when_requested(monkeypatch: pytest.MonkeyPatch) -> None: repo_instances: list[DummyRepository] = [] def fake_session_local() -> DummySession: return DummySession() def fake_repo(session: DummySession) -> DummyRepository: repo = DummyRepository(session) repo_instances.append(repo) return repo monkeypatch.setattr(stations_load, "SessionLocal", fake_session_local) monkeypatch.setattr(stations_load, "StationRepository", fake_repo) stations = stations_load._parse_station_entries( [ { "name": "Central", "latitude": 52.52, "longitude": 13.405, "osm_id": "123", "is_active": True, } ] ) created = stations_load.load_stations(stations, commit=True) assert created == 1 assert repo_instances[0].session.committed is True assert repo_instances[0].session.rolled_back is False assert len(repo_instances[0].created) == 1 def test_load_stations_rolls_back_when_no_commit( monkeypatch: pytest.MonkeyPatch, ) -> None: repo_instances: list[DummyRepository] = [] def fake_session_local() -> DummySession: return DummySession() def fake_repo(session: DummySession) -> DummyRepository: repo = DummyRepository(session) repo_instances.append(repo) return repo monkeypatch.setattr(stations_load, "SessionLocal", fake_session_local) monkeypatch.setattr(stations_load, "StationRepository", fake_repo) stations = stations_load._parse_station_entries( [ { "name": "Central", "latitude": 52.52, "longitude": 13.405, "osm_id": "123", "is_active": True, } ] ) created = stations_load.load_stations(stations, commit=False) assert created == 1 assert repo_instances[0].session.committed is False assert repo_instances[0].session.rolled_back is True