general: make time.tz.via_location user config lazy, move tests to my.tests package

also gets rid of the problematic reset_modules thingie
This commit is contained in:
Dima Gerasimov 2024-08-26 02:00:51 +01:00 committed by karlicoss
parent 270080bd56
commit a5643206a0
15 changed files with 269 additions and 233 deletions

View file

@ -19,7 +19,7 @@ def _calendar():
# todo switch to using time.tz.main once _get_tz stabilizes?
from ..time.tz import via_location as LTZ
# TODO would be nice to do it dynamically depending on the past timezones...
tz = LTZ._get_tz(datetime.now())
tz = LTZ.get_tz(datetime.now())
assert tz is not None
zone = tz.zone; assert zone is not None
code = zone_to_countrycode(zone)

View file

@ -125,8 +125,10 @@ def test_fromisoformat() -> None:
if sys.version_info[:2] >= (3, 10):
from types import NoneType
from typing import TypeAlias
else:
NoneType = type(None)
from typing_extensions import TypeAlias
if sys.version_info[:2] >= (3, 11):

View file

@ -29,7 +29,6 @@ from typing import Iterator, List
from my.core import make_logger
from my.core.compat import bisect_left
from my.ip.all import ips
from my.location.common import Location
from my.location.fallback.common import FallbackLocation, DateExact, _datetime_timestamp
@ -37,6 +36,9 @@ logger = make_logger(__name__, level="warning")
def fallback_locations() -> Iterator[FallbackLocation]:
# prefer late import since ips get overridden in tests
from my.ip.all import ips
dur = config.for_duration.total_seconds()
for ip in ips():
lat, lon = ip.latlon

9
my/tests/calendar.py Normal file
View file

@ -0,0 +1,9 @@
from my.calendar.holidays import is_holiday
from .shared_tz_config import config # autoused fixture
def test_is_holiday() -> None:
assert is_holiday('20190101')
assert not is_holiday('20180601')
assert is_holiday('20200906') # national holiday in Bulgaria

View file

@ -1,7 +1,5 @@
import os
from pathlib import Path
import re
import sys
import pytest
@ -13,20 +11,6 @@ skip_if_not_karlicoss = pytest.mark.skipif(
)
def reset_modules() -> None:
'''
A hack to 'unload' HPI modules, otherwise some modules might cache the config
TODO: a bit crap, need a better way..
'''
to_unload = [m for m in sys.modules if re.match(r'my[.]?', m)]
for m in to_unload:
if 'my.pdfs' in m:
# temporary hack -- since my.pdfs migrated to a 'lazy' config, this isn't necessary anymore
# but if we reset module anyway, it confuses the ProcessPool inside my.pdfs
continue
del sys.modules[m]
def testdata() -> Path:
d = Path(__file__).absolute().parent.parent.parent / 'testdata'
assert d.exists(), d

View file

@ -1,8 +1,10 @@
import pytest
# I guess makes sense by default
@pytest.fixture(autouse=True)
def without_cachew():
from my.core.cachew import disabled_cachew
with disabled_cachew():
yield

View file

