diff --git a/jedi/api/refactoring.py b/jedi/api/refactoring.py index b88b193b..a1620f11 100644 --- a/jedi/api/refactoring.py +++ b/jedi/api/refactoring.py @@ -41,8 +41,11 @@ class Refactoring(object): def get_changed_files(self): return [ - ChangedFile(self._grammar, path, next(iter(map_)).get_root_node(), map_) - for path, map_ in self._file_to_node_changes.items() + ChangedFile( + self._grammar, path, + module_node=next(iter(map_)).get_root_node(), + node_to_str_map=map_ + ) for path, map_ in self._file_to_node_changes.items() ] def get_renames(self): diff --git a/test/refactor.py b/test/refactor.py old mode 100755 new mode 100644 index 9185ca8b..de32e4fc --- a/test/refactor.py +++ b/test/refactor.py @@ -16,38 +16,32 @@ import jedi class RefactoringCase(object): def __init__(self, name, code, line_nr, index, path, - args, desired): + args, desired_diff): self.name = name self.code = code self.line_nr = line_nr self.index = index self.path = path self._args = args - self.desired = desired + self.desired_diff = desired_diff @property def refactor_type(self): f_name = os.path.basename(self.path) return f_name.replace('.py', '') - def refactor(self): + def calculate_diff(self): script = jedi.Script(self.code, path=self.path) refactor_func = getattr(script, self.refactor_type) - return refactor_func(self.line_nr, self.index, *self._args) - - def run(self): - refactor_object = self.refactor() + refactor_object = refactor_func(self.line_nr, self.index, *self._args) return refactor_object.get_diff() - def check(self): - return self.run() == self.desired - def __repr__(self): return '<%s: %s:%s>' % (self.__class__.__name__, self.name, self.line_nr - 1) -def collect_file_tests(code, path, lines_to_execute): +def _collect_file_tests(code, path, lines_to_execute): r = r'^# -{5} ?([^\n]*)\n((?:(?!\n# \+{5}).)*\n)' \ r'# \+{5}\n((?:(?!\n# -{5}).)*\n)' for match in re.finditer(r, code, re.DOTALL | re.MULTILINE): @@ -80,5 +74,5 @@ def collect_dir_tests(base_dir, test_files): path = os.path.join(base_dir, f_name) with open(path) as f: code = f.read() - for case in collect_file_tests(code, path, lines_to_execute): + for case in _collect_file_tests(code, path, lines_to_execute): yield case diff --git a/test/test_integration.py b/test/test_integration.py index c61f944b..2498a02d 100644 --- a/test/test_integration.py +++ b/test/test_integration.py @@ -58,5 +58,5 @@ def test_refactor(refactor_case): :type refactor_case: :class:`.refactor.RefactoringCase` """ - diff = refactor_case.run() - assert_case_equal(refactor_case, diff, refactor_case.desired) + diff = refactor_case.calculate_diff() + assert_case_equal(refactor_case, diff, refactor_case.desired_diff)