diff --git a/jedi/api/__init__.py b/jedi/api/__init__.py index a7b0fbfe..bfb2951b 100644 --- a/jedi/api/__init__.py +++ b/jedi/api/__init__.py @@ -550,17 +550,18 @@ class Script(object): return parso_to_jedi_errors(self._inference_state.grammar, self._module_node) def _names(self, all_scopes=False, definitions=True, references=False): - def def_ref_filter(name): - is_def = name.tree_name.is_definition() - return definitions and is_def or references and not is_def - # Set line/column to a random position, because they don't matter. module_context = self._get_module_context() defs = [ module_context.create_name(name) - for name in get_module_names(self._module_node, all_scopes) + for name in get_module_names( + self._module_node, + all_scopes=all_scopes, + definitions=definitions, + references=references, + ) ] - return sorted(filter(def_ref_filter, defs), key=lambda x: x.start_pos) + return sorted(defs, key=lambda x: x.start_pos) @no_py2_support def rename(self, line=None, column=None, **kwargs): diff --git a/jedi/inference/helpers.py b/jedi/inference/helpers.py index 3b0b33dd..1e1f6914 100644 --- a/jedi/inference/helpers.py +++ b/jedi/inference/helpers.py @@ -122,11 +122,15 @@ def get_names_of_node(node): return list(chain.from_iterable(get_names_of_node(c) for c in children)) -def get_module_names(module, all_scopes): +def get_module_names(module, all_scopes, definitions=True, references=False): """ Returns a dictionary with name parts as keys and their call paths as values. """ + def def_ref_filter(name): + is_def = name.is_definition() + return definitions and is_def or references and not is_def + names = list(chain.from_iterable(module.get_used_names().values())) if not all_scopes: # We have to filter all the names that don't have the module as a @@ -142,7 +146,7 @@ def get_module_names(module, all_scopes): return parent_scope in (module, None) names = [n for n in names if is_module_scope_name(n)] - return names + return filter(def_ref_filter, names) def is_string(value): diff --git a/test/test_api/test_project.py b/test/test_api/test_project.py index b34fcacc..0885c3e5 100644 --- a/test/test_api/test_project.py +++ b/test/test_api/test_project.py @@ -51,10 +51,21 @@ def test_load_save_project(tmpdir): dict(complete=True)), ('test_load_save_p', ['test_api.test_project.test_load_save_project'], dict(complete=True, all_scopes=True)), + + ('some_search_test_var', [], {}), + ('some_search_test_var', ['test_api.test_project.test_search.some_search_test_var'], + dict(all_scopes=True)), + ('some_search_test_var', ['test_api.test_project.test_search.some_search_test_var'], + dict(complete=True, all_scopes=True)), + + ('sample_int', ['helpers.sample_int'], {}), + ('sample_int', ['helpers.sample_int'], dict(all_scopes=True)), + ('sample_int.real', ['builtins.int.real'], {}), ] ) @pytest.mark.skipif(sys.version_info < (3, 6), reason="Ignore Python 2, because EOL") def test_search(string, full_names, kwargs, skip_pre_python36): + some_search_test_var = 1.0 project = Project(test_dir) defs = project.search(string, **kwargs) assert [d.full_name for d in defs] == full_names