core.stats: fix is_data_provider when from __future__ import annotations is used
This commit is contained in:
parent
872053a3c3
commit
c5fe2e9412
1 changed files with 8 additions and 5 deletions
|
@ -4,7 +4,6 @@ Helpers for hpi doctor/stats functionality.
|
|||
import collections
|
||||
import importlib
|
||||
import inspect
|
||||
import sys
|
||||
import typing
|
||||
from typing import Optional, Callable, Any, Iterator, Sequence, Dict, List
|
||||
|
||||
|
@ -59,7 +58,14 @@ def is_data_provider(fun: Any) -> bool:
|
|||
if fun.__name__ == 'inputs' or fun.__name__.endswith('_inputs'):
|
||||
return False
|
||||
|
||||
return_type = sig.return_annotation
|
||||
# inspect.signature might return str instead of a proper type object
|
||||
# if from __future__ import annotations is used
|
||||
# so best to rely on get_type_hints (which evals the annotations)
|
||||
type_hints = typing.get_type_hints(fun)
|
||||
return_type = type_hints.get('return')
|
||||
if return_type is None:
|
||||
return False
|
||||
|
||||
return type_is_iterable(return_type)
|
||||
|
||||
|
||||
|
@ -123,9 +129,6 @@ def test_sig_required_params() -> None:
|
|||
|
||||
|
||||
def type_is_iterable(type_spec) -> bool:
|
||||
if sys.version_info[1] < 8:
|
||||
# there is no get_origin before 3.8, and retrofitting gonna be a lot of pain
|
||||
return any(x in str(type_spec) for x in ['List', 'Sequence', 'Iterable', 'Iterator'])
|
||||
origin = typing.get_origin(type_spec)
|
||||
if origin is None:
|
||||
return False
|
||||
|
|
Loading…
Add table
Reference in a new issue