core: add helper for more_iterable to check that all types involved are hashable

Otherwise unique_everseen performance may degrade to quadratic rather than linear

For now hidden behind HPI_CHECK_UNIQUE_EVERSEEN flag

also switch some modules to use it
This commit is contained in:
karlicoss 2023-10-31 00:42:17 +00:00
parent d6786084ca
commit 71cb66df5f
8 changed files with 90 additions and 23 deletions

View file

@ -6,7 +6,25 @@ from contextlib import contextmanager
import os import os
import sys import sys
import types import types
from typing import Union, Callable, Dict, Iterable, TypeVar, Sequence, List, Optional, Any, cast, Tuple, TYPE_CHECKING, NoReturn from typing import (
Any,
Callable,
Dict,
Iterable,
Iterator,
List,
NoReturn,
Optional,
Sequence,
TYPE_CHECKING,
Tuple,
TypeVar,
Union,
cast,
get_args,
get_type_hints,
get_origin,
)
import warnings import warnings
from . import warnings as core_warnings from . import warnings as core_warnings
@ -628,6 +646,59 @@ def assert_never(value: NoReturn) -> NoReturn:
assert False, f'Unhandled value: {value} ({type(value).__name__})' assert False, f'Unhandled value: {value} ({type(value).__name__})'
def _check_all_hashable(fun):
# TODO ok, take callable?
hints = get_type_hints(fun)
# TODO needs to be defensive like in cachew?
return_type = hints.get('return')
# TODO check if None
origin = get_origin(return_type) # Iterator etc?
(arg,) = get_args(return_type)
# options we wanna handle are simple type on the top level or union
arg_origin = get_origin(arg)
if sys.version_info[:2] >= (3, 10):
is_uniontype = arg_origin is types.UnionType
else:
is_uniontype = False
is_union = arg_origin is Union or is_uniontype
if is_union:
to_check = get_args(arg)
else:
to_check = (arg,)
no_hash = [
t
for t in to_check
# seems that objects that have not overridden hash have the attribute but it's set to None
if getattr(t, '__hash__', None) is None
]
assert len(no_hash) == 0, f'Types {no_hash} are not hashable, this will result in significant performance downgrade for unique_everseen'
_UET = TypeVar('_UET')
_UEU = TypeVar('_UEU')
def unique_everseen(
fun: Callable[[], Iterable[_UET]],
key: Optional[Callable[[_UET], _UEU]] = None,
) -> Iterator[_UET]:
# TODO support normal iterable as well?
import more_itertools
# NOTE: it has to take original callable, because otherwise we don't have access to generator type annotations
iterable = fun()
if key is None:
# todo check key return type as well? but it's more likely to be hashable
if os.environ.get('HPI_CHECK_UNIQUE_EVERSEEN') is not None:
_check_all_hashable(fun)
return more_itertools.unique_everseen(iterable=iterable, key=key)
## legacy imports, keeping them here for backwards compatibility ## legacy imports, keeping them here for backwards compatibility
from functools import cached_property as cproperty from functools import cached_property as cproperty
from typing import Literal from typing import Literal

View file

@ -9,9 +9,8 @@ from pathlib import Path
import sqlite3 import sqlite3
from typing import Iterator, Sequence, Optional, Dict, Union, List from typing import Iterator, Sequence, Optional, Dict, Union, List
from more_itertools import unique_everseen
from my.core import get_files, Paths, datetime_aware, Res, assert_never, LazyLogger, make_config from my.core import get_files, Paths, datetime_aware, Res, assert_never, LazyLogger, make_config
from my.core.common import unique_everseen
from my.core.error import echain from my.core.error import echain
from my.core.sqlite import sqlite_connection from my.core.sqlite import sqlite_connection
@ -242,7 +241,7 @@ def messages() -> Iterator[Res[Message]]:
senders: Dict[str, Sender] = {} senders: Dict[str, Sender] = {}
msgs: Dict[str, Message] = {} msgs: Dict[str, Message] = {}
threads: Dict[str, Thread] = {} threads: Dict[str, Thread] = {}
for x in unique_everseen(_entities()): for x in unique_everseen(_entities):
if isinstance(x, Exception): if isinstance(x, Exception):
yield x yield x
continue continue

View file

