core/general: add assert_never + typing annotations for dataset

This commit is contained in:
Dima Gerasimov 2022-06-03 20:31:13 +01:00 committed by karlicoss
parent fd1a683d49
commit 7a1b7b1554
8 changed files with 54 additions and 19 deletions

View file

@ -54,7 +54,7 @@ class Message(_BaseMessage):
import json import json
from typing import Union from typing import Union
from ..core import Res from ..core import Res, assert_never
import sqlite3 import sqlite3
from ..core.sqlite import sqlite_connect_immutable from ..core.sqlite import sqlite_connect_immutable
@ -66,7 +66,12 @@ def _entities() -> Iterator[EntitiesRes]:
yield from _handle_db(db) yield from _handle_db(db)
def _handle_db(db) -> Iterator[EntitiesRes]: def _handle_db(db: sqlite3.Connection) -> Iterator[EntitiesRes]:
# todo hmm not sure
# on the one hand kinda nice to use dataset..
# on the other, it's somewhat of a complication, and
# would be nice to have something type-directed for sql queries though
# e.g. with typeddict or something, so the number of parameter to the sql query matches?
for row in db.execute(f'SELECT user_id, user_name FROM conversation_info'): for row in db.execute(f'SELECT user_id, user_name FROM conversation_info'):
(user_id, user_name) = row (user_id, user_name) = row
yield Person( yield Person(
@ -136,4 +141,4 @@ def messages() -> Iterator[Res[Message]]:
id2msg[m.id] = m id2msg[m.id] = m
yield m yield m
continue continue
assert False, type(x) # should be unreachable assert_never(x)

View file

@ -5,6 +5,7 @@ from .common import LazyLogger
from .common import warn_if_empty from .common import warn_if_empty
from .common import stat, Stats from .common import stat, Stats
from .common import datetime_naive, datetime_aware from .common import datetime_naive, datetime_aware
from .common import assert_never
from .cfg import make_config from .cfg import make_config
from .util import __NOT_HPI_MODULE__ from .util import __NOT_HPI_MODULE__

View file

@ -4,7 +4,7 @@ from datetime import datetime
import functools import functools
from contextlib import contextmanager from contextlib import contextmanager
import types import types
from typing import Union, Callable, Dict, Iterable, TypeVar, Sequence, List, Optional, Any, cast, Tuple, TYPE_CHECKING from typing import Union, Callable, Dict, Iterable, TypeVar, Sequence, List, Optional, Any, cast, Tuple, TYPE_CHECKING, NoReturn
import warnings import warnings
from . import warnings as core_warnings from . import warnings as core_warnings
@ -632,5 +632,11 @@ class DummyExecutor(Executor):
def shutdown(self, wait: bool=True) -> None: # type: ignore[override] def shutdown(self, wait: bool=True) -> None: # type: ignore[override]
self._shutdown = True self._shutdown = True
# see https://hakibenita.com/python-mypy-exhaustive-checking#exhaustiveness-checking
def assert_never(value: NoReturn) -> NoReturn:
assert False, f'Unhandled value: {value} ({type(value).__name__})'
# legacy deprecated import # legacy deprecated import
from .compat import cached_property as cproperty from .compat import cached_property as cproperty

View file

@ -1,11 +1,29 @@
from __future__ import annotations
from .common import assert_subpackage; assert_subpackage(__name__) from .common import assert_subpackage; assert_subpackage(__name__)
from .common import PathIsh from .common import PathIsh
from .compat import Protocol
from .sqlite import sqlite_connect_immutable from .sqlite import sqlite_connect_immutable
## sadly dataset doesn't have any type definitions
from typing import Iterable, Iterator, Dict, Optional, Any
from contextlib import AbstractContextManager
# NOTE: may not be true in general, but will be in the vast majority of cases
row_type_T = Dict[str, Any]
class TableT(Iterable, Protocol):
def find(self, *, order_by: Optional[str]=None) -> Iterator[row_type_T]: ...
class DatabaseT(AbstractContextManager['DatabaseT'], Protocol):
def __getitem__(self, table: str) -> TableT: ...
##
# TODO wonder if also need to open without WAL.. test this on read-only directory/db file # TODO wonder if also need to open without WAL.. test this on read-only directory/db file
def connect_readonly(db: PathIsh): def connect_readonly(db: PathIsh) -> DatabaseT:
import dataset # type: ignore import dataset # type: ignore
# see https://github.com/pudo/dataset/issues/136#issuecomment-128693122 # see https://github.com/pudo/dataset/issues/136#issuecomment-128693122
# todo not sure if mode=ro has any benefit, but it doesn't work on read-only filesystems # todo not sure if mode=ro has any benefit, but it doesn't work on read-only filesystems

View file

@ -3,6 +3,8 @@ Messenger data from Android app database (in =/data/data/com.facebook.orca/datab
""" """
from __future__ import annotations from __future__ import annotations
REQUIRES = ['dataset']
from dataclasses import dataclass from dataclasses import dataclass
from datetime import datetime from datetime import datetime
from typing import Iterator, Sequence, Optional, Dict from typing import Iterator, Sequence, Optional, Dict
@ -61,8 +63,8 @@ class Message(_BaseMessage):
import json import json
from typing import Union from typing import Union
from ..core.error import Res from ..core import Res, assert_never
from ..core.dataset import connect_readonly from ..core.dataset import connect_readonly, DatabaseT
Entity = Union[Sender, Thread, _Message] Entity = Union[Sender, Thread, _Message]
def _entities() -> Iterator[Res[Entity]]: def _entities() -> Iterator[Res[Entity]]:
for f in inputs(): for f in inputs():
@ -70,11 +72,11 @@ def _entities() -> Iterator[Res[Entity]]:
yield from _process_db(db) yield from _process_db(db)
def _process_db(db) -> Iterator[Res[Entity]]: def _process_db(db: DatabaseT) -> Iterator[Res[Entity]]:
# works both for GROUP:group_id and ONE_TO_ONE:other_user:your_user # works both for GROUP:group_id and ONE_TO_ONE:other_user:your_user
threadkey2id = lambda key: key.split(':')[1] threadkey2id = lambda key: key.split(':')[1]
for r in db['threads']: for r in db['threads'].find():
try: try:
yield Thread( yield Thread(
id=threadkey2id(r['thread_key']), id=threadkey2id(r['thread_key']),
@ -84,8 +86,8 @@ def _process_db(db) -> Iterator[Res[Entity]]:
yield e yield e
continue continue
for r in db['messages'].all(order_by='timestamp_ms'): for r in db['messages'].find(order_by='timestamp_ms'):
mtype = r['msg_type'] mtype: int = r['msg_type']
if mtype == -1: if mtype == -1:
# likely immediately deleted or something? doesn't have any data at all # likely immediately deleted or something? doesn't have any data at all
continue continue
@ -94,7 +96,7 @@ def _process_db(db) -> Iterator[Res[Entity]]:
try: try:
# todo could use thread_users? # todo could use thread_users?
sj = json.loads(r['sender']) sj = json.loads(r['sender'])
ukey = sj['user_key'] ukey: str = sj['user_key']
prefix = 'FACEBOOK:' prefix = 'FACEBOOK:'
assert ukey.startswith(prefix), ukey assert ukey.startswith(prefix), ukey
user_id = ukey[len(prefix):] user_id = ukey[len(prefix):]
@ -167,4 +169,6 @@ def messages() -> Iterator[Res[Message]]:
msgs[m.id] = m msgs[m.id] = m
yield m yield m
continue continue
assert False, type(x) # should be unreachable # NOTE: for some reason mypy coverage highlights it as red?
# but it actually works as expected: i.e. if you omit one of the clauses above, mypy will complain
assert_never(x)

View file

@ -88,12 +88,13 @@ def _parse_message(j: Json) -> Optional[_Message]:
import json import json
from typing import Union from typing import Union
from ..core.error import Res from ..core import Res, assert_never
import sqlite3 import sqlite3
from ..core.sqlite import sqlite_connect_immutable from ..core.sqlite import sqlite_connect_immutable
def _entities() -> Iterator[Res[Union[User, _Message]]]: def _entities() -> Iterator[Res[Union[User, _Message]]]:
# NOTE: definitely need to merge multiple, app seems to recycle old messages # NOTE: definitely need to merge multiple, app seems to recycle old messages
# TODO: hmm hard to guarantee timestamp ordering when we use synthetic input data... # TODO: hmm hard to guarantee timestamp ordering when we use synthetic input data...
# todo use TypedDict?
for f in inputs(): for f in inputs():
with sqlite_connect_immutable(f) as db: with sqlite_connect_immutable(f) as db:
@ -149,4 +150,4 @@ def messages() -> Iterator[Res[Message]]:
user=user, user=user,
) )
continue continue
assert False, type(x) # should not happen assert_never(x)

View file

@ -56,7 +56,7 @@ def _decode(s: str) -> str:
import json import json
from typing import Union from typing import Union
from ..core.error import Res from ..core import Res, assert_never
def _entities() -> Iterator[Res[Union[User, _Message]]]: def _entities() -> Iterator[Res[Union[User, _Message]]]:
from ..core.kompress import ZipPath from ..core.kompress import ZipPath
last = ZipPath(max(inputs())) last = ZipPath(max(inputs()))
@ -165,4 +165,4 @@ def messages() -> Iterator[Res[Message]]:
user=user, user=user,
) )
continue continue
assert False, type(x) # should not happen assert_never(x)

View file

@ -79,7 +79,7 @@ class Message:
from typing import Union from typing import Union
from itertools import count from itertools import count
import json import json
from ..core import Res from ..core import Res, assert_never
# todo cache it # todo cache it
def _entities() -> Iterator[Res[Union[Server, Sender, _Message]]]: def _entities() -> Iterator[Res[Union[Server, Sender, _Message]]]:
# TODO hmm -- not sure if max lexicographically will actually be latest? # TODO hmm -- not sure if max lexicographically will actually be latest?
@ -169,4 +169,4 @@ def messages() -> Iterator[Res[Message]]:
content=x.content, content=x.content,
) )
continue continue
assert False # should be unreachable assert_never(x)