diff --git a/my/core/hpi_compat.py b/my/core/hpi_compat.py index 6261c23..949046d 100644 --- a/my/core/hpi_compat.py +++ b/my/core/hpi_compat.py @@ -2,6 +2,7 @@ Contains various backwards compatibility/deprecation helpers relevant to HPI itself. (as opposed to .compat module which implements compatibility between python versions) """ + import inspect import os import re @@ -116,32 +117,141 @@ V = TypeVar('V') # named to be kinda consistent with more_itertools, e.g. more_itertools.always_iterable class always_supports_sequence(Iterator[V]): """ - Helper to make migration from Sequence/List to Iterable/Iterator type backwards compatible + Helper to make migration from Sequence/List to Iterable/Iterator type backwards compatible in runtime """ def __init__(self, it: Iterator[V]) -> None: - self.it = it - self._list: Optional[List] = None + self._it = it + self._list: Optional[List[V]] = None + self._lit: Optional[Iterator[V]] = None def __iter__(self) -> Iterator[V]: # noqa: PYI034 - return self.it.__iter__() + if self._list is not None: + self._lit = iter(self._list) + return self def __next__(self) -> V: - return self.it.__next__() + if self._list is not None: + assert self._lit is not None + delegate = self._lit + else: + delegate = self._it + return next(delegate) def __getattr__(self, name): - return getattr(self.it, name) + return getattr(self._it, name) @property - def aslist(self) -> List[V]: + def _aslist(self) -> List[V]: if self._list is None: - qualname = getattr(self.it, '__qualname__', '') # defensive just in case + qualname = getattr(self._it, '__qualname__', '') # defensive just in case warnings.medium(f'Using {qualname} as list is deprecated. Migrate to iterative processing or call list() explicitly.') - self._list = list(self.it) + self._list = list(self._it) + + # this is necessary for list constructor to work correctly + # since it's __iter__ first, then tries to compute length and then starts iterating... + self._lit = iter(self._list) return self._list def __len__(self) -> int: - return len(self.aslist) + return len(self._aslist) def __getitem__(self, i: int) -> V: - return self.aslist[i] + return self._aslist[i] + + +def test_always_supports_sequence_list_constructor() -> None: + exhausted = 0 + + def it() -> Iterator[str]: + nonlocal exhausted + yield from ['a', 'b', 'c'] + exhausted += 1 + + sit = always_supports_sequence(it()) + + # list constructor is a bit special... it's trying to compute length if it's available to optimize memory allocation + # so, what's happening in this case is + # - sit.__iter__ is called + # - sit.__len__ is called + # - sit.__next__ is called + res = list(sit) + assert res == ['a', 'b', 'c'] + assert exhausted == 1 + + res = list(sit) + assert res == ['a', 'b', 'c'] + assert exhausted == 1 # this will iterate over 'cached' list now, so original generator is only exhausted once + + +def test_always_supports_sequence_indexing() -> None: + exhausted = 0 + + def it() -> Iterator[str]: + nonlocal exhausted + yield from ['a', 'b', 'c'] + exhausted += 1 + + sit = always_supports_sequence(it()) + + assert len(sit) == 3 + assert exhausted == 1 + + assert sit[2] == 'c' + assert sit[1] == 'b' + assert sit[0] == 'a' + assert exhausted == 1 + + # a few tests to make sure list-like operations are working.. + assert list(sit) == ['a', 'b', 'c'] + assert [x for x in sit] == ['a', 'b', 'c'] # noqa: C416 + assert list(sit) == ['a', 'b', 'c'] + assert [x for x in sit] == ['a', 'b', 'c'] # noqa: C416 + assert exhausted == 1 + + +def test_always_supports_sequence_next() -> None: + exhausted = 0 + + def it() -> Iterator[str]: + nonlocal exhausted + yield from ['a', 'b', 'c'] + exhausted += 1 + + sit = always_supports_sequence(it()) + + x = next(sit) + assert x == 'a' + assert exhausted == 0 + + x = next(sit) + assert x == 'b' + assert exhausted == 0 + + +def test_always_supports_sequence_iter() -> None: + exhausted = 0 + + def it() -> Iterator[str]: + nonlocal exhausted + yield from ['a', 'b', 'c'] + exhausted += 1 + + sit = always_supports_sequence(it()) + + for x in sit: + assert x == 'a' + break + + x = next(sit) + assert x == 'b' + + assert exhausted == 0 + + x = next(sit) + assert x == 'c' + assert exhausted == 0 + + for _ in sit: + raise RuntimeError # shouldn't trigger, just exhaust the iterator + assert exhausted == 1