diff --git a/jedi/plugins/pytest.py b/jedi/plugins/pytest.py index 7ae2194d..910f540c 100644 --- a/jedi/plugins/pytest.py +++ b/jedi/plugins/pytest.py @@ -1,3 +1,4 @@ +from parso.python.tree import search_ancestor from jedi._compatibility import FileNotFoundError from jedi.inference.cache import inference_state_method_cache from jedi.inference.imports import load_module_from_path @@ -32,23 +33,25 @@ def infer_anonymous_param(func): return function_context.get_return_values() def wrapper(param_name): - module = param_name.get_root_context() - fixtures = _goto_pytest_fixture(module, param_name.string_name) - if fixtures: - return ValueSet.from_sets( - get_returns(value) - for fixture in fixtures - for value in fixture.infer() - ) + if _is_a_pytest_param(param_name): + module = param_name.get_root_context() + fixtures = _goto_pytest_fixture(module, param_name.string_name) + if fixtures: + return ValueSet.from_sets( + get_returns(value) + for fixture in fixtures + for value in fixture.infer() + ) return func(param_name) return wrapper def goto_anonymous_param(func): def wrapper(param_name): - names = _goto_pytest_fixture(param_name.get_root_context(), param_name.string_name) - if names: - return names + if _is_a_pytest_param(param_name): + names = _goto_pytest_fixture(param_name.get_root_context(), param_name.string_name) + if names: + return names return func(param_name) return wrapper @@ -56,8 +59,7 @@ def goto_anonymous_param(func): def complete_param_names(func): def wrapper(context, func_name, decorator_nodes): module_context = context.get_root_context() - if func_name.startswith('test_') \ - or any('fixture' in n.get_code() for n in decorator_nodes): + if _is_pytest_func(func_name, decorator_nodes): names = [] for module_context in _iter_pytest_modules(module_context): names += FixtureFilter(module_context).values() @@ -74,6 +76,23 @@ def _goto_pytest_fixture(module_context, name): return names +def _is_a_pytest_param(param_name): + """ + Pytest params are either in a `test_*` function or have a pytest fixture + with the decorator @pytest.fixture. + + This is a heuristic and will work in most cases. + """ + funcdef = search_ancestor(param_name.tree_name, 'funcdef') + decorators = funcdef.get_decorators() + return _is_pytest_func(funcdef.name.value, decorators) + + +def _is_pytest_func(func_name, decorator_nodes): + return func_name.startswith('test_') \ + or any('fixture' in n.get_code() for n in decorator_nodes) + + @inference_state_method_cache() def _iter_pytest_modules(module_context): yield module_context diff --git a/test/completion/pytest.py b/test/completion/pytest.py index 345fa34d..eed0217e 100644 --- a/test/completion/pytest.py +++ b/test/completion/pytest.py @@ -47,6 +47,24 @@ def test_x(MyClassFixture): #? MyClassFixture +#? 15 +def lala(my_fixture): + pass + +@pytest.fixture +#? 15 str() +def lala(my_fixture): + pass + +#! 15 ['param my_fixture'] +def lala(my_fixture): + pass + +@pytest.fixture +#! 15 ['def my_fixture'] +def lala(my_fixture): + pass + # ----------------- # completion # -----------------