core/sqlite: experiment at typing SELECT query (to some extent)
ideally would be cool to use TypedDict here somehow, but perhaps it'd only be possible after variadic generics https://peps.python.org/pep-0646
This commit is contained in:
parent
7a1b7b1554
commit
bf3dd6e931
3 changed files with 57 additions and 17 deletions
|
@ -56,7 +56,7 @@ import json
|
||||||
from typing import Union
|
from typing import Union
|
||||||
from ..core import Res, assert_never
|
from ..core import Res, assert_never
|
||||||
import sqlite3
|
import sqlite3
|
||||||
from ..core.sqlite import sqlite_connect_immutable
|
from ..core.sqlite import sqlite_connect_immutable, select
|
||||||
|
|
||||||
EntitiesRes = Res[Union[Person, _Message]]
|
EntitiesRes = Res[Union[Person, _Message]]
|
||||||
|
|
||||||
|
@ -72,20 +72,22 @@ def _handle_db(db: sqlite3.Connection) -> Iterator[EntitiesRes]:
|
||||||
# on the other, it's somewhat of a complication, and
|
# on the other, it's somewhat of a complication, and
|
||||||
# would be nice to have something type-directed for sql queries though
|
# would be nice to have something type-directed for sql queries though
|
||||||
# e.g. with typeddict or something, so the number of parameter to the sql query matches?
|
# e.g. with typeddict or something, so the number of parameter to the sql query matches?
|
||||||
for row in db.execute(f'SELECT user_id, user_name FROM conversation_info'):
|
for (user_id, user_name) in select(
|
||||||
(user_id, user_name) = row
|
('user_id', 'user_name'),
|
||||||
|
'FROM conversation_info',
|
||||||
|
db=db,
|
||||||
|
):
|
||||||
yield Person(
|
yield Person(
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
user_name=user_name,
|
user_name=user_name,
|
||||||
)
|
)
|
||||||
|
|
||||||
# has sender_name, but it's always None
|
# note: has sender_name, but it's always None
|
||||||
for row in db.execute(f'''
|
for ( id, conversation_id , created , is_incoming , payload_type , payload , reply_to_id) in select(
|
||||||
SELECT id, conversation_id, created_timestamp, is_incoming, payload_type, payload, reply_to_id
|
('id', 'conversation_id', 'created_timestamp', 'is_incoming', 'payload_type', 'payload', 'reply_to_id'),
|
||||||
FROM message
|
'FROM message ORDER BY created_timestamp',
|
||||||
ORDER BY created_timestamp
|
db=db
|
||||||
'''):
|
):
|
||||||
(id, conversation_id, created, is_incoming, payload_type, payload, reply_to_id) = row
|
|
||||||
try:
|
try:
|
||||||
key = {'TEXT': 'text', 'QUESTION_GAME': 'text', 'IMAGE': 'url', 'GIF': 'url'}[payload_type]
|
key = {'TEXT': 'text', 'QUESTION_GAME': 'text', 'IMAGE': 'url', 'GIF': 'url'}[payload_type]
|
||||||
text = json.loads(payload)[key]
|
text = json.loads(payload)[key]
|
||||||
|
|
|
@ -50,3 +50,43 @@ def sqlite_copy_and_open(db: PathIsh) -> sqlite3.Connection:
|
||||||
sqlite_backup(source=conn, dest=dest)
|
sqlite_backup(source=conn, dest=dest)
|
||||||
conn.close()
|
conn.close()
|
||||||
return dest
|
return dest
|
||||||
|
|
||||||
|
|
||||||
|
from typing import Tuple, Any, Iterator
|
||||||
|
|
||||||
|
# NOTE hmm, so this kinda works
|
||||||
|
# V = TypeVar('V', bound=Tuple[Any, ...])
|
||||||
|
# def select(cols: V, rest: str, *, db: sqlite3.Connetion) -> Iterator[V]:
|
||||||
|
# but sadly when we pass columns (Tuple[str, ...]), it seems to bind this type to V?
|
||||||
|
# and then the return type ends up as Iterator[Tuple[str, ...]], which isn't desirable :(
|
||||||
|
# a bit annoying to have this copy-pasting, but hopefully not a big issue
|
||||||
|
|
||||||
|
from typing import overload
|
||||||
|
@overload
|
||||||
|
def select(cols: Tuple[str ], rest: str, *, db: sqlite3.Connection) -> \
|
||||||
|
Iterator[Tuple[Any ]]: ...
|
||||||
|
@overload
|
||||||
|
def select(cols: Tuple[str, str ], rest: str, *, db: sqlite3.Connection) -> \
|
||||||
|
Iterator[Tuple[Any, Any ]]: ...
|
||||||
|
@overload
|
||||||
|
def select(cols: Tuple[str, str, str ], rest: str, *, db: sqlite3.Connection) -> \
|
||||||
|
Iterator[Tuple[Any, Any, Any ]]: ...
|
||||||
|
@overload
|
||||||
|
def select(cols: Tuple[str, str, str, str ], rest: str, *, db: sqlite3.Connection) -> \
|
||||||
|
Iterator[Tuple[Any, Any, Any, Any ]]: ...
|
||||||
|
@overload
|
||||||
|
def select(cols: Tuple[str, str, str, str, str ], rest: str, *, db: sqlite3.Connection) -> \
|
||||||
|
Iterator[Tuple[Any, Any, Any, Any, Any ]]: ...
|
||||||
|
@overload
|
||||||
|
def select(cols: Tuple[str, str, str, str, str, str ], rest: str, *, db: sqlite3.Connection) -> \
|
||||||
|
Iterator[Tuple[Any, Any, Any, Any, Any, Any ]]: ...
|
||||||
|
@overload
|
||||||
|
def select(cols: Tuple[str, str, str, str, str, str, str ], rest: str, *, db: sqlite3.Connection) -> \
|
||||||
|
Iterator[Tuple[Any, Any, Any, Any, Any, Any, Any ]]: ...
|
||||||
|
@overload
|
||||||
|
def select(cols: Tuple[str, str, str, str, str, str, str, str], rest: str, *, db: sqlite3.Connection) -> \
|
||||||
|
Iterator[Tuple[Any, Any, Any, Any, Any, Any, Any, Any]]: ...
|
||||||
|
|
||||||
|
def select(cols, rest, *, db):
|
||||||
|
# db arg is last cause that results in nicer code formatting..
|
||||||
|
return db.execute('SELECT ' + ','.join(cols) + ' ' + rest)
|
||||||
|
|
|
@ -90,7 +90,7 @@ import json
|
||||||
from typing import Union
|
from typing import Union
|
||||||
from ..core import Res, assert_never
|
from ..core import Res, assert_never
|
||||||
import sqlite3
|
import sqlite3
|
||||||
from ..core.sqlite import sqlite_connect_immutable
|
from ..core.sqlite import sqlite_connect_immutable, select
|
||||||
def _entities() -> Iterator[Res[Union[User, _Message]]]:
|
def _entities() -> Iterator[Res[Union[User, _Message]]]:
|
||||||
# NOTE: definitely need to merge multiple, app seems to recycle old messages
|
# NOTE: definitely need to merge multiple, app seems to recycle old messages
|
||||||
# TODO: hmm hard to guarantee timestamp ordering when we use synthetic input data...
|
# TODO: hmm hard to guarantee timestamp ordering when we use synthetic input data...
|
||||||
|
@ -98,15 +98,14 @@ def _entities() -> Iterator[Res[Union[User, _Message]]]:
|
||||||
for f in inputs():
|
for f in inputs():
|
||||||
with sqlite_connect_immutable(f) as db:
|
with sqlite_connect_immutable(f) as db:
|
||||||
|
|
||||||
for row in db.execute(f'SELECT user_id, thread_info FROM threads'):
|
for (self_uid, thread_json) in select(('user_id', 'thread_info'), 'FROM threads', db=db):
|
||||||
(self_uid, js,) = row
|
|
||||||
# ugh wtf?? no easier way to extract your own user id/name??
|
# ugh wtf?? no easier way to extract your own user id/name??
|
||||||
yield User(
|
yield User(
|
||||||
id=str(self_uid),
|
id=str(self_uid),
|
||||||
full_name='You',
|
full_name='You',
|
||||||
username='you',
|
username='you',
|
||||||
)
|
)
|
||||||
j = json.loads(js)
|
j = json.loads(thread_json)
|
||||||
for r in j['recipients']:
|
for r in j['recipients']:
|
||||||
yield User(
|
yield User(
|
||||||
id=str(r['id']), # for some reason it's int in the db
|
id=str(r['id']), # for some reason it's int in the db
|
||||||
|
@ -114,10 +113,9 @@ def _entities() -> Iterator[Res[Union[User, _Message]]]:
|
||||||
username=r['username'],
|
username=r['username'],
|
||||||
)
|
)
|
||||||
|
|
||||||
for row in db.execute(f'SELECT message FROM messages ORDER BY timestamp'):
|
for (msg_json,) in select(('message',), 'FROM messages ORDER BY timestamp', db=db):
|
||||||
# eh, seems to contain everything in json?
|
# eh, seems to contain everything in json?
|
||||||
(js,) = row
|
j = json.loads(msg_json)
|
||||||
j = json.loads(js)
|
|
||||||
try:
|
try:
|
||||||
m = _parse_message(j)
|
m = _parse_message(j)
|
||||||
if m is not None:
|
if m is not None:
|
||||||
|
|
Loading…
Add table
Reference in a new issue