core.stats: fix is_data_provider when from __future__ import annotations is used

This commit is contained in:
karlicoss 2023-10-21 23:08:40 +01:00
parent 872053a3c3
commit c5fe2e9412

View file

@ -4,7 +4,6 @@ Helpers for hpi doctor/stats functionality.
import collections import collections
import importlib import importlib
import inspect import inspect
import sys
import typing import typing
from typing import Optional, Callable, Any, Iterator, Sequence, Dict, List from typing import Optional, Callable, Any, Iterator, Sequence, Dict, List
@ -59,7 +58,14 @@ def is_data_provider(fun: Any) -> bool:
if fun.__name__ == 'inputs' or fun.__name__.endswith('_inputs'): if fun.__name__ == 'inputs' or fun.__name__.endswith('_inputs'):
return False return False
return_type = sig.return_annotation # inspect.signature might return str instead of a proper type object
# if from __future__ import annotations is used
# so best to rely on get_type_hints (which evals the annotations)
type_hints = typing.get_type_hints(fun)
return_type = type_hints.get('return')
if return_type is None:
return False
return type_is_iterable(return_type) return type_is_iterable(return_type)
@ -123,9 +129,6 @@ def test_sig_required_params() -> None:
def type_is_iterable(type_spec) -> bool: def type_is_iterable(type_spec) -> bool:
if sys.version_info[1] < 8:
# there is no get_origin before 3.8, and retrofitting gonna be a lot of pain
return any(x in str(type_spec) for x in ['List', 'Sequence', 'Iterable', 'Iterator'])
origin = typing.get_origin(type_spec) origin = typing.get_origin(type_spec)
if origin is None: if origin is None:
return False return False