cli/query: interactive fallback, improve guess_stats (#163)
This commit is contained in:
parent
91eed15a75
commit
277f0e3988
3 changed files with 119 additions and 8 deletions
|
@ -1,9 +1,10 @@
|
|||
import functools
|
||||
import importlib
|
||||
import inspect
|
||||
import os
|
||||
import sys
|
||||
import traceback
|
||||
from typing import Optional, Sequence, Iterable, List, Type, Any
|
||||
from typing import Optional, Sequence, Iterable, List, Type, Any, Callable
|
||||
from pathlib import Path
|
||||
from subprocess import check_call, run, PIPE, CompletedProcess
|
||||
|
||||
|
@ -329,6 +330,80 @@ def module_install(*, user: bool, module: str) -> None:
|
|||
check_call(cmd)
|
||||
|
||||
|
||||
def _ui_getchar_pick(choices: Sequence[str], prompt: str = 'Select from: ') -> int:
|
||||
'''
|
||||
Basic menu allowing the user to select one of the choices
|
||||
returns the index the user chose
|
||||
'''
|
||||
assert len(choices) > 0, 'Didnt recieve any choices to prompt!'
|
||||
eprint(prompt + '\n')
|
||||
|
||||
# prompts like 1,2,3,4,5,6,7,8,9,a,b,c,d,e,f...
|
||||
chr_offset = ord('a') - 10
|
||||
|
||||
# dict from key user can press -> resulting index
|
||||
result_map = {}
|
||||
for i, opt in enumerate(choices, 1):
|
||||
char: str = str(i) if i < 10 else chr(i + chr_offset)
|
||||
result_map[char] = i - 1
|
||||
eprint(f'\t{char}. {opt}')
|
||||
|
||||
eprint('')
|
||||
while True:
|
||||
ch = click.getchar()
|
||||
if ch not in result_map:
|
||||
eprint(f'{ch} not in {list(result_map.keys())}')
|
||||
continue
|
||||
return result_map[ch]
|
||||
|
||||
|
||||
def _locate_functions_or_prompt(qualified_names: List[str], prompt: bool = True) -> Iterable[Callable[..., Any]]:
|
||||
from .query import locate_qualified_function, QueryException
|
||||
from .stats import is_data_provider
|
||||
|
||||
# if not connected to a terminal, cant prompt
|
||||
if not sys.stdout.isatty():
|
||||
prompt = False
|
||||
|
||||
for qualname in qualified_names:
|
||||
try:
|
||||
# common-case
|
||||
yield locate_qualified_function(qualname)
|
||||
except QueryException as qr_err:
|
||||
# can't prompt, raise error
|
||||
if prompt is False:
|
||||
# hmm, should we yield here instead and ignore the error if one iterator succeeds?
|
||||
# this is likely a query running in the background, so probably bad for it
|
||||
# to fail silently
|
||||
raise qr_err
|
||||
|
||||
# maybe the user specified a module name instead of a function name?
|
||||
# try importing the name the user specified as a module and prompt the
|
||||
# user to select a 'data provider' like function
|
||||
try:
|
||||
mod = importlib.import_module(qualname)
|
||||
except Exception:
|
||||
eprint(f"During fallback, importing '{qualname}' as module failed")
|
||||
raise qr_err
|
||||
|
||||
# find data providers in this module
|
||||
data_providers = [f for _, f in inspect.getmembers(mod, inspect.isfunction) if is_data_provider(f)]
|
||||
if len(data_providers) == 0:
|
||||
eprint(f"During fallback, could not find any data providers in '{qualname}'")
|
||||
raise qr_err
|
||||
else:
|
||||
# was only one data provider-like function, use that
|
||||
if len(data_providers) == 1:
|
||||
yield data_providers[0]
|
||||
else:
|
||||
# prompt the user to pick the function to use
|
||||
choices = [f.__name__ for f in data_providers]
|
||||
chosen_index = _ui_getchar_pick(choices, f"Which function should be used from '{qualname}'?")
|
||||
# respond to the user, so they know something has been picked
|
||||
eprint(f"Selected '{choices[chosen_index]}'")
|
||||
yield data_providers[chosen_index]
|
||||
|
||||
|
||||
# handle the 'hpi query' call
|
||||
# can raise a QueryException, caught in the click command
|
||||
def query_hpi_functions(
|
||||
|
@ -350,11 +425,10 @@ def query_hpi_functions(
|
|||
|
||||
from itertools import chain
|
||||
|
||||
from .query import locate_qualified_function
|
||||
from .query_range import select_range, RangeTuple
|
||||
|
||||
# chain list of functions from user, in the order they wrote them on the CLI
|
||||
input_src = chain(*(locate_qualified_function(f)() for f in qualified_names))
|
||||
input_src = chain(*(f() for f in _locate_functions_or_prompt(qualified_names)))
|
||||
|
||||
res = list(select_range(
|
||||
input_src,
|
||||
|
|
|
@ -56,7 +56,7 @@ def locate_function(module_name: str, function_name: str) -> Callable[[], Iterab
|
|||
return func
|
||||
except Exception as e:
|
||||
raise QueryException(str(e))
|
||||
raise QueryException(f"Could not find function {function_name} in {module_name}")
|
||||
raise QueryException(f"Could not find function '{function_name}' in '{module_name}'")
|
||||
|
||||
|
||||
def locate_qualified_function(qualified_name: str) -> Callable[[], Iterable[ET]]:
|
||||
|
|
|
@ -6,7 +6,7 @@ import importlib
|
|||
import inspect
|
||||
import sys
|
||||
import typing
|
||||
from typing import Optional
|
||||
from typing import Optional, Callable, Any, Iterator
|
||||
|
||||
from .common import StatsFun, Stats, stat
|
||||
|
||||
|
@ -24,10 +24,11 @@ def guess_stats(module_name: str) -> Optional[StatsFun]:
|
|||
return auto_stats
|
||||
|
||||
|
||||
def is_data_provider(fun) -> bool:
|
||||
def is_data_provider(fun: Any) -> bool:
|
||||
"""
|
||||
1. returns iterable or something like that
|
||||
2. takes no arguments? (otherwise not callable by stats anyway?)
|
||||
3. doesn't start with an underscore (those are probably helper functions?)
|
||||
"""
|
||||
# todo maybe for 2 allow default arguments? not sure
|
||||
# one example which could benefit is my.pdfs
|
||||
|
@ -36,11 +37,17 @@ def is_data_provider(fun) -> bool:
|
|||
# todo. uh.. very similar to what cachew is trying to do?
|
||||
try:
|
||||
sig = inspect.signature(fun)
|
||||
except ValueError: # not a function?
|
||||
except (ValueError, TypeError): # not a function?
|
||||
return False
|
||||
|
||||
if len(sig.parameters) > 0:
|
||||
# has at least one argument without default values
|
||||
if len(list(sig_required_params(sig))) > 0:
|
||||
return False
|
||||
|
||||
# probably a helper function?
|
||||
if hasattr(fun, '__name__') and fun.__name__.startswith('_'):
|
||||
return False
|
||||
|
||||
return_type = sig.return_annotation
|
||||
return type_is_iterable(return_type)
|
||||
|
||||
|
@ -49,6 +56,7 @@ def test_is_data_provider() -> None:
|
|||
idp = is_data_provider
|
||||
assert not idp(None)
|
||||
assert not idp(int)
|
||||
assert not idp("x")
|
||||
|
||||
def no_return_type():
|
||||
return [1, 2 ,3]
|
||||
|
@ -65,6 +73,35 @@ def test_is_data_provider() -> None:
|
|||
return ['a', 'b', 'c']
|
||||
assert idp(has_return_type)
|
||||
|
||||
def _helper_func() -> Iterator[Any]:
|
||||
yield 5
|
||||
assert not idp(_helper_func)
|
||||
|
||||
|
||||
# return any parameters the user is required to provide - those which don't have default values
|
||||
def sig_required_params(sig: inspect.Signature) -> Iterator[inspect.Parameter]:
|
||||
for param in sig.parameters.values():
|
||||
if param.default == inspect.Parameter.empty:
|
||||
yield param
|
||||
|
||||
|
||||
def test_sig_required_params() -> None:
|
||||
|
||||
def x() -> int:
|
||||
return 5
|
||||
assert len(list(sig_required_params(inspect.signature(x)))) == 0
|
||||
|
||||
def y(arg: int) -> int:
|
||||
return arg
|
||||
assert len(list(sig_required_params(inspect.signature(y)))) == 1
|
||||
|
||||
# from stats perspective, this should be treated as a data provider as well
|
||||
# could be that the default value to the data provider is the 'default'
|
||||
# path to use for inputs/a function to provide input data
|
||||
def z(arg: int = 5) -> int:
|
||||
return arg
|
||||
assert len(list(sig_required_params(inspect.signature(z)))) == 0
|
||||
|
||||
|
||||
def type_is_iterable(type_spec) -> bool:
|
||||
if sys.version_info[1] < 8:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue