diff --git a/parso/python/errors.py b/parso/python/errors.py index 9ae5329..71c0dad 100644 --- a/parso/python/errors.py +++ b/parso/python/errors.py @@ -125,19 +125,6 @@ def _get_for_stmt_definition_exprs(for_stmt): return list(_iter_definition_exprs_from_lists(exprlist)) -def _get_namedexpr(node): - """Get assignment expression if node contains.""" - namedexpr_list = list() - - def fetch_namedexpr(node): - if node.type == 'namedexpr_test': - namedexpr_list.append(node) - if hasattr(node, 'children'): - [fetch_namedexpr(child) for child in node.children] - - fetch_namedexpr(node) - return namedexpr_list - class _Context(object): def __init__(self, node, add_syntax_error, parent_context=None): @@ -979,14 +966,6 @@ class _CompForRule(_CheckAssignmentRule): if expr_list.type != 'expr_list': # Already handled. self._check_assignment(expr_list) - or_test = node.children[3] - expr_list = _get_namedexpr(or_test) - for expr in expr_list: - # [i+1 for i in (i := range(5))] - # [i+1 for i in (j := range(5))] - # [i+1 for i in (lambda: (j := range(5)))()] - self.add_issue(expr, message='assignment expression cannot be used in a comprehension iterable expression') - return node.parent.children[0] == 'async' \ and not self._normalizer.context.is_async_funcdef() @@ -1040,8 +1019,24 @@ class _NamedExprRule(_CheckAssignmentRule): # namedexpr_test: test [':=' test] def is_issue(self, namedexpr_test): + # assigned name first = namedexpr_test.children[0] + def search_namedexpr_in_comp_for(node): + while True: + parent = node.parent + if parent is None: + return parent + if parent.type == 'sync_comp_for' and parent.children[3] == node: + return parent + node = parent + + if search_namedexpr_in_comp_for(namedexpr_test): + # [i+1 for i in (i := range(5))] + # [i+1 for i in (j := range(5))] + # [i+1 for i in (lambda: (j := range(5)))()] + self.add_issue(namedexpr_test, message='assignment expression cannot be used in a comprehension iterable expression') + # defined names exprlist = list() diff --git a/test/test_python_errors.py b/test/test_python_errors.py index 71a67eb..ce5bcb3 100644 --- a/test/test_python_errors.py +++ b/test/test_python_errors.py @@ -293,6 +293,19 @@ def test_valid_fstrings(code): assert not _get_error_list(code, version='3.6') +@pytest.mark.parametrize( + 'code', [ + 'a = (b := 1)', + '[x4 := x ** 5 for x in range(7)]', + '[total := total + v for v in range(10)]', + 'while chunk := file.read(2):\n pass', + 'numbers = [y := math.factorial(x), y**2, y**3]', + ] +) +def test_valid_namedexpr(code): + assert not _get_error_list(code, version='3.8') + + @pytest.mark.parametrize( ('code', 'message'), [ ("f'{1+}'", ('invalid syntax')),