Refactor test cases for improved readability and consistency
Some checks failed
Run Tests / e2e tests (push) Failing after 1m27s
Run Tests / lint tests (push) Failing after 6s
Run Tests / unit tests (push) Failing after 7s

- Updated test functions in various test files to enhance code clarity by formatting long lines and improving indentation.
- Adjusted assertions to use multi-line formatting for better readability.
- Added new test cases for theme settings API to ensure proper functionality.
- Ensured consistent use of line breaks and spacing across test files for uniformity.
This commit is contained in:
2025-10-27 10:32:55 +01:00
parent e8a86b15e4
commit 97b1c0360b
78 changed files with 2327 additions and 650 deletions

View File

@@ -9,6 +9,7 @@ This script is intentionally cautious: it defaults to dry-run mode and will refu
if database connection settings are missing. It supports creating missing currency rows when `--create-missing`
is provided. Always run against a development/staging database first.
"""
from __future__ import annotations
import argparse
import importlib
@@ -36,26 +37,43 @@ def load_database_url() -> str:
return getattr(db_module, "DATABASE_URL")
def backfill(db_url: str, dry_run: bool = True, create_missing: bool = False) -> None:
def backfill(
db_url: str, dry_run: bool = True, create_missing: bool = False
) -> None:
engine = create_engine(db_url)
with engine.begin() as conn:
# Ensure currency table exists
res = conn.execute(text("SELECT name FROM sqlite_master WHERE type='table' AND name='currency';")) if db_url.startswith(
'sqlite:') else conn.execute(text("SELECT to_regclass('public.currency');"))
res = (
conn.execute(
text(
"SELECT name FROM sqlite_master WHERE type='table' AND name='currency';"
)
)
if db_url.startswith("sqlite:")
else conn.execute(text("SELECT to_regclass('public.currency');"))
)
# Note: we don't strictly depend on the above - we assume migration was already applied
# Helper: find or create currency by code
def find_currency_id(code: str):
r = conn.execute(text("SELECT id FROM currency WHERE code = :code"), {
"code": code}).fetchone()
r = conn.execute(
text("SELECT id FROM currency WHERE code = :code"),
{"code": code},
).fetchone()
if r:
return r[0]
if create_missing:
# insert and return id
conn.execute(text("INSERT INTO currency (code, name, symbol, is_active) VALUES (:c, :n, NULL, TRUE)"), {
"c": code, "n": code})
r2 = conn.execute(text("SELECT id FROM currency WHERE code = :code"), {
"code": code}).fetchone()
conn.execute(
text(
"INSERT INTO currency (code, name, symbol, is_active) VALUES (:c, :n, NULL, TRUE)"
),
{"c": code, "n": code},
)
r2 = conn.execute(
text("SELECT id FROM currency WHERE code = :code"),
{"code": code},
).fetchone()
if not r2:
raise RuntimeError(
f"Unable to determine currency ID for '{code}' after insert"
@@ -67,8 +85,15 @@ def backfill(db_url: str, dry_run: bool = True, create_missing: bool = False) ->
for table in ("capex", "opex"):
# Check if currency_id column exists
try:
cols = conn.execute(text(f"SELECT 1 FROM information_schema.columns WHERE table_name = '{table}' AND column_name = 'currency_id'")) if not db_url.startswith(
'sqlite:') else [(1,)]
cols = (
conn.execute(
text(
f"SELECT 1 FROM information_schema.columns WHERE table_name = '{table}' AND column_name = 'currency_id'"
)
)
if not db_url.startswith("sqlite:")
else [(1,)]
)
except Exception:
cols = [(1,)]
@@ -77,8 +102,11 @@ def backfill(db_url: str, dry_run: bool = True, create_missing: bool = False) ->
continue
# Find rows where currency_id IS NULL but currency_code exists
rows = conn.execute(text(
f"SELECT id, currency_code FROM {table} WHERE currency_id IS NULL OR currency_id = ''"))
rows = conn.execute(
text(
f"SELECT id, currency_code FROM {table} WHERE currency_id IS NULL OR currency_id = ''"
)
)
changed = 0
for r in rows:
rid = r[0]
@@ -86,14 +114,20 @@ def backfill(db_url: str, dry_run: bool = True, create_missing: bool = False) ->
cid = find_currency_id(code)
if cid is None:
print(
f"Row {table}:{rid} has unknown currency code '{code}' and create_missing=False; skipping")
f"Row {table}:{rid} has unknown currency code '{code}' and create_missing=False; skipping"
)
continue
if dry_run:
print(
f"[DRY RUN] Would set {table}.currency_id = {cid} for row id={rid} (code={code})")
f"[DRY RUN] Would set {table}.currency_id = {cid} for row id={rid} (code={code})"
)
else:
conn.execute(text(f"UPDATE {table} SET currency_id = :cid WHERE id = :rid"), {
"cid": cid, "rid": rid})
conn.execute(
text(
f"UPDATE {table} SET currency_id = :cid WHERE id = :rid"
),
{"cid": cid, "rid": rid},
)
changed += 1
print(f"{table}: processed, changed={changed} (dry_run={dry_run})")
@@ -101,11 +135,19 @@ def backfill(db_url: str, dry_run: bool = True, create_missing: bool = False) ->
def main() -> None:
parser = argparse.ArgumentParser(
description="Backfill currency_id from currency_code for capex/opex tables")
parser.add_argument("--dry-run", action="store_true",
default=True, help="Show actions without writing")
parser.add_argument("--create-missing", action="store_true",
help="Create missing currency rows in the currency table")
description="Backfill currency_id from currency_code for capex/opex tables"
)
parser.add_argument(
"--dry-run",
action="store_true",
default=True,
help="Show actions without writing",
)
parser.add_argument(
"--create-missing",
action="store_true",
help="Create missing currency rows in the currency table",
)
args = parser.parse_args()
db = load_database_url()

View File

@@ -4,25 +4,30 @@ Checks only local file links (relative paths) and reports missing targets.
Run from the repository root using the project's Python environment.
"""
import re
from pathlib import Path
ROOT = Path(__file__).resolve().parent.parent
DOCS = ROOT / 'docs'
DOCS = ROOT / "docs"
MD_LINK_RE = re.compile(r"\[([^\]]+)\]\(([^)]+)\)")
errors = []
for md in DOCS.rglob('*.md'):
text = md.read_text(encoding='utf-8')
for md in DOCS.rglob("*.md"):
text = md.read_text(encoding="utf-8")
for m in MD_LINK_RE.finditer(text):
label, target = m.groups()
# skip URLs
if target.startswith('http://') or target.startswith('https://') or target.startswith('#'):
if (
target.startswith("http://")
or target.startswith("https://")
or target.startswith("#")
):
continue
# strip anchors
target_path = target.split('#')[0]
target_path = target.split("#")[0]
# if link is to a directory index, allow
candidate = (md.parent / target_path).resolve()
if candidate.exists():
@@ -30,14 +35,16 @@ for md in DOCS.rglob('*.md'):
# check common implicit index: target/ -> target/README.md or target/index.md
candidate_dir = md.parent / target_path
if candidate_dir.is_dir():
if (candidate_dir / 'README.md').exists() or (candidate_dir / 'index.md').exists():
if (candidate_dir / "README.md").exists() or (
candidate_dir / "index.md"
).exists():
continue
errors.append((str(md.relative_to(ROOT)), target, label))
if errors:
print('Broken local links found:')
print("Broken local links found:")
for src, tgt, label in errors:
print(f'- {src} -> {tgt} ({label})')
print(f"- {src} -> {tgt} ({label})")
exit(2)
print('No broken local links detected.')
print("No broken local links detected.")

