Some more refactoring for relative imports

This commit is contained in:
Dave Halter
2019-03-08 10:54:28 +01:00
parent c1d65ff144
commit 6b579d53ec
2 changed files with 87 additions and 31 deletions

View File

@@ -193,6 +193,33 @@ class OsPathName(ImportName):
return self.parent_context.evaluator.import_module(('os', 'path')) return self.parent_context.evaluator.import_module(('os', 'path'))
def _level_to_base_import_path(project_path, directory, level):
"""
In case the level is outside of the currently known package (something like
import .....foo), we can still try our best to help the user for
completions.
"""
for i in range(level - 1):
old = directory
directory = os.path.dirname(directory)
if old == directory:
return None, None
d = directory
level_import_paths = []
# Now that we are on the level that the user wants to be, calculate the
# import path for it.
while True:
if d == project_path:
return level_import_paths, None
dir_name = os.path.basename(d)
if dir_name:
level_import_paths.insert(0, dir_name)
d = os.path.dirname(d)
else:
return None, d
class Importer(object): class Importer(object):
def __init__(self, evaluator, import_path, module_context, level=0): def __init__(self, evaluator, import_path, module_context, level=0):
""" """
@@ -217,6 +244,8 @@ class Importer(object):
# Can be None for certain compiled modules like 'builtins'. # Can be None for certain compiled modules like 'builtins'.
self.file_path = None self.file_path = None
self._fixed_sys_path = None
self._inference_possible = True
if level: if level:
base = module_context.py__package__().split('.') base = module_context.py__package__().split('.')
if base == [''] or base == ['__main__']: if base == [''] or base == ['__main__']:
@@ -247,36 +276,37 @@ class Importer(object):
directory = os.getcwd() directory = os.getcwd()
else: else:
directory = os.path.dirname(path) directory = os.path.dirname(path)
level_import_paths = []
for i in range(level - 1): base_import_path, base_directory = _level_to_base_import_path(
directory = os.path.dirname(directory) self._evaluator.project._path, directory, level,
while directory != os.path.dirname(directory): )
if directory == self._evaluator.project._path: if base_import_path is None:
break if import_path:
dir_name = os.path.basename(directory)
if dir_name:
level_import_paths.insert(0, dir_name)
directory = os.path.dirname(directory)
else:
_add_error( _add_error(
module_context, import_path[-1], module_context, import_path[0],
message='Attempted relative import beyond top-level package.' message='Attempted relative import beyond top-level package.'
) )
self.import_path = [] if base_directory is None:
return # Everything is lost, the relative import does point
import_path = level_import_paths + import_path # somewhere out of the filesystem.
self._inference_possible = False
else:
self._fixed_sys_path = [base_directory]
else:
import_path = base_import_path + import_path
self.import_path = import_path self.import_path = import_path
@property @property
def str_import_path(self): def _str_import_path(self):
"""Returns the import path as pure strings instead of `Name`.""" """Returns the import path as pure strings instead of `Name`."""
return tuple( return tuple(
name.value if isinstance(name, tree.Name) else name name.value if isinstance(name, tree.Name) else name
for name in self.import_path for name in self.import_path
) )
def sys_path_with_modifications(self): def _sys_path_with_modifications(self):
if self._fixed_sys_path is not None:
return self._fixed_sys_path
sys_path_mod = ( sys_path_mod = (
self._evaluator.get_sys_path() self._evaluator.get_sys_path()
@@ -291,7 +321,8 @@ class Importer(object):
return sys_path_mod return sys_path_mod
def follow(self): def follow(self):
if not self.import_path or not self._evaluator.infer_enabled: if not self.import_path or not self._evaluator.infer_enabled \
or not self._inference_possible:
return NO_CONTEXTS return NO_CONTEXTS
import_names = tuple( import_names = tuple(
@@ -306,7 +337,7 @@ class Importer(object):
self._evaluator.import_module( self._evaluator.import_module(
import_names[:i+1], import_names[:i+1],
parent_module_context, parent_module_context,
self.sys_path_with_modifications(), self._sys_path_with_modifications(),
) )
for parent_module_context in context_set for parent_module_context in context_set
]) ])
@@ -329,7 +360,7 @@ class Importer(object):
for name in sub.get_builtin_module_names()] for name in sub.get_builtin_module_names()]
if search_path is None: if search_path is None:
search_path = self.sys_path_with_modifications() search_path = self._sys_path_with_modifications()
for name in sub.list_module_names(search_path): for name in sub.list_module_names(search_path):
if in_module is None: if in_module is None:
@@ -347,7 +378,7 @@ class Importer(object):
names = [] names = []
if self.import_path: if self.import_path:
# flask # flask
if self.str_import_path == ('flask', 'ext'): if self._str_import_path == ('flask', 'ext'):
# List Flask extensions like ``flask_foo`` # List Flask extensions like ``flask_foo``
for mod in self._get_module_names(): for mod in self._get_module_names():
modname = mod.string_name modname = mod.string_name
@@ -355,7 +386,7 @@ class Importer(object):
extname = modname[len('flask_'):] extname = modname[len('flask_'):]
names.append(ImportName(self.module_context, extname)) names.append(ImportName(self.module_context, extname))
# Now the old style: ``flaskext.foo`` # Now the old style: ``flaskext.foo``
for dir in self.sys_path_with_modifications(): for dir in self._sys_path_with_modifications():
flaskext = os.path.join(dir, 'flaskext') flaskext = os.path.join(dir, 'flaskext')
if os.path.isdir(flaskext): if os.path.isdir(flaskext):
names += self._get_module_names([flaskext]) names += self._get_module_names([flaskext])
@@ -376,7 +407,7 @@ class Importer(object):
if only_modules: if only_modules:
# In the case of an import like `from x.` we don't need to # In the case of an import like `from x.` we don't need to
# add all the variables. # add all the variables.
if ('os',) == self.str_import_path and not self.level: if ('os',) == self._str_import_path and not self.level:
# os.path is a hardcoded exception, because it's a # os.path is a hardcoded exception, because it's a
# ``sys.modules`` modification. # ``sys.modules`` modification.
names.append(OsPathName(context, 'path')) names.append(OsPathName(context, 'path'))
@@ -386,15 +417,15 @@ class Importer(object):
names += filter.values() names += filter.values()
else: else:
# Empty import path=completion after import # Empty import path=completion after import
if not self.level: if self.level:
names += self._get_module_names()
if self.file_path is not None: if self.file_path is not None:
path = os.path.abspath(self.file_path) path = self.file_path
for i in range(self.level - 1): for i in range(self.level):
path = os.path.dirname(path) path = os.path.dirname(path)
raise 1
names += self._get_module_names([path]) names += self._get_module_names([path])
else:
names += self._get_module_names()
return names return names

View File

@@ -308,3 +308,28 @@ def test_relative_imports_without_path(Script):
script = Script("from ... ") script = Script("from ... ")
assert [c.name for c in script.completions()] == ['api', 'import', 'whatever'] assert [c.name for c in script.completions()] == ['api', 'import', 'whatever']
def test_relative_import_out_of_file_system(Script):
script = Script("from " + '.' * 100)
import_, = script.completions()
assert import_.name == 'import'
script = Script("from " + '.' * 100 + 'abc import ABCMeta')
assert not script.goto_definitions()
assert not script.completions()
@pytest.mark.parametrize(
'level, directory, project_path, result', [
(1, '/a/b/c', '/a', (['b', 'c'], None)),
(2, '/a/b/c', '/a', (['b'], None)),
(3, '/a/b/c', '/a', ([], None)),
(4, '/a/b/c', '/a', (None, '/')),
(5, '/a/b/c', '/a', (None, None)),
(1, '/', '/', ([], None)),
(2, '/', '/', (None, None)),
]
)
def test_level_to_import_path(level, directory, project_path, result):
assert imports._level_to_base_import_path(project_path, directory, level) == result