From 277f0e3988a9605d1e52fd47393fb987e02c32e7 Mon Sep 17 00:00:00 2001 From: Sean Breckenridge Date: Mon, 19 Apr 2021 10:57:42 -0700 Subject: [PATCH] cli/query: interactive fallback, improve guess_stats (#163) --- my/core/__main__.py | 80 +++++++++++++++++++++++++++++++++++++++++++-- my/core/query.py | 2 +- my/core/stats.py | 45 ++++++++++++++++++++++--- 3 files changed, 119 insertions(+), 8 deletions(-) diff --git a/my/core/__main__.py b/my/core/__main__.py index 1787538..3a7b074 100644 --- a/my/core/__main__.py +++ b/my/core/__main__.py @@ -1,9 +1,10 @@ import functools import importlib +import inspect import os import sys import traceback -from typing import Optional, Sequence, Iterable, List, Type, Any +from typing import Optional, Sequence, Iterable, List, Type, Any, Callable from pathlib import Path from subprocess import check_call, run, PIPE, CompletedProcess @@ -329,6 +330,80 @@ def module_install(*, user: bool, module: str) -> None: check_call(cmd) +def _ui_getchar_pick(choices: Sequence[str], prompt: str = 'Select from: ') -> int: + ''' + Basic menu allowing the user to select one of the choices + returns the index the user chose + ''' + assert len(choices) > 0, 'Didnt recieve any choices to prompt!' + eprint(prompt + '\n') + + # prompts like 1,2,3,4,5,6,7,8,9,a,b,c,d,e,f... + chr_offset = ord('a') - 10 + + # dict from key user can press -> resulting index + result_map = {} + for i, opt in enumerate(choices, 1): + char: str = str(i) if i < 10 else chr(i + chr_offset) + result_map[char] = i - 1 + eprint(f'\t{char}. {opt}') + + eprint('') + while True: + ch = click.getchar() + if ch not in result_map: + eprint(f'{ch} not in {list(result_map.keys())}') + continue + return result_map[ch] + + +def _locate_functions_or_prompt(qualified_names: List[str], prompt: bool = True) -> Iterable[Callable[..., Any]]: + from .query import locate_qualified_function, QueryException + from .stats import is_data_provider + + # if not connected to a terminal, cant prompt + if not sys.stdout.isatty(): + prompt = False + + for qualname in qualified_names: + try: + # common-case + yield locate_qualified_function(qualname) + except QueryException as qr_err: + # can't prompt, raise error + if prompt is False: + # hmm, should we yield here instead and ignore the error if one iterator succeeds? + # this is likely a query running in the background, so probably bad for it + # to fail silently + raise qr_err + + # maybe the user specified a module name instead of a function name? + # try importing the name the user specified as a module and prompt the + # user to select a 'data provider' like function + try: + mod = importlib.import_module(qualname) + except Exception: + eprint(f"During fallback, importing '{qualname}' as module failed") + raise qr_err + + # find data providers in this module + data_providers = [f for _, f in inspect.getmembers(mod, inspect.isfunction) if is_data_provider(f)] + if len(data_providers) == 0: + eprint(f"During fallback, could not find any data providers in '{qualname}'") + raise qr_err + else: + # was only one data provider-like function, use that + if len(data_providers) == 1: + yield data_providers[0] + else: + # prompt the user to pick the function to use + choices = [f.__name__ for f in data_providers] + chosen_index = _ui_getchar_pick(choices, f"Which function should be used from '{qualname}'?") + # respond to the user, so they know something has been picked + eprint(f"Selected '{choices[chosen_index]}'") + yield data_providers[chosen_index] + + # handle the 'hpi query' call # can raise a QueryException, caught in the click command def query_hpi_functions( @@ -350,11 +425,10 @@ def query_hpi_functions( from itertools import chain - from .query import locate_qualified_function from .query_range import select_range, RangeTuple # chain list of functions from user, in the order they wrote them on the CLI - input_src = chain(*(locate_qualified_function(f)() for f in qualified_names)) + input_src = chain(*(f() for f in _locate_functions_or_prompt(qualified_names))) res = list(select_range( input_src, diff --git a/my/core/query.py b/my/core/query.py index 570f059..edcb41a 100644 --- a/my/core/query.py +++ b/my/core/query.py @@ -56,7 +56,7 @@ def locate_function(module_name: str, function_name: str) -> Callable[[], Iterab return func except Exception as e: raise QueryException(str(e)) - raise QueryException(f"Could not find function {function_name} in {module_name}") + raise QueryException(f"Could not find function '{function_name}' in '{module_name}'") def locate_qualified_function(qualified_name: str) -> Callable[[], Iterable[ET]]: diff --git a/my/core/stats.py b/my/core/stats.py index 1a59f3d..b54a3b9 100644 --- a/my/core/stats.py +++ b/my/core/stats.py @@ -6,7 +6,7 @@ import importlib import inspect import sys import typing -from typing import Optional +from typing import Optional, Callable, Any, Iterator from .common import StatsFun, Stats, stat @@ -24,10 +24,11 @@ def guess_stats(module_name: str) -> Optional[StatsFun]: return auto_stats -def is_data_provider(fun) -> bool: +def is_data_provider(fun: Any) -> bool: """ 1. returns iterable or something like that 2. takes no arguments? (otherwise not callable by stats anyway?) + 3. doesn't start with an underscore (those are probably helper functions?) """ # todo maybe for 2 allow default arguments? not sure # one example which could benefit is my.pdfs @@ -36,11 +37,17 @@ def is_data_provider(fun) -> bool: # todo. uh.. very similar to what cachew is trying to do? try: sig = inspect.signature(fun) - except ValueError: # not a function? + except (ValueError, TypeError): # not a function? return False - if len(sig.parameters) > 0: + # has at least one argument without default values + if len(list(sig_required_params(sig))) > 0: return False + + # probably a helper function? + if hasattr(fun, '__name__') and fun.__name__.startswith('_'): + return False + return_type = sig.return_annotation return type_is_iterable(return_type) @@ -49,6 +56,7 @@ def test_is_data_provider() -> None: idp = is_data_provider assert not idp(None) assert not idp(int) + assert not idp("x") def no_return_type(): return [1, 2 ,3] @@ -65,6 +73,35 @@ def test_is_data_provider() -> None: return ['a', 'b', 'c'] assert idp(has_return_type) + def _helper_func() -> Iterator[Any]: + yield 5 + assert not idp(_helper_func) + + +# return any parameters the user is required to provide - those which don't have default values +def sig_required_params(sig: inspect.Signature) -> Iterator[inspect.Parameter]: + for param in sig.parameters.values(): + if param.default == inspect.Parameter.empty: + yield param + + +def test_sig_required_params() -> None: + + def x() -> int: + return 5 + assert len(list(sig_required_params(inspect.signature(x)))) == 0 + + def y(arg: int) -> int: + return arg + assert len(list(sig_required_params(inspect.signature(y)))) == 1 + + # from stats perspective, this should be treated as a data provider as well + # could be that the default value to the data provider is the 'default' + # path to use for inputs/a function to provide input data + def z(arg: int = 5) -> int: + return arg + assert len(list(sig_required_params(inspect.signature(z)))) == 0 + def type_is_iterable(type_spec) -> bool: if sys.version_info[1] < 8: