Fix the issues with added equals after params in the wrong places. Fixes #643.

This commit is contained in:
Dave Halter
2016-07-20 23:19:05 +02:00
parent a2d66579d7
commit f20df95074
3 changed files with 21 additions and 11 deletions

View File

@@ -366,10 +366,11 @@ class Completion(BaseDefinition):
`Completion` objects are returned from :meth:`api.Script.completions`. They `Completion` objects are returned from :meth:`api.Script.completions`. They
provide additional information about a completion. provide additional information about a completion.
""" """
def __init__(self, evaluator, name, like_name_length): def __init__(self, evaluator, name, stack, like_name_length):
super(Completion, self).__init__(evaluator, name) super(Completion, self).__init__(evaluator, name)
self._like_name_length = like_name_length self._like_name_length = like_name_length
self.stack = stack
# Completion objects with the same Completion name (which means # Completion objects with the same Completion name (which means
# duplicate items in the completion) # duplicate items in the completion)
@@ -381,7 +382,10 @@ class Completion(BaseDefinition):
and self.type == 'Function': and self.type == 'Function':
append = '(' append = '('
if isinstance(self._definition, tree.Param): if self.stack is not None:
node_names = list(self.stack.get_node_names(self._evaluator.grammar))
print(node_names)
if 'trailer' in node_names and 'argument' not in node_names:
append += '=' append += '='
name = str(self._name) name = str(self._name)

View File

@@ -28,7 +28,7 @@ def get_call_signature_param_names(call_signatures):
yield p._name yield p._name
def filter_names(evaluator, completion_names, like_name): def filter_names(evaluator, completion_names, stack, like_name):
comp_dct = {} comp_dct = {}
for name in set(completion_names): for name in set(completion_names):
if settings.case_insensitive_completion \ if settings.case_insensitive_completion \
@@ -42,6 +42,7 @@ def filter_names(evaluator, completion_names, like_name):
new = classes.Completion( new = classes.Completion(
evaluator, evaluator,
name, name,
stack,
len(like_name) len(like_name)
) )
k = (new.name, new.complete) # key k = (new.name, new.complete) # key
@@ -89,7 +90,7 @@ class Completion:
completion_names = self._get_context_completions() completion_names = self._get_context_completions()
completions = filter_names(self._evaluator, completion_names, completions = filter_names(self._evaluator, completion_names,
self._like_name) self.stack, self._like_name)
return sorted(completions, key=lambda x: (x.name.startswith('__'), return sorted(completions, key=lambda x: (x.name.startswith('__'),
x.name.startswith('_'), x.name.startswith('_'),
@@ -113,7 +114,7 @@ class Completion:
grammar = self._evaluator.grammar grammar = self._evaluator.grammar
try: try:
stack = helpers.get_stack_at_position( self.stack = helpers.get_stack_at_position(
grammar, self._code_lines, self._module, self._position grammar, self._code_lines, self._module, self._position
) )
except helpers.OnErrorLeaf as e: except helpers.OnErrorLeaf as e:
@@ -122,19 +123,21 @@ class Completion:
# completions since this probably just confuses the user. # completions since this probably just confuses the user.
return [] return []
# If we don't have a context, just use global completion. # If we don't have a context, just use global completion.
self.stack = None
return self._global_completions() return self._global_completions()
allowed_keywords, allowed_tokens = \ allowed_keywords, allowed_tokens = \
helpers.get_possible_completion_types(grammar, stack) helpers.get_possible_completion_types(grammar, self.stack)
completion_names = list(self._get_keyword_completion_names(allowed_keywords)) completion_names = list(self._get_keyword_completion_names(allowed_keywords))
if token.NAME in allowed_tokens: if token.NAME in allowed_tokens:
# This means that we actually have to do type inference. # This means that we actually have to do type inference.
symbol_names = list(stack.get_node_names(grammar)) symbol_names = list(self.stack.get_node_names(grammar))
nodes = list(stack.get_nodes()) nodes = list(self.stack.get_nodes())
if "import_stmt" in symbol_names: if "import_stmt" in symbol_names:
level = 0 level = 0

View File

@@ -338,7 +338,7 @@ class TestGotoAssignments(TestCase):
def test_added_equals_to_params(): def test_added_equals_to_params():
def run(rest_source): def run(rest_source):
source = dedent(""" source = dedent("""
def foo(bar): def foo(bar, baz):
pass pass
""") """)
results = Script(source + rest_source).completions() results = Script(source + rest_source).completions()
@@ -347,5 +347,8 @@ def test_added_equals_to_params():
assert run('foo(bar').name_with_symbols == 'bar=' assert run('foo(bar').name_with_symbols == 'bar='
assert run('foo(bar').complete == '=' assert run('foo(bar').complete == '='
assert run('foo(bar').name_with_symbols == 'bar' assert run('foo(bar, baz').complete == '='
assert run(' bar').name_with_symbols == 'bar'
assert run(' bar').complete == '' assert run(' bar').complete == ''
x = run('foo(bar=isins').name_with_symbols
assert x == 'isinstance'