diff --git a/tests/check_new_syntax.py b/tests/check_new_syntax.py index a7b23c437..ff6d303be 100755 --- a/tests/check_new_syntax.py +++ b/tests/check_new_syntax.py @@ -18,6 +18,9 @@ def check_new_syntax(tree: ast.AST, path: Path) -> list[str]: return isinstance(node, ast.Constant) and node.s is Ellipsis class OldSyntaxFinder(ast.NodeVisitor): + def __init__(self, *, set_from_collections_abc: bool) -> None: + self.set_from_collections_abc = set_from_collections_abc + def visit_Subscript(self, node: ast.Subscript) -> None: if isinstance(node.value, ast.Name): if node.value.id == "Union" and isinstance(node.slice, ast.Tuple): @@ -29,6 +32,9 @@ def check_new_syntax(tree: ast.AST, path: Path) -> list[str]: if node.value.id in {"List", "FrozenSet"}: new_syntax = f"{node.value.id.lower()}[{ast.unparse(node.slice)}]" errors.append(f"{path}:{node.lineno}: Use built-in generics, e.g. `{new_syntax}`") + if not self.set_from_collections_abc and node.value.id == "Set": + new_syntax = f"set[{ast.unparse(node.slice)}]" + errors.append(f"{path}:{node.lineno}: Use built-in generics, e.g. `{new_syntax}`") if node.value.id == "Deque": new_syntax = f"collections.deque[{ast.unparse(node.slice)}]" errors.append(f"{path}:{node.lineno}: Use `collections.deque` instead of `typing.Deque`, e.g. `{new_syntax}`") @@ -55,21 +61,32 @@ def check_new_syntax(tree: ast.AST, path: Path) -> list[str]: # # TODO: can use built-in generics in type aliases class AnnotationFinder(ast.NodeVisitor): + def __init__(self) -> None: + self.set_from_collections_abc = False + + def visit_ImportFrom(self, node: ast.ImportFrom) -> None: + if node.module == "collections.abc": + imported_classes = node.names + if any(cls.name == "Set" for cls in imported_classes): + self.set_from_collections_abc = True + + self.generic_visit(node) + def visit_AnnAssign(self, node: ast.AnnAssign) -> None: - OldSyntaxFinder().visit(node.annotation) + OldSyntaxFinder(set_from_collections_abc=self.set_from_collections_abc).visit(node.annotation) def visit_arg(self, node: ast.arg) -> None: if node.annotation is not None: - OldSyntaxFinder().visit(node.annotation) + OldSyntaxFinder(set_from_collections_abc=self.set_from_collections_abc).visit(node.annotation) def visit_FunctionDef(self, node: ast.FunctionDef) -> None: if node.returns is not None: - OldSyntaxFinder().visit(node.returns) + OldSyntaxFinder(set_from_collections_abc=self.set_from_collections_abc).visit(node.returns) self.generic_visit(node) def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef) -> None: if node.returns is not None: - OldSyntaxFinder().visit(node.returns) + OldSyntaxFinder(set_from_collections_abc=self.set_from_collections_abc).visit(node.returns) self.generic_visit(node) AnnotationFinder().visit(tree)