diff --git a/parso/python/parser.py b/parso/python/parser.py index be13490..7d4592a 100644 --- a/parso/python/parser.py +++ b/parso/python/parser.py @@ -182,19 +182,16 @@ class Parser(BaseParser): def current_suite(stack): # For now just discard everything that is not a suite or # file_input, if we detect an error. - one_line_suite = False for until_index, stack_node in reversed(list(enumerate(stack))): # `suite` can sometimes be only simple_stmt, not stmt. - if one_line_suite: - break - elif stack_node.nonterminal == 'file_input': + if stack_node.nonterminal == 'file_input': break elif stack_node.nonterminal == 'suite': - if len(stack_node.nodes) > 1: + # In the case where we just have a newline we don't want to + # do error recovery here. In all other cases, we want to do + # error recovery. + if len(stack_node.nodes) != 1: break - elif not stack_node.nodes: - one_line_suite = True - # `suite` without an indent are error nodes. return until_index until_index = current_suite(stack) @@ -221,9 +218,8 @@ class Parser(BaseParser): pass def _stack_removal(self, stack, start_index): - all_nodes = [] - for stack_node in stack[start_index:]: - all_nodes += stack_node.nodes + all_nodes = [node for stack_node in stack[start_index:] for node in stack_node.nodes] + if all_nodes: stack[start_index - 1].nodes.append(tree.PythonErrorNode(all_nodes)) diff --git a/test/test_error_recovery.py b/test/test_error_recovery.py index af6137b..f8dcd94 100644 --- a/test/test_error_recovery.py +++ b/test/test_error_recovery.py @@ -26,13 +26,36 @@ def test_one_line_function(each_version): assert func.children[-1] == ':' +def test_if_else(): + module = parse('if x:\n f.\nelse:\n g(') + if_stmt = module.children[0] + if_, test, colon, suite1, else_, colon, suite2 = if_stmt.children + f = suite1.children[1] + assert f.type == 'error_node' + assert f.children[0].value == 'f' + assert f.children[1].value == '.' + g = suite2.children[1] + assert g.children[0].value == 'g' + assert g.children[1].value == '(' + + def test_if_stmt(): - module = parse('if x: f.')# \nelse: g( + module = parse('if x: f.\nelse: g(') if_stmt = module.children[0] assert if_stmt.type == 'if_stmt' if_, test, colon, f = if_stmt.children assert f.type == 'error_node' assert f.children[0].value == 'f' assert f.children[1].value == '.' - #assert g.children[0].value == 'g' - #assert g.children[1].value == '(' + + assert module.children[1].type == 'newline' + assert module.children[1].value == '\n' + assert module.children[2].type == 'error_leaf' + assert module.children[2].value == 'else' + assert module.children[3].type == 'error_leaf' + assert module.children[3].value == ':' + + in_else_stmt = module.children[4] + assert in_else_stmt.type == 'error_node' + assert in_else_stmt.children[0].value == 'g' + assert in_else_stmt.children[1].value == '('