1
0
forked from VimPlug/jedi

Get references in the current module only

This commit is contained in:
muffinmad
2020-06-03 16:35:28 +03:00
parent f9bbccbc13
commit 782dedd439
5 changed files with 40 additions and 59 deletions

View File

@@ -505,23 +505,23 @@ class Script(object):
quite hard to do for Jedi, if it is too complicated, Jedi will stop quite hard to do for Jedi, if it is too complicated, Jedi will stop
searching. searching.
:param include_builtins: Default ``True``. If ``False`` checks if a reference :param include_builtins: Default ``True``. If ``False``, checks if a reference
is a builtin (e.g. ``sys``) and in that case does not return it. is a builtin (e.g. ``sys``) and in that case does not return it.
:param all_scopes: Default ``True``. If ``False`` include references in :param scope: Default ``'project'``. If ``'file'``, include references in
the current module only. the current module only.
:rtype: list of :class:`.Name` :rtype: list of :class:`.Name`
""" """
def _references(include_builtins=True, all_scopes=True): def _references(include_builtins=True, scope='project'):
tree_name = self._module_node.get_name_of_position((line, column)) tree_name = self._module_node.get_name_of_position((line, column))
if tree_name is None: if tree_name is None:
# Must be syntax # Must be syntax
return [] return []
names = find_references(self._get_module_context(), tree_name, all_scopes) names = find_references(self._get_module_context(), tree_name, scope == 'file')
definitions = [classes.Name(self._inference_state, n) for n in names] definitions = [classes.Name(self._inference_state, n) for n in names]
if not (include_builtins and all_scopes): if not include_builtins or scope == 'file':
definitions = [d for d in definitions if not d.in_builtin_module()] definitions = [d for d in definitions if not d.in_builtin_module()]
return helpers.sorted_definitions(definitions) return helpers.sorted_definitions(definitions)
return _references(**kwargs) return _references(**kwargs)

View File

@@ -138,7 +138,7 @@ class AbstractTreeName(AbstractNameDefinition):
return self.parent_context.get_value() # Might be None return self.parent_context.get_value() # Might be None
return None return None
def goto(self, all_scopes=True): def goto(self):
context = self.parent_context context = self.parent_context
name = self.tree_name name = self.tree_name
definition = name.get_definition(import_name_always=True) definition = name.get_definition(import_name_always=True)
@@ -150,13 +150,13 @@ class AbstractTreeName(AbstractNameDefinition):
is_simple_name = name.parent.type not in ('power', 'trailer') is_simple_name = name.parent.type not in ('power', 'trailer')
if is_simple_name: if is_simple_name:
return [self] return [self]
elif all_scopes and type_ in ('import_from', 'import_name'): elif type_ in ('import_from', 'import_name'):
from jedi.inference.imports import goto_import from jedi.inference.imports import goto_import
module_names = goto_import(context, name) module_names = goto_import(context, name)
return module_names return module_names
else: else:
return [self] return [self]
elif all_scopes: else:
from jedi.inference.imports import follow_error_node_imports_if_possible from jedi.inference.imports import follow_error_node_imports_if_possible
values = follow_error_node_imports_if_possible(context, name) values = follow_error_node_imports_if_possible(context, name)
if values is not None: if values is not None:
@@ -166,8 +166,6 @@ class AbstractTreeName(AbstractNameDefinition):
node_type = par.type node_type = par.type
if node_type == 'argument' and par.children[1] == '=' and par.children[0] == name: if node_type == 'argument' and par.children[1] == '=' and par.children[0] == name:
# Named param goto. # Named param goto.
if not all_scopes:
return [self]
trailer = par.parent trailer = par.parent
if trailer.type == 'arglist': if trailer.type == 'arglist':
trailer = trailer.parent trailer = trailer.parent
@@ -202,8 +200,6 @@ class AbstractTreeName(AbstractNameDefinition):
) )
if node_type == 'trailer' and par.children[0] == '.': if node_type == 'trailer' and par.children[0] == '.':
if not all_scopes:
return [self]
values = infer_call_of_leaf(context, name, cut_own_trailer=True) values = infer_call_of_leaf(context, name, cut_own_trailer=True)
return values.goto(name, name_context=context) return values.goto(name, name_context=context)
else: else:
@@ -502,8 +498,8 @@ class _ActualTreeParamName(BaseTreeParamName):
class AnonymousParamName(_ActualTreeParamName): class AnonymousParamName(_ActualTreeParamName):
@plugin_manager.decorate(name='goto_anonymous_param') @plugin_manager.decorate(name='goto_anonymous_param')
def goto(self, **kwargs): def goto(self):
return super(AnonymousParamName, self).goto(**kwargs) return super(AnonymousParamName, self).goto()
@plugin_manager.decorate(name='infer_anonymous_param') @plugin_manager.decorate(name='infer_anonymous_param')
def infer(self): def infer(self):

