Refactor test cases for improved readability and consistency
- Updated test functions in various test files to enhance code clarity by formatting long lines and improving indentation. - Adjusted assertions to use multi-line formatting for better readability. - Added new test cases for theme settings API to ensure proper functionality. - Ensured consistent use of line breaks and spacing across test files for uniformity.
This commit is contained in:
@@ -27,7 +27,8 @@ engine = create_engine(
|
||||
poolclass=StaticPool,
|
||||
)
|
||||
TestingSessionLocal = sessionmaker(
|
||||
autocommit=False, autoflush=False, bind=engine)
|
||||
autocommit=False, autoflush=False, bind=engine
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session", autouse=True)
|
||||
@@ -37,19 +38,24 @@ def setup_database() -> Generator[None, None, None]:
|
||||
application_setting,
|
||||
capex,
|
||||
consumption,
|
||||
currency,
|
||||
distribution,
|
||||
equipment,
|
||||
maintenance,
|
||||
opex,
|
||||
parameters,
|
||||
production_output,
|
||||
role,
|
||||
scenario,
|
||||
simulation_result,
|
||||
theme_setting,
|
||||
user,
|
||||
) # noqa: F401 - imported for side effects
|
||||
|
||||
_ = (
|
||||
capex,
|
||||
consumption,
|
||||
currency,
|
||||
distribution,
|
||||
equipment,
|
||||
maintenance,
|
||||
@@ -57,8 +63,11 @@ def setup_database() -> Generator[None, None, None]:
|
||||
opex,
|
||||
parameters,
|
||||
production_output,
|
||||
role,
|
||||
scenario,
|
||||
simulation_result,
|
||||
theme_setting,
|
||||
user,
|
||||
)
|
||||
|
||||
Base.metadata.create_all(bind=engine)
|
||||
@@ -86,22 +95,23 @@ def api_client(db_session: Session) -> Generator[TestClient, None, None]:
|
||||
finally:
|
||||
pass
|
||||
|
||||
from routes import dependencies as route_dependencies
|
||||
from routes.dependencies import get_db
|
||||
|
||||
app.dependency_overrides[route_dependencies.get_db] = override_get_db
|
||||
app.dependency_overrides[get_db] = override_get_db
|
||||
|
||||
with TestClient(app) as client:
|
||||
yield client
|
||||
|
||||
app.dependency_overrides.pop(route_dependencies.get_db, None)
|
||||
app.dependency_overrides.pop(get_db, None)
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def seeded_ui_data(db_session: Session) -> Generator[Dict[str, Any], None, None]:
|
||||
def seeded_ui_data(
|
||||
db_session: Session,
|
||||
) -> Generator[Dict[str, Any], None, None]:
|
||||
"""Populate a scenario with representative related records for UI tests."""
|
||||
scenario_name = f"Scenario Alpha {uuid4()}"
|
||||
scenario = Scenario(name=scenario_name,
|
||||
description="Seeded UI scenario")
|
||||
scenario = Scenario(name=scenario_name, description="Seeded UI scenario")
|
||||
db_session.add(scenario)
|
||||
db_session.flush()
|
||||
|
||||
@@ -161,7 +171,9 @@ def seeded_ui_data(db_session: Session) -> Generator[Dict[str, Any], None, None]
|
||||
iteration=index,
|
||||
result=value,
|
||||
)
|
||||
for index, value in enumerate((950_000.0, 975_000.0, 990_000.0), start=1)
|
||||
for index, value in enumerate(
|
||||
(950_000.0, 975_000.0, 990_000.0), start=1
|
||||
)
|
||||
]
|
||||
|
||||
db_session.add(maintenance)
|
||||
@@ -196,11 +208,15 @@ def seeded_ui_data(db_session: Session) -> Generator[Dict[str, Any], None, None]
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def invalid_request_payloads(db_session: Session) -> Generator[Dict[str, Any], None, None]:
|
||||
def invalid_request_payloads(
|
||||
db_session: Session,
|
||||
) -> Generator[Dict[str, Any], None, None]:
|
||||
"""Provide reusable invalid request bodies for exercising validation branches."""
|
||||
duplicate_name = f"Scenario Duplicate {uuid4()}"
|
||||
existing = Scenario(name=duplicate_name,
|
||||
description="Existing scenario for duplicate checks")
|
||||
existing = Scenario(
|
||||
name=duplicate_name,
|
||||
description="Existing scenario for duplicate checks",
|
||||
)
|
||||
db_session.add(existing)
|
||||
db_session.commit()
|
||||
|
||||
|
||||
231
tests/unit/test_auth.py
Normal file
231
tests/unit/test_auth.py
Normal file
@@ -0,0 +1,231 @@
|
||||
from services.security import get_password_hash, verify_password
|
||||
|
||||
|
||||
def test_password_hashing():
|
||||
password = "testpassword"
|
||||
hashed_password = get_password_hash(password)
|
||||
assert verify_password(password, hashed_password)
|
||||
assert not verify_password("wrongpassword", hashed_password)
|
||||
|
||||
|
||||
def test_register_user(api_client):
|
||||
response = api_client.post(
|
||||
"/users/register",
|
||||
json={
|
||||
"username": "testuser",
|
||||
"email": "test@example.com",
|
||||
"password": "testpassword",
|
||||
},
|
||||
)
|
||||
assert response.status_code == 201
|
||||
data = response.json()
|
||||
assert data["username"] == "testuser"
|
||||
assert data["email"] == "test@example.com"
|
||||
assert "id" in data
|
||||
assert "role_id" in data
|
||||
|
||||
response = api_client.post(
|
||||
"/users/register",
|
||||
json={
|
||||
"username": "testuser",
|
||||
"email": "another@example.com",
|
||||
"password": "testpassword",
|
||||
},
|
||||
)
|
||||
assert response.status_code == 400
|
||||
assert response.json() == {"detail": "Username already registered"}
|
||||
|
||||
response = api_client.post(
|
||||
"/users/register",
|
||||
json={
|
||||
"username": "anotheruser",
|
||||
"email": "test@example.com",
|
||||
"password": "testpassword",
|
||||
},
|
||||
)
|
||||
assert response.status_code == 400
|
||||
assert response.json() == {"detail": "Email already registered"}
|
||||
|
||||
|
||||
def test_login_user(api_client):
|
||||
# Register a user first
|
||||
api_client.post(
|
||||
"/users/register",
|
||||
json={
|
||||
"username": "loginuser",
|
||||
"email": "login@example.com",
|
||||
"password": "loginpassword",
|
||||
},
|
||||
)
|
||||
|
||||
response = api_client.post(
|
||||
"/users/login",
|
||||
json={"username": "loginuser", "password": "loginpassword"},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "access_token" in data
|
||||
assert data["token_type"] == "bearer"
|
||||
|
||||
response = api_client.post(
|
||||
"/users/login",
|
||||
json={"username": "loginuser", "password": "wrongpassword"},
|
||||
)
|
||||
assert response.status_code == 401
|
||||
assert response.json() == {"detail": "Incorrect username or password"}
|
||||
|
||||
response = api_client.post(
|
||||
"/users/login",
|
||||
json={"username": "nonexistent", "password": "password"},
|
||||
)
|
||||
assert response.status_code == 401
|
||||
assert response.json() == {"detail": "Incorrect username or password"}
|
||||
|
||||
|
||||
def test_read_users_me(api_client):
|
||||
# Register a user first
|
||||
api_client.post(
|
||||
"/users/register",
|
||||
json={
|
||||
"username": "profileuser",
|
||||
"email": "profile@example.com",
|
||||
"password": "profilepassword",
|
||||
},
|
||||
)
|
||||
# Login to get a token
|
||||
login_response = api_client.post(
|
||||
"/users/login",
|
||||
json={"username": "profileuser", "password": "profilepassword"},
|
||||
)
|
||||
token = login_response.json()["access_token"]
|
||||
|
||||
response = api_client.get(
|
||||
"/users/me", headers={"Authorization": f"Bearer {token}"}
|
||||
)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["username"] == "profileuser"
|
||||
assert data["email"] == "profile@example.com"
|
||||
|
||||
|
||||
def test_update_users_me(api_client):
|
||||
# Register a user first
|
||||
api_client.post(
|
||||
"/users/register",
|
||||
json={
|
||||
"username": "updateuser",
|
||||
"email": "update@example.com",
|
||||
"password": "updatepassword",
|
||||
},
|
||||
)
|
||||
# Login to get a token
|
||||
login_response = api_client.post(
|
||||
"/users/login",
|
||||
json={"username": "updateuser", "password": "updatepassword"},
|
||||
)
|
||||
token = login_response.json()["access_token"]
|
||||
|
||||
response = api_client.put(
|
||||
"/users/me",
|
||||
headers={"Authorization": f"Bearer {token}"},
|
||||
json={
|
||||
"username": "updateduser",
|
||||
"email": "updated@example.com",
|
||||
"password": "newpassword",
|
||||
},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["username"] == "updateduser"
|
||||
assert data["email"] == "updated@example.com"
|
||||
|
||||
# Verify password change
|
||||
response = api_client.post(
|
||||
"/users/login",
|
||||
json={"username": "updateduser", "password": "newpassword"},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
token = response.json()["access_token"]
|
||||
|
||||
# Test username already taken
|
||||
api_client.post(
|
||||
"/users/register",
|
||||
json={
|
||||
"username": "anotherupdateuser",
|
||||
"email": "anotherupdate@example.com",
|
||||
"password": "password",
|
||||
},
|
||||
)
|
||||
response = api_client.put(
|
||||
"/users/me",
|
||||
headers={"Authorization": f"Bearer {token}"},
|
||||
json={
|
||||
"username": "anotherupdateuser",
|
||||
},
|
||||
)
|
||||
assert response.status_code == 400
|
||||
assert response.json() == {"detail": "Username already taken"}
|
||||
|
||||
# Test email already registered
|
||||
api_client.post(
|
||||
"/users/register",
|
||||
json={
|
||||
"username": "yetanotheruser",
|
||||
"email": "yetanother@example.com",
|
||||
"password": "password",
|
||||
},
|
||||
)
|
||||
response = api_client.put(
|
||||
"/users/me",
|
||||
headers={"Authorization": f"Bearer {token}"},
|
||||
json={
|
||||
"email": "yetanother@example.com",
|
||||
},
|
||||
)
|
||||
assert response.status_code == 400
|
||||
assert response.json() == {"detail": "Email already registered"}
|
||||
|
||||
|
||||
def test_forgot_password(api_client):
|
||||
response = api_client.post(
|
||||
"/users/forgot-password", json={"email": "nonexistent@example.com"}
|
||||
)
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {
|
||||
"message": "Password reset email sent (not really)"}
|
||||
|
||||
|
||||
def test_reset_password(api_client):
|
||||
# Register a user first
|
||||
api_client.post(
|
||||
"/users/register",
|
||||
json={
|
||||
"username": "resetuser",
|
||||
"email": "reset@example.com",
|
||||
"password": "oldpassword",
|
||||
},
|
||||
)
|
||||
|
||||
response = api_client.post(
|
||||
"/users/reset-password",
|
||||
json={
|
||||
"token": "resetuser", # Use username as token for test
|
||||
"new_password": "newpassword",
|
||||
},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {
|
||||
"message": "Password has been reset successfully"}
|
||||
|
||||
# Verify password change
|
||||
response = api_client.post(
|
||||
"/users/login",
|
||||
json={"username": "resetuser", "password": "newpassword"},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
|
||||
response = api_client.post(
|
||||
"/users/login",
|
||||
json={"username": "resetuser", "password": "oldpassword"},
|
||||
)
|
||||
assert response.status_code == 401
|
||||
@@ -57,8 +57,11 @@ def test_list_consumption_returns_created_items(client: TestClient) -> None:
|
||||
|
||||
list_response = client.get("/api/consumption/")
|
||||
assert list_response.status_code == 200
|
||||
items = [item for item in list_response.json(
|
||||
) if item["scenario_id"] == scenario_id]
|
||||
items = [
|
||||
item
|
||||
for item in list_response.json()
|
||||
if item["scenario_id"] == scenario_id
|
||||
]
|
||||
assert {item["amount"] for item in items} == set(values)
|
||||
|
||||
|
||||
|
||||
@@ -47,8 +47,9 @@ def test_create_and_list_capex_and_opex():
|
||||
resp3 = client.get("/api/costs/capex")
|
||||
assert resp3.status_code == 200
|
||||
data = resp3.json()
|
||||
assert any(item["amount"] == 1000.0 and item["scenario_id"]
|
||||
== sid for item in data)
|
||||
assert any(
|
||||
item["amount"] == 1000.0 and item["scenario_id"] == sid for item in data
|
||||
)
|
||||
|
||||
opex_payload = {
|
||||
"scenario_id": sid,
|
||||
@@ -66,8 +67,10 @@ def test_create_and_list_capex_and_opex():
|
||||
resp5 = client.get("/api/costs/opex")
|
||||
assert resp5.status_code == 200
|
||||
data_o = resp5.json()
|
||||
assert any(item["amount"] == 500.0 and item["scenario_id"]
|
||||
== sid for item in data_o)
|
||||
assert any(
|
||||
item["amount"] == 500.0 and item["scenario_id"] == sid
|
||||
for item in data_o
|
||||
)
|
||||
|
||||
|
||||
def test_multiple_capex_entries():
|
||||
@@ -88,8 +91,9 @@ def test_multiple_capex_entries():
|
||||
resp = client.get("/api/costs/capex")
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
retrieved_amounts = [item["amount"]
|
||||
for item in data if item["scenario_id"] == sid]
|
||||
retrieved_amounts = [
|
||||
item["amount"] for item in data if item["scenario_id"] == sid
|
||||
]
|
||||
for amount in amounts:
|
||||
assert amount in retrieved_amounts
|
||||
|
||||
@@ -112,7 +116,8 @@ def test_multiple_opex_entries():
|
||||
resp = client.get("/api/costs/opex")
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
retrieved_amounts = [item["amount"]
|
||||
for item in data if item["scenario_id"] == sid]
|
||||
retrieved_amounts = [
|
||||
item["amount"] for item in data if item["scenario_id"] == sid
|
||||
]
|
||||
for amount in amounts:
|
||||
assert amount in retrieved_amounts
|
||||
|
||||
@@ -14,7 +14,13 @@ def _cleanup_currencies(db_session):
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def _assert_currency(payload: Dict[str, object], code: str, name: str, symbol: str | None, is_active: bool) -> None:
|
||||
def _assert_currency(
|
||||
payload: Dict[str, object],
|
||||
code: str,
|
||||
name: str,
|
||||
symbol: str | None,
|
||||
is_active: bool,
|
||||
) -> None:
|
||||
assert payload["code"] == code
|
||||
assert payload["name"] == name
|
||||
assert payload["is_active"] is is_active
|
||||
@@ -47,13 +53,21 @@ def test_create_currency_success(api_client, db_session):
|
||||
def test_create_currency_conflict(api_client, db_session):
|
||||
api_client.post(
|
||||
"/api/currencies/",
|
||||
json={"code": "CAD", "name": "Canadian Dollar",
|
||||
"symbol": "$", "is_active": True},
|
||||
json={
|
||||
"code": "CAD",
|
||||
"name": "Canadian Dollar",
|
||||
"symbol": "$",
|
||||
"is_active": True,
|
||||
},
|
||||
)
|
||||
duplicate = api_client.post(
|
||||
"/api/currencies/",
|
||||
json={"code": "CAD", "name": "Canadian Dollar",
|
||||
"symbol": "$", "is_active": True},
|
||||
json={
|
||||
"code": "CAD",
|
||||
"name": "Canadian Dollar",
|
||||
"symbol": "$",
|
||||
"is_active": True,
|
||||
},
|
||||
)
|
||||
assert duplicate.status_code == 409
|
||||
|
||||
@@ -61,8 +75,12 @@ def test_create_currency_conflict(api_client, db_session):
|
||||
def test_update_currency_fields(api_client, db_session):
|
||||
api_client.post(
|
||||
"/api/currencies/",
|
||||
json={"code": "GBP", "name": "British Pound",
|
||||
"symbol": "£", "is_active": True},
|
||||
json={
|
||||
"code": "GBP",
|
||||
"name": "British Pound",
|
||||
"symbol": "£",
|
||||
"is_active": True,
|
||||
},
|
||||
)
|
||||
|
||||
response = api_client.put(
|
||||
@@ -77,8 +95,12 @@ def test_update_currency_fields(api_client, db_session):
|
||||
def test_toggle_currency_activation(api_client, db_session):
|
||||
api_client.post(
|
||||
"/api/currencies/",
|
||||
json={"code": "AUD", "name": "Australian Dollar",
|
||||
"symbol": "A$", "is_active": True},
|
||||
json={
|
||||
"code": "AUD",
|
||||
"name": "Australian Dollar",
|
||||
"symbol": "A$",
|
||||
"is_active": True,
|
||||
},
|
||||
)
|
||||
|
||||
response = api_client.patch(
|
||||
@@ -97,5 +119,7 @@ def test_default_currency_cannot_be_deactivated(api_client, db_session):
|
||||
json={"is_active": False},
|
||||
)
|
||||
assert response.status_code == 400
|
||||
assert response.json()[
|
||||
"detail"] == "The default currency cannot be deactivated."
|
||||
assert (
|
||||
response.json()["detail"]
|
||||
== "The default currency cannot be deactivated."
|
||||
)
|
||||
|
||||
@@ -41,9 +41,10 @@ def test_create_capex_with_currency_code_and_list(api_client, seeded_currency):
|
||||
resp = api_client.post("/api/costs/capex", json=payload)
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data.get("currency_code") == seeded_currency.code or data.get(
|
||||
"currency", {}
|
||||
).get("code") == seeded_currency.code
|
||||
assert (
|
||||
data.get("currency_code") == seeded_currency.code
|
||||
or data.get("currency", {}).get("code") == seeded_currency.code
|
||||
)
|
||||
|
||||
|
||||
def test_create_opex_with_currency_id(api_client, seeded_currency):
|
||||
|
||||
@@ -30,7 +30,9 @@ def _create_scenario_and_equipment(client: TestClient):
|
||||
return scenario_id, equipment_id
|
||||
|
||||
|
||||
def _create_maintenance_payload(equipment_id: int, scenario_id: int, description: str):
|
||||
def _create_maintenance_payload(
|
||||
equipment_id: int, scenario_id: int, description: str
|
||||
):
|
||||
return {
|
||||
"equipment_id": equipment_id,
|
||||
"scenario_id": scenario_id,
|
||||
@@ -43,7 +45,8 @@ def _create_maintenance_payload(equipment_id: int, scenario_id: int, description
|
||||
def test_create_and_list_maintenance(client: TestClient):
|
||||
scenario_id, equipment_id = _create_scenario_and_equipment(client)
|
||||
payload = _create_maintenance_payload(
|
||||
equipment_id, scenario_id, "Create maintenance")
|
||||
equipment_id, scenario_id, "Create maintenance"
|
||||
)
|
||||
|
||||
response = client.post("/api/maintenance/", json=payload)
|
||||
assert response.status_code == 201
|
||||
@@ -95,7 +98,8 @@ def test_update_maintenance(client: TestClient):
|
||||
}
|
||||
|
||||
response = client.put(
|
||||
f"/api/maintenance/{maintenance_id}", json=update_payload)
|
||||
f"/api/maintenance/{maintenance_id}", json=update_payload
|
||||
)
|
||||
assert response.status_code == 200
|
||||
updated = response.json()
|
||||
assert updated["maintenance_date"] == "2025-11-01"
|
||||
@@ -108,7 +112,8 @@ def test_delete_maintenance(client: TestClient):
|
||||
create_response = client.post(
|
||||
"/api/maintenance/",
|
||||
json=_create_maintenance_payload(
|
||||
equipment_id, scenario_id, "Delete maintenance"),
|
||||
equipment_id, scenario_id, "Delete maintenance"
|
||||
),
|
||||
)
|
||||
assert create_response.status_code == 201
|
||||
maintenance_id = create_response.json()["id"]
|
||||
|
||||
@@ -67,7 +67,10 @@ def test_create_and_list_parameter():
|
||||
|
||||
def test_create_parameter_for_missing_scenario():
|
||||
payload: Dict[str, Any] = {
|
||||
"scenario_id": 0, "name": "invalid", "value": 1.0}
|
||||
"scenario_id": 0,
|
||||
"name": "invalid",
|
||||
"value": 1.0,
|
||||
}
|
||||
response = client.post("/api/parameters/", json=payload)
|
||||
assert response.status_code == 404
|
||||
assert response.json()["detail"] == "Scenario not found"
|
||||
|
||||
@@ -42,7 +42,11 @@ def test_list_production_filters_by_scenario(client: TestClient) -> None:
|
||||
target_scenario = _create_scenario(client)
|
||||
other_scenario = _create_scenario(client)
|
||||
|
||||
for scenario_id, amount in [(target_scenario, 100.0), (target_scenario, 150.0), (other_scenario, 200.0)]:
|
||||
for scenario_id, amount in [
|
||||
(target_scenario, 100.0),
|
||||
(target_scenario, 150.0),
|
||||
(other_scenario, 200.0),
|
||||
]:
|
||||
response = client.post(
|
||||
"/api/production/",
|
||||
json={
|
||||
@@ -57,8 +61,11 @@ def test_list_production_filters_by_scenario(client: TestClient) -> None:
|
||||
|
||||
list_response = client.get("/api/production/")
|
||||
assert list_response.status_code == 200
|
||||
items = [item for item in list_response.json()
|
||||
if item["scenario_id"] == target_scenario]
|
||||
items = [
|
||||
item
|
||||
for item in list_response.json()
|
||||
if item["scenario_id"] == target_scenario
|
||||
]
|
||||
assert {item["amount"] for item in items} == {100.0, 150.0}
|
||||
|
||||
|
||||
|
||||
@@ -50,9 +50,11 @@ def test_generate_report_with_values():
|
||||
|
||||
|
||||
def test_generate_report_single_value():
|
||||
report = generate_report([
|
||||
{"iteration": 1, "result": 42.0},
|
||||
])
|
||||
report = generate_report(
|
||||
[
|
||||
{"iteration": 1, "result": 42.0},
|
||||
]
|
||||
)
|
||||
assert report["count"] == 1
|
||||
assert report["std_dev"] == 0.0
|
||||
assert report["variance"] == 0.0
|
||||
@@ -105,8 +107,10 @@ def test_reporting_endpoint_success(client: TestClient):
|
||||
validation_error_cases: List[tuple[List[Any], str]] = [
|
||||
(["not-a-dict"], "Entry at index 0 must be an object"),
|
||||
([{"iteration": 1}], "Entry at index 0 must include numeric 'result'"),
|
||||
([{"iteration": 1, "result": "bad"}],
|
||||
"Entry at index 0 must include numeric 'result'"),
|
||||
(
|
||||
[{"iteration": 1, "result": "bad"}],
|
||||
"Entry at index 0 must include numeric 'result'",
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
|
||||
@@ -27,7 +27,7 @@ def test_parameter_create_missing_scenario_returns_404(
|
||||
|
||||
@pytest.mark.usefixtures("invalid_request_payloads")
|
||||
def test_parameter_create_invalid_distribution_is_422(
|
||||
api_client: TestClient
|
||||
api_client: TestClient,
|
||||
) -> None:
|
||||
response = api_client.post(
|
||||
"/api/parameters/",
|
||||
@@ -90,6 +90,5 @@ def test_maintenance_negative_cost_rejected_by_schema(
|
||||
payload = invalid_request_payloads["maintenance_negative_cost"]
|
||||
response = api_client.post("/api/maintenance/", json=payload)
|
||||
assert response.status_code == 422
|
||||
error_locations = [tuple(item["loc"])
|
||||
for item in response.json()["detail"]]
|
||||
error_locations = [tuple(item["loc"]) for item in response.json()["detail"]]
|
||||
assert ("body", "cost") in error_locations
|
||||
|
||||
@@ -42,7 +42,7 @@ def test_update_css_settings_persists_changes(
|
||||
|
||||
@pytest.mark.usefixtures("db_session")
|
||||
def test_update_css_settings_invalid_value_returns_422(
|
||||
api_client: TestClient
|
||||
api_client: TestClient,
|
||||
) -> None:
|
||||
response = api_client.put(
|
||||
"/api/settings/css",
|
||||
|
||||
@@ -20,8 +20,14 @@ def fixture_clean_env(monkeypatch: pytest.MonkeyPatch) -> Dict[str, str]:
|
||||
|
||||
|
||||
def test_css_key_to_env_var_formatting():
|
||||
assert settings_service.css_key_to_env_var("--color-background") == "CALMINER_THEME_COLOR_BACKGROUND"
|
||||
assert settings_service.css_key_to_env_var("--color-primary-stronger") == "CALMINER_THEME_COLOR_PRIMARY_STRONGER"
|
||||
assert (
|
||||
settings_service.css_key_to_env_var("--color-background")
|
||||
== "CALMINER_THEME_COLOR_BACKGROUND"
|
||||
)
|
||||
assert (
|
||||
settings_service.css_key_to_env_var("--color-primary-stronger")
|
||||
== "CALMINER_THEME_COLOR_PRIMARY_STRONGER"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@@ -33,7 +39,9 @@ def test_css_key_to_env_var_formatting():
|
||||
("--color-text-secondary", "hsla(210, 40%, 40%, 1)"),
|
||||
],
|
||||
)
|
||||
def test_read_css_color_env_overrides_valid_values(clean_env, env_key, env_value):
|
||||
def test_read_css_color_env_overrides_valid_values(
|
||||
clean_env, env_key, env_value
|
||||
):
|
||||
env_var = settings_service.css_key_to_env_var(env_key)
|
||||
clean_env[env_var] = env_value
|
||||
|
||||
@@ -50,7 +58,9 @@ def test_read_css_color_env_overrides_valid_values(clean_env, env_key, env_value
|
||||
"rgb(1,2)", # malformed rgb
|
||||
],
|
||||
)
|
||||
def test_read_css_color_env_overrides_invalid_values_raise(clean_env, invalid_value):
|
||||
def test_read_css_color_env_overrides_invalid_values_raise(
|
||||
clean_env, invalid_value
|
||||
):
|
||||
env_var = settings_service.css_key_to_env_var("--color-background")
|
||||
clean_env[env_var] = invalid_value
|
||||
|
||||
@@ -64,7 +74,9 @@ def test_read_css_color_env_overrides_ignores_missing(clean_env):
|
||||
|
||||
|
||||
def test_list_css_env_override_rows_returns_structured_data(clean_env):
|
||||
clean_env[settings_service.css_key_to_env_var("--color-primary")] = "#123456"
|
||||
clean_env[settings_service.css_key_to_env_var("--color-primary")] = (
|
||||
"#123456"
|
||||
)
|
||||
rows = settings_service.list_css_env_override_rows(clean_env)
|
||||
assert rows == [
|
||||
{
|
||||
|
||||
@@ -31,10 +31,13 @@ def setup_instance(mock_config: DatabaseConfig) -> DatabaseSetup:
|
||||
return DatabaseSetup(mock_config, dry_run=True)
|
||||
|
||||
|
||||
def test_seed_baseline_data_dry_run_skips_verification(setup_instance: DatabaseSetup) -> None:
|
||||
with mock.patch("scripts.seed_data.run_with_namespace") as seed_run, mock.patch.object(
|
||||
setup_instance, "_verify_seeded_data"
|
||||
) as verify_mock:
|
||||
def test_seed_baseline_data_dry_run_skips_verification(
|
||||
setup_instance: DatabaseSetup,
|
||||
) -> None:
|
||||
with (
|
||||
mock.patch("scripts.seed_data.run_with_namespace") as seed_run,
|
||||
mock.patch.object(setup_instance, "_verify_seeded_data") as verify_mock,
|
||||
):
|
||||
setup_instance.seed_baseline_data(dry_run=True)
|
||||
|
||||
seed_run.assert_called_once()
|
||||
@@ -47,13 +50,16 @@ def test_seed_baseline_data_dry_run_skips_verification(setup_instance: DatabaseS
|
||||
verify_mock.assert_not_called()
|
||||
|
||||
|
||||
def test_seed_baseline_data_invokes_verification(setup_instance: DatabaseSetup) -> None:
|
||||
def test_seed_baseline_data_invokes_verification(
|
||||
setup_instance: DatabaseSetup,
|
||||
) -> None:
|
||||
expected_currencies = {code for code, *_ in seed_data.CURRENCY_SEEDS}
|
||||
expected_units = {code for code, *_ in seed_data.MEASUREMENT_UNIT_SEEDS}
|
||||
|
||||
with mock.patch("scripts.seed_data.run_with_namespace") as seed_run, mock.patch.object(
|
||||
setup_instance, "_verify_seeded_data"
|
||||
) as verify_mock:
|
||||
with (
|
||||
mock.patch("scripts.seed_data.run_with_namespace") as seed_run,
|
||||
mock.patch.object(setup_instance, "_verify_seeded_data") as verify_mock,
|
||||
):
|
||||
setup_instance.seed_baseline_data(dry_run=False)
|
||||
|
||||
seed_run.assert_called_once()
|
||||
@@ -67,7 +73,9 @@ def test_seed_baseline_data_invokes_verification(setup_instance: DatabaseSetup)
|
||||
)
|
||||
|
||||
|
||||
def test_run_migrations_applies_baseline_when_missing(mock_config: DatabaseConfig, tmp_path) -> None:
|
||||
def test_run_migrations_applies_baseline_when_missing(
|
||||
mock_config: DatabaseConfig, tmp_path
|
||||
) -> None:
|
||||
setup_instance = DatabaseSetup(mock_config, dry_run=False)
|
||||
|
||||
baseline = tmp_path / "000_base.sql"
|
||||
@@ -88,15 +96,24 @@ def test_run_migrations_applies_baseline_when_missing(mock_config: DatabaseConfi
|
||||
cursor_context.__enter__.return_value = cursor_mock
|
||||
connection_mock.cursor.return_value = cursor_context
|
||||
|
||||
with mock.patch.object(
|
||||
setup_instance, "_application_connection", return_value=connection_mock
|
||||
), mock.patch.object(
|
||||
setup_instance, "_migrations_table_exists", return_value=True
|
||||
), mock.patch.object(
|
||||
setup_instance, "_fetch_applied_migrations", return_value=set()
|
||||
), mock.patch.object(
|
||||
setup_instance, "_apply_migration_file", side_effect=capture_migration
|
||||
) as apply_mock:
|
||||
with (
|
||||
mock.patch.object(
|
||||
setup_instance,
|
||||
"_application_connection",
|
||||
return_value=connection_mock,
|
||||
),
|
||||
mock.patch.object(
|
||||
setup_instance, "_migrations_table_exists", return_value=True
|
||||
),
|
||||
mock.patch.object(
|
||||
setup_instance, "_fetch_applied_migrations", return_value=set()
|
||||
),
|
||||
mock.patch.object(
|
||||
setup_instance,
|
||||
"_apply_migration_file",
|
||||
side_effect=capture_migration,
|
||||
) as apply_mock,
|
||||
):
|
||||
setup_instance.run_migrations(tmp_path)
|
||||
|
||||
assert apply_mock.call_count == 1
|
||||
@@ -121,17 +138,24 @@ def test_run_migrations_noop_when_all_files_already_applied(
|
||||
|
||||
connection_mock, cursor_mock = _connection_with_cursor()
|
||||
|
||||
with mock.patch.object(
|
||||
setup_instance, "_application_connection", return_value=connection_mock
|
||||
), mock.patch.object(
|
||||
setup_instance, "_migrations_table_exists", return_value=True
|
||||
), mock.patch.object(
|
||||
setup_instance,
|
||||
"_fetch_applied_migrations",
|
||||
return_value={"000_base.sql", "20251022_add_other.sql"},
|
||||
), mock.patch.object(
|
||||
setup_instance, "_apply_migration_file"
|
||||
) as apply_mock:
|
||||
with (
|
||||
mock.patch.object(
|
||||
setup_instance,
|
||||
"_application_connection",
|
||||
return_value=connection_mock,
|
||||
),
|
||||
mock.patch.object(
|
||||
setup_instance, "_migrations_table_exists", return_value=True
|
||||
),
|
||||
mock.patch.object(
|
||||
setup_instance,
|
||||
"_fetch_applied_migrations",
|
||||
return_value={"000_base.sql", "20251022_add_other.sql"},
|
||||
),
|
||||
mock.patch.object(
|
||||
setup_instance, "_apply_migration_file"
|
||||
) as apply_mock,
|
||||
):
|
||||
setup_instance.run_migrations(tmp_path)
|
||||
|
||||
apply_mock.assert_not_called()
|
||||
@@ -148,12 +172,16 @@ def _connection_with_cursor() -> tuple[mock.MagicMock, mock.MagicMock]:
|
||||
return connection_mock, cursor_mock
|
||||
|
||||
|
||||
def test_verify_seeded_data_raises_when_currency_missing(mock_config: DatabaseConfig) -> None:
|
||||
def test_verify_seeded_data_raises_when_currency_missing(
|
||||
mock_config: DatabaseConfig,
|
||||
) -> None:
|
||||
setup_instance = DatabaseSetup(mock_config, dry_run=False)
|
||||
connection_mock, cursor_mock = _connection_with_cursor()
|
||||
cursor_mock.fetchall.return_value = [("USD", True)]
|
||||
|
||||
with mock.patch.object(setup_instance, "_application_connection", return_value=connection_mock):
|
||||
with mock.patch.object(
|
||||
setup_instance, "_application_connection", return_value=connection_mock
|
||||
):
|
||||
with pytest.raises(RuntimeError) as exc:
|
||||
setup_instance._verify_seeded_data(
|
||||
expected_currency_codes={"USD", "EUR"},
|
||||
@@ -163,12 +191,16 @@ def test_verify_seeded_data_raises_when_currency_missing(mock_config: DatabaseCo
|
||||
assert "EUR" in str(exc.value)
|
||||
|
||||
|
||||
def test_verify_seeded_data_raises_when_default_currency_inactive(mock_config: DatabaseConfig) -> None:
|
||||
def test_verify_seeded_data_raises_when_default_currency_inactive(
|
||||
mock_config: DatabaseConfig,
|
||||
) -> None:
|
||||
setup_instance = DatabaseSetup(mock_config, dry_run=False)
|
||||
connection_mock, cursor_mock = _connection_with_cursor()
|
||||
cursor_mock.fetchall.return_value = [("USD", False)]
|
||||
|
||||
with mock.patch.object(setup_instance, "_application_connection", return_value=connection_mock):
|
||||
with mock.patch.object(
|
||||
setup_instance, "_application_connection", return_value=connection_mock
|
||||
):
|
||||
with pytest.raises(RuntimeError) as exc:
|
||||
setup_instance._verify_seeded_data(
|
||||
expected_currency_codes={"USD"},
|
||||
@@ -178,12 +210,16 @@ def test_verify_seeded_data_raises_when_default_currency_inactive(mock_config: D
|
||||
assert "inactive" in str(exc.value)
|
||||
|
||||
|
||||
def test_verify_seeded_data_raises_when_units_missing(mock_config: DatabaseConfig) -> None:
|
||||
def test_verify_seeded_data_raises_when_units_missing(
|
||||
mock_config: DatabaseConfig,
|
||||
) -> None:
|
||||
setup_instance = DatabaseSetup(mock_config, dry_run=False)
|
||||
connection_mock, cursor_mock = _connection_with_cursor()
|
||||
cursor_mock.fetchall.return_value = [("tonnes", True)]
|
||||
|
||||
with mock.patch.object(setup_instance, "_application_connection", return_value=connection_mock):
|
||||
with mock.patch.object(
|
||||
setup_instance, "_application_connection", return_value=connection_mock
|
||||
):
|
||||
with pytest.raises(RuntimeError) as exc:
|
||||
setup_instance._verify_seeded_data(
|
||||
expected_currency_codes=set(),
|
||||
@@ -193,12 +229,18 @@ def test_verify_seeded_data_raises_when_units_missing(mock_config: DatabaseConfi
|
||||
assert "liters" in str(exc.value)
|
||||
|
||||
|
||||
def test_verify_seeded_data_raises_when_measurement_table_missing(mock_config: DatabaseConfig) -> None:
|
||||
def test_verify_seeded_data_raises_when_measurement_table_missing(
|
||||
mock_config: DatabaseConfig,
|
||||
) -> None:
|
||||
setup_instance = DatabaseSetup(mock_config, dry_run=False)
|
||||
connection_mock, cursor_mock = _connection_with_cursor()
|
||||
cursor_mock.execute.side_effect = psycopg_errors.UndefinedTable("relation does not exist")
|
||||
cursor_mock.execute.side_effect = psycopg_errors.UndefinedTable(
|
||||
"relation does not exist"
|
||||
)
|
||||
|
||||
with mock.patch.object(setup_instance, "_application_connection", return_value=connection_mock):
|
||||
with mock.patch.object(
|
||||
setup_instance, "_application_connection", return_value=connection_mock
|
||||
):
|
||||
with pytest.raises(RuntimeError) as exc:
|
||||
setup_instance._verify_seeded_data(
|
||||
expected_currency_codes=set(),
|
||||
@@ -226,9 +268,14 @@ def test_seed_baseline_data_rerun_uses_existing_records(
|
||||
unit_rows,
|
||||
]
|
||||
|
||||
with mock.patch.object(
|
||||
setup_instance, "_application_connection", return_value=connection_mock
|
||||
), mock.patch("scripts.seed_data.run_with_namespace") as seed_run:
|
||||
with (
|
||||
mock.patch.object(
|
||||
setup_instance,
|
||||
"_application_connection",
|
||||
return_value=connection_mock,
|
||||
),
|
||||
mock.patch("scripts.seed_data.run_with_namespace") as seed_run,
|
||||
):
|
||||
setup_instance.seed_baseline_data(dry_run=False)
|
||||
setup_instance.seed_baseline_data(dry_run=False)
|
||||
|
||||
@@ -240,7 +287,9 @@ def test_seed_baseline_data_rerun_uses_existing_records(
|
||||
assert cursor_mock.execute.call_count == 4
|
||||
|
||||
|
||||
def test_ensure_database_raises_with_context(mock_config: DatabaseConfig) -> None:
|
||||
def test_ensure_database_raises_with_context(
|
||||
mock_config: DatabaseConfig,
|
||||
) -> None:
|
||||
setup_instance = DatabaseSetup(mock_config, dry_run=False)
|
||||
connection_mock = mock.MagicMock()
|
||||
cursor_mock = mock.MagicMock()
|
||||
@@ -248,14 +297,18 @@ def test_ensure_database_raises_with_context(mock_config: DatabaseConfig) -> Non
|
||||
cursor_mock.execute.side_effect = [None, psycopg2.Error("create_fail")]
|
||||
connection_mock.cursor.return_value = cursor_mock
|
||||
|
||||
with mock.patch.object(setup_instance, "_admin_connection", return_value=connection_mock):
|
||||
with mock.patch.object(
|
||||
setup_instance, "_admin_connection", return_value=connection_mock
|
||||
):
|
||||
with pytest.raises(RuntimeError) as exc:
|
||||
setup_instance.ensure_database()
|
||||
|
||||
assert "Failed to create database" in str(exc.value)
|
||||
|
||||
|
||||
def test_ensure_role_raises_with_context_during_creation(mock_config: DatabaseConfig) -> None:
|
||||
def test_ensure_role_raises_with_context_during_creation(
|
||||
mock_config: DatabaseConfig,
|
||||
) -> None:
|
||||
setup_instance = DatabaseSetup(mock_config, dry_run=False)
|
||||
|
||||
admin_conn, admin_cursor = _connection_with_cursor()
|
||||
@@ -295,7 +348,9 @@ def test_ensure_role_raises_with_context_during_privilege_grants(
|
||||
assert "Failed to grant privileges" in str(exc.value)
|
||||
|
||||
|
||||
def test_ensure_database_dry_run_skips_creation(mock_config: DatabaseConfig) -> None:
|
||||
def test_ensure_database_dry_run_skips_creation(
|
||||
mock_config: DatabaseConfig,
|
||||
) -> None:
|
||||
setup_instance = DatabaseSetup(mock_config, dry_run=True)
|
||||
|
||||
connection_mock = mock.MagicMock()
|
||||
@@ -303,45 +358,59 @@ def test_ensure_database_dry_run_skips_creation(mock_config: DatabaseConfig) ->
|
||||
cursor_mock.fetchone.return_value = None
|
||||
connection_mock.cursor.return_value = cursor_mock
|
||||
|
||||
with mock.patch.object(setup_instance, "_admin_connection", return_value=connection_mock), mock.patch(
|
||||
"scripts.setup_database.logger"
|
||||
) as logger_mock:
|
||||
with (
|
||||
mock.patch.object(
|
||||
setup_instance, "_admin_connection", return_value=connection_mock
|
||||
),
|
||||
mock.patch("scripts.setup_database.logger") as logger_mock,
|
||||
):
|
||||
setup_instance.ensure_database()
|
||||
|
||||
# expect only existence check, no create attempt
|
||||
cursor_mock.execute.assert_called_once()
|
||||
logger_mock.info.assert_any_call(
|
||||
"Dry run: would create database '%s'. Run without --dry-run to proceed.", mock_config.database
|
||||
"Dry run: would create database '%s'. Run without --dry-run to proceed.",
|
||||
mock_config.database,
|
||||
)
|
||||
|
||||
|
||||
def test_ensure_role_dry_run_skips_creation_and_grants(mock_config: DatabaseConfig) -> None:
|
||||
def test_ensure_role_dry_run_skips_creation_and_grants(
|
||||
mock_config: DatabaseConfig,
|
||||
) -> None:
|
||||
setup_instance = DatabaseSetup(mock_config, dry_run=True)
|
||||
|
||||
admin_conn, admin_cursor = _connection_with_cursor()
|
||||
admin_cursor.fetchone.return_value = None
|
||||
|
||||
with mock.patch.object(
|
||||
setup_instance,
|
||||
"_admin_connection",
|
||||
side_effect=[admin_conn],
|
||||
) as conn_mock, mock.patch("scripts.setup_database.logger") as logger_mock:
|
||||
with (
|
||||
mock.patch.object(
|
||||
setup_instance,
|
||||
"_admin_connection",
|
||||
side_effect=[admin_conn],
|
||||
) as conn_mock,
|
||||
mock.patch("scripts.setup_database.logger") as logger_mock,
|
||||
):
|
||||
setup_instance.ensure_role()
|
||||
|
||||
assert conn_mock.call_count == 1
|
||||
admin_cursor.execute.assert_called_once()
|
||||
logger_mock.info.assert_any_call(
|
||||
"Dry run: would create role '%s'. Run without --dry-run to apply.", mock_config.user
|
||||
"Dry run: would create role '%s'. Run without --dry-run to apply.",
|
||||
mock_config.user,
|
||||
)
|
||||
|
||||
|
||||
def test_register_rollback_skipped_when_dry_run(mock_config: DatabaseConfig) -> None:
|
||||
def test_register_rollback_skipped_when_dry_run(
|
||||
mock_config: DatabaseConfig,
|
||||
) -> None:
|
||||
setup_instance = DatabaseSetup(mock_config, dry_run=True)
|
||||
setup_instance._register_rollback("noop", lambda: None)
|
||||
assert setup_instance._rollback_actions == []
|
||||
|
||||
|
||||
def test_execute_rollbacks_runs_in_reverse_order(mock_config: DatabaseConfig) -> None:
|
||||
def test_execute_rollbacks_runs_in_reverse_order(
|
||||
mock_config: DatabaseConfig,
|
||||
) -> None:
|
||||
setup_instance = DatabaseSetup(mock_config, dry_run=False)
|
||||
|
||||
calls: list[str] = []
|
||||
@@ -362,16 +431,24 @@ def test_execute_rollbacks_runs_in_reverse_order(mock_config: DatabaseConfig) ->
|
||||
assert setup_instance._rollback_actions == []
|
||||
|
||||
|
||||
def test_ensure_database_registers_rollback_action(mock_config: DatabaseConfig) -> None:
|
||||
def test_ensure_database_registers_rollback_action(
|
||||
mock_config: DatabaseConfig,
|
||||
) -> None:
|
||||
setup_instance = DatabaseSetup(mock_config, dry_run=False)
|
||||
connection_mock = mock.MagicMock()
|
||||
cursor_mock = mock.MagicMock()
|
||||
cursor_mock.fetchone.return_value = None
|
||||
connection_mock.cursor.return_value = cursor_mock
|
||||
|
||||
with mock.patch.object(setup_instance, "_admin_connection", return_value=connection_mock), mock.patch.object(
|
||||
setup_instance, "_register_rollback"
|
||||
) as register_mock, mock.patch.object(setup_instance, "_drop_database") as drop_mock:
|
||||
with (
|
||||
mock.patch.object(
|
||||
setup_instance, "_admin_connection", return_value=connection_mock
|
||||
),
|
||||
mock.patch.object(
|
||||
setup_instance, "_register_rollback"
|
||||
) as register_mock,
|
||||
mock.patch.object(setup_instance, "_drop_database") as drop_mock,
|
||||
):
|
||||
setup_instance.ensure_database()
|
||||
register_mock.assert_called_once()
|
||||
label, action = register_mock.call_args[0]
|
||||
@@ -380,24 +457,29 @@ def test_ensure_database_registers_rollback_action(mock_config: DatabaseConfig)
|
||||
drop_mock.assert_called_once_with(mock_config.database)
|
||||
|
||||
|
||||
def test_ensure_role_registers_rollback_actions(mock_config: DatabaseConfig) -> None:
|
||||
def test_ensure_role_registers_rollback_actions(
|
||||
mock_config: DatabaseConfig,
|
||||
) -> None:
|
||||
setup_instance = DatabaseSetup(mock_config, dry_run=False)
|
||||
|
||||
admin_conn, admin_cursor = _connection_with_cursor()
|
||||
admin_cursor.fetchone.return_value = None
|
||||
privilege_conn, privilege_cursor = _connection_with_cursor()
|
||||
|
||||
with mock.patch.object(
|
||||
setup_instance,
|
||||
"_admin_connection",
|
||||
side_effect=[admin_conn, privilege_conn],
|
||||
), mock.patch.object(
|
||||
setup_instance, "_register_rollback"
|
||||
) as register_mock, mock.patch.object(
|
||||
setup_instance, "_drop_role"
|
||||
) as drop_mock, mock.patch.object(
|
||||
setup_instance, "_revoke_role_privileges"
|
||||
) as revoke_mock:
|
||||
with (
|
||||
mock.patch.object(
|
||||
setup_instance,
|
||||
"_admin_connection",
|
||||
side_effect=[admin_conn, privilege_conn],
|
||||
),
|
||||
mock.patch.object(
|
||||
setup_instance, "_register_rollback"
|
||||
) as register_mock,
|
||||
mock.patch.object(setup_instance, "_drop_role") as drop_mock,
|
||||
mock.patch.object(
|
||||
setup_instance, "_revoke_role_privileges"
|
||||
) as revoke_mock,
|
||||
):
|
||||
setup_instance.ensure_role()
|
||||
assert register_mock.call_count == 2
|
||||
drop_label, drop_action = register_mock.call_args_list[0][0]
|
||||
@@ -413,7 +495,9 @@ def test_ensure_role_registers_rollback_actions(mock_config: DatabaseConfig) ->
|
||||
revoke_mock.assert_called_once()
|
||||
|
||||
|
||||
def test_main_triggers_rollbacks_on_failure(mock_config: DatabaseConfig) -> None:
|
||||
def test_main_triggers_rollbacks_on_failure(
|
||||
mock_config: DatabaseConfig,
|
||||
) -> None:
|
||||
args = argparse.Namespace(
|
||||
ensure_database=True,
|
||||
ensure_role=True,
|
||||
@@ -437,11 +521,13 @@ def test_main_triggers_rollbacks_on_failure(mock_config: DatabaseConfig) -> None
|
||||
verbose=0,
|
||||
)
|
||||
|
||||
with mock.patch.object(setup_db_module, "parse_args", return_value=args), mock.patch.object(
|
||||
setup_db_module.DatabaseConfig, "from_env", return_value=mock_config
|
||||
), mock.patch.object(
|
||||
setup_db_module, "DatabaseSetup"
|
||||
) as setup_cls:
|
||||
with (
|
||||
mock.patch.object(setup_db_module, "parse_args", return_value=args),
|
||||
mock.patch.object(
|
||||
setup_db_module.DatabaseConfig, "from_env", return_value=mock_config
|
||||
),
|
||||
mock.patch.object(setup_db_module, "DatabaseSetup") as setup_cls,
|
||||
):
|
||||
setup_instance = mock.MagicMock()
|
||||
setup_instance.dry_run = False
|
||||
setup_instance._rollback_actions = [
|
||||
|
||||
@@ -19,7 +19,12 @@ def client(api_client: TestClient) -> TestClient:
|
||||
|
||||
def test_run_simulation_function_generates_samples():
|
||||
params: List[Dict[str, Any]] = [
|
||||
{"name": "grade", "value": 1.8, "distribution": "normal", "std_dev": 0.2},
|
||||
{
|
||||
"name": "grade",
|
||||
"value": 1.8,
|
||||
"distribution": "normal",
|
||||
"std_dev": 0.2,
|
||||
},
|
||||
{
|
||||
"name": "recovery",
|
||||
"value": 0.9,
|
||||
@@ -45,7 +50,10 @@ def test_run_simulation_with_zero_iterations_returns_empty():
|
||||
@pytest.mark.parametrize(
|
||||
"parameter_payload,error_message",
|
||||
[
|
||||
({"name": "missing-value"}, "Parameter at index 0 must include 'value'"),
|
||||
(
|
||||
{"name": "missing-value"},
|
||||
"Parameter at index 0 must include 'value'",
|
||||
),
|
||||
(
|
||||
{
|
||||
"name": "bad-dist",
|
||||
@@ -110,7 +118,8 @@ def test_run_simulation_triangular_sampling_path():
|
||||
span = 10.0 * DEFAULT_UNIFORM_SPAN_RATIO
|
||||
rng = Random(seed)
|
||||
expected_samples = [
|
||||
rng.triangular(10.0 - span, 10.0 + span, 10.0) for _ in range(iterations)
|
||||
rng.triangular(10.0 - span, 10.0 + span, 10.0)
|
||||
for _ in range(iterations)
|
||||
]
|
||||
actual_samples = [entry["result"] for entry in results]
|
||||
for actual, expected in zip(actual_samples, expected_samples):
|
||||
@@ -156,9 +165,7 @@ def test_simulation_endpoint_no_params(client: TestClient):
|
||||
assert resp.json()["detail"] == "No parameters provided"
|
||||
|
||||
|
||||
def test_simulation_endpoint_success(
|
||||
client: TestClient, db_session: Session
|
||||
):
|
||||
def test_simulation_endpoint_success(client: TestClient, db_session: Session):
|
||||
scenario_payload: Dict[str, Any] = {
|
||||
"name": f"SimScenario-{uuid4()}",
|
||||
"description": "Simulation test",
|
||||
@@ -168,7 +175,12 @@ def test_simulation_endpoint_success(
|
||||
scenario_id = scenario_resp.json()["id"]
|
||||
|
||||
params: List[Dict[str, Any]] = [
|
||||
{"name": "param1", "value": 2.5, "distribution": "normal", "std_dev": 0.5}
|
||||
{
|
||||
"name": "param1",
|
||||
"value": 2.5,
|
||||
"distribution": "normal",
|
||||
"std_dev": 0.5,
|
||||
}
|
||||
]
|
||||
payload: Dict[str, Any] = {
|
||||
"scenario_id": scenario_id,
|
||||
|
||||
63
tests/unit/test_theme_settings.py
Normal file
63
tests/unit/test_theme_settings.py
Normal file
@@ -0,0 +1,63 @@
|
||||
import pytest
|
||||
from sqlalchemy.orm import Session
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from main import app
|
||||
from models.theme_setting import ThemeSetting
|
||||
from services.settings import save_theme_settings, get_theme_settings
|
||||
|
||||
|
||||
client = TestClient(app)
|
||||
|
||||
|
||||
def test_save_theme_settings(db_session: Session):
|
||||
theme_data = {
|
||||
"theme_name": "dark",
|
||||
"primary_color": "#000000",
|
||||
"secondary_color": "#333333",
|
||||
"accent_color": "#ff0000",
|
||||
"background_color": "#1a1a1a",
|
||||
"text_color": "#ffffff"
|
||||
}
|
||||
|
||||
saved_setting = save_theme_settings(db_session, theme_data)
|
||||
assert str(saved_setting.theme_name) == "dark"
|
||||
assert str(saved_setting.primary_color) == "#000000"
|
||||
|
||||
|
||||
def test_get_theme_settings(db_session: Session):
|
||||
# Create a theme setting first
|
||||
theme_data = {
|
||||
"theme_name": "light",
|
||||
"primary_color": "#ffffff",
|
||||
"secondary_color": "#cccccc",
|
||||
"accent_color": "#0000ff",
|
||||
"background_color": "#f0f0f0",
|
||||
"text_color": "#000000"
|
||||
}
|
||||
save_theme_settings(db_session, theme_data)
|
||||
|
||||
settings = get_theme_settings(db_session)
|
||||
assert settings["theme_name"] == "light"
|
||||
assert settings["primary_color"] == "#ffffff"
|
||||
|
||||
|
||||
def test_theme_settings_api(api_client):
|
||||
# Test API endpoint for saving theme settings
|
||||
theme_data = {
|
||||
"theme_name": "test_theme",
|
||||
"primary_color": "#123456",
|
||||
"secondary_color": "#789abc",
|
||||
"accent_color": "#def012",
|
||||
"background_color": "#345678",
|
||||
"text_color": "#9abcde"
|
||||
}
|
||||
|
||||
response = api_client.post("/api/settings/theme", json=theme_data)
|
||||
assert response.status_code == 200
|
||||
assert response.json()["theme"]["theme_name"] == "test_theme"
|
||||
|
||||
# Test API endpoint for getting theme settings
|
||||
response = api_client.get("/api/settings/theme")
|
||||
assert response.status_code == 200
|
||||
assert response.json()["theme_name"] == "test_theme"
|
||||
@@ -21,11 +21,18 @@ def test_dashboard_route_provides_summary(
|
||||
assert context.get("report_available") is True
|
||||
|
||||
metric_labels = {item["label"] for item in context["summary_metrics"]}
|
||||
assert {"CAPEX Total", "OPEX Total", "Production", "Simulation Iterations"}.issubset(metric_labels)
|
||||
assert {
|
||||
"CAPEX Total",
|
||||
"OPEX Total",
|
||||
"Production",
|
||||
"Simulation Iterations",
|
||||
}.issubset(metric_labels)
|
||||
|
||||
scenario = cast(Scenario, seeded_ui_data["scenario"])
|
||||
scenario_row = next(
|
||||
row for row in context["scenario_rows"] if row["scenario_name"] == scenario.name
|
||||
row
|
||||
for row in context["scenario_rows"]
|
||||
if row["scenario_name"] == scenario.name
|
||||
)
|
||||
assert scenario_row["iterations"] == 3
|
||||
assert scenario_row["simulation_mean_display"] == "971,666.67"
|
||||
@@ -81,7 +88,9 @@ def test_dashboard_data_endpoint_returns_aggregates(
|
||||
payload = response.json()
|
||||
assert payload["report_available"] is True
|
||||
|
||||
metric_map = {item["label"]: item["value"] for item in payload["summary_metrics"]}
|
||||
metric_map = {
|
||||
item["label"]: item["value"] for item in payload["summary_metrics"]
|
||||
}
|
||||
assert metric_map["CAPEX Total"].startswith("$")
|
||||
assert metric_map["Maintenance Cost"].startswith("$")
|
||||
|
||||
@@ -99,7 +108,9 @@ def test_dashboard_data_endpoint_returns_aggregates(
|
||||
|
||||
activity_labels = payload["scenario_activity_chart"]["labels"]
|
||||
activity_idx = activity_labels.index(scenario.name)
|
||||
assert payload["scenario_activity_chart"]["production"][activity_idx] == 800.0
|
||||
assert (
|
||||
payload["scenario_activity_chart"]["production"][activity_idx] == 800.0
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@@ -154,7 +165,10 @@ def test_settings_route_provides_css_context(
|
||||
assert "css_env_override_meta" in context
|
||||
|
||||
assert context["css_variables"]["--color-accent"] == "#abcdef"
|
||||
assert context["css_defaults"]["--color-accent"] == settings_service.CSS_COLOR_DEFAULTS["--color-accent"]
|
||||
assert (
|
||||
context["css_defaults"]["--color-accent"]
|
||||
== settings_service.CSS_COLOR_DEFAULTS["--color-accent"]
|
||||
)
|
||||
assert context["css_env_overrides"]["--color-accent"] == "#abcdef"
|
||||
|
||||
override_rows = context["css_env_override_rows"]
|
||||
|
||||
Reference in New Issue
Block a user