Make it possible to test refactoring outputs a bit different

This commit is contained in:
Dave Halter
2020-02-15 00:59:26 +01:00
parent 24114ba631
commit ee8cdb667d
2 changed files with 18 additions and 14 deletions

View File

@@ -16,27 +16,26 @@ from .helpers import test_dir
class RefactoringCase(object):
def __init__(self, name, code, line_nr, index, path, kwargs, is_error, desired_diff):
def __init__(self, name, code, line_nr, index, path, kwargs, type_, desired_result):
self.name = name
self._code = code
self._line_nr = line_nr
self._index = index
self._path = path
self._kwargs = kwargs
self.is_error = is_error
self.desired_diff = desired_diff
self.type = type_
self.desired_result = desired_result
@property
def refactor_type(self):
f_name = os.path.basename(self._path)
return f_name.replace('.py', '')
def calculate_diff(self):
def refactor(self):
project = jedi.Project(os.path.join(test_dir, 'refactor'))
script = jedi.Script(self._code, path=self._path, project=project)
refactor_func = getattr(script, self.refactor_type)
refactor_object = refactor_func(self._line_nr, self._index, **self._kwargs)
return refactor_object.get_diff()
return refactor_func(self._line_nr, self._index, **self._kwargs)
def __repr__(self):
return '<%s: %s:%s>' % (self.__class__.__name__,
@@ -53,13 +52,13 @@ def _collect_file_tests(code, path, lines_to_execute):
second = match.group(3)
# get the line with the position of the operation
p = re.match(r'((?:(?!#\?).)*)#\? (\d*)( error|) ?([^\n]*)', first, re.DOTALL)
p = re.match(r'((?:(?!#\?).)*)#\? (\d*)( error| text|) ?([^\n]*)', first, re.DOTALL)
if p is None:
raise Exception("Please add a test start.")
continue
until = p.group(1)
index = int(p.group(2))
is_error = bool(p.group(3))
type_ = p.group(3).strip()
if p.group(4):
kwargs = eval(p.group(4))
else:
@@ -69,7 +68,7 @@ def _collect_file_tests(code, path, lines_to_execute):
if lines_to_execute and line_nr - 1 not in lines_to_execute:
continue
yield RefactoringCase(name, first, line_nr, index, path, kwargs, is_error, second)
yield RefactoringCase(name, first, line_nr, index, path, kwargs, type_, second)
if match is None:
raise Exception("Didn't match any test")
if match.end() != len(code):

View File

@@ -62,10 +62,15 @@ def test_refactor(refactor_case, skip_python2):
:type refactor_case: :class:`.refactor.RefactoringCase`
"""
if refactor_case.is_error:
if refactor_case.type == 'error':
with pytest.raises(RefactoringError) as e:
refactor_case.calculate_diff()
assert e.value.args[0] == refactor_case.desired_diff.strip()
refactor_case.refactor()
assert e.value.args[0] == refactor_case.desired_result.strip()
elif refactor_case.type == 'text':
refactoring = refactor_case.refactor()
assert not refactoring.get_renames()
text = ''.join(f.get_new_code() for f in refactoring.get_changed_files().values())
assert_case_equal(refactor_case, text, refactor_case.desired_result)
else:
diff = refactor_case.calculate_diff()
assert_case_equal(refactor_case, diff, refactor_case.desired_diff)
diff = refactor_case.refactor().get_diff()
assert_case_equal(refactor_case, diff, refactor_case.desired_result)