View File

@@ -25,7 +25,7 @@ easily 100ms for bigger files.
""" """
def _resolve_names(definition_names, avoid_names=(), all_scopes=True): def _resolve_names(definition_names, avoid_names=()):
for name in definition_names: for name in definition_names:
if name in avoid_names: if name in avoid_names:
# Avoiding recursions here, because goto on a module name lands # Avoiding recursions here, because goto on a module name lands
@@ -37,7 +37,7 @@ def _resolve_names(definition_names, avoid_names=(), all_scopes=True):
# names when importing something like `import foo.bar.baz`. # names when importing something like `import foo.bar.baz`.
yield name yield name
if all_scopes and name.api_type == 'module': if name.api_type == 'module':
for n in _resolve_names(name.goto(), definition_names): for n in _resolve_names(name.goto(), definition_names):
yield n yield n
@@ -49,10 +49,9 @@ def _dictionarize(names):
) )
def _find_defining_names(module_context, tree_name, all_scopes=True): def _find_defining_names(module_context, tree_name):
found_names = _find_names(module_context, tree_name, all_scopes=all_scopes) found_names = _find_names(module_context, tree_name)
if all_scopes:
for name in list(found_names): for name in list(found_names):
# Convert from/to stubs, because those might also be usages. # Convert from/to stubs, because those might also be usages.
found_names |= set(convert_names( found_names |= set(convert_names(
@@ -67,15 +66,15 @@ def _find_defining_names(module_context, tree_name, all_scopes=True):
or name.tree_name.parent.type == 'trailer': or name.tree_name.parent.type == 'trailer':
continue continue
found_names |= set(_add_names_in_same_context(name.parent_context, name.string_name)) found_names |= set(_add_names_in_same_context(name.parent_context, name.string_name))
return set(_resolve_names(found_names, all_scopes=all_scopes)) return set(_resolve_names(found_names))
def _find_names(module_context, tree_name, all_scopes=True): def _find_names(module_context, tree_name):
name = module_context.create_name(tree_name) name = module_context.create_name(tree_name)
found_names = set(name.goto(all_scopes=all_scopes)) found_names = set(name.goto())
found_names.add(name) found_names.add(name)
return set(_resolve_names(found_names, all_scopes=all_scopes)) return set(_resolve_names(found_names))
def _add_names_in_same_context(context, string_name): def _add_names_in_same_context(context, string_name):
@@ -114,7 +113,7 @@ def _find_global_variables(names, search_name):
yield n yield n
def find_references(module_context, tree_name, all_scopes=True): def find_references(module_context, tree_name, only_in_module=False):
inf = module_context.inference_state inf = module_context.inference_state
search_name = tree_name.value search_name = tree_name.value
@@ -122,20 +121,20 @@ def find_references(module_context, tree_name, all_scopes=True):
# certain cases, we want both sides. # certain cases, we want both sides.
try: try:
inf.flow_analysis_enabled = False inf.flow_analysis_enabled = False
found_names = _find_defining_names(module_context, tree_name, all_scopes=all_scopes) found_names = _find_defining_names(module_context, tree_name)
finally: finally:
inf.flow_analysis_enabled = True inf.flow_analysis_enabled = True
found_names_dct = _dictionarize(found_names) found_names_dct = _dictionarize(found_names)
module_contexts = [module_context] module_contexts = [module_context]
if all_scopes: if not only_in_module:
module_contexts.extend( module_contexts.extend(
m for m in set(d.get_root_context() for d in found_names) m for m in set(d.get_root_context() for d in found_names)
if m != module_context and m.tree_node is not None if m != module_context and m.tree_node is not None
) )
# For param no search for other modules is necessary. # For param no search for other modules is necessary.
if not all_scopes or any(n.api_type == 'param' for n in found_names): if only_in_module or any(n.api_type == 'param' for n in found_names):
potential_modules = module_contexts potential_modules = module_contexts
else: else:
potential_modules = get_module_contexts_containing_name( potential_modules = get_module_contexts_containing_name(
@@ -144,15 +143,10 @@ def find_references(module_context, tree_name, all_scopes=True):
search_name, search_name,
) )
gotos = set()
if not all_scopes:
for name in found_names_dct.values():
gotos |= set(name.goto())
non_matching_reference_maps = {} non_matching_reference_maps = {}
for module_context in potential_modules: for module_context in potential_modules:
for name_leaf in module_context.tree_node.get_used_names().get(search_name, []): for name_leaf in module_context.tree_node.get_used_names().get(search_name, []):
new = _dictionarize(_find_names(module_context, name_leaf, all_scopes=all_scopes)) new = _dictionarize(_find_names(module_context, name_leaf))
if any(tree_name in found_names_dct for tree_name in new): if any(tree_name in found_names_dct for tree_name in new):
found_names_dct.update(new) found_names_dct.update(new)
for tree_name in new: for tree_name in new:
@@ -167,19 +161,10 @@ def find_references(module_context, tree_name, all_scopes=True):
else: else:
for name in new: for name in new:
non_matching_reference_maps.setdefault(name, []).append(new) non_matching_reference_maps.setdefault(name, []).append(new)
result = found_names_dct.values()
if not all_scopes: if only_in_module:
def in_gotos(g): return [n for n in result if n.get_root_context() == module_context]
for g_ in gotos: return result
if g.start_pos == g_.start_pos and g.get_root_context() == g_.get_root_context():
return True
for dct_list in non_matching_reference_maps.values():
for dct in dct_list:
for k, v in dct.items():
if any(in_gotos(g) for g in v.goto()):
found_names_dct[k] = v
return found_names_dct.values()
def _check_fs(inference_state, file_io, regex): def _check_fs(inference_state, file_io, regex):

View File

@@ -63,7 +63,7 @@ def infer_anonymous_param(func):
def goto_anonymous_param(func): def goto_anonymous_param(func):
def wrapper(param_name, **kwargs): def wrapper(param_name):
is_pytest_param, param_name_is_function_name = \ is_pytest_param, param_name_is_function_name = \
_is_a_pytest_param_and_inherited(param_name) _is_a_pytest_param_and_inherited(param_name)
if is_pytest_param: if is_pytest_param:
@@ -74,7 +74,7 @@ def goto_anonymous_param(func):
) )
if names: if names:
return names return names
return func(param_name, **kwargs) return func(param_name)
return wrapper return wrapper

View File

@@ -38,7 +38,7 @@ foo()
''', project=project) ''', project=project)
def r(*args): def r(*args):
return script.get_references(all_scopes=False, *args) return script.get_references(scope='file', *args)
print(script._code_lines) print(script._code_lines)
sys_places = r(1, 7) sys_places = r(1, 7)
@@ -79,9 +79,9 @@ def test_local_references_method_other_file(Script):
d1 = datetime.now() d1 = datetime.now()
d2 = datetime.now() d2 = datetime.now()
''', project=Project('', sys_path=[], smart_sys_path=False)) ''', project=Project('', sys_path=[], smart_sys_path=False))
now_places = script.get_references(2, 14, all_scopes=False) now_places = script.get_references(2, 14, scope='file')
assert len(now_places) == 2 assert len(now_places) == 2
assert now_places == script.get_references(3, 14, all_scopes=False) assert now_places == script.get_references(3, 14, scope='file')
def test_local_references_kwarg(Script): def test_local_references_kwarg(Script):
@@ -89,4 +89,4 @@ def test_local_references_kwarg(Script):
script = Script('''from jedi import Script script = Script('''from jedi import Script
Script(code='') Script(code='')
''', project=Project('', sys_path=[], smart_sys_path=False)) ''', project=Project('', sys_path=[], smart_sys_path=False))
assert len(script.get_references(2, 7, all_scopes=False)) == 1 assert len(script.get_references(2, 7, scope='file')) == 1