@ -0,0 +1,135 @@
"""
To test my.location.fallback_location.all
"""
from datetime import datetime, timedelta, timezone
from typing import Iterator
import pytest
from more_itertools import ilen
import my.ip.all as ip_module
from my.ip.common import IP
from my.location.fallback import via_ip
from ..shared_tz_config import config # autoused fixture
# these are all tests for the bisect algorithm defined in via_ip.py
# to make sure we can correctly find IPs that are within the 'for_duration' of a given datetime
def test_ip_fallback() -> None:
# precondition, make sure that the data override works
assert ilen(ip_module.ips()) == ilen(data())
assert ilen(ip_module.ips()) == ilen(via_ip.fallback_locations())
assert ilen(via_ip.fallback_locations()) == 5
assert ilen(via_ip._sorted_fallback_locations()) == 5
# confirm duration from via_ip since that is used for bisect
assert via_ip.config.for_duration == timedelta(hours=24)
# basic tests
# try estimating slightly before the first IP
est = list(via_ip.estimate_location(datetime(2020, 1, 1, 11, 59, 59, tzinfo=timezone.utc)))
assert len(est) == 0
# during the duration for the first IP
est = list(via_ip.estimate_location(datetime(2020, 1, 1, 12, 30, 0, tzinfo=timezone.utc)))
assert len(est) == 1
# right after the 'for_duration' for an IP
est = list(
via_ip.estimate_location(datetime(2020, 1, 1, 12, 0, 0, tzinfo=timezone.utc) + via_ip.config.for_duration + timedelta(seconds=1))
)
assert len(est) == 0
# on 2/1/2020, threes one IP if before 16:30
est = list(via_ip.estimate_location(datetime(2020, 2, 1, 12, 30, 0, tzinfo=timezone.utc)))
assert len(est) == 1
# and two if after 16:30
est = list(via_ip.estimate_location(datetime(2020, 2, 1, 17, 00, 0, tzinfo=timezone.utc)))
assert len(est) == 2
# the 12:30 IP should 'expire' before the 16:30 IP, use 3:30PM on the next day
est = list(via_ip.estimate_location(datetime(2020, 2, 2, 15, 30, 0, tzinfo=timezone.utc)))
assert len(est) == 1
use_dt = datetime(2020, 3, 1, 12, 15, 0, tzinfo=timezone.utc)
# test last IP
est = list(via_ip.estimate_location(use_dt))
assert len(est) == 1
# datetime should be the IPs, not the passed IP (if via_home, it uses the passed dt)
assert est[0].dt != use_dt
# test interop with other fallback estimators/all.py
#
# redefine fallback_estimators to prevent possible namespace packages the user
# may have installed from having side effects testing this
from my.location.fallback import all, via_home
def _fe() -> Iterator[all.LocationEstimator]:
yield via_ip.estimate_location
yield via_home.estimate_location
all.fallback_estimators = _fe
assert ilen(all.fallback_estimators()) == 2
# test that all.estimate_location has access to both IPs
#
# just passing via_ip should give one IP
from my.location.fallback.common import _iter_estimate_from
raw_est = list(_iter_estimate_from(use_dt, (via_ip.estimate_location,)))
assert len(raw_est) == 1
assert raw_est[0].datasource == "via_ip"
assert raw_est[0].accuracy == 15_000
# passing home should give one
home_est = list(_iter_estimate_from(use_dt, (via_home.estimate_location,)))
assert len(home_est) == 1
assert home_est[0].accuracy == 30_000
# make sure ip accuracy is more accurate
assert raw_est[0].accuracy < home_est[0].accuracy
# passing both should give two
raw_est = list(_iter_estimate_from(use_dt, (via_ip.estimate_location, via_home.estimate_location)))
assert len(raw_est) == 2
# shouldn't raise value error
all_est = all.estimate_location(use_dt)
# should have used the IP from via_ip since it was more accurate
assert all_est.datasource == "via_ip"
# test that a home defined in shared_tz_config.py is used if no IP is found
loc = all.estimate_location(datetime(2021, 1, 1, 12, 30, 0, tzinfo=timezone.utc))
assert loc.datasource == "via_home"
# test a different home using location.fallback.all
bulgaria = all.estimate_location(datetime(2006, 1, 1, 12, 30, 0, tzinfo=timezone.utc))
assert bulgaria.datasource == "via_home"
assert (bulgaria.lat, bulgaria.lon) == (42.697842, 23.325973)
assert (loc.lat, loc.lon) != (bulgaria.lat, bulgaria.lon)
def data() -> Iterator[IP]:
# random IP addresses
yield IP(addr="67.98.113.0", dt=datetime(2020, 1, 1, 12, 0, 0, tzinfo=timezone.utc))
yield IP(addr="67.98.112.0", dt=datetime(2020, 1, 15, 12, 0, 0, tzinfo=timezone.utc))
yield IP(addr="59.40.113.87", dt=datetime(2020, 2, 1, 12, 0, 0, tzinfo=timezone.utc))
yield IP(addr="59.40.139.87", dt=datetime(2020, 2, 1, 16, 0, 0, tzinfo=timezone.utc))
yield IP(addr="161.235.192.228", dt=datetime(2020, 3, 1, 12, 0, 0, tzinfo=timezone.utc))
@pytest.fixture(autouse=True)
def prepare(config):
before = ip_module.ips
# redefine the my.ip.all function using data for testing
ip_module.ips = data
try:
yield
finally:
ip_module.ips = before

View file

