diff --git a/jedi/api/__init__.py b/jedi/api/__init__.py index 7531598a..a4ec5363 100644 --- a/jedi/api/__init__.py +++ b/jedi/api/__init__.py @@ -505,21 +505,23 @@ class Script(object): quite hard to do for Jedi, if it is too complicated, Jedi will stop searching. - :param include_builtins: Default True, checks if a reference is a - builtin (e.g. ``sys``) and in that case does not return it. + :param include_builtins: Default ``True``. If ``False`` checks if a reference + is a builtin (e.g. ``sys``) and in that case does not return it. + :param all_scopes: Default ``True``. If ``False`` include references in + the current module only. :rtype: list of :class:`.Name` """ - def _references(include_builtins=True): + def _references(include_builtins=True, all_scopes=True): tree_name = self._module_node.get_name_of_position((line, column)) if tree_name is None: # Must be syntax return [] - names = find_references(self._get_module_context(), tree_name) + names = find_references(self._get_module_context(), tree_name, all_scopes) definitions = [classes.Name(self._inference_state, n) for n in names] - if not include_builtins: + if not (include_builtins and all_scopes): definitions = [d for d in definitions if not d.in_builtin_module()] return helpers.sorted_definitions(definitions) return _references(**kwargs) diff --git a/jedi/inference/names.py b/jedi/inference/names.py index ccf42d4c..bc6d494f 100644 --- a/jedi/inference/names.py +++ b/jedi/inference/names.py @@ -138,7 +138,7 @@ class AbstractTreeName(AbstractNameDefinition): return self.parent_context.get_value() # Might be None return None - def goto(self): + def goto(self, all_scopes=True): context = self.parent_context name = self.tree_name definition = name.get_definition(import_name_always=True) @@ -150,13 +150,13 @@ class AbstractTreeName(AbstractNameDefinition): is_simple_name = name.parent.type not in ('power', 'trailer') if is_simple_name: return [self] - elif type_ in ('import_from', 'import_name'): + elif all_scopes and type_ in ('import_from', 'import_name'): from jedi.inference.imports import goto_import module_names = goto_import(context, name) return module_names else: return [self] - else: + elif all_scopes: from jedi.inference.imports import follow_error_node_imports_if_possible values = follow_error_node_imports_if_possible(context, name) if values is not None: @@ -498,8 +498,8 @@ class _ActualTreeParamName(BaseTreeParamName): class AnonymousParamName(_ActualTreeParamName): @plugin_manager.decorate(name='goto_anonymous_param') - def goto(self): - return super(AnonymousParamName, self).goto() + def goto(self, **kwargs): + return super(AnonymousParamName, self).goto(**kwargs) @plugin_manager.decorate(name='infer_anonymous_param') def infer(self): diff --git a/jedi/inference/references.py b/jedi/inference/references.py index 5534db30..19a843f9 100644 --- a/jedi/inference/references.py +++ b/jedi/inference/references.py @@ -25,7 +25,7 @@ easily 100ms for bigger files. """ -def _resolve_names(definition_names, avoid_names=()): +def _resolve_names(definition_names, avoid_names=(), all_scopes=True): for name in definition_names: if name in avoid_names: # Avoiding recursions here, because goto on a module name lands @@ -37,7 +37,7 @@ def _resolve_names(definition_names, avoid_names=()): # names when importing something like `import foo.bar.baz`. yield name - if name.api_type == 'module': + if all_scopes and name.api_type == 'module': for n in _resolve_names(name.goto(), definition_names): yield n @@ -49,16 +49,17 @@ def _dictionarize(names): ) -def _find_defining_names(module_context, tree_name): - found_names = _find_names(module_context, tree_name) +def _find_defining_names(module_context, tree_name, all_scopes=True): + found_names = _find_names(module_context, tree_name, all_scopes=all_scopes) - for name in list(found_names): - # Convert from/to stubs, because those might also be usages. - found_names |= set(convert_names( - [name], - only_stubs=not name.get_root_context().is_stub(), - prefer_stub_to_compiled=False - )) + if all_scopes: + for name in list(found_names): + # Convert from/to stubs, because those might also be usages. + found_names |= set(convert_names( + [name], + only_stubs=not name.get_root_context().is_stub(), + prefer_stub_to_compiled=False + )) found_names |= set(_find_global_variables(found_names, tree_name.value)) for name in list(found_names): @@ -66,15 +67,15 @@ def _find_defining_names(module_context, tree_name): or name.tree_name.parent.type == 'trailer': continue found_names |= set(_add_names_in_same_context(name.parent_context, name.string_name)) - return set(_resolve_names(found_names)) + return set(_resolve_names(found_names, all_scopes=all_scopes)) -def _find_names(module_context, tree_name): +def _find_names(module_context, tree_name, all_scopes=True): name = module_context.create_name(tree_name) - found_names = set(name.goto()) + found_names = set(name.goto(all_scopes=all_scopes)) found_names.add(name) - return set(_resolve_names(found_names)) + return set(_resolve_names(found_names, all_scopes=all_scopes)) def _add_names_in_same_context(context, string_name): @@ -113,7 +114,7 @@ def _find_global_variables(names, search_name): yield n -def find_references(module_context, tree_name): +def find_references(module_context, tree_name, all_scopes=True): inf = module_context.inference_state search_name = tree_name.value @@ -121,17 +122,20 @@ def find_references(module_context, tree_name): # certain cases, we want both sides. try: inf.flow_analysis_enabled = False - found_names = _find_defining_names(module_context, tree_name) + found_names = _find_defining_names(module_context, tree_name, all_scopes=all_scopes) finally: inf.flow_analysis_enabled = True found_names_dct = _dictionarize(found_names) - module_contexts = set(d.get_root_context() for d in found_names) - module_contexts = [module_context] \ - + [m for m in module_contexts if m != module_context and m.tree_node is not None] + module_contexts = [module_context] + if all_scopes: + module_contexts.extend( + m for m in set(d.get_root_context() for d in found_names) + if m != module_context and m.tree_node is not None + ) # For param no search for other modules is necessary. - if any(n.api_type == 'param' for n in found_names): + if not all_scopes or any(n.api_type == 'param' for n in found_names): potential_modules = module_contexts else: potential_modules = get_module_contexts_containing_name( @@ -143,7 +147,7 @@ def find_references(module_context, tree_name): non_matching_reference_maps = {} for module_context in potential_modules: for name_leaf in module_context.tree_node.get_used_names().get(search_name, []): - new = _dictionarize(_find_names(module_context, name_leaf)) + new = _dictionarize(_find_names(module_context, name_leaf, all_scopes=all_scopes)) if any(tree_name in found_names_dct for tree_name in new): found_names_dct.update(new) for tree_name in new: diff --git a/jedi/plugins/pytest.py b/jedi/plugins/pytest.py index f9b04284..76754394 100644 --- a/jedi/plugins/pytest.py +++ b/jedi/plugins/pytest.py @@ -63,7 +63,7 @@ def infer_anonymous_param(func): def goto_anonymous_param(func): - def wrapper(param_name): + def wrapper(param_name, **kwargs): is_pytest_param, param_name_is_function_name = \ _is_a_pytest_param_and_inherited(param_name) if is_pytest_param: @@ -74,7 +74,7 @@ def goto_anonymous_param(func): ) if names: return names - return func(param_name) + return func(param_name, **kwargs) return wrapper diff --git a/test/test_api/test_usages.py b/test/test_api/test_usages.py index 188f893d..f7ef9c57 100644 --- a/test/test_api/test_usages.py +++ b/test/test_api/test_usages.py @@ -15,3 +15,59 @@ def test_exclude_builtin_modules(Script): places = get(include=False) assert places == [(1, 7), (2, 6)] + + +def test_references_scope(Script): + from jedi.api.project import Project + project = Project('', sys_path=[], smart_sys_path=False) + script = Script( + '''import sys +from collections import defaultdict + +print(sys.path) + +def foo(bar): + baz = defaultdict(int) + return baz + +def bar(foo): + baz = defaultdict(int) + return baz + +foo() +''', project=project) + + def r(*args): + return script.get_references(all_scopes=False, *args) + + print(script._code_lines) + sys_places = r(1, 7) + assert len(sys_places) == 2 + assert sys_places == r(4, 6) + + assert len(r(2, 5)) == 1 + + dd_places = r(2, 24) + assert len(dd_places) == 3 + assert dd_places == r(7, 10) + assert dd_places == r(11, 10) + + foo_places = r(6, 4) + assert len(foo_places) == 2 + assert foo_places == r(14, 0) + + baz_places = r(7, 4) + assert len(baz_places) == 2 + assert baz_places == r(8, 11) + + int_places = r(7, 22) + assert len(int_places) == 2 + assert int_places == r(11, 22) + + baz_places = r(11, 4) + assert len(baz_places) == 2 + assert baz_places == r(12, 11) + + script = Script('from datetime', project=project) + places = r(1, 5) + assert len(places) == 1