Fix unittest stub issues reported by mypy stubtest (#4248)

Co-authored-by: Shantanu <>
This commit is contained in:
Milap Sheth
2020-06-21 18:09:17 -04:00
committed by GitHub
parent 79111ee8ed
commit 9b3edda33b
10 changed files with 152 additions and 48 deletions

View File

@@ -1,32 +1,15 @@
# Stubs for unittest
from typing import Iterable, List, Optional, Type, Union
from types import ModuleType
from typing import Optional
from unittest.async_case import *
from unittest.case import *
from unittest.loader import *
from unittest.main import *
from unittest.result import TestResult as TestResult
from unittest.runner import *
from unittest.signals import *
from unittest.suite import *
# not really documented
class TestProgram:
result: TestResult
def runTests(self) -> None: ... # undocumented
def main(module: Union[None, str, ModuleType] = ...,
defaultTest: Union[str, Iterable[str], None] = ...,
argv: Optional[List[str]] = ...,
testRunner: Union[Type[TestRunner], TestRunner, None] = ...,
testLoader: TestLoader = ..., exit: bool = ..., verbosity: int = ...,
failfast: Optional[bool] = ..., catchbreak: Optional[bool] = ...,
buffer: Optional[bool] = ...,
warnings: Optional[str] = ...) -> TestProgram: ...
def load_tests(loader: TestLoader, tests: TestSuite,
pattern: Optional[str]) -> TestSuite: ...

View File

@@ -17,7 +17,7 @@ if sys.version_info >= (3, 8):
def addModuleCleanup(__function: Callable[..., Any], *args: Any, **kwargs: Any) -> None: ...
def doModuleCleanups() -> None: ...
def expectedFailure(func: _FT) -> _FT: ...
def expectedFailure(test_item: _FT) -> _FT: ...
def skip(reason: str) -> Callable[[_FT], _FT]: ...
def skipIf(condition: object, reason: str) -> Callable[[_FT], _FT]: ...
def skipUnless(condition: object, reason: str) -> Callable[[_FT], _FT]: ...
@@ -120,25 +120,25 @@ class TestCase:
level: Union[int, str, None] = ...
) -> _AssertLogsContext: ...
@overload
def assertAlmostEqual(self, first: float, second: float, places: int = ...,
msg: Any = ..., delta: float = ...) -> None: ...
def assertAlmostEqual(self, first: float, second: float, places: Optional[int] = ...,
msg: Any = ..., delta: Optional[float] = ...) -> None: ...
@overload
def assertAlmostEqual(self, first: datetime.datetime, second: datetime.datetime,
places: int = ..., msg: Any = ...,
delta: datetime.timedelta = ...) -> None: ...
places: Optional[int] = ..., msg: Any = ...,
delta: Optional[datetime.timedelta] = ...) -> None: ...
@overload
def assertNotAlmostEqual(self, first: float, second: float, *,
msg: Any = ...) -> None: ...
@overload
def assertNotAlmostEqual(self, first: float, second: float,
places: int = ..., msg: Any = ...) -> None: ...
places: Optional[int] = ..., msg: Any = ...) -> None: ...
@overload
def assertNotAlmostEqual(self, first: float, second: float, *,
msg: Any = ..., delta: float = ...) -> None: ...
msg: Any = ..., delta: Optional[float] = ...) -> None: ...
@overload
def assertNotAlmostEqual(self, first: datetime.datetime, second: datetime.datetime,
places: int = ..., msg: Any = ...,
delta: datetime.timedelta = ...) -> None: ...
places: Optional[int] = ..., msg: Any = ...,
delta: Optional[datetime.timedelta] = ...) -> None: ...
def assertRegex(self, text: AnyStr, expected_regex: Union[AnyStr, Pattern[AnyStr]],
msg: Any = ...) -> None: ...
def assertNotRegex(self, text: AnyStr, unexpected_regex: Union[AnyStr, Pattern[AnyStr]],
@@ -151,7 +151,7 @@ class TestCase:
msg: Any = ...) -> None: ...
def assertSequenceEqual(self, seq1: Sequence[Any], seq2: Sequence[Any],
msg: Any = ...,
seq_type: Type[Sequence[Any]] = ...) -> None: ...
seq_type: Optional[Type[Sequence[Any]]] = ...) -> None: ...
def assertListEqual(self, list1: List[Any], list2: List[Any],
msg: Any = ...) -> None: ...
def assertTupleEqual(self, tuple1: Tuple[Any, ...], tuple2: Tuple[Any, ...],
@@ -165,8 +165,14 @@ class TestCase:
def defaultTestResult(self) -> unittest.result.TestResult: ...
def id(self) -> str: ...
def shortDescription(self) -> Optional[str]: ...
def addCleanup(self, function: Callable[..., Any], *args: Any,
**kwargs: Any) -> None: ...
if sys.version_info >= (3, 8):
def addCleanup(self, __function: Callable[..., Any], *args: Any,
**kwargs: Any) -> None: ...
else:
def addCleanup(self, function: Callable[..., Any], *args: Any,
**kwargs: Any) -> None: ...
def doCleanups(self) -> None: ...
if sys.version_info >= (3, 8):
@classmethod
@@ -205,6 +211,8 @@ class TestCase:
delta: float = ...) -> None: ...
def assertRegexpMatches(self, text: AnyStr, regex: Union[AnyStr, Pattern[AnyStr]],
msg: Any = ...) -> None: ...
def assertNotRegexpMatches(self, text: AnyStr, regex: Union[AnyStr, Pattern[AnyStr]],
msg: Any = ...) -> None: ...
@overload
def assertRaisesRegexp(self, # type: ignore
exception: Union[Type[BaseException], Tuple[Type[BaseException], ...]],
@@ -217,8 +225,8 @@ class TestCase:
expected_regex: Union[str, bytes, Pattern[str], Pattern[bytes]],
msg: Any = ...) -> _AssertRaisesContext[_E]: ...
def assertDictContainsSubset(self,
expected: Mapping[Any, Any],
actual: Mapping[Any, Any],
subset: Mapping[Any, Any],
dictionary: Mapping[Any, Any],
msg: object = ...) -> None: ...
class FunctionTestCase(TestCase):
@@ -226,6 +234,7 @@ class FunctionTestCase(TestCase):
setUp: Optional[Callable[[], None]] = ...,
tearDown: Optional[Callable[[], None]] = ...,
description: Optional[str] = ...) -> None: ...
def runTest(self) -> None: ...
class _AssertRaisesContext(Generic[_E]):
exception: _E
@@ -243,8 +252,11 @@ class _AssertWarnsContext:
exc_tb: Optional[TracebackType]) -> None: ...
class _AssertLogsContext:
LOGGING_FORMAT: str
records: List[logging.LogRecord]
output: List[str]
def __init__(self, test_case: unittest.case.TestCase,
logger_name: str, level: int) -> None: ...
def __enter__(self) -> _AssertLogsContext: ...
def __exit__(self, exc_type: Optional[Type[BaseException]], exc_val: Optional[BaseException],
exc_tb: Optional[TracebackType]) -> Optional[bool]: ...

