Files
ai.allucanget.biz/backend/tests/test_db.py
T

195 lines
6.2 KiB
Python

"""Tests for DuckDB initialization and schema."""
import asyncio
import pytest
import duckdb
from backend.app import db as db_module
@pytest.fixture(autouse=True)
def fresh_db():
"""Use an in-memory DB for each test and reset global state."""
db_module._conn = None
yield
db_module.close_db()
db_module._conn = None
# ---------------------------------------------------------------------------
# Lifecycle
# ---------------------------------------------------------------------------
def test_init_creates_users_table():
conn = db_module.init_db(":memory:")
result = conn.execute(
"SELECT table_name FROM information_schema.tables WHERE table_name = 'users'"
).fetchone()
assert result is not None
def test_init_creates_refresh_tokens_table():
conn = db_module.init_db(":memory:")
result = conn.execute(
"SELECT table_name FROM information_schema.tables WHERE table_name = 'refresh_tokens'"
).fetchone()
assert result is not None
def test_init_is_idempotent():
conn1 = db_module.init_db(":memory:")
conn2 = db_module.init_db(":memory:")
assert conn1 is conn2
def test_get_conn_raises_before_init():
with pytest.raises(RuntimeError, match="not initialised"):
db_module.get_conn()
def test_get_conn_returns_connection_after_init():
db_module.init_db(":memory:")
conn = db_module.get_conn()
assert conn is not None
def test_close_db_resets_connection():
db_module.init_db(":memory:")
db_module.close_db()
with pytest.raises(RuntimeError):
db_module.get_conn()
# ---------------------------------------------------------------------------
# Schema: users table columns and defaults
# ---------------------------------------------------------------------------
def test_users_columns():
conn = db_module.init_db(":memory:")
cols = {
row[0]
for row in conn.execute(
"SELECT column_name FROM information_schema.columns WHERE table_name = 'users'"
).fetchall()
}
assert {"id", "email", "password_hash",
"role", "created_at", "updated_at"} <= cols
def test_users_default_role_is_user():
conn = db_module.init_db(":memory:")
conn.execute(
"INSERT INTO users (email, password_hash) VALUES ('a@example.com', 'hash')"
)
row = conn.execute(
"SELECT role FROM users WHERE email = 'a@example.com'").fetchone()
assert row[0] == "user"
def test_users_id_auto_generated():
conn = db_module.init_db(":memory:")
conn.execute(
"INSERT INTO users (email, password_hash) VALUES ('b@example.com', 'hash')"
)
row = conn.execute(
"SELECT id FROM users WHERE email = 'b@example.com'").fetchone()
assert row[0] is not None
def test_users_email_unique_constraint():
conn = db_module.init_db(":memory:")
conn.execute(
"INSERT INTO users (email, password_hash) VALUES ('c@example.com', 'h')")
with pytest.raises(Exception):
conn.execute(
"INSERT INTO users (email, password_hash) VALUES ('c@example.com', 'h2')")
def test_users_timestamps_auto_set():
conn = db_module.init_db(":memory:")
conn.execute(
"INSERT INTO users (email, password_hash) VALUES ('d@example.com', 'hash')"
)
row = conn.execute(
"SELECT created_at, updated_at FROM users WHERE email = 'd@example.com'"
).fetchone()
assert row[0] is not None
assert row[1] is not None
# ---------------------------------------------------------------------------
# Schema: refresh_tokens table columns and defaults
# ---------------------------------------------------------------------------
def test_refresh_tokens_columns():
conn = db_module.init_db(":memory:")
cols = {
row[0]
for row in conn.execute(
"SELECT column_name FROM information_schema.columns WHERE table_name = 'refresh_tokens'"
).fetchall()
}
assert {"jti", "user_id", "issued_at", "expires_at", "revoked"} <= cols
def test_refresh_tokens_default_revoked_false():
conn = db_module.init_db(":memory:")
conn.execute(
"INSERT INTO users (email, password_hash) VALUES ('e@example.com', 'h')")
user_id = conn.execute(
"SELECT id FROM users WHERE email = 'e@example.com'").fetchone()[0]
conn.execute(
"INSERT INTO refresh_tokens (user_id, expires_at) VALUES (?, now() + INTERVAL 7 DAY)",
[user_id],
)
row = conn.execute("SELECT revoked FROM refresh_tokens WHERE user_id = ?", [
user_id]).fetchone()
assert row[0] is False
def test_refresh_tokens_jti_auto_generated():
conn = db_module.init_db(":memory:")
conn.execute(
"INSERT INTO users (email, password_hash) VALUES ('f@example.com', 'h')")
user_id = conn.execute(
"SELECT id FROM users WHERE email = 'f@example.com'").fetchone()[0]
conn.execute(
"INSERT INTO refresh_tokens (user_id, expires_at) VALUES (?, now() + INTERVAL 7 DAY)",
[user_id],
)
row = conn.execute("SELECT jti FROM refresh_tokens WHERE user_id = ?", [
user_id]).fetchone()
assert row[0] is not None
# ---------------------------------------------------------------------------
# Write lock
# ---------------------------------------------------------------------------
def test_get_write_lock_returns_asyncio_lock():
lock = db_module.get_write_lock()
assert isinstance(lock, asyncio.Lock)
def test_get_write_lock_returns_same_instance():
lock1 = db_module.get_write_lock()
lock2 = db_module.get_write_lock()
assert lock1 is lock2
async def test_write_lock_serialises_concurrent_writes():
"""Two coroutines acquiring the lock must not overlap."""
conn = db_module.init_db(":memory:")
lock = db_module.get_write_lock()
order = []
async def writer(label: str):
async with lock:
order.append(f"{label}-start")
await asyncio.sleep(0) # yield to event loop
order.append(f"{label}-end")
await asyncio.gather(writer("A"), writer("B"))
# Each writer's start and end must be adjacent (no interleaving)
assert order.index("A-start") + 1 == order.index("A-end") or \
order.index("B-start") + 1 == order.index("B-end")