Fix partial signatures, fixes #1371

This commit is contained in:
Dave Halter
2019-07-24 02:28:12 +02:00
parent 25bbecc269
commit e76120da06
3 changed files with 85 additions and 22 deletions

View File

@@ -1,19 +1,7 @@
from jedi._compatibility import Parameter from jedi._compatibility import Parameter
class AbstractSignature(object): class _SignatureMixin(object):
def __init__(self, context, is_bound=False):
self.context = context
self.is_bound = is_bound
@property
def name(self):
return self.context.name
@property
def annotation_string(self):
return ''
def to_string(self): def to_string(self):
def param_strings(): def param_strings():
is_positional = False is_positional = False
@@ -42,9 +30,6 @@ class AbstractSignature(object):
s += ' -> ' + annotation s += ' -> ' + annotation
return s return s
def bind(self, context):
raise NotImplementedError
def get_param_names(self): def get_param_names(self):
param_names = self._function_context.get_param_names() param_names = self._function_context.get_param_names()
if self.is_bound: if self.is_bound:
@@ -52,6 +37,23 @@ class AbstractSignature(object):
return param_names return param_names
class AbstractSignature(_SignatureMixin):
def __init__(self, context, is_bound=False):
self.context = context
self.is_bound = is_bound
@property
def name(self):
return self.context.name
@property
def annotation_string(self):
return ''
def bind(self, context):
raise NotImplementedError
class TreeSignature(AbstractSignature): class TreeSignature(AbstractSignature):
def __init__(self, context, function_context=None, is_bound=False): def __init__(self, context, function_context=None, is_bound=False):
super(TreeSignature, self).__init__(context, is_bound) super(TreeSignature, self).__init__(context, is_bound)
@@ -92,3 +94,11 @@ class BuiltinSignature(AbstractSignature):
def bind(self, context): def bind(self, context):
assert not self.is_bound assert not self.is_bound
return BuiltinSignature(context, self._return_string, is_bound=True) return BuiltinSignature(context, self._return_string, is_bound=True)
class SignatureWrapper(_SignatureMixin):
def __init__(self, wrapped_signature):
self._wrapped_signature = wrapped_signature
def __getattr__(self, name):
return getattr(self._wrapped_signature, name)

View File

@@ -32,7 +32,7 @@ from jedi.evaluate.names import ContextName, BaseTreeParamName
from jedi.evaluate.syntax_tree import is_string from jedi.evaluate.syntax_tree import is_string
from jedi.evaluate.filters import AttributeOverwrite, publish_method, \ from jedi.evaluate.filters import AttributeOverwrite, publish_method, \
ParserTreeFilter, DictFilter ParserTreeFilter, DictFilter
from jedi.evaluate.signature import AbstractSignature from jedi.evaluate.signature import AbstractSignature, SignatureWrapper
# Copied from Python 3.6's stdlib. # Copied from Python 3.6's stdlib.
@@ -477,17 +477,49 @@ class PartialObject(object):
def __getattr__(self, name): def __getattr__(self, name):
return getattr(self._actual_context, name) return getattr(self._actual_context, name)
def py__call__(self, arguments): def _get_function(self, unpacked_arguments):
key, lazy_context = next(self._arguments.unpack(), (None, None)) key, lazy_context = next(unpacked_arguments, (None, None))
if key is not None or lazy_context is None: if key is not None or lazy_context is None:
debug.warning("Partial should have a proper function %s", self._arguments) debug.warning("Partial should have a proper function %s", self._arguments)
return None
return lazy_context.infer()
def get_signatures(self):
unpacked_arguments = self._arguments.unpack()
func = self._get_function(unpacked_arguments)
if func is None:
return []
arg_count = 0
keys = set()
for key, _ in unpacked_arguments:
if key is None:
arg_count += 1
else:
keys.add(key)
return [PartialSignature(s, arg_count, keys) for s in func.get_signatures()]
def py__call__(self, arguments):
func = self._get_function(self._arguments.unpack())
if func is None:
return NO_CONTEXTS return NO_CONTEXTS
return lazy_context.infer().execute( return func.execute(
MergedPartialArguments(self._arguments, arguments) MergedPartialArguments(self._arguments, arguments)
) )
class PartialSignature(SignatureWrapper):
def __init__(self, wrapped_signature, skipped_arg_count, skipped_arg_set):
super(PartialSignature, self).__init__(wrapped_signature)
self._skipped_arg_count = skipped_arg_count
self._skipped_arg_set = skipped_arg_set
def get_param_names(self):
names = self._wrapped_signature.get_param_names()[self._skipped_arg_count:]
return [n for n in names if n.string_name not in self._skipped_arg_set]
class MergedPartialArguments(AbstractArguments): class MergedPartialArguments(AbstractArguments):
def __init__(self, partial_arguments, call_arguments): def __init__(self, partial_arguments, call_arguments):
self._partial_arguments = partial_arguments self._partial_arguments = partial_arguments

View File

@@ -47,6 +47,19 @@ class X:
pass pass
''' '''
partial_code = '''
import functools
def func(a, b, c):
pass
a = functools.partial(func)
b = functools.partial(func, 1)
c = functools.partial(func, 1, c=2)
d = functools.partial()
'''
@pytest.mark.parametrize( @pytest.mark.parametrize(
'code, expected', [ 'code, expected', [
('def f(a, * args, x): pass\n f(', 'f(a, *args, x)'), ('def f(a, * args, x): pass\n f(', 'f(a, *args, x)'),
@@ -59,6 +72,11 @@ class X:
(classmethod_code + 'X().x(', 'x(cls, a, b)'), (classmethod_code + 'X().x(', 'x(cls, a, b)'),
(classmethod_code + 'X.static(', 'static(a, b)'), (classmethod_code + 'X.static(', 'static(a, b)'),
(classmethod_code + 'X().static(', 'static(a, b)'), (classmethod_code + 'X().static(', 'static(a, b)'),
(partial_code + 'a(', 'func(a, b, c)'),
(partial_code + 'b(', 'func(b, c)'),
(partial_code + 'c(', 'func(b)'),
(partial_code + 'd(', None),
] ]
) )
def test_tree_signature(Script, environment, code, expected): def test_tree_signature(Script, environment, code, expected):
@@ -66,8 +84,11 @@ def test_tree_signature(Script, environment, code, expected):
if environment.version_info < (3, 8): if environment.version_info < (3, 8):
pytest.skip() pytest.skip()
sig, = Script(code).call_signatures() if expected is None:
assert expected == sig._signature.to_string() assert not Script(code).call_signatures()
else:
sig, = Script(code).call_signatures()
assert expected == sig._signature.to_string()
def test_pow_signature(Script): def test_pow_signature(Script):