my.core: fix list constructor in always_support_sequence and add some tests

This commit is contained in:
Dima Gerasimov 2024-09-22 04:27:32 +01:00 committed by karlicoss
parent 02dabe9f2b
commit 3166109f15

View file

@ -2,6 +2,7 @@
Contains various backwards compatibility/deprecation helpers relevant to HPI itself. Contains various backwards compatibility/deprecation helpers relevant to HPI itself.
(as opposed to .compat module which implements compatibility between python versions) (as opposed to .compat module which implements compatibility between python versions)
""" """
import inspect import inspect
import os import os
import re import re
@ -116,32 +117,141 @@ V = TypeVar('V')
# named to be kinda consistent with more_itertools, e.g. more_itertools.always_iterable # named to be kinda consistent with more_itertools, e.g. more_itertools.always_iterable
class always_supports_sequence(Iterator[V]): class always_supports_sequence(Iterator[V]):
""" """
Helper to make migration from Sequence/List to Iterable/Iterator type backwards compatible Helper to make migration from Sequence/List to Iterable/Iterator type backwards compatible in runtime
""" """
def __init__(self, it: Iterator[V]) -> None: def __init__(self, it: Iterator[V]) -> None:
self.it = it self._it = it
self._list: Optional[List] = None self._list: Optional[List[V]] = None
self._lit: Optional[Iterator[V]] = None
def __iter__(self) -> Iterator[V]: # noqa: PYI034 def __iter__(self) -> Iterator[V]: # noqa: PYI034
return self.it.__iter__() if self._list is not None:
self._lit = iter(self._list)
return self
def __next__(self) -> V: def __next__(self) -> V:
return self.it.__next__() if self._list is not None:
assert self._lit is not None
delegate = self._lit
else:
delegate = self._it
return next(delegate)
def __getattr__(self, name): def __getattr__(self, name):
return getattr(self.it, name) return getattr(self._it, name)
@property @property
def aslist(self) -> List[V]: def _aslist(self) -> List[V]:
if self._list is None: if self._list is None:
qualname = getattr(self.it, '__qualname__', '<no qualname>') # defensive just in case qualname = getattr(self._it, '__qualname__', '<no qualname>') # defensive just in case
warnings.medium(f'Using {qualname} as list is deprecated. Migrate to iterative processing or call list() explicitly.') warnings.medium(f'Using {qualname} as list is deprecated. Migrate to iterative processing or call list() explicitly.')
self._list = list(self.it) self._list = list(self._it)
# this is necessary for list constructor to work correctly
# since it's __iter__ first, then tries to compute length and then starts iterating...
self._lit = iter(self._list)
return self._list return self._list
def __len__(self) -> int: def __len__(self) -> int:
return len(self.aslist) return len(self._aslist)
def __getitem__(self, i: int) -> V: def __getitem__(self, i: int) -> V:
return self.aslist[i] return self._aslist[i]
def test_always_supports_sequence_list_constructor() -> None:
exhausted = 0
def it() -> Iterator[str]:
nonlocal exhausted
yield from ['a', 'b', 'c']
exhausted += 1
sit = always_supports_sequence(it())
# list constructor is a bit special... it's trying to compute length if it's available to optimize memory allocation
# so, what's happening in this case is
# - sit.__iter__ is called
# - sit.__len__ is called
# - sit.__next__ is called
res = list(sit)
assert res == ['a', 'b', 'c']
assert exhausted == 1
res = list(sit)
assert res == ['a', 'b', 'c']
assert exhausted == 1 # this will iterate over 'cached' list now, so original generator is only exhausted once
def test_always_supports_sequence_indexing() -> None:
exhausted = 0
def it() -> Iterator[str]:
nonlocal exhausted
yield from ['a', 'b', 'c']
exhausted += 1
sit = always_supports_sequence(it())
assert len(sit) == 3
assert exhausted == 1
assert sit[2] == 'c'
assert sit[1] == 'b'
assert sit[0] == 'a'
assert exhausted == 1
# a few tests to make sure list-like operations are working..
assert list(sit) == ['a', 'b', 'c']
assert [x for x in sit] == ['a', 'b', 'c'] # noqa: C416
assert list(sit) == ['a', 'b', 'c']
assert [x for x in sit] == ['a', 'b', 'c'] # noqa: C416
assert exhausted == 1
def test_always_supports_sequence_next() -> None:
exhausted = 0
def it() -> Iterator[str]:
nonlocal exhausted
yield from ['a', 'b', 'c']
exhausted += 1
sit = always_supports_sequence(it())
x = next(sit)
assert x == 'a'
assert exhausted == 0
x = next(sit)
assert x == 'b'
assert exhausted == 0
def test_always_supports_sequence_iter() -> None:
exhausted = 0
def it() -> Iterator[str]:
nonlocal exhausted
yield from ['a', 'b', 'c']
exhausted += 1
sit = always_supports_sequence(it())
for x in sit:
assert x == 'a'
break
x = next(sit)
assert x == 'b'
assert exhausted == 0
x = next(sit)
assert x == 'c'
assert exhausted == 0
for _ in sit:
raise RuntimeError # shouldn't trigger, just exhaust the iterator
assert exhausted == 1