View File

@@ -1,3 +1,4 @@
import sys
import unittest.case
import unittest.suite
import unittest.result
@@ -5,14 +6,21 @@ from types import ModuleType
from typing import Any, Callable, List, Optional, Sequence, Type
_SortComparisonMethod = Callable[[str, str], int]
_SuiteClass = Callable[[List[unittest.case.TestCase]], unittest.suite.TestSuite]
class TestLoader:
errors: List[Type[BaseException]]
testMethodPrefix: str
sortTestMethodsUsing: Callable[[str, str], bool]
suiteClass: Callable[[List[unittest.case.TestCase]], unittest.suite.TestSuite]
sortTestMethodsUsing: _SortComparisonMethod
if sys.version_info >= (3, 7):
testNamePatterns: Optional[List[str]]
suiteClass: _SuiteClass
def loadTestsFromTestCase(self,
testCaseClass: Type[unittest.case.TestCase]) -> unittest.suite.TestSuite: ...
def loadTestsFromModule(self, module: ModuleType, *, pattern: Any = ...) -> unittest.suite.TestSuite: ...
def loadTestsFromModule(self, module: ModuleType, *args: Any, pattern: Any = ...) -> unittest.suite.TestSuite: ...
def loadTestsFromName(self, name: str,
module: Optional[ModuleType] = ...) -> unittest.suite.TestSuite: ...
def loadTestsFromNames(self, names: Sequence[str],
@@ -23,3 +31,19 @@ class TestLoader:
top_level_dir: Optional[str] = ...) -> unittest.suite.TestSuite: ...
defaultTestLoader: TestLoader
if sys.version_info >= (3, 7):
def getTestCaseNames(testCaseClass: Type[unittest.case.TestCase], prefix: str,
sortUsing: _SortComparisonMethod = ...,
testNamePatterns: Optional[List[str]] = ...) -> Sequence[str]: ...
else:
def getTestCaseNames(testCaseClass: Type[unittest.case.TestCase], prefix: str,
sortUsing: _SortComparisonMethod = ...) -> Sequence[str]: ...
def makeSuite(testCaseClass: Type[unittest.case.TestCase], prefix: str = ...,
sortUsing: _SortComparisonMethod = ...,
suiteClass: _SuiteClass = ...) -> unittest.suite.TestSuite: ...
def findTestCases(module, prefix: str = ..., sortUsing: _SortComparisonMethod = ...,
suiteClass: _SuiteClass = ...) -> unittest.suite.TestSuite: ...

