"""Tests for DuckDB initialization and schema.""" import asyncio import pytest import duckdb from backend.app import db as db_module @pytest.fixture(autouse=True) def fresh_db(): """Use an in-memory DB for each test and reset global state.""" db_module._conn = None yield db_module.close_db() db_module._conn = None # --------------------------------------------------------------------------- # Lifecycle # --------------------------------------------------------------------------- def test_init_creates_users_table(): conn = db_module.init_db(":memory:") result = conn.execute( "SELECT table_name FROM information_schema.tables WHERE table_name = 'users'" ).fetchone() assert result is not None def test_init_creates_refresh_tokens_table(): conn = db_module.init_db(":memory:") result = conn.execute( "SELECT table_name FROM information_schema.tables WHERE table_name = 'refresh_tokens'" ).fetchone() assert result is not None def test_init_is_idempotent(): conn1 = db_module.init_db(":memory:") conn2 = db_module.init_db(":memory:") assert conn1 is conn2 def test_get_conn_raises_before_init(): with pytest.raises(RuntimeError, match="not initialised"): db_module.get_conn() def test_get_conn_returns_connection_after_init(): db_module.init_db(":memory:") conn = db_module.get_conn() assert conn is not None def test_close_db_resets_connection(): db_module.init_db(":memory:") db_module.close_db() with pytest.raises(RuntimeError): db_module.get_conn() # --------------------------------------------------------------------------- # Schema: users table columns and defaults # --------------------------------------------------------------------------- def test_users_columns(): conn = db_module.init_db(":memory:") cols = { row[0] for row in conn.execute( "SELECT column_name FROM information_schema.columns WHERE table_name = 'users'" ).fetchall() } assert {"id", "email", "password_hash", "role", "created_at", "updated_at"} <= cols def test_users_default_role_is_user(): conn = db_module.init_db(":memory:") conn.execute( "INSERT INTO users (email, password_hash) VALUES ('a@example.com', 'hash')" ) row = conn.execute( "SELECT role FROM users WHERE email = 'a@example.com'").fetchone() assert row[0] == "user" def test_users_id_auto_generated(): conn = db_module.init_db(":memory:") conn.execute( "INSERT INTO users (email, password_hash) VALUES ('b@example.com', 'hash')" ) row = conn.execute( "SELECT id FROM users WHERE email = 'b@example.com'").fetchone() assert row[0] is not None def test_users_email_unique_constraint(): conn = db_module.init_db(":memory:") conn.execute( "INSERT INTO users (email, password_hash) VALUES ('c@example.com', 'h')") with pytest.raises(Exception): conn.execute( "INSERT INTO users (email, password_hash) VALUES ('c@example.com', 'h2')") def test_users_timestamps_auto_set(): conn = db_module.init_db(":memory:") conn.execute( "INSERT INTO users (email, password_hash) VALUES ('d@example.com', 'hash')" ) row = conn.execute( "SELECT created_at, updated_at FROM users WHERE email = 'd@example.com'" ).fetchone() assert row[0] is not None assert row[1] is not None # --------------------------------------------------------------------------- # Schema: refresh_tokens table columns and defaults # --------------------------------------------------------------------------- def test_refresh_tokens_columns(): conn = db_module.init_db(":memory:") cols = { row[0] for row in conn.execute( "SELECT column_name FROM information_schema.columns WHERE table_name = 'refresh_tokens'" ).fetchall() } assert {"jti", "user_id", "issued_at", "expires_at", "revoked"} <= cols def test_refresh_tokens_default_revoked_false(): conn = db_module.init_db(":memory:") conn.execute( "INSERT INTO users (email, password_hash) VALUES ('e@example.com', 'h')") user_id = conn.execute( "SELECT id FROM users WHERE email = 'e@example.com'").fetchone()[0] conn.execute( "INSERT INTO refresh_tokens (user_id, expires_at) VALUES (?, now() + INTERVAL 7 DAY)", [user_id], ) row = conn.execute("SELECT revoked FROM refresh_tokens WHERE user_id = ?", [ user_id]).fetchone() assert row[0] is False def test_refresh_tokens_jti_auto_generated(): conn = db_module.init_db(":memory:") conn.execute( "INSERT INTO users (email, password_hash) VALUES ('f@example.com', 'h')") user_id = conn.execute( "SELECT id FROM users WHERE email = 'f@example.com'").fetchone()[0] conn.execute( "INSERT INTO refresh_tokens (user_id, expires_at) VALUES (?, now() + INTERVAL 7 DAY)", [user_id], ) row = conn.execute("SELECT jti FROM refresh_tokens WHERE user_id = ?", [ user_id]).fetchone() assert row[0] is not None # --------------------------------------------------------------------------- # Write lock # --------------------------------------------------------------------------- def test_get_write_lock_returns_asyncio_lock(): lock = db_module.get_write_lock() assert isinstance(lock, asyncio.Lock) def test_get_write_lock_returns_same_instance(): lock1 = db_module.get_write_lock() lock2 = db_module.get_write_lock() assert lock1 is lock2 async def test_write_lock_serialises_concurrent_writes(): """Two coroutines acquiring the lock must not overlap.""" conn = db_module.init_db(":memory:") lock = db_module.get_write_lock() order = [] async def writer(label: str): async with lock: order.append(f"{label}-start") await asyncio.sleep(0) # yield to event loop order.append(f"{label}-end") await asyncio.gather(writer("A"), writer("B")) # Each writer's start and end must be adjacent (no interleaving) assert order.index("A-start") + 1 == order.index("A-end") or \ order.index("B-start") + 1 == order.index("B-end")