Make sure to not load unsafe modules anymore if they are not on the sys path, fixes #760

This commit is contained in:
Dave Halter
2020-01-31 13:09:28 +01:00
parent e7a77e438d
commit 8ff2ea4b38
6 changed files with 56 additions and 23 deletions

View File

@@ -60,6 +60,11 @@ class Project(object):
:param path: The base path for this project. :param path: The base path for this project.
:param python_path: The Python executable path, typically the path of a :param python_path: The Python executable path, typically the path of a
virtual environment. virtual environment.
:param load_unsafe_extensions: Loads extensions that are not in the
sys path and in the local directories. With this option enabled,
this is potentially unsafe if you clone a git repository and
analyze it's code, because those compiled extensions will be
important and therefore have execution privileges.
:param sys_path: list of str. You can override the sys path if you :param sys_path: list of str. You can override the sys path if you
want. By default the ``sys.path.`` is generated from the want. By default the ``sys.path.`` is generated from the
environment (virtualenvs, etc). environment (virtualenvs, etc).
@@ -69,13 +74,14 @@ class Project(object):
local directories. Otherwise you will have to rely on your packages local directories. Otherwise you will have to rely on your packages
being properly configured on the ``sys.path``. being properly configured on the ``sys.path``.
""" """
def py2_comp(path, python_path=None, sys_path=None, def py2_comp(path, python_path=None, load_unsafe_extensions=False,
added_sys_path=(), smart_sys_path=True): sys_path=None, added_sys_path=(), smart_sys_path=True):
self._path = os.path.abspath(path) self._path = os.path.abspath(path)
self._python_path = python_path self._python_path = python_path
self._sys_path = sys_path self._sys_path = sys_path
self._smart_sys_path = smart_sys_path self._smart_sys_path = smart_sys_path
self._load_unsafe_extensions = load_unsafe_extensions
self._django = False self._django = False
self.added_sys_path = list(added_sys_path) self.added_sys_path = list(added_sys_path)
"""The sys path that is going to be added at the end of the """ """The sys path that is going to be added at the end of the """
@@ -83,20 +89,14 @@ class Project(object):
py2_comp(path, **kwargs) py2_comp(path, **kwargs)
@inference_state_as_method_param_cache() @inference_state_as_method_param_cache()
def _get_base_sys_path(self, inference_state, environment=None): def _get_base_sys_path(self, inference_state):
if self._sys_path is not None:
return self._sys_path
# The sys path has not been set explicitly. # The sys path has not been set explicitly.
if environment is None: sys_path = list(self.get_environment().get_sys_path())
environment = self.get_environment()
sys_path = list(environment.get_sys_path())
try: try:
sys_path.remove('') sys_path.remove('')
except ValueError: except ValueError:
pass pass
return sys_path + self.added_sys_path return sys_path
@inference_state_as_method_param_cache() @inference_state_as_method_param_cache()
def _get_sys_path(self, inference_state, environment=None, def _get_sys_path(self, inference_state, environment=None,
@@ -105,10 +105,14 @@ class Project(object):
Keep this method private for all users of jedi. However internally this Keep this method private for all users of jedi. However internally this
one is used like a public method. one is used like a public method.
""" """
suffixed = [] suffixed = list(self.added_sys_path)
prefixed = [] prefixed = []
sys_path = list(self._get_base_sys_path(inference_state, environment)) if self._sys_path is None:
sys_path = list(self._get_base_sys_path(inference_state))
else:
sys_path = list(self._sys_path)
if self._smart_sys_path: if self._smart_sys_path:
prefixed.append(self._path) prefixed.append(self._path)

View File

@@ -142,7 +142,7 @@ class InferenceState(object):
def get_sys_path(self, **kwargs): def get_sys_path(self, **kwargs):
"""Convenience function""" """Convenience function"""
return self.project._get_sys_path(self, environment=self.environment, **kwargs) return self.project._get_sys_path(self, **kwargs)
def infer(self, context, name): def infer(self, context, name):
def_ = name.get_definition(import_name_always=True) def_ = name.get_definition(import_name_always=True)

View File

@@ -454,8 +454,12 @@ def _load_python_module(inference_state, file_io,
def _load_builtin_module(inference_state, import_names=None, sys_path=None): def _load_builtin_module(inference_state, import_names=None, sys_path=None):
project = inference_state.project
if sys_path is None: if sys_path is None:
sys_path = inference_state.get_sys_path() sys_path = inference_state.get_sys_path()
if not project._load_unsafe_extensions:
safe_paths = project._get_base_sys_path(inference_state)
sys_path = [p for p in sys_path if p in safe_paths]
dotted_name = '.'.join(import_names) dotted_name = '.'.join(import_names)
assert dotted_name is not None assert dotted_name is not None

View File

@@ -29,7 +29,7 @@ def test_added_sys_path(inference_state):
project = get_default_project() project = get_default_project()
p = '/some_random_path' p = '/some_random_path'
project.added_sys_path = [p] project.added_sys_path = [p]
assert p in project._get_base_sys_path(inference_state) assert p in project._get_sys_path(inference_state)
def test_load_save_project(tmpdir): def test_load_save_project(tmpdir):

View File

@@ -35,8 +35,9 @@ def test_get_signatures_stdlib(Script):
# Check only on linux 64 bit platform and Python3.4. # Check only on linux 64 bit platform and Python3.4.
@pytest.mark.skipif('sys.platform != "linux" or sys.maxsize <= 2**32 or sys.version_info[:2] != (3, 4)') @pytest.mark.skipif('sys.platform != "linux" or sys.maxsize <= 2**32 or sys.version_info[:2] != (3, 4)')
@pytest.mark.parametrize('load_unsafe_extensions', [False, True])
@cwd_at('test/examples') @cwd_at('test/examples')
def test_init_extension_module(Script): def test_init_extension_module(Script, load_unsafe_extensions):
""" """
``__init__`` extension modules are also packages and Jedi should understand ``__init__`` extension modules are also packages and Jedi should understand
that. that.
@@ -50,8 +51,25 @@ def test_init_extension_module(Script):
This is also why this test only runs on certain systems (and Python 3.4). This is also why this test only runs on certain systems (and Python 3.4).
""" """
s = jedi.Script('import init_extension_module as i\ni.', path='not_existing.py') project = jedi.Project('.', load_unsafe_extensions=load_unsafe_extensions)
assert 'foo' in [c.name for c in s.complete()] s = jedi.Script(
'import init_extension_module as i\ni.',
path='not_existing.py',
project=project,
)
if load_unsafe_extensions:
assert 'foo' in [c.name for c in s.complete()]
else:
assert 'foo' not in [c.name for c in s.complete()]
s = jedi.Script('from init_extension_module import foo\nfoo', path='not_existing.py') s = jedi.Script(
assert ['foo'] == [c.name for c in s.complete()] 'from init_extension_module import foo\nfoo',
path='not_existing.py',
project=project,
)
c, = s.complete()
assert c.name == 'foo'
if load_unsafe_extensions:
assert c.infer()
else:
assert not c.infer()

View File

@@ -55,7 +55,8 @@ def pyc_project_path(tmpdir):
shutil.rmtree(path) shutil.rmtree(path)
def test_pyc(pyc_project_path, environment): @pytest.mark.parametrize('load_unsafe_extensions', [False, True])
def test_pyc(pyc_project_path, environment, load_unsafe_extensions):
""" """
The list of completion must be greater than 2. The list of completion must be greater than 2.
""" """
@@ -66,8 +67,14 @@ def test_pyc(pyc_project_path, environment):
# we also have the same version and it's easier to debug. # we also have the same version and it's easier to debug.
environment = SameEnvironment() environment = SameEnvironment()
environment = environment environment = environment
project = jedi.Project(pyc_project_path, load_unsafe_extensions=load_unsafe_extensions)
s = jedi.Script( s = jedi.Script(
"from dummy_package import dummy; dummy.", "from dummy_package import dummy; dummy.",
path=path, path=path,
environment=environment) environment=environment,
assert len(s.complete()) >= 2 project=project,
)
if load_unsafe_extensions:
assert len(s.complete()) >= 2
else:
assert not s.complete()