From cf16720f287c6a0535518cfd5e797cd04b3aa99a Mon Sep 17 00:00:00 2001 From: Adam Johnson Date: Fri, 26 Aug 2022 11:09:41 +0100 Subject: [PATCH] Improve types for DiscoverRunner (#1106) * Improve types for DiscoverRunner * add types for class-level attrs * Add extra methods * Add PDBDebugResult class * black * Update django-stubs/test/runner.pyi * fix load_with_patterns return Co-authored-by: Nikita Sobolev --- django-stubs/test/runner.pyi | 37 ++++++++++++++++++++++++++++-------- 1 file changed, 29 insertions(+), 8 deletions(-) diff --git a/django-stubs/test/runner.pyi b/django-stubs/test/runner.pyi index fc3adfe..694afea 100644 --- a/django-stubs/test/runner.pyi +++ b/django-stubs/test/runner.pyi @@ -1,12 +1,15 @@ import logging from argparse import ArgumentParser +from contextlib import contextmanager from io import StringIO -from typing import Any, Dict, List, Optional, Sequence, Set, Tuple, Type -from unittest import TestCase, TestSuite, TextTestResult +from typing import Any, Dict, Iterator, List, Optional, Sequence, Set, Tuple, Type, Union +from unittest import TestCase, TestLoader, TestSuite, TextTestResult, TextTestRunner from django.db.backends.base.base import BaseDatabaseWrapper from django.test.testcases import SimpleTestCase, TestCase +from django.test.utils import TimeKeeperProtocol from django.utils.datastructures import OrderedSet +from typing_extensions import Literal class DebugSQLTextTestResult(TextTestResult): buffer: bool @@ -33,6 +36,8 @@ class DebugSQLTextTestResult(TextTestResult): def addError(self, test: Any, err: Any) -> None: ... def addFailure(self, test: Any, err: Any) -> None: ... +class PDBDebugResult(TextTestResult): ... + class RemoteTestResult: events: List[Any] = ... failfast: bool = ... @@ -77,11 +82,11 @@ class ParallelTestSuite(TestSuite): def run(self, result: Any): ... # type: ignore[override] class DiscoverRunner: - test_suite: Any = ... - parallel_test_suite: Any = ... - test_runner: Any = ... - test_loader: Any = ... - reorder_by: Any = ... + test_suite: Type[TestSuite] = ... + parallel_test_suite: Type[ParallelTestSuite] = ... + test_runner: Type[TextTestRunner] = ... + test_loader: TestLoader = ... + reorder_by: Tuple[SimpleTestCase, ...] = ... pattern: Optional[str] = ... top_level: Optional[str] = ... verbosity: int = ... @@ -94,6 +99,12 @@ class DiscoverRunner: parallel: int = ... tags: Set[str] = ... exclude_tags: Set[str] = ... + pdb: bool = ... + buffer: bool = ... + test_name_patterns: Optional[Set[str]] = ... + time_keeper: TimeKeeperProtocol = ... + shuffle: Union[int, Literal[False]] = ... + logger: Optional[logging.Logger] = ... def __init__( self, pattern: Optional[str] = ..., @@ -113,22 +124,32 @@ class DiscoverRunner: buffer: bool = ..., enable_faulthandler: bool = ..., timing: bool = ..., + shuffle: Union[int, Literal[False]] = ..., + logger: Optional[logging.Logger] = ..., **kwargs: Any ) -> None: ... @classmethod def add_arguments(cls, parser: ArgumentParser) -> None: ... + @property + def shuffle_seed(self) -> Optional[int]: ... + def log(self, msg: str, level: Optional[int]) -> None: ... def setup_test_environment(self, **kwargs: Any) -> None: ... + def setup_shuffler(self) -> None: ... + @contextmanager + def load_with_patterns(self) -> Iterator[None]: ... + def load_tests_for_label(self, label: str, discover_kwargs: Dict[str, str]) -> TestSuite: ... def build_suite( self, test_labels: Sequence[str] = ..., extra_tests: Optional[List[Any]] = ..., **kwargs: Any ) -> TestSuite: ... def setup_databases(self, **kwargs: Any) -> List[Tuple[BaseDatabaseWrapper, str, bool]]: ... def get_resultclass(self) -> Optional[Type[TextTestResult]]: ... - def get_test_runner_kwargs(self) -> Dict[str, Optional[int]]: ... + def get_test_runner_kwargs(self) -> Dict[str, Any]: ... def run_checks(self, databases: Set[str]) -> None: ... def run_suite(self, suite: TestSuite, **kwargs: Any) -> TextTestResult: ... def teardown_databases(self, old_config: List[Tuple[BaseDatabaseWrapper, str, bool]], **kwargs: Any) -> None: ... def teardown_test_environment(self, **kwargs: Any) -> None: ... def suite_result(self, suite: TestSuite, result: TextTestResult, **kwargs: Any) -> int: ... + def _get_databases(self, suite: TestSuite) -> Set[str]: ... def get_databases(self, suite: TestSuite) -> Set[str]: ... def run_tests(self, test_labels: List[str], extra_tests: Optional[List[Any]] = ..., **kwargs: Any) -> int: ...