View File

@@ -0,0 +1,51 @@
import sys
from typing import Any, Iterable, List, Optional, Protocol, Type, Union
from types import ModuleType
import unittest.case
import unittest.loader
import unittest.result
import unittest.suite
class _TestRunner(Protocol):
def run(
self, test: Union[unittest.suite.TestSuite, unittest.case.TestCase]
) -> unittest.result.TestResult: ...
# not really documented
class TestProgram:
result: unittest.result.TestResult
module: Union[None, str, ModuleType]
verbosity: int
failfast: Optional[bool]
catchbreak: Optional[bool]
buffer: Optional[bool]
progName: Optional[str]
warnings: Optional[str]
if sys.version_info >= (3, 7):
testNamePatterns: Optional[List[str]]
def __init__(self, module: Union[None, str, ModuleType] = ...,
defaultTest: Union[str, Iterable[str], None] = ...,
argv: Optional[List[str]] = ...,
testRunner: Union[Type[_TestRunner], _TestRunner, None] = ...,
testLoader: unittest.loader.TestLoader = ...,
exit: bool = ..., verbosity: int = ...,
failfast: Optional[bool] = ..., catchbreak: Optional[bool] = ...,
buffer: Optional[bool] = ...,
warnings: Optional[str] = ..., *,
tb_locals: bool = ...) -> None: ...
def usageExit(self, msg: Any = ...) -> None: ...
def parseArgs(self, argv: List[str]) -> None: ...
if sys.version_info >= (3, 7):
def createTests(self, from_discovery: bool = ...,
Loader: Optional[unittest.loader.TestLoader] = ...) -> None: ...
else:
def createTests(self) -> None: ...
def runTests(self) -> None: ... # undocumented
main = TestProgram

View File

@@ -1,7 +1,7 @@
# Stubs for mock
import sys
from typing import Any, List, Optional, Sequence, Text, Tuple, Type, TypeVar
from typing import Any, List, Optional, Sequence, Text, Tuple, Type, TypeVar, Union
_T = TypeVar("_T")
@@ -70,6 +70,19 @@ class Base:
class NonCallableMock(Base, Any): # type: ignore
def __new__(cls, *args, **kw) -> NonCallableMock: ...
def __init__(self,
spec: Union[List[str], object, Type[object], None] = ...,
wraps: Optional[Any] = ...,
name: Optional[str] = ...,
spec_set: Union[List[str], object, Type[object], None] = ...,
parent: Optional[NonCallableMock] = ...,
_spec_state: Optional[Any] = ...,
_new_name: str = ...,
_new_parent: Optional[NonCallableMock] = ...,
_spec_as_instance: bool = ...,
_eat_self: Optional[bool] = ...,
unsafe: bool = ..., **kwargs) -> None: ...
def __getattr__(self, name: str) -> Any: ...
if sys.version_info >= (3, 8):
def _calls_repr(self, prefix: str = ...) -> str: ...
@@ -87,8 +100,10 @@ class NonCallableMock(Base, Any): # type: ignore
def _format_mock_failure_message(self, args: Any, kwargs: Any) -> str: ...
if sys.version_info >= (3, 8):
def assert_called(self) -> None: ...
def assert_called_once(self) -> None: ...
elif sys.version_info >= (3, 6):
def assert_called(_mock_self) -> None: ...
def assert_called_once(_mock_self) -> None: ...
if sys.version_info >= (3, 6):

View File

