* moved all namedexpr_test related rules to _NamedExprRule

* added valid examples
This commit is contained in:
Jarry Shaw
2019-12-14 09:37:16 +01:00
parent 776e151370
commit 89c4d959e9
2 changed files with 29 additions and 21 deletions

View File

@@ -125,19 +125,6 @@ def _get_for_stmt_definition_exprs(for_stmt):
return list(_iter_definition_exprs_from_lists(exprlist))
def _get_namedexpr(node):
"""Get assignment expression if node contains."""
namedexpr_list = list()
def fetch_namedexpr(node):
if node.type == 'namedexpr_test':
namedexpr_list.append(node)
if hasattr(node, 'children'):
[fetch_namedexpr(child) for child in node.children]
fetch_namedexpr(node)
return namedexpr_list
class _Context(object):
def __init__(self, node, add_syntax_error, parent_context=None):
@@ -979,14 +966,6 @@ class _CompForRule(_CheckAssignmentRule):
if expr_list.type != 'expr_list': # Already handled.
self._check_assignment(expr_list)
or_test = node.children[3]
expr_list = _get_namedexpr(or_test)
for expr in expr_list:
# [i+1 for i in (i := range(5))]
# [i+1 for i in (j := range(5))]
# [i+1 for i in (lambda: (j := range(5)))()]
self.add_issue(expr, message='assignment expression cannot be used in a comprehension iterable expression')
return node.parent.children[0] == 'async' \
and not self._normalizer.context.is_async_funcdef()
@@ -1040,8 +1019,24 @@ class _NamedExprRule(_CheckAssignmentRule):
# namedexpr_test: test [':=' test]
def is_issue(self, namedexpr_test):
# assigned name
first = namedexpr_test.children[0]
def search_namedexpr_in_comp_for(node):
while True:
parent = node.parent
if parent is None:
return parent
if parent.type == 'sync_comp_for' and parent.children[3] == node:
return parent
node = parent
if search_namedexpr_in_comp_for(namedexpr_test):
# [i+1 for i in (i := range(5))]
# [i+1 for i in (j := range(5))]
# [i+1 for i in (lambda: (j := range(5)))()]
self.add_issue(namedexpr_test, message='assignment expression cannot be used in a comprehension iterable expression')
# defined names
exprlist = list()

View File

@@ -293,6 +293,19 @@ def test_valid_fstrings(code):
assert not _get_error_list(code, version='3.6')
@pytest.mark.parametrize(
'code', [
'a = (b := 1)',
'[x4 := x ** 5 for x in range(7)]',
'[total := total + v for v in range(10)]',
'while chunk := file.read(2):\n pass',
'numbers = [y := math.factorial(x), y**2, y**3]',
]
)
def test_valid_namedexpr(code):
assert not _get_error_list(code, version='3.8')
@pytest.mark.parametrize(
('code', 'message'), [
("f'{1+}'", ('invalid syntax')),