diff --git a/backend/scripts/osm_refresh.py b/backend/scripts/osm_refresh.py index 7d6c0d3..b389fff 100644 --- a/backend/scripts/osm_refresh.py +++ b/backend/scripts/osm_refresh.py @@ -93,7 +93,7 @@ def build_argument_parser() -> argparse.ArgumentParser: return parser -def _build_stage_plan(args: argparse.Namespace) -> tuple[list[Stage], Path, Path]: +def _build_stage_plan(args: argparse.Namespace) -> list[Stage]: station_json = args.stations_json or args.output_dir / "osm_stations.json" track_json = args.tracks_json or args.output_dir / "osm_tracks.json" @@ -145,7 +145,7 @@ def _build_stage_plan(args: argparse.Namespace) -> tuple[list[Stage], Path, Path ) ) - return stages, station_json, track_json + return stages def _describe_plan(stages: Sequence[Stage]) -> None: @@ -183,19 +183,13 @@ def main(argv: list[str] | None = None) -> int: parser = build_argument_parser() args = parser.parse_args(argv) - stages, station_json, track_json = _build_stage_plan(args) + stages = _build_stage_plan(args) if args.dry_run: print("Dry run: the following stages would run in order.") _describe_plan(stages) return 0 - # Ensure parent directories exist when we plan to write files. - if not args.skip_station_import: - station_json.parent.mkdir(parents=True, exist_ok=True) - if not args.skip_track_import: - track_json.parent.mkdir(parents=True, exist_ok=True) - for stage in stages: _execute_stage(stage) diff --git a/backend/tests/test_osm_refresh.py b/backend/tests/test_osm_refresh.py new file mode 100644 index 0000000..675acb0 --- /dev/null +++ b/backend/tests/test_osm_refresh.py @@ -0,0 +1,158 @@ +from __future__ import annotations + +from argparse import Namespace +from pathlib import Path + +import pytest + +from backend.scripts import osm_refresh + + +def _namespace(output_dir: Path, **overrides: object) -> Namespace: + defaults: dict[str, object] = { + "region": "all", + "output_dir": output_dir, + "stations_json": None, + "tracks_json": None, + "skip_station_import": False, + "skip_station_load": False, + "skip_track_import": False, + "skip_track_load": False, + "dry_run": False, + "commit": True, + } + defaults.update(overrides) + return Namespace(**defaults) + + +def test_build_stage_plan_default_sequence(tmp_path: Path) -> None: + stages = osm_refresh._build_stage_plan(_namespace(tmp_path)) + + labels = [stage.label for stage in stages] + assert labels == [ + "Import stations", + "Load stations", + "Import tracks", + "Load tracks", + ] + + expected_station_path = tmp_path / "osm_stations.json" + expected_track_path = tmp_path / "osm_tracks.json" + + assert stages[0].output_path == expected_station_path + assert stages[1].input_path == expected_station_path + assert stages[2].output_path == expected_track_path + assert stages[3].input_path == expected_track_path + + +def test_build_stage_plan_respects_skip_flags(tmp_path: Path) -> None: + stages = osm_refresh._build_stage_plan( + _namespace( + tmp_path, + skip_station_import=True, + skip_track_import=True, + ) + ) + + labels = [stage.label for stage in stages] + assert labels == ["Load stations", "Load tracks"] + + +def test_main_dry_run_lists_plan(monkeypatch: pytest.MonkeyPatch, tmp_path: Path, capsys: pytest.CaptureFixture[str]) -> None: + def fail(_args: list[str] | None) -> int: # pragma: no cover - defensive + raise AssertionError("runner should not be invoked during dry run") + + monkeypatch.setattr(osm_refresh.stations_import, "main", fail) + monkeypatch.setattr(osm_refresh.tracks_import, "main", fail) + monkeypatch.setattr(osm_refresh.stations_load, "main", fail) + monkeypatch.setattr(osm_refresh.tracks_load, "main", fail) + + exit_code = osm_refresh.main(["--dry-run", "--output-dir", str(tmp_path)]) + + assert exit_code == 0 + captured = capsys.readouterr().out + assert "Dry run" in captured + assert "Import stations" in captured + assert "Load tracks" in captured + + +def test_main_executes_stages_in_order(monkeypatch: pytest.MonkeyPatch, tmp_path: Path) -> None: + calls: list[str] = [] + + def make_import(name: str): + def runner(args: list[str] | None) -> int: + assert args is not None + calls.append(name) + output_index = args.index("--output") + 1 + output_path = Path(args[output_index]) + output_path.write_text("{}", encoding="utf-8") + return 0 + + return runner + + def make_load(name: str): + def runner(args: list[str] | None) -> int: + assert args is not None + calls.append(name) + return 0 + + return runner + + monkeypatch.setattr(osm_refresh.stations_import, "main", + make_import("stations_import")) + monkeypatch.setattr(osm_refresh.tracks_import, "main", + make_import("tracks_import")) + monkeypatch.setattr(osm_refresh.stations_load, "main", + make_load("stations_load")) + monkeypatch.setattr(osm_refresh.tracks_load, "main", + make_load("tracks_load")) + + exit_code = osm_refresh.main(["--output-dir", str(tmp_path)]) + + assert exit_code == 0 + assert calls == [ + "stations_import", + "stations_load", + "tracks_import", + "tracks_load", + ] + + +def test_main_skip_import_flags(monkeypatch: pytest.MonkeyPatch, tmp_path: Path) -> None: + station_json = tmp_path / "stations.json" + station_json.write_text("{}", encoding="utf-8") + track_json = tmp_path / "tracks.json" + track_json.write_text("{}", encoding="utf-8") + + def fail(_args: list[str] | None) -> int: # pragma: no cover - defensive + raise AssertionError("import stage should be skipped") + + calls: list[str] = [] + + def record(name: str): + def runner(args: list[str] | None) -> int: + assert args is not None + calls.append(name) + return 0 + + return runner + + monkeypatch.setattr(osm_refresh.stations_import, "main", fail) + monkeypatch.setattr(osm_refresh.tracks_import, "main", fail) + monkeypatch.setattr(osm_refresh.stations_load, + "main", record("stations_load")) + monkeypatch.setattr(osm_refresh.tracks_load, "main", record("tracks_load")) + + exit_code = osm_refresh.main( + [ + "--skip-station-import", + "--skip-track-import", + "--stations-json", + str(station_json), + "--tracks-json", + str(track_json), + ] + ) + + assert exit_code == 0 + assert calls == ["stations_load", "tracks_load"]