View File

@@ -2,16 +2,17 @@
This is intentionally small and non-destructive; it touches only files under docs/ and makes safe changes.
"""
import re
from pathlib import Path
DOCS = Path(__file__).resolve().parents[1] / "docs"
CODE_LANG_HINTS = {
'powershell': ('powershell',),
'bash': ('bash', 'sh'),
'sql': ('sql',),
'python': ('python',),
"powershell": ("powershell",),
"bash": ("bash", "sh"),
"sql": ("sql",),
"python": ("python",),
}
@@ -19,48 +20,60 @@ def add_code_fence_language(match):
fence = match.group(0)
inner = match.group(1)
# If language already present, return unchanged
if fence.startswith('```') and len(fence.splitlines()[0].strip()) > 3:
if fence.startswith("```") and len(fence.splitlines()[0].strip()) > 3:
return fence
# Try to infer language from the code content
code = inner.strip().splitlines()[0] if inner.strip() else ''
lang = ''
if code.startswith('$') or code.startswith('PS') or code.lower().startswith('powershell'):
lang = 'powershell'
elif code.startswith('#') or code.startswith('import') or code.startswith('from'):
lang = 'python'
elif re.match(r'^(select|insert|update|create)\b', code.strip(), re.I):
lang = 'sql'
elif code.startswith('git') or code.startswith('./') or code.startswith('sudo'):
lang = 'bash'
code = inner.strip().splitlines()[0] if inner.strip() else ""
lang = ""
if (
code.startswith("$")
or code.startswith("PS")
or code.lower().startswith("powershell")
):
lang = "powershell"
elif (
code.startswith("#")
or code.startswith("import")
or code.startswith("from")
):
lang = "python"
elif re.match(r"^(select|insert|update|create)\b", code.strip(), re.I):
lang = "sql"
elif (
code.startswith("git")
or code.startswith("./")
or code.startswith("sudo")
):
lang = "bash"
if lang:
return f'```{lang}\n{inner}\n```'
return f"```{lang}\n{inner}\n```"
return fence
def normalize_file(path: Path):
text = path.read_text(encoding='utf-8')
text = path.read_text(encoding="utf-8")
orig = text
# Trim trailing whitespace and ensure single trailing newline
text = '\n'.join(line.rstrip() for line in text.splitlines()) + '\n'
text = "\n".join(line.rstrip() for line in text.splitlines()) + "\n"
# Ensure first non-empty line is H1
lines = text.splitlines()
for i, ln in enumerate(lines):
if ln.strip():
if not ln.startswith('#'):
lines[i] = '# ' + ln
if not ln.startswith("#"):
lines[i] = "# " + ln
break
text = '\n'.join(lines) + '\n'
text = "\n".join(lines) + "\n"
# Add basic code fence languages where missing (simple heuristic)
text = re.sub(r'```\n([\s\S]*?)\n```', add_code_fence_language, text)
text = re.sub(r"```\n([\s\S]*?)\n```", add_code_fence_language, text)
if text != orig:
path.write_text(text, encoding='utf-8')
path.write_text(text, encoding="utf-8")
return True
return False
def main():
changed = []
for p in DOCS.rglob('*.md'):
for p in DOCS.rglob("*.md"):
if p.is_file():
try:
if normalize_file(p):
@@ -68,12 +81,12 @@ def main():
except Exception as e:
print(f"Failed to format {p}: {e}")
if changed:
print('Formatted files:')
print("Formatted files:")
for c in changed:
print(' -', c)
print(" -", c)
else:
print('No formatting changes required.')
print("No formatting changes required.")
if __name__ == '__main__':
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,11 @@
-- Migration: 20251027_create_theme_settings_table.sql
CREATE TABLE theme_settings (
id SERIAL PRIMARY KEY,
theme_name VARCHAR(255) UNIQUE NOT NULL,
primary_color VARCHAR(7) NOT NULL,
secondary_color VARCHAR(7) NOT NULL,
accent_color VARCHAR(7) NOT NULL,
background_color VARCHAR(7) NOT NULL,
text_color VARCHAR(7) NOT NULL
);

View File

@@ -0,0 +1,15 @@
-- Migration: 20251027_create_user_and_role_tables.sql
CREATE TABLE roles (
id SERIAL PRIMARY KEY,
name VARCHAR(255) UNIQUE NOT NULL
);
CREATE TABLE users (
id SERIAL PRIMARY KEY,
username VARCHAR(255) UNIQUE NOT NULL,
email VARCHAR(255) UNIQUE NOT NULL,
hashed_password VARCHAR(255) NOT NULL,
role_id INTEGER NOT NULL,
FOREIGN KEY (role_id) REFERENCES roles(id)
);

View File

@@ -47,22 +47,82 @@ MEASUREMENT_UNIT_SEEDS = (
("kilowatt_hours", "Kilowatt Hours", "kWh", "energy", True),
)
THEME_SETTING_SEEDS = (
("--color-background", "#f4f5f7", "color",
"theme", "CSS variable --color-background", True),
("--color-surface", "#ffffff", "color",
"theme", "CSS variable --color-surface", True),
("--color-text-primary", "#2a1f33", "color",
"theme", "CSS variable --color-text-primary", True),
("--color-text-secondary", "#624769", "color",
"theme", "CSS variable --color-text-secondary", True),
("--color-text-muted", "#64748b", "color",
"theme", "CSS variable --color-text-muted", True),
("--color-text-subtle", "#94a3b8", "color",
"theme", "CSS variable --color-text-subtle", True),
("--color-text-invert", "#ffffff", "color",
"theme", "CSS variable --color-text-invert", True),
("--color-text-dark", "#0f172a", "color",
"theme", "CSS variable --color-text-dark", True),
("--color-text-strong", "#111827", "color",
"theme", "CSS variable --color-text-strong", True),
("--color-primary", "#5f320d", "color",
"theme", "CSS variable --color-primary", True),
("--color-primary-strong", "#7e4c13", "color",
"theme", "CSS variable --color-primary-strong", True),
("--color-primary-stronger", "#837c15", "color",
"theme", "CSS variable --color-primary-stronger", True),
("--color-accent", "#bff838", "color",
"theme", "CSS variable --color-accent", True),
("--color-border", "#e2e8f0", "color",
"theme", "CSS variable --color-border", True),
("--color-border-strong", "#cbd5e1", "color",
"theme", "CSS variable --color-border-strong", True),
("--color-highlight", "#eef2ff", "color",
"theme", "CSS variable --color-highlight", True),
("--color-panel-shadow", "rgba(15, 23, 42, 0.08)", "color",
"theme", "CSS variable --color-panel-shadow", True),
("--color-panel-shadow-deep", "rgba(15, 23, 42, 0.12)", "color",
"theme", "CSS variable --color-panel-shadow-deep", True),
("--color-surface-alt", "#f8fafc", "color",
"theme", "CSS variable --color-surface-alt", True),
("--color-success", "#047857", "color",
"theme", "CSS variable --color-success", True),
("--color-error", "#b91c1c", "color",
"theme", "CSS variable --color-error", True),
)
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="Seed baseline CalMiner data")
parser.add_argument("--currencies", action="store_true", help="Seed currency table")
parser.add_argument("--units", action="store_true", help="Seed unit table")
parser.add_argument("--defaults", action="store_true", help="Seed default records")
parser.add_argument("--dry-run", action="store_true", help="Print actions without executing")
parser.add_argument(
"--verbose", "-v", action="count", default=0, help="Increase logging verbosity"
"--currencies", action="store_true", help="Seed currency table"
)
parser.add_argument("--units", action="store_true", help="Seed unit table")
parser.add_argument(
"--theme", action="store_true", help="Seed theme settings"
)
parser.add_argument(
"--defaults", action="store_true", help="Seed default records"
)
parser.add_argument(
"--dry-run", action="store_true", help="Print actions without executing"
)
parser.add_argument(
"--verbose",
"-v",
action="count",
default=0,
help="Increase logging verbosity",
)
return parser.parse_args()
def _configure_logging(args: argparse.Namespace) -> None:
level = logging.WARNING - (10 * min(args.verbose, 2))
logging.basicConfig(level=max(level, logging.INFO), format="%(levelname)s %(message)s")
logging.basicConfig(
level=max(level, logging.INFO), format="%(levelname)s %(message)s"
)
def main() -> None:
@@ -77,7 +137,7 @@ def run_with_namespace(
) -> None:
_configure_logging(args)
if not any((args.currencies, args.units, args.defaults)):
if not any((args.currencies, args.units, args.theme, args.defaults)):
logger.info("No seeding options provided; exiting")
return
@@ -89,6 +149,8 @@ def run_with_namespace(
_seed_currencies(cursor, dry_run=args.dry_run)
if args.units:
_seed_units(cursor, dry_run=args.dry_run)
if args.theme:
_seed_theme(cursor, dry_run=args.dry_run)
if args.defaults:
_seed_defaults(cursor, dry_run=args.dry_run)
@@ -152,11 +214,44 @@ def _seed_units(cursor, *, dry_run: bool) -> None:
logger.info("Measurement unit seed complete")
def _seed_defaults(cursor, *, dry_run: bool) -> None:
logger.info("Seeding default records - not yet implemented")
def _seed_theme(cursor, *, dry_run: bool) -> None:
logger.info("Seeding theme settings (%d rows)", len(THEME_SETTING_SEEDS))
if dry_run:
for key, value, _, _, _, _ in THEME_SETTING_SEEDS:
logger.info(
"Dry run: would upsert theme setting %s = %s", key, value)
return
try:
execute_values(
cursor,
"""
INSERT INTO application_setting (key, value, value_type, category, description, is_editable)
VALUES %s
ON CONFLICT (key) DO UPDATE
SET value = EXCLUDED.value,
value_type = EXCLUDED.value_type,
category = EXCLUDED.category,
description = EXCLUDED.description,
is_editable = EXCLUDED.is_editable
""",
THEME_SETTING_SEEDS,
)
except errors.UndefinedTable:
logger.warning(
"application_setting table does not exist; skipping theme seeding."
)
cursor.connection.rollback()
return
logger.info("Theme settings seed complete")
def _seed_defaults(cursor, *, dry_run: bool) -> None:
logger.info("Seeding default records")
_seed_theme(cursor, dry_run=dry_run)
logger.info("Default records seed complete")
if __name__ == "__main__":
main()
main()

View File

@@ -39,6 +39,7 @@ from psycopg2 import extensions
from psycopg2.extensions import connection as PGConnection, parse_dsn
from dotenv import load_dotenv
from sqlalchemy import create_engine, inspect
ROOT_DIR = Path(__file__).resolve().parents[1]
if str(ROOT_DIR) not in sys.path:
sys.path.insert(0, str(ROOT_DIR))
@@ -125,8 +126,7 @@ class DatabaseConfig:
]
if missing:
raise RuntimeError(
"Missing required database configuration: " +
", ".join(missing)
"Missing required database configuration: " + ", ".join(missing)
)
host = cast(str, host)
@@ -208,12 +208,17 @@ class DatabaseConfig:
class DatabaseSetup:
"""Encapsulates the full setup workflow."""
def __init__(self, config: DatabaseConfig, *, dry_run: bool = False) -> None:
def __init__(
self, config: DatabaseConfig, *, dry_run: bool = False
) -> None:
self.config = config
self.dry_run = dry_run
self._models_loaded = False
self._rollback_actions: list[tuple[str, Callable[[], None]]] = []
def _register_rollback(self, label: str, action: Callable[[], None]) -> None:
def _register_rollback(
self, label: str, action: Callable[[], None]
) -> None:
if self.dry_run:
return
self._rollback_actions.append((label, action))
@@ -237,7 +242,6 @@ class DatabaseSetup:
def clear_rollbacks(self) -> None:
self._rollback_actions.clear()
def _describe_connection(self, user: str, database: str) -> str:
return f"{user}@{self.config.host}:{self.config.port}/{database}"
@@ -384,9 +388,9 @@ class DatabaseSetup:
try:
if self.config.password:
cursor.execute(
sql.SQL("CREATE ROLE {} WITH LOGIN PASSWORD %s").format(
sql.Identifier(self.config.user)
),
sql.SQL(
"CREATE ROLE {} WITH LOGIN PASSWORD %s"
).format(sql.Identifier(self.config.user)),
(self.config.password,),
)
else:
@@ -589,8 +593,7 @@ class DatabaseSetup:
return psycopg2.connect(dsn)
except psycopg2.Error as exc:
raise RuntimeError(
"Unable to establish admin connection. "
f"Target: {descriptor}"
"Unable to establish admin connection. " f"Target: {descriptor}"
) from exc
def _application_connection(self) -> PGConnection:
@@ -645,7 +648,9 @@ class DatabaseSetup:
importlib.import_module(f"{package.__name__}.{module_info.name}")
self._models_loaded = True
def run_migrations(self, migrations_dir: Optional[Path | str] = None) -> None:
def run_migrations(
self, migrations_dir: Optional[Path | str] = None
) -> None:
"""Execute pending SQL migrations in chronological order."""
directory = (
@@ -673,7 +678,8 @@ class DatabaseSetup:
conn.autocommit = True
with conn.cursor() as cursor:
table_exists = self._migrations_table_exists(
cursor, schema_name)
cursor, schema_name
)
if not table_exists:
if self.dry_run:
logger.info(
@@ -692,12 +698,10 @@ class DatabaseSetup:
applied = set()
else:
applied = self._fetch_applied_migrations(
cursor, schema_name)
cursor, schema_name
)
if (
baseline_path.exists()
and baseline_name not in applied
):
if baseline_path.exists() and baseline_name not in applied:
if self.dry_run:
logger.info(
"Dry run: baseline migration '%s' pending; would apply and mark legacy files",
@@ -756,9 +760,7 @@ class DatabaseSetup:
)
pending = [
path
for path in migration_files
if path.name not in applied
path for path in migration_files if path.name not in applied
]
if not pending:
@@ -792,9 +794,7 @@ class DatabaseSetup:
cursor.execute(
sql.SQL(
"INSERT INTO {} (filename, applied_at) VALUES (%s, NOW())"
).format(
sql.Identifier(schema_name, MIGRATIONS_TABLE)
),
).format(sql.Identifier(schema_name, MIGRATIONS_TABLE)),
(path.name,),
)
return path.name
@@ -820,9 +820,7 @@ class DatabaseSetup:
"filename TEXT PRIMARY KEY,"
"applied_at TIMESTAMPTZ NOT NULL DEFAULT NOW()"
")"
).format(
sql.Identifier(schema_name, MIGRATIONS_TABLE)
)
).format(sql.Identifier(schema_name, MIGRATIONS_TABLE))
)
def _fetch_applied_migrations(self, cursor, schema_name: str) -> set[str]:
@@ -974,7 +972,7 @@ class DatabaseSetup:
(database,),
)
cursor.execute(
sql.SQL("DROP DATABASE IF EXISTS {}" ).format(
sql.SQL("DROP DATABASE IF EXISTS {}").format(
sql.Identifier(database)
)
)
@@ -985,7 +983,7 @@ class DatabaseSetup:
conn.autocommit = True
with conn.cursor() as cursor:
cursor.execute(
sql.SQL("DROP ROLE IF EXISTS {}" ).format(
sql.SQL("DROP ROLE IF EXISTS {}").format(
sql.Identifier(role)
)
)
@@ -1000,27 +998,35 @@ class DatabaseSetup:
conn.autocommit = True
with conn.cursor() as cursor:
cursor.execute(
sql.SQL("REVOKE ALL PRIVILEGES ON ALL TABLES IN SCHEMA {} FROM {}" ).format(
sql.SQL(
"REVOKE ALL PRIVILEGES ON ALL TABLES IN SCHEMA {} FROM {}"
).format(
sql.Identifier(schema_name),
sql.Identifier(self.config.user)
sql.Identifier(self.config.user),
)
)
cursor.execute(
sql.SQL("REVOKE ALL PRIVILEGES ON ALL SEQUENCES IN SCHEMA {} FROM {}" ).format(
sql.SQL(
"REVOKE ALL PRIVILEGES ON ALL SEQUENCES IN SCHEMA {} FROM {}"
).format(
sql.Identifier(schema_name),
sql.Identifier(self.config.user)
sql.Identifier(self.config.user),
)
)
cursor.execute(
sql.SQL("ALTER DEFAULT PRIVILEGES IN SCHEMA {} REVOKE SELECT, INSERT, UPDATE, DELETE ON TABLES FROM {}" ).format(
sql.SQL(
"ALTER DEFAULT PRIVILEGES IN SCHEMA {} REVOKE SELECT, INSERT, UPDATE, DELETE ON TABLES FROM {}"
).format(
sql.Identifier(schema_name),
sql.Identifier(self.config.user)
sql.Identifier(self.config.user),
)
)
cursor.execute(
sql.SQL("ALTER DEFAULT PRIVILEGES IN SCHEMA {} REVOKE USAGE, SELECT ON SEQUENCES FROM {}" ).format(
sql.SQL(
"ALTER DEFAULT PRIVILEGES IN SCHEMA {} REVOKE USAGE, SELECT ON SEQUENCES FROM {}"
).format(
sql.Identifier(schema_name),
sql.Identifier(self.config.user)
sql.Identifier(self.config.user),
)
)
@@ -1064,19 +1070,18 @@ def parse_args() -> argparse.Namespace:
)
parser.add_argument("--db-driver", help="Override DATABASE_DRIVER")
parser.add_argument("--db-host", help="Override DATABASE_HOST")
parser.add_argument("--db-port", type=int,
help="Override DATABASE_PORT")
parser.add_argument("--db-port", type=int, help="Override DATABASE_PORT")
parser.add_argument("--db-name", help="Override DATABASE_NAME")
parser.add_argument("--db-user", help="Override DATABASE_USER")
parser.add_argument(
"--db-password", help="Override DATABASE_PASSWORD")
parser.add_argument("--db-password", help="Override DATABASE_PASSWORD")
parser.add_argument("--db-schema", help="Override DATABASE_SCHEMA")
parser.add_argument(
"--admin-url",
help="Override DATABASE_ADMIN_URL for administrative operations",
)
parser.add_argument(
"--admin-user", help="Override DATABASE_SUPERUSER for admin ops")
"--admin-user", help="Override DATABASE_SUPERUSER for admin ops"
)
parser.add_argument(
"--admin-password",
help="Override DATABASE_SUPERUSER_PASSWORD for admin ops",
@@ -1091,7 +1096,11 @@ def parse_args() -> argparse.Namespace:
help="Log actions without applying changes.",
)
parser.add_argument(
"--verbose", "-v", action="count", default=0, help="Increase logging verbosity"
"--verbose",
"-v",
action="count",
default=0,
help="Increase logging verbosity",
)
return parser.parse_args()
@@ -1099,8 +1108,9 @@ def parse_args() -> argparse.Namespace:
def main() -> None:
args = parse_args()
level = logging.WARNING - (10 * min(args.verbose, 2))
logging.basicConfig(level=max(level, logging.INFO),
format="%(levelname)s %(message)s")
logging.basicConfig(
level=max(level, logging.INFO), format="%(levelname)s %(message)s"
)
override_args: dict[str, Optional[str]] = {
"DATABASE_DRIVER": args.db_driver,
@@ -1120,7 +1130,9 @@ def main() -> None:
config = DatabaseConfig.from_env(overrides=override_args)
setup = DatabaseSetup(config, dry_run=args.dry_run)
admin_tasks_requested = args.ensure_database or args.ensure_role or args.ensure_schema
admin_tasks_requested = (
args.ensure_database or args.ensure_role or args.ensure_schema
)
if admin_tasks_requested:
setup.validate_admin_connection()
@@ -1145,9 +1157,7 @@ def main() -> None:
auto_run_migrations_reason: Optional[str] = None
if args.seed_data and not should_run_migrations:
should_run_migrations = True
auto_run_migrations_reason = (
"Seed data requested without explicit --run-migrations; applying migrations first."
)
auto_run_migrations_reason = "Seed data requested without explicit --run-migrations; applying migrations first."
try:
if args.ensure_database:
@@ -1167,9 +1177,7 @@ def main() -> None:
if auto_run_migrations_reason:
logger.info(auto_run_migrations_reason)
migrations_path = (
Path(args.migrations_dir)
if args.migrations_dir
else None
Path(args.migrations_dir) if args.migrations_dir else None
)
setup.run_migrations(migrations_path)
if args.seed_data: