Fix completion for non-pytest params

This commit is contained in:
Dave Halter
2019-12-27 13:02:16 +01:00
parent 31936776a5
commit 4c22f4dbb1
3 changed files with 59 additions and 16 deletions

View File

@@ -76,7 +76,7 @@ def get_flow_scope_node(module_node, position):
@plugin_manager.decorate() @plugin_manager.decorate()
def complete_param_names(context, function_name): def complete_param_names(context, function_name, decorator_nodes):
# Basically there's no way to do param completion. The plugins are # Basically there's no way to do param completion. The plugins are
# responsible for this. # responsible for this.
return [] return []
@@ -225,14 +225,7 @@ class Completion:
dot = self._module_node.get_leaf_for_position(self._position) dot = self._module_node.get_leaf_for_position(self._position)
completion_names += self._complete_trailer(dot.get_previous_leaf()) completion_names += self._complete_trailer(dot.get_previous_leaf())
elif self._is_parameter_completion(): elif self._is_parameter_completion():
stack_node = self.stack[-2] completion_names += self._complete_params(leaf)
if stack_node.nonterminal == 'parameters':
stack_node = self.stack[-3]
if stack_node.nonterminal == 'funcdef':
context = get_user_context(self._module_context, self._position)
function_name = stack_node.nodes[1]
completion_names += complete_param_names(context, function_name)
else: else:
completion_names += self._complete_global_scope() completion_names += self._complete_global_scope()
completion_names += self._complete_inherited(is_function=False) completion_names += self._complete_inherited(is_function=False)
@@ -264,6 +257,28 @@ class Completion:
# var args is for lambdas and typed args for normal functions # var args is for lambdas and typed args for normal functions
return tos.nonterminal in ('typedargslist', 'varargslist') and tos.nodes[-1] == ',' return tos.nonterminal in ('typedargslist', 'varargslist') and tos.nodes[-1] == ','
def _complete_params(self, leaf):
stack_node = self.stack[-2]
if stack_node.nonterminal == 'parameters':
stack_node = self.stack[-3]
if stack_node.nonterminal == 'funcdef':
context = get_user_context(self._module_context, self._position)
node = search_ancestor(leaf, 'error_node', 'funcdef')
if node.type == 'error_node':
n = node.children[0]
if n.type == 'decorators':
decorators = n.children
elif n.type == 'decorator':
decorators = [n]
else:
decorators = []
else:
decorators = node.get_decorators()
function_name = stack_node.nodes[1]
return complete_param_names(context, function_name.value, decorators)
return []
def _complete_keywords(self, allowed_transitions): def _complete_keywords(self, allowed_transitions):
for k in allowed_transitions: for k in allowed_transitions:
if isinstance(k, str) and k.isalpha(): if isinstance(k, str) and k.isalpha():

View File

@@ -54,14 +54,16 @@ def goto_anonymous_param(func):
def complete_param_names(func): def complete_param_names(func):
def wrapper(context, func_name): def wrapper(context, func_name, decorator_nodes):
module_context = context.get_root_context() module_context = context.get_root_context()
if func_name.startswith('test_') \
or any('fixture' in n.get_code() for n in decorator_nodes):
names = [] names = []
for module_context in _iter_pytest_modules(module_context): for module_context in _iter_pytest_modules(module_context):
names += FixtureFilter(module_context).values() names += FixtureFilter(module_context).values()
if names: if names:
return names return names
return func(context, func_name) return func(context, func_name, decorator_nodes)
return wrapper return wrapper

View File

@@ -72,3 +72,29 @@ def test_x(my_con
#? 18 ['my_conftest_fixture'] #? 18 ['my_conftest_fixture']
def test_x(my_conftest_fixture): def test_x(my_conftest_fixture):
return return
#? []
def lala(my_con
return
@pytest.fixture
#? ['my_conftest_fixture']
def lala(my_con
return
@pytest.fixture
#? 15 ['my_conftest_fixture']
def lala(my_con):
return
@pytest.fixture
@some_decorator
#? ['my_conftest_fixture']
def lala(my_con
return
@pytest.fixture
@some_decorator
#? 15 ['my_conftest_fixture']
def lala(my_con):
return