cli/query: interactive fallback, improve guess_stats (#163)
This commit is contained in:
parent
91eed15a75
commit
277f0e3988
3 changed files with 119 additions and 8 deletions
|
@ -1,9 +1,10 @@
|
||||||
import functools
|
import functools
|
||||||
import importlib
|
import importlib
|
||||||
|
import inspect
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
import traceback
|
import traceback
|
||||||
from typing import Optional, Sequence, Iterable, List, Type, Any
|
from typing import Optional, Sequence, Iterable, List, Type, Any, Callable
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from subprocess import check_call, run, PIPE, CompletedProcess
|
from subprocess import check_call, run, PIPE, CompletedProcess
|
||||||
|
|
||||||
|
@ -329,6 +330,80 @@ def module_install(*, user: bool, module: str) -> None:
|
||||||
check_call(cmd)
|
check_call(cmd)
|
||||||
|
|
||||||
|
|
||||||
|
def _ui_getchar_pick(choices: Sequence[str], prompt: str = 'Select from: ') -> int:
|
||||||
|
'''
|
||||||
|
Basic menu allowing the user to select one of the choices
|
||||||
|
returns the index the user chose
|
||||||
|
'''
|
||||||
|
assert len(choices) > 0, 'Didnt recieve any choices to prompt!'
|
||||||
|
eprint(prompt + '\n')
|
||||||
|
|
||||||
|
# prompts like 1,2,3,4,5,6,7,8,9,a,b,c,d,e,f...
|
||||||
|
chr_offset = ord('a') - 10
|
||||||
|
|
||||||
|
# dict from key user can press -> resulting index
|
||||||
|
result_map = {}
|
||||||
|
for i, opt in enumerate(choices, 1):
|
||||||
|
char: str = str(i) if i < 10 else chr(i + chr_offset)
|
||||||
|
result_map[char] = i - 1
|
||||||
|
eprint(f'\t{char}. {opt}')
|
||||||
|
|
||||||
|
eprint('')
|
||||||
|
while True:
|
||||||
|
ch = click.getchar()
|
||||||
|
if ch not in result_map:
|
||||||
|
eprint(f'{ch} not in {list(result_map.keys())}')
|
||||||
|
continue
|
||||||
|
return result_map[ch]
|
||||||
|
|
||||||
|
|
||||||
|
def _locate_functions_or_prompt(qualified_names: List[str], prompt: bool = True) -> Iterable[Callable[..., Any]]:
|
||||||
|
from .query import locate_qualified_function, QueryException
|
||||||
|
from .stats import is_data_provider
|
||||||
|
|
||||||
|
# if not connected to a terminal, cant prompt
|
||||||
|
if not sys.stdout.isatty():
|
||||||
|
prompt = False
|
||||||
|
|
||||||
|
for qualname in qualified_names:
|
||||||
|
try:
|
||||||
|
# common-case
|
||||||
|
yield locate_qualified_function(qualname)
|
||||||
|
except QueryException as qr_err:
|
||||||
|
# can't prompt, raise error
|
||||||
|
if prompt is False:
|
||||||
|
# hmm, should we yield here instead and ignore the error if one iterator succeeds?
|
||||||
|
# this is likely a query running in the background, so probably bad for it
|
||||||
|
# to fail silently
|
||||||
|
raise qr_err
|
||||||
|
|
||||||
|
# maybe the user specified a module name instead of a function name?
|
||||||
|
# try importing the name the user specified as a module and prompt the
|
||||||
|
# user to select a 'data provider' like function
|
||||||
|
try:
|
||||||
|
mod = importlib.import_module(qualname)
|
||||||
|
except Exception:
|
||||||
|
eprint(f"During fallback, importing '{qualname}' as module failed")
|
||||||
|
raise qr_err
|
||||||
|
|
||||||
|
# find data providers in this module
|
||||||
|
data_providers = [f for _, f in inspect.getmembers(mod, inspect.isfunction) if is_data_provider(f)]
|
||||||
|
if len(data_providers) == 0:
|
||||||
|
eprint(f"During fallback, could not find any data providers in '{qualname}'")
|
||||||
|
raise qr_err
|
||||||
|
else:
|
||||||
|
# was only one data provider-like function, use that
|
||||||
|
if len(data_providers) == 1:
|
||||||
|
yield data_providers[0]
|
||||||
|
else:
|
||||||
|
# prompt the user to pick the function to use
|
||||||
|
choices = [f.__name__ for f in data_providers]
|
||||||
|
chosen_index = _ui_getchar_pick(choices, f"Which function should be used from '{qualname}'?")
|
||||||
|
# respond to the user, so they know something has been picked
|
||||||
|
eprint(f"Selected '{choices[chosen_index]}'")
|
||||||
|
yield data_providers[chosen_index]
|
||||||
|
|
||||||
|
|
||||||
# handle the 'hpi query' call
|
# handle the 'hpi query' call
|
||||||
# can raise a QueryException, caught in the click command
|
# can raise a QueryException, caught in the click command
|
||||||
def query_hpi_functions(
|
def query_hpi_functions(
|
||||||
|
@ -350,11 +425,10 @@ def query_hpi_functions(
|
||||||
|
|
||||||
from itertools import chain
|
from itertools import chain
|
||||||
|
|
||||||
from .query import locate_qualified_function
|
|
||||||
from .query_range import select_range, RangeTuple
|
from .query_range import select_range, RangeTuple
|
||||||
|
|
||||||
# chain list of functions from user, in the order they wrote them on the CLI
|
# chain list of functions from user, in the order they wrote them on the CLI
|
||||||
input_src = chain(*(locate_qualified_function(f)() for f in qualified_names))
|
input_src = chain(*(f() for f in _locate_functions_or_prompt(qualified_names)))
|
||||||
|
|
||||||
res = list(select_range(
|
res = list(select_range(
|
||||||
input_src,
|
input_src,
|
||||||
|
|
|
@ -56,7 +56,7 @@ def locate_function(module_name: str, function_name: str) -> Callable[[], Iterab
|
||||||
return func
|
return func
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise QueryException(str(e))
|
raise QueryException(str(e))
|
||||||
raise QueryException(f"Could not find function {function_name} in {module_name}")
|
raise QueryException(f"Could not find function '{function_name}' in '{module_name}'")
|
||||||
|
|
||||||
|
|
||||||
def locate_qualified_function(qualified_name: str) -> Callable[[], Iterable[ET]]:
|
def locate_qualified_function(qualified_name: str) -> Callable[[], Iterable[ET]]:
|
||||||
|
|
|
@ -6,7 +6,7 @@ import importlib
|
||||||
import inspect
|
import inspect
|
||||||
import sys
|
import sys
|
||||||
import typing
|
import typing
|
||||||
from typing import Optional
|
from typing import Optional, Callable, Any, Iterator
|
||||||
|
|
||||||
from .common import StatsFun, Stats, stat
|
from .common import StatsFun, Stats, stat
|
||||||
|
|
||||||
|
@ -24,10 +24,11 @@ def guess_stats(module_name: str) -> Optional[StatsFun]:
|
||||||
return auto_stats
|
return auto_stats
|
||||||
|
|
||||||
|
|
||||||
def is_data_provider(fun) -> bool:
|
def is_data_provider(fun: Any) -> bool:
|
||||||
"""
|
"""
|
||||||
1. returns iterable or something like that
|
1. returns iterable or something like that
|
||||||
2. takes no arguments? (otherwise not callable by stats anyway?)
|
2. takes no arguments? (otherwise not callable by stats anyway?)
|
||||||
|
3. doesn't start with an underscore (those are probably helper functions?)
|
||||||
"""
|
"""
|
||||||
# todo maybe for 2 allow default arguments? not sure
|
# todo maybe for 2 allow default arguments? not sure
|
||||||
# one example which could benefit is my.pdfs
|
# one example which could benefit is my.pdfs
|
||||||
|
@ -36,11 +37,17 @@ def is_data_provider(fun) -> bool:
|
||||||
# todo. uh.. very similar to what cachew is trying to do?
|
# todo. uh.. very similar to what cachew is trying to do?
|
||||||
try:
|
try:
|
||||||
sig = inspect.signature(fun)
|
sig = inspect.signature(fun)
|
||||||
except ValueError: # not a function?
|
except (ValueError, TypeError): # not a function?
|
||||||
return False
|
return False
|
||||||
|
|
||||||
if len(sig.parameters) > 0:
|
# has at least one argument without default values
|
||||||
|
if len(list(sig_required_params(sig))) > 0:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
# probably a helper function?
|
||||||
|
if hasattr(fun, '__name__') and fun.__name__.startswith('_'):
|
||||||
|
return False
|
||||||
|
|
||||||
return_type = sig.return_annotation
|
return_type = sig.return_annotation
|
||||||
return type_is_iterable(return_type)
|
return type_is_iterable(return_type)
|
||||||
|
|
||||||
|
@ -49,6 +56,7 @@ def test_is_data_provider() -> None:
|
||||||
idp = is_data_provider
|
idp = is_data_provider
|
||||||
assert not idp(None)
|
assert not idp(None)
|
||||||
assert not idp(int)
|
assert not idp(int)
|
||||||
|
assert not idp("x")
|
||||||
|
|
||||||
def no_return_type():
|
def no_return_type():
|
||||||
return [1, 2 ,3]
|
return [1, 2 ,3]
|
||||||
|
@ -65,6 +73,35 @@ def test_is_data_provider() -> None:
|
||||||
return ['a', 'b', 'c']
|
return ['a', 'b', 'c']
|
||||||
assert idp(has_return_type)
|
assert idp(has_return_type)
|
||||||
|
|
||||||
|
def _helper_func() -> Iterator[Any]:
|
||||||
|
yield 5
|
||||||
|
assert not idp(_helper_func)
|
||||||
|
|
||||||
|
|
||||||
|
# return any parameters the user is required to provide - those which don't have default values
|
||||||
|
def sig_required_params(sig: inspect.Signature) -> Iterator[inspect.Parameter]:
|
||||||
|
for param in sig.parameters.values():
|
||||||
|
if param.default == inspect.Parameter.empty:
|
||||||
|
yield param
|
||||||
|
|
||||||
|
|
||||||
|
def test_sig_required_params() -> None:
|
||||||
|
|
||||||
|
def x() -> int:
|
||||||
|
return 5
|
||||||
|
assert len(list(sig_required_params(inspect.signature(x)))) == 0
|
||||||
|
|
||||||
|
def y(arg: int) -> int:
|
||||||
|
return arg
|
||||||
|
assert len(list(sig_required_params(inspect.signature(y)))) == 1
|
||||||
|
|
||||||
|
# from stats perspective, this should be treated as a data provider as well
|
||||||
|
# could be that the default value to the data provider is the 'default'
|
||||||
|
# path to use for inputs/a function to provide input data
|
||||||
|
def z(arg: int = 5) -> int:
|
||||||
|
return arg
|
||||||
|
assert len(list(sig_required_params(inspect.signature(z)))) == 0
|
||||||
|
|
||||||
|
|
||||||
def type_is_iterable(type_spec) -> bool:
|
def type_is_iterable(type_spec) -> bool:
|
||||||
if sys.version_info[1] < 8:
|
if sys.version_info[1] < 8:
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue