From 9b3edda33bac20534cf15ca8f8bed1df71bae7ac Mon Sep 17 00:00:00 2001 From: Milap Sheth Date: Sun, 21 Jun 2020 18:09:17 -0400 Subject: [PATCH] Fix unittest stub issues reported by mypy stubtest (#4248) Co-authored-by: Shantanu <> --- stdlib/3/unittest/__init__.pyi | 21 ++------------ stdlib/3/unittest/case.pyi | 40 ++++++++++++++++---------- stdlib/3/unittest/loader.pyi | 30 ++++++++++++++++++-- stdlib/3/unittest/main.pyi | 51 ++++++++++++++++++++++++++++++++++ stdlib/3/unittest/mock.pyi | 17 +++++++++++- stdlib/3/unittest/result.pyi | 8 ++++-- stdlib/3/unittest/runner.pyi | 8 ++---- stdlib/3/unittest/signals.pyi | 4 +-- stdlib/3/unittest/suite.pyi | 4 ++- third_party/2and3/mock.pyi | 17 +++++++++++- 10 files changed, 152 insertions(+), 48 deletions(-) create mode 100644 stdlib/3/unittest/main.pyi diff --git a/stdlib/3/unittest/__init__.pyi b/stdlib/3/unittest/__init__.pyi index 2580f76a3..1c8c66804 100644 --- a/stdlib/3/unittest/__init__.pyi +++ b/stdlib/3/unittest/__init__.pyi @@ -1,32 +1,15 @@ # Stubs for unittest -from typing import Iterable, List, Optional, Type, Union -from types import ModuleType +from typing import Optional from unittest.async_case import * from unittest.case import * from unittest.loader import * +from unittest.main import * from unittest.result import TestResult as TestResult from unittest.runner import * from unittest.signals import * from unittest.suite import * - -# not really documented -class TestProgram: - result: TestResult - def runTests(self) -> None: ... # undocumented - - -def main(module: Union[None, str, ModuleType] = ..., - defaultTest: Union[str, Iterable[str], None] = ..., - argv: Optional[List[str]] = ..., - testRunner: Union[Type[TestRunner], TestRunner, None] = ..., - testLoader: TestLoader = ..., exit: bool = ..., verbosity: int = ..., - failfast: Optional[bool] = ..., catchbreak: Optional[bool] = ..., - buffer: Optional[bool] = ..., - warnings: Optional[str] = ...) -> TestProgram: ... - - def load_tests(loader: TestLoader, tests: TestSuite, pattern: Optional[str]) -> TestSuite: ... diff --git a/stdlib/3/unittest/case.pyi b/stdlib/3/unittest/case.pyi index 01ae808f2..2c9de24c4 100644 --- a/stdlib/3/unittest/case.pyi +++ b/stdlib/3/unittest/case.pyi @@ -17,7 +17,7 @@ if sys.version_info >= (3, 8): def addModuleCleanup(__function: Callable[..., Any], *args: Any, **kwargs: Any) -> None: ... def doModuleCleanups() -> None: ... -def expectedFailure(func: _FT) -> _FT: ... +def expectedFailure(test_item: _FT) -> _FT: ... def skip(reason: str) -> Callable[[_FT], _FT]: ... def skipIf(condition: object, reason: str) -> Callable[[_FT], _FT]: ... def skipUnless(condition: object, reason: str) -> Callable[[_FT], _FT]: ... @@ -120,25 +120,25 @@ class TestCase: level: Union[int, str, None] = ... ) -> _AssertLogsContext: ... @overload - def assertAlmostEqual(self, first: float, second: float, places: int = ..., - msg: Any = ..., delta: float = ...) -> None: ... + def assertAlmostEqual(self, first: float, second: float, places: Optional[int] = ..., + msg: Any = ..., delta: Optional[float] = ...) -> None: ... @overload def assertAlmostEqual(self, first: datetime.datetime, second: datetime.datetime, - places: int = ..., msg: Any = ..., - delta: datetime.timedelta = ...) -> None: ... + places: Optional[int] = ..., msg: Any = ..., + delta: Optional[datetime.timedelta] = ...) -> None: ... @overload def assertNotAlmostEqual(self, first: float, second: float, *, msg: Any = ...) -> None: ... @overload def assertNotAlmostEqual(self, first: float, second: float, - places: int = ..., msg: Any = ...) -> None: ... + places: Optional[int] = ..., msg: Any = ...) -> None: ... @overload def assertNotAlmostEqual(self, first: float, second: float, *, - msg: Any = ..., delta: float = ...) -> None: ... + msg: Any = ..., delta: Optional[float] = ...) -> None: ... @overload def assertNotAlmostEqual(self, first: datetime.datetime, second: datetime.datetime, - places: int = ..., msg: Any = ..., - delta: datetime.timedelta = ...) -> None: ... + places: Optional[int] = ..., msg: Any = ..., + delta: Optional[datetime.timedelta] = ...) -> None: ... def assertRegex(self, text: AnyStr, expected_regex: Union[AnyStr, Pattern[AnyStr]], msg: Any = ...) -> None: ... def assertNotRegex(self, text: AnyStr, unexpected_regex: Union[AnyStr, Pattern[AnyStr]], @@ -151,7 +151,7 @@ class TestCase: msg: Any = ...) -> None: ... def assertSequenceEqual(self, seq1: Sequence[Any], seq2: Sequence[Any], msg: Any = ..., - seq_type: Type[Sequence[Any]] = ...) -> None: ... + seq_type: Optional[Type[Sequence[Any]]] = ...) -> None: ... def assertListEqual(self, list1: List[Any], list2: List[Any], msg: Any = ...) -> None: ... def assertTupleEqual(self, tuple1: Tuple[Any, ...], tuple2: Tuple[Any, ...], @@ -165,8 +165,14 @@ class TestCase: def defaultTestResult(self) -> unittest.result.TestResult: ... def id(self) -> str: ... def shortDescription(self) -> Optional[str]: ... - def addCleanup(self, function: Callable[..., Any], *args: Any, - **kwargs: Any) -> None: ... + + if sys.version_info >= (3, 8): + def addCleanup(self, __function: Callable[..., Any], *args: Any, + **kwargs: Any) -> None: ... + else: + def addCleanup(self, function: Callable[..., Any], *args: Any, + **kwargs: Any) -> None: ... + def doCleanups(self) -> None: ... if sys.version_info >= (3, 8): @classmethod @@ -205,6 +211,8 @@ class TestCase: delta: float = ...) -> None: ... def assertRegexpMatches(self, text: AnyStr, regex: Union[AnyStr, Pattern[AnyStr]], msg: Any = ...) -> None: ... + def assertNotRegexpMatches(self, text: AnyStr, regex: Union[AnyStr, Pattern[AnyStr]], + msg: Any = ...) -> None: ... @overload def assertRaisesRegexp(self, # type: ignore exception: Union[Type[BaseException], Tuple[Type[BaseException], ...]], @@ -217,8 +225,8 @@ class TestCase: expected_regex: Union[str, bytes, Pattern[str], Pattern[bytes]], msg: Any = ...) -> _AssertRaisesContext[_E]: ... def assertDictContainsSubset(self, - expected: Mapping[Any, Any], - actual: Mapping[Any, Any], + subset: Mapping[Any, Any], + dictionary: Mapping[Any, Any], msg: object = ...) -> None: ... class FunctionTestCase(TestCase): @@ -226,6 +234,7 @@ class FunctionTestCase(TestCase): setUp: Optional[Callable[[], None]] = ..., tearDown: Optional[Callable[[], None]] = ..., description: Optional[str] = ...) -> None: ... + def runTest(self) -> None: ... class _AssertRaisesContext(Generic[_E]): exception: _E @@ -243,8 +252,11 @@ class _AssertWarnsContext: exc_tb: Optional[TracebackType]) -> None: ... class _AssertLogsContext: + LOGGING_FORMAT: str records: List[logging.LogRecord] output: List[str] + def __init__(self, test_case: unittest.case.TestCase, + logger_name: str, level: int) -> None: ... def __enter__(self) -> _AssertLogsContext: ... def __exit__(self, exc_type: Optional[Type[BaseException]], exc_val: Optional[BaseException], exc_tb: Optional[TracebackType]) -> Optional[bool]: ... diff --git a/stdlib/3/unittest/loader.pyi b/stdlib/3/unittest/loader.pyi index 2371778f9..b8db7e007 100644 --- a/stdlib/3/unittest/loader.pyi +++ b/stdlib/3/unittest/loader.pyi @@ -1,3 +1,4 @@ +import sys import unittest.case import unittest.suite import unittest.result @@ -5,14 +6,21 @@ from types import ModuleType from typing import Any, Callable, List, Optional, Sequence, Type +_SortComparisonMethod = Callable[[str, str], int] +_SuiteClass = Callable[[List[unittest.case.TestCase]], unittest.suite.TestSuite] + class TestLoader: errors: List[Type[BaseException]] testMethodPrefix: str - sortTestMethodsUsing: Callable[[str, str], bool] - suiteClass: Callable[[List[unittest.case.TestCase]], unittest.suite.TestSuite] + sortTestMethodsUsing: _SortComparisonMethod + + if sys.version_info >= (3, 7): + testNamePatterns: Optional[List[str]] + + suiteClass: _SuiteClass def loadTestsFromTestCase(self, testCaseClass: Type[unittest.case.TestCase]) -> unittest.suite.TestSuite: ... - def loadTestsFromModule(self, module: ModuleType, *, pattern: Any = ...) -> unittest.suite.TestSuite: ... + def loadTestsFromModule(self, module: ModuleType, *args: Any, pattern: Any = ...) -> unittest.suite.TestSuite: ... def loadTestsFromName(self, name: str, module: Optional[ModuleType] = ...) -> unittest.suite.TestSuite: ... def loadTestsFromNames(self, names: Sequence[str], @@ -23,3 +31,19 @@ class TestLoader: top_level_dir: Optional[str] = ...) -> unittest.suite.TestSuite: ... defaultTestLoader: TestLoader + + +if sys.version_info >= (3, 7): + def getTestCaseNames(testCaseClass: Type[unittest.case.TestCase], prefix: str, + sortUsing: _SortComparisonMethod = ..., + testNamePatterns: Optional[List[str]] = ...) -> Sequence[str]: ... +else: + def getTestCaseNames(testCaseClass: Type[unittest.case.TestCase], prefix: str, + sortUsing: _SortComparisonMethod = ...) -> Sequence[str]: ... + +def makeSuite(testCaseClass: Type[unittest.case.TestCase], prefix: str = ..., + sortUsing: _SortComparisonMethod = ..., + suiteClass: _SuiteClass = ...) -> unittest.suite.TestSuite: ... + +def findTestCases(module, prefix: str = ..., sortUsing: _SortComparisonMethod = ..., + suiteClass: _SuiteClass = ...) -> unittest.suite.TestSuite: ... diff --git a/stdlib/3/unittest/main.pyi b/stdlib/3/unittest/main.pyi new file mode 100644 index 000000000..d6e1ebdbe --- /dev/null +++ b/stdlib/3/unittest/main.pyi @@ -0,0 +1,51 @@ +import sys +from typing import Any, Iterable, List, Optional, Protocol, Type, Union +from types import ModuleType + +import unittest.case +import unittest.loader +import unittest.result +import unittest.suite + +class _TestRunner(Protocol): + def run( + self, test: Union[unittest.suite.TestSuite, unittest.case.TestCase] + ) -> unittest.result.TestResult: ... + + +# not really documented +class TestProgram: + result: unittest.result.TestResult + module: Union[None, str, ModuleType] + verbosity: int + failfast: Optional[bool] + catchbreak: Optional[bool] + buffer: Optional[bool] + progName: Optional[str] + warnings: Optional[str] + + if sys.version_info >= (3, 7): + testNamePatterns: Optional[List[str]] + + def __init__(self, module: Union[None, str, ModuleType] = ..., + defaultTest: Union[str, Iterable[str], None] = ..., + argv: Optional[List[str]] = ..., + testRunner: Union[Type[_TestRunner], _TestRunner, None] = ..., + testLoader: unittest.loader.TestLoader = ..., + exit: bool = ..., verbosity: int = ..., + failfast: Optional[bool] = ..., catchbreak: Optional[bool] = ..., + buffer: Optional[bool] = ..., + warnings: Optional[str] = ..., *, + tb_locals: bool = ...) -> None: ... + def usageExit(self, msg: Any = ...) -> None: ... + def parseArgs(self, argv: List[str]) -> None: ... + + if sys.version_info >= (3, 7): + def createTests(self, from_discovery: bool = ..., + Loader: Optional[unittest.loader.TestLoader] = ...) -> None: ... + else: + def createTests(self) -> None: ... + + def runTests(self) -> None: ... # undocumented + +main = TestProgram diff --git a/stdlib/3/unittest/mock.pyi b/stdlib/3/unittest/mock.pyi index 0d2cd387b..1f38f28a0 100644 --- a/stdlib/3/unittest/mock.pyi +++ b/stdlib/3/unittest/mock.pyi @@ -1,7 +1,7 @@ # Stubs for mock import sys -from typing import Any, List, Optional, Sequence, Text, Tuple, Type, TypeVar +from typing import Any, List, Optional, Sequence, Text, Tuple, Type, TypeVar, Union _T = TypeVar("_T") @@ -70,6 +70,19 @@ class Base: class NonCallableMock(Base, Any): # type: ignore + def __new__(cls, *args, **kw) -> NonCallableMock: ... + def __init__(self, + spec: Union[List[str], object, Type[object], None] = ..., + wraps: Optional[Any] = ..., + name: Optional[str] = ..., + spec_set: Union[List[str], object, Type[object], None] = ..., + parent: Optional[NonCallableMock] = ..., + _spec_state: Optional[Any] = ..., + _new_name: str = ..., + _new_parent: Optional[NonCallableMock] = ..., + _spec_as_instance: bool = ..., + _eat_self: Optional[bool] = ..., + unsafe: bool = ..., **kwargs) -> None: ... def __getattr__(self, name: str) -> Any: ... if sys.version_info >= (3, 8): def _calls_repr(self, prefix: str = ...) -> str: ... @@ -87,8 +100,10 @@ class NonCallableMock(Base, Any): # type: ignore def _format_mock_failure_message(self, args: Any, kwargs: Any) -> str: ... if sys.version_info >= (3, 8): + def assert_called(self) -> None: ... def assert_called_once(self) -> None: ... elif sys.version_info >= (3, 6): + def assert_called(_mock_self) -> None: ... def assert_called_once(_mock_self) -> None: ... if sys.version_info >= (3, 6): diff --git a/stdlib/3/unittest/result.pyi b/stdlib/3/unittest/result.pyi index 8fcbc472f..b6ba1c48b 100644 --- a/stdlib/3/unittest/result.pyi +++ b/stdlib/3/unittest/result.pyi @@ -1,4 +1,4 @@ -from typing import Any, Callable, List, Optional, Tuple, Type, TypeVar, Union +from typing import Any, Callable, List, Optional, TextIO, Tuple, Type, TypeVar, Union from types import TracebackType import unittest.case @@ -25,6 +25,10 @@ class TestResult: buffer: bool failfast: bool tb_locals: bool + def __init__(self, stream: Optional[TextIO] = ..., + descriptions: Optional[bool] = ..., + verbosity: Optional[int] = ...) -> None: ... + def printErrors(self) -> None: ... def wasSuccessful(self) -> bool: ... def stop(self) -> None: ... def startTest(self, test: unittest.case.TestCase) -> None: ... @@ -39,4 +43,4 @@ class TestResult: err: _SysExcInfoType) -> None: ... def addUnexpectedSuccess(self, test: unittest.case.TestCase) -> None: ... def addSubTest(self, test: unittest.case.TestCase, subtest: unittest.case.TestCase, - outcome: Optional[_SysExcInfoType]) -> None: ... + err: Optional[_SysExcInfoType]) -> None: ... diff --git a/stdlib/3/unittest/runner.pyi b/stdlib/3/unittest/runner.pyi index ffe4e323a..21423fa99 100644 --- a/stdlib/3/unittest/runner.pyi +++ b/stdlib/3/unittest/runner.pyi @@ -21,11 +21,8 @@ class TextTestResult(unittest.result.TestResult): def printErrorList(self, flavour: str, errors: Tuple[unittest.case.TestCase, str]) -> None: ... -class TestRunner: - def run(self, test: Union[unittest.suite.TestSuite, unittest.case.TestCase]) -> unittest.result.TestResult: ... - - -class TextTestRunner(TestRunner): +class TextTestRunner(object): + resultclass: _ResultClassType def __init__( self, stream: Optional[TextIO] = ..., @@ -39,3 +36,4 @@ class TextTestRunner(TestRunner): tb_locals: bool = ..., ) -> None: ... def _makeResult(self) -> unittest.result.TestResult: ... + def run(self, test: Union[unittest.suite.TestSuite, unittest.case.TestCase]) -> unittest.result.TestResult: ... diff --git a/stdlib/3/unittest/signals.pyi b/stdlib/3/unittest/signals.pyi index a4616b779..b925d45c3 100644 --- a/stdlib/3/unittest/signals.pyi +++ b/stdlib/3/unittest/signals.pyi @@ -9,6 +9,6 @@ def installHandler() -> None: ... def registerResult(result: unittest.result.TestResult) -> None: ... def removeResult(result: unittest.result.TestResult) -> bool: ... @overload -def removeHandler() -> None: ... +def removeHandler(method: None = ...) -> None: ... @overload -def removeHandler(function: _F) -> _F: ... +def removeHandler(method: _F) -> _F: ... diff --git a/stdlib/3/unittest/suite.pyi b/stdlib/3/unittest/suite.pyi index 54e9e69e9..b31804e20 100644 --- a/stdlib/3/unittest/suite.pyi +++ b/stdlib/3/unittest/suite.pyi @@ -19,4 +19,6 @@ class BaseTestSuite(Iterable[_TestType]): def __iter__(self) -> Iterator[_TestType]: ... -class TestSuite(BaseTestSuite): ... +class TestSuite(BaseTestSuite): + def run(self, result: unittest.result.TestResult, + debug: bool = ...) -> unittest.result.TestResult: ... diff --git a/third_party/2and3/mock.pyi b/third_party/2and3/mock.pyi index 0d2cd387b..1f38f28a0 100644 --- a/third_party/2and3/mock.pyi +++ b/third_party/2and3/mock.pyi @@ -1,7 +1,7 @@ # Stubs for mock import sys -from typing import Any, List, Optional, Sequence, Text, Tuple, Type, TypeVar +from typing import Any, List, Optional, Sequence, Text, Tuple, Type, TypeVar, Union _T = TypeVar("_T") @@ -70,6 +70,19 @@ class Base: class NonCallableMock(Base, Any): # type: ignore + def __new__(cls, *args, **kw) -> NonCallableMock: ... + def __init__(self, + spec: Union[List[str], object, Type[object], None] = ..., + wraps: Optional[Any] = ..., + name: Optional[str] = ..., + spec_set: Union[List[str], object, Type[object], None] = ..., + parent: Optional[NonCallableMock] = ..., + _spec_state: Optional[Any] = ..., + _new_name: str = ..., + _new_parent: Optional[NonCallableMock] = ..., + _spec_as_instance: bool = ..., + _eat_self: Optional[bool] = ..., + unsafe: bool = ..., **kwargs) -> None: ... def __getattr__(self, name: str) -> Any: ... if sys.version_info >= (3, 8): def _calls_repr(self, prefix: str = ...) -> str: ... @@ -87,8 +100,10 @@ class NonCallableMock(Base, Any): # type: ignore def _format_mock_failure_message(self, args: Any, kwargs: Any) -> str: ... if sys.version_info >= (3, 8): + def assert_called(self) -> None: ... def assert_called_once(self) -> None: ... elif sys.version_info >= (3, 6): + def assert_called(_mock_self) -> None: ... def assert_called_once(_mock_self) -> None: ... if sys.version_info >= (3, 6):