Compare commits
12 Commits
1dcd80a8bb
..
main
| Author | SHA1 | Date | |
|---|---|---|---|
| 3412a5ccaa | |||
| 915c55d7ed | |||
| 788f3ea6b7 | |||
| 0952c21c7b | |||
| f88f60a019 | |||
| 565c4078bb | |||
| 8f8c3655db | |||
| 584231b0df | |||
| a4c92470b6 | |||
| 01d94376d4 | |||
| 3f5630da2c | |||
| 6bae1e2f66 |
@@ -1,2 +1,11 @@
|
|||||||
# Replace with your actual Discord webhook URL
|
# Replace with your actual Discord webhook URL
|
||||||
DISCORD_WEBHOOK_URL=https://discord.com/api/webhooks/<your_webhook_id>/<your_webhook_token>
|
DISCORD_WEBHOOK_URL=https://discord.com/api/webhooks/<your_webhook_id>/<your_webhook_token>
|
||||||
|
|
||||||
|
# Replace with your Discord bot token
|
||||||
|
DISCORD_BOT_TOKEN=<your_bot_token>
|
||||||
|
|
||||||
|
# Replace with your Discord channel ID
|
||||||
|
DISCORD_CHANNEL_ID=<your_channel_id>
|
||||||
|
|
||||||
|
# Replace with your Discord guild/server ID
|
||||||
|
DISCORD_GUILD_ID=<your_guild_id>
|
||||||
|
|||||||
@@ -53,15 +53,3 @@ jobs:
|
|||||||
context: .
|
context: .
|
||||||
push: true
|
push: true
|
||||||
tags: git.allucanget.biz/${{ secrets.REGISTRY_USERNAME }}/thc-webhook:latest
|
tags: git.allucanget.biz/${{ secrets.REGISTRY_USERNAME }}/thc-webhook:latest
|
||||||
|
|
||||||
- name: Deploy to Portainer
|
|
||||||
uses: appleboy/ssh-action@v0.1.7
|
|
||||||
with:
|
|
||||||
host: ${{ secrets.SERVER_HOST }}
|
|
||||||
username: ${{ secrets.SERVER_USER }}
|
|
||||||
key: ${{ secrets.SERVER_SSH_KEY }}
|
|
||||||
script: |
|
|
||||||
docker stop thc-webhook || true
|
|
||||||
docker rm thc-webhook || true
|
|
||||||
docker pull git.allucanget.biz/${{ secrets.REGISTRY_USERNAME }}/thc-webhook:latest
|
|
||||||
docker run -d --name thc-webhook -e DISCORD_WEBHOOK_URL=${{ secrets.DISCORD_WEBHOOK_URL }} -p 8080:8080 git.allucanget.biz/${{ secrets.REGISTRY_USERNAME }}/thc-webhook:latest
|
|
||||||
|
|||||||
+1
-1
@@ -12,7 +12,7 @@ RUN pip install --no-cache-dir -r requirements.txt
|
|||||||
COPY . .
|
COPY . .
|
||||||
|
|
||||||
# Dashboard (Flask) port
|
# Dashboard (Flask) port
|
||||||
EXPOSE 8080
|
EXPOSE 8420
|
||||||
|
|
||||||
# Create a non-root user
|
# Create a non-root user
|
||||||
RUN addgroup -S appgroup && adduser -S appuser -G appgroup
|
RUN addgroup -S appgroup && adduser -S appuser -G appgroup
|
||||||
|
|||||||
@@ -38,13 +38,13 @@ The app will run continuously and send notifications at the scheduled times.
|
|||||||
|
|
||||||
### Dashboard
|
### Dashboard
|
||||||
|
|
||||||
By default, a minimal dashboard is available at `http://localhost:8080/`.
|
By default, a minimal dashboard is available at `http://localhost:8420/`.
|
||||||
|
|
||||||
You can disable it by setting `DASHBOARD_ENABLED=0`.
|
You can disable it by setting `DASHBOARD_ENABLED=0`.
|
||||||
|
|
||||||
### Admin
|
### Admin
|
||||||
|
|
||||||
You can edit the embed message templates at `http://localhost:8080/admin`.
|
You can edit the embed message templates at `http://localhost:8420/admin`.
|
||||||
|
|
||||||
- Templates are saved to `templates.json` by default.
|
- Templates are saved to `templates.json` by default.
|
||||||
- Override the location with `TEMPLATES_PATH=/path/to/templates.json`.
|
- Override the location with `TEMPLATES_PATH=/path/to/templates.json`.
|
||||||
|
|||||||
+38
-24
@@ -17,6 +17,32 @@ def _fmt_dt(dt: datetime | None) -> str:
|
|||||||
return str(dt)
|
return str(dt)
|
||||||
|
|
||||||
|
|
||||||
|
HTML_TEMPLATE = (
|
||||||
|
"<!doctype html><html><head><meta charset='utf-8'><title>thc-webhook admin</title>"
|
||||||
|
"<style>body{{font-family:sans-serif;max-width:900px;margin:24px;}}textarea,input{{font-family:inherit;}}</style>"
|
||||||
|
"</head><body>"
|
||||||
|
"<h1>Admin: templates</h1>"
|
||||||
|
"<p>Edits are saved to the templates JSON file on disk.</p>"
|
||||||
|
"{content}"
|
||||||
|
"</body></html>"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_html_template(content) -> str:
|
||||||
|
return HTML_TEMPLATE.format(content=content)
|
||||||
|
|
||||||
|
|
||||||
|
def _as_hex_color(value: int | str | None) -> str:
|
||||||
|
if isinstance(value, int):
|
||||||
|
return f"#{value:06X}"
|
||||||
|
if isinstance(value, str):
|
||||||
|
try:
|
||||||
|
return f"#{parse_color(value):06X}"
|
||||||
|
except ValueError:
|
||||||
|
return "#000000"
|
||||||
|
return "#000000"
|
||||||
|
|
||||||
|
|
||||||
def create_app(
|
def create_app(
|
||||||
*,
|
*,
|
||||||
get_state: Callable[[], dict],
|
get_state: Callable[[], dict],
|
||||||
@@ -32,11 +58,7 @@ def create_app(
|
|||||||
locations = state.get("last_locations") or []
|
locations = state.get("last_locations") or []
|
||||||
locations_html = "".join(f"<li>{loc}</li>" for loc in locations)
|
locations_html = "".join(f"<li>{loc}</li>" for loc in locations)
|
||||||
|
|
||||||
return (
|
return get_html_template(
|
||||||
"<!doctype html>"
|
|
||||||
"<html><head><meta charset='utf-8'><title>thc-webhook</title>"
|
|
||||||
"<style>body{font-family:sans-serif;max-width:900px;margin:24px;}</style>"
|
|
||||||
"</head><body>"
|
|
||||||
"<h1>thc-webhook</h1>"
|
"<h1>thc-webhook</h1>"
|
||||||
"<h2>Status</h2>"
|
"<h2>Status</h2>"
|
||||||
"<ul>"
|
"<ul>"
|
||||||
@@ -56,7 +78,6 @@ def create_app(
|
|||||||
"<h2>Locations (latest)</h2>"
|
"<h2>Locations (latest)</h2>"
|
||||||
f"<p>Total: {len(locations)}</p>"
|
f"<p>Total: {len(locations)}</p>"
|
||||||
f"<ul>{locations_html}</ul>"
|
f"<ul>{locations_html}</ul>"
|
||||||
"</body></html>"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
@app.get("/admin")
|
@app.get("/admin")
|
||||||
@@ -67,30 +88,26 @@ def create_app(
|
|||||||
blocks = []
|
blocks = []
|
||||||
for key, tpl in templates.items():
|
for key, tpl in templates.items():
|
||||||
text = (tpl.get("text") or "").replace("'", "'")
|
text = (tpl.get("text") or "").replace("'", "'")
|
||||||
color = tpl.get("color")
|
color_hex = _as_hex_color(tpl.get("color"))
|
||||||
image_url = tpl.get("image_url") or ""
|
image_url = tpl.get("image_url") or ""
|
||||||
blocks.append(
|
blocks.append(
|
||||||
"<fieldset style='margin-bottom:16px;'>"
|
"<fieldset style='margin-bottom:16px;'>"
|
||||||
f"<legend><strong>{key}</strong></legend>"
|
f"<legend><strong>{key}</strong></legend>"
|
||||||
f"<label>Text<br><textarea name='{key}__text' rows='3' style='width:100%'>{text}</textarea></label><br>"
|
f"<label>Text<br><textarea name='{key}__text' rows='3' style='width:100%'>{text}</textarea></label><br>"
|
||||||
f"<label>Color<br><input name='{key}__color' value='{color}' style='width:200px'></label><br>"
|
f"<label>Color<br><input type='color' name='{key}__color_picker' value='{color_hex}' oninput=\"this.form['{key}__color'].value=this.value\"></label><br>"
|
||||||
|
f"<label>Color value<br><input name='{key}__color' value='{color_hex}' style='width:200px'></label><br>"
|
||||||
f"<label>Image URL (optional)<br><input name='{key}__image_url' value='{image_url}' style='width:100%'></label>"
|
f"<label>Image URL (optional)<br><input name='{key}__image_url' value='{image_url}' style='width:100%'></label>"
|
||||||
"</fieldset>"
|
"</fieldset>"
|
||||||
)
|
)
|
||||||
|
|
||||||
blocks_html = "".join(blocks)
|
blocks_html = "".join(blocks)
|
||||||
return (
|
return get_html_template(
|
||||||
"<!doctype html>"
|
|
||||||
"<html><head><meta charset='utf-8'><title>thc-webhook admin</title>"
|
|
||||||
"<style>body{font-family:sans-serif;max-width:900px;margin:24px;}textarea,input{font-family:inherit;}</style>"
|
|
||||||
"</head><body>"
|
|
||||||
"<h1>Admin: templates</h1>"
|
"<h1>Admin: templates</h1>"
|
||||||
"<p>Edits are saved to the templates JSON file on disk.</p>"
|
"<p>Edits are saved to the templates JSON file on disk.</p>"
|
||||||
"<form method='post'>"
|
"<form method='post'>"
|
||||||
f"{blocks_html}"
|
f"{blocks_html}"
|
||||||
"<button type='submit'>Save</button>"
|
"<button type='submit'>Save</button>"
|
||||||
"</form>"
|
"</form>"
|
||||||
"</body></html>"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
@app.post("/admin")
|
@app.post("/admin")
|
||||||
@@ -103,6 +120,9 @@ def create_app(
|
|||||||
for key in current.keys():
|
for key in current.keys():
|
||||||
text = request.form.get(f"{key}__text", "").strip()
|
text = request.form.get(f"{key}__text", "").strip()
|
||||||
color_raw = request.form.get(f"{key}__color", "").strip()
|
color_raw = request.form.get(f"{key}__color", "").strip()
|
||||||
|
if not color_raw:
|
||||||
|
color_raw = request.form.get(
|
||||||
|
f"{key}__color_picker", "").strip()
|
||||||
image_url = request.form.get(f"{key}__image_url", "").strip()
|
image_url = request.form.get(f"{key}__image_url", "").strip()
|
||||||
|
|
||||||
if not text:
|
if not text:
|
||||||
@@ -121,23 +141,17 @@ def create_app(
|
|||||||
|
|
||||||
if errors:
|
if errors:
|
||||||
msg = "<br>".join(errors)
|
msg = "<br>".join(errors)
|
||||||
return (
|
return get_html_template(
|
||||||
"<!doctype html><html><body>"
|
|
||||||
"<h1>Admin: templates</h1>"
|
"<h1>Admin: templates</h1>"
|
||||||
f"<p style='color:#b00;'>Errors:<br>{msg}</p>"
|
f"<p style='color:#b00;'>Errors:<br>{msg}</p>"
|
||||||
"<p><a href='/admin'>Back</a></p>"
|
"<p><a href='/admin'>Back</a></p>"
|
||||||
"</body></html>",
|
), 400
|
||||||
400,
|
|
||||||
)
|
|
||||||
|
|
||||||
save_templates(templates_path, updated)
|
save_templates(templates_path, updated)
|
||||||
return (
|
return get_html_template(
|
||||||
"<!doctype html><html><body>"
|
|
||||||
"<h1>Admin: templates</h1>"
|
"<h1>Admin: templates</h1>"
|
||||||
"<p>Saved.</p>"
|
"<p>Saved.</p>"
|
||||||
"<p><a href='/admin'>Back</a></p>"
|
"<p><a href='/admin'>Back</a></p>"
|
||||||
"</body></html>",
|
), 200
|
||||||
200,
|
|
||||||
)
|
|
||||||
|
|
||||||
return app
|
return app
|
||||||
|
|||||||
+2
-2
@@ -4,10 +4,10 @@ services:
|
|||||||
container_name: thc-webhook-app
|
container_name: thc-webhook-app
|
||||||
environment:
|
environment:
|
||||||
- DISCORD_WEBHOOK_URL=${DISCORD_WEBHOOK_URL}
|
- DISCORD_WEBHOOK_URL=${DISCORD_WEBHOOK_URL}
|
||||||
- DASHBOARD_PORT=8080
|
- DASHBOARD_PORT=8420
|
||||||
restart: unless-stopped
|
restart: unless-stopped
|
||||||
ports:
|
ports:
|
||||||
- "8080:8080"
|
- "8420:8420"
|
||||||
logging:
|
logging:
|
||||||
driver: json-file
|
driver: json-file
|
||||||
options:
|
options:
|
||||||
|
|||||||
@@ -1,19 +1,32 @@
|
|||||||
import os
|
import os
|
||||||
import pytz
|
|
||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
import time
|
import time
|
||||||
import logging
|
import logging
|
||||||
import threading
|
import threading
|
||||||
from zoneinfo import ZoneInfo, ZoneInfoNotFoundError
|
|
||||||
import requests
|
import requests
|
||||||
import schedule
|
import schedule
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
from templates import load_templates
|
from templates import load_templates
|
||||||
|
from dashboard import create_app
|
||||||
|
from maintenance import delete_old_messages
|
||||||
|
from thctime import (
|
||||||
|
get_country_info,
|
||||||
|
get_tz_info,
|
||||||
|
get_tzdb_cache,
|
||||||
|
init_tzdb_cache,
|
||||||
|
load_countries,
|
||||||
|
load_timezones,
|
||||||
|
split_tz_name,
|
||||||
|
where_is_it_420,
|
||||||
|
)
|
||||||
|
|
||||||
|
SCHEDULED_NOTIFICATIONS = [
|
||||||
TZDB_CACHE: dict | None = None
|
(15, "reminder"),
|
||||||
|
(20, "420"),
|
||||||
|
(45, "reminder_halftime"),
|
||||||
|
(50, "halftime"),
|
||||||
|
]
|
||||||
|
|
||||||
STATE_LOCK = threading.Lock()
|
STATE_LOCK = threading.Lock()
|
||||||
STATE: dict = {
|
STATE: dict = {
|
||||||
@@ -37,91 +50,16 @@ logging.basicConfig(
|
|||||||
load_dotenv()
|
load_dotenv()
|
||||||
|
|
||||||
WEBHOOK_URL = os.getenv('DISCORD_WEBHOOK_URL')
|
WEBHOOK_URL = os.getenv('DISCORD_WEBHOOK_URL')
|
||||||
DISCORD_BOT_TOKEN = os.getenv('DISCORD_BOT_TOKEN')
|
TEST_MESSAGE_DELETE_PATTERN = r"test notification"
|
||||||
DISCORD_CHANNEL_ID = os.getenv('DISCORD_CHANNEL_ID')
|
|
||||||
GUILD_ID = os.getenv('DISCORD_GUILD_ID')
|
|
||||||
|
|
||||||
|
|
||||||
def init_tzdb_cache() -> dict:
|
def get_state() -> dict:
|
||||||
"""Initialize a cached lookup structure for tzdb data.
|
with STATE_LOCK:
|
||||||
|
return dict(STATE)
|
||||||
This keeps the hourly scheduler fast by:
|
|
||||||
- Building O(1) maps (zone_name -> country_code, country_code -> country_name)
|
|
||||||
- Precomputing a list of tz names that exist in both tzdb CSVs and `pytz`
|
|
||||||
|
|
||||||
Note: this is intentionally NOT run at import time so tests can monkeypatch
|
|
||||||
`load_timezones`/`load_countries` without needing to reset global state.
|
|
||||||
"""
|
|
||||||
global TZDB_CACHE
|
|
||||||
if TZDB_CACHE is not None:
|
|
||||||
return TZDB_CACHE
|
|
||||||
|
|
||||||
timezones = load_timezones()
|
|
||||||
countries = load_countries()
|
|
||||||
|
|
||||||
tz_to_country_code: dict[str, str] = {}
|
|
||||||
tz_meta: dict[str, dict] = {}
|
|
||||||
for tz in timezones:
|
|
||||||
zone_name = tz.get("zone_name")
|
|
||||||
country_code = tz.get("country_code")
|
|
||||||
if not isinstance(zone_name, str) or not zone_name:
|
|
||||||
continue
|
|
||||||
if not isinstance(country_code, str) or not country_code:
|
|
||||||
continue
|
|
||||||
|
|
||||||
tz_to_country_code[zone_name] = country_code
|
|
||||||
region, city = split_tz_name(zone_name)
|
|
||||||
tz_meta[zone_name] = {
|
|
||||||
"zone_name": zone_name,
|
|
||||||
"country_code": country_code,
|
|
||||||
"region": region,
|
|
||||||
"city": city,
|
|
||||||
}
|
|
||||||
|
|
||||||
country_code_to_name: dict[str, str] = {}
|
|
||||||
for c in countries:
|
|
||||||
code = c.get("country_code")
|
|
||||||
name = c.get("country_name")
|
|
||||||
if code and name:
|
|
||||||
country_code_to_name[code] = str(name).strip().strip('"')
|
|
||||||
|
|
||||||
# Attach resolved country names onto tz_meta (storage-only for now).
|
|
||||||
for zone_name, meta in tz_meta.items():
|
|
||||||
code = meta.get("country_code")
|
|
||||||
if isinstance(code, str):
|
|
||||||
meta["country_name"] = country_code_to_name.get(code)
|
|
||||||
|
|
||||||
# Vetted tz list: only names that are present in tzdb and loadable by zoneinfo.
|
|
||||||
# Installing the `tzdata` package keeps this mapping up-to-date.
|
|
||||||
tz_names: list[str] = []
|
|
||||||
for zone_name in tz_to_country_code.keys():
|
|
||||||
try:
|
|
||||||
ZoneInfo(zone_name)
|
|
||||||
except ZoneInfoNotFoundError:
|
|
||||||
continue
|
|
||||||
tz_names.append(zone_name)
|
|
||||||
|
|
||||||
TZDB_CACHE = {
|
|
||||||
"tz_to_country_code": tz_to_country_code,
|
|
||||||
"country_code_to_name": country_code_to_name,
|
|
||||||
"tz_names": tz_names,
|
|
||||||
"tz_meta": tz_meta,
|
|
||||||
}
|
|
||||||
return TZDB_CACHE
|
|
||||||
|
|
||||||
|
|
||||||
def split_tz_name(zone_name: str) -> tuple[str, str]:
|
def get_next_event() -> dict:
|
||||||
"""Split an IANA timezone name into (region, city).
|
return get_next_scheduled_event()
|
||||||
|
|
||||||
Examples:
|
|
||||||
- "America/New_York" -> ("America", "New_York")
|
|
||||||
- "America/Argentina/Buenos_Aires" -> ("America", "Argentina/Buenos_Aires")
|
|
||||||
- "UTC" -> ("UTC", "")
|
|
||||||
"""
|
|
||||||
if "/" not in zone_name:
|
|
||||||
return zone_name, ""
|
|
||||||
region, rest = zone_name.split("/", 1)
|
|
||||||
return region, rest
|
|
||||||
|
|
||||||
|
|
||||||
def _update_state(**updates) -> None:
|
def _update_state(**updates) -> None:
|
||||||
@@ -134,20 +72,52 @@ def get_state_snapshot() -> dict:
|
|||||||
return dict(STATE)
|
return dict(STATE)
|
||||||
|
|
||||||
|
|
||||||
|
def _check_420(tz_list: list[str] | None = None) -> str:
|
||||||
|
cache = get_tzdb_cache()
|
||||||
|
timezones = load_timezones() if cache is None else []
|
||||||
|
countries = load_countries() if cache is None else []
|
||||||
|
if tz_list is None:
|
||||||
|
if cache is None:
|
||||||
|
tz_list = where_is_it_420(timezones, countries)
|
||||||
|
else:
|
||||||
|
tz_list = where_is_it_420(
|
||||||
|
timezones,
|
||||||
|
countries,
|
||||||
|
tz_names=cache.get("tz_names"),
|
||||||
|
tz_to_country_code=cache.get("tz_to_country_code"),
|
||||||
|
country_code_to_name=cache.get("country_code_to_name"),
|
||||||
|
)
|
||||||
|
if tz_list:
|
||||||
|
tz_str = "\n".join(tz_list)
|
||||||
|
return f"\nIt's 4:20 in:\n{tz_str}"
|
||||||
|
return ""
|
||||||
|
|
||||||
|
|
||||||
|
def _update_420_cache() -> None:
|
||||||
|
try:
|
||||||
|
cache = get_tzdb_cache()
|
||||||
|
if cache is None:
|
||||||
|
tz_list = where_is_it_420(load_timezones(), load_countries())
|
||||||
|
else:
|
||||||
|
tz_list = where_is_it_420(
|
||||||
|
[],
|
||||||
|
[],
|
||||||
|
tz_names=cache.get("tz_names"),
|
||||||
|
tz_to_country_code=cache.get("tz_to_country_code"),
|
||||||
|
country_code_to_name=cache.get("country_code_to_name"),
|
||||||
|
)
|
||||||
|
_update_state(last_locations=tz_list or [])
|
||||||
|
except Exception as e:
|
||||||
|
_update_state(last_locations=[], last_error=str(e))
|
||||||
|
|
||||||
|
|
||||||
def get_next_scheduled_event(now: datetime | None = None) -> dict:
|
def get_next_scheduled_event(now: datetime | None = None) -> dict:
|
||||||
"""Return the next scheduled notification time/type based on known minute marks."""
|
"""Return the next scheduled notification time/type based on known minute marks."""
|
||||||
if now is None:
|
if now is None:
|
||||||
now = datetime.now()
|
now = datetime.now()
|
||||||
|
|
||||||
schedule_map = [
|
|
||||||
(15, "reminder"),
|
|
||||||
(20, "420"),
|
|
||||||
(45, "reminder_halftime"),
|
|
||||||
(50, "halftime"),
|
|
||||||
]
|
|
||||||
|
|
||||||
candidates: list[tuple[datetime, str]] = []
|
candidates: list[tuple[datetime, str]] = []
|
||||||
for minute, msg_type in schedule_map:
|
for minute, msg_type in SCHEDULED_NOTIFICATIONS:
|
||||||
candidate = now.replace(minute=minute, second=0, microsecond=0)
|
candidate = now.replace(minute=minute, second=0, microsecond=0)
|
||||||
if candidate > now:
|
if candidate > now:
|
||||||
candidates.append((candidate, msg_type))
|
candidates.append((candidate, msg_type))
|
||||||
@@ -155,60 +125,19 @@ def get_next_scheduled_event(now: datetime | None = None) -> dict:
|
|||||||
if not candidates:
|
if not candidates:
|
||||||
base = (now + timedelta(hours=1)).replace(minute=0,
|
base = (now + timedelta(hours=1)).replace(minute=0,
|
||||||
second=0, microsecond=0)
|
second=0, microsecond=0)
|
||||||
next_dt = base.replace(minute=schedule_map[0][0])
|
next_dt = base.replace(minute=SCHEDULED_NOTIFICATIONS[0][0])
|
||||||
return {"at": next_dt, "type": schedule_map[0][1]}
|
return {"at": next_dt, "type": SCHEDULED_NOTIFICATIONS[0][1]}
|
||||||
|
|
||||||
next_dt, next_type = min(candidates, key=lambda x: x[0])
|
next_dt, next_type = min(candidates, key=lambda x: x[0])
|
||||||
return {"at": next_dt, "type": next_type}
|
return {"at": next_dt, "type": next_type}
|
||||||
|
|
||||||
|
|
||||||
def load_timezones() -> list[dict]:
|
|
||||||
"""Load timezones from csv file."""
|
|
||||||
# Read the CSV file and return a list of timezones
|
|
||||||
with open("tzdb/TimeZoneDB.csv/time_zone.csv", "r", encoding="utf-8") as f:
|
|
||||||
lines = f.readlines()
|
|
||||||
# Fields: zone_name,country_code,abbreviation,time_start,gmt_offset,dst
|
|
||||||
timezones = []
|
|
||||||
for line in lines:
|
|
||||||
fields = line.strip().split(",")
|
|
||||||
if len(fields) >= 5:
|
|
||||||
timezones.append({
|
|
||||||
"zone_name": fields[0],
|
|
||||||
"country_code": fields[1],
|
|
||||||
"abbreviation": fields[2],
|
|
||||||
"time_start": fields[3],
|
|
||||||
"gmt_offset": int(fields[4]),
|
|
||||||
"dst": fields[5] == '1'
|
|
||||||
})
|
|
||||||
return timezones
|
|
||||||
|
|
||||||
|
|
||||||
def load_countries() -> list[dict]:
|
|
||||||
"""Load countries from csv file."""
|
|
||||||
# Read the CSV file and return a list of countries
|
|
||||||
with open("tzdb/TimeZoneDB.csv/country.csv", "r", encoding="utf-8") as f:
|
|
||||||
lines = f.readlines()
|
|
||||||
# Fields: country_code,country_name
|
|
||||||
countries = []
|
|
||||||
for line in lines:
|
|
||||||
fields = line.strip().split(",")
|
|
||||||
if len(fields) >= 2:
|
|
||||||
countries.append({
|
|
||||||
"country_code": fields[0],
|
|
||||||
"country_name": fields[1]
|
|
||||||
})
|
|
||||||
return countries
|
|
||||||
|
|
||||||
|
|
||||||
def create_embed(type: str, tz_list: list[str] | None = None) -> dict:
|
def create_embed(type: str, tz_list: list[str] | None = None) -> dict:
|
||||||
"""
|
"""
|
||||||
Create a Discord embed message.
|
Create a Discord embed message.
|
||||||
"""
|
"""
|
||||||
templates_path = os.getenv("TEMPLATES_PATH", "templates.json")
|
templates_path = os.getenv("TEMPLATES_PATH", "templates.json")
|
||||||
messages = load_templates(templates_path)
|
messages = load_templates(templates_path)
|
||||||
cache = TZDB_CACHE
|
|
||||||
timezones = load_timezones() if cache is None else []
|
|
||||||
countries = load_countries() if cache is None else []
|
|
||||||
if type in messages:
|
if type in messages:
|
||||||
msg = messages[type]
|
msg = messages[type]
|
||||||
image_url = msg.get("image_url")
|
image_url = msg.get("image_url")
|
||||||
@@ -216,20 +145,7 @@ def create_embed(type: str, tz_list: list[str] | None = None) -> dict:
|
|||||||
msg["image"] = {"url": image_url}
|
msg["image"] = {"url": image_url}
|
||||||
if type == "420":
|
if type == "420":
|
||||||
# Check where it's 4:20
|
# Check where it's 4:20
|
||||||
if tz_list is None:
|
msg["text"] += _check_420(tz_list)
|
||||||
if cache is None:
|
|
||||||
tz_list = where_is_it_420(timezones, countries)
|
|
||||||
else:
|
|
||||||
tz_list = where_is_it_420(
|
|
||||||
timezones,
|
|
||||||
countries,
|
|
||||||
tz_names=cache.get("tz_names"),
|
|
||||||
tz_to_country_code=cache.get("tz_to_country_code"),
|
|
||||||
country_code_to_name=cache.get("country_code_to_name"),
|
|
||||||
)
|
|
||||||
if tz_list:
|
|
||||||
tz_str = "\n".join(tz_list)
|
|
||||||
msg["text"] += f"\nIt's 4:20 in:\n{tz_str}"
|
|
||||||
else:
|
else:
|
||||||
msg = {"text": "Unknown notification type", "color": 0xFF0000}
|
msg = {"text": "Unknown notification type", "color": 0xFF0000}
|
||||||
embed = {
|
embed = {
|
||||||
@@ -243,348 +159,16 @@ def create_embed(type: str, tz_list: list[str] | None = None) -> dict:
|
|||||||
return embed
|
return embed
|
||||||
|
|
||||||
|
|
||||||
def get_discord_headers() -> dict[str, str]:
|
def send_notification(message: str) -> bool:
|
||||||
return {
|
|
||||||
"Authorization": f"Bot {DISCORD_BOT_TOKEN}",
|
|
||||||
"Content-Type": "application/json",
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def parse_message_timestamp(message: dict) -> datetime:
|
|
||||||
return datetime.fromisoformat(message["timestamp"].replace("Z", ""))
|
|
||||||
|
|
||||||
|
|
||||||
def build_delete_entry(message: dict) -> dict:
|
|
||||||
return {
|
|
||||||
"id": message.get("id"),
|
|
||||||
"timestamp": parse_message_timestamp(message),
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def parse_float(value: str | int | float | None) -> float | None:
|
|
||||||
if value is None:
|
|
||||||
return None
|
|
||||||
try:
|
|
||||||
return float(value)
|
|
||||||
except (TypeError, ValueError):
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def should_delete_message(
|
|
||||||
message: dict,
|
|
||||||
webhook_id: str,
|
|
||||||
author_id: str,
|
|
||||||
cutoff: int,
|
|
||||||
) -> bool:
|
|
||||||
message_timestamp = int(parse_message_timestamp(message).timestamp())
|
|
||||||
return (
|
|
||||||
message_timestamp <= cutoff
|
|
||||||
and message.get("webhook_id") == webhook_id
|
|
||||||
and message.get("author", {}).get("id") == author_id
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def get_rate_limit_retry_after(response: requests.Response) -> float | None:
|
|
||||||
header_retry_after = parse_float(response.headers.get("Retry-After"))
|
|
||||||
if header_retry_after is not None:
|
|
||||||
return header_retry_after
|
|
||||||
|
|
||||||
reset_after = parse_float(response.headers.get("X-RateLimit-Reset-After"))
|
|
||||||
if reset_after is not None:
|
|
||||||
return reset_after
|
|
||||||
|
|
||||||
try:
|
|
||||||
payload = response.json()
|
|
||||||
except ValueError:
|
|
||||||
return None
|
|
||||||
|
|
||||||
return parse_float(payload.get("retry_after"))
|
|
||||||
|
|
||||||
|
|
||||||
def get_bucket_exhausted_delay(response: requests.Response) -> float | None:
|
|
||||||
remaining = response.headers.get("X-RateLimit-Remaining")
|
|
||||||
if remaining != "0":
|
|
||||||
return None
|
|
||||||
return parse_float(response.headers.get("X-RateLimit-Reset-After"))
|
|
||||||
|
|
||||||
|
|
||||||
def sleep_for_rate_limit(delay_seconds: float, reason: str) -> None:
|
|
||||||
if delay_seconds <= 0:
|
|
||||||
return
|
|
||||||
logging.info(
|
|
||||||
f"Waiting {delay_seconds:.3f}s for Discord rate limit reset ({reason}).")
|
|
||||||
time.sleep(delay_seconds)
|
|
||||||
|
|
||||||
|
|
||||||
def find_last_message_by_author(
|
|
||||||
headers: dict[str, str],
|
|
||||||
guild_id: str,
|
|
||||||
channel_id: str,
|
|
||||||
author_id: str,
|
|
||||||
) -> dict | None:
|
|
||||||
"""Find the newest indexed message for the author in the target guild channel."""
|
|
||||||
url = f"https://discord.com/api/v10/guilds/{guild_id}/messages/search"
|
|
||||||
params = [
|
|
||||||
("author_id", author_id),
|
|
||||||
("author_type", "webhook"),
|
|
||||||
("channel_id", channel_id),
|
|
||||||
("sort_by", "timestamp"),
|
|
||||||
("sort_order", "desc"),
|
|
||||||
("limit", "10"),
|
|
||||||
]
|
|
||||||
|
|
||||||
for _ in range(3):
|
|
||||||
try:
|
|
||||||
response = requests.get(
|
|
||||||
url, headers=headers, params=params, timeout=10)
|
|
||||||
except requests.RequestException as e:
|
|
||||||
logging.error(f"Error searching guild messages: {e}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
if response.status_code == 202:
|
|
||||||
payload = response.json()
|
|
||||||
retry_after = float(payload.get("retry_after", 1) or 1)
|
|
||||||
logging.info(
|
|
||||||
f"Guild search index not ready. Retrying after {retry_after} seconds."
|
|
||||||
)
|
|
||||||
time.sleep(retry_after)
|
|
||||||
continue
|
|
||||||
|
|
||||||
if response.status_code == 429:
|
|
||||||
retry_after = get_rate_limit_retry_after(response) or 1.0
|
|
||||||
sleep_for_rate_limit(retry_after, "guild search")
|
|
||||||
continue
|
|
||||||
|
|
||||||
if response.status_code != 200:
|
|
||||||
logging.error(
|
|
||||||
f"Failed to search guild messages: {response.status_code} - {response.text}"
|
|
||||||
)
|
|
||||||
return None
|
|
||||||
|
|
||||||
message_groups = response.json().get("messages", [])
|
|
||||||
if not message_groups or not message_groups[0]:
|
|
||||||
return None
|
|
||||||
return message_groups[0][0]
|
|
||||||
|
|
||||||
logging.error("Guild search index did not become available in time.")
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def fetch_messages_to_delete(headers: dict, channel_id: str, webhook_id: str, author_id: str, cutoff: int, last_message_id: str | None = None) -> tuple[list[dict], str | None]:
|
|
||||||
"""
|
|
||||||
Fetch messages from the channel that are older than the cutoff timestamp and sent by the webhook.
|
|
||||||
Uses pagination with the 'before' parameter to resume from the last processed message.
|
|
||||||
Returns a tuple of (list of messages to delete, last message ID for pagination).
|
|
||||||
"""
|
|
||||||
url = f"https://discord.com/api/v10/channels/{channel_id}/messages"
|
|
||||||
params: dict[str, str | int] = {
|
|
||||||
"limit": 100, # Maximum allowed by Discord API
|
|
||||||
}
|
|
||||||
|
|
||||||
# Use the 'before' parameter for pagination if last_message_id is provided
|
|
||||||
if last_message_id:
|
|
||||||
params["before"] = last_message_id
|
|
||||||
|
|
||||||
try:
|
|
||||||
for _ in range(3):
|
|
||||||
response = requests.get(url, headers=headers,
|
|
||||||
params=params, timeout=10)
|
|
||||||
|
|
||||||
if response.status_code == 429:
|
|
||||||
retry_after = get_rate_limit_retry_after(response) or 1.0
|
|
||||||
sleep_for_rate_limit(retry_after, "channel message fetch")
|
|
||||||
continue
|
|
||||||
|
|
||||||
break
|
|
||||||
else:
|
|
||||||
logging.error(
|
|
||||||
"Failed to fetch messages after repeated rate limits.")
|
|
||||||
return [], last_message_id
|
|
||||||
|
|
||||||
if response.status_code == 200:
|
|
||||||
messages = response.json()
|
|
||||||
delete_list = []
|
|
||||||
new_last_message_id = None
|
|
||||||
|
|
||||||
for message in messages:
|
|
||||||
new_last_message_id = message.get("id")
|
|
||||||
|
|
||||||
if should_delete_message(
|
|
||||||
message,
|
|
||||||
webhook_id,
|
|
||||||
author_id,
|
|
||||||
cutoff,
|
|
||||||
):
|
|
||||||
delete_list.append(build_delete_entry(message))
|
|
||||||
|
|
||||||
# Limit the list to 100 items
|
|
||||||
if len(delete_list) >= 100:
|
|
||||||
break
|
|
||||||
|
|
||||||
return delete_list, new_last_message_id
|
|
||||||
else:
|
|
||||||
logging.error(
|
|
||||||
f"Failed to fetch messages: {response.status_code} - {response.text}")
|
|
||||||
return [], last_message_id
|
|
||||||
except requests.RequestException as e:
|
|
||||||
logging.error(f"Error fetching messages: {e}")
|
|
||||||
return [], last_message_id
|
|
||||||
|
|
||||||
|
|
||||||
def delete_message(headers: dict, channel_id: str, message_id: str) -> tuple[bool, float | None, bool]:
|
|
||||||
"""
|
|
||||||
Delete a single message from the channel.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
tuple[bool, float | None, bool]:
|
|
||||||
- whether the delete succeeded
|
|
||||||
- how long to wait before the next request, if any
|
|
||||||
- whether to abort the batch because further requests would be invalid
|
|
||||||
"""
|
|
||||||
delete_url = f"https://discord.com/api/v10/channels/{channel_id}/messages/{message_id}"
|
|
||||||
delete_response = requests.delete(delete_url, headers=headers, timeout=10)
|
|
||||||
|
|
||||||
if delete_response.status_code == 204:
|
|
||||||
return True, get_bucket_exhausted_delay(delete_response), False
|
|
||||||
|
|
||||||
if delete_response.status_code == 429:
|
|
||||||
retry_after = get_rate_limit_retry_after(delete_response) or 1.0
|
|
||||||
scope = delete_response.headers.get("X-RateLimit-Scope", "unknown")
|
|
||||||
is_global = delete_response.headers.get(
|
|
||||||
"X-RateLimit-Global", "false").lower() == "true"
|
|
||||||
logging.warning(
|
|
||||||
"Discord rate limit hit while deleting message %s: scope=%s global=%s retry_after=%.3fs",
|
|
||||||
message_id,
|
|
||||||
scope,
|
|
||||||
is_global,
|
|
||||||
retry_after,
|
|
||||||
)
|
|
||||||
return False, retry_after, False
|
|
||||||
|
|
||||||
if delete_response.status_code in {401, 403}:
|
|
||||||
logging.error(
|
|
||||||
"Failed to delete message %s: %s - %s. Stopping deletes to avoid invalid request spam.",
|
|
||||||
message_id,
|
|
||||||
delete_response.status_code,
|
|
||||||
delete_response.text,
|
|
||||||
)
|
|
||||||
return False, None, True
|
|
||||||
|
|
||||||
logging.error(
|
|
||||||
f"Failed to delete message {message_id}: {delete_response.status_code} - {delete_response.text}")
|
|
||||||
return False, None, False
|
|
||||||
|
|
||||||
|
|
||||||
def delete_old_messages(minutes: int = 6) -> None:
|
|
||||||
"""
|
|
||||||
Delete all messages sent by the webhook in the last `minutes` minutes.
|
|
||||||
Uses a dynamic slowdown to avoid hitting Discord API rate limits and pagination to fetch all messages.
|
|
||||||
"""
|
|
||||||
if not DISCORD_BOT_TOKEN or not DISCORD_CHANNEL_ID or not GUILD_ID:
|
|
||||||
logging.error(
|
|
||||||
"DISCORD_BOT_TOKEN, DISCORD_CHANNEL_ID, or DISCORD_GUILD_ID not set")
|
|
||||||
return
|
|
||||||
|
|
||||||
headers = get_discord_headers()
|
|
||||||
|
|
||||||
# Calculate the time `minutes` minutes ago
|
|
||||||
cutoff_timestamp = datetime.now() - timedelta(minutes=minutes)
|
|
||||||
cutoff = int(cutoff_timestamp.timestamp())
|
|
||||||
webhook_id = '1413817504194760766'
|
|
||||||
author_id = '1413817504194760766'
|
|
||||||
|
|
||||||
last_author_message = find_last_message_by_author(
|
|
||||||
headers,
|
|
||||||
GUILD_ID,
|
|
||||||
DISCORD_CHANNEL_ID,
|
|
||||||
author_id,
|
|
||||||
)
|
|
||||||
if last_author_message is None:
|
|
||||||
logging.info("No indexed messages found for the target author.")
|
|
||||||
return
|
|
||||||
|
|
||||||
last_message_id = last_author_message.get("id")
|
|
||||||
if not last_message_id:
|
|
||||||
logging.info("Search result did not contain a message id.")
|
|
||||||
return
|
|
||||||
|
|
||||||
deleted_count = 0
|
|
||||||
|
|
||||||
if should_delete_message(
|
|
||||||
last_author_message,
|
|
||||||
webhook_id,
|
|
||||||
author_id,
|
|
||||||
cutoff,
|
|
||||||
):
|
|
||||||
anchor_message = build_delete_entry(last_author_message)
|
|
||||||
deleted, wait_seconds, abort_batch = delete_message(
|
|
||||||
headers,
|
|
||||||
DISCORD_CHANNEL_ID,
|
|
||||||
anchor_message["id"],
|
|
||||||
)
|
|
||||||
if deleted:
|
|
||||||
deleted_count += 1
|
|
||||||
logging.info(
|
|
||||||
f"Deleted message {anchor_message['id']} from {anchor_message['timestamp'].isoformat()}"
|
|
||||||
)
|
|
||||||
elif abort_batch:
|
|
||||||
return
|
|
||||||
|
|
||||||
if wait_seconds is not None:
|
|
||||||
sleep_for_rate_limit(wait_seconds, "delete bucket")
|
|
||||||
|
|
||||||
while True:
|
|
||||||
# Fetch messages to delete with pagination
|
|
||||||
delete_list, next_last_message_id = fetch_messages_to_delete(
|
|
||||||
headers, DISCORD_CHANNEL_ID, webhook_id, author_id, cutoff, last_message_id
|
|
||||||
)
|
|
||||||
|
|
||||||
# Stop scanning when a page has no eligible messages to delete.
|
|
||||||
# Continuing pagination here can walk indefinitely through history.
|
|
||||||
if not delete_list:
|
|
||||||
if deleted_count == 0:
|
|
||||||
logging.info("No messages to delete.")
|
|
||||||
else:
|
|
||||||
logging.info("No more messages to delete.")
|
|
||||||
break
|
|
||||||
|
|
||||||
for message in delete_list:
|
|
||||||
message_id = message["id"]
|
|
||||||
message_time = message["timestamp"]
|
|
||||||
|
|
||||||
deleted, wait_seconds, abort_batch = delete_message(
|
|
||||||
headers,
|
|
||||||
DISCORD_CHANNEL_ID,
|
|
||||||
message_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
if deleted:
|
|
||||||
deleted_count += 1
|
|
||||||
logging.info(
|
|
||||||
f"Deleted message {message_id} from {message_time.isoformat()}")
|
|
||||||
elif abort_batch:
|
|
||||||
logging.info(
|
|
||||||
"Stopping delete batch after an invalid Discord response.")
|
|
||||||
return
|
|
||||||
|
|
||||||
if wait_seconds is not None:
|
|
||||||
sleep_for_rate_limit(wait_seconds, "delete bucket")
|
|
||||||
|
|
||||||
if next_last_message_id is None or next_last_message_id == last_message_id:
|
|
||||||
break
|
|
||||||
last_message_id = next_last_message_id
|
|
||||||
|
|
||||||
logging.info(
|
|
||||||
f"Deleted {deleted_count} messages older than {minutes} minutes sent by the webhook.")
|
|
||||||
|
|
||||||
|
|
||||||
def send_notification(message: str) -> None:
|
|
||||||
"""
|
"""
|
||||||
Send a notification to the Discord webhook.
|
Send a notification to the Discord webhook.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True when the webhook accepted the notification, False otherwise.
|
||||||
"""
|
"""
|
||||||
if not WEBHOOK_URL:
|
if not WEBHOOK_URL:
|
||||||
logging.error("WEBHOOK_URL not set")
|
logging.error("WEBHOOK_URL not set")
|
||||||
return
|
return False
|
||||||
|
|
||||||
_update_state(
|
_update_state(
|
||||||
last_type=message,
|
last_type=message,
|
||||||
@@ -602,21 +186,7 @@ def send_notification(message: str) -> None:
|
|||||||
|
|
||||||
tz_list: list[str] | None = None
|
tz_list: list[str] | None = None
|
||||||
if message == "420":
|
if message == "420":
|
||||||
try:
|
_update_420_cache()
|
||||||
cache = TZDB_CACHE
|
|
||||||
if cache is None:
|
|
||||||
tz_list = where_is_it_420(load_timezones(), load_countries())
|
|
||||||
else:
|
|
||||||
tz_list = where_is_it_420(
|
|
||||||
[],
|
|
||||||
[],
|
|
||||||
tz_names=cache.get("tz_names"),
|
|
||||||
tz_to_country_code=cache.get("tz_to_country_code"),
|
|
||||||
country_code_to_name=cache.get("country_code_to_name"),
|
|
||||||
)
|
|
||||||
_update_state(last_locations=tz_list)
|
|
||||||
except Exception as e:
|
|
||||||
_update_state(last_locations=[], last_error=str(e))
|
|
||||||
|
|
||||||
embed = create_embed(message, tz_list=tz_list)
|
embed = create_embed(message, tz_list=tz_list)
|
||||||
data = {"embeds": [embed]}
|
data = {"embeds": [embed]}
|
||||||
@@ -626,6 +196,7 @@ def send_notification(message: str) -> None:
|
|||||||
logging.info(f"Notification sent: {message}")
|
logging.info(f"Notification sent: {message}")
|
||||||
_update_state(last_success_at=datetime.now(),
|
_update_state(last_success_at=datetime.now(),
|
||||||
last_status_code=response.status_code)
|
last_status_code=response.status_code)
|
||||||
|
return True
|
||||||
else:
|
else:
|
||||||
logging.error(
|
logging.error(
|
||||||
f"Failed to send notification: {response.status_code} - "
|
f"Failed to send notification: {response.status_code} - "
|
||||||
@@ -633,96 +204,61 @@ def send_notification(message: str) -> None:
|
|||||||
)
|
)
|
||||||
_update_state(last_status_code=response.status_code,
|
_update_state(last_status_code=response.status_code,
|
||||||
last_error=response.text)
|
last_error=response.text)
|
||||||
|
return False
|
||||||
except requests.RequestException as e:
|
except requests.RequestException as e:
|
||||||
logging.error(f"Error sending notification: {e}")
|
logging.error(f"Error sending notification: {e}")
|
||||||
_update_state(last_error=str(e))
|
_update_state(last_error=str(e))
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
def get_tz_info(tz_name: str, timezones: list[dict]) -> dict | None:
|
def _schedule_startup_test_cleanup(test_sent: bool) -> None:
|
||||||
"""Get timezone info by name."""
|
"""Schedule one-time cleanup for the startup test notification."""
|
||||||
return next((tz for tz in timezones if tz["zone_name"] == tz_name), None)
|
if not test_sent:
|
||||||
|
return
|
||||||
|
|
||||||
|
def cleanup_startup_test_message() -> schedule.CancelJob:
|
||||||
|
delete_old_messages(1, content_pattern=TEST_MESSAGE_DELETE_PATTERN)
|
||||||
|
return schedule.CancelJob
|
||||||
|
|
||||||
|
schedule.every(1).minutes.do(cleanup_startup_test_message)
|
||||||
|
|
||||||
|
|
||||||
def get_country_info(country_code: str, countries: list[dict]) -> dict | None:
|
def schedule_notification(interval: str, at: str, type: str) -> None:
|
||||||
"""Get country info by country code."""
|
"""Example: schedule.every().hour.at(":15").do(send_notification, "reminder")"""
|
||||||
return next((c for c in countries if c["country_code"] == country_code), None)
|
if interval == "hour":
|
||||||
|
schedule.every().hour.at(at).do(send_notification, type)
|
||||||
|
elif interval == "day":
|
||||||
|
schedule.every().day.at(at).do(send_notification, type)
|
||||||
|
else:
|
||||||
|
logging.error(f"Unsupported interval: {interval}")
|
||||||
|
|
||||||
|
|
||||||
def where_is_it_420(
|
def start_dashboard() -> None:
|
||||||
timezones: list[dict],
|
"""Compatibility hook for tests and optional dashboard startup."""
|
||||||
countries: list[dict],
|
app = create_app(get_state=get_state, get_next_event=get_next_event)
|
||||||
tz_names: list[str] | None = None,
|
app.run(host="0.0.0.0", port=8420, debug=False, use_reloader=False)
|
||||||
tz_to_country_code: dict[str, str] | None = None,
|
|
||||||
country_code_to_name: dict[str, str] | None = None,
|
|
||||||
) -> list[str]:
|
|
||||||
"""Get timezones where the current hour is 4 or 16, indicating it's 4:20 there.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
list[str]: A list of timezones where it's currently 4:20 PM or AM.
|
|
||||||
"""
|
|
||||||
# Build fast lookup dicts if not provided.
|
|
||||||
if tz_to_country_code is None:
|
|
||||||
tz_to_country_code = {}
|
|
||||||
for tz in timezones:
|
|
||||||
zone_name = tz.get("zone_name")
|
|
||||||
country_code = tz.get("country_code")
|
|
||||||
if isinstance(zone_name, str) and isinstance(country_code, str):
|
|
||||||
tz_to_country_code[zone_name] = country_code
|
|
||||||
|
|
||||||
if country_code_to_name is None:
|
|
||||||
country_code_to_name = {}
|
|
||||||
for c in countries:
|
|
||||||
code = c.get("country_code")
|
|
||||||
name = c.get("country_name")
|
|
||||||
if isinstance(code, str) and name is not None:
|
|
||||||
country_code_to_name[code] = str(name).strip().strip('"')
|
|
||||||
|
|
||||||
names_to_check = tz_names if tz_names is not None else pytz.all_timezones
|
|
||||||
results: list[str] = []
|
|
||||||
seen: set[str] = set()
|
|
||||||
|
|
||||||
for tz_name in names_to_check:
|
|
||||||
try:
|
|
||||||
tz_obj = pytz.timezone(tz_name)
|
|
||||||
except Exception:
|
|
||||||
continue
|
|
||||||
|
|
||||||
now = datetime.now(tz_obj)
|
|
||||||
if now.hour != 4 and now.hour != 16:
|
|
||||||
continue
|
|
||||||
|
|
||||||
country_code = tz_to_country_code.get(tz_name)
|
|
||||||
if not country_code:
|
|
||||||
continue
|
|
||||||
country_name = country_code_to_name.get(country_code)
|
|
||||||
if not country_name:
|
|
||||||
continue
|
|
||||||
if country_name in seen:
|
|
||||||
continue
|
|
||||||
|
|
||||||
seen.add(country_name)
|
|
||||||
results.append(country_name)
|
|
||||||
|
|
||||||
return results
|
|
||||||
|
|
||||||
|
|
||||||
def main() -> None:
|
def main() -> None:
|
||||||
"""
|
"""
|
||||||
Main function to run the scheduler.
|
Main function to run the scheduler.
|
||||||
"""
|
"""
|
||||||
# Schedule notifications
|
# Start the dashboard in a separate thread
|
||||||
schedule.every().hour.at(":15").do(send_notification, "reminder")
|
dashboard_thread = threading.Thread(target=start_dashboard, daemon=True)
|
||||||
schedule.every().hour.at(":20").do(send_notification, "420")
|
dashboard_thread.start()
|
||||||
# schedule.every().hour.at(":45").do(send_notification, "reminder_halftime")
|
|
||||||
# schedule.every().hour.at(":50").do(send_notification, "halftime")
|
# Schedule notifications based on the defined SCHEDULED_NOTIFICATIONS
|
||||||
|
for minute, msg_type in SCHEDULED_NOTIFICATIONS:
|
||||||
|
schedule_notification("hour", f":{minute:02d}", msg_type)
|
||||||
|
|
||||||
# Schedule deletion of old messages every 5 minutes
|
# Schedule deletion of old messages every 5 minutes
|
||||||
schedule.every(5).minutes.do(delete_old_messages, 6)
|
schedule.every(5).minutes.do(delete_old_messages, 6)
|
||||||
|
|
||||||
logging.info("Scheduler started.")
|
logging.info("Scheduler started.")
|
||||||
|
|
||||||
# Test the notification on startup
|
# Send one startup test message and cleanup only if send succeeded.
|
||||||
# send_notification("420")
|
test_sent = send_notification("test")
|
||||||
|
_schedule_startup_test_cleanup(test_sent)
|
||||||
|
|
||||||
# delete old messages on startup to clean up any previous notifications
|
# delete old messages on startup to clean up any previous notifications
|
||||||
# delete_old_messages(6)
|
# delete_old_messages(6)
|
||||||
|
|||||||
+392
@@ -0,0 +1,392 @@
|
|||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import re
|
||||||
|
import time
|
||||||
|
from datetime import datetime, timedelta
|
||||||
|
|
||||||
|
import requests
|
||||||
|
|
||||||
|
|
||||||
|
WEBHOOK_AUTHOR_ID = "1413817504194760766"
|
||||||
|
|
||||||
|
|
||||||
|
def get_discord_headers() -> dict[str, str]:
|
||||||
|
token = os.getenv("DISCORD_BOT_TOKEN")
|
||||||
|
return {
|
||||||
|
"Authorization": f"Bot {token}",
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def parse_message_timestamp(message: dict) -> datetime:
|
||||||
|
return datetime.fromisoformat(message["timestamp"].replace("Z", ""))
|
||||||
|
|
||||||
|
|
||||||
|
def build_delete_entry(message: dict) -> dict:
|
||||||
|
return {
|
||||||
|
"id": message.get("id"),
|
||||||
|
"timestamp": parse_message_timestamp(message),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def parse_float(value: str | int | float | None) -> float | None:
|
||||||
|
if value is None:
|
||||||
|
return None
|
||||||
|
try:
|
||||||
|
return float(value)
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def should_delete_message(
|
||||||
|
message: dict,
|
||||||
|
webhook_id: str,
|
||||||
|
author_id: str,
|
||||||
|
cutoff: int,
|
||||||
|
content_pattern: str | None = None,
|
||||||
|
) -> bool:
|
||||||
|
message_timestamp = int(parse_message_timestamp(message).timestamp())
|
||||||
|
return (
|
||||||
|
message_timestamp <= cutoff
|
||||||
|
and message.get("webhook_id") == webhook_id
|
||||||
|
and message.get("author", {}).get("id") == author_id
|
||||||
|
and message_matches_pattern(message, content_pattern)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def message_matches_pattern(message: dict, content_pattern: str | None = None) -> bool:
|
||||||
|
"""Return True when message content/embed text matches the optional pattern."""
|
||||||
|
if not content_pattern:
|
||||||
|
return True
|
||||||
|
|
||||||
|
text_chunks: list[str] = []
|
||||||
|
content = message.get("content")
|
||||||
|
if isinstance(content, str) and content:
|
||||||
|
text_chunks.append(content)
|
||||||
|
|
||||||
|
embeds = message.get("embeds")
|
||||||
|
if isinstance(embeds, list):
|
||||||
|
for embed in embeds:
|
||||||
|
if not isinstance(embed, dict):
|
||||||
|
continue
|
||||||
|
title = embed.get("title")
|
||||||
|
description = embed.get("description")
|
||||||
|
if isinstance(title, str) and title:
|
||||||
|
text_chunks.append(title)
|
||||||
|
if isinstance(description, str) and description:
|
||||||
|
text_chunks.append(description)
|
||||||
|
|
||||||
|
footer = embed.get("footer")
|
||||||
|
if isinstance(footer, dict):
|
||||||
|
footer_text = footer.get("text")
|
||||||
|
if isinstance(footer_text, str) and footer_text:
|
||||||
|
text_chunks.append(footer_text)
|
||||||
|
|
||||||
|
if not text_chunks:
|
||||||
|
return False
|
||||||
|
|
||||||
|
searchable_text = "\n".join(text_chunks)
|
||||||
|
try:
|
||||||
|
return re.search(content_pattern, searchable_text, flags=re.IGNORECASE) is not None
|
||||||
|
except re.error:
|
||||||
|
return content_pattern.lower() in searchable_text.lower()
|
||||||
|
|
||||||
|
|
||||||
|
def get_rate_limit_retry_after(response: requests.Response) -> float | None:
|
||||||
|
header_retry_after = parse_float(response.headers.get("Retry-After"))
|
||||||
|
if header_retry_after is not None:
|
||||||
|
return header_retry_after
|
||||||
|
|
||||||
|
reset_after = parse_float(response.headers.get("X-RateLimit-Reset-After"))
|
||||||
|
if reset_after is not None:
|
||||||
|
return reset_after
|
||||||
|
|
||||||
|
try:
|
||||||
|
payload = response.json()
|
||||||
|
except ValueError:
|
||||||
|
return None
|
||||||
|
|
||||||
|
return parse_float(payload.get("retry_after"))
|
||||||
|
|
||||||
|
|
||||||
|
def get_bucket_exhausted_delay(response: requests.Response) -> float | None:
|
||||||
|
remaining = response.headers.get("X-RateLimit-Remaining")
|
||||||
|
if remaining != "0":
|
||||||
|
return None
|
||||||
|
return parse_float(response.headers.get("X-RateLimit-Reset-After"))
|
||||||
|
|
||||||
|
|
||||||
|
def sleep_for_rate_limit(delay_seconds: float, reason: str) -> None:
|
||||||
|
if delay_seconds <= 0:
|
||||||
|
return
|
||||||
|
logging.info(
|
||||||
|
f"Waiting {delay_seconds:.3f}s for Discord rate limit reset ({reason}).")
|
||||||
|
time.sleep(delay_seconds)
|
||||||
|
|
||||||
|
|
||||||
|
def find_last_message_by_author(
|
||||||
|
headers: dict[str, str],
|
||||||
|
guild_id: str,
|
||||||
|
channel_id: str,
|
||||||
|
author_id: str,
|
||||||
|
) -> dict | None:
|
||||||
|
"""Find the newest indexed message for the author in the target guild channel."""
|
||||||
|
url = f"https://discord.com/api/v10/guilds/{guild_id}/messages/search"
|
||||||
|
params = [
|
||||||
|
("author_id", author_id),
|
||||||
|
("author_type", "webhook"),
|
||||||
|
("channel_id", channel_id),
|
||||||
|
("sort_by", "timestamp"),
|
||||||
|
("sort_order", "desc"),
|
||||||
|
("limit", "10"),
|
||||||
|
]
|
||||||
|
|
||||||
|
for _ in range(3):
|
||||||
|
try:
|
||||||
|
response = requests.get(
|
||||||
|
url, headers=headers, params=params, timeout=10)
|
||||||
|
except requests.RequestException as e:
|
||||||
|
logging.error(f"Error searching guild messages: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
if response.status_code == 202:
|
||||||
|
payload = response.json()
|
||||||
|
retry_after = float(payload.get("retry_after", 1) or 1)
|
||||||
|
logging.info(
|
||||||
|
f"Guild search index not ready. Retrying after {retry_after} seconds."
|
||||||
|
)
|
||||||
|
time.sleep(retry_after)
|
||||||
|
continue
|
||||||
|
|
||||||
|
if response.status_code == 429:
|
||||||
|
retry_after = get_rate_limit_retry_after(response) or 1.0
|
||||||
|
sleep_for_rate_limit(retry_after, "guild search")
|
||||||
|
continue
|
||||||
|
|
||||||
|
if response.status_code != 200:
|
||||||
|
logging.error(
|
||||||
|
f"Failed to search guild messages: {response.status_code} - {response.text}"
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
|
||||||
|
message_groups = response.json().get("messages", [])
|
||||||
|
if not message_groups or not message_groups[0]:
|
||||||
|
return None
|
||||||
|
return message_groups[0][0]
|
||||||
|
|
||||||
|
logging.error("Guild search index did not become available in time.")
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def fetch_messages_to_delete(headers: dict, channel_id: str, webhook_id: str, author_id: str, cutoff: int, last_message_id: str | None = None, content_pattern: str | None = None) -> tuple[list[dict], str | None]:
|
||||||
|
"""
|
||||||
|
Fetch messages from the channel that are older than the cutoff timestamp and sent by the webhook.
|
||||||
|
Uses pagination with the 'before' parameter to resume from the last processed message.
|
||||||
|
Returns a tuple of (list of messages to delete, last message ID for pagination).
|
||||||
|
"""
|
||||||
|
url = f"https://discord.com/api/v10/channels/{channel_id}/messages"
|
||||||
|
params: dict[str, str | int] = {
|
||||||
|
"limit": 100,
|
||||||
|
}
|
||||||
|
|
||||||
|
if last_message_id:
|
||||||
|
params["before"] = last_message_id
|
||||||
|
|
||||||
|
try:
|
||||||
|
for _ in range(3):
|
||||||
|
response = requests.get(url, headers=headers,
|
||||||
|
params=params, timeout=10)
|
||||||
|
|
||||||
|
if response.status_code == 429:
|
||||||
|
retry_after = get_rate_limit_retry_after(response) or 1.0
|
||||||
|
sleep_for_rate_limit(retry_after, "channel message fetch")
|
||||||
|
continue
|
||||||
|
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
logging.error(
|
||||||
|
"Failed to fetch messages after repeated rate limits.")
|
||||||
|
return [], last_message_id
|
||||||
|
|
||||||
|
if response.status_code == 200:
|
||||||
|
messages = response.json()
|
||||||
|
delete_list = []
|
||||||
|
new_last_message_id = None
|
||||||
|
|
||||||
|
for message in messages:
|
||||||
|
new_last_message_id = message.get("id")
|
||||||
|
|
||||||
|
if should_delete_message(
|
||||||
|
message,
|
||||||
|
webhook_id,
|
||||||
|
author_id,
|
||||||
|
cutoff,
|
||||||
|
content_pattern,
|
||||||
|
):
|
||||||
|
delete_list.append(build_delete_entry(message))
|
||||||
|
|
||||||
|
if len(delete_list) >= 100:
|
||||||
|
break
|
||||||
|
|
||||||
|
return delete_list, new_last_message_id
|
||||||
|
|
||||||
|
logging.error(
|
||||||
|
f"Failed to fetch messages: {response.status_code} - {response.text}")
|
||||||
|
return [], last_message_id
|
||||||
|
except requests.RequestException as e:
|
||||||
|
logging.error(f"Error fetching messages: {e}")
|
||||||
|
return [], last_message_id
|
||||||
|
|
||||||
|
|
||||||
|
def delete_message(headers: dict, channel_id: str, message_id: str) -> tuple[bool, float | None, bool]:
|
||||||
|
"""
|
||||||
|
Delete a single message from the channel.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple[bool, float | None, bool]:
|
||||||
|
- whether the delete succeeded
|
||||||
|
- how long to wait before the next request, if any
|
||||||
|
- whether to abort the batch because further requests would be invalid
|
||||||
|
"""
|
||||||
|
delete_url = f"https://discord.com/api/v10/channels/{channel_id}/messages/{message_id}"
|
||||||
|
delete_response = requests.delete(delete_url, headers=headers, timeout=10)
|
||||||
|
|
||||||
|
if delete_response.status_code == 204:
|
||||||
|
return True, get_bucket_exhausted_delay(delete_response), False
|
||||||
|
|
||||||
|
if delete_response.status_code == 429:
|
||||||
|
retry_after = get_rate_limit_retry_after(delete_response) or 1.0
|
||||||
|
scope = delete_response.headers.get("X-RateLimit-Scope", "unknown")
|
||||||
|
is_global = delete_response.headers.get(
|
||||||
|
"X-RateLimit-Global", "false").lower() == "true"
|
||||||
|
logging.warning(
|
||||||
|
"Discord rate limit hit while deleting message %s: scope=%s global=%s retry_after=%.3fs",
|
||||||
|
message_id,
|
||||||
|
scope,
|
||||||
|
is_global,
|
||||||
|
retry_after,
|
||||||
|
)
|
||||||
|
return False, retry_after, False
|
||||||
|
|
||||||
|
if delete_response.status_code in {401, 403}:
|
||||||
|
logging.error(
|
||||||
|
"Failed to delete message %s: %s - %s. Stopping deletes to avoid invalid request spam.",
|
||||||
|
message_id,
|
||||||
|
delete_response.status_code,
|
||||||
|
delete_response.text,
|
||||||
|
)
|
||||||
|
return False, None, True
|
||||||
|
|
||||||
|
logging.error(
|
||||||
|
f"Failed to delete message {message_id}: {delete_response.status_code} - {delete_response.text}")
|
||||||
|
return False, None, False
|
||||||
|
|
||||||
|
|
||||||
|
def delete_old_messages(minutes: int = 6, content_pattern: str | None = None) -> None:
|
||||||
|
"""
|
||||||
|
Delete all messages sent by the webhook in the last `minutes` minutes.
|
||||||
|
Uses a dynamic slowdown to avoid hitting Discord API rate limits and pagination to fetch all messages.
|
||||||
|
"""
|
||||||
|
discord_bot_token = os.getenv("DISCORD_BOT_TOKEN")
|
||||||
|
discord_channel_id = os.getenv("DISCORD_CHANNEL_ID")
|
||||||
|
guild_id = os.getenv("DISCORD_GUILD_ID")
|
||||||
|
|
||||||
|
if not discord_bot_token or not discord_channel_id or not guild_id:
|
||||||
|
logging.error(
|
||||||
|
"DISCORD_BOT_TOKEN, DISCORD_CHANNEL_ID, or DISCORD_GUILD_ID not set")
|
||||||
|
return
|
||||||
|
|
||||||
|
headers = get_discord_headers()
|
||||||
|
|
||||||
|
cutoff_timestamp = datetime.now() - timedelta(minutes=minutes)
|
||||||
|
cutoff = int(cutoff_timestamp.timestamp())
|
||||||
|
webhook_id = WEBHOOK_AUTHOR_ID
|
||||||
|
author_id = WEBHOOK_AUTHOR_ID
|
||||||
|
|
||||||
|
last_author_message = find_last_message_by_author(
|
||||||
|
headers,
|
||||||
|
guild_id,
|
||||||
|
discord_channel_id,
|
||||||
|
author_id,
|
||||||
|
)
|
||||||
|
if last_author_message is None:
|
||||||
|
logging.info("No indexed messages found for the target author.")
|
||||||
|
return
|
||||||
|
|
||||||
|
last_message_id = last_author_message.get("id")
|
||||||
|
if not last_message_id:
|
||||||
|
logging.info("Search result did not contain a message id.")
|
||||||
|
return
|
||||||
|
|
||||||
|
deleted_count = 0
|
||||||
|
|
||||||
|
if should_delete_message(
|
||||||
|
last_author_message,
|
||||||
|
webhook_id,
|
||||||
|
author_id,
|
||||||
|
cutoff,
|
||||||
|
content_pattern,
|
||||||
|
):
|
||||||
|
anchor_message = build_delete_entry(last_author_message)
|
||||||
|
deleted, wait_seconds, abort_batch = delete_message(
|
||||||
|
headers,
|
||||||
|
discord_channel_id,
|
||||||
|
anchor_message["id"],
|
||||||
|
)
|
||||||
|
if deleted:
|
||||||
|
deleted_count += 1
|
||||||
|
logging.info(
|
||||||
|
f"Deleted message {anchor_message['id']} from {anchor_message['timestamp'].isoformat()}"
|
||||||
|
)
|
||||||
|
elif abort_batch:
|
||||||
|
return
|
||||||
|
|
||||||
|
if wait_seconds is not None:
|
||||||
|
sleep_for_rate_limit(wait_seconds, "delete bucket")
|
||||||
|
|
||||||
|
while True:
|
||||||
|
delete_list, next_last_message_id = fetch_messages_to_delete(
|
||||||
|
headers,
|
||||||
|
discord_channel_id,
|
||||||
|
webhook_id,
|
||||||
|
author_id,
|
||||||
|
cutoff,
|
||||||
|
last_message_id,
|
||||||
|
content_pattern,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not delete_list:
|
||||||
|
if deleted_count == 0:
|
||||||
|
logging.info("No messages to delete.")
|
||||||
|
else:
|
||||||
|
logging.info("No more messages to delete.")
|
||||||
|
break
|
||||||
|
|
||||||
|
for message in delete_list:
|
||||||
|
message_id = message["id"]
|
||||||
|
message_time = message["timestamp"]
|
||||||
|
|
||||||
|
deleted, wait_seconds, abort_batch = delete_message(
|
||||||
|
headers,
|
||||||
|
discord_channel_id,
|
||||||
|
message_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
if deleted:
|
||||||
|
deleted_count += 1
|
||||||
|
logging.info(
|
||||||
|
f"Deleted message {message_id} from {message_time.isoformat()}")
|
||||||
|
elif abort_batch:
|
||||||
|
logging.info(
|
||||||
|
"Stopping delete batch after an invalid Discord response.")
|
||||||
|
return
|
||||||
|
|
||||||
|
if wait_seconds is not None:
|
||||||
|
sleep_for_rate_limit(wait_seconds, "delete bucket")
|
||||||
|
|
||||||
|
if next_last_message_id is None or next_last_message_id == last_message_id:
|
||||||
|
break
|
||||||
|
last_message_id = next_last_message_id
|
||||||
|
|
||||||
|
logging.info(
|
||||||
|
f"Deleted {deleted_count} messages older than {minutes} minutes sent by the webhook.")
|
||||||
@@ -0,0 +1,11 @@
|
|||||||
|
[phases.install]
|
||||||
|
|
||||||
|
[phases.test]
|
||||||
|
dependsOn = ["install"]
|
||||||
|
cmds = ["pytest --maxfail=1 --disable-warnings -q"]
|
||||||
|
|
||||||
|
[start]
|
||||||
|
cmd = "python main.py"
|
||||||
|
|
||||||
|
[variables]
|
||||||
|
PORT = "8420"
|
||||||
@@ -1,5 +1,3 @@
|
|||||||
numpy
|
|
||||||
pandas
|
|
||||||
pytest
|
pytest
|
||||||
python-dotenv
|
python-dotenv
|
||||||
pytz
|
pytz
|
||||||
|
|||||||
+8
-4
@@ -1,20 +1,24 @@
|
|||||||
{
|
{
|
||||||
"420": {
|
"420": {
|
||||||
"color": 3066993,
|
"color": "#2ECC71",
|
||||||
"image_url": "https://copyparty.allucanget.biz/img/weed.png",
|
"image_url": "https://copyparty.allucanget.biz/img/weed.png",
|
||||||
"text": "Blaze it!"
|
"text": "Blaze it!"
|
||||||
},
|
},
|
||||||
"halftime": {
|
"halftime": {
|
||||||
"color": 3066993,
|
"color": "#2ECC71",
|
||||||
"image_url": "https://copyparty.allucanget.biz/img/weed.png",
|
"image_url": "https://copyparty.allucanget.biz/img/weed.png",
|
||||||
"text": "Half-time!"
|
"text": "Half-time!"
|
||||||
},
|
},
|
||||||
"reminder": {
|
"reminder": {
|
||||||
"color": 15105570,
|
"color": "#E67E22",
|
||||||
"text": "This is your 5 minute reminder to 420!"
|
"text": "This is your 5 minute reminder to 420!"
|
||||||
},
|
},
|
||||||
"reminder_halftime": {
|
"reminder_halftime": {
|
||||||
"color": 15105570,
|
"color": "#E67E22",
|
||||||
"text": "Half-time in 5 minutes!"
|
"text": "Half-time in 5 minutes!"
|
||||||
|
},
|
||||||
|
"test": {
|
||||||
|
"color": "#3498DB",
|
||||||
|
"text": "This is a test notification."
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
+25
-9
@@ -6,12 +6,8 @@ from pathlib import Path
|
|||||||
|
|
||||||
|
|
||||||
DEFAULT_TEMPLATES: dict[str, dict] = {
|
DEFAULT_TEMPLATES: dict[str, dict] = {
|
||||||
"reminder_halftime": {
|
"420": {
|
||||||
"text": "Half-time in 5 minutes!",
|
"text": "Blaze it!",
|
||||||
"color": 0xE67E22,
|
|
||||||
},
|
|
||||||
"halftime": {
|
|
||||||
"text": "Half-time!",
|
|
||||||
"color": 0x2ECC71,
|
"color": 0x2ECC71,
|
||||||
"image_url": "https://copyparty.allucanget.biz/img/weed.png",
|
"image_url": "https://copyparty.allucanget.biz/img/weed.png",
|
||||||
},
|
},
|
||||||
@@ -19,11 +15,19 @@ DEFAULT_TEMPLATES: dict[str, dict] = {
|
|||||||
"text": "This is your 5 minute reminder to 420!",
|
"text": "This is your 5 minute reminder to 420!",
|
||||||
"color": 0xE67E22,
|
"color": 0xE67E22,
|
||||||
},
|
},
|
||||||
"420": {
|
"halftime": {
|
||||||
"text": "Blaze it!",
|
"text": "Half-time!",
|
||||||
"color": 0x2ECC71,
|
"color": 0x2ECC71,
|
||||||
"image_url": "https://copyparty.allucanget.biz/img/weed.png",
|
"image_url": "https://copyparty.allucanget.biz/img/weed.png",
|
||||||
},
|
},
|
||||||
|
"reminder_halftime": {
|
||||||
|
"text": "Half-time in 5 minutes!",
|
||||||
|
"color": 0xE67E22,
|
||||||
|
},
|
||||||
|
"test": {
|
||||||
|
"text": "This is a test notification.",
|
||||||
|
"color": 0x3498DB,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@@ -45,6 +49,11 @@ def _normalize_templates(raw: dict) -> dict[str, dict]:
|
|||||||
color = incoming.get("color")
|
color = incoming.get("color")
|
||||||
if isinstance(color, int):
|
if isinstance(color, int):
|
||||||
out[key]["color"] = color
|
out[key]["color"] = color
|
||||||
|
elif isinstance(color, str):
|
||||||
|
try:
|
||||||
|
out[key]["color"] = parse_color(color)
|
||||||
|
except ValueError:
|
||||||
|
pass
|
||||||
|
|
||||||
image_url = incoming.get("image_url")
|
image_url = incoming.get("image_url")
|
||||||
if isinstance(image_url, str) and image_url.strip():
|
if isinstance(image_url, str) and image_url.strip():
|
||||||
@@ -70,10 +79,17 @@ def load_templates(path: str | Path) -> dict[str, dict]:
|
|||||||
def save_templates(path: str | Path, templates: dict) -> None:
|
def save_templates(path: str | Path, templates: dict) -> None:
|
||||||
p = Path(path)
|
p = Path(path)
|
||||||
normalized = _normalize_templates(templates)
|
normalized = _normalize_templates(templates)
|
||||||
|
serialized = deepcopy(normalized)
|
||||||
|
|
||||||
|
for tpl in serialized.values():
|
||||||
|
color = tpl.get("color")
|
||||||
|
if isinstance(color, int):
|
||||||
|
tpl["color"] = f"#{color:06X}"
|
||||||
|
|
||||||
p.parent.mkdir(parents=True, exist_ok=True)
|
p.parent.mkdir(parents=True, exist_ok=True)
|
||||||
tmp = p.with_suffix(p.suffix + ".tmp")
|
tmp = p.with_suffix(p.suffix + ".tmp")
|
||||||
tmp.write_text(json.dumps(normalized, indent=2, sort_keys=True) + "\n", encoding="utf-8")
|
tmp.write_text(json.dumps(serialized, indent=2,
|
||||||
|
sort_keys=True) + "\n", encoding="utf-8")
|
||||||
tmp.replace(p)
|
tmp.replace(p)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
+52
-30
@@ -1,6 +1,6 @@
|
|||||||
import pytz
|
import pytz
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
import pandas as pd
|
from csv import DictReader
|
||||||
|
|
||||||
|
|
||||||
def get_tz_info(tz_name: str, timezones: list[dict]) -> dict | None:
|
def get_tz_info(tz_name: str, timezones: list[dict]) -> dict | None:
|
||||||
@@ -41,52 +41,74 @@ def load_tz_file():
|
|||||||
"abbreviation", "time_start", "gmt_offset", "dst"]
|
"abbreviation", "time_start", "gmt_offset", "dst"]
|
||||||
# columns to load
|
# columns to load
|
||||||
load_columns = ["zone_name", "country_code"]
|
load_columns = ["zone_name", "country_code"]
|
||||||
# read csv with pandas
|
# read csv
|
||||||
df = pd.read_csv(timezone_file, names=timezone_names)
|
with open(timezone_file, newline='') as csvfile:
|
||||||
|
reader = DictReader(csvfile, fieldnames=timezone_names)
|
||||||
|
csv = [row for row in reader]
|
||||||
# drop all columns except load_columns
|
# drop all columns except load_columns
|
||||||
df = df[load_columns]
|
csv = [{k: v for k, v in row.items() if k in load_columns} for row in csv]
|
||||||
# distinct zone_names
|
# distinct zone_names
|
||||||
df = df.drop_duplicates(subset=["zone_name"])
|
seen = set()
|
||||||
|
unique_csv = []
|
||||||
|
for row in csv:
|
||||||
|
if row["zone_name"] not in seen:
|
||||||
|
seen.add(row["zone_name"])
|
||||||
|
unique_csv.append(row)
|
||||||
|
csv = unique_csv
|
||||||
|
|
||||||
# reset index
|
return csv
|
||||||
df = df.reset_index(drop=True)
|
|
||||||
|
|
||||||
return df
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
|
|
||||||
# read csv with pandas
|
# read csv file and load timezones and countries
|
||||||
df_file = load_tz_file()
|
csv = load_tz_file()
|
||||||
|
|
||||||
# split zone_name into components by "/"
|
# split zone_name into components by "/"
|
||||||
df_file[['region', 'city']] = df_file['zone_name'].str.split(
|
for row in csv:
|
||||||
'/', expand=True, n=1)
|
parts = row["zone_name"].split("/", 1)
|
||||||
|
row["region"] = parts[0]
|
||||||
|
row["city"] = parts[1] if len(parts) > 1 else None
|
||||||
# drop regions with no country_code (like Etc, GMT, etc)
|
# drop regions with no country_code (like Etc, GMT, etc)
|
||||||
df_file = df_file[df_file['country_code'].notna()]
|
csv = [row for row in csv if row["country_code"]]
|
||||||
|
|
||||||
|
# get all timezones from pytz and split into region and city
|
||||||
|
|
||||||
|
tz = [{"zone_name": tz} for tz in pytz.all_timezones]
|
||||||
|
|
||||||
df_tz = pd.DataFrame(pytz.all_timezones)
|
|
||||||
# rename column to zone_name
|
|
||||||
df_tz = df_tz.rename(columns={0: 'zone_name'})
|
|
||||||
# split zone_name into components by "/"
|
# split zone_name into components by "/"
|
||||||
df_tz[['region', 'city']] = df_tz['zone_name'].str.split(
|
for row in tz:
|
||||||
'/', expand=True, n=1)
|
parts = row["zone_name"].split("/", 1)
|
||||||
|
row["region"] = parts[0]
|
||||||
|
row["city"] = parts[1] if len(parts) > 1 else None
|
||||||
# drop regions with no city (like UTC, GMT, etc)
|
# drop regions with no city (like UTC, GMT, etc)
|
||||||
df_tz = df_tz[df_tz['city'].notna()]
|
tz = [row for row in tz if row["city"]]
|
||||||
# drop rows where region is 'Etc'
|
# drop rows where region is 'Etc'
|
||||||
df_tz = df_tz[df_tz['region'] != 'Etc']
|
tz = [row for row in tz if row["region"] != "Etc"]
|
||||||
|
|
||||||
|
# join data on region and city
|
||||||
|
timezones = []
|
||||||
|
for tz_row in tz:
|
||||||
|
for csv_row in csv:
|
||||||
|
if tz_row["region"] == csv_row["region"] and tz_row["city"] == csv_row["city"]:
|
||||||
|
timezones.append({
|
||||||
|
"zone_name": tz_row["zone_name"],
|
||||||
|
"country_code": csv_row["country_code"],
|
||||||
|
"region": tz_row["region"],
|
||||||
|
"city": tz_row["city"],
|
||||||
|
})
|
||||||
|
break
|
||||||
|
|
||||||
# join dataframes on region and city
|
|
||||||
df_merged = pd.merge(df_file, df_tz, on=[
|
|
||||||
'region', 'city'], how='inner', indicator=True)
|
|
||||||
# reorder columns
|
# reorder columns
|
||||||
df_merged = df_merged[['region', 'city', 'country_code']]
|
timezones = [{k: row[k] for k in ['region', 'city', 'country_code']}
|
||||||
# print merged dataframe
|
for row in timezones]
|
||||||
print(f"Merged timezones: {len(df_merged)}")
|
|
||||||
print(df_merged.sample(20).to_string(index=False))
|
# print merged data
|
||||||
regions = df_merged['region'].unique()
|
print(f"Merged timezones: {len(timezones)}")
|
||||||
|
print(timezones[:20])
|
||||||
|
regions = set(row['region'] for row in timezones)
|
||||||
for region in regions:
|
for region in regions:
|
||||||
df_region = df_merged[df_merged['region'] == region]
|
df_region = [row for row in timezones if row['region'] == region]
|
||||||
print(f"{len(df_region)} merged in {region}")
|
print(f"{len(df_region)} merged in {region}")
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -0,0 +1,145 @@
|
|||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
import dashboard
|
||||||
|
|
||||||
|
|
||||||
|
def _make_app():
|
||||||
|
return dashboard.create_app(
|
||||||
|
get_state=lambda: {
|
||||||
|
"running": True,
|
||||||
|
"started_at": datetime(2026, 1, 1, 10, 0, 0),
|
||||||
|
"last_type": "420",
|
||||||
|
"last_attempt_at": datetime(2026, 1, 1, 10, 15, 0),
|
||||||
|
"last_success_at": datetime(2026, 1, 1, 10, 20, 0),
|
||||||
|
"last_status_code": 204,
|
||||||
|
"last_error": None,
|
||||||
|
"last_locations": ["Nowhere"],
|
||||||
|
},
|
||||||
|
get_next_event=lambda: {
|
||||||
|
"type": "reminder",
|
||||||
|
"at": datetime(2026, 1, 1, 11, 15, 0),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_fmt_dt_none_and_datetime():
|
||||||
|
assert dashboard._fmt_dt(None) == "—"
|
||||||
|
assert dashboard._fmt_dt(
|
||||||
|
datetime(2026, 1, 1, 10, 0, 0)) == "2026-01-01T10:00:00"
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_html_template_wraps_content():
|
||||||
|
html = dashboard.get_html_template("<p>hello</p>")
|
||||||
|
assert "<title>thc-webhook admin</title>" in html
|
||||||
|
assert "<p>hello</p>" in html
|
||||||
|
|
||||||
|
|
||||||
|
def test_index_route_renders_status_page():
|
||||||
|
app = _make_app()
|
||||||
|
client = app.test_client()
|
||||||
|
|
||||||
|
response = client.get("/")
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
body = response.get_data(as_text=True)
|
||||||
|
assert "thc-webhook" in body
|
||||||
|
assert "last_type: 420" in body
|
||||||
|
assert "type: reminder" in body
|
||||||
|
assert "Nowhere" in body
|
||||||
|
|
||||||
|
|
||||||
|
def test_admin_get_renders_template_form(monkeypatch):
|
||||||
|
monkeypatch.setattr(
|
||||||
|
dashboard,
|
||||||
|
"load_templates",
|
||||||
|
lambda path: {
|
||||||
|
"420": {
|
||||||
|
"text": "Blaze",
|
||||||
|
"color": 3066993,
|
||||||
|
"image_url": "https://example.com/img.png",
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
app = _make_app()
|
||||||
|
app.config["TEMPLATES_PATH"] = "templates.json"
|
||||||
|
client = app.test_client()
|
||||||
|
|
||||||
|
response = client.get("/admin")
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
body = response.get_data(as_text=True)
|
||||||
|
assert "Admin: templates" in body
|
||||||
|
assert "name='420__text'" in body
|
||||||
|
assert "name='420__color'" in body
|
||||||
|
assert "name='420__color_picker'" in body
|
||||||
|
assert "type='color'" in body
|
||||||
|
assert "name='420__image_url'" in body
|
||||||
|
|
||||||
|
|
||||||
|
def test_admin_post_validation_error(monkeypatch):
|
||||||
|
monkeypatch.setattr(
|
||||||
|
dashboard,
|
||||||
|
"load_templates",
|
||||||
|
lambda path: {"420": {"text": "x", "color": 1}},
|
||||||
|
)
|
||||||
|
monkeypatch.setattr(dashboard, "parse_color", lambda raw: (
|
||||||
|
_ for _ in ()).throw(ValueError("bad color")))
|
||||||
|
|
||||||
|
save_called = {"value": False}
|
||||||
|
|
||||||
|
def _save_templates(path, updated):
|
||||||
|
save_called["value"] = True
|
||||||
|
|
||||||
|
monkeypatch.setattr(dashboard, "save_templates", _save_templates)
|
||||||
|
|
||||||
|
app = _make_app()
|
||||||
|
client = app.test_client()
|
||||||
|
|
||||||
|
response = client.post(
|
||||||
|
"/admin",
|
||||||
|
data={
|
||||||
|
"420__text": "Updated",
|
||||||
|
"420__color": "bad",
|
||||||
|
"420__image_url": "",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 400
|
||||||
|
assert "invalid color" in response.get_data(as_text=True)
|
||||||
|
assert save_called["value"] is False
|
||||||
|
|
||||||
|
|
||||||
|
def test_admin_post_success(monkeypatch):
|
||||||
|
monkeypatch.setattr(
|
||||||
|
dashboard,
|
||||||
|
"load_templates",
|
||||||
|
lambda path: {"420": {"text": "x", "color": 1}},
|
||||||
|
)
|
||||||
|
monkeypatch.setattr(dashboard, "parse_color", lambda raw: 123)
|
||||||
|
|
||||||
|
saved = {"path": None, "payload": None}
|
||||||
|
|
||||||
|
def _save_templates(path, updated):
|
||||||
|
saved["path"] = path
|
||||||
|
saved["payload"] = updated
|
||||||
|
|
||||||
|
monkeypatch.setattr(dashboard, "save_templates", _save_templates)
|
||||||
|
|
||||||
|
app = _make_app()
|
||||||
|
app.config["TEMPLATES_PATH"] = "custom_templates.json"
|
||||||
|
client = app.test_client()
|
||||||
|
|
||||||
|
response = client.post(
|
||||||
|
"/admin",
|
||||||
|
data={
|
||||||
|
"420__text": "Updated",
|
||||||
|
"420__color": "123",
|
||||||
|
"420__image_url": "",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert "Saved." in response.get_data(as_text=True)
|
||||||
|
assert saved["path"] == "custom_templates.json"
|
||||||
|
assert saved["payload"] == {"420": {"text": "Updated", "color": 123}}
|
||||||
@@ -1,135 +0,0 @@
|
|||||||
import io
|
|
||||||
import time
|
|
||||||
from unittest import mock
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
import main
|
|
||||||
|
|
||||||
|
|
||||||
SAMPLE_TIMEZONE_CSV = """Etc/UTC,ZZ,UTC,0,0,0
|
|
||||||
America/New_York,US,EST,0,-18000,0
|
|
||||||
Europe/London,GB,BST,0,0,1
|
|
||||||
"""
|
|
||||||
|
|
||||||
SAMPLE_COUNTRY_CSV = """ZZ,Unknown
|
|
||||||
US,United States
|
|
||||||
GB,United Kingdom
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
def test_load_timezones_and_countries(monkeypatch):
|
|
||||||
tzs = main.load_timezones()
|
|
||||||
countries = main.load_countries()
|
|
||||||
assert any(t['zone_name'] == 'America/New_York' for t in tzs)
|
|
||||||
assert any(c['country_code'] == 'US' for c in countries)
|
|
||||||
|
|
||||||
|
|
||||||
def test_get_tz_and_country_info():
|
|
||||||
timezones = [{'zone_name': 'A/B', 'country_code': 'US'}]
|
|
||||||
countries = [{'country_code': 'US', 'country_name': 'United States'}]
|
|
||||||
assert main.get_tz_info('A/B', timezones)['zone_name'] == 'A/B'
|
|
||||||
assert main.get_country_info('US', countries)[
|
|
||||||
'country_name'] == 'United States'
|
|
||||||
assert main.get_tz_info('X/Y', timezones) is None
|
|
||||||
assert main.get_country_info('XX', countries) is None
|
|
||||||
|
|
||||||
|
|
||||||
def test_create_embed_all_types(monkeypatch):
|
|
||||||
# Prevent create_embed from trying to read actual CSV files by patching loaders
|
|
||||||
monkeypatch.setattr(main, 'load_timezones', lambda: [
|
|
||||||
{'zone_name': 'Etc/UTC', 'country_code': 'ZZ'}])
|
|
||||||
monkeypatch.setattr(main, 'load_countries', lambda: [
|
|
||||||
{'country_code': 'ZZ', 'country_name': 'Nowhere'}])
|
|
||||||
|
|
||||||
# reminder
|
|
||||||
emb = main.create_embed('reminder')
|
|
||||||
assert emb['title'] == 'Reminder'
|
|
||||||
assert '5 minute' in emb['description']
|
|
||||||
assert emb['color'] == 0xe67e22
|
|
||||||
|
|
||||||
# reminder_halftime
|
|
||||||
emb = main.create_embed('reminder_halftime')
|
|
||||||
assert emb['title'] == 'Reminder halftime'
|
|
||||||
assert 'Half-time in 5 minutes' in emb['description']
|
|
||||||
|
|
||||||
# halftime (should include image)
|
|
||||||
monkeypatch.setattr(main, 'where_is_it_420', lambda tzs, cs, **kwargs: [])
|
|
||||||
emb = main.create_embed('halftime')
|
|
||||||
assert emb['title'] == 'Halftime'
|
|
||||||
assert emb['image'] is not None
|
|
||||||
|
|
||||||
# 420 (should include image and appended tz info string when list empty)
|
|
||||||
monkeypatch.setattr(main, 'where_is_it_420', lambda tzs, cs, **kwargs: [])
|
|
||||||
emb = main.create_embed('420')
|
|
||||||
assert emb['title'] == '420'
|
|
||||||
assert emb['image'] is not None
|
|
||||||
|
|
||||||
# unknown
|
|
||||||
emb = main.create_embed('nope')
|
|
||||||
assert emb['description'] == 'Unknown notification type'
|
|
||||||
|
|
||||||
|
|
||||||
def test_where_is_it_420(monkeypatch):
|
|
||||||
# Limit timezones to a predictable set
|
|
||||||
monkeypatch.setattr(main.pytz, 'all_timezones', ['Etc/UTC'])
|
|
||||||
|
|
||||||
tzs = [{'zone_name': 'Etc/UTC', 'country_code': 'ZZ'}]
|
|
||||||
countries = [{'country_code': 'ZZ', 'country_name': 'Nowhere'}]
|
|
||||||
|
|
||||||
class FakeDatetime:
|
|
||||||
@staticmethod
|
|
||||||
def now(tz):
|
|
||||||
class R:
|
|
||||||
hour = 4
|
|
||||||
return R()
|
|
||||||
|
|
||||||
monkeypatch.setattr(main, 'datetime', FakeDatetime)
|
|
||||||
monkeypatch.setattr(main, 'get_tz_info', lambda name,
|
|
||||||
t: tzs[0] if name == 'Etc/UTC' else None)
|
|
||||||
monkeypatch.setattr(main, 'get_country_info', lambda code,
|
|
||||||
c: countries[0] if code == 'ZZ' else None)
|
|
||||||
|
|
||||||
res = main.where_is_it_420(tzs, countries)
|
|
||||||
assert res == ['Nowhere']
|
|
||||||
|
|
||||||
|
|
||||||
def test_main_exits_quickly(monkeypatch):
|
|
||||||
# Patch send_notification so it doesn't perform network
|
|
||||||
monkeypatch.setattr(main, 'send_notification', lambda x: None)
|
|
||||||
# Don't start dashboard during this test
|
|
||||||
monkeypatch.setattr(main, 'start_dashboard', lambda: None)
|
|
||||||
# Make schedule.run_pending raise KeyboardInterrupt to exit loop
|
|
||||||
monkeypatch.setattr(main.schedule, 'run_pending', lambda: (
|
|
||||||
_ for _ in ()).throw(KeyboardInterrupt()))
|
|
||||||
# Patch time.sleep to no-op
|
|
||||||
monkeypatch.setattr(main.time, 'sleep', lambda s: None)
|
|
||||||
# Ensure WEBHOOK_URL present to avoid early return
|
|
||||||
monkeypatch.setenv('DISCORD_WEBHOOK_URL', 'http://example.com/webhook')
|
|
||||||
main.WEBHOOK_URL = 'http://example.com/webhook'
|
|
||||||
|
|
||||||
# Should exit quickly due to KeyboardInterrupt from run_pending
|
|
||||||
main.main()
|
|
||||||
|
|
||||||
|
|
||||||
def test_get_next_scheduled_event():
|
|
||||||
# 10:14 -> next is 10:15 reminder
|
|
||||||
now = main.datetime(2025, 1, 1, 10, 14, 30)
|
|
||||||
nxt = main.get_next_scheduled_event(now)
|
|
||||||
assert nxt["type"] == "reminder"
|
|
||||||
assert nxt["at"].hour == 10 and nxt["at"].minute == 15
|
|
||||||
|
|
||||||
# 10:50:01 -> next is 11:15 reminder
|
|
||||||
now = main.datetime(2025, 1, 1, 10, 50, 1)
|
|
||||||
nxt = main.get_next_scheduled_event(now)
|
|
||||||
assert nxt["type"] == "reminder"
|
|
||||||
assert nxt["at"].hour == 11 and nxt["at"].minute == 15
|
|
||||||
|
|
||||||
|
|
||||||
def test_split_tz_name():
|
|
||||||
assert main.split_tz_name("America/New_York") == ("America", "New_York")
|
|
||||||
assert main.split_tz_name("America/Argentina/Buenos_Aires") == (
|
|
||||||
"America",
|
|
||||||
"Argentina/Buenos_Aires",
|
|
||||||
)
|
|
||||||
assert main.split_tz_name("UTC") == ("UTC", "")
|
|
||||||
@@ -0,0 +1,32 @@
|
|||||||
|
import main
|
||||||
|
|
||||||
|
|
||||||
|
def test_create_embed_all_types(monkeypatch):
|
||||||
|
monkeypatch.setattr(main, "load_timezones", lambda: [
|
||||||
|
{"zone_name": "Etc/UTC", "country_code": "ZZ"}
|
||||||
|
])
|
||||||
|
monkeypatch.setattr(main, "load_countries", lambda: [
|
||||||
|
{"country_code": "ZZ", "country_name": "Nowhere"}
|
||||||
|
])
|
||||||
|
|
||||||
|
emb = main.create_embed("reminder")
|
||||||
|
assert emb["title"] == "Reminder"
|
||||||
|
assert "5 minute" in emb["description"]
|
||||||
|
assert emb["color"] == 0xE67E22
|
||||||
|
|
||||||
|
emb = main.create_embed("reminder_halftime")
|
||||||
|
assert emb["title"] == "Reminder halftime"
|
||||||
|
assert "Half-time in 5 minutes" in emb["description"]
|
||||||
|
|
||||||
|
monkeypatch.setattr(main, "where_is_it_420", lambda tzs, cs, **kwargs: [])
|
||||||
|
emb = main.create_embed("halftime")
|
||||||
|
assert emb["title"] == "Halftime"
|
||||||
|
assert emb["image"] is not None
|
||||||
|
|
||||||
|
monkeypatch.setattr(main, "where_is_it_420", lambda tzs, cs, **kwargs: [])
|
||||||
|
emb = main.create_embed("420")
|
||||||
|
assert emb["title"] == "420"
|
||||||
|
assert emb["image"] is not None
|
||||||
|
|
||||||
|
emb = main.create_embed("nope")
|
||||||
|
assert emb["description"] == "Unknown notification type"
|
||||||
@@ -0,0 +1,123 @@
|
|||||||
|
import main
|
||||||
|
|
||||||
|
|
||||||
|
def test_main_exits_quickly(monkeypatch):
|
||||||
|
monkeypatch.setattr(main, "send_notification", lambda x: None)
|
||||||
|
monkeypatch.setattr(main, "start_dashboard", lambda: None)
|
||||||
|
monkeypatch.setattr(main.schedule, "run_pending", lambda: (
|
||||||
|
_ for _ in ()).throw(KeyboardInterrupt()))
|
||||||
|
monkeypatch.setattr(main.time, "sleep", lambda s: None)
|
||||||
|
monkeypatch.setenv("DISCORD_WEBHOOK_URL", "http://example.com/webhook")
|
||||||
|
main.WEBHOOK_URL = "http://example.com/webhook"
|
||||||
|
|
||||||
|
main.main()
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_next_scheduled_event():
|
||||||
|
now = main.datetime(2025, 1, 1, 10, 14, 30)
|
||||||
|
nxt = main.get_next_scheduled_event(now)
|
||||||
|
assert nxt["type"] == "reminder"
|
||||||
|
assert nxt["at"].hour == 10 and nxt["at"].minute == 15
|
||||||
|
|
||||||
|
now = main.datetime(2025, 1, 1, 10, 50, 1)
|
||||||
|
nxt = main.get_next_scheduled_event(now)
|
||||||
|
assert nxt["type"] == "reminder"
|
||||||
|
assert nxt["at"].hour == 11 and nxt["at"].minute == 15
|
||||||
|
|
||||||
|
|
||||||
|
def test_schedule_startup_test_cleanup_when_sent(monkeypatch):
|
||||||
|
captured: dict[str, object] = {}
|
||||||
|
delete_calls: list[tuple[int, str | None]] = []
|
||||||
|
|
||||||
|
class FakeEvery:
|
||||||
|
@property
|
||||||
|
def minutes(self):
|
||||||
|
return self
|
||||||
|
|
||||||
|
def do(self, fn, *args, **kwargs):
|
||||||
|
captured["job"] = lambda: fn(*args, **kwargs)
|
||||||
|
return object()
|
||||||
|
|
||||||
|
monkeypatch.setattr(main.schedule, "every", lambda n: FakeEvery())
|
||||||
|
monkeypatch.setattr(
|
||||||
|
main,
|
||||||
|
"delete_old_messages",
|
||||||
|
lambda minutes, content_pattern=None: delete_calls.append(
|
||||||
|
(minutes, content_pattern)),
|
||||||
|
)
|
||||||
|
|
||||||
|
main._schedule_startup_test_cleanup(True)
|
||||||
|
|
||||||
|
assert "job" in captured
|
||||||
|
result = captured["job"]()
|
||||||
|
assert result == main.schedule.CancelJob
|
||||||
|
assert delete_calls == [(1, main.TEST_MESSAGE_DELETE_PATTERN)]
|
||||||
|
|
||||||
|
|
||||||
|
def test_schedule_startup_test_cleanup_skips_when_not_sent(monkeypatch):
|
||||||
|
called = {"value": False}
|
||||||
|
|
||||||
|
class FakeEvery:
|
||||||
|
@property
|
||||||
|
def minutes(self):
|
||||||
|
return self
|
||||||
|
|
||||||
|
def do(self, fn, *args, **kwargs):
|
||||||
|
called["value"] = True
|
||||||
|
return object()
|
||||||
|
|
||||||
|
monkeypatch.setattr(main.schedule, "every", lambda n: FakeEvery())
|
||||||
|
|
||||||
|
main._schedule_startup_test_cleanup(False)
|
||||||
|
|
||||||
|
assert called["value"] is False
|
||||||
|
|
||||||
|
|
||||||
|
def test_main_sends_startup_test_and_deletes_it(monkeypatch):
|
||||||
|
send_calls: list[str] = []
|
||||||
|
delete_calls: list[tuple[int, str | None]] = []
|
||||||
|
scheduled_jobs: dict[int, list[object]] = {1: [], 5: []}
|
||||||
|
|
||||||
|
monkeypatch.setattr(main, "start_dashboard", lambda: None)
|
||||||
|
monkeypatch.setattr(main, "schedule_notification",
|
||||||
|
lambda interval, at, type: None)
|
||||||
|
|
||||||
|
class FakeEvery:
|
||||||
|
def __init__(self, minutes_value: int):
|
||||||
|
self.minutes_value = minutes_value
|
||||||
|
|
||||||
|
@property
|
||||||
|
def minutes(self):
|
||||||
|
return self
|
||||||
|
|
||||||
|
def do(self, fn, *args, **kwargs):
|
||||||
|
scheduled_jobs.setdefault(self.minutes_value, []).append(
|
||||||
|
lambda: fn(*args, **kwargs)
|
||||||
|
)
|
||||||
|
return object()
|
||||||
|
|
||||||
|
monkeypatch.setattr(main.schedule, "every", lambda n: FakeEvery(n))
|
||||||
|
|
||||||
|
def fake_send_notification(message: str) -> bool:
|
||||||
|
send_calls.append(message)
|
||||||
|
return True
|
||||||
|
|
||||||
|
def fake_delete_old_messages(minutes: int = 6, content_pattern: str | None = None):
|
||||||
|
delete_calls.append((minutes, content_pattern))
|
||||||
|
|
||||||
|
def fake_run_pending():
|
||||||
|
for job in scheduled_jobs.get(1, []):
|
||||||
|
job()
|
||||||
|
raise KeyboardInterrupt()
|
||||||
|
|
||||||
|
monkeypatch.setattr(main, "send_notification", fake_send_notification)
|
||||||
|
monkeypatch.setattr(main, "delete_old_messages", fake_delete_old_messages)
|
||||||
|
monkeypatch.setattr(main.schedule, "run_pending", fake_run_pending)
|
||||||
|
monkeypatch.setattr(main.time, "sleep", lambda s: None)
|
||||||
|
monkeypatch.setenv("DISCORD_WEBHOOK_URL", "http://example.com/webhook")
|
||||||
|
main.WEBHOOK_URL = "http://example.com/webhook"
|
||||||
|
|
||||||
|
main.main()
|
||||||
|
|
||||||
|
assert send_calls == ["test"]
|
||||||
|
assert delete_calls == [(1, main.TEST_MESSAGE_DELETE_PATTERN)]
|
||||||
@@ -0,0 +1,126 @@
|
|||||||
|
from datetime import datetime, timezone
|
||||||
|
|
||||||
|
import maintenance
|
||||||
|
|
||||||
|
|
||||||
|
class DummyResponse:
|
||||||
|
def __init__(self, headers=None, payload=None):
|
||||||
|
self.headers = headers or {}
|
||||||
|
self._payload = payload
|
||||||
|
|
||||||
|
def json(self):
|
||||||
|
if self._payload is None:
|
||||||
|
raise ValueError("no json")
|
||||||
|
return self._payload
|
||||||
|
|
||||||
|
|
||||||
|
def test_parse_float():
|
||||||
|
assert maintenance.parse_float("1.5") == 1.5
|
||||||
|
assert maintenance.parse_float(2) == 2.0
|
||||||
|
assert maintenance.parse_float(None) is None
|
||||||
|
assert maintenance.parse_float("nope") is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_parse_message_timestamp_and_build_delete_entry():
|
||||||
|
msg = {"id": "42", "timestamp": "2026-01-01T10:00:00Z"}
|
||||||
|
parsed = maintenance.parse_message_timestamp(msg)
|
||||||
|
assert parsed == datetime(2026, 1, 1, 10, 0, 0)
|
||||||
|
|
||||||
|
entry = maintenance.build_delete_entry(msg)
|
||||||
|
assert entry["id"] == "42"
|
||||||
|
assert entry["timestamp"] == parsed
|
||||||
|
|
||||||
|
|
||||||
|
def test_should_delete_message():
|
||||||
|
ts = int(datetime(2026, 1, 1, 10, 0, 0, tzinfo=timezone.utc).timestamp())
|
||||||
|
message = {
|
||||||
|
"timestamp": "2026-01-01T10:00:00Z",
|
||||||
|
"webhook_id": "w",
|
||||||
|
"author": {"id": "a"},
|
||||||
|
}
|
||||||
|
|
||||||
|
assert maintenance.should_delete_message(
|
||||||
|
message,
|
||||||
|
webhook_id="w",
|
||||||
|
author_id="a",
|
||||||
|
cutoff=ts,
|
||||||
|
)
|
||||||
|
assert not maintenance.should_delete_message(
|
||||||
|
message,
|
||||||
|
webhook_id="x",
|
||||||
|
author_id="a",
|
||||||
|
cutoff=ts,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_message_matches_pattern_content_and_embeds():
|
||||||
|
message = {
|
||||||
|
"content": "This is a smoke test payload",
|
||||||
|
"embeds": [
|
||||||
|
{"title": "Reminder", "description": "Half-time in 5 minutes"}
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
assert maintenance.message_matches_pattern(message, r"smoke test")
|
||||||
|
assert maintenance.message_matches_pattern(message, r"half-time")
|
||||||
|
assert not maintenance.message_matches_pattern(message, r"does-not-match")
|
||||||
|
|
||||||
|
|
||||||
|
def test_should_delete_message_with_content_pattern():
|
||||||
|
ts = int(datetime(2026, 1, 1, 10, 0, 0, tzinfo=timezone.utc).timestamp())
|
||||||
|
message = {
|
||||||
|
"timestamp": "2026-01-01T10:00:00Z",
|
||||||
|
"webhook_id": "w",
|
||||||
|
"author": {"id": "a"},
|
||||||
|
"embeds": [{"description": "This is a test notification."}],
|
||||||
|
}
|
||||||
|
|
||||||
|
assert maintenance.should_delete_message(
|
||||||
|
message,
|
||||||
|
webhook_id="w",
|
||||||
|
author_id="a",
|
||||||
|
cutoff=ts,
|
||||||
|
content_pattern=r"test notification",
|
||||||
|
)
|
||||||
|
assert not maintenance.should_delete_message(
|
||||||
|
message,
|
||||||
|
webhook_id="w",
|
||||||
|
author_id="a",
|
||||||
|
cutoff=ts,
|
||||||
|
content_pattern=r"production-only",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_rate_limit_retry_after_header_priority():
|
||||||
|
response = DummyResponse(
|
||||||
|
headers={
|
||||||
|
"Retry-After": "2.5",
|
||||||
|
"X-RateLimit-Reset-After": "10",
|
||||||
|
},
|
||||||
|
payload={"retry_after": "20"},
|
||||||
|
)
|
||||||
|
assert maintenance.get_rate_limit_retry_after(response) == 2.5
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_rate_limit_retry_after_json_fallback():
|
||||||
|
response = DummyResponse(payload={"retry_after": "3"})
|
||||||
|
assert maintenance.get_rate_limit_retry_after(response) == 3.0
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_bucket_exhausted_delay():
|
||||||
|
response = DummyResponse(
|
||||||
|
headers={
|
||||||
|
"X-RateLimit-Remaining": "0",
|
||||||
|
"X-RateLimit-Reset-After": "1.25",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
assert maintenance.get_bucket_exhausted_delay(response) == 1.25
|
||||||
|
|
||||||
|
response_not_exhausted = DummyResponse(
|
||||||
|
headers={
|
||||||
|
"X-RateLimit-Remaining": "1",
|
||||||
|
"X-RateLimit-Reset-After": "1.25",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
assert maintenance.get_bucket_exhausted_delay(
|
||||||
|
response_not_exhausted) is None
|
||||||
@@ -20,6 +20,9 @@ def test_save_and_load_templates_roundtrip(tmp_path):
|
|||||||
}
|
}
|
||||||
save_templates(path, data)
|
save_templates(path, data)
|
||||||
|
|
||||||
|
raw = path.read_text(encoding="utf-8")
|
||||||
|
assert '"color": "#00007B"' in raw
|
||||||
|
|
||||||
loaded = load_templates(path)
|
loaded = load_templates(path)
|
||||||
assert loaded["420"]["text"] == "Custom"
|
assert loaded["420"]["text"] == "Custom"
|
||||||
assert loaded["420"]["color"] == 123
|
assert loaded["420"]["color"] == 123
|
||||||
|
|||||||
@@ -0,0 +1,47 @@
|
|||||||
|
import thctime
|
||||||
|
|
||||||
|
|
||||||
|
def test_load_timezones_and_countries():
|
||||||
|
tzs = thctime.load_timezones()
|
||||||
|
countries = thctime.load_countries()
|
||||||
|
assert any(t["zone_name"] == "America/New_York" for t in tzs)
|
||||||
|
assert any(c["country_code"] == "US" for c in countries)
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_tz_and_country_info():
|
||||||
|
timezones = [{"zone_name": "A/B", "country_code": "US"}]
|
||||||
|
countries = [{"country_code": "US", "country_name": "United States"}]
|
||||||
|
assert thctime.get_tz_info("A/B", timezones)["zone_name"] == "A/B"
|
||||||
|
assert thctime.get_country_info("US", countries)[
|
||||||
|
"country_name"] == "United States"
|
||||||
|
assert thctime.get_tz_info("X/Y", timezones) is None
|
||||||
|
assert thctime.get_country_info("XX", countries) is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_where_is_it_420(monkeypatch):
|
||||||
|
monkeypatch.setattr(thctime.pytz, "all_timezones", ["Etc/UTC"])
|
||||||
|
|
||||||
|
tzs = [{"zone_name": "Etc/UTC", "country_code": "ZZ"}]
|
||||||
|
countries = [{"country_code": "ZZ", "country_name": "Nowhere"}]
|
||||||
|
|
||||||
|
class FakeDatetime:
|
||||||
|
@staticmethod
|
||||||
|
def now(tz):
|
||||||
|
class Result:
|
||||||
|
hour = 4
|
||||||
|
|
||||||
|
return Result()
|
||||||
|
|
||||||
|
monkeypatch.setattr(thctime, "datetime", FakeDatetime)
|
||||||
|
|
||||||
|
res = thctime.where_is_it_420(tzs, countries)
|
||||||
|
assert res == ["Nowhere"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_split_tz_name():
|
||||||
|
assert thctime.split_tz_name("America/New_York") == ("America", "New_York")
|
||||||
|
assert thctime.split_tz_name("America/Argentina/Buenos_Aires") == (
|
||||||
|
"America",
|
||||||
|
"Argentina/Buenos_Aires",
|
||||||
|
)
|
||||||
|
assert thctime.split_tz_name("UTC") == ("UTC", "")
|
||||||
+175
@@ -0,0 +1,175 @@
|
|||||||
|
from datetime import datetime
|
||||||
|
from zoneinfo import ZoneInfo, ZoneInfoNotFoundError
|
||||||
|
|
||||||
|
import pytz
|
||||||
|
|
||||||
|
|
||||||
|
TZDB_CACHE: dict | None = None
|
||||||
|
|
||||||
|
|
||||||
|
def get_tzdb_cache() -> dict | None:
|
||||||
|
return TZDB_CACHE
|
||||||
|
|
||||||
|
|
||||||
|
def init_tzdb_cache() -> dict:
|
||||||
|
"""Initialize a cached lookup structure for tzdb data."""
|
||||||
|
global TZDB_CACHE
|
||||||
|
if TZDB_CACHE is not None:
|
||||||
|
return TZDB_CACHE
|
||||||
|
|
||||||
|
timezones = load_timezones()
|
||||||
|
countries = load_countries()
|
||||||
|
|
||||||
|
tz_to_country_code: dict[str, str] = {}
|
||||||
|
tz_meta: dict[str, dict] = {}
|
||||||
|
for tz in timezones:
|
||||||
|
zone_name = tz.get("zone_name")
|
||||||
|
country_code = tz.get("country_code")
|
||||||
|
if not isinstance(zone_name, str) or not zone_name:
|
||||||
|
continue
|
||||||
|
if not isinstance(country_code, str) or not country_code:
|
||||||
|
continue
|
||||||
|
|
||||||
|
tz_to_country_code[zone_name] = country_code
|
||||||
|
region, city = split_tz_name(zone_name)
|
||||||
|
tz_meta[zone_name] = {
|
||||||
|
"zone_name": zone_name,
|
||||||
|
"country_code": country_code,
|
||||||
|
"region": region,
|
||||||
|
"city": city,
|
||||||
|
}
|
||||||
|
|
||||||
|
country_code_to_name: dict[str, str] = {}
|
||||||
|
for c in countries:
|
||||||
|
code = c.get("country_code")
|
||||||
|
name = c.get("country_name")
|
||||||
|
if code and name:
|
||||||
|
country_code_to_name[code] = str(name).strip().strip('"')
|
||||||
|
|
||||||
|
for zone_name, meta in tz_meta.items():
|
||||||
|
code = meta.get("country_code")
|
||||||
|
if isinstance(code, str):
|
||||||
|
meta["country_name"] = country_code_to_name.get(code)
|
||||||
|
|
||||||
|
tz_names: list[str] = []
|
||||||
|
for zone_name in tz_to_country_code.keys():
|
||||||
|
try:
|
||||||
|
ZoneInfo(zone_name)
|
||||||
|
except ZoneInfoNotFoundError:
|
||||||
|
continue
|
||||||
|
tz_names.append(zone_name)
|
||||||
|
|
||||||
|
TZDB_CACHE = {
|
||||||
|
"tz_to_country_code": tz_to_country_code,
|
||||||
|
"country_code_to_name": country_code_to_name,
|
||||||
|
"tz_names": tz_names,
|
||||||
|
"tz_meta": tz_meta,
|
||||||
|
}
|
||||||
|
return TZDB_CACHE
|
||||||
|
|
||||||
|
|
||||||
|
def split_tz_name(zone_name: str) -> tuple[str, str]:
|
||||||
|
"""Split an IANA timezone name into (region, city)."""
|
||||||
|
if "/" not in zone_name:
|
||||||
|
return zone_name, ""
|
||||||
|
region, rest = zone_name.split("/", 1)
|
||||||
|
return region, rest
|
||||||
|
|
||||||
|
|
||||||
|
def load_timezones() -> list[dict]:
|
||||||
|
"""Load timezones from csv file."""
|
||||||
|
with open("tzdb/TimeZoneDB.csv/time_zone.csv", "r", encoding="utf-8") as f:
|
||||||
|
lines = f.readlines()
|
||||||
|
|
||||||
|
timezones = []
|
||||||
|
for line in lines:
|
||||||
|
fields = line.strip().split(",")
|
||||||
|
if len(fields) >= 5:
|
||||||
|
timezones.append({
|
||||||
|
"zone_name": fields[0],
|
||||||
|
"country_code": fields[1],
|
||||||
|
"abbreviation": fields[2],
|
||||||
|
"time_start": fields[3],
|
||||||
|
"gmt_offset": int(fields[4]),
|
||||||
|
"dst": fields[5] == "1",
|
||||||
|
})
|
||||||
|
return timezones
|
||||||
|
|
||||||
|
|
||||||
|
def load_countries() -> list[dict]:
|
||||||
|
"""Load countries from csv file."""
|
||||||
|
with open("tzdb/TimeZoneDB.csv/country.csv", "r", encoding="utf-8") as f:
|
||||||
|
lines = f.readlines()
|
||||||
|
|
||||||
|
countries = []
|
||||||
|
for line in lines:
|
||||||
|
fields = line.strip().split(",")
|
||||||
|
if len(fields) >= 2:
|
||||||
|
countries.append({
|
||||||
|
"country_code": fields[0],
|
||||||
|
"country_name": fields[1],
|
||||||
|
})
|
||||||
|
return countries
|
||||||
|
|
||||||
|
|
||||||
|
def get_tz_info(tz_name: str, timezones: list[dict]) -> dict | None:
|
||||||
|
"""Get timezone info by name."""
|
||||||
|
return next((tz for tz in timezones if tz["zone_name"] == tz_name), None)
|
||||||
|
|
||||||
|
|
||||||
|
def get_country_info(country_code: str, countries: list[dict]) -> dict | None:
|
||||||
|
"""Get country info by country code."""
|
||||||
|
return next((c for c in countries if c["country_code"] == country_code), None)
|
||||||
|
|
||||||
|
|
||||||
|
def where_is_it_420(
|
||||||
|
timezones: list[dict],
|
||||||
|
countries: list[dict],
|
||||||
|
tz_names: list[str] | None = None,
|
||||||
|
tz_to_country_code: dict[str, str] | None = None,
|
||||||
|
country_code_to_name: dict[str, str] | None = None,
|
||||||
|
) -> list[str]:
|
||||||
|
"""Get timezones where the current hour is 4 or 16."""
|
||||||
|
if tz_to_country_code is None:
|
||||||
|
tz_to_country_code = {}
|
||||||
|
for tz in timezones:
|
||||||
|
zone_name = tz.get("zone_name")
|
||||||
|
country_code = tz.get("country_code")
|
||||||
|
if isinstance(zone_name, str) and isinstance(country_code, str):
|
||||||
|
tz_to_country_code[zone_name] = country_code
|
||||||
|
|
||||||
|
if country_code_to_name is None:
|
||||||
|
country_code_to_name = {}
|
||||||
|
for c in countries:
|
||||||
|
code = c.get("country_code")
|
||||||
|
name = c.get("country_name")
|
||||||
|
if isinstance(code, str) and name is not None:
|
||||||
|
country_code_to_name[code] = str(name).strip().strip('"')
|
||||||
|
|
||||||
|
names_to_check = tz_names if tz_names is not None else pytz.all_timezones
|
||||||
|
results: list[str] = []
|
||||||
|
seen: set[str] = set()
|
||||||
|
|
||||||
|
for tz_name in names_to_check:
|
||||||
|
try:
|
||||||
|
tz_obj = pytz.timezone(tz_name)
|
||||||
|
except Exception:
|
||||||
|
continue
|
||||||
|
|
||||||
|
now = datetime.now(tz_obj)
|
||||||
|
if now.hour != 4 and now.hour != 16:
|
||||||
|
continue
|
||||||
|
|
||||||
|
country_code = tz_to_country_code.get(tz_name)
|
||||||
|
if not country_code:
|
||||||
|
continue
|
||||||
|
country_name = country_code_to_name.get(country_code)
|
||||||
|
if not country_name:
|
||||||
|
continue
|
||||||
|
if country_name in seen:
|
||||||
|
continue
|
||||||
|
|
||||||
|
seen.add(country_name)
|
||||||
|
results.append(country_name)
|
||||||
|
|
||||||
|
return results
|
||||||
Reference in New Issue
Block a user