From 9e86c6026a00a7bcc6c77ec8f7925a99930ee983 Mon Sep 17 00:00:00 2001 From: Shantanu <12621235+hauntsaninja@users.noreply.github.com> Date: Tue, 20 Jun 2023 04:48:49 -0700 Subject: [PATCH] unittest.mock: use ParamSpec in patch (#10325) Fixes #10324 --- stdlib/unittest/mock.pyi | 5 +++-- test_cases/stdlib/check_unittest.py | 29 +++++++++++++++++++++++++++++ tests/utils.py | 4 ++-- 3 files changed, 34 insertions(+), 4 deletions(-) diff --git a/stdlib/unittest/mock.pyi b/stdlib/unittest/mock.pyi index 1f554da52..0ed0701cc 100644 --- a/stdlib/unittest/mock.pyi +++ b/stdlib/unittest/mock.pyi @@ -3,13 +3,14 @@ from collections.abc import Awaitable, Callable, Coroutine, Iterable, Mapping, S from contextlib import _GeneratorContextManager from types import TracebackType from typing import Any, Generic, TypeVar, overload -from typing_extensions import Final, Literal, Self, TypeAlias +from typing_extensions import Final, Literal, ParamSpec, Self, TypeAlias _T = TypeVar("_T") _TT = TypeVar("_TT", bound=type[Any]) _R = TypeVar("_R") _F = TypeVar("_F", bound=Callable[..., Any]) _AF = TypeVar("_AF", bound=Callable[..., Coroutine[Any, Any, Any]]) +_P = ParamSpec("_P") if sys.version_info >= (3, 8): __all__ = ( @@ -234,7 +235,7 @@ class _patch(Generic[_T]): @overload def __call__(self, func: _TT) -> _TT: ... @overload - def __call__(self, func: Callable[..., _R]) -> Callable[..., _R]: ... + def __call__(self, func: Callable[_P, _R]) -> Callable[_P, _R]: ... if sys.version_info >= (3, 8): def decoration_helper( self, patched: _patch[Any], args: Sequence[Any], keywargs: Any diff --git a/test_cases/stdlib/check_unittest.py b/test_cases/stdlib/check_unittest.py index 7b0de9b30..3c03b1001 100644 --- a/test_cases/stdlib/check_unittest.py +++ b/test_cases/stdlib/check_unittest.py @@ -1,9 +1,12 @@ from __future__ import annotations import unittest +from collections.abc import Callable from datetime import datetime, timedelta from decimal import Decimal from fractions import Fraction +from typing_extensions import assert_type +from unittest.mock import Mock, patch case = unittest.TestCase() @@ -86,3 +89,29 @@ case.assertGreater(datetime(1999, 1, 2), 1) # type: ignore case.assertGreater(Spam(), Eggs()) # type: ignore case.assertGreater(Ham(), Bacon()) # type: ignore case.assertGreater(Bacon(), Ham()) # type: ignore + +### +# Tests for mock.patch +### + + +@patch("sys.exit", new=Mock()) +def f(i: int) -> str: + return "asdf" + + +assert_type(f(1), str) +f("a") # type: ignore + + +@patch("sys.exit", new=Mock()) +class TestXYZ(unittest.TestCase): + attr: int = 5 + + @staticmethod + def method() -> int: + return 123 + + +assert_type(TestXYZ.attr, int) +assert_type(TestXYZ.method, Callable[[], int]) diff --git a/tests/utils.py b/tests/utils.py index 1a47dd8d3..11a590327 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -111,12 +111,12 @@ def testcase_dir_from_package_name(package_name: str) -> Path: def get_all_testcase_directories() -> list[PackageInfo]: - testcase_directories = [PackageInfo("stdlib", Path("test_cases"))] + testcase_directories: list[PackageInfo] = [] for package_name in os.listdir("stubs"): potential_testcase_dir = testcase_dir_from_package_name(package_name) if potential_testcase_dir.is_dir(): testcase_directories.append(PackageInfo(package_name, potential_testcase_dir)) - return sorted(testcase_directories) + return [PackageInfo("stdlib", Path("test_cases"))] + sorted(testcase_directories) # ====================================================================