refactor: simplify stage plan return type and enhance test coverage for OSM refresh
This commit is contained in:
@@ -93,7 +93,7 @@ def build_argument_parser() -> argparse.ArgumentParser:
|
|||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
def _build_stage_plan(args: argparse.Namespace) -> tuple[list[Stage], Path, Path]:
|
def _build_stage_plan(args: argparse.Namespace) -> list[Stage]:
|
||||||
station_json = args.stations_json or args.output_dir / "osm_stations.json"
|
station_json = args.stations_json or args.output_dir / "osm_stations.json"
|
||||||
track_json = args.tracks_json or args.output_dir / "osm_tracks.json"
|
track_json = args.tracks_json or args.output_dir / "osm_tracks.json"
|
||||||
|
|
||||||
@@ -145,7 +145,7 @@ def _build_stage_plan(args: argparse.Namespace) -> tuple[list[Stage], Path, Path
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
return stages, station_json, track_json
|
return stages
|
||||||
|
|
||||||
|
|
||||||
def _describe_plan(stages: Sequence[Stage]) -> None:
|
def _describe_plan(stages: Sequence[Stage]) -> None:
|
||||||
@@ -183,19 +183,13 @@ def main(argv: list[str] | None = None) -> int:
|
|||||||
parser = build_argument_parser()
|
parser = build_argument_parser()
|
||||||
args = parser.parse_args(argv)
|
args = parser.parse_args(argv)
|
||||||
|
|
||||||
stages, station_json, track_json = _build_stage_plan(args)
|
stages = _build_stage_plan(args)
|
||||||
|
|
||||||
if args.dry_run:
|
if args.dry_run:
|
||||||
print("Dry run: the following stages would run in order.")
|
print("Dry run: the following stages would run in order.")
|
||||||
_describe_plan(stages)
|
_describe_plan(stages)
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
# Ensure parent directories exist when we plan to write files.
|
|
||||||
if not args.skip_station_import:
|
|
||||||
station_json.parent.mkdir(parents=True, exist_ok=True)
|
|
||||||
if not args.skip_track_import:
|
|
||||||
track_json.parent.mkdir(parents=True, exist_ok=True)
|
|
||||||
|
|
||||||
for stage in stages:
|
for stage in stages:
|
||||||
_execute_stage(stage)
|
_execute_stage(stage)
|
||||||
|
|
||||||
|
|||||||
158
backend/tests/test_osm_refresh.py
Normal file
158
backend/tests/test_osm_refresh.py
Normal file
@@ -0,0 +1,158 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from argparse import Namespace
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from backend.scripts import osm_refresh
|
||||||
|
|
||||||
|
|
||||||
|
def _namespace(output_dir: Path, **overrides: object) -> Namespace:
|
||||||
|
defaults: dict[str, object] = {
|
||||||
|
"region": "all",
|
||||||
|
"output_dir": output_dir,
|
||||||
|
"stations_json": None,
|
||||||
|
"tracks_json": None,
|
||||||
|
"skip_station_import": False,
|
||||||
|
"skip_station_load": False,
|
||||||
|
"skip_track_import": False,
|
||||||
|
"skip_track_load": False,
|
||||||
|
"dry_run": False,
|
||||||
|
"commit": True,
|
||||||
|
}
|
||||||
|
defaults.update(overrides)
|
||||||
|
return Namespace(**defaults)
|
||||||
|
|
||||||
|
|
||||||
|
def test_build_stage_plan_default_sequence(tmp_path: Path) -> None:
|
||||||
|
stages = osm_refresh._build_stage_plan(_namespace(tmp_path))
|
||||||
|
|
||||||
|
labels = [stage.label for stage in stages]
|
||||||
|
assert labels == [
|
||||||
|
"Import stations",
|
||||||
|
"Load stations",
|
||||||
|
"Import tracks",
|
||||||
|
"Load tracks",
|
||||||
|
]
|
||||||
|
|
||||||
|
expected_station_path = tmp_path / "osm_stations.json"
|
||||||
|
expected_track_path = tmp_path / "osm_tracks.json"
|
||||||
|
|
||||||
|
assert stages[0].output_path == expected_station_path
|
||||||
|
assert stages[1].input_path == expected_station_path
|
||||||
|
assert stages[2].output_path == expected_track_path
|
||||||
|
assert stages[3].input_path == expected_track_path
|
||||||
|
|
||||||
|
|
||||||
|
def test_build_stage_plan_respects_skip_flags(tmp_path: Path) -> None:
|
||||||
|
stages = osm_refresh._build_stage_plan(
|
||||||
|
_namespace(
|
||||||
|
tmp_path,
|
||||||
|
skip_station_import=True,
|
||||||
|
skip_track_import=True,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
labels = [stage.label for stage in stages]
|
||||||
|
assert labels == ["Load stations", "Load tracks"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_main_dry_run_lists_plan(monkeypatch: pytest.MonkeyPatch, tmp_path: Path, capsys: pytest.CaptureFixture[str]) -> None:
|
||||||
|
def fail(_args: list[str] | None) -> int: # pragma: no cover - defensive
|
||||||
|
raise AssertionError("runner should not be invoked during dry run")
|
||||||
|
|
||||||
|
monkeypatch.setattr(osm_refresh.stations_import, "main", fail)
|
||||||
|
monkeypatch.setattr(osm_refresh.tracks_import, "main", fail)
|
||||||
|
monkeypatch.setattr(osm_refresh.stations_load, "main", fail)
|
||||||
|
monkeypatch.setattr(osm_refresh.tracks_load, "main", fail)
|
||||||
|
|
||||||
|
exit_code = osm_refresh.main(["--dry-run", "--output-dir", str(tmp_path)])
|
||||||
|
|
||||||
|
assert exit_code == 0
|
||||||
|
captured = capsys.readouterr().out
|
||||||
|
assert "Dry run" in captured
|
||||||
|
assert "Import stations" in captured
|
||||||
|
assert "Load tracks" in captured
|
||||||
|
|
||||||
|
|
||||||
|
def test_main_executes_stages_in_order(monkeypatch: pytest.MonkeyPatch, tmp_path: Path) -> None:
|
||||||
|
calls: list[str] = []
|
||||||
|
|
||||||
|
def make_import(name: str):
|
||||||
|
def runner(args: list[str] | None) -> int:
|
||||||
|
assert args is not None
|
||||||
|
calls.append(name)
|
||||||
|
output_index = args.index("--output") + 1
|
||||||
|
output_path = Path(args[output_index])
|
||||||
|
output_path.write_text("{}", encoding="utf-8")
|
||||||
|
return 0
|
||||||
|
|
||||||
|
return runner
|
||||||
|
|
||||||
|
def make_load(name: str):
|
||||||
|
def runner(args: list[str] | None) -> int:
|
||||||
|
assert args is not None
|
||||||
|
calls.append(name)
|
||||||
|
return 0
|
||||||
|
|
||||||
|
return runner
|
||||||
|
|
||||||
|
monkeypatch.setattr(osm_refresh.stations_import, "main",
|
||||||
|
make_import("stations_import"))
|
||||||
|
monkeypatch.setattr(osm_refresh.tracks_import, "main",
|
||||||
|
make_import("tracks_import"))
|
||||||
|
monkeypatch.setattr(osm_refresh.stations_load, "main",
|
||||||
|
make_load("stations_load"))
|
||||||
|
monkeypatch.setattr(osm_refresh.tracks_load, "main",
|
||||||
|
make_load("tracks_load"))
|
||||||
|
|
||||||
|
exit_code = osm_refresh.main(["--output-dir", str(tmp_path)])
|
||||||
|
|
||||||
|
assert exit_code == 0
|
||||||
|
assert calls == [
|
||||||
|
"stations_import",
|
||||||
|
"stations_load",
|
||||||
|
"tracks_import",
|
||||||
|
"tracks_load",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def test_main_skip_import_flags(monkeypatch: pytest.MonkeyPatch, tmp_path: Path) -> None:
|
||||||
|
station_json = tmp_path / "stations.json"
|
||||||
|
station_json.write_text("{}", encoding="utf-8")
|
||||||
|
track_json = tmp_path / "tracks.json"
|
||||||
|
track_json.write_text("{}", encoding="utf-8")
|
||||||
|
|
||||||
|
def fail(_args: list[str] | None) -> int: # pragma: no cover - defensive
|
||||||
|
raise AssertionError("import stage should be skipped")
|
||||||
|
|
||||||
|
calls: list[str] = []
|
||||||
|
|
||||||
|
def record(name: str):
|
||||||
|
def runner(args: list[str] | None) -> int:
|
||||||
|
assert args is not None
|
||||||
|
calls.append(name)
|
||||||
|
return 0
|
||||||
|
|
||||||
|
return runner
|
||||||
|
|
||||||
|
monkeypatch.setattr(osm_refresh.stations_import, "main", fail)
|
||||||
|
monkeypatch.setattr(osm_refresh.tracks_import, "main", fail)
|
||||||
|
monkeypatch.setattr(osm_refresh.stations_load,
|
||||||
|
"main", record("stations_load"))
|
||||||
|
monkeypatch.setattr(osm_refresh.tracks_load, "main", record("tracks_load"))
|
||||||
|
|
||||||
|
exit_code = osm_refresh.main(
|
||||||
|
[
|
||||||
|
"--skip-station-import",
|
||||||
|
"--skip-track-import",
|
||||||
|
"--stations-json",
|
||||||
|
str(station_json),
|
||||||
|
"--tracks-json",
|
||||||
|
str(track_json),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
assert exit_code == 0
|
||||||
|
assert calls == ["stations_load", "tracks_load"]
|
||||||
Reference in New Issue
Block a user