diff --git a/test/refactor.py b/test/refactor.py index b36bf959..5b96bf40 100644 --- a/test/refactor.py +++ b/test/refactor.py @@ -16,27 +16,26 @@ from .helpers import test_dir class RefactoringCase(object): - def __init__(self, name, code, line_nr, index, path, kwargs, is_error, desired_diff): + def __init__(self, name, code, line_nr, index, path, kwargs, type_, desired_result): self.name = name self._code = code self._line_nr = line_nr self._index = index self._path = path self._kwargs = kwargs - self.is_error = is_error - self.desired_diff = desired_diff + self.type = type_ + self.desired_result = desired_result @property def refactor_type(self): f_name = os.path.basename(self._path) return f_name.replace('.py', '') - def calculate_diff(self): + def refactor(self): project = jedi.Project(os.path.join(test_dir, 'refactor')) script = jedi.Script(self._code, path=self._path, project=project) refactor_func = getattr(script, self.refactor_type) - refactor_object = refactor_func(self._line_nr, self._index, **self._kwargs) - return refactor_object.get_diff() + return refactor_func(self._line_nr, self._index, **self._kwargs) def __repr__(self): return '<%s: %s:%s>' % (self.__class__.__name__, @@ -53,13 +52,13 @@ def _collect_file_tests(code, path, lines_to_execute): second = match.group(3) # get the line with the position of the operation - p = re.match(r'((?:(?!#\?).)*)#\? (\d*)( error|) ?([^\n]*)', first, re.DOTALL) + p = re.match(r'((?:(?!#\?).)*)#\? (\d*)( error| text|) ?([^\n]*)', first, re.DOTALL) if p is None: raise Exception("Please add a test start.") continue until = p.group(1) index = int(p.group(2)) - is_error = bool(p.group(3)) + type_ = p.group(3).strip() if p.group(4): kwargs = eval(p.group(4)) else: @@ -69,7 +68,7 @@ def _collect_file_tests(code, path, lines_to_execute): if lines_to_execute and line_nr - 1 not in lines_to_execute: continue - yield RefactoringCase(name, first, line_nr, index, path, kwargs, is_error, second) + yield RefactoringCase(name, first, line_nr, index, path, kwargs, type_, second) if match is None: raise Exception("Didn't match any test") if match.end() != len(code): diff --git a/test/test_integration.py b/test/test_integration.py index 823dcb9c..25fe213c 100644 --- a/test/test_integration.py +++ b/test/test_integration.py @@ -62,10 +62,15 @@ def test_refactor(refactor_case, skip_python2): :type refactor_case: :class:`.refactor.RefactoringCase` """ - if refactor_case.is_error: + if refactor_case.type == 'error': with pytest.raises(RefactoringError) as e: - refactor_case.calculate_diff() - assert e.value.args[0] == refactor_case.desired_diff.strip() + refactor_case.refactor() + assert e.value.args[0] == refactor_case.desired_result.strip() + elif refactor_case.type == 'text': + refactoring = refactor_case.refactor() + assert not refactoring.get_renames() + text = ''.join(f.get_new_code() for f in refactoring.get_changed_files().values()) + assert_case_equal(refactor_case, text, refactor_case.desired_result) else: - diff = refactor_case.calculate_diff() - assert_case_equal(refactor_case, diff, refactor_case.desired_diff) + diff = refactor_case.refactor().get_diff() + assert_case_equal(refactor_case, diff, refactor_case.desired_result)