diff --git a/sql.py b/sql.py index c8d2118..b93475a 100755 --- a/sql.py +++ b/sql.py @@ -2,21 +2,40 @@ from pathlib import Path from datetime import datetime from itertools import islice -from typing import Type, NamedTuple, Union +from typing import Type, NamedTuple, Union, Optional import logging from location import _load_locations, Location, get_logger +import sqlalchemy # type: ignore import sqlalchemy as sa # type: ignore from kython import ichunks +from kython.py37 import fromisoformat + +# TODO move to some common thing? +class IsoDateTime(sqlalchemy.TypeDecorator): + # TODO can we use something more effecient? e.g. blob for encoded datetime and tz? not sure if worth it + impl = sqlalchemy.types.String + + # TODO optional? + def process_bind_param(self, value: Optional[datetime], dialect) -> Optional[str]: + if value is None: + return None + return value.isoformat() + + def process_result_value(self, value: Optional[str], dialect) -> Optional[datetime]: + if value is None: + return None + return fromisoformat(value) + + def _map_type(cls): tmap = { str: sa.String, float: sa.Float, - datetime: sa.types.TIMESTAMP(timezone=True), # TODO tz? - # TODO FIXME utc seems to be lost.. doesn't sqlite support it or what? + datetime: IsoDateTime, } r = tmap.get(cls, None) if r is not None: @@ -30,7 +49,8 @@ def _map_type(cls): return _map_type(elems[0]) # meh.. raise RuntimeError(f'Unexpected type {cls}') - +# TODO to strart with, just assert utc when serializing, deserializing +# TODO how to use timestamp as key? just round it? def make_schema(cls: Type[NamedTuple]): # TODO covariant? res = [] @@ -71,7 +91,7 @@ def test(tmp_path): tdir = Path(tmp_path) tdb = tdir / 'test.sqlite' test_limit = 100 - test_src = Path('/L/tmp/loc/LocationHistory.json') + test_src = Path('/L/tmp/LocationHistory.json') # TODO meh, double loading, but for now fine with test_src.open('r') as fo: @@ -80,10 +100,7 @@ def test(tmp_path): cache_locs(source=test_src, db_path=tdb, limit=test_limit) cached_locs = list(iter_db_locs(tdb)) assert len(cached_locs) == test_limit - def FIXME_tz(locs): - # TODO FIXME tzinfo... - return [x._replace(dt=x.dt.replace(tzinfo=None)) for x in locs] - assert FIXME_tz(real_locs) == FIXME_tz(cached_locs) + assert real_locs == cached_locs def main():