diff --git a/.gitignore b/.gitignore index 1957995..02f2aed 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,5 @@ .github/copilot-instructions.md +.github/instructions/* # Byte-compiled / optimized / DLL files __pycache__/ *.py[cod] diff --git a/Dockerfile b/Dockerfile index 13d75dd..3135939 100644 --- a/Dockerfile +++ b/Dockerfile @@ -11,6 +11,9 @@ RUN pip install --no-cache-dir -r requirements.txt # Copy application code COPY . . +# Dashboard (Flask) port +EXPOSE 8080 + # Create a non-root user RUN addgroup -S appgroup && adduser -S appuser -G appgroup USER appuser diff --git a/README.md b/README.md index c85de87..7ccaa17 100644 --- a/README.md +++ b/README.md @@ -11,16 +11,21 @@ This Python application sends notifications to a Discord channel via webhook eve ``` 2. Set up your Discord webhook: - - Go to your Discord server settings > Integrations > Webhooks - Create a new webhook and copy the URL -3. Update the `.env` file with your webhook URL: +3. Update the `.env` file with your webhook URL, bot token, channel ID, and guild ID: ```text DISCORD_WEBHOOK_URL=https://discord.com/api/webhooks/your_webhook_id/your_webhook_token + DISCORD_BOT_TOKEN=your_bot_token + DISCORD_CHANNEL_ID=your_channel_id + DISCORD_GUILD_ID=your_guild_id ``` + - To get a bot token, create a Discord bot in the [Discord Developer Portal](https://discord.com/developers/applications) and invite it to your server with the `Manage Messages` permission. + - The channel ID can be found by enabling Developer Mode in Discord and right-clicking the channel name. + ## Running the App Run the application: @@ -31,6 +36,19 @@ python main.py The app will run continuously and send notifications at the scheduled times. +### Dashboard + +By default, a minimal dashboard is available at `http://localhost:8080/`. + +You can disable it by setting `DASHBOARD_ENABLED=0`. + +### Admin + +You can edit the embed message templates at `http://localhost:8080/admin`. + +- Templates are saved to `templates.json` by default. +- Override the location with `TEMPLATES_PATH=/path/to/templates.json`. + ## Requirements - Python 3.6+ diff --git a/dashboard.py b/dashboard.py new file mode 100644 index 0000000..75871d5 --- /dev/null +++ b/dashboard.py @@ -0,0 +1,143 @@ +from __future__ import annotations + +from collections.abc import Callable +from datetime import datetime + +from flask import Flask, request + +from templates import load_templates, parse_color, save_templates + + +def _fmt_dt(dt: datetime | None) -> str: + if dt is None: + return "—" + try: + return dt.isoformat(timespec="seconds") + except Exception: + return str(dt) + + +def create_app( + *, + get_state: Callable[[], dict], + get_next_event: Callable[[], dict], +) -> Flask: + app = Flask(__name__) + + @app.get("/") + def index() -> str: + state = get_state() or {} + next_event = get_next_event() or {} + + locations = state.get("last_locations") or [] + locations_html = "".join(f"
Total: {len(locations)}
" + f"Edits are saved to the templates JSON file on disk.
" + "" + "" + ) + + @app.post("/admin") + def admin_post() -> tuple[str, int]: + templates_path = app.config.get("TEMPLATES_PATH", "templates.json") + current = load_templates(templates_path) + + errors: list[str] = [] + updated: dict[str, dict] = {} + for key in current.keys(): + text = request.form.get(f"{key}__text", "").strip() + color_raw = request.form.get(f"{key}__color", "").strip() + image_url = request.form.get(f"{key}__image_url", "").strip() + + if not text: + errors.append(f"{key}: text is required") + continue + try: + color = parse_color(color_raw) + except Exception as e: + errors.append(f"{key}: invalid color ({e})") + continue + + tpl: dict = {"text": text, "color": color} + if image_url: + tpl["image_url"] = image_url + updated[key] = tpl + + if errors: + msg = "Errors:
{msg}
Saved.
" + "" + "", + 200, + ) + + return app diff --git a/docker-compose.yml b/docker-compose.yml index 0f9be66..59f3d1b 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -4,7 +4,10 @@ services: container_name: thc-webhook-app environment: - DISCORD_WEBHOOK_URL=${DISCORD_WEBHOOK_URL} + - DASHBOARD_PORT=8080 restart: unless-stopped + ports: + - "8080:8080" logging: driver: json-file options: diff --git a/main.py b/main.py index 0f65163..6438d38 100644 --- a/main.py +++ b/main.py @@ -1,12 +1,32 @@ import os import pytz -from datetime import datetime +from datetime import datetime, timedelta import time import logging +import threading +from zoneinfo import ZoneInfo, ZoneInfoNotFoundError import requests import schedule from dotenv import load_dotenv +from templates import load_templates + + +TZDB_CACHE: dict | None = None + + +STATE_LOCK = threading.Lock() +STATE: dict = { + "running": True, + "started_at": datetime.now(), + "last_type": None, + "last_attempt_at": None, + "last_success_at": None, + "last_status_code": None, + "last_error": None, + "last_locations": [], +} + # Configure logging logging.basicConfig( level=logging.INFO, @@ -17,6 +37,157 @@ logging.basicConfig( load_dotenv() WEBHOOK_URL = os.getenv('DISCORD_WEBHOOK_URL') +DISCORD_BOT_TOKEN = os.getenv('DISCORD_BOT_TOKEN') +DISCORD_CHANNEL_ID = os.getenv('DISCORD_CHANNEL_ID') +GUILD_ID = os.getenv('DISCORD_GUILD_ID') + + +def init_tzdb_cache() -> dict: + """Initialize a cached lookup structure for tzdb data. + + This keeps the hourly scheduler fast by: + - Building O(1) maps (zone_name -> country_code, country_code -> country_name) + - Precomputing a list of tz names that exist in both tzdb CSVs and `pytz` + + Note: this is intentionally NOT run at import time so tests can monkeypatch + `load_timezones`/`load_countries` without needing to reset global state. + """ + global TZDB_CACHE + if TZDB_CACHE is not None: + return TZDB_CACHE + + timezones = load_timezones() + countries = load_countries() + + tz_to_country_code: dict[str, str] = {} + tz_meta: dict[str, dict] = {} + for tz in timezones: + zone_name = tz.get("zone_name") + country_code = tz.get("country_code") + if not isinstance(zone_name, str) or not zone_name: + continue + if not isinstance(country_code, str) or not country_code: + continue + + tz_to_country_code[zone_name] = country_code + region, city = split_tz_name(zone_name) + tz_meta[zone_name] = { + "zone_name": zone_name, + "country_code": country_code, + "region": region, + "city": city, + } + + country_code_to_name: dict[str, str] = {} + for c in countries: + code = c.get("country_code") + name = c.get("country_name") + if code and name: + country_code_to_name[code] = str(name).strip().strip('"') + + # Attach resolved country names onto tz_meta (storage-only for now). + for zone_name, meta in tz_meta.items(): + code = meta.get("country_code") + if isinstance(code, str): + meta["country_name"] = country_code_to_name.get(code) + + # Vetted tz list: only names that are present in tzdb and loadable by zoneinfo. + # Installing the `tzdata` package keeps this mapping up-to-date. + tz_names: list[str] = [] + for zone_name in tz_to_country_code.keys(): + try: + ZoneInfo(zone_name) + except ZoneInfoNotFoundError: + continue + tz_names.append(zone_name) + + TZDB_CACHE = { + "tz_to_country_code": tz_to_country_code, + "country_code_to_name": country_code_to_name, + "tz_names": tz_names, + "tz_meta": tz_meta, + } + return TZDB_CACHE + + +def split_tz_name(zone_name: str) -> tuple[str, str]: + """Split an IANA timezone name into (region, city). + + Examples: + - "America/New_York" -> ("America", "New_York") + - "America/Argentina/Buenos_Aires" -> ("America", "Argentina/Buenos_Aires") + - "UTC" -> ("UTC", "") + """ + if "/" not in zone_name: + return zone_name, "" + region, rest = zone_name.split("/", 1) + return region, rest + + +def _update_state(**updates) -> None: + with STATE_LOCK: + STATE.update(updates) + + +def get_state_snapshot() -> dict: + with STATE_LOCK: + return dict(STATE) + + +def get_next_scheduled_event(now: datetime | None = None) -> dict: + """Return the next scheduled notification time/type based on known minute marks.""" + if now is None: + now = datetime.now() + + schedule_map = [ + (15, "reminder"), + (20, "420"), + (45, "reminder_halftime"), + (50, "halftime"), + ] + + candidates: list[tuple[datetime, str]] = [] + for minute, msg_type in schedule_map: + candidate = now.replace(minute=minute, second=0, microsecond=0) + if candidate > now: + candidates.append((candidate, msg_type)) + + if not candidates: + base = (now + timedelta(hours=1)).replace(minute=0, + second=0, microsecond=0) + next_dt = base.replace(minute=schedule_map[0][0]) + return {"at": next_dt, "type": schedule_map[0][1]} + + next_dt, next_type = min(candidates, key=lambda x: x[0]) + return {"at": next_dt, "type": next_type} + + +def start_dashboard() -> None: + """Start the minimal dashboard in a background thread.""" + enabled = os.getenv("DASHBOARD_ENABLED", "1").strip().lower() not in { + "0", "false", "no"} + if not enabled: + return + + host = os.getenv("DASHBOARD_HOST", "0.0.0.0") + port = int(os.getenv("DASHBOARD_PORT", "8080")) + + def _run() -> None: + try: + from dashboard import create_app + + app = create_app( + get_state=get_state_snapshot, + get_next_event=lambda: get_next_scheduled_event(), + ) + app.config["TEMPLATES_PATH"] = os.getenv( + "TEMPLATES_PATH", "templates.json") + app.run(host=host, port=port, debug=False, use_reloader=False) + except Exception as e: + logging.error(f"Dashboard failed to start: {e}") + + thread = threading.Thread(target=_run, name="dashboard", daemon=True) + thread.start() def load_timezones() -> list[dict]: @@ -57,36 +228,38 @@ def load_countries() -> list[dict]: return countries -def create_embed(type: str) -> dict: +def create_embed(type: str, tz_list: list[str] | None = None) -> dict: """ Create a Discord embed message. """ - color_orange = 0xe67e22 - color_green = 0x2ecc71 - messages = { - "reminder_halftime": {"text": "Half-time in 5 minutes!", "color": color_orange}, - "halftime": {"text": "Half-time!", "color": color_green}, - "reminder": {"text": "This is your 5 minute reminder to 420!", "color": color_orange}, - "420": {"text": "Blaze it!", "color": color_green}, - } - timezones = load_timezones() - countries = load_countries() + templates_path = os.getenv("TEMPLATES_PATH", "templates.json") + messages = load_templates(templates_path) + cache = TZDB_CACHE + timezones = load_timezones() if cache is None else [] + countries = load_countries() if cache is None else [] if type in messages: msg = messages[type] - if type in ["halftime", "420"]: - # Add an image for halftime and 420 - msg["image"] = { - "url": "https://copyparty.allucanget.biz/img/weed.png"} + image_url = msg.get("image_url") + if isinstance(image_url, str) and image_url: + msg["image"] = {"url": image_url} if type == "420": # Check where it's 4:20 - tz_list = where_is_it_420(timezones, countries) + if tz_list is None: + if cache is None: + tz_list = where_is_it_420(timezones, countries) + else: + tz_list = where_is_it_420( + timezones, + countries, + tz_names=cache.get("tz_names"), + tz_to_country_code=cache.get("tz_to_country_code"), + country_code_to_name=cache.get("country_code_to_name"), + ) if tz_list: tz_str = "\n".join(tz_list) msg["text"] += f"\nIt's 4:20 in:\n{tz_str}" - else: - msg["text"] += " It's not 4:20 anywhere right now." else: - msg = {"text": "Unknown notification type", "color": 0xFF0000} + msg = {"text": "", "color": 0xFF0000} embed = { "title": type.replace("_", " ").capitalize(), "description": msg["text"], @@ -98,6 +271,347 @@ def create_embed(type: str) -> dict: return embed +def get_discord_headers() -> dict[str, str]: + return { + "Authorization": f"Bot {DISCORD_BOT_TOKEN}", + "Content-Type": "application/json", + } + + +def parse_message_timestamp(message: dict) -> datetime: + return datetime.fromisoformat(message["timestamp"].replace("Z", "")) + + +def build_delete_entry(message: dict) -> dict: + return { + "id": message.get("id"), + "timestamp": parse_message_timestamp(message), + } + + +def parse_float(value: str | int | float | None) -> float | None: + if value is None: + return None + try: + return float(value) + except (TypeError, ValueError): + return None + + +def should_delete_message( + message: dict, + webhook_id: str, + author_id: str, + five_minutes_ago_timestamp: int, +) -> bool: + message_timestamp = int(parse_message_timestamp(message).timestamp()) + return ( + message_timestamp <= five_minutes_ago_timestamp + and message.get("webhook_id") == webhook_id + and message.get("author", {}).get("id") == author_id + ) + + +def get_rate_limit_retry_after(response: requests.Response) -> float | None: + header_retry_after = parse_float(response.headers.get("Retry-After")) + if header_retry_after is not None: + return header_retry_after + + reset_after = parse_float(response.headers.get("X-RateLimit-Reset-After")) + if reset_after is not None: + return reset_after + + try: + payload = response.json() + except ValueError: + return None + + return parse_float(payload.get("retry_after")) + + +def get_bucket_exhausted_delay(response: requests.Response) -> float | None: + remaining = response.headers.get("X-RateLimit-Remaining") + if remaining != "0": + return None + return parse_float(response.headers.get("X-RateLimit-Reset-After")) + + +def sleep_for_rate_limit(delay_seconds: float, reason: str) -> None: + if delay_seconds <= 0: + return + logging.info( + f"Waiting {delay_seconds:.3f}s for Discord rate limit reset ({reason}).") + time.sleep(delay_seconds) + + +def find_last_message_by_author( + headers: dict[str, str], + guild_id: str, + channel_id: str, + author_id: str, +) -> dict | None: + """Find the newest indexed message for the author in the target guild channel.""" + url = f"https://discord.com/api/v10/guilds/{guild_id}/messages/search" + params = [ + ("author_id", author_id), + ("author_type", "webhook"), + ("channel_id", channel_id), + ("sort_by", "timestamp"), + ("sort_order", "desc"), + ("limit", "10"), + ] + + for _ in range(3): + try: + response = requests.get( + url, headers=headers, params=params, timeout=10) + except requests.RequestException as e: + logging.error(f"Error searching guild messages: {e}") + return None + + if response.status_code == 202: + payload = response.json() + retry_after = float(payload.get("retry_after", 1) or 1) + logging.info( + f"Guild search index not ready. Retrying after {retry_after} seconds." + ) + time.sleep(retry_after) + continue + + if response.status_code == 429: + retry_after = get_rate_limit_retry_after(response) or 1.0 + sleep_for_rate_limit(retry_after, "guild search") + continue + + if response.status_code != 200: + logging.error( + f"Failed to search guild messages: {response.status_code} - {response.text}" + ) + return None + + message_groups = response.json().get("messages", []) + if not message_groups or not message_groups[0]: + return None + return message_groups[0][0] + + logging.error("Guild search index did not become available in time.") + return None + + +def fetch_messages_to_delete(headers: dict, channel_id: str, webhook_id: str, author_id: str, cutoff: int, last_message_id: str | None = None) -> tuple[list[dict], str | None]: + """ + Fetch messages from the channel that are older than the cutoff timestamp and sent by the webhook. + Uses pagination with the 'before' parameter to resume from the last processed message. + Returns a tuple of (list of messages to delete, last message ID for pagination). + """ + url = f"https://discord.com/api/v10/channels/{channel_id}/messages" + params: dict[str, str | int] = { + "limit": 100, # Maximum allowed by Discord API + } + + # Use the 'before' parameter for pagination if last_message_id is provided + if last_message_id: + params["before"] = last_message_id + + try: + for _ in range(3): + response = requests.get(url, headers=headers, + params=params, timeout=10) + + if response.status_code == 429: + retry_after = get_rate_limit_retry_after(response) or 1.0 + sleep_for_rate_limit(retry_after, "channel message fetch") + continue + + break + else: + logging.error( + "Failed to fetch messages after repeated rate limits.") + return [], last_message_id + + if response.status_code == 200: + messages = response.json() + delete_list = [] + new_last_message_id = None + + for message in messages: + new_last_message_id = message.get("id") + + if should_delete_message( + message, + webhook_id, + author_id, + cutoff, + ): + delete_list.append(build_delete_entry(message)) + + # Limit the list to 100 items + if len(delete_list) >= 100: + break + + return delete_list, new_last_message_id + else: + logging.error( + f"Failed to fetch messages: {response.status_code} - {response.text}") + return [], last_message_id + except requests.RequestException as e: + logging.error(f"Error fetching messages: {e}") + return [], last_message_id + + +def delete_message(headers: dict, channel_id: str, message_id: str) -> tuple[bool, float | None, bool]: + """ + Delete a single message from the channel. + + Returns: + tuple[bool, float | None, bool]: + - whether the delete succeeded + - how long to wait before the next request, if any + - whether to abort the batch because further requests would be invalid + """ + delete_url = f"https://discord.com/api/v10/channels/{channel_id}/messages/{message_id}" + delete_response = requests.delete(delete_url, headers=headers, timeout=10) + + if delete_response.status_code == 204: + return True, get_bucket_exhausted_delay(delete_response), False + + if delete_response.status_code == 429: + retry_after = get_rate_limit_retry_after(delete_response) or 1.0 + scope = delete_response.headers.get("X-RateLimit-Scope", "unknown") + is_global = delete_response.headers.get( + "X-RateLimit-Global", "false").lower() == "true" + logging.warning( + "Discord rate limit hit while deleting message %s: scope=%s global=%s retry_after=%.3fs", + message_id, + scope, + is_global, + retry_after, + ) + return False, retry_after, False + + if delete_response.status_code in {401, 403}: + logging.error( + "Failed to delete message %s: %s - %s. Stopping deletes to avoid invalid request spam.", + message_id, + delete_response.status_code, + delete_response.text, + ) + return False, None, True + + logging.error( + f"Failed to delete message {message_id}: {delete_response.status_code} - {delete_response.text}") + return False, None, False + + +def delete_old_messages(minutes: int = 6) -> None: + """ + Delete all messages sent by the webhook in the last `minutes` minutes. + Uses a dynamic slowdown to avoid hitting Discord API rate limits and pagination to fetch all messages. + """ + if not DISCORD_BOT_TOKEN or not DISCORD_CHANNEL_ID or not GUILD_ID: + logging.error( + "DISCORD_BOT_TOKEN, DISCORD_CHANNEL_ID, or DISCORD_GUILD_ID not set") + return + + headers = get_discord_headers() + + # Calculate the time `minutes` minutes ago + cutoff_timestamp = datetime.now() - timedelta(minutes=minutes) + cutoff = int(cutoff_timestamp.timestamp()) + webhook_id = '1413817504194760766' + author_id = '1413817504194760766' + + last_author_message = find_last_message_by_author( + headers, + GUILD_ID, + DISCORD_CHANNEL_ID, + author_id, + ) + if last_author_message is None: + logging.info("No indexed messages found for the target author.") + return + + last_message_id = last_author_message.get("id") + if not last_message_id: + logging.info("Search result did not contain a message id.") + return + + deleted_count = 0 + + if should_delete_message( + last_author_message, + webhook_id, + author_id, + cutoff, + ): + anchor_message = build_delete_entry(last_author_message) + deleted, wait_seconds, abort_batch = delete_message( + headers, + DISCORD_CHANNEL_ID, + anchor_message["id"], + ) + if deleted: + deleted_count += 1 + logging.info( + f"Deleted message {anchor_message['id']} from {anchor_message['timestamp'].isoformat()}" + ) + elif abort_batch: + return + + if wait_seconds is not None: + sleep_for_rate_limit(wait_seconds, "delete bucket") + + while True: + # Fetch messages to delete with pagination + delete_list, next_last_message_id = fetch_messages_to_delete( + headers, DISCORD_CHANNEL_ID, webhook_id, author_id, cutoff, last_message_id + ) + + # Exit loop if no more messages to delete + if not delete_list: + if deleted_count == 0: + if next_last_message_id is None or next_last_message_id == last_message_id: + logging.info("No messages to delete.") + break + else: + logging.info( + "No messages deleted in this batch, but more messages may exist... Continuing pagination.") + last_message_id = next_last_message_id + continue + else: + logging.info("No more messages to delete.") + break + + for message in delete_list: + message_id = message["id"] + message_time = message["timestamp"] + + deleted, wait_seconds, abort_batch = delete_message( + headers, + DISCORD_CHANNEL_ID, + message_id, + ) + + if deleted: + deleted_count += 1 + logging.info( + f"Deleted message {message_id} from {message_time.isoformat()}") + elif abort_batch: + logging.info( + "Stopping delete batch after an invalid Discord response.") + return + + if wait_seconds is not None: + sleep_for_rate_limit(wait_seconds, "delete bucket") + + if next_last_message_id is None or next_last_message_id == last_message_id: + break + last_message_id = next_last_message_id + + logging.info( + f"Deleted {deleted_count} messages older than {minutes} minutes sent by the webhook.") + + def send_notification(message: str) -> None: """ Send a notification to the Discord webhook. @@ -105,19 +619,57 @@ def send_notification(message: str) -> None: if not WEBHOOK_URL: logging.error("WEBHOOK_URL not set") return - embed = create_embed(message) + + _update_state( + last_type=message, + last_attempt_at=datetime.now(), + last_status_code=None, + last_error=None, + ) + + # Warm the tzdb cache once per process to avoid repeated CSV parsing + # and to avoid scanning all pytz timezones every hour. + try: + init_tzdb_cache() + except Exception as e: + logging.error(f"Failed to initialize tzdb cache: {e}") + + tz_list: list[str] | None = None + if message == "420": + try: + cache = TZDB_CACHE + if cache is None: + tz_list = where_is_it_420(load_timezones(), load_countries()) + else: + tz_list = where_is_it_420( + [], + [], + tz_names=cache.get("tz_names"), + tz_to_country_code=cache.get("tz_to_country_code"), + country_code_to_name=cache.get("country_code_to_name"), + ) + _update_state(last_locations=tz_list) + except Exception as e: + _update_state(last_locations=[], last_error=str(e)) + + embed = create_embed(message, tz_list=tz_list) data = {"embeds": [embed]} try: response = requests.post(WEBHOOK_URL, json=data, timeout=10) if response.status_code == 204: logging.info(f"Notification sent: {message}") + _update_state(last_success_at=datetime.now(), + last_status_code=response.status_code) else: logging.error( f"Failed to send notification: {response.status_code} - " f"{response.text}" ) + _update_state(last_status_code=response.status_code, + last_error=response.text) except requests.RequestException as e: logging.error(f"Error sending notification: {e}") + _update_state(last_error=str(e)) def get_tz_info(tz_name: str, timezones: list[dict]) -> dict | None: @@ -130,40 +682,86 @@ def get_country_info(country_code: str, countries: list[dict]) -> dict | None: return next((c for c in countries if c["country_code"] == country_code), None) -def where_is_it_420(timezones: list[dict], countries: list[dict]) -> list[str]: +def where_is_it_420( + timezones: list[dict], + countries: list[dict], + tz_names: list[str] | None = None, + tz_to_country_code: dict[str, str] | None = None, + country_code_to_name: dict[str, str] | None = None, +) -> list[str]: """Get timezones where the current hour is 4 or 16, indicating it's 4:20 there. Returns: list[str]: A list of timezones where it's currently 4:20 PM or AM. """ - tz_list = [] - for tz in pytz.all_timezones: - now = datetime.now(pytz.timezone(tz)) - if now.hour == 4 or now.hour == 16: - # Find the timezone in the loaded timezones - tz_info = get_tz_info(tz, timezones) - if tz_info: - country = get_country_info(tz_info["country_code"], countries) - if country: - country_name = country["country_name"].strip().strip('"') - if country_name not in tz_list: - tz_list.append(country_name) - return tz_list + # Build fast lookup dicts if not provided. + if tz_to_country_code is None: + tz_to_country_code = {} + for tz in timezones: + zone_name = tz.get("zone_name") + country_code = tz.get("country_code") + if isinstance(zone_name, str) and isinstance(country_code, str): + tz_to_country_code[zone_name] = country_code + + if country_code_to_name is None: + country_code_to_name = {} + for c in countries: + code = c.get("country_code") + name = c.get("country_name") + if isinstance(code, str) and name is not None: + country_code_to_name[code] = str(name).strip().strip('"') + + names_to_check = tz_names if tz_names is not None else pytz.all_timezones + results: list[str] = [] + seen: set[str] = set() + + for tz_name in names_to_check: + try: + tz_obj = pytz.timezone(tz_name) + except Exception: + continue + + now = datetime.now(tz_obj) + if now.hour != 4 and now.hour != 16: + continue + + country_code = tz_to_country_code.get(tz_name) + if not country_code: + continue + country_name = country_code_to_name.get(country_code) + if not country_name: + continue + if country_name in seen: + continue + + seen.add(country_name) + results.append(country_name) + + return results def main() -> None: """ Main function to run the scheduler. """ + # start_dashboard() + # Schedule notifications - schedule.every().hour.at(":15").do(send_notification, "reminder") - schedule.every().hour.at(":20").do(send_notification, "420") - schedule.every().hour.at(":45").do(send_notification, "reminder_halftime") - schedule.every().hour.at(":50").do(send_notification, "halftime") + # schedule.every().hour.at(":15").do(send_notification, "reminder") + # schedule.every().hour.at(":20").do(send_notification, "420") + # schedule.every().hour.at(":45").do(send_notification, "reminder_halftime") + # schedule.every().hour.at(":50").do(send_notification, "halftime") + + # Schedule deletion of old messages every 3 minutes + schedule.every(3).minutes.do(delete_old_messages, 6) + logging.info("Scheduler started.") # Test the notification on startup - send_notification("420") + # send_notification("420") + + # delete old messages on startup to clean up any previous notifications + delete_old_messages(60) try: while True: diff --git a/requirements.txt b/requirements.txt index ce9097e..8887562 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,9 @@ +numpy +pandas pytest python-dotenv pytz requests schedule +flask +tzdata diff --git a/templates.json b/templates.json new file mode 100644 index 0000000..21b12ac --- /dev/null +++ b/templates.json @@ -0,0 +1,20 @@ +{ + "420": { + "color": 3066993, + "image_url": "https://copyparty.allucanget.biz/img/weed.png", + "text": "Blaze it!" + }, + "halftime": { + "color": 3066993, + "image_url": "https://copyparty.allucanget.biz/img/weed.png", + "text": "Half-time!" + }, + "reminder": { + "color": 15105570, + "text": "This is your 5 minute reminder to 420!" + }, + "reminder_halftime": { + "color": 15105570, + "text": "Half-time in 5 minutes!" + } +} diff --git a/templates.py b/templates.py new file mode 100644 index 0000000..dfc4a76 --- /dev/null +++ b/templates.py @@ -0,0 +1,98 @@ +from __future__ import annotations + +import json +from copy import deepcopy +from pathlib import Path + + +DEFAULT_TEMPLATES: dict[str, dict] = { + "reminder_halftime": { + "text": "Half-time in 5 minutes!", + "color": 0xE67E22, + }, + "halftime": { + "text": "Half-time!", + "color": 0x2ECC71, + "image_url": "https://copyparty.allucanget.biz/img/weed.png", + }, + "reminder": { + "text": "This is your 5 minute reminder to 420!", + "color": 0xE67E22, + }, + "420": { + "text": "Blaze it!", + "color": 0x2ECC71, + "image_url": "https://copyparty.allucanget.biz/img/weed.png", + }, +} + + +def _normalize_templates(raw: dict) -> dict[str, dict]: + out: dict[str, dict] = deepcopy(DEFAULT_TEMPLATES) + + if not isinstance(raw, dict): + return out + + for key, default in DEFAULT_TEMPLATES.items(): + incoming = raw.get(key) + if not isinstance(incoming, dict): + continue + + text = incoming.get("text") + if isinstance(text, str): + out[key]["text"] = text + + color = incoming.get("color") + if isinstance(color, int): + out[key]["color"] = color + + image_url = incoming.get("image_url") + if isinstance(image_url, str) and image_url.strip(): + out[key]["image_url"] = image_url.strip() + elif "image_url" in default and image_url in (None, ""): + # Allow clearing image_url only if explicitly set to empty. + out[key].pop("image_url", None) + + return out + + +def load_templates(path: str | Path) -> dict[str, dict]: + p = Path(path) + try: + if not p.exists(): + return deepcopy(DEFAULT_TEMPLATES) + raw = json.loads(p.read_text(encoding="utf-8")) + return _normalize_templates(raw) + except Exception: + return deepcopy(DEFAULT_TEMPLATES) + + +def save_templates(path: str | Path, templates: dict) -> None: + p = Path(path) + normalized = _normalize_templates(templates) + + p.parent.mkdir(parents=True, exist_ok=True) + tmp = p.with_suffix(p.suffix + ".tmp") + tmp.write_text(json.dumps(normalized, indent=2, sort_keys=True) + "\n", encoding="utf-8") + tmp.replace(p) + + +def parse_color(value: str) -> int: + """Parse color from '#RRGGBB', 'RRGGBB', '0xRRGGBB', or decimal.""" + s = (value or "").strip().lower() + if not s: + raise ValueError("color is required") + + if s.startswith("#"): + s = s[1:] + + base = 16 + if s.startswith("0x"): + s = s[2:] + elif all(c.isdigit() for c in s): + base = 10 + + color = int(s, base) + if color < 0 or color > 0xFFFFFF: + raise ValueError("color must be between 0 and 0xFFFFFF") + return color diff --git a/test_print_timezones.py b/test_print_timezones.py new file mode 100644 index 0000000..a3b81c8 --- /dev/null +++ b/test_print_timezones.py @@ -0,0 +1,94 @@ +import pytz +from datetime import datetime +import pandas as pd + + +def get_tz_info(tz_name: str, timezones: list[dict]) -> dict | None: + """Get timezone info by name.""" + return next((tz for tz in timezones if tz["zone_name"] == tz_name), None) + + +def get_country_info(country_code: str, countries: list[dict]) -> dict | None: + """Get country info by country code.""" + return next((c for c in countries if c["country_code"] == country_code), None) + + +def where_is_it_420(timezones: list[dict], countries: list[dict]) -> list[str]: + """Get timezones where the current hour is 4 or 16, indicating it's 4:20 there. + + Returns: + list[str]: A list of timezones where it's currently 4:20 PM or AM. + """ + tz_list = [] + for tz in pytz.all_timezones: + now = datetime.now(pytz.timezone(tz)) + if now.hour == 4 or now.hour == 16: + # Find the timezone in the loaded timezones + tz_info = get_tz_info(tz, timezones) + if tz_info: + country = get_country_info(tz_info["country_code"], countries) + if country: + country_name = country["country_name"].strip().strip('"') + if country_name not in tz_list: + tz_list.append(country_name) + return tz_list + + +def load_tz_file(): + timezone_file = "./tzdb/TimeZoneDB.csv/time_zone.csv" + # column names in the csv + timezone_names = ["zone_name", "country_code", + "abbreviation", "time_start", "gmt_offset", "dst"] + # columns to load + load_columns = ["zone_name", "country_code"] + # read csv with pandas + df = pd.read_csv(timezone_file, names=timezone_names) + # drop all columns except load_columns + df = df[load_columns] + # distinct zone_names + df = df.drop_duplicates(subset=["zone_name"]) + + # reset index + df = df.reset_index(drop=True) + + return df + + +def main(): + + # read csv with pandas + df_file = load_tz_file() + + # split zone_name into components by "/" + df_file[['region', 'city']] = df_file['zone_name'].str.split( + '/', expand=True, n=1) + # drop regions with no country_code (like Etc, GMT, etc) + df_file = df_file[df_file['country_code'].notna()] + + df_tz = pd.DataFrame(pytz.all_timezones) + # rename column to zone_name + df_tz = df_tz.rename(columns={0: 'zone_name'}) + # split zone_name into components by "/" + df_tz[['region', 'city']] = df_tz['zone_name'].str.split( + '/', expand=True, n=1) + # drop regions with no city (like UTC, GMT, etc) + df_tz = df_tz[df_tz['city'].notna()] + # drop rows where region is 'Etc' + df_tz = df_tz[df_tz['region'] != 'Etc'] + + # join dataframes on region and city + df_merged = pd.merge(df_file, df_tz, on=[ + 'region', 'city'], how='inner', indicator=True) + # reorder columns + df_merged = df_merged[['region', 'city', 'country_code']] + # print merged dataframe + print(f"Merged timezones: {len(df_merged)}") + print(df_merged.sample(20).to_string(index=False)) + regions = df_merged['region'].unique() + for region in regions: + df_region = df_merged[df_merged['region'] == region] + print(f"{len(df_region)} merged in {region}") + + +if __name__ == "__main__": + main() diff --git a/tests/test_dashboard.py b/tests/test_dashboard.py new file mode 100644 index 0000000..72b98e3 --- /dev/null +++ b/tests/test_dashboard.py @@ -0,0 +1,65 @@ +from datetime import datetime + +from dashboard import create_app + + +def test_dashboard_renders_locations(): + state = { + "running": True, + "started_at": datetime(2025, 1, 1, 0, 0, 0), + "last_type": "420", + "last_attempt_at": datetime(2025, 1, 1, 0, 1, 2), + "last_success_at": datetime(2025, 1, 1, 0, 1, 3), + "last_status_code": 204, + "last_error": None, + "last_locations": ["Nowhere", "Somewhere"], + } + + app = create_app( + get_state=lambda: state, + get_next_event=lambda: {"type": "reminder", "at": datetime(2025, 1, 1, 0, 15, 0)}, + ) + + client = app.test_client() + resp = client.get("/") + assert resp.status_code == 200 + body = resp.data.decode("utf-8") + assert "thc-webhook" in body + assert "Locations" in body + assert "