Follow error node imports properly in goto assignments as well

This commit is contained in:
Dave Halter
2019-03-27 00:53:35 +01:00
parent f4c17e578c
commit 1c105b5c68
2 changed files with 28 additions and 9 deletions

View File

@@ -290,18 +290,27 @@ class Evaluator(object):
if type_ in ('import_from', 'import_name'): if type_ in ('import_from', 'import_name'):
return imports.infer_import(context, name) return imports.infer_import(context, name)
error_node = tree.search_ancestor(name, 'error_node') result = self.follow_error_node_imports_if_possible(context, name)
if error_node is not None: if result is not None:
first_name = error_node.get_first_leaf().value return result
if first_name == 'from':
level, names = helpers.parse_dotted_names(
error_node.children,
is_import_from=True
)
return imports.Importer(self, names, context.get_root_context(), level).follow()
return helpers.evaluate_call_of_leaf(context, name) return helpers.evaluate_call_of_leaf(context, name)
def follow_error_node_imports_if_possible(self, context, name):
error_node = tree.search_ancestor(name, 'error_node')
if error_node is not None:
# Make it possible to infer stuff like `import foo.` or
# `from foo.bar`.
first_name = error_node.get_first_leaf().value
if first_name in ('from', 'import'):
is_import_from = first_name == 'from'
level, names = helpers.parse_dotted_names(
error_node.children,
is_import_from=is_import_from
)
return imports.Importer(self, names, context.get_root_context(), level).follow()
return None
def goto(self, context, name): def goto(self, context, name):
definition = name.get_definition(import_name_always=True) definition = name.get_definition(import_name_always=True)
if definition is not None: if definition is not None:
@@ -320,6 +329,10 @@ class Evaluator(object):
module_names = imports.infer_import(context, name, is_goto=True) module_names = imports.infer_import(context, name, is_goto=True)
return module_names return module_names
contexts = self.follow_error_node_imports_if_possible(context, name)
if contexts is not None:
return [context.name for context in contexts]
par = name.parent par = name.parent
node_type = par.type node_type = par.type
if node_type == 'argument' and par.children[1] == '=' and par.children[0] == name: if node_type == 'argument' and par.children[1] == '=' and par.children[0] == name:

View File

@@ -119,6 +119,12 @@ def broken_import():
#? import_tree.mod1 #? import_tree.mod1
from import_tree.mod1 from import_tree.mod1
#? 20 import_tree.mod1
import import_tree.mod1.
#! ['module json']
from json
def test_import_priorities(): def test_import_priorities():
""" """