Make it possible to be able to test errors for refactorings

This commit is contained in:
Dave Halter
2020-02-14 14:15:57 +01:00
parent 0a3ff6bd70
commit dbf88f2750
3 changed files with 22 additions and 7 deletions

View File

@@ -16,13 +16,14 @@ from .helpers import test_dir
class RefactoringCase(object):
def __init__(self, name, code, line_nr, index, path, kwargs, desired_diff):
def __init__(self, name, code, line_nr, index, path, kwargs, is_error, desired_diff):
self.name = name
self._code = code
self._line_nr = line_nr
self._index = index
self._path = path
self._kwargs = kwargs
self.is_error = is_error
self.desired_diff = desired_diff
@property
@@ -52,14 +53,15 @@ def _collect_file_tests(code, path, lines_to_execute):
second = match.group(3)
# get the line with the position of the operation
p = re.match(r'((?:(?!#\?).)*)#\? (\d*) ?([^\n]*)', first, re.DOTALL)
p = re.match(r'((?:(?!#\?).)*)#\? (\d*)( error|) ?([^\n]*)', first, re.DOTALL)
if p is None:
raise Exception("Please add a test start.")
continue
until = p.group(1)
index = int(p.group(2))
if p.group(3):
kwargs = eval(p.group(3))
is_error = bool(p.group(3))
if p.group(4):
kwargs = eval(p.group(4))
else:
kwargs = {}
@@ -67,7 +69,7 @@ def _collect_file_tests(code, path, lines_to_execute):
if lines_to_execute and line_nr - 1 not in lines_to_execute:
continue
yield RefactoringCase(name, first, line_nr, index, path, kwargs, second)
yield RefactoringCase(name, first, line_nr, index, path, kwargs, is_error, second)
if match is None:
raise Exception("Didn't match any test")
if match.end() != len(code):

View File

@@ -1,3 +1,10 @@
# -------------------------------------------------- multi-equal
def test():
#? 4 error
a = b = 3
return test(100, a)
# ++++++++++++++++++++++++++++++++++++++++++++++++++
Cannot inline a statement with multiple definitions
# -------------------------------------------------- simple
def test():
#? 4

View File

@@ -4,6 +4,7 @@ import pytest
from . import helpers
from jedi.inference.utils import indent_block
from jedi import RefactoringError
def assert_case_equal(case, actual, desired):
@@ -61,5 +62,10 @@ def test_refactor(refactor_case):
:type refactor_case: :class:`.refactor.RefactoringCase`
"""
diff = refactor_case.calculate_diff()
assert_case_equal(refactor_case, diff, refactor_case.desired_diff)
if refactor_case.is_error:
with pytest.raises(RefactoringError) as e:
refactor_case.calculate_diff()
assert e.value.args[0] == refactor_case.desired_diff.strip()
else:
diff = refactor_case.calculate_diff()
assert_case_equal(refactor_case, diff, refactor_case.desired_diff)