"""Authentication service: password hashing, JWT creation/verification, token management.""" import os from datetime import datetime, timedelta, timezone from typing import Any from jose import JWTError, jwt from passlib.context import CryptContext from ..db import get_conn, get_write_lock _pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") ACCESS_TOKEN_EXPIRE_MINUTES = 15 REFRESH_TOKEN_EXPIRE_DAYS = 7 ALGORITHM = "HS256" def _secret() -> str: secret = os.getenv("JWT_SECRET") if not secret: raise RuntimeError("JWT_SECRET environment variable is not set.") return secret # --- Password --- def hash_password(plain: str) -> str: return _pwd_context.hash(plain) def verify_password(plain: str, hashed: str) -> bool: return _pwd_context.verify(plain, hashed) # --- Tokens --- def create_access_token(user_id: str, email: str, role: str) -> str: expire = datetime.now(timezone.utc) + timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES) payload = { "sub": user_id, "email": email, "role": role, "exp": expire, "type": "access", } return jwt.encode(payload, _secret(), algorithm=ALGORITHM) def create_refresh_token(user_id: str, jti: str) -> str: expire = datetime.now(timezone.utc) + timedelta(days=REFRESH_TOKEN_EXPIRE_DAYS) payload = { "sub": user_id, "jti": jti, "exp": expire, "type": "refresh", } return jwt.encode(payload, _secret(), algorithm=ALGORITHM) def decode_token(token: str) -> dict[str, Any]: """Decode and validate a JWT. Raises JWTError on failure.""" return jwt.decode(token, _secret(), algorithms=[ALGORITHM]) # --- Database operations --- async def register_user(email: str, password: str) -> dict[str, Any]: """Insert a new user. Returns the created user row.""" conn = get_conn() lock = get_write_lock() async with lock: existing = conn.execute( "SELECT id FROM users WHERE email = ?", [email] ).fetchone() if existing: raise ValueError("Email already registered.") conn.execute( "INSERT INTO users (email, password_hash) VALUES (?, ?)", [email, hash_password(password)], ) row = conn.execute( "SELECT id, email, role FROM users WHERE email = ?", [email] ).fetchone() return {"id": str(row[0]), "email": row[1], "role": row[2]} async def authenticate_user(email: str, password: str) -> dict[str, Any] | None: """Return user dict if credentials are valid, else None.""" conn = get_conn() row = conn.execute( "SELECT id, email, password_hash, role FROM users WHERE email = ?", [email] ).fetchone() if row is None or not verify_password(password, row[2]): return None return {"id": str(row[0]), "email": row[1], "role": row[3]} async def store_refresh_token(user_id: str, jti: str) -> None: """Persist a refresh token JTI in the database.""" conn = get_conn() lock = get_write_lock() from datetime import timedelta expires_at = datetime.now(timezone.utc) + timedelta(days=REFRESH_TOKEN_EXPIRE_DAYS) async with lock: conn.execute( "INSERT INTO refresh_tokens (jti, user_id, expires_at) VALUES (?, ?, ?)", [jti, user_id, expires_at], ) async def revoke_refresh_token(jti: str) -> None: """Mark a refresh token as revoked.""" conn = get_conn() lock = get_write_lock() async with lock: conn.execute( "UPDATE refresh_tokens SET revoked = true WHERE jti = ?", [jti] ) async def validate_refresh_token_jti(jti: str, user_id: str) -> bool: """Return True if the JTI exists, is not revoked, and belongs to user_id.""" conn = get_conn() now = datetime.now(timezone.utc) row = conn.execute( """ SELECT 1 FROM refresh_tokens WHERE jti = ? AND user_id = ? AND revoked = false AND expires_at > ? """, [jti, user_id, now], ).fetchone() return row is not None