import os import pytz from datetime import datetime, timedelta import time import logging import threading from zoneinfo import ZoneInfo, ZoneInfoNotFoundError import requests import schedule from dotenv import load_dotenv from templates import load_templates TZDB_CACHE: dict | None = None STATE_LOCK = threading.Lock() STATE: dict = { "running": True, "started_at": datetime.now(), "last_type": None, "last_attempt_at": None, "last_success_at": None, "last_status_code": None, "last_error": None, "last_locations": [], } # Configure logging logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s' ) # Load environment variables load_dotenv() WEBHOOK_URL = os.getenv('DISCORD_WEBHOOK_URL') DISCORD_BOT_TOKEN = os.getenv('DISCORD_BOT_TOKEN') DISCORD_CHANNEL_ID = os.getenv('DISCORD_CHANNEL_ID') GUILD_ID = os.getenv('DISCORD_GUILD_ID') def init_tzdb_cache() -> dict: """Initialize a cached lookup structure for tzdb data. This keeps the hourly scheduler fast by: - Building O(1) maps (zone_name -> country_code, country_code -> country_name) - Precomputing a list of tz names that exist in both tzdb CSVs and `pytz` Note: this is intentionally NOT run at import time so tests can monkeypatch `load_timezones`/`load_countries` without needing to reset global state. """ global TZDB_CACHE if TZDB_CACHE is not None: return TZDB_CACHE timezones = load_timezones() countries = load_countries() tz_to_country_code: dict[str, str] = {} tz_meta: dict[str, dict] = {} for tz in timezones: zone_name = tz.get("zone_name") country_code = tz.get("country_code") if not isinstance(zone_name, str) or not zone_name: continue if not isinstance(country_code, str) or not country_code: continue tz_to_country_code[zone_name] = country_code region, city = split_tz_name(zone_name) tz_meta[zone_name] = { "zone_name": zone_name, "country_code": country_code, "region": region, "city": city, } country_code_to_name: dict[str, str] = {} for c in countries: code = c.get("country_code") name = c.get("country_name") if code and name: country_code_to_name[code] = str(name).strip().strip('"') # Attach resolved country names onto tz_meta (storage-only for now). for zone_name, meta in tz_meta.items(): code = meta.get("country_code") if isinstance(code, str): meta["country_name"] = country_code_to_name.get(code) # Vetted tz list: only names that are present in tzdb and loadable by zoneinfo. # Installing the `tzdata` package keeps this mapping up-to-date. tz_names: list[str] = [] for zone_name in tz_to_country_code.keys(): try: ZoneInfo(zone_name) except ZoneInfoNotFoundError: continue tz_names.append(zone_name) TZDB_CACHE = { "tz_to_country_code": tz_to_country_code, "country_code_to_name": country_code_to_name, "tz_names": tz_names, "tz_meta": tz_meta, } return TZDB_CACHE def split_tz_name(zone_name: str) -> tuple[str, str]: """Split an IANA timezone name into (region, city). Examples: - "America/New_York" -> ("America", "New_York") - "America/Argentina/Buenos_Aires" -> ("America", "Argentina/Buenos_Aires") - "UTC" -> ("UTC", "") """ if "/" not in zone_name: return zone_name, "" region, rest = zone_name.split("/", 1) return region, rest def _update_state(**updates) -> None: with STATE_LOCK: STATE.update(updates) def get_state_snapshot() -> dict: with STATE_LOCK: return dict(STATE) def get_next_scheduled_event(now: datetime | None = None) -> dict: """Return the next scheduled notification time/type based on known minute marks.""" if now is None: now = datetime.now() schedule_map = [ (15, "reminder"), (20, "420"), (45, "reminder_halftime"), (50, "halftime"), ] candidates: list[tuple[datetime, str]] = [] for minute, msg_type in schedule_map: candidate = now.replace(minute=minute, second=0, microsecond=0) if candidate > now: candidates.append((candidate, msg_type)) if not candidates: base = (now + timedelta(hours=1)).replace(minute=0, second=0, microsecond=0) next_dt = base.replace(minute=schedule_map[0][0]) return {"at": next_dt, "type": schedule_map[0][1]} next_dt, next_type = min(candidates, key=lambda x: x[0]) return {"at": next_dt, "type": next_type} def start_dashboard() -> None: """Start the minimal dashboard in a background thread.""" enabled = os.getenv("DASHBOARD_ENABLED", "1").strip().lower() not in { "0", "false", "no"} if not enabled: return host = os.getenv("DASHBOARD_HOST", "0.0.0.0") port = int(os.getenv("DASHBOARD_PORT", "8080")) def _run() -> None: try: from dashboard import create_app app = create_app( get_state=get_state_snapshot, get_next_event=lambda: get_next_scheduled_event(), ) app.config["TEMPLATES_PATH"] = os.getenv( "TEMPLATES_PATH", "templates.json") app.run(host=host, port=port, debug=False, use_reloader=False) except Exception as e: logging.error(f"Dashboard failed to start: {e}") thread = threading.Thread(target=_run, name="dashboard", daemon=True) thread.start() def load_timezones() -> list[dict]: """Load timezones from csv file.""" # Read the CSV file and return a list of timezones with open("tzdb/TimeZoneDB.csv/time_zone.csv", "r", encoding="utf-8") as f: lines = f.readlines() # Fields: zone_name,country_code,abbreviation,time_start,gmt_offset,dst timezones = [] for line in lines: fields = line.strip().split(",") if len(fields) >= 5: timezones.append({ "zone_name": fields[0], "country_code": fields[1], "abbreviation": fields[2], "time_start": fields[3], "gmt_offset": int(fields[4]), "dst": fields[5] == '1' }) return timezones def load_countries() -> list[dict]: """Load countries from csv file.""" # Read the CSV file and return a list of countries with open("tzdb/TimeZoneDB.csv/country.csv", "r", encoding="utf-8") as f: lines = f.readlines() # Fields: country_code,country_name countries = [] for line in lines: fields = line.strip().split(",") if len(fields) >= 2: countries.append({ "country_code": fields[0], "country_name": fields[1] }) return countries def create_embed(type: str, tz_list: list[str] | None = None) -> dict: """ Create a Discord embed message. """ templates_path = os.getenv("TEMPLATES_PATH", "templates.json") messages = load_templates(templates_path) cache = TZDB_CACHE timezones = load_timezones() if cache is None else [] countries = load_countries() if cache is None else [] if type in messages: msg = messages[type] image_url = msg.get("image_url") if isinstance(image_url, str) and image_url: msg["image"] = {"url": image_url} if type == "420": # Check where it's 4:20 if tz_list is None: if cache is None: tz_list = where_is_it_420(timezones, countries) else: tz_list = where_is_it_420( timezones, countries, tz_names=cache.get("tz_names"), tz_to_country_code=cache.get("tz_to_country_code"), country_code_to_name=cache.get("country_code_to_name"), ) if tz_list: tz_str = "\n".join(tz_list) msg["text"] += f"\nIt's 4:20 in:\n{tz_str}" else: msg = {"text": "Unknown notification type", "color": 0xFF0000} embed = { "title": type.replace("_", " ").capitalize(), "description": msg["text"], "image": msg.get("image"), "color": msg["color"], "timestamp": time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()), "footer": {"text": "THC - Toke Hash Coordinated"} } return embed def get_discord_headers() -> dict[str, str]: return { "Authorization": f"Bot {DISCORD_BOT_TOKEN}", "Content-Type": "application/json", } def parse_message_timestamp(message: dict) -> datetime: return datetime.fromisoformat(message["timestamp"].replace("Z", "")) def build_delete_entry(message: dict) -> dict: return { "id": message.get("id"), "timestamp": parse_message_timestamp(message), } def parse_float(value: str | int | float | None) -> float | None: if value is None: return None try: return float(value) except (TypeError, ValueError): return None def should_delete_message( message: dict, webhook_id: str, author_id: str, five_minutes_ago_timestamp: int, ) -> bool: message_timestamp = int(parse_message_timestamp(message).timestamp()) return ( message_timestamp <= five_minutes_ago_timestamp and message.get("webhook_id") == webhook_id and message.get("author", {}).get("id") == author_id ) def get_rate_limit_retry_after(response: requests.Response) -> float | None: header_retry_after = parse_float(response.headers.get("Retry-After")) if header_retry_after is not None: return header_retry_after reset_after = parse_float(response.headers.get("X-RateLimit-Reset-After")) if reset_after is not None: return reset_after try: payload = response.json() except ValueError: return None return parse_float(payload.get("retry_after")) def get_bucket_exhausted_delay(response: requests.Response) -> float | None: remaining = response.headers.get("X-RateLimit-Remaining") if remaining != "0": return None return parse_float(response.headers.get("X-RateLimit-Reset-After")) def sleep_for_rate_limit(delay_seconds: float, reason: str) -> None: if delay_seconds <= 0: return logging.info( f"Waiting {delay_seconds:.3f}s for Discord rate limit reset ({reason}).") time.sleep(delay_seconds) def find_last_message_by_author( headers: dict[str, str], guild_id: str, channel_id: str, author_id: str, ) -> dict | None: """Find the newest indexed message for the author in the target guild channel.""" url = f"https://discord.com/api/v10/guilds/{guild_id}/messages/search" params = [ ("author_id", author_id), ("author_type", "webhook"), ("channel_id", channel_id), ("sort_by", "timestamp"), ("sort_order", "desc"), ("limit", "10"), ] for _ in range(3): try: response = requests.get( url, headers=headers, params=params, timeout=10) except requests.RequestException as e: logging.error(f"Error searching guild messages: {e}") return None if response.status_code == 202: payload = response.json() retry_after = float(payload.get("retry_after", 1) or 1) logging.info( f"Guild search index not ready. Retrying after {retry_after} seconds." ) time.sleep(retry_after) continue if response.status_code == 429: retry_after = get_rate_limit_retry_after(response) or 1.0 sleep_for_rate_limit(retry_after, "guild search") continue if response.status_code != 200: logging.error( f"Failed to search guild messages: {response.status_code} - {response.text}" ) return None message_groups = response.json().get("messages", []) if not message_groups or not message_groups[0]: return None return message_groups[0][0] logging.error("Guild search index did not become available in time.") return None def fetch_messages_to_delete(headers: dict, channel_id: str, webhook_id: str, author_id: str, cutoff: int, last_message_id: str | None = None) -> tuple[list[dict], str | None]: """ Fetch messages from the channel that are older than the cutoff timestamp and sent by the webhook. Uses pagination with the 'before' parameter to resume from the last processed message. Returns a tuple of (list of messages to delete, last message ID for pagination). """ url = f"https://discord.com/api/v10/channels/{channel_id}/messages" params: dict[str, str | int] = { "limit": 100, # Maximum allowed by Discord API } # Use the 'before' parameter for pagination if last_message_id is provided if last_message_id: params["before"] = last_message_id try: for _ in range(3): response = requests.get(url, headers=headers, params=params, timeout=10) if response.status_code == 429: retry_after = get_rate_limit_retry_after(response) or 1.0 sleep_for_rate_limit(retry_after, "channel message fetch") continue break else: logging.error( "Failed to fetch messages after repeated rate limits.") return [], last_message_id if response.status_code == 200: messages = response.json() delete_list = [] new_last_message_id = None for message in messages: new_last_message_id = message.get("id") if should_delete_message( message, webhook_id, author_id, cutoff, ): delete_list.append(build_delete_entry(message)) # Limit the list to 100 items if len(delete_list) >= 100: break return delete_list, new_last_message_id else: logging.error( f"Failed to fetch messages: {response.status_code} - {response.text}") return [], last_message_id except requests.RequestException as e: logging.error(f"Error fetching messages: {e}") return [], last_message_id def delete_message(headers: dict, channel_id: str, message_id: str) -> tuple[bool, float | None, bool]: """ Delete a single message from the channel. Returns: tuple[bool, float | None, bool]: - whether the delete succeeded - how long to wait before the next request, if any - whether to abort the batch because further requests would be invalid """ delete_url = f"https://discord.com/api/v10/channels/{channel_id}/messages/{message_id}" delete_response = requests.delete(delete_url, headers=headers, timeout=10) if delete_response.status_code == 204: return True, get_bucket_exhausted_delay(delete_response), False if delete_response.status_code == 429: retry_after = get_rate_limit_retry_after(delete_response) or 1.0 scope = delete_response.headers.get("X-RateLimit-Scope", "unknown") is_global = delete_response.headers.get( "X-RateLimit-Global", "false").lower() == "true" logging.warning( "Discord rate limit hit while deleting message %s: scope=%s global=%s retry_after=%.3fs", message_id, scope, is_global, retry_after, ) return False, retry_after, False if delete_response.status_code in {401, 403}: logging.error( "Failed to delete message %s: %s - %s. Stopping deletes to avoid invalid request spam.", message_id, delete_response.status_code, delete_response.text, ) return False, None, True logging.error( f"Failed to delete message {message_id}: {delete_response.status_code} - {delete_response.text}") return False, None, False def delete_old_messages(minutes: int = 6) -> None: """ Delete all messages sent by the webhook in the last `minutes` minutes. Uses a dynamic slowdown to avoid hitting Discord API rate limits and pagination to fetch all messages. """ if not DISCORD_BOT_TOKEN or not DISCORD_CHANNEL_ID or not GUILD_ID: logging.error( "DISCORD_BOT_TOKEN, DISCORD_CHANNEL_ID, or DISCORD_GUILD_ID not set") return headers = get_discord_headers() # Calculate the time `minutes` minutes ago cutoff_timestamp = datetime.now() - timedelta(minutes=minutes) cutoff = int(cutoff_timestamp.timestamp()) webhook_id = '1413817504194760766' author_id = '1413817504194760766' last_author_message = find_last_message_by_author( headers, GUILD_ID, DISCORD_CHANNEL_ID, author_id, ) if last_author_message is None: logging.info("No indexed messages found for the target author.") return last_message_id = last_author_message.get("id") if not last_message_id: logging.info("Search result did not contain a message id.") return deleted_count = 0 if should_delete_message( last_author_message, webhook_id, author_id, cutoff, ): anchor_message = build_delete_entry(last_author_message) deleted, wait_seconds, abort_batch = delete_message( headers, DISCORD_CHANNEL_ID, anchor_message["id"], ) if deleted: deleted_count += 1 logging.info( f"Deleted message {anchor_message['id']} from {anchor_message['timestamp'].isoformat()}" ) elif abort_batch: return if wait_seconds is not None: sleep_for_rate_limit(wait_seconds, "delete bucket") while True: # Fetch messages to delete with pagination delete_list, next_last_message_id = fetch_messages_to_delete( headers, DISCORD_CHANNEL_ID, webhook_id, author_id, cutoff, last_message_id ) # Exit loop if no more messages to delete if not delete_list: if deleted_count == 0: if next_last_message_id is None or next_last_message_id == last_message_id: logging.info("No messages to delete.") break else: logging.info( "No messages deleted in this batch, but more messages may exist... Continuing pagination.") last_message_id = next_last_message_id continue else: logging.info("No more messages to delete.") break for message in delete_list: message_id = message["id"] message_time = message["timestamp"] deleted, wait_seconds, abort_batch = delete_message( headers, DISCORD_CHANNEL_ID, message_id, ) if deleted: deleted_count += 1 logging.info( f"Deleted message {message_id} from {message_time.isoformat()}") elif abort_batch: logging.info( "Stopping delete batch after an invalid Discord response.") return if wait_seconds is not None: sleep_for_rate_limit(wait_seconds, "delete bucket") if next_last_message_id is None or next_last_message_id == last_message_id: break last_message_id = next_last_message_id logging.info( f"Deleted {deleted_count} messages older than {minutes} minutes sent by the webhook.") def send_notification(message: str) -> None: """ Send a notification to the Discord webhook. """ if not WEBHOOK_URL: logging.error("WEBHOOK_URL not set") return _update_state( last_type=message, last_attempt_at=datetime.now(), last_status_code=None, last_error=None, ) # Warm the tzdb cache once per process to avoid repeated CSV parsing # and to avoid scanning all pytz timezones every hour. try: init_tzdb_cache() except Exception as e: logging.error(f"Failed to initialize tzdb cache: {e}") tz_list: list[str] | None = None if message == "420": try: cache = TZDB_CACHE if cache is None: tz_list = where_is_it_420(load_timezones(), load_countries()) else: tz_list = where_is_it_420( [], [], tz_names=cache.get("tz_names"), tz_to_country_code=cache.get("tz_to_country_code"), country_code_to_name=cache.get("country_code_to_name"), ) _update_state(last_locations=tz_list) except Exception as e: _update_state(last_locations=[], last_error=str(e)) embed = create_embed(message, tz_list=tz_list) data = {"embeds": [embed]} try: response = requests.post(WEBHOOK_URL, json=data, timeout=10) if response.status_code == 204: logging.info(f"Notification sent: {message}") _update_state(last_success_at=datetime.now(), last_status_code=response.status_code) else: logging.error( f"Failed to send notification: {response.status_code} - " f"{response.text}" ) _update_state(last_status_code=response.status_code, last_error=response.text) except requests.RequestException as e: logging.error(f"Error sending notification: {e}") _update_state(last_error=str(e)) def get_tz_info(tz_name: str, timezones: list[dict]) -> dict | None: """Get timezone info by name.""" return next((tz for tz in timezones if tz["zone_name"] == tz_name), None) def get_country_info(country_code: str, countries: list[dict]) -> dict | None: """Get country info by country code.""" return next((c for c in countries if c["country_code"] == country_code), None) def where_is_it_420( timezones: list[dict], countries: list[dict], tz_names: list[str] | None = None, tz_to_country_code: dict[str, str] | None = None, country_code_to_name: dict[str, str] | None = None, ) -> list[str]: """Get timezones where the current hour is 4 or 16, indicating it's 4:20 there. Returns: list[str]: A list of timezones where it's currently 4:20 PM or AM. """ # Build fast lookup dicts if not provided. if tz_to_country_code is None: tz_to_country_code = {} for tz in timezones: zone_name = tz.get("zone_name") country_code = tz.get("country_code") if isinstance(zone_name, str) and isinstance(country_code, str): tz_to_country_code[zone_name] = country_code if country_code_to_name is None: country_code_to_name = {} for c in countries: code = c.get("country_code") name = c.get("country_name") if isinstance(code, str) and name is not None: country_code_to_name[code] = str(name).strip().strip('"') names_to_check = tz_names if tz_names is not None else pytz.all_timezones results: list[str] = [] seen: set[str] = set() for tz_name in names_to_check: try: tz_obj = pytz.timezone(tz_name) except Exception: continue now = datetime.now(tz_obj) if now.hour != 4 and now.hour != 16: continue country_code = tz_to_country_code.get(tz_name) if not country_code: continue country_name = country_code_to_name.get(country_code) if not country_name: continue if country_name in seen: continue seen.add(country_name) results.append(country_name) return results def main() -> None: """ Main function to run the scheduler. """ # start_dashboard() # Schedule notifications schedule.every().hour.at(":15").do(send_notification, "reminder") schedule.every().hour.at(":20").do(send_notification, "420") # schedule.every().hour.at(":45").do(send_notification, "reminder_halftime") # schedule.every().hour.at(":50").do(send_notification, "halftime") # Schedule deletion of old messages every 3 minutes schedule.every(3).minutes.do(delete_old_messages, 6) logging.info("Scheduler started.") # Test the notification on startup # send_notification("420") # delete old messages on startup to clean up any previous notifications delete_old_messages(60) try: while True: schedule.run_pending() time.sleep(1) # Check every second except KeyboardInterrupt: logging.info("Scheduler stopped by user") except Exception as e: logging.error(f"Unexpected error: {e}") if __name__ == "__main__": main()