diff --git a/tests/check_consistent.py b/tests/check_consistent.py index ff8f20bcd..001e8caa8 100755 --- a/tests/check_consistent.py +++ b/tests/check_consistent.py @@ -99,7 +99,7 @@ def check_same_files(): ) -_VERSIONS_RE = re.compile(r"^([a-zA-Z_][a-zA-Z0-9_]*): [23]\.\d{1,2}-(?:[23]\.\d{1,2})?$") +_VERSIONS_RE = re.compile(r"^([a-zA-Z_][a-zA-Z0-9_.]*): [23]\.\d{1,2}-(?:[23]\.\d{1,2})?$") def check_versions(): @@ -115,21 +115,29 @@ def check_versions(): module = m.group(1) assert module not in versions, f"Duplicate module {module} in VERSIONS" versions.add(module) - modules = set() - for entry in os.listdir("stdlib"): - if entry == "@python2" or entry == "VERSIONS": - continue - if os.path.isfile(os.path.join("stdlib", entry)): - mod, _ = os.path.splitext(entry) - modules.add(mod) - else: - modules.add(entry) - extra = modules - versions + modules = _find_stdlib_modules() + # Sub-modules don't need to be listed in VERSIONS. + extra = {m.split(".")[0] for m in modules} - versions assert not extra, f"Modules not in versions: {extra}" extra = versions - modules assert not extra, f"Versions not in modules: {extra}" +def _find_stdlib_modules() -> set[str]: + modules = set() + for path, _, files in os.walk("stdlib"): + if "@python2" in path: + continue + for filename in files: + base_module = ".".join(os.path.normpath(path).split(os.sep)[1:]) + if filename == "__init__.pyi": + modules.add(base_module) + elif filename.endswith(".pyi"): + mod, _ = os.path.splitext(filename) + modules.add(f"{base_module}.{mod}" if base_module else mod) + return modules + + def _strip_dep_version(dependency): dep_version_pos = len(dependency) for pos, c in enumerate(dependency): diff --git a/tests/mypy_test.py b/tests/mypy_test.py index 349a61dac..49f6f4323 100755 --- a/tests/mypy_test.py +++ b/tests/mypy_test.py @@ -71,7 +71,7 @@ def match(fn, args, exclude_list): return True -_VERSION_LINE_RE = re.compile(r"^([a-zA-Z_][a-zA-Z0-9_]*): ([23]\.\d{1,2})-([23]\.\d{1,2})?$") +_VERSION_LINE_RE = re.compile(r"^([a-zA-Z_][a-zA-Z0-9_.]*): ([23]\.\d{1,2})-([23]\.\d{1,2})?$") def parse_versions(fname):