from __future__ import annotations from argparse import Namespace from pathlib import Path import pytest from backend.scripts import osm_refresh def _namespace(output_dir: Path, **overrides: object) -> Namespace: defaults: dict[str, object] = { "region": "all", "output_dir": output_dir, "stations_json": None, "tracks_json": None, "skip_station_import": False, "skip_station_load": False, "skip_track_import": False, "skip_track_load": False, "dry_run": False, "commit": True, } defaults.update(overrides) return Namespace(**defaults) def test_build_stage_plan_default_sequence(tmp_path: Path) -> None: stages = osm_refresh._build_stage_plan(_namespace(tmp_path)) labels = [stage.label for stage in stages] assert labels == [ "Import stations", "Load stations", "Import tracks", "Load tracks", ] expected_station_path = tmp_path / "osm_stations.json" expected_track_path = tmp_path / "osm_tracks.json" assert stages[0].output_path == expected_station_path assert stages[1].input_path == expected_station_path assert stages[2].output_path == expected_track_path assert stages[3].input_path == expected_track_path def test_build_stage_plan_respects_skip_flags(tmp_path: Path) -> None: stages = osm_refresh._build_stage_plan( _namespace( tmp_path, skip_station_import=True, skip_track_import=True, ) ) labels = [stage.label for stage in stages] assert labels == ["Load stations", "Load tracks"] def test_main_dry_run_lists_plan(monkeypatch: pytest.MonkeyPatch, tmp_path: Path, capsys: pytest.CaptureFixture[str]) -> None: def fail(_args: list[str] | None) -> int: # pragma: no cover - defensive raise AssertionError("runner should not be invoked during dry run") monkeypatch.setattr(osm_refresh.stations_import, "main", fail) monkeypatch.setattr(osm_refresh.tracks_import, "main", fail) monkeypatch.setattr(osm_refresh.stations_load, "main", fail) monkeypatch.setattr(osm_refresh.tracks_load, "main", fail) exit_code = osm_refresh.main(["--dry-run", "--output-dir", str(tmp_path)]) assert exit_code == 0 captured = capsys.readouterr().out assert "Dry run" in captured assert "Import stations" in captured assert "Load tracks" in captured def test_main_executes_stages_in_order(monkeypatch: pytest.MonkeyPatch, tmp_path: Path) -> None: calls: list[str] = [] def make_import(name: str): def runner(args: list[str] | None) -> int: assert args is not None calls.append(name) output_index = args.index("--output") + 1 output_path = Path(args[output_index]) output_path.write_text("{}", encoding="utf-8") return 0 return runner def make_load(name: str): def runner(args: list[str] | None) -> int: assert args is not None calls.append(name) return 0 return runner monkeypatch.setattr(osm_refresh.stations_import, "main", make_import("stations_import")) monkeypatch.setattr(osm_refresh.tracks_import, "main", make_import("tracks_import")) monkeypatch.setattr(osm_refresh.stations_load, "main", make_load("stations_load")) monkeypatch.setattr(osm_refresh.tracks_load, "main", make_load("tracks_load")) exit_code = osm_refresh.main(["--output-dir", str(tmp_path)]) assert exit_code == 0 assert calls == [ "stations_import", "stations_load", "tracks_import", "tracks_load", ] def test_main_skip_import_flags(monkeypatch: pytest.MonkeyPatch, tmp_path: Path) -> None: station_json = tmp_path / "stations.json" station_json.write_text("{}", encoding="utf-8") track_json = tmp_path / "tracks.json" track_json.write_text("{}", encoding="utf-8") def fail(_args: list[str] | None) -> int: # pragma: no cover - defensive raise AssertionError("import stage should be skipped") calls: list[str] = [] def record(name: str): def runner(args: list[str] | None) -> int: assert args is not None calls.append(name) return 0 return runner monkeypatch.setattr(osm_refresh.stations_import, "main", fail) monkeypatch.setattr(osm_refresh.tracks_import, "main", fail) monkeypatch.setattr(osm_refresh.stations_load, "main", record("stations_load")) monkeypatch.setattr(osm_refresh.tracks_load, "main", record("tracks_load")) exit_code = osm_refresh.main( [ "--skip-station-import", "--skip-track-import", "--stations-json", str(station_json), "--tracks-json", str(track_json), ] ) assert exit_code == 0 assert calls == ["stations_load", "tracks_load"]