231 lines
7.4 KiB
Python
231 lines
7.4 KiB
Python
"""Tests for DuckDB initialization and schema."""
|
|
import asyncio
|
|
import pytest
|
|
import duckdb
|
|
|
|
from app import db as db_module
|
|
|
|
|
|
@pytest.fixture(autouse=True)
|
|
def fresh_db():
|
|
"""Use an in-memory DB for each test and reset global state."""
|
|
db_module._conn = None
|
|
yield
|
|
db_module.close_db()
|
|
db_module._conn = None
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Lifecycle
|
|
# ---------------------------------------------------------------------------
|
|
|
|
def test_init_creates_users_table():
|
|
conn = db_module.init_db(":memory:")
|
|
result = conn.execute(
|
|
"SELECT table_name FROM information_schema.tables WHERE table_name = 'users'"
|
|
).fetchone()
|
|
assert result is not None
|
|
|
|
|
|
def test_init_creates_refresh_tokens_table():
|
|
conn = db_module.init_db(":memory:")
|
|
result = conn.execute(
|
|
"SELECT table_name FROM information_schema.tables WHERE table_name = 'refresh_tokens'"
|
|
).fetchone()
|
|
assert result is not None
|
|
|
|
|
|
def test_init_is_idempotent():
|
|
conn1 = db_module.init_db(":memory:")
|
|
conn2 = db_module.init_db(":memory:")
|
|
assert conn1 is conn2
|
|
|
|
|
|
def test_get_conn_raises_before_init():
|
|
with pytest.raises(RuntimeError, match="not initialised"):
|
|
db_module.get_conn()
|
|
|
|
|
|
def test_get_conn_returns_connection_after_init():
|
|
db_module.init_db(":memory:")
|
|
conn = db_module.get_conn()
|
|
assert conn is not None
|
|
|
|
|
|
def test_close_db_resets_connection():
|
|
db_module.init_db(":memory:")
|
|
db_module.close_db()
|
|
with pytest.raises(RuntimeError):
|
|
db_module.get_conn()
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Schema: users table columns and defaults
|
|
# ---------------------------------------------------------------------------
|
|
|
|
def test_users_columns():
|
|
conn = db_module.init_db(":memory:")
|
|
cols = {
|
|
row[0]
|
|
for row in conn.execute(
|
|
"SELECT column_name FROM information_schema.columns WHERE table_name = 'users'"
|
|
).fetchall()
|
|
}
|
|
assert {"id", "email", "password_hash",
|
|
"role", "created_at", "updated_at"} <= cols
|
|
|
|
|
|
def test_users_default_role_is_user():
|
|
conn = db_module.init_db(":memory:")
|
|
conn.execute(
|
|
"INSERT INTO users (email, password_hash) VALUES ('a@example.com', 'hash')"
|
|
)
|
|
row = conn.execute(
|
|
"SELECT role FROM users WHERE email = 'a@example.com'").fetchone()
|
|
assert row[0] == "user"
|
|
|
|
|
|
def test_users_id_auto_generated():
|
|
conn = db_module.init_db(":memory:")
|
|
conn.execute(
|
|
"INSERT INTO users (email, password_hash) VALUES ('b@example.com', 'hash')"
|
|
)
|
|
row = conn.execute(
|
|
"SELECT id FROM users WHERE email = 'b@example.com'").fetchone()
|
|
assert row[0] is not None
|
|
|
|
|
|
def test_users_email_unique_constraint():
|
|
conn = db_module.init_db(":memory:")
|
|
conn.execute(
|
|
"INSERT INTO users (email, password_hash) VALUES ('c@example.com', 'h')")
|
|
with pytest.raises(Exception):
|
|
conn.execute(
|
|
"INSERT INTO users (email, password_hash) VALUES ('c@example.com', 'h2')")
|
|
|
|
|
|
def test_users_timestamps_auto_set():
|
|
conn = db_module.init_db(":memory:")
|
|
conn.execute(
|
|
"INSERT INTO users (email, password_hash) VALUES ('d@example.com', 'hash')"
|
|
)
|
|
row = conn.execute(
|
|
"SELECT created_at, updated_at FROM users WHERE email = 'd@example.com'"
|
|
).fetchone()
|
|
assert row[0] is not None
|
|
assert row[1] is not None
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Schema: refresh_tokens table columns and defaults
|
|
# ---------------------------------------------------------------------------
|
|
|
|
def test_refresh_tokens_columns():
|
|
conn = db_module.init_db(":memory:")
|
|
cols = {
|
|
row[0]
|
|
for row in conn.execute(
|
|
"SELECT column_name FROM information_schema.columns WHERE table_name = 'refresh_tokens'"
|
|
).fetchall()
|
|
}
|
|
assert {"jti", "user_id", "issued_at", "expires_at", "revoked"} <= cols
|
|
|
|
|
|
def test_refresh_tokens_default_revoked_false():
|
|
conn = db_module.init_db(":memory:")
|
|
conn.execute(
|
|
"INSERT INTO users (email, password_hash) VALUES ('e@example.com', 'h')")
|
|
user_id = conn.execute(
|
|
"SELECT id FROM users WHERE email = 'e@example.com'").fetchone()[0]
|
|
conn.execute(
|
|
"INSERT INTO refresh_tokens (user_id, expires_at) VALUES (?, now() + INTERVAL 7 DAY)",
|
|
[user_id],
|
|
)
|
|
row = conn.execute("SELECT revoked FROM refresh_tokens WHERE user_id = ?", [
|
|
user_id]).fetchone()
|
|
assert row[0] is False
|
|
|
|
|
|
def test_refresh_tokens_jti_auto_generated():
|
|
conn = db_module.init_db(":memory:")
|
|
conn.execute(
|
|
"INSERT INTO users (email, password_hash) VALUES ('f@example.com', 'h')")
|
|
user_id = conn.execute(
|
|
"SELECT id FROM users WHERE email = 'f@example.com'").fetchone()[0]
|
|
conn.execute(
|
|
"INSERT INTO refresh_tokens (user_id, expires_at) VALUES (?, now() + INTERVAL 7 DAY)",
|
|
[user_id],
|
|
)
|
|
row = conn.execute("SELECT jti FROM refresh_tokens WHERE user_id = ?", [
|
|
user_id]).fetchone()
|
|
assert row[0] is not None
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Write lock
|
|
# ---------------------------------------------------------------------------
|
|
|
|
def test_get_write_lock_returns_asyncio_lock():
|
|
lock = db_module.get_write_lock()
|
|
assert isinstance(lock, asyncio.Lock)
|
|
|
|
|
|
def test_get_write_lock_returns_same_instance():
|
|
lock1 = db_module.get_write_lock()
|
|
lock2 = db_module.get_write_lock()
|
|
assert lock1 is lock2
|
|
|
|
|
|
async def test_write_lock_serialises_concurrent_writes():
|
|
"""Two coroutines acquiring the lock must not overlap."""
|
|
conn = db_module.init_db(":memory:")
|
|
lock = db_module.get_write_lock()
|
|
order = []
|
|
|
|
async def writer(label: str):
|
|
async with lock:
|
|
order.append(f"{label}-start")
|
|
await asyncio.sleep(0) # yield to event loop
|
|
order.append(f"{label}-end")
|
|
|
|
await asyncio.gather(writer("A"), writer("B"))
|
|
# Each writer's start and end must be adjacent (no interleaving)
|
|
assert order.index("A-start") + 1 == order.index("A-end") or \
|
|
order.index("B-start") + 1 == order.index("B-end")
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Admin seed user
|
|
# ---------------------------------------------------------------------------
|
|
|
|
def test_seed_admin_user_created_on_init():
|
|
import os
|
|
conn = db_module.init_db(":memory:")
|
|
row = conn.execute(
|
|
"SELECT email, role FROM users WHERE email = 'ai@allucanget.biz'"
|
|
).fetchone()
|
|
assert row is not None
|
|
assert row[0] == "ai@allucanget.biz"
|
|
assert row[1] == "admin"
|
|
|
|
|
|
def test_seed_admin_is_idempotent():
|
|
conn = db_module.init_db(":memory:")
|
|
# Simulate re-running seed (second init_db call reuses connection, so call _seed_admin directly)
|
|
db_module._seed_admin(conn)
|
|
count = conn.execute(
|
|
"SELECT COUNT(*) FROM users WHERE email = 'ai@allucanget.biz'"
|
|
).fetchone()[0]
|
|
assert count == 1
|
|
|
|
|
|
def test_seed_admin_email_env_override(monkeypatch):
|
|
monkeypatch.setenv("ADMIN_EMAIL", "custom@example.com")
|
|
monkeypatch.setenv("ADMIN_PASSWORD", "custompass")
|
|
conn = db_module.init_db(":memory:")
|
|
row = conn.execute(
|
|
"SELECT email, role FROM users WHERE email = 'custom@example.com'"
|
|
).fetchone()
|
|
assert row is not None
|
|
assert row[1] == "admin"
|