mirror of
https://github.com/davidhalter/typeshed.git
synced 2025-12-07 20:54:28 +08:00
Fix @patch when new is missing (#10459)
This commit is contained in:
@@ -234,6 +234,8 @@ class _patch(Generic[_T]):
|
||||
def copy(self) -> _patch[_T]: ...
|
||||
@overload
|
||||
def __call__(self, func: _TT) -> _TT: ...
|
||||
# If new==DEFAULT, this should add a MagicMock parameter to the function
|
||||
# arguments. See the _patch_default_new class below for this functionality.
|
||||
@overload
|
||||
def __call__(self, func: Callable[_P, _R]) -> Callable[_P, _R]: ...
|
||||
if sys.version_info >= (3, 8):
|
||||
@@ -257,6 +259,22 @@ class _patch(Generic[_T]):
|
||||
def start(self) -> _T: ...
|
||||
def stop(self) -> None: ...
|
||||
|
||||
if sys.version_info >= (3, 8):
|
||||
_Mock: TypeAlias = MagicMock | AsyncMock
|
||||
else:
|
||||
_Mock: TypeAlias = MagicMock
|
||||
|
||||
# This class does not exist at runtime, it's a hack to make this work:
|
||||
# @patch("foo")
|
||||
# def bar(..., mock: MagicMock) -> None: ...
|
||||
class _patch_default_new(_patch[_Mock]):
|
||||
@overload
|
||||
def __call__(self, func: _TT) -> _TT: ...
|
||||
# Can't use the following as ParamSpec is only allowed as last parameter:
|
||||
# def __call__(self, func: Callable[_P, _R]) -> Callable[Concatenate[_P, MagicMock], _R]: ...
|
||||
@overload
|
||||
def __call__(self, func: Callable[..., _R]) -> Callable[..., _R]: ...
|
||||
|
||||
class _patch_dict:
|
||||
in_dict: Any
|
||||
values: Any
|
||||
@@ -273,11 +291,8 @@ class _patch_dict:
|
||||
start: Any
|
||||
stop: Any
|
||||
|
||||
if sys.version_info >= (3, 8):
|
||||
_Mock: TypeAlias = MagicMock | AsyncMock
|
||||
else:
|
||||
_Mock: TypeAlias = MagicMock
|
||||
|
||||
# This class does not exist at runtime, it's a hack to add methods to the
|
||||
# patch() function.
|
||||
class _patcher:
|
||||
TEST_PREFIX: str
|
||||
dict: type[_patch_dict]
|
||||
@@ -307,7 +322,7 @@ class _patcher:
|
||||
autospec: Any | None = ...,
|
||||
new_callable: Any | None = ...,
|
||||
**kwargs: Any,
|
||||
) -> _patch[_Mock]: ...
|
||||
) -> _patch_default_new: ...
|
||||
@overload
|
||||
@staticmethod
|
||||
def object( # type: ignore[misc]
|
||||
|
||||
@@ -5,7 +5,7 @@ from datetime import datetime, timedelta
|
||||
from decimal import Decimal
|
||||
from fractions import Fraction
|
||||
from typing_extensions import assert_type
|
||||
from unittest.mock import Mock, patch
|
||||
from unittest.mock import MagicMock, Mock, patch
|
||||
|
||||
case = unittest.TestCase()
|
||||
|
||||
@@ -94,13 +94,20 @@ case.assertGreater(Bacon(), Ham()) # type: ignore
|
||||
###
|
||||
|
||||
|
||||
@patch("sys.exit", new=Mock())
|
||||
def f(i: int) -> str:
|
||||
@patch("sys.exit")
|
||||
def f_default_new(i: int, mock: MagicMock) -> str:
|
||||
return "asdf"
|
||||
|
||||
|
||||
assert_type(f(1), str)
|
||||
f("a") # type: ignore
|
||||
@patch("sys.exit", new=42)
|
||||
def f_explicit_new(i: int) -> str:
|
||||
return "asdf"
|
||||
|
||||
|
||||
assert_type(f_default_new(1), str)
|
||||
f_default_new("a") # Not an error due to ParamSpec limitations
|
||||
assert_type(f_explicit_new(1), str)
|
||||
f_explicit_new("a") # type: ignore[arg-type]
|
||||
|
||||
|
||||
@patch("sys.exit", new=Mock())
|
||||
|
||||
Reference in New Issue
Block a user