Use PEP 604 syntax wherever possible (#7493)

This commit is contained in:
Alex Waygood
2022-03-16 15:01:33 +00:00
committed by GitHub
parent 15e21a8dc1
commit 3ab250eec8
174 changed files with 472 additions and 490 deletions

View File

@@ -10,11 +10,12 @@ STUBS_SUPPORTING_PYTHON_2 = frozenset(
)
def check_new_syntax(tree: ast.AST, path: Path) -> list[str]:
def check_new_syntax(tree: ast.AST, path: Path, stub: str) -> list[str]:
errors = []
sourcelines = stub.splitlines()
python_2_support_required = any(directory in path.parents for directory in STUBS_SUPPORTING_PYTHON_2)
class UnionFinder(ast.NodeVisitor):
class AnnotationUnionFinder(ast.NodeVisitor):
def visit_Subscript(self, node: ast.Subscript) -> None:
if isinstance(node.value, ast.Name):
if node.value.id == "Union" and isinstance(node.slice, ast.Tuple):
@@ -26,23 +27,48 @@ def check_new_syntax(tree: ast.AST, path: Path) -> list[str]:
self.generic_visit(node)
class NonAnnotationUnionFinder(ast.NodeVisitor):
def visit_Subscript(self, node: ast.Subscript) -> None:
if isinstance(node.value, ast.Name):
nodelines = sourcelines[(node.lineno - 1) : node.end_lineno]
for line in nodelines:
# A hack to workaround various PEP 604 bugs in mypy
if any(x in line for x in {"tuple[", "Callable[", "type["}):
return None
if node.value.id == "Union" and isinstance(node.slice, ast.Tuple):
new_syntax = " | ".join(ast.unparse(x) for x in node.slice.elts)
errors.append(f"{path}:{node.lineno}: Use PEP 604 syntax for Union, e.g. `{new_syntax}`")
elif node.value.id == "Optional":
new_syntax = f"{ast.unparse(node.slice)} | None"
errors.append(f"{path}:{node.lineno}: Use PEP 604 syntax for Optional, e.g. `{new_syntax}`")
self.generic_visit(node)
class OldSyntaxFinder(ast.NodeVisitor):
def visit_AnnAssign(self, node: ast.AnnAssign) -> None:
UnionFinder().visit(node.annotation)
AnnotationUnionFinder().visit(node.annotation)
def visit_arg(self, node: ast.arg) -> None:
if node.annotation is not None:
UnionFinder().visit(node.annotation)
AnnotationUnionFinder().visit(node.annotation)
def _visit_function(self, node: ast.FunctionDef | ast.AsyncFunctionDef) -> None:
if node.returns is not None:
AnnotationUnionFinder().visit(node.returns)
self.generic_visit(node)
def visit_FunctionDef(self, node: ast.FunctionDef) -> None:
if node.returns is not None:
UnionFinder().visit(node.returns)
self.generic_visit(node)
self._visit_function(node)
def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef) -> None:
if node.returns is not None:
UnionFinder().visit(node.returns)
self.generic_visit(node)
self._visit_function(node)
def visit_Assign(self, node: ast.Assign) -> None:
NonAnnotationUnionFinder().visit(node.value)
def visit_ClassDef(self, node: ast.ClassDef) -> None:
for base in node.bases:
NonAnnotationUnionFinder().visit(base)
class ObjectClassdefFinder(ast.NodeVisitor):
def visit_ClassDef(self, node: ast.ClassDef) -> None:
@@ -85,8 +111,9 @@ def main() -> None:
continue
with open(path) as f:
tree = ast.parse(f.read())
errors.extend(check_new_syntax(tree, path))
stub = f.read()
tree = ast.parse(stub)
errors.extend(check_new_syntax(tree, path, stub))
if errors:
print("\n".join(errors))