stats/is_data_provider: ignore 'inputs' func
This commit is contained in:
parent
68019c80db
commit
d71383ddee
1 changed files with 18 additions and 5 deletions
|
@ -29,12 +29,12 @@ def guess_data_providers(module_name: str) -> Dict[str, Callable]:
|
||||||
|
|
||||||
|
|
||||||
# todo how to exclude deprecated stuff?
|
# todo how to exclude deprecated stuff?
|
||||||
# todo also exclude def inputs()?
|
|
||||||
def is_data_provider(fun: Any) -> bool:
|
def is_data_provider(fun: Any) -> bool:
|
||||||
"""
|
"""
|
||||||
1. returns iterable or something like that
|
1. returns iterable or something like that
|
||||||
2. takes no arguments? (otherwise not callable by stats anyway?)
|
2. takes no arguments? (otherwise not callable by stats anyway?)
|
||||||
3. doesn't start with an underscore (those are probably helper functions?)
|
3. doesn't start with an underscore (those are probably helper functions?)
|
||||||
|
4. functions isnt the 'inputs' function (or ends with '_inputs')
|
||||||
"""
|
"""
|
||||||
# todo maybe for 2 allow default arguments? not sure
|
# todo maybe for 2 allow default arguments? not sure
|
||||||
# one example which could benefit is my.pdfs
|
# one example which could benefit is my.pdfs
|
||||||
|
@ -50,9 +50,13 @@ def is_data_provider(fun: Any) -> bool:
|
||||||
if len(list(sig_required_params(sig))) > 0:
|
if len(list(sig_required_params(sig))) > 0:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# probably a helper function?
|
if hasattr(fun, '__name__'):
|
||||||
if hasattr(fun, '__name__') and fun.__name__.startswith('_'):
|
# probably a helper function?
|
||||||
return False
|
if fun.__name__.startswith('_'):
|
||||||
|
return False
|
||||||
|
# ignore def inputs; something like comment_inputs or backup_inputs should also be ignored
|
||||||
|
if fun.__name__ == 'inputs' or fun.__name__.endswith('_inputs'):
|
||||||
|
return False
|
||||||
|
|
||||||
return_type = sig.return_annotation
|
return_type = sig.return_annotation
|
||||||
return type_is_iterable(return_type)
|
return type_is_iterable(return_type)
|
||||||
|
@ -80,9 +84,18 @@ def test_is_data_provider() -> None:
|
||||||
assert idp(has_return_type)
|
assert idp(has_return_type)
|
||||||
|
|
||||||
def _helper_func() -> Iterator[Any]:
|
def _helper_func() -> Iterator[Any]:
|
||||||
yield 5
|
yield 1
|
||||||
assert not idp(_helper_func)
|
assert not idp(_helper_func)
|
||||||
|
|
||||||
|
def inputs() -> Iterator[Any]:
|
||||||
|
yield 1
|
||||||
|
assert not idp(inputs)
|
||||||
|
|
||||||
|
def producer_inputs() -> Iterator[Any]:
|
||||||
|
yield 1
|
||||||
|
assert not idp(producer_inputs)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# return any parameters the user is required to provide - those which don't have default values
|
# return any parameters the user is required to provide - those which don't have default values
|
||||||
def sig_required_params(sig: inspect.Signature) -> Iterator[inspect.Parameter]:
|
def sig_required_params(sig: inspect.Signature) -> Iterator[inspect.Parameter]:
|
||||||
|
|
Loading…
Add table
Reference in a new issue