diff --git a/jedi/api/__init__.py b/jedi/api/__init__.py index 3a4d2d8d..7368b228 100644 --- a/jedi/api/__init__.py +++ b/jedi/api/__init__.py @@ -223,9 +223,9 @@ class Script(object): :rtype: list of :class:`classes.Definition` """ - def filter_follow_imports(names, follow_classes): + def filter_follow_imports(names, check): for name in names: - if isinstance(name, follow_classes): + if check(name): for context in name.infer(): yield context.name else: @@ -233,13 +233,12 @@ class Script(object): names = self._goto() if follow_imports: - # TODO really, sure? TreeNameDefinition? Should probably not follow - # that. - follow_classes = (imports.ImportName, TreeNameDefinition) + names = filter_follow_imports(names, lambda name: name.api_type == 'module') else: - follow_classes = (imports.SubModuleName,) - - names = filter_follow_imports(names, follow_classes) + names = filter_follow_imports( + names, + lambda name: isinstance(name, imports.SubModuleName) + ) defs = [classes.Definition(self._evaluator, d) for d in set(names)] return helpers.sorted_definitions(defs) diff --git a/test/test_api/test_api.py b/test/test_api/test_api.py index 5674fd87..a8064d7c 100644 --- a/test/test_api/test_api.py +++ b/test/test_api/test_api.py @@ -207,6 +207,9 @@ def test_goto_assignments_follow_imports(): definition, = script.goto_assignments() assert (definition.line, definition.column) == start_pos + d, = api.Script('a = 1\na').goto_assignments(follow_imports=True) + assert d.name == 'a' + def test_goto_module(): def check(line, expected):