@@ -1,4 +1,4 @@
from typing import Any, Callable, List, Optional, Tuple, Type, TypeVar, Union
from typing import Any, Callable, List, Optional, TextIO, Tuple, Type, TypeVar, Union
from types import TracebackType
import unittest.case
@@ -25,6 +25,10 @@ class TestResult:
buffer: bool
failfast: bool
tb_locals: bool
def __init__(self, stream: Optional[TextIO] = ...,
descriptions: Optional[bool] = ...,
verbosity: Optional[int] = ...) -> None: ...
def printErrors(self) -> None: ...
def wasSuccessful(self) -> bool: ...
def stop(self) -> None: ...
def startTest(self, test: unittest.case.TestCase) -> None: ...
@@ -39,4 +43,4 @@ class TestResult:
err: _SysExcInfoType) -> None: ...
def addUnexpectedSuccess(self, test: unittest.case.TestCase) -> None: ...
def addSubTest(self, test: unittest.case.TestCase, subtest: unittest.case.TestCase,
outcome: Optional[_SysExcInfoType]) -> None: ...
err: Optional[_SysExcInfoType]) -> None: ...

View File

@@ -21,11 +21,8 @@ class TextTestResult(unittest.result.TestResult):
def printErrorList(self, flavour: str, errors: Tuple[unittest.case.TestCase, str]) -> None: ...
class TestRunner:
def run(self, test: Union[unittest.suite.TestSuite, unittest.case.TestCase]) -> unittest.result.TestResult: ...
class TextTestRunner(TestRunner):
class TextTestRunner(object):
resultclass: _ResultClassType
def __init__(
self,
stream: Optional[TextIO] = ...,
@@ -39,3 +36,4 @@ class TextTestRunner(TestRunner):
tb_locals: bool = ...,
) -> None: ...
def _makeResult(self) -> unittest.result.TestResult: ...
def run(self, test: Union[unittest.suite.TestSuite, unittest.case.TestCase]) -> unittest.result.TestResult: ...

View File

@@ -9,6 +9,6 @@ def installHandler() -> None: ...
def registerResult(result: unittest.result.TestResult) -> None: ...
def removeResult(result: unittest.result.TestResult) -> bool: ...
@overload
def removeHandler() -> None: ...
def removeHandler(method: None = ...) -> None: ...
@overload
def removeHandler(function: _F) -> _F: ...
def removeHandler(method: _F) -> _F: ...

View File

@@ -19,4 +19,6 @@ class BaseTestSuite(Iterable[_TestType]):
def __iter__(self) -> Iterator[_TestType]: ...
class TestSuite(BaseTestSuite): ...
class TestSuite(BaseTestSuite):
def run(self, result: unittest.result.TestResult,
debug: bool = ...) -> unittest.result.TestResult: ...

View File

@@ -1,7 +1,7 @@
# Stubs for mock
import sys
from typing import Any, List, Optional, Sequence, Text, Tuple, Type, TypeVar
from typing import Any, List, Optional, Sequence, Text, Tuple, Type, TypeVar, Union
_T = TypeVar("_T")
@@ -70,6 +70,19 @@ class Base:
class NonCallableMock(Base, Any): # type: ignore
def __new__(cls, *args, **kw) -> NonCallableMock: ...
def __init__(self,
spec: Union[List[str], object, Type[object], None] = ...,
wraps: Optional[Any] = ...,
name: Optional[str] = ...,
spec_set: Union[List[str], object, Type[object], None] = ...,
parent: Optional[NonCallableMock] = ...,
_spec_state: Optional[Any] = ...,
_new_name: str = ...,
_new_parent: Optional[NonCallableMock] = ...,
_spec_as_instance: bool = ...,
_eat_self: Optional[bool] = ...,
unsafe: bool = ..., **kwargs) -> None: ...
def __getattr__(self, name: str) -> Any: ...
if sys.version_info >= (3, 8):
def _calls_repr(self, prefix: str = ...) -> str: ...
@@ -87,8 +100,10 @@ class NonCallableMock(Base, Any): # type: ignore
def _format_mock_failure_message(self, args: Any, kwargs: Any) -> str: ...
if sys.version_info >= (3, 8):
def assert_called(self) -> None: ...
def assert_called_once(self) -> None: ...
elif sys.version_info >= (3, 6):
def assert_called(_mock_self) -> None: ...
def assert_called_once(_mock_self) -> None: ...
if sys.version_info >= (3, 6):