Some more refactoring for relative imports

This commit is contained in:
Dave Halter
2019-03-08 10:54:28 +01:00
parent c1d65ff144
commit 6b579d53ec
2 changed files with 87 additions and 31 deletions

View File

@@ -193,6 +193,33 @@ class OsPathName(ImportName):
return self.parent_context.evaluator.import_module(('os', 'path'))
def _level_to_base_import_path(project_path, directory, level):
"""
In case the level is outside of the currently known package (something like
import .....foo), we can still try our best to help the user for
completions.
"""
for i in range(level - 1):
old = directory
directory = os.path.dirname(directory)
if old == directory:
return None, None
d = directory
level_import_paths = []
# Now that we are on the level that the user wants to be, calculate the
# import path for it.
while True:
if d == project_path:
return level_import_paths, None
dir_name = os.path.basename(d)
if dir_name:
level_import_paths.insert(0, dir_name)
d = os.path.dirname(d)
else:
return None, d
class Importer(object):
def __init__(self, evaluator, import_path, module_context, level=0):
"""
@@ -217,6 +244,8 @@ class Importer(object):
# Can be None for certain compiled modules like 'builtins'.
self.file_path = None
self._fixed_sys_path = None
self._inference_possible = True
if level:
base = module_context.py__package__().split('.')
if base == [''] or base == ['__main__']:
@@ -247,36 +276,37 @@ class Importer(object):
directory = os.getcwd()
else:
directory = os.path.dirname(path)
level_import_paths = []
for i in range(level - 1):
directory = os.path.dirname(directory)
while directory != os.path.dirname(directory):
if directory == self._evaluator.project._path:
break
dir_name = os.path.basename(directory)
if dir_name:
level_import_paths.insert(0, dir_name)
directory = os.path.dirname(directory)
else:
base_import_path, base_directory = _level_to_base_import_path(
self._evaluator.project._path, directory, level,
)
if base_import_path is None:
if import_path:
_add_error(
module_context, import_path[-1],
module_context, import_path[0],
message='Attempted relative import beyond top-level package.'
)
self.import_path = []
return
import_path = level_import_paths + import_path
if base_directory is None:
# Everything is lost, the relative import does point
# somewhere out of the filesystem.
self._inference_possible = False
else:
self._fixed_sys_path = [base_directory]
else:
import_path = base_import_path + import_path
self.import_path = import_path
@property
def str_import_path(self):
def _str_import_path(self):
"""Returns the import path as pure strings instead of `Name`."""
return tuple(
name.value if isinstance(name, tree.Name) else name
for name in self.import_path
)
def sys_path_with_modifications(self):
def _sys_path_with_modifications(self):
if self._fixed_sys_path is not None:
return self._fixed_sys_path
sys_path_mod = (
self._evaluator.get_sys_path()
@@ -291,7 +321,8 @@ class Importer(object):
return sys_path_mod
def follow(self):
if not self.import_path or not self._evaluator.infer_enabled:
if not self.import_path or not self._evaluator.infer_enabled \
or not self._inference_possible:
return NO_CONTEXTS
import_names = tuple(
@@ -306,7 +337,7 @@ class Importer(object):
self._evaluator.import_module(
import_names[:i+1],
parent_module_context,
self.sys_path_with_modifications(),
self._sys_path_with_modifications(),
)
for parent_module_context in context_set
])
@@ -329,7 +360,7 @@ class Importer(object):
for name in sub.get_builtin_module_names()]
if search_path is None:
search_path = self.sys_path_with_modifications()
search_path = self._sys_path_with_modifications()
for name in sub.list_module_names(search_path):
if in_module is None:
@@ -347,7 +378,7 @@ class Importer(object):
names = []
if self.import_path:
# flask
if self.str_import_path == ('flask', 'ext'):
if self._str_import_path == ('flask', 'ext'):
# List Flask extensions like ``flask_foo``
for mod in self._get_module_names():
modname = mod.string_name
@@ -355,7 +386,7 @@ class Importer(object):
extname = modname[len('flask_'):]
names.append(ImportName(self.module_context, extname))
# Now the old style: ``flaskext.foo``
for dir in self.sys_path_with_modifications():
for dir in self._sys_path_with_modifications():
flaskext = os.path.join(dir, 'flaskext')
if os.path.isdir(flaskext):
names += self._get_module_names([flaskext])
@@ -376,7 +407,7 @@ class Importer(object):
if only_modules:
# In the case of an import like `from x.` we don't need to
# add all the variables.
if ('os',) == self.str_import_path and not self.level:
if ('os',) == self._str_import_path and not self.level:
# os.path is a hardcoded exception, because it's a
# ``sys.modules`` modification.
names.append(OsPathName(context, 'path'))
@@ -386,15 +417,15 @@ class Importer(object):
names += filter.values()
else:
# Empty import path=completion after import
if not self.level:
names += self._get_module_names()
if self.level:
if self.file_path is not None:
path = os.path.abspath(self.file_path)
for i in range(self.level - 1):
path = self.file_path
for i in range(self.level):
path = os.path.dirname(path)
raise 1
names += self._get_module_names([path])
else:
names += self._get_module_names()
return names

View File

@@ -308,3 +308,28 @@ def test_relative_imports_without_path(Script):
script = Script("from ... ")
assert [c.name for c in script.completions()] == ['api', 'import', 'whatever']
def test_relative_import_out_of_file_system(Script):
script = Script("from " + '.' * 100)
import_, = script.completions()
assert import_.name == 'import'
script = Script("from " + '.' * 100 + 'abc import ABCMeta')
assert not script.goto_definitions()
assert not script.completions()
@pytest.mark.parametrize(
'level, directory, project_path, result', [
(1, '/a/b/c', '/a', (['b', 'c'], None)),
(2, '/a/b/c', '/a', (['b'], None)),
(3, '/a/b/c', '/a', ([], None)),
(4, '/a/b/c', '/a', (None, '/')),
(5, '/a/b/c', '/a', (None, None)),
(1, '/', '/', ([], None)),
(2, '/', '/', (None, None)),
]
)
def test_level_to_import_path(level, directory, project_path, result):
assert imports._level_to_base_import_path(project_path, directory, level) == result