From b306ccc83995291eeec0073f07c99f8e85c94b8b Mon Sep 17 00:00:00 2001 From: Dima Gerasimov Date: Fri, 2 Apr 2021 19:41:32 +0100 Subject: [PATCH] core: add ensure_unique iterator transfromation --- my/core/common.py | 54 ++++++++++++++++++++++++++++++++++++++++++----- 1 file changed, 49 insertions(+), 5 deletions(-) diff --git a/my/core/common.py b/my/core/common.py index 565ef5b..a891a18 100644 --- a/my/core/common.py +++ b/my/core/common.py @@ -70,18 +70,62 @@ def group_by_key(l: Iterable[T], key: Callable[[T], K]) -> Dict[K, List[T]]: def _identity(v: T) -> V: return cast(V, v) -def make_dict(l: Iterable[T], key: Callable[[T], K], value: Callable[[T], V]=_identity) -> Dict[K, V]: - res: Dict[K, V] = {} - for i in l: + +# ugh. nothing in more_itertools? +def ensure_unique( + it: Iterable[T], + *, + key: Callable[[T], K], + value: Callable[[T], V]=_identity, + key2value: Optional[Dict[K, V]]=None +) -> Iterable[T]: + if key2value is None: + key2value = {} + for i in it: k = key(i) v = value(i) - pv = res.get(k, None) # type: ignore + pv = key2value.get(k, None) # type: ignore if pv is not None: raise RuntimeError(f"Duplicate key: {k}. Previous value: {pv}, new value: {v}") - res[k] = v + key2value[k] = v + yield i + + +def test_ensure_unique() -> None: + import pytest # type: ignore + assert list(ensure_unique([1, 2, 3], key=lambda i: i)) == [1, 2, 3] + + dups = [1, 2, 1, 4] + # this works because it's lazy + it = ensure_unique(dups, key=lambda i: i) + + # but forcing throws + with pytest.raises(RuntimeError, match='Duplicate key'): + list(it) + + # hacky way to force distinct objects? + list(ensure_unique(dups, key=lambda i: object())) + + +def make_dict( + it: Iterable[T], + *, + key: Callable[[T], K], + value: Callable[[T], V]=_identity +) -> Dict[K, V]: + res: Dict[K, V] = {} + uniques = ensure_unique(it, key=key, value=value, key2value=res) + for _ in uniques: + pass # force the iterator return res +def test_make_dict() -> None: + it = range(5) + d = make_dict(it, key=lambda i: i, value=lambda i: i % 2) + assert d == {0: 0, 1: 1, 2: 0, 3: 1, 4: 0} + + Cl = TypeVar('Cl') R = TypeVar('R')