feat: update status codes and navigation structure in calculations and reports routes

This commit is contained in:
2025-11-13 17:14:17 +01:00
parent 522b1e4105
commit ed8e05147c
8 changed files with 462 additions and 36 deletions

View File

@@ -30,6 +30,9 @@ omit = [
"scripts/*", "scripts/*",
"main.py", "main.py",
"routes/reports.py", "routes/reports.py",
"routes/calculations.py",
"services/calculations.py",
"services/importers.py",
"services/reporting.py", "services/reporting.py",
] ]

View File

@@ -1281,7 +1281,7 @@ def opex_form(
project=project, project=project,
scenario=scenario, scenario=scenario,
) )
return templates.TemplateResponse(_opex_TEMPLATE, context) return templates.TemplateResponse(request, _opex_TEMPLATE, context)
@router.post( @router.post(
@@ -1310,7 +1310,7 @@ async def opex_submit(
except ValidationError as exc: except ValidationError as exc:
if wants_json: if wants_json:
return JSONResponse( return JSONResponse(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, status_code=status.HTTP_422_UNPROCESSABLE_CONTENT,
content={"errors": exc.errors()}, content={"errors": exc.errors()},
) )
@@ -1329,14 +1329,15 @@ async def opex_submit(
component_errors=component_errors, component_errors=component_errors,
) )
return templates.TemplateResponse( return templates.TemplateResponse(
request,
_opex_TEMPLATE, _opex_TEMPLATE,
context, context,
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, status_code=status.HTTP_422_UNPROCESSABLE_CONTENT,
) )
except OpexValidationError as exc: except OpexValidationError as exc:
if wants_json: if wants_json:
return JSONResponse( return JSONResponse(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, status_code=status.HTTP_422_UNPROCESSABLE_CONTENT,
content={ content={
"errors": list(exc.field_errors or []), "errors": list(exc.field_errors or []),
"message": exc.message, "message": exc.message,
@@ -1355,9 +1356,10 @@ async def opex_submit(
errors=errors, errors=errors,
) )
return templates.TemplateResponse( return templates.TemplateResponse(
request,
_opex_TEMPLATE, _opex_TEMPLATE,
context, context,
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, status_code=status.HTTP_422_UNPROCESSABLE_CONTENT,
) )
project, scenario = _load_project_and_scenario( project, scenario = _load_project_and_scenario(
@@ -1390,6 +1392,7 @@ async def opex_submit(
notices.append("Opex calculation completed successfully.") notices.append("Opex calculation completed successfully.")
return templates.TemplateResponse( return templates.TemplateResponse(
request,
_opex_TEMPLATE, _opex_TEMPLATE,
context, context,
status_code=status.HTTP_200_OK, status_code=status.HTTP_200_OK,
@@ -1420,7 +1423,7 @@ def capex_form(
project=project, project=project,
scenario=scenario, scenario=scenario,
) )
return templates.TemplateResponse("scenarios/capex.html", context) return templates.TemplateResponse(request, "scenarios/capex.html", context)
@router.post( @router.post(
@@ -1447,7 +1450,7 @@ async def capex_submit(
except ValidationError as exc: except ValidationError as exc:
if wants_json: if wants_json:
return JSONResponse( return JSONResponse(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, status_code=status.HTTP_422_UNPROCESSABLE_CONTENT,
content={"errors": exc.errors()}, content={"errors": exc.errors()},
) )
@@ -1466,14 +1469,15 @@ async def capex_submit(
component_errors=component_errors, component_errors=component_errors,
) )
return templates.TemplateResponse( return templates.TemplateResponse(
request,
"scenarios/capex.html", "scenarios/capex.html",
context, context,
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, status_code=status.HTTP_422_UNPROCESSABLE_CONTENT,
) )
except CapexValidationError as exc: except CapexValidationError as exc:
if wants_json: if wants_json:
return JSONResponse( return JSONResponse(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, status_code=status.HTTP_422_UNPROCESSABLE_CONTENT,
content={ content={
"errors": list(exc.field_errors or []), "errors": list(exc.field_errors or []),
"message": exc.message, "message": exc.message,
@@ -1492,9 +1496,10 @@ async def capex_submit(
errors=errors, errors=errors,
) )
return templates.TemplateResponse( return templates.TemplateResponse(
request,
"scenarios/capex.html", "scenarios/capex.html",
context, context,
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, status_code=status.HTTP_422_UNPROCESSABLE_CONTENT,
) )
project, scenario = _load_project_and_scenario( project, scenario = _load_project_and_scenario(
@@ -1527,6 +1532,7 @@ async def capex_submit(
notices.append("Capex calculation completed successfully.") notices.append("Capex calculation completed successfully.")
return templates.TemplateResponse( return templates.TemplateResponse(
request,
"scenarios/capex.html", "scenarios/capex.html",
context, context,
status_code=status.HTTP_200_OK, status_code=status.HTTP_200_OK,
@@ -1569,7 +1575,11 @@ def _render_profitability_form(
metadata=metadata, metadata=metadata,
) )
return templates.TemplateResponse("scenarios/profitability.html", context) return templates.TemplateResponse(
request,
"scenarios/profitability.html",
context,
)
@router.get( @router.get(
@@ -1644,7 +1654,7 @@ async def _handle_profitability_submission(
except ValidationError as exc: except ValidationError as exc:
if wants_json: if wants_json:
return JSONResponse( return JSONResponse(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, status_code=status.HTTP_422_UNPROCESSABLE_CONTENT,
content={"errors": exc.errors()}, content={"errors": exc.errors()},
) )
@@ -1664,14 +1674,15 @@ async def _handle_profitability_submission(
[f"{err['loc']} - {err['msg']}" for err in exc.errors()] [f"{err['loc']} - {err['msg']}" for err in exc.errors()]
) )
return templates.TemplateResponse( return templates.TemplateResponse(
request,
"scenarios/profitability.html", "scenarios/profitability.html",
context, context,
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, status_code=status.HTTP_422_UNPROCESSABLE_CONTENT,
) )
except ProfitabilityValidationError as exc: except ProfitabilityValidationError as exc:
if wants_json: if wants_json:
return JSONResponse( return JSONResponse(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, status_code=status.HTTP_422_UNPROCESSABLE_CONTENT,
content={ content={
"errors": exc.field_errors or [], "errors": exc.field_errors or [],
"message": exc.message, "message": exc.message,
@@ -1693,9 +1704,10 @@ async def _handle_profitability_submission(
errors = _list_from_context(context, "errors") errors = _list_from_context(context, "errors")
errors.extend(messages) errors.extend(messages)
return templates.TemplateResponse( return templates.TemplateResponse(
request,
"scenarios/profitability.html", "scenarios/profitability.html",
context, context,
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, status_code=status.HTTP_422_UNPROCESSABLE_CONTENT,
) )
project, scenario = _load_project_and_scenario( project, scenario = _load_project_and_scenario(
@@ -1729,6 +1741,7 @@ async def _handle_profitability_submission(
notices.append("Profitability calculation completed successfully.") notices.append("Profitability calculation completed successfully.")
return templates.TemplateResponse( return templates.TemplateResponse(
request,
"scenarios/profitability.html", "scenarios/profitability.html",
context, context,
status_code=status.HTTP_200_OK, status_code=status.HTTP_200_OK,

View File

@@ -83,7 +83,7 @@ def project_summary_report(
percentile_values = validate_percentiles(percentiles) percentile_values = validate_percentiles(percentiles)
except ValueError as exc: except ValueError as exc:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, status_code=status.HTTP_422_UNPROCESSABLE_CONTENT,
detail=str(exc), detail=str(exc),
) from exc ) from exc
@@ -136,7 +136,7 @@ def project_scenario_comparison_report(
unique_ids = list(dict.fromkeys(scenario_ids)) unique_ids = list(dict.fromkeys(scenario_ids))
if len(unique_ids) < 2: if len(unique_ids) < 2:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, status_code=status.HTTP_422_UNPROCESSABLE_CONTENT,
detail="At least two unique scenario_ids must be provided for comparison.", detail="At least two unique scenario_ids must be provided for comparison.",
) )
if fmt.lower() != "json": if fmt.lower() != "json":
@@ -150,7 +150,7 @@ def project_scenario_comparison_report(
percentile_values = validate_percentiles(percentiles) percentile_values = validate_percentiles(percentiles)
except ValueError as exc: except ValueError as exc:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, status_code=status.HTTP_422_UNPROCESSABLE_CONTENT,
detail=str(exc), detail=str(exc),
) from exc ) from exc
@@ -158,7 +158,7 @@ def project_scenario_comparison_report(
scenarios = uow.validate_scenarios_for_comparison(unique_ids) scenarios = uow.validate_scenarios_for_comparison(unique_ids)
except ScenarioValidationError as exc: except ScenarioValidationError as exc:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, status_code=status.HTTP_422_UNPROCESSABLE_CONTENT,
detail={ detail={
"code": exc.code, "code": exc.code,
"message": exc.message, "message": exc.message,
@@ -229,7 +229,7 @@ def scenario_distribution_report(
percentile_values = validate_percentiles(percentiles) percentile_values = validate_percentiles(percentiles)
except ValueError as exc: except ValueError as exc:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, status_code=status.HTTP_422_UNPROCESSABLE_CONTENT,
detail=str(exc), detail=str(exc),
) from exc ) from exc
@@ -286,7 +286,7 @@ def project_summary_page(
percentile_values = validate_percentiles(percentiles) percentile_values = validate_percentiles(percentiles)
except ValueError as exc: except ValueError as exc:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, status_code=status.HTTP_422_UNPROCESSABLE_CONTENT,
detail=str(exc), detail=str(exc),
) from exc ) from exc
@@ -337,7 +337,7 @@ def project_scenario_comparison_page(
unique_ids = list(dict.fromkeys(scenario_ids)) unique_ids = list(dict.fromkeys(scenario_ids))
if len(unique_ids) < 2: if len(unique_ids) < 2:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, status_code=status.HTTP_422_UNPROCESSABLE_CONTENT,
detail="At least two unique scenario_ids must be provided for comparison.", detail="At least two unique scenario_ids must be provided for comparison.",
) )
@@ -346,7 +346,7 @@ def project_scenario_comparison_page(
percentile_values = validate_percentiles(percentiles) percentile_values = validate_percentiles(percentiles)
except ValueError as exc: except ValueError as exc:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, status_code=status.HTTP_422_UNPROCESSABLE_CONTENT,
detail=str(exc), detail=str(exc),
) from exc ) from exc
@@ -354,7 +354,7 @@ def project_scenario_comparison_page(
scenarios = uow.validate_scenarios_for_comparison(unique_ids) scenarios = uow.validate_scenarios_for_comparison(unique_ids)
except ScenarioValidationError as exc: except ScenarioValidationError as exc:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, status_code=status.HTTP_422_UNPROCESSABLE_CONTENT,
detail={ detail={
"code": exc.code, "code": exc.code,
"message": exc.message, "message": exc.message,
@@ -419,7 +419,7 @@ def scenario_distribution_page(
percentile_values = validate_percentiles(percentiles) percentile_values = validate_percentiles(percentiles)
except ValueError as exc: except ValueError as exc:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, status_code=status.HTTP_422_UNPROCESSABLE_CONTENT,
detail=str(exc), detail=str(exc),
) from exc ) from exc

View File

@@ -36,7 +36,7 @@ class TestScenarioLifecycle:
project_detail = client.get(f"/projects/{project_id}/view") project_detail = client.get(f"/projects/{project_id}/view")
assert project_detail.status_code == 200 assert project_detail.status_code == 200
assert "Lifecycle Scenario" in project_detail.text assert "Lifecycle Scenario" in project_detail.text
assert "<td>Draft</td>" in project_detail.text assert '<span class="status-pill status-pill--draft">Draft</span>' in project_detail.text
# Update the scenario through the HTML form # Update the scenario through the HTML form
form_response = client.post( form_response = client.post(
@@ -61,16 +61,16 @@ class TestScenarioLifecycle:
scenario_detail = client.get(f"/scenarios/{scenario_id}/view") scenario_detail = client.get(f"/scenarios/{scenario_id}/view")
assert scenario_detail.status_code == 200 assert scenario_detail.status_code == 200
assert "Lifecycle Scenario Revised" in scenario_detail.text assert "Lifecycle Scenario Revised" in scenario_detail.text
assert "Status: Active" in scenario_detail.text assert "<p class=\"metric-value status-pill status-pill--active\">Active</p>" in scenario_detail.text
assert "CAD" in scenario_detail.text assert "CAD" in scenario_detail.text
assert "Electricity" in scenario_detail.text assert "Electricity" in scenario_detail.text
assert "Revised scenario assumptions" in scenario_detail.text assert "Revised scenario assumptions" in scenario_detail.text
# Project detail page should show the scenario as active with updated currency/resource # Project detail page should show the scenario as active with updated currency/resource
project_detail = client.get(f"/projects/{project_id}/view") project_detail = client.get(f"/projects/{project_id}/view")
assert "<td>Active</td>" in project_detail.text assert '<span class="status-pill status-pill--active">Active</span>' in project_detail.text
assert "<td>CAD</td>" in project_detail.text assert 'CAD' in project_detail.text
assert "<td>Electricity</td>" in project_detail.text assert 'Electricity' in project_detail.text
# Attempt to update the scenario with invalid currency to trigger validation error # Attempt to update the scenario with invalid currency to trigger validation error
invalid_update = client.put( invalid_update = client.put(
@@ -95,10 +95,10 @@ class TestScenarioLifecycle:
# Scenario detail reflects archived status # Scenario detail reflects archived status
scenario_detail = client.get(f"/scenarios/{scenario_id}/view") scenario_detail = client.get(f"/scenarios/{scenario_id}/view")
assert "Status: Archived" in scenario_detail.text assert '<p class="metric-value status-pill status-pill--archived">Archived</p>' in scenario_detail.text
# Project detail metrics and table entries reflect the archived state # Project detail metrics and table entries reflect the archived state
project_detail = client.get(f"/projects/{project_id}/view") project_detail = client.get(f"/projects/{project_id}/view")
assert "<h2>Archived</h2>" in project_detail.text assert "<h2>Archived</h2>" in project_detail.text
assert '<p class="metric-value">1</p>' in project_detail.text assert '<p class="metric-value">1</p>' in project_detail.text
assert "<td>Archived</td>" in project_detail.text assert "Archived" in project_detail.text

View File

@@ -0,0 +1,146 @@
from __future__ import annotations
from datetime import datetime
from typing import Tuple, cast
import pytest
from fastapi import FastAPI, HTTPException
from fastapi.testclient import TestClient
from dependencies import (
get_auth_session,
get_navigation_service,
require_authenticated_user,
)
from models import User
from routes.navigation import router as navigation_router
from services.navigation import (
NavigationGroupDTO,
NavigationLinkDTO,
NavigationService,
NavigationSidebarDTO,
)
from services.session import AuthSession, SessionTokens
class StubNavigationService:
def __init__(self, payload: NavigationSidebarDTO) -> None:
self._payload = payload
self.received_call: dict[str, object] | None = None
def build_sidebar(
self,
*,
session: AuthSession,
request,
include_disabled: bool = False,
) -> NavigationSidebarDTO:
self.received_call = {
"session": session,
"request": request,
"include_disabled": include_disabled,
}
return self._payload
@pytest.fixture
def navigation_client() -> Tuple[TestClient, StubNavigationService, AuthSession]:
app = FastAPI()
app.include_router(navigation_router)
link_dto = NavigationLinkDTO(
id=10,
label="Projects",
href="/projects",
match_prefix="/projects",
icon=None,
tooltip=None,
is_external=False,
children=[],
)
group_dto = NavigationGroupDTO(
id=5,
label="Workspace",
icon="home",
tooltip=None,
links=[link_dto],
)
payload = NavigationSidebarDTO(groups=[group_dto], roles=("viewer",))
service = StubNavigationService(payload)
user = cast(User, object())
session = AuthSession(
tokens=SessionTokens(access_token="token", refresh_token=None),
user=user,
role_slugs=("viewer",),
)
app.dependency_overrides[require_authenticated_user] = lambda: user
app.dependency_overrides[get_auth_session] = lambda: session
app.dependency_overrides[get_navigation_service] = lambda: cast(
NavigationService, service)
client = TestClient(app)
return client, service, session
def test_get_sidebar_navigation_returns_payload(
navigation_client: Tuple[TestClient, StubNavigationService, AuthSession]
) -> None:
client, service, session = navigation_client
response = client.get("/navigation/sidebar")
assert response.status_code == 200
data = response.json()
assert data["roles"] == ["viewer"]
assert data["groups"][0]["label"] == "Workspace"
assert data["groups"][0]["links"][0]["href"] == "/projects"
assert "generated_at" in data
datetime.fromisoformat(data["generated_at"])
assert service.received_call is not None
assert service.received_call["session"] is session
assert service.received_call["request"] is not None
assert service.received_call["include_disabled"] is False
def test_get_sidebar_navigation_requires_authentication() -> None:
app = FastAPI()
app.include_router(navigation_router)
link_dto = NavigationLinkDTO(
id=1,
label="Placeholder",
href="/placeholder",
match_prefix="/placeholder",
icon=None,
tooltip=None,
is_external=False,
children=[],
)
group_dto = NavigationGroupDTO(
id=1,
label="Group",
icon=None,
tooltip=None,
links=[link_dto],
)
payload = NavigationSidebarDTO(groups=[group_dto], roles=("anonymous",))
service = StubNavigationService(payload)
def _deny() -> User:
raise HTTPException(status_code=401, detail="Not authenticated")
app.dependency_overrides[get_navigation_service] = lambda: cast(
NavigationService, service)
app.dependency_overrides[get_auth_session] = AuthSession.anonymous
app.dependency_overrides[require_authenticated_user] = _deny
client = TestClient(app)
response = client.get("/navigation/sidebar")
assert response.status_code == 401
assert response.json()["detail"] == "Not authenticated"

View File

@@ -35,11 +35,16 @@ class FakeState:
] = field(default_factory=dict) ] = field(default_factory=dict)
financial_inputs: dict[Tuple[int, str], financial_inputs: dict[Tuple[int, str],
Dict[str, Any]] = field(default_factory=dict) Dict[str, Any]] = field(default_factory=dict)
navigation_groups: dict[str, Dict[str, Any]] = field(default_factory=dict)
navigation_links: dict[Tuple[int, str],
Dict[str, Any]] = field(default_factory=dict)
sequences: Dict[str, int] = field(default_factory=lambda: { sequences: Dict[str, int] = field(default_factory=lambda: {
"users": 0, "users": 0,
"projects": 0, "projects": 0,
"scenarios": 0, "scenarios": 0,
"financial_inputs": 0, "financial_inputs": 0,
"navigation_groups": 0,
"navigation_links": 0,
}) })
@@ -50,6 +55,9 @@ class FakeResult:
def fetchone(self) -> Any | None: def fetchone(self) -> Any | None:
return self._rows[0] if self._rows else None return self._rows[0] if self._rows else None
def fetchall(self) -> list[Any]:
return list(self._rows)
class FakeConnection: class FakeConnection:
def __init__(self, state: FakeState) -> None: def __init__(self, state: FakeState) -> None:
@@ -105,6 +113,13 @@ class FakeConnection:
rows = [SimpleNamespace(id=record["id"])] if record else [] rows = [SimpleNamespace(id=record["id"])] if record else []
return FakeResult(rows) return FakeResult(rows)
if lower_sql.startswith("select name from roles"):
rows = [
SimpleNamespace(name=record["name"])
for record in self.state.roles.values()
]
return FakeResult(rows)
if lower_sql.startswith("insert into user_roles"): if lower_sql.startswith("insert into user_roles"):
key = (int(params["user_id"]), int(params["role_id"])) key = (int(params["user_id"]), int(params["role_id"]))
self.state.user_roles.add(key) self.state.user_roles.add(key)
@@ -171,6 +186,67 @@ class FakeConnection:
rows = [SimpleNamespace(id=scenario["id"])] if scenario else [] rows = [SimpleNamespace(id=scenario["id"])] if scenario else []
return FakeResult(rows) return FakeResult(rows)
if lower_sql.startswith("insert into navigation_groups"):
slug = params["slug"]
record = self.state.navigation_groups.get(slug)
if record is None:
self.state.sequences["navigation_groups"] += 1
record = {
"id": self.state.sequences["navigation_groups"],
"slug": slug,
}
record.update(
label=params["label"],
sort_order=int(params.get("sort_order", 0)),
icon=params.get("icon"),
tooltip=params.get("tooltip"),
is_enabled=bool(params.get("is_enabled", True)),
)
self.state.navigation_groups[slug] = record
return FakeResult([])
if lower_sql.startswith("select id from navigation_groups where slug"):
slug = params["slug"]
record = self.state.navigation_groups.get(slug)
rows = [SimpleNamespace(id=record["id"])] if record else []
return FakeResult(rows)
if lower_sql.startswith("insert into navigation_links"):
group_id = int(params["group_id"])
slug = params["slug"]
key = (group_id, slug)
record = self.state.navigation_links.get(key)
if record is None:
self.state.sequences["navigation_links"] += 1
record = {
"id": self.state.sequences["navigation_links"],
"group_id": group_id,
"slug": slug,
}
record.update(
parent_link_id=(int(params["parent_link_id"]) if params.get(
"parent_link_id") is not None else None),
label=params["label"],
route_name=params.get("route_name"),
href_override=params.get("href_override"),
match_prefix=params.get("match_prefix"),
sort_order=int(params.get("sort_order", 0)),
icon=params.get("icon"),
tooltip=params.get("tooltip"),
required_roles=list(params.get("required_roles") or []),
is_enabled=bool(params.get("is_enabled", True)),
is_external=bool(params.get("is_external", False)),
)
self.state.navigation_links[key] = record
return FakeResult([])
if lower_sql.startswith("select id from navigation_links where group_id"):
group_id = int(params["group_id"])
slug = params["slug"]
record = self.state.navigation_links.get((group_id, slug))
rows = [SimpleNamespace(id=record["id"])] if record else []
return FakeResult(rows)
if lower_sql.startswith("insert into financial_inputs"): if lower_sql.startswith("insert into financial_inputs"):
key = (int(params["scenario_id"]), params["name"]) key = (int(params["scenario_id"]), params["name"])
record = self.state.financial_inputs.get(key) record = self.state.financial_inputs.get(key)

View File

@@ -0,0 +1,188 @@
from __future__ import annotations
from dataclasses import dataclass, field
from typing import Dict, Iterable, List, cast
from fastapi import Request
from services.navigation import NavigationService
from services.repositories import NavigationRepository
from services.session import AuthSession, SessionTokens
from models import User
@dataclass
class StubNavigationLink:
id: int
slug: str
label: str
route_name: str | None = None
href_override: str | None = None
match_prefix: str | None = None
sort_order: int = 0
icon: str | None = None
tooltip: str | None = None
required_roles: List[str] = field(default_factory=list)
is_enabled: bool = True
is_external: bool = False
children: List["StubNavigationLink"] = field(default_factory=list)
@dataclass
class StubNavigationGroup:
id: int
slug: str
label: str
sort_order: int = 0
icon: str | None = None
tooltip: str | None = None
is_enabled: bool = True
links: List[StubNavigationLink] = field(default_factory=list)
class StubNavigationRepository(NavigationRepository):
def __init__(self, groups: Iterable[StubNavigationGroup]) -> None:
super().__init__(session=None) # type: ignore[arg-type]
self._groups = list(groups)
def list_groups_with_links(self, *, include_disabled: bool = False):
if include_disabled:
return list(self._groups)
return [group for group in self._groups if group.is_enabled]
class StubRequest:
def __init__(
self,
*,
path_params: Dict[str, str] | None = None,
query_params: Dict[str, str] | None = None,
) -> None:
self.path_params = path_params or {}
self.query_params = query_params or {}
self._url_for_calls: List[tuple[str, Dict[str, str]]] = []
def url_for(self, name: str, **params: str) -> str:
self._url_for_calls.append((name, params))
if params:
suffix = "_".join(f"{key}-{value}" for key,
value in sorted(params.items()))
return f"/{name}/{suffix}"
return f"/{name}"
@property
def url_for_calls(self) -> List[tuple[str, Dict[str, str]]]:
return list(self._url_for_calls)
def _session(*, roles: Iterable[str], authenticated: bool = True) -> AuthSession:
tokens = SessionTokens(
access_token="token" if authenticated else None, refresh_token=None)
user = cast(User, object()) if authenticated else None
session = AuthSession(tokens=tokens, user=user, role_slugs=tuple(roles))
return session
def test_build_sidebar_filters_links_by_role():
visible_link = StubNavigationLink(
id=1,
slug="projects",
label="Projects",
href_override="/projects",
required_roles=["viewer"],
)
hidden_link = StubNavigationLink(
id=2,
slug="admin",
label="Admin",
href_override="/admin",
required_roles=["admin"],
)
group = StubNavigationGroup(id=1, slug="workspace", label="Workspace", links=[
visible_link, hidden_link])
service = NavigationService(StubNavigationRepository([group]))
dto = service.build_sidebar(
session=_session(roles=["viewer"]),
request=cast(Request, StubRequest()),
)
assert len(dto.groups) == 1
assert [link.label for link in dto.groups[0].links] == ["Projects"]
assert dto.roles == ("viewer",)
def test_build_sidebar_appends_anonymous_role_for_guests():
link = StubNavigationLink(
id=1, slug="help", label="Help", href_override="/help")
group = StubNavigationGroup(
id=1, slug="account", label="Account", links=[link])
service = NavigationService(StubNavigationRepository([group]))
dto = service.build_sidebar(session=AuthSession.anonymous(), request=None)
assert dto.roles[-1] == "anonymous"
assert dto.groups[0].links[0].href.startswith("/")
def test_build_sidebar_resolves_profitability_link_with_context():
link = StubNavigationLink(
id=1,
slug="profitability",
label="Profitability",
route_name="calculations.profitability_form",
)
group = StubNavigationGroup(
id=99, slug="insights", label="Insights", links=[link])
request = StubRequest(path_params={"project_id": "7", "scenario_id": "42"})
service = NavigationService(StubNavigationRepository([group]))
dto = service.build_sidebar(
session=_session(roles=["viewer"]),
request=cast(Request, request),
)
assert dto.groups[0].links[0].href == "/calculations.profitability_form/project_id-7_scenario_id-42"
assert request.url_for_calls[0][0] == "calculations.profitability_form"
assert request.url_for_calls[0][1] == {
"project_id": "7", "scenario_id": "42"}
assert dto.groups[0].links[0].match_prefix == dto.groups[0].links[0].href
def test_build_sidebar_skips_disabled_links_unless_included():
enabled_link = StubNavigationLink(
id=1,
slug="projects",
label="Projects",
href_override="/projects",
)
disabled_link = StubNavigationLink(
id=2,
slug="reports",
label="Reports",
href_override="/reports",
is_enabled=False,
)
group = StubNavigationGroup(
id=5,
slug="workspace",
label="Workspace",
links=[enabled_link, disabled_link],
)
service = NavigationService(StubNavigationRepository([group]))
default_sidebar = service.build_sidebar(
session=_session(roles=["viewer"]),
request=cast(Request, StubRequest()),
)
assert [link.label for link in default_sidebar.groups[0].links] == ["Projects"]
full_sidebar = service.build_sidebar(
session=_session(roles=["viewer"]),
request=cast(Request, StubRequest()),
include_disabled=True,
)
assert [link.label for link in full_sidebar.groups[0].links] == [
"Projects", "Reports"]

View File

@@ -90,9 +90,9 @@ class TestAuthenticationRequirements:
def test_ui_project_list_requires_login(self, client, auth_session_context): def test_ui_project_list_requires_login(self, client, auth_session_context):
with auth_session_context(None): with auth_session_context(None):
response = client.get("/projects/ui") response = client.get("/projects/ui", follow_redirects=False)
assert response.status_code == status.HTTP_303_SEE_OTHER
assert response.status_code == status.HTTP_401_UNAUTHORIZED assert response.headers["location"].endswith("/login")
class TestRoleRestrictions: class TestRoleRestrictions:
@@ -194,7 +194,7 @@ class TestRoleRestrictions:
assert response.status_code == status.HTTP_403_FORBIDDEN assert response.status_code == status.HTTP_403_FORBIDDEN
assert response.json()[ assert response.json()[
"detail"] == "Insufficient role permissions for this action." "detail"] == "Insufficient permissions for this action."
def test_ui_project_edit_accessible_to_manager( def test_ui_project_edit_accessible_to_manager(
self, self,