cli/query: interactive fallback, improve guess_stats (#163)

This commit is contained in:
Sean Breckenridge 2021-04-19 10:57:42 -07:00 committed by GitHub
parent 91eed15a75
commit 277f0e3988
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 119 additions and 8 deletions

View file

@ -1,9 +1,10 @@
import functools import functools
import importlib import importlib
import inspect
import os import os
import sys import sys
import traceback import traceback
from typing import Optional, Sequence, Iterable, List, Type, Any from typing import Optional, Sequence, Iterable, List, Type, Any, Callable
from pathlib import Path from pathlib import Path
from subprocess import check_call, run, PIPE, CompletedProcess from subprocess import check_call, run, PIPE, CompletedProcess
@ -329,6 +330,80 @@ def module_install(*, user: bool, module: str) -> None:
check_call(cmd) check_call(cmd)
def _ui_getchar_pick(choices: Sequence[str], prompt: str = 'Select from: ') -> int:
'''
Basic menu allowing the user to select one of the choices
returns the index the user chose
'''
assert len(choices) > 0, 'Didnt recieve any choices to prompt!'
eprint(prompt + '\n')
# prompts like 1,2,3,4,5,6,7,8,9,a,b,c,d,e,f...
chr_offset = ord('a') - 10
# dict from key user can press -> resulting index
result_map = {}
for i, opt in enumerate(choices, 1):
char: str = str(i) if i < 10 else chr(i + chr_offset)
result_map[char] = i - 1
eprint(f'\t{char}. {opt}')
eprint('')
while True:
ch = click.getchar()
if ch not in result_map:
eprint(f'{ch} not in {list(result_map.keys())}')
continue
return result_map[ch]
def _locate_functions_or_prompt(qualified_names: List[str], prompt: bool = True) -> Iterable[Callable[..., Any]]:
from .query import locate_qualified_function, QueryException
from .stats import is_data_provider
# if not connected to a terminal, cant prompt
if not sys.stdout.isatty():
prompt = False
for qualname in qualified_names:
try:
# common-case
yield locate_qualified_function(qualname)
except QueryException as qr_err:
# can't prompt, raise error
if prompt is False:
# hmm, should we yield here instead and ignore the error if one iterator succeeds?
# this is likely a query running in the background, so probably bad for it
# to fail silently
raise qr_err
# maybe the user specified a module name instead of a function name?
# try importing the name the user specified as a module and prompt the
# user to select a 'data provider' like function
try:
mod = importlib.import_module(qualname)
except Exception:
eprint(f"During fallback, importing '{qualname}' as module failed")
raise qr_err
# find data providers in this module
data_providers = [f for _, f in inspect.getmembers(mod, inspect.isfunction) if is_data_provider(f)]
if len(data_providers) == 0:
eprint(f"During fallback, could not find any data providers in '{qualname}'")
raise qr_err
else:
# was only one data provider-like function, use that
if len(data_providers) == 1:
yield data_providers[0]
else:
# prompt the user to pick the function to use
choices = [f.__name__ for f in data_providers]
chosen_index = _ui_getchar_pick(choices, f"Which function should be used from '{qualname}'?")
# respond to the user, so they know something has been picked
eprint(f"Selected '{choices[chosen_index]}'")
yield data_providers[chosen_index]
# handle the 'hpi query' call # handle the 'hpi query' call
# can raise a QueryException, caught in the click command # can raise a QueryException, caught in the click command
def query_hpi_functions( def query_hpi_functions(
@ -350,11 +425,10 @@ def query_hpi_functions(
from itertools import chain from itertools import chain
from .query import locate_qualified_function
from .query_range import select_range, RangeTuple from .query_range import select_range, RangeTuple
# chain list of functions from user, in the order they wrote them on the CLI # chain list of functions from user, in the order they wrote them on the CLI
input_src = chain(*(locate_qualified_function(f)() for f in qualified_names)) input_src = chain(*(f() for f in _locate_functions_or_prompt(qualified_names)))
res = list(select_range( res = list(select_range(
input_src, input_src,

View file

@ -56,7 +56,7 @@ def locate_function(module_name: str, function_name: str) -> Callable[[], Iterab
return func return func
except Exception as e: except Exception as e:
raise QueryException(str(e)) raise QueryException(str(e))
raise QueryException(f"Could not find function {function_name} in {module_name}") raise QueryException(f"Could not find function '{function_name}' in '{module_name}'")
def locate_qualified_function(qualified_name: str) -> Callable[[], Iterable[ET]]: def locate_qualified_function(qualified_name: str) -> Callable[[], Iterable[ET]]:

View file

@ -6,7 +6,7 @@ import importlib
import inspect import inspect
import sys import sys
import typing import typing
from typing import Optional from typing import Optional, Callable, Any, Iterator
from .common import StatsFun, Stats, stat from .common import StatsFun, Stats, stat
@ -24,10 +24,11 @@ def guess_stats(module_name: str) -> Optional[StatsFun]:
return auto_stats return auto_stats
def is_data_provider(fun) -> bool: def is_data_provider(fun: Any) -> bool:
""" """
1. returns iterable or something like that 1. returns iterable or something like that
2. takes no arguments? (otherwise not callable by stats anyway?) 2. takes no arguments? (otherwise not callable by stats anyway?)
3. doesn't start with an underscore (those are probably helper functions?)
""" """
# todo maybe for 2 allow default arguments? not sure # todo maybe for 2 allow default arguments? not sure
# one example which could benefit is my.pdfs # one example which could benefit is my.pdfs
@ -36,11 +37,17 @@ def is_data_provider(fun) -> bool:
# todo. uh.. very similar to what cachew is trying to do? # todo. uh.. very similar to what cachew is trying to do?
try: try:
sig = inspect.signature(fun) sig = inspect.signature(fun)
except ValueError: # not a function? except (ValueError, TypeError): # not a function?
return False return False
if len(sig.parameters) > 0: # has at least one argument without default values
if len(list(sig_required_params(sig))) > 0:
return False return False
# probably a helper function?
if hasattr(fun, '__name__') and fun.__name__.startswith('_'):
return False
return_type = sig.return_annotation return_type = sig.return_annotation
return type_is_iterable(return_type) return type_is_iterable(return_type)
@ -49,6 +56,7 @@ def test_is_data_provider() -> None:
idp = is_data_provider idp = is_data_provider
assert not idp(None) assert not idp(None)
assert not idp(int) assert not idp(int)
assert not idp("x")
def no_return_type(): def no_return_type():
return [1, 2 ,3] return [1, 2 ,3]
@ -65,6 +73,35 @@ def test_is_data_provider() -> None:
return ['a', 'b', 'c'] return ['a', 'b', 'c']
assert idp(has_return_type) assert idp(has_return_type)
def _helper_func() -> Iterator[Any]:
yield 5
assert not idp(_helper_func)
# return any parameters the user is required to provide - those which don't have default values
def sig_required_params(sig: inspect.Signature) -> Iterator[inspect.Parameter]:
for param in sig.parameters.values():
if param.default == inspect.Parameter.empty:
yield param
def test_sig_required_params() -> None:
def x() -> int:
return 5
assert len(list(sig_required_params(inspect.signature(x)))) == 0
def y(arg: int) -> int:
return arg
assert len(list(sig_required_params(inspect.signature(y)))) == 1
# from stats perspective, this should be treated as a data provider as well
# could be that the default value to the data provider is the 'default'
# path to use for inputs/a function to provide input data
def z(arg: int = 5) -> int:
return arg
assert len(list(sig_required_params(inspect.signature(z)))) == 0
def type_is_iterable(type_spec) -> bool: def type_is_iterable(type_spec) -> bool:
if sys.version_info[1] < 8: if sys.version_info[1] < 8: