diff --git a/jedi/__init__.py b/jedi/__init__.py index 8f6a3265..eb2e814c 100644 --- a/jedi/__init__.py +++ b/jedi/__init__.py @@ -42,7 +42,7 @@ from jedi.api.environment import find_virtualenvs, find_system_environments, \ get_default_environment, InvalidPythonEnvironment, create_environment, \ get_system_environment from jedi.api.project import Project, get_default_project -from jedi.api.exceptions import InternalError +from jedi.api.exceptions import InternalError, RefactoringError # Finally load the internal plugins. This is only internal. from jedi.plugins import registry del registry diff --git a/jedi/api/__init__.py b/jedi/api/__init__.py index 095b19d2..27374331 100644 --- a/jedi/api/__init__.py +++ b/jedi/api/__init__.py @@ -524,11 +524,17 @@ class Script(object): ] return sorted(filter(def_ref_filter, defs), key=lambda x: (x.line, x.column)) - def rename(self, line, column, new_name): + def rename(self, line=None, column=None, **kwargs): """ Returns an object that you can use to rename the variable under the cursor and its references to a different name. + + :param new_name: The variable under the cursor will be renamed to this + string. """ + return self._rename(line, column, **kwargs) + + def _rename(self, line, column, new_name): # Python 2... definitions = self.get_references(line, column, include_builtins=False) return rename(self._grammar, definitions, new_name) diff --git a/jedi/api/refactoring.py b/jedi/api/refactoring.py index 25760b48..13257135 100644 --- a/jedi/api/refactoring.py +++ b/jedi/api/refactoring.py @@ -1,9 +1,12 @@ from os.path import dirname, basename, join +import os import re import difflib from parso import split_lines +from jedi.api.exceptions import RefactoringError + class ChangedFile(object): def __init__(self, grammar, from_path, to_path, module_node, node_to_str_map): @@ -15,7 +18,7 @@ class ChangedFile(object): def get_diff(self): old_lines = split_lines(self._module_node.get_code(), keepends=True) - new_lines = split_lines(self.get_code(), keepends=True) + new_lines = split_lines(self.get_new_code(), keepends=True) diff = difflib.unified_diff( old_lines, new_lines, fromfile=self._from_path, @@ -25,12 +28,17 @@ class ChangedFile(object): # reason. return ''.join(diff).rstrip(' ') - def get_code(self): + def get_new_code(self): return self._grammar.refactor(self._module_node, self._node_to_str_map) def apply(self): + if self._from_path is None: + raise RefactoringError( + 'Cannot apply a refactoring on a Script with path=None' + ) + with open(self._from_path, 'w') as f: - f.write(self.get_code()) + f.write(self.get_new_code()) def __repr__(self): return '<%s: %s>' % (self.__class__.__name__, self._from_path) @@ -43,7 +51,13 @@ class Refactoring(object): self._file_to_node_changes = file_to_node_changes def get_changed_files(self): + """ + Returns a path to ``ChangedFile`` map. The files can be used + ``Dict[str + """ def calculate_to_path(p): + if p is None: + return p for from_, to in renames: if p.startswith(from_): p = to + p[len(from_):] @@ -76,11 +90,14 @@ class Refactoring(object): return text + ''.join(f.get_diff() for f in self.get_changed_files().values()) def apply(self): - for f in self.get_changed_files(): + """ + Applies the whole refactoring to the files, which includes renames. + """ + for f in self.get_changed_files().values(): f.apply() for old, new in self.get_renames(): - rename(old, new) + os.rename(old, new) def _calculate_rename(path, new_name): diff --git a/test/test_api/test_refactoring.py b/test/test_api/test_refactoring.py new file mode 100644 index 00000000..11c4bf1e --- /dev/null +++ b/test/test_api/test_refactoring.py @@ -0,0 +1,57 @@ +import os +from textwrap import dedent + +import pytest + +import jedi + + +@pytest.fixture() +def dir_with_content(tmpdir): + with open(os.path.join(tmpdir.strpath, 'modx.py'), 'w') as f: + f.write('import modx\nfoo\n') # self reference + return tmpdir.strpath + + +def test_rename_mod(Script, dir_with_content): + script = Script( + 'import modx; modx\n', + path=os.path.join(dir_with_content, 'some_script.py'), + project=jedi.Project(dir_with_content), + ) + refactoring = script.rename(line=1, new_name='modr') + refactoring.apply() + + p1 = os.path.join(dir_with_content, 'modx.py') + p2 = os.path.join(dir_with_content, 'modr.py') + expected_code = 'import modr\nfoo\n' + assert not os.path.exists(p1) + with open(p2) as f: + assert f.read() == expected_code + + assert refactoring.get_renames() == [(p1, p2)] + + assert refactoring.get_changed_files()[p1].get_new_code() == expected_code + + assert refactoring.get_diff() == dedent('''\ + rename from {dir}/modx.py + rename to {dir}/modr.py + --- {dir}/modx.py + +++ {dir}/modr.py + @@ -1,3 +1,3 @@ + -import modx + +import modr + foo + --- {dir}/some_script.py + +++ {dir}/some_script.py + @@ -1,2 +1,2 @@ + -import modx; modx + +import modr; modr + ''').format(dir=dir_with_content) + + +def test_rename_none_path(Script): + refactoring = Script('foo', path=None).rename(new_name='bar') + with pytest.raises(jedi.RefactoringError, match='on a Script with path=None'): + refactoring.apply() + assert refactoring