Refactor test cases for improved readability and consistency
- Updated test functions in various test files to enhance code clarity by formatting long lines and improving indentation. - Adjusted assertions to use multi-line formatting for better readability. - Added new test cases for theme settings API to ensure proper functionality. - Ensured consistent use of line breaks and spacing across test files for uniformity.
This commit is contained in:
@@ -9,6 +9,7 @@ This script is intentionally cautious: it defaults to dry-run mode and will refu
|
||||
if database connection settings are missing. It supports creating missing currency rows when `--create-missing`
|
||||
is provided. Always run against a development/staging database first.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
import argparse
|
||||
import importlib
|
||||
@@ -36,26 +37,43 @@ def load_database_url() -> str:
|
||||
return getattr(db_module, "DATABASE_URL")
|
||||
|
||||
|
||||
def backfill(db_url: str, dry_run: bool = True, create_missing: bool = False) -> None:
|
||||
def backfill(
|
||||
db_url: str, dry_run: bool = True, create_missing: bool = False
|
||||
) -> None:
|
||||
engine = create_engine(db_url)
|
||||
with engine.begin() as conn:
|
||||
# Ensure currency table exists
|
||||
res = conn.execute(text("SELECT name FROM sqlite_master WHERE type='table' AND name='currency';")) if db_url.startswith(
|
||||
'sqlite:') else conn.execute(text("SELECT to_regclass('public.currency');"))
|
||||
res = (
|
||||
conn.execute(
|
||||
text(
|
||||
"SELECT name FROM sqlite_master WHERE type='table' AND name='currency';"
|
||||
)
|
||||
)
|
||||
if db_url.startswith("sqlite:")
|
||||
else conn.execute(text("SELECT to_regclass('public.currency');"))
|
||||
)
|
||||
# Note: we don't strictly depend on the above - we assume migration was already applied
|
||||
|
||||
# Helper: find or create currency by code
|
||||
def find_currency_id(code: str):
|
||||
r = conn.execute(text("SELECT id FROM currency WHERE code = :code"), {
|
||||
"code": code}).fetchone()
|
||||
r = conn.execute(
|
||||
text("SELECT id FROM currency WHERE code = :code"),
|
||||
{"code": code},
|
||||
).fetchone()
|
||||
if r:
|
||||
return r[0]
|
||||
if create_missing:
|
||||
# insert and return id
|
||||
conn.execute(text("INSERT INTO currency (code, name, symbol, is_active) VALUES (:c, :n, NULL, TRUE)"), {
|
||||
"c": code, "n": code})
|
||||
r2 = conn.execute(text("SELECT id FROM currency WHERE code = :code"), {
|
||||
"code": code}).fetchone()
|
||||
conn.execute(
|
||||
text(
|
||||
"INSERT INTO currency (code, name, symbol, is_active) VALUES (:c, :n, NULL, TRUE)"
|
||||
),
|
||||
{"c": code, "n": code},
|
||||
)
|
||||
r2 = conn.execute(
|
||||
text("SELECT id FROM currency WHERE code = :code"),
|
||||
{"code": code},
|
||||
).fetchone()
|
||||
if not r2:
|
||||
raise RuntimeError(
|
||||
f"Unable to determine currency ID for '{code}' after insert"
|
||||
@@ -67,8 +85,15 @@ def backfill(db_url: str, dry_run: bool = True, create_missing: bool = False) ->
|
||||
for table in ("capex", "opex"):
|
||||
# Check if currency_id column exists
|
||||
try:
|
||||
cols = conn.execute(text(f"SELECT 1 FROM information_schema.columns WHERE table_name = '{table}' AND column_name = 'currency_id'")) if not db_url.startswith(
|
||||
'sqlite:') else [(1,)]
|
||||
cols = (
|
||||
conn.execute(
|
||||
text(
|
||||
f"SELECT 1 FROM information_schema.columns WHERE table_name = '{table}' AND column_name = 'currency_id'"
|
||||
)
|
||||
)
|
||||
if not db_url.startswith("sqlite:")
|
||||
else [(1,)]
|
||||
)
|
||||
except Exception:
|
||||
cols = [(1,)]
|
||||
|
||||
@@ -77,8 +102,11 @@ def backfill(db_url: str, dry_run: bool = True, create_missing: bool = False) ->
|
||||
continue
|
||||
|
||||
# Find rows where currency_id IS NULL but currency_code exists
|
||||
rows = conn.execute(text(
|
||||
f"SELECT id, currency_code FROM {table} WHERE currency_id IS NULL OR currency_id = ''"))
|
||||
rows = conn.execute(
|
||||
text(
|
||||
f"SELECT id, currency_code FROM {table} WHERE currency_id IS NULL OR currency_id = ''"
|
||||
)
|
||||
)
|
||||
changed = 0
|
||||
for r in rows:
|
||||
rid = r[0]
|
||||
@@ -86,14 +114,20 @@ def backfill(db_url: str, dry_run: bool = True, create_missing: bool = False) ->
|
||||
cid = find_currency_id(code)
|
||||
if cid is None:
|
||||
print(
|
||||
f"Row {table}:{rid} has unknown currency code '{code}' and create_missing=False; skipping")
|
||||
f"Row {table}:{rid} has unknown currency code '{code}' and create_missing=False; skipping"
|
||||
)
|
||||
continue
|
||||
if dry_run:
|
||||
print(
|
||||
f"[DRY RUN] Would set {table}.currency_id = {cid} for row id={rid} (code={code})")
|
||||
f"[DRY RUN] Would set {table}.currency_id = {cid} for row id={rid} (code={code})"
|
||||
)
|
||||
else:
|
||||
conn.execute(text(f"UPDATE {table} SET currency_id = :cid WHERE id = :rid"), {
|
||||
"cid": cid, "rid": rid})
|
||||
conn.execute(
|
||||
text(
|
||||
f"UPDATE {table} SET currency_id = :cid WHERE id = :rid"
|
||||
),
|
||||
{"cid": cid, "rid": rid},
|
||||
)
|
||||
changed += 1
|
||||
|
||||
print(f"{table}: processed, changed={changed} (dry_run={dry_run})")
|
||||
@@ -101,11 +135,19 @@ def backfill(db_url: str, dry_run: bool = True, create_missing: bool = False) ->
|
||||
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Backfill currency_id from currency_code for capex/opex tables")
|
||||
parser.add_argument("--dry-run", action="store_true",
|
||||
default=True, help="Show actions without writing")
|
||||
parser.add_argument("--create-missing", action="store_true",
|
||||
help="Create missing currency rows in the currency table")
|
||||
description="Backfill currency_id from currency_code for capex/opex tables"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dry-run",
|
||||
action="store_true",
|
||||
default=True,
|
||||
help="Show actions without writing",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--create-missing",
|
||||
action="store_true",
|
||||
help="Create missing currency rows in the currency table",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
db = load_database_url()
|
||||
|
||||
@@ -4,25 +4,30 @@ Checks only local file links (relative paths) and reports missing targets.
|
||||
|
||||
Run from the repository root using the project's Python environment.
|
||||
"""
|
||||
|
||||
import re
|
||||
from pathlib import Path
|
||||
|
||||
ROOT = Path(__file__).resolve().parent.parent
|
||||
DOCS = ROOT / 'docs'
|
||||
DOCS = ROOT / "docs"
|
||||
|
||||
MD_LINK_RE = re.compile(r"\[([^\]]+)\]\(([^)]+)\)")
|
||||
|
||||
errors = []
|
||||
|
||||
for md in DOCS.rglob('*.md'):
|
||||
text = md.read_text(encoding='utf-8')
|
||||
for md in DOCS.rglob("*.md"):
|
||||
text = md.read_text(encoding="utf-8")
|
||||
for m in MD_LINK_RE.finditer(text):
|
||||
label, target = m.groups()
|
||||
# skip URLs
|
||||
if target.startswith('http://') or target.startswith('https://') or target.startswith('#'):
|
||||
if (
|
||||
target.startswith("http://")
|
||||
or target.startswith("https://")
|
||||
or target.startswith("#")
|
||||
):
|
||||
continue
|
||||
# strip anchors
|
||||
target_path = target.split('#')[0]
|
||||
target_path = target.split("#")[0]
|
||||
# if link is to a directory index, allow
|
||||
candidate = (md.parent / target_path).resolve()
|
||||
if candidate.exists():
|
||||
@@ -30,14 +35,16 @@ for md in DOCS.rglob('*.md'):
|
||||
# check common implicit index: target/ -> target/README.md or target/index.md
|
||||
candidate_dir = md.parent / target_path
|
||||
if candidate_dir.is_dir():
|
||||
if (candidate_dir / 'README.md').exists() or (candidate_dir / 'index.md').exists():
|
||||
if (candidate_dir / "README.md").exists() or (
|
||||
candidate_dir / "index.md"
|
||||
).exists():
|
||||
continue
|
||||
errors.append((str(md.relative_to(ROOT)), target, label))
|
||||
|
||||
if errors:
|
||||
print('Broken local links found:')
|
||||
print("Broken local links found:")
|
||||
for src, tgt, label in errors:
|
||||
print(f'- {src} -> {tgt} ({label})')
|
||||
print(f"- {src} -> {tgt} ({label})")
|
||||
exit(2)
|
||||
|
||||
print('No broken local links detected.')
|
||||
print("No broken local links detected.")
|
||||
|
||||
@@ -2,16 +2,17 @@
|
||||
|
||||
This is intentionally small and non-destructive; it touches only files under docs/ and makes safe changes.
|
||||
"""
|
||||
|
||||
import re
|
||||
from pathlib import Path
|
||||
|
||||
DOCS = Path(__file__).resolve().parents[1] / "docs"
|
||||
|
||||
CODE_LANG_HINTS = {
|
||||
'powershell': ('powershell',),
|
||||
'bash': ('bash', 'sh'),
|
||||
'sql': ('sql',),
|
||||
'python': ('python',),
|
||||
"powershell": ("powershell",),
|
||||
"bash": ("bash", "sh"),
|
||||
"sql": ("sql",),
|
||||
"python": ("python",),
|
||||
}
|
||||
|
||||
|
||||
@@ -19,48 +20,60 @@ def add_code_fence_language(match):
|
||||
fence = match.group(0)
|
||||
inner = match.group(1)
|
||||
# If language already present, return unchanged
|
||||
if fence.startswith('```') and len(fence.splitlines()[0].strip()) > 3:
|
||||
if fence.startswith("```") and len(fence.splitlines()[0].strip()) > 3:
|
||||
return fence
|
||||
# Try to infer language from the code content
|
||||
code = inner.strip().splitlines()[0] if inner.strip() else ''
|
||||
lang = ''
|
||||
if code.startswith('$') or code.startswith('PS') or code.lower().startswith('powershell'):
|
||||
lang = 'powershell'
|
||||
elif code.startswith('#') or code.startswith('import') or code.startswith('from'):
|
||||
lang = 'python'
|
||||
elif re.match(r'^(select|insert|update|create)\b', code.strip(), re.I):
|
||||
lang = 'sql'
|
||||
elif code.startswith('git') or code.startswith('./') or code.startswith('sudo'):
|
||||
lang = 'bash'
|
||||
code = inner.strip().splitlines()[0] if inner.strip() else ""
|
||||
lang = ""
|
||||
if (
|
||||
code.startswith("$")
|
||||
or code.startswith("PS")
|
||||
or code.lower().startswith("powershell")
|
||||
):
|
||||
lang = "powershell"
|
||||
elif (
|
||||
code.startswith("#")
|
||||
or code.startswith("import")
|
||||
or code.startswith("from")
|
||||
):
|
||||
lang = "python"
|
||||
elif re.match(r"^(select|insert|update|create)\b", code.strip(), re.I):
|
||||
lang = "sql"
|
||||
elif (
|
||||
code.startswith("git")
|
||||
or code.startswith("./")
|
||||
or code.startswith("sudo")
|
||||
):
|
||||
lang = "bash"
|
||||
if lang:
|
||||
return f'```{lang}\n{inner}\n```'
|
||||
return f"```{lang}\n{inner}\n```"
|
||||
return fence
|
||||
|
||||
|
||||
def normalize_file(path: Path):
|
||||
text = path.read_text(encoding='utf-8')
|
||||
text = path.read_text(encoding="utf-8")
|
||||
orig = text
|
||||
# Trim trailing whitespace and ensure single trailing newline
|
||||
text = '\n'.join(line.rstrip() for line in text.splitlines()) + '\n'
|
||||
text = "\n".join(line.rstrip() for line in text.splitlines()) + "\n"
|
||||
# Ensure first non-empty line is H1
|
||||
lines = text.splitlines()
|
||||
for i, ln in enumerate(lines):
|
||||
if ln.strip():
|
||||
if not ln.startswith('#'):
|
||||
lines[i] = '# ' + ln
|
||||
if not ln.startswith("#"):
|
||||
lines[i] = "# " + ln
|
||||
break
|
||||
text = '\n'.join(lines) + '\n'
|
||||
text = "\n".join(lines) + "\n"
|
||||
# Add basic code fence languages where missing (simple heuristic)
|
||||
text = re.sub(r'```\n([\s\S]*?)\n```', add_code_fence_language, text)
|
||||
text = re.sub(r"```\n([\s\S]*?)\n```", add_code_fence_language, text)
|
||||
if text != orig:
|
||||
path.write_text(text, encoding='utf-8')
|
||||
path.write_text(text, encoding="utf-8")
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def main():
|
||||
changed = []
|
||||
for p in DOCS.rglob('*.md'):
|
||||
for p in DOCS.rglob("*.md"):
|
||||
if p.is_file():
|
||||
try:
|
||||
if normalize_file(p):
|
||||
@@ -68,12 +81,12 @@ def main():
|
||||
except Exception as e:
|
||||
print(f"Failed to format {p}: {e}")
|
||||
if changed:
|
||||
print('Formatted files:')
|
||||
print("Formatted files:")
|
||||
for c in changed:
|
||||
print(' -', c)
|
||||
print(" -", c)
|
||||
else:
|
||||
print('No formatting changes required.')
|
||||
print("No formatting changes required.")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
11
scripts/migrations/20251027_create_theme_settings_table.sql
Normal file
11
scripts/migrations/20251027_create_theme_settings_table.sql
Normal file
@@ -0,0 +1,11 @@
|
||||
-- Migration: 20251027_create_theme_settings_table.sql
|
||||
|
||||
CREATE TABLE theme_settings (
|
||||
id SERIAL PRIMARY KEY,
|
||||
theme_name VARCHAR(255) UNIQUE NOT NULL,
|
||||
primary_color VARCHAR(7) NOT NULL,
|
||||
secondary_color VARCHAR(7) NOT NULL,
|
||||
accent_color VARCHAR(7) NOT NULL,
|
||||
background_color VARCHAR(7) NOT NULL,
|
||||
text_color VARCHAR(7) NOT NULL
|
||||
);
|
||||
15
scripts/migrations/20251027_create_user_and_role_tables.sql
Normal file
15
scripts/migrations/20251027_create_user_and_role_tables.sql
Normal file
@@ -0,0 +1,15 @@
|
||||
-- Migration: 20251027_create_user_and_role_tables.sql
|
||||
|
||||
CREATE TABLE roles (
|
||||
id SERIAL PRIMARY KEY,
|
||||
name VARCHAR(255) UNIQUE NOT NULL
|
||||
);
|
||||
|
||||
CREATE TABLE users (
|
||||
id SERIAL PRIMARY KEY,
|
||||
username VARCHAR(255) UNIQUE NOT NULL,
|
||||
email VARCHAR(255) UNIQUE NOT NULL,
|
||||
hashed_password VARCHAR(255) NOT NULL,
|
||||
role_id INTEGER NOT NULL,
|
||||
FOREIGN KEY (role_id) REFERENCES roles(id)
|
||||
);
|
||||
@@ -47,22 +47,82 @@ MEASUREMENT_UNIT_SEEDS = (
|
||||
("kilowatt_hours", "Kilowatt Hours", "kWh", "energy", True),
|
||||
)
|
||||
|
||||
THEME_SETTING_SEEDS = (
|
||||
("--color-background", "#f4f5f7", "color",
|
||||
"theme", "CSS variable --color-background", True),
|
||||
("--color-surface", "#ffffff", "color",
|
||||
"theme", "CSS variable --color-surface", True),
|
||||
("--color-text-primary", "#2a1f33", "color",
|
||||
"theme", "CSS variable --color-text-primary", True),
|
||||
("--color-text-secondary", "#624769", "color",
|
||||
"theme", "CSS variable --color-text-secondary", True),
|
||||
("--color-text-muted", "#64748b", "color",
|
||||
"theme", "CSS variable --color-text-muted", True),
|
||||
("--color-text-subtle", "#94a3b8", "color",
|
||||
"theme", "CSS variable --color-text-subtle", True),
|
||||
("--color-text-invert", "#ffffff", "color",
|
||||
"theme", "CSS variable --color-text-invert", True),
|
||||
("--color-text-dark", "#0f172a", "color",
|
||||
"theme", "CSS variable --color-text-dark", True),
|
||||
("--color-text-strong", "#111827", "color",
|
||||
"theme", "CSS variable --color-text-strong", True),
|
||||
("--color-primary", "#5f320d", "color",
|
||||
"theme", "CSS variable --color-primary", True),
|
||||
("--color-primary-strong", "#7e4c13", "color",
|
||||
"theme", "CSS variable --color-primary-strong", True),
|
||||
("--color-primary-stronger", "#837c15", "color",
|
||||
"theme", "CSS variable --color-primary-stronger", True),
|
||||
("--color-accent", "#bff838", "color",
|
||||
"theme", "CSS variable --color-accent", True),
|
||||
("--color-border", "#e2e8f0", "color",
|
||||
"theme", "CSS variable --color-border", True),
|
||||
("--color-border-strong", "#cbd5e1", "color",
|
||||
"theme", "CSS variable --color-border-strong", True),
|
||||
("--color-highlight", "#eef2ff", "color",
|
||||
"theme", "CSS variable --color-highlight", True),
|
||||
("--color-panel-shadow", "rgba(15, 23, 42, 0.08)", "color",
|
||||
"theme", "CSS variable --color-panel-shadow", True),
|
||||
("--color-panel-shadow-deep", "rgba(15, 23, 42, 0.12)", "color",
|
||||
"theme", "CSS variable --color-panel-shadow-deep", True),
|
||||
("--color-surface-alt", "#f8fafc", "color",
|
||||
"theme", "CSS variable --color-surface-alt", True),
|
||||
("--color-success", "#047857", "color",
|
||||
"theme", "CSS variable --color-success", True),
|
||||
("--color-error", "#b91c1c", "color",
|
||||
"theme", "CSS variable --color-error", True),
|
||||
)
|
||||
|
||||
|
||||
def parse_args() -> argparse.Namespace:
|
||||
parser = argparse.ArgumentParser(description="Seed baseline CalMiner data")
|
||||
parser.add_argument("--currencies", action="store_true", help="Seed currency table")
|
||||
parser.add_argument("--units", action="store_true", help="Seed unit table")
|
||||
parser.add_argument("--defaults", action="store_true", help="Seed default records")
|
||||
parser.add_argument("--dry-run", action="store_true", help="Print actions without executing")
|
||||
parser.add_argument(
|
||||
"--verbose", "-v", action="count", default=0, help="Increase logging verbosity"
|
||||
"--currencies", action="store_true", help="Seed currency table"
|
||||
)
|
||||
parser.add_argument("--units", action="store_true", help="Seed unit table")
|
||||
parser.add_argument(
|
||||
"--theme", action="store_true", help="Seed theme settings"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--defaults", action="store_true", help="Seed default records"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dry-run", action="store_true", help="Print actions without executing"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--verbose",
|
||||
"-v",
|
||||
action="count",
|
||||
default=0,
|
||||
help="Increase logging verbosity",
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def _configure_logging(args: argparse.Namespace) -> None:
|
||||
level = logging.WARNING - (10 * min(args.verbose, 2))
|
||||
logging.basicConfig(level=max(level, logging.INFO), format="%(levelname)s %(message)s")
|
||||
logging.basicConfig(
|
||||
level=max(level, logging.INFO), format="%(levelname)s %(message)s"
|
||||
)
|
||||
|
||||
|
||||
def main() -> None:
|
||||
@@ -77,7 +137,7 @@ def run_with_namespace(
|
||||
) -> None:
|
||||
_configure_logging(args)
|
||||
|
||||
if not any((args.currencies, args.units, args.defaults)):
|
||||
if not any((args.currencies, args.units, args.theme, args.defaults)):
|
||||
logger.info("No seeding options provided; exiting")
|
||||
return
|
||||
|
||||
@@ -89,6 +149,8 @@ def run_with_namespace(
|
||||
_seed_currencies(cursor, dry_run=args.dry_run)
|
||||
if args.units:
|
||||
_seed_units(cursor, dry_run=args.dry_run)
|
||||
if args.theme:
|
||||
_seed_theme(cursor, dry_run=args.dry_run)
|
||||
if args.defaults:
|
||||
_seed_defaults(cursor, dry_run=args.dry_run)
|
||||
|
||||
@@ -152,11 +214,44 @@ def _seed_units(cursor, *, dry_run: bool) -> None:
|
||||
logger.info("Measurement unit seed complete")
|
||||
|
||||
|
||||
def _seed_defaults(cursor, *, dry_run: bool) -> None:
|
||||
logger.info("Seeding default records - not yet implemented")
|
||||
def _seed_theme(cursor, *, dry_run: bool) -> None:
|
||||
logger.info("Seeding theme settings (%d rows)", len(THEME_SETTING_SEEDS))
|
||||
if dry_run:
|
||||
for key, value, _, _, _, _ in THEME_SETTING_SEEDS:
|
||||
logger.info(
|
||||
"Dry run: would upsert theme setting %s = %s", key, value)
|
||||
return
|
||||
|
||||
try:
|
||||
execute_values(
|
||||
cursor,
|
||||
"""
|
||||
INSERT INTO application_setting (key, value, value_type, category, description, is_editable)
|
||||
VALUES %s
|
||||
ON CONFLICT (key) DO UPDATE
|
||||
SET value = EXCLUDED.value,
|
||||
value_type = EXCLUDED.value_type,
|
||||
category = EXCLUDED.category,
|
||||
description = EXCLUDED.description,
|
||||
is_editable = EXCLUDED.is_editable
|
||||
""",
|
||||
THEME_SETTING_SEEDS,
|
||||
)
|
||||
except errors.UndefinedTable:
|
||||
logger.warning(
|
||||
"application_setting table does not exist; skipping theme seeding."
|
||||
)
|
||||
cursor.connection.rollback()
|
||||
return
|
||||
|
||||
logger.info("Theme settings seed complete")
|
||||
|
||||
|
||||
def _seed_defaults(cursor, *, dry_run: bool) -> None:
|
||||
logger.info("Seeding default records")
|
||||
_seed_theme(cursor, dry_run=dry_run)
|
||||
logger.info("Default records seed complete")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
main()
|
||||
|
||||
@@ -39,6 +39,7 @@ from psycopg2 import extensions
|
||||
from psycopg2.extensions import connection as PGConnection, parse_dsn
|
||||
from dotenv import load_dotenv
|
||||
from sqlalchemy import create_engine, inspect
|
||||
|
||||
ROOT_DIR = Path(__file__).resolve().parents[1]
|
||||
if str(ROOT_DIR) not in sys.path:
|
||||
sys.path.insert(0, str(ROOT_DIR))
|
||||
@@ -125,8 +126,7 @@ class DatabaseConfig:
|
||||
]
|
||||
if missing:
|
||||
raise RuntimeError(
|
||||
"Missing required database configuration: " +
|
||||
", ".join(missing)
|
||||
"Missing required database configuration: " + ", ".join(missing)
|
||||
)
|
||||
|
||||
host = cast(str, host)
|
||||
@@ -208,12 +208,17 @@ class DatabaseConfig:
|
||||
class DatabaseSetup:
|
||||
"""Encapsulates the full setup workflow."""
|
||||
|
||||
def __init__(self, config: DatabaseConfig, *, dry_run: bool = False) -> None:
|
||||
def __init__(
|
||||
self, config: DatabaseConfig, *, dry_run: bool = False
|
||||
) -> None:
|
||||
self.config = config
|
||||
self.dry_run = dry_run
|
||||
self._models_loaded = False
|
||||
self._rollback_actions: list[tuple[str, Callable[[], None]]] = []
|
||||
def _register_rollback(self, label: str, action: Callable[[], None]) -> None:
|
||||
|
||||
def _register_rollback(
|
||||
self, label: str, action: Callable[[], None]
|
||||
) -> None:
|
||||
if self.dry_run:
|
||||
return
|
||||
self._rollback_actions.append((label, action))
|
||||
@@ -237,7 +242,6 @@ class DatabaseSetup:
|
||||
def clear_rollbacks(self) -> None:
|
||||
self._rollback_actions.clear()
|
||||
|
||||
|
||||
def _describe_connection(self, user: str, database: str) -> str:
|
||||
return f"{user}@{self.config.host}:{self.config.port}/{database}"
|
||||
|
||||
@@ -384,9 +388,9 @@ class DatabaseSetup:
|
||||
try:
|
||||
if self.config.password:
|
||||
cursor.execute(
|
||||
sql.SQL("CREATE ROLE {} WITH LOGIN PASSWORD %s").format(
|
||||
sql.Identifier(self.config.user)
|
||||
),
|
||||
sql.SQL(
|
||||
"CREATE ROLE {} WITH LOGIN PASSWORD %s"
|
||||
).format(sql.Identifier(self.config.user)),
|
||||
(self.config.password,),
|
||||
)
|
||||
else:
|
||||
@@ -589,8 +593,7 @@ class DatabaseSetup:
|
||||
return psycopg2.connect(dsn)
|
||||
except psycopg2.Error as exc:
|
||||
raise RuntimeError(
|
||||
"Unable to establish admin connection. "
|
||||
f"Target: {descriptor}"
|
||||
"Unable to establish admin connection. " f"Target: {descriptor}"
|
||||
) from exc
|
||||
|
||||
def _application_connection(self) -> PGConnection:
|
||||
@@ -645,7 +648,9 @@ class DatabaseSetup:
|
||||
importlib.import_module(f"{package.__name__}.{module_info.name}")
|
||||
self._models_loaded = True
|
||||
|
||||
def run_migrations(self, migrations_dir: Optional[Path | str] = None) -> None:
|
||||
def run_migrations(
|
||||
self, migrations_dir: Optional[Path | str] = None
|
||||
) -> None:
|
||||
"""Execute pending SQL migrations in chronological order."""
|
||||
|
||||
directory = (
|
||||
@@ -673,7 +678,8 @@ class DatabaseSetup:
|
||||
conn.autocommit = True
|
||||
with conn.cursor() as cursor:
|
||||
table_exists = self._migrations_table_exists(
|
||||
cursor, schema_name)
|
||||
cursor, schema_name
|
||||
)
|
||||
if not table_exists:
|
||||
if self.dry_run:
|
||||
logger.info(
|
||||
@@ -692,12 +698,10 @@ class DatabaseSetup:
|
||||
applied = set()
|
||||
else:
|
||||
applied = self._fetch_applied_migrations(
|
||||
cursor, schema_name)
|
||||
cursor, schema_name
|
||||
)
|
||||
|
||||
if (
|
||||
baseline_path.exists()
|
||||
and baseline_name not in applied
|
||||
):
|
||||
if baseline_path.exists() and baseline_name not in applied:
|
||||
if self.dry_run:
|
||||
logger.info(
|
||||
"Dry run: baseline migration '%s' pending; would apply and mark legacy files",
|
||||
@@ -756,9 +760,7 @@ class DatabaseSetup:
|
||||
)
|
||||
|
||||
pending = [
|
||||
path
|
||||
for path in migration_files
|
||||
if path.name not in applied
|
||||
path for path in migration_files if path.name not in applied
|
||||
]
|
||||
|
||||
if not pending:
|
||||
@@ -792,9 +794,7 @@ class DatabaseSetup:
|
||||
cursor.execute(
|
||||
sql.SQL(
|
||||
"INSERT INTO {} (filename, applied_at) VALUES (%s, NOW())"
|
||||
).format(
|
||||
sql.Identifier(schema_name, MIGRATIONS_TABLE)
|
||||
),
|
||||
).format(sql.Identifier(schema_name, MIGRATIONS_TABLE)),
|
||||
(path.name,),
|
||||
)
|
||||
return path.name
|
||||
@@ -820,9 +820,7 @@ class DatabaseSetup:
|
||||
"filename TEXT PRIMARY KEY,"
|
||||
"applied_at TIMESTAMPTZ NOT NULL DEFAULT NOW()"
|
||||
")"
|
||||
).format(
|
||||
sql.Identifier(schema_name, MIGRATIONS_TABLE)
|
||||
)
|
||||
).format(sql.Identifier(schema_name, MIGRATIONS_TABLE))
|
||||
)
|
||||
|
||||
def _fetch_applied_migrations(self, cursor, schema_name: str) -> set[str]:
|
||||
@@ -974,7 +972,7 @@ class DatabaseSetup:
|
||||
(database,),
|
||||
)
|
||||
cursor.execute(
|
||||
sql.SQL("DROP DATABASE IF EXISTS {}" ).format(
|
||||
sql.SQL("DROP DATABASE IF EXISTS {}").format(
|
||||
sql.Identifier(database)
|
||||
)
|
||||
)
|
||||
@@ -985,7 +983,7 @@ class DatabaseSetup:
|
||||
conn.autocommit = True
|
||||
with conn.cursor() as cursor:
|
||||
cursor.execute(
|
||||
sql.SQL("DROP ROLE IF EXISTS {}" ).format(
|
||||
sql.SQL("DROP ROLE IF EXISTS {}").format(
|
||||
sql.Identifier(role)
|
||||
)
|
||||
)
|
||||
@@ -1000,27 +998,35 @@ class DatabaseSetup:
|
||||
conn.autocommit = True
|
||||
with conn.cursor() as cursor:
|
||||
cursor.execute(
|
||||
sql.SQL("REVOKE ALL PRIVILEGES ON ALL TABLES IN SCHEMA {} FROM {}" ).format(
|
||||
sql.SQL(
|
||||
"REVOKE ALL PRIVILEGES ON ALL TABLES IN SCHEMA {} FROM {}"
|
||||
).format(
|
||||
sql.Identifier(schema_name),
|
||||
sql.Identifier(self.config.user)
|
||||
sql.Identifier(self.config.user),
|
||||
)
|
||||
)
|
||||
cursor.execute(
|
||||
sql.SQL("REVOKE ALL PRIVILEGES ON ALL SEQUENCES IN SCHEMA {} FROM {}" ).format(
|
||||
sql.SQL(
|
||||
"REVOKE ALL PRIVILEGES ON ALL SEQUENCES IN SCHEMA {} FROM {}"
|
||||
).format(
|
||||
sql.Identifier(schema_name),
|
||||
sql.Identifier(self.config.user)
|
||||
sql.Identifier(self.config.user),
|
||||
)
|
||||
)
|
||||
cursor.execute(
|
||||
sql.SQL("ALTER DEFAULT PRIVILEGES IN SCHEMA {} REVOKE SELECT, INSERT, UPDATE, DELETE ON TABLES FROM {}" ).format(
|
||||
sql.SQL(
|
||||
"ALTER DEFAULT PRIVILEGES IN SCHEMA {} REVOKE SELECT, INSERT, UPDATE, DELETE ON TABLES FROM {}"
|
||||
).format(
|
||||
sql.Identifier(schema_name),
|
||||
sql.Identifier(self.config.user)
|
||||
sql.Identifier(self.config.user),
|
||||
)
|
||||
)
|
||||
cursor.execute(
|
||||
sql.SQL("ALTER DEFAULT PRIVILEGES IN SCHEMA {} REVOKE USAGE, SELECT ON SEQUENCES FROM {}" ).format(
|
||||
sql.SQL(
|
||||
"ALTER DEFAULT PRIVILEGES IN SCHEMA {} REVOKE USAGE, SELECT ON SEQUENCES FROM {}"
|
||||
).format(
|
||||
sql.Identifier(schema_name),
|
||||
sql.Identifier(self.config.user)
|
||||
sql.Identifier(self.config.user),
|
||||
)
|
||||
)
|
||||
|
||||
@@ -1064,19 +1070,18 @@ def parse_args() -> argparse.Namespace:
|
||||
)
|
||||
parser.add_argument("--db-driver", help="Override DATABASE_DRIVER")
|
||||
parser.add_argument("--db-host", help="Override DATABASE_HOST")
|
||||
parser.add_argument("--db-port", type=int,
|
||||
help="Override DATABASE_PORT")
|
||||
parser.add_argument("--db-port", type=int, help="Override DATABASE_PORT")
|
||||
parser.add_argument("--db-name", help="Override DATABASE_NAME")
|
||||
parser.add_argument("--db-user", help="Override DATABASE_USER")
|
||||
parser.add_argument(
|
||||
"--db-password", help="Override DATABASE_PASSWORD")
|
||||
parser.add_argument("--db-password", help="Override DATABASE_PASSWORD")
|
||||
parser.add_argument("--db-schema", help="Override DATABASE_SCHEMA")
|
||||
parser.add_argument(
|
||||
"--admin-url",
|
||||
help="Override DATABASE_ADMIN_URL for administrative operations",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--admin-user", help="Override DATABASE_SUPERUSER for admin ops")
|
||||
"--admin-user", help="Override DATABASE_SUPERUSER for admin ops"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--admin-password",
|
||||
help="Override DATABASE_SUPERUSER_PASSWORD for admin ops",
|
||||
@@ -1091,7 +1096,11 @@ def parse_args() -> argparse.Namespace:
|
||||
help="Log actions without applying changes.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--verbose", "-v", action="count", default=0, help="Increase logging verbosity"
|
||||
"--verbose",
|
||||
"-v",
|
||||
action="count",
|
||||
default=0,
|
||||
help="Increase logging verbosity",
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
@@ -1099,8 +1108,9 @@ def parse_args() -> argparse.Namespace:
|
||||
def main() -> None:
|
||||
args = parse_args()
|
||||
level = logging.WARNING - (10 * min(args.verbose, 2))
|
||||
logging.basicConfig(level=max(level, logging.INFO),
|
||||
format="%(levelname)s %(message)s")
|
||||
logging.basicConfig(
|
||||
level=max(level, logging.INFO), format="%(levelname)s %(message)s"
|
||||
)
|
||||
|
||||
override_args: dict[str, Optional[str]] = {
|
||||
"DATABASE_DRIVER": args.db_driver,
|
||||
@@ -1120,7 +1130,9 @@ def main() -> None:
|
||||
config = DatabaseConfig.from_env(overrides=override_args)
|
||||
setup = DatabaseSetup(config, dry_run=args.dry_run)
|
||||
|
||||
admin_tasks_requested = args.ensure_database or args.ensure_role or args.ensure_schema
|
||||
admin_tasks_requested = (
|
||||
args.ensure_database or args.ensure_role or args.ensure_schema
|
||||
)
|
||||
if admin_tasks_requested:
|
||||
setup.validate_admin_connection()
|
||||
|
||||
@@ -1145,9 +1157,7 @@ def main() -> None:
|
||||
auto_run_migrations_reason: Optional[str] = None
|
||||
if args.seed_data and not should_run_migrations:
|
||||
should_run_migrations = True
|
||||
auto_run_migrations_reason = (
|
||||
"Seed data requested without explicit --run-migrations; applying migrations first."
|
||||
)
|
||||
auto_run_migrations_reason = "Seed data requested without explicit --run-migrations; applying migrations first."
|
||||
|
||||
try:
|
||||
if args.ensure_database:
|
||||
@@ -1167,9 +1177,7 @@ def main() -> None:
|
||||
if auto_run_migrations_reason:
|
||||
logger.info(auto_run_migrations_reason)
|
||||
migrations_path = (
|
||||
Path(args.migrations_dir)
|
||||
if args.migrations_dir
|
||||
else None
|
||||
Path(args.migrations_dir) if args.migrations_dir else None
|
||||
)
|
||||
setup.run_migrations(migrations_path)
|
||||
if args.seed_data:
|
||||
|
||||
Reference in New Issue
Block a user