diff --git a/test/refactor.py b/test/refactor.py index 243b674e..e82c582b 100644 --- a/test/refactor.py +++ b/test/refactor.py @@ -7,8 +7,11 @@ valuable - just ignore them.** """ from __future__ import with_statement import os +import platform import re +from parso import split_lines + from functools import reduce import jedi from .helpers import test_dir @@ -24,7 +27,18 @@ class RefactoringCase(object): self._path = path self._kwargs = kwargs self.type = type_ - self.desired_result = desired_result + self._desired_result = desired_result + + def get_desired_result(self): + + if platform.system().lower() == 'windows' and self.type == 'diff': + # Windows uses backslashes to separate paths. + lines = split_lines(self._desired_result, keepends=True) + for i, line in enumerate(lines): + if re.search(' import_tree/', line): + lines[i] = line.replace('/', '\\') + return ''.join(lines) + return self._desired_result @property def refactor_type(self): @@ -58,7 +72,7 @@ def _collect_file_tests(code, path, lines_to_execute): continue until = p.group(1) index = int(p.group(2)) - type_ = p.group(3).strip() + type_ = p.group(3).strip() or 'diff' if p.group(4): kwargs = eval(p.group(4)) else: diff --git a/test/test_integration.py b/test/test_integration.py index 883622d3..3bb3ed66 100644 --- a/test/test_integration.py +++ b/test/test_integration.py @@ -66,15 +66,16 @@ def test_refactor(refactor_case, skip_pre_python36, environment): if sys.version_info < (3, 6): pytest.skip() + desired_result = refactor_case.get_desired_result() if refactor_case.type == 'error': with pytest.raises(RefactoringError) as e: refactor_case.refactor(environment) - assert e.value.args[0] == refactor_case.desired_result.strip() + assert e.value.args[0] == desired_result.strip() elif refactor_case.type == 'text': refactoring = refactor_case.refactor(environment) assert not refactoring.get_renames() text = ''.join(f.get_new_code() for f in refactoring.get_changed_files().values()) - assert_case_equal(refactor_case, text, refactor_case.desired_result) + assert_case_equal(refactor_case, text, desired_result) else: diff = refactor_case.refactor(environment).get_diff() - assert_case_equal(refactor_case, diff, refactor_case.desired_result) + assert_case_equal(refactor_case, diff, desired_result)