@ -10,8 +10,6 @@ from pathlib import Path
import sqlite3 import sqlite3
from typing import Iterator, Sequence, Optional, Dict, Union from typing import Iterator, Sequence, Optional, Dict, Union
from more_itertools import unique_everseen
from my.core import ( from my.core import (
get_files, get_files,
Paths, Paths,
@ -22,6 +20,7 @@ from my.core import (
Res, Res,
assert_never, assert_never,
) )
from my.core.common import unique_everseen
from my.core.cachew import mcachew from my.core.cachew import mcachew
from my.core.error import echain from my.core.error import echain
from my.core.sqlite import sqlite_connect_immutable, select from my.core.sqlite import sqlite_connect_immutable, select
@ -196,7 +195,7 @@ def _entities() -> Iterator[Res[Union[User, _Message]]]:
@mcachew(depends_on=inputs) @mcachew(depends_on=inputs)
def messages() -> Iterator[Res[Message]]: def messages() -> Iterator[Res[Message]]:
id2user: Dict[str, User] = {} id2user: Dict[str, User] = {}
for x in unique_everseen(_entities()): for x in unique_everseen(_entities):
if isinstance(x, Exception): if isinstance(x, Exception):
yield x yield x
continue continue

View file

@ -7,7 +7,7 @@ import json
from pathlib import Path from pathlib import Path
from typing import Iterator, Sequence, Dict, Union from typing import Iterator, Sequence, Dict, Union
from more_itertools import bucket, unique_everseen from more_itertools import bucket
from my.core import ( from my.core import (
get_files, get_files,
@ -17,6 +17,7 @@ from my.core import (
assert_never, assert_never,
make_logger, make_logger,
) )
from my.core.common import unique_everseen
from my.config import instagram as user_config from my.config import instagram as user_config
@ -196,7 +197,7 @@ def _entitites_from_path(path: Path) -> Iterator[Res[Union[User, _Message]]]:
# TODO basically copy pasted from android.py... hmm # TODO basically copy pasted from android.py... hmm
def messages() -> Iterator[Res[Message]]: def messages() -> Iterator[Res[Message]]:
id2user: Dict[str, User] = {} id2user: Dict[str, User] = {}
for x in unique_everseen(_entities()): for x in unique_everseen(_entities):
if isinstance(x, Exception): if isinstance(x, Exception):
yield x yield x
continue continue

View file

@ -11,9 +11,8 @@ from pathlib import Path
import sqlite3 import sqlite3
from typing import Sequence, Iterator, Union, Dict, List, Mapping from typing import Sequence, Iterator, Union, Dict, List, Mapping
from more_itertools import unique_everseen
from my.core import Paths, get_files, Res, assert_never, stat, Stats, datetime_aware, make_logger from my.core import Paths, get_files, Res, assert_never, stat, Stats, datetime_aware, make_logger
from my.core.common import unique_everseen
from my.core.error import echain from my.core.error import echain
from my.core.sqlite import sqlite_connection from my.core.sqlite import sqlite_connection
import my.config import my.config
@ -162,7 +161,7 @@ def _parse_msg(row: sqlite3.Row) -> _Message:
def entities() -> Iterator[Res[Entity]]: def entities() -> Iterator[Res[Entity]]:
id2person: Dict[str, Person] = {} id2person: Dict[str, Person] = {}
id2match: Dict[str, Match] = {} id2match: Dict[str, Match] = {}
for x in unique_everseen(_entities()): for x in unique_everseen(_entities):
if isinstance(x, Exception): if isinstance(x, Exception):
yield x yield x
continue continue

View file

@ -9,9 +9,8 @@ import re
import sqlite3 import sqlite3
from typing import Iterator, Sequence, Union from typing import Iterator, Sequence, Union
from more_itertools import unique_everseen
from my.core import Paths, Res, datetime_aware, get_files from my.core import Paths, Res, datetime_aware, get_files
from my.core.common import unique_everseen
from my.core.sqlite import sqlite_connection from my.core.sqlite import sqlite_connection
from .common import TweetId, permalink from .common import TweetId, permalink
@ -133,7 +132,7 @@ def _parse_tweet(row: sqlite3.Row) -> Tweet:
def tweets() -> Iterator[Res[Tweet]]: def tweets() -> Iterator[Res[Tweet]]:
for x in unique_everseen(_entities()): for x in unique_everseen(_entities):
if isinstance(x, Exception): if isinstance(x, Exception):
yield x yield x
elif isinstance(x, _IsTweet): elif isinstance(x, _IsTweet):
@ -141,7 +140,7 @@ def tweets() -> Iterator[Res[Tweet]]:
def likes() -> Iterator[Res[Tweet]]: def likes() -> Iterator[Res[Tweet]]:
for x in unique_everseen(_entities()): for x in unique_everseen(_entities):
if isinstance(x, Exception): if isinstance(x, Exception):
yield x yield x
elif isinstance(x, _IsFavorire): elif isinstance(x, _IsFavorire):

View file

@ -5,12 +5,12 @@ VK data (exported by [[https://github.com/Totktonada/vk_messages_backup][Totkton
from datetime import datetime from datetime import datetime
from dataclasses import dataclass from dataclasses import dataclass
import json import json
from typing import Dict, Iterator, NamedTuple from typing import Dict, Iterator
from more_itertools import unique_everseen
import pytz import pytz
from my.core import stat, Stats, Json, Res, datetime_aware from my.core import stat, Stats, Json, Res, datetime_aware, get_files
from my.core.common import unique_everseen
from my.config import vk_messages_backup as config from my.config import vk_messages_backup as config
@ -147,7 +147,7 @@ def _messages() -> Iterator[Res[Message]]:
def messages() -> Iterator[Res[Message]]: def messages() -> Iterator[Res[Message]]:
# seems that during backup messages were sometimes duplicated.. # seems that during backup messages were sometimes duplicated..
yield from unique_everseen(_messages()) yield from unique_everseen(_messages)
def stats() -> Stats: def stats() -> Stats:

View file

@ -9,9 +9,8 @@ from pathlib import Path
import sqlite3 import sqlite3
from typing import Sequence, Iterator, Optional from typing import Sequence, Iterator, Optional
from more_itertools import unique_everseen
from my.core import get_files, Paths, datetime_aware, Res, make_logger, make_config from my.core import get_files, Paths, datetime_aware, Res, make_logger, make_config
from my.core.common import unique_everseen
from my.core.error import echain, notnone from my.core.error import echain, notnone
from my.core.sqlite import sqlite_connection from my.core.sqlite import sqlite_connection
import my.config import my.config
@ -202,4 +201,4 @@ def _messages() -> Iterator[Res[Message]]:
def messages() -> Iterator[Res[Message]]: def messages() -> Iterator[Res[Message]]:
yield from unique_everseen(_messages()) yield from unique_everseen(_messages)