Fix call signatures.

This commit is contained in:
Dave Halter
2016-12-04 03:52:33 +01:00
parent 6940900c58
commit 439e394535
6 changed files with 60 additions and 43 deletions

View File

@@ -314,14 +314,6 @@ class BaseDefinition(object):
Follow both statements and imports, as far as possible.
"""
return self._name.infer()
if self._name.api_type == 'expr_stmt':
return self._evaluator.eval_statement(self._definition)
elif self._name.api_type == 'import':
raise DeprecationWarning
# TODO self._name.infer()?
return imports.ImportWrapper(self._evaluator, self._name).follow()
else:
return set([self._name.parent_context])
@property
@memoize_method
@@ -349,6 +341,8 @@ class BaseDefinition(object):
# there's no better solution.
inferred = names[0].infer()
return get_param_names(next(iter(inferred)))
elif isinstance(context, compiled.CompiledObject):
return context.get_param_names()
return param_names
followed = list(self._follow_statements_imports())
@@ -469,10 +463,8 @@ class Completion(BaseDefinition):
parses all libraries starting with ``a``.
"""
context = self._name.parent_context
if isinstance(self._name, imports.ImportName):
if fast:
return ''
else:
if self._name.api_type == 'module':
if not fast:
followed = self._name.infer()
if followed:
# TODO: Use all of the followed objects as input to Documentation.
@@ -489,18 +481,9 @@ class Completion(BaseDefinition):
The type of the completion objects. Follows imports. For a further
description, look at :attr:`jedi.api.classes.BaseDefinition.type`.
"""
if isinstance(self._definition, tree.Import):
raise DeprecationWarning
i = imports.ImportWrapper(self._evaluator, self._name)
if len(i.import_path) <= 1:
return 'module'
followed = self.follow_definition()
if followed:
# Caveat: Only follows the first one, ignore the other ones.
# This is ok, since people are almost never interested in
# variations.
return followed[0].type
if self._name.api_type == 'module':
for context in self._name.infer():
return context.name.api_type
return super(Completion, self).type
@memoize_method
@@ -567,8 +550,12 @@ class Definition(BaseDefinition):
"""
typ = self.type
if typ in ('statement', 'param'):
definition = self._name.tree_name.get_definition()
try:
tree_name = self._name.tree_name
except AttributeError:
pass
else:
definition = tree_name.get_definition()
try:
first_leaf = definition.first_leaf()
@@ -668,7 +655,12 @@ class Definition(BaseDefinition):
Returns True, if defined as a name in a statement, function or class.
Returns False, if it's a reference to such a definition.
"""
return self._name.is_definition()
try:
tree_name = self._name.tree_name
except AttributeError:
return True
else:
return tree_name.is_definition()
def __eq__(self, other):
return self._name.start_pos == other._name.start_pos \
@@ -705,17 +697,27 @@ class CallSignature(Definition):
for i, param in enumerate(self.params):
if self._key_name_str == param.name:
return i
if self.params and self.params[-1]._name.get_definition().stars == 2:
return i
else:
return None
if self.params:
param_name = self.params[-1]._name
try:
tree_name = param_name.tree_name
except AttributeError:
pass
else:
if tree_name.get_definition().stars == 2:
return i
return None
if self._index >= len(self.params):
for i, param in enumerate(self.params):
# *args case
if param._name.get_definition().stars == 1:
return i
try:
tree_name = param._name.tree_name
except AttributeError:
pass
else:
# *args case
if tree_name.get_definition().stars == 1:
return i
return None
return self._index

View File

@@ -121,7 +121,7 @@ class CompiledObject(Context):
parts = p.strip().split('=')
if len(parts) > 1:
parts.insert(1, Operator('=', (0, 0)))
yield UnresolvableParamName(self, p[0])
yield UnresolvableParamName(self, parts[0])
def __repr__(self):
return '<%s: %s>' % (self.__class__.__name__, repr(self.obj))
@@ -373,8 +373,11 @@ class CompiledObjectFilter(AbstractFilter):
@memoize_method
def get(self, name):
name = str(name)
obj = self._compiled_obj.obj
try:
getattr(self._compiled_obj.obj, name)
getattr(obj, name)
if self._is_instance and name not in dir(obj):
return []
except AttributeError:
return []
except Exception:

View File

@@ -69,7 +69,8 @@ class TreeNameDefinition(ContextName):
def api_type(self):
definition = self.tree_name.get_definition()
return dict(
import_name='import',
import_name='module',
import_from='module',
funcdef='function',
param='param',
classdef='class',

View File

@@ -224,7 +224,10 @@ class CompiledInstanceName(compiled.CompiledName):
parent_context, result_context.funcdef
)
else:
yield result_context
if result_context.api_type == 'function':
yield CompiledBoundMethod(result_context)
else:
yield result_context
class CompiledInstanceClassFilter(compiled.CompiledObjectFilter):
@@ -257,6 +260,15 @@ class BoundMethod(er.FunctionContext):
)
class CompiledBoundMethod(compiled.CompiledObject):
def __init__(self, func):
super(CompiledBoundMethod, self).__init__(
func.evaluator, func.obj, func.parent_context, func.classdef)
def get_param_names(self):
return list(super(CompiledBoundMethod, self).get_param_names())[1:]
class InstanceNameDefinition(filters.TreeNameDefinition):
def infer(self):
contexts = super(InstanceNameDefinition, self).infer()

View File

@@ -1618,7 +1618,7 @@ class Param(BaseNode):
return self.get_parent_until(IsScope)
def __repr__(self):
default = '' if self.default is None else '=%s' % self.default
default = '' if self.default is None else '=%s' % self.default.get_code()
return '<%s: %s>' % (type(self).__name__, str(self._tfpdef()) + default)

View File

@@ -170,7 +170,7 @@ class TestCallSignatures(TestCase):
signatures = Script(s).call_signatures()
assert len(signatures) == 1
x = [p.description for p in signatures[0].params]
assert x == ['*args']
assert x == ['param *args']
def test_additional_brackets(self):
assert_signature('str((', 'str', 0)
@@ -265,8 +265,7 @@ def test_signature_is_definition():
# Now compare all the attributes that a CallSignature must also have.
for attr_name in dir(definition):
dont_scan = ['defined_names', 'line_nr', 'start_pos', 'documentation',
'doc', 'parent', 'goto_assignments']
dont_scan = ['defined_names', 'parent', 'goto_assignments', 'params']
if attr_name.startswith('_') or attr_name in dont_scan:
continue
attribute = getattr(definition, attr_name)