Fix diff-parser: Copying parts of if else should not lead to the whole thing being copied

This commit is contained in:
Dave Halter
2018-12-30 15:25:17 +01:00
parent a64c32bb2a
commit f99fe6ad21
2 changed files with 41 additions and 3 deletions

View File

@@ -171,6 +171,7 @@ class DiffParser(object):
% (last_pos, line_length, parso.__version__, ''.join(diff)) % (last_pos, line_length, parso.__version__, ''.join(diff))
) )
#assert self._module.get_code() == ''.join(new_lines)
LOG.debug('diff parser end') LOG.debug('diff parser end')
return self._module return self._module
@@ -212,7 +213,9 @@ class DiffParser(object):
to = self._nodes_tree.parsed_until_line to = self._nodes_tree.parsed_until_line
LOG.debug('actually copy %s to %s', from_, to) LOG.debug('copy old[%s:%s] new[%s:%s]',
copied_nodes[0].start_pos[0],
copied_nodes[-1].end_pos[0] - 1, from_, to)
# Since there are potential bugs that might loop here endlessly, we # Since there are potential bugs that might loop here endlessly, we
# just stop here. # just stop here.
assert last_until_line != self._nodes_tree.parsed_until_line \ assert last_until_line != self._nodes_tree.parsed_until_line \
@@ -228,7 +231,11 @@ class DiffParser(object):
node = leaf node = leaf
while node.parent.type not in ('file_input', 'suite'): while node.parent.type not in ('file_input', 'suite'):
node = node.parent node = node.parent
return node
# Make sure that if only the `else:` line of an if statement is
# copied that not the whole thing is going to be copied.
if node.start_pos[0] >= old_line:
return node
# Must be on the same line. Otherwise we need to parse that bit. # Must be on the same line. Otherwise we need to parse that bit.
return None return None

View File

@@ -142,7 +142,7 @@ def test_if_simple(differ):
differ.initialize(src + 'a') differ.initialize(src + 'a')
differ.parse(src + else_ + "a", copies=0, parsers=1) differ.parse(src + else_ + "a", copies=0, parsers=1)
differ.parse(else_, parsers=1, expect_error_leaves=True) differ.parse(else_, parsers=1, copies=1, expect_error_leaves=True)
differ.parse(src + else_, parsers=1) differ.parse(src + else_, parsers=1)
@@ -560,3 +560,34 @@ def test_invalid_to_valid_nodes(differ):
differ.initialize(code1) differ.initialize(code1)
differ.parse(code2, parsers=1, copies=3) differ.parse(code2, parsers=1, copies=3)
def test_if_removal_and_reappearence(differ):
code1 = dedent('''\
la = 3
if foo:
latte = 3
else:
la
pass
''')
code2 = dedent('''\
la = 3
latte = 3
else:
la
pass
''')
code3 = dedent('''\
la = 3
if foo:
latte = 3
else:
la
''')
differ.initialize(code1)
differ.parse(code2, parsers=1, copies=4, expect_error_leaves=True)
differ.parse(code1, parsers=1, copies=1)
differ.parse(code3, parsers=1, copies=1)