From e76120da062d07d8c699d033769b634472388c6f Mon Sep 17 00:00:00 2001 From: Dave Halter Date: Wed, 24 Jul 2019 02:28:12 +0200 Subject: [PATCH] Fix partial signatures, fixes #1371 --- jedi/evaluate/signature.py | 42 +++++++++++++++++----------- jedi/plugins/stdlib.py | 40 +++++++++++++++++++++++--- test/test_evaluate/test_signature.py | 25 +++++++++++++++-- 3 files changed, 85 insertions(+), 22 deletions(-) diff --git a/jedi/evaluate/signature.py b/jedi/evaluate/signature.py index e0efab09..1f4cefbe 100644 --- a/jedi/evaluate/signature.py +++ b/jedi/evaluate/signature.py @@ -1,19 +1,7 @@ from jedi._compatibility import Parameter -class AbstractSignature(object): - def __init__(self, context, is_bound=False): - self.context = context - self.is_bound = is_bound - - @property - def name(self): - return self.context.name - - @property - def annotation_string(self): - return '' - +class _SignatureMixin(object): def to_string(self): def param_strings(): is_positional = False @@ -42,9 +30,6 @@ class AbstractSignature(object): s += ' -> ' + annotation return s - def bind(self, context): - raise NotImplementedError - def get_param_names(self): param_names = self._function_context.get_param_names() if self.is_bound: @@ -52,6 +37,23 @@ class AbstractSignature(object): return param_names +class AbstractSignature(_SignatureMixin): + def __init__(self, context, is_bound=False): + self.context = context + self.is_bound = is_bound + + @property + def name(self): + return self.context.name + + @property + def annotation_string(self): + return '' + + def bind(self, context): + raise NotImplementedError + + class TreeSignature(AbstractSignature): def __init__(self, context, function_context=None, is_bound=False): super(TreeSignature, self).__init__(context, is_bound) @@ -92,3 +94,11 @@ class BuiltinSignature(AbstractSignature): def bind(self, context): assert not self.is_bound return BuiltinSignature(context, self._return_string, is_bound=True) + + +class SignatureWrapper(_SignatureMixin): + def __init__(self, wrapped_signature): + self._wrapped_signature = wrapped_signature + + def __getattr__(self, name): + return getattr(self._wrapped_signature, name) diff --git a/jedi/plugins/stdlib.py b/jedi/plugins/stdlib.py index 384bf219..fc138cd2 100644 --- a/jedi/plugins/stdlib.py +++ b/jedi/plugins/stdlib.py @@ -32,7 +32,7 @@ from jedi.evaluate.names import ContextName, BaseTreeParamName from jedi.evaluate.syntax_tree import is_string from jedi.evaluate.filters import AttributeOverwrite, publish_method, \ ParserTreeFilter, DictFilter -from jedi.evaluate.signature import AbstractSignature +from jedi.evaluate.signature import AbstractSignature, SignatureWrapper # Copied from Python 3.6's stdlib. @@ -477,17 +477,49 @@ class PartialObject(object): def __getattr__(self, name): return getattr(self._actual_context, name) - def py__call__(self, arguments): - key, lazy_context = next(self._arguments.unpack(), (None, None)) + def _get_function(self, unpacked_arguments): + key, lazy_context = next(unpacked_arguments, (None, None)) if key is not None or lazy_context is None: debug.warning("Partial should have a proper function %s", self._arguments) + return None + return lazy_context.infer() + + def get_signatures(self): + unpacked_arguments = self._arguments.unpack() + func = self._get_function(unpacked_arguments) + if func is None: + return [] + + arg_count = 0 + keys = set() + for key, _ in unpacked_arguments: + if key is None: + arg_count += 1 + else: + keys.add(key) + return [PartialSignature(s, arg_count, keys) for s in func.get_signatures()] + + def py__call__(self, arguments): + func = self._get_function(self._arguments.unpack()) + if func is None: return NO_CONTEXTS - return lazy_context.infer().execute( + return func.execute( MergedPartialArguments(self._arguments, arguments) ) +class PartialSignature(SignatureWrapper): + def __init__(self, wrapped_signature, skipped_arg_count, skipped_arg_set): + super(PartialSignature, self).__init__(wrapped_signature) + self._skipped_arg_count = skipped_arg_count + self._skipped_arg_set = skipped_arg_set + + def get_param_names(self): + names = self._wrapped_signature.get_param_names()[self._skipped_arg_count:] + return [n for n in names if n.string_name not in self._skipped_arg_set] + + class MergedPartialArguments(AbstractArguments): def __init__(self, partial_arguments, call_arguments): self._partial_arguments = partial_arguments diff --git a/test/test_evaluate/test_signature.py b/test/test_evaluate/test_signature.py index 3f5107bd..8cc05109 100644 --- a/test/test_evaluate/test_signature.py +++ b/test/test_evaluate/test_signature.py @@ -47,6 +47,19 @@ class X: pass ''' + +partial_code = ''' +import functools + +def func(a, b, c): + pass + +a = functools.partial(func) +b = functools.partial(func, 1) +c = functools.partial(func, 1, c=2) +d = functools.partial() +''' + @pytest.mark.parametrize( 'code, expected', [ ('def f(a, * args, x): pass\n f(', 'f(a, *args, x)'), @@ -59,6 +72,11 @@ class X: (classmethod_code + 'X().x(', 'x(cls, a, b)'), (classmethod_code + 'X.static(', 'static(a, b)'), (classmethod_code + 'X().static(', 'static(a, b)'), + + (partial_code + 'a(', 'func(a, b, c)'), + (partial_code + 'b(', 'func(b, c)'), + (partial_code + 'c(', 'func(b)'), + (partial_code + 'd(', None), ] ) def test_tree_signature(Script, environment, code, expected): @@ -66,8 +84,11 @@ def test_tree_signature(Script, environment, code, expected): if environment.version_info < (3, 8): pytest.skip() - sig, = Script(code).call_signatures() - assert expected == sig._signature.to_string() + if expected is None: + assert not Script(code).call_signatures() + else: + sig, = Script(code).call_signatures() + assert expected == sig._signature.to_string() def test_pow_signature(Script):