diff --git a/tests/check_new_syntax.py b/tests/check_new_syntax.py index ff6d303be..8064ca9c3 100755 --- a/tests/check_new_syntax.py +++ b/tests/check_new_syntax.py @@ -4,10 +4,23 @@ import ast import sys from itertools import chain from pathlib import Path +from typing import TypedDict + + +class ModuleImportsDict(TypedDict): + set_from_collections_abc: bool + context_manager_from_typing: bool + async_context_manager_from_typing: bool + + +STUBS_SUPPORTING_PYTHON_2 = frozenset( + {path.parent for path in Path("stubs").rglob("METADATA.toml") if "python2 = true" in path.read_text().splitlines()} +) def check_new_syntax(tree: ast.AST, path: Path) -> list[str]: errors = [] + python_2_support_required = any(directory in path.parents for directory in STUBS_SUPPORTING_PYTHON_2) def unparse_without_tuple_parens(node: ast.AST) -> str: if isinstance(node, ast.Tuple) and node.elts: @@ -18,8 +31,8 @@ def check_new_syntax(tree: ast.AST, path: Path) -> list[str]: return isinstance(node, ast.Constant) and node.s is Ellipsis class OldSyntaxFinder(ast.NodeVisitor): - def __init__(self, *, set_from_collections_abc: bool) -> None: - self.set_from_collections_abc = set_from_collections_abc + def __init__(self, module_imports_info_dict: ModuleImportsDict) -> None: + self.module_imports_info_dict = module_imports_info_dict def visit_Subscript(self, node: ast.Subscript) -> None: if isinstance(node.value, ast.Name): @@ -32,7 +45,7 @@ def check_new_syntax(tree: ast.AST, path: Path) -> list[str]: if node.value.id in {"List", "FrozenSet"}: new_syntax = f"{node.value.id.lower()}[{ast.unparse(node.slice)}]" errors.append(f"{path}:{node.lineno}: Use built-in generics, e.g. `{new_syntax}`") - if not self.set_from_collections_abc and node.value.id == "Set": + if not self.module_imports_info_dict["set_from_collections_abc"] and node.value.id == "Set": new_syntax = f"set[{ast.unparse(node.slice)}]" errors.append(f"{path}:{node.lineno}: Use built-in generics, e.g. `{new_syntax}`") if node.value.id == "Deque": @@ -53,6 +66,23 @@ def check_new_syntax(tree: ast.AST, path: Path) -> list[str]: ): new_syntax = f"tuple[{unparse_without_tuple_parens(node.slice)}]" errors.append(f"{path}:{node.lineno}: Use built-in generics, e.g. `{new_syntax}`") + if not python_2_support_required: + if self.module_imports_info_dict["context_manager_from_typing"] and node.value.id == "ContextManager": + new_syntax = f"contextlib.AbstractContextManager[{ast.unparse(node.slice)}]" + errors.append( + f"{path}:{node.lineno}: Use `contextlib.AbstractContextManager` instead of `typing.ContextManager`, " + f"e.g. `{new_syntax}`" + ) + if ( + self.module_imports_info_dict["async_context_manager_from_typing"] + and node.value.id == "AsyncContextManager" + ): + new_syntax = f"contextlib.AbstractAsyncContextManager[{ast.unparse(node.slice)}]" + errors.append( + f"{path}:{node.lineno}: " + f"Use `contextlib.AbstractAsyncContextManager` instead of `typing.AsyncContextManager`, " + f"e.g. `{new_syntax}`" + ) self.generic_visit(node) @@ -62,31 +92,47 @@ def check_new_syntax(tree: ast.AST, path: Path) -> list[str]: # TODO: can use built-in generics in type aliases class AnnotationFinder(ast.NodeVisitor): def __init__(self) -> None: - self.set_from_collections_abc = False + self.module_imports_info_dict = ModuleImportsDict( + set_from_collections_abc=False, context_manager_from_typing=False, async_context_manager_from_typing=False + ) + + def old_syntax_finder(self) -> OldSyntaxFinder: + """Convenience method to create an `OldSyntaxFinder` method with the correct state""" + return OldSyntaxFinder(self.module_imports_info_dict) def visit_ImportFrom(self, node: ast.ImportFrom) -> None: if node.module == "collections.abc": - imported_classes = node.names - if any(cls.name == "Set" for cls in imported_classes): - self.set_from_collections_abc = True + classes_from_collections_abc = node.names + if any(cls.name == "Set" for cls in classes_from_collections_abc): + self.module_imports_info_dict["set_from_collections_abc"] = True + + elif node.module == "typing": + classes_from_typing = node.names + + for cls in classes_from_typing: + cls_name = cls.name + if cls_name == "ContextManager": + self.module_imports_info_dict["context_manager_from_typing"] = True + elif cls_name == "AsyncContextManager": + self.module_imports_info_dict["async_context_manager_from_typing"] = True self.generic_visit(node) def visit_AnnAssign(self, node: ast.AnnAssign) -> None: - OldSyntaxFinder(set_from_collections_abc=self.set_from_collections_abc).visit(node.annotation) + self.old_syntax_finder().visit(node.annotation) def visit_arg(self, node: ast.arg) -> None: if node.annotation is not None: - OldSyntaxFinder(set_from_collections_abc=self.set_from_collections_abc).visit(node.annotation) + self.old_syntax_finder().visit(node.annotation) def visit_FunctionDef(self, node: ast.FunctionDef) -> None: if node.returns is not None: - OldSyntaxFinder(set_from_collections_abc=self.set_from_collections_abc).visit(node.returns) + self.old_syntax_finder().visit(node.returns) self.generic_visit(node) def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef) -> None: if node.returns is not None: - OldSyntaxFinder(set_from_collections_abc=self.set_from_collections_abc).visit(node.returns) + self.old_syntax_finder().visit(node.returns) self.generic_visit(node) AnnotationFinder().visit(tree)