@ -0,0 +1,60 @@
"""
Helper to test various timezone/location dependent things
"""
from datetime import date, datetime, timezone
from pathlib import Path
import pytest
from more_itertools import one
from my.core.cfg import tmp_config
@pytest.fixture(autouse=True)
def config(tmp_path: Path):
# TODO could just pick a part of shared config? not sure
_takeout_path = _prepare_takeouts_dir(tmp_path)
class google:
takeout_path = _takeout_path
class location:
# fmt: off
home = (
# supports ISO strings
('2005-12-04' , (42.697842, 23.325973)), # Bulgaria, Sofia
# supports date/datetime objects
(date(year=1980, month=2, day=15) , (40.7128 , -74.0060 )), # NY
# check tz handling..
(datetime.fromtimestamp(1600000000, tz=timezone.utc), (55.7558 , 37.6173 )), # Moscow, Russia
)
# fmt: on
# note: order doesn't matter, will be sorted in the data provider
class time:
class tz:
class via_location:
fast = True # some tests rely on it
with tmp_config() as cfg:
cfg.google = google
cfg.location = location
cfg.time = time
yield cfg
def _prepare_takeouts_dir(tmp_path: Path) -> Path:
from .common import testdata
try:
track = one(testdata().rglob('italy-slovenia-2017-07-29.json'))
except ValueError:
raise RuntimeError('testdata not found, setup git submodules?')
# todo ugh. unnecessary zipping, but at the moment takeout provider doesn't support plain dirs
import zipfile
with zipfile.ZipFile(tmp_path / 'takeout.zip', 'w') as zf:
zf.writestr('Takeout/Location History/Location History.json', track.read_bytes())
return tmp_path

107
my/tests/tz.py Normal file
View file

@ -0,0 +1,107 @@
import sys
from datetime import datetime, timedelta
import pytest
import pytz
import my.time.tz.main as tz_main
import my.time.tz.via_location as tz_via_location
from my.core import notnone
from my.core.compat import fromisoformat
from .shared_tz_config import config # autoused fixture
def getzone(dt: datetime) -> str:
tz = notnone(dt.tzinfo)
return getattr(tz, 'zone')
@pytest.mark.parametrize('fast', [False, True])
def test_iter_tzs(fast: bool, config) -> None:
# TODO hmm.. maybe need to make sure we start with empty config?
config.time.tz.via_location.fast = fast
ll = list(tz_via_location._iter_tzs())
zones = [x.zone for x in ll]
if fast:
assert zones == [
'Europe/Rome',
'Europe/Rome',
'Europe/Vienna',
'Europe/Vienna',
'Europe/Vienna',
]
else:
assert zones == [
'Europe/Rome',
'Europe/Rome',
'Europe/Ljubljana',
'Europe/Ljubljana',
'Europe/Ljubljana',
]
def test_past() -> None:
"""
Should fallback to the 'home' location provider
"""
dt = fromisoformat('2000-01-01 12:34:45')
dt = tz_main.localize(dt)
assert getzone(dt) == 'America/New_York'
def test_future() -> None:
"""
For locations in the future should rely on 'home' location
"""
fut = datetime.now() + timedelta(days=100)
fut = tz_main.localize(fut)
assert getzone(fut) == 'Europe/Moscow'
def test_get_tz(config) -> None:
# todo hmm, the way it's implemented at the moment, never returns None?
get_tz = tz_via_location.get_tz
# not present in the test data
tz = get_tz(fromisoformat('2020-01-01 10:00:00'))
assert notnone(tz).zone == 'Europe/Sofia'
tz = get_tz(fromisoformat('2017-08-01 11:00:00'))
assert notnone(tz).zone == 'Europe/Vienna'
tz = get_tz(fromisoformat('2017-07-30 10:00:00'))
assert notnone(tz).zone == 'Europe/Rome'
tz = get_tz(fromisoformat('2020-10-01 14:15:16'))
assert tz is not None
on_windows = sys.platform == 'win32'
if not on_windows:
tz = get_tz(datetime.min)
assert tz is not None
else:
# seems this fails because windows doesn't support same date ranges
# https://stackoverflow.com/a/41400321/
with pytest.raises(OSError):
get_tz(datetime.min)
def test_policies() -> None:
naive = fromisoformat('2017-07-30 10:00:00')
assert naive.tzinfo is None # just in case
# actual timezone at the time
assert getzone(tz_main.localize(naive)) == 'Europe/Rome'
z = pytz.timezone('America/New_York')
aware = z.localize(naive)
assert getzone(tz_main.localize(aware)) == 'America/New_York'
assert getzone(tz_main.localize(aware, policy='convert')) == 'Europe/Rome'
with pytest.raises(RuntimeError):
assert tz_main.localize(aware, policy='throw')

View file

