From 40bf7a562d5f02a68aafdca06003e325ad658c1d Mon Sep 17 00:00:00 2001 From: Xavier Francisco <98830734+XF-FW@users.noreply.github.com> Date: Wed, 13 Apr 2022 09:56:57 +0000 Subject: [PATCH] Add missing type for setup_databases (#919) * Add missing type for setup_databases `django.test.utils` was seemingly missing the type for `setup_databases`. This change resolves my issue locally. The type was copied directly from `django.test.runner`. * Fix typecheck_tests runner Co-authored-by: Xavier Francisco --- django-stubs/test/utils.pyi | 17 +++++++++++++++++ scripts/git_helpers.py | 2 +- 2 files changed, 18 insertions(+), 1 deletion(-) diff --git a/django-stubs/test/utils.pyi b/django-stubs/test/utils.pyi index 3e194f2..25df3c6 100644 --- a/django-stubs/test/utils.pyi +++ b/django-stubs/test/utils.pyi @@ -12,6 +12,7 @@ from typing import ( List, Mapping, Optional, + Protocol, Set, Tuple, Type, @@ -139,12 +140,28 @@ def tag(*tags: str): ... _Signature = str _TestDatabase = Tuple[str, List[str]] +class TimeKeeperProtocol(Protocol): + @contextmanager + def timed(self, name: Any) -> Iterator[None]: ... + def print_results(self) -> None: ... + def dependency_ordered( test_databases: Iterable[Tuple[_Signature, _TestDatabase]], dependencies: Mapping[str, List[str]] ) -> List[Tuple[_Signature, _TestDatabase]]: ... def get_unique_databases_and_mirrors( aliases: Optional[Set[str]] = ..., ) -> Tuple[Dict[_Signature, _TestDatabase], Dict[str, Any]]: ... +def setup_databases( + verbosity: int, + interactive: bool, + *, + time_keeper: Optional[TimeKeeperProtocol] = ..., + keepdb: bool = ..., + debug_sql: bool = ..., + parallel: int = ..., + aliases: Optional[Mapping[str, Any]] = ..., + **kwargs: Any +) -> List[Tuple[BaseDatabaseWrapper, str, bool]]: ... def teardown_databases( old_config: Iterable[Tuple[Any, str, bool]], verbosity: int, parallel: int = ..., keepdb: bool = ... ) -> None: ... diff --git a/scripts/git_helpers.py b/scripts/git_helpers.py index a6bee59..48f47e4 100644 --- a/scripts/git_helpers.py +++ b/scripts/git_helpers.py @@ -28,4 +28,4 @@ def checkout_django_branch(django_version: str, commit_sha: Optional[str]) -> Re ) if commit_sha and repo.head.commit.hexsha != commit_sha: repo.remote("origin").fetch(branch, progress=ProgressPrinter(), depth=100) - repo.git.checkout(commit_sha) + repo.git.checkout(branch)