8e36f48527
Co-authored-by: Copilot <copilot@github.com>
128 lines
4.1 KiB
Python
128 lines
4.1 KiB
Python
"""Authentication service: password hashing, JWT creation/verification, token management."""
|
|
import os
|
|
from datetime import datetime, timedelta, timezone
|
|
from typing import Any
|
|
|
|
from jose import JWTError, jwt
|
|
from passlib.context import CryptContext
|
|
|
|
from ..db import get_conn, get_write_lock
|
|
|
|
_pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
|
|
|
ACCESS_TOKEN_EXPIRE_MINUTES = 15
|
|
REFRESH_TOKEN_EXPIRE_DAYS = 7
|
|
ALGORITHM = "HS256"
|
|
|
|
|
|
def _secret() -> str:
|
|
secret = os.getenv("JWT_SECRET")
|
|
if not secret:
|
|
raise RuntimeError("JWT_SECRET environment variable is not set.")
|
|
return secret
|
|
|
|
|
|
# --- Password ---
|
|
|
|
def hash_password(plain: str) -> str:
|
|
return _pwd_context.hash(plain)
|
|
|
|
|
|
def verify_password(plain: str, hashed: str) -> bool:
|
|
return _pwd_context.verify(plain, hashed)
|
|
|
|
|
|
# --- Tokens ---
|
|
|
|
def create_access_token(user_id: str, email: str, role: str) -> str:
|
|
expire = datetime.now(timezone.utc) + \
|
|
timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
|
|
payload = {
|
|
"sub": user_id,
|
|
"email": email,
|
|
"role": role,
|
|
"exp": expire,
|
|
"type": "access",
|
|
}
|
|
return jwt.encode(payload, _secret(), algorithm=ALGORITHM)
|
|
|
|
|
|
def create_refresh_token(user_id: str, jti: str) -> str:
|
|
expire = datetime.now(timezone.utc) + \
|
|
timedelta(days=REFRESH_TOKEN_EXPIRE_DAYS)
|
|
payload = {
|
|
"sub": user_id,
|
|
"jti": jti,
|
|
"exp": expire,
|
|
"type": "refresh",
|
|
}
|
|
return jwt.encode(payload, _secret(), algorithm=ALGORITHM)
|
|
|
|
|
|
def decode_token(token: str) -> dict[str, Any]:
|
|
"""Decode and validate a JWT. Raises JWTError on failure."""
|
|
return jwt.decode(token, _secret(), algorithms=[ALGORITHM])
|
|
|
|
|
|
# --- Database operations ---
|
|
|
|
async def register_user(email: str, password: str) -> dict[str, Any]:
|
|
"""Insert a new user. Returns the created user row."""
|
|
conn = get_conn()
|
|
lock = get_write_lock()
|
|
sql_check = "SELECT id FROM users WHERE email = ?"
|
|
sql_insert = "INSERT INTO users (email, password_hash) VALUES (?, ?)"
|
|
sql_fetch = "SELECT id, email, role FROM users WHERE email = ?"
|
|
async with lock:
|
|
existing = conn.execute(sql_check, [email]).fetchone()
|
|
if existing:
|
|
raise ValueError("Email already registered.")
|
|
conn.execute(sql_insert, [email, hash_password(password)],)
|
|
row = conn.execute(sql_fetch, [email]).fetchone()
|
|
if row is None:
|
|
raise RuntimeError("Failed to fetch user after registration.")
|
|
return {"id": str(row[0]), "email": row[1], "role": row[2]}
|
|
|
|
|
|
async def authenticate_user(email: str, password: str) -> dict[str, Any] | None:
|
|
"""Return user dict if credentials are valid, else None."""
|
|
conn = get_conn()
|
|
sql_fetch = "SELECT id, email, password_hash, role FROM users WHERE email = ?"
|
|
row = conn.execute(sql_fetch, [email]).fetchone()
|
|
if row is None or not verify_password(password, row[2]):
|
|
return None
|
|
return {"id": str(row[0]), "email": row[1], "role": row[3]}
|
|
|
|
|
|
async def store_refresh_token(user_id: str, jti: str) -> None:
|
|
"""Persist a refresh token JTI in the database."""
|
|
conn = get_conn()
|
|
lock = get_write_lock()
|
|
sql_insert = "INSERT INTO refresh_tokens (jti, user_id, expires_at) VALUES (?, ?, ?)"
|
|
from datetime import timedelta
|
|
expires_at = datetime.now(timezone.utc) + \
|
|
timedelta(days=REFRESH_TOKEN_EXPIRE_DAYS)
|
|
async with lock:
|
|
conn.execute(sql_insert, [jti, user_id, expires_at])
|
|
|
|
|
|
async def revoke_refresh_token(jti: str) -> None:
|
|
"""Mark a refresh token as revoked."""
|
|
conn = get_conn()
|
|
lock = get_write_lock()
|
|
sql_update = "UPDATE refresh_tokens SET revoked = true WHERE jti = ?"
|
|
async with lock:
|
|
conn.execute(sql_update, [jti])
|
|
|
|
|
|
async def validate_refresh_token_jti(jti: str, user_id: str) -> bool:
|
|
"""Return True if the JTI exists, is not revoked, and belongs to user_id."""
|
|
conn = get_conn()
|
|
now = datetime.now(timezone.utc)
|
|
sql_select = """
|
|
SELECT 1 FROM refresh_tokens
|
|
WHERE jti = ? AND user_id = ? AND revoked = false AND expires_at > ?
|
|
"""
|
|
row = conn.execute(sql_select, [jti, user_id, now]).fetchone()
|
|
return row is not None
|