@ -1,52 +1,43 @@
'''
Timezone data provider, guesses timezone based on location data (e.g. GPS)
'''
REQUIRES = [
# for determining timezone by coordinate
'timezonefinder',
]
import heapq
import os
from collections import Counter
from dataclasses import dataclass
from datetime import date, datetime
from functools import lru_cache
import heapq
from itertools import groupby
import os
from typing import Iterator, Optional, Tuple, Any, List, Iterable, Set, Dict
from typing import (
TYPE_CHECKING,
Any,
Dict,
Iterable,
Iterator,
List,
Optional,
Protocol,
Set,
Tuple,
)
import pytz
from my.core import Stats, datetime_aware, make_logger, stat
from my.core.cachew import mcachew
from my.core import make_logger, stat, Stats, datetime_aware
from my.core.compat import TypeAlias
from my.core.source import import_source
from my.core.warnings import high
from my.location.common import LatLon
## user might not have tz config section, so makes sense to be more defensive about it
# todo might be useful to extract a helper for this
try:
from my.config import time
except ImportError as ie:
if ie.name != 'time':
raise ie
else:
try:
user_config = time.tz.via_location
except AttributeError as ae:
if not ("'tz'" in str(ae) or "'via_location'"):
raise ae
# deliberately dynamic to prevent confusing mypy
if 'user_config' not in globals():
globals()['user_config'] = object
##
@dataclass
class config(user_config):
class config(Protocol):
# less precise, but faster
fast: bool = True
@ -62,6 +53,43 @@ class config(user_config):
_iter_tz_refresh_time: int = 6
def _get_user_config():
## user might not have tz config section, so makes sense to be more defensive about it
class empty_config: ...
try:
from my.config import time
except ImportError as ie:
if "'time'" not in str(ie):
raise ie
else:
return empty_config
try:
user_config = time.tz.via_location
except AttributeError as ae:
if not ("'tz'" in str(ae) or "'via_location'" in str(ae)):
raise ae
else:
return empty_config
return user_config
def make_config() -> config:
if TYPE_CHECKING:
import my.config
user_config: TypeAlias = my.config.time.tz.via_location
else:
user_config = _get_user_config()
class combined_config(user_config, config): ...
return combined_config()
logger = make_logger(__name__)
@ -78,6 +106,7 @@ def _timezone_finder(fast: bool) -> Any:
# for backwards compatibility
def _locations() -> Iterator[Tuple[LatLon, datetime_aware]]:
try:
raise RuntimeError
import my.location.all
for loc in my.location.all.locations():
@ -140,13 +169,14 @@ def _find_tz_for_locs(finder: Any, locs: Iterable[Tuple[LatLon, datetime]]) -> I
# Note: this takes a while, as the upstream since _locations isn't sorted, so this
# has to do an iterative sort of the entire my.locations.all list
def _iter_local_dates() -> Iterator[DayWithZone]:
finder = _timezone_finder(fast=config.fast) # rely on the default
cfg = make_config()
finder = _timezone_finder(fast=cfg.fast) # rely on the default
# pdt = None
# TODO: warnings doesn't actually warn?
# warnings = []
locs: Iterable[Tuple[LatLon, datetime]]
locs = _sorted_locations() if config.sort_locations else _locations()
locs = _sorted_locations() if cfg.sort_locations else _locations()
yield from _find_tz_for_locs(finder, locs)
@ -158,11 +188,13 @@ def _iter_local_dates() -> Iterator[DayWithZone]:
def _iter_local_dates_fallback() -> Iterator[DayWithZone]:
from my.location.fallback.all import fallback_locations as flocs
cfg = make_config()
def _fallback_locations() -> Iterator[Tuple[LatLon, datetime]]:
for loc in sorted(flocs(), key=lambda x: x.dt):
yield ((loc.lat, loc.lon), loc.dt)
yield from _find_tz_for_locs(_timezone_finder(fast=config.fast), _fallback_locations())
yield from _find_tz_for_locs(_timezone_finder(fast=cfg.fast), _fallback_locations())
def most_common(lst: Iterator[DayWithZone]) -> DayWithZone:
@ -180,7 +212,8 @@ def _iter_tz_depends_on() -> str:
2022-04-26_12
2022-04-26_18
"""
mod = config._iter_tz_refresh_time
cfg = make_config()
mod = cfg._iter_tz_refresh_time
assert mod >= 1
day = str(date.today())
hr = datetime.now().hour
@ -293,5 +326,13 @@ def stats(quick: bool = False) -> Stats:
return stat(localized_years)
# deprecated -- still used in some other modules so need to keep
_get_tz = get_tz
## deprecated -- keeping for now as might be used in other modules?
if not TYPE_CHECKING:
from my.core.compat import deprecated
@deprecated('use get_tz function instead')
def _get_tz(*args, **kwargs):
return get_tz(*args, **kwargs)
##