diff --git a/jedi/api/__init__.py b/jedi/api/__init__.py index 160f57eb..f09050bc 100644 --- a/jedi/api/__init__.py +++ b/jedi/api/__init__.py @@ -505,21 +505,23 @@ class Script(object): quite hard to do for Jedi, if it is too complicated, Jedi will stop searching. - :param include_builtins: Default True, checks if a reference is a - builtin (e.g. ``sys``) and in that case does not return it. + :param include_builtins: Default ``True``. If ``False``, checks if a reference + is a builtin (e.g. ``sys``) and in that case does not return it. + :param scope: Default ``'project'``. If ``'file'``, include references in + the current module only. :rtype: list of :class:`.Name` """ - def _references(include_builtins=True): + def _references(include_builtins=True, scope='project'): tree_name = self._module_node.get_name_of_position((line, column)) if tree_name is None: # Must be syntax return [] - names = find_references(self._get_module_context(), tree_name) + names = find_references(self._get_module_context(), tree_name, scope == 'file') definitions = [classes.Name(self._inference_state, n) for n in names] - if not include_builtins: + if not include_builtins or scope == 'file': definitions = [d for d in definitions if not d.in_builtin_module()] return helpers.sorted_definitions(definitions) return _references(**kwargs) diff --git a/jedi/inference/references.py b/jedi/inference/references.py index 4b90dd74..4a1321ed 100644 --- a/jedi/inference/references.py +++ b/jedi/inference/references.py @@ -114,7 +114,7 @@ def _find_global_variables(names, search_name): yield n -def find_references(module_context, tree_name): +def find_references(module_context, tree_name, only_in_module=False): inf = module_context.inference_state search_name = tree_name.value @@ -128,11 +128,14 @@ def find_references(module_context, tree_name): found_names_dct = _dictionarize(found_names) - module_contexts = set(d.get_root_context() for d in found_names) - module_contexts = [module_context] \ - + [m for m in module_contexts if m != module_context and m.tree_node is not None] + module_contexts = [module_context] + if not only_in_module: + module_contexts.extend( + m for m in set(d.get_root_context() for d in found_names) + if m != module_context and m.tree_node is not None + ) # For param no search for other modules is necessary. - if any(n.api_type == 'param' for n in found_names): + if only_in_module or any(n.api_type == 'param' for n in found_names): potential_modules = module_contexts else: potential_modules = get_module_contexts_containing_name( @@ -159,7 +162,10 @@ def find_references(module_context, tree_name): else: for name in new: non_matching_reference_maps.setdefault(name, []).append(new) - return found_names_dct.values() + result = found_names_dct.values() + if only_in_module: + return [n for n in result if n.get_root_context() == module_context] + return result def _check_fs(inference_state, file_io, regex): diff --git a/setup.cfg b/setup.cfg index dfc0c6d0..39633dc1 100644 --- a/setup.cfg +++ b/setup.cfg @@ -14,3 +14,6 @@ ignore = # Line break before binary operator W503, exclude = jedi/third_party/* .tox/* + +[pycodestyle] +max-line-length = 100 diff --git a/test/test_api/test_usages.py b/test/test_api/test_usages.py index 188f893d..e3868333 100644 --- a/test/test_api/test_usages.py +++ b/test/test_api/test_usages.py @@ -1,3 +1,6 @@ +import pytest + + def test_import_references(Script): s = Script("from .. import foo", path="foo.py") assert [usage.line for usage in s.get_references(line=1, column=18)] == [1] @@ -15,3 +18,46 @@ def test_exclude_builtin_modules(Script): places = get(include=False) assert places == [(1, 7), (2, 6)] + + +@pytest.mark.parametrize('code, places', [ + ('', [(1, 7), (4, 6)]), + ('', [(2, 5)]), + ('', [(2, 24), (7, 10), (11, 10)]), + ('', [(6, 4), (14, 0)]), + ('', [(7, 4), (8, 11)]), + ('', [(7, 22), (11, 22)]), + ('', [(11, 4), (12, 11)]), + ('from datetime', [(1, 5)]), + ('''from datetime import datetime +d1 = datetime.now() +d2 = datetime.now() +''', [(2, 14), (3, 14)]), + ('''from datetime import timedelta +t1 = timedelta(seconds=1) +t2 = timedelta(seconds=2) +''', [(2, 15), (3, 15)]) +]) +def test_references_scope(Script, code, places): + if not code: + code = '''import sys +from collections import defaultdict + +print(sys.path) + +def foo(bar): + baz = defaultdict(int) + return baz + +def bar(foo): + baz = defaultdict(int) + return baz + +foo() +''' + from jedi.api.project import Project + project = Project('', sys_path=[], smart_sys_path=False) + script = Script(code, project=project) + + for place in places: + assert places == [(n.line, n.column) for n in script.get_references(scope='file', *place)]