From c5fe2e94125deda779b0d0088d2bb44732fcbbe1 Mon Sep 17 00:00:00 2001 From: karlicoss Date: Sat, 21 Oct 2023 23:08:40 +0100 Subject: [PATCH] core.stats: fix is_data_provider when from __future__ import annotations is used --- my/core/stats.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/my/core/stats.py b/my/core/stats.py index 8923996..42e8cd9 100644 --- a/my/core/stats.py +++ b/my/core/stats.py @@ -4,7 +4,6 @@ Helpers for hpi doctor/stats functionality. import collections import importlib import inspect -import sys import typing from typing import Optional, Callable, Any, Iterator, Sequence, Dict, List @@ -59,7 +58,14 @@ def is_data_provider(fun: Any) -> bool: if fun.__name__ == 'inputs' or fun.__name__.endswith('_inputs'): return False - return_type = sig.return_annotation + # inspect.signature might return str instead of a proper type object + # if from __future__ import annotations is used + # so best to rely on get_type_hints (which evals the annotations) + type_hints = typing.get_type_hints(fun) + return_type = type_hints.get('return') + if return_type is None: + return False + return type_is_iterable(return_type) @@ -123,9 +129,6 @@ def test_sig_required_params() -> None: def type_is_iterable(type_spec) -> bool: - if sys.version_info[1] < 8: - # there is no get_origin before 3.8, and retrofitting gonna be a lot of pain - return any(x in str(type_spec) for x in ['List', 'Sequence', 'Iterable', 'Iterator']) origin = typing.get_origin(type_spec) if origin is None: return False