diff --git a/jedi/api/project.py b/jedi/api/project.py index 8b9c9228..9b418e2a 100644 --- a/jedi/api/project.py +++ b/jedi/api/project.py @@ -60,6 +60,11 @@ class Project(object): :param path: The base path for this project. :param python_path: The Python executable path, typically the path of a virtual environment. + :param load_unsafe_extensions: Loads extensions that are not in the + sys path and in the local directories. With this option enabled, + this is potentially unsafe if you clone a git repository and + analyze it's code, because those compiled extensions will be + important and therefore have execution privileges. :param sys_path: list of str. You can override the sys path if you want. By default the ``sys.path.`` is generated from the environment (virtualenvs, etc). @@ -69,13 +74,14 @@ class Project(object): local directories. Otherwise you will have to rely on your packages being properly configured on the ``sys.path``. """ - def py2_comp(path, python_path=None, sys_path=None, - added_sys_path=(), smart_sys_path=True): + def py2_comp(path, python_path=None, load_unsafe_extensions=False, + sys_path=None, added_sys_path=(), smart_sys_path=True): self._path = os.path.abspath(path) self._python_path = python_path self._sys_path = sys_path self._smart_sys_path = smart_sys_path + self._load_unsafe_extensions = load_unsafe_extensions self._django = False self.added_sys_path = list(added_sys_path) """The sys path that is going to be added at the end of the """ @@ -83,20 +89,14 @@ class Project(object): py2_comp(path, **kwargs) @inference_state_as_method_param_cache() - def _get_base_sys_path(self, inference_state, environment=None): - if self._sys_path is not None: - return self._sys_path - + def _get_base_sys_path(self, inference_state): # The sys path has not been set explicitly. - if environment is None: - environment = self.get_environment() - - sys_path = list(environment.get_sys_path()) + sys_path = list(self.get_environment().get_sys_path()) try: sys_path.remove('') except ValueError: pass - return sys_path + self.added_sys_path + return sys_path @inference_state_as_method_param_cache() def _get_sys_path(self, inference_state, environment=None, @@ -105,10 +105,14 @@ class Project(object): Keep this method private for all users of jedi. However internally this one is used like a public method. """ - suffixed = [] + suffixed = list(self.added_sys_path) prefixed = [] - sys_path = list(self._get_base_sys_path(inference_state, environment)) + if self._sys_path is None: + sys_path = list(self._get_base_sys_path(inference_state)) + else: + sys_path = list(self._sys_path) + if self._smart_sys_path: prefixed.append(self._path) diff --git a/jedi/inference/__init__.py b/jedi/inference/__init__.py index 7606be42..91f05841 100644 --- a/jedi/inference/__init__.py +++ b/jedi/inference/__init__.py @@ -142,7 +142,7 @@ class InferenceState(object): def get_sys_path(self, **kwargs): """Convenience function""" - return self.project._get_sys_path(self, environment=self.environment, **kwargs) + return self.project._get_sys_path(self, **kwargs) def infer(self, context, name): def_ = name.get_definition(import_name_always=True) diff --git a/jedi/inference/imports.py b/jedi/inference/imports.py index bbf944ed..bc1ce0b9 100644 --- a/jedi/inference/imports.py +++ b/jedi/inference/imports.py @@ -454,8 +454,12 @@ def _load_python_module(inference_state, file_io, def _load_builtin_module(inference_state, import_names=None, sys_path=None): + project = inference_state.project if sys_path is None: sys_path = inference_state.get_sys_path() + if not project._load_unsafe_extensions: + safe_paths = project._get_base_sys_path(inference_state) + sys_path = [p for p in sys_path if p in safe_paths] dotted_name = '.'.join(import_names) assert dotted_name is not None diff --git a/test/test_api/test_project.py b/test/test_api/test_project.py index b5ef5db0..c3d016f3 100644 --- a/test/test_api/test_project.py +++ b/test/test_api/test_project.py @@ -29,7 +29,7 @@ def test_added_sys_path(inference_state): project = get_default_project() p = '/some_random_path' project.added_sys_path = [p] - assert p in project._get_base_sys_path(inference_state) + assert p in project._get_sys_path(inference_state) def test_load_save_project(tmpdir): diff --git a/test/test_inference/test_extension.py b/test/test_inference/test_extension.py index 6e6f5899..e488f4fd 100644 --- a/test/test_inference/test_extension.py +++ b/test/test_inference/test_extension.py @@ -35,8 +35,9 @@ def test_get_signatures_stdlib(Script): # Check only on linux 64 bit platform and Python3.4. @pytest.mark.skipif('sys.platform != "linux" or sys.maxsize <= 2**32 or sys.version_info[:2] != (3, 4)') +@pytest.mark.parametrize('load_unsafe_extensions', [False, True]) @cwd_at('test/examples') -def test_init_extension_module(Script): +def test_init_extension_module(Script, load_unsafe_extensions): """ ``__init__`` extension modules are also packages and Jedi should understand that. @@ -50,8 +51,25 @@ def test_init_extension_module(Script): This is also why this test only runs on certain systems (and Python 3.4). """ - s = jedi.Script('import init_extension_module as i\ni.', path='not_existing.py') - assert 'foo' in [c.name for c in s.complete()] + project = jedi.Project('.', load_unsafe_extensions=load_unsafe_extensions) + s = jedi.Script( + 'import init_extension_module as i\ni.', + path='not_existing.py', + project=project, + ) + if load_unsafe_extensions: + assert 'foo' in [c.name for c in s.complete()] + else: + assert 'foo' not in [c.name for c in s.complete()] - s = jedi.Script('from init_extension_module import foo\nfoo', path='not_existing.py') - assert ['foo'] == [c.name for c in s.complete()] + s = jedi.Script( + 'from init_extension_module import foo\nfoo', + path='not_existing.py', + project=project, + ) + c, = s.complete() + assert c.name == 'foo' + if load_unsafe_extensions: + assert c.infer() + else: + assert not c.infer() diff --git a/test/test_inference/test_pyc.py b/test/test_inference/test_pyc.py index 615fd1cd..35d51327 100644 --- a/test/test_inference/test_pyc.py +++ b/test/test_inference/test_pyc.py @@ -55,7 +55,8 @@ def pyc_project_path(tmpdir): shutil.rmtree(path) -def test_pyc(pyc_project_path, environment): +@pytest.mark.parametrize('load_unsafe_extensions', [False, True]) +def test_pyc(pyc_project_path, environment, load_unsafe_extensions): """ The list of completion must be greater than 2. """ @@ -66,8 +67,14 @@ def test_pyc(pyc_project_path, environment): # we also have the same version and it's easier to debug. environment = SameEnvironment() environment = environment + project = jedi.Project(pyc_project_path, load_unsafe_extensions=load_unsafe_extensions) s = jedi.Script( "from dummy_package import dummy; dummy.", path=path, - environment=environment) - assert len(s.complete()) >= 2 + environment=environment, + project=project, + ) + if load_unsafe_extensions: + assert len(s.complete()) >= 2 + else: + assert not s.complete()