Files
thc-webhook/main.py
T

588 lines
19 KiB
Python

import os
from datetime import datetime, timedelta
import time
import logging
import threading
import requests
import schedule
from dotenv import load_dotenv
from templates import load_templates
from dashboard import create_app
from thctime import (
get_country_info,
get_tz_info,
get_tzdb_cache,
init_tzdb_cache,
load_countries,
load_timezones,
split_tz_name,
where_is_it_420,
)
SCHEDULED_NOTIFICATIONS = [
(15, "reminder"),
(20, "420"),
(45, "reminder_halftime"),
(50, "halftime"),
]
STATE_LOCK = threading.Lock()
STATE: dict = {
"running": True,
"started_at": datetime.now(),
"last_type": None,
"last_attempt_at": None,
"last_success_at": None,
"last_status_code": None,
"last_error": None,
"last_locations": [],
}
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s'
)
# Load environment variables
load_dotenv()
WEBHOOK_URL = os.getenv('DISCORD_WEBHOOK_URL')
DISCORD_BOT_TOKEN = os.getenv('DISCORD_BOT_TOKEN')
DISCORD_CHANNEL_ID = os.getenv('DISCORD_CHANNEL_ID')
GUILD_ID = os.getenv('DISCORD_GUILD_ID')
def get_state() -> dict:
with STATE_LOCK:
return dict(STATE)
def get_next_event() -> dict:
return get_next_scheduled_event()
def _update_state(**updates) -> None:
with STATE_LOCK:
STATE.update(updates)
def get_state_snapshot() -> dict:
with STATE_LOCK:
return dict(STATE)
def get_next_scheduled_event(now: datetime | None = None) -> dict:
"""Return the next scheduled notification time/type based on known minute marks."""
if now is None:
now = datetime.now()
candidates: list[tuple[datetime, str]] = []
for minute, msg_type in SCHEDULED_NOTIFICATIONS:
candidate = now.replace(minute=minute, second=0, microsecond=0)
if candidate > now:
candidates.append((candidate, msg_type))
if not candidates:
base = (now + timedelta(hours=1)).replace(minute=0,
second=0, microsecond=0)
next_dt = base.replace(minute=SCHEDULED_NOTIFICATIONS[0][0])
return {"at": next_dt, "type": SCHEDULED_NOTIFICATIONS[0][1]}
next_dt, next_type = min(candidates, key=lambda x: x[0])
return {"at": next_dt, "type": next_type}
def create_embed(type: str, tz_list: list[str] | None = None) -> dict:
"""
Create a Discord embed message.
"""
templates_path = os.getenv("TEMPLATES_PATH", "templates.json")
messages = load_templates(templates_path)
cache = get_tzdb_cache()
timezones = load_timezones() if cache is None else []
countries = load_countries() if cache is None else []
if type in messages:
msg = messages[type]
image_url = msg.get("image_url")
if isinstance(image_url, str) and image_url:
msg["image"] = {"url": image_url}
if type == "420":
# Check where it's 4:20
if tz_list is None:
if cache is None:
tz_list = where_is_it_420(timezones, countries)
else:
tz_list = where_is_it_420(
timezones,
countries,
tz_names=cache.get("tz_names"),
tz_to_country_code=cache.get("tz_to_country_code"),
country_code_to_name=cache.get("country_code_to_name"),
)
if tz_list:
tz_str = "\n".join(tz_list)
msg["text"] += f"\nIt's 4:20 in:\n{tz_str}"
else:
msg = {"text": "Unknown notification type", "color": 0xFF0000}
embed = {
"title": type.replace("_", " ").capitalize(),
"description": msg["text"],
"image": msg.get("image"),
"color": msg["color"],
"timestamp": time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()),
"footer": {"text": "THC - Toke Hash Coordinated"}
}
return embed
def get_discord_headers() -> dict[str, str]:
return {
"Authorization": f"Bot {DISCORD_BOT_TOKEN}",
"Content-Type": "application/json",
}
def parse_message_timestamp(message: dict) -> datetime:
return datetime.fromisoformat(message["timestamp"].replace("Z", ""))
def build_delete_entry(message: dict) -> dict:
return {
"id": message.get("id"),
"timestamp": parse_message_timestamp(message),
}
def parse_float(value: str | int | float | None) -> float | None:
if value is None:
return None
try:
return float(value)
except (TypeError, ValueError):
return None
def should_delete_message(
message: dict,
webhook_id: str,
author_id: str,
cutoff: int,
) -> bool:
message_timestamp = int(parse_message_timestamp(message).timestamp())
return (
message_timestamp <= cutoff
and message.get("webhook_id") == webhook_id
and message.get("author", {}).get("id") == author_id
)
def get_rate_limit_retry_after(response: requests.Response) -> float | None:
header_retry_after = parse_float(response.headers.get("Retry-After"))
if header_retry_after is not None:
return header_retry_after
reset_after = parse_float(response.headers.get("X-RateLimit-Reset-After"))
if reset_after is not None:
return reset_after
try:
payload = response.json()
except ValueError:
return None
return parse_float(payload.get("retry_after"))
def get_bucket_exhausted_delay(response: requests.Response) -> float | None:
remaining = response.headers.get("X-RateLimit-Remaining")
if remaining != "0":
return None
return parse_float(response.headers.get("X-RateLimit-Reset-After"))
def sleep_for_rate_limit(delay_seconds: float, reason: str) -> None:
if delay_seconds <= 0:
return
logging.info(
f"Waiting {delay_seconds:.3f}s for Discord rate limit reset ({reason}).")
time.sleep(delay_seconds)
def find_last_message_by_author(
headers: dict[str, str],
guild_id: str,
channel_id: str,
author_id: str,
) -> dict | None:
"""Find the newest indexed message for the author in the target guild channel."""
url = f"https://discord.com/api/v10/guilds/{guild_id}/messages/search"
params = [
("author_id", author_id),
("author_type", "webhook"),
("channel_id", channel_id),
("sort_by", "timestamp"),
("sort_order", "desc"),
("limit", "10"),
]
for _ in range(3):
try:
response = requests.get(
url, headers=headers, params=params, timeout=10)
except requests.RequestException as e:
logging.error(f"Error searching guild messages: {e}")
return None
if response.status_code == 202:
payload = response.json()
retry_after = float(payload.get("retry_after", 1) or 1)
logging.info(
f"Guild search index not ready. Retrying after {retry_after} seconds."
)
time.sleep(retry_after)
continue
if response.status_code == 429:
retry_after = get_rate_limit_retry_after(response) or 1.0
sleep_for_rate_limit(retry_after, "guild search")
continue
if response.status_code != 200:
logging.error(
f"Failed to search guild messages: {response.status_code} - {response.text}"
)
return None
message_groups = response.json().get("messages", [])
if not message_groups or not message_groups[0]:
return None
return message_groups[0][0]
logging.error("Guild search index did not become available in time.")
return None
def fetch_messages_to_delete(headers: dict, channel_id: str, webhook_id: str, author_id: str, cutoff: int, last_message_id: str | None = None) -> tuple[list[dict], str | None]:
"""
Fetch messages from the channel that are older than the cutoff timestamp and sent by the webhook.
Uses pagination with the 'before' parameter to resume from the last processed message.
Returns a tuple of (list of messages to delete, last message ID for pagination).
"""
url = f"https://discord.com/api/v10/channels/{channel_id}/messages"
params: dict[str, str | int] = {
"limit": 100, # Maximum allowed by Discord API
}
# Use the 'before' parameter for pagination if last_message_id is provided
if last_message_id:
params["before"] = last_message_id
try:
for _ in range(3):
response = requests.get(url, headers=headers,
params=params, timeout=10)
if response.status_code == 429:
retry_after = get_rate_limit_retry_after(response) or 1.0
sleep_for_rate_limit(retry_after, "channel message fetch")
continue
break
else:
logging.error(
"Failed to fetch messages after repeated rate limits.")
return [], last_message_id
if response.status_code == 200:
messages = response.json()
delete_list = []
new_last_message_id = None
for message in messages:
new_last_message_id = message.get("id")
if should_delete_message(
message,
webhook_id,
author_id,
cutoff,
):
delete_list.append(build_delete_entry(message))
# Limit the list to 100 items
if len(delete_list) >= 100:
break
return delete_list, new_last_message_id
else:
logging.error(
f"Failed to fetch messages: {response.status_code} - {response.text}")
return [], last_message_id
except requests.RequestException as e:
logging.error(f"Error fetching messages: {e}")
return [], last_message_id
def delete_message(headers: dict, channel_id: str, message_id: str) -> tuple[bool, float | None, bool]:
"""
Delete a single message from the channel.
Returns:
tuple[bool, float | None, bool]:
- whether the delete succeeded
- how long to wait before the next request, if any
- whether to abort the batch because further requests would be invalid
"""
delete_url = f"https://discord.com/api/v10/channels/{channel_id}/messages/{message_id}"
delete_response = requests.delete(delete_url, headers=headers, timeout=10)
if delete_response.status_code == 204:
return True, get_bucket_exhausted_delay(delete_response), False
if delete_response.status_code == 429:
retry_after = get_rate_limit_retry_after(delete_response) or 1.0
scope = delete_response.headers.get("X-RateLimit-Scope", "unknown")
is_global = delete_response.headers.get(
"X-RateLimit-Global", "false").lower() == "true"
logging.warning(
"Discord rate limit hit while deleting message %s: scope=%s global=%s retry_after=%.3fs",
message_id,
scope,
is_global,
retry_after,
)
return False, retry_after, False
if delete_response.status_code in {401, 403}:
logging.error(
"Failed to delete message %s: %s - %s. Stopping deletes to avoid invalid request spam.",
message_id,
delete_response.status_code,
delete_response.text,
)
return False, None, True
logging.error(
f"Failed to delete message {message_id}: {delete_response.status_code} - {delete_response.text}")
return False, None, False
def delete_old_messages(minutes: int = 6) -> None:
"""
Delete all messages sent by the webhook in the last `minutes` minutes.
Uses a dynamic slowdown to avoid hitting Discord API rate limits and pagination to fetch all messages.
"""
if not DISCORD_BOT_TOKEN or not DISCORD_CHANNEL_ID or not GUILD_ID:
logging.error(
"DISCORD_BOT_TOKEN, DISCORD_CHANNEL_ID, or DISCORD_GUILD_ID not set")
return
headers = get_discord_headers()
# Calculate the time `minutes` minutes ago
cutoff_timestamp = datetime.now() - timedelta(minutes=minutes)
cutoff = int(cutoff_timestamp.timestamp())
webhook_id = '1413817504194760766'
author_id = '1413817504194760766'
last_author_message = find_last_message_by_author(
headers,
GUILD_ID,
DISCORD_CHANNEL_ID,
author_id,
)
if last_author_message is None:
logging.info("No indexed messages found for the target author.")
return
last_message_id = last_author_message.get("id")
if not last_message_id:
logging.info("Search result did not contain a message id.")
return
deleted_count = 0
if should_delete_message(
last_author_message,
webhook_id,
author_id,
cutoff,
):
anchor_message = build_delete_entry(last_author_message)
deleted, wait_seconds, abort_batch = delete_message(
headers,
DISCORD_CHANNEL_ID,
anchor_message["id"],
)
if deleted:
deleted_count += 1
logging.info(
f"Deleted message {anchor_message['id']} from {anchor_message['timestamp'].isoformat()}"
)
elif abort_batch:
return
if wait_seconds is not None:
sleep_for_rate_limit(wait_seconds, "delete bucket")
while True:
# Fetch messages to delete with pagination
delete_list, next_last_message_id = fetch_messages_to_delete(
headers, DISCORD_CHANNEL_ID, webhook_id, author_id, cutoff, last_message_id
)
# Stop scanning when a page has no eligible messages to delete.
# Continuing pagination here can walk indefinitely through history.
if not delete_list:
if deleted_count == 0:
logging.info("No messages to delete.")
else:
logging.info("No more messages to delete.")
break
for message in delete_list:
message_id = message["id"]
message_time = message["timestamp"]
deleted, wait_seconds, abort_batch = delete_message(
headers,
DISCORD_CHANNEL_ID,
message_id,
)
if deleted:
deleted_count += 1
logging.info(
f"Deleted message {message_id} from {message_time.isoformat()}")
elif abort_batch:
logging.info(
"Stopping delete batch after an invalid Discord response.")
return
if wait_seconds is not None:
sleep_for_rate_limit(wait_seconds, "delete bucket")
if next_last_message_id is None or next_last_message_id == last_message_id:
break
last_message_id = next_last_message_id
logging.info(
f"Deleted {deleted_count} messages older than {minutes} minutes sent by the webhook.")
def send_notification(message: str) -> None:
"""
Send a notification to the Discord webhook.
"""
if not WEBHOOK_URL:
logging.error("WEBHOOK_URL not set")
return
_update_state(
last_type=message,
last_attempt_at=datetime.now(),
last_status_code=None,
last_error=None,
)
# Warm the tzdb cache once per process to avoid repeated CSV parsing
# and to avoid scanning all pytz timezones every hour.
try:
init_tzdb_cache()
except Exception as e:
logging.error(f"Failed to initialize tzdb cache: {e}")
tz_list: list[str] | None = None
if message == "420":
try:
cache = get_tzdb_cache()
if cache is None:
tz_list = where_is_it_420(load_timezones(), load_countries())
else:
tz_list = where_is_it_420(
[],
[],
tz_names=cache.get("tz_names"),
tz_to_country_code=cache.get("tz_to_country_code"),
country_code_to_name=cache.get("country_code_to_name"),
)
_update_state(last_locations=tz_list or [])
except Exception as e:
_update_state(last_locations=[], last_error=str(e))
embed = create_embed(message, tz_list=tz_list)
data = {"embeds": [embed]}
try:
response = requests.post(WEBHOOK_URL, json=data, timeout=10)
if response.status_code == 204:
logging.info(f"Notification sent: {message}")
_update_state(last_success_at=datetime.now(),
last_status_code=response.status_code)
else:
logging.error(
f"Failed to send notification: {response.status_code} - "
f"{response.text}"
)
_update_state(last_status_code=response.status_code,
last_error=response.text)
except requests.RequestException as e:
logging.error(f"Error sending notification: {e}")
_update_state(last_error=str(e))
def schedule_notification(interval: str, at: str, type: str) -> None:
"""Example: schedule.every().hour.at(":15").do(send_notification, "reminder")"""
if interval == "hour":
schedule.every().hour.at(at).do(send_notification, type)
elif interval == "day":
schedule.every().day.at(at).do(send_notification, type)
else:
logging.error(f"Unsupported interval: {interval}")
def start_dashboard() -> None:
"""Compatibility hook for tests and optional dashboard startup."""
app = create_app(get_state=get_state, get_next_event=get_next_event)
app.run(host="0.0.0.0", port=8080, debug=False, use_reloader=False)
def main() -> None:
"""
Main function to run the scheduler.
"""
# Start the dashboard in a separate thread
dashboard_thread = threading.Thread(target=start_dashboard, daemon=True)
dashboard_thread.start()
# Schedule notifications based on the defined SCHEDULED_NOTIFICATIONS
for minute, msg_type in SCHEDULED_NOTIFICATIONS:
schedule_notification("hour", f":{minute:02d}", msg_type)
# Schedule deletion of old messages every 5 minutes
schedule.every(5).minutes.do(delete_old_messages, 6)
logging.info("Scheduler started.")
# Test the notification on startup
send_notification("test")
# delete the test message after a short delay to keep the channel clean
schedule.every(1).minutes.do(delete_old_messages, 1)
# delete old messages on startup to clean up any previous notifications
# delete_old_messages(6)
try:
while True:
schedule.run_pending()
time.sleep(1) # Check every second
except KeyboardInterrupt:
logging.info("Scheduler stopped by user")
except Exception as e:
logging.error(f"Unexpected error: {e}")
if __name__ == "__main__":
main()