feat: implement timezone and country data handling in thctime module
This commit is contained in:
@@ -1,16 +1,24 @@
|
|||||||
import os
|
import os
|
||||||
import pytz
|
|
||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
import time
|
import time
|
||||||
import logging
|
import logging
|
||||||
import threading
|
import threading
|
||||||
from zoneinfo import ZoneInfo, ZoneInfoNotFoundError
|
|
||||||
import requests
|
import requests
|
||||||
import schedule
|
import schedule
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
from templates import load_templates
|
from templates import load_templates
|
||||||
from dashboard import create_app
|
from dashboard import create_app
|
||||||
|
from thctime import (
|
||||||
|
get_country_info,
|
||||||
|
get_tz_info,
|
||||||
|
get_tzdb_cache,
|
||||||
|
init_tzdb_cache,
|
||||||
|
load_countries,
|
||||||
|
load_timezones,
|
||||||
|
split_tz_name,
|
||||||
|
where_is_it_420,
|
||||||
|
)
|
||||||
|
|
||||||
SCHEDULED_NOTIFICATIONS = [
|
SCHEDULED_NOTIFICATIONS = [
|
||||||
(15, "reminder"),
|
(15, "reminder"),
|
||||||
@@ -19,10 +27,6 @@ SCHEDULED_NOTIFICATIONS = [
|
|||||||
(50, "halftime"),
|
(50, "halftime"),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
TZDB_CACHE: dict | None = None
|
|
||||||
|
|
||||||
|
|
||||||
STATE_LOCK = threading.Lock()
|
STATE_LOCK = threading.Lock()
|
||||||
STATE: dict = {
|
STATE: dict = {
|
||||||
"running": True,
|
"running": True,
|
||||||
@@ -59,88 +63,6 @@ def get_next_event() -> dict:
|
|||||||
return get_next_scheduled_event()
|
return get_next_scheduled_event()
|
||||||
|
|
||||||
|
|
||||||
def init_tzdb_cache() -> dict:
|
|
||||||
"""Initialize a cached lookup structure for tzdb data.
|
|
||||||
|
|
||||||
This keeps the hourly scheduler fast by:
|
|
||||||
- Building O(1) maps (zone_name -> country_code, country_code -> country_name)
|
|
||||||
- Precomputing a list of tz names that exist in both tzdb CSVs and `pytz`
|
|
||||||
|
|
||||||
Note: this is intentionally NOT run at import time so tests can monkeypatch
|
|
||||||
`load_timezones`/`load_countries` without needing to reset global state.
|
|
||||||
"""
|
|
||||||
global TZDB_CACHE
|
|
||||||
if TZDB_CACHE is not None:
|
|
||||||
return TZDB_CACHE
|
|
||||||
|
|
||||||
timezones = load_timezones()
|
|
||||||
countries = load_countries()
|
|
||||||
|
|
||||||
tz_to_country_code: dict[str, str] = {}
|
|
||||||
tz_meta: dict[str, dict] = {}
|
|
||||||
for tz in timezones:
|
|
||||||
zone_name = tz.get("zone_name")
|
|
||||||
country_code = tz.get("country_code")
|
|
||||||
if not isinstance(zone_name, str) or not zone_name:
|
|
||||||
continue
|
|
||||||
if not isinstance(country_code, str) or not country_code:
|
|
||||||
continue
|
|
||||||
|
|
||||||
tz_to_country_code[zone_name] = country_code
|
|
||||||
region, city = split_tz_name(zone_name)
|
|
||||||
tz_meta[zone_name] = {
|
|
||||||
"zone_name": zone_name,
|
|
||||||
"country_code": country_code,
|
|
||||||
"region": region,
|
|
||||||
"city": city,
|
|
||||||
}
|
|
||||||
|
|
||||||
country_code_to_name: dict[str, str] = {}
|
|
||||||
for c in countries:
|
|
||||||
code = c.get("country_code")
|
|
||||||
name = c.get("country_name")
|
|
||||||
if code and name:
|
|
||||||
country_code_to_name[code] = str(name).strip().strip('"')
|
|
||||||
|
|
||||||
# Attach resolved country names onto tz_meta (storage-only for now).
|
|
||||||
for zone_name, meta in tz_meta.items():
|
|
||||||
code = meta.get("country_code")
|
|
||||||
if isinstance(code, str):
|
|
||||||
meta["country_name"] = country_code_to_name.get(code)
|
|
||||||
|
|
||||||
# Vetted tz list: only names that are present in tzdb and loadable by zoneinfo.
|
|
||||||
# Installing the `tzdata` package keeps this mapping up-to-date.
|
|
||||||
tz_names: list[str] = []
|
|
||||||
for zone_name in tz_to_country_code.keys():
|
|
||||||
try:
|
|
||||||
ZoneInfo(zone_name)
|
|
||||||
except ZoneInfoNotFoundError:
|
|
||||||
continue
|
|
||||||
tz_names.append(zone_name)
|
|
||||||
|
|
||||||
TZDB_CACHE = {
|
|
||||||
"tz_to_country_code": tz_to_country_code,
|
|
||||||
"country_code_to_name": country_code_to_name,
|
|
||||||
"tz_names": tz_names,
|
|
||||||
"tz_meta": tz_meta,
|
|
||||||
}
|
|
||||||
return TZDB_CACHE
|
|
||||||
|
|
||||||
|
|
||||||
def split_tz_name(zone_name: str) -> tuple[str, str]:
|
|
||||||
"""Split an IANA timezone name into (region, city).
|
|
||||||
|
|
||||||
Examples:
|
|
||||||
- "America/New_York" -> ("America", "New_York")
|
|
||||||
- "America/Argentina/Buenos_Aires" -> ("America", "Argentina/Buenos_Aires")
|
|
||||||
- "UTC" -> ("UTC", "")
|
|
||||||
"""
|
|
||||||
if "/" not in zone_name:
|
|
||||||
return zone_name, ""
|
|
||||||
region, rest = zone_name.split("/", 1)
|
|
||||||
return region, rest
|
|
||||||
|
|
||||||
|
|
||||||
def _update_state(**updates) -> None:
|
def _update_state(**updates) -> None:
|
||||||
with STATE_LOCK:
|
with STATE_LOCK:
|
||||||
STATE.update(updates)
|
STATE.update(updates)
|
||||||
@@ -172,51 +94,13 @@ def get_next_scheduled_event(now: datetime | None = None) -> dict:
|
|||||||
return {"at": next_dt, "type": next_type}
|
return {"at": next_dt, "type": next_type}
|
||||||
|
|
||||||
|
|
||||||
def load_timezones() -> list[dict]:
|
|
||||||
"""Load timezones from csv file."""
|
|
||||||
# Read the CSV file and return a list of timezones
|
|
||||||
with open("tzdb/TimeZoneDB.csv/time_zone.csv", "r", encoding="utf-8") as f:
|
|
||||||
lines = f.readlines()
|
|
||||||
# Fields: zone_name,country_code,abbreviation,time_start,gmt_offset,dst
|
|
||||||
timezones = []
|
|
||||||
for line in lines:
|
|
||||||
fields = line.strip().split(",")
|
|
||||||
if len(fields) >= 5:
|
|
||||||
timezones.append({
|
|
||||||
"zone_name": fields[0],
|
|
||||||
"country_code": fields[1],
|
|
||||||
"abbreviation": fields[2],
|
|
||||||
"time_start": fields[3],
|
|
||||||
"gmt_offset": int(fields[4]),
|
|
||||||
"dst": fields[5] == '1'
|
|
||||||
})
|
|
||||||
return timezones
|
|
||||||
|
|
||||||
|
|
||||||
def load_countries() -> list[dict]:
|
|
||||||
"""Load countries from csv file."""
|
|
||||||
# Read the CSV file and return a list of countries
|
|
||||||
with open("tzdb/TimeZoneDB.csv/country.csv", "r", encoding="utf-8") as f:
|
|
||||||
lines = f.readlines()
|
|
||||||
# Fields: country_code,country_name
|
|
||||||
countries = []
|
|
||||||
for line in lines:
|
|
||||||
fields = line.strip().split(",")
|
|
||||||
if len(fields) >= 2:
|
|
||||||
countries.append({
|
|
||||||
"country_code": fields[0],
|
|
||||||
"country_name": fields[1]
|
|
||||||
})
|
|
||||||
return countries
|
|
||||||
|
|
||||||
|
|
||||||
def create_embed(type: str, tz_list: list[str] | None = None) -> dict:
|
def create_embed(type: str, tz_list: list[str] | None = None) -> dict:
|
||||||
"""
|
"""
|
||||||
Create a Discord embed message.
|
Create a Discord embed message.
|
||||||
"""
|
"""
|
||||||
templates_path = os.getenv("TEMPLATES_PATH", "templates.json")
|
templates_path = os.getenv("TEMPLATES_PATH", "templates.json")
|
||||||
messages = load_templates(templates_path)
|
messages = load_templates(templates_path)
|
||||||
cache = TZDB_CACHE
|
cache = get_tzdb_cache()
|
||||||
timezones = load_timezones() if cache is None else []
|
timezones = load_timezones() if cache is None else []
|
||||||
countries = load_countries() if cache is None else []
|
countries = load_countries() if cache is None else []
|
||||||
if type in messages:
|
if type in messages:
|
||||||
@@ -613,7 +497,7 @@ def send_notification(message: str) -> None:
|
|||||||
tz_list: list[str] | None = None
|
tz_list: list[str] | None = None
|
||||||
if message == "420":
|
if message == "420":
|
||||||
try:
|
try:
|
||||||
cache = TZDB_CACHE
|
cache = get_tzdb_cache()
|
||||||
if cache is None:
|
if cache is None:
|
||||||
tz_list = where_is_it_420(load_timezones(), load_countries())
|
tz_list = where_is_it_420(load_timezones(), load_countries())
|
||||||
else:
|
else:
|
||||||
@@ -624,7 +508,7 @@ def send_notification(message: str) -> None:
|
|||||||
tz_to_country_code=cache.get("tz_to_country_code"),
|
tz_to_country_code=cache.get("tz_to_country_code"),
|
||||||
country_code_to_name=cache.get("country_code_to_name"),
|
country_code_to_name=cache.get("country_code_to_name"),
|
||||||
)
|
)
|
||||||
_update_state(last_locations=tz_list)
|
_update_state(last_locations=tz_list or [])
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
_update_state(last_locations=[], last_error=str(e))
|
_update_state(last_locations=[], last_error=str(e))
|
||||||
|
|
||||||
@@ -648,74 +532,6 @@ def send_notification(message: str) -> None:
|
|||||||
_update_state(last_error=str(e))
|
_update_state(last_error=str(e))
|
||||||
|
|
||||||
|
|
||||||
def get_tz_info(tz_name: str, timezones: list[dict]) -> dict | None:
|
|
||||||
"""Get timezone info by name."""
|
|
||||||
return next((tz for tz in timezones if tz["zone_name"] == tz_name), None)
|
|
||||||
|
|
||||||
|
|
||||||
def get_country_info(country_code: str, countries: list[dict]) -> dict | None:
|
|
||||||
"""Get country info by country code."""
|
|
||||||
return next((c for c in countries if c["country_code"] == country_code), None)
|
|
||||||
|
|
||||||
|
|
||||||
def where_is_it_420(
|
|
||||||
timezones: list[dict],
|
|
||||||
countries: list[dict],
|
|
||||||
tz_names: list[str] | None = None,
|
|
||||||
tz_to_country_code: dict[str, str] | None = None,
|
|
||||||
country_code_to_name: dict[str, str] | None = None,
|
|
||||||
) -> list[str]:
|
|
||||||
"""Get timezones where the current hour is 4 or 16, indicating it's 4:20 there.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
list[str]: A list of timezones where it's currently 4:20 PM or AM.
|
|
||||||
"""
|
|
||||||
# Build fast lookup dicts if not provided.
|
|
||||||
if tz_to_country_code is None:
|
|
||||||
tz_to_country_code = {}
|
|
||||||
for tz in timezones:
|
|
||||||
zone_name = tz.get("zone_name")
|
|
||||||
country_code = tz.get("country_code")
|
|
||||||
if isinstance(zone_name, str) and isinstance(country_code, str):
|
|
||||||
tz_to_country_code[zone_name] = country_code
|
|
||||||
|
|
||||||
if country_code_to_name is None:
|
|
||||||
country_code_to_name = {}
|
|
||||||
for c in countries:
|
|
||||||
code = c.get("country_code")
|
|
||||||
name = c.get("country_name")
|
|
||||||
if isinstance(code, str) and name is not None:
|
|
||||||
country_code_to_name[code] = str(name).strip().strip('"')
|
|
||||||
|
|
||||||
names_to_check = tz_names if tz_names is not None else pytz.all_timezones
|
|
||||||
results: list[str] = []
|
|
||||||
seen: set[str] = set()
|
|
||||||
|
|
||||||
for tz_name in names_to_check:
|
|
||||||
try:
|
|
||||||
tz_obj = pytz.timezone(tz_name)
|
|
||||||
except Exception:
|
|
||||||
continue
|
|
||||||
|
|
||||||
now = datetime.now(tz_obj)
|
|
||||||
if now.hour != 4 and now.hour != 16:
|
|
||||||
continue
|
|
||||||
|
|
||||||
country_code = tz_to_country_code.get(tz_name)
|
|
||||||
if not country_code:
|
|
||||||
continue
|
|
||||||
country_name = country_code_to_name.get(country_code)
|
|
||||||
if not country_name:
|
|
||||||
continue
|
|
||||||
if country_name in seen:
|
|
||||||
continue
|
|
||||||
|
|
||||||
seen.add(country_name)
|
|
||||||
results.append(country_name)
|
|
||||||
|
|
||||||
return results
|
|
||||||
|
|
||||||
|
|
||||||
def schedule_notification(interval: str, at: str, type: str) -> None:
|
def schedule_notification(interval: str, at: str, type: str) -> None:
|
||||||
"""Example: schedule.every().hour.at(":15").do(send_notification, "reminder")"""
|
"""Example: schedule.every().hour.at(":15").do(send_notification, "reminder")"""
|
||||||
if interval == "hour":
|
if interval == "hour":
|
||||||
|
|||||||
+3
-6
@@ -5,6 +5,7 @@ from unittest import mock
|
|||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
import main
|
import main
|
||||||
|
import thctime
|
||||||
|
|
||||||
|
|
||||||
SAMPLE_TIMEZONE_CSV = """Etc/UTC,ZZ,UTC,0,0,0
|
SAMPLE_TIMEZONE_CSV = """Etc/UTC,ZZ,UTC,0,0,0
|
||||||
@@ -72,7 +73,7 @@ def test_create_embed_all_types(monkeypatch):
|
|||||||
|
|
||||||
def test_where_is_it_420(monkeypatch):
|
def test_where_is_it_420(monkeypatch):
|
||||||
# Limit timezones to a predictable set
|
# Limit timezones to a predictable set
|
||||||
monkeypatch.setattr(main.pytz, 'all_timezones', ['Etc/UTC'])
|
monkeypatch.setattr(thctime.pytz, 'all_timezones', ['Etc/UTC'])
|
||||||
|
|
||||||
tzs = [{'zone_name': 'Etc/UTC', 'country_code': 'ZZ'}]
|
tzs = [{'zone_name': 'Etc/UTC', 'country_code': 'ZZ'}]
|
||||||
countries = [{'country_code': 'ZZ', 'country_name': 'Nowhere'}]
|
countries = [{'country_code': 'ZZ', 'country_name': 'Nowhere'}]
|
||||||
@@ -84,11 +85,7 @@ def test_where_is_it_420(monkeypatch):
|
|||||||
hour = 4
|
hour = 4
|
||||||
return R()
|
return R()
|
||||||
|
|
||||||
monkeypatch.setattr(main, 'datetime', FakeDatetime)
|
monkeypatch.setattr(thctime, 'datetime', FakeDatetime)
|
||||||
monkeypatch.setattr(main, 'get_tz_info', lambda name,
|
|
||||||
t: tzs[0] if name == 'Etc/UTC' else None)
|
|
||||||
monkeypatch.setattr(main, 'get_country_info', lambda code,
|
|
||||||
c: countries[0] if code == 'ZZ' else None)
|
|
||||||
|
|
||||||
res = main.where_is_it_420(tzs, countries)
|
res = main.where_is_it_420(tzs, countries)
|
||||||
assert res == ['Nowhere']
|
assert res == ['Nowhere']
|
||||||
|
|||||||
+175
@@ -0,0 +1,175 @@
|
|||||||
|
from datetime import datetime
|
||||||
|
from zoneinfo import ZoneInfo, ZoneInfoNotFoundError
|
||||||
|
|
||||||
|
import pytz
|
||||||
|
|
||||||
|
|
||||||
|
TZDB_CACHE: dict | None = None
|
||||||
|
|
||||||
|
|
||||||
|
def get_tzdb_cache() -> dict | None:
|
||||||
|
return TZDB_CACHE
|
||||||
|
|
||||||
|
|
||||||
|
def init_tzdb_cache() -> dict:
|
||||||
|
"""Initialize a cached lookup structure for tzdb data."""
|
||||||
|
global TZDB_CACHE
|
||||||
|
if TZDB_CACHE is not None:
|
||||||
|
return TZDB_CACHE
|
||||||
|
|
||||||
|
timezones = load_timezones()
|
||||||
|
countries = load_countries()
|
||||||
|
|
||||||
|
tz_to_country_code: dict[str, str] = {}
|
||||||
|
tz_meta: dict[str, dict] = {}
|
||||||
|
for tz in timezones:
|
||||||
|
zone_name = tz.get("zone_name")
|
||||||
|
country_code = tz.get("country_code")
|
||||||
|
if not isinstance(zone_name, str) or not zone_name:
|
||||||
|
continue
|
||||||
|
if not isinstance(country_code, str) or not country_code:
|
||||||
|
continue
|
||||||
|
|
||||||
|
tz_to_country_code[zone_name] = country_code
|
||||||
|
region, city = split_tz_name(zone_name)
|
||||||
|
tz_meta[zone_name] = {
|
||||||
|
"zone_name": zone_name,
|
||||||
|
"country_code": country_code,
|
||||||
|
"region": region,
|
||||||
|
"city": city,
|
||||||
|
}
|
||||||
|
|
||||||
|
country_code_to_name: dict[str, str] = {}
|
||||||
|
for c in countries:
|
||||||
|
code = c.get("country_code")
|
||||||
|
name = c.get("country_name")
|
||||||
|
if code and name:
|
||||||
|
country_code_to_name[code] = str(name).strip().strip('"')
|
||||||
|
|
||||||
|
for zone_name, meta in tz_meta.items():
|
||||||
|
code = meta.get("country_code")
|
||||||
|
if isinstance(code, str):
|
||||||
|
meta["country_name"] = country_code_to_name.get(code)
|
||||||
|
|
||||||
|
tz_names: list[str] = []
|
||||||
|
for zone_name in tz_to_country_code.keys():
|
||||||
|
try:
|
||||||
|
ZoneInfo(zone_name)
|
||||||
|
except ZoneInfoNotFoundError:
|
||||||
|
continue
|
||||||
|
tz_names.append(zone_name)
|
||||||
|
|
||||||
|
TZDB_CACHE = {
|
||||||
|
"tz_to_country_code": tz_to_country_code,
|
||||||
|
"country_code_to_name": country_code_to_name,
|
||||||
|
"tz_names": tz_names,
|
||||||
|
"tz_meta": tz_meta,
|
||||||
|
}
|
||||||
|
return TZDB_CACHE
|
||||||
|
|
||||||
|
|
||||||
|
def split_tz_name(zone_name: str) -> tuple[str, str]:
|
||||||
|
"""Split an IANA timezone name into (region, city)."""
|
||||||
|
if "/" not in zone_name:
|
||||||
|
return zone_name, ""
|
||||||
|
region, rest = zone_name.split("/", 1)
|
||||||
|
return region, rest
|
||||||
|
|
||||||
|
|
||||||
|
def load_timezones() -> list[dict]:
|
||||||
|
"""Load timezones from csv file."""
|
||||||
|
with open("tzdb/TimeZoneDB.csv/time_zone.csv", "r", encoding="utf-8") as f:
|
||||||
|
lines = f.readlines()
|
||||||
|
|
||||||
|
timezones = []
|
||||||
|
for line in lines:
|
||||||
|
fields = line.strip().split(",")
|
||||||
|
if len(fields) >= 5:
|
||||||
|
timezones.append({
|
||||||
|
"zone_name": fields[0],
|
||||||
|
"country_code": fields[1],
|
||||||
|
"abbreviation": fields[2],
|
||||||
|
"time_start": fields[3],
|
||||||
|
"gmt_offset": int(fields[4]),
|
||||||
|
"dst": fields[5] == "1",
|
||||||
|
})
|
||||||
|
return timezones
|
||||||
|
|
||||||
|
|
||||||
|
def load_countries() -> list[dict]:
|
||||||
|
"""Load countries from csv file."""
|
||||||
|
with open("tzdb/TimeZoneDB.csv/country.csv", "r", encoding="utf-8") as f:
|
||||||
|
lines = f.readlines()
|
||||||
|
|
||||||
|
countries = []
|
||||||
|
for line in lines:
|
||||||
|
fields = line.strip().split(",")
|
||||||
|
if len(fields) >= 2:
|
||||||
|
countries.append({
|
||||||
|
"country_code": fields[0],
|
||||||
|
"country_name": fields[1],
|
||||||
|
})
|
||||||
|
return countries
|
||||||
|
|
||||||
|
|
||||||
|
def get_tz_info(tz_name: str, timezones: list[dict]) -> dict | None:
|
||||||
|
"""Get timezone info by name."""
|
||||||
|
return next((tz for tz in timezones if tz["zone_name"] == tz_name), None)
|
||||||
|
|
||||||
|
|
||||||
|
def get_country_info(country_code: str, countries: list[dict]) -> dict | None:
|
||||||
|
"""Get country info by country code."""
|
||||||
|
return next((c for c in countries if c["country_code"] == country_code), None)
|
||||||
|
|
||||||
|
|
||||||
|
def where_is_it_420(
|
||||||
|
timezones: list[dict],
|
||||||
|
countries: list[dict],
|
||||||
|
tz_names: list[str] | None = None,
|
||||||
|
tz_to_country_code: dict[str, str] | None = None,
|
||||||
|
country_code_to_name: dict[str, str] | None = None,
|
||||||
|
) -> list[str]:
|
||||||
|
"""Get timezones where the current hour is 4 or 16."""
|
||||||
|
if tz_to_country_code is None:
|
||||||
|
tz_to_country_code = {}
|
||||||
|
for tz in timezones:
|
||||||
|
zone_name = tz.get("zone_name")
|
||||||
|
country_code = tz.get("country_code")
|
||||||
|
if isinstance(zone_name, str) and isinstance(country_code, str):
|
||||||
|
tz_to_country_code[zone_name] = country_code
|
||||||
|
|
||||||
|
if country_code_to_name is None:
|
||||||
|
country_code_to_name = {}
|
||||||
|
for c in countries:
|
||||||
|
code = c.get("country_code")
|
||||||
|
name = c.get("country_name")
|
||||||
|
if isinstance(code, str) and name is not None:
|
||||||
|
country_code_to_name[code] = str(name).strip().strip('"')
|
||||||
|
|
||||||
|
names_to_check = tz_names if tz_names is not None else pytz.all_timezones
|
||||||
|
results: list[str] = []
|
||||||
|
seen: set[str] = set()
|
||||||
|
|
||||||
|
for tz_name in names_to_check:
|
||||||
|
try:
|
||||||
|
tz_obj = pytz.timezone(tz_name)
|
||||||
|
except Exception:
|
||||||
|
continue
|
||||||
|
|
||||||
|
now = datetime.now(tz_obj)
|
||||||
|
if now.hour != 4 and now.hour != 16:
|
||||||
|
continue
|
||||||
|
|
||||||
|
country_code = tz_to_country_code.get(tz_name)
|
||||||
|
if not country_code:
|
||||||
|
continue
|
||||||
|
country_name = country_code_to_name.get(country_code)
|
||||||
|
if not country_name:
|
||||||
|
continue
|
||||||
|
if country_name in seen:
|
||||||
|
continue
|
||||||
|
|
||||||
|
seen.add(country_name)
|
||||||
|
results.append(country_name)
|
||||||
|
|
||||||
|
return results
|
||||||
Reference in New Issue
Block a user