feat: Add dashboard support and enhance Discord webhook functionality
- Updated docker-compose.yml to expose dashboard on port 8080. - Enhanced main.py with timezone database caching and improved state management. - Introduced a minimal dashboard using Flask to display webhook status and notifications. - Added templates.json for customizable embed messages in Discord notifications. - Created templates.py for loading and saving notification templates. - Implemented tests for dashboard rendering and main functionality. - Added requirements for Flask and tzdata to support new features. - Included test cases for timezone handling and template management.
This commit is contained in:
@@ -1,4 +1,5 @@
|
||||
.github/copilot-instructions.md
|
||||
.github/instructions/*
|
||||
# Byte-compiled / optimized / DLL files
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
|
||||
@@ -11,6 +11,9 @@ RUN pip install --no-cache-dir -r requirements.txt
|
||||
# Copy application code
|
||||
COPY . .
|
||||
|
||||
# Dashboard (Flask) port
|
||||
EXPOSE 8080
|
||||
|
||||
# Create a non-root user
|
||||
RUN addgroup -S appgroup && adduser -S appuser -G appgroup
|
||||
USER appuser
|
||||
|
||||
@@ -11,16 +11,21 @@ This Python application sends notifications to a Discord channel via webhook eve
|
||||
```
|
||||
|
||||
2. Set up your Discord webhook:
|
||||
|
||||
- Go to your Discord server settings > Integrations > Webhooks
|
||||
- Create a new webhook and copy the URL
|
||||
|
||||
3. Update the `.env` file with your webhook URL:
|
||||
3. Update the `.env` file with your webhook URL, bot token, channel ID, and guild ID:
|
||||
|
||||
```text
|
||||
DISCORD_WEBHOOK_URL=https://discord.com/api/webhooks/your_webhook_id/your_webhook_token
|
||||
DISCORD_BOT_TOKEN=your_bot_token
|
||||
DISCORD_CHANNEL_ID=your_channel_id
|
||||
DISCORD_GUILD_ID=your_guild_id
|
||||
```
|
||||
|
||||
- To get a bot token, create a Discord bot in the [Discord Developer Portal](https://discord.com/developers/applications) and invite it to your server with the `Manage Messages` permission.
|
||||
- The channel ID can be found by enabling Developer Mode in Discord and right-clicking the channel name.
|
||||
|
||||
## Running the App
|
||||
|
||||
Run the application:
|
||||
@@ -31,6 +36,19 @@ python main.py
|
||||
|
||||
The app will run continuously and send notifications at the scheduled times.
|
||||
|
||||
### Dashboard
|
||||
|
||||
By default, a minimal dashboard is available at `http://localhost:8080/`.
|
||||
|
||||
You can disable it by setting `DASHBOARD_ENABLED=0`.
|
||||
|
||||
### Admin
|
||||
|
||||
You can edit the embed message templates at `http://localhost:8080/admin`.
|
||||
|
||||
- Templates are saved to `templates.json` by default.
|
||||
- Override the location with `TEMPLATES_PATH=/path/to/templates.json`.
|
||||
|
||||
## Requirements
|
||||
|
||||
- Python 3.6+
|
||||
|
||||
+143
@@ -0,0 +1,143 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable
|
||||
from datetime import datetime
|
||||
|
||||
from flask import Flask, request
|
||||
|
||||
from templates import load_templates, parse_color, save_templates
|
||||
|
||||
|
||||
def _fmt_dt(dt: datetime | None) -> str:
|
||||
if dt is None:
|
||||
return "—"
|
||||
try:
|
||||
return dt.isoformat(timespec="seconds")
|
||||
except Exception:
|
||||
return str(dt)
|
||||
|
||||
|
||||
def create_app(
|
||||
*,
|
||||
get_state: Callable[[], dict],
|
||||
get_next_event: Callable[[], dict],
|
||||
) -> Flask:
|
||||
app = Flask(__name__)
|
||||
|
||||
@app.get("/")
|
||||
def index() -> str:
|
||||
state = get_state() or {}
|
||||
next_event = get_next_event() or {}
|
||||
|
||||
locations = state.get("last_locations") or []
|
||||
locations_html = "".join(f"<li>{loc}</li>" for loc in locations)
|
||||
|
||||
return (
|
||||
"<!doctype html>"
|
||||
"<html><head><meta charset='utf-8'><title>thc-webhook</title>"
|
||||
"<style>body{font-family:sans-serif;max-width:900px;margin:24px;}</style>"
|
||||
"</head><body>"
|
||||
"<h1>thc-webhook</h1>"
|
||||
"<h2>Status</h2>"
|
||||
"<ul>"
|
||||
f"<li>running: {state.get('running', True)}</li>"
|
||||
f"<li>started_at: {_fmt_dt(state.get('started_at'))}</li>"
|
||||
f"<li>last_type: {state.get('last_type') or '—'}</li>"
|
||||
f"<li>last_attempt_at: {_fmt_dt(state.get('last_attempt_at'))}</li>"
|
||||
f"<li>last_success_at: {_fmt_dt(state.get('last_success_at'))}</li>"
|
||||
f"<li>last_status_code: {state.get('last_status_code') or '—'}</li>"
|
||||
f"<li>last_error: {state.get('last_error') or '—'}</li>"
|
||||
"</ul>"
|
||||
"<h2>Next scheduled</h2>"
|
||||
"<ul>"
|
||||
f"<li>type: {next_event.get('type') or '—'}</li>"
|
||||
f"<li>at: {_fmt_dt(next_event.get('at'))}</li>"
|
||||
"</ul>"
|
||||
"<h2>Locations (latest)</h2>"
|
||||
f"<p>Total: {len(locations)}</p>"
|
||||
f"<ul>{locations_html}</ul>"
|
||||
"</body></html>"
|
||||
)
|
||||
|
||||
@app.get("/admin")
|
||||
def admin_get() -> str:
|
||||
templates_path = app.config.get("TEMPLATES_PATH", "templates.json")
|
||||
templates = load_templates(templates_path)
|
||||
|
||||
blocks = []
|
||||
for key, tpl in templates.items():
|
||||
text = (tpl.get("text") or "").replace("'", "'")
|
||||
color = tpl.get("color")
|
||||
image_url = tpl.get("image_url") or ""
|
||||
blocks.append(
|
||||
"<fieldset style='margin-bottom:16px;'>"
|
||||
f"<legend><strong>{key}</strong></legend>"
|
||||
f"<label>Text<br><textarea name='{key}__text' rows='3' style='width:100%'>{text}</textarea></label><br>"
|
||||
f"<label>Color<br><input name='{key}__color' value='{color}' style='width:200px'></label><br>"
|
||||
f"<label>Image URL (optional)<br><input name='{key}__image_url' value='{image_url}' style='width:100%'></label>"
|
||||
"</fieldset>"
|
||||
)
|
||||
|
||||
blocks_html = "".join(blocks)
|
||||
return (
|
||||
"<!doctype html>"
|
||||
"<html><head><meta charset='utf-8'><title>thc-webhook admin</title>"
|
||||
"<style>body{font-family:sans-serif;max-width:900px;margin:24px;}textarea,input{font-family:inherit;}</style>"
|
||||
"</head><body>"
|
||||
"<h1>Admin: templates</h1>"
|
||||
"<p>Edits are saved to the templates JSON file on disk.</p>"
|
||||
"<form method='post'>"
|
||||
f"{blocks_html}"
|
||||
"<button type='submit'>Save</button>"
|
||||
"</form>"
|
||||
"</body></html>"
|
||||
)
|
||||
|
||||
@app.post("/admin")
|
||||
def admin_post() -> tuple[str, int]:
|
||||
templates_path = app.config.get("TEMPLATES_PATH", "templates.json")
|
||||
current = load_templates(templates_path)
|
||||
|
||||
errors: list[str] = []
|
||||
updated: dict[str, dict] = {}
|
||||
for key in current.keys():
|
||||
text = request.form.get(f"{key}__text", "").strip()
|
||||
color_raw = request.form.get(f"{key}__color", "").strip()
|
||||
image_url = request.form.get(f"{key}__image_url", "").strip()
|
||||
|
||||
if not text:
|
||||
errors.append(f"{key}: text is required")
|
||||
continue
|
||||
try:
|
||||
color = parse_color(color_raw)
|
||||
except Exception as e:
|
||||
errors.append(f"{key}: invalid color ({e})")
|
||||
continue
|
||||
|
||||
tpl: dict = {"text": text, "color": color}
|
||||
if image_url:
|
||||
tpl["image_url"] = image_url
|
||||
updated[key] = tpl
|
||||
|
||||
if errors:
|
||||
msg = "<br>".join(errors)
|
||||
return (
|
||||
"<!doctype html><html><body>"
|
||||
"<h1>Admin: templates</h1>"
|
||||
f"<p style='color:#b00;'>Errors:<br>{msg}</p>"
|
||||
"<p><a href='/admin'>Back</a></p>"
|
||||
"</body></html>",
|
||||
400,
|
||||
)
|
||||
|
||||
save_templates(templates_path, updated)
|
||||
return (
|
||||
"<!doctype html><html><body>"
|
||||
"<h1>Admin: templates</h1>"
|
||||
"<p>Saved.</p>"
|
||||
"<p><a href='/admin'>Back</a></p>"
|
||||
"</body></html>",
|
||||
200,
|
||||
)
|
||||
|
||||
return app
|
||||
@@ -4,7 +4,10 @@ services:
|
||||
container_name: thc-webhook-app
|
||||
environment:
|
||||
- DISCORD_WEBHOOK_URL=${DISCORD_WEBHOOK_URL}
|
||||
- DASHBOARD_PORT=8080
|
||||
restart: unless-stopped
|
||||
ports:
|
||||
- "8080:8080"
|
||||
logging:
|
||||
driver: json-file
|
||||
options:
|
||||
|
||||
@@ -1,12 +1,32 @@
|
||||
import os
|
||||
import pytz
|
||||
from datetime import datetime
|
||||
from datetime import datetime, timedelta
|
||||
import time
|
||||
import logging
|
||||
import threading
|
||||
from zoneinfo import ZoneInfo, ZoneInfoNotFoundError
|
||||
import requests
|
||||
import schedule
|
||||
from dotenv import load_dotenv
|
||||
|
||||
from templates import load_templates
|
||||
|
||||
|
||||
TZDB_CACHE: dict | None = None
|
||||
|
||||
|
||||
STATE_LOCK = threading.Lock()
|
||||
STATE: dict = {
|
||||
"running": True,
|
||||
"started_at": datetime.now(),
|
||||
"last_type": None,
|
||||
"last_attempt_at": None,
|
||||
"last_success_at": None,
|
||||
"last_status_code": None,
|
||||
"last_error": None,
|
||||
"last_locations": [],
|
||||
}
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
@@ -17,6 +37,157 @@ logging.basicConfig(
|
||||
load_dotenv()
|
||||
|
||||
WEBHOOK_URL = os.getenv('DISCORD_WEBHOOK_URL')
|
||||
DISCORD_BOT_TOKEN = os.getenv('DISCORD_BOT_TOKEN')
|
||||
DISCORD_CHANNEL_ID = os.getenv('DISCORD_CHANNEL_ID')
|
||||
GUILD_ID = os.getenv('DISCORD_GUILD_ID')
|
||||
|
||||
|
||||
def init_tzdb_cache() -> dict:
|
||||
"""Initialize a cached lookup structure for tzdb data.
|
||||
|
||||
This keeps the hourly scheduler fast by:
|
||||
- Building O(1) maps (zone_name -> country_code, country_code -> country_name)
|
||||
- Precomputing a list of tz names that exist in both tzdb CSVs and `pytz`
|
||||
|
||||
Note: this is intentionally NOT run at import time so tests can monkeypatch
|
||||
`load_timezones`/`load_countries` without needing to reset global state.
|
||||
"""
|
||||
global TZDB_CACHE
|
||||
if TZDB_CACHE is not None:
|
||||
return TZDB_CACHE
|
||||
|
||||
timezones = load_timezones()
|
||||
countries = load_countries()
|
||||
|
||||
tz_to_country_code: dict[str, str] = {}
|
||||
tz_meta: dict[str, dict] = {}
|
||||
for tz in timezones:
|
||||
zone_name = tz.get("zone_name")
|
||||
country_code = tz.get("country_code")
|
||||
if not isinstance(zone_name, str) or not zone_name:
|
||||
continue
|
||||
if not isinstance(country_code, str) or not country_code:
|
||||
continue
|
||||
|
||||
tz_to_country_code[zone_name] = country_code
|
||||
region, city = split_tz_name(zone_name)
|
||||
tz_meta[zone_name] = {
|
||||
"zone_name": zone_name,
|
||||
"country_code": country_code,
|
||||
"region": region,
|
||||
"city": city,
|
||||
}
|
||||
|
||||
country_code_to_name: dict[str, str] = {}
|
||||
for c in countries:
|
||||
code = c.get("country_code")
|
||||
name = c.get("country_name")
|
||||
if code and name:
|
||||
country_code_to_name[code] = str(name).strip().strip('"')
|
||||
|
||||
# Attach resolved country names onto tz_meta (storage-only for now).
|
||||
for zone_name, meta in tz_meta.items():
|
||||
code = meta.get("country_code")
|
||||
if isinstance(code, str):
|
||||
meta["country_name"] = country_code_to_name.get(code)
|
||||
|
||||
# Vetted tz list: only names that are present in tzdb and loadable by zoneinfo.
|
||||
# Installing the `tzdata` package keeps this mapping up-to-date.
|
||||
tz_names: list[str] = []
|
||||
for zone_name in tz_to_country_code.keys():
|
||||
try:
|
||||
ZoneInfo(zone_name)
|
||||
except ZoneInfoNotFoundError:
|
||||
continue
|
||||
tz_names.append(zone_name)
|
||||
|
||||
TZDB_CACHE = {
|
||||
"tz_to_country_code": tz_to_country_code,
|
||||
"country_code_to_name": country_code_to_name,
|
||||
"tz_names": tz_names,
|
||||
"tz_meta": tz_meta,
|
||||
}
|
||||
return TZDB_CACHE
|
||||
|
||||
|
||||
def split_tz_name(zone_name: str) -> tuple[str, str]:
|
||||
"""Split an IANA timezone name into (region, city).
|
||||
|
||||
Examples:
|
||||
- "America/New_York" -> ("America", "New_York")
|
||||
- "America/Argentina/Buenos_Aires" -> ("America", "Argentina/Buenos_Aires")
|
||||
- "UTC" -> ("UTC", "")
|
||||
"""
|
||||
if "/" not in zone_name:
|
||||
return zone_name, ""
|
||||
region, rest = zone_name.split("/", 1)
|
||||
return region, rest
|
||||
|
||||
|
||||
def _update_state(**updates) -> None:
|
||||
with STATE_LOCK:
|
||||
STATE.update(updates)
|
||||
|
||||
|
||||
def get_state_snapshot() -> dict:
|
||||
with STATE_LOCK:
|
||||
return dict(STATE)
|
||||
|
||||
|
||||
def get_next_scheduled_event(now: datetime | None = None) -> dict:
|
||||
"""Return the next scheduled notification time/type based on known minute marks."""
|
||||
if now is None:
|
||||
now = datetime.now()
|
||||
|
||||
schedule_map = [
|
||||
(15, "reminder"),
|
||||
(20, "420"),
|
||||
(45, "reminder_halftime"),
|
||||
(50, "halftime"),
|
||||
]
|
||||
|
||||
candidates: list[tuple[datetime, str]] = []
|
||||
for minute, msg_type in schedule_map:
|
||||
candidate = now.replace(minute=minute, second=0, microsecond=0)
|
||||
if candidate > now:
|
||||
candidates.append((candidate, msg_type))
|
||||
|
||||
if not candidates:
|
||||
base = (now + timedelta(hours=1)).replace(minute=0,
|
||||
second=0, microsecond=0)
|
||||
next_dt = base.replace(minute=schedule_map[0][0])
|
||||
return {"at": next_dt, "type": schedule_map[0][1]}
|
||||
|
||||
next_dt, next_type = min(candidates, key=lambda x: x[0])
|
||||
return {"at": next_dt, "type": next_type}
|
||||
|
||||
|
||||
def start_dashboard() -> None:
|
||||
"""Start the minimal dashboard in a background thread."""
|
||||
enabled = os.getenv("DASHBOARD_ENABLED", "1").strip().lower() not in {
|
||||
"0", "false", "no"}
|
||||
if not enabled:
|
||||
return
|
||||
|
||||
host = os.getenv("DASHBOARD_HOST", "0.0.0.0")
|
||||
port = int(os.getenv("DASHBOARD_PORT", "8080"))
|
||||
|
||||
def _run() -> None:
|
||||
try:
|
||||
from dashboard import create_app
|
||||
|
||||
app = create_app(
|
||||
get_state=get_state_snapshot,
|
||||
get_next_event=lambda: get_next_scheduled_event(),
|
||||
)
|
||||
app.config["TEMPLATES_PATH"] = os.getenv(
|
||||
"TEMPLATES_PATH", "templates.json")
|
||||
app.run(host=host, port=port, debug=False, use_reloader=False)
|
||||
except Exception as e:
|
||||
logging.error(f"Dashboard failed to start: {e}")
|
||||
|
||||
thread = threading.Thread(target=_run, name="dashboard", daemon=True)
|
||||
thread.start()
|
||||
|
||||
|
||||
def load_timezones() -> list[dict]:
|
||||
@@ -57,36 +228,38 @@ def load_countries() -> list[dict]:
|
||||
return countries
|
||||
|
||||
|
||||
def create_embed(type: str) -> dict:
|
||||
def create_embed(type: str, tz_list: list[str] | None = None) -> dict:
|
||||
"""
|
||||
Create a Discord embed message.
|
||||
"""
|
||||
color_orange = 0xe67e22
|
||||
color_green = 0x2ecc71
|
||||
messages = {
|
||||
"reminder_halftime": {"text": "Half-time in 5 minutes!", "color": color_orange},
|
||||
"halftime": {"text": "Half-time!", "color": color_green},
|
||||
"reminder": {"text": "This is your 5 minute reminder to 420!", "color": color_orange},
|
||||
"420": {"text": "Blaze it!", "color": color_green},
|
||||
}
|
||||
timezones = load_timezones()
|
||||
countries = load_countries()
|
||||
templates_path = os.getenv("TEMPLATES_PATH", "templates.json")
|
||||
messages = load_templates(templates_path)
|
||||
cache = TZDB_CACHE
|
||||
timezones = load_timezones() if cache is None else []
|
||||
countries = load_countries() if cache is None else []
|
||||
if type in messages:
|
||||
msg = messages[type]
|
||||
if type in ["halftime", "420"]:
|
||||
# Add an image for halftime and 420
|
||||
msg["image"] = {
|
||||
"url": "https://copyparty.allucanget.biz/img/weed.png"}
|
||||
image_url = msg.get("image_url")
|
||||
if isinstance(image_url, str) and image_url:
|
||||
msg["image"] = {"url": image_url}
|
||||
if type == "420":
|
||||
# Check where it's 4:20
|
||||
if tz_list is None:
|
||||
if cache is None:
|
||||
tz_list = where_is_it_420(timezones, countries)
|
||||
else:
|
||||
tz_list = where_is_it_420(
|
||||
timezones,
|
||||
countries,
|
||||
tz_names=cache.get("tz_names"),
|
||||
tz_to_country_code=cache.get("tz_to_country_code"),
|
||||
country_code_to_name=cache.get("country_code_to_name"),
|
||||
)
|
||||
if tz_list:
|
||||
tz_str = "\n".join(tz_list)
|
||||
msg["text"] += f"\nIt's 4:20 in:\n{tz_str}"
|
||||
else:
|
||||
msg["text"] += " It's not 4:20 anywhere right now."
|
||||
else:
|
||||
msg = {"text": "Unknown notification type", "color": 0xFF0000}
|
||||
msg = {"text": "", "color": 0xFF0000}
|
||||
embed = {
|
||||
"title": type.replace("_", " ").capitalize(),
|
||||
"description": msg["text"],
|
||||
@@ -98,6 +271,347 @@ def create_embed(type: str) -> dict:
|
||||
return embed
|
||||
|
||||
|
||||
def get_discord_headers() -> dict[str, str]:
|
||||
return {
|
||||
"Authorization": f"Bot {DISCORD_BOT_TOKEN}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
|
||||
def parse_message_timestamp(message: dict) -> datetime:
|
||||
return datetime.fromisoformat(message["timestamp"].replace("Z", ""))
|
||||
|
||||
|
||||
def build_delete_entry(message: dict) -> dict:
|
||||
return {
|
||||
"id": message.get("id"),
|
||||
"timestamp": parse_message_timestamp(message),
|
||||
}
|
||||
|
||||
|
||||
def parse_float(value: str | int | float | None) -> float | None:
|
||||
if value is None:
|
||||
return None
|
||||
try:
|
||||
return float(value)
|
||||
except (TypeError, ValueError):
|
||||
return None
|
||||
|
||||
|
||||
def should_delete_message(
|
||||
message: dict,
|
||||
webhook_id: str,
|
||||
author_id: str,
|
||||
five_minutes_ago_timestamp: int,
|
||||
) -> bool:
|
||||
message_timestamp = int(parse_message_timestamp(message).timestamp())
|
||||
return (
|
||||
message_timestamp <= five_minutes_ago_timestamp
|
||||
and message.get("webhook_id") == webhook_id
|
||||
and message.get("author", {}).get("id") == author_id
|
||||
)
|
||||
|
||||
|
||||
def get_rate_limit_retry_after(response: requests.Response) -> float | None:
|
||||
header_retry_after = parse_float(response.headers.get("Retry-After"))
|
||||
if header_retry_after is not None:
|
||||
return header_retry_after
|
||||
|
||||
reset_after = parse_float(response.headers.get("X-RateLimit-Reset-After"))
|
||||
if reset_after is not None:
|
||||
return reset_after
|
||||
|
||||
try:
|
||||
payload = response.json()
|
||||
except ValueError:
|
||||
return None
|
||||
|
||||
return parse_float(payload.get("retry_after"))
|
||||
|
||||
|
||||
def get_bucket_exhausted_delay(response: requests.Response) -> float | None:
|
||||
remaining = response.headers.get("X-RateLimit-Remaining")
|
||||
if remaining != "0":
|
||||
return None
|
||||
return parse_float(response.headers.get("X-RateLimit-Reset-After"))
|
||||
|
||||
|
||||
def sleep_for_rate_limit(delay_seconds: float, reason: str) -> None:
|
||||
if delay_seconds <= 0:
|
||||
return
|
||||
logging.info(
|
||||
f"Waiting {delay_seconds:.3f}s for Discord rate limit reset ({reason}).")
|
||||
time.sleep(delay_seconds)
|
||||
|
||||
|
||||
def find_last_message_by_author(
|
||||
headers: dict[str, str],
|
||||
guild_id: str,
|
||||
channel_id: str,
|
||||
author_id: str,
|
||||
) -> dict | None:
|
||||
"""Find the newest indexed message for the author in the target guild channel."""
|
||||
url = f"https://discord.com/api/v10/guilds/{guild_id}/messages/search"
|
||||
params = [
|
||||
("author_id", author_id),
|
||||
("author_type", "webhook"),
|
||||
("channel_id", channel_id),
|
||||
("sort_by", "timestamp"),
|
||||
("sort_order", "desc"),
|
||||
("limit", "10"),
|
||||
]
|
||||
|
||||
for _ in range(3):
|
||||
try:
|
||||
response = requests.get(
|
||||
url, headers=headers, params=params, timeout=10)
|
||||
except requests.RequestException as e:
|
||||
logging.error(f"Error searching guild messages: {e}")
|
||||
return None
|
||||
|
||||
if response.status_code == 202:
|
||||
payload = response.json()
|
||||
retry_after = float(payload.get("retry_after", 1) or 1)
|
||||
logging.info(
|
||||
f"Guild search index not ready. Retrying after {retry_after} seconds."
|
||||
)
|
||||
time.sleep(retry_after)
|
||||
continue
|
||||
|
||||
if response.status_code == 429:
|
||||
retry_after = get_rate_limit_retry_after(response) or 1.0
|
||||
sleep_for_rate_limit(retry_after, "guild search")
|
||||
continue
|
||||
|
||||
if response.status_code != 200:
|
||||
logging.error(
|
||||
f"Failed to search guild messages: {response.status_code} - {response.text}"
|
||||
)
|
||||
return None
|
||||
|
||||
message_groups = response.json().get("messages", [])
|
||||
if not message_groups or not message_groups[0]:
|
||||
return None
|
||||
return message_groups[0][0]
|
||||
|
||||
logging.error("Guild search index did not become available in time.")
|
||||
return None
|
||||
|
||||
|
||||
def fetch_messages_to_delete(headers: dict, channel_id: str, webhook_id: str, author_id: str, cutoff: int, last_message_id: str | None = None) -> tuple[list[dict], str | None]:
|
||||
"""
|
||||
Fetch messages from the channel that are older than the cutoff timestamp and sent by the webhook.
|
||||
Uses pagination with the 'before' parameter to resume from the last processed message.
|
||||
Returns a tuple of (list of messages to delete, last message ID for pagination).
|
||||
"""
|
||||
url = f"https://discord.com/api/v10/channels/{channel_id}/messages"
|
||||
params: dict[str, str | int] = {
|
||||
"limit": 100, # Maximum allowed by Discord API
|
||||
}
|
||||
|
||||
# Use the 'before' parameter for pagination if last_message_id is provided
|
||||
if last_message_id:
|
||||
params["before"] = last_message_id
|
||||
|
||||
try:
|
||||
for _ in range(3):
|
||||
response = requests.get(url, headers=headers,
|
||||
params=params, timeout=10)
|
||||
|
||||
if response.status_code == 429:
|
||||
retry_after = get_rate_limit_retry_after(response) or 1.0
|
||||
sleep_for_rate_limit(retry_after, "channel message fetch")
|
||||
continue
|
||||
|
||||
break
|
||||
else:
|
||||
logging.error(
|
||||
"Failed to fetch messages after repeated rate limits.")
|
||||
return [], last_message_id
|
||||
|
||||
if response.status_code == 200:
|
||||
messages = response.json()
|
||||
delete_list = []
|
||||
new_last_message_id = None
|
||||
|
||||
for message in messages:
|
||||
new_last_message_id = message.get("id")
|
||||
|
||||
if should_delete_message(
|
||||
message,
|
||||
webhook_id,
|
||||
author_id,
|
||||
cutoff,
|
||||
):
|
||||
delete_list.append(build_delete_entry(message))
|
||||
|
||||
# Limit the list to 100 items
|
||||
if len(delete_list) >= 100:
|
||||
break
|
||||
|
||||
return delete_list, new_last_message_id
|
||||
else:
|
||||
logging.error(
|
||||
f"Failed to fetch messages: {response.status_code} - {response.text}")
|
||||
return [], last_message_id
|
||||
except requests.RequestException as e:
|
||||
logging.error(f"Error fetching messages: {e}")
|
||||
return [], last_message_id
|
||||
|
||||
|
||||
def delete_message(headers: dict, channel_id: str, message_id: str) -> tuple[bool, float | None, bool]:
|
||||
"""
|
||||
Delete a single message from the channel.
|
||||
|
||||
Returns:
|
||||
tuple[bool, float | None, bool]:
|
||||
- whether the delete succeeded
|
||||
- how long to wait before the next request, if any
|
||||
- whether to abort the batch because further requests would be invalid
|
||||
"""
|
||||
delete_url = f"https://discord.com/api/v10/channels/{channel_id}/messages/{message_id}"
|
||||
delete_response = requests.delete(delete_url, headers=headers, timeout=10)
|
||||
|
||||
if delete_response.status_code == 204:
|
||||
return True, get_bucket_exhausted_delay(delete_response), False
|
||||
|
||||
if delete_response.status_code == 429:
|
||||
retry_after = get_rate_limit_retry_after(delete_response) or 1.0
|
||||
scope = delete_response.headers.get("X-RateLimit-Scope", "unknown")
|
||||
is_global = delete_response.headers.get(
|
||||
"X-RateLimit-Global", "false").lower() == "true"
|
||||
logging.warning(
|
||||
"Discord rate limit hit while deleting message %s: scope=%s global=%s retry_after=%.3fs",
|
||||
message_id,
|
||||
scope,
|
||||
is_global,
|
||||
retry_after,
|
||||
)
|
||||
return False, retry_after, False
|
||||
|
||||
if delete_response.status_code in {401, 403}:
|
||||
logging.error(
|
||||
"Failed to delete message %s: %s - %s. Stopping deletes to avoid invalid request spam.",
|
||||
message_id,
|
||||
delete_response.status_code,
|
||||
delete_response.text,
|
||||
)
|
||||
return False, None, True
|
||||
|
||||
logging.error(
|
||||
f"Failed to delete message {message_id}: {delete_response.status_code} - {delete_response.text}")
|
||||
return False, None, False
|
||||
|
||||
|
||||
def delete_old_messages(minutes: int = 6) -> None:
|
||||
"""
|
||||
Delete all messages sent by the webhook in the last `minutes` minutes.
|
||||
Uses a dynamic slowdown to avoid hitting Discord API rate limits and pagination to fetch all messages.
|
||||
"""
|
||||
if not DISCORD_BOT_TOKEN or not DISCORD_CHANNEL_ID or not GUILD_ID:
|
||||
logging.error(
|
||||
"DISCORD_BOT_TOKEN, DISCORD_CHANNEL_ID, or DISCORD_GUILD_ID not set")
|
||||
return
|
||||
|
||||
headers = get_discord_headers()
|
||||
|
||||
# Calculate the time `minutes` minutes ago
|
||||
cutoff_timestamp = datetime.now() - timedelta(minutes=minutes)
|
||||
cutoff = int(cutoff_timestamp.timestamp())
|
||||
webhook_id = '1413817504194760766'
|
||||
author_id = '1413817504194760766'
|
||||
|
||||
last_author_message = find_last_message_by_author(
|
||||
headers,
|
||||
GUILD_ID,
|
||||
DISCORD_CHANNEL_ID,
|
||||
author_id,
|
||||
)
|
||||
if last_author_message is None:
|
||||
logging.info("No indexed messages found for the target author.")
|
||||
return
|
||||
|
||||
last_message_id = last_author_message.get("id")
|
||||
if not last_message_id:
|
||||
logging.info("Search result did not contain a message id.")
|
||||
return
|
||||
|
||||
deleted_count = 0
|
||||
|
||||
if should_delete_message(
|
||||
last_author_message,
|
||||
webhook_id,
|
||||
author_id,
|
||||
cutoff,
|
||||
):
|
||||
anchor_message = build_delete_entry(last_author_message)
|
||||
deleted, wait_seconds, abort_batch = delete_message(
|
||||
headers,
|
||||
DISCORD_CHANNEL_ID,
|
||||
anchor_message["id"],
|
||||
)
|
||||
if deleted:
|
||||
deleted_count += 1
|
||||
logging.info(
|
||||
f"Deleted message {anchor_message['id']} from {anchor_message['timestamp'].isoformat()}"
|
||||
)
|
||||
elif abort_batch:
|
||||
return
|
||||
|
||||
if wait_seconds is not None:
|
||||
sleep_for_rate_limit(wait_seconds, "delete bucket")
|
||||
|
||||
while True:
|
||||
# Fetch messages to delete with pagination
|
||||
delete_list, next_last_message_id = fetch_messages_to_delete(
|
||||
headers, DISCORD_CHANNEL_ID, webhook_id, author_id, cutoff, last_message_id
|
||||
)
|
||||
|
||||
# Exit loop if no more messages to delete
|
||||
if not delete_list:
|
||||
if deleted_count == 0:
|
||||
if next_last_message_id is None or next_last_message_id == last_message_id:
|
||||
logging.info("No messages to delete.")
|
||||
break
|
||||
else:
|
||||
logging.info(
|
||||
"No messages deleted in this batch, but more messages may exist... Continuing pagination.")
|
||||
last_message_id = next_last_message_id
|
||||
continue
|
||||
else:
|
||||
logging.info("No more messages to delete.")
|
||||
break
|
||||
|
||||
for message in delete_list:
|
||||
message_id = message["id"]
|
||||
message_time = message["timestamp"]
|
||||
|
||||
deleted, wait_seconds, abort_batch = delete_message(
|
||||
headers,
|
||||
DISCORD_CHANNEL_ID,
|
||||
message_id,
|
||||
)
|
||||
|
||||
if deleted:
|
||||
deleted_count += 1
|
||||
logging.info(
|
||||
f"Deleted message {message_id} from {message_time.isoformat()}")
|
||||
elif abort_batch:
|
||||
logging.info(
|
||||
"Stopping delete batch after an invalid Discord response.")
|
||||
return
|
||||
|
||||
if wait_seconds is not None:
|
||||
sleep_for_rate_limit(wait_seconds, "delete bucket")
|
||||
|
||||
if next_last_message_id is None or next_last_message_id == last_message_id:
|
||||
break
|
||||
last_message_id = next_last_message_id
|
||||
|
||||
logging.info(
|
||||
f"Deleted {deleted_count} messages older than {minutes} minutes sent by the webhook.")
|
||||
|
||||
|
||||
def send_notification(message: str) -> None:
|
||||
"""
|
||||
Send a notification to the Discord webhook.
|
||||
@@ -105,19 +619,57 @@ def send_notification(message: str) -> None:
|
||||
if not WEBHOOK_URL:
|
||||
logging.error("WEBHOOK_URL not set")
|
||||
return
|
||||
embed = create_embed(message)
|
||||
|
||||
_update_state(
|
||||
last_type=message,
|
||||
last_attempt_at=datetime.now(),
|
||||
last_status_code=None,
|
||||
last_error=None,
|
||||
)
|
||||
|
||||
# Warm the tzdb cache once per process to avoid repeated CSV parsing
|
||||
# and to avoid scanning all pytz timezones every hour.
|
||||
try:
|
||||
init_tzdb_cache()
|
||||
except Exception as e:
|
||||
logging.error(f"Failed to initialize tzdb cache: {e}")
|
||||
|
||||
tz_list: list[str] | None = None
|
||||
if message == "420":
|
||||
try:
|
||||
cache = TZDB_CACHE
|
||||
if cache is None:
|
||||
tz_list = where_is_it_420(load_timezones(), load_countries())
|
||||
else:
|
||||
tz_list = where_is_it_420(
|
||||
[],
|
||||
[],
|
||||
tz_names=cache.get("tz_names"),
|
||||
tz_to_country_code=cache.get("tz_to_country_code"),
|
||||
country_code_to_name=cache.get("country_code_to_name"),
|
||||
)
|
||||
_update_state(last_locations=tz_list)
|
||||
except Exception as e:
|
||||
_update_state(last_locations=[], last_error=str(e))
|
||||
|
||||
embed = create_embed(message, tz_list=tz_list)
|
||||
data = {"embeds": [embed]}
|
||||
try:
|
||||
response = requests.post(WEBHOOK_URL, json=data, timeout=10)
|
||||
if response.status_code == 204:
|
||||
logging.info(f"Notification sent: {message}")
|
||||
_update_state(last_success_at=datetime.now(),
|
||||
last_status_code=response.status_code)
|
||||
else:
|
||||
logging.error(
|
||||
f"Failed to send notification: {response.status_code} - "
|
||||
f"{response.text}"
|
||||
)
|
||||
_update_state(last_status_code=response.status_code,
|
||||
last_error=response.text)
|
||||
except requests.RequestException as e:
|
||||
logging.error(f"Error sending notification: {e}")
|
||||
_update_state(last_error=str(e))
|
||||
|
||||
|
||||
def get_tz_info(tz_name: str, timezones: list[dict]) -> dict | None:
|
||||
@@ -130,40 +682,86 @@ def get_country_info(country_code: str, countries: list[dict]) -> dict | None:
|
||||
return next((c for c in countries if c["country_code"] == country_code), None)
|
||||
|
||||
|
||||
def where_is_it_420(timezones: list[dict], countries: list[dict]) -> list[str]:
|
||||
def where_is_it_420(
|
||||
timezones: list[dict],
|
||||
countries: list[dict],
|
||||
tz_names: list[str] | None = None,
|
||||
tz_to_country_code: dict[str, str] | None = None,
|
||||
country_code_to_name: dict[str, str] | None = None,
|
||||
) -> list[str]:
|
||||
"""Get timezones where the current hour is 4 or 16, indicating it's 4:20 there.
|
||||
|
||||
Returns:
|
||||
list[str]: A list of timezones where it's currently 4:20 PM or AM.
|
||||
"""
|
||||
tz_list = []
|
||||
for tz in pytz.all_timezones:
|
||||
now = datetime.now(pytz.timezone(tz))
|
||||
if now.hour == 4 or now.hour == 16:
|
||||
# Find the timezone in the loaded timezones
|
||||
tz_info = get_tz_info(tz, timezones)
|
||||
if tz_info:
|
||||
country = get_country_info(tz_info["country_code"], countries)
|
||||
if country:
|
||||
country_name = country["country_name"].strip().strip('"')
|
||||
if country_name not in tz_list:
|
||||
tz_list.append(country_name)
|
||||
return tz_list
|
||||
# Build fast lookup dicts if not provided.
|
||||
if tz_to_country_code is None:
|
||||
tz_to_country_code = {}
|
||||
for tz in timezones:
|
||||
zone_name = tz.get("zone_name")
|
||||
country_code = tz.get("country_code")
|
||||
if isinstance(zone_name, str) and isinstance(country_code, str):
|
||||
tz_to_country_code[zone_name] = country_code
|
||||
|
||||
if country_code_to_name is None:
|
||||
country_code_to_name = {}
|
||||
for c in countries:
|
||||
code = c.get("country_code")
|
||||
name = c.get("country_name")
|
||||
if isinstance(code, str) and name is not None:
|
||||
country_code_to_name[code] = str(name).strip().strip('"')
|
||||
|
||||
names_to_check = tz_names if tz_names is not None else pytz.all_timezones
|
||||
results: list[str] = []
|
||||
seen: set[str] = set()
|
||||
|
||||
for tz_name in names_to_check:
|
||||
try:
|
||||
tz_obj = pytz.timezone(tz_name)
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
now = datetime.now(tz_obj)
|
||||
if now.hour != 4 and now.hour != 16:
|
||||
continue
|
||||
|
||||
country_code = tz_to_country_code.get(tz_name)
|
||||
if not country_code:
|
||||
continue
|
||||
country_name = country_code_to_name.get(country_code)
|
||||
if not country_name:
|
||||
continue
|
||||
if country_name in seen:
|
||||
continue
|
||||
|
||||
seen.add(country_name)
|
||||
results.append(country_name)
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def main() -> None:
|
||||
"""
|
||||
Main function to run the scheduler.
|
||||
"""
|
||||
# start_dashboard()
|
||||
|
||||
# Schedule notifications
|
||||
schedule.every().hour.at(":15").do(send_notification, "reminder")
|
||||
schedule.every().hour.at(":20").do(send_notification, "420")
|
||||
schedule.every().hour.at(":45").do(send_notification, "reminder_halftime")
|
||||
schedule.every().hour.at(":50").do(send_notification, "halftime")
|
||||
# schedule.every().hour.at(":15").do(send_notification, "reminder")
|
||||
# schedule.every().hour.at(":20").do(send_notification, "420")
|
||||
# schedule.every().hour.at(":45").do(send_notification, "reminder_halftime")
|
||||
# schedule.every().hour.at(":50").do(send_notification, "halftime")
|
||||
|
||||
# Schedule deletion of old messages every 3 minutes
|
||||
schedule.every(3).minutes.do(delete_old_messages, 6)
|
||||
|
||||
logging.info("Scheduler started.")
|
||||
|
||||
# Test the notification on startup
|
||||
send_notification("420")
|
||||
# send_notification("420")
|
||||
|
||||
# delete old messages on startup to clean up any previous notifications
|
||||
delete_old_messages(60)
|
||||
|
||||
try:
|
||||
while True:
|
||||
|
||||
@@ -1,5 +1,9 @@
|
||||
numpy
|
||||
pandas
|
||||
pytest
|
||||
python-dotenv
|
||||
pytz
|
||||
requests
|
||||
schedule
|
||||
flask
|
||||
tzdata
|
||||
|
||||
@@ -0,0 +1,20 @@
|
||||
{
|
||||
"420": {
|
||||
"color": 3066993,
|
||||
"image_url": "https://copyparty.allucanget.biz/img/weed.png",
|
||||
"text": "Blaze it!"
|
||||
},
|
||||
"halftime": {
|
||||
"color": 3066993,
|
||||
"image_url": "https://copyparty.allucanget.biz/img/weed.png",
|
||||
"text": "Half-time!"
|
||||
},
|
||||
"reminder": {
|
||||
"color": 15105570,
|
||||
"text": "This is your 5 minute reminder to 420!"
|
||||
},
|
||||
"reminder_halftime": {
|
||||
"color": 15105570,
|
||||
"text": "Half-time in 5 minutes!"
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,98 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from copy import deepcopy
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
DEFAULT_TEMPLATES: dict[str, dict] = {
|
||||
"reminder_halftime": {
|
||||
"text": "Half-time in 5 minutes!",
|
||||
"color": 0xE67E22,
|
||||
},
|
||||
"halftime": {
|
||||
"text": "Half-time!",
|
||||
"color": 0x2ECC71,
|
||||
"image_url": "https://copyparty.allucanget.biz/img/weed.png",
|
||||
},
|
||||
"reminder": {
|
||||
"text": "This is your 5 minute reminder to 420!",
|
||||
"color": 0xE67E22,
|
||||
},
|
||||
"420": {
|
||||
"text": "Blaze it!",
|
||||
"color": 0x2ECC71,
|
||||
"image_url": "https://copyparty.allucanget.biz/img/weed.png",
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def _normalize_templates(raw: dict) -> dict[str, dict]:
|
||||
out: dict[str, dict] = deepcopy(DEFAULT_TEMPLATES)
|
||||
|
||||
if not isinstance(raw, dict):
|
||||
return out
|
||||
|
||||
for key, default in DEFAULT_TEMPLATES.items():
|
||||
incoming = raw.get(key)
|
||||
if not isinstance(incoming, dict):
|
||||
continue
|
||||
|
||||
text = incoming.get("text")
|
||||
if isinstance(text, str):
|
||||
out[key]["text"] = text
|
||||
|
||||
color = incoming.get("color")
|
||||
if isinstance(color, int):
|
||||
out[key]["color"] = color
|
||||
|
||||
image_url = incoming.get("image_url")
|
||||
if isinstance(image_url, str) and image_url.strip():
|
||||
out[key]["image_url"] = image_url.strip()
|
||||
elif "image_url" in default and image_url in (None, ""):
|
||||
# Allow clearing image_url only if explicitly set to empty.
|
||||
out[key].pop("image_url", None)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
def load_templates(path: str | Path) -> dict[str, dict]:
|
||||
p = Path(path)
|
||||
try:
|
||||
if not p.exists():
|
||||
return deepcopy(DEFAULT_TEMPLATES)
|
||||
raw = json.loads(p.read_text(encoding="utf-8"))
|
||||
return _normalize_templates(raw)
|
||||
except Exception:
|
||||
return deepcopy(DEFAULT_TEMPLATES)
|
||||
|
||||
|
||||
def save_templates(path: str | Path, templates: dict) -> None:
|
||||
p = Path(path)
|
||||
normalized = _normalize_templates(templates)
|
||||
|
||||
p.parent.mkdir(parents=True, exist_ok=True)
|
||||
tmp = p.with_suffix(p.suffix + ".tmp")
|
||||
tmp.write_text(json.dumps(normalized, indent=2, sort_keys=True) + "\n", encoding="utf-8")
|
||||
tmp.replace(p)
|
||||
|
||||
|
||||
def parse_color(value: str) -> int:
|
||||
"""Parse color from '#RRGGBB', 'RRGGBB', '0xRRGGBB', or decimal."""
|
||||
s = (value or "").strip().lower()
|
||||
if not s:
|
||||
raise ValueError("color is required")
|
||||
|
||||
if s.startswith("#"):
|
||||
s = s[1:]
|
||||
|
||||
base = 16
|
||||
if s.startswith("0x"):
|
||||
s = s[2:]
|
||||
elif all(c.isdigit() for c in s):
|
||||
base = 10
|
||||
|
||||
color = int(s, base)
|
||||
if color < 0 or color > 0xFFFFFF:
|
||||
raise ValueError("color must be between 0 and 0xFFFFFF")
|
||||
return color
|
||||
@@ -0,0 +1,94 @@
|
||||
import pytz
|
||||
from datetime import datetime
|
||||
import pandas as pd
|
||||
|
||||
|
||||
def get_tz_info(tz_name: str, timezones: list[dict]) -> dict | None:
|
||||
"""Get timezone info by name."""
|
||||
return next((tz for tz in timezones if tz["zone_name"] == tz_name), None)
|
||||
|
||||
|
||||
def get_country_info(country_code: str, countries: list[dict]) -> dict | None:
|
||||
"""Get country info by country code."""
|
||||
return next((c for c in countries if c["country_code"] == country_code), None)
|
||||
|
||||
|
||||
def where_is_it_420(timezones: list[dict], countries: list[dict]) -> list[str]:
|
||||
"""Get timezones where the current hour is 4 or 16, indicating it's 4:20 there.
|
||||
|
||||
Returns:
|
||||
list[str]: A list of timezones where it's currently 4:20 PM or AM.
|
||||
"""
|
||||
tz_list = []
|
||||
for tz in pytz.all_timezones:
|
||||
now = datetime.now(pytz.timezone(tz))
|
||||
if now.hour == 4 or now.hour == 16:
|
||||
# Find the timezone in the loaded timezones
|
||||
tz_info = get_tz_info(tz, timezones)
|
||||
if tz_info:
|
||||
country = get_country_info(tz_info["country_code"], countries)
|
||||
if country:
|
||||
country_name = country["country_name"].strip().strip('"')
|
||||
if country_name not in tz_list:
|
||||
tz_list.append(country_name)
|
||||
return tz_list
|
||||
|
||||
|
||||
def load_tz_file():
|
||||
timezone_file = "./tzdb/TimeZoneDB.csv/time_zone.csv"
|
||||
# column names in the csv
|
||||
timezone_names = ["zone_name", "country_code",
|
||||
"abbreviation", "time_start", "gmt_offset", "dst"]
|
||||
# columns to load
|
||||
load_columns = ["zone_name", "country_code"]
|
||||
# read csv with pandas
|
||||
df = pd.read_csv(timezone_file, names=timezone_names)
|
||||
# drop all columns except load_columns
|
||||
df = df[load_columns]
|
||||
# distinct zone_names
|
||||
df = df.drop_duplicates(subset=["zone_name"])
|
||||
|
||||
# reset index
|
||||
df = df.reset_index(drop=True)
|
||||
|
||||
return df
|
||||
|
||||
|
||||
def main():
|
||||
|
||||
# read csv with pandas
|
||||
df_file = load_tz_file()
|
||||
|
||||
# split zone_name into components by "/"
|
||||
df_file[['region', 'city']] = df_file['zone_name'].str.split(
|
||||
'/', expand=True, n=1)
|
||||
# drop regions with no country_code (like Etc, GMT, etc)
|
||||
df_file = df_file[df_file['country_code'].notna()]
|
||||
|
||||
df_tz = pd.DataFrame(pytz.all_timezones)
|
||||
# rename column to zone_name
|
||||
df_tz = df_tz.rename(columns={0: 'zone_name'})
|
||||
# split zone_name into components by "/"
|
||||
df_tz[['region', 'city']] = df_tz['zone_name'].str.split(
|
||||
'/', expand=True, n=1)
|
||||
# drop regions with no city (like UTC, GMT, etc)
|
||||
df_tz = df_tz[df_tz['city'].notna()]
|
||||
# drop rows where region is 'Etc'
|
||||
df_tz = df_tz[df_tz['region'] != 'Etc']
|
||||
|
||||
# join dataframes on region and city
|
||||
df_merged = pd.merge(df_file, df_tz, on=[
|
||||
'region', 'city'], how='inner', indicator=True)
|
||||
# reorder columns
|
||||
df_merged = df_merged[['region', 'city', 'country_code']]
|
||||
# print merged dataframe
|
||||
print(f"Merged timezones: {len(df_merged)}")
|
||||
print(df_merged.sample(20).to_string(index=False))
|
||||
regions = df_merged['region'].unique()
|
||||
for region in regions:
|
||||
df_region = df_merged[df_merged['region'] == region]
|
||||
print(f"{len(df_region)} merged in {region}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,65 @@
|
||||
from datetime import datetime
|
||||
|
||||
from dashboard import create_app
|
||||
|
||||
|
||||
def test_dashboard_renders_locations():
|
||||
state = {
|
||||
"running": True,
|
||||
"started_at": datetime(2025, 1, 1, 0, 0, 0),
|
||||
"last_type": "420",
|
||||
"last_attempt_at": datetime(2025, 1, 1, 0, 1, 2),
|
||||
"last_success_at": datetime(2025, 1, 1, 0, 1, 3),
|
||||
"last_status_code": 204,
|
||||
"last_error": None,
|
||||
"last_locations": ["Nowhere", "Somewhere"],
|
||||
}
|
||||
|
||||
app = create_app(
|
||||
get_state=lambda: state,
|
||||
get_next_event=lambda: {"type": "reminder", "at": datetime(2025, 1, 1, 0, 15, 0)},
|
||||
)
|
||||
|
||||
client = app.test_client()
|
||||
resp = client.get("/")
|
||||
assert resp.status_code == 200
|
||||
body = resp.data.decode("utf-8")
|
||||
assert "thc-webhook" in body
|
||||
assert "Locations" in body
|
||||
assert "<li>Nowhere</li>" in body
|
||||
assert "<li>Somewhere</li>" in body
|
||||
|
||||
|
||||
def test_admin_roundtrip(tmp_path, monkeypatch):
|
||||
templates_path = tmp_path / "templates.json"
|
||||
monkeypatch.setenv("TEMPLATES_PATH", str(templates_path))
|
||||
|
||||
app = create_app(
|
||||
get_state=lambda: {},
|
||||
get_next_event=lambda: {"type": "reminder", "at": datetime(2025, 1, 1, 0, 15, 0)},
|
||||
)
|
||||
app.config["TEMPLATES_PATH"] = str(templates_path)
|
||||
client = app.test_client()
|
||||
|
||||
# GET admin should render
|
||||
resp = client.get("/admin")
|
||||
assert resp.status_code == 200
|
||||
|
||||
# POST should save
|
||||
form = {
|
||||
"reminder__text": "R",
|
||||
"reminder__color": "#e67e22",
|
||||
"reminder__image_url": "",
|
||||
"420__text": "B",
|
||||
"420__color": "0x2ecc71",
|
||||
"420__image_url": "http://example.com/img.png",
|
||||
"reminder_halftime__text": "H",
|
||||
"reminder_halftime__color": "15158306",
|
||||
"reminder_halftime__image_url": "",
|
||||
"halftime__text": "HT",
|
||||
"halftime__color": "3066993",
|
||||
"halftime__image_url": "",
|
||||
}
|
||||
resp = client.post("/admin", data=form)
|
||||
assert resp.status_code == 200
|
||||
assert templates_path.exists()
|
||||
@@ -54,13 +54,13 @@ def test_create_embed_all_types(monkeypatch):
|
||||
assert 'Half-time in 5 minutes' in emb['description']
|
||||
|
||||
# halftime (should include image)
|
||||
monkeypatch.setattr(main, 'where_is_it_420', lambda tzs, cs: [])
|
||||
monkeypatch.setattr(main, 'where_is_it_420', lambda tzs, cs, **kwargs: [])
|
||||
emb = main.create_embed('halftime')
|
||||
assert emb['title'] == 'Halftime'
|
||||
assert emb['image'] is not None
|
||||
|
||||
# 420 (should include image and appended tz info string when list empty)
|
||||
monkeypatch.setattr(main, 'where_is_it_420', lambda tzs, cs: [])
|
||||
monkeypatch.setattr(main, 'where_is_it_420', lambda tzs, cs, **kwargs: [])
|
||||
emb = main.create_embed('420')
|
||||
assert emb['title'] == '420'
|
||||
assert emb['image'] is not None
|
||||
@@ -98,6 +98,8 @@ def test_where_is_it_420(monkeypatch):
|
||||
def test_main_exits_quickly(monkeypatch):
|
||||
# Patch send_notification so it doesn't perform network
|
||||
monkeypatch.setattr(main, 'send_notification', lambda x: None)
|
||||
# Don't start dashboard during this test
|
||||
monkeypatch.setattr(main, 'start_dashboard', lambda: None)
|
||||
# Make schedule.run_pending raise KeyboardInterrupt to exit loop
|
||||
monkeypatch.setattr(main.schedule, 'run_pending', lambda: (
|
||||
_ for _ in ()).throw(KeyboardInterrupt()))
|
||||
@@ -109,3 +111,26 @@ def test_main_exits_quickly(monkeypatch):
|
||||
|
||||
# Should exit quickly due to KeyboardInterrupt from run_pending
|
||||
main.main()
|
||||
|
||||
|
||||
def test_get_next_scheduled_event():
|
||||
# 10:14 -> next is 10:15 reminder
|
||||
now = main.datetime(2025, 1, 1, 10, 14, 30)
|
||||
nxt = main.get_next_scheduled_event(now)
|
||||
assert nxt["type"] == "reminder"
|
||||
assert nxt["at"].hour == 10 and nxt["at"].minute == 15
|
||||
|
||||
# 10:50:01 -> next is 11:15 reminder
|
||||
now = main.datetime(2025, 1, 1, 10, 50, 1)
|
||||
nxt = main.get_next_scheduled_event(now)
|
||||
assert nxt["type"] == "reminder"
|
||||
assert nxt["at"].hour == 11 and nxt["at"].minute == 15
|
||||
|
||||
|
||||
def test_split_tz_name():
|
||||
assert main.split_tz_name("America/New_York") == ("America", "New_York")
|
||||
assert main.split_tz_name("America/Argentina/Buenos_Aires") == (
|
||||
"America",
|
||||
"Argentina/Buenos_Aires",
|
||||
)
|
||||
assert main.split_tz_name("UTC") == ("UTC", "")
|
||||
@@ -0,0 +1,27 @@
|
||||
from templates import DEFAULT_TEMPLATES, load_templates, parse_color, save_templates
|
||||
|
||||
|
||||
def test_parse_color():
|
||||
assert parse_color("#e67e22") == 0xE67E22
|
||||
assert parse_color("e67e22") == 0xE67E22
|
||||
assert parse_color("0xe67e22") == 0xE67E22
|
||||
assert parse_color("15158306") == 15158306
|
||||
|
||||
|
||||
def test_load_templates_missing_file(tmp_path):
|
||||
templates = load_templates(tmp_path / "missing.json")
|
||||
assert templates["420"]["text"] == DEFAULT_TEMPLATES["420"]["text"]
|
||||
|
||||
|
||||
def test_save_and_load_templates_roundtrip(tmp_path):
|
||||
path = tmp_path / "templates.json"
|
||||
data = {
|
||||
"420": {"text": "Custom", "color": 123},
|
||||
}
|
||||
save_templates(path, data)
|
||||
|
||||
loaded = load_templates(path)
|
||||
assert loaded["420"]["text"] == "Custom"
|
||||
assert loaded["420"]["color"] == 123
|
||||
# defaults should still exist for other keys
|
||||
assert "reminder" in loaded
|
||||
Reference in New Issue
Block a user