From eb26cf863345f764a8eb279aff0589d67e397bae Mon Sep 17 00:00:00 2001 From: Sean Breckenridge Date: Fri, 19 Mar 2021 17:48:03 -0700 Subject: [PATCH] my.core.serialize: orjson with additional default and _serialize hook (#140) basic orjson serialize, json.dumps fallback Lots of surrounding changes from this discussion: https://github.com/seanbreckenridge/HPI-to-master/commit/0593c690566ccbbf54aea24f3e0ce3613e6603da --- my/core/common.py | 22 +++-- my/core/error.py | 10 +-- my/core/pandas.py | 23 +++++- my/core/serialize.py | 189 +++++++++++++++++++++++++++++++++++++++++++ setup.py | 1 + tests/core.py | 1 + tests/serialize.py | 1 + tox.ini | 1 + 8 files changed, 224 insertions(+), 24 deletions(-) create mode 100644 my/core/serialize.py create mode 100644 tests/serialize.py diff --git a/my/core/common.py b/my/core/common.py index 6e2251b..7c60632 100644 --- a/my/core/common.py +++ b/my/core/common.py @@ -531,7 +531,13 @@ def test_guess_datetime() -> None: # TODO test @property? -def asdict(thing) -> Json: +def is_namedtuple(thing: Any) -> bool: + # basic check to see if this is namedtuple-like + _asdict = getattr(thing, '_asdict', None) + return _asdict and callable(_asdict) + + +def asdict(thing: Any) -> Json: # todo primitive? # todo exception? if isinstance(thing, dict): @@ -539,19 +545,11 @@ def asdict(thing) -> Json: import dataclasses as D if D.is_dataclass(thing): return D.asdict(thing) - # must be a NT otherwise? - # todo add a proper check.. () - return thing._asdict() + if is_namedtuple(thing): + return thing._asdict() + raise TypeError(f'Could not convert object {thing} to dict') -# todo not sure about naming -def to_jsons(it) -> Iterable[Json]: - from .error import error_to_json # prevent circular import - for r in it: - if isinstance(r, Exception): - yield error_to_json(r) - else: - yield asdict(r) datetime_naive = datetime diff --git a/my/core/error.py b/my/core/error.py index 1d55d4a..33ba96a 100644 --- a/my/core/error.py +++ b/my/core/error.py @@ -145,15 +145,9 @@ def extract_error_datetime(e: Exception) -> Optional[datetime]: import traceback from .common import Json -def error_to_json(e: Exception, *, dt_col: str='dt', tz=None) -> Json: - edt = extract_error_datetime(e) - if edt is not None and edt.tzinfo is None and tz is not None: - edt = edt.replace(tzinfo=tz) +def error_to_json(e: Exception) -> Json: estr = ''.join(traceback.format_exception(Exception, e, e.__traceback__)) - return { - 'error': estr, - dt_col : edt, - } + return {'error': estr} def test_datetime_errors() -> None: diff --git a/my/core/pandas.py b/my/core/pandas.py index 90b49ce..03450f2 100644 --- a/my/core/pandas.py +++ b/my/core/pandas.py @@ -7,7 +7,7 @@ from datetime import datetime from pprint import pformat from typing import Optional, TYPE_CHECKING, Any, Iterable, Type, List, Dict from . import warnings, Res -from .common import LazyLogger +from .common import LazyLogger, Json, asdict logger = LazyLogger(__name__) @@ -97,8 +97,23 @@ def check_dataframe(f: FuncT, error_col_policy: ErrorColPolicy='add_if_missing', # todo doctor: could have a suggesion to wrap dataframes with it?? discover by return type? -from .error import error_to_json -error_to_row = error_to_json # todo deprecate? +def error_to_row(e: Exception, *, dt_col: str='dt', tz=None) -> Json: + from .error import error_to_json, extract_error_datetime + edt = extract_error_datetime(e) + if edt is not None and edt.tzinfo is None and tz is not None: + edt = edt.replace(tzinfo=tz) + err_dict: Json = error_to_json(e) + err_dict[dt_col] = edt + return err_dict + + +# todo not sure about naming +def to_jsons(it: Iterable[Res[Any]]) -> Iterable[Json]: + for r in it: + if isinstance(r, Exception): + yield error_to_row(r) + else: + yield asdict(r) # mm. https://github.com/python/mypy/issues/8564 @@ -111,6 +126,7 @@ def _as_columns(s: Schema) -> Dict[str, Type]: if D.is_dataclass(s): return {f.name: f.type for f in D.fields(s)} # else must be NamedTuple?? + # todo assert my.core.common.is_namedtuple? return getattr(s, '_field_types') @@ -124,7 +140,6 @@ def as_dataframe(it: Iterable[Res[Any]], schema: Optional[Schema]=None) -> DataF # https://github.com/pandas-dev/pandas/blob/fc9fdba6592bdb5d0d1147ce4d65639acd897565/pandas/core/frame.py#L562 # same for NamedTuple -- seems that it takes whatever schema the first NT has # so we need to convert each individually... sigh - from .common import to_jsons import pandas as pd columns = None if schema is None else list(_as_columns(schema).keys()) return pd.DataFrame(to_jsons(it), columns=columns) diff --git a/my/core/serialize.py b/my/core/serialize.py new file mode 100644 index 0000000..e910ecc --- /dev/null +++ b/my/core/serialize.py @@ -0,0 +1,189 @@ +import datetime +from typing import Any, Optional, Callable +from functools import lru_cache + +from .common import is_namedtuple +from .error import error_to_json + +# note: it would be nice to combine the 'asdict' and _default_encode to some function +# that takes a complex python object and returns JSON-compatible fields, while still +# being a dictionary. +# a workaround is to encode with dumps below and then json.loads it immediately + + +DefaultEncoder = Callable[[Any], Any] + + +def _default_encode(obj: Any) -> Any: + """ + Encodes complex python datatypes to simpler representations, + before they're serialized to JSON string + """ + # orjson doesn't serialize namedtuples to avoid serializing + # them as tuples (arrays), since they're technically a subclass + if is_namedtuple(obj): + return obj._asdict() + if isinstance(obj, datetime.timedelta): + return obj.total_seconds() + if isinstance(obj, Exception): + return error_to_json(obj) + # note: _serialize would only be called for items which aren't already + # serialized as a dataclass or namedtuple + # discussion: https://github.com/karlicoss/HPI/issues/138#issuecomment-801704929 + if hasattr(obj, '_serialize') and callable(obj._serialize): + return obj._serialize() + raise TypeError(f"Could not serialize object of type {type(obj).__name__}") + + +# could possibly run multiple times/raise warning if you provide different 'default' +# functions or change the kwargs? The alternative is to maintain all of this at the module +# level, which is just as annoying +@lru_cache(maxsize=None) +def _dumps_factory(**kwargs) -> Callable[[Any], str]: + use_default: DefaultEncoder = _default_encode + # if the user passed an additional 'default' parameter, + # try using that to serialize before before _default_encode + _additional_default: Optional[DefaultEncoder] = kwargs.get("default") + if _additional_default is not None and callable(_additional_default): + + def wrapped_default(obj: Any) -> Any: + try: + # hmm... shouldn't mypy know that _additional_default is not None here? + # assert _additional_default is not None + return _additional_default(obj) # type: ignore[misc] + except TypeError: + # expected TypeError, signifies couldn't be encoded by custom + # serializer function. Try _default_encode from here + return _default_encode(obj) + + use_default = wrapped_default + + kwargs["default"] = use_default + + try: + import orjson + + # todo: add orjson.OPT_NON_STR_KEYS? would require some bitwise ops + # most keys are typically attributes from a NT/Dataclass, + # so most seem to work: https://github.com/ijl/orjson#opt_non_str_keys + def _orjson_dumps(obj: Any) -> str: + # orjson returns json as bytes, encode to string + return orjson.dumps(obj, **kwargs).decode('utf-8') + + return _orjson_dumps + except ModuleNotFoundError: + import json + import warnings + + warnings.warn("You might want to install 'orjson' to support serialization for lots more types!") + + def _stdlib_dumps(obj: Any) -> str: + return json.dumps(obj, **kwargs) + + return _stdlib_dumps + + +def dumps( + obj: Any, + default: Optional[DefaultEncoder] = None, + **kwargs, +) -> str: + """ + Any additional arguments are forwarded -- either to orjson.dumps + or json.dumps if orjson is not installed + + You can pass the 'option' kwarg to orjson, see here for possible options: + https://github.com/ijl/orjson#option + + Any class/instance can implement a `_serialize` function, which is used + to convert it to a JSON-compatible representation. + If present, it is called during _default_encode + + 'default' is called before _default_encode, and should raise a TypeError if + its not able to serialize the type. As an example: + + from my.core.serialize import dumps + + class MyClass: + def __init__(self, x): + self.x = x + + def serialize_default(o: Any) -> Any: + if isinstance(o, MyClass): + return {"x": o.x} + raise TypeError("Could not serialize...") + + dumps({"info": MyClass(5)}, default=serialize_default) + """ + return _dumps_factory(default=default, **kwargs)(obj) + + +def test_serialize_fallback() -> None: + import json as jsn # dont cause possible conflicts with module code + + import pytest + + # cant use a namedtuple here, since the default json.dump serializer + # serializes namedtuples as tuples, which become arrays + # just test with an array of mixed objects + X = [5, datetime.timedelta(seconds=5.0)] + + # ignore warnings. depending on test order, + # the lru_cache'd warning may have already been sent, + # so checking may be nondeterministic? + with pytest.warns(None): + res = jsn.loads(dumps(X)) + assert res == [5, 5.0] + + + +def test_nt_serialize() -> None: + import json as jsn # dont cause possible conflicts with module code + import orjson # import to make sure this is installed + + from typing import NamedTuple + + class A(NamedTuple): + x: int + y: float + + res: str = dumps(A(x=1, y=2.0)) + assert res == '{"x":1,"y":2.0}' + + # test orjson option kwarg + data = {datetime.date(year=1970, month=1, day=1): 5} + res = jsn.loads(dumps(data, option=orjson.OPT_NON_STR_KEYS)) + assert res == {'1970-01-01': 5} + + +def test_default_serializer() -> None: + import pytest + import json as jsn # dont cause possible conflicts with module code + + class Unserializable: + def __init__(self, x: int): + self.x = x + # add something handled by the _default_encode function + self.y = datetime.timedelta(seconds=float(x)) + + with pytest.raises(TypeError): + dumps(Unserializable(5)) + + class WithUnderscoreSerialize(Unserializable): + def _serialize(self) -> Any: + return {"x": self.x, "y": self.y} + + res = jsn.loads(dumps(WithUnderscoreSerialize(6))) + assert res == {"x": 6, "y": 6.0} + + # test passing additional 'default' func + def _serialize_with_default(o: Any) -> Any: + if isinstance(o, Unserializable): + return {"x": o.x, "y": o.y} + raise TypeError("Couldnt serialize") + + # this serializes both Unserializable, which is a custom type otherwise + # not handled, and timedelta, which is handled by the '_default_encode' + # in the 'wrapped_default' function + res2 = jsn.loads(dumps(Unserializable(10), default=_serialize_with_default)) + assert res2 == {"x": 10, "y": 10.0} diff --git a/setup.py b/setup.py index 360f975..9e4e891 100644 --- a/setup.py +++ b/setup.py @@ -55,6 +55,7 @@ def main(): 'optional': [ # todo document these? 'logzero', + 'orjson', 'cachew>=0.8.0', 'mypy', # used for config checks ], diff --git a/tests/core.py b/tests/core.py index 52ff688..95d30b7 100644 --- a/tests/core.py +++ b/tests/core.py @@ -18,3 +18,4 @@ from my.core.util import * from my.core.discovery_pure import * from my.core.types import * from my.core.stats import * +from my.core.serialize import test_serialize_fallback diff --git a/tests/serialize.py b/tests/serialize.py new file mode 100644 index 0000000..d9ee9a3 --- /dev/null +++ b/tests/serialize.py @@ -0,0 +1 @@ +from my.core.serialize import * diff --git a/tox.ini b/tox.ini index c96d84d..7294726 100644 --- a/tox.ini +++ b/tox.ini @@ -23,6 +23,7 @@ setenv = MY_CONFIG = nonexistent commands = pip install -e .[testing] pip install cachew + pip install orjson hpi module install my.location.google pip install ijson # optional dependency