diff --git a/main.py b/main.py index 0fe8955..a5f673c 100644 --- a/main.py +++ b/main.py @@ -9,6 +9,7 @@ from dotenv import load_dotenv from templates import load_templates from dashboard import create_app +from maintenance import delete_old_messages from thctime import ( get_country_info, get_tz_info, @@ -49,9 +50,6 @@ logging.basicConfig( load_dotenv() WEBHOOK_URL = os.getenv('DISCORD_WEBHOOK_URL') -DISCORD_BOT_TOKEN = os.getenv('DISCORD_BOT_TOKEN') -DISCORD_CHANNEL_ID = os.getenv('DISCORD_CHANNEL_ID') -GUILD_ID = os.getenv('DISCORD_GUILD_ID') def get_state() -> dict: @@ -137,341 +135,6 @@ def create_embed(type: str, tz_list: list[str] | None = None) -> dict: return embed -def get_discord_headers() -> dict[str, str]: - return { - "Authorization": f"Bot {DISCORD_BOT_TOKEN}", - "Content-Type": "application/json", - } - - -def parse_message_timestamp(message: dict) -> datetime: - return datetime.fromisoformat(message["timestamp"].replace("Z", "")) - - -def build_delete_entry(message: dict) -> dict: - return { - "id": message.get("id"), - "timestamp": parse_message_timestamp(message), - } - - -def parse_float(value: str | int | float | None) -> float | None: - if value is None: - return None - try: - return float(value) - except (TypeError, ValueError): - return None - - -def should_delete_message( - message: dict, - webhook_id: str, - author_id: str, - cutoff: int, -) -> bool: - message_timestamp = int(parse_message_timestamp(message).timestamp()) - return ( - message_timestamp <= cutoff - and message.get("webhook_id") == webhook_id - and message.get("author", {}).get("id") == author_id - ) - - -def get_rate_limit_retry_after(response: requests.Response) -> float | None: - header_retry_after = parse_float(response.headers.get("Retry-After")) - if header_retry_after is not None: - return header_retry_after - - reset_after = parse_float(response.headers.get("X-RateLimit-Reset-After")) - if reset_after is not None: - return reset_after - - try: - payload = response.json() - except ValueError: - return None - - return parse_float(payload.get("retry_after")) - - -def get_bucket_exhausted_delay(response: requests.Response) -> float | None: - remaining = response.headers.get("X-RateLimit-Remaining") - if remaining != "0": - return None - return parse_float(response.headers.get("X-RateLimit-Reset-After")) - - -def sleep_for_rate_limit(delay_seconds: float, reason: str) -> None: - if delay_seconds <= 0: - return - logging.info( - f"Waiting {delay_seconds:.3f}s for Discord rate limit reset ({reason}).") - time.sleep(delay_seconds) - - -def find_last_message_by_author( - headers: dict[str, str], - guild_id: str, - channel_id: str, - author_id: str, -) -> dict | None: - """Find the newest indexed message for the author in the target guild channel.""" - url = f"https://discord.com/api/v10/guilds/{guild_id}/messages/search" - params = [ - ("author_id", author_id), - ("author_type", "webhook"), - ("channel_id", channel_id), - ("sort_by", "timestamp"), - ("sort_order", "desc"), - ("limit", "10"), - ] - - for _ in range(3): - try: - response = requests.get( - url, headers=headers, params=params, timeout=10) - except requests.RequestException as e: - logging.error(f"Error searching guild messages: {e}") - return None - - if response.status_code == 202: - payload = response.json() - retry_after = float(payload.get("retry_after", 1) or 1) - logging.info( - f"Guild search index not ready. Retrying after {retry_after} seconds." - ) - time.sleep(retry_after) - continue - - if response.status_code == 429: - retry_after = get_rate_limit_retry_after(response) or 1.0 - sleep_for_rate_limit(retry_after, "guild search") - continue - - if response.status_code != 200: - logging.error( - f"Failed to search guild messages: {response.status_code} - {response.text}" - ) - return None - - message_groups = response.json().get("messages", []) - if not message_groups or not message_groups[0]: - return None - return message_groups[0][0] - - logging.error("Guild search index did not become available in time.") - return None - - -def fetch_messages_to_delete(headers: dict, channel_id: str, webhook_id: str, author_id: str, cutoff: int, last_message_id: str | None = None) -> tuple[list[dict], str | None]: - """ - Fetch messages from the channel that are older than the cutoff timestamp and sent by the webhook. - Uses pagination with the 'before' parameter to resume from the last processed message. - Returns a tuple of (list of messages to delete, last message ID for pagination). - """ - url = f"https://discord.com/api/v10/channels/{channel_id}/messages" - params: dict[str, str | int] = { - "limit": 100, # Maximum allowed by Discord API - } - - # Use the 'before' parameter for pagination if last_message_id is provided - if last_message_id: - params["before"] = last_message_id - - try: - for _ in range(3): - response = requests.get(url, headers=headers, - params=params, timeout=10) - - if response.status_code == 429: - retry_after = get_rate_limit_retry_after(response) or 1.0 - sleep_for_rate_limit(retry_after, "channel message fetch") - continue - - break - else: - logging.error( - "Failed to fetch messages after repeated rate limits.") - return [], last_message_id - - if response.status_code == 200: - messages = response.json() - delete_list = [] - new_last_message_id = None - - for message in messages: - new_last_message_id = message.get("id") - - if should_delete_message( - message, - webhook_id, - author_id, - cutoff, - ): - delete_list.append(build_delete_entry(message)) - - # Limit the list to 100 items - if len(delete_list) >= 100: - break - - return delete_list, new_last_message_id - else: - logging.error( - f"Failed to fetch messages: {response.status_code} - {response.text}") - return [], last_message_id - except requests.RequestException as e: - logging.error(f"Error fetching messages: {e}") - return [], last_message_id - - -def delete_message(headers: dict, channel_id: str, message_id: str) -> tuple[bool, float | None, bool]: - """ - Delete a single message from the channel. - - Returns: - tuple[bool, float | None, bool]: - - whether the delete succeeded - - how long to wait before the next request, if any - - whether to abort the batch because further requests would be invalid - """ - delete_url = f"https://discord.com/api/v10/channels/{channel_id}/messages/{message_id}" - delete_response = requests.delete(delete_url, headers=headers, timeout=10) - - if delete_response.status_code == 204: - return True, get_bucket_exhausted_delay(delete_response), False - - if delete_response.status_code == 429: - retry_after = get_rate_limit_retry_after(delete_response) or 1.0 - scope = delete_response.headers.get("X-RateLimit-Scope", "unknown") - is_global = delete_response.headers.get( - "X-RateLimit-Global", "false").lower() == "true" - logging.warning( - "Discord rate limit hit while deleting message %s: scope=%s global=%s retry_after=%.3fs", - message_id, - scope, - is_global, - retry_after, - ) - return False, retry_after, False - - if delete_response.status_code in {401, 403}: - logging.error( - "Failed to delete message %s: %s - %s. Stopping deletes to avoid invalid request spam.", - message_id, - delete_response.status_code, - delete_response.text, - ) - return False, None, True - - logging.error( - f"Failed to delete message {message_id}: {delete_response.status_code} - {delete_response.text}") - return False, None, False - - -def delete_old_messages(minutes: int = 6) -> None: - """ - Delete all messages sent by the webhook in the last `minutes` minutes. - Uses a dynamic slowdown to avoid hitting Discord API rate limits and pagination to fetch all messages. - """ - if not DISCORD_BOT_TOKEN or not DISCORD_CHANNEL_ID or not GUILD_ID: - logging.error( - "DISCORD_BOT_TOKEN, DISCORD_CHANNEL_ID, or DISCORD_GUILD_ID not set") - return - - headers = get_discord_headers() - - # Calculate the time `minutes` minutes ago - cutoff_timestamp = datetime.now() - timedelta(minutes=minutes) - cutoff = int(cutoff_timestamp.timestamp()) - webhook_id = '1413817504194760766' - author_id = '1413817504194760766' - - last_author_message = find_last_message_by_author( - headers, - GUILD_ID, - DISCORD_CHANNEL_ID, - author_id, - ) - if last_author_message is None: - logging.info("No indexed messages found for the target author.") - return - - last_message_id = last_author_message.get("id") - if not last_message_id: - logging.info("Search result did not contain a message id.") - return - - deleted_count = 0 - - if should_delete_message( - last_author_message, - webhook_id, - author_id, - cutoff, - ): - anchor_message = build_delete_entry(last_author_message) - deleted, wait_seconds, abort_batch = delete_message( - headers, - DISCORD_CHANNEL_ID, - anchor_message["id"], - ) - if deleted: - deleted_count += 1 - logging.info( - f"Deleted message {anchor_message['id']} from {anchor_message['timestamp'].isoformat()}" - ) - elif abort_batch: - return - - if wait_seconds is not None: - sleep_for_rate_limit(wait_seconds, "delete bucket") - - while True: - # Fetch messages to delete with pagination - delete_list, next_last_message_id = fetch_messages_to_delete( - headers, DISCORD_CHANNEL_ID, webhook_id, author_id, cutoff, last_message_id - ) - - # Stop scanning when a page has no eligible messages to delete. - # Continuing pagination here can walk indefinitely through history. - if not delete_list: - if deleted_count == 0: - logging.info("No messages to delete.") - else: - logging.info("No more messages to delete.") - break - - for message in delete_list: - message_id = message["id"] - message_time = message["timestamp"] - - deleted, wait_seconds, abort_batch = delete_message( - headers, - DISCORD_CHANNEL_ID, - message_id, - ) - - if deleted: - deleted_count += 1 - logging.info( - f"Deleted message {message_id} from {message_time.isoformat()}") - elif abort_batch: - logging.info( - "Stopping delete batch after an invalid Discord response.") - return - - if wait_seconds is not None: - sleep_for_rate_limit(wait_seconds, "delete bucket") - - if next_last_message_id is None or next_last_message_id == last_message_id: - break - last_message_id = next_last_message_id - - logging.info( - f"Deleted {deleted_count} messages older than {minutes} minutes sent by the webhook.") - - def send_notification(message: str) -> None: """ Send a notification to the Discord webhook. diff --git a/maintenance.py b/maintenance.py new file mode 100644 index 0000000..863f7d2 --- /dev/null +++ b/maintenance.py @@ -0,0 +1,343 @@ +import logging +import os +import time +from datetime import datetime, timedelta + +import requests + + +WEBHOOK_AUTHOR_ID = "1413817504194760766" + + +def get_discord_headers() -> dict[str, str]: + token = os.getenv("DISCORD_BOT_TOKEN") + return { + "Authorization": f"Bot {token}", + "Content-Type": "application/json", + } + + +def parse_message_timestamp(message: dict) -> datetime: + return datetime.fromisoformat(message["timestamp"].replace("Z", "")) + + +def build_delete_entry(message: dict) -> dict: + return { + "id": message.get("id"), + "timestamp": parse_message_timestamp(message), + } + + +def parse_float(value: str | int | float | None) -> float | None: + if value is None: + return None + try: + return float(value) + except (TypeError, ValueError): + return None + + +def should_delete_message( + message: dict, + webhook_id: str, + author_id: str, + cutoff: int, +) -> bool: + message_timestamp = int(parse_message_timestamp(message).timestamp()) + return ( + message_timestamp <= cutoff + and message.get("webhook_id") == webhook_id + and message.get("author", {}).get("id") == author_id + ) + + +def get_rate_limit_retry_after(response: requests.Response) -> float | None: + header_retry_after = parse_float(response.headers.get("Retry-After")) + if header_retry_after is not None: + return header_retry_after + + reset_after = parse_float(response.headers.get("X-RateLimit-Reset-After")) + if reset_after is not None: + return reset_after + + try: + payload = response.json() + except ValueError: + return None + + return parse_float(payload.get("retry_after")) + + +def get_bucket_exhausted_delay(response: requests.Response) -> float | None: + remaining = response.headers.get("X-RateLimit-Remaining") + if remaining != "0": + return None + return parse_float(response.headers.get("X-RateLimit-Reset-After")) + + +def sleep_for_rate_limit(delay_seconds: float, reason: str) -> None: + if delay_seconds <= 0: + return + logging.info( + f"Waiting {delay_seconds:.3f}s for Discord rate limit reset ({reason}).") + time.sleep(delay_seconds) + + +def find_last_message_by_author( + headers: dict[str, str], + guild_id: str, + channel_id: str, + author_id: str, +) -> dict | None: + """Find the newest indexed message for the author in the target guild channel.""" + url = f"https://discord.com/api/v10/guilds/{guild_id}/messages/search" + params = [ + ("author_id", author_id), + ("author_type", "webhook"), + ("channel_id", channel_id), + ("sort_by", "timestamp"), + ("sort_order", "desc"), + ("limit", "10"), + ] + + for _ in range(3): + try: + response = requests.get( + url, headers=headers, params=params, timeout=10) + except requests.RequestException as e: + logging.error(f"Error searching guild messages: {e}") + return None + + if response.status_code == 202: + payload = response.json() + retry_after = float(payload.get("retry_after", 1) or 1) + logging.info( + f"Guild search index not ready. Retrying after {retry_after} seconds." + ) + time.sleep(retry_after) + continue + + if response.status_code == 429: + retry_after = get_rate_limit_retry_after(response) or 1.0 + sleep_for_rate_limit(retry_after, "guild search") + continue + + if response.status_code != 200: + logging.error( + f"Failed to search guild messages: {response.status_code} - {response.text}" + ) + return None + + message_groups = response.json().get("messages", []) + if not message_groups or not message_groups[0]: + return None + return message_groups[0][0] + + logging.error("Guild search index did not become available in time.") + return None + + +def fetch_messages_to_delete(headers: dict, channel_id: str, webhook_id: str, author_id: str, cutoff: int, last_message_id: str | None = None) -> tuple[list[dict], str | None]: + """ + Fetch messages from the channel that are older than the cutoff timestamp and sent by the webhook. + Uses pagination with the 'before' parameter to resume from the last processed message. + Returns a tuple of (list of messages to delete, last message ID for pagination). + """ + url = f"https://discord.com/api/v10/channels/{channel_id}/messages" + params: dict[str, str | int] = { + "limit": 100, + } + + if last_message_id: + params["before"] = last_message_id + + try: + for _ in range(3): + response = requests.get(url, headers=headers, + params=params, timeout=10) + + if response.status_code == 429: + retry_after = get_rate_limit_retry_after(response) or 1.0 + sleep_for_rate_limit(retry_after, "channel message fetch") + continue + + break + else: + logging.error( + "Failed to fetch messages after repeated rate limits.") + return [], last_message_id + + if response.status_code == 200: + messages = response.json() + delete_list = [] + new_last_message_id = None + + for message in messages: + new_last_message_id = message.get("id") + + if should_delete_message( + message, + webhook_id, + author_id, + cutoff, + ): + delete_list.append(build_delete_entry(message)) + + if len(delete_list) >= 100: + break + + return delete_list, new_last_message_id + + logging.error( + f"Failed to fetch messages: {response.status_code} - {response.text}") + return [], last_message_id + except requests.RequestException as e: + logging.error(f"Error fetching messages: {e}") + return [], last_message_id + + +def delete_message(headers: dict, channel_id: str, message_id: str) -> tuple[bool, float | None, bool]: + """ + Delete a single message from the channel. + + Returns: + tuple[bool, float | None, bool]: + - whether the delete succeeded + - how long to wait before the next request, if any + - whether to abort the batch because further requests would be invalid + """ + delete_url = f"https://discord.com/api/v10/channels/{channel_id}/messages/{message_id}" + delete_response = requests.delete(delete_url, headers=headers, timeout=10) + + if delete_response.status_code == 204: + return True, get_bucket_exhausted_delay(delete_response), False + + if delete_response.status_code == 429: + retry_after = get_rate_limit_retry_after(delete_response) or 1.0 + scope = delete_response.headers.get("X-RateLimit-Scope", "unknown") + is_global = delete_response.headers.get( + "X-RateLimit-Global", "false").lower() == "true" + logging.warning( + "Discord rate limit hit while deleting message %s: scope=%s global=%s retry_after=%.3fs", + message_id, + scope, + is_global, + retry_after, + ) + return False, retry_after, False + + if delete_response.status_code in {401, 403}: + logging.error( + "Failed to delete message %s: %s - %s. Stopping deletes to avoid invalid request spam.", + message_id, + delete_response.status_code, + delete_response.text, + ) + return False, None, True + + logging.error( + f"Failed to delete message {message_id}: {delete_response.status_code} - {delete_response.text}") + return False, None, False + + +def delete_old_messages(minutes: int = 6) -> None: + """ + Delete all messages sent by the webhook in the last `minutes` minutes. + Uses a dynamic slowdown to avoid hitting Discord API rate limits and pagination to fetch all messages. + """ + discord_bot_token = os.getenv("DISCORD_BOT_TOKEN") + discord_channel_id = os.getenv("DISCORD_CHANNEL_ID") + guild_id = os.getenv("DISCORD_GUILD_ID") + + if not discord_bot_token or not discord_channel_id or not guild_id: + logging.error( + "DISCORD_BOT_TOKEN, DISCORD_CHANNEL_ID, or DISCORD_GUILD_ID not set") + return + + headers = get_discord_headers() + + cutoff_timestamp = datetime.now() - timedelta(minutes=minutes) + cutoff = int(cutoff_timestamp.timestamp()) + webhook_id = WEBHOOK_AUTHOR_ID + author_id = WEBHOOK_AUTHOR_ID + + last_author_message = find_last_message_by_author( + headers, + guild_id, + discord_channel_id, + author_id, + ) + if last_author_message is None: + logging.info("No indexed messages found for the target author.") + return + + last_message_id = last_author_message.get("id") + if not last_message_id: + logging.info("Search result did not contain a message id.") + return + + deleted_count = 0 + + if should_delete_message( + last_author_message, + webhook_id, + author_id, + cutoff, + ): + anchor_message = build_delete_entry(last_author_message) + deleted, wait_seconds, abort_batch = delete_message( + headers, + discord_channel_id, + anchor_message["id"], + ) + if deleted: + deleted_count += 1 + logging.info( + f"Deleted message {anchor_message['id']} from {anchor_message['timestamp'].isoformat()}" + ) + elif abort_batch: + return + + if wait_seconds is not None: + sleep_for_rate_limit(wait_seconds, "delete bucket") + + while True: + delete_list, next_last_message_id = fetch_messages_to_delete( + headers, discord_channel_id, webhook_id, author_id, cutoff, last_message_id + ) + + if not delete_list: + if deleted_count == 0: + logging.info("No messages to delete.") + else: + logging.info("No more messages to delete.") + break + + for message in delete_list: + message_id = message["id"] + message_time = message["timestamp"] + + deleted, wait_seconds, abort_batch = delete_message( + headers, + discord_channel_id, + message_id, + ) + + if deleted: + deleted_count += 1 + logging.info( + f"Deleted message {message_id} from {message_time.isoformat()}") + elif abort_batch: + logging.info( + "Stopping delete batch after an invalid Discord response.") + return + + if wait_seconds is not None: + sleep_for_rate_limit(wait_seconds, "delete bucket") + + if next_last_message_id is None or next_last_message_id == last_message_id: + break + last_message_id = next_last_message_id + + logging.info( + f"Deleted {deleted_count} messages older than {minutes} minutes sent by the webhook.") diff --git a/tests/test_dashboard.py b/tests/test_dashboard.py new file mode 100644 index 0000000..83db381 --- /dev/null +++ b/tests/test_dashboard.py @@ -0,0 +1,143 @@ +from datetime import datetime + +import dashboard + + +def _make_app(): + return dashboard.create_app( + get_state=lambda: { + "running": True, + "started_at": datetime(2026, 1, 1, 10, 0, 0), + "last_type": "420", + "last_attempt_at": datetime(2026, 1, 1, 10, 15, 0), + "last_success_at": datetime(2026, 1, 1, 10, 20, 0), + "last_status_code": 204, + "last_error": None, + "last_locations": ["Nowhere"], + }, + get_next_event=lambda: { + "type": "reminder", + "at": datetime(2026, 1, 1, 11, 15, 0), + }, + ) + + +def test_fmt_dt_none_and_datetime(): + assert dashboard._fmt_dt(None) == "—" + assert dashboard._fmt_dt( + datetime(2026, 1, 1, 10, 0, 0)) == "2026-01-01T10:00:00" + + +def test_get_html_template_wraps_content(): + html = dashboard.get_html_template("

hello

") + assert "thc-webhook admin" in html + assert "

hello

" in html + + +def test_index_route_renders_status_page(): + app = _make_app() + client = app.test_client() + + response = client.get("/") + + assert response.status_code == 200 + body = response.get_data(as_text=True) + assert "thc-webhook" in body + assert "last_type: 420" in body + assert "type: reminder" in body + assert "Nowhere" in body + + +def test_admin_get_renders_template_form(monkeypatch): + monkeypatch.setattr( + dashboard, + "load_templates", + lambda path: { + "420": { + "text": "Blaze", + "color": 3066993, + "image_url": "https://example.com/img.png", + } + }, + ) + + app = _make_app() + app.config["TEMPLATES_PATH"] = "templates.json" + client = app.test_client() + + response = client.get("/admin") + + assert response.status_code == 200 + body = response.get_data(as_text=True) + assert "Admin: templates" in body + assert "name='420__text'" in body + assert "name='420__color'" in body + assert "name='420__image_url'" in body + + +def test_admin_post_validation_error(monkeypatch): + monkeypatch.setattr( + dashboard, + "load_templates", + lambda path: {"420": {"text": "x", "color": 1}}, + ) + monkeypatch.setattr(dashboard, "parse_color", lambda raw: ( + _ for _ in ()).throw(ValueError("bad color"))) + + save_called = {"value": False} + + def _save_templates(path, updated): + save_called["value"] = True + + monkeypatch.setattr(dashboard, "save_templates", _save_templates) + + app = _make_app() + client = app.test_client() + + response = client.post( + "/admin", + data={ + "420__text": "Updated", + "420__color": "bad", + "420__image_url": "", + }, + ) + + assert response.status_code == 400 + assert "invalid color" in response.get_data(as_text=True) + assert save_called["value"] is False + + +def test_admin_post_success(monkeypatch): + monkeypatch.setattr( + dashboard, + "load_templates", + lambda path: {"420": {"text": "x", "color": 1}}, + ) + monkeypatch.setattr(dashboard, "parse_color", lambda raw: 123) + + saved = {"path": None, "payload": None} + + def _save_templates(path, updated): + saved["path"] = path + saved["payload"] = updated + + monkeypatch.setattr(dashboard, "save_templates", _save_templates) + + app = _make_app() + app.config["TEMPLATES_PATH"] = "custom_templates.json" + client = app.test_client() + + response = client.post( + "/admin", + data={ + "420__text": "Updated", + "420__color": "123", + "420__image_url": "", + }, + ) + + assert response.status_code == 200 + assert "Saved." in response.get_data(as_text=True) + assert saved["path"] == "custom_templates.json" + assert saved["payload"] == {"420": {"text": "Updated", "color": 123}} diff --git a/tests/test_main.py b/tests/test_main.py deleted file mode 100644 index 8c71def..0000000 --- a/tests/test_main.py +++ /dev/null @@ -1,132 +0,0 @@ -import io -import time -from unittest import mock - -import pytest - -import main -import thctime - - -SAMPLE_TIMEZONE_CSV = """Etc/UTC,ZZ,UTC,0,0,0 -America/New_York,US,EST,0,-18000,0 -Europe/London,GB,BST,0,0,1 -""" - -SAMPLE_COUNTRY_CSV = """ZZ,Unknown -US,United States -GB,United Kingdom -""" - - -def test_load_timezones_and_countries(monkeypatch): - tzs = main.load_timezones() - countries = main.load_countries() - assert any(t['zone_name'] == 'America/New_York' for t in tzs) - assert any(c['country_code'] == 'US' for c in countries) - - -def test_get_tz_and_country_info(): - timezones = [{'zone_name': 'A/B', 'country_code': 'US'}] - countries = [{'country_code': 'US', 'country_name': 'United States'}] - assert main.get_tz_info('A/B', timezones)['zone_name'] == 'A/B' - assert main.get_country_info('US', countries)[ - 'country_name'] == 'United States' - assert main.get_tz_info('X/Y', timezones) is None - assert main.get_country_info('XX', countries) is None - - -def test_create_embed_all_types(monkeypatch): - # Prevent create_embed from trying to read actual CSV files by patching loaders - monkeypatch.setattr(main, 'load_timezones', lambda: [ - {'zone_name': 'Etc/UTC', 'country_code': 'ZZ'}]) - monkeypatch.setattr(main, 'load_countries', lambda: [ - {'country_code': 'ZZ', 'country_name': 'Nowhere'}]) - - # reminder - emb = main.create_embed('reminder') - assert emb['title'] == 'Reminder' - assert '5 minute' in emb['description'] - assert emb['color'] == 0xe67e22 - - # reminder_halftime - emb = main.create_embed('reminder_halftime') - assert emb['title'] == 'Reminder halftime' - assert 'Half-time in 5 minutes' in emb['description'] - - # halftime (should include image) - monkeypatch.setattr(main, 'where_is_it_420', lambda tzs, cs, **kwargs: []) - emb = main.create_embed('halftime') - assert emb['title'] == 'Halftime' - assert emb['image'] is not None - - # 420 (should include image and appended tz info string when list empty) - monkeypatch.setattr(main, 'where_is_it_420', lambda tzs, cs, **kwargs: []) - emb = main.create_embed('420') - assert emb['title'] == '420' - assert emb['image'] is not None - - # unknown - emb = main.create_embed('nope') - assert emb['description'] == 'Unknown notification type' - - -def test_where_is_it_420(monkeypatch): - # Limit timezones to a predictable set - monkeypatch.setattr(thctime.pytz, 'all_timezones', ['Etc/UTC']) - - tzs = [{'zone_name': 'Etc/UTC', 'country_code': 'ZZ'}] - countries = [{'country_code': 'ZZ', 'country_name': 'Nowhere'}] - - class FakeDatetime: - @staticmethod - def now(tz): - class R: - hour = 4 - return R() - - monkeypatch.setattr(thctime, 'datetime', FakeDatetime) - - res = main.where_is_it_420(tzs, countries) - assert res == ['Nowhere'] - - -def test_main_exits_quickly(monkeypatch): - # Patch send_notification so it doesn't perform network - monkeypatch.setattr(main, 'send_notification', lambda x: None) - # Don't start dashboard during this test - monkeypatch.setattr(main, 'start_dashboard', lambda: None) - # Make schedule.run_pending raise KeyboardInterrupt to exit loop - monkeypatch.setattr(main.schedule, 'run_pending', lambda: ( - _ for _ in ()).throw(KeyboardInterrupt())) - # Patch time.sleep to no-op - monkeypatch.setattr(main.time, 'sleep', lambda s: None) - # Ensure WEBHOOK_URL present to avoid early return - monkeypatch.setenv('DISCORD_WEBHOOK_URL', 'http://example.com/webhook') - main.WEBHOOK_URL = 'http://example.com/webhook' - - # Should exit quickly due to KeyboardInterrupt from run_pending - main.main() - - -def test_get_next_scheduled_event(): - # 10:14 -> next is 10:15 reminder - now = main.datetime(2025, 1, 1, 10, 14, 30) - nxt = main.get_next_scheduled_event(now) - assert nxt["type"] == "reminder" - assert nxt["at"].hour == 10 and nxt["at"].minute == 15 - - # 10:50:01 -> next is 11:15 reminder - now = main.datetime(2025, 1, 1, 10, 50, 1) - nxt = main.get_next_scheduled_event(now) - assert nxt["type"] == "reminder" - assert nxt["at"].hour == 11 and nxt["at"].minute == 15 - - -def test_split_tz_name(): - assert main.split_tz_name("America/New_York") == ("America", "New_York") - assert main.split_tz_name("America/Argentina/Buenos_Aires") == ( - "America", - "Argentina/Buenos_Aires", - ) - assert main.split_tz_name("UTC") == ("UTC", "") diff --git a/tests/test_main_embed.py b/tests/test_main_embed.py new file mode 100644 index 0000000..3e6fd46 --- /dev/null +++ b/tests/test_main_embed.py @@ -0,0 +1,32 @@ +import main + + +def test_create_embed_all_types(monkeypatch): + monkeypatch.setattr(main, "load_timezones", lambda: [ + {"zone_name": "Etc/UTC", "country_code": "ZZ"} + ]) + monkeypatch.setattr(main, "load_countries", lambda: [ + {"country_code": "ZZ", "country_name": "Nowhere"} + ]) + + emb = main.create_embed("reminder") + assert emb["title"] == "Reminder" + assert "5 minute" in emb["description"] + assert emb["color"] == 0xE67E22 + + emb = main.create_embed("reminder_halftime") + assert emb["title"] == "Reminder halftime" + assert "Half-time in 5 minutes" in emb["description"] + + monkeypatch.setattr(main, "where_is_it_420", lambda tzs, cs, **kwargs: []) + emb = main.create_embed("halftime") + assert emb["title"] == "Halftime" + assert emb["image"] is not None + + monkeypatch.setattr(main, "where_is_it_420", lambda tzs, cs, **kwargs: []) + emb = main.create_embed("420") + assert emb["title"] == "420" + assert emb["image"] is not None + + emb = main.create_embed("nope") + assert emb["description"] == "Unknown notification type" diff --git a/tests/test_main_scheduler.py b/tests/test_main_scheduler.py new file mode 100644 index 0000000..8c67ba0 --- /dev/null +++ b/tests/test_main_scheduler.py @@ -0,0 +1,25 @@ +import main + + +def test_main_exits_quickly(monkeypatch): + monkeypatch.setattr(main, "send_notification", lambda x: None) + monkeypatch.setattr(main, "start_dashboard", lambda: None) + monkeypatch.setattr(main.schedule, "run_pending", lambda: ( + _ for _ in ()).throw(KeyboardInterrupt())) + monkeypatch.setattr(main.time, "sleep", lambda s: None) + monkeypatch.setenv("DISCORD_WEBHOOK_URL", "http://example.com/webhook") + main.WEBHOOK_URL = "http://example.com/webhook" + + main.main() + + +def test_get_next_scheduled_event(): + now = main.datetime(2025, 1, 1, 10, 14, 30) + nxt = main.get_next_scheduled_event(now) + assert nxt["type"] == "reminder" + assert nxt["at"].hour == 10 and nxt["at"].minute == 15 + + now = main.datetime(2025, 1, 1, 10, 50, 1) + nxt = main.get_next_scheduled_event(now) + assert nxt["type"] == "reminder" + assert nxt["at"].hour == 11 and nxt["at"].minute == 15 diff --git a/tests/test_maintenance.py b/tests/test_maintenance.py new file mode 100644 index 0000000..8bcce99 --- /dev/null +++ b/tests/test_maintenance.py @@ -0,0 +1,88 @@ +from datetime import datetime, timezone + +import maintenance + + +class DummyResponse: + def __init__(self, headers=None, payload=None): + self.headers = headers or {} + self._payload = payload + + def json(self): + if self._payload is None: + raise ValueError("no json") + return self._payload + + +def test_parse_float(): + assert maintenance.parse_float("1.5") == 1.5 + assert maintenance.parse_float(2) == 2.0 + assert maintenance.parse_float(None) is None + assert maintenance.parse_float("nope") is None + + +def test_parse_message_timestamp_and_build_delete_entry(): + msg = {"id": "42", "timestamp": "2026-01-01T10:00:00Z"} + parsed = maintenance.parse_message_timestamp(msg) + assert parsed == datetime(2026, 1, 1, 10, 0, 0) + + entry = maintenance.build_delete_entry(msg) + assert entry["id"] == "42" + assert entry["timestamp"] == parsed + + +def test_should_delete_message(): + ts = int(datetime(2026, 1, 1, 10, 0, 0, tzinfo=timezone.utc).timestamp()) + message = { + "timestamp": "2026-01-01T10:00:00Z", + "webhook_id": "w", + "author": {"id": "a"}, + } + + assert maintenance.should_delete_message( + message, + webhook_id="w", + author_id="a", + cutoff=ts, + ) + assert not maintenance.should_delete_message( + message, + webhook_id="x", + author_id="a", + cutoff=ts, + ) + + +def test_get_rate_limit_retry_after_header_priority(): + response = DummyResponse( + headers={ + "Retry-After": "2.5", + "X-RateLimit-Reset-After": "10", + }, + payload={"retry_after": "20"}, + ) + assert maintenance.get_rate_limit_retry_after(response) == 2.5 + + +def test_get_rate_limit_retry_after_json_fallback(): + response = DummyResponse(payload={"retry_after": "3"}) + assert maintenance.get_rate_limit_retry_after(response) == 3.0 + + +def test_get_bucket_exhausted_delay(): + response = DummyResponse( + headers={ + "X-RateLimit-Remaining": "0", + "X-RateLimit-Reset-After": "1.25", + } + ) + assert maintenance.get_bucket_exhausted_delay(response) == 1.25 + + response_not_exhausted = DummyResponse( + headers={ + "X-RateLimit-Remaining": "1", + "X-RateLimit-Reset-After": "1.25", + } + ) + assert maintenance.get_bucket_exhausted_delay( + response_not_exhausted) is None diff --git a/tests/test_thctime.py b/tests/test_thctime.py new file mode 100644 index 0000000..a5ca78b --- /dev/null +++ b/tests/test_thctime.py @@ -0,0 +1,47 @@ +import thctime + + +def test_load_timezones_and_countries(): + tzs = thctime.load_timezones() + countries = thctime.load_countries() + assert any(t["zone_name"] == "America/New_York" for t in tzs) + assert any(c["country_code"] == "US" for c in countries) + + +def test_get_tz_and_country_info(): + timezones = [{"zone_name": "A/B", "country_code": "US"}] + countries = [{"country_code": "US", "country_name": "United States"}] + assert thctime.get_tz_info("A/B", timezones)["zone_name"] == "A/B" + assert thctime.get_country_info("US", countries)[ + "country_name"] == "United States" + assert thctime.get_tz_info("X/Y", timezones) is None + assert thctime.get_country_info("XX", countries) is None + + +def test_where_is_it_420(monkeypatch): + monkeypatch.setattr(thctime.pytz, "all_timezones", ["Etc/UTC"]) + + tzs = [{"zone_name": "Etc/UTC", "country_code": "ZZ"}] + countries = [{"country_code": "ZZ", "country_name": "Nowhere"}] + + class FakeDatetime: + @staticmethod + def now(tz): + class Result: + hour = 4 + + return Result() + + monkeypatch.setattr(thctime, "datetime", FakeDatetime) + + res = thctime.where_is_it_420(tzs, countries) + assert res == ["Nowhere"] + + +def test_split_tz_name(): + assert thctime.split_tz_name("America/New_York") == ("America", "New_York") + assert thctime.split_tz_name("America/Argentina/Buenos_Aires") == ( + "America", + "Argentina/Buenos_Aires", + ) + assert thctime.split_tz_name("UTC") == ("UTC", "")