diff --git a/my/core/stats.py b/my/core/stats.py index dfa68e5..9750061 100644 --- a/my/core/stats.py +++ b/my/core/stats.py @@ -29,12 +29,12 @@ def guess_data_providers(module_name: str) -> Dict[str, Callable]: # todo how to exclude deprecated stuff? -# todo also exclude def inputs()? def is_data_provider(fun: Any) -> bool: """ 1. returns iterable or something like that 2. takes no arguments? (otherwise not callable by stats anyway?) 3. doesn't start with an underscore (those are probably helper functions?) + 4. functions isnt the 'inputs' function (or ends with '_inputs') """ # todo maybe for 2 allow default arguments? not sure # one example which could benefit is my.pdfs @@ -50,9 +50,13 @@ def is_data_provider(fun: Any) -> bool: if len(list(sig_required_params(sig))) > 0: return False - # probably a helper function? - if hasattr(fun, '__name__') and fun.__name__.startswith('_'): - return False + if hasattr(fun, '__name__'): + # probably a helper function? + if fun.__name__.startswith('_'): + return False + # ignore def inputs; something like comment_inputs or backup_inputs should also be ignored + if fun.__name__ == 'inputs' or fun.__name__.endswith('_inputs'): + return False return_type = sig.return_annotation return type_is_iterable(return_type) @@ -80,9 +84,18 @@ def test_is_data_provider() -> None: assert idp(has_return_type) def _helper_func() -> Iterator[Any]: - yield 5 + yield 1 assert not idp(_helper_func) + def inputs() -> Iterator[Any]: + yield 1 + assert not idp(inputs) + + def producer_inputs() -> Iterator[Any]: + yield 1 + assert not idp(producer_inputs) + + # return any parameters the user is required to provide - those which don't have default values def sig_required_params(sig: inspect.Signature) -> Iterator[inspect.Parameter]: