unittest.mock: use ParamSpec in patch (#10325)

Fixes #10324
This commit is contained in:
Shantanu
2023-06-20 04:48:49 -07:00
committed by GitHub
parent 7114aecf77
commit 9e86c6026a
3 changed files with 34 additions and 4 deletions

View File

@@ -3,13 +3,14 @@ from collections.abc import Awaitable, Callable, Coroutine, Iterable, Mapping, S
from contextlib import _GeneratorContextManager
from types import TracebackType
from typing import Any, Generic, TypeVar, overload
from typing_extensions import Final, Literal, Self, TypeAlias
from typing_extensions import Final, Literal, ParamSpec, Self, TypeAlias
_T = TypeVar("_T")
_TT = TypeVar("_TT", bound=type[Any])
_R = TypeVar("_R")
_F = TypeVar("_F", bound=Callable[..., Any])
_AF = TypeVar("_AF", bound=Callable[..., Coroutine[Any, Any, Any]])
_P = ParamSpec("_P")
if sys.version_info >= (3, 8):
__all__ = (
@@ -234,7 +235,7 @@ class _patch(Generic[_T]):
@overload
def __call__(self, func: _TT) -> _TT: ...
@overload
def __call__(self, func: Callable[..., _R]) -> Callable[..., _R]: ...
def __call__(self, func: Callable[_P, _R]) -> Callable[_P, _R]: ...
if sys.version_info >= (3, 8):
def decoration_helper(
self, patched: _patch[Any], args: Sequence[Any], keywargs: Any

View File

@@ -1,9 +1,12 @@
from __future__ import annotations
import unittest
from collections.abc import Callable
from datetime import datetime, timedelta
from decimal import Decimal
from fractions import Fraction
from typing_extensions import assert_type
from unittest.mock import Mock, patch
case = unittest.TestCase()
@@ -86,3 +89,29 @@ case.assertGreater(datetime(1999, 1, 2), 1) # type: ignore
case.assertGreater(Spam(), Eggs()) # type: ignore
case.assertGreater(Ham(), Bacon()) # type: ignore
case.assertGreater(Bacon(), Ham()) # type: ignore
###
# Tests for mock.patch
###
@patch("sys.exit", new=Mock())
def f(i: int) -> str:
return "asdf"
assert_type(f(1), str)
f("a") # type: ignore
@patch("sys.exit", new=Mock())
class TestXYZ(unittest.TestCase):
attr: int = 5
@staticmethod
def method() -> int:
return 123
assert_type(TestXYZ.attr, int)
assert_type(TestXYZ.method, Callable[[], int])

View File

@@ -111,12 +111,12 @@ def testcase_dir_from_package_name(package_name: str) -> Path:
def get_all_testcase_directories() -> list[PackageInfo]:
testcase_directories = [PackageInfo("stdlib", Path("test_cases"))]
testcase_directories: list[PackageInfo] = []
for package_name in os.listdir("stubs"):
potential_testcase_dir = testcase_dir_from_package_name(package_name)
if potential_testcase_dir.is_dir():
testcase_directories.append(PackageInfo(package_name, potential_testcase_dir))
return sorted(testcase_directories)
return [PackageInfo("stdlib", Path("test_cases"))] + sorted(testcase_directories)
# ====================================================================