query interactive fallback, improve guess_stats

This commit is contained in:
Sean Breckenridge 2021-04-19 02:08:18 -07:00
parent 91eed15a75
commit 91f3e81573
3 changed files with 116 additions and 7 deletions

View file

@ -6,7 +6,7 @@ import importlib
import inspect
import sys
import typing
from typing import Optional
from typing import Optional, Callable, Any, List, Iterator
from .common import StatsFun, Stats, stat
@ -24,23 +24,29 @@ def guess_stats(module_name: str) -> Optional[StatsFun]:
return auto_stats
def is_data_provider(fun) -> bool:
def is_data_provider(fun: Any) -> bool:
"""
1. returns iterable or something like that
2. takes no arguments? (otherwise not callable by stats anyway?)
3. doesn't start with an underscore (those are probably helper functions?)
"""
# todo maybe for 2 allow default arguments? not sure
# one example which could benefit is my.pdfs
if fun is None:
return False
# probably a helper function?
if fun.__name__.startswith("_"):
return False
# todo. uh.. very similar to what cachew is trying to do?
try:
sig = inspect.signature(fun)
except ValueError: # not a function?
return False
if len(sig.parameters) > 0:
# has at least one argument without default values
if len(list(sig_required_params(sig))) > 0:
return False
return_type = sig.return_annotation
return type_is_iterable(return_type)
@ -65,6 +71,35 @@ def test_is_data_provider() -> None:
return ['a', 'b', 'c']
assert idp(has_return_type)
def _helper_func() -> Iterator[Any]:
yield 5
assert not idp(_helper_func)
# return any parameters the user is required to provide - those which don't have default values
def sig_required_params(sig: inspect.Signature) -> Iterator[inspect.Parameter]:
for param in sig.parameters.values():
if param.default == inspect.Parameter.empty:
yield param
def test_sig_required_args() -> None:
def x() -> int:
return 5
assert len(list(sig_required_params(inspect.signature(x)))) == 0
def y(arg: int) -> int:
return arg
assert len(list(sig_required_params(inspect.signature(y)))) == 1
# from stats perspective, this should be treated as a data provider as well
# could be that the default value to the data provider is the 'default'
# path to use for inputs/a function to provide input data
def z(arg: int = 5) -> int:
return arg
assert len(list(sig_required_params(inspect.signature(z)))) == 0
def type_is_iterable(type_spec) -> bool:
if sys.version